fix async oper

This commit is contained in:
jxxghp
2025-07-30 18:48:50 +08:00
parent edec18cacb
commit 48d353aa90
6 changed files with 40 additions and 61 deletions

View File

@@ -15,7 +15,7 @@ from app.core.workflow import WorkFlowManager
from app.db import get_async_db, get_db
from app.db.models import Workflow
from app.db.systemconfig_oper import SystemConfigOper
from app.db.workflow_oper import AsyncWorkflowOper, WorkflowOper
from app.db.workflow_oper import WorkflowOper
from app.helper.workflow import WorkflowHelper
from app.scheduler import Scheduler
from app.schemas.types import EventType, EVENT_TYPE_NAMES
@@ -29,7 +29,7 @@ async def list_workflows(db: AsyncSession = Depends(get_async_db),
"""
获取工作流列表
"""
return await AsyncWorkflowOper(db).list()
return await WorkflowOper(db).async_list()
@router.post("/", summary="创建工作流", response_model=schemas.Response)
@@ -39,7 +39,7 @@ async def create_workflow(workflow: schemas.Workflow,
"""
创建工作流
"""
if workflow.name and await AsyncWorkflowOper(db).get_by_name(workflow.name):
if workflow.name and await WorkflowOper(db).async_get_by_name(workflow.name):
return schemas.Response(success=False, message="已存在相同名称的工作流")
if not workflow.add_time:
workflow.add_time = datetime.strftime(datetime.now(), "%Y-%m-%d %H:%M:%S")
@@ -149,8 +149,8 @@ async def workflow_fork(
}
# 检查名称是否重复
workflow_oper = AsyncWorkflowOper(db)
if await workflow_oper.get_by_name(workflow_dict["name"]):
workflow_oper = WorkflowOper(db)
if await workflow_oper.async_get_by_name(workflow_dict["name"]):
return schemas.Response(success=False, message="已存在相同名称的工作流")
# 创建新工作流
@@ -158,7 +158,7 @@ async def workflow_fork(
await workflow_obj.async_create(db)
# 获取工作流ID在数据库会话有效时
workflow = await workflow_oper.get_by_name(workflow_dict["name"])
workflow = await workflow_oper.async_get_by_name(workflow_dict["name"])
# 更新复用次数
if workflow:
@@ -244,7 +244,7 @@ async def reset_workflow(workflow_id: int,
"""
重置工作流
"""
workflow = await AsyncWorkflowOper(db).get(workflow_id)
workflow = await WorkflowOper(db).async_get(workflow_id)
if not workflow:
return schemas.Response(success=False, message="工作流不存在")
# 停止工作流
@@ -263,7 +263,7 @@ async def get_workflow(workflow_id: int,
"""
获取工作流详情
"""
return await AsyncWorkflowOper(db).get(workflow_id)
return await WorkflowOper(db).async_get(workflow_id)
@router.put("/{workflow_id}", summary="更新工作流", response_model=schemas.Response)

View File

@@ -1,5 +1,5 @@
import asyncio
from typing import Any, Generator, List, Optional, Self, Tuple, AsyncGenerator, Sequence
from typing import Any, Generator, List, Optional, Self, Tuple, AsyncGenerator, Sequence, Union
from sqlalchemy import NullPool, QueuePool, and_, create_engine, inspect, text, select, delete
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
@@ -428,14 +428,5 @@ class DbOper:
数据库操作基类
"""
def __init__(self, db: Session = None):
self._db = db
class AsyncDbOper:
"""
异步数据库操作基类
"""
def __init__(self, db: AsyncSession = None):
def __init__(self, db: Union[Session, AsyncSession] = None):
self._db = db

View File

@@ -2,7 +2,7 @@ import time
from typing import Tuple, List, Optional
from app.core.context import MediaInfo
from app.db import DbOper, AsyncDbOper
from app.db import DbOper
from app.db.models.subscribe import Subscribe
from app.db.models.subscribehistory import SubscribeHistory
@@ -68,6 +68,12 @@ class SubscribeOper(DbOper):
"""
return Subscribe.get(self._db, rid=sid)
async def async_get(self, sid: int) -> Subscribe:
"""
获取订阅
"""
return await Subscribe.async_get(self._db, id=sid)
def list(self, state: Optional[str] = None) -> List[Subscribe]:
"""
获取订阅列表
@@ -136,15 +142,3 @@ class SubscribeOper(DbOper):
elif doubanid:
return True if SubscribeHistory.exists(self._db, doubanid=doubanid) else False
return False
class AsyncSubscribeOper(AsyncDbOper):
"""
异步订阅管理
"""
async def get(self, sid: int) -> Subscribe:
"""
获取订阅
"""
return await Subscribe.async_get(self._db, id=sid)

View File

@@ -1,6 +1,6 @@
from typing import List, Tuple, Optional, Any, Coroutine, Sequence
from app.db import DbOper, AsyncDbOper
from app.db import DbOper
from app.db.models.workflow import Workflow
@@ -25,12 +25,24 @@ class WorkflowOper(DbOper):
"""
return Workflow.get(self._db, wid)
async def async_get(self, wid: int) -> Workflow:
"""
异步查询单个工作流
"""
return await Workflow.async_get(self._db, wid)
def list(self) -> List[Workflow]:
"""
获取所有工作流列表
"""
return Workflow.list(self._db)
async def async_list(self) -> Coroutine[Any, Any, Sequence[Any]]:
"""
异步获取所有工作流列表
"""
return await Workflow.async_list(self._db)
def list_enabled(self) -> List[Workflow]:
"""
获取启用的工作流列表
@@ -55,6 +67,12 @@ class WorkflowOper(DbOper):
"""
return Workflow.get_by_name(self._db, name)
async def async_get_by_name(self, name: str) -> Workflow:
"""
异步按名称获取工作流
"""
return await Workflow.async_get_by_name(self._db, name)
def start(self, wid: int) -> bool:
"""
启动
@@ -84,27 +102,3 @@ class WorkflowOper(DbOper):
重置
"""
return Workflow.reset(self._db, wid, reset_count=reset_count)
class AsyncWorkflowOper(AsyncDbOper):
"""
异步工作流管理
"""
async def get(self, wid: int) -> Workflow:
"""
异步查询单个工作流
"""
return await Workflow.async_get(self._db, wid)
async def list(self) -> Coroutine[Any, Any, Sequence[Any]]:
"""
异步获取所有工作流列表
"""
return await Workflow.async_list(self._db)
async def get_by_name(self, name: str) -> Workflow:
"""
异步按名称获取工作流
"""
return await Workflow.async_get_by_name(self._db, name)

View File

@@ -3,7 +3,7 @@ from typing import List, Tuple, Optional
from app.core.cache import cached, cache_backend
from app.core.config import settings
from app.db.subscribe_oper import SubscribeOper, AsyncSubscribeOper
from app.db.subscribe_oper import SubscribeOper
from app.db.systemconfig_oper import SystemConfigOper
from app.log import logger
from app.schemas.types import SystemConfigKey
@@ -190,7 +190,7 @@ class SubscribeHelper(metaclass=WeakSingleton):
"""
if not settings.SUBSCRIBE_STATISTIC_SHARE:
return False, "当前没有开启订阅数据共享功能"
subscribe = await AsyncSubscribeOper().get(subscribe_id)
subscribe = await SubscribeOper().async_get(subscribe_id)
if not subscribe:
return False, "订阅不存在"
subscribe_dict = subscribe.to_dict()

View File

@@ -3,7 +3,7 @@ from typing import List, Tuple, Optional
from app.core.cache import cached, cache_backend
from app.core.config import settings
from app.db.workflow_oper import WorkflowOper, AsyncWorkflowOper
from app.db.workflow_oper import WorkflowOper
from app.log import logger
from app.utils.http import RequestUtils, AsyncRequestUtils
from app.utils.singleton import WeakSingleton
@@ -79,7 +79,7 @@ class WorkflowHelper(metaclass=WeakSingleton):
return False, "当前没有开启工作流数据共享功能"
# 获取工作流信息
workflow = await AsyncWorkflowOper().get(workflow_id)
workflow = await WorkflowOper().async_get(workflow_id)
if not workflow:
return False, "工作流不存在"