mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-04-05 03:38:36 +08:00
fix async oper
This commit is contained in:
@@ -15,7 +15,7 @@ from app.core.workflow import WorkFlowManager
|
||||
from app.db import get_async_db, get_db
|
||||
from app.db.models import Workflow
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.db.workflow_oper import AsyncWorkflowOper, WorkflowOper
|
||||
from app.db.workflow_oper import WorkflowOper
|
||||
from app.helper.workflow import WorkflowHelper
|
||||
from app.scheduler import Scheduler
|
||||
from app.schemas.types import EventType, EVENT_TYPE_NAMES
|
||||
@@ -29,7 +29,7 @@ async def list_workflows(db: AsyncSession = Depends(get_async_db),
|
||||
"""
|
||||
获取工作流列表
|
||||
"""
|
||||
return await AsyncWorkflowOper(db).list()
|
||||
return await WorkflowOper(db).async_list()
|
||||
|
||||
|
||||
@router.post("/", summary="创建工作流", response_model=schemas.Response)
|
||||
@@ -39,7 +39,7 @@ async def create_workflow(workflow: schemas.Workflow,
|
||||
"""
|
||||
创建工作流
|
||||
"""
|
||||
if workflow.name and await AsyncWorkflowOper(db).get_by_name(workflow.name):
|
||||
if workflow.name and await WorkflowOper(db).async_get_by_name(workflow.name):
|
||||
return schemas.Response(success=False, message="已存在相同名称的工作流")
|
||||
if not workflow.add_time:
|
||||
workflow.add_time = datetime.strftime(datetime.now(), "%Y-%m-%d %H:%M:%S")
|
||||
@@ -149,8 +149,8 @@ async def workflow_fork(
|
||||
}
|
||||
|
||||
# 检查名称是否重复
|
||||
workflow_oper = AsyncWorkflowOper(db)
|
||||
if await workflow_oper.get_by_name(workflow_dict["name"]):
|
||||
workflow_oper = WorkflowOper(db)
|
||||
if await workflow_oper.async_get_by_name(workflow_dict["name"]):
|
||||
return schemas.Response(success=False, message="已存在相同名称的工作流")
|
||||
|
||||
# 创建新工作流
|
||||
@@ -158,7 +158,7 @@ async def workflow_fork(
|
||||
await workflow_obj.async_create(db)
|
||||
|
||||
# 获取工作流ID(在数据库会话有效时)
|
||||
workflow = await workflow_oper.get_by_name(workflow_dict["name"])
|
||||
workflow = await workflow_oper.async_get_by_name(workflow_dict["name"])
|
||||
|
||||
# 更新复用次数
|
||||
if workflow:
|
||||
@@ -244,7 +244,7 @@ async def reset_workflow(workflow_id: int,
|
||||
"""
|
||||
重置工作流
|
||||
"""
|
||||
workflow = await AsyncWorkflowOper(db).get(workflow_id)
|
||||
workflow = await WorkflowOper(db).async_get(workflow_id)
|
||||
if not workflow:
|
||||
return schemas.Response(success=False, message="工作流不存在")
|
||||
# 停止工作流
|
||||
@@ -263,7 +263,7 @@ async def get_workflow(workflow_id: int,
|
||||
"""
|
||||
获取工作流详情
|
||||
"""
|
||||
return await AsyncWorkflowOper(db).get(workflow_id)
|
||||
return await WorkflowOper(db).async_get(workflow_id)
|
||||
|
||||
|
||||
@router.put("/{workflow_id}", summary="更新工作流", response_model=schemas.Response)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import asyncio
|
||||
from typing import Any, Generator, List, Optional, Self, Tuple, AsyncGenerator, Sequence
|
||||
from typing import Any, Generator, List, Optional, Self, Tuple, AsyncGenerator, Sequence, Union
|
||||
|
||||
from sqlalchemy import NullPool, QueuePool, and_, create_engine, inspect, text, select, delete
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
|
||||
@@ -428,14 +428,5 @@ class DbOper:
|
||||
数据库操作基类
|
||||
"""
|
||||
|
||||
def __init__(self, db: Session = None):
|
||||
self._db = db
|
||||
|
||||
|
||||
class AsyncDbOper:
|
||||
"""
|
||||
异步数据库操作基类
|
||||
"""
|
||||
|
||||
def __init__(self, db: AsyncSession = None):
|
||||
def __init__(self, db: Union[Session, AsyncSession] = None):
|
||||
self._db = db
|
||||
|
||||
@@ -2,7 +2,7 @@ import time
|
||||
from typing import Tuple, List, Optional
|
||||
|
||||
from app.core.context import MediaInfo
|
||||
from app.db import DbOper, AsyncDbOper
|
||||
from app.db import DbOper
|
||||
from app.db.models.subscribe import Subscribe
|
||||
from app.db.models.subscribehistory import SubscribeHistory
|
||||
|
||||
@@ -68,6 +68,12 @@ class SubscribeOper(DbOper):
|
||||
"""
|
||||
return Subscribe.get(self._db, rid=sid)
|
||||
|
||||
async def async_get(self, sid: int) -> Subscribe:
|
||||
"""
|
||||
获取订阅
|
||||
"""
|
||||
return await Subscribe.async_get(self._db, id=sid)
|
||||
|
||||
def list(self, state: Optional[str] = None) -> List[Subscribe]:
|
||||
"""
|
||||
获取订阅列表
|
||||
@@ -136,15 +142,3 @@ class SubscribeOper(DbOper):
|
||||
elif doubanid:
|
||||
return True if SubscribeHistory.exists(self._db, doubanid=doubanid) else False
|
||||
return False
|
||||
|
||||
|
||||
class AsyncSubscribeOper(AsyncDbOper):
|
||||
"""
|
||||
异步订阅管理
|
||||
"""
|
||||
|
||||
async def get(self, sid: int) -> Subscribe:
|
||||
"""
|
||||
获取订阅
|
||||
"""
|
||||
return await Subscribe.async_get(self._db, id=sid)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import List, Tuple, Optional, Any, Coroutine, Sequence
|
||||
|
||||
from app.db import DbOper, AsyncDbOper
|
||||
from app.db import DbOper
|
||||
from app.db.models.workflow import Workflow
|
||||
|
||||
|
||||
@@ -25,12 +25,24 @@ class WorkflowOper(DbOper):
|
||||
"""
|
||||
return Workflow.get(self._db, wid)
|
||||
|
||||
async def async_get(self, wid: int) -> Workflow:
|
||||
"""
|
||||
异步查询单个工作流
|
||||
"""
|
||||
return await Workflow.async_get(self._db, wid)
|
||||
|
||||
def list(self) -> List[Workflow]:
|
||||
"""
|
||||
获取所有工作流列表
|
||||
"""
|
||||
return Workflow.list(self._db)
|
||||
|
||||
async def async_list(self) -> Coroutine[Any, Any, Sequence[Any]]:
|
||||
"""
|
||||
异步获取所有工作流列表
|
||||
"""
|
||||
return await Workflow.async_list(self._db)
|
||||
|
||||
def list_enabled(self) -> List[Workflow]:
|
||||
"""
|
||||
获取启用的工作流列表
|
||||
@@ -55,6 +67,12 @@ class WorkflowOper(DbOper):
|
||||
"""
|
||||
return Workflow.get_by_name(self._db, name)
|
||||
|
||||
async def async_get_by_name(self, name: str) -> Workflow:
|
||||
"""
|
||||
异步按名称获取工作流
|
||||
"""
|
||||
return await Workflow.async_get_by_name(self._db, name)
|
||||
|
||||
def start(self, wid: int) -> bool:
|
||||
"""
|
||||
启动
|
||||
@@ -84,27 +102,3 @@ class WorkflowOper(DbOper):
|
||||
重置
|
||||
"""
|
||||
return Workflow.reset(self._db, wid, reset_count=reset_count)
|
||||
|
||||
|
||||
class AsyncWorkflowOper(AsyncDbOper):
|
||||
"""
|
||||
异步工作流管理
|
||||
"""
|
||||
|
||||
async def get(self, wid: int) -> Workflow:
|
||||
"""
|
||||
异步查询单个工作流
|
||||
"""
|
||||
return await Workflow.async_get(self._db, wid)
|
||||
|
||||
async def list(self) -> Coroutine[Any, Any, Sequence[Any]]:
|
||||
"""
|
||||
异步获取所有工作流列表
|
||||
"""
|
||||
return await Workflow.async_list(self._db)
|
||||
|
||||
async def get_by_name(self, name: str) -> Workflow:
|
||||
"""
|
||||
异步按名称获取工作流
|
||||
"""
|
||||
return await Workflow.async_get_by_name(self._db, name)
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import List, Tuple, Optional
|
||||
|
||||
from app.core.cache import cached, cache_backend
|
||||
from app.core.config import settings
|
||||
from app.db.subscribe_oper import SubscribeOper, AsyncSubscribeOper
|
||||
from app.db.subscribe_oper import SubscribeOper
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.log import logger
|
||||
from app.schemas.types import SystemConfigKey
|
||||
@@ -190,7 +190,7 @@ class SubscribeHelper(metaclass=WeakSingleton):
|
||||
"""
|
||||
if not settings.SUBSCRIBE_STATISTIC_SHARE:
|
||||
return False, "当前没有开启订阅数据共享功能"
|
||||
subscribe = await AsyncSubscribeOper().get(subscribe_id)
|
||||
subscribe = await SubscribeOper().async_get(subscribe_id)
|
||||
if not subscribe:
|
||||
return False, "订阅不存在"
|
||||
subscribe_dict = subscribe.to_dict()
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import List, Tuple, Optional
|
||||
|
||||
from app.core.cache import cached, cache_backend
|
||||
from app.core.config import settings
|
||||
from app.db.workflow_oper import WorkflowOper, AsyncWorkflowOper
|
||||
from app.db.workflow_oper import WorkflowOper
|
||||
from app.log import logger
|
||||
from app.utils.http import RequestUtils, AsyncRequestUtils
|
||||
from app.utils.singleton import WeakSingleton
|
||||
@@ -79,7 +79,7 @@ class WorkflowHelper(metaclass=WeakSingleton):
|
||||
return False, "当前没有开启工作流数据共享功能"
|
||||
|
||||
# 获取工作流信息
|
||||
workflow = await AsyncWorkflowOper().get(workflow_id)
|
||||
workflow = await WorkflowOper().async_get(workflow_id)
|
||||
if not workflow:
|
||||
return False, "工作流不存在"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user