add db异步转换器

This commit is contained in:
jxxghp
2025-07-30 08:54:04 +08:00
parent f077a9684b
commit 0053d31f84
2 changed files with 140 additions and 1 deletions

136
app/db/async_adapter.py Normal file
View File

@@ -0,0 +1,136 @@
import asyncio
from concurrent.futures import ThreadPoolExecutor
from functools import wraps, partial
from typing import Any, Callable, Coroutine, TypeVar
from sqlalchemy.orm import Session
from app.db import ScopedSession
T = TypeVar('T')
# 全局线程池,用于执行同步数据库操作
_db_executor = ThreadPoolExecutor(max_workers=10, thread_name_prefix="db_async")
def async_db_operation(func: Callable[..., T]) -> Callable[..., Coroutine[Any, Any, T]]:
"""
将同步数据库操作转换为异步操作的装饰器
"""
@wraps(func)
async def wrapper(*args, **kwargs) -> T:
# 在线程池中执行同步数据库操作
loop = asyncio.get_event_loop()
partial_func = partial(func, *args, **kwargs)
return await loop.run_in_executor(_db_executor, partial_func, ())
return wrapper
def async_db_session(func: Callable[..., T]) -> Callable[..., Coroutine[Any, Any, T]]:
"""
为异步操作提供数据库会话的装饰器
"""
@wraps(func)
async def wrapper(*args, **kwargs) -> T:
# 创建数据库会话
db = ScopedSession()
try:
# 将数据库会话添加到参数中
if 'db' not in kwargs:
kwargs['db'] = db
# 在线程池中执行同步数据库操作
loop = asyncio.get_event_loop()
partial_func = partial(func, *args, **kwargs)
return await loop.run_in_executor(_db_executor, partial_func, ())
finally:
db.close()
return wrapper
class AsyncDbOper:
"""
异步数据库操作基类
"""
def __init__(self, db: Session = None):
self._db = db
async def _get_db(self) -> Session:
"""
获取数据库会话
"""
if self._db:
return self._db
return ScopedSession()
@staticmethod
async def _execute_sync(func: Callable[..., T], *args, **kwargs) -> T:
"""
在线程池中执行同步数据库操作
"""
loop = asyncio.get_event_loop()
partial_func = partial(func, *args, **kwargs)
return await loop.run_in_executor(_db_executor, partial_func, ())
def to_async_db_oper(sync_oper_class):
"""
将同步数据库操作类转换为异步版本的装饰器
"""
class AsyncOperClass(AsyncDbOper):
def __init__(self, db: Session = None):
super().__init__(db)
self._sync_oper = sync_oper_class(db)
def __getattr__(self, name):
"""动态获取同步操作类的方法并转换为异步"""
if hasattr(self._sync_oper, name):
method = getattr(self._sync_oper, name)
if callable(method):
return async_db_operation(method)
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
return AsyncOperClass
# 异步数据库会话获取函数
async def get_async_db():
"""
异步获取数据库会话
"""
def _get_db():
return ScopedSession()
loop = asyncio.get_event_loop()
return await loop.run_in_executor(_db_executor, _get_db) # type: ignore
# 异步上下文管理器
class AsyncDbSession:
"""
异步数据库会话上下文管理器
"""
def __init__(self):
self.db = None
async def __aenter__(self):
self.db = await get_async_db()
return self.db
async def __aexit__(self, exc_type, exc_val, exc_tb):
if self.db:
self.db.close()
def shutdown_db_executor():
"""关闭数据库线程池"""
global _db_executor
if _db_executor:
_db_executor.shutdown(wait=True)

View File

@@ -4,6 +4,8 @@ from contextlib import asynccontextmanager
from fastapi import FastAPI
from app.chain.system import SystemChain
from app.db.async_adapter import shutdown_db_executor
from app.helper.system import SystemHelper
from app.startup.command_initializer import init_command, stop_command, restart_command
from app.startup.modules_initializer import init_modules, stop_modules
from app.startup.monitor_initializer import stop_monitor, init_monitor
@@ -11,7 +13,6 @@ from app.startup.plugins_initializer import init_plugins, stop_plugins, sync_plu
from app.startup.routers_initializer import init_routers
from app.startup.scheduler_initializer import stop_scheduler, init_scheduler, init_plugin_scheduler
from app.startup.workflow_initializer import init_workflow, stop_workflow
from app.helper.system import SystemHelper
async def init_extra():
@@ -80,3 +81,5 @@ async def lifespan(app: FastAPI):
stop_plugins()
# 停止模块
stop_modules()
# 关闭数据库异步执行器
shutdown_db_executor()