From f14f4e1e9b85f3cc549530e041c0e36a0c5484ae Mon Sep 17 00:00:00 2001 From: jxxghp Date: Wed, 30 Jul 2025 13:18:45 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E5=BC=82=E6=AD=A5=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E5=BA=93=E6=94=AF=E6=8C=81=EF=BC=8C=E6=9B=B4=E6=96=B0?= =?UTF-8?q?=E7=9B=B8=E5=85=B3=E6=A8=A1=E5=9E=8B=E5=92=8C=E4=BC=9A=E8=AF=9D?= =?UTF-8?q?=E7=AE=A1=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/endpoints/site.py | 13 +- app/db/__init__.py | 325 +++++++++++++++++++++++------ app/db/async_adapter.py | 136 ------------ app/db/models/siteicon.py | 11 +- app/startup/lifecycle.py | 5 +- app/startup/modules_initializer.py | 6 +- docs/async-database-guide.md | 171 --------------- requirements.in | 1 + 8 files changed, 286 insertions(+), 382 deletions(-) delete mode 100644 app/db/async_adapter.py delete mode 100644 docs/async-database-guide.md diff --git a/app/api/endpoints/site.py b/app/api/endpoints/site.py index 0bfd1838..f712c3d3 100644 --- a/app/api/endpoints/site.py +++ b/app/api/endpoints/site.py @@ -1,6 +1,7 @@ from typing import List, Any, Dict, Optional from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session from starlette.background import BackgroundTasks @@ -12,7 +13,7 @@ from app.command import Command from app.core.event import EventManager from app.core.plugin import PluginManager from app.core.security import verify_token -from app.db import get_db +from app.db import get_db, get_async_db from app.db.models import User from app.db.models.site import Site from app.db.models.siteicon import SiteIcon @@ -242,19 +243,19 @@ def test_site(site_id: int, @router.get("/icon/{site_id}", summary="站点图标", response_model=schemas.Response) -def site_icon(site_id: int, - db: Session = Depends(get_db), - _: schemas.TokenPayload = Depends(verify_token)) -> Any: +async def site_icon(site_id: int, + db: AsyncSession = Depends(get_async_db), + _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 获取站点图标:base64或者url """ - site = Site.get(db, site_id) + site = await Site.async_get(db, site_id) if not site: raise HTTPException( status_code=404, detail=f"站点 {site_id} 不存在", ) - icon = SiteIcon.get_by_domain(db, site.domain) + icon = await SiteIcon.async_get_by_domain(db, site.domain) if not icon: return schemas.Response(success=False, message="站点图标不存在!") return schemas.Response(success=True, data={ diff --git a/app/db/__init__.py b/app/db/__init__.py index c7e473fe..357197d1 100644 --- a/app/db/__init__.py +++ b/app/db/__init__.py @@ -1,45 +1,106 @@ -from typing import Any, Generator, List, Optional, Self, Tuple +import asyncio +from typing import Any, Generator, List, Optional, Self, Tuple, AsyncGenerator, Sequence from sqlalchemy import NullPool, QueuePool, and_, create_engine, inspect, text +from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker from sqlalchemy.orm import Session, as_declarative, declared_attr, scoped_session, sessionmaker from app.core.config import settings -# 根据池类型设置 poolclass 和相关参数 -pool_class = NullPool if settings.DB_POOL_TYPE == "NullPool" else QueuePool -connect_args = { - "timeout": settings.DB_TIMEOUT -} -# 启用 WAL 模式时的额外配置 -if settings.DB_WAL_ENABLE: - connect_args["check_same_thread"] = False -db_kwargs = { - "url": f"sqlite:///{settings.CONFIG_PATH}/user.db", - "pool_pre_ping": settings.DB_POOL_PRE_PING, - "echo": settings.DB_ECHO, - "poolclass": pool_class, - "pool_recycle": settings.DB_POOL_RECYCLE, - "connect_args": connect_args -} -# 当使用 QueuePool 时,添加 QueuePool 特有的参数 -if pool_class == QueuePool: - db_kwargs.update({ - "pool_size": settings.CONF.dbpool, - "pool_timeout": settings.DB_POOL_TIMEOUT, - "max_overflow": settings.CONF.dbpooloverflow - }) -# 创建数据库引擎 -Engine = create_engine(**db_kwargs) -# 根据配置设置日志模式 -journal_mode = "WAL" if settings.DB_WAL_ENABLE else "DELETE" -with Engine.connect() as connection: - current_mode = connection.execute(text(f"PRAGMA journal_mode={journal_mode};")).scalar() - print(f"Database journal mode set to: {current_mode}") -# 会话工厂 +def _get_database_engine(is_async: bool = False): + """ + 获取数据库连接参数并设置WAL模式 + :param is_async: 是否创建异步引擎,True - 异步引擎, False - 同步引擎 + :return: 返回对应的数据库引擎 + """ + # 连接参数 + _connect_args = { + "timeout": settings.DB_TIMEOUT, + } + # 启用 WAL 模式时的额外配置 + if settings.DB_WAL_ENABLE: + _connect_args["check_same_thread"] = False + + # 创建同步引擎 + if not is_async: + # 根据池类型设置 poolclass 和相关参数 + _pool_class = NullPool if settings.DB_POOL_TYPE == "NullPool" else QueuePool + + # 数据库参数 + _db_kwargs = { + "url": f"sqlite:///{settings.CONFIG_PATH}/user.db", + "pool_pre_ping": settings.DB_POOL_PRE_PING, + "echo": settings.DB_ECHO, + "poolclass": _pool_class, + "pool_recycle": settings.DB_POOL_RECYCLE, + "connect_args": _connect_args + } + + # 当使用 QueuePool 时,添加 QueuePool 特有的参数 + if _pool_class == QueuePool: + _db_kwargs.update({ + "pool_size": settings.CONF.dbpool, + "pool_timeout": settings.DB_POOL_TIMEOUT, + "max_overflow": settings.CONF.dbpooloverflow + }) + + # 创建数据库引擎 + engine = create_engine(**_db_kwargs) + + # 设置WAL模式 + _journal_mode = "WAL" if settings.DB_WAL_ENABLE else "DELETE" + with engine.connect() as connection: + current_mode = connection.execute(text(f"PRAGMA journal_mode={_journal_mode};")).scalar() + print(f"Database journal mode set to: {current_mode}") + + return engine + else: + # 数据库参数,只能使用 NullPool + _db_kwargs = { + "url": f"sqlite+aiosqlite:///{settings.CONFIG_PATH}/user.db", + "pool_pre_ping": settings.DB_POOL_PRE_PING, + "echo": settings.DB_ECHO, + "poolclass": NullPool, + "pool_recycle": settings.DB_POOL_RECYCLE, + "connect_args": _connect_args + } + # 创建异步数据库引擎 + async_engine = create_async_engine(**_db_kwargs) + + # 设置WAL模式 + _journal_mode = "WAL" if settings.DB_WAL_ENABLE else "DELETE" + + async def set_async_wal_mode(): + """ + 设置异步引擎的WAL模式 + """ + async with async_engine.connect() as _connection: + result = await _connection.execute(text(f"PRAGMA journal_mode={_journal_mode};")) + _current_mode = result.scalar() + print(f"Async database journal mode set to: {_current_mode}") + + try: + asyncio.run(set_async_wal_mode()) + except Exception as e: + print(f"Failed to set async WAL mode: {e}") + + return async_engine + + +# 同步数据库引擎 +Engine = _get_database_engine(is_async=False) + +# 异步数据库引擎 +AsyncEngine = _get_database_engine(is_async=True) + +# 同步会话工厂 SessionFactory = sessionmaker(bind=Engine) -# 多线程全局使用的数据库会话 +# 异步会话工厂 +AsyncSessionFactory = async_sessionmaker(bind=AsyncEngine, class_=AsyncSession) + +# 同步多线程全局使用的数据库会话 ScopedSession = scoped_session(SessionFactory) @@ -57,37 +118,32 @@ def get_db() -> Generator: db.close() -def perform_checkpoint(mode: str = "PASSIVE"): +async def get_async_db() -> AsyncGenerator[AsyncSession, None]: """ - 执行 SQLite 的 checkpoint 操作,将 WAL 文件内容写回主数据库 - :param mode: checkpoint 模式,可选值包括 "PASSIVE"、"FULL"、"RESTART"、"TRUNCATE" - 默认为 "PASSIVE",即不锁定 WAL 文件的轻量级同步 + 获取异步数据库会话,用于WEB请求 + :return: AsyncSession """ - if not settings.DB_WAL_ENABLE: - return - valid_modes = {"PASSIVE", "FULL", "RESTART", "TRUNCATE"} - if mode.upper() not in valid_modes: - raise ValueError(f"Invalid checkpoint mode '{mode}'. Must be one of {valid_modes}") - try: - # 使用指定的 checkpoint 模式,确保 WAL 文件数据被正确写回主数据库 - with Engine.connect() as conn: - conn.execute(text(f"PRAGMA wal_checkpoint({mode.upper()});")) - except Exception as e: - print(f"Error during WAL checkpoint: {e}") + async with AsyncSessionFactory() as session: + try: + yield session + finally: + await session.close() -def close_database(): +async def close_database(): """ 关闭所有数据库连接并清理资源 """ try: - # 释放连接池,SQLite 会自动清空 WAL 文件,这里不单独再调用 checkpoint - Engine.dispose() - except Exception as e: - print(f"Error while disposing database connections: {e}") + # 释放同步连接池 + Engine.dispose() # noqa + # 释放异步连接池 + await AsyncEngine.dispose() + except Exception as err: + print(f"Error while disposing database connections: {err}") -def get_args_db(args: tuple, kwargs: dict) -> Optional[Session]: +def _get_args_db(args: tuple, kwargs: dict) -> Optional[Session]: """ 从参数中获取数据库Session对象 """ @@ -105,7 +161,25 @@ def get_args_db(args: tuple, kwargs: dict) -> Optional[Session]: return db -def update_args_db(args: tuple, kwargs: dict, db: Session) -> Tuple[tuple, dict]: +def _get_args_async_db(args: tuple, kwargs: dict) -> Optional[AsyncSession]: + """ + 从参数中获取异步数据库AsyncSession对象 + """ + db = None + if args: + for arg in args: + if isinstance(arg, AsyncSession): + db = arg + break + if kwargs: + for key, value in kwargs.items(): + if isinstance(value, AsyncSession): + db = value + break + return db + + +def _update_args_db(args: tuple, kwargs: dict, db: Session) -> Tuple[tuple, dict]: """ 更新参数中的数据库Session对象,关键字传参时更新db的值,否则更新第1或第2个参数 """ @@ -119,6 +193,20 @@ def update_args_db(args: tuple, kwargs: dict, db: Session) -> Tuple[tuple, dict] return args, kwargs +def _update_args_async_db(args: tuple, kwargs: dict, db: AsyncSession) -> Tuple[tuple, dict]: + """ + 更新参数中的异步数据库AsyncSession对象,关键字传参时更新db的值,否则更新第1或第2个参数 + """ + if kwargs and 'db' in kwargs: + kwargs['db'] = db + elif args: + if args[0] is None: + args = (db, *args[1:]) + else: + args = (args[0], db, *args[2:]) + return args, kwargs + + def db_update(func): """ 数据库更新类操作装饰器,第一个参数必须是数据库会话或存在db参数 @@ -128,14 +216,14 @@ def db_update(func): # 是否关闭数据库会话 _close_db = False # 从参数中获取数据库会话 - db = get_args_db(args, kwargs) + db = _get_args_db(args, kwargs) if not db: # 如果没有获取到数据库会话,创建一个 db = ScopedSession() # 标记需要关闭数据库会话 _close_db = True # 更新参数中的数据库会话 - args, kwargs = update_args_db(args, kwargs, db) + args, kwargs = _update_args_db(args, kwargs, db) try: # 执行函数 result = func(*args, **kwargs) @@ -154,6 +242,41 @@ def db_update(func): return wrapper +def async_db_update(func): + """ + 异步数据库更新类操作装饰器,第一个参数必须是异步数据库会话或存在db参数 + """ + + async def wrapper(*args, **kwargs): + # 是否关闭数据库会话 + _close_db = False + # 从参数中获取异步数据库会话 + db = _get_args_async_db(args, kwargs) + if not db: + # 如果没有获取到异步数据库会话,创建一个 + db = AsyncSessionFactory() + # 标记需要关闭数据库会话 + _close_db = True + # 更新参数中的异步数据库会话 + args, kwargs = _update_args_async_db(args, kwargs, db) + try: + # 执行函数 + result = await func(*args, **kwargs) + # 提交事务 + await db.commit() + except Exception as err: + # 回滚事务 + await db.rollback() + raise err + finally: + # 关闭数据库会话 + if _close_db: + await db.close() + return result + + return wrapper + + def db_query(func): """ 数据库查询操作装饰器,第一个参数必须是数据库会话或存在db参数 @@ -164,14 +287,14 @@ def db_query(func): # 是否关闭数据库会话 _close_db = False # 从参数中获取数据库会话 - db = get_args_db(args, kwargs) + db = _get_args_db(args, kwargs) if not db: # 如果没有获取到数据库会话,创建一个 db = ScopedSession() # 标记需要关闭数据库会话 _close_db = True # 更新参数中的数据库会话 - args, kwargs = update_args_db(args, kwargs, db) + args, kwargs = _update_args_db(args, kwargs, db) try: # 执行函数 result = func(*args, **kwargs) @@ -186,6 +309,38 @@ def db_query(func): return wrapper +def async_db_query(func): + """ + 异步数据库查询操作装饰器,第一个参数必须是异步数据库会话或存在db参数 + 注意:db.query列表数据时,需要转换为list返回 + """ + + async def wrapper(*args, **kwargs): + # 是否关闭数据库会话 + _close_db = False + # 从参数中获取异步数据库会话 + db = _get_args_async_db(args, kwargs) + if not db: + # 如果没有获取到异步数据库会话,创建一个 + db = AsyncSessionFactory() + # 标记需要关闭数据库会话 + _close_db = True + # 更新参数中的异步数据库会话 + args, kwargs = _update_args_async_db(args, kwargs, db) + try: + # 执行函数 + result = await func(*args, **kwargs) + except Exception as err: + raise err + finally: + # 关闭数据库会话 + if _close_db: + await db.close() + return result + + return wrapper + + @as_declarative() class Base: id: Any @@ -195,11 +350,22 @@ class Base: def create(self, db: Session): db.add(self) + @async_db_update + async def async_create(self, db: AsyncSession): + db.add(self) + @classmethod @db_query def get(cls, db: Session, rid: int) -> Self: return db.query(cls).filter(and_(cls.id == rid)).first() + @classmethod + @async_db_query + async def async_get(cls, db: AsyncSession, rid: int) -> Self: + from sqlalchemy import select + result = await db.execute(select(cls).where(and_(cls.id == rid))) + return result.scalars().first() + @db_update def update(self, db: Session, payload: dict): payload = {k: v for k, v in payload.items() if v is not None} @@ -208,23 +374,53 @@ class Base: if inspect(self).detached: db.add(self) + @async_db_update + async def async_update(self, db: AsyncSession, payload: dict): + payload = {k: v for k, v in payload.items() if v is not None} + for key, value in payload.items(): + setattr(self, key, value) + if inspect(self).detached: + db.add(self) + @classmethod @db_update def delete(cls, db: Session, rid): db.query(cls).filter(and_(cls.id == rid)).delete() + @classmethod + @async_db_update + async def async_delete(cls, db: AsyncSession, rid): + from sqlalchemy import select + result = await db.execute(select(cls).where(and_(cls.id == rid))) + user = result.scalars().first() + if user: + await db.delete(user) + @classmethod @db_update def truncate(cls, db: Session): db.query(cls).delete() + @classmethod + @async_db_update + async def async_truncate(cls, db: AsyncSession): + from sqlalchemy import delete + await db.execute(delete(cls)) + @classmethod @db_query def list(cls, db: Session) -> List[Self]: return db.query(cls).all() + @classmethod + @async_db_query + async def async_list(cls, db: AsyncSession) -> Sequence[Self]: + from sqlalchemy import select + result = await db.execute(select(cls)) + return result.scalars().all() + def to_dict(self): - return {c.name: getattr(self, c.name, None) for c in self.__table__.columns} # noqa + return {c.name: getattr(self, c.name, None) for c in self.__table__.columns} # noqa @declared_attr def __tablename__(self) -> str: @@ -238,3 +434,12 @@ class DbOper: def __init__(self, db: Session = None): self._db = db + + +class AsyncDbOper: + """ + 异步数据库操作基类 + """ + + def __init__(self, db: AsyncSession = None): + self._db = db diff --git a/app/db/async_adapter.py b/app/db/async_adapter.py deleted file mode 100644 index 6ce86d35..00000000 --- a/app/db/async_adapter.py +++ /dev/null @@ -1,136 +0,0 @@ -import asyncio -from concurrent.futures import ThreadPoolExecutor -from functools import wraps, partial -from typing import Any, Callable, Coroutine, TypeVar - -from sqlalchemy.orm import Session - -from app.db import ScopedSession - -T = TypeVar('T') - -# 全局线程池,用于执行同步数据库操作 -_db_executor = ThreadPoolExecutor(max_workers=10, thread_name_prefix="db_async") - - -def async_db_operation(func: Callable[..., T]) -> Callable[..., Coroutine[Any, Any, T]]: - """ - 将同步数据库操作转换为异步操作的装饰器 - """ - - @wraps(func) - async def wrapper(*args, **kwargs) -> T: - # 在线程池中执行同步数据库操作 - loop = asyncio.get_event_loop() - partial_func = partial(func, *args, **kwargs) - return await loop.run_in_executor(_db_executor, partial_func, ()) - - return wrapper - - -def async_db_session(func: Callable[..., T]) -> Callable[..., Coroutine[Any, Any, T]]: - """ - 为异步操作提供数据库会话的装饰器 - """ - - @wraps(func) - async def wrapper(*args, **kwargs) -> T: - # 创建数据库会话 - db = ScopedSession() - try: - # 将数据库会话添加到参数中 - if 'db' not in kwargs: - kwargs['db'] = db - # 在线程池中执行同步数据库操作 - loop = asyncio.get_event_loop() - partial_func = partial(func, *args, **kwargs) - return await loop.run_in_executor(_db_executor, partial_func, ()) - finally: - db.close() - - return wrapper - - -class AsyncDbOper: - """ - 异步数据库操作基类 - """ - - def __init__(self, db: Session = None): - self._db = db - - async def _get_db(self) -> Session: - """ - 获取数据库会话 - """ - if self._db: - return self._db - return ScopedSession() - - @staticmethod - async def _execute_sync(func: Callable[..., T], *args, **kwargs) -> T: - """ - 在线程池中执行同步数据库操作 - """ - loop = asyncio.get_event_loop() - partial_func = partial(func, *args, **kwargs) - return await loop.run_in_executor(_db_executor, partial_func, ()) - - -def to_async_db_oper(sync_oper_class): - """ - 将同步数据库操作类转换为异步版本的装饰器 - """ - - class AsyncOperClass(AsyncDbOper): - def __init__(self, db: Session = None): - super().__init__(db) - self._sync_oper = sync_oper_class(db) - - def __getattr__(self, name): - """动态获取同步操作类的方法并转换为异步""" - if hasattr(self._sync_oper, name): - method = getattr(self._sync_oper, name) - if callable(method): - return async_db_operation(method) - raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'") - - return AsyncOperClass - - -# 异步数据库会话获取函数 -async def get_async_db(): - """ - 异步获取数据库会话 - """ - - def _get_db(): - return ScopedSession() - - loop = asyncio.get_event_loop() - return await loop.run_in_executor(_db_executor, _get_db) # type: ignore - - -# 异步上下文管理器 -class AsyncDbSession: - """ - 异步数据库会话上下文管理器 - """ - - def __init__(self): - self.db = None - - async def __aenter__(self): - self.db = await get_async_db() - return self.db - - async def __aexit__(self, exc_type, exc_val, exc_tb): - if self.db: - self.db.close() - - -def shutdown_db_executor(): - """关闭数据库线程池""" - global _db_executor - if _db_executor: - _db_executor.shutdown(wait=True) diff --git a/app/db/models/siteicon.py b/app/db/models/siteicon.py index 770cd37e..024378b9 100644 --- a/app/db/models/siteicon.py +++ b/app/db/models/siteicon.py @@ -1,7 +1,8 @@ -from sqlalchemy import Column, Integer, String, Sequence +from sqlalchemy import Column, Integer, String, Sequence, select +from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session -from app.db import db_query, Base +from app.db import db_query, Base, async_db_query class SiteIcon(Base): @@ -22,3 +23,9 @@ class SiteIcon(Base): @db_query def get_by_domain(db: Session, domain: str): return db.query(SiteIcon).filter(SiteIcon.domain == domain).first() + + @classmethod + @async_db_query + async def async_get_by_domain(cls, db: AsyncSession, domain: str): + result = await db.execute(select(cls).where(cls.domain == domain)) + return result.scalar_one_or_none() diff --git a/app/startup/lifecycle.py b/app/startup/lifecycle.py index 04ac2a78..94561002 100644 --- a/app/startup/lifecycle.py +++ b/app/startup/lifecycle.py @@ -4,7 +4,6 @@ from contextlib import asynccontextmanager from fastapi import FastAPI from app.chain.system import SystemChain -from app.db.async_adapter import shutdown_db_executor from app.helper.system import SystemHelper from app.startup.command_initializer import init_command, stop_command, restart_command from app.startup.modules_initializer import init_modules, stop_modules @@ -80,6 +79,4 @@ async def lifespan(app: FastAPI): # 停止插件 stop_plugins() # 停止模块 - stop_modules() - # 关闭数据库异步执行器 - shutdown_db_executor() + await stop_modules() diff --git a/app/startup/modules_initializer.py b/app/startup/modules_initializer.py index 9f6301e5..496fd7f7 100644 --- a/app/startup/modules_initializer.py +++ b/app/startup/modules_initializer.py @@ -1,11 +1,11 @@ import sys +from app.command import CommandChain from app.core.cache import close_cache from app.core.config import settings from app.core.module import ModuleManager from app.log import logger from app.utils.system import SystemUtils -from app.command import CommandChain # SitesHelper涉及资源包拉取,提前引入并容错提示 try: @@ -105,7 +105,7 @@ def check_auth(): ) -def stop_modules(): +async def stop_modules(): """ 服务关闭 """ @@ -120,7 +120,7 @@ def stop_modules(): # 停止缓存连接 close_cache() # 停止数据库连接 - close_database() + await close_database() # 停止前端服务 stop_frontend() # 清理临时文件 diff --git a/docs/async-database-guide.md b/docs/async-database-guide.md deleted file mode 100644 index 947690ad..00000000 --- a/docs/async-database-guide.md +++ /dev/null @@ -1,171 +0,0 @@ -# 异步数据库操作指南 - -## 概述 - -本指南介绍如何在MoviePilot项目中实现异步数据库操作,而无需重写现有的同步数据库模块。 - -## 方案优势 - -1. **最小化改动**: 不需要重写现有的同步数据库操作代码 -2. **渐进式迁移**: 可以逐步将需要异步的操作迁移到异步版本 -3. **向后兼容**: 现有的同步代码继续正常工作 -4. **性能提升**: 通过线程池执行同步数据库操作,避免阻塞事件循环 - -## 核心组件 - -### 1. 异步适配器 (`app/db/async_adapter.py`) - -提供以下功能: -- `async_db_operation`: 将同步函数包装为异步函数 -- `async_db_session`: 为异步操作提供数据库会话 -- `AsyncDbOper`: 异步数据库操作基类 -- `to_async_db_oper`: 将同步操作类转换为异步版本 -- `AsyncDbSession`: 异步数据库会话上下文管理器 - -### 2. 异步操作类 (`app/db/async_user_oper.py`) - -展示如何创建异步版本的数据库操作类。 - -## 使用方法 - -### 方法1: 使用装饰器自动转换 - -```python -from app.db.user_oper import UserOper -from app.db.async_adapter import to_async_db_oper - -# 自动将同步类转换为异步版本 -AsyncUserOper = to_async_db_oper(UserOper) - -# 使用异步版本 -async def example(): - async_user_oper = AsyncUserOper() - users = await async_user_oper.list() - return users -``` - -### 方法2: 手动创建异步方法 - -```python -from app.db.async_adapter import async_db_operation - -class AsyncUserOperManual: - def __init__(self, db: Session = None): - self._sync_oper = UserOper(db) - - @async_db_operation - def list(self) -> List[User]: - return self._sync_oper.list() - - @async_db_operation - def add(self, **kwargs): - return self._sync_oper.add(**kwargs) -``` - -### 方法3: 使用异步上下文管理器 - -```python -from app.db.async_adapter import AsyncDbSession - -async def example(): - async with AsyncDbSession() as db: - # 直接使用同步操作,但在线程池中执行 - from functools import partial - users = await asyncio.get_event_loop().run_in_executor( - None, partial(lambda db: db.query(User).all(), db) - ) - return users -``` - -### 方法4: 装饰器包装单个函数 - -```python -from app.db.async_adapter import async_db_operation - -@async_db_operation -def get_user_by_name_sync(db: Session, name: str) -> User: - return db.query(User).filter(User.name == name).first() - -async def example(): - async with AsyncDbSession() as db: - user = await get_user_by_name_sync(db, "admin") - return user -``` - -## 在FastAPI中使用 - -### 依赖注入 - -```python -async def get_current_user_async_dependency( - token_data: schemas.TokenPayload = Depends(verify_token) -) -> User: - async with AsyncDbSession() as db: - from functools import partial - user = await asyncio.get_event_loop().run_in_executor( - None, partial(User.get, db, rid=token_data.sub) - ) - if not user: - raise HTTPException(status_code=403, detail="用户不存在") - return user - -@router.get("/me/async") -async def get_current_user_async( - current_user: User = Depends(get_current_user_async_dependency) -): - return current_user -``` - -### API端点 - -```python -@router.get("/users/async") -async def get_users_async(): - async_user_oper = AsyncUserOper() - users = await async_user_oper.list() - return users -``` - -## 并发操作 - -### 批量操作 - -```python -@router.post("/users/batch-async") -async def create_users_batch_async(users_data: List[schemas.UserCreate]): - async_user_oper = AsyncUserOper() - - # 并发创建用户 - tasks = [] - for user_data in users_data: - task = async_user_oper.add(**user_data.dict()) - tasks.append(task) - - # 等待所有任务完成 - await asyncio.gather(*tasks) - - # 获取创建的用户列表 - all_users = await async_user_oper.list() - return all_users[-len(users_data):] -``` - -### 复杂查询 - -```python -@router.get("/users/stats/async") -async def get_user_stats_async(): - async_user_oper = AsyncUserOper() - - # 并发执行多个数据库查询 - users_task = async_user_oper.list() - active_users_task = async_user_oper.get_active_users() - - # 等待所有查询完成 - users, active_users = await asyncio.gather(users_task, active_users_task) - - return { - "total_users": len(users), - "active_users": len(active_users), - "inactive_users": len(users) - len(active_users) - } -``` \ No newline at end of file diff --git a/requirements.in b/requirements.in index 177d3d27..34944860 100644 --- a/requirements.in +++ b/requirements.in @@ -60,6 +60,7 @@ Pinyin2Hanzi~=0.1.1 pywebpush~=2.0.3 python-cookietools==0.0.4 aiofiles~=24.1.0 +aiosqlite~=0.21.0 jieba~=0.42.1 rsa~=4.9 redis~=6.2.0