From 81f30ef25adb58de524960a2ac383fa7893f1c0d Mon Sep 17 00:00:00 2001 From: jxxghp Date: Sat, 13 Dec 2025 12:26:08 +0800 Subject: [PATCH] fix agent memory --- app/agent/__init__.py | 66 +++++++++++++++++------------------- app/agent/memory/__init__.py | 4 +-- app/agent/tools/base.py | 52 ++++++++++++++++++++++++---- app/agent/tools/factory.py | 4 ++- app/agent/tools/manager.py | 16 ++++++--- 5 files changed, 94 insertions(+), 48 deletions(-) diff --git a/app/agent/__init__.py b/app/agent/__init__.py index 71561d6e..c41b5c77 100644 --- a/app/agent/__init__.py +++ b/app/agent/__init__.py @@ -7,7 +7,7 @@ from langchain.agents import AgentExecutor, create_openai_tools_agent from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_community.callbacks import get_openai_callback from langchain_core.chat_history import InMemoryChatMessageHistory -from langchain_core.messages import HumanMessage, AIMessage, ToolCall +from langchain_core.messages import HumanMessage, AIMessage, ToolCall, ToolMessage, SystemMessage from langchain_core.runnables.history import RunnableWithMessageHistory from app.agent.callback import StreamingCallbackHandler @@ -56,9 +56,6 @@ class MoviePilotAgent: # 工具 self.tools = self._initialize_tools() - # 会话存储 - self.session_store = self._initialize_session_store() - # 提示词模板 self.prompt = self._initialize_prompt() @@ -127,7 +124,8 @@ class MoviePilotAgent: channel=self.channel, source=self.source, username=self.username, - callback_handler=self.callback_handler + callback_handler=self.callback_handler, + memory_mananger=self.memory_manager ) @staticmethod @@ -137,34 +135,36 @@ class MoviePilotAgent: def get_session_history(self, session_id: str) -> InMemoryChatMessageHistory: """获取会话历史""" - if session_id not in self.session_store: - chat_history = InMemoryChatMessageHistory() - messages: List[dict] = self.memory_manager.get_recent_messages_for_agent( - session_id=session_id, - user_id=self.user_id - ) - if messages: - for msg in messages: - if msg.get("role") == "user": - chat_history.add_user_message(HumanMessage(content=msg.get("content", ""))) - elif msg.get("role") == "agent": - chat_history.add_ai_message(AIMessage(content=msg.get("content", ""))) - elif msg.get("role") == "tool_call": - metadata = msg.get("metadata", {}) - chat_history.add_ai_message(AIMessage( + chat_history = InMemoryChatMessageHistory() + messages: List[dict] = self.memory_manager.get_recent_messages_for_agent( + session_id=session_id, + user_id=self.user_id + ) + if messages: + for msg in messages: + if msg.get("role") == "user": + chat_history.add_message(HumanMessage(content=msg.get("content", ""))) + elif msg.get("role") == "agent": + chat_history.add_message(AIMessage(content=msg.get("content", ""))) + elif msg.get("role") == "tool_call": + metadata = msg.get("metadata", {}) + chat_history.add_message( + AIMessage( content=msg.get("content", ""), - tool_calls=[ToolCall( - id=metadata.get("call_id"), - name=metadata.get("tool_name"), - args=metadata.get("parameters"), - )] - )) - elif msg.get("role") == "tool_result": - chat_history.add_ai_message(AIMessage(content=msg.get("content", ""))) - elif msg.get("role") == "system": - chat_history.add_ai_message(AIMessage(content=msg.get("content", ""))) - self.session_store[session_id] = chat_history - return self.session_store[session_id] + tool_calls=[ + ToolCall( + id=metadata.get("call_id"), + name=metadata.get("tool_name"), + args=metadata.get("parameters"), + ) + ] + ) + ) + elif msg.get("role") == "tool_result": + chat_history.add_message(ToolMessage(content=msg.get("content", ""))) + elif msg.get("role") == "system": + chat_history.add_message(SystemMessage(content=msg.get("content", ""))) + return chat_history @staticmethod def _initialize_prompt() -> ChatPromptTemplate: @@ -306,8 +306,6 @@ class MoviePilotAgent: async def cleanup(self): """清理智能体资源""" - if self.session_id in self.session_store: - del self.session_store[self.session_id] logger.info(f"MoviePilot智能体已清理: session_id={self.session_id}") diff --git a/app/agent/memory/__init__.py b/app/agent/memory/__init__.py index dbca172d..02daa8dc 100644 --- a/app/agent/memory/__init__.py +++ b/app/agent/memory/__init__.py @@ -186,9 +186,7 @@ class ConversationMemoryManager: return [] # 获取所有消息 - messages = memory.messages - - return messages + return memory.messages async def get_recent_messages( self, diff --git a/app/agent/tools/base.py b/app/agent/tools/base.py index 8e1262aa..1b9f0c0a 100644 --- a/app/agent/tools/base.py +++ b/app/agent/tools/base.py @@ -1,11 +1,12 @@ """MoviePilot工具基类""" +import json from abc import ABCMeta, abstractmethod -from typing import Callable, Any, Optional +from typing import Any, Optional from langchain.tools import BaseTool from pydantic import PrivateAttr -from app.agent import StreamingCallbackHandler +from app.agent import StreamingCallbackHandler, ConversationMemoryManager from app.chain import ChainBase from app.log import logger from app.schemas import Notification @@ -24,6 +25,7 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta): _source: str = PrivateAttr(default=None) _username: str = PrivateAttr(default=None) _callback_handler: StreamingCallbackHandler = PrivateAttr(default=None) + _memory_manager: ConversationMemoryManager = PrivateAttr(default=None) def __init__(self, session_id: str, user_id: str, **kwargs): super().__init__(**kwargs) @@ -35,24 +37,58 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta): async def _arun(self, **kwargs) -> str: """异步运行工具""" - # 发送运行工具前的消息 + # 发送和记忆工具调用前的信息 agent_message = await self._callback_handler.get_message() if agent_message: + # 发送消息 await self.send_tool_message(agent_message, title="MoviePilot助手") - # 发送执行工具说明 - # 优先使用工具自定义的提示消息,如果没有则使用 explanation + await self.memory_manager.add_memory( + session_id=self._session_id, + user_id=self._user_id, + role="agent", + content=agent_message + ) + + # 记忆工具调用 + await self.memory_manager.add_memory( + session_id=self.session_id, + user_id=self.user_id, + role="tool_call", + metadata={ + "call_id": self.__class__.__name__, + "tool_name": self.__class__.__name__, + "parameters": kwargs + } + ) + + # 发送执行工具说明,优先使用工具自定义的提示消息,如果没有则使用 explanation tool_message = self.get_tool_message(**kwargs) if not tool_message: explanation = kwargs.get("explanation") if explanation: tool_message = explanation - if tool_message: formatted_message = f"⚙️ => {tool_message}" await self.send_tool_message(formatted_message) + logger.debug(f'Executing tool {self.name} with args: {kwargs}') result = await self.run(**kwargs) logger.debug(f'Tool {self.name} executed with result: {result}') + + # 记忆工具调用结果 + if isinstance(result, str): + formated_result = result + elif isinstance(result, int, float): + formated_result = str(result) + else: + formated_result = json.dumps(result, ensure_ascii=False, indent=2) + await self.memory_manager.add_memory( + session_id=self.session_id, + user_id=self.user_id, + role="tool_result", + content=formated_result + ) + return result def get_tool_message(self, **kwargs) -> Optional[str]: @@ -84,6 +120,10 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta): """设置回调处理器""" self._callback_handler = callback_handler + def set_memory_manager(self, memory_manager: ConversationMemoryManager): + """设置记忆客理器""" + self._memory_manager = memory_manager + async def send_tool_message(self, message: str, title: str = ""): """发送工具消息""" await ToolChain().async_post_message( diff --git a/app/agent/tools/factory.py b/app/agent/tools/factory.py index c48930c5..2b5f4686 100644 --- a/app/agent/tools/factory.py +++ b/app/agent/tools/factory.py @@ -51,7 +51,7 @@ class MoviePilotToolFactory: @staticmethod def create_tools(session_id: str, user_id: str, channel: str = None, source: str = None, username: str = None, - callback_handler: Callable = None) -> List[MoviePilotTool]: + callback_handler: Callable = None, memory_mananger: Callable = None) -> List[MoviePilotTool]: """创建MoviePilot工具列表""" tools = [] tool_definitions = [ @@ -102,6 +102,7 @@ class MoviePilotToolFactory: ) tool.set_message_attr(channel=channel, source=source, username=username) tool.set_callback_handler(callback_handler=callback_handler) + tool.set_memory_manager(memory_manager=memory_mananger) tools.append(tool) # 加载插件提供的工具 @@ -124,6 +125,7 @@ class MoviePilotToolFactory: ) tool.set_message_attr(channel=channel, source=source, username=username) tool.set_callback_handler(callback_handler=callback_handler) + tool.set_memory_manager(memory_manager=memory_mananger) tools.append(tool) plugin_tools_count += 1 logger.debug(f"成功加载插件 {plugin_name}({plugin_id}) 的工具: {ToolClass.__name__}") diff --git a/app/agent/tools/manager.py b/app/agent/tools/manager.py index 6e56cc21..770c527a 100644 --- a/app/agent/tools/manager.py +++ b/app/agent/tools/manager.py @@ -3,8 +3,10 @@ """ import json +import uuid from typing import Any, Dict, List, Optional +from app.agent import ConversationMemoryManager from app.agent.tools.factory import MoviePilotToolFactory from app.log import logger @@ -21,7 +23,7 @@ class ToolDefinition: class MoviePilotToolsManager: """MoviePilot工具管理器(用于HTTP API)""" - def __init__(self, user_id: str = "api_user", session_id: str = "api_session"): + def __init__(self, user_id: str = "api_user", session_id: str = uuid.uuid4()): """ 初始化工具管理器 @@ -32,6 +34,7 @@ class MoviePilotToolsManager: self.user_id = user_id self.session_id = session_id self.tools: List[Any] = [] + self.memory_manager = ConversationMemoryManager() self._load_tools() def _load_tools(self): @@ -44,7 +47,8 @@ class MoviePilotToolsManager: channel=None, source="api", username="API Client", - callback_handler=None + callback_handler=None, + memory_mananger=None, ) logger.info(f"成功加载 {len(self.tools)} 个工具") except Exception as e: @@ -121,9 +125,13 @@ class MoviePilotToolsManager: # 确保返回字符串 if isinstance(result, str): - return result + formated_result = result + elif isinstance(result, int, float): + formated_result = str(result) else: - return json.dumps(result, ensure_ascii=False, indent=2) + formated_result = json.dumps(result, ensure_ascii=False, indent=2) + + return formated_result except Exception as e: logger.error(f"调用工具 {tool_name} 时发生错误: {e}", exc_info=True) error_msg = json.dumps({