mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-04-13 17:52:28 +08:00
fix actions execute
This commit is contained in:
@@ -45,10 +45,11 @@ class AddDownloadAction(BaseAction):
|
||||
def success(self) -> bool:
|
||||
return True if self._added_downloads else False
|
||||
|
||||
async def execute(self, params: AddDownloadParams, context: ActionContext) -> ActionContext:
|
||||
def execute(self, params: dict, context: ActionContext) -> ActionContext:
|
||||
"""
|
||||
将上下文中的torrents添加到下载任务中
|
||||
"""
|
||||
params = AddDownloadParams(**params)
|
||||
for t in context.torrents:
|
||||
if not t.meta_info:
|
||||
t.meta_info = MetaInfo(title=t.title, subtitle=t.description)
|
||||
|
||||
@@ -41,7 +41,7 @@ class AddSubscribeAction(BaseAction):
|
||||
def success(self) -> bool:
|
||||
return True if self._added_subscribes else False
|
||||
|
||||
async def execute(self, params: AddSubscribeParams, context: ActionContext) -> ActionContext:
|
||||
def execute(self, params: dict, context: ActionContext) -> ActionContext:
|
||||
"""
|
||||
将medias中的信息添加订阅,如果订阅不存在的话
|
||||
"""
|
||||
|
||||
@@ -39,7 +39,7 @@ class FetchDownloadsAction(BaseAction):
|
||||
return True
|
||||
return True if all([d.completed for d in self._downloads]) else False
|
||||
|
||||
async def execute(self, params: FetchDownloadsParams, context: ActionContext) -> ActionContext:
|
||||
def execute(self, params: dict, context: ActionContext) -> ActionContext:
|
||||
"""
|
||||
更新downloads中的下载任务状态
|
||||
"""
|
||||
|
||||
@@ -117,10 +117,11 @@ class FetchMediasAction(BaseAction):
|
||||
return s
|
||||
return None
|
||||
|
||||
async def execute(self, params: FetchMediasParams, context: ActionContext) -> ActionContext:
|
||||
def execute(self, params: dict, context: ActionContext) -> ActionContext:
|
||||
"""
|
||||
获取媒体数据,填充到medias
|
||||
"""
|
||||
params = FetchMediasParams(**params)
|
||||
for name in params.sources:
|
||||
source = self.__get_source(name)
|
||||
if not source:
|
||||
|
||||
@@ -51,10 +51,11 @@ class FetchRssAction(BaseAction):
|
||||
def success(self) -> bool:
|
||||
return True if self._rss_torrents else False
|
||||
|
||||
async def execute(self, params: FetchRssParams, context: ActionContext) -> ActionContext:
|
||||
def execute(self, params: dict, context: ActionContext) -> ActionContext:
|
||||
"""
|
||||
请求RSS地址获取数据,并解析为资源列表
|
||||
"""
|
||||
params = FetchRssParams(**params)
|
||||
if not params.url:
|
||||
return context
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ class FetchTorrentsParams(ActionParams):
|
||||
获取站点资源参数
|
||||
"""
|
||||
name: str = Field(None, description="资源名称")
|
||||
year: Optional[int] = Field(None, description="年份")
|
||||
year: Optional[str] = Field(None, description="年份")
|
||||
type: Optional[str] = Field(None, description="资源类型 (电影/电视剧)")
|
||||
season: Optional[int] = Field(None, description="季度")
|
||||
sites: Optional[List[int]] = Field([], description="站点列表")
|
||||
@@ -46,10 +46,11 @@ class FetchTorrentsAction(BaseAction):
|
||||
def success(self) -> bool:
|
||||
return True if self._torrents else False
|
||||
|
||||
async def execute(self, params: FetchTorrentsParams, context: ActionContext) -> ActionContext:
|
||||
def execute(self, params: dict, context: ActionContext) -> ActionContext:
|
||||
"""
|
||||
搜索站点,获取资源列表
|
||||
"""
|
||||
params = FetchTorrentsParams(**params)
|
||||
torrents = self.searchchain.search_by_title(title=params.name, sites=params.sites)
|
||||
for torrent in torrents:
|
||||
if params.year and torrent.meta_info.year != params.year:
|
||||
|
||||
@@ -40,10 +40,11 @@ class FilterMediasAction(BaseAction):
|
||||
def success(self) -> bool:
|
||||
return True if self.__medias else False
|
||||
|
||||
async def execute(self, params: FilterMediasParams, context: ActionContext) -> ActionContext:
|
||||
def execute(self, params: dict, context: ActionContext) -> ActionContext:
|
||||
"""
|
||||
过滤medias中媒体数据
|
||||
"""
|
||||
params = FilterMediasParams(**params)
|
||||
for media in context.medias:
|
||||
if params.type and media.type != MediaType(params.type):
|
||||
continue
|
||||
|
||||
@@ -48,10 +48,11 @@ class FilterTorrentsAction(BaseAction):
|
||||
def success(self) -> bool:
|
||||
return self.done
|
||||
|
||||
async def execute(self, params: FilterTorrentsParams, context: ActionContext) -> ActionContext:
|
||||
def execute(self, params: dict, context: ActionContext) -> ActionContext:
|
||||
"""
|
||||
过滤torrents中的资源
|
||||
"""
|
||||
params = FilterTorrentsParams(**params)
|
||||
for torrent in context.torrents:
|
||||
if self.torrenthelper.filter_torrent(
|
||||
torrent_info=torrent.torrent_info,
|
||||
|
||||
@@ -43,7 +43,7 @@ class ScrapeFileAction(BaseAction):
|
||||
def success(self) -> bool:
|
||||
return True if self.__scraped_files else False
|
||||
|
||||
async def execute(self, params: ScrapeFileParams, context: ActionContext) -> ActionContext:
|
||||
def execute(self, params: dict, context: ActionContext) -> ActionContext:
|
||||
"""
|
||||
刮削fileitems中的所有文件
|
||||
"""
|
||||
|
||||
@@ -33,7 +33,7 @@ class SendEventAction(BaseAction):
|
||||
def success(self) -> bool:
|
||||
return self.done
|
||||
|
||||
async def execute(self, params: SendEventParams, context: ActionContext) -> ActionContext:
|
||||
def execute(self, params: dict, context: ActionContext) -> ActionContext:
|
||||
"""
|
||||
发送events中的事件
|
||||
"""
|
||||
|
||||
@@ -40,7 +40,7 @@ class SendMessageAction(BaseAction):
|
||||
def success(self) -> bool:
|
||||
return self.done
|
||||
|
||||
async def execute(self, params: SendMessageParams, context: ActionContext) -> ActionContext:
|
||||
def execute(self, params: dict, context: ActionContext) -> ActionContext:
|
||||
"""
|
||||
发送messages中的消息
|
||||
"""
|
||||
|
||||
@@ -42,7 +42,7 @@ class TransferFileAction(BaseAction):
|
||||
def success(self) -> bool:
|
||||
return True if self.__fileitems else False
|
||||
|
||||
async def execute(self, params: TransferFileParams, context: ActionContext) -> ActionContext:
|
||||
def execute(self, params: dict, context: ActionContext) -> ActionContext:
|
||||
"""
|
||||
从downloads中整理文件,记录到fileitems
|
||||
"""
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
import base64
|
||||
import pickle
|
||||
import threading
|
||||
from collections import defaultdict, deque
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from time import sleep
|
||||
from typing import List, Tuple
|
||||
|
||||
from pydantic.fields import Callable
|
||||
@@ -60,8 +63,11 @@ class WorkflowExecutor:
|
||||
self.indegree[action_id] = 0
|
||||
|
||||
# 初始上下文
|
||||
if workflow.current_action:
|
||||
self.context = ActionContext(**workflow.context)
|
||||
if workflow.current_action and workflow.context:
|
||||
# Base64解码
|
||||
decoded_data = base64.b64decode(workflow.context["content"])
|
||||
# 反序列化数据
|
||||
self.context = pickle.loads(decoded_data)
|
||||
else:
|
||||
self.context = ActionContext()
|
||||
|
||||
@@ -79,19 +85,24 @@ class WorkflowExecutor:
|
||||
# 退出条件:队列为空且无运行任务
|
||||
if not self.queue and self.running_tasks == 0:
|
||||
break
|
||||
# 退出条件:出现了错误
|
||||
if not self.success:
|
||||
break
|
||||
if not self.queue:
|
||||
sleep(1)
|
||||
continue
|
||||
# 取出队首节点
|
||||
node_id = self.queue.popleft()
|
||||
# 标记任务开始
|
||||
self.running_tasks += 1
|
||||
|
||||
# 已停机
|
||||
if not global_vars.is_system_stopped:
|
||||
if global_vars.is_system_stopped:
|
||||
break
|
||||
|
||||
# 已执行的跳过
|
||||
if (self.workflow.current_action
|
||||
and str(node_id) in self.workflow.current_action.split(',')):
|
||||
and node_id in self.workflow.current_action.split(',')):
|
||||
continue
|
||||
|
||||
# 提交任务到线程池
|
||||
@@ -119,7 +130,11 @@ class WorkflowExecutor:
|
||||
# 节点执行失败
|
||||
if not state:
|
||||
self.success = False
|
||||
self.errmsg = f"{action.name} 执行失败"
|
||||
self.errmsg = f"{action.name} 失败"
|
||||
# 标记任务完成
|
||||
with self.lock:
|
||||
self.running_tasks -= 1
|
||||
|
||||
return
|
||||
|
||||
with self.lock:
|
||||
@@ -145,12 +160,9 @@ class WorkflowExecutor:
|
||||
"""
|
||||
合并上下文
|
||||
"""
|
||||
# 遍历上下文,补充缺失的字段
|
||||
self_context_dict = self.context.dict()
|
||||
for key, value in context.dict().items():
|
||||
if key not in self_context_dict:
|
||||
self_context_dict[key] = value
|
||||
self.context = ActionContext(**self_context_dict)
|
||||
if not getattr(self.context, key, None):
|
||||
setattr(self.context, key, value)
|
||||
|
||||
|
||||
class WorkflowChain(ChainBase):
|
||||
@@ -173,7 +185,13 @@ class WorkflowChain(ChainBase):
|
||||
"""
|
||||
保存上下文到数据库
|
||||
"""
|
||||
self.workflowoper.step(workflow_id, action_id=action.id, context=context.dict())
|
||||
# 序列化数据
|
||||
serialized_data = pickle.dumps(context)
|
||||
# 使用Base64编码字节流
|
||||
encoded_data = base64.b64encode(serialized_data).decode('utf-8')
|
||||
self.workflowoper.step(workflow_id, action_id=action.id, context={
|
||||
"content": encoded_data
|
||||
})
|
||||
|
||||
# 重置工作流
|
||||
if from_begin:
|
||||
@@ -198,8 +216,8 @@ class WorkflowChain(ChainBase):
|
||||
executor = WorkflowExecutor(workflow, step_callback=save_step)
|
||||
executor.execute()
|
||||
|
||||
if executor.success:
|
||||
logger.info(f"工作流 {workflow.name} 执行失败:{executor.errmsg}")
|
||||
if not executor.success:
|
||||
logger.info(f"工作流 {workflow.name} 执行失败:{executor.errmsg}!")
|
||||
self.workflowoper.fail(workflow_id, result=executor.errmsg)
|
||||
return False, executor.errmsg
|
||||
else:
|
||||
|
||||
@@ -65,16 +65,22 @@ class WorkFlowManager(metaclass=Singleton):
|
||||
action_obj = self._actions[action.type]
|
||||
# 执行
|
||||
logger.info(f"执行动作: {action.id} - {action.name}")
|
||||
result_context = action_obj.execute(action.data, context)
|
||||
try:
|
||||
result_context = action_obj.execute(action.data, context)
|
||||
except Exception as err:
|
||||
logger.error(f"{action.name} 执行失败: {err}")
|
||||
return False, context
|
||||
if action_obj.success:
|
||||
logger.info(f"{action.name} 执行成功")
|
||||
else:
|
||||
logger.error(f"{action.name} 执行失败")
|
||||
if action.data.loop and action.data.loop_interval:
|
||||
loop = action.data.get("loop")
|
||||
loop_interval = action.data.get("loop_interval")
|
||||
if loop and loop_interval:
|
||||
while not action_obj.done:
|
||||
# 等待
|
||||
logger.info(f"{action.name} 等待 {action.data.loop_interval} 秒后继续执行 ...")
|
||||
sleep(action.data.loop_interval)
|
||||
logger.info(f"{action.name} 等待 {loop_interval} 秒后继续执行 ...")
|
||||
sleep(loop_interval)
|
||||
# 执行
|
||||
logger.info(f"继续执行动作: {action.id} - {action.name}")
|
||||
result_context = action_obj.execute(action.data, result_context)
|
||||
|
||||
@@ -48,8 +48,8 @@ class Action(BaseModel):
|
||||
type: Optional[str] = Field(None, description="动作类型 (类名)")
|
||||
name: Optional[str] = Field(None, description="动作名称")
|
||||
description: Optional[str] = Field(None, description="动作描述")
|
||||
data: Optional[ActionParams] = Field({}, description="参数")
|
||||
position: Optional[dict] = Field({}, description="位置")
|
||||
data: Optional[dict] = Field({}, description="参数")
|
||||
|
||||
|
||||
class ActionContext(BaseModel):
|
||||
|
||||
Reference in New Issue
Block a user