mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-05-07 13:52:42 +08:00
- Refactor MoviePilotToolSelectorMiddleware to perform tool selection once per agent execution and cache the result in state, avoiding redundant LLM calls for each model round. - Add abefore_agent to select tools at the start of agent execution and store selected tool names in state. - Update awrap_model_call to reuse cached tool selection from state for subsequent model calls. - Enhance test coverage for tool selection caching and reuse logic. - Improve error logging in skill version extraction.
260 lines
8.5 KiB
Python
260 lines
8.5 KiB
Python
import asyncio
|
|
import importlib.util
|
|
import sys
|
|
import unittest
|
|
from pathlib import Path
|
|
from types import ModuleType, SimpleNamespace
|
|
from unittest.mock import patch
|
|
|
|
from langchain_core.messages import HumanMessage
|
|
|
|
|
|
def _stub_module(name: str, **attrs):
|
|
module = sys.modules.get(name)
|
|
if module is None:
|
|
module = ModuleType(name)
|
|
sys.modules[name] = module
|
|
for key, value in attrs.items():
|
|
setattr(module, key, value)
|
|
return module
|
|
|
|
|
|
sys.modules.pop("app.agent.middleware.tool_selection", None)
|
|
_stub_module(
|
|
"app.log",
|
|
logger=SimpleNamespace(debug=lambda *args, **kwargs: None),
|
|
log_settings=lambda *args, **kwargs: None,
|
|
LogConfigModel=type("LogConfigModel", (), {}),
|
|
)
|
|
|
|
module_path = (
|
|
Path(__file__).resolve().parents[1]
|
|
/ "app"
|
|
/ "agent"
|
|
/ "middleware"
|
|
/ "tool_selection.py"
|
|
)
|
|
spec = importlib.util.spec_from_file_location("test_tool_selector_module", module_path)
|
|
tool_selector_module = importlib.util.module_from_spec(spec)
|
|
assert spec and spec.loader
|
|
spec.loader.exec_module(tool_selector_module)
|
|
|
|
|
|
class _FakeBoundModel:
|
|
def __init__(self, content):
|
|
self.content = content
|
|
self.messages = None
|
|
|
|
def invoke(self, messages):
|
|
self.messages = messages
|
|
return SimpleNamespace(content=self.content)
|
|
|
|
async def ainvoke(self, messages):
|
|
self.messages = messages
|
|
return SimpleNamespace(content=self.content)
|
|
|
|
|
|
class _FakeModel:
|
|
def __init__(
|
|
self,
|
|
*,
|
|
content='{"tools": ["calendar", "search"]}',
|
|
model_name="deepseek-reasoner",
|
|
base_url="https://api.deepseek.com",
|
|
):
|
|
self.model_name = model_name
|
|
self.openai_api_base = base_url
|
|
self.bind_calls = []
|
|
self.bound_model = _FakeBoundModel(content)
|
|
|
|
def bind(self, **kwargs):
|
|
self.bind_calls.append(kwargs)
|
|
return self.bound_model
|
|
|
|
|
|
class _FakeRequest:
|
|
def __init__(self, *, tools, messages, model, state=None, runtime=None):
|
|
self.tools = tools
|
|
self.messages = messages
|
|
self.model = model
|
|
self.state = state if state is not None else {"messages": messages}
|
|
self.runtime = runtime
|
|
|
|
def override(self, **kwargs):
|
|
data = {
|
|
"tools": self.tools,
|
|
"messages": self.messages,
|
|
"model": self.model,
|
|
"state": self.state,
|
|
"runtime": self.runtime,
|
|
}
|
|
data.update(kwargs)
|
|
return _FakeRequest(**data)
|
|
|
|
|
|
class ToolSelectorMiddlewareTest(unittest.TestCase):
|
|
def test_awrap_model_call_uses_json_mode_for_deepseek(self):
|
|
tools = [
|
|
SimpleNamespace(name="search", description="Search for information"),
|
|
SimpleNamespace(name="calendar", description="Manage events"),
|
|
SimpleNamespace(name="translate", description="Translate text"),
|
|
]
|
|
model = _FakeModel()
|
|
middleware = tool_selector_module.MoviePilotToolSelectorMiddleware(
|
|
max_tools=2,
|
|
selection_tools=tools,
|
|
)
|
|
middleware.model = model
|
|
request = _FakeRequest(
|
|
tools=tools,
|
|
messages=[HumanMessage(content="帮我安排明天的行程并查天气")],
|
|
model=model,
|
|
)
|
|
handled_requests = []
|
|
|
|
async def handler(updated_request):
|
|
handled_requests.append(updated_request)
|
|
return updated_request
|
|
|
|
state_update = asyncio.run(
|
|
middleware.abefore_agent(request.state, runtime=None, config=None)
|
|
)
|
|
if state_update:
|
|
request.state.update(state_update)
|
|
result = asyncio.run(middleware.awrap_model_call(request, handler))
|
|
|
|
self.assertEqual(
|
|
model.bind_calls,
|
|
[{"response_format": {"type": "json_object"}}],
|
|
)
|
|
self.assertEqual(
|
|
[tool.name for tool in result.tools],
|
|
["search", "calendar"],
|
|
)
|
|
prompt = model.bound_model.messages[0]["content"]
|
|
self.assertIn("Return the answer in JSON only.", prompt)
|
|
self.assertIn('- search: Search for information', prompt)
|
|
self.assertIn('- calendar: Manage events', prompt)
|
|
self.assertEqual(len(handled_requests), 1)
|
|
|
|
def test_awrap_model_call_reuses_first_selection_for_later_model_rounds(self):
|
|
tools = [
|
|
SimpleNamespace(name="search", description="Search for information"),
|
|
SimpleNamespace(name="calendar", description="Manage events"),
|
|
SimpleNamespace(name="translate", description="Translate text"),
|
|
]
|
|
model = _FakeModel(content='{"tools": ["calendar", "search"]}')
|
|
middleware = tool_selector_module.MoviePilotToolSelectorMiddleware(
|
|
max_tools=2,
|
|
selection_tools=tools,
|
|
)
|
|
middleware.model = model
|
|
request = _FakeRequest(
|
|
tools=tools,
|
|
messages=[HumanMessage(content="帮我安排明天的行程并查天气")],
|
|
model=model,
|
|
)
|
|
handled_requests = []
|
|
|
|
async def handler(updated_request):
|
|
handled_requests.append(updated_request)
|
|
return updated_request
|
|
|
|
state_update = asyncio.run(
|
|
middleware.abefore_agent(request.state, runtime=None, config=None)
|
|
)
|
|
if state_update:
|
|
request.state.update(state_update)
|
|
first_result = asyncio.run(middleware.awrap_model_call(request, handler))
|
|
second_result = asyncio.run(middleware.awrap_model_call(request, handler))
|
|
|
|
self.assertEqual(
|
|
model.bind_calls,
|
|
[{"response_format": {"type": "json_object"}}],
|
|
)
|
|
self.assertEqual(
|
|
[tool.name for tool in first_result.tools],
|
|
["search", "calendar"],
|
|
)
|
|
self.assertEqual(
|
|
[tool.name for tool in second_result.tools],
|
|
["search", "calendar"],
|
|
)
|
|
self.assertEqual(len(handled_requests), 2)
|
|
|
|
def test_awrap_model_call_caches_non_deepseek_selection_too(self):
|
|
tools = [
|
|
SimpleNamespace(name="search", description="Search for information"),
|
|
SimpleNamespace(name="calendar", description="Manage events"),
|
|
SimpleNamespace(name="translate", description="Translate text"),
|
|
]
|
|
model = _FakeModel(
|
|
model_name="gpt-4o-mini",
|
|
base_url="https://api.openai.com/v1",
|
|
)
|
|
middleware = tool_selector_module.MoviePilotToolSelectorMiddleware(
|
|
max_tools=2,
|
|
selection_tools=tools,
|
|
)
|
|
middleware.model = model
|
|
request = _FakeRequest(
|
|
tools=tools,
|
|
messages=[HumanMessage(content="帮我安排明天的行程并查天气")],
|
|
model=model,
|
|
)
|
|
|
|
async def handler(updated_request):
|
|
return updated_request
|
|
|
|
parent_calls = 0
|
|
|
|
async def _fake_parent_awrap(self, request_arg, handler_arg):
|
|
nonlocal parent_calls
|
|
parent_calls += 1
|
|
selected_request = request_arg.override(
|
|
tools=[request_arg.tools[1], request_arg.tools[0]]
|
|
)
|
|
return await handler_arg(selected_request)
|
|
|
|
with patch.object(
|
|
tool_selector_module.LLMToolSelectorMiddleware,
|
|
"awrap_model_call",
|
|
_fake_parent_awrap,
|
|
):
|
|
state_update = asyncio.run(
|
|
middleware.abefore_agent(request.state, runtime=None, config=None)
|
|
)
|
|
if state_update:
|
|
request.state.update(state_update)
|
|
first_result = asyncio.run(middleware.awrap_model_call(request, handler))
|
|
second_result = asyncio.run(middleware.awrap_model_call(request, handler))
|
|
|
|
self.assertEqual(parent_calls, 1)
|
|
self.assertEqual(
|
|
[tool.name for tool in first_result.tools],
|
|
["calendar", "search"],
|
|
)
|
|
self.assertEqual(
|
|
[tool.name for tool in second_result.tools],
|
|
["calendar", "search"],
|
|
)
|
|
|
|
def test_normalize_selection_response_accepts_code_fence_json(self):
|
|
middleware = tool_selector_module.MoviePilotToolSelectorMiddleware()
|
|
response = SimpleNamespace(
|
|
content=[
|
|
{
|
|
"type": "text",
|
|
"text": '```json\n{"tools": ["search"]}\n```',
|
|
}
|
|
]
|
|
)
|
|
|
|
normalized = middleware._normalize_selection_response(response)
|
|
|
|
self.assertEqual(normalized, {"tools": ["search"]})
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|