mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-05-06 21:34:41 +08:00
feat: optimize tool selection middleware to cache and reuse tool selection per agent run
- Refactor MoviePilotToolSelectorMiddleware to perform tool selection once per agent execution and cache the result in state, avoiding redundant LLM calls for each model round. - Add abefore_agent to select tools at the start of agent execution and store selected tool names in state. - Update awrap_model_call to reuse cached tool selection from state for subsequent model calls. - Enhance test coverage for tool selection caching and reuse logic. - Improve error logging in skill version extraction.
This commit is contained in:
@@ -451,6 +451,7 @@ class MoviePilotAgent:
|
||||
model=non_streaming_model,
|
||||
max_tools=max_tools,
|
||||
always_include=always_include_tools,
|
||||
selection_tools=tools,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -310,7 +310,8 @@ def _extract_version(skill_md: Path) -> int:
|
||||
"""从 SKILL.md 文件中快速提取 version 字段,无法提取时返回 0。"""
|
||||
try:
|
||||
content = skill_md.read_text(encoding="utf-8")
|
||||
except Exception:
|
||||
except Exception as err:
|
||||
print(err)
|
||||
return 0
|
||||
match = re.match(r"^---\s*\n(.*?)\n---\s*\n", content, re.DOTALL)
|
||||
if not match:
|
||||
|
||||
@@ -1,16 +1,39 @@
|
||||
"""MoviePilot 自定义工具筛选中间件。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Annotated, Any, NotRequired, TypedDict
|
||||
|
||||
from langchain.agents.middleware import LLMToolSelectorMiddleware
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentState,
|
||||
ContextT,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
PrivateStateAttr, # noqa
|
||||
ResponseT,
|
||||
)
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class ToolSelectionState(AgentState):
|
||||
"""工具筛选中间件私有状态。"""
|
||||
|
||||
selected_tool_names: NotRequired[
|
||||
Annotated[list[str] | None, PrivateStateAttr]
|
||||
]
|
||||
"""当前这条用户请求首轮筛选得到的工具名列表。"""
|
||||
|
||||
|
||||
class ToolSelectionStateUpdate(TypedDict):
|
||||
"""工具筛选中间件状态更新项。"""
|
||||
|
||||
selected_tool_names: list[str] | None
|
||||
|
||||
|
||||
class MoviePilotToolSelectorMiddleware(LLMToolSelectorMiddleware):
|
||||
"""
|
||||
为 DeepSeek 兼容端点提供更稳妥的工具筛选实现。
|
||||
@@ -24,8 +47,22 @@ class MoviePilotToolSelectorMiddleware(LLMToolSelectorMiddleware):
|
||||
1. 使用 `response_format={"type": "json_object"}`;
|
||||
2. 在提示词中明确约束返回 JSON 结构;
|
||||
3. 手动解析 `{"tools": [...]}`,其余模型继续沿用 LangChain 默认实现。
|
||||
|
||||
另外,LangChain 原生工具筛选挂在 `wrap_model_call` 上,会在同一条用户请求
|
||||
的每次“模型回合”前都重新筛选一次工具。对于会多轮调用工具的复杂任务,
|
||||
这会重复消耗一次额外的 LLM 调用。这里改成:
|
||||
- `abefore_agent()`:在本轮 Agent 执行开始时筛选一次;
|
||||
- `awrap_model_call()`:从 `request.state` 读取首轮筛选结果并复用。
|
||||
"""
|
||||
|
||||
state_schema = ToolSelectionState
|
||||
|
||||
def __init__(self, *args, selection_tools: list[Any] | None = None, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
# `abefore_agent()` 无法直接拿到 ModelRequest,因此把首次可见的工具集
|
||||
# 通过初始化参数传入,后续在进入模型循环前完成一次真实筛选。
|
||||
self._selection_tools = selection_tools or []
|
||||
|
||||
@staticmethod
|
||||
def _is_deepseek_compatible_model(model: BaseChatModel) -> bool:
|
||||
"""
|
||||
@@ -175,22 +212,125 @@ class MoviePilotToolSelectorMiddleware(LLMToolSelectorMiddleware):
|
||||
)
|
||||
return self._normalize_selection_response(response)
|
||||
|
||||
async def awrap_model_call(self, request: Any, handler: Any) -> Any:
|
||||
@staticmethod
|
||||
def _extract_selected_tool_names(request: ModelRequest) -> list[str]:
|
||||
"""从已筛选后的请求中提取最终工具名,保留原有顺序。"""
|
||||
return [
|
||||
tool.name for tool in request.tools if not isinstance(tool, dict)
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _apply_selected_tools(
|
||||
request: ModelRequest[ContextT],
|
||||
selected_tool_names: list[str],
|
||||
) -> ModelRequest[ContextT]:
|
||||
"""
|
||||
异步版本的 DeepSeek 工具筛选兼容分支。
|
||||
将已筛选出的工具集应用到当前模型请求。
|
||||
|
||||
这里只复用首次筛选出的客户端工具名;provider-specific 的 dict 工具仍然
|
||||
原样保留,避免破坏 LangChain/provider 自身的工具绑定约定。
|
||||
"""
|
||||
if not selected_tool_names:
|
||||
return request
|
||||
|
||||
current_tools_by_name = {
|
||||
tool.name: tool
|
||||
for tool in request.tools
|
||||
if not isinstance(tool, dict)
|
||||
}
|
||||
selected_tools = [
|
||||
current_tools_by_name[tool_name]
|
||||
for tool_name in selected_tool_names
|
||||
if tool_name in current_tools_by_name
|
||||
]
|
||||
provider_tools = [tool for tool in request.tools if isinstance(tool, dict)]
|
||||
return request.override(tools=[*selected_tools, *provider_tools])
|
||||
|
||||
async def _aselect_request_once(
|
||||
self, request: ModelRequest[ContextT]
|
||||
) -> ModelRequest[ContextT]:
|
||||
"""
|
||||
执行一次真实工具筛选,并返回筛选后的请求对象。
|
||||
|
||||
这里单独抽成 helper,便于首次筛选后缓存结果,也便于测试覆盖
|
||||
“首轮筛选,后续复用”的行为。
|
||||
"""
|
||||
selection_request = self._prepare_selection_request(request)
|
||||
if selection_request is None:
|
||||
return await handler(request)
|
||||
return request
|
||||
|
||||
if not self._is_deepseek_compatible_model(selection_request.model):
|
||||
return await super().awrap_model_call(request, handler)
|
||||
captured_request: ModelRequest[ContextT] = request
|
||||
|
||||
async def _capture_handler(
|
||||
updated_request: ModelRequest[ContextT],
|
||||
) -> ModelRequest[ContextT]:
|
||||
nonlocal captured_request
|
||||
captured_request = updated_request
|
||||
return updated_request
|
||||
|
||||
await super().awrap_model_call(request, _capture_handler)
|
||||
return captured_request
|
||||
|
||||
response = await self._aselect_tools_with_deepseek(selection_request)
|
||||
modified_request = self._process_selection_response(
|
||||
return self._process_selection_response(
|
||||
response,
|
||||
selection_request.available_tools,
|
||||
selection_request.valid_tool_names,
|
||||
request,
|
||||
)
|
||||
return await handler(modified_request)
|
||||
|
||||
async def abefore_agent( # noqa
|
||||
self,
|
||||
state: ToolSelectionState,
|
||||
runtime: Runtime, # noqa
|
||||
config: RunnableConfig,
|
||||
) -> ToolSelectionStateUpdate | None: # ty: ignore[invalid-method-override]
|
||||
"""
|
||||
在本轮 Agent 执行开始前完成一次真实工具筛选。
|
||||
|
||||
这样后续多轮 `model -> tools -> model` 循环都只复用这一次结果,
|
||||
不会为每次模型回合重复追加一笔 selector LLM 开销。
|
||||
"""
|
||||
if "selected_tool_names" in state:
|
||||
return None
|
||||
|
||||
if not self._selection_tools or self.model is None:
|
||||
return ToolSelectionStateUpdate(selected_tool_names=None)
|
||||
|
||||
selection_request = ModelRequest(
|
||||
model=self.model,
|
||||
tools=list(self._selection_tools),
|
||||
messages=state["messages"],
|
||||
state=state,
|
||||
runtime=runtime,
|
||||
)
|
||||
modified_request = await self._aselect_request_once(selection_request)
|
||||
selected_tool_names = self._extract_selected_tool_names(modified_request)
|
||||
return ToolSelectionStateUpdate(
|
||||
selected_tool_names=selected_tool_names or None
|
||||
)
|
||||
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[
|
||||
[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]
|
||||
],
|
||||
) -> ModelResponse[ResponseT]:
|
||||
"""
|
||||
从 state 中读取首次筛选结果,并应用到每次模型回合。
|
||||
"""
|
||||
selected_tool_names = request.state.get("selected_tool_names") # noqa
|
||||
|
||||
# 正常路径下,`abefore_agent()` 已经提前写入状态;这里只保留一层兜底,
|
||||
# 兼容直接单测或未来某些绕过 before_agent 的调用场景。
|
||||
if selected_tool_names is None and self._selection_tools and self.model is not None:
|
||||
request = await self._aselect_request_once(request)
|
||||
selected_tool_names = self._extract_selected_tool_names(request) or None
|
||||
request.state["selected_tool_names"] = selected_tool_names # noqa
|
||||
|
||||
if selected_tool_names:
|
||||
request = self._apply_selected_tools(request, selected_tool_names)
|
||||
|
||||
return await handler(request)
|
||||
|
||||
@@ -56,10 +56,17 @@ class TestAgentSummarizationStreaming(unittest.TestCase):
|
||||
captured: dict = {}
|
||||
|
||||
class _FakeToolSelectorMiddleware:
|
||||
def __init__(self, model, max_tools, always_include=None):
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
max_tools,
|
||||
always_include=None,
|
||||
selection_tools=None,
|
||||
):
|
||||
self.model = model
|
||||
self.max_tools = max_tools
|
||||
self.always_include = always_include or []
|
||||
self.selection_tools = selection_tools or []
|
||||
|
||||
def _fake_create_agent(**kwargs):
|
||||
captured.update(kwargs)
|
||||
@@ -88,7 +95,7 @@ class TestAgentSummarizationStreaming(unittest.TestCase):
|
||||
),
|
||||
patch.object(
|
||||
agent_module,
|
||||
"LLMToolSelectorMiddleware",
|
||||
"MoviePilotToolSelectorMiddleware",
|
||||
_FakeToolSelectorMiddleware,
|
||||
),
|
||||
patch.object(agent_module, "create_agent", side_effect=_fake_create_agent),
|
||||
@@ -114,6 +121,7 @@ class TestAgentSummarizationStreaming(unittest.TestCase):
|
||||
"execute_command",
|
||||
],
|
||||
)
|
||||
self.assertEqual(tool_selector_middleware.selection_tools, fake_tools)
|
||||
|
||||
def test_non_streaming_agent_reuses_main_llm_for_summary(self):
|
||||
agent = agent_module.MoviePilotAgent(session_id="session-1", user_id="10001")
|
||||
|
||||
@@ -4,6 +4,7 @@ import sys
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from types import ModuleType, SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
@@ -22,6 +23,8 @@ sys.modules.pop("app.agent.middleware.tool_selection", None)
|
||||
_stub_module(
|
||||
"app.log",
|
||||
logger=SimpleNamespace(debug=lambda *args, **kwargs: None),
|
||||
log_settings=lambda *args, **kwargs: None,
|
||||
LogConfigModel=type("LogConfigModel", (), {}),
|
||||
)
|
||||
|
||||
module_path = (
|
||||
@@ -70,16 +73,20 @@ class _FakeModel:
|
||||
|
||||
|
||||
class _FakeRequest:
|
||||
def __init__(self, *, tools, messages, model):
|
||||
def __init__(self, *, tools, messages, model, state=None, runtime=None):
|
||||
self.tools = tools
|
||||
self.messages = messages
|
||||
self.model = model
|
||||
self.state = state if state is not None else {"messages": messages}
|
||||
self.runtime = runtime
|
||||
|
||||
def override(self, **kwargs):
|
||||
data = {
|
||||
"tools": self.tools,
|
||||
"messages": self.messages,
|
||||
"model": self.model,
|
||||
"state": self.state,
|
||||
"runtime": self.runtime,
|
||||
}
|
||||
data.update(kwargs)
|
||||
return _FakeRequest(**data)
|
||||
@@ -87,13 +94,17 @@ class _FakeRequest:
|
||||
|
||||
class ToolSelectorMiddlewareTest(unittest.TestCase):
|
||||
def test_awrap_model_call_uses_json_mode_for_deepseek(self):
|
||||
middleware = tool_selector_module.MoviePilotToolSelectorMiddleware(max_tools=2)
|
||||
tools = [
|
||||
SimpleNamespace(name="search", description="Search for information"),
|
||||
SimpleNamespace(name="calendar", description="Manage events"),
|
||||
SimpleNamespace(name="translate", description="Translate text"),
|
||||
]
|
||||
model = _FakeModel()
|
||||
middleware = tool_selector_module.MoviePilotToolSelectorMiddleware(
|
||||
max_tools=2,
|
||||
selection_tools=tools,
|
||||
)
|
||||
middleware.model = model
|
||||
request = _FakeRequest(
|
||||
tools=tools,
|
||||
messages=[HumanMessage(content="帮我安排明天的行程并查天气")],
|
||||
@@ -105,6 +116,11 @@ class ToolSelectorMiddlewareTest(unittest.TestCase):
|
||||
handled_requests.append(updated_request)
|
||||
return updated_request
|
||||
|
||||
state_update = asyncio.run(
|
||||
middleware.abefore_agent(request.state, runtime=None, config=None)
|
||||
)
|
||||
if state_update:
|
||||
request.state.update(state_update)
|
||||
result = asyncio.run(middleware.awrap_model_call(request, handler))
|
||||
|
||||
self.assertEqual(
|
||||
@@ -121,6 +137,108 @@ class ToolSelectorMiddlewareTest(unittest.TestCase):
|
||||
self.assertIn('- calendar: Manage events', prompt)
|
||||
self.assertEqual(len(handled_requests), 1)
|
||||
|
||||
def test_awrap_model_call_reuses_first_selection_for_later_model_rounds(self):
|
||||
tools = [
|
||||
SimpleNamespace(name="search", description="Search for information"),
|
||||
SimpleNamespace(name="calendar", description="Manage events"),
|
||||
SimpleNamespace(name="translate", description="Translate text"),
|
||||
]
|
||||
model = _FakeModel(content='{"tools": ["calendar", "search"]}')
|
||||
middleware = tool_selector_module.MoviePilotToolSelectorMiddleware(
|
||||
max_tools=2,
|
||||
selection_tools=tools,
|
||||
)
|
||||
middleware.model = model
|
||||
request = _FakeRequest(
|
||||
tools=tools,
|
||||
messages=[HumanMessage(content="帮我安排明天的行程并查天气")],
|
||||
model=model,
|
||||
)
|
||||
handled_requests = []
|
||||
|
||||
async def handler(updated_request):
|
||||
handled_requests.append(updated_request)
|
||||
return updated_request
|
||||
|
||||
state_update = asyncio.run(
|
||||
middleware.abefore_agent(request.state, runtime=None, config=None)
|
||||
)
|
||||
if state_update:
|
||||
request.state.update(state_update)
|
||||
first_result = asyncio.run(middleware.awrap_model_call(request, handler))
|
||||
second_result = asyncio.run(middleware.awrap_model_call(request, handler))
|
||||
|
||||
self.assertEqual(
|
||||
model.bind_calls,
|
||||
[{"response_format": {"type": "json_object"}}],
|
||||
)
|
||||
self.assertEqual(
|
||||
[tool.name for tool in first_result.tools],
|
||||
["search", "calendar"],
|
||||
)
|
||||
self.assertEqual(
|
||||
[tool.name for tool in second_result.tools],
|
||||
["search", "calendar"],
|
||||
)
|
||||
self.assertEqual(len(handled_requests), 2)
|
||||
|
||||
def test_awrap_model_call_caches_non_deepseek_selection_too(self):
|
||||
tools = [
|
||||
SimpleNamespace(name="search", description="Search for information"),
|
||||
SimpleNamespace(name="calendar", description="Manage events"),
|
||||
SimpleNamespace(name="translate", description="Translate text"),
|
||||
]
|
||||
model = _FakeModel(
|
||||
model_name="gpt-4o-mini",
|
||||
base_url="https://api.openai.com/v1",
|
||||
)
|
||||
middleware = tool_selector_module.MoviePilotToolSelectorMiddleware(
|
||||
max_tools=2,
|
||||
selection_tools=tools,
|
||||
)
|
||||
middleware.model = model
|
||||
request = _FakeRequest(
|
||||
tools=tools,
|
||||
messages=[HumanMessage(content="帮我安排明天的行程并查天气")],
|
||||
model=model,
|
||||
)
|
||||
|
||||
async def handler(updated_request):
|
||||
return updated_request
|
||||
|
||||
parent_calls = 0
|
||||
|
||||
async def _fake_parent_awrap(self, request_arg, handler_arg):
|
||||
nonlocal parent_calls
|
||||
parent_calls += 1
|
||||
selected_request = request_arg.override(
|
||||
tools=[request_arg.tools[1], request_arg.tools[0]]
|
||||
)
|
||||
return await handler_arg(selected_request)
|
||||
|
||||
with patch.object(
|
||||
tool_selector_module.LLMToolSelectorMiddleware,
|
||||
"awrap_model_call",
|
||||
_fake_parent_awrap,
|
||||
):
|
||||
state_update = asyncio.run(
|
||||
middleware.abefore_agent(request.state, runtime=None, config=None)
|
||||
)
|
||||
if state_update:
|
||||
request.state.update(state_update)
|
||||
first_result = asyncio.run(middleware.awrap_model_call(request, handler))
|
||||
second_result = asyncio.run(middleware.awrap_model_call(request, handler))
|
||||
|
||||
self.assertEqual(parent_calls, 1)
|
||||
self.assertEqual(
|
||||
[tool.name for tool in first_result.tools],
|
||||
["calendar", "search"],
|
||||
)
|
||||
self.assertEqual(
|
||||
[tool.name for tool in second_result.tools],
|
||||
["calendar", "search"],
|
||||
)
|
||||
|
||||
def test_normalize_selection_response_accepts_code_fence_json(self):
|
||||
middleware = tool_selector_module.MoviePilotToolSelectorMiddleware()
|
||||
response = SimpleNamespace(
|
||||
|
||||
Reference in New Issue
Block a user