mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-06-16 15:08:22 +08:00
fix: implement tool execution timeout handling and improve blocking call management
This commit is contained in:
@@ -279,6 +279,15 @@ class MoviePilotAgent:
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _get_recursion_limit() -> int:
|
||||
"""读取 LangGraph 递归上限,防止模型持续循环调用工具。"""
|
||||
try:
|
||||
limit = int(settings.LLM_MAX_ITERATIONS or 0)
|
||||
except (TypeError, ValueError):
|
||||
limit = 0
|
||||
return limit if limit > 0 else 128
|
||||
|
||||
@classmethod
|
||||
def _get_model_name(cls, model: Any) -> Optional[str]:
|
||||
return (
|
||||
@@ -1024,7 +1033,8 @@ class MoviePilotAgent:
|
||||
agent_config = {
|
||||
"configurable": {
|
||||
"thread_id": self.session_id,
|
||||
}
|
||||
},
|
||||
"recursion_limit": self._get_recursion_limit(),
|
||||
}
|
||||
|
||||
# 判断是否启用流式输出
|
||||
|
||||
@@ -76,6 +76,7 @@ def format_tool_result_for_agent(
|
||||
|
||||
# 将常见的阻塞调用按能力域拆分到独立线程池,避免外部慢 IO 抢占同一批 worker。
|
||||
_BLOCKING_BUCKET_LIMITS = {
|
||||
"command": 4,
|
||||
"default": 4,
|
||||
"config": 2,
|
||||
"db": 4,
|
||||
@@ -86,6 +87,7 @@ _BLOCKING_BUCKET_LIMITS = {
|
||||
"site": 4,
|
||||
"storage": 4,
|
||||
"subscribe": 2,
|
||||
"web": 2,
|
||||
"workflow": 2,
|
||||
}
|
||||
_blocking_semaphores = {
|
||||
@@ -112,6 +114,54 @@ def _get_blocking_executor(bucket: str) -> ThreadPoolExecutor:
|
||||
return executor
|
||||
|
||||
|
||||
class ToolExecutionTimeoutError(TimeoutError):
|
||||
"""Agent 工具执行超时异常。"""
|
||||
|
||||
|
||||
def _get_tool_timeout_seconds() -> Optional[float]:
|
||||
"""读取工具执行超时时间,配置为 0 或负数时表示不限制。"""
|
||||
try:
|
||||
timeout = float(settings.LLM_TOOL_TIMEOUT or 0)
|
||||
except (TypeError, ValueError):
|
||||
timeout = 0
|
||||
return timeout if timeout > 0 else None
|
||||
|
||||
|
||||
async def run_agent_blocking(
|
||||
bucket: str, func: Callable[..., Any], *args: Any, **kwargs: Any
|
||||
) -> Any:
|
||||
"""
|
||||
在受控线程池中运行阻塞型同步代码。
|
||||
|
||||
调用方被取消时不会提前释放并发名额,避免底层阻塞调用仍在运行时继续接纳
|
||||
新任务,把同一类慢 IO 的线程池持续打满。
|
||||
"""
|
||||
bucket_name = bucket if bucket in _BLOCKING_BUCKET_LIMITS else "default"
|
||||
semaphore = _blocking_semaphores[bucket_name]
|
||||
bound_call = partial(func, *args, **kwargs)
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
await semaphore.acquire()
|
||||
try:
|
||||
future = _get_blocking_executor(bucket_name).submit(bound_call)
|
||||
except Exception:
|
||||
semaphore.release()
|
||||
raise
|
||||
|
||||
def _release_semaphore(_future) -> None:
|
||||
try:
|
||||
_future.exception()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
loop.call_soon_threadsafe(semaphore.release)
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
future.add_done_callback(_release_semaphore)
|
||||
return await asyncio.shield(asyncio.wrap_future(future, loop=loop))
|
||||
|
||||
|
||||
class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
"""
|
||||
MoviePilot专用工具基类(LangChain v1 / langchain_core)
|
||||
@@ -236,7 +286,7 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
|
||||
# 执行具体工具逻辑
|
||||
try:
|
||||
result = await self.run(**kwargs)
|
||||
result = await self.run_with_timeout(**kwargs)
|
||||
|
||||
# 记录工具执行结果摘要日志
|
||||
str_result = serialize_tool_result_for_agent(result)
|
||||
@@ -246,6 +296,10 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
summary = str_result
|
||||
logger.info(f"Agent工具 {self.name} 执行完成,结果摘要: {summary}")
|
||||
|
||||
except ToolExecutionTimeoutError as e:
|
||||
error_message = str(e)
|
||||
logger.warning(error_message)
|
||||
result = error_message
|
||||
except Exception as e:
|
||||
error_message = f"工具执行异常 ({type(e).__name__}): {str(e)}"
|
||||
logger.error(f"Tool {self.name} execution failed: {e}", exc_info=True)
|
||||
@@ -276,6 +330,18 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
"""子类实现具体的工具执行逻辑"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def run_with_timeout(self, **kwargs) -> str:
|
||||
"""按系统配置限制单个工具调用的最长执行时间。"""
|
||||
timeout = _get_tool_timeout_seconds()
|
||||
if not timeout:
|
||||
return await self.run(**kwargs)
|
||||
try:
|
||||
return await asyncio.wait_for(self.run(**kwargs), timeout=timeout)
|
||||
except asyncio.TimeoutError as err:
|
||||
raise ToolExecutionTimeoutError(
|
||||
f"工具 {self.name} 执行超时(超过 {timeout:g} 秒),已停止等待结果。"
|
||||
) from err
|
||||
|
||||
@staticmethod
|
||||
async def run_blocking(
|
||||
bucket: str, func: Callable[..., Any], *args: Any, **kwargs: Any
|
||||
@@ -283,15 +349,7 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
"""
|
||||
在受控线程池中运行阻塞型同步代码,避免拖住 FastAPI 主事件循环。
|
||||
"""
|
||||
bucket_name = bucket if bucket in _BLOCKING_BUCKET_LIMITS else "default"
|
||||
semaphore = _blocking_semaphores[bucket_name]
|
||||
bound_call = partial(func, *args, **kwargs)
|
||||
|
||||
async with semaphore:
|
||||
loop = asyncio.get_running_loop()
|
||||
return await loop.run_in_executor(
|
||||
_get_blocking_executor(bucket_name), bound_call
|
||||
)
|
||||
return await run_agent_blocking(bucket, func, *args, **kwargs)
|
||||
|
||||
def set_message_attr(self, channel: str, source: str, username: str):
|
||||
"""
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
"""插件 Agent 工具共享辅助方法"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import shutil
|
||||
from typing import Any, Optional
|
||||
@@ -251,7 +250,9 @@ async def install_plugin_runtime(
|
||||
SystemConfigKey.UserInstalledPlugins, install_plugins
|
||||
)
|
||||
|
||||
await asyncio.to_thread(reload_plugin_runtime, plugin_id)
|
||||
from app.agent.tools.base import run_agent_blocking
|
||||
|
||||
await run_agent_blocking("plugin", reload_plugin_runtime, plugin_id)
|
||||
return True, message or "插件安装成功", refreshed_only
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
"""浏览器操作工具 - 让Agent能够通过Playwright控制浏览器进行网页交互"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
from enum import Enum
|
||||
@@ -167,21 +166,18 @@ class BrowseWebpageTool(MoviePilotTool):
|
||||
if browser_action == BrowserAction.EVALUATE and not script:
|
||||
return "错误: 'evaluate' 操作需要提供 script 参数"
|
||||
|
||||
# 在线程池中运行同步的 Playwright 操作
|
||||
loop = asyncio.get_running_loop()
|
||||
result = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self._execute_browser_action(
|
||||
browser_action=browser_action,
|
||||
url=url,
|
||||
selector=selector,
|
||||
value=value,
|
||||
script=script,
|
||||
content_type=content_type,
|
||||
timeout=timeout,
|
||||
cookies=cookies,
|
||||
user_agent=user_agent,
|
||||
),
|
||||
result = await self.run_blocking(
|
||||
"web",
|
||||
self._execute_browser_action,
|
||||
browser_action=browser_action,
|
||||
url=url,
|
||||
selector=selector,
|
||||
value=value,
|
||||
script=script,
|
||||
content_type=content_type,
|
||||
timeout=timeout,
|
||||
cookies=cookies,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
@@ -451,6 +451,9 @@ class ExecuteCommandTool(MoviePilotTool):
|
||||
except asyncio.TimeoutError:
|
||||
timed_out = True
|
||||
await self._cleanup_process(process, wait_task)
|
||||
except asyncio.CancelledError:
|
||||
await self._cleanup_process(process, wait_task)
|
||||
raise
|
||||
|
||||
try:
|
||||
await self._finish_reader_tasks(reader_tasks)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
@@ -445,8 +444,7 @@ class SearchWebTool(MoviePilotTool):
|
||||
logger.warning(f"搜索引擎搜索进程失败: {err}")
|
||||
return results
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
results = await loop.run_in_executor(None, sync_search)
|
||||
results = await self.run_blocking("web", sync_search)
|
||||
return self._filter_results_by_site(results, site_filter)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
"""发送语音消息工具。"""
|
||||
|
||||
import asyncio
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -74,7 +72,8 @@ class SendVoiceMessageTool(MoviePilotTool):
|
||||
reply_mode == AgentCapabilityManager.REPLY_MODE_NATIVE
|
||||
and AgentCapabilityManager.is_audio_output_available()
|
||||
):
|
||||
voice_file = await asyncio.to_thread(
|
||||
voice_file = await self.run_blocking(
|
||||
"default",
|
||||
AgentCapabilityManager.synthesize_speech, message
|
||||
)
|
||||
if voice_file:
|
||||
|
||||
@@ -2,7 +2,7 @@ import json
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.agent.tools.base import format_tool_result_for_agent
|
||||
from app.agent.tools.base import ToolExecutionTimeoutError, format_tool_result_for_agent
|
||||
from app.agent.tools.factory import MoviePilotToolFactory
|
||||
from app.log import logger
|
||||
|
||||
@@ -259,10 +259,14 @@ class MoviePilotToolsManager:
|
||||
|
||||
# 调用工具的run方法。HTTP/MCP 工具调用不会经过 BaseTool._arun,
|
||||
# 因此这里也必须复用同一套返回值格式化和兜底截断逻辑。
|
||||
result = await tool_instance.run(**normalized_arguments)
|
||||
result = await tool_instance.run_with_timeout(**normalized_arguments)
|
||||
|
||||
# 记录工具执行结果摘要日志
|
||||
str_result = format_tool_result_for_agent(result, tool_name=tool_name, max_chars=getattr(tool_instance, "result_max_chars", None))
|
||||
str_result = format_tool_result_for_agent(
|
||||
result,
|
||||
tool_name=tool_name,
|
||||
max_chars=getattr(tool_instance, "result_max_chars", None),
|
||||
)
|
||||
if len(str_result) > 500:
|
||||
summary = str_result[:500] + f"...(已截断,总长度: {len(str_result)})"
|
||||
else:
|
||||
@@ -270,6 +274,13 @@ class MoviePilotToolsManager:
|
||||
logger.info(f"Agent工具 {tool_name} 执行完成,结果摘要: {summary}")
|
||||
|
||||
return str_result
|
||||
except ToolExecutionTimeoutError as e:
|
||||
logger.warning(str(e))
|
||||
return format_tool_result_for_agent(
|
||||
str(e),
|
||||
tool_name=tool_name,
|
||||
max_chars=getattr(tool_instance, "result_max_chars", None),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"调用工具 {tool_name} 时发生错误: {e}", exc_info=True)
|
||||
error_msg = json.dumps(
|
||||
|
||||
@@ -179,8 +179,8 @@ class TestAgentPluginTools(unittest.TestCase):
|
||||
config_oper.get.return_value = ["DemoPlugin"]
|
||||
calls = []
|
||||
|
||||
async def fake_to_thread(func, *args, **kwargs):
|
||||
calls.append((func, args, kwargs))
|
||||
async def fake_run_agent_blocking(bucket, func, *args, **kwargs):
|
||||
calls.append((bucket, func, args, kwargs))
|
||||
return None
|
||||
|
||||
with patch(
|
||||
@@ -198,8 +198,8 @@ class TestAgentPluginTools(unittest.TestCase):
|
||||
"app.agent.tools.impl._plugin_tool_utils.MoviePilotServerHelper.async_install_plugin_reg",
|
||||
AsyncMock(return_value=True),
|
||||
) as install_reg, patch(
|
||||
"app.agent.tools.impl._plugin_tool_utils.asyncio.to_thread",
|
||||
side_effect=fake_to_thread,
|
||||
"app.agent.tools.base.run_agent_blocking",
|
||||
side_effect=fake_run_agent_blocking,
|
||||
):
|
||||
success, message, refreshed_only = asyncio.run(
|
||||
install_plugin_runtime(
|
||||
@@ -217,9 +217,10 @@ class TestAgentPluginTools(unittest.TestCase):
|
||||
repo_url="https://example.com/market",
|
||||
)
|
||||
self.assertEqual(1, len(calls))
|
||||
self.assertEqual(reload_runtime, calls[0][0])
|
||||
self.assertEqual(("DemoPlugin",), calls[0][1])
|
||||
self.assertEqual({}, calls[0][2])
|
||||
self.assertEqual("plugin", calls[0][0])
|
||||
self.assertEqual(reload_runtime, calls[0][1])
|
||||
self.assertEqual(("DemoPlugin",), calls[0][2])
|
||||
self.assertEqual({}, calls[0][3])
|
||||
|
||||
def test_uninstall_plugin_uninstalls_installed_candidate(self):
|
||||
tool = UninstallPluginTool(session_id="session-1", user_id="10001")
|
||||
|
||||
140
tests/test_agent_tool_timeouts.py
Normal file
140
tests/test_agent_tool_timeouts.py
Normal file
@@ -0,0 +1,140 @@
|
||||
import asyncio
|
||||
import threading
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.manager import MoviePilotToolsManager
|
||||
|
||||
|
||||
class SlowAgentTool(MoviePilotTool):
|
||||
"""用于验证工具超时保护的慢工具。"""
|
||||
|
||||
name: str = "slow_agent_tool"
|
||||
description: str = "Test slow tool."
|
||||
|
||||
async def run(self, **kwargs) -> str:
|
||||
"""等待足够久以触发测试中的短超时。"""
|
||||
await asyncio.sleep(1)
|
||||
return "finished"
|
||||
|
||||
|
||||
class BlockingAgentTool(MoviePilotTool):
|
||||
"""用于验证阻塞调用并发名额释放时机的工具。"""
|
||||
|
||||
name: str = "blocking_agent_tool"
|
||||
description: str = "Test blocking tool."
|
||||
|
||||
async def run(self, **kwargs) -> str:
|
||||
"""本测试不会直接调用该方法。"""
|
||||
return "unused"
|
||||
|
||||
|
||||
def test_arun_returns_timeout_message_when_tool_exceeds_limit():
|
||||
"""LangChain 工具入口应按 LLM_TOOL_TIMEOUT 停止等待慢工具。"""
|
||||
tool = SlowAgentTool(session_id="session-1", user_id="10001")
|
||||
|
||||
async def _run_tool():
|
||||
with patch("app.agent.tools.base.settings.LLM_TOOL_TIMEOUT", 0.05):
|
||||
return await tool._arun()
|
||||
|
||||
result = asyncio.run(_run_tool())
|
||||
|
||||
assert "工具 slow_agent_tool 执行超时" in result
|
||||
assert "超过 0.05 秒" in result
|
||||
|
||||
|
||||
def test_http_tool_manager_uses_same_timeout_guard():
|
||||
"""HTTP/MCP 工具入口绕过 _arun 时也应复用工具超时保护。"""
|
||||
manager = MoviePilotToolsManager(is_admin=True)
|
||||
manager.tools = [SlowAgentTool(session_id="session-1", user_id="10001")]
|
||||
|
||||
async def _call_tool():
|
||||
with patch("app.agent.tools.base.settings.LLM_TOOL_TIMEOUT", 0.05):
|
||||
return await manager.call_tool("slow_agent_tool", {})
|
||||
|
||||
result = asyncio.run(_call_tool())
|
||||
|
||||
assert "工具 slow_agent_tool 执行超时" in result
|
||||
|
||||
|
||||
def test_run_blocking_keeps_bucket_slot_until_worker_finishes():
|
||||
"""被取消的阻塞调用在底层线程结束前不应释放同桶并发名额。"""
|
||||
tool = BlockingAgentTool(session_id="session-1", user_id="10001")
|
||||
started = asyncio.Event()
|
||||
release = threading.Event()
|
||||
|
||||
def _blocking_call() -> str:
|
||||
loop.call_soon_threadsafe(started.set)
|
||||
release.wait()
|
||||
return "done"
|
||||
|
||||
async def _run_scenario():
|
||||
nonlocal loop
|
||||
loop = asyncio.get_running_loop()
|
||||
with patch.dict(
|
||||
"app.agent.tools.base._blocking_semaphores",
|
||||
{"subscribe": asyncio.Semaphore(1)},
|
||||
):
|
||||
task = asyncio.create_task(tool.run_blocking("subscribe", _blocking_call))
|
||||
await started.wait()
|
||||
task.cancel()
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await task
|
||||
|
||||
second_task = asyncio.create_task(
|
||||
tool.run_blocking("subscribe", lambda: "second")
|
||||
)
|
||||
await asyncio.sleep(0.05)
|
||||
assert not second_task.done()
|
||||
|
||||
release.set()
|
||||
assert await asyncio.wait_for(second_task, timeout=1) == "second"
|
||||
|
||||
loop = None
|
||||
asyncio.run(_run_scenario())
|
||||
|
||||
|
||||
def test_create_agent_config_uses_llm_max_iterations():
|
||||
"""Agent 执行配置应把 LLM_MAX_ITERATIONS 传给 LangGraph recursion_limit。"""
|
||||
from app.agent import MoviePilotAgent
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
class _FakeGraphState:
|
||||
"""提供最小 LangGraph 状态替身。"""
|
||||
|
||||
values = {"messages": [AIMessage(content="ok")]}
|
||||
|
||||
class _FakeAgent:
|
||||
"""记录 ainvoke 收到的 config。"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.config = None
|
||||
|
||||
async def ainvoke(self, _payload, config=None):
|
||||
"""保存运行配置供断言。"""
|
||||
self.config = config
|
||||
|
||||
def get_state(self, _config):
|
||||
"""返回最小消息状态。"""
|
||||
return _FakeGraphState()
|
||||
|
||||
async def _execute() -> dict:
|
||||
agent = MoviePilotAgent(session_id="session-1", user_id="10001")
|
||||
fake_agent = _FakeAgent()
|
||||
agent._should_stream = lambda: False
|
||||
|
||||
async def _create_agent(streaming=False):
|
||||
"""返回测试替身 Agent。"""
|
||||
return fake_agent
|
||||
|
||||
agent._create_agent = _create_agent
|
||||
agent.stream_handler.stop_streaming = lambda: asyncio.sleep(0, result=(False, ""))
|
||||
with patch("app.agent.settings.LLM_MAX_ITERATIONS", 7):
|
||||
await agent._execute_agent([])
|
||||
return fake_agent.config
|
||||
|
||||
config = asyncio.run(_execute())
|
||||
|
||||
assert config["recursion_limit"] == 7
|
||||
@@ -7,6 +7,7 @@ import subprocess
|
||||
import sys
|
||||
import time
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
from app.agent.tools.impl.execute_command import (
|
||||
ExecuteCommandTool,
|
||||
@@ -69,6 +70,39 @@ class TestExecuteCommandTool(unittest.TestCase):
|
||||
self.assertIn("命令执行超时", result)
|
||||
self.assertIn("started", result)
|
||||
|
||||
def test_cancelled_run_cleans_up_process(self):
|
||||
"""外层取消 action=run 时应同步清理已经启动的子进程。"""
|
||||
async def _run_and_cancel():
|
||||
tool = ExecuteCommandTool(session_id="session-1", user_id="10001")
|
||||
command = _python_command("import time; time.sleep(20)")
|
||||
original_create = asyncio.create_subprocess_shell
|
||||
process_holder = {}
|
||||
|
||||
async def wrapped_create(*args, **kwargs):
|
||||
process = await original_create(*args, **kwargs)
|
||||
process_holder["process"] = process
|
||||
return process
|
||||
|
||||
with patch(
|
||||
"app.agent.tools.impl.execute_command.asyncio.create_subprocess_shell",
|
||||
side_effect=wrapped_create,
|
||||
):
|
||||
task = asyncio.create_task(
|
||||
tool.run(action="run", command=command, timeout=60)
|
||||
)
|
||||
for _ in range(50):
|
||||
if "process" in process_holder:
|
||||
break
|
||||
await asyncio.sleep(0.02)
|
||||
self.assertIn("process", process_holder)
|
||||
task.cancel()
|
||||
with self.assertRaises(asyncio.CancelledError):
|
||||
await task
|
||||
return process_holder["process"]
|
||||
|
||||
process = asyncio.run(_run_and_cancel())
|
||||
self.assertIsNotNone(process.returncode)
|
||||
|
||||
def test_timeout_with_large_output_writes_partial_full_log_to_temp_file(self):
|
||||
"""超时且输出较大时,终止前完整输出应写入临时文件。"""
|
||||
command = _python_command(
|
||||
|
||||
Reference in New Issue
Block a user