From 140d224a9a01ac4346a4469ce230b156e7bbe5e3 Mon Sep 17 00:00:00 2001 From: jxxghp Date: Mon, 27 Apr 2026 07:57:32 +0800 Subject: [PATCH] fix agent stream blocking during command execution Offload synchronous message edits from the event loop and stream subprocess output so long-running commands stay responsive. --- app/agent/callback/__init__.py | 15 +- app/agent/tools/impl/execute_command.py | 194 ++++++++++++++++++++---- tests/test_agent_tool_streaming.py | 62 +++++++- tests/test_execute_command_tool.py | 57 +++++++ 4 files changed, 293 insertions(+), 35 deletions(-) create mode 100644 tests/test_execute_command_tool.py diff --git a/app/agent/callback/__init__.py b/app/agent/callback/__init__.py index 2cf5de92..d80ff816 100644 --- a/app/agent/callback/__init__.py +++ b/app/agent/callback/__init__.py @@ -2,6 +2,8 @@ import asyncio import threading from typing import Optional, Tuple +from fastapi.concurrency import run_in_threadpool + from app.chain import ChainBase from app.log import logger from app.schemas import Notification @@ -256,7 +258,8 @@ class StreamingHandler: try: if self._message_response is None: # 第一次发送:发送新消息并获取 message_id - response = chain.send_direct_message( + response = await run_in_threadpool( + chain.send_direct_message, Notification( channel=self._channel, source=self._source, @@ -264,7 +267,7 @@ class StreamingHandler: username=self._username, title=self._title, text=current_text, - ) + ), ) if response and response.success and response.message_id: self._message_response = response @@ -297,7 +300,8 @@ class StreamingHandler: # 如果偏移后还有新内容,立即发送为新消息 if current_text: - response = chain.send_direct_message( + response = await run_in_threadpool( + chain.send_direct_message, Notification( channel=self._channel, source=self._source, @@ -305,7 +309,7 @@ class StreamingHandler: username=self._username, title=self._title, text=current_text, - ) + ), ) if response and response.success and response.message_id: self._message_response = response @@ -324,7 +328,8 @@ class StreamingHandler: except (ValueError, KeyError): return - success = chain.edit_message( + success = await run_in_threadpool( + chain.edit_message, channel=channel_enum, source=self._message_response.source, message_id=self._message_response.message_id, diff --git a/app/agent/tools/impl/execute_command.py b/app/agent/tools/impl/execute_command.py index a5cfb4b1..c05bbc0f 100644 --- a/app/agent/tools/impl/execute_command.py +++ b/app/agent/tools/impl/execute_command.py @@ -1,7 +1,8 @@ """执行Shell命令工具""" import asyncio -from typing import Optional, Type +import codecs +from typing import Any, Dict, Optional, Type from pydantic import BaseModel, Field @@ -26,12 +27,133 @@ class ExecuteCommandTool(MoviePilotTool): description: str = "Safely execute shell commands on the server. Useful for system maintenance, checking status, or running custom scripts. Includes timeout and output limits." args_schema: Type[BaseModel] = ExecuteCommandInput require_admin: bool = True + RESULT_LIMIT = 3000 + STREAM_CAPTURE_LIMIT = 2000 + LIVE_OUTPUT_LIMIT = 1200 def get_tool_message(self, **kwargs) -> Optional[str]: """根据命令生成友好的提示消息""" command = kwargs.get("command", "") return f"执行系统命令: {command}" + def _build_result( + self, + message: str, + stdout_capture: Dict[str, Any], + stderr_capture: Dict[str, Any], + ) -> str: + stdout_str = "".join(stdout_capture["chunks"]).strip() + stderr_str = "".join(stderr_capture["chunks"]).strip() + + result = message + if stdout_str: + result += f"\n\n标准输出:\n{stdout_str}" + if stderr_str: + result += f"\n\n错误输出:\n{stderr_str}" + if not stdout_str and not stderr_str: + result += "\n\n(无输出内容)" + + was_truncated = stdout_capture["truncated"] or stderr_capture["truncated"] + overflow_suffix = "\n\n...(输出内容过长,已截断)" + if was_truncated or len(result) > self.RESULT_LIMIT: + result = ( + result[: self.RESULT_LIMIT - len(overflow_suffix)] + overflow_suffix + ) + return result + + def _append_capture(self, capture: Dict[str, Any], text: str): + if not text: + return + + remaining = self.STREAM_CAPTURE_LIMIT - capture["length"] + if remaining <= 0: + capture["truncated"] = True + return + + fragment = text[:remaining] + capture["chunks"].append(fragment) + capture["length"] += len(fragment) + if len(text) > remaining: + capture["truncated"] = True + + def _should_emit_live_output(self) -> bool: + return bool( + self._stream_handler + and self._stream_handler.is_streaming + and self._stream_handler.is_auto_flushing + ) + + def _emit_live_output( + self, text: str, stream_name: str, live_state: Dict[str, Any] + ): + if not text or not live_state["enabled"]: + return + + header_key = f"{stream_name}_header_sent" + prefix = "" + if not live_state[header_key]: + prefix = "标准输出:\n" if stream_name == "stdout" else "\n错误输出:\n" + live_state[header_key] = True + + payload = prefix + text + remaining = self.LIVE_OUTPUT_LIMIT - live_state["chars"] + if remaining <= 0: + if not live_state["truncated"]: + self._stream_handler.emit("\n...(命令输出过长,停止实时展示)\n") + live_state["truncated"] = True + return + + fragment = payload[:remaining] + if fragment: + self._stream_handler.emit(fragment) + live_state["chars"] += len(fragment) + + if len(payload) > remaining and not live_state["truncated"]: + self._stream_handler.emit("\n...(命令输出过长,停止实时展示)\n") + live_state["truncated"] = True + + async def _collect_stream( + self, + stream: Optional[asyncio.StreamReader], + stream_name: str, + capture: Dict[str, Any], + live_state: Dict[str, Any], + ): + if not stream: + return + + decoder = codecs.getincrementaldecoder("utf-8")(errors="replace") + while True: + chunk = await stream.read(512) + if not chunk: + tail = decoder.decode(b"", final=True) + if tail: + self._append_capture(capture, tail) + self._emit_live_output(tail, stream_name, live_state) + return + + text = decoder.decode(chunk) + if not text: + continue + + self._append_capture(capture, text) + self._emit_live_output(text, stream_name, live_state) + + @staticmethod + async def _terminate_process(process: asyncio.subprocess.Process): + if process.returncode is not None: + return + + try: + process.kill() + except ProcessLookupError: + return + + try: + await asyncio.wait_for(process.wait(), timeout=5) + except asyncio.TimeoutError: + logger.warning("终止命令进程超时") + async def run(self, command: str, timeout: Optional[int] = 60, **kwargs) -> str: logger.info( f"执行工具: {self.name}, 参数: command={command}, timeout={timeout}" @@ -56,40 +178,54 @@ class ExecuteCommandTool(MoviePilotTool): command, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE ) + stdout_capture: Dict[str, Any] = { + "chunks": [], + "length": 0, + "truncated": False, + } + stderr_capture: Dict[str, Any] = { + "chunks": [], + "length": 0, + "truncated": False, + } + live_state: Dict[str, Any] = { + "enabled": self._should_emit_live_output(), + "chars": 0, + "truncated": False, + "stdout_header_sent": False, + "stderr_header_sent": False, + } + + stdout_task = asyncio.create_task( + self._collect_stream( + process.stdout, "stdout", stdout_capture, live_state + ) + ) + stderr_task = asyncio.create_task( + self._collect_stream( + process.stderr, "stderr", stderr_capture, live_state + ) + ) + try: # 等待完成,带超时 - stdout, stderr = await asyncio.wait_for( - process.communicate(), timeout=timeout + await asyncio.wait_for(process.wait(), timeout=timeout) + await asyncio.gather(stdout_task, stderr_task) + return self._build_result( + f"命令执行完成 (退出码: {process.returncode})", + stdout_capture, + stderr_capture, ) - # 处理输出 - stdout_str = stdout.decode("utf-8", errors="replace").strip() - stderr_str = stderr.decode("utf-8", errors="replace").strip() - exit_code = process.returncode - - result = f"命令执行完成 (退出码: {exit_code})" - if stdout_str: - result += f"\n\n标准输出:\n{stdout_str}" - if stderr_str: - result += f"\n\n错误输出:\n{stderr_str}" - - # 如果没有输出 - if not stdout_str and not stderr_str: - result += "\n\n(无输出内容)" - - # 限制输出长度,防止上下文过长 - if len(result) > 3000: - result = result[:3000] + "\n\n...(输出内容过长,已截断)" - - return result - except asyncio.TimeoutError: # 超时处理 - try: - process.kill() - except ProcessLookupError: - pass - return f"命令执行超时 (限制: {timeout}秒)" + await self._terminate_process(process) + await asyncio.gather(stdout_task, stderr_task) + return self._build_result( + f"命令执行超时 (限制: {timeout}秒)", + stdout_capture, + stderr_capture, + ) except Exception as e: logger.error(f"执行命令失败: {e}", exc_info=True) diff --git a/tests/test_agent_tool_streaming.py b/tests/test_agent_tool_streaming.py index e343a6bb..bfeb556c 100644 --- a/tests/test_agent_tool_streaming.py +++ b/tests/test_agent_tool_streaming.py @@ -1,10 +1,17 @@ import asyncio import unittest -from unittest.mock import patch +from unittest.mock import AsyncMock, patch + +import langchain.agents as langchain_agents + +if not hasattr(langchain_agents, "create_agent"): + langchain_agents.create_agent = lambda *args, **kwargs: None from app.agent.callback import StreamingHandler from app.agent.tools.base import MoviePilotTool from app.core.config import settings +from app.schemas.message import MessageResponse +from app.schemas.types import MessageChannel class DummyTool(MoviePilotTool): @@ -48,6 +55,59 @@ class TestAgentToolStreaming(unittest.TestCase): self.assertEqual(result, "ok") self.assertEqual(buffered_message, "") + def test_flush_sends_direct_message_via_threadpool(self): + handler = StreamingHandler() + handler._channel = MessageChannel.Telegram.value + handler._source = "telegram" + handler._user_id = "10001" + handler._username = "tester" + handler._streaming_enabled = True + handler.emit("hello") + + with patch( + "app.agent.callback.run_in_threadpool", new_callable=AsyncMock + ) as run_in_threadpool_mock: + run_in_threadpool_mock.return_value = MessageResponse( + message_id=1, + chat_id=2, + source="telegram", + success=True, + ) + + asyncio.run(handler._flush()) + + self.assertEqual(run_in_threadpool_mock.await_count, 1) + self.assertEqual( + run_in_threadpool_mock.await_args.args[0].__name__, "send_direct_message" + ) + self.assertTrue(handler.has_sent_message) + + def test_flush_edits_message_via_threadpool(self): + handler = StreamingHandler() + handler._channel = MessageChannel.Telegram.value + handler._streaming_enabled = True + handler._message_response = MessageResponse( + message_id=1, + chat_id=2, + source="telegram", + success=True, + ) + handler._sent_text = "hello" + handler.emit("hello world") + + with patch( + "app.agent.callback.run_in_threadpool", new_callable=AsyncMock + ) as run_in_threadpool_mock: + run_in_threadpool_mock.return_value = True + + asyncio.run(handler._flush()) + + self.assertEqual(run_in_threadpool_mock.await_count, 1) + self.assertEqual( + run_in_threadpool_mock.await_args.args[0].__name__, "edit_message" + ) + self.assertEqual(handler._sent_text, "hello world") + if __name__ == "__main__": unittest.main() diff --git a/tests/test_execute_command_tool.py b/tests/test_execute_command_tool.py new file mode 100644 index 00000000..8a9e74fe --- /dev/null +++ b/tests/test_execute_command_tool.py @@ -0,0 +1,57 @@ +import asyncio +import shlex +import sys +import unittest + +import langchain.agents as langchain_agents + +if not hasattr(langchain_agents, "create_agent"): + langchain_agents.create_agent = lambda *args, **kwargs: None + +from app.agent.callback import StreamingHandler +from app.agent.tools.impl.execute_command import ExecuteCommandTool + + +class TestExecuteCommandTool(unittest.TestCase): + @staticmethod + def _build_python_command(script: str) -> str: + return f"{shlex.quote(sys.executable)} -c '{script}'" + + @staticmethod + def _build_streaming_tool() -> tuple[ExecuteCommandTool, StreamingHandler]: + tool = ExecuteCommandTool(session_id="session-1", user_id="10001") + handler = StreamingHandler() + handler._streaming_enabled = True + handler._flush_task = object() + tool.set_stream_handler(handler) + return tool, handler + + def test_run_streams_live_output_and_collects_result(self): + tool, handler = self._build_streaming_tool() + command = self._build_python_command( + 'import sys; print("out"); print("err", file=sys.stderr)' + ) + + result = asyncio.run(tool.run(command=command, timeout=5)) + live_output = asyncio.run(handler.take()) + + self.assertIn("命令执行完成 (退出码: 0)", result) + self.assertIn("标准输出:\nout", result) + self.assertIn("错误输出:\nerr", result) + self.assertIn("标准输出:\nout", live_output) + self.assertIn("错误输出:\nerr", live_output) + + def test_run_timeout_keeps_partial_output(self): + tool = ExecuteCommandTool(session_id="session-1", user_id="10001") + command = self._build_python_command( + 'import sys,time; print("start"); sys.stdout.flush(); time.sleep(0.2)' + ) + + result = asyncio.run(tool.run(command=command, timeout=0.05)) + + self.assertIn("命令执行超时", result) + self.assertIn("标准输出:\nstart", result) + + +if __name__ == "__main__": + unittest.main()