mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-03-20 03:57:30 +08:00
fix(agent):修复智能体工具调用,优化媒体库查询工具
This commit is contained in:
@@ -98,15 +98,14 @@ class MoviePilotAgent:
|
||||
user_id=self.user_id
|
||||
)
|
||||
if messages:
|
||||
loaded_messages = []
|
||||
for msg in messages:
|
||||
if msg.get("role") == "user":
|
||||
loaded_messages.append(HumanMessage(content=msg.get("content", "")))
|
||||
chat_history.add_message(HumanMessage(content=msg.get("content", "")))
|
||||
elif msg.get("role") == "agent":
|
||||
loaded_messages.append(AIMessage(content=msg.get("content", "")))
|
||||
chat_history.add_message(AIMessage(content=msg.get("content", "")))
|
||||
elif msg.get("role") == "tool_call":
|
||||
metadata = msg.get("metadata", {})
|
||||
loaded_messages.append(
|
||||
chat_history.add_message(
|
||||
AIMessage(
|
||||
content=msg.get("content", ""),
|
||||
tool_calls=[
|
||||
@@ -120,17 +119,12 @@ class MoviePilotAgent:
|
||||
)
|
||||
elif msg.get("role") == "tool_result":
|
||||
metadata = msg.get("metadata", {})
|
||||
loaded_messages.append(ToolMessage(
|
||||
chat_history.add_message(ToolMessage(
|
||||
content=msg.get("content", ""),
|
||||
tool_call_id=metadata.get("call_id", "unknown")
|
||||
))
|
||||
elif msg.get("role") == "system":
|
||||
loaded_messages.append(SystemMessage(content=msg.get("content", "")))
|
||||
|
||||
# 验证并修复工具调用的完整性
|
||||
validated_messages = self._ensure_tool_call_integrity(loaded_messages)
|
||||
for msg in validated_messages:
|
||||
chat_history.add_message(msg)
|
||||
chat_history.add_message(SystemMessage(content=msg.get("content", "")))
|
||||
|
||||
return chat_history
|
||||
|
||||
@@ -199,71 +193,6 @@ class MoviePilotAgent:
|
||||
# 发生错误时返回一个保守的估算值
|
||||
return len(str(messages)) // 4
|
||||
|
||||
def _ensure_tool_call_integrity(self, messages: List[Union[HumanMessage, AIMessage, ToolMessage, SystemMessage]]) \
|
||||
-> List[Union[HumanMessage, AIMessage, ToolMessage, SystemMessage]]:
|
||||
"""
|
||||
确保工具调用的完整性:
|
||||
1. 带有 tool_calls 的 AIMessage 必须后跟相应的 ToolMessages
|
||||
2. 移除不完整的 AIMessages(有 tool_calls 但缺少 ToolMessage 响应)及其部分 ToolMessages
|
||||
3. 过滤掉不对应任何 tool_call 的孤立 ToolMessages
|
||||
"""
|
||||
if not messages:
|
||||
return messages
|
||||
|
||||
validated_messages = []
|
||||
i = 0
|
||||
|
||||
while i < len(messages):
|
||||
msg = messages[i]
|
||||
|
||||
# 检查这是否是一个带有 tool_calls 的 AIMessage
|
||||
if isinstance(msg, AIMessage) and getattr(msg, 'tool_calls', None):
|
||||
# 提取 tool_call IDs(ToolCall 是 TypedDict,在运行时是 dict)
|
||||
tool_call_ids = {
|
||||
tc.get('id') if tc.get('id') is not None else tc.get('name')
|
||||
for tc in msg.tool_calls
|
||||
}
|
||||
|
||||
# 查找后续的 ToolMessages
|
||||
j = i + 1
|
||||
found_tool_messages = []
|
||||
while j < len(messages) and isinstance(messages[j], ToolMessage):
|
||||
found_tool_messages.append(messages[j])
|
||||
j += 1
|
||||
|
||||
# 检查所有 tool_calls 是否都有对应的 ToolMessages
|
||||
found_tool_call_ids = {tm.tool_call_id for tm in found_tool_messages}
|
||||
|
||||
# 警告:如果存在不对应任何 tool_call 的额外 ToolMessages
|
||||
extra_tool_messages = found_tool_call_ids - tool_call_ids
|
||||
if extra_tool_messages:
|
||||
logger.warning(f"Found extra ToolMessages that don't correspond to any tool_call: {extra_tool_messages}")
|
||||
|
||||
if not tool_call_ids.issubset(found_tool_call_ids):
|
||||
# 缺少部分 tool_call 响应,跳过此 AIMessage 及其所有 ToolMessages
|
||||
logger.warning("Removing incomplete tool_call AIMessage and its partial ToolMessages: missing tool_call responses")
|
||||
i = j # 跳过此 AIMessage 及其所有相关的 ToolMessages
|
||||
continue
|
||||
else:
|
||||
# 添加此 AIMessage 及其对应的 ToolMessages
|
||||
validated_messages.append(msg)
|
||||
for tm in found_tool_messages:
|
||||
if tm.tool_call_id in tool_call_ids:
|
||||
validated_messages.append(tm)
|
||||
i = j
|
||||
continue
|
||||
# 跳过孤立的 ToolMessages(前面没有带 tool_calls 的 AIMessage)
|
||||
elif isinstance(msg, ToolMessage):
|
||||
logger.warning(f"Skipping orphaned ToolMessage with tool_call_id: {msg.tool_call_id}")
|
||||
i += 1
|
||||
continue
|
||||
else:
|
||||
validated_messages.append(msg)
|
||||
|
||||
i += 1
|
||||
|
||||
return validated_messages
|
||||
|
||||
def _create_agent_executor(self) -> RunnableWithMessageHistory:
|
||||
"""
|
||||
创建Agent执行器
|
||||
@@ -281,8 +210,13 @@ class MoviePilotAgent:
|
||||
|
||||
# 包装trimmer,在裁剪后验证工具调用的完整性
|
||||
def validated_trimmer(messages):
|
||||
trimmed = base_trimmer(messages)
|
||||
return self._ensure_tool_call_integrity(trimmed)
|
||||
# 如果输入是 PromptValue,转换为消息列表
|
||||
if hasattr(messages, "to_messages"):
|
||||
messages = messages.to_messages()
|
||||
trimmed = base_trimmer.invoke(messages)
|
||||
if len(trimmed) < len(messages):
|
||||
logger.info(f"LangChain消息上下文已裁剪: {len(messages)} -> {len(trimmed)}")
|
||||
return trimmed
|
||||
|
||||
# 创建Agent执行链
|
||||
agent = (
|
||||
|
||||
Reference in New Issue
Block a user