diff --git a/app/agent/__init__.py b/app/agent/__init__.py index 83691f6e..2144d2c8 100644 --- a/app/agent/__init__.py +++ b/app/agent/__init__.py @@ -127,7 +127,7 @@ class MoviePilotAgent: elif msg.get("role") == "system": loaded_messages.append(SystemMessage(content=msg.get("content", ""))) - # 验证并修复工具调用的完整性 + # Validate and fix tool call integrity validated_messages = self._ensure_tool_call_integrity(loaded_messages) for msg in validated_messages: chat_history.add_message(msg) @@ -202,9 +202,10 @@ class MoviePilotAgent: def _ensure_tool_call_integrity(self, messages: List[Union[HumanMessage, AIMessage, ToolMessage, SystemMessage]]) \ -> List[Union[HumanMessage, AIMessage, ToolMessage, SystemMessage]]: """ - 确保工具调用的完整性: - 1. 如果AIMessage包含tool_calls,必须后跟相应的ToolMessage - 2. 移除孤立的AIMessage(有tool_calls但没有对应的ToolMessage)及其关联的ToolMessage + Ensure tool call integrity: + 1. AIMessage with tool_calls must be followed by corresponding ToolMessages + 2. Remove incomplete AIMessages (with tool_calls but missing ToolMessage responses) and their partial ToolMessages + 3. Filter out orphaned ToolMessages that don't correspond to any tool_call """ if not messages: return messages @@ -215,32 +216,39 @@ class MoviePilotAgent: while i < len(messages): msg = messages[i] - # 检查是否是包含tool_calls的AIMessage + # Check if this is an AIMessage with tool_calls if isinstance(msg, AIMessage) and getattr(msg, 'tool_calls', None): - tool_call_ids = {tc.get('id') if isinstance(tc, dict) else tc.id - for tc in msg.tool_calls} + # Extract tool_call IDs (ToolCall is a TypedDict, so it's a dict at runtime) + tool_call_ids = {tc.get('id') or tc.get('name') for tc in msg.tool_calls} - # 查找后续的ToolMessage + # Find subsequent ToolMessages j = i + 1 found_tool_messages = [] while j < len(messages) and isinstance(messages[j], ToolMessage): found_tool_messages.append(messages[j]) j += 1 - # 检查是否所有tool_call都有对应的ToolMessage + # Check if all tool_calls have corresponding ToolMessages found_tool_call_ids = {tm.tool_call_id for tm in found_tool_messages} if not tool_call_ids.issubset(found_tool_call_ids): - # 如果缺少某些tool_call的响应,移除这个AIMessage及其相关的ToolMessage - logger.warning(f"移除不完整的tool_call AIMessage及其部分ToolMessage: 缺少tool_call响应") - i = j # 跳过AIMessage和所有相关的ToolMessage + # Missing some tool_call responses, skip this AIMessage and all its ToolMessages + logger.warning("Removing incomplete tool_call AIMessage and its partial ToolMessages: missing tool_call responses") + i = j # Skip the AIMessage and all related ToolMessages continue else: - # 添加AIMessage和所有对应的ToolMessage + # Add the AIMessage and only the ToolMessages that correspond to its tool_calls validated_messages.append(msg) - validated_messages.extend(found_tool_messages) + for tm in found_tool_messages: + if tm.tool_call_id in tool_call_ids: + validated_messages.append(tm) i = j continue + # Skip orphaned ToolMessages (not preceded by an AIMessage with tool_calls) + elif isinstance(msg, ToolMessage): + logger.warning(f"Skipping orphaned ToolMessage with tool_call_id: {msg.tool_call_id}") + i += 1 + continue else: validated_messages.append(msg)