diff --git a/app/db/async_adapter.py b/app/db/async_adapter.py new file mode 100644 index 00000000..6ce86d35 --- /dev/null +++ b/app/db/async_adapter.py @@ -0,0 +1,136 @@ +import asyncio +from concurrent.futures import ThreadPoolExecutor +from functools import wraps, partial +from typing import Any, Callable, Coroutine, TypeVar + +from sqlalchemy.orm import Session + +from app.db import ScopedSession + +T = TypeVar('T') + +# 全局线程池,用于执行同步数据库操作 +_db_executor = ThreadPoolExecutor(max_workers=10, thread_name_prefix="db_async") + + +def async_db_operation(func: Callable[..., T]) -> Callable[..., Coroutine[Any, Any, T]]: + """ + 将同步数据库操作转换为异步操作的装饰器 + """ + + @wraps(func) + async def wrapper(*args, **kwargs) -> T: + # 在线程池中执行同步数据库操作 + loop = asyncio.get_event_loop() + partial_func = partial(func, *args, **kwargs) + return await loop.run_in_executor(_db_executor, partial_func, ()) + + return wrapper + + +def async_db_session(func: Callable[..., T]) -> Callable[..., Coroutine[Any, Any, T]]: + """ + 为异步操作提供数据库会话的装饰器 + """ + + @wraps(func) + async def wrapper(*args, **kwargs) -> T: + # 创建数据库会话 + db = ScopedSession() + try: + # 将数据库会话添加到参数中 + if 'db' not in kwargs: + kwargs['db'] = db + # 在线程池中执行同步数据库操作 + loop = asyncio.get_event_loop() + partial_func = partial(func, *args, **kwargs) + return await loop.run_in_executor(_db_executor, partial_func, ()) + finally: + db.close() + + return wrapper + + +class AsyncDbOper: + """ + 异步数据库操作基类 + """ + + def __init__(self, db: Session = None): + self._db = db + + async def _get_db(self) -> Session: + """ + 获取数据库会话 + """ + if self._db: + return self._db + return ScopedSession() + + @staticmethod + async def _execute_sync(func: Callable[..., T], *args, **kwargs) -> T: + """ + 在线程池中执行同步数据库操作 + """ + loop = asyncio.get_event_loop() + partial_func = partial(func, *args, **kwargs) + return await loop.run_in_executor(_db_executor, partial_func, ()) + + +def to_async_db_oper(sync_oper_class): + """ + 将同步数据库操作类转换为异步版本的装饰器 + """ + + class AsyncOperClass(AsyncDbOper): + def __init__(self, db: Session = None): + super().__init__(db) + self._sync_oper = sync_oper_class(db) + + def __getattr__(self, name): + """动态获取同步操作类的方法并转换为异步""" + if hasattr(self._sync_oper, name): + method = getattr(self._sync_oper, name) + if callable(method): + return async_db_operation(method) + raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'") + + return AsyncOperClass + + +# 异步数据库会话获取函数 +async def get_async_db(): + """ + 异步获取数据库会话 + """ + + def _get_db(): + return ScopedSession() + + loop = asyncio.get_event_loop() + return await loop.run_in_executor(_db_executor, _get_db) # type: ignore + + +# 异步上下文管理器 +class AsyncDbSession: + """ + 异步数据库会话上下文管理器 + """ + + def __init__(self): + self.db = None + + async def __aenter__(self): + self.db = await get_async_db() + return self.db + + async def __aexit__(self, exc_type, exc_val, exc_tb): + if self.db: + self.db.close() + + +def shutdown_db_executor(): + """关闭数据库线程池""" + global _db_executor + if _db_executor: + _db_executor.shutdown(wait=True) diff --git a/app/startup/lifecycle.py b/app/startup/lifecycle.py index 3a50e8be..04ac2a78 100644 --- a/app/startup/lifecycle.py +++ b/app/startup/lifecycle.py @@ -4,6 +4,8 @@ from contextlib import asynccontextmanager from fastapi import FastAPI from app.chain.system import SystemChain +from app.db.async_adapter import shutdown_db_executor +from app.helper.system import SystemHelper from app.startup.command_initializer import init_command, stop_command, restart_command from app.startup.modules_initializer import init_modules, stop_modules from app.startup.monitor_initializer import stop_monitor, init_monitor @@ -11,7 +13,6 @@ from app.startup.plugins_initializer import init_plugins, stop_plugins, sync_plu from app.startup.routers_initializer import init_routers from app.startup.scheduler_initializer import stop_scheduler, init_scheduler, init_plugin_scheduler from app.startup.workflow_initializer import init_workflow, stop_workflow -from app.helper.system import SystemHelper async def init_extra(): @@ -80,3 +81,5 @@ async def lifespan(app: FastAPI): stop_plugins() # 停止模块 stop_modules() + # 关闭数据库异步执行器 + shutdown_db_executor()