Preserve empty tool selections

This commit is contained in:
jxxghp
2026-06-22 21:29:15 +08:00
parent f9ea0118d9
commit 3407cc8edd
2 changed files with 49 additions and 25 deletions

View File

@@ -328,7 +328,7 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware):
request: ModelRequest[ContextT],
) -> ModelRequest[ContextT]:
"""
处理工具筛选响应,并保留空结果回退所有工具的 MoviePilot 策略
处理工具筛选响应,并在正常空结果时禁用可筛选工具
"""
if response.get("tools") == []:
always_included_tools: list[BaseTool] = [
@@ -337,10 +337,7 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware):
if not isinstance(tool, dict) and tool.name in self.always_include
]
provider_tools = [tool for tool in request.tools if isinstance(tool, dict)]
return request.override(
tools=[*available_tools, *always_included_tools, *provider_tools]
)
return request.override(tools=[*always_included_tools, *provider_tools])
response["tools"] = self._complete_low_count_selection(
selected_tool_names=[
@@ -427,7 +424,7 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware):
"""
limit_instruction = ""
if self.max_tools:
limit_instruction = f"- Select up to {self.max_tools} tools. IF NO TOOLS ARE RELEVANT, DO NOT RETURN AN EMPTY ARRAY. SELECT THE MOST APPLICABLE ONES TO ENSURE THE REQUEST IS HANDLED."
limit_instruction = f"- Select up to {self.max_tools} tools. Return an empty array if no tools are relevant."
return (
f"{selection_request.system_message}\n\n"
@@ -502,9 +499,6 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware):
selected_text = ", ".join(attempt.selected_tool_names) or "无有效工具"
logger.info(f"工具筛选结果: {selected_text}")
return
if attempt.status == "empty_fallback":
logger.info(f"工具筛选结果为空,将恢复使用所有工具(共 {tool_count} 个)。")
return
if attempt.status == "failed_fallback":
logger.warning(
f"工具筛选失败,将恢复使用所有工具(共 {tool_count} 个): {attempt.detail}"
@@ -528,9 +522,6 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware):
这里只复用首次筛选出的客户端工具名provider-specific 的 dict 工具仍然
原样保留,避免破坏 LangChain/provider 自身的工具绑定约定。
"""
if not selected_tool_names:
return request
current_tools_by_name = {
tool.name: tool for tool in request.tools if not isinstance(tool, dict)
}
@@ -576,15 +567,10 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware):
selection_request.valid_tool_names,
request,
)
status = (
"empty_fallback"
if response.get("tools") == []
else "selected"
)
return _ToolSelectionAttempt(
request=modified_request,
selected_tool_names=self._extract_selected_tool_names(modified_request),
status=status,
status="selected",
)
except Exception as err:
return _ToolSelectionAttempt(
@@ -650,7 +636,7 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware):
attempt = await self._aselect_request_once_with_status(selection_request)
self._log_selection_attempt(attempt)
selected_tool_names = attempt.selected_tool_names
return ToolSelectionStateUpdate(selected_tool_names=selected_tool_names or None)
return ToolSelectionStateUpdate(selected_tool_names=selected_tool_names)
async def awrap_model_call(
self,
@@ -674,10 +660,10 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware):
attempt = await self._aselect_request_once_with_status(request)
self._log_selection_attempt(attempt)
request = attempt.request
selected_tool_names = attempt.selected_tool_names or None
selected_tool_names = attempt.selected_tool_names
request.state["selected_tool_names"] = selected_tool_names # noqa
if selected_tool_names:
if selected_tool_names is not None:
request = self._apply_selected_tools(request, selected_tool_names)
return await handler(request)

View File

@@ -292,8 +292,8 @@ def test_tool_selection_failure_falls_back_to_all_tools():
assert state_update == {"selected_tool_names": ["search", "calendar"]}
def test_empty_tool_selection_logs_info_not_warning():
"""工具筛选返回空数组时应按信息日志记录降级"""
def test_empty_tool_selection_keeps_empty_tool_list():
"""工具筛选返回空数组时应保持空工具列表"""
tools = [
SimpleNamespace(name="search", description="Search for information"),
SimpleNamespace(name="calendar", description="Manage events"),
@@ -310,17 +310,55 @@ def test_empty_tool_selection_logs_info_not_warning():
model=model,
)
async def handler(updated_request):
return updated_request
with patch.object(tool_selector_module.logger, "info") as logger_info, \
patch.object(tool_selector_module.logger, "warning") as logger_warning:
state_update = asyncio.run(
middleware.abefore_agent(request.state, runtime=None, config=None)
)
request.state.update(state_update)
result = asyncio.run(middleware.awrap_model_call(request, handler))
assert state_update == {"selected_tool_names": ["search", "calendar"]}
logger_info.assert_called_once_with("工具筛选结果为空,将恢复使用所有工具(共 2 个)。")
assert state_update == {"selected_tool_names": []}
assert result.tools == []
logger_info.assert_called_once_with("工具筛选结果: 无有效工具")
logger_warning.assert_not_called()
def test_empty_tool_selection_keeps_always_included_tools():
"""工具筛选返回空数组时仍应保留必须包括的工具。"""
tools = [
SimpleNamespace(name="search", description="Search for information"),
SimpleNamespace(name="skill", description="Run skill"),
]
model = _FakeModel(content='{"tools": []}')
middleware = tool_selector_module.ToolSelectorMiddleware(
max_tools=2,
selection_tools=tools,
always_include=["skill"],
)
middleware.model = model
request = _FakeRequest(
tools=tools,
messages=[HumanMessage(content="不用工具,直接回答")],
model=model,
)
async def handler(updated_request):
return updated_request
state_update = asyncio.run(
middleware.abefore_agent(request.state, runtime=None, config=None)
)
request.state.update(state_update)
result = asyncio.run(middleware.awrap_model_call(request, handler))
assert state_update == {"selected_tool_names": ["skill"]}
assert [tool.name for tool in result.tools] == ["skill"]
def test_abefore_agent_logs_selected_tools():
"""工具筛选返回有效工具时应记录最终生效的工具名。"""
tools = [