fix agent memory

This commit is contained in:
jxxghp
2025-12-13 12:26:08 +08:00
parent 140b0d3df2
commit 81f30ef25a
5 changed files with 94 additions and 48 deletions

View File

@@ -7,7 +7,7 @@ from langchain.agents import AgentExecutor, create_openai_tools_agent
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_community.callbacks import get_openai_callback
from langchain_core.chat_history import InMemoryChatMessageHistory
from langchain_core.messages import HumanMessage, AIMessage, ToolCall
from langchain_core.messages import HumanMessage, AIMessage, ToolCall, ToolMessage, SystemMessage
from langchain_core.runnables.history import RunnableWithMessageHistory
from app.agent.callback import StreamingCallbackHandler
@@ -56,9 +56,6 @@ class MoviePilotAgent:
# 工具
self.tools = self._initialize_tools()
# 会话存储
self.session_store = self._initialize_session_store()
# 提示词模板
self.prompt = self._initialize_prompt()
@@ -127,7 +124,8 @@ class MoviePilotAgent:
channel=self.channel,
source=self.source,
username=self.username,
callback_handler=self.callback_handler
callback_handler=self.callback_handler,
memory_mananger=self.memory_manager
)
@staticmethod
@@ -137,34 +135,36 @@ class MoviePilotAgent:
def get_session_history(self, session_id: str) -> InMemoryChatMessageHistory:
"""获取会话历史"""
if session_id not in self.session_store:
chat_history = InMemoryChatMessageHistory()
messages: List[dict] = self.memory_manager.get_recent_messages_for_agent(
session_id=session_id,
user_id=self.user_id
)
if messages:
for msg in messages:
if msg.get("role") == "user":
chat_history.add_user_message(HumanMessage(content=msg.get("content", "")))
elif msg.get("role") == "agent":
chat_history.add_ai_message(AIMessage(content=msg.get("content", "")))
elif msg.get("role") == "tool_call":
metadata = msg.get("metadata", {})
chat_history.add_ai_message(AIMessage(
chat_history = InMemoryChatMessageHistory()
messages: List[dict] = self.memory_manager.get_recent_messages_for_agent(
session_id=session_id,
user_id=self.user_id
)
if messages:
for msg in messages:
if msg.get("role") == "user":
chat_history.add_message(HumanMessage(content=msg.get("content", "")))
elif msg.get("role") == "agent":
chat_history.add_message(AIMessage(content=msg.get("content", "")))
elif msg.get("role") == "tool_call":
metadata = msg.get("metadata", {})
chat_history.add_message(
AIMessage(
content=msg.get("content", ""),
tool_calls=[ToolCall(
id=metadata.get("call_id"),
name=metadata.get("tool_name"),
args=metadata.get("parameters"),
)]
))
elif msg.get("role") == "tool_result":
chat_history.add_ai_message(AIMessage(content=msg.get("content", "")))
elif msg.get("role") == "system":
chat_history.add_ai_message(AIMessage(content=msg.get("content", "")))
self.session_store[session_id] = chat_history
return self.session_store[session_id]
tool_calls=[
ToolCall(
id=metadata.get("call_id"),
name=metadata.get("tool_name"),
args=metadata.get("parameters"),
)
]
)
)
elif msg.get("role") == "tool_result":
chat_history.add_message(ToolMessage(content=msg.get("content", "")))
elif msg.get("role") == "system":
chat_history.add_message(SystemMessage(content=msg.get("content", "")))
return chat_history
@staticmethod
def _initialize_prompt() -> ChatPromptTemplate:
@@ -306,8 +306,6 @@ class MoviePilotAgent:
async def cleanup(self):
"""清理智能体资源"""
if self.session_id in self.session_store:
del self.session_store[self.session_id]
logger.info(f"MoviePilot智能体已清理: session_id={self.session_id}")

View File

@@ -186,9 +186,7 @@ class ConversationMemoryManager:
return []
# 获取所有消息
messages = memory.messages
return messages
return memory.messages
async def get_recent_messages(
self,

View File

@@ -1,11 +1,12 @@
"""MoviePilot工具基类"""
import json
from abc import ABCMeta, abstractmethod
from typing import Callable, Any, Optional
from typing import Any, Optional
from langchain.tools import BaseTool
from pydantic import PrivateAttr
from app.agent import StreamingCallbackHandler
from app.agent import StreamingCallbackHandler, ConversationMemoryManager
from app.chain import ChainBase
from app.log import logger
from app.schemas import Notification
@@ -24,6 +25,7 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
_source: str = PrivateAttr(default=None)
_username: str = PrivateAttr(default=None)
_callback_handler: StreamingCallbackHandler = PrivateAttr(default=None)
_memory_manager: ConversationMemoryManager = PrivateAttr(default=None)
def __init__(self, session_id: str, user_id: str, **kwargs):
super().__init__(**kwargs)
@@ -35,24 +37,58 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
async def _arun(self, **kwargs) -> str:
"""异步运行工具"""
# 发送运行工具前的
# 发送和记忆工具调用前的
agent_message = await self._callback_handler.get_message()
if agent_message:
# 发送消息
await self.send_tool_message(agent_message, title="MoviePilot助手")
# 发送执行工具说明
# 优先使用工具自定义的提示消息,如果没有则使用 explanation
await self.memory_manager.add_memory(
session_id=self._session_id,
user_id=self._user_id,
role="agent",
content=agent_message
)
# 记忆工具调用
await self.memory_manager.add_memory(
session_id=self.session_id,
user_id=self.user_id,
role="tool_call",
metadata={
"call_id": self.__class__.__name__,
"tool_name": self.__class__.__name__,
"parameters": kwargs
}
)
# 发送执行工具说明,优先使用工具自定义的提示消息,如果没有则使用 explanation
tool_message = self.get_tool_message(**kwargs)
if not tool_message:
explanation = kwargs.get("explanation")
if explanation:
tool_message = explanation
if tool_message:
formatted_message = f"⚙️ => {tool_message}"
await self.send_tool_message(formatted_message)
logger.debug(f'Executing tool {self.name} with args: {kwargs}')
result = await self.run(**kwargs)
logger.debug(f'Tool {self.name} executed with result: {result}')
# 记忆工具调用结果
if isinstance(result, str):
formated_result = result
elif isinstance(result, int, float):
formated_result = str(result)
else:
formated_result = json.dumps(result, ensure_ascii=False, indent=2)
await self.memory_manager.add_memory(
session_id=self.session_id,
user_id=self.user_id,
role="tool_result",
content=formated_result
)
return result
def get_tool_message(self, **kwargs) -> Optional[str]:
@@ -84,6 +120,10 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
"""设置回调处理器"""
self._callback_handler = callback_handler
def set_memory_manager(self, memory_manager: ConversationMemoryManager):
"""设置记忆客理器"""
self._memory_manager = memory_manager
async def send_tool_message(self, message: str, title: str = ""):
"""发送工具消息"""
await ToolChain().async_post_message(

View File

@@ -51,7 +51,7 @@ class MoviePilotToolFactory:
@staticmethod
def create_tools(session_id: str, user_id: str,
channel: str = None, source: str = None, username: str = None,
callback_handler: Callable = None) -> List[MoviePilotTool]:
callback_handler: Callable = None, memory_mananger: Callable = None) -> List[MoviePilotTool]:
"""创建MoviePilot工具列表"""
tools = []
tool_definitions = [
@@ -102,6 +102,7 @@ class MoviePilotToolFactory:
)
tool.set_message_attr(channel=channel, source=source, username=username)
tool.set_callback_handler(callback_handler=callback_handler)
tool.set_memory_manager(memory_manager=memory_mananger)
tools.append(tool)
# 加载插件提供的工具
@@ -124,6 +125,7 @@ class MoviePilotToolFactory:
)
tool.set_message_attr(channel=channel, source=source, username=username)
tool.set_callback_handler(callback_handler=callback_handler)
tool.set_memory_manager(memory_manager=memory_mananger)
tools.append(tool)
plugin_tools_count += 1
logger.debug(f"成功加载插件 {plugin_name}({plugin_id}) 的工具: {ToolClass.__name__}")

View File

@@ -3,8 +3,10 @@
"""
import json
import uuid
from typing import Any, Dict, List, Optional
from app.agent import ConversationMemoryManager
from app.agent.tools.factory import MoviePilotToolFactory
from app.log import logger
@@ -21,7 +23,7 @@ class ToolDefinition:
class MoviePilotToolsManager:
"""MoviePilot工具管理器用于HTTP API"""
def __init__(self, user_id: str = "api_user", session_id: str = "api_session"):
def __init__(self, user_id: str = "api_user", session_id: str = uuid.uuid4()):
"""
初始化工具管理器
@@ -32,6 +34,7 @@ class MoviePilotToolsManager:
self.user_id = user_id
self.session_id = session_id
self.tools: List[Any] = []
self.memory_manager = ConversationMemoryManager()
self._load_tools()
def _load_tools(self):
@@ -44,7 +47,8 @@ class MoviePilotToolsManager:
channel=None,
source="api",
username="API Client",
callback_handler=None
callback_handler=None,
memory_mananger=None,
)
logger.info(f"成功加载 {len(self.tools)} 个工具")
except Exception as e:
@@ -121,9 +125,13 @@ class MoviePilotToolsManager:
# 确保返回字符串
if isinstance(result, str):
return result
formated_result = result
elif isinstance(result, int, float):
formated_result = str(result)
else:
return json.dumps(result, ensure_ascii=False, indent=2)
formated_result = json.dumps(result, ensure_ascii=False, indent=2)
return formated_result
except Exception as e:
logger.error(f"调用工具 {tool_name} 时发生错误: {e}", exc_info=True)
error_msg = json.dumps({