fix agent memory

This commit is contained in:
jxxghp
2025-12-13 12:26:08 +08:00
parent 140b0d3df2
commit 81f30ef25a
5 changed files with 94 additions and 48 deletions

View File

@@ -1,11 +1,12 @@
"""MoviePilot工具基类"""
import json
from abc import ABCMeta, abstractmethod
from typing import Callable, Any, Optional
from typing import Any, Optional
from langchain.tools import BaseTool
from pydantic import PrivateAttr
from app.agent import StreamingCallbackHandler
from app.agent import StreamingCallbackHandler, ConversationMemoryManager
from app.chain import ChainBase
from app.log import logger
from app.schemas import Notification
@@ -24,6 +25,7 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
_source: str = PrivateAttr(default=None)
_username: str = PrivateAttr(default=None)
_callback_handler: StreamingCallbackHandler = PrivateAttr(default=None)
_memory_manager: ConversationMemoryManager = PrivateAttr(default=None)
def __init__(self, session_id: str, user_id: str, **kwargs):
super().__init__(**kwargs)
@@ -35,24 +37,58 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
async def _arun(self, **kwargs) -> str:
"""异步运行工具"""
# 发送运行工具前的
# 发送和记忆工具调用前的
agent_message = await self._callback_handler.get_message()
if agent_message:
# 发送消息
await self.send_tool_message(agent_message, title="MoviePilot助手")
# 发送执行工具说明
# 优先使用工具自定义的提示消息,如果没有则使用 explanation
await self.memory_manager.add_memory(
session_id=self._session_id,
user_id=self._user_id,
role="agent",
content=agent_message
)
# 记忆工具调用
await self.memory_manager.add_memory(
session_id=self.session_id,
user_id=self.user_id,
role="tool_call",
metadata={
"call_id": self.__class__.__name__,
"tool_name": self.__class__.__name__,
"parameters": kwargs
}
)
# 发送执行工具说明,优先使用工具自定义的提示消息,如果没有则使用 explanation
tool_message = self.get_tool_message(**kwargs)
if not tool_message:
explanation = kwargs.get("explanation")
if explanation:
tool_message = explanation
if tool_message:
formatted_message = f"⚙️ => {tool_message}"
await self.send_tool_message(formatted_message)
logger.debug(f'Executing tool {self.name} with args: {kwargs}')
result = await self.run(**kwargs)
logger.debug(f'Tool {self.name} executed with result: {result}')
# 记忆工具调用结果
if isinstance(result, str):
formated_result = result
elif isinstance(result, int, float):
formated_result = str(result)
else:
formated_result = json.dumps(result, ensure_ascii=False, indent=2)
await self.memory_manager.add_memory(
session_id=self.session_id,
user_id=self.user_id,
role="tool_result",
content=formated_result
)
return result
def get_tool_message(self, **kwargs) -> Optional[str]:
@@ -84,6 +120,10 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
"""设置回调处理器"""
self._callback_handler = callback_handler
def set_memory_manager(self, memory_manager: ConversationMemoryManager):
"""设置记忆客理器"""
self._memory_manager = memory_manager
async def send_tool_message(self, message: str, title: str = ""):
"""发送工具消息"""
await ToolChain().async_post_message(

View File

@@ -51,7 +51,7 @@ class MoviePilotToolFactory:
@staticmethod
def create_tools(session_id: str, user_id: str,
channel: str = None, source: str = None, username: str = None,
callback_handler: Callable = None) -> List[MoviePilotTool]:
callback_handler: Callable = None, memory_mananger: Callable = None) -> List[MoviePilotTool]:
"""创建MoviePilot工具列表"""
tools = []
tool_definitions = [
@@ -102,6 +102,7 @@ class MoviePilotToolFactory:
)
tool.set_message_attr(channel=channel, source=source, username=username)
tool.set_callback_handler(callback_handler=callback_handler)
tool.set_memory_manager(memory_manager=memory_mananger)
tools.append(tool)
# 加载插件提供的工具
@@ -124,6 +125,7 @@ class MoviePilotToolFactory:
)
tool.set_message_attr(channel=channel, source=source, username=username)
tool.set_callback_handler(callback_handler=callback_handler)
tool.set_memory_manager(memory_manager=memory_mananger)
tools.append(tool)
plugin_tools_count += 1
logger.debug(f"成功加载插件 {plugin_name}({plugin_id}) 的工具: {ToolClass.__name__}")

View File

@@ -3,8 +3,10 @@
"""
import json
import uuid
from typing import Any, Dict, List, Optional
from app.agent import ConversationMemoryManager
from app.agent.tools.factory import MoviePilotToolFactory
from app.log import logger
@@ -21,7 +23,7 @@ class ToolDefinition:
class MoviePilotToolsManager:
"""MoviePilot工具管理器用于HTTP API"""
def __init__(self, user_id: str = "api_user", session_id: str = "api_session"):
def __init__(self, user_id: str = "api_user", session_id: str = uuid.uuid4()):
"""
初始化工具管理器
@@ -32,6 +34,7 @@ class MoviePilotToolsManager:
self.user_id = user_id
self.session_id = session_id
self.tools: List[Any] = []
self.memory_manager = ConversationMemoryManager()
self._load_tools()
def _load_tools(self):
@@ -44,7 +47,8 @@ class MoviePilotToolsManager:
channel=None,
source="api",
username="API Client",
callback_handler=None
callback_handler=None,
memory_mananger=None,
)
logger.info(f"成功加载 {len(self.tools)} 个工具")
except Exception as e:
@@ -121,9 +125,13 @@ class MoviePilotToolsManager:
# 确保返回字符串
if isinstance(result, str):
return result
formated_result = result
elif isinstance(result, int, float):
formated_result = str(result)
else:
return json.dumps(result, ensure_ascii=False, indent=2)
formated_result = json.dumps(result, ensure_ascii=False, indent=2)
return formated_result
except Exception as e:
logger.error(f"调用工具 {tool_name} 时发生错误: {e}", exc_info=True)
error_msg = json.dumps({