mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-04-01 01:41:59 +08:00
44 lines
1.9 KiB
Python
44 lines
1.9 KiB
Python
from typing import Any
|
|
|
|
from langchain.agents.middleware import AgentMiddleware, AgentState
|
|
from langchain_core.messages import AIMessage, ToolMessage
|
|
from langgraph.runtime import Runtime
|
|
from langgraph.types import Overwrite
|
|
|
|
|
|
class PatchToolCallsMiddleware(AgentMiddleware):
|
|
"""修复消息历史中悬空工具调用的中间件。"""
|
|
|
|
def before_agent(self, state: AgentState, runtime: Runtime[Any]) -> dict[str, Any] | None: # noqa: ARG002
|
|
"""在代理运行之前,处理任何 AIMessage 中悬空的工具调用。"""
|
|
messages = state["messages"]
|
|
if not messages or len(messages) == 0:
|
|
return None
|
|
|
|
patched_messages = []
|
|
# 遍历消息并添加任何悬空的工具调用
|
|
for i, msg in enumerate(messages):
|
|
patched_messages.append(msg)
|
|
if isinstance(msg, AIMessage) and msg.tool_calls:
|
|
for tool_call in msg.tool_calls:
|
|
corresponding_tool_msg = next(
|
|
(msg for msg in messages[i:] if msg.type == "tool" and msg.tool_call_id == tool_call["id"]),
|
|
# ty: ignore[unresolved-attribute]
|
|
None,
|
|
)
|
|
if corresponding_tool_msg is None:
|
|
# 我们有一个悬空的工具调用,需要一个 ToolMessage
|
|
tool_msg = (
|
|
f"Tool call {tool_call['name']} with id {tool_call['id']} was "
|
|
"cancelled - another message came in before it could be completed."
|
|
)
|
|
patched_messages.append(
|
|
ToolMessage(
|
|
content=tool_msg,
|
|
name=tool_call["name"],
|
|
tool_call_id=tool_call["id"],
|
|
)
|
|
)
|
|
|
|
return {"messages": Overwrite(patched_messages)}
|