mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-03-20 03:57:30 +08:00
重构工作流相关API,支持异步操作并引入异步数据库管理
This commit is contained in:
@@ -1,9 +1,10 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import Column, Integer, JSON, Sequence, String, and_, or_
|
||||
from sqlalchemy import Column, Integer, JSON, Sequence, String, and_, or_, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import Base, db_query, db_update
|
||||
from app.db import Base, db_query, db_update, async_db_query, async_db_update
|
||||
|
||||
|
||||
class Workflow(Base):
|
||||
@@ -48,11 +49,23 @@ class Workflow(Base):
|
||||
def list(db):
|
||||
return db.query(Workflow).all()
|
||||
|
||||
@staticmethod
|
||||
@async_db_query
|
||||
async def async_list(db: AsyncSession):
|
||||
result = await db.execute(select(Workflow))
|
||||
return result.scalars().all()
|
||||
|
||||
@staticmethod
|
||||
@db_query
|
||||
def get_enabled_workflows(db):
|
||||
return db.query(Workflow).filter(Workflow.state != 'P').all()
|
||||
|
||||
@staticmethod
|
||||
@async_db_query
|
||||
async def async_get_enabled_workflows(db: AsyncSession):
|
||||
result = await db.execute(select(Workflow).where(Workflow.state != 'P'))
|
||||
return result.scalars().all()
|
||||
|
||||
@staticmethod
|
||||
@db_query
|
||||
def get_timer_triggered_workflows(db):
|
||||
@@ -67,6 +80,21 @@ class Workflow(Base):
|
||||
)
|
||||
).all()
|
||||
|
||||
@staticmethod
|
||||
@async_db_query
|
||||
async def async_get_timer_triggered_workflows(db: AsyncSession):
|
||||
"""异步获取定时触发的工作流"""
|
||||
result = await db.execute(select(Workflow).where(
|
||||
and_(
|
||||
or_(
|
||||
Workflow.trigger_type == 'timer',
|
||||
not Workflow.trigger_type
|
||||
),
|
||||
Workflow.state != 'P'
|
||||
)
|
||||
))
|
||||
return result.scalars().all()
|
||||
|
||||
@staticmethod
|
||||
@db_query
|
||||
def get_event_triggered_workflows(db):
|
||||
@@ -78,17 +106,42 @@ class Workflow(Base):
|
||||
)
|
||||
).all()
|
||||
|
||||
@staticmethod
|
||||
@async_db_query
|
||||
async def async_get_event_triggered_workflows(db: AsyncSession):
|
||||
"""异步获取事件触发的工作流"""
|
||||
result = await db.execute(select(Workflow).where(
|
||||
and_(
|
||||
Workflow.trigger_type == 'event',
|
||||
Workflow.state != 'P'
|
||||
)
|
||||
))
|
||||
return result.scalars().all()
|
||||
|
||||
@staticmethod
|
||||
@db_query
|
||||
def get_by_name(db, name: str):
|
||||
return db.query(Workflow).filter(Workflow.name == name).first()
|
||||
|
||||
@staticmethod
|
||||
@async_db_query
|
||||
async def async_get_by_name(db: AsyncSession, name: str):
|
||||
result = await db.execute(select(Workflow).where(Workflow.name == name))
|
||||
return result.scalars().first()
|
||||
|
||||
@staticmethod
|
||||
@db_update
|
||||
def update_state(db, wid: int, state: str):
|
||||
db.query(Workflow).filter(Workflow.id == wid).update({"state": state})
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
@async_db_update
|
||||
async def async_update_state(db: AsyncSession, wid: int, state: str):
|
||||
from sqlalchemy import update
|
||||
await db.execute(update(Workflow).where(Workflow.id == wid).values(state=state))
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
@db_update
|
||||
def start(db, wid: int):
|
||||
@@ -97,6 +150,13 @@ class Workflow(Base):
|
||||
})
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
@async_db_update
|
||||
async def async_start(db: AsyncSession, wid: int):
|
||||
from sqlalchemy import update
|
||||
await db.execute(update(Workflow).where(Workflow.id == wid).values(state='R'))
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
@db_update
|
||||
def fail(db, wid: int, result: str):
|
||||
@@ -107,6 +167,19 @@ class Workflow(Base):
|
||||
})
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
@async_db_update
|
||||
async def async_fail(db: AsyncSession, wid: int, result: str):
|
||||
from sqlalchemy import update
|
||||
await db.execute(update(Workflow).where(
|
||||
and_(Workflow.id == wid, Workflow.state != "P")
|
||||
).values(
|
||||
state='F',
|
||||
result=result,
|
||||
last_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||||
))
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
@db_update
|
||||
def success(db, wid: int, result: Optional[str] = None):
|
||||
@@ -118,6 +191,20 @@ class Workflow(Base):
|
||||
})
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
@async_db_update
|
||||
async def async_success(db: AsyncSession, wid: int, result: Optional[str] = None):
|
||||
from sqlalchemy import update
|
||||
await db.execute(update(Workflow).where(
|
||||
and_(Workflow.id == wid, Workflow.state != "P")
|
||||
).values(
|
||||
state='S',
|
||||
result=result,
|
||||
run_count=Workflow.run_count + 1,
|
||||
last_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||||
))
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
@db_update
|
||||
def reset(db, wid: int, reset_count: Optional[bool] = False):
|
||||
@@ -129,6 +216,18 @@ class Workflow(Base):
|
||||
})
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
@async_db_update
|
||||
async def async_reset(db: AsyncSession, wid: int, reset_count: Optional[bool] = False):
|
||||
from sqlalchemy import update
|
||||
await db.execute(update(Workflow).where(Workflow.id == wid).values(
|
||||
state='W',
|
||||
result=None,
|
||||
current_action=None,
|
||||
run_count=0 if reset_count else Workflow.run_count,
|
||||
))
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
@db_update
|
||||
def update_current_action(db, wid: int, action_id: str, context: dict):
|
||||
@@ -137,3 +236,18 @@ class Workflow(Base):
|
||||
"context": context
|
||||
})
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
@async_db_update
|
||||
async def async_update_current_action(db: AsyncSession, wid: int, action_id: str, context: dict):
|
||||
from sqlalchemy import update
|
||||
# 先获取当前current_action
|
||||
result = await db.execute(select(Workflow.current_action).where(Workflow.id == wid))
|
||||
current_action = result.scalar()
|
||||
new_current_action = current_action + f",{action_id}" if current_action else action_id
|
||||
|
||||
await db.execute(update(Workflow).where(Workflow.id == wid).values(
|
||||
current_action=new_current_action,
|
||||
context=context
|
||||
))
|
||||
return True
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import List, Tuple, Optional
|
||||
from typing import List, Tuple, Optional, Any, Coroutine, Sequence
|
||||
|
||||
from app.db import DbOper
|
||||
from app.db import DbOper, AsyncDbOper
|
||||
from app.db.models.workflow import Workflow
|
||||
|
||||
|
||||
@@ -84,3 +84,27 @@ class WorkflowOper(DbOper):
|
||||
重置
|
||||
"""
|
||||
return Workflow.reset(self._db, wid, reset_count=reset_count)
|
||||
|
||||
|
||||
class AsyncWorkflowOper(AsyncDbOper):
|
||||
"""
|
||||
异步工作流管理
|
||||
"""
|
||||
|
||||
async def get(self, wid: int) -> Workflow:
|
||||
"""
|
||||
异步查询单个工作流
|
||||
"""
|
||||
return await Workflow.async_get(self._db, wid)
|
||||
|
||||
async def list(self) -> Coroutine[Any, Any, Sequence[Any]]:
|
||||
"""
|
||||
异步获取所有工作流列表
|
||||
"""
|
||||
return await Workflow.async_list(self._db)
|
||||
|
||||
async def get_by_name(self, name: str) -> Workflow:
|
||||
"""
|
||||
异步按名称获取工作流
|
||||
"""
|
||||
return await Workflow.async_get_by_name(self._db, name)
|
||||
|
||||
Reference in New Issue
Block a user