mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-03-20 03:57:30 +08:00
fix workflow executor
This commit is contained in:
@@ -3,6 +3,8 @@ from collections import defaultdict, deque
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import List, Tuple
|
||||
|
||||
from pydantic.fields import Callable
|
||||
|
||||
from app.chain import ChainBase
|
||||
from app.core.config import global_vars
|
||||
from app.core.workflow import WorkFlowManager
|
||||
@@ -17,20 +19,32 @@ class WorkflowExecutor:
|
||||
工作流执行器
|
||||
"""
|
||||
|
||||
def __init__(self, workflow: Workflow):
|
||||
def __init__(self, workflow: Workflow, step_callback: Callable = None):
|
||||
"""
|
||||
初始化工作流执行器
|
||||
:param workflow: 工作流对象
|
||||
:param step_callback: 步骤回调函数
|
||||
"""
|
||||
self.workflowoper = WorkflowOper()
|
||||
self.workflowmanager = WorkFlowManager()
|
||||
|
||||
# 工作流数据
|
||||
self.workflow = workflow
|
||||
self.step_callback = step_callback
|
||||
self.actions = {action['id']: Action(**action) for action in workflow.actions}
|
||||
self.flows = [ActionFlow(**flow) for flow in workflow.flows]
|
||||
|
||||
self.success = True
|
||||
self.errmsg = ""
|
||||
|
||||
# 工作流管理器
|
||||
self.workflowmanager = WorkFlowManager()
|
||||
# 线程安全队列
|
||||
self.queue = deque()
|
||||
# 锁用于保证线程安全
|
||||
self.lock = threading.Lock()
|
||||
# 线程池
|
||||
self.executor = ThreadPoolExecutor()
|
||||
# 跟踪运行中的任务数
|
||||
self.running_tasks = 0
|
||||
|
||||
# 构建邻接表、入度表
|
||||
self.adjacency = defaultdict(list)
|
||||
self.indegree = defaultdict(int)
|
||||
@@ -51,15 +65,6 @@ class WorkflowExecutor:
|
||||
else:
|
||||
self.context = ActionContext()
|
||||
|
||||
# 线程安全队列
|
||||
self.queue = deque()
|
||||
# 锁用于保证线程安全
|
||||
self.lock = threading.Lock()
|
||||
# 线程池
|
||||
self.executor = ThreadPoolExecutor()
|
||||
# 跟踪运行中的任务数
|
||||
self.running_tasks = 0
|
||||
|
||||
# 初始化队列:入度为0的节点
|
||||
for action_id in self.actions:
|
||||
if self.indegree[action_id] == 0:
|
||||
@@ -102,7 +107,7 @@ class WorkflowExecutor:
|
||||
执行单个节点操作,返回修改后的上下文和节点ID
|
||||
"""
|
||||
action = self.actions[node_id]
|
||||
state, result_ctx = self.workflowmanager.excute(action, context)
|
||||
state, result_ctx = self.workflowmanager.excute(action, context=context)
|
||||
return action, state, result_ctx
|
||||
|
||||
def on_node_complete(self, future):
|
||||
@@ -117,17 +122,18 @@ class WorkflowExecutor:
|
||||
self.errmsg = f"{action.name} 执行失败"
|
||||
return
|
||||
|
||||
# 更新主上下文
|
||||
with self.lock:
|
||||
# 更新主上下文
|
||||
self.merge_context(result_ctx)
|
||||
self.save_step(action)
|
||||
# 回调
|
||||
if self.step_callback:
|
||||
self.step_callback(action, self.context)
|
||||
|
||||
# 处理后继节点
|
||||
successors = self.adjacency.get(action.id, [])
|
||||
for succ_id in successors:
|
||||
with self.lock:
|
||||
self.indegree[succ_id] -= 1
|
||||
print(f"节点 {succ_id} 入度减至 {self.indegree[succ_id]}")
|
||||
if self.indegree[succ_id] == 0:
|
||||
self.queue.append(succ_id)
|
||||
|
||||
@@ -146,12 +152,6 @@ class WorkflowExecutor:
|
||||
self_context_dict[key] = value
|
||||
self.context = ActionContext(**self_context_dict)
|
||||
|
||||
def save_step(self, node_id: int):
|
||||
"""
|
||||
保存上下文到数据库
|
||||
"""
|
||||
self.workflowoper.step(self.workflow.id, action_id=node_id, context=self.context.dict())
|
||||
|
||||
|
||||
class WorkflowChain(ChainBase):
|
||||
"""
|
||||
@@ -159,8 +159,8 @@ class WorkflowChain(ChainBase):
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.workflowoper = WorkflowOper()
|
||||
super().__init__()
|
||||
self.workflowoper = WorkflowOper()
|
||||
|
||||
def process(self, workflow_id: int, from_begin: bool = True) -> Tuple[bool, str]:
|
||||
"""
|
||||
@@ -168,6 +168,13 @@ class WorkflowChain(ChainBase):
|
||||
:param workflow_id: 工作流ID
|
||||
:param from_begin: 是否从头开始,默认为True
|
||||
"""
|
||||
|
||||
def save_step(action: Action, context: ActionContext):
|
||||
"""
|
||||
保存上下文到数据库
|
||||
"""
|
||||
self.workflowoper.step(workflow_id, action_id=action.id, context=context.dict())
|
||||
|
||||
# 重置工作流
|
||||
if from_begin:
|
||||
self.workflowoper.reset(workflow_id)
|
||||
@@ -188,7 +195,7 @@ class WorkflowChain(ChainBase):
|
||||
self.workflowoper.start(workflow_id)
|
||||
|
||||
# 执行工作流
|
||||
executor = WorkflowExecutor(workflow)
|
||||
executor = WorkflowExecutor(workflow, step_callback=save_step)
|
||||
executor.execute()
|
||||
|
||||
if executor.success:
|
||||
|
||||
Reference in New Issue
Block a user