diff --git a/app/api/endpoints/workflow.py b/app/api/endpoints/workflow.py index a1652b95..ef8bd11a 100644 --- a/app/api/endpoints/workflow.py +++ b/app/api/endpoints/workflow.py @@ -15,7 +15,7 @@ from app.core.workflow import WorkFlowManager from app.db import get_async_db, get_db from app.db.models import Workflow from app.db.systemconfig_oper import SystemConfigOper -from app.db.workflow_oper import AsyncWorkflowOper, WorkflowOper +from app.db.workflow_oper import WorkflowOper from app.helper.workflow import WorkflowHelper from app.scheduler import Scheduler from app.schemas.types import EventType, EVENT_TYPE_NAMES @@ -29,7 +29,7 @@ async def list_workflows(db: AsyncSession = Depends(get_async_db), """ 获取工作流列表 """ - return await AsyncWorkflowOper(db).list() + return await WorkflowOper(db).async_list() @router.post("/", summary="创建工作流", response_model=schemas.Response) @@ -39,7 +39,7 @@ async def create_workflow(workflow: schemas.Workflow, """ 创建工作流 """ - if workflow.name and await AsyncWorkflowOper(db).get_by_name(workflow.name): + if workflow.name and await WorkflowOper(db).async_get_by_name(workflow.name): return schemas.Response(success=False, message="已存在相同名称的工作流") if not workflow.add_time: workflow.add_time = datetime.strftime(datetime.now(), "%Y-%m-%d %H:%M:%S") @@ -149,8 +149,8 @@ async def workflow_fork( } # 检查名称是否重复 - workflow_oper = AsyncWorkflowOper(db) - if await workflow_oper.get_by_name(workflow_dict["name"]): + workflow_oper = WorkflowOper(db) + if await workflow_oper.async_get_by_name(workflow_dict["name"]): return schemas.Response(success=False, message="已存在相同名称的工作流") # 创建新工作流 @@ -158,7 +158,7 @@ async def workflow_fork( await workflow_obj.async_create(db) # 获取工作流ID(在数据库会话有效时) - workflow = await workflow_oper.get_by_name(workflow_dict["name"]) + workflow = await workflow_oper.async_get_by_name(workflow_dict["name"]) # 更新复用次数 if workflow: @@ -244,7 +244,7 @@ async def reset_workflow(workflow_id: int, """ 重置工作流 """ - workflow = await AsyncWorkflowOper(db).get(workflow_id) + workflow = await WorkflowOper(db).async_get(workflow_id) if not workflow: return schemas.Response(success=False, message="工作流不存在") # 停止工作流 @@ -263,7 +263,7 @@ async def get_workflow(workflow_id: int, """ 获取工作流详情 """ - return await AsyncWorkflowOper(db).get(workflow_id) + return await WorkflowOper(db).async_get(workflow_id) @router.put("/{workflow_id}", summary="更新工作流", response_model=schemas.Response) diff --git a/app/db/__init__.py b/app/db/__init__.py index e8c0b172..3f07ce64 100644 --- a/app/db/__init__.py +++ b/app/db/__init__.py @@ -1,5 +1,5 @@ import asyncio -from typing import Any, Generator, List, Optional, Self, Tuple, AsyncGenerator, Sequence +from typing import Any, Generator, List, Optional, Self, Tuple, AsyncGenerator, Sequence, Union from sqlalchemy import NullPool, QueuePool, and_, create_engine, inspect, text, select, delete from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker @@ -428,14 +428,5 @@ class DbOper: 数据库操作基类 """ - def __init__(self, db: Session = None): - self._db = db - - -class AsyncDbOper: - """ - 异步数据库操作基类 - """ - - def __init__(self, db: AsyncSession = None): + def __init__(self, db: Union[Session, AsyncSession] = None): self._db = db diff --git a/app/db/subscribe_oper.py b/app/db/subscribe_oper.py index 69b64237..1d4b2a36 100644 --- a/app/db/subscribe_oper.py +++ b/app/db/subscribe_oper.py @@ -2,7 +2,7 @@ import time from typing import Tuple, List, Optional from app.core.context import MediaInfo -from app.db import DbOper, AsyncDbOper +from app.db import DbOper from app.db.models.subscribe import Subscribe from app.db.models.subscribehistory import SubscribeHistory @@ -68,6 +68,12 @@ class SubscribeOper(DbOper): """ return Subscribe.get(self._db, rid=sid) + async def async_get(self, sid: int) -> Subscribe: + """ + 获取订阅 + """ + return await Subscribe.async_get(self._db, id=sid) + def list(self, state: Optional[str] = None) -> List[Subscribe]: """ 获取订阅列表 @@ -136,15 +142,3 @@ class SubscribeOper(DbOper): elif doubanid: return True if SubscribeHistory.exists(self._db, doubanid=doubanid) else False return False - - -class AsyncSubscribeOper(AsyncDbOper): - """ - 异步订阅管理 - """ - - async def get(self, sid: int) -> Subscribe: - """ - 获取订阅 - """ - return await Subscribe.async_get(self._db, id=sid) diff --git a/app/db/workflow_oper.py b/app/db/workflow_oper.py index 66ea26c9..0175dbb0 100644 --- a/app/db/workflow_oper.py +++ b/app/db/workflow_oper.py @@ -1,6 +1,6 @@ from typing import List, Tuple, Optional, Any, Coroutine, Sequence -from app.db import DbOper, AsyncDbOper +from app.db import DbOper from app.db.models.workflow import Workflow @@ -25,12 +25,24 @@ class WorkflowOper(DbOper): """ return Workflow.get(self._db, wid) + async def async_get(self, wid: int) -> Workflow: + """ + 异步查询单个工作流 + """ + return await Workflow.async_get(self._db, wid) + def list(self) -> List[Workflow]: """ 获取所有工作流列表 """ return Workflow.list(self._db) + async def async_list(self) -> Coroutine[Any, Any, Sequence[Any]]: + """ + 异步获取所有工作流列表 + """ + return await Workflow.async_list(self._db) + def list_enabled(self) -> List[Workflow]: """ 获取启用的工作流列表 @@ -55,6 +67,12 @@ class WorkflowOper(DbOper): """ return Workflow.get_by_name(self._db, name) + async def async_get_by_name(self, name: str) -> Workflow: + """ + 异步按名称获取工作流 + """ + return await Workflow.async_get_by_name(self._db, name) + def start(self, wid: int) -> bool: """ 启动 @@ -84,27 +102,3 @@ class WorkflowOper(DbOper): 重置 """ return Workflow.reset(self._db, wid, reset_count=reset_count) - - -class AsyncWorkflowOper(AsyncDbOper): - """ - 异步工作流管理 - """ - - async def get(self, wid: int) -> Workflow: - """ - 异步查询单个工作流 - """ - return await Workflow.async_get(self._db, wid) - - async def list(self) -> Coroutine[Any, Any, Sequence[Any]]: - """ - 异步获取所有工作流列表 - """ - return await Workflow.async_list(self._db) - - async def get_by_name(self, name: str) -> Workflow: - """ - 异步按名称获取工作流 - """ - return await Workflow.async_get_by_name(self._db, name) diff --git a/app/helper/subscribe.py b/app/helper/subscribe.py index 369230e7..d0c21bfe 100644 --- a/app/helper/subscribe.py +++ b/app/helper/subscribe.py @@ -3,7 +3,7 @@ from typing import List, Tuple, Optional from app.core.cache import cached, cache_backend from app.core.config import settings -from app.db.subscribe_oper import SubscribeOper, AsyncSubscribeOper +from app.db.subscribe_oper import SubscribeOper from app.db.systemconfig_oper import SystemConfigOper from app.log import logger from app.schemas.types import SystemConfigKey @@ -190,7 +190,7 @@ class SubscribeHelper(metaclass=WeakSingleton): """ if not settings.SUBSCRIBE_STATISTIC_SHARE: return False, "当前没有开启订阅数据共享功能" - subscribe = await AsyncSubscribeOper().get(subscribe_id) + subscribe = await SubscribeOper().async_get(subscribe_id) if not subscribe: return False, "订阅不存在" subscribe_dict = subscribe.to_dict() diff --git a/app/helper/workflow.py b/app/helper/workflow.py index 3e9f7e40..c889f1f5 100644 --- a/app/helper/workflow.py +++ b/app/helper/workflow.py @@ -3,7 +3,7 @@ from typing import List, Tuple, Optional from app.core.cache import cached, cache_backend from app.core.config import settings -from app.db.workflow_oper import WorkflowOper, AsyncWorkflowOper +from app.db.workflow_oper import WorkflowOper from app.log import logger from app.utils.http import RequestUtils, AsyncRequestUtils from app.utils.singleton import WeakSingleton @@ -79,7 +79,7 @@ class WorkflowHelper(metaclass=WeakSingleton): return False, "当前没有开启工作流数据共享功能" # 获取工作流信息 - workflow = await AsyncWorkflowOper().get(workflow_id) + workflow = await WorkflowOper().async_get(workflow_id) if not workflow: return False, "工作流不存在"