From 78ddd6093f45a25e24050926d2fb49594235f6c7 Mon Sep 17 00:00:00 2001 From: jxxghp Date: Sun, 21 Jun 2026 18:49:27 +0800 Subject: [PATCH] Improve streaming tool summary counts --- app/agent/callback/__init__.py | 47 ++- tests/test_agent_summarization_streaming.py | 330 ++++++++++---------- tests/test_agent_tool_streaming.py | 50 +++ 3 files changed, 265 insertions(+), 162 deletions(-) diff --git a/app/agent/callback/__init__.py b/app/agent/callback/__init__.py index ec619489..d2621249 100644 --- a/app/agent/callback/__init__.py +++ b/app/agent/callback/__init__.py @@ -261,6 +261,12 @@ class StreamingHandler: tool_message=tool_message, tool_kwargs=tool_kwargs or {}, ) + target_values = [] + if isinstance(target, (list, tuple, set)): + target_values = [item for item in target if item] + elif target: + target_values = [target] + with self._lock: bucket = self._pending_tool_stats.setdefault( category, @@ -269,9 +275,30 @@ class StreamingHandler: "targets": set(), }, ) - bucket["count"] += 1 - if target: - bucket["targets"].add(str(target)) + if category == "subagent" and target_values: + bucket["count"] += len(target_values) + else: + bucket["count"] += 1 + for target_value in target_values: + bucket["targets"].add(str(target_value)) + + @staticmethod + def _extract_subagent_targets(tool_kwargs: dict[str, Any]) -> list[str]: + """提取子代理工具请求中的目标子代理类型。""" + tasks = tool_kwargs.get("tasks") + if not isinstance(tasks, list): + subagent_type = tool_kwargs.get("subagent_type") + return [str(subagent_type)] if subagent_type else [] + + targets = [] + for task in tasks: + if isinstance(task, dict): + subagent_type = task.get("subagent_type") + else: + subagent_type = getattr(task, "subagent_type", None) + if subagent_type: + targets.append(str(subagent_type)) + return targets def flush_pending_tool_summary(self) -> str: """ @@ -288,11 +315,17 @@ class StreamingHandler: tool_name: str, tool_message: Optional[str], tool_kwargs: dict[str, Any], - ) -> tuple[str, Optional[str]]: + ) -> tuple[str, Optional[Any]]: tool_name = (tool_name or "").strip().lower() tool_message = (tool_message or "").strip() tool_message_lower = tool_message.lower() + if tool_name == "skill": + return "skill", tool_kwargs.get("name") + if tool_name == "query_activity_log": + return "activity_log", tool_kwargs.get("keyword") or tool_kwargs.get("date") + if tool_name == "subagent_task": + return "subagent", StreamingHandler._extract_subagent_targets(tool_kwargs) if tool_name == "task": return "subagent", tool_kwargs.get("subagent_type") if tool_name == "read_file": @@ -369,7 +402,7 @@ class StreamingHandler: parts = [] for category, bucket in self._pending_tool_stats.items(): value = bucket["count"] - if category in {"file_read", "file_write", "directory", "web_browse"} and bucket["targets"]: + if category in {"file_read", "file_write", "directory", "web_browse", "skill"} and bucket["targets"]: value = len(bucket["targets"]) part = self._format_tool_stat(category, value) if part: @@ -406,6 +439,10 @@ class StreamingHandler: return f"执行了 {count} 条命令" if category == "data_query": return f"查询了 {count} 次数据" + if category == "skill": + return f"查询了 {count} 个技能说明" + if category == "activity_log": + return f"查询了 {count} 次活动日志" if category == "action": return f"执行了 {count} 次操作" if category == "interaction": diff --git a/tests/test_agent_summarization_streaming.py b/tests/test_agent_summarization_streaming.py index cf25e36f..a7a4afac 100644 --- a/tests/test_agent_summarization_streaming.py +++ b/tests/test_agent_summarization_streaming.py @@ -1,5 +1,4 @@ import asyncio -import unittest from unittest.mock import patch from langchain.agents.middleware import SummarizationMiddleware @@ -16,180 +15,197 @@ class _FakeLLM: self.profile = {"max_input_tokens": 64000} -class TestAgentSummarizationStreaming(unittest.TestCase): - def test_streaming_agent_uses_non_streaming_llm_for_summary(self): - agent = agent_module.MoviePilotAgent(session_id="session-1", user_id="10001") - main_llm = _FakeLLM("main") - non_streaming_llm = _FakeLLM("non-streaming") - captured: dict = {} +def test_streaming_agent_uses_non_streaming_llm_for_summary(): + """流式 Agent 的摘要中间件应使用非流式 LLM。""" + agent = agent_module.MoviePilotAgent(session_id="session-1", user_id="10001") + main_llm = _FakeLLM("main") + non_streaming_llm = _FakeLLM("non-streaming") + captured: dict = {} - def _fake_create_agent(**kwargs): - captured.update(kwargs) - return object() + def _fake_create_agent(**kwargs): + """捕获 create_agent 参数。""" + captured.update(kwargs) + return object() - with ( - patch.object( - agent, "_initialize_llm", side_effect=[main_llm, non_streaming_llm] - ), - patch.object(agent, "_initialize_tools", return_value=[]), - patch.object( - agent_module.prompt_manager, "get_agent_prompt", return_value="prompt" - ), - patch.object( - agent_module, "create_subagent_middlewares", return_value=([], []) - ), - patch.object(agent_module, "create_agent", side_effect=_fake_create_agent), - patch.object(agent_module.settings, "LLM_MAX_TOOLS", 0), + with ( + patch.object( + agent, "_initialize_llm", side_effect=[main_llm, non_streaming_llm] + ), + patch.object(agent, "_initialize_tools", return_value=[]), + patch.object( + agent_module.prompt_manager, "get_agent_prompt", return_value="prompt" + ), + patch.object( + agent_module, "create_subagent_middlewares", return_value=([], []) + ), + patch.object(agent_module, "create_agent", side_effect=_fake_create_agent), + patch.object(agent_module.settings, "LLM_MAX_TOOLS", 0), + ): + asyncio.run(agent._create_agent(streaming=True)) + + summary_middleware = next( + middleware + for middleware in captured["middleware"] + if isinstance(middleware, SummarizationMiddleware) + ) + + assert captured["model"] is main_llm + assert summary_middleware.model is non_streaming_llm + + +def test_streaming_agent_uses_non_streaming_llm_for_model_middlewares(): + """流式 Agent 的模型型中间件应使用非流式 LLM。""" + agent = agent_module.MoviePilotAgent(session_id="session-1", user_id="10001") + main_llm = _FakeLLM("main") + non_streaming_llm = _FakeLLM("non-streaming") + captured: dict = {} + + class _FakeToolSelectorMiddleware: + """记录工具选择中间件初始化参数。""" + + def __init__( + self, + model, + max_tools, + always_include=None, + selection_tools=None, ): - asyncio.run(agent._create_agent(streaming=True)) + """保存测试断言需要的参数。""" + self.model = model + self.max_tools = max_tools + self.always_include = always_include or [] + self.selection_tools = selection_tools or [] - summary_middleware = next( - middleware - for middleware in captured["middleware"] - if isinstance(middleware, SummarizationMiddleware) - ) + def _fake_create_agent(**kwargs): + """捕获 create_agent 参数。""" + captured.update(kwargs) + return object() - self.assertIs(captured["model"], main_llm) - self.assertIs(summary_middleware.model, non_streaming_llm) + class _FakeTool: + """测试用工具占位对象。""" - def test_streaming_agent_uses_non_streaming_llm_for_model_middlewares(self): - agent = agent_module.MoviePilotAgent(session_id="session-1", user_id="10001") - main_llm = _FakeLLM("main") - non_streaming_llm = _FakeLLM("non-streaming") - captured: dict = {} + def __init__(self, name: str): + """保存工具名。""" + self.name = name - class _FakeToolSelectorMiddleware: - def __init__( - self, - model, - max_tools, - always_include=None, - selection_tools=None, - ): - self.model = model - self.max_tools = max_tools - self.always_include = always_include or [] - self.selection_tools = selection_tools or [] + fake_tools = [ + _FakeTool("list_directory"), + _FakeTool("write_file"), + _FakeTool("read_file"), + _FakeTool("edit_file"), + _FakeTool("execute_command"), + _FakeTool("search_media"), + ] - def _fake_create_agent(**kwargs): - captured.update(kwargs) - return object() + with ( + patch.object( + agent, "_initialize_llm", side_effect=[main_llm, non_streaming_llm] + ), + patch.object(agent, "_initialize_tools", return_value=fake_tools), + patch.object( + agent_module.prompt_manager, "get_agent_prompt", return_value="prompt" + ), + patch.object( + agent_module, "create_subagent_middlewares", return_value=([], []) + ), + patch.object( + agent_module, + "ToolSelectorMiddleware", + _FakeToolSelectorMiddleware, + ), + patch.object(agent_module, "create_agent", side_effect=_fake_create_agent), + patch.object(agent_module.settings, "LLM_MAX_TOOLS", 3), + ): + asyncio.run(agent._create_agent(streaming=True)) - class _FakeTool: - def __init__(self, name: str): - self.name = name + tool_selector_middleware = next( + middleware + for middleware in captured["middleware"] + if isinstance(middleware, _FakeToolSelectorMiddleware) + ) - fake_tools = [ - _FakeTool("list_directory"), - _FakeTool("write_file"), - _FakeTool("read_file"), - _FakeTool("edit_file"), - _FakeTool("execute_command"), - _FakeTool("search_media"), - ] + assert tool_selector_middleware.model is non_streaming_llm + assert tool_selector_middleware.max_tools == 3 + assert tool_selector_middleware.always_include == [ + "list_directory", + "write_file", + "read_file", + "edit_file", + "execute_command", + "skill", + ] + assert tool_selector_middleware.selection_tools[: len(fake_tools)] == fake_tools + assert [ + getattr(tool, "name", None) + for tool in tool_selector_middleware.selection_tools[len(fake_tools):] + ] == ["skill"] - with ( - patch.object( - agent, "_initialize_llm", side_effect=[main_llm, non_streaming_llm] - ), - patch.object(agent, "_initialize_tools", return_value=fake_tools), - patch.object( - agent_module.prompt_manager, "get_agent_prompt", return_value="prompt" - ), - patch.object( - agent_module, "create_subagent_middlewares", return_value=([], []) - ), - patch.object( - agent_module, - "ToolSelectorMiddleware", - _FakeToolSelectorMiddleware, - ), - patch.object(agent_module, "create_agent", side_effect=_fake_create_agent), - patch.object(agent_module.settings, "LLM_MAX_TOOLS", 3), - ): - asyncio.run(agent._create_agent(streaming=True)) - tool_selector_middleware = next( - middleware - for middleware in captured["middleware"] - if isinstance(middleware, _FakeToolSelectorMiddleware) - ) +def test_non_streaming_agent_reuses_main_llm_for_summary(): + """非流式 Agent 的摘要中间件应复用主 LLM。""" + agent = agent_module.MoviePilotAgent(session_id="session-1", user_id="10001") + main_llm = _FakeLLM("main") + captured: dict = {} - self.assertIs(tool_selector_middleware.model, non_streaming_llm) - self.assertEqual(tool_selector_middleware.max_tools, 3) - self.assertEqual( - tool_selector_middleware.always_include, - [ - "list_directory", - "write_file", - "read_file", - "edit_file", - "execute_command", - ], - ) - self.assertEqual(tool_selector_middleware.selection_tools, fake_tools) + def _fake_create_agent(**kwargs): + """捕获 create_agent 参数。""" + captured.update(kwargs) + return object() - def test_non_streaming_agent_reuses_main_llm_for_summary(self): - agent = agent_module.MoviePilotAgent(session_id="session-1", user_id="10001") - main_llm = _FakeLLM("main") - captured: dict = {} + with ( + patch.object(agent, "_initialize_llm", return_value=main_llm), + patch.object(agent, "_initialize_tools", return_value=[]), + patch.object( + agent_module.prompt_manager, "get_agent_prompt", return_value="prompt" + ), + patch.object( + agent_module, "create_subagent_middlewares", return_value=([], []) + ), + patch.object(agent_module, "create_agent", side_effect=_fake_create_agent), + patch.object(agent_module.settings, "LLM_MAX_TOOLS", 0), + ): + asyncio.run(agent._create_agent(streaming=False)) - def _fake_create_agent(**kwargs): - captured.update(kwargs) - return object() + summary_middleware = next( + middleware + for middleware in captured["middleware"] + if isinstance(middleware, SummarizationMiddleware) + ) - with ( - patch.object(agent, "_initialize_llm", return_value=main_llm), - patch.object(agent, "_initialize_tools", return_value=[]), - patch.object( - agent_module.prompt_manager, "get_agent_prompt", return_value="prompt" - ), - patch.object( - agent_module, "create_subagent_middlewares", return_value=([], []) - ), - patch.object(agent_module, "create_agent", side_effect=_fake_create_agent), - patch.object(agent_module.settings, "LLM_MAX_TOOLS", 0), - ): - asyncio.run(agent._create_agent(streaming=False)) + assert captured["model"] is main_llm + assert summary_middleware.model is main_llm - summary_middleware = next( - middleware - for middleware in captured["middleware"] - if isinstance(middleware, SummarizationMiddleware) - ) - self.assertIs(captured["model"], main_llm) - self.assertIs(summary_middleware.model, main_llm) +def test_agent_uses_runtime_config_middleware_instead_of_hooks(): + """Agent 应使用运行时配置中间件而不是旧 hooks。""" + agent = agent_module.MoviePilotAgent(session_id="session-1", user_id="10001") + main_llm = _FakeLLM("main") + captured: dict = {} - def test_agent_uses_runtime_config_middleware_instead_of_hooks(self): - agent = agent_module.MoviePilotAgent(session_id="session-1", user_id="10001") - main_llm = _FakeLLM("main") - captured: dict = {} + def _fake_create_agent(**kwargs): + """捕获 create_agent 参数。""" + captured.update(kwargs) + return object() - def _fake_create_agent(**kwargs): - captured.update(kwargs) - return object() + with ( + patch.object(agent, "_initialize_llm", return_value=main_llm), + patch.object(agent, "_initialize_tools", return_value=[]), + patch.object( + agent_module.prompt_manager, "get_agent_prompt", return_value="prompt" + ), + patch.object( + agent_module, "create_subagent_middlewares", return_value=([], []) + ), + patch.object(agent_module, "create_agent", side_effect=_fake_create_agent), + patch.object(agent_module.settings, "LLM_MAX_TOOLS", 0), + ): + asyncio.run(agent._create_agent(streaming=False)) - with ( - patch.object(agent, "_initialize_llm", return_value=main_llm), - patch.object(agent, "_initialize_tools", return_value=[]), - patch.object( - agent_module.prompt_manager, "get_agent_prompt", return_value="prompt" - ), - patch.object( - agent_module, "create_subagent_middlewares", return_value=([], []) - ), - patch.object(agent_module, "create_agent", side_effect=_fake_create_agent), - patch.object(agent_module.settings, "LLM_MAX_TOOLS", 0), - ): - asyncio.run(agent._create_agent(streaming=False)) - - self.assertTrue( - any( - isinstance(middleware, RuntimeConfigMiddleware) - for middleware in captured["middleware"] - ) - ) - self.assertFalse( - any(type(middleware).__name__ == "AgentHooksMiddleware" for middleware in captured["middleware"]) - ) + assert any( + isinstance(middleware, RuntimeConfigMiddleware) + for middleware in captured["middleware"] + ) + assert not any( + type(middleware).__name__ == "AgentHooksMiddleware" + for middleware in captured["middleware"] + ) diff --git a/tests/test_agent_tool_streaming.py b/tests/test_agent_tool_streaming.py index e2683f0a..f7adc414 100644 --- a/tests/test_agent_tool_streaming.py +++ b/tests/test_agent_tool_streaming.py @@ -141,6 +141,56 @@ class TestAgentToolStreaming: assert buffered_message == "处理中:\n\n(已调用 2 个子代理)\n\n" + def test_non_verbose_tool_summary_describes_skill_lookup(self): + """校验非详细模式单独描述 Skill 说明查询。""" + async def _run(): + handler = StreamingHandler() + await handler.start_streaming() + handler.emit("处理中:") + handler.record_tool_call( + tool_name="skill", + tool_message="Loads the full instructions for a MoviePilot skill", + tool_kwargs={"name": "moviepilot-cli"}, + ) + handler.record_tool_call( + tool_name="skill", + tool_message="Loads the full instructions for a MoviePilot skill", + tool_kwargs={"name": "moviepilot-cli"}, + ) + handler.record_tool_call( + tool_name="query_activity_log", + tool_message="Query recent MoviePilot Agent activity logs", + tool_kwargs={"keyword": "整理"}, + ) + return await handler.take() + + buffered_message = asyncio.run(_run()) + + assert buffered_message == "处理中:\n\n(查询了 1 个技能说明,查询了 1 次活动日志)\n\n" + + def test_non_verbose_tool_summary_counts_subagent_batch_tasks(self): + """校验批量子代理控制工具按子任务数统计。""" + async def _run(): + handler = StreamingHandler() + await handler.start_streaming() + handler.emit("处理中:") + handler.record_tool_call( + tool_name="subagent_task", + tool_message="Start and manage multiple MoviePilot subagent tasks", + tool_kwargs={ + "action": "start", + "tasks": [ + {"subagent_type": "media-researcher"}, + {"subagent_type": "download-diagnostician"}, + ], + }, + ) + return await handler.take() + + buffered_message = asyncio.run(_run()) + + assert buffered_message == "处理中:\n\n(已调用 2 个子代理)\n\n" + def test_subagent_stream_metadata_is_suppressed(self): """校验子代理流式元数据会被识别并抑制。""" assert is_subagent_stream_metadata(