Fix agent tool_calls integrity validation

Co-authored-by: jxxghp <51039935+jxxghp@users.noreply.github.com>
This commit is contained in:
copilot-swe-agent[bot]
2026-01-24 00:02:47 +00:00
parent 3412498438
commit f4157b52ea

View File

@@ -98,14 +98,15 @@ class MoviePilotAgent:
user_id=self.user_id
)
if messages:
loaded_messages = []
for msg in messages:
if msg.get("role") == "user":
chat_history.add_message(HumanMessage(content=msg.get("content", "")))
loaded_messages.append(HumanMessage(content=msg.get("content", "")))
elif msg.get("role") == "agent":
chat_history.add_message(AIMessage(content=msg.get("content", "")))
loaded_messages.append(AIMessage(content=msg.get("content", "")))
elif msg.get("role") == "tool_call":
metadata = msg.get("metadata", {})
chat_history.add_message(
loaded_messages.append(
AIMessage(
content=msg.get("content", ""),
tool_calls=[
@@ -119,12 +120,18 @@ class MoviePilotAgent:
)
elif msg.get("role") == "tool_result":
metadata = msg.get("metadata", {})
chat_history.add_message(ToolMessage(
loaded_messages.append(ToolMessage(
content=msg.get("content", ""),
tool_call_id=metadata.get("call_id", "unknown")
))
elif msg.get("role") == "system":
chat_history.add_message(SystemMessage(content=msg.get("content", "")))
loaded_messages.append(SystemMessage(content=msg.get("content", "")))
# 验证并修复工具调用的完整性
validated_messages = self._ensure_tool_call_integrity(loaded_messages)
for msg in validated_messages:
chat_history.add_message(msg)
return chat_history
@staticmethod
@@ -192,13 +199,62 @@ class MoviePilotAgent:
# 发生错误时返回一个保守的估算值
return len(str(messages)) // 4
def _ensure_tool_call_integrity(self, messages: List[Union[HumanMessage, AIMessage, ToolMessage, SystemMessage]]) \
-> List[Union[HumanMessage, AIMessage, ToolMessage, SystemMessage]]:
"""
确保工具调用的完整性:
1. 如果AIMessage包含tool_calls必须后跟相应的ToolMessage
2. 移除孤立的AIMessage有tool_calls但没有对应的ToolMessage
"""
if not messages:
return messages
validated_messages = []
i = 0
while i < len(messages):
msg = messages[i]
# 检查是否是包含tool_calls的AIMessage
if isinstance(msg, AIMessage) and getattr(msg, 'tool_calls', None):
tool_call_ids = {tc.get('id') if isinstance(tc, dict) else tc.id
for tc in msg.tool_calls}
# 查找后续的ToolMessage
j = i + 1
found_tool_messages = []
while j < len(messages) and isinstance(messages[j], ToolMessage):
found_tool_messages.append(messages[j])
j += 1
# 检查是否所有tool_call都有对应的ToolMessage
found_tool_call_ids = {tm.tool_call_id for tm in found_tool_messages}
if not tool_call_ids.issubset(found_tool_call_ids):
# 如果缺少某些tool_call的响应移除这个AIMessage
logger.warning(f"移除不完整的tool_call AIMessage: 缺少tool_call响应")
i += 1
continue
else:
# 添加AIMessage和所有对应的ToolMessage
validated_messages.append(msg)
validated_messages.extend(found_tool_messages)
i = j
continue
else:
validated_messages.append(msg)
i += 1
return validated_messages
def _create_agent_executor(self) -> RunnableWithMessageHistory:
"""
创建Agent执行器
"""
try:
# 消息裁剪器,防止上下文超出限制
trimmer = trim_messages(
base_trimmer = trim_messages(
max_tokens=settings.LLM_MAX_CONTEXT_TOKENS * 1000 * 0.8,
strategy="last",
token_counter=self._token_counter,
@@ -206,6 +262,12 @@ class MoviePilotAgent:
allow_partial=False,
start_on="human",
)
# 包装trimmer在裁剪后验证工具调用的完整性
def validated_trimmer(messages):
trimmed = base_trimmer(messages)
return self._ensure_tool_call_integrity(trimmed)
# 创建Agent执行链
agent = (
RunnablePassthrough.assign(
@@ -214,7 +276,7 @@ class MoviePilotAgent:
)
)
| self.prompt
| trimmer
| validated_trimmer
| self.llm.bind_tools(self.tools)
| OpenAIToolsAgentOutputParser()
)