fix actions execute

This commit is contained in:
jxxghp
2025-02-27 20:39:42 +08:00
parent 37926b4c19
commit f8ed16666c
15 changed files with 61 additions and 31 deletions

View File

@@ -45,10 +45,11 @@ class AddDownloadAction(BaseAction):
def success(self) -> bool:
return True if self._added_downloads else False
async def execute(self, params: AddDownloadParams, context: ActionContext) -> ActionContext:
def execute(self, params: dict, context: ActionContext) -> ActionContext:
"""
将上下文中的torrents添加到下载任务中
"""
params = AddDownloadParams(**params)
for t in context.torrents:
if not t.meta_info:
t.meta_info = MetaInfo(title=t.title, subtitle=t.description)

View File

@@ -41,7 +41,7 @@ class AddSubscribeAction(BaseAction):
def success(self) -> bool:
return True if self._added_subscribes else False
async def execute(self, params: AddSubscribeParams, context: ActionContext) -> ActionContext:
def execute(self, params: dict, context: ActionContext) -> ActionContext:
"""
将medias中的信息添加订阅如果订阅不存在的话
"""

View File

@@ -39,7 +39,7 @@ class FetchDownloadsAction(BaseAction):
return True
return True if all([d.completed for d in self._downloads]) else False
async def execute(self, params: FetchDownloadsParams, context: ActionContext) -> ActionContext:
def execute(self, params: dict, context: ActionContext) -> ActionContext:
"""
更新downloads中的下载任务状态
"""

View File

@@ -117,10 +117,11 @@ class FetchMediasAction(BaseAction):
return s
return None
async def execute(self, params: FetchMediasParams, context: ActionContext) -> ActionContext:
def execute(self, params: dict, context: ActionContext) -> ActionContext:
"""
获取媒体数据填充到medias
"""
params = FetchMediasParams(**params)
for name in params.sources:
source = self.__get_source(name)
if not source:

View File

@@ -51,10 +51,11 @@ class FetchRssAction(BaseAction):
def success(self) -> bool:
return True if self._rss_torrents else False
async def execute(self, params: FetchRssParams, context: ActionContext) -> ActionContext:
def execute(self, params: dict, context: ActionContext) -> ActionContext:
"""
请求RSS地址获取数据并解析为资源列表
"""
params = FetchRssParams(**params)
if not params.url:
return context

View File

@@ -13,7 +13,7 @@ class FetchTorrentsParams(ActionParams):
获取站点资源参数
"""
name: str = Field(None, description="资源名称")
year: Optional[int] = Field(None, description="年份")
year: Optional[str] = Field(None, description="年份")
type: Optional[str] = Field(None, description="资源类型 (电影/电视剧)")
season: Optional[int] = Field(None, description="季度")
sites: Optional[List[int]] = Field([], description="站点列表")
@@ -46,10 +46,11 @@ class FetchTorrentsAction(BaseAction):
def success(self) -> bool:
return True if self._torrents else False
async def execute(self, params: FetchTorrentsParams, context: ActionContext) -> ActionContext:
def execute(self, params: dict, context: ActionContext) -> ActionContext:
"""
搜索站点,获取资源列表
"""
params = FetchTorrentsParams(**params)
torrents = self.searchchain.search_by_title(title=params.name, sites=params.sites)
for torrent in torrents:
if params.year and torrent.meta_info.year != params.year:

View File

@@ -40,10 +40,11 @@ class FilterMediasAction(BaseAction):
def success(self) -> bool:
return True if self.__medias else False
async def execute(self, params: FilterMediasParams, context: ActionContext) -> ActionContext:
def execute(self, params: dict, context: ActionContext) -> ActionContext:
"""
过滤medias中媒体数据
"""
params = FilterMediasParams(**params)
for media in context.medias:
if params.type and media.type != MediaType(params.type):
continue

View File

@@ -48,10 +48,11 @@ class FilterTorrentsAction(BaseAction):
def success(self) -> bool:
return self.done
async def execute(self, params: FilterTorrentsParams, context: ActionContext) -> ActionContext:
def execute(self, params: dict, context: ActionContext) -> ActionContext:
"""
过滤torrents中的资源
"""
params = FilterTorrentsParams(**params)
for torrent in context.torrents:
if self.torrenthelper.filter_torrent(
torrent_info=torrent.torrent_info,

View File

@@ -43,7 +43,7 @@ class ScrapeFileAction(BaseAction):
def success(self) -> bool:
return True if self.__scraped_files else False
async def execute(self, params: ScrapeFileParams, context: ActionContext) -> ActionContext:
def execute(self, params: dict, context: ActionContext) -> ActionContext:
"""
刮削fileitems中的所有文件
"""

View File

@@ -33,7 +33,7 @@ class SendEventAction(BaseAction):
def success(self) -> bool:
return self.done
async def execute(self, params: SendEventParams, context: ActionContext) -> ActionContext:
def execute(self, params: dict, context: ActionContext) -> ActionContext:
"""
发送events中的事件
"""

View File

@@ -40,7 +40,7 @@ class SendMessageAction(BaseAction):
def success(self) -> bool:
return self.done
async def execute(self, params: SendMessageParams, context: ActionContext) -> ActionContext:
def execute(self, params: dict, context: ActionContext) -> ActionContext:
"""
发送messages中的消息
"""

View File

@@ -42,7 +42,7 @@ class TransferFileAction(BaseAction):
def success(self) -> bool:
return True if self.__fileitems else False
async def execute(self, params: TransferFileParams, context: ActionContext) -> ActionContext:
def execute(self, params: dict, context: ActionContext) -> ActionContext:
"""
从downloads中整理文件记录到fileitems
"""

View File

@@ -1,6 +1,9 @@
import base64
import pickle
import threading
from collections import defaultdict, deque
from concurrent.futures import ThreadPoolExecutor
from time import sleep
from typing import List, Tuple
from pydantic.fields import Callable
@@ -60,8 +63,11 @@ class WorkflowExecutor:
self.indegree[action_id] = 0
# 初始上下文
if workflow.current_action:
self.context = ActionContext(**workflow.context)
if workflow.current_action and workflow.context:
# Base64解码
decoded_data = base64.b64decode(workflow.context["content"])
# 反序列化数据
self.context = pickle.loads(decoded_data)
else:
self.context = ActionContext()
@@ -79,19 +85,24 @@ class WorkflowExecutor:
# 退出条件:队列为空且无运行任务
if not self.queue and self.running_tasks == 0:
break
# 退出条件:出现了错误
if not self.success:
break
if not self.queue:
sleep(1)
continue
# 取出队首节点
node_id = self.queue.popleft()
# 标记任务开始
self.running_tasks += 1
# 已停机
if not global_vars.is_system_stopped:
if global_vars.is_system_stopped:
break
# 已执行的跳过
if (self.workflow.current_action
and str(node_id) in self.workflow.current_action.split(',')):
and node_id in self.workflow.current_action.split(',')):
continue
# 提交任务到线程池
@@ -119,7 +130,11 @@ class WorkflowExecutor:
# 节点执行失败
if not state:
self.success = False
self.errmsg = f"{action.name} 执行失败"
self.errmsg = f"{action.name} 失败"
# 标记任务完成
with self.lock:
self.running_tasks -= 1
return
with self.lock:
@@ -145,12 +160,9 @@ class WorkflowExecutor:
"""
合并上下文
"""
# 遍历上下文,补充缺失的字段
self_context_dict = self.context.dict()
for key, value in context.dict().items():
if key not in self_context_dict:
self_context_dict[key] = value
self.context = ActionContext(**self_context_dict)
if not getattr(self.context, key, None):
setattr(self.context, key, value)
class WorkflowChain(ChainBase):
@@ -173,7 +185,13 @@ class WorkflowChain(ChainBase):
"""
保存上下文到数据库
"""
self.workflowoper.step(workflow_id, action_id=action.id, context=context.dict())
# 序列化数据
serialized_data = pickle.dumps(context)
# 使用Base64编码字节流
encoded_data = base64.b64encode(serialized_data).decode('utf-8')
self.workflowoper.step(workflow_id, action_id=action.id, context={
"content": encoded_data
})
# 重置工作流
if from_begin:
@@ -198,8 +216,8 @@ class WorkflowChain(ChainBase):
executor = WorkflowExecutor(workflow, step_callback=save_step)
executor.execute()
if executor.success:
logger.info(f"工作流 {workflow.name} 执行失败:{executor.errmsg}")
if not executor.success:
logger.info(f"工作流 {workflow.name} 执行失败:{executor.errmsg}")
self.workflowoper.fail(workflow_id, result=executor.errmsg)
return False, executor.errmsg
else:

View File

@@ -65,16 +65,22 @@ class WorkFlowManager(metaclass=Singleton):
action_obj = self._actions[action.type]
# 执行
logger.info(f"执行动作: {action.id} - {action.name}")
result_context = action_obj.execute(action.data, context)
try:
result_context = action_obj.execute(action.data, context)
except Exception as err:
logger.error(f"{action.name} 执行失败: {err}")
return False, context
if action_obj.success:
logger.info(f"{action.name} 执行成功")
else:
logger.error(f"{action.name} 执行失败")
if action.data.loop and action.data.loop_interval:
loop = action.data.get("loop")
loop_interval = action.data.get("loop_interval")
if loop and loop_interval:
while not action_obj.done:
# 等待
logger.info(f"{action.name} 等待 {action.data.loop_interval} 秒后继续执行 ...")
sleep(action.data.loop_interval)
logger.info(f"{action.name} 等待 {loop_interval} 秒后继续执行 ...")
sleep(loop_interval)
# 执行
logger.info(f"继续执行动作: {action.id} - {action.name}")
result_context = action_obj.execute(action.data, result_context)

View File

@@ -48,8 +48,8 @@ class Action(BaseModel):
type: Optional[str] = Field(None, description="动作类型 (类名)")
name: Optional[str] = Field(None, description="动作名称")
description: Optional[str] = Field(None, description="动作描述")
data: Optional[ActionParams] = Field({}, description="参数")
position: Optional[dict] = Field({}, description="位置")
data: Optional[dict] = Field({}, description="参数")
class ActionContext(BaseModel):