mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-06-28 07:46:36 +08:00
Preserve empty tool selections
This commit is contained in:
@@ -328,7 +328,7 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware):
|
||||
request: ModelRequest[ContextT],
|
||||
) -> ModelRequest[ContextT]:
|
||||
"""
|
||||
处理工具筛选响应,并保留空结果回退所有工具的 MoviePilot 策略。
|
||||
处理工具筛选响应,并在正常空结果时禁用可筛选工具。
|
||||
"""
|
||||
if response.get("tools") == []:
|
||||
always_included_tools: list[BaseTool] = [
|
||||
@@ -337,10 +337,7 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware):
|
||||
if not isinstance(tool, dict) and tool.name in self.always_include
|
||||
]
|
||||
provider_tools = [tool for tool in request.tools if isinstance(tool, dict)]
|
||||
|
||||
return request.override(
|
||||
tools=[*available_tools, *always_included_tools, *provider_tools]
|
||||
)
|
||||
return request.override(tools=[*always_included_tools, *provider_tools])
|
||||
|
||||
response["tools"] = self._complete_low_count_selection(
|
||||
selected_tool_names=[
|
||||
@@ -427,7 +424,7 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware):
|
||||
"""
|
||||
limit_instruction = ""
|
||||
if self.max_tools:
|
||||
limit_instruction = f"- Select up to {self.max_tools} tools. IF NO TOOLS ARE RELEVANT, DO NOT RETURN AN EMPTY ARRAY. SELECT THE MOST APPLICABLE ONES TO ENSURE THE REQUEST IS HANDLED."
|
||||
limit_instruction = f"- Select up to {self.max_tools} tools. Return an empty array if no tools are relevant."
|
||||
|
||||
return (
|
||||
f"{selection_request.system_message}\n\n"
|
||||
@@ -502,9 +499,6 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware):
|
||||
selected_text = ", ".join(attempt.selected_tool_names) or "无有效工具"
|
||||
logger.info(f"工具筛选结果: {selected_text}")
|
||||
return
|
||||
if attempt.status == "empty_fallback":
|
||||
logger.info(f"工具筛选结果为空,将恢复使用所有工具(共 {tool_count} 个)。")
|
||||
return
|
||||
if attempt.status == "failed_fallback":
|
||||
logger.warning(
|
||||
f"工具筛选失败,将恢复使用所有工具(共 {tool_count} 个): {attempt.detail}"
|
||||
@@ -528,9 +522,6 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware):
|
||||
这里只复用首次筛选出的客户端工具名;provider-specific 的 dict 工具仍然
|
||||
原样保留,避免破坏 LangChain/provider 自身的工具绑定约定。
|
||||
"""
|
||||
if not selected_tool_names:
|
||||
return request
|
||||
|
||||
current_tools_by_name = {
|
||||
tool.name: tool for tool in request.tools if not isinstance(tool, dict)
|
||||
}
|
||||
@@ -576,15 +567,10 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware):
|
||||
selection_request.valid_tool_names,
|
||||
request,
|
||||
)
|
||||
status = (
|
||||
"empty_fallback"
|
||||
if response.get("tools") == []
|
||||
else "selected"
|
||||
)
|
||||
return _ToolSelectionAttempt(
|
||||
request=modified_request,
|
||||
selected_tool_names=self._extract_selected_tool_names(modified_request),
|
||||
status=status,
|
||||
status="selected",
|
||||
)
|
||||
except Exception as err:
|
||||
return _ToolSelectionAttempt(
|
||||
@@ -650,7 +636,7 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware):
|
||||
attempt = await self._aselect_request_once_with_status(selection_request)
|
||||
self._log_selection_attempt(attempt)
|
||||
selected_tool_names = attempt.selected_tool_names
|
||||
return ToolSelectionStateUpdate(selected_tool_names=selected_tool_names or None)
|
||||
return ToolSelectionStateUpdate(selected_tool_names=selected_tool_names)
|
||||
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
@@ -674,10 +660,10 @@ class ToolSelectorMiddleware(LLMToolSelectorMiddleware):
|
||||
attempt = await self._aselect_request_once_with_status(request)
|
||||
self._log_selection_attempt(attempt)
|
||||
request = attempt.request
|
||||
selected_tool_names = attempt.selected_tool_names or None
|
||||
selected_tool_names = attempt.selected_tool_names
|
||||
request.state["selected_tool_names"] = selected_tool_names # noqa
|
||||
|
||||
if selected_tool_names:
|
||||
if selected_tool_names is not None:
|
||||
request = self._apply_selected_tools(request, selected_tool_names)
|
||||
|
||||
return await handler(request)
|
||||
|
||||
@@ -292,8 +292,8 @@ def test_tool_selection_failure_falls_back_to_all_tools():
|
||||
assert state_update == {"selected_tool_names": ["search", "calendar"]}
|
||||
|
||||
|
||||
def test_empty_tool_selection_logs_info_not_warning():
|
||||
"""工具筛选返回空数组时应按信息日志记录降级。"""
|
||||
def test_empty_tool_selection_keeps_empty_tool_list():
|
||||
"""工具筛选返回空数组时应保持空工具列表。"""
|
||||
tools = [
|
||||
SimpleNamespace(name="search", description="Search for information"),
|
||||
SimpleNamespace(name="calendar", description="Manage events"),
|
||||
@@ -310,17 +310,55 @@ def test_empty_tool_selection_logs_info_not_warning():
|
||||
model=model,
|
||||
)
|
||||
|
||||
async def handler(updated_request):
|
||||
return updated_request
|
||||
|
||||
with patch.object(tool_selector_module.logger, "info") as logger_info, \
|
||||
patch.object(tool_selector_module.logger, "warning") as logger_warning:
|
||||
state_update = asyncio.run(
|
||||
middleware.abefore_agent(request.state, runtime=None, config=None)
|
||||
)
|
||||
request.state.update(state_update)
|
||||
result = asyncio.run(middleware.awrap_model_call(request, handler))
|
||||
|
||||
assert state_update == {"selected_tool_names": ["search", "calendar"]}
|
||||
logger_info.assert_called_once_with("工具筛选结果为空,将恢复使用所有工具(共 2 个)。")
|
||||
assert state_update == {"selected_tool_names": []}
|
||||
assert result.tools == []
|
||||
logger_info.assert_called_once_with("工具筛选结果: 无有效工具")
|
||||
logger_warning.assert_not_called()
|
||||
|
||||
|
||||
def test_empty_tool_selection_keeps_always_included_tools():
|
||||
"""工具筛选返回空数组时仍应保留必须包括的工具。"""
|
||||
tools = [
|
||||
SimpleNamespace(name="search", description="Search for information"),
|
||||
SimpleNamespace(name="skill", description="Run skill"),
|
||||
]
|
||||
model = _FakeModel(content='{"tools": []}')
|
||||
middleware = tool_selector_module.ToolSelectorMiddleware(
|
||||
max_tools=2,
|
||||
selection_tools=tools,
|
||||
always_include=["skill"],
|
||||
)
|
||||
middleware.model = model
|
||||
request = _FakeRequest(
|
||||
tools=tools,
|
||||
messages=[HumanMessage(content="不用工具,直接回答")],
|
||||
model=model,
|
||||
)
|
||||
|
||||
async def handler(updated_request):
|
||||
return updated_request
|
||||
|
||||
state_update = asyncio.run(
|
||||
middleware.abefore_agent(request.state, runtime=None, config=None)
|
||||
)
|
||||
request.state.update(state_update)
|
||||
result = asyncio.run(middleware.awrap_model_call(request, handler))
|
||||
|
||||
assert state_update == {"selected_tool_names": ["skill"]}
|
||||
assert [tool.name for tool in result.tools] == ["skill"]
|
||||
|
||||
|
||||
def test_abefore_agent_logs_selected_tools():
|
||||
"""工具筛选返回有效工具时应记录最终生效的工具名。"""
|
||||
tools = [
|
||||
|
||||
Reference in New Issue
Block a user