diff --git a/app/agent/middleware/tool_selection.py b/app/agent/middleware/tool_selection.py index 875ddc76..16a6b07b 100644 --- a/app/agent/middleware/tool_selection.py +++ b/app/agent/middleware/tool_selection.py @@ -328,7 +328,7 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware): request: ModelRequest[ContextT], ) -> ModelRequest[ContextT]: """ - 处理工具筛选响应,并保留空结果回退所有工具的 MoviePilot 策略。 + 处理工具筛选响应,并在正常空结果时禁用可筛选工具。 """ if response.get("tools") == []: always_included_tools: list[BaseTool] = [ @@ -337,10 +337,7 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware): if not isinstance(tool, dict) and tool.name in self.always_include ] provider_tools = [tool for tool in request.tools if isinstance(tool, dict)] - - return request.override( - tools=[*available_tools, *always_included_tools, *provider_tools] - ) + return request.override(tools=[*always_included_tools, *provider_tools]) response["tools"] = self._complete_low_count_selection( selected_tool_names=[ @@ -427,7 +424,7 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware): """ limit_instruction = "" if self.max_tools: - limit_instruction = f"- Select up to {self.max_tools} tools. IF NO TOOLS ARE RELEVANT, DO NOT RETURN AN EMPTY ARRAY. SELECT THE MOST APPLICABLE ONES TO ENSURE THE REQUEST IS HANDLED." + limit_instruction = f"- Select up to {self.max_tools} tools. Return an empty array if no tools are relevant." return ( f"{selection_request.system_message}\n\n" @@ -502,9 +499,6 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware): selected_text = ", ".join(attempt.selected_tool_names) or "无有效工具" logger.info(f"工具筛选结果: {selected_text}") return - if attempt.status == "empty_fallback": - logger.info(f"工具筛选结果为空,将恢复使用所有工具(共 {tool_count} 个)。") - return if attempt.status == "failed_fallback": logger.warning( f"工具筛选失败,将恢复使用所有工具(共 {tool_count} 个): {attempt.detail}" @@ -528,9 +522,6 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware): 这里只复用首次筛选出的客户端工具名;provider-specific 的 dict 工具仍然 原样保留,避免破坏 LangChain/provider 自身的工具绑定约定。 """ - if not selected_tool_names: - return request - current_tools_by_name = { tool.name: tool for tool in request.tools if not isinstance(tool, dict) } @@ -576,15 +567,10 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware): selection_request.valid_tool_names, request, ) - status = ( - "empty_fallback" - if response.get("tools") == [] - else "selected" - ) return _ToolSelectionAttempt( request=modified_request, selected_tool_names=self._extract_selected_tool_names(modified_request), - status=status, + status="selected", ) except Exception as err: return _ToolSelectionAttempt( @@ -650,7 +636,7 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware): attempt = await self._aselect_request_once_with_status(selection_request) self._log_selection_attempt(attempt) selected_tool_names = attempt.selected_tool_names - return ToolSelectionStateUpdate(selected_tool_names=selected_tool_names or None) + return ToolSelectionStateUpdate(selected_tool_names=selected_tool_names) async def awrap_model_call( self, @@ -674,10 +660,10 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware): attempt = await self._aselect_request_once_with_status(request) self._log_selection_attempt(attempt) request = attempt.request - selected_tool_names = attempt.selected_tool_names or None + selected_tool_names = attempt.selected_tool_names request.state["selected_tool_names"] = selected_tool_names # noqa - if selected_tool_names: + if selected_tool_names is not None: request = self._apply_selected_tools(request, selected_tool_names) return await handler(request) diff --git a/tests/test_agent_tool_selector_middleware.py b/tests/test_agent_tool_selector_middleware.py index 2286eca9..c7718ee5 100644 --- a/tests/test_agent_tool_selector_middleware.py +++ b/tests/test_agent_tool_selector_middleware.py @@ -292,8 +292,8 @@ def test_tool_selection_failure_falls_back_to_all_tools(): assert state_update == {"selected_tool_names": ["search", "calendar"]} -def test_empty_tool_selection_logs_info_not_warning(): - """工具筛选返回空数组时应按信息日志记录降级。""" +def test_empty_tool_selection_keeps_empty_tool_list(): + """工具筛选返回空数组时应保持空工具列表。""" tools = [ SimpleNamespace(name="search", description="Search for information"), SimpleNamespace(name="calendar", description="Manage events"), @@ -310,17 +310,55 @@ def test_empty_tool_selection_logs_info_not_warning(): model=model, ) + async def handler(updated_request): + return updated_request + with patch.object(tool_selector_module.logger, "info") as logger_info, \ patch.object(tool_selector_module.logger, "warning") as logger_warning: state_update = asyncio.run( middleware.abefore_agent(request.state, runtime=None, config=None) ) + request.state.update(state_update) + result = asyncio.run(middleware.awrap_model_call(request, handler)) - assert state_update == {"selected_tool_names": ["search", "calendar"]} - logger_info.assert_called_once_with("工具筛选结果为空,将恢复使用所有工具(共 2 个)。") + assert state_update == {"selected_tool_names": []} + assert result.tools == [] + logger_info.assert_called_once_with("工具筛选结果: 无有效工具") logger_warning.assert_not_called() +def test_empty_tool_selection_keeps_always_included_tools(): + """工具筛选返回空数组时仍应保留必须包括的工具。""" + tools = [ + SimpleNamespace(name="search", description="Search for information"), + SimpleNamespace(name="skill", description="Run skill"), + ] + model = _FakeModel(content='{"tools": []}') + middleware = tool_selector_module.ToolSelectorMiddleware( + max_tools=2, + selection_tools=tools, + always_include=["skill"], + ) + middleware.model = model + request = _FakeRequest( + tools=tools, + messages=[HumanMessage(content="不用工具,直接回答")], + model=model, + ) + + async def handler(updated_request): + return updated_request + + state_update = asyncio.run( + middleware.abefore_agent(request.state, runtime=None, config=None) + ) + request.state.update(state_update) + result = asyncio.run(middleware.awrap_model_call(request, handler)) + + assert state_update == {"selected_tool_names": ["skill"]} + assert [tool.name for tool in result.tools] == ["skill"] + + def test_abefore_agent_logs_selected_tools(): """工具筛选返回有效工具时应记录最终生效的工具名。""" tools = [