"""MoviePilot 自定义工具筛选中间件。""" from dataclasses import dataclass, replace import json from collections.abc import Awaitable, Callable from typing import Annotated, Any, NotRequired from langchain.agents.middleware.types import ( AgentState, ContextT, ModelRequest, ModelResponse, ResponseT, ) from langchain.agents.middleware.types import ( PrivateStateAttr, # noqa ) from langchain.agents.middleware.tool_selection import ( DEFAULT_SYSTEM_PROMPT, LLMToolSelectorMiddleware, ) from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import AIMessage, HumanMessage, SystemMessage from langchain_core.runnables import RunnableConfig from langchain_core.tools import BaseTool from langgraph.runtime import Runtime from typing_extensions import TypedDict # noqa from app.agent.llm import LLMHelper from app.agent.tools.tags import ToolTag from app.log import logger MIN_SELECTED_TOOL_COUNT = 4 RECENT_SELECTION_CONTEXT_MESSAGE_LIMIT = 6 RECENT_SELECTION_CONTEXT_MAX_CHARS = 6000 RECENT_SELECTION_CONTEXT_TRUNCATION_PREFIX = "..." TOOL_GROUP_EXCLUDED_TAGS = frozenset( { ToolTag.AgentTool.value, ToolTag.Read.value, ToolTag.Write.value, ToolTag.Admin.value, ToolTag.Message.value, ToolTag.UserInteraction.value, ToolTag.TerminalResponse.value, } ) MOVIEPILOT_TOOL_SELECTION_HINT = """ MoviePilot tool-chain hints: - Tools with the same capability tag belong to the same functional group. - For multi-step MoviePilot tasks, keep same-tag tools together when relevant. - Prefer selecting likely next-step tools in the same capability group instead of selecting only the first tool. """ class ToolSelectionState(AgentState): """工具筛选中间件私有状态。""" selected_tool_names: NotRequired[Annotated[list[str] | None, PrivateStateAttr]] """当前这条用户请求首轮筛选得到的工具名列表。""" class ToolSelectionStateUpdate(TypedDict): """工具筛选中间件状态更新项。""" selected_tool_names: list[str] | None @dataclass(frozen=True) class _ToolSelectionAttempt: """工具筛选尝试结果,用于统一记录最终日志。""" request: ModelRequest selected_tool_names: list[str] status: str detail: str = "" class ToolSelectorMiddleware(LLMToolSelectorMiddleware): """ 使用 provider-neutral JSON 提示执行工具筛选。 LangChain 默认会通过 `with_structured_output()` 走 provider-specific 的 结构化输出能力,不同 OpenAI/Anthropic 兼容端点对 `response_format`、 JSON schema 和工具绑定的支持并不一致。工具筛选只是 Agent 执行前的 辅助优化,失败时也会恢复使用全部工具,因此这里统一使用文本提示约束 模型返回 `{"tools": [...]}` 并手动解析,避免在筛选阶段引入额外兼容分支。 另外,LangChain 原生工具筛选挂在 `wrap_model_call` 上,会在同一条用户请求 的每次“模型回合”前都重新筛选一次工具。对于会多轮调用工具的复杂任务, 这会重复消耗一次额外的 LLM 调用。这里改成: - `abefore_agent()`:在本轮 Agent 执行开始时筛选一次; - `awrap_model_call()`:从 `request.state` 读取首轮筛选结果并复用。 """ state_schema = ToolSelectionState def __init__( self, model: BaseChatModel | str | None = None, system_prompt: str = DEFAULT_SYSTEM_PROMPT, selection_tools: list[Any] | None = None, max_tools: int | None = None, always_include: list[str] | None = None, ) -> None: super().__init__( model=model, system_prompt=self._append_tool_selection_hint(system_prompt), max_tools=max_tools, always_include=always_include, ) self.selection_tools = selection_tools or [] @classmethod def _render_recent_conversation_context( cls, messages: list[Any], ) -> tuple[str, int]: """渲染最近对话上下文,供工具筛选模型理解多轮追问。""" rendered_messages = [] for message in messages: if isinstance(message, HumanMessage): role = "User" elif isinstance(message, AIMessage): role = "Assistant" else: continue content = LLMHelper.extract_text_content(message.content).strip() if not content: continue rendered_messages.append(f"{role}: {content}") recent_messages = rendered_messages[-RECENT_SELECTION_CONTEXT_MESSAGE_LIMIT:] context = "\n\n".join(recent_messages) if len(context) > RECENT_SELECTION_CONTEXT_MAX_CHARS: context = ( f"{RECENT_SELECTION_CONTEXT_TRUNCATION_PREFIX}" f"{context[-RECENT_SELECTION_CONTEXT_MAX_CHARS:]}" ) return context, len(recent_messages) @classmethod def _build_contextual_user_message( cls, messages: list[Any], last_user_message: HumanMessage, ) -> HumanMessage: """根据最近对话构造工具筛选专用用户消息。""" context, message_count = cls._render_recent_conversation_context(messages) if message_count <= 1: return last_user_message return HumanMessage( content=( "Recent conversation context for tool selection:\n" f"{context}\n\n" "Select tools for the latest user instruction. Use prior assistant " "messages and earlier user requests when the latest user message " "depends on previous context." ) ) def _prepare_selection_request( self, request: ModelRequest[ContextT], ) -> Any | None: """准备带最近对话上下文的工具筛选请求。""" selection_request = super()._prepare_selection_request(request) if selection_request is None: return None contextual_user_message = self._build_contextual_user_message( messages=request.messages, last_user_message=selection_request.last_user_message, ) if contextual_user_message is selection_request.last_user_message: return selection_request return replace(selection_request, last_user_message=contextual_user_message) @staticmethod def _append_tool_selection_hint(system_prompt: str) -> str: """追加 MoviePilot 工具组选择提示,避免复杂链路只选中首个工具。""" if "MoviePilot tool-chain hints:" in system_prompt: return system_prompt return f"{system_prompt.rstrip()}{MOVIEPILOT_TOOL_SELECTION_HINT}" def _get_tool_selection_limit(self, valid_tool_names: list[str]) -> int: """计算补齐筛选结果时允许使用的工具数量上限。""" if self.max_tools: return min(self.max_tools, len(valid_tool_names)) return len(valid_tool_names) @staticmethod def _normalize_tool_tags(tool: BaseTool) -> list[str]: """读取工具的业务标签,过滤掉无法表达工具组的通用标签。""" tags = getattr(tool, "tags", None) or [] if isinstance(tags, str): tags = [tags] normalized_tags = [] for tag in tags: tag_value = getattr(tag, "value", tag) if not tag_value: continue tag_name = str(tag_value) if tag_name in TOOL_GROUP_EXCLUDED_TAGS or tag_name in normalized_tags: continue normalized_tags.append(tag_name) return normalized_tags @classmethod def _build_tool_groups( cls, available_tools: list[BaseTool], valid_tool_names: list[str], ) -> list[tuple[str, list[str]]]: """根据工具标签构造能力组,保留当前工具列表中的稳定顺序。""" valid_tool_set = set(valid_tool_names) tool_groups: dict[str, list[str]] = {} for tool in available_tools: tool_name = getattr(tool, "name", None) if not tool_name or tool_name not in valid_tool_set: continue for tag in cls._normalize_tool_tags(tool): group_tool_names = tool_groups.setdefault(tag, []) if tool_name not in group_tool_names: group_tool_names.append(tool_name) return [ (tag, tool_names) for tag, tool_names in tool_groups.items() if len(tool_names) > 1 ] @classmethod def _get_matched_tool_groups( cls, selected_names: list[str], available_tools: list[BaseTool], valid_tool_names: list[str], ) -> list[tuple[str, list[str]]]: """返回已选工具命中的标签能力组。""" groups_by_tag = { tag: tool_names for tag, tool_names in cls._build_tool_groups( available_tools=available_tools, valid_tool_names=valid_tool_names, ) } tools_by_name = { tool.name: tool for tool in available_tools if getattr(tool, "name", None) } matched_groups: list[tuple[str, list[str]]] = [] seen_tags = set() for tool_name in selected_names: tool = tools_by_name.get(tool_name) if not tool: continue for tag in cls._normalize_tool_tags(tool): if tag in seen_tags or tag not in groups_by_tag: continue matched_groups.append((tag, groups_by_tag[tag])) seen_tags.add(tag) return matched_groups def _complete_low_count_selection( self, selected_tool_names: list[str], valid_tool_names: list[str], available_tools: list[BaseTool], ) -> list[str]: """ 当模型只选出极少工具时,按工具标签补齐同组工具。 工具标签是工具自身声明的能力归属。这里只补齐已经命中的标签组, 不会把所有工具组都展开。 """ limit = self._get_tool_selection_limit(valid_tool_names) selected_names = [ tool_name for tool_name in selected_tool_names if tool_name in valid_tool_names ] selected_set = set(selected_names) valid_tool_set = set(valid_tool_names) completed_names = list(selected_names) matched_groups = self._get_matched_tool_groups( selected_names=selected_names, available_tools=available_tools, valid_tool_names=valid_tool_names, ) if not matched_groups: return completed_names[:limit] matched_group_tool_names = { tool_name for _, group_tool_names in matched_groups for tool_name in group_tool_names } target_count = min( max(MIN_SELECTED_TOOL_COUNT, len(matched_group_tool_names)), limit, ) if len(selected_names) >= target_count: return selected_names[:limit] for _, group_tool_names in matched_groups: for tool_name in group_tool_names: if tool_name in selected_set or tool_name not in valid_tool_set: continue completed_names.append(tool_name) selected_set.add(tool_name) if len(completed_names) >= target_count: return completed_names[:limit] return completed_names[:limit] def _process_selection_response( self, response: dict[str, Any], available_tools: list[BaseTool], valid_tool_names: list[str], request: ModelRequest[ContextT], ) -> ModelRequest[ContextT]: """ 处理工具筛选响应,并保留空结果回退所有工具的 MoviePilot 策略。 """ if response.get("tools") == []: always_included_tools: list[BaseTool] = [ tool for tool in request.tools if not isinstance(tool, dict) and tool.name in self.always_include ] provider_tools = [tool for tool in request.tools if isinstance(tool, dict)] return request.override( tools=[*available_tools, *always_included_tools, *provider_tools] ) response["tools"] = self._complete_low_count_selection( selected_tool_names=[ tool_name for tool_name in response.get("tools", []) if isinstance(tool_name, str) ], valid_tool_names=valid_tool_names, available_tools=available_tools, ) modified_request = super()._process_selection_response( response, available_tools, valid_tool_names, request, ) return modified_request @staticmethod def _parse_json_object(text: str) -> dict[str, Any]: """ 解析模型返回的 JSON。 不同模型可能偶发输出 Markdown 围栏或前后说明文本,因此这里从 响应中提取第一个 JSON 对象作为兜底。 """ stripped_text = text.strip() if not stripped_text: raise ValueError("工具筛选返回了空响应") try: payload = json.loads(stripped_text) if isinstance(payload, dict): return payload except json.JSONDecodeError: pass start = stripped_text.find("{") end = stripped_text.rfind("}") if start == -1 or end == -1 or end <= start: raise ValueError(f"工具筛选返回的内容不是合法 JSON: {stripped_text}") payload = json.loads(stripped_text[start: end + 1]) if not isinstance(payload, dict): raise ValueError("工具筛选 JSON 顶层必须是对象") return payload @classmethod def _render_tool_list(cls, available_tools: list[Any]) -> str: """把工具名和描述渲染成稳定的文本列表。""" lines = [] for tool in available_tools: tags = cls._normalize_tool_tags(tool) tag_text = f" [group tags: {', '.join(tags)}]" if tags else "" lines.append(f"- {tool.name}{tag_text}: {tool.description}") return "\n".join(lines) @classmethod def _render_tool_groups(cls, available_tools: list[BaseTool]) -> str: """把当前可用工具按标签渲染成能力组提示。""" valid_tool_names = [ tool.name for tool in available_tools if getattr(tool, "name", None) ] groups = cls._build_tool_groups( available_tools=available_tools, valid_tool_names=valid_tool_names, ) if not groups: return "" rendered_groups = "\n".join( f"- {tag}: {', '.join(tool_names)}" for tag, tool_names in groups ) return f"Capability groups from tool tags:\n{rendered_groups}\n\n" def _build_json_selection_prompt(self, selection_request: Any) -> str: """ 生成显式 JSON 输出提示。 使用纯提示约束可覆盖更多兼容端点,避免在工具筛选阶段依赖某个 provider 专属的 `response_format` 或 schema 能力。 """ limit_instruction = "" if self.max_tools: limit_instruction = f"- Select up to {self.max_tools} tools. IF NO TOOLS ARE RELEVANT, DO NOT RETURN AN EMPTY ARRAY. SELECT THE MOST APPLICABLE ONES TO ENSURE THE REQUEST IS HANDLED." return ( f"{selection_request.system_message}\n\n" "Return the answer in JSON only.\n" 'Use exactly this shape: {"tools": ["tool_name_1", "tool_name_2"]}\n' "Rules:\n" "- The `tools` field must be a JSON array of strings.\n" "- Only use tool names from the allowed list below.\n" "- Order tools by relevance, with the most relevant first.\n" "- Tools sharing the same capability tag are in the same group; include same-group tools together when relevant.\n" f"{limit_instruction}\n" "- Do not add explanations, markdown, or extra keys.\n\n" f"{self._render_tool_groups(selection_request.available_tools)}" "Allowed tools:\n" f"{self._render_tool_list(selection_request.available_tools)}" ) def _normalize_selection_response(self, response: Any) -> dict[str, list[str]]: """ 解析并标准化显式 JSON 模式的工具筛选结果。 """ content = getattr(response, "content", response) text = LLMHelper.extract_text_content(content) logger.debug(f"工具筛选原始响应: {text}") payload = self._parse_json_object(text) tools = payload.get("tools") if not isinstance(tools, list): raise ValueError(f"工具筛选 JSON 缺少 `tools` 数组: {payload}") normalized_tools = [ tool_name for tool_name in tools if isinstance(tool_name, str) ] logger.debug(f"工具筛选标准化结果: {normalized_tools}") return {"tools": normalized_tools} async def _aselect_tools_with_json_prompt( self, selection_request: Any ) -> dict[str, list[str]]: """ 使用 JSON 提示执行异步工具筛选。 :param selection_request: LangChain 工具筛选请求 :return: 标准化后的工具名列表 """ logger.debug("工具筛选走 JSON 提示分支") response = await selection_request.model.ainvoke( [ SystemMessage( content=self._build_json_selection_prompt(selection_request) ), selection_request.last_user_message, ] ) return self._normalize_selection_response(response) @staticmethod def _extract_selected_tool_names(request: ModelRequest) -> list[str]: """从已筛选后的请求中提取最终工具名,保留原有顺序。""" return [tool.name for tool in request.tools if not isinstance(tool, dict)] @staticmethod def _count_request_tools(request: ModelRequest) -> int: """统计当前请求中的 LangChain 工具数量,不包含 provider 原生工具字典。""" return len([tool for tool in request.tools if not isinstance(tool, dict)]) @classmethod def _log_selection_attempt(cls, attempt: _ToolSelectionAttempt) -> None: """按工具筛选最终状态记录稳定日志。""" tool_count = cls._count_request_tools(attempt.request) if attempt.status == "selected": selected_text = ", ".join(attempt.selected_tool_names) or "无有效工具" logger.info(f"工具筛选结果: {selected_text}") return if attempt.status == "empty_fallback": logger.info(f"工具筛选结果为空,将恢复使用所有工具(共 {tool_count} 个)。") return if attempt.status == "failed_fallback": logger.warning( f"工具筛选失败,将恢复使用所有工具(共 {tool_count} 个): {attempt.detail}" ) return if attempt.status == "skipped": logger.info(f"工具筛选跳过: {attempt.detail}。") return if attempt.status == "reused": selected_text = ", ".join(attempt.selected_tool_names) or "无有效工具" logger.info(f"工具筛选复用已有结果: {selected_text}") @staticmethod def _apply_selected_tools( request: ModelRequest[ContextT], selected_tool_names: list[str], ) -> ModelRequest[ContextT]: """ 将已筛选出的工具集应用到当前模型请求。 这里只复用首次筛选出的客户端工具名;provider-specific 的 dict 工具仍然 原样保留,避免破坏 LangChain/provider 自身的工具绑定约定。 """ if not selected_tool_names: return request current_tools_by_name = { tool.name: tool for tool in request.tools if not isinstance(tool, dict) } selected_tools = [ current_tools_by_name[tool_name] for tool_name in selected_tool_names if tool_name in current_tools_by_name ] provider_tools = [tool for tool in request.tools if isinstance(tool, dict)] return request.override(tools=[*selected_tools, *provider_tools]) async def _aselect_request_once( self, request: ModelRequest[ContextT] ) -> ModelRequest[ContextT]: """ 执行一次真实工具筛选,并返回筛选后的请求对象。 这里单独抽成 helper,便于首次筛选后缓存结果,也便于测试覆盖 “首轮筛选,后续复用”的行为。 """ return (await self._aselect_request_once_with_status(request)).request async def _aselect_request_once_with_status( self, request: ModelRequest[ContextT] ) -> _ToolSelectionAttempt: """ 执行一次真实工具筛选,并携带最终状态供调用方统一记录日志。 """ selection_request = self._prepare_selection_request(request) if selection_request is None: return _ToolSelectionAttempt( request=request, selected_tool_names=self._extract_selected_tool_names(request), status="skipped", detail="没有需要筛选的工具", ) try: response = await self._aselect_tools_with_json_prompt(selection_request) modified_request = self._process_selection_response( response, selection_request.available_tools, selection_request.valid_tool_names, request, ) status = ( "empty_fallback" if response.get("tools") == [] else "selected" ) return _ToolSelectionAttempt( request=modified_request, selected_tool_names=self._extract_selected_tool_names(modified_request), status=status, ) except Exception as err: return _ToolSelectionAttempt( request=request, selected_tool_names=self._extract_selected_tool_names(request), status="failed_fallback", detail=str(err), ) async def abefore_agent( # noqa self, state: ToolSelectionState, runtime: Runtime, # noqa config: RunnableConfig, ) -> ToolSelectionStateUpdate | None: # ty: ignore[invalid-method-override] """ 在本轮 Agent 执行开始前完成一次真实工具筛选。 这样后续多轮 `model -> tools -> model` 循环都只复用这一次结果, 不会为每次模型回合重复追加一笔 selector LLM 开销。 """ if "selected_tool_names" in state: self._log_selection_attempt( _ToolSelectionAttempt( request=ModelRequest( model=self.model, tools=list(self.selection_tools), messages=state["messages"], state=state, runtime=runtime, ), selected_tool_names=state.get("selected_tool_names") or [], status="reused", ) ) return None if not self.selection_tools or self.model is None: detail = "没有可筛选工具" if not self.selection_tools else "未配置筛选模型" self._log_selection_attempt( _ToolSelectionAttempt( request=ModelRequest( model=self.model, tools=list(self.selection_tools), messages=state["messages"], state=state, runtime=runtime, ), selected_tool_names=[], status="skipped", detail=detail, ) ) return ToolSelectionStateUpdate(selected_tool_names=None) selection_request = ModelRequest( model=self.model, tools=list(self.selection_tools), messages=state["messages"], state=state, runtime=runtime, ) attempt = await self._aselect_request_once_with_status(selection_request) self._log_selection_attempt(attempt) selected_tool_names = attempt.selected_tool_names return ToolSelectionStateUpdate(selected_tool_names=selected_tool_names or None) async def awrap_model_call( self, request: ModelRequest[ContextT], handler: Callable[ [ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]] ], ) -> ModelResponse[ResponseT]: """ 从 state 中读取首次筛选结果,并应用到每次模型回合。 """ selected_tool_names = request.state.get("selected_tool_names") # noqa # 正常路径下,`abefore_agent()` 已经提前写入状态;这里只保留一层兜底, # 兼容直接单测或未来某些绕过 before_agent 的调用场景。 if ( selected_tool_names is None and self.selection_tools and self.model is not None ): attempt = await self._aselect_request_once_with_status(request) self._log_selection_attempt(attempt) request = attempt.request selected_tool_names = attempt.selected_tool_names or None request.state["selected_tool_names"] = selected_tool_names # noqa if selected_tool_names: request = self._apply_selected_tools(request, selected_tool_names) return await handler(request)