diff --git a/app/agent/__init__.py b/app/agent/__init__.py index cb996854..afc70238 100644 --- a/app/agent/__init__.py +++ b/app/agent/__init__.py @@ -451,6 +451,7 @@ class MoviePilotAgent: model=non_streaming_model, max_tools=max_tools, always_include=always_include_tools, + selection_tools=tools, ) ) diff --git a/app/agent/middleware/skills.py b/app/agent/middleware/skills.py index a93ff797..2a24325e 100644 --- a/app/agent/middleware/skills.py +++ b/app/agent/middleware/skills.py @@ -310,7 +310,8 @@ def _extract_version(skill_md: Path) -> int: """从 SKILL.md 文件中快速提取 version 字段,无法提取时返回 0。""" try: content = skill_md.read_text(encoding="utf-8") - except Exception: + except Exception as err: + print(err) return 0 match = re.match(r"^---\s*\n(.*?)\n---\s*\n", content, re.DOTALL) if not match: diff --git a/app/agent/middleware/tool_selection.py b/app/agent/middleware/tool_selection.py index 890e8c38..dc52f14d 100644 --- a/app/agent/middleware/tool_selection.py +++ b/app/agent/middleware/tool_selection.py @@ -1,16 +1,39 @@ """MoviePilot 自定义工具筛选中间件。""" - -from __future__ import annotations - import json -from typing import Any +from collections.abc import Awaitable, Callable +from typing import Annotated, Any, NotRequired, TypedDict from langchain.agents.middleware import LLMToolSelectorMiddleware +from langchain.agents.middleware.types import ( + AgentState, + ContextT, + ModelRequest, + ModelResponse, + PrivateStateAttr, # noqa + ResponseT, +) from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.runnables import RunnableConfig +from langgraph.runtime import Runtime from app.log import logger +class ToolSelectionState(AgentState): + """工具筛选中间件私有状态。""" + + selected_tool_names: NotRequired[ + Annotated[list[str] | None, PrivateStateAttr] + ] + """当前这条用户请求首轮筛选得到的工具名列表。""" + + +class ToolSelectionStateUpdate(TypedDict): + """工具筛选中间件状态更新项。""" + + selected_tool_names: list[str] | None + + class MoviePilotToolSelectorMiddleware(LLMToolSelectorMiddleware): """ 为 DeepSeek 兼容端点提供更稳妥的工具筛选实现。 @@ -24,8 +47,22 @@ class MoviePilotToolSelectorMiddleware(LLMToolSelectorMiddleware): 1. 使用 `response_format={"type": "json_object"}`; 2. 在提示词中明确约束返回 JSON 结构; 3. 手动解析 `{"tools": [...]}`,其余模型继续沿用 LangChain 默认实现。 + + 另外,LangChain 原生工具筛选挂在 `wrap_model_call` 上,会在同一条用户请求 + 的每次“模型回合”前都重新筛选一次工具。对于会多轮调用工具的复杂任务, + 这会重复消耗一次额外的 LLM 调用。这里改成: + - `abefore_agent()`:在本轮 Agent 执行开始时筛选一次; + - `awrap_model_call()`:从 `request.state` 读取首轮筛选结果并复用。 """ + state_schema = ToolSelectionState + + def __init__(self, *args, selection_tools: list[Any] | None = None, **kwargs) -> None: + super().__init__(*args, **kwargs) + # `abefore_agent()` 无法直接拿到 ModelRequest,因此把首次可见的工具集 + # 通过初始化参数传入,后续在进入模型循环前完成一次真实筛选。 + self._selection_tools = selection_tools or [] + @staticmethod def _is_deepseek_compatible_model(model: BaseChatModel) -> bool: """ @@ -175,22 +212,125 @@ class MoviePilotToolSelectorMiddleware(LLMToolSelectorMiddleware): ) return self._normalize_selection_response(response) - async def awrap_model_call(self, request: Any, handler: Any) -> Any: + @staticmethod + def _extract_selected_tool_names(request: ModelRequest) -> list[str]: + """从已筛选后的请求中提取最终工具名,保留原有顺序。""" + return [ + tool.name for tool in request.tools if not isinstance(tool, dict) + ] + + @staticmethod + def _apply_selected_tools( + request: ModelRequest[ContextT], + selected_tool_names: list[str], + ) -> ModelRequest[ContextT]: """ - 异步版本的 DeepSeek 工具筛选兼容分支。 + 将已筛选出的工具集应用到当前模型请求。 + + 这里只复用首次筛选出的客户端工具名;provider-specific 的 dict 工具仍然 + 原样保留,避免破坏 LangChain/provider 自身的工具绑定约定。 + """ + if not selected_tool_names: + return request + + current_tools_by_name = { + tool.name: tool + for tool in request.tools + if not isinstance(tool, dict) + } + selected_tools = [ + current_tools_by_name[tool_name] + for tool_name in selected_tool_names + if tool_name in current_tools_by_name + ] + provider_tools = [tool for tool in request.tools if isinstance(tool, dict)] + return request.override(tools=[*selected_tools, *provider_tools]) + + async def _aselect_request_once( + self, request: ModelRequest[ContextT] + ) -> ModelRequest[ContextT]: + """ + 执行一次真实工具筛选,并返回筛选后的请求对象。 + + 这里单独抽成 helper,便于首次筛选后缓存结果,也便于测试覆盖 + “首轮筛选,后续复用”的行为。 """ selection_request = self._prepare_selection_request(request) if selection_request is None: - return await handler(request) + return request if not self._is_deepseek_compatible_model(selection_request.model): - return await super().awrap_model_call(request, handler) + captured_request: ModelRequest[ContextT] = request + + async def _capture_handler( + updated_request: ModelRequest[ContextT], + ) -> ModelRequest[ContextT]: + nonlocal captured_request + captured_request = updated_request + return updated_request + + await super().awrap_model_call(request, _capture_handler) + return captured_request response = await self._aselect_tools_with_deepseek(selection_request) - modified_request = self._process_selection_response( + return self._process_selection_response( response, selection_request.available_tools, selection_request.valid_tool_names, request, ) - return await handler(modified_request) + + async def abefore_agent( # noqa + self, + state: ToolSelectionState, + runtime: Runtime, # noqa + config: RunnableConfig, + ) -> ToolSelectionStateUpdate | None: # ty: ignore[invalid-method-override] + """ + 在本轮 Agent 执行开始前完成一次真实工具筛选。 + + 这样后续多轮 `model -> tools -> model` 循环都只复用这一次结果, + 不会为每次模型回合重复追加一笔 selector LLM 开销。 + """ + if "selected_tool_names" in state: + return None + + if not self._selection_tools or self.model is None: + return ToolSelectionStateUpdate(selected_tool_names=None) + + selection_request = ModelRequest( + model=self.model, + tools=list(self._selection_tools), + messages=state["messages"], + state=state, + runtime=runtime, + ) + modified_request = await self._aselect_request_once(selection_request) + selected_tool_names = self._extract_selected_tool_names(modified_request) + return ToolSelectionStateUpdate( + selected_tool_names=selected_tool_names or None + ) + + async def awrap_model_call( + self, + request: ModelRequest[ContextT], + handler: Callable[ + [ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]] + ], + ) -> ModelResponse[ResponseT]: + """ + 从 state 中读取首次筛选结果,并应用到每次模型回合。 + """ + selected_tool_names = request.state.get("selected_tool_names") # noqa + + # 正常路径下,`abefore_agent()` 已经提前写入状态;这里只保留一层兜底, + # 兼容直接单测或未来某些绕过 before_agent 的调用场景。 + if selected_tool_names is None and self._selection_tools and self.model is not None: + request = await self._aselect_request_once(request) + selected_tool_names = self._extract_selected_tool_names(request) or None + request.state["selected_tool_names"] = selected_tool_names # noqa + + if selected_tool_names: + request = self._apply_selected_tools(request, selected_tool_names) + + return await handler(request) diff --git a/tests/test_agent_summarization_streaming.py b/tests/test_agent_summarization_streaming.py index 00dff0ed..ab4f2312 100644 --- a/tests/test_agent_summarization_streaming.py +++ b/tests/test_agent_summarization_streaming.py @@ -56,10 +56,17 @@ class TestAgentSummarizationStreaming(unittest.TestCase): captured: dict = {} class _FakeToolSelectorMiddleware: - def __init__(self, model, max_tools, always_include=None): + def __init__( + self, + model, + max_tools, + always_include=None, + selection_tools=None, + ): self.model = model self.max_tools = max_tools self.always_include = always_include or [] + self.selection_tools = selection_tools or [] def _fake_create_agent(**kwargs): captured.update(kwargs) @@ -88,7 +95,7 @@ class TestAgentSummarizationStreaming(unittest.TestCase): ), patch.object( agent_module, - "LLMToolSelectorMiddleware", + "MoviePilotToolSelectorMiddleware", _FakeToolSelectorMiddleware, ), patch.object(agent_module, "create_agent", side_effect=_fake_create_agent), @@ -114,6 +121,7 @@ class TestAgentSummarizationStreaming(unittest.TestCase): "execute_command", ], ) + self.assertEqual(tool_selector_middleware.selection_tools, fake_tools) def test_non_streaming_agent_reuses_main_llm_for_summary(self): agent = agent_module.MoviePilotAgent(session_id="session-1", user_id="10001") diff --git a/tests/test_agent_tool_selector_middleware.py b/tests/test_agent_tool_selector_middleware.py index bc38a7e6..eb02d475 100644 --- a/tests/test_agent_tool_selector_middleware.py +++ b/tests/test_agent_tool_selector_middleware.py @@ -4,6 +4,7 @@ import sys import unittest from pathlib import Path from types import ModuleType, SimpleNamespace +from unittest.mock import patch from langchain_core.messages import HumanMessage @@ -22,6 +23,8 @@ sys.modules.pop("app.agent.middleware.tool_selection", None) _stub_module( "app.log", logger=SimpleNamespace(debug=lambda *args, **kwargs: None), + log_settings=lambda *args, **kwargs: None, + LogConfigModel=type("LogConfigModel", (), {}), ) module_path = ( @@ -70,16 +73,20 @@ class _FakeModel: class _FakeRequest: - def __init__(self, *, tools, messages, model): + def __init__(self, *, tools, messages, model, state=None, runtime=None): self.tools = tools self.messages = messages self.model = model + self.state = state if state is not None else {"messages": messages} + self.runtime = runtime def override(self, **kwargs): data = { "tools": self.tools, "messages": self.messages, "model": self.model, + "state": self.state, + "runtime": self.runtime, } data.update(kwargs) return _FakeRequest(**data) @@ -87,13 +94,17 @@ class _FakeRequest: class ToolSelectorMiddlewareTest(unittest.TestCase): def test_awrap_model_call_uses_json_mode_for_deepseek(self): - middleware = tool_selector_module.MoviePilotToolSelectorMiddleware(max_tools=2) tools = [ SimpleNamespace(name="search", description="Search for information"), SimpleNamespace(name="calendar", description="Manage events"), SimpleNamespace(name="translate", description="Translate text"), ] model = _FakeModel() + middleware = tool_selector_module.MoviePilotToolSelectorMiddleware( + max_tools=2, + selection_tools=tools, + ) + middleware.model = model request = _FakeRequest( tools=tools, messages=[HumanMessage(content="帮我安排明天的行程并查天气")], @@ -105,6 +116,11 @@ class ToolSelectorMiddlewareTest(unittest.TestCase): handled_requests.append(updated_request) return updated_request + state_update = asyncio.run( + middleware.abefore_agent(request.state, runtime=None, config=None) + ) + if state_update: + request.state.update(state_update) result = asyncio.run(middleware.awrap_model_call(request, handler)) self.assertEqual( @@ -121,6 +137,108 @@ class ToolSelectorMiddlewareTest(unittest.TestCase): self.assertIn('- calendar: Manage events', prompt) self.assertEqual(len(handled_requests), 1) + def test_awrap_model_call_reuses_first_selection_for_later_model_rounds(self): + tools = [ + SimpleNamespace(name="search", description="Search for information"), + SimpleNamespace(name="calendar", description="Manage events"), + SimpleNamespace(name="translate", description="Translate text"), + ] + model = _FakeModel(content='{"tools": ["calendar", "search"]}') + middleware = tool_selector_module.MoviePilotToolSelectorMiddleware( + max_tools=2, + selection_tools=tools, + ) + middleware.model = model + request = _FakeRequest( + tools=tools, + messages=[HumanMessage(content="帮我安排明天的行程并查天气")], + model=model, + ) + handled_requests = [] + + async def handler(updated_request): + handled_requests.append(updated_request) + return updated_request + + state_update = asyncio.run( + middleware.abefore_agent(request.state, runtime=None, config=None) + ) + if state_update: + request.state.update(state_update) + first_result = asyncio.run(middleware.awrap_model_call(request, handler)) + second_result = asyncio.run(middleware.awrap_model_call(request, handler)) + + self.assertEqual( + model.bind_calls, + [{"response_format": {"type": "json_object"}}], + ) + self.assertEqual( + [tool.name for tool in first_result.tools], + ["search", "calendar"], + ) + self.assertEqual( + [tool.name for tool in second_result.tools], + ["search", "calendar"], + ) + self.assertEqual(len(handled_requests), 2) + + def test_awrap_model_call_caches_non_deepseek_selection_too(self): + tools = [ + SimpleNamespace(name="search", description="Search for information"), + SimpleNamespace(name="calendar", description="Manage events"), + SimpleNamespace(name="translate", description="Translate text"), + ] + model = _FakeModel( + model_name="gpt-4o-mini", + base_url="https://api.openai.com/v1", + ) + middleware = tool_selector_module.MoviePilotToolSelectorMiddleware( + max_tools=2, + selection_tools=tools, + ) + middleware.model = model + request = _FakeRequest( + tools=tools, + messages=[HumanMessage(content="帮我安排明天的行程并查天气")], + model=model, + ) + + async def handler(updated_request): + return updated_request + + parent_calls = 0 + + async def _fake_parent_awrap(self, request_arg, handler_arg): + nonlocal parent_calls + parent_calls += 1 + selected_request = request_arg.override( + tools=[request_arg.tools[1], request_arg.tools[0]] + ) + return await handler_arg(selected_request) + + with patch.object( + tool_selector_module.LLMToolSelectorMiddleware, + "awrap_model_call", + _fake_parent_awrap, + ): + state_update = asyncio.run( + middleware.abefore_agent(request.state, runtime=None, config=None) + ) + if state_update: + request.state.update(state_update) + first_result = asyncio.run(middleware.awrap_model_call(request, handler)) + second_result = asyncio.run(middleware.awrap_model_call(request, handler)) + + self.assertEqual(parent_calls, 1) + self.assertEqual( + [tool.name for tool in first_result.tools], + ["calendar", "search"], + ) + self.assertEqual( + [tool.name for tool in second_result.tools], + ["calendar", "search"], + ) + def test_normalize_selection_response_accepts_code_fence_json(self): middleware = tool_selector_module.MoviePilotToolSelectorMiddleware() response = SimpleNamespace(