Improve streaming tool summary counts

This commit is contained in:
jxxghp
2026-06-21 18:49:27 +08:00
parent 6647565ec4
commit 78ddd6093f
3 changed files with 265 additions and 162 deletions

View File

@@ -261,6 +261,12 @@ class StreamingHandler:
tool_message=tool_message,
tool_kwargs=tool_kwargs or {},
)
target_values = []
if isinstance(target, (list, tuple, set)):
target_values = [item for item in target if item]
elif target:
target_values = [target]
with self._lock:
bucket = self._pending_tool_stats.setdefault(
category,
@@ -269,9 +275,30 @@ class StreamingHandler:
"targets": set(),
},
)
bucket["count"] += 1
if target:
bucket["targets"].add(str(target))
if category == "subagent" and target_values:
bucket["count"] += len(target_values)
else:
bucket["count"] += 1
for target_value in target_values:
bucket["targets"].add(str(target_value))
@staticmethod
def _extract_subagent_targets(tool_kwargs: dict[str, Any]) -> list[str]:
"""提取子代理工具请求中的目标子代理类型。"""
tasks = tool_kwargs.get("tasks")
if not isinstance(tasks, list):
subagent_type = tool_kwargs.get("subagent_type")
return [str(subagent_type)] if subagent_type else []
targets = []
for task in tasks:
if isinstance(task, dict):
subagent_type = task.get("subagent_type")
else:
subagent_type = getattr(task, "subagent_type", None)
if subagent_type:
targets.append(str(subagent_type))
return targets
def flush_pending_tool_summary(self) -> str:
"""
@@ -288,11 +315,17 @@ class StreamingHandler:
tool_name: str,
tool_message: Optional[str],
tool_kwargs: dict[str, Any],
) -> tuple[str, Optional[str]]:
) -> tuple[str, Optional[Any]]:
tool_name = (tool_name or "").strip().lower()
tool_message = (tool_message or "").strip()
tool_message_lower = tool_message.lower()
if tool_name == "skill":
return "skill", tool_kwargs.get("name")
if tool_name == "query_activity_log":
return "activity_log", tool_kwargs.get("keyword") or tool_kwargs.get("date")
if tool_name == "subagent_task":
return "subagent", StreamingHandler._extract_subagent_targets(tool_kwargs)
if tool_name == "task":
return "subagent", tool_kwargs.get("subagent_type")
if tool_name == "read_file":
@@ -369,7 +402,7 @@ class StreamingHandler:
parts = []
for category, bucket in self._pending_tool_stats.items():
value = bucket["count"]
if category in {"file_read", "file_write", "directory", "web_browse"} and bucket["targets"]:
if category in {"file_read", "file_write", "directory", "web_browse", "skill"} and bucket["targets"]:
value = len(bucket["targets"])
part = self._format_tool_stat(category, value)
if part:
@@ -406,6 +439,10 @@ class StreamingHandler:
return f"执行了 {count} 条命令"
if category == "data_query":
return f"查询了 {count} 次数据"
if category == "skill":
return f"查询了 {count} 个技能说明"
if category == "activity_log":
return f"查询了 {count} 次活动日志"
if category == "action":
return f"执行了 {count} 次操作"
if category == "interaction":

View File

@@ -1,5 +1,4 @@
import asyncio
import unittest
from unittest.mock import patch
from langchain.agents.middleware import SummarizationMiddleware
@@ -16,180 +15,197 @@ class _FakeLLM:
self.profile = {"max_input_tokens": 64000}
class TestAgentSummarizationStreaming(unittest.TestCase):
def test_streaming_agent_uses_non_streaming_llm_for_summary(self):
agent = agent_module.MoviePilotAgent(session_id="session-1", user_id="10001")
main_llm = _FakeLLM("main")
non_streaming_llm = _FakeLLM("non-streaming")
captured: dict = {}
def test_streaming_agent_uses_non_streaming_llm_for_summary():
"""流式 Agent 的摘要中间件应使用非流式 LLM。"""
agent = agent_module.MoviePilotAgent(session_id="session-1", user_id="10001")
main_llm = _FakeLLM("main")
non_streaming_llm = _FakeLLM("non-streaming")
captured: dict = {}
def _fake_create_agent(**kwargs):
captured.update(kwargs)
return object()
def _fake_create_agent(**kwargs):
"""捕获 create_agent 参数。"""
captured.update(kwargs)
return object()
with (
patch.object(
agent, "_initialize_llm", side_effect=[main_llm, non_streaming_llm]
),
patch.object(agent, "_initialize_tools", return_value=[]),
patch.object(
agent_module.prompt_manager, "get_agent_prompt", return_value="prompt"
),
patch.object(
agent_module, "create_subagent_middlewares", return_value=([], [])
),
patch.object(agent_module, "create_agent", side_effect=_fake_create_agent),
patch.object(agent_module.settings, "LLM_MAX_TOOLS", 0),
with (
patch.object(
agent, "_initialize_llm", side_effect=[main_llm, non_streaming_llm]
),
patch.object(agent, "_initialize_tools", return_value=[]),
patch.object(
agent_module.prompt_manager, "get_agent_prompt", return_value="prompt"
),
patch.object(
agent_module, "create_subagent_middlewares", return_value=([], [])
),
patch.object(agent_module, "create_agent", side_effect=_fake_create_agent),
patch.object(agent_module.settings, "LLM_MAX_TOOLS", 0),
):
asyncio.run(agent._create_agent(streaming=True))
summary_middleware = next(
middleware
for middleware in captured["middleware"]
if isinstance(middleware, SummarizationMiddleware)
)
assert captured["model"] is main_llm
assert summary_middleware.model is non_streaming_llm
def test_streaming_agent_uses_non_streaming_llm_for_model_middlewares():
"""流式 Agent 的模型型中间件应使用非流式 LLM。"""
agent = agent_module.MoviePilotAgent(session_id="session-1", user_id="10001")
main_llm = _FakeLLM("main")
non_streaming_llm = _FakeLLM("non-streaming")
captured: dict = {}
class _FakeToolSelectorMiddleware:
"""记录工具选择中间件初始化参数。"""
def __init__(
self,
model,
max_tools,
always_include=None,
selection_tools=None,
):
asyncio.run(agent._create_agent(streaming=True))
"""保存测试断言需要的参数。"""
self.model = model
self.max_tools = max_tools
self.always_include = always_include or []
self.selection_tools = selection_tools or []
summary_middleware = next(
middleware
for middleware in captured["middleware"]
if isinstance(middleware, SummarizationMiddleware)
)
def _fake_create_agent(**kwargs):
"""捕获 create_agent 参数。"""
captured.update(kwargs)
return object()
self.assertIs(captured["model"], main_llm)
self.assertIs(summary_middleware.model, non_streaming_llm)
class _FakeTool:
"""测试用工具占位对象。"""
def test_streaming_agent_uses_non_streaming_llm_for_model_middlewares(self):
agent = agent_module.MoviePilotAgent(session_id="session-1", user_id="10001")
main_llm = _FakeLLM("main")
non_streaming_llm = _FakeLLM("non-streaming")
captured: dict = {}
def __init__(self, name: str):
"""保存工具名。"""
self.name = name
class _FakeToolSelectorMiddleware:
def __init__(
self,
model,
max_tools,
always_include=None,
selection_tools=None,
):
self.model = model
self.max_tools = max_tools
self.always_include = always_include or []
self.selection_tools = selection_tools or []
fake_tools = [
_FakeTool("list_directory"),
_FakeTool("write_file"),
_FakeTool("read_file"),
_FakeTool("edit_file"),
_FakeTool("execute_command"),
_FakeTool("search_media"),
]
def _fake_create_agent(**kwargs):
captured.update(kwargs)
return object()
with (
patch.object(
agent, "_initialize_llm", side_effect=[main_llm, non_streaming_llm]
),
patch.object(agent, "_initialize_tools", return_value=fake_tools),
patch.object(
agent_module.prompt_manager, "get_agent_prompt", return_value="prompt"
),
patch.object(
agent_module, "create_subagent_middlewares", return_value=([], [])
),
patch.object(
agent_module,
"ToolSelectorMiddleware",
_FakeToolSelectorMiddleware,
),
patch.object(agent_module, "create_agent", side_effect=_fake_create_agent),
patch.object(agent_module.settings, "LLM_MAX_TOOLS", 3),
):
asyncio.run(agent._create_agent(streaming=True))
class _FakeTool:
def __init__(self, name: str):
self.name = name
tool_selector_middleware = next(
middleware
for middleware in captured["middleware"]
if isinstance(middleware, _FakeToolSelectorMiddleware)
)
fake_tools = [
_FakeTool("list_directory"),
_FakeTool("write_file"),
_FakeTool("read_file"),
_FakeTool("edit_file"),
_FakeTool("execute_command"),
_FakeTool("search_media"),
]
assert tool_selector_middleware.model is non_streaming_llm
assert tool_selector_middleware.max_tools == 3
assert tool_selector_middleware.always_include == [
"list_directory",
"write_file",
"read_file",
"edit_file",
"execute_command",
"skill",
]
assert tool_selector_middleware.selection_tools[: len(fake_tools)] == fake_tools
assert [
getattr(tool, "name", None)
for tool in tool_selector_middleware.selection_tools[len(fake_tools):]
] == ["skill"]
with (
patch.object(
agent, "_initialize_llm", side_effect=[main_llm, non_streaming_llm]
),
patch.object(agent, "_initialize_tools", return_value=fake_tools),
patch.object(
agent_module.prompt_manager, "get_agent_prompt", return_value="prompt"
),
patch.object(
agent_module, "create_subagent_middlewares", return_value=([], [])
),
patch.object(
agent_module,
"ToolSelectorMiddleware",
_FakeToolSelectorMiddleware,
),
patch.object(agent_module, "create_agent", side_effect=_fake_create_agent),
patch.object(agent_module.settings, "LLM_MAX_TOOLS", 3),
):
asyncio.run(agent._create_agent(streaming=True))
tool_selector_middleware = next(
middleware
for middleware in captured["middleware"]
if isinstance(middleware, _FakeToolSelectorMiddleware)
)
def test_non_streaming_agent_reuses_main_llm_for_summary():
"""非流式 Agent 的摘要中间件应复用主 LLM。"""
agent = agent_module.MoviePilotAgent(session_id="session-1", user_id="10001")
main_llm = _FakeLLM("main")
captured: dict = {}
self.assertIs(tool_selector_middleware.model, non_streaming_llm)
self.assertEqual(tool_selector_middleware.max_tools, 3)
self.assertEqual(
tool_selector_middleware.always_include,
[
"list_directory",
"write_file",
"read_file",
"edit_file",
"execute_command",
],
)
self.assertEqual(tool_selector_middleware.selection_tools, fake_tools)
def _fake_create_agent(**kwargs):
"""捕获 create_agent 参数。"""
captured.update(kwargs)
return object()
def test_non_streaming_agent_reuses_main_llm_for_summary(self):
agent = agent_module.MoviePilotAgent(session_id="session-1", user_id="10001")
main_llm = _FakeLLM("main")
captured: dict = {}
with (
patch.object(agent, "_initialize_llm", return_value=main_llm),
patch.object(agent, "_initialize_tools", return_value=[]),
patch.object(
agent_module.prompt_manager, "get_agent_prompt", return_value="prompt"
),
patch.object(
agent_module, "create_subagent_middlewares", return_value=([], [])
),
patch.object(agent_module, "create_agent", side_effect=_fake_create_agent),
patch.object(agent_module.settings, "LLM_MAX_TOOLS", 0),
):
asyncio.run(agent._create_agent(streaming=False))
def _fake_create_agent(**kwargs):
captured.update(kwargs)
return object()
summary_middleware = next(
middleware
for middleware in captured["middleware"]
if isinstance(middleware, SummarizationMiddleware)
)
with (
patch.object(agent, "_initialize_llm", return_value=main_llm),
patch.object(agent, "_initialize_tools", return_value=[]),
patch.object(
agent_module.prompt_manager, "get_agent_prompt", return_value="prompt"
),
patch.object(
agent_module, "create_subagent_middlewares", return_value=([], [])
),
patch.object(agent_module, "create_agent", side_effect=_fake_create_agent),
patch.object(agent_module.settings, "LLM_MAX_TOOLS", 0),
):
asyncio.run(agent._create_agent(streaming=False))
assert captured["model"] is main_llm
assert summary_middleware.model is main_llm
summary_middleware = next(
middleware
for middleware in captured["middleware"]
if isinstance(middleware, SummarizationMiddleware)
)
self.assertIs(captured["model"], main_llm)
self.assertIs(summary_middleware.model, main_llm)
def test_agent_uses_runtime_config_middleware_instead_of_hooks():
"""Agent 应使用运行时配置中间件而不是旧 hooks。"""
agent = agent_module.MoviePilotAgent(session_id="session-1", user_id="10001")
main_llm = _FakeLLM("main")
captured: dict = {}
def test_agent_uses_runtime_config_middleware_instead_of_hooks(self):
agent = agent_module.MoviePilotAgent(session_id="session-1", user_id="10001")
main_llm = _FakeLLM("main")
captured: dict = {}
def _fake_create_agent(**kwargs):
"""捕获 create_agent 参数。"""
captured.update(kwargs)
return object()
def _fake_create_agent(**kwargs):
captured.update(kwargs)
return object()
with (
patch.object(agent, "_initialize_llm", return_value=main_llm),
patch.object(agent, "_initialize_tools", return_value=[]),
patch.object(
agent_module.prompt_manager, "get_agent_prompt", return_value="prompt"
),
patch.object(
agent_module, "create_subagent_middlewares", return_value=([], [])
),
patch.object(agent_module, "create_agent", side_effect=_fake_create_agent),
patch.object(agent_module.settings, "LLM_MAX_TOOLS", 0),
):
asyncio.run(agent._create_agent(streaming=False))
with (
patch.object(agent, "_initialize_llm", return_value=main_llm),
patch.object(agent, "_initialize_tools", return_value=[]),
patch.object(
agent_module.prompt_manager, "get_agent_prompt", return_value="prompt"
),
patch.object(
agent_module, "create_subagent_middlewares", return_value=([], [])
),
patch.object(agent_module, "create_agent", side_effect=_fake_create_agent),
patch.object(agent_module.settings, "LLM_MAX_TOOLS", 0),
):
asyncio.run(agent._create_agent(streaming=False))
self.assertTrue(
any(
isinstance(middleware, RuntimeConfigMiddleware)
for middleware in captured["middleware"]
)
)
self.assertFalse(
any(type(middleware).__name__ == "AgentHooksMiddleware" for middleware in captured["middleware"])
)
assert any(
isinstance(middleware, RuntimeConfigMiddleware)
for middleware in captured["middleware"]
)
assert not any(
type(middleware).__name__ == "AgentHooksMiddleware"
for middleware in captured["middleware"]
)

View File

@@ -141,6 +141,56 @@ class TestAgentToolStreaming:
assert buffered_message == "处理中:\n\n(已调用 2 个子代理)\n\n"
def test_non_verbose_tool_summary_describes_skill_lookup(self):
"""校验非详细模式单独描述 Skill 说明查询。"""
async def _run():
handler = StreamingHandler()
await handler.start_streaming()
handler.emit("处理中:")
handler.record_tool_call(
tool_name="skill",
tool_message="Loads the full instructions for a MoviePilot skill",
tool_kwargs={"name": "moviepilot-cli"},
)
handler.record_tool_call(
tool_name="skill",
tool_message="Loads the full instructions for a MoviePilot skill",
tool_kwargs={"name": "moviepilot-cli"},
)
handler.record_tool_call(
tool_name="query_activity_log",
tool_message="Query recent MoviePilot Agent activity logs",
tool_kwargs={"keyword": "整理"},
)
return await handler.take()
buffered_message = asyncio.run(_run())
assert buffered_message == "处理中:\n\n(查询了 1 个技能说明,查询了 1 次活动日志)\n\n"
def test_non_verbose_tool_summary_counts_subagent_batch_tasks(self):
"""校验批量子代理控制工具按子任务数统计。"""
async def _run():
handler = StreamingHandler()
await handler.start_streaming()
handler.emit("处理中:")
handler.record_tool_call(
tool_name="subagent_task",
tool_message="Start and manage multiple MoviePilot subagent tasks",
tool_kwargs={
"action": "start",
"tasks": [
{"subagent_type": "media-researcher"},
{"subagent_type": "download-diagnostician"},
],
},
)
return await handler.take()
buffered_message = asyncio.run(_run())
assert buffered_message == "处理中:\n\n(已调用 2 个子代理)\n\n"
def test_subagent_stream_metadata_is_suppressed(self):
"""校验子代理流式元数据会被识别并抑制。"""
assert is_subagent_stream_metadata(