from typing import Any from langchain.agents.middleware import AgentMiddleware, AgentState from langchain_core.messages import AIMessage, ToolMessage from langgraph.runtime import Runtime from langgraph.types import Overwrite class PatchToolCallsMiddleware(AgentMiddleware): """修复消息历史中悬空工具调用的中间件。""" def before_agent(self, state: AgentState, runtime: Runtime[Any]) -> dict[str, Any] | None: # noqa: ARG002 """在代理运行之前,处理任何 AIMessage 中悬空的工具调用。""" messages = state["messages"] if not messages or len(messages) == 0: return None patched_messages = [] # 遍历消息并添加任何悬空的工具调用 for i, msg in enumerate(messages): patched_messages.append(msg) if isinstance(msg, AIMessage) and msg.tool_calls: for tool_call in msg.tool_calls: corresponding_tool_msg = next( (msg for msg in messages[i:] if msg.type == "tool" and msg.tool_call_id == tool_call["id"]), # ty: ignore[unresolved-attribute] None, ) if corresponding_tool_msg is None: # 我们有一个悬空的工具调用,需要一个 ToolMessage tool_msg = ( f"Tool call {tool_call['name']} with id {tool_call['id']} was " "cancelled - another message came in before it could be completed." ) patched_messages.append( ToolMessage( content=tool_msg, name=tool_call["name"], tool_call_id=tool_call["id"], ) ) return {"messages": Overwrite(patched_messages)}