diff --git a/app/agent/middleware/tool_selection.py b/app/agent/middleware/tool_selection.py index 840e2550..875ddc76 100644 --- a/app/agent/middleware/tool_selection.py +++ b/app/agent/middleware/tool_selection.py @@ -1,6 +1,6 @@ """MoviePilot 自定义工具筛选中间件。""" -from dataclasses import replace +from dataclasses import dataclass, replace import json from collections.abc import Awaitable, Callable from typing import Annotated, Any, NotRequired @@ -68,6 +68,16 @@ class ToolSelectionStateUpdate(TypedDict): selected_tool_names: list[str] | None +@dataclass(frozen=True) +class _ToolSelectionAttempt: + """工具筛选尝试结果,用于统一记录最终日志。""" + + request: ModelRequest + selected_tool_names: list[str] + status: str + detail: str = "" + + class ToolSelectorMiddleware(LLMToolSelectorMiddleware): """ 使用 provider-neutral JSON 提示执行工具筛选。 @@ -321,8 +331,6 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware): 处理工具筛选响应,并保留空结果回退所有工具的 MoviePilot 策略。 """ if response.get("tools") == []: - logger.info("工具筛选结果为空,将恢复使用所有工具。") - always_included_tools: list[BaseTool] = [ tool for tool in request.tools @@ -349,8 +357,6 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware): valid_tool_names, request, ) - selected_tool_names = self._extract_selected_tool_names(modified_request) - logger.info(f"工具筛选结果: {', '.join(selected_tool_names) or '无有效工具'}") return modified_request @staticmethod @@ -483,6 +489,34 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware): """从已筛选后的请求中提取最终工具名,保留原有顺序。""" return [tool.name for tool in request.tools if not isinstance(tool, dict)] + @staticmethod + def _count_request_tools(request: ModelRequest) -> int: + """统计当前请求中的 LangChain 工具数量,不包含 provider 原生工具字典。""" + return len([tool for tool in request.tools if not isinstance(tool, dict)]) + + @classmethod + def _log_selection_attempt(cls, attempt: _ToolSelectionAttempt) -> None: + """按工具筛选最终状态记录稳定日志。""" + tool_count = cls._count_request_tools(attempt.request) + if attempt.status == "selected": + selected_text = ", ".join(attempt.selected_tool_names) or "无有效工具" + logger.info(f"工具筛选结果: {selected_text}") + return + if attempt.status == "empty_fallback": + logger.info(f"工具筛选结果为空,将恢复使用所有工具(共 {tool_count} 个)。") + return + if attempt.status == "failed_fallback": + logger.warning( + f"工具筛选失败,将恢复使用所有工具(共 {tool_count} 个): {attempt.detail}" + ) + return + if attempt.status == "skipped": + logger.info(f"工具筛选跳过: {attempt.detail}。") + return + if attempt.status == "reused": + selected_text = ", ".join(attempt.selected_tool_names) or "无有效工具" + logger.info(f"工具筛选复用已有结果: {selected_text}") + @staticmethod def _apply_selected_tools( request: ModelRequest[ContextT], @@ -517,21 +551,48 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware): 这里单独抽成 helper,便于首次筛选后缓存结果,也便于测试覆盖 “首轮筛选,后续复用”的行为。 """ + return (await self._aselect_request_once_with_status(request)).request + + async def _aselect_request_once_with_status( + self, request: ModelRequest[ContextT] + ) -> _ToolSelectionAttempt: + """ + 执行一次真实工具筛选,并携带最终状态供调用方统一记录日志。 + """ selection_request = self._prepare_selection_request(request) if selection_request is None: - return request + return _ToolSelectionAttempt( + request=request, + selected_tool_names=self._extract_selected_tool_names(request), + status="skipped", + detail="没有需要筛选的工具", + ) try: response = await self._aselect_tools_with_json_prompt(selection_request) - return self._process_selection_response( + modified_request = self._process_selection_response( response, selection_request.available_tools, selection_request.valid_tool_names, request, ) + status = ( + "empty_fallback" + if response.get("tools") == [] + else "selected" + ) + return _ToolSelectionAttempt( + request=modified_request, + selected_tool_names=self._extract_selected_tool_names(modified_request), + status=status, + ) except Exception as err: - logger.warning(f"工具筛选失败,将恢复使用所有工具: {str(err)}") - return request + return _ToolSelectionAttempt( + request=request, + selected_tool_names=self._extract_selected_tool_names(request), + status="failed_fallback", + detail=str(err), + ) async def abefore_agent( # noqa self, @@ -546,9 +607,37 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware): 不会为每次模型回合重复追加一笔 selector LLM 开销。 """ if "selected_tool_names" in state: + self._log_selection_attempt( + _ToolSelectionAttempt( + request=ModelRequest( + model=self.model, + tools=list(self.selection_tools), + messages=state["messages"], + state=state, + runtime=runtime, + ), + selected_tool_names=state.get("selected_tool_names") or [], + status="reused", + ) + ) return None if not self.selection_tools or self.model is None: + detail = "没有可筛选工具" if not self.selection_tools else "未配置筛选模型" + self._log_selection_attempt( + _ToolSelectionAttempt( + request=ModelRequest( + model=self.model, + tools=list(self.selection_tools), + messages=state["messages"], + state=state, + runtime=runtime, + ), + selected_tool_names=[], + status="skipped", + detail=detail, + ) + ) return ToolSelectionStateUpdate(selected_tool_names=None) selection_request = ModelRequest( @@ -558,8 +647,9 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware): state=state, runtime=runtime, ) - modified_request = await self._aselect_request_once(selection_request) - selected_tool_names = self._extract_selected_tool_names(modified_request) + attempt = await self._aselect_request_once_with_status(selection_request) + self._log_selection_attempt(attempt) + selected_tool_names = attempt.selected_tool_names return ToolSelectionStateUpdate(selected_tool_names=selected_tool_names or None) async def awrap_model_call( @@ -581,8 +671,10 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware): and self.selection_tools and self.model is not None ): - request = await self._aselect_request_once(request) - selected_tool_names = self._extract_selected_tool_names(request) or None + attempt = await self._aselect_request_once_with_status(request) + self._log_selection_attempt(attempt) + request = attempt.request + selected_tool_names = attempt.selected_tool_names or None request.state["selected_tool_names"] = selected_tool_names # noqa if selected_tool_names: diff --git a/tests/test_agent_tool_selector_middleware.py b/tests/test_agent_tool_selector_middleware.py index 47f82eb3..2286eca9 100644 --- a/tests/test_agent_tool_selector_middleware.py +++ b/tests/test_agent_tool_selector_middleware.py @@ -298,58 +298,70 @@ def test_empty_tool_selection_logs_info_not_warning(): SimpleNamespace(name="search", description="Search for information"), SimpleNamespace(name="calendar", description="Manage events"), ] + model = _FakeModel(content='{"tools": []}') middleware = tool_selector_module.ToolSelectorMiddleware( max_tools=2, selection_tools=tools, ) + middleware.model = model request = _FakeRequest( tools=tools, messages=[HumanMessage(content="帮我安排明天的行程并查天气")], - model=_FakeModel(), + model=model, ) with patch.object(tool_selector_module.logger, "info") as logger_info, \ patch.object(tool_selector_module.logger, "warning") as logger_warning: - result = middleware._process_selection_response( - {"tools": []}, - available_tools=tools, - valid_tool_names=[tool.name for tool in tools], - request=request, + state_update = asyncio.run( + middleware.abefore_agent(request.state, runtime=None, config=None) ) - assert [tool.name for tool in result.tools] == ["search", "calendar"] - logger_info.assert_called_once_with("工具筛选结果为空,将恢复使用所有工具。") + assert state_update == {"selected_tool_names": ["search", "calendar"]} + logger_info.assert_called_once_with("工具筛选结果为空,将恢复使用所有工具(共 2 个)。") logger_warning.assert_not_called() -def test_process_selection_response_logs_selected_tools(): +def test_abefore_agent_logs_selected_tools(): """工具筛选返回有效工具时应记录最终生效的工具名。""" tools = [ SimpleNamespace(name="search", description="Search for information"), SimpleNamespace(name="calendar", description="Manage events"), ] + model = _FakeModel(content='{"tools": ["calendar"]}') middleware = tool_selector_module.ToolSelectorMiddleware( max_tools=2, selection_tools=tools, ) + middleware.model = model request = _FakeRequest( tools=tools, messages=[HumanMessage(content="帮我安排明天的行程并查天气")], - model=_FakeModel(), + model=model, ) with patch.object(tool_selector_module.logger, "info") as logger_info: - result = middleware._process_selection_response( - {"tools": ["calendar"]}, - available_tools=tools, - valid_tool_names=[tool.name for tool in tools], - request=request, + state_update = asyncio.run( + middleware.abefore_agent(request.state, runtime=None, config=None) ) - assert [tool.name for tool in result.tools] == ["calendar"] + assert state_update == {"selected_tool_names": ["calendar"]} logger_info.assert_called_once_with("工具筛选结果: calendar") +def test_abefore_agent_logs_skipped_selection(): + """工具筛选未启用时也应记录跳过原因。""" + middleware = tool_selector_module.ToolSelectorMiddleware(selection_tools=[]) + request_state = {"messages": [HumanMessage(content="帮我安排明天的行程")]} + + with patch.object(tool_selector_module.logger, "info") as logger_info: + state_update = asyncio.run( + middleware.abefore_agent(request_state, runtime=None, config=None) + ) + + assert state_update == {"selected_tool_names": None} + logger_info.assert_called_once_with("工具筛选跳过: 没有可筛选工具。") + + def test_normalize_selection_response_accepts_code_fence_json(): """工具筛选响应应兼容 Markdown 代码围栏包裹的 JSON。""" middleware = tool_selector_module.ToolSelectorMiddleware()