diff --git a/app/agent/__init__.py b/app/agent/__init__.py index e9e849ae..7c1f31f2 100644 --- a/app/agent/__init__.py +++ b/app/agent/__init__.py @@ -98,14 +98,15 @@ class MoviePilotAgent: user_id=self.user_id ) if messages: + loaded_messages = [] for msg in messages: if msg.get("role") == "user": - chat_history.add_message(HumanMessage(content=msg.get("content", ""))) + loaded_messages.append(HumanMessage(content=msg.get("content", ""))) elif msg.get("role") == "agent": - chat_history.add_message(AIMessage(content=msg.get("content", ""))) + loaded_messages.append(AIMessage(content=msg.get("content", ""))) elif msg.get("role") == "tool_call": metadata = msg.get("metadata", {}) - chat_history.add_message( + loaded_messages.append( AIMessage( content=msg.get("content", ""), tool_calls=[ @@ -119,12 +120,18 @@ class MoviePilotAgent: ) elif msg.get("role") == "tool_result": metadata = msg.get("metadata", {}) - chat_history.add_message(ToolMessage( + loaded_messages.append(ToolMessage( content=msg.get("content", ""), tool_call_id=metadata.get("call_id", "unknown") )) elif msg.get("role") == "system": - chat_history.add_message(SystemMessage(content=msg.get("content", ""))) + loaded_messages.append(SystemMessage(content=msg.get("content", ""))) + + # Validate and fix tool call integrity + validated_messages = self._ensure_tool_call_integrity(loaded_messages) + for msg in validated_messages: + chat_history.add_message(msg) + return chat_history @staticmethod @@ -192,13 +199,78 @@ class MoviePilotAgent: # 发生错误时返回一个保守的估算值 return len(str(messages)) // 4 + def _ensure_tool_call_integrity(self, messages: List[Union[HumanMessage, AIMessage, ToolMessage, SystemMessage]]) \ + -> List[Union[HumanMessage, AIMessage, ToolMessage, SystemMessage]]: + """ + Ensure tool call integrity: + 1. AIMessage with tool_calls must be followed by corresponding ToolMessages + 2. Remove incomplete AIMessages (with tool_calls but missing ToolMessage responses) and their partial ToolMessages + 3. Filter out orphaned ToolMessages that don't correspond to any tool_call + """ + if not messages: + return messages + + validated_messages = [] + i = 0 + + while i < len(messages): + msg = messages[i] + + # Check if this is an AIMessage with tool_calls + if isinstance(msg, AIMessage) and getattr(msg, 'tool_calls', None): + # Extract tool_call IDs (ToolCall is a TypedDict, so it's a dict at runtime) + tool_call_ids = { + tc.get('id') if tc.get('id') is not None else tc.get('name') + for tc in msg.tool_calls + } + + # Find subsequent ToolMessages + j = i + 1 + found_tool_messages = [] + while j < len(messages) and isinstance(messages[j], ToolMessage): + found_tool_messages.append(messages[j]) + j += 1 + + # Check if all tool_calls have corresponding ToolMessages + found_tool_call_ids = {tm.tool_call_id for tm in found_tool_messages} + + # Warn if there are extra ToolMessages that don't correspond to any tool_call + extra_tool_messages = found_tool_call_ids - tool_call_ids + if extra_tool_messages: + logger.warning(f"Found extra ToolMessages that don't correspond to any tool_call: {extra_tool_messages}") + + if not tool_call_ids.issubset(found_tool_call_ids): + # Missing some tool_call responses, skip this AIMessage and all its ToolMessages + logger.warning("Removing incomplete tool_call AIMessage and its partial ToolMessages: missing tool_call responses") + i = j # Skip the AIMessage and all related ToolMessages + continue + else: + # Add the AIMessage and only the ToolMessages that correspond to its tool_calls + validated_messages.append(msg) + for tm in found_tool_messages: + if tm.tool_call_id in tool_call_ids: + validated_messages.append(tm) + i = j + continue + # Skip orphaned ToolMessages (not preceded by an AIMessage with tool_calls) + elif isinstance(msg, ToolMessage): + logger.warning(f"Skipping orphaned ToolMessage with tool_call_id: {msg.tool_call_id}") + i += 1 + continue + else: + validated_messages.append(msg) + + i += 1 + + return validated_messages + def _create_agent_executor(self) -> RunnableWithMessageHistory: """ 创建Agent执行器 """ try: # 消息裁剪器,防止上下文超出限制 - trimmer = trim_messages( + base_trimmer = trim_messages( max_tokens=settings.LLM_MAX_CONTEXT_TOKENS * 1000 * 0.8, strategy="last", token_counter=self._token_counter, @@ -206,6 +278,12 @@ class MoviePilotAgent: allow_partial=False, start_on="human", ) + + # 包装trimmer,在裁剪后验证工具调用的完整性 + def validated_trimmer(messages): + trimmed = base_trimmer(messages) + return self._ensure_tool_call_integrity(trimmed) + # 创建Agent执行链 agent = ( RunnablePassthrough.assign( @@ -214,7 +292,7 @@ class MoviePilotAgent: ) ) | self.prompt - | trimmer + | validated_trimmer | self.llm.bind_tools(self.tools) | OpenAIToolsAgentOutputParser() )