From 1b83abe15521dae85164d99c08ed376da218ec8e Mon Sep 17 00:00:00 2001 From: jxxghp Date: Fri, 12 Jun 2026 08:43:17 +0800 Subject: [PATCH] fix: implement tool execution timeout handling and improve blocking call management --- app/agent/__init__.py | 12 +- app/agent/tools/base.py | 78 ++++++++++-- app/agent/tools/impl/_plugin_tool_utils.py | 5 +- app/agent/tools/impl/browse_webpage.py | 28 ++--- app/agent/tools/impl/execute_command.py | 3 + app/agent/tools/impl/search_web.py | 4 +- app/agent/tools/impl/send_voice_message.py | 5 +- app/agent/tools/manager.py | 17 ++- tests/test_agent_plugin_tools.py | 15 +-- tests/test_agent_tool_timeouts.py | 140 +++++++++++++++++++++ tests/test_execute_command_tool.py | 34 +++++ 11 files changed, 296 insertions(+), 45 deletions(-) create mode 100644 tests/test_agent_tool_timeouts.py diff --git a/app/agent/__init__.py b/app/agent/__init__.py index 0edf7609..2e6d6338 100644 --- a/app/agent/__init__.py +++ b/app/agent/__init__.py @@ -279,6 +279,15 @@ class MoviePilotAgent: except (TypeError, ValueError): return None + @staticmethod + def _get_recursion_limit() -> int: + """读取 LangGraph 递归上限,防止模型持续循环调用工具。""" + try: + limit = int(settings.LLM_MAX_ITERATIONS or 0) + except (TypeError, ValueError): + limit = 0 + return limit if limit > 0 else 128 + @classmethod def _get_model_name(cls, model: Any) -> Optional[str]: return ( @@ -1024,7 +1033,8 @@ class MoviePilotAgent: agent_config = { "configurable": { "thread_id": self.session_id, - } + }, + "recursion_limit": self._get_recursion_limit(), } # 判断是否启用流式输出 diff --git a/app/agent/tools/base.py b/app/agent/tools/base.py index 86e981a1..1a85b76f 100644 --- a/app/agent/tools/base.py +++ b/app/agent/tools/base.py @@ -76,6 +76,7 @@ def format_tool_result_for_agent( # 将常见的阻塞调用按能力域拆分到独立线程池,避免外部慢 IO 抢占同一批 worker。 _BLOCKING_BUCKET_LIMITS = { + "command": 4, "default": 4, "config": 2, "db": 4, @@ -86,6 +87,7 @@ _BLOCKING_BUCKET_LIMITS = { "site": 4, "storage": 4, "subscribe": 2, + "web": 2, "workflow": 2, } _blocking_semaphores = { @@ -112,6 +114,54 @@ def _get_blocking_executor(bucket: str) -> ThreadPoolExecutor: return executor +class ToolExecutionTimeoutError(TimeoutError): + """Agent 工具执行超时异常。""" + + +def _get_tool_timeout_seconds() -> Optional[float]: + """读取工具执行超时时间,配置为 0 或负数时表示不限制。""" + try: + timeout = float(settings.LLM_TOOL_TIMEOUT or 0) + except (TypeError, ValueError): + timeout = 0 + return timeout if timeout > 0 else None + + +async def run_agent_blocking( + bucket: str, func: Callable[..., Any], *args: Any, **kwargs: Any +) -> Any: + """ + 在受控线程池中运行阻塞型同步代码。 + + 调用方被取消时不会提前释放并发名额,避免底层阻塞调用仍在运行时继续接纳 + 新任务,把同一类慢 IO 的线程池持续打满。 + """ + bucket_name = bucket if bucket in _BLOCKING_BUCKET_LIMITS else "default" + semaphore = _blocking_semaphores[bucket_name] + bound_call = partial(func, *args, **kwargs) + loop = asyncio.get_running_loop() + + await semaphore.acquire() + try: + future = _get_blocking_executor(bucket_name).submit(bound_call) + except Exception: + semaphore.release() + raise + + def _release_semaphore(_future) -> None: + try: + _future.exception() + except Exception: + pass + try: + loop.call_soon_threadsafe(semaphore.release) + except RuntimeError: + pass + + future.add_done_callback(_release_semaphore) + return await asyncio.shield(asyncio.wrap_future(future, loop=loop)) + + class MoviePilotTool(BaseTool, metaclass=ABCMeta): """ MoviePilot专用工具基类(LangChain v1 / langchain_core) @@ -236,7 +286,7 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta): # 执行具体工具逻辑 try: - result = await self.run(**kwargs) + result = await self.run_with_timeout(**kwargs) # 记录工具执行结果摘要日志 str_result = serialize_tool_result_for_agent(result) @@ -246,6 +296,10 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta): summary = str_result logger.info(f"Agent工具 {self.name} 执行完成,结果摘要: {summary}") + except ToolExecutionTimeoutError as e: + error_message = str(e) + logger.warning(error_message) + result = error_message except Exception as e: error_message = f"工具执行异常 ({type(e).__name__}): {str(e)}" logger.error(f"Tool {self.name} execution failed: {e}", exc_info=True) @@ -276,6 +330,18 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta): """子类实现具体的工具执行逻辑""" raise NotImplementedError + async def run_with_timeout(self, **kwargs) -> str: + """按系统配置限制单个工具调用的最长执行时间。""" + timeout = _get_tool_timeout_seconds() + if not timeout: + return await self.run(**kwargs) + try: + return await asyncio.wait_for(self.run(**kwargs), timeout=timeout) + except asyncio.TimeoutError as err: + raise ToolExecutionTimeoutError( + f"工具 {self.name} 执行超时(超过 {timeout:g} 秒),已停止等待结果。" + ) from err + @staticmethod async def run_blocking( bucket: str, func: Callable[..., Any], *args: Any, **kwargs: Any @@ -283,15 +349,7 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta): """ 在受控线程池中运行阻塞型同步代码,避免拖住 FastAPI 主事件循环。 """ - bucket_name = bucket if bucket in _BLOCKING_BUCKET_LIMITS else "default" - semaphore = _blocking_semaphores[bucket_name] - bound_call = partial(func, *args, **kwargs) - - async with semaphore: - loop = asyncio.get_running_loop() - return await loop.run_in_executor( - _get_blocking_executor(bucket_name), bound_call - ) + return await run_agent_blocking(bucket, func, *args, **kwargs) def set_message_attr(self, channel: str, source: str, username: str): """ diff --git a/app/agent/tools/impl/_plugin_tool_utils.py b/app/agent/tools/impl/_plugin_tool_utils.py index 85769066..6c92745b 100644 --- a/app/agent/tools/impl/_plugin_tool_utils.py +++ b/app/agent/tools/impl/_plugin_tool_utils.py @@ -1,6 +1,5 @@ """插件 Agent 工具共享辅助方法""" -import asyncio import json import shutil from typing import Any, Optional @@ -251,7 +250,9 @@ async def install_plugin_runtime( SystemConfigKey.UserInstalledPlugins, install_plugins ) - await asyncio.to_thread(reload_plugin_runtime, plugin_id) + from app.agent.tools.base import run_agent_blocking + + await run_agent_blocking("plugin", reload_plugin_runtime, plugin_id) return True, message or "插件安装成功", refreshed_only diff --git a/app/agent/tools/impl/browse_webpage.py b/app/agent/tools/impl/browse_webpage.py index 035a1496..eeee294a 100644 --- a/app/agent/tools/impl/browse_webpage.py +++ b/app/agent/tools/impl/browse_webpage.py @@ -1,6 +1,5 @@ """浏览器操作工具 - 让Agent能够通过Playwright控制浏览器进行网页交互""" -import asyncio import base64 import json from enum import Enum @@ -167,21 +166,18 @@ class BrowseWebpageTool(MoviePilotTool): if browser_action == BrowserAction.EVALUATE and not script: return "错误: 'evaluate' 操作需要提供 script 参数" - # 在线程池中运行同步的 Playwright 操作 - loop = asyncio.get_running_loop() - result = await loop.run_in_executor( - None, - lambda: self._execute_browser_action( - browser_action=browser_action, - url=url, - selector=selector, - value=value, - script=script, - content_type=content_type, - timeout=timeout, - cookies=cookies, - user_agent=user_agent, - ), + result = await self.run_blocking( + "web", + self._execute_browser_action, + browser_action=browser_action, + url=url, + selector=selector, + value=value, + script=script, + content_type=content_type, + timeout=timeout, + cookies=cookies, + user_agent=user_agent, ) return result diff --git a/app/agent/tools/impl/execute_command.py b/app/agent/tools/impl/execute_command.py index db10e8bc..e2b417ca 100644 --- a/app/agent/tools/impl/execute_command.py +++ b/app/agent/tools/impl/execute_command.py @@ -451,6 +451,9 @@ class ExecuteCommandTool(MoviePilotTool): except asyncio.TimeoutError: timed_out = True await self._cleanup_process(process, wait_task) + except asyncio.CancelledError: + await self._cleanup_process(process, wait_task) + raise try: await self._finish_reader_tasks(reader_tasks) diff --git a/app/agent/tools/impl/search_web.py b/app/agent/tools/impl/search_web.py index 14b4254c..ad718af1 100644 --- a/app/agent/tools/impl/search_web.py +++ b/app/agent/tools/impl/search_web.py @@ -1,4 +1,3 @@ -import asyncio import json import re from dataclasses import dataclass @@ -445,8 +444,7 @@ class SearchWebTool(MoviePilotTool): logger.warning(f"搜索引擎搜索进程失败: {err}") return results - loop = asyncio.get_running_loop() - results = await loop.run_in_executor(None, sync_search) + results = await self.run_blocking("web", sync_search) return self._filter_results_by_site(results, site_filter) except Exception as e: diff --git a/app/agent/tools/impl/send_voice_message.py b/app/agent/tools/impl/send_voice_message.py index 982ceabe..7b16dfb7 100644 --- a/app/agent/tools/impl/send_voice_message.py +++ b/app/agent/tools/impl/send_voice_message.py @@ -1,6 +1,4 @@ """发送语音消息工具。""" - -import asyncio from typing import Optional, Type from pydantic import BaseModel, Field @@ -74,7 +72,8 @@ class SendVoiceMessageTool(MoviePilotTool): reply_mode == AgentCapabilityManager.REPLY_MODE_NATIVE and AgentCapabilityManager.is_audio_output_available() ): - voice_file = await asyncio.to_thread( + voice_file = await self.run_blocking( + "default", AgentCapabilityManager.synthesize_speech, message ) if voice_file: diff --git a/app/agent/tools/manager.py b/app/agent/tools/manager.py index ec8a5db6..96301d23 100644 --- a/app/agent/tools/manager.py +++ b/app/agent/tools/manager.py @@ -2,7 +2,7 @@ import json import uuid from typing import Any, Dict, List, Optional -from app.agent.tools.base import format_tool_result_for_agent +from app.agent.tools.base import ToolExecutionTimeoutError, format_tool_result_for_agent from app.agent.tools.factory import MoviePilotToolFactory from app.log import logger @@ -259,10 +259,14 @@ class MoviePilotToolsManager: # 调用工具的run方法。HTTP/MCP 工具调用不会经过 BaseTool._arun, # 因此这里也必须复用同一套返回值格式化和兜底截断逻辑。 - result = await tool_instance.run(**normalized_arguments) + result = await tool_instance.run_with_timeout(**normalized_arguments) # 记录工具执行结果摘要日志 - str_result = format_tool_result_for_agent(result, tool_name=tool_name, max_chars=getattr(tool_instance, "result_max_chars", None)) + str_result = format_tool_result_for_agent( + result, + tool_name=tool_name, + max_chars=getattr(tool_instance, "result_max_chars", None), + ) if len(str_result) > 500: summary = str_result[:500] + f"...(已截断,总长度: {len(str_result)})" else: @@ -270,6 +274,13 @@ class MoviePilotToolsManager: logger.info(f"Agent工具 {tool_name} 执行完成,结果摘要: {summary}") return str_result + except ToolExecutionTimeoutError as e: + logger.warning(str(e)) + return format_tool_result_for_agent( + str(e), + tool_name=tool_name, + max_chars=getattr(tool_instance, "result_max_chars", None), + ) except Exception as e: logger.error(f"调用工具 {tool_name} 时发生错误: {e}", exc_info=True) error_msg = json.dumps( diff --git a/tests/test_agent_plugin_tools.py b/tests/test_agent_plugin_tools.py index 4c9bd11f..d6127995 100644 --- a/tests/test_agent_plugin_tools.py +++ b/tests/test_agent_plugin_tools.py @@ -179,8 +179,8 @@ class TestAgentPluginTools(unittest.TestCase): config_oper.get.return_value = ["DemoPlugin"] calls = [] - async def fake_to_thread(func, *args, **kwargs): - calls.append((func, args, kwargs)) + async def fake_run_agent_blocking(bucket, func, *args, **kwargs): + calls.append((bucket, func, args, kwargs)) return None with patch( @@ -198,8 +198,8 @@ class TestAgentPluginTools(unittest.TestCase): "app.agent.tools.impl._plugin_tool_utils.MoviePilotServerHelper.async_install_plugin_reg", AsyncMock(return_value=True), ) as install_reg, patch( - "app.agent.tools.impl._plugin_tool_utils.asyncio.to_thread", - side_effect=fake_to_thread, + "app.agent.tools.base.run_agent_blocking", + side_effect=fake_run_agent_blocking, ): success, message, refreshed_only = asyncio.run( install_plugin_runtime( @@ -217,9 +217,10 @@ class TestAgentPluginTools(unittest.TestCase): repo_url="https://example.com/market", ) self.assertEqual(1, len(calls)) - self.assertEqual(reload_runtime, calls[0][0]) - self.assertEqual(("DemoPlugin",), calls[0][1]) - self.assertEqual({}, calls[0][2]) + self.assertEqual("plugin", calls[0][0]) + self.assertEqual(reload_runtime, calls[0][1]) + self.assertEqual(("DemoPlugin",), calls[0][2]) + self.assertEqual({}, calls[0][3]) def test_uninstall_plugin_uninstalls_installed_candidate(self): tool = UninstallPluginTool(session_id="session-1", user_id="10001") diff --git a/tests/test_agent_tool_timeouts.py b/tests/test_agent_tool_timeouts.py new file mode 100644 index 00000000..50e4af4b --- /dev/null +++ b/tests/test_agent_tool_timeouts.py @@ -0,0 +1,140 @@ +import asyncio +import threading +from unittest.mock import patch + +import pytest + +from app.agent.tools.base import MoviePilotTool +from app.agent.tools.manager import MoviePilotToolsManager + + +class SlowAgentTool(MoviePilotTool): + """用于验证工具超时保护的慢工具。""" + + name: str = "slow_agent_tool" + description: str = "Test slow tool." + + async def run(self, **kwargs) -> str: + """等待足够久以触发测试中的短超时。""" + await asyncio.sleep(1) + return "finished" + + +class BlockingAgentTool(MoviePilotTool): + """用于验证阻塞调用并发名额释放时机的工具。""" + + name: str = "blocking_agent_tool" + description: str = "Test blocking tool." + + async def run(self, **kwargs) -> str: + """本测试不会直接调用该方法。""" + return "unused" + + +def test_arun_returns_timeout_message_when_tool_exceeds_limit(): + """LangChain 工具入口应按 LLM_TOOL_TIMEOUT 停止等待慢工具。""" + tool = SlowAgentTool(session_id="session-1", user_id="10001") + + async def _run_tool(): + with patch("app.agent.tools.base.settings.LLM_TOOL_TIMEOUT", 0.05): + return await tool._arun() + + result = asyncio.run(_run_tool()) + + assert "工具 slow_agent_tool 执行超时" in result + assert "超过 0.05 秒" in result + + +def test_http_tool_manager_uses_same_timeout_guard(): + """HTTP/MCP 工具入口绕过 _arun 时也应复用工具超时保护。""" + manager = MoviePilotToolsManager(is_admin=True) + manager.tools = [SlowAgentTool(session_id="session-1", user_id="10001")] + + async def _call_tool(): + with patch("app.agent.tools.base.settings.LLM_TOOL_TIMEOUT", 0.05): + return await manager.call_tool("slow_agent_tool", {}) + + result = asyncio.run(_call_tool()) + + assert "工具 slow_agent_tool 执行超时" in result + + +def test_run_blocking_keeps_bucket_slot_until_worker_finishes(): + """被取消的阻塞调用在底层线程结束前不应释放同桶并发名额。""" + tool = BlockingAgentTool(session_id="session-1", user_id="10001") + started = asyncio.Event() + release = threading.Event() + + def _blocking_call() -> str: + loop.call_soon_threadsafe(started.set) + release.wait() + return "done" + + async def _run_scenario(): + nonlocal loop + loop = asyncio.get_running_loop() + with patch.dict( + "app.agent.tools.base._blocking_semaphores", + {"subscribe": asyncio.Semaphore(1)}, + ): + task = asyncio.create_task(tool.run_blocking("subscribe", _blocking_call)) + await started.wait() + task.cancel() + with pytest.raises(asyncio.CancelledError): + await task + + second_task = asyncio.create_task( + tool.run_blocking("subscribe", lambda: "second") + ) + await asyncio.sleep(0.05) + assert not second_task.done() + + release.set() + assert await asyncio.wait_for(second_task, timeout=1) == "second" + + loop = None + asyncio.run(_run_scenario()) + + +def test_create_agent_config_uses_llm_max_iterations(): + """Agent 执行配置应把 LLM_MAX_ITERATIONS 传给 LangGraph recursion_limit。""" + from app.agent import MoviePilotAgent + from langchain_core.messages import AIMessage + + class _FakeGraphState: + """提供最小 LangGraph 状态替身。""" + + values = {"messages": [AIMessage(content="ok")]} + + class _FakeAgent: + """记录 ainvoke 收到的 config。""" + + def __init__(self) -> None: + self.config = None + + async def ainvoke(self, _payload, config=None): + """保存运行配置供断言。""" + self.config = config + + def get_state(self, _config): + """返回最小消息状态。""" + return _FakeGraphState() + + async def _execute() -> dict: + agent = MoviePilotAgent(session_id="session-1", user_id="10001") + fake_agent = _FakeAgent() + agent._should_stream = lambda: False + + async def _create_agent(streaming=False): + """返回测试替身 Agent。""" + return fake_agent + + agent._create_agent = _create_agent + agent.stream_handler.stop_streaming = lambda: asyncio.sleep(0, result=(False, "")) + with patch("app.agent.settings.LLM_MAX_ITERATIONS", 7): + await agent._execute_agent([]) + return fake_agent.config + + config = asyncio.run(_execute()) + + assert config["recursion_limit"] == 7 diff --git a/tests/test_execute_command_tool.py b/tests/test_execute_command_tool.py index 7234ef87..94baa380 100644 --- a/tests/test_execute_command_tool.py +++ b/tests/test_execute_command_tool.py @@ -7,6 +7,7 @@ import subprocess import sys import time import unittest +from unittest.mock import patch from app.agent.tools.impl.execute_command import ( ExecuteCommandTool, @@ -69,6 +70,39 @@ class TestExecuteCommandTool(unittest.TestCase): self.assertIn("命令执行超时", result) self.assertIn("started", result) + def test_cancelled_run_cleans_up_process(self): + """外层取消 action=run 时应同步清理已经启动的子进程。""" + async def _run_and_cancel(): + tool = ExecuteCommandTool(session_id="session-1", user_id="10001") + command = _python_command("import time; time.sleep(20)") + original_create = asyncio.create_subprocess_shell + process_holder = {} + + async def wrapped_create(*args, **kwargs): + process = await original_create(*args, **kwargs) + process_holder["process"] = process + return process + + with patch( + "app.agent.tools.impl.execute_command.asyncio.create_subprocess_shell", + side_effect=wrapped_create, + ): + task = asyncio.create_task( + tool.run(action="run", command=command, timeout=60) + ) + for _ in range(50): + if "process" in process_holder: + break + await asyncio.sleep(0.02) + self.assertIn("process", process_holder) + task.cancel() + with self.assertRaises(asyncio.CancelledError): + await task + return process_holder["process"] + + process = asyncio.run(_run_and_cancel()) + self.assertIsNotNone(process.returncode) + def test_timeout_with_large_output_writes_partial_full_log_to_temp_file(self): """超时且输出较大时,终止前完整输出应写入临时文件。""" command = _python_command(