fix agent stream blocking during command execution

Offload synchronous message edits from the event loop and stream subprocess output so long-running commands stay responsive.
This commit is contained in:
jxxghp
2026-04-27 07:57:32 +08:00
parent 7bc032d17c
commit 140d224a9a
4 changed files with 293 additions and 35 deletions

View File

@@ -2,6 +2,8 @@ import asyncio
import threading
from typing import Optional, Tuple
from fastapi.concurrency import run_in_threadpool
from app.chain import ChainBase
from app.log import logger
from app.schemas import Notification
@@ -256,7 +258,8 @@ class StreamingHandler:
try:
if self._message_response is None:
# 第一次发送:发送新消息并获取 message_id
response = chain.send_direct_message(
response = await run_in_threadpool(
chain.send_direct_message,
Notification(
channel=self._channel,
source=self._source,
@@ -264,7 +267,7 @@ class StreamingHandler:
username=self._username,
title=self._title,
text=current_text,
)
),
)
if response and response.success and response.message_id:
self._message_response = response
@@ -297,7 +300,8 @@ class StreamingHandler:
# 如果偏移后还有新内容,立即发送为新消息
if current_text:
response = chain.send_direct_message(
response = await run_in_threadpool(
chain.send_direct_message,
Notification(
channel=self._channel,
source=self._source,
@@ -305,7 +309,7 @@ class StreamingHandler:
username=self._username,
title=self._title,
text=current_text,
)
),
)
if response and response.success and response.message_id:
self._message_response = response
@@ -324,7 +328,8 @@ class StreamingHandler:
except (ValueError, KeyError):
return
success = chain.edit_message(
success = await run_in_threadpool(
chain.edit_message,
channel=channel_enum,
source=self._message_response.source,
message_id=self._message_response.message_id,

View File

@@ -1,7 +1,8 @@
"""执行Shell命令工具"""
import asyncio
from typing import Optional, Type
import codecs
from typing import Any, Dict, Optional, Type
from pydantic import BaseModel, Field
@@ -26,12 +27,133 @@ class ExecuteCommandTool(MoviePilotTool):
description: str = "Safely execute shell commands on the server. Useful for system maintenance, checking status, or running custom scripts. Includes timeout and output limits."
args_schema: Type[BaseModel] = ExecuteCommandInput
require_admin: bool = True
RESULT_LIMIT = 3000
STREAM_CAPTURE_LIMIT = 2000
LIVE_OUTPUT_LIMIT = 1200
def get_tool_message(self, **kwargs) -> Optional[str]:
"""根据命令生成友好的提示消息"""
command = kwargs.get("command", "")
return f"执行系统命令: {command}"
def _build_result(
self,
message: str,
stdout_capture: Dict[str, Any],
stderr_capture: Dict[str, Any],
) -> str:
stdout_str = "".join(stdout_capture["chunks"]).strip()
stderr_str = "".join(stderr_capture["chunks"]).strip()
result = message
if stdout_str:
result += f"\n\n标准输出:\n{stdout_str}"
if stderr_str:
result += f"\n\n错误输出:\n{stderr_str}"
if not stdout_str and not stderr_str:
result += "\n\n(无输出内容)"
was_truncated = stdout_capture["truncated"] or stderr_capture["truncated"]
overflow_suffix = "\n\n...(输出内容过长,已截断)"
if was_truncated or len(result) > self.RESULT_LIMIT:
result = (
result[: self.RESULT_LIMIT - len(overflow_suffix)] + overflow_suffix
)
return result
def _append_capture(self, capture: Dict[str, Any], text: str):
if not text:
return
remaining = self.STREAM_CAPTURE_LIMIT - capture["length"]
if remaining <= 0:
capture["truncated"] = True
return
fragment = text[:remaining]
capture["chunks"].append(fragment)
capture["length"] += len(fragment)
if len(text) > remaining:
capture["truncated"] = True
def _should_emit_live_output(self) -> bool:
return bool(
self._stream_handler
and self._stream_handler.is_streaming
and self._stream_handler.is_auto_flushing
)
def _emit_live_output(
self, text: str, stream_name: str, live_state: Dict[str, Any]
):
if not text or not live_state["enabled"]:
return
header_key = f"{stream_name}_header_sent"
prefix = ""
if not live_state[header_key]:
prefix = "标准输出:\n" if stream_name == "stdout" else "\n错误输出:\n"
live_state[header_key] = True
payload = prefix + text
remaining = self.LIVE_OUTPUT_LIMIT - live_state["chars"]
if remaining <= 0:
if not live_state["truncated"]:
self._stream_handler.emit("\n...(命令输出过长,停止实时展示)\n")
live_state["truncated"] = True
return
fragment = payload[:remaining]
if fragment:
self._stream_handler.emit(fragment)
live_state["chars"] += len(fragment)
if len(payload) > remaining and not live_state["truncated"]:
self._stream_handler.emit("\n...(命令输出过长,停止实时展示)\n")
live_state["truncated"] = True
async def _collect_stream(
self,
stream: Optional[asyncio.StreamReader],
stream_name: str,
capture: Dict[str, Any],
live_state: Dict[str, Any],
):
if not stream:
return
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
while True:
chunk = await stream.read(512)
if not chunk:
tail = decoder.decode(b"", final=True)
if tail:
self._append_capture(capture, tail)
self._emit_live_output(tail, stream_name, live_state)
return
text = decoder.decode(chunk)
if not text:
continue
self._append_capture(capture, text)
self._emit_live_output(text, stream_name, live_state)
@staticmethod
async def _terminate_process(process: asyncio.subprocess.Process):
if process.returncode is not None:
return
try:
process.kill()
except ProcessLookupError:
return
try:
await asyncio.wait_for(process.wait(), timeout=5)
except asyncio.TimeoutError:
logger.warning("终止命令进程超时")
async def run(self, command: str, timeout: Optional[int] = 60, **kwargs) -> str:
logger.info(
f"执行工具: {self.name}, 参数: command={command}, timeout={timeout}"
@@ -56,40 +178,54 @@ class ExecuteCommandTool(MoviePilotTool):
command, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
)
stdout_capture: Dict[str, Any] = {
"chunks": [],
"length": 0,
"truncated": False,
}
stderr_capture: Dict[str, Any] = {
"chunks": [],
"length": 0,
"truncated": False,
}
live_state: Dict[str, Any] = {
"enabled": self._should_emit_live_output(),
"chars": 0,
"truncated": False,
"stdout_header_sent": False,
"stderr_header_sent": False,
}
stdout_task = asyncio.create_task(
self._collect_stream(
process.stdout, "stdout", stdout_capture, live_state
)
)
stderr_task = asyncio.create_task(
self._collect_stream(
process.stderr, "stderr", stderr_capture, live_state
)
)
try:
# 等待完成,带超时
stdout, stderr = await asyncio.wait_for(
process.communicate(), timeout=timeout
await asyncio.wait_for(process.wait(), timeout=timeout)
await asyncio.gather(stdout_task, stderr_task)
return self._build_result(
f"命令执行完成 (退出码: {process.returncode})",
stdout_capture,
stderr_capture,
)
# 处理输出
stdout_str = stdout.decode("utf-8", errors="replace").strip()
stderr_str = stderr.decode("utf-8", errors="replace").strip()
exit_code = process.returncode
result = f"命令执行完成 (退出码: {exit_code})"
if stdout_str:
result += f"\n\n标准输出:\n{stdout_str}"
if stderr_str:
result += f"\n\n错误输出:\n{stderr_str}"
# 如果没有输出
if not stdout_str and not stderr_str:
result += "\n\n(无输出内容)"
# 限制输出长度,防止上下文过长
if len(result) > 3000:
result = result[:3000] + "\n\n...(输出内容过长,已截断)"
return result
except asyncio.TimeoutError:
# 超时处理
try:
process.kill()
except ProcessLookupError:
pass
return f"命令执行超时 (限制: {timeout}秒)"
await self._terminate_process(process)
await asyncio.gather(stdout_task, stderr_task)
return self._build_result(
f"命令执行超时 (限制: {timeout}秒)",
stdout_capture,
stderr_capture,
)
except Exception as e:
logger.error(f"执行命令失败: {e}", exc_info=True)

View File

@@ -1,10 +1,17 @@
import asyncio
import unittest
from unittest.mock import patch
from unittest.mock import AsyncMock, patch
import langchain.agents as langchain_agents
if not hasattr(langchain_agents, "create_agent"):
langchain_agents.create_agent = lambda *args, **kwargs: None
from app.agent.callback import StreamingHandler
from app.agent.tools.base import MoviePilotTool
from app.core.config import settings
from app.schemas.message import MessageResponse
from app.schemas.types import MessageChannel
class DummyTool(MoviePilotTool):
@@ -48,6 +55,59 @@ class TestAgentToolStreaming(unittest.TestCase):
self.assertEqual(result, "ok")
self.assertEqual(buffered_message, "")
def test_flush_sends_direct_message_via_threadpool(self):
handler = StreamingHandler()
handler._channel = MessageChannel.Telegram.value
handler._source = "telegram"
handler._user_id = "10001"
handler._username = "tester"
handler._streaming_enabled = True
handler.emit("hello")
with patch(
"app.agent.callback.run_in_threadpool", new_callable=AsyncMock
) as run_in_threadpool_mock:
run_in_threadpool_mock.return_value = MessageResponse(
message_id=1,
chat_id=2,
source="telegram",
success=True,
)
asyncio.run(handler._flush())
self.assertEqual(run_in_threadpool_mock.await_count, 1)
self.assertEqual(
run_in_threadpool_mock.await_args.args[0].__name__, "send_direct_message"
)
self.assertTrue(handler.has_sent_message)
def test_flush_edits_message_via_threadpool(self):
handler = StreamingHandler()
handler._channel = MessageChannel.Telegram.value
handler._streaming_enabled = True
handler._message_response = MessageResponse(
message_id=1,
chat_id=2,
source="telegram",
success=True,
)
handler._sent_text = "hello"
handler.emit("hello world")
with patch(
"app.agent.callback.run_in_threadpool", new_callable=AsyncMock
) as run_in_threadpool_mock:
run_in_threadpool_mock.return_value = True
asyncio.run(handler._flush())
self.assertEqual(run_in_threadpool_mock.await_count, 1)
self.assertEqual(
run_in_threadpool_mock.await_args.args[0].__name__, "edit_message"
)
self.assertEqual(handler._sent_text, "hello world")
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,57 @@
import asyncio
import shlex
import sys
import unittest
import langchain.agents as langchain_agents
if not hasattr(langchain_agents, "create_agent"):
langchain_agents.create_agent = lambda *args, **kwargs: None
from app.agent.callback import StreamingHandler
from app.agent.tools.impl.execute_command import ExecuteCommandTool
class TestExecuteCommandTool(unittest.TestCase):
@staticmethod
def _build_python_command(script: str) -> str:
return f"{shlex.quote(sys.executable)} -c '{script}'"
@staticmethod
def _build_streaming_tool() -> tuple[ExecuteCommandTool, StreamingHandler]:
tool = ExecuteCommandTool(session_id="session-1", user_id="10001")
handler = StreamingHandler()
handler._streaming_enabled = True
handler._flush_task = object()
tool.set_stream_handler(handler)
return tool, handler
def test_run_streams_live_output_and_collects_result(self):
tool, handler = self._build_streaming_tool()
command = self._build_python_command(
'import sys; print("out"); print("err", file=sys.stderr)'
)
result = asyncio.run(tool.run(command=command, timeout=5))
live_output = asyncio.run(handler.take())
self.assertIn("命令执行完成 (退出码: 0)", result)
self.assertIn("标准输出:\nout", result)
self.assertIn("错误输出:\nerr", result)
self.assertIn("标准输出:\nout", live_output)
self.assertIn("错误输出:\nerr", live_output)
def test_run_timeout_keeps_partial_output(self):
tool = ExecuteCommandTool(session_id="session-1", user_id="10001")
command = self._build_python_command(
'import sys,time; print("start"); sys.stdout.flush(); time.sleep(0.2)'
)
result = asyncio.run(tool.run(command=command, timeout=0.05))
self.assertIn("命令执行超时", result)
self.assertIn("标准输出:\nstart", result)
if __name__ == "__main__":
unittest.main()