From a7c4161f91806b4b7ad1dc7579f8f80f3c2fba8e Mon Sep 17 00:00:00 2001 From: jxxghp Date: Wed, 26 Feb 2025 12:57:57 +0800 Subject: [PATCH] fix workflow executor --- app/chain/workflow.py | 57 ++++++++++++++++++++++++------------------- 1 file changed, 32 insertions(+), 25 deletions(-) diff --git a/app/chain/workflow.py b/app/chain/workflow.py index 3ea0e4b7..f558f737 100644 --- a/app/chain/workflow.py +++ b/app/chain/workflow.py @@ -3,6 +3,8 @@ from collections import defaultdict, deque from concurrent.futures import ThreadPoolExecutor from typing import List, Tuple +from pydantic.fields import Callable + from app.chain import ChainBase from app.core.config import global_vars from app.core.workflow import WorkFlowManager @@ -17,20 +19,32 @@ class WorkflowExecutor: 工作流执行器 """ - def __init__(self, workflow: Workflow): + def __init__(self, workflow: Workflow, step_callback: Callable = None): """ 初始化工作流执行器 + :param workflow: 工作流对象 + :param step_callback: 步骤回调函数 """ - self.workflowoper = WorkflowOper() - self.workflowmanager = WorkFlowManager() - + # 工作流数据 self.workflow = workflow + self.step_callback = step_callback self.actions = {action['id']: Action(**action) for action in workflow.actions} self.flows = [ActionFlow(**flow) for flow in workflow.flows] self.success = True self.errmsg = "" + # 工作流管理器 + self.workflowmanager = WorkFlowManager() + # 线程安全队列 + self.queue = deque() + # 锁用于保证线程安全 + self.lock = threading.Lock() + # 线程池 + self.executor = ThreadPoolExecutor() + # 跟踪运行中的任务数 + self.running_tasks = 0 + # 构建邻接表、入度表 self.adjacency = defaultdict(list) self.indegree = defaultdict(int) @@ -51,15 +65,6 @@ class WorkflowExecutor: else: self.context = ActionContext() - # 线程安全队列 - self.queue = deque() - # 锁用于保证线程安全 - self.lock = threading.Lock() - # 线程池 - self.executor = ThreadPoolExecutor() - # 跟踪运行中的任务数 - self.running_tasks = 0 - # 初始化队列:入度为0的节点 for action_id in self.actions: if self.indegree[action_id] == 0: @@ -102,7 +107,7 @@ class WorkflowExecutor: 执行单个节点操作,返回修改后的上下文和节点ID """ action = self.actions[node_id] - state, result_ctx = self.workflowmanager.excute(action, context) + state, result_ctx = self.workflowmanager.excute(action, context=context) return action, state, result_ctx def on_node_complete(self, future): @@ -117,17 +122,18 @@ class WorkflowExecutor: self.errmsg = f"{action.name} 执行失败" return - # 更新主上下文 with self.lock: + # 更新主上下文 self.merge_context(result_ctx) - self.save_step(action) + # 回调 + if self.step_callback: + self.step_callback(action, self.context) # 处理后继节点 successors = self.adjacency.get(action.id, []) for succ_id in successors: with self.lock: self.indegree[succ_id] -= 1 - print(f"节点 {succ_id} 入度减至 {self.indegree[succ_id]}") if self.indegree[succ_id] == 0: self.queue.append(succ_id) @@ -146,12 +152,6 @@ class WorkflowExecutor: self_context_dict[key] = value self.context = ActionContext(**self_context_dict) - def save_step(self, node_id: int): - """ - 保存上下文到数据库 - """ - self.workflowoper.step(self.workflow.id, action_id=node_id, context=self.context.dict()) - class WorkflowChain(ChainBase): """ @@ -159,8 +159,8 @@ class WorkflowChain(ChainBase): """ def __init__(self): - self.workflowoper = WorkflowOper() super().__init__() + self.workflowoper = WorkflowOper() def process(self, workflow_id: int, from_begin: bool = True) -> Tuple[bool, str]: """ @@ -168,6 +168,13 @@ class WorkflowChain(ChainBase): :param workflow_id: 工作流ID :param from_begin: 是否从头开始,默认为True """ + + def save_step(action: Action, context: ActionContext): + """ + 保存上下文到数据库 + """ + self.workflowoper.step(workflow_id, action_id=action.id, context=context.dict()) + # 重置工作流 if from_begin: self.workflowoper.reset(workflow_id) @@ -188,7 +195,7 @@ class WorkflowChain(ChainBase): self.workflowoper.start(workflow_id) # 执行工作流 - executor = WorkflowExecutor(workflow) + executor = WorkflowExecutor(workflow, step_callback=save_step) executor.execute() if executor.success: