mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-05-07 05:43:55 +08:00
fix agent stream blocking during command execution
Offload synchronous message edits from the event loop and stream subprocess output so long-running commands stay responsive.
This commit is contained in:
@@ -2,6 +2,8 @@ import asyncio
|
||||
import threading
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
|
||||
from app.chain import ChainBase
|
||||
from app.log import logger
|
||||
from app.schemas import Notification
|
||||
@@ -256,7 +258,8 @@ class StreamingHandler:
|
||||
try:
|
||||
if self._message_response is None:
|
||||
# 第一次发送:发送新消息并获取 message_id
|
||||
response = chain.send_direct_message(
|
||||
response = await run_in_threadpool(
|
||||
chain.send_direct_message,
|
||||
Notification(
|
||||
channel=self._channel,
|
||||
source=self._source,
|
||||
@@ -264,7 +267,7 @@ class StreamingHandler:
|
||||
username=self._username,
|
||||
title=self._title,
|
||||
text=current_text,
|
||||
)
|
||||
),
|
||||
)
|
||||
if response and response.success and response.message_id:
|
||||
self._message_response = response
|
||||
@@ -297,7 +300,8 @@ class StreamingHandler:
|
||||
|
||||
# 如果偏移后还有新内容,立即发送为新消息
|
||||
if current_text:
|
||||
response = chain.send_direct_message(
|
||||
response = await run_in_threadpool(
|
||||
chain.send_direct_message,
|
||||
Notification(
|
||||
channel=self._channel,
|
||||
source=self._source,
|
||||
@@ -305,7 +309,7 @@ class StreamingHandler:
|
||||
username=self._username,
|
||||
title=self._title,
|
||||
text=current_text,
|
||||
)
|
||||
),
|
||||
)
|
||||
if response and response.success and response.message_id:
|
||||
self._message_response = response
|
||||
@@ -324,7 +328,8 @@ class StreamingHandler:
|
||||
except (ValueError, KeyError):
|
||||
return
|
||||
|
||||
success = chain.edit_message(
|
||||
success = await run_in_threadpool(
|
||||
chain.edit_message,
|
||||
channel=channel_enum,
|
||||
source=self._message_response.source,
|
||||
message_id=self._message_response.message_id,
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
"""执行Shell命令工具"""
|
||||
|
||||
import asyncio
|
||||
from typing import Optional, Type
|
||||
import codecs
|
||||
from typing import Any, Dict, Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -26,12 +27,133 @@ class ExecuteCommandTool(MoviePilotTool):
|
||||
description: str = "Safely execute shell commands on the server. Useful for system maintenance, checking status, or running custom scripts. Includes timeout and output limits."
|
||||
args_schema: Type[BaseModel] = ExecuteCommandInput
|
||||
require_admin: bool = True
|
||||
RESULT_LIMIT = 3000
|
||||
STREAM_CAPTURE_LIMIT = 2000
|
||||
LIVE_OUTPUT_LIMIT = 1200
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据命令生成友好的提示消息"""
|
||||
command = kwargs.get("command", "")
|
||||
return f"执行系统命令: {command}"
|
||||
|
||||
def _build_result(
|
||||
self,
|
||||
message: str,
|
||||
stdout_capture: Dict[str, Any],
|
||||
stderr_capture: Dict[str, Any],
|
||||
) -> str:
|
||||
stdout_str = "".join(stdout_capture["chunks"]).strip()
|
||||
stderr_str = "".join(stderr_capture["chunks"]).strip()
|
||||
|
||||
result = message
|
||||
if stdout_str:
|
||||
result += f"\n\n标准输出:\n{stdout_str}"
|
||||
if stderr_str:
|
||||
result += f"\n\n错误输出:\n{stderr_str}"
|
||||
if not stdout_str and not stderr_str:
|
||||
result += "\n\n(无输出内容)"
|
||||
|
||||
was_truncated = stdout_capture["truncated"] or stderr_capture["truncated"]
|
||||
overflow_suffix = "\n\n...(输出内容过长,已截断)"
|
||||
if was_truncated or len(result) > self.RESULT_LIMIT:
|
||||
result = (
|
||||
result[: self.RESULT_LIMIT - len(overflow_suffix)] + overflow_suffix
|
||||
)
|
||||
return result
|
||||
|
||||
def _append_capture(self, capture: Dict[str, Any], text: str):
|
||||
if not text:
|
||||
return
|
||||
|
||||
remaining = self.STREAM_CAPTURE_LIMIT - capture["length"]
|
||||
if remaining <= 0:
|
||||
capture["truncated"] = True
|
||||
return
|
||||
|
||||
fragment = text[:remaining]
|
||||
capture["chunks"].append(fragment)
|
||||
capture["length"] += len(fragment)
|
||||
if len(text) > remaining:
|
||||
capture["truncated"] = True
|
||||
|
||||
def _should_emit_live_output(self) -> bool:
|
||||
return bool(
|
||||
self._stream_handler
|
||||
and self._stream_handler.is_streaming
|
||||
and self._stream_handler.is_auto_flushing
|
||||
)
|
||||
|
||||
def _emit_live_output(
|
||||
self, text: str, stream_name: str, live_state: Dict[str, Any]
|
||||
):
|
||||
if not text or not live_state["enabled"]:
|
||||
return
|
||||
|
||||
header_key = f"{stream_name}_header_sent"
|
||||
prefix = ""
|
||||
if not live_state[header_key]:
|
||||
prefix = "标准输出:\n" if stream_name == "stdout" else "\n错误输出:\n"
|
||||
live_state[header_key] = True
|
||||
|
||||
payload = prefix + text
|
||||
remaining = self.LIVE_OUTPUT_LIMIT - live_state["chars"]
|
||||
if remaining <= 0:
|
||||
if not live_state["truncated"]:
|
||||
self._stream_handler.emit("\n...(命令输出过长,停止实时展示)\n")
|
||||
live_state["truncated"] = True
|
||||
return
|
||||
|
||||
fragment = payload[:remaining]
|
||||
if fragment:
|
||||
self._stream_handler.emit(fragment)
|
||||
live_state["chars"] += len(fragment)
|
||||
|
||||
if len(payload) > remaining and not live_state["truncated"]:
|
||||
self._stream_handler.emit("\n...(命令输出过长,停止实时展示)\n")
|
||||
live_state["truncated"] = True
|
||||
|
||||
async def _collect_stream(
|
||||
self,
|
||||
stream: Optional[asyncio.StreamReader],
|
||||
stream_name: str,
|
||||
capture: Dict[str, Any],
|
||||
live_state: Dict[str, Any],
|
||||
):
|
||||
if not stream:
|
||||
return
|
||||
|
||||
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
|
||||
while True:
|
||||
chunk = await stream.read(512)
|
||||
if not chunk:
|
||||
tail = decoder.decode(b"", final=True)
|
||||
if tail:
|
||||
self._append_capture(capture, tail)
|
||||
self._emit_live_output(tail, stream_name, live_state)
|
||||
return
|
||||
|
||||
text = decoder.decode(chunk)
|
||||
if not text:
|
||||
continue
|
||||
|
||||
self._append_capture(capture, text)
|
||||
self._emit_live_output(text, stream_name, live_state)
|
||||
|
||||
@staticmethod
|
||||
async def _terminate_process(process: asyncio.subprocess.Process):
|
||||
if process.returncode is not None:
|
||||
return
|
||||
|
||||
try:
|
||||
process.kill()
|
||||
except ProcessLookupError:
|
||||
return
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(process.wait(), timeout=5)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("终止命令进程超时")
|
||||
|
||||
async def run(self, command: str, timeout: Optional[int] = 60, **kwargs) -> str:
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: command={command}, timeout={timeout}"
|
||||
@@ -56,40 +178,54 @@ class ExecuteCommandTool(MoviePilotTool):
|
||||
command, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
|
||||
)
|
||||
|
||||
stdout_capture: Dict[str, Any] = {
|
||||
"chunks": [],
|
||||
"length": 0,
|
||||
"truncated": False,
|
||||
}
|
||||
stderr_capture: Dict[str, Any] = {
|
||||
"chunks": [],
|
||||
"length": 0,
|
||||
"truncated": False,
|
||||
}
|
||||
live_state: Dict[str, Any] = {
|
||||
"enabled": self._should_emit_live_output(),
|
||||
"chars": 0,
|
||||
"truncated": False,
|
||||
"stdout_header_sent": False,
|
||||
"stderr_header_sent": False,
|
||||
}
|
||||
|
||||
stdout_task = asyncio.create_task(
|
||||
self._collect_stream(
|
||||
process.stdout, "stdout", stdout_capture, live_state
|
||||
)
|
||||
)
|
||||
stderr_task = asyncio.create_task(
|
||||
self._collect_stream(
|
||||
process.stderr, "stderr", stderr_capture, live_state
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
# 等待完成,带超时
|
||||
stdout, stderr = await asyncio.wait_for(
|
||||
process.communicate(), timeout=timeout
|
||||
await asyncio.wait_for(process.wait(), timeout=timeout)
|
||||
await asyncio.gather(stdout_task, stderr_task)
|
||||
return self._build_result(
|
||||
f"命令执行完成 (退出码: {process.returncode})",
|
||||
stdout_capture,
|
||||
stderr_capture,
|
||||
)
|
||||
|
||||
# 处理输出
|
||||
stdout_str = stdout.decode("utf-8", errors="replace").strip()
|
||||
stderr_str = stderr.decode("utf-8", errors="replace").strip()
|
||||
exit_code = process.returncode
|
||||
|
||||
result = f"命令执行完成 (退出码: {exit_code})"
|
||||
if stdout_str:
|
||||
result += f"\n\n标准输出:\n{stdout_str}"
|
||||
if stderr_str:
|
||||
result += f"\n\n错误输出:\n{stderr_str}"
|
||||
|
||||
# 如果没有输出
|
||||
if not stdout_str and not stderr_str:
|
||||
result += "\n\n(无输出内容)"
|
||||
|
||||
# 限制输出长度,防止上下文过长
|
||||
if len(result) > 3000:
|
||||
result = result[:3000] + "\n\n...(输出内容过长,已截断)"
|
||||
|
||||
return result
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# 超时处理
|
||||
try:
|
||||
process.kill()
|
||||
except ProcessLookupError:
|
||||
pass
|
||||
return f"命令执行超时 (限制: {timeout}秒)"
|
||||
await self._terminate_process(process)
|
||||
await asyncio.gather(stdout_task, stderr_task)
|
||||
return self._build_result(
|
||||
f"命令执行超时 (限制: {timeout}秒)",
|
||||
stdout_capture,
|
||||
stderr_capture,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"执行命令失败: {e}", exc_info=True)
|
||||
|
||||
@@ -1,10 +1,17 @@
|
||||
import asyncio
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import langchain.agents as langchain_agents
|
||||
|
||||
if not hasattr(langchain_agents, "create_agent"):
|
||||
langchain_agents.create_agent = lambda *args, **kwargs: None
|
||||
|
||||
from app.agent.callback import StreamingHandler
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.core.config import settings
|
||||
from app.schemas.message import MessageResponse
|
||||
from app.schemas.types import MessageChannel
|
||||
|
||||
|
||||
class DummyTool(MoviePilotTool):
|
||||
@@ -48,6 +55,59 @@ class TestAgentToolStreaming(unittest.TestCase):
|
||||
self.assertEqual(result, "ok")
|
||||
self.assertEqual(buffered_message, "")
|
||||
|
||||
def test_flush_sends_direct_message_via_threadpool(self):
|
||||
handler = StreamingHandler()
|
||||
handler._channel = MessageChannel.Telegram.value
|
||||
handler._source = "telegram"
|
||||
handler._user_id = "10001"
|
||||
handler._username = "tester"
|
||||
handler._streaming_enabled = True
|
||||
handler.emit("hello")
|
||||
|
||||
with patch(
|
||||
"app.agent.callback.run_in_threadpool", new_callable=AsyncMock
|
||||
) as run_in_threadpool_mock:
|
||||
run_in_threadpool_mock.return_value = MessageResponse(
|
||||
message_id=1,
|
||||
chat_id=2,
|
||||
source="telegram",
|
||||
success=True,
|
||||
)
|
||||
|
||||
asyncio.run(handler._flush())
|
||||
|
||||
self.assertEqual(run_in_threadpool_mock.await_count, 1)
|
||||
self.assertEqual(
|
||||
run_in_threadpool_mock.await_args.args[0].__name__, "send_direct_message"
|
||||
)
|
||||
self.assertTrue(handler.has_sent_message)
|
||||
|
||||
def test_flush_edits_message_via_threadpool(self):
|
||||
handler = StreamingHandler()
|
||||
handler._channel = MessageChannel.Telegram.value
|
||||
handler._streaming_enabled = True
|
||||
handler._message_response = MessageResponse(
|
||||
message_id=1,
|
||||
chat_id=2,
|
||||
source="telegram",
|
||||
success=True,
|
||||
)
|
||||
handler._sent_text = "hello"
|
||||
handler.emit("hello world")
|
||||
|
||||
with patch(
|
||||
"app.agent.callback.run_in_threadpool", new_callable=AsyncMock
|
||||
) as run_in_threadpool_mock:
|
||||
run_in_threadpool_mock.return_value = True
|
||||
|
||||
asyncio.run(handler._flush())
|
||||
|
||||
self.assertEqual(run_in_threadpool_mock.await_count, 1)
|
||||
self.assertEqual(
|
||||
run_in_threadpool_mock.await_args.args[0].__name__, "edit_message"
|
||||
)
|
||||
self.assertEqual(handler._sent_text, "hello world")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
57
tests/test_execute_command_tool.py
Normal file
57
tests/test_execute_command_tool.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import asyncio
|
||||
import shlex
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import langchain.agents as langchain_agents
|
||||
|
||||
if not hasattr(langchain_agents, "create_agent"):
|
||||
langchain_agents.create_agent = lambda *args, **kwargs: None
|
||||
|
||||
from app.agent.callback import StreamingHandler
|
||||
from app.agent.tools.impl.execute_command import ExecuteCommandTool
|
||||
|
||||
|
||||
class TestExecuteCommandTool(unittest.TestCase):
|
||||
@staticmethod
|
||||
def _build_python_command(script: str) -> str:
|
||||
return f"{shlex.quote(sys.executable)} -c '{script}'"
|
||||
|
||||
@staticmethod
|
||||
def _build_streaming_tool() -> tuple[ExecuteCommandTool, StreamingHandler]:
|
||||
tool = ExecuteCommandTool(session_id="session-1", user_id="10001")
|
||||
handler = StreamingHandler()
|
||||
handler._streaming_enabled = True
|
||||
handler._flush_task = object()
|
||||
tool.set_stream_handler(handler)
|
||||
return tool, handler
|
||||
|
||||
def test_run_streams_live_output_and_collects_result(self):
|
||||
tool, handler = self._build_streaming_tool()
|
||||
command = self._build_python_command(
|
||||
'import sys; print("out"); print("err", file=sys.stderr)'
|
||||
)
|
||||
|
||||
result = asyncio.run(tool.run(command=command, timeout=5))
|
||||
live_output = asyncio.run(handler.take())
|
||||
|
||||
self.assertIn("命令执行完成 (退出码: 0)", result)
|
||||
self.assertIn("标准输出:\nout", result)
|
||||
self.assertIn("错误输出:\nerr", result)
|
||||
self.assertIn("标准输出:\nout", live_output)
|
||||
self.assertIn("错误输出:\nerr", live_output)
|
||||
|
||||
def test_run_timeout_keeps_partial_output(self):
|
||||
tool = ExecuteCommandTool(session_id="session-1", user_id="10001")
|
||||
command = self._build_python_command(
|
||||
'import sys,time; print("start"); sys.stdout.flush(); time.sleep(0.2)'
|
||||
)
|
||||
|
||||
result = asyncio.run(tool.run(command=command, timeout=0.05))
|
||||
|
||||
self.assertIn("命令执行超时", result)
|
||||
self.assertIn("标准输出:\nstart", result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user