From ec40f3611438f36a4a3d34e5ea69815f748c33da Mon Sep 17 00:00:00 2001 From: jxxghp Date: Sat, 24 Jan 2026 09:46:19 +0800 Subject: [PATCH] =?UTF-8?q?fix(agent)=EF=BC=9A=E4=BF=AE=E5=A4=8D=E6=99=BA?= =?UTF-8?q?=E8=83=BD=E4=BD=93=E5=B7=A5=E5=85=B7=E8=B0=83=E7=94=A8=EF=BC=8C?= =?UTF-8?q?=E4=BC=98=E5=8C=96=E5=AA=92=E4=BD=93=E5=BA=93=E6=9F=A5=E8=AF=A2?= =?UTF-8?q?=E5=B7=A5=E5=85=B7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/agent/__init__.py | 90 +++----------------- app/agent/memory/__init__.py | 2 +- app/agent/tools/base.py | 18 ++-- app/agent/tools/impl/query_library_exists.py | 82 +++++++++++++----- app/chain/message.py | 22 +++-- 5 files changed, 97 insertions(+), 117 deletions(-) diff --git a/app/agent/__init__.py b/app/agent/__init__.py index daba432b..c341120b 100644 --- a/app/agent/__init__.py +++ b/app/agent/__init__.py @@ -98,15 +98,14 @@ class MoviePilotAgent: user_id=self.user_id ) if messages: - loaded_messages = [] for msg in messages: if msg.get("role") == "user": - loaded_messages.append(HumanMessage(content=msg.get("content", ""))) + chat_history.add_message(HumanMessage(content=msg.get("content", ""))) elif msg.get("role") == "agent": - loaded_messages.append(AIMessage(content=msg.get("content", ""))) + chat_history.add_message(AIMessage(content=msg.get("content", ""))) elif msg.get("role") == "tool_call": metadata = msg.get("metadata", {}) - loaded_messages.append( + chat_history.add_message( AIMessage( content=msg.get("content", ""), tool_calls=[ @@ -120,17 +119,12 @@ class MoviePilotAgent: ) elif msg.get("role") == "tool_result": metadata = msg.get("metadata", {}) - loaded_messages.append(ToolMessage( + chat_history.add_message(ToolMessage( content=msg.get("content", ""), tool_call_id=metadata.get("call_id", "unknown") )) elif msg.get("role") == "system": - loaded_messages.append(SystemMessage(content=msg.get("content", ""))) - - # 验证并修复工具调用的完整性 - validated_messages = self._ensure_tool_call_integrity(loaded_messages) - for msg in validated_messages: - chat_history.add_message(msg) + chat_history.add_message(SystemMessage(content=msg.get("content", ""))) return chat_history @@ -199,71 +193,6 @@ class MoviePilotAgent: # 发生错误时返回一个保守的估算值 return len(str(messages)) // 4 - def _ensure_tool_call_integrity(self, messages: List[Union[HumanMessage, AIMessage, ToolMessage, SystemMessage]]) \ - -> List[Union[HumanMessage, AIMessage, ToolMessage, SystemMessage]]: - """ - 确保工具调用的完整性: - 1. 带有 tool_calls 的 AIMessage 必须后跟相应的 ToolMessages - 2. 移除不完整的 AIMessages(有 tool_calls 但缺少 ToolMessage 响应)及其部分 ToolMessages - 3. 过滤掉不对应任何 tool_call 的孤立 ToolMessages - """ - if not messages: - return messages - - validated_messages = [] - i = 0 - - while i < len(messages): - msg = messages[i] - - # 检查这是否是一个带有 tool_calls 的 AIMessage - if isinstance(msg, AIMessage) and getattr(msg, 'tool_calls', None): - # 提取 tool_call IDs(ToolCall 是 TypedDict,在运行时是 dict) - tool_call_ids = { - tc.get('id') if tc.get('id') is not None else tc.get('name') - for tc in msg.tool_calls - } - - # 查找后续的 ToolMessages - j = i + 1 - found_tool_messages = [] - while j < len(messages) and isinstance(messages[j], ToolMessage): - found_tool_messages.append(messages[j]) - j += 1 - - # 检查所有 tool_calls 是否都有对应的 ToolMessages - found_tool_call_ids = {tm.tool_call_id for tm in found_tool_messages} - - # 警告:如果存在不对应任何 tool_call 的额外 ToolMessages - extra_tool_messages = found_tool_call_ids - tool_call_ids - if extra_tool_messages: - logger.warning(f"Found extra ToolMessages that don't correspond to any tool_call: {extra_tool_messages}") - - if not tool_call_ids.issubset(found_tool_call_ids): - # 缺少部分 tool_call 响应,跳过此 AIMessage 及其所有 ToolMessages - logger.warning("Removing incomplete tool_call AIMessage and its partial ToolMessages: missing tool_call responses") - i = j # 跳过此 AIMessage 及其所有相关的 ToolMessages - continue - else: - # 添加此 AIMessage 及其对应的 ToolMessages - validated_messages.append(msg) - for tm in found_tool_messages: - if tm.tool_call_id in tool_call_ids: - validated_messages.append(tm) - i = j - continue - # 跳过孤立的 ToolMessages(前面没有带 tool_calls 的 AIMessage) - elif isinstance(msg, ToolMessage): - logger.warning(f"Skipping orphaned ToolMessage with tool_call_id: {msg.tool_call_id}") - i += 1 - continue - else: - validated_messages.append(msg) - - i += 1 - - return validated_messages - def _create_agent_executor(self) -> RunnableWithMessageHistory: """ 创建Agent执行器 @@ -281,8 +210,13 @@ class MoviePilotAgent: # 包装trimmer,在裁剪后验证工具调用的完整性 def validated_trimmer(messages): - trimmed = base_trimmer(messages) - return self._ensure_tool_call_integrity(trimmed) + # 如果输入是 PromptValue,转换为消息列表 + if hasattr(messages, "to_messages"): + messages = messages.to_messages() + trimmed = base_trimmer.invoke(messages) + if len(trimmed) < len(messages): + logger.info(f"LangChain消息上下文已裁剪: {len(messages)} -> {len(trimmed)}") + return trimmed # 创建Agent执行链 agent = ( diff --git a/app/agent/memory/__init__.py b/app/agent/memory/__init__.py index 7192ebda..d2957a4a 100644 --- a/app/agent/memory/__init__.py +++ b/app/agent/memory/__init__.py @@ -232,7 +232,7 @@ class ConversationMemoryManager: return [] # 获取所有消息 - return memory.messages + return memory.messages[:-1] async def get_recent_messages( self, diff --git a/app/agent/tools/base.py b/app/agent/tools/base.py index dc534034..5d05dd9a 100644 --- a/app/agent/tools/base.py +++ b/app/agent/tools/base.py @@ -1,4 +1,5 @@ import json +import uuid from abc import ABCMeta, abstractmethod from typing import Any, Optional @@ -42,6 +43,9 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta): # 获取工具调用前的agent消息 agent_message = await self._callback_handler.get_message() + # 生成唯一的工具调用ID + call_id = f"call_{str(uuid.uuid4())[:16]}" + # 记忆工具调用 await conversation_manager.add_conversation( session_id=self._session_id, @@ -49,8 +53,8 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta): role="tool_call", content=agent_message, metadata={ - "call_id": self.__class__.__name__, - "tool_name": self.__class__.__name__, + "call_id": call_id, + "tool_name": self.name, "parameters": kwargs } ) @@ -61,21 +65,21 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta): explanation = kwargs.get("explanation") if explanation: tool_message = explanation - + # 合并agent消息和工具执行消息,一起发送 messages = [] if agent_message: messages.append(agent_message) if tool_message: messages.append(f"⚙️ => {tool_message}") - + # 发送合并后的消息 if messages: merged_message = "\n\n".join(messages) await self.send_tool_message(merged_message, title="MoviePilot助手") logger.debug(f'Executing tool {self.name} with args: {kwargs}') - + # 执行工具,捕获异常确保结果总是被存储到记忆中 try: result = await self.run(**kwargs) @@ -93,13 +97,15 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta): formated_result = str(result) else: formated_result = json.dumps(result, ensure_ascii=False, indent=2) + await conversation_manager.add_conversation( session_id=self._session_id, user_id=self._user_id, role="tool_result", content=formated_result, metadata={ - "call_id": self.__class__.__name__ + "call_id": call_id, + "tool_name": self.name, } ) diff --git a/app/agent/tools/impl/query_library_exists.py b/app/agent/tools/impl/query_library_exists.py index 104c9c43..19a009d6 100644 --- a/app/agent/tools/impl/query_library_exists.py +++ b/app/agent/tools/impl/query_library_exists.py @@ -8,6 +8,7 @@ from pydantic import BaseModel, Field from app.agent.tools.base import MoviePilotTool from app.chain.mediaserver import MediaServerChain from app.core.context import MediaInfo +from app.core.meta import MetaBase from app.log import logger from app.schemas.types import MediaType @@ -51,47 +52,88 @@ class QueryLibraryExistsTool(MoviePilotTool): try: if not title: return "请提供媒体标题进行查询" - - # 创建 MediaInfo 对象 - mediainfo = MediaInfo() - mediainfo.title = title - mediainfo.year = year - - # 转换媒体类型 - if media_type == "电影": - mediainfo.type = MediaType.MOVIE - elif media_type == "电视剧": - mediainfo.type = MediaType.TV - # media_type == "all" 时不设置类型,让媒体服务器自动判断 - - # 调用媒体服务器接口实时查询 + media_chain = MediaServerChain() + + # 1. 识别媒体信息(获取 TMDB ID 和各季的总集数等元数据) + meta = MetaBase(title=title) + if year: + meta.year = str(year) + if media_type == "电影": + meta.type = MediaType.MOVIE + elif media_type == "电视剧": + meta.type = MediaType.TV + + # 使用识别方法补充信息 + recognize_info = media_chain.recognize_media(meta=meta) + if recognize_info: + mediainfo = recognize_info + else: + # 识别失败,创建基本信息的 MediaInfo + mediainfo = MediaInfo() + mediainfo.title = title + mediainfo.year = year + if media_type == "电影": + mediainfo.type = MediaType.MOVIE + elif media_type == "电视剧": + mediainfo.type = MediaType.TV + + # 2. 调用媒体服务器接口实时查询存在信息 existsinfo = media_chain.media_exists(mediainfo=mediainfo) - + if not existsinfo: return "媒体库中未找到相关媒体" - - # 如果找到了,获取详细信息 + + # 3. 如果找到了,获取详细信息并组装结果 result_items = [] if existsinfo.itemid and existsinfo.server: iteminfo = media_chain.iteminfo(server=existsinfo.server, item_id=existsinfo.itemid) if iteminfo: # 使用 model_dump() 转换为字典格式 item_dict = iteminfo.model_dump(exclude_none=True) + + # 对于电视剧,补充已存在的季集详情及进度统计 + if existsinfo.type == MediaType.TV: + # 注入已存在集信息 (Dict[int, list]) + item_dict["seasoninfo"] = existsinfo.seasons + + # 统计库中已存在的季集总数 + if existsinfo.seasons: + item_dict["existing_episodes_count"] = sum(len(e) for e in existsinfo.seasons.values()) + item_dict["seasons_existing_count"] = {str(s): len(e) for s, e in existsinfo.seasons.items()} + + # 如果识别到了元数据,补充总计对比和进度概览 + if mediainfo.seasons: + item_dict["seasons_total_count"] = {str(s): len(e) for s, e in mediainfo.seasons.items()} + # 进度概览,例如 "Season 1": "3/12" + item_dict["seasons_progress"] = { + f"第{s}季": f"{len(existsinfo.seasons.get(s, []))}/{len(mediainfo.seasons.get(s, []))} 集" + for s in mediainfo.seasons.keys() if (s in existsinfo.seasons or s > 0) + } + result_items.append(item_dict) - + if result_items: return json.dumps(result_items, ensure_ascii=False) - - # 如果找到了但没有详细信息,返回基本信息 + + # 如果找到了但没有获取到 iteminfo,返回基本信息 result_dict = { + "title": mediainfo.title, + "year": mediainfo.year, "type": existsinfo.type.value if existsinfo.type else None, "server": existsinfo.server, "server_type": existsinfo.server_type, "itemid": existsinfo.itemid, "seasons": existsinfo.seasons if existsinfo.seasons else {} } + if existsinfo.type == MediaType.TV and existsinfo.seasons: + result_dict["existing_episodes_count"] = sum(len(e) for e in existsinfo.seasons.values()) + result_dict["seasons_existing_count"] = {str(s): len(e) for s, e in existsinfo.seasons.items()} + if mediainfo.seasons: + result_dict["seasons_total_count"] = {str(s): len(e) for s, e in mediainfo.seasons.items()} + return json.dumps([result_dict], ensure_ascii=False) except Exception as e: logger.error(f"查询媒体库失败: {e}", exc_info=True) return f"查询媒体库时发生错误: {str(e)}" + diff --git a/app/chain/message.py b/app/chain/message.py index 275224bd..2396f1cd 100644 --- a/app/chain/message.py +++ b/app/chain/message.py @@ -842,8 +842,7 @@ class MessageChain(ChainBase): return buttons - @staticmethod - def _get_or_create_session_id(userid: Union[str, int]) -> str: + def _get_or_create_session_id(self, userid: Union[str, int]) -> str: """ 获取或创建会话ID 如果用户上次会话在15分钟内,则复用相同的会话ID;否则创建新的会话ID @@ -851,8 +850,8 @@ class MessageChain(ChainBase): current_time = datetime.now() # 检查用户是否有已存在的会话 - if userid in MessageChain._user_sessions: - session_id, last_time = MessageChain._user_sessions[userid] + if userid in self._user_sessions: + session_id, last_time = self._user_sessions[userid] # 计算时间差 time_diff = current_time - last_time @@ -860,25 +859,24 @@ class MessageChain(ChainBase): # 如果时间差小于等于15分钟,复用会话ID if time_diff <= timedelta(minutes=MessageChain._session_timeout_minutes): # 更新最后使用时间 - MessageChain._user_sessions[userid] = (session_id, current_time) + self._user_sessions[userid] = (session_id, current_time) logger.info( f"复用会话ID: {session_id}, 用户: {userid}, 距离上次会话: {time_diff.total_seconds() / 60:.1f}分钟") return session_id # 创建新的会话ID new_session_id = f"user_{userid}_{int(time.time())}" - MessageChain._user_sessions[userid] = (new_session_id, current_time) + self._user_sessions[userid] = (new_session_id, current_time) logger.info(f"创建新会话ID: {new_session_id}, 用户: {userid}") return new_session_id - @staticmethod - def clear_user_session(userid: Union[str, int]) -> bool: + def clear_user_session(self, userid: Union[str, int]) -> bool: """ 清除指定用户的会话信息 返回是否成功清除 """ - if userid in MessageChain._user_sessions: - session_id, _ = MessageChain._user_sessions.pop(userid) + if userid in self._user_sessions: + session_id, _ = self._user_sessions.pop(userid) logger.info(f"已清除用户 {userid} 的会话: {session_id}") return True return False @@ -889,8 +887,8 @@ class MessageChain(ChainBase): """ # 获取并清除会话信息 session_id = None - if userid in MessageChain._user_sessions: - session_id, _ = MessageChain._user_sessions.pop(userid) + if userid in self._user_sessions: + session_id, _ = self._user_sessions.pop(userid) logger.info(f"已清除用户 {userid} 的会话: {session_id}") # 如果有会话ID,同时清除智能体的会话记忆