"""MoviePilot 自定义工具筛选中间件。""" import json from collections.abc import Awaitable, Callable from dataclasses import dataclass from typing import Annotated, Any, Literal, Union, NotRequired from langchain.agents.middleware.types import ( AgentMiddleware, AgentState, ContextT, ModelRequest, ModelResponse, ResponseT, ) from langchain.agents.middleware.types import ( PrivateStateAttr, # noqa ) from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import HumanMessage from langchain_core.runnables import RunnableConfig from langchain_core.tools import BaseTool from langgraph.runtime import Runtime from pydantic import Field, TypeAdapter from typing_extensions import TypedDict # noqa from app.log import logger DEFAULT_SYSTEM_PROMPT = ( "Your goal is to select the most relevant tools for answering the user's query." ) @dataclass class _SelectionRequest: """Prepared inputs for tool selection.""" available_tools: list[BaseTool] system_message: str last_user_message: HumanMessage model: BaseChatModel valid_tool_names: list[str] def _create_tool_selection_response(tools: list[BaseTool]) -> TypeAdapter[Any]: """Create a structured output schema for tool selection. Args: tools: Available tools to include in the schema. Returns: `TypeAdapter` for a schema where each tool name is a `Literal` with its description. Raises: AssertionError: If `tools` is empty. """ if not tools: msg = "Invalid usage: tools must be non-empty" raise AssertionError(msg) # Create a Union of Annotated Literal types for each tool name with description # For instance: Union[Annotated[Literal["tool1"], Field(description="...")], ...] literals = [ Annotated[Literal[tool.name], Field(description=tool.description)] for tool in tools # noqa ] selected_tool_type = Union[tuple(literals)] # type: ignore[valid-type] # noqa: UP007 description = "Tools to use. Place the most relevant tools first." class ToolSelectionResponse(TypedDict): """Use to select relevant tools.""" tools: Annotated[list[selected_tool_type], Field(description=description)] # type: ignore[valid-type] return TypeAdapter(ToolSelectionResponse) def _render_tool_list(tools: list[BaseTool]) -> str: """Format tools as markdown list. Args: tools: Tools to format. Returns: Markdown string with each tool on a new line. """ return "\n".join(f"- {tool.name}: {tool.description}" for tool in tools) class ToolSelectionState(AgentState): """工具筛选中间件私有状态。""" selected_tool_names: NotRequired[Annotated[list[str] | None, PrivateStateAttr]] """当前这条用户请求首轮筛选得到的工具名列表。""" class ToolSelectionStateUpdate(TypedDict): """工具筛选中间件状态更新项。""" selected_tool_names: list[str] | None class ToolSelectorMiddleware( AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT] ): """ 为 DeepSeek 兼容端点提供更稳妥的工具筛选实现。 LangChain 默认会通过 `with_structured_output()` 走 OpenAI 的 `response_format=json_schema` 路径,但 DeepSeek 官方 OpenAI 兼容端点公开文档 仅保证 `json_object` 模式可用。对于 `deepseek-reasoner`,这会在工具筛选阶段 提前触发 400,导致 Agent 还没真正开始执行工具就失败。 因此这里仅在识别到 DeepSeek 模型/端点时,退回到显式 JSON 输出模式: 1. 使用 `response_format={"type": "json_object"}`; 2. 在提示词中明确约束返回 JSON 结构; 3. 手动解析 `{"tools": [...]}`,其余模型继续沿用 LangChain 默认实现。 另外,LangChain 原生工具筛选挂在 `wrap_model_call` 上,会在同一条用户请求 的每次“模型回合”前都重新筛选一次工具。对于会多轮调用工具的复杂任务, 这会重复消耗一次额外的 LLM 调用。这里改成: - `abefore_agent()`:在本轮 Agent 执行开始时筛选一次; - `awrap_model_call()`:从 `request.state` 读取首轮筛选结果并复用。 """ state_schema = ToolSelectionState def __init__( self, model: BaseChatModel, system_prompt: str = DEFAULT_SYSTEM_PROMPT, selection_tools: list[Any] | None = None, max_tools: int | None = None, always_include: list[str] | None = None, ) -> None: super().__init__() self.model = model self.system_prompt = system_prompt self.max_tools = max_tools self.always_include = always_include or [] self.selection_tools = selection_tools or [] def _prepare_selection_request( self, request: ModelRequest[ContextT] ) -> _SelectionRequest | None: """Prepare inputs for tool selection. Args: request: the model request. Returns: `SelectionRequest` with prepared inputs, or `None` if no selection is needed. Raises: ValueError: If tools in `always_include` are not found in the request. AssertionError: If no user message is found in the request messages. """ # If no tools available, return None if not request.tools or len(request.tools) == 0: return None # Filter to only BaseTool instances (exclude provider-specific tool dicts) base_tools = [tool for tool in request.tools if not isinstance(tool, dict)] # Validate that always_include tools exist if self.always_include: available_tool_names = {tool.name for tool in base_tools} missing_tools = [ name for name in self.always_include if name not in available_tool_names ] if missing_tools: msg = ( f"Tools in always_include not found in request: {missing_tools}. " f"Available tools: {sorted(available_tool_names)}" ) raise ValueError(msg) # Separate tools that are always included from those available for selection available_tools = [ tool for tool in base_tools if tool.name not in self.always_include ] # If no tools available for selection, return None if not available_tools: return None system_message = self.system_prompt # If there's a max_tools limit, append instructions to the system prompt if self.max_tools is not None: system_message += ( f"\nIMPORTANT: List the tool names in order of relevance, " f"with the most relevant first. " f"If you exceed the maximum number of tools, " f"only the first {self.max_tools} will be used." ) # Get the last user message from the conversation history last_user_message: HumanMessage for message in reversed(request.messages): if isinstance(message, HumanMessage): last_user_message = message break else: msg = "No user message found in request messages" raise AssertionError(msg) model = self.model or request.model valid_tool_names = [tool.name for tool in available_tools] return _SelectionRequest( available_tools=available_tools, system_message=system_message, last_user_message=last_user_message, model=model, valid_tool_names=valid_tool_names, ) def _process_selection_response( self, response: dict[str, Any], available_tools: list[BaseTool], valid_tool_names: list[str], request: ModelRequest[ContextT], ) -> ModelRequest[ContextT]: """Process the selection response and return filtered `ModelRequest`.""" selected_tool_names: list[str] = [] invalid_tool_selections = [] for tool_name in response["tools"]: if tool_name not in valid_tool_names: invalid_tool_selections.append(tool_name) continue # Only add if not already selected and within max_tools limit if tool_name not in selected_tool_names and ( self.max_tools is None or len(selected_tool_names) < self.max_tools ): selected_tool_names.append(tool_name) if invalid_tool_selections: msg = f"Model selected invalid tools: {invalid_tool_selections}" raise ValueError(msg) # Filter tools based on selection and append always-included tools if selected_tool_names: selected_tools: list[BaseTool] = [ tool for tool in available_tools if tool.name in selected_tool_names ] else: # 如果模型筛选结果为空,则不对工具进行裁剪,使用所有可用工具 logger.warning("工具筛选结果为空,将恢复使用所有工具。") selected_tools = available_tools always_included_tools: list[BaseTool] = [ tool for tool in request.tools if not isinstance(tool, dict) and tool.name in self.always_include ] selected_tools.extend(always_included_tools) # Also preserve any provider-specific tool dicts from the original request provider_tools = [tool for tool in request.tools if isinstance(tool, dict)] return request.override(tools=[*selected_tools, *provider_tools]) @staticmethod def _is_deepseek_compatible_model(model: BaseChatModel) -> bool: """ 判断当前模型是否应当走 DeepSeek JSON 兼容分支。 除了官方 `langchain_deepseek`,用户也可能通过 OpenAI-compatible 配置把 DeepSeek 端点接到 `ChatOpenAI`。因此这里同时检查模块名、模型名 和 Base URL,避免只靠单一条件漏判。 """ module_name = type(model).__module__.lower() model_name = ( str(getattr(model, "model_name", "") or getattr(model, "model", "")) .strip() .lower() ) base_url = ( str(getattr(model, "openai_api_base", "") or getattr(model, "api_base", "")) .strip() .lower() ) return ( "deepseek" in module_name or model_name.startswith("deepseek-") or "api.deepseek.com" in base_url ) @staticmethod def _extract_text_content(content: Any) -> str: """ 从模型响应中提取纯文本。 这里不依赖上层 LLMHelper,避免中间件与 LLM 构造逻辑互相耦合。 """ if content is None: return "" if isinstance(content, str): return content if isinstance(content, list): text_parts: list[str] = [] for block in content: if isinstance(block, str): text_parts.append(block) continue if isinstance(block, dict): if block.get("type") == "text" and isinstance( block.get("text"), str ): text_parts.append(block["text"]) continue if not block.get("type") and isinstance(block.get("text"), str): text_parts.append(block["text"]) return "".join(text_parts) if isinstance(content, dict): if content.get("type") == "text" and isinstance(content.get("text"), str): return content["text"] if not content.get("type") and isinstance(content.get("text"), str): return content["text"] return "" @staticmethod def _parse_json_object(text: str) -> dict[str, Any]: """ 解析模型返回的 JSON。 DeepSeek 在 JSON 模式下通常会返回纯 JSON,但这里仍做一层兜底, 兼容模型偶发输出围栏或前后说明文本的情况。 """ stripped_text = text.strip() if not stripped_text: raise ValueError("工具筛选返回了空响应") try: payload = json.loads(stripped_text) if isinstance(payload, dict): return payload except json.JSONDecodeError: pass start = stripped_text.find("{") end = stripped_text.rfind("}") if start == -1 or end == -1 or end <= start: raise ValueError(f"工具筛选返回的内容不是合法 JSON: {stripped_text}") payload = json.loads(stripped_text[start: end + 1]) if not isinstance(payload, dict): raise ValueError("工具筛选 JSON 顶层必须是对象") return payload @staticmethod def _render_tool_list(available_tools: list[Any]) -> str: """把工具名和描述渲染成稳定的文本列表。""" return "\n".join( f"- {tool.name}: {tool.description}" for tool in available_tools ) def _build_deepseek_selection_prompt(self, selection_request: Any) -> str: """ 为 DeepSeek 生成显式 JSON 输出提示。 DeepSeek 官方文档要求在 JSON 输出模式下,提示词中必须明确包含 JSON 约束,否则兼容端点可能返回空内容或无意义输出。 """ limit_instruction = "" if self.max_tools: limit_instruction = f"- Select up to {self.max_tools} tools. IF NO TOOLS ARE RELEVANT, DO NOT RETURN AN EMPTY ARRAY. SELECT THE MOST APPLICABLE ONES TO ENSURE THE REQUEST IS HANDLED." return ( f"{selection_request.system_message}\n\n" "Return the answer in JSON only.\n" 'Use exactly this shape: {"tools": ["tool_name_1", "tool_name_2"]}\n' "Rules:\n" "- The `tools` field must be a JSON array of strings.\n" "- Only use tool names from the allowed list below.\n" "- Order tools by relevance, with the most relevant first.\n" f"{limit_instruction}\n" "- Do not add explanations, markdown, or extra keys.\n\n" "Allowed tools:\n" f"{self._render_tool_list(selection_request.available_tools)}" ) def _normalize_selection_response(self, response: Any) -> dict[str, list[str]]: """ 解析并标准化 DeepSeek JSON 模式的工具筛选结果。 """ content = getattr(response, "content", response) text = self._extract_text_content(content) logger.debug(f"工具筛选原始响应: {text}") payload = self._parse_json_object(text) tools = payload.get("tools") if not isinstance(tools, list): raise ValueError(f"工具筛选 JSON 缺少 `tools` 数组: {payload}") normalized_tools = [ tool_name for tool_name in tools if isinstance(tool_name, str) ] logger.debug(f"工具筛选标准化结果: {normalized_tools}") return {"tools": normalized_tools} async def _aselect_tools_with_deepseek( self, selection_request: Any ) -> dict[str, list[str]]: """ 使用 DeepSeek 兼容的 JSON 输出模式执行异步工具筛选。 """ logger.debug("工具筛选走 DeepSeek JSON 兼容分支") structured_model = selection_request.model.bind( response_format={"type": "json_object"} ) response = await structured_model.ainvoke( [ { "role": "system", "content": self._build_deepseek_selection_prompt(selection_request), }, selection_request.last_user_message, ] ) return self._normalize_selection_response(response) @staticmethod def _extract_selected_tool_names(request: ModelRequest) -> list[str]: """从已筛选后的请求中提取最终工具名,保留原有顺序。""" return [tool.name for tool in request.tools if not isinstance(tool, dict)] @staticmethod def _apply_selected_tools( request: ModelRequest[ContextT], selected_tool_names: list[str], ) -> ModelRequest[ContextT]: """ 将已筛选出的工具集应用到当前模型请求。 这里只复用首次筛选出的客户端工具名;provider-specific 的 dict 工具仍然 原样保留,避免破坏 LangChain/provider 自身的工具绑定约定。 """ if not selected_tool_names: return request current_tools_by_name = { tool.name: tool for tool in request.tools if not isinstance(tool, dict) } selected_tools = [ current_tools_by_name[tool_name] for tool_name in selected_tool_names if tool_name in current_tools_by_name ] provider_tools = [tool for tool in request.tools if isinstance(tool, dict)] return request.override(tools=[*selected_tools, *provider_tools]) async def _aselect_request_once( self, request: ModelRequest[ContextT] ) -> ModelRequest[ContextT]: """ 执行一次真实工具筛选,并返回筛选后的请求对象。 这里单独抽成 helper,便于首次筛选后缓存结果,也便于测试覆盖 “首轮筛选,后续复用”的行为。 """ selection_request = self._prepare_selection_request(request) if selection_request is None: return request if not self._is_deepseek_compatible_model(selection_request.model): captured_request: ModelRequest[ContextT] = request async def _capture_handler( updated_request: ModelRequest[ContextT], ) -> ModelRequest[ContextT]: nonlocal captured_request captured_request = updated_request return updated_request await super().awrap_model_call(request, _capture_handler) return captured_request response = await self._aselect_tools_with_deepseek(selection_request) return self._process_selection_response( response, selection_request.available_tools, selection_request.valid_tool_names, request, ) async def abefore_agent( # noqa self, state: ToolSelectionState, runtime: Runtime, # noqa config: RunnableConfig, ) -> ToolSelectionStateUpdate | None: # ty: ignore[invalid-method-override] """ 在本轮 Agent 执行开始前完成一次真实工具筛选。 这样后续多轮 `model -> tools -> model` 循环都只复用这一次结果, 不会为每次模型回合重复追加一笔 selector LLM 开销。 """ if "selected_tool_names" in state: return None if not self.selection_tools or self.model is None: return ToolSelectionStateUpdate(selected_tool_names=None) selection_request = ModelRequest( model=self.model, tools=list(self.selection_tools), messages=state["messages"], state=state, runtime=runtime, ) modified_request = await self._aselect_request_once(selection_request) selected_tool_names = self._extract_selected_tool_names(modified_request) return ToolSelectionStateUpdate(selected_tool_names=selected_tool_names or None) async def awrap_model_call( self, request: ModelRequest[ContextT], handler: Callable[ [ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]] ], ) -> ModelResponse[ResponseT]: """ 从 state 中读取首次筛选结果,并应用到每次模型回合。 """ selected_tool_names = request.state.get("selected_tool_names") # noqa # 正常路径下,`abefore_agent()` 已经提前写入状态;这里只保留一层兜底, # 兼容直接单测或未来某些绕过 before_agent 的调用场景。 if ( selected_tool_names is None and self.selection_tools and self.model is not None ): request = await self._aselect_request_once(request) selected_tool_names = self._extract_selected_tool_names(request) or None request.state["selected_tool_names"] = selected_tool_names # noqa if selected_tool_names: request = self._apply_selected_tools(request, selected_tool_names) return await handler(request)