fix async oper

This commit is contained in:
jxxghp
2025-07-30 18:48:50 +08:00
parent edec18cacb
commit 48d353aa90
6 changed files with 40 additions and 61 deletions

View File

@@ -1,5 +1,5 @@
import asyncio
from typing import Any, Generator, List, Optional, Self, Tuple, AsyncGenerator, Sequence
from typing import Any, Generator, List, Optional, Self, Tuple, AsyncGenerator, Sequence, Union
from sqlalchemy import NullPool, QueuePool, and_, create_engine, inspect, text, select, delete
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
@@ -428,14 +428,5 @@ class DbOper:
数据库操作基类
"""
def __init__(self, db: Session = None):
self._db = db
class AsyncDbOper:
"""
异步数据库操作基类
"""
def __init__(self, db: AsyncSession = None):
def __init__(self, db: Union[Session, AsyncSession] = None):
self._db = db

View File

@@ -2,7 +2,7 @@ import time
from typing import Tuple, List, Optional
from app.core.context import MediaInfo
from app.db import DbOper, AsyncDbOper
from app.db import DbOper
from app.db.models.subscribe import Subscribe
from app.db.models.subscribehistory import SubscribeHistory
@@ -68,6 +68,12 @@ class SubscribeOper(DbOper):
"""
return Subscribe.get(self._db, rid=sid)
async def async_get(self, sid: int) -> Subscribe:
"""
获取订阅
"""
return await Subscribe.async_get(self._db, id=sid)
def list(self, state: Optional[str] = None) -> List[Subscribe]:
"""
获取订阅列表
@@ -136,15 +142,3 @@ class SubscribeOper(DbOper):
elif doubanid:
return True if SubscribeHistory.exists(self._db, doubanid=doubanid) else False
return False
class AsyncSubscribeOper(AsyncDbOper):
"""
异步订阅管理
"""
async def get(self, sid: int) -> Subscribe:
"""
获取订阅
"""
return await Subscribe.async_get(self._db, id=sid)

View File

@@ -1,6 +1,6 @@
from typing import List, Tuple, Optional, Any, Coroutine, Sequence
from app.db import DbOper, AsyncDbOper
from app.db import DbOper
from app.db.models.workflow import Workflow
@@ -25,12 +25,24 @@ class WorkflowOper(DbOper):
"""
return Workflow.get(self._db, wid)
async def async_get(self, wid: int) -> Workflow:
"""
异步查询单个工作流
"""
return await Workflow.async_get(self._db, wid)
def list(self) -> List[Workflow]:
"""
获取所有工作流列表
"""
return Workflow.list(self._db)
async def async_list(self) -> Coroutine[Any, Any, Sequence[Any]]:
"""
异步获取所有工作流列表
"""
return await Workflow.async_list(self._db)
def list_enabled(self) -> List[Workflow]:
"""
获取启用的工作流列表
@@ -55,6 +67,12 @@ class WorkflowOper(DbOper):
"""
return Workflow.get_by_name(self._db, name)
async def async_get_by_name(self, name: str) -> Workflow:
"""
异步按名称获取工作流
"""
return await Workflow.async_get_by_name(self._db, name)
def start(self, wid: int) -> bool:
"""
启动
@@ -84,27 +102,3 @@ class WorkflowOper(DbOper):
重置
"""
return Workflow.reset(self._db, wid, reset_count=reset_count)
class AsyncWorkflowOper(AsyncDbOper):
"""
异步工作流管理
"""
async def get(self, wid: int) -> Workflow:
"""
异步查询单个工作流
"""
return await Workflow.async_get(self._db, wid)
async def list(self) -> Coroutine[Any, Any, Sequence[Any]]:
"""
异步获取所有工作流列表
"""
return await Workflow.async_list(self._db)
async def get_by_name(self, name: str) -> Workflow:
"""
异步按名称获取工作流
"""
return await Workflow.async_get_by_name(self._db, name)