mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-07-05 03:17:38 +08:00
Improve streaming tool summary counts
This commit is contained in:
@@ -261,6 +261,12 @@ class StreamingHandler:
|
||||
tool_message=tool_message,
|
||||
tool_kwargs=tool_kwargs or {},
|
||||
)
|
||||
target_values = []
|
||||
if isinstance(target, (list, tuple, set)):
|
||||
target_values = [item for item in target if item]
|
||||
elif target:
|
||||
target_values = [target]
|
||||
|
||||
with self._lock:
|
||||
bucket = self._pending_tool_stats.setdefault(
|
||||
category,
|
||||
@@ -269,9 +275,30 @@ class StreamingHandler:
|
||||
"targets": set(),
|
||||
},
|
||||
)
|
||||
bucket["count"] += 1
|
||||
if target:
|
||||
bucket["targets"].add(str(target))
|
||||
if category == "subagent" and target_values:
|
||||
bucket["count"] += len(target_values)
|
||||
else:
|
||||
bucket["count"] += 1
|
||||
for target_value in target_values:
|
||||
bucket["targets"].add(str(target_value))
|
||||
|
||||
@staticmethod
|
||||
def _extract_subagent_targets(tool_kwargs: dict[str, Any]) -> list[str]:
|
||||
"""提取子代理工具请求中的目标子代理类型。"""
|
||||
tasks = tool_kwargs.get("tasks")
|
||||
if not isinstance(tasks, list):
|
||||
subagent_type = tool_kwargs.get("subagent_type")
|
||||
return [str(subagent_type)] if subagent_type else []
|
||||
|
||||
targets = []
|
||||
for task in tasks:
|
||||
if isinstance(task, dict):
|
||||
subagent_type = task.get("subagent_type")
|
||||
else:
|
||||
subagent_type = getattr(task, "subagent_type", None)
|
||||
if subagent_type:
|
||||
targets.append(str(subagent_type))
|
||||
return targets
|
||||
|
||||
def flush_pending_tool_summary(self) -> str:
|
||||
"""
|
||||
@@ -288,11 +315,17 @@ class StreamingHandler:
|
||||
tool_name: str,
|
||||
tool_message: Optional[str],
|
||||
tool_kwargs: dict[str, Any],
|
||||
) -> tuple[str, Optional[str]]:
|
||||
) -> tuple[str, Optional[Any]]:
|
||||
tool_name = (tool_name or "").strip().lower()
|
||||
tool_message = (tool_message or "").strip()
|
||||
tool_message_lower = tool_message.lower()
|
||||
|
||||
if tool_name == "skill":
|
||||
return "skill", tool_kwargs.get("name")
|
||||
if tool_name == "query_activity_log":
|
||||
return "activity_log", tool_kwargs.get("keyword") or tool_kwargs.get("date")
|
||||
if tool_name == "subagent_task":
|
||||
return "subagent", StreamingHandler._extract_subagent_targets(tool_kwargs)
|
||||
if tool_name == "task":
|
||||
return "subagent", tool_kwargs.get("subagent_type")
|
||||
if tool_name == "read_file":
|
||||
@@ -369,7 +402,7 @@ class StreamingHandler:
|
||||
parts = []
|
||||
for category, bucket in self._pending_tool_stats.items():
|
||||
value = bucket["count"]
|
||||
if category in {"file_read", "file_write", "directory", "web_browse"} and bucket["targets"]:
|
||||
if category in {"file_read", "file_write", "directory", "web_browse", "skill"} and bucket["targets"]:
|
||||
value = len(bucket["targets"])
|
||||
part = self._format_tool_stat(category, value)
|
||||
if part:
|
||||
@@ -406,6 +439,10 @@ class StreamingHandler:
|
||||
return f"执行了 {count} 条命令"
|
||||
if category == "data_query":
|
||||
return f"查询了 {count} 次数据"
|
||||
if category == "skill":
|
||||
return f"查询了 {count} 个技能说明"
|
||||
if category == "activity_log":
|
||||
return f"查询了 {count} 次活动日志"
|
||||
if category == "action":
|
||||
return f"执行了 {count} 次操作"
|
||||
if category == "interaction":
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import asyncio
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
from langchain.agents.middleware import SummarizationMiddleware
|
||||
@@ -16,180 +15,197 @@ class _FakeLLM:
|
||||
self.profile = {"max_input_tokens": 64000}
|
||||
|
||||
|
||||
class TestAgentSummarizationStreaming(unittest.TestCase):
|
||||
def test_streaming_agent_uses_non_streaming_llm_for_summary(self):
|
||||
agent = agent_module.MoviePilotAgent(session_id="session-1", user_id="10001")
|
||||
main_llm = _FakeLLM("main")
|
||||
non_streaming_llm = _FakeLLM("non-streaming")
|
||||
captured: dict = {}
|
||||
def test_streaming_agent_uses_non_streaming_llm_for_summary():
|
||||
"""流式 Agent 的摘要中间件应使用非流式 LLM。"""
|
||||
agent = agent_module.MoviePilotAgent(session_id="session-1", user_id="10001")
|
||||
main_llm = _FakeLLM("main")
|
||||
non_streaming_llm = _FakeLLM("non-streaming")
|
||||
captured: dict = {}
|
||||
|
||||
def _fake_create_agent(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return object()
|
||||
def _fake_create_agent(**kwargs):
|
||||
"""捕获 create_agent 参数。"""
|
||||
captured.update(kwargs)
|
||||
return object()
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
agent, "_initialize_llm", side_effect=[main_llm, non_streaming_llm]
|
||||
),
|
||||
patch.object(agent, "_initialize_tools", return_value=[]),
|
||||
patch.object(
|
||||
agent_module.prompt_manager, "get_agent_prompt", return_value="prompt"
|
||||
),
|
||||
patch.object(
|
||||
agent_module, "create_subagent_middlewares", return_value=([], [])
|
||||
),
|
||||
patch.object(agent_module, "create_agent", side_effect=_fake_create_agent),
|
||||
patch.object(agent_module.settings, "LLM_MAX_TOOLS", 0),
|
||||
with (
|
||||
patch.object(
|
||||
agent, "_initialize_llm", side_effect=[main_llm, non_streaming_llm]
|
||||
),
|
||||
patch.object(agent, "_initialize_tools", return_value=[]),
|
||||
patch.object(
|
||||
agent_module.prompt_manager, "get_agent_prompt", return_value="prompt"
|
||||
),
|
||||
patch.object(
|
||||
agent_module, "create_subagent_middlewares", return_value=([], [])
|
||||
),
|
||||
patch.object(agent_module, "create_agent", side_effect=_fake_create_agent),
|
||||
patch.object(agent_module.settings, "LLM_MAX_TOOLS", 0),
|
||||
):
|
||||
asyncio.run(agent._create_agent(streaming=True))
|
||||
|
||||
summary_middleware = next(
|
||||
middleware
|
||||
for middleware in captured["middleware"]
|
||||
if isinstance(middleware, SummarizationMiddleware)
|
||||
)
|
||||
|
||||
assert captured["model"] is main_llm
|
||||
assert summary_middleware.model is non_streaming_llm
|
||||
|
||||
|
||||
def test_streaming_agent_uses_non_streaming_llm_for_model_middlewares():
|
||||
"""流式 Agent 的模型型中间件应使用非流式 LLM。"""
|
||||
agent = agent_module.MoviePilotAgent(session_id="session-1", user_id="10001")
|
||||
main_llm = _FakeLLM("main")
|
||||
non_streaming_llm = _FakeLLM("non-streaming")
|
||||
captured: dict = {}
|
||||
|
||||
class _FakeToolSelectorMiddleware:
|
||||
"""记录工具选择中间件初始化参数。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
max_tools,
|
||||
always_include=None,
|
||||
selection_tools=None,
|
||||
):
|
||||
asyncio.run(agent._create_agent(streaming=True))
|
||||
"""保存测试断言需要的参数。"""
|
||||
self.model = model
|
||||
self.max_tools = max_tools
|
||||
self.always_include = always_include or []
|
||||
self.selection_tools = selection_tools or []
|
||||
|
||||
summary_middleware = next(
|
||||
middleware
|
||||
for middleware in captured["middleware"]
|
||||
if isinstance(middleware, SummarizationMiddleware)
|
||||
)
|
||||
def _fake_create_agent(**kwargs):
|
||||
"""捕获 create_agent 参数。"""
|
||||
captured.update(kwargs)
|
||||
return object()
|
||||
|
||||
self.assertIs(captured["model"], main_llm)
|
||||
self.assertIs(summary_middleware.model, non_streaming_llm)
|
||||
class _FakeTool:
|
||||
"""测试用工具占位对象。"""
|
||||
|
||||
def test_streaming_agent_uses_non_streaming_llm_for_model_middlewares(self):
|
||||
agent = agent_module.MoviePilotAgent(session_id="session-1", user_id="10001")
|
||||
main_llm = _FakeLLM("main")
|
||||
non_streaming_llm = _FakeLLM("non-streaming")
|
||||
captured: dict = {}
|
||||
def __init__(self, name: str):
|
||||
"""保存工具名。"""
|
||||
self.name = name
|
||||
|
||||
class _FakeToolSelectorMiddleware:
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
max_tools,
|
||||
always_include=None,
|
||||
selection_tools=None,
|
||||
):
|
||||
self.model = model
|
||||
self.max_tools = max_tools
|
||||
self.always_include = always_include or []
|
||||
self.selection_tools = selection_tools or []
|
||||
fake_tools = [
|
||||
_FakeTool("list_directory"),
|
||||
_FakeTool("write_file"),
|
||||
_FakeTool("read_file"),
|
||||
_FakeTool("edit_file"),
|
||||
_FakeTool("execute_command"),
|
||||
_FakeTool("search_media"),
|
||||
]
|
||||
|
||||
def _fake_create_agent(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return object()
|
||||
with (
|
||||
patch.object(
|
||||
agent, "_initialize_llm", side_effect=[main_llm, non_streaming_llm]
|
||||
),
|
||||
patch.object(agent, "_initialize_tools", return_value=fake_tools),
|
||||
patch.object(
|
||||
agent_module.prompt_manager, "get_agent_prompt", return_value="prompt"
|
||||
),
|
||||
patch.object(
|
||||
agent_module, "create_subagent_middlewares", return_value=([], [])
|
||||
),
|
||||
patch.object(
|
||||
agent_module,
|
||||
"ToolSelectorMiddleware",
|
||||
_FakeToolSelectorMiddleware,
|
||||
),
|
||||
patch.object(agent_module, "create_agent", side_effect=_fake_create_agent),
|
||||
patch.object(agent_module.settings, "LLM_MAX_TOOLS", 3),
|
||||
):
|
||||
asyncio.run(agent._create_agent(streaming=True))
|
||||
|
||||
class _FakeTool:
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
tool_selector_middleware = next(
|
||||
middleware
|
||||
for middleware in captured["middleware"]
|
||||
if isinstance(middleware, _FakeToolSelectorMiddleware)
|
||||
)
|
||||
|
||||
fake_tools = [
|
||||
_FakeTool("list_directory"),
|
||||
_FakeTool("write_file"),
|
||||
_FakeTool("read_file"),
|
||||
_FakeTool("edit_file"),
|
||||
_FakeTool("execute_command"),
|
||||
_FakeTool("search_media"),
|
||||
]
|
||||
assert tool_selector_middleware.model is non_streaming_llm
|
||||
assert tool_selector_middleware.max_tools == 3
|
||||
assert tool_selector_middleware.always_include == [
|
||||
"list_directory",
|
||||
"write_file",
|
||||
"read_file",
|
||||
"edit_file",
|
||||
"execute_command",
|
||||
"skill",
|
||||
]
|
||||
assert tool_selector_middleware.selection_tools[: len(fake_tools)] == fake_tools
|
||||
assert [
|
||||
getattr(tool, "name", None)
|
||||
for tool in tool_selector_middleware.selection_tools[len(fake_tools):]
|
||||
] == ["skill"]
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
agent, "_initialize_llm", side_effect=[main_llm, non_streaming_llm]
|
||||
),
|
||||
patch.object(agent, "_initialize_tools", return_value=fake_tools),
|
||||
patch.object(
|
||||
agent_module.prompt_manager, "get_agent_prompt", return_value="prompt"
|
||||
),
|
||||
patch.object(
|
||||
agent_module, "create_subagent_middlewares", return_value=([], [])
|
||||
),
|
||||
patch.object(
|
||||
agent_module,
|
||||
"ToolSelectorMiddleware",
|
||||
_FakeToolSelectorMiddleware,
|
||||
),
|
||||
patch.object(agent_module, "create_agent", side_effect=_fake_create_agent),
|
||||
patch.object(agent_module.settings, "LLM_MAX_TOOLS", 3),
|
||||
):
|
||||
asyncio.run(agent._create_agent(streaming=True))
|
||||
|
||||
tool_selector_middleware = next(
|
||||
middleware
|
||||
for middleware in captured["middleware"]
|
||||
if isinstance(middleware, _FakeToolSelectorMiddleware)
|
||||
)
|
||||
def test_non_streaming_agent_reuses_main_llm_for_summary():
|
||||
"""非流式 Agent 的摘要中间件应复用主 LLM。"""
|
||||
agent = agent_module.MoviePilotAgent(session_id="session-1", user_id="10001")
|
||||
main_llm = _FakeLLM("main")
|
||||
captured: dict = {}
|
||||
|
||||
self.assertIs(tool_selector_middleware.model, non_streaming_llm)
|
||||
self.assertEqual(tool_selector_middleware.max_tools, 3)
|
||||
self.assertEqual(
|
||||
tool_selector_middleware.always_include,
|
||||
[
|
||||
"list_directory",
|
||||
"write_file",
|
||||
"read_file",
|
||||
"edit_file",
|
||||
"execute_command",
|
||||
],
|
||||
)
|
||||
self.assertEqual(tool_selector_middleware.selection_tools, fake_tools)
|
||||
def _fake_create_agent(**kwargs):
|
||||
"""捕获 create_agent 参数。"""
|
||||
captured.update(kwargs)
|
||||
return object()
|
||||
|
||||
def test_non_streaming_agent_reuses_main_llm_for_summary(self):
|
||||
agent = agent_module.MoviePilotAgent(session_id="session-1", user_id="10001")
|
||||
main_llm = _FakeLLM("main")
|
||||
captured: dict = {}
|
||||
with (
|
||||
patch.object(agent, "_initialize_llm", return_value=main_llm),
|
||||
patch.object(agent, "_initialize_tools", return_value=[]),
|
||||
patch.object(
|
||||
agent_module.prompt_manager, "get_agent_prompt", return_value="prompt"
|
||||
),
|
||||
patch.object(
|
||||
agent_module, "create_subagent_middlewares", return_value=([], [])
|
||||
),
|
||||
patch.object(agent_module, "create_agent", side_effect=_fake_create_agent),
|
||||
patch.object(agent_module.settings, "LLM_MAX_TOOLS", 0),
|
||||
):
|
||||
asyncio.run(agent._create_agent(streaming=False))
|
||||
|
||||
def _fake_create_agent(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return object()
|
||||
summary_middleware = next(
|
||||
middleware
|
||||
for middleware in captured["middleware"]
|
||||
if isinstance(middleware, SummarizationMiddleware)
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(agent, "_initialize_llm", return_value=main_llm),
|
||||
patch.object(agent, "_initialize_tools", return_value=[]),
|
||||
patch.object(
|
||||
agent_module.prompt_manager, "get_agent_prompt", return_value="prompt"
|
||||
),
|
||||
patch.object(
|
||||
agent_module, "create_subagent_middlewares", return_value=([], [])
|
||||
),
|
||||
patch.object(agent_module, "create_agent", side_effect=_fake_create_agent),
|
||||
patch.object(agent_module.settings, "LLM_MAX_TOOLS", 0),
|
||||
):
|
||||
asyncio.run(agent._create_agent(streaming=False))
|
||||
assert captured["model"] is main_llm
|
||||
assert summary_middleware.model is main_llm
|
||||
|
||||
summary_middleware = next(
|
||||
middleware
|
||||
for middleware in captured["middleware"]
|
||||
if isinstance(middleware, SummarizationMiddleware)
|
||||
)
|
||||
|
||||
self.assertIs(captured["model"], main_llm)
|
||||
self.assertIs(summary_middleware.model, main_llm)
|
||||
def test_agent_uses_runtime_config_middleware_instead_of_hooks():
|
||||
"""Agent 应使用运行时配置中间件而不是旧 hooks。"""
|
||||
agent = agent_module.MoviePilotAgent(session_id="session-1", user_id="10001")
|
||||
main_llm = _FakeLLM("main")
|
||||
captured: dict = {}
|
||||
|
||||
def test_agent_uses_runtime_config_middleware_instead_of_hooks(self):
|
||||
agent = agent_module.MoviePilotAgent(session_id="session-1", user_id="10001")
|
||||
main_llm = _FakeLLM("main")
|
||||
captured: dict = {}
|
||||
def _fake_create_agent(**kwargs):
|
||||
"""捕获 create_agent 参数。"""
|
||||
captured.update(kwargs)
|
||||
return object()
|
||||
|
||||
def _fake_create_agent(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return object()
|
||||
with (
|
||||
patch.object(agent, "_initialize_llm", return_value=main_llm),
|
||||
patch.object(agent, "_initialize_tools", return_value=[]),
|
||||
patch.object(
|
||||
agent_module.prompt_manager, "get_agent_prompt", return_value="prompt"
|
||||
),
|
||||
patch.object(
|
||||
agent_module, "create_subagent_middlewares", return_value=([], [])
|
||||
),
|
||||
patch.object(agent_module, "create_agent", side_effect=_fake_create_agent),
|
||||
patch.object(agent_module.settings, "LLM_MAX_TOOLS", 0),
|
||||
):
|
||||
asyncio.run(agent._create_agent(streaming=False))
|
||||
|
||||
with (
|
||||
patch.object(agent, "_initialize_llm", return_value=main_llm),
|
||||
patch.object(agent, "_initialize_tools", return_value=[]),
|
||||
patch.object(
|
||||
agent_module.prompt_manager, "get_agent_prompt", return_value="prompt"
|
||||
),
|
||||
patch.object(
|
||||
agent_module, "create_subagent_middlewares", return_value=([], [])
|
||||
),
|
||||
patch.object(agent_module, "create_agent", side_effect=_fake_create_agent),
|
||||
patch.object(agent_module.settings, "LLM_MAX_TOOLS", 0),
|
||||
):
|
||||
asyncio.run(agent._create_agent(streaming=False))
|
||||
|
||||
self.assertTrue(
|
||||
any(
|
||||
isinstance(middleware, RuntimeConfigMiddleware)
|
||||
for middleware in captured["middleware"]
|
||||
)
|
||||
)
|
||||
self.assertFalse(
|
||||
any(type(middleware).__name__ == "AgentHooksMiddleware" for middleware in captured["middleware"])
|
||||
)
|
||||
assert any(
|
||||
isinstance(middleware, RuntimeConfigMiddleware)
|
||||
for middleware in captured["middleware"]
|
||||
)
|
||||
assert not any(
|
||||
type(middleware).__name__ == "AgentHooksMiddleware"
|
||||
for middleware in captured["middleware"]
|
||||
)
|
||||
|
||||
@@ -141,6 +141,56 @@ class TestAgentToolStreaming:
|
||||
|
||||
assert buffered_message == "处理中:\n\n(已调用 2 个子代理)\n\n"
|
||||
|
||||
def test_non_verbose_tool_summary_describes_skill_lookup(self):
|
||||
"""校验非详细模式单独描述 Skill 说明查询。"""
|
||||
async def _run():
|
||||
handler = StreamingHandler()
|
||||
await handler.start_streaming()
|
||||
handler.emit("处理中:")
|
||||
handler.record_tool_call(
|
||||
tool_name="skill",
|
||||
tool_message="Loads the full instructions for a MoviePilot skill",
|
||||
tool_kwargs={"name": "moviepilot-cli"},
|
||||
)
|
||||
handler.record_tool_call(
|
||||
tool_name="skill",
|
||||
tool_message="Loads the full instructions for a MoviePilot skill",
|
||||
tool_kwargs={"name": "moviepilot-cli"},
|
||||
)
|
||||
handler.record_tool_call(
|
||||
tool_name="query_activity_log",
|
||||
tool_message="Query recent MoviePilot Agent activity logs",
|
||||
tool_kwargs={"keyword": "整理"},
|
||||
)
|
||||
return await handler.take()
|
||||
|
||||
buffered_message = asyncio.run(_run())
|
||||
|
||||
assert buffered_message == "处理中:\n\n(查询了 1 个技能说明,查询了 1 次活动日志)\n\n"
|
||||
|
||||
def test_non_verbose_tool_summary_counts_subagent_batch_tasks(self):
|
||||
"""校验批量子代理控制工具按子任务数统计。"""
|
||||
async def _run():
|
||||
handler = StreamingHandler()
|
||||
await handler.start_streaming()
|
||||
handler.emit("处理中:")
|
||||
handler.record_tool_call(
|
||||
tool_name="subagent_task",
|
||||
tool_message="Start and manage multiple MoviePilot subagent tasks",
|
||||
tool_kwargs={
|
||||
"action": "start",
|
||||
"tasks": [
|
||||
{"subagent_type": "media-researcher"},
|
||||
{"subagent_type": "download-diagnostician"},
|
||||
],
|
||||
},
|
||||
)
|
||||
return await handler.take()
|
||||
|
||||
buffered_message = asyncio.run(_run())
|
||||
|
||||
assert buffered_message == "处理中:\n\n(已调用 2 个子代理)\n\n"
|
||||
|
||||
def test_subagent_stream_metadata_is_suppressed(self):
|
||||
"""校验子代理流式元数据会被识别并抑制。"""
|
||||
assert is_subagent_stream_metadata(
|
||||
|
||||
Reference in New Issue
Block a user