mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-03-20 03:57:30 +08:00
添加异步数据库支持,更新相关模型和会话管理
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
from typing import List, Any, Dict, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Session
|
||||
from starlette.background import BackgroundTasks
|
||||
|
||||
@@ -12,7 +13,7 @@ from app.command import Command
|
||||
from app.core.event import EventManager
|
||||
from app.core.plugin import PluginManager
|
||||
from app.core.security import verify_token
|
||||
from app.db import get_db
|
||||
from app.db import get_db, get_async_db
|
||||
from app.db.models import User
|
||||
from app.db.models.site import Site
|
||||
from app.db.models.siteicon import SiteIcon
|
||||
@@ -242,19 +243,19 @@ def test_site(site_id: int,
|
||||
|
||||
|
||||
@router.get("/icon/{site_id}", summary="站点图标", response_model=schemas.Response)
|
||||
def site_icon(site_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
async def site_icon(site_id: int,
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
获取站点图标:base64或者url
|
||||
"""
|
||||
site = Site.get(db, site_id)
|
||||
site = await Site.async_get(db, site_id)
|
||||
if not site:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"站点 {site_id} 不存在",
|
||||
)
|
||||
icon = SiteIcon.get_by_domain(db, site.domain)
|
||||
icon = await SiteIcon.async_get_by_domain(db, site.domain)
|
||||
if not icon:
|
||||
return schemas.Response(success=False, message="站点图标不存在!")
|
||||
return schemas.Response(success=True, data={
|
||||
|
||||
@@ -1,45 +1,106 @@
|
||||
from typing import Any, Generator, List, Optional, Self, Tuple
|
||||
import asyncio
|
||||
from typing import Any, Generator, List, Optional, Self, Tuple, AsyncGenerator, Sequence
|
||||
|
||||
from sqlalchemy import NullPool, QueuePool, and_, create_engine, inspect, text
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
|
||||
from sqlalchemy.orm import Session, as_declarative, declared_attr, scoped_session, sessionmaker
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
# 根据池类型设置 poolclass 和相关参数
|
||||
pool_class = NullPool if settings.DB_POOL_TYPE == "NullPool" else QueuePool
|
||||
connect_args = {
|
||||
"timeout": settings.DB_TIMEOUT
|
||||
|
||||
def _get_database_engine(is_async: bool = False):
|
||||
"""
|
||||
获取数据库连接参数并设置WAL模式
|
||||
:param is_async: 是否创建异步引擎,True - 异步引擎, False - 同步引擎
|
||||
:return: 返回对应的数据库引擎
|
||||
"""
|
||||
# 连接参数
|
||||
_connect_args = {
|
||||
"timeout": settings.DB_TIMEOUT,
|
||||
}
|
||||
# 启用 WAL 模式时的额外配置
|
||||
if settings.DB_WAL_ENABLE:
|
||||
connect_args["check_same_thread"] = False
|
||||
db_kwargs = {
|
||||
_connect_args["check_same_thread"] = False
|
||||
|
||||
# 创建同步引擎
|
||||
if not is_async:
|
||||
# 根据池类型设置 poolclass 和相关参数
|
||||
_pool_class = NullPool if settings.DB_POOL_TYPE == "NullPool" else QueuePool
|
||||
|
||||
# 数据库参数
|
||||
_db_kwargs = {
|
||||
"url": f"sqlite:///{settings.CONFIG_PATH}/user.db",
|
||||
"pool_pre_ping": settings.DB_POOL_PRE_PING,
|
||||
"echo": settings.DB_ECHO,
|
||||
"poolclass": pool_class,
|
||||
"poolclass": _pool_class,
|
||||
"pool_recycle": settings.DB_POOL_RECYCLE,
|
||||
"connect_args": connect_args
|
||||
"connect_args": _connect_args
|
||||
}
|
||||
|
||||
# 当使用 QueuePool 时,添加 QueuePool 特有的参数
|
||||
if pool_class == QueuePool:
|
||||
db_kwargs.update({
|
||||
if _pool_class == QueuePool:
|
||||
_db_kwargs.update({
|
||||
"pool_size": settings.CONF.dbpool,
|
||||
"pool_timeout": settings.DB_POOL_TIMEOUT,
|
||||
"max_overflow": settings.CONF.dbpooloverflow
|
||||
})
|
||||
|
||||
# 创建数据库引擎
|
||||
Engine = create_engine(**db_kwargs)
|
||||
# 根据配置设置日志模式
|
||||
journal_mode = "WAL" if settings.DB_WAL_ENABLE else "DELETE"
|
||||
with Engine.connect() as connection:
|
||||
current_mode = connection.execute(text(f"PRAGMA journal_mode={journal_mode};")).scalar()
|
||||
engine = create_engine(**_db_kwargs)
|
||||
|
||||
# 设置WAL模式
|
||||
_journal_mode = "WAL" if settings.DB_WAL_ENABLE else "DELETE"
|
||||
with engine.connect() as connection:
|
||||
current_mode = connection.execute(text(f"PRAGMA journal_mode={_journal_mode};")).scalar()
|
||||
print(f"Database journal mode set to: {current_mode}")
|
||||
|
||||
# 会话工厂
|
||||
return engine
|
||||
else:
|
||||
# 数据库参数,只能使用 NullPool
|
||||
_db_kwargs = {
|
||||
"url": f"sqlite+aiosqlite:///{settings.CONFIG_PATH}/user.db",
|
||||
"pool_pre_ping": settings.DB_POOL_PRE_PING,
|
||||
"echo": settings.DB_ECHO,
|
||||
"poolclass": NullPool,
|
||||
"pool_recycle": settings.DB_POOL_RECYCLE,
|
||||
"connect_args": _connect_args
|
||||
}
|
||||
# 创建异步数据库引擎
|
||||
async_engine = create_async_engine(**_db_kwargs)
|
||||
|
||||
# 设置WAL模式
|
||||
_journal_mode = "WAL" if settings.DB_WAL_ENABLE else "DELETE"
|
||||
|
||||
async def set_async_wal_mode():
|
||||
"""
|
||||
设置异步引擎的WAL模式
|
||||
"""
|
||||
async with async_engine.connect() as _connection:
|
||||
result = await _connection.execute(text(f"PRAGMA journal_mode={_journal_mode};"))
|
||||
_current_mode = result.scalar()
|
||||
print(f"Async database journal mode set to: {_current_mode}")
|
||||
|
||||
try:
|
||||
asyncio.run(set_async_wal_mode())
|
||||
except Exception as e:
|
||||
print(f"Failed to set async WAL mode: {e}")
|
||||
|
||||
return async_engine
|
||||
|
||||
|
||||
# 同步数据库引擎
|
||||
Engine = _get_database_engine(is_async=False)
|
||||
|
||||
# 异步数据库引擎
|
||||
AsyncEngine = _get_database_engine(is_async=True)
|
||||
|
||||
# 同步会话工厂
|
||||
SessionFactory = sessionmaker(bind=Engine)
|
||||
|
||||
# 多线程全局使用的数据库会话
|
||||
# 异步会话工厂
|
||||
AsyncSessionFactory = async_sessionmaker(bind=AsyncEngine, class_=AsyncSession)
|
||||
|
||||
# 同步多线程全局使用的数据库会话
|
||||
ScopedSession = scoped_session(SessionFactory)
|
||||
|
||||
|
||||
@@ -57,37 +118,32 @@ def get_db() -> Generator:
|
||||
db.close()
|
||||
|
||||
|
||||
def perform_checkpoint(mode: str = "PASSIVE"):
|
||||
async def get_async_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""
|
||||
执行 SQLite 的 checkpoint 操作,将 WAL 文件内容写回主数据库
|
||||
:param mode: checkpoint 模式,可选值包括 "PASSIVE"、"FULL"、"RESTART"、"TRUNCATE"
|
||||
默认为 "PASSIVE",即不锁定 WAL 文件的轻量级同步
|
||||
获取异步数据库会话,用于WEB请求
|
||||
:return: AsyncSession
|
||||
"""
|
||||
if not settings.DB_WAL_ENABLE:
|
||||
return
|
||||
valid_modes = {"PASSIVE", "FULL", "RESTART", "TRUNCATE"}
|
||||
if mode.upper() not in valid_modes:
|
||||
raise ValueError(f"Invalid checkpoint mode '{mode}'. Must be one of {valid_modes}")
|
||||
async with AsyncSessionFactory() as session:
|
||||
try:
|
||||
# 使用指定的 checkpoint 模式,确保 WAL 文件数据被正确写回主数据库
|
||||
with Engine.connect() as conn:
|
||||
conn.execute(text(f"PRAGMA wal_checkpoint({mode.upper()});"))
|
||||
except Exception as e:
|
||||
print(f"Error during WAL checkpoint: {e}")
|
||||
yield session
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
def close_database():
|
||||
async def close_database():
|
||||
"""
|
||||
关闭所有数据库连接并清理资源
|
||||
"""
|
||||
try:
|
||||
# 释放连接池,SQLite 会自动清空 WAL 文件,这里不单独再调用 checkpoint
|
||||
Engine.dispose()
|
||||
except Exception as e:
|
||||
print(f"Error while disposing database connections: {e}")
|
||||
# 释放同步连接池
|
||||
Engine.dispose() # noqa
|
||||
# 释放异步连接池
|
||||
await AsyncEngine.dispose()
|
||||
except Exception as err:
|
||||
print(f"Error while disposing database connections: {err}")
|
||||
|
||||
|
||||
def get_args_db(args: tuple, kwargs: dict) -> Optional[Session]:
|
||||
def _get_args_db(args: tuple, kwargs: dict) -> Optional[Session]:
|
||||
"""
|
||||
从参数中获取数据库Session对象
|
||||
"""
|
||||
@@ -105,7 +161,25 @@ def get_args_db(args: tuple, kwargs: dict) -> Optional[Session]:
|
||||
return db
|
||||
|
||||
|
||||
def update_args_db(args: tuple, kwargs: dict, db: Session) -> Tuple[tuple, dict]:
|
||||
def _get_args_async_db(args: tuple, kwargs: dict) -> Optional[AsyncSession]:
|
||||
"""
|
||||
从参数中获取异步数据库AsyncSession对象
|
||||
"""
|
||||
db = None
|
||||
if args:
|
||||
for arg in args:
|
||||
if isinstance(arg, AsyncSession):
|
||||
db = arg
|
||||
break
|
||||
if kwargs:
|
||||
for key, value in kwargs.items():
|
||||
if isinstance(value, AsyncSession):
|
||||
db = value
|
||||
break
|
||||
return db
|
||||
|
||||
|
||||
def _update_args_db(args: tuple, kwargs: dict, db: Session) -> Tuple[tuple, dict]:
|
||||
"""
|
||||
更新参数中的数据库Session对象,关键字传参时更新db的值,否则更新第1或第2个参数
|
||||
"""
|
||||
@@ -119,6 +193,20 @@ def update_args_db(args: tuple, kwargs: dict, db: Session) -> Tuple[tuple, dict]
|
||||
return args, kwargs
|
||||
|
||||
|
||||
def _update_args_async_db(args: tuple, kwargs: dict, db: AsyncSession) -> Tuple[tuple, dict]:
|
||||
"""
|
||||
更新参数中的异步数据库AsyncSession对象,关键字传参时更新db的值,否则更新第1或第2个参数
|
||||
"""
|
||||
if kwargs and 'db' in kwargs:
|
||||
kwargs['db'] = db
|
||||
elif args:
|
||||
if args[0] is None:
|
||||
args = (db, *args[1:])
|
||||
else:
|
||||
args = (args[0], db, *args[2:])
|
||||
return args, kwargs
|
||||
|
||||
|
||||
def db_update(func):
|
||||
"""
|
||||
数据库更新类操作装饰器,第一个参数必须是数据库会话或存在db参数
|
||||
@@ -128,14 +216,14 @@ def db_update(func):
|
||||
# 是否关闭数据库会话
|
||||
_close_db = False
|
||||
# 从参数中获取数据库会话
|
||||
db = get_args_db(args, kwargs)
|
||||
db = _get_args_db(args, kwargs)
|
||||
if not db:
|
||||
# 如果没有获取到数据库会话,创建一个
|
||||
db = ScopedSession()
|
||||
# 标记需要关闭数据库会话
|
||||
_close_db = True
|
||||
# 更新参数中的数据库会话
|
||||
args, kwargs = update_args_db(args, kwargs, db)
|
||||
args, kwargs = _update_args_db(args, kwargs, db)
|
||||
try:
|
||||
# 执行函数
|
||||
result = func(*args, **kwargs)
|
||||
@@ -154,6 +242,41 @@ def db_update(func):
|
||||
return wrapper
|
||||
|
||||
|
||||
def async_db_update(func):
|
||||
"""
|
||||
异步数据库更新类操作装饰器,第一个参数必须是异步数据库会话或存在db参数
|
||||
"""
|
||||
|
||||
async def wrapper(*args, **kwargs):
|
||||
# 是否关闭数据库会话
|
||||
_close_db = False
|
||||
# 从参数中获取异步数据库会话
|
||||
db = _get_args_async_db(args, kwargs)
|
||||
if not db:
|
||||
# 如果没有获取到异步数据库会话,创建一个
|
||||
db = AsyncSessionFactory()
|
||||
# 标记需要关闭数据库会话
|
||||
_close_db = True
|
||||
# 更新参数中的异步数据库会话
|
||||
args, kwargs = _update_args_async_db(args, kwargs, db)
|
||||
try:
|
||||
# 执行函数
|
||||
result = await func(*args, **kwargs)
|
||||
# 提交事务
|
||||
await db.commit()
|
||||
except Exception as err:
|
||||
# 回滚事务
|
||||
await db.rollback()
|
||||
raise err
|
||||
finally:
|
||||
# 关闭数据库会话
|
||||
if _close_db:
|
||||
await db.close()
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def db_query(func):
|
||||
"""
|
||||
数据库查询操作装饰器,第一个参数必须是数据库会话或存在db参数
|
||||
@@ -164,14 +287,14 @@ def db_query(func):
|
||||
# 是否关闭数据库会话
|
||||
_close_db = False
|
||||
# 从参数中获取数据库会话
|
||||
db = get_args_db(args, kwargs)
|
||||
db = _get_args_db(args, kwargs)
|
||||
if not db:
|
||||
# 如果没有获取到数据库会话,创建一个
|
||||
db = ScopedSession()
|
||||
# 标记需要关闭数据库会话
|
||||
_close_db = True
|
||||
# 更新参数中的数据库会话
|
||||
args, kwargs = update_args_db(args, kwargs, db)
|
||||
args, kwargs = _update_args_db(args, kwargs, db)
|
||||
try:
|
||||
# 执行函数
|
||||
result = func(*args, **kwargs)
|
||||
@@ -186,6 +309,38 @@ def db_query(func):
|
||||
return wrapper
|
||||
|
||||
|
||||
def async_db_query(func):
|
||||
"""
|
||||
异步数据库查询操作装饰器,第一个参数必须是异步数据库会话或存在db参数
|
||||
注意:db.query列表数据时,需要转换为list返回
|
||||
"""
|
||||
|
||||
async def wrapper(*args, **kwargs):
|
||||
# 是否关闭数据库会话
|
||||
_close_db = False
|
||||
# 从参数中获取异步数据库会话
|
||||
db = _get_args_async_db(args, kwargs)
|
||||
if not db:
|
||||
# 如果没有获取到异步数据库会话,创建一个
|
||||
db = AsyncSessionFactory()
|
||||
# 标记需要关闭数据库会话
|
||||
_close_db = True
|
||||
# 更新参数中的异步数据库会话
|
||||
args, kwargs = _update_args_async_db(args, kwargs, db)
|
||||
try:
|
||||
# 执行函数
|
||||
result = await func(*args, **kwargs)
|
||||
except Exception as err:
|
||||
raise err
|
||||
finally:
|
||||
# 关闭数据库会话
|
||||
if _close_db:
|
||||
await db.close()
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@as_declarative()
|
||||
class Base:
|
||||
id: Any
|
||||
@@ -195,11 +350,22 @@ class Base:
|
||||
def create(self, db: Session):
|
||||
db.add(self)
|
||||
|
||||
@async_db_update
|
||||
async def async_create(self, db: AsyncSession):
|
||||
db.add(self)
|
||||
|
||||
@classmethod
|
||||
@db_query
|
||||
def get(cls, db: Session, rid: int) -> Self:
|
||||
return db.query(cls).filter(and_(cls.id == rid)).first()
|
||||
|
||||
@classmethod
|
||||
@async_db_query
|
||||
async def async_get(cls, db: AsyncSession, rid: int) -> Self:
|
||||
from sqlalchemy import select
|
||||
result = await db.execute(select(cls).where(and_(cls.id == rid)))
|
||||
return result.scalars().first()
|
||||
|
||||
@db_update
|
||||
def update(self, db: Session, payload: dict):
|
||||
payload = {k: v for k, v in payload.items() if v is not None}
|
||||
@@ -208,21 +374,51 @@ class Base:
|
||||
if inspect(self).detached:
|
||||
db.add(self)
|
||||
|
||||
@async_db_update
|
||||
async def async_update(self, db: AsyncSession, payload: dict):
|
||||
payload = {k: v for k, v in payload.items() if v is not None}
|
||||
for key, value in payload.items():
|
||||
setattr(self, key, value)
|
||||
if inspect(self).detached:
|
||||
db.add(self)
|
||||
|
||||
@classmethod
|
||||
@db_update
|
||||
def delete(cls, db: Session, rid):
|
||||
db.query(cls).filter(and_(cls.id == rid)).delete()
|
||||
|
||||
@classmethod
|
||||
@async_db_update
|
||||
async def async_delete(cls, db: AsyncSession, rid):
|
||||
from sqlalchemy import select
|
||||
result = await db.execute(select(cls).where(and_(cls.id == rid)))
|
||||
user = result.scalars().first()
|
||||
if user:
|
||||
await db.delete(user)
|
||||
|
||||
@classmethod
|
||||
@db_update
|
||||
def truncate(cls, db: Session):
|
||||
db.query(cls).delete()
|
||||
|
||||
@classmethod
|
||||
@async_db_update
|
||||
async def async_truncate(cls, db: AsyncSession):
|
||||
from sqlalchemy import delete
|
||||
await db.execute(delete(cls))
|
||||
|
||||
@classmethod
|
||||
@db_query
|
||||
def list(cls, db: Session) -> List[Self]:
|
||||
return db.query(cls).all()
|
||||
|
||||
@classmethod
|
||||
@async_db_query
|
||||
async def async_list(cls, db: AsyncSession) -> Sequence[Self]:
|
||||
from sqlalchemy import select
|
||||
result = await db.execute(select(cls))
|
||||
return result.scalars().all()
|
||||
|
||||
def to_dict(self):
|
||||
return {c.name: getattr(self, c.name, None) for c in self.__table__.columns} # noqa
|
||||
|
||||
@@ -238,3 +434,12 @@ class DbOper:
|
||||
|
||||
def __init__(self, db: Session = None):
|
||||
self._db = db
|
||||
|
||||
|
||||
class AsyncDbOper:
|
||||
"""
|
||||
异步数据库操作基类
|
||||
"""
|
||||
|
||||
def __init__(self, db: AsyncSession = None):
|
||||
self._db = db
|
||||
|
||||
@@ -1,136 +0,0 @@
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import wraps, partial
|
||||
from typing import Any, Callable, Coroutine, TypeVar
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.db import ScopedSession
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
# 全局线程池,用于执行同步数据库操作
|
||||
_db_executor = ThreadPoolExecutor(max_workers=10, thread_name_prefix="db_async")
|
||||
|
||||
|
||||
def async_db_operation(func: Callable[..., T]) -> Callable[..., Coroutine[Any, Any, T]]:
|
||||
"""
|
||||
将同步数据库操作转换为异步操作的装饰器
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs) -> T:
|
||||
# 在线程池中执行同步数据库操作
|
||||
loop = asyncio.get_event_loop()
|
||||
partial_func = partial(func, *args, **kwargs)
|
||||
return await loop.run_in_executor(_db_executor, partial_func, ())
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def async_db_session(func: Callable[..., T]) -> Callable[..., Coroutine[Any, Any, T]]:
|
||||
"""
|
||||
为异步操作提供数据库会话的装饰器
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs) -> T:
|
||||
# 创建数据库会话
|
||||
db = ScopedSession()
|
||||
try:
|
||||
# 将数据库会话添加到参数中
|
||||
if 'db' not in kwargs:
|
||||
kwargs['db'] = db
|
||||
# 在线程池中执行同步数据库操作
|
||||
loop = asyncio.get_event_loop()
|
||||
partial_func = partial(func, *args, **kwargs)
|
||||
return await loop.run_in_executor(_db_executor, partial_func, ())
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class AsyncDbOper:
|
||||
"""
|
||||
异步数据库操作基类
|
||||
"""
|
||||
|
||||
def __init__(self, db: Session = None):
|
||||
self._db = db
|
||||
|
||||
async def _get_db(self) -> Session:
|
||||
"""
|
||||
获取数据库会话
|
||||
"""
|
||||
if self._db:
|
||||
return self._db
|
||||
return ScopedSession()
|
||||
|
||||
@staticmethod
|
||||
async def _execute_sync(func: Callable[..., T], *args, **kwargs) -> T:
|
||||
"""
|
||||
在线程池中执行同步数据库操作
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
partial_func = partial(func, *args, **kwargs)
|
||||
return await loop.run_in_executor(_db_executor, partial_func, ())
|
||||
|
||||
|
||||
def to_async_db_oper(sync_oper_class):
|
||||
"""
|
||||
将同步数据库操作类转换为异步版本的装饰器
|
||||
"""
|
||||
|
||||
class AsyncOperClass(AsyncDbOper):
|
||||
def __init__(self, db: Session = None):
|
||||
super().__init__(db)
|
||||
self._sync_oper = sync_oper_class(db)
|
||||
|
||||
def __getattr__(self, name):
|
||||
"""动态获取同步操作类的方法并转换为异步"""
|
||||
if hasattr(self._sync_oper, name):
|
||||
method = getattr(self._sync_oper, name)
|
||||
if callable(method):
|
||||
return async_db_operation(method)
|
||||
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
|
||||
|
||||
return AsyncOperClass
|
||||
|
||||
|
||||
# 异步数据库会话获取函数
|
||||
async def get_async_db():
|
||||
"""
|
||||
异步获取数据库会话
|
||||
"""
|
||||
|
||||
def _get_db():
|
||||
return ScopedSession()
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(_db_executor, _get_db) # type: ignore
|
||||
|
||||
|
||||
# 异步上下文管理器
|
||||
class AsyncDbSession:
|
||||
"""
|
||||
异步数据库会话上下文管理器
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.db = None
|
||||
|
||||
async def __aenter__(self):
|
||||
self.db = await get_async_db()
|
||||
return self.db
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
if self.db:
|
||||
self.db.close()
|
||||
|
||||
|
||||
def shutdown_db_executor():
|
||||
"""关闭数据库线程池"""
|
||||
global _db_executor
|
||||
if _db_executor:
|
||||
_db_executor.shutdown(wait=True)
|
||||
@@ -1,7 +1,8 @@
|
||||
from sqlalchemy import Column, Integer, String, Sequence
|
||||
from sqlalchemy import Column, Integer, String, Sequence, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.db import db_query, Base
|
||||
from app.db import db_query, Base, async_db_query
|
||||
|
||||
|
||||
class SiteIcon(Base):
|
||||
@@ -22,3 +23,9 @@ class SiteIcon(Base):
|
||||
@db_query
|
||||
def get_by_domain(db: Session, domain: str):
|
||||
return db.query(SiteIcon).filter(SiteIcon.domain == domain).first()
|
||||
|
||||
@classmethod
|
||||
@async_db_query
|
||||
async def async_get_by_domain(cls, db: AsyncSession, domain: str):
|
||||
result = await db.execute(select(cls).where(cls.domain == domain))
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
@@ -4,7 +4,6 @@ from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI
|
||||
|
||||
from app.chain.system import SystemChain
|
||||
from app.db.async_adapter import shutdown_db_executor
|
||||
from app.helper.system import SystemHelper
|
||||
from app.startup.command_initializer import init_command, stop_command, restart_command
|
||||
from app.startup.modules_initializer import init_modules, stop_modules
|
||||
@@ -80,6 +79,4 @@ async def lifespan(app: FastAPI):
|
||||
# 停止插件
|
||||
stop_plugins()
|
||||
# 停止模块
|
||||
stop_modules()
|
||||
# 关闭数据库异步执行器
|
||||
shutdown_db_executor()
|
||||
await stop_modules()
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import sys
|
||||
|
||||
from app.command import CommandChain
|
||||
from app.core.cache import close_cache
|
||||
from app.core.config import settings
|
||||
from app.core.module import ModuleManager
|
||||
from app.log import logger
|
||||
from app.utils.system import SystemUtils
|
||||
from app.command import CommandChain
|
||||
|
||||
# SitesHelper涉及资源包拉取,提前引入并容错提示
|
||||
try:
|
||||
@@ -105,7 +105,7 @@ def check_auth():
|
||||
)
|
||||
|
||||
|
||||
def stop_modules():
|
||||
async def stop_modules():
|
||||
"""
|
||||
服务关闭
|
||||
"""
|
||||
@@ -120,7 +120,7 @@ def stop_modules():
|
||||
# 停止缓存连接
|
||||
close_cache()
|
||||
# 停止数据库连接
|
||||
close_database()
|
||||
await close_database()
|
||||
# 停止前端服务
|
||||
stop_frontend()
|
||||
# 清理临时文件
|
||||
|
||||
@@ -1,171 +0,0 @@
|
||||
# 异步数据库操作指南
|
||||
|
||||
## 概述
|
||||
|
||||
本指南介绍如何在MoviePilot项目中实现异步数据库操作,而无需重写现有的同步数据库模块。
|
||||
|
||||
## 方案优势
|
||||
|
||||
1. **最小化改动**: 不需要重写现有的同步数据库操作代码
|
||||
2. **渐进式迁移**: 可以逐步将需要异步的操作迁移到异步版本
|
||||
3. **向后兼容**: 现有的同步代码继续正常工作
|
||||
4. **性能提升**: 通过线程池执行同步数据库操作,避免阻塞事件循环
|
||||
|
||||
## 核心组件
|
||||
|
||||
### 1. 异步适配器 (`app/db/async_adapter.py`)
|
||||
|
||||
提供以下功能:
|
||||
- `async_db_operation`: 将同步函数包装为异步函数
|
||||
- `async_db_session`: 为异步操作提供数据库会话
|
||||
- `AsyncDbOper`: 异步数据库操作基类
|
||||
- `to_async_db_oper`: 将同步操作类转换为异步版本
|
||||
- `AsyncDbSession`: 异步数据库会话上下文管理器
|
||||
|
||||
### 2. 异步操作类 (`app/db/async_user_oper.py`)
|
||||
|
||||
展示如何创建异步版本的数据库操作类。
|
||||
|
||||
## 使用方法
|
||||
|
||||
### 方法1: 使用装饰器自动转换
|
||||
|
||||
```python
|
||||
from app.db.user_oper import UserOper
|
||||
from app.db.async_adapter import to_async_db_oper
|
||||
|
||||
# 自动将同步类转换为异步版本
|
||||
AsyncUserOper = to_async_db_oper(UserOper)
|
||||
|
||||
# 使用异步版本
|
||||
async def example():
|
||||
async_user_oper = AsyncUserOper()
|
||||
users = await async_user_oper.list()
|
||||
return users
|
||||
```
|
||||
|
||||
### 方法2: 手动创建异步方法
|
||||
|
||||
```python
|
||||
from app.db.async_adapter import async_db_operation
|
||||
|
||||
class AsyncUserOperManual:
|
||||
def __init__(self, db: Session = None):
|
||||
self._sync_oper = UserOper(db)
|
||||
|
||||
@async_db_operation
|
||||
def list(self) -> List[User]:
|
||||
return self._sync_oper.list()
|
||||
|
||||
@async_db_operation
|
||||
def add(self, **kwargs):
|
||||
return self._sync_oper.add(**kwargs)
|
||||
```
|
||||
|
||||
### 方法3: 使用异步上下文管理器
|
||||
|
||||
```python
|
||||
from app.db.async_adapter import AsyncDbSession
|
||||
|
||||
async def example():
|
||||
async with AsyncDbSession() as db:
|
||||
# 直接使用同步操作,但在线程池中执行
|
||||
from functools import partial
|
||||
users = await asyncio.get_event_loop().run_in_executor(
|
||||
None, partial(lambda db: db.query(User).all(), db)
|
||||
)
|
||||
return users
|
||||
```
|
||||
|
||||
### 方法4: 装饰器包装单个函数
|
||||
|
||||
```python
|
||||
from app.db.async_adapter import async_db_operation
|
||||
|
||||
@async_db_operation
|
||||
def get_user_by_name_sync(db: Session, name: str) -> User:
|
||||
return db.query(User).filter(User.name == name).first()
|
||||
|
||||
async def example():
|
||||
async with AsyncDbSession() as db:
|
||||
user = await get_user_by_name_sync(db, "admin")
|
||||
return user
|
||||
```
|
||||
|
||||
## 在FastAPI中使用
|
||||
|
||||
### 依赖注入
|
||||
|
||||
```python
|
||||
async def get_current_user_async_dependency(
|
||||
token_data: schemas.TokenPayload = Depends(verify_token)
|
||||
) -> User:
|
||||
async with AsyncDbSession() as db:
|
||||
from functools import partial
|
||||
user = await asyncio.get_event_loop().run_in_executor(
|
||||
None, partial(User.get, db, rid=token_data.sub)
|
||||
)
|
||||
if not user:
|
||||
raise HTTPException(status_code=403, detail="用户不存在")
|
||||
return user
|
||||
|
||||
@router.get("/me/async")
|
||||
async def get_current_user_async(
|
||||
current_user: User = Depends(get_current_user_async_dependency)
|
||||
):
|
||||
return current_user
|
||||
```
|
||||
|
||||
### API端点
|
||||
|
||||
```python
|
||||
@router.get("/users/async")
|
||||
async def get_users_async():
|
||||
async_user_oper = AsyncUserOper()
|
||||
users = await async_user_oper.list()
|
||||
return users
|
||||
```
|
||||
|
||||
## 并发操作
|
||||
|
||||
### 批量操作
|
||||
|
||||
```python
|
||||
@router.post("/users/batch-async")
|
||||
async def create_users_batch_async(users_data: List[schemas.UserCreate]):
|
||||
async_user_oper = AsyncUserOper()
|
||||
|
||||
# 并发创建用户
|
||||
tasks = []
|
||||
for user_data in users_data:
|
||||
task = async_user_oper.add(**user_data.dict())
|
||||
tasks.append(task)
|
||||
|
||||
# 等待所有任务完成
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
# 获取创建的用户列表
|
||||
all_users = await async_user_oper.list()
|
||||
return all_users[-len(users_data):]
|
||||
```
|
||||
|
||||
### 复杂查询
|
||||
|
||||
```python
|
||||
@router.get("/users/stats/async")
|
||||
async def get_user_stats_async():
|
||||
async_user_oper = AsyncUserOper()
|
||||
|
||||
# 并发执行多个数据库查询
|
||||
users_task = async_user_oper.list()
|
||||
active_users_task = async_user_oper.get_active_users()
|
||||
|
||||
# 等待所有查询完成
|
||||
users, active_users = await asyncio.gather(users_task, active_users_task)
|
||||
|
||||
return {
|
||||
"total_users": len(users),
|
||||
"active_users": len(active_users),
|
||||
"inactive_users": len(users) - len(active_users)
|
||||
}
|
||||
```
|
||||
@@ -60,6 +60,7 @@ Pinyin2Hanzi~=0.1.1
|
||||
pywebpush~=2.0.3
|
||||
python-cookietools==0.0.4
|
||||
aiofiles~=24.1.0
|
||||
aiosqlite~=0.21.0
|
||||
jieba~=0.42.1
|
||||
rsa~=4.9
|
||||
redis~=6.2.0
|
||||
|
||||
Reference in New Issue
Block a user