"""对话记忆管理器""" import asyncio from datetime import datetime from typing import Dict, List, Optional from langchain_core.messages import BaseMessage from app.core.config import settings from app.log import logger from app.schemas.agent import ConversationMemory class MemoryManager: """ 对话记忆管理器 """ def __init__(self): # 内存中的会话记忆缓存 self.memory_cache: Dict[str, ConversationMemory] = {} # 内存缓存清理任务 self.cleanup_task: Optional[asyncio.Task] = None def initialize(self): """ 初始化记忆管理器 """ try: # 启动内存缓存清理任务(Redis通过TTL自动过期) self.cleanup_task = asyncio.create_task( self._cleanup_expired_memories() ) logger.info("对话记忆管理器初始化完成") except Exception as e: logger.warning(f"Redis连接失败,将使用内存存储: {e}") async def close(self): """ 关闭记忆管理器 """ if self.cleanup_task: self.cleanup_task.cancel() try: await self.cleanup_task except asyncio.CancelledError: pass logger.info("对话记忆管理器已关闭") @staticmethod def _get_memory_key(session_id: str, user_id: str): """ 计算内存Key """ return f"{user_id}:{session_id}" if user_id else session_id def get_memory(self, session_id: str, user_id: str) -> Optional[ConversationMemory]: """ 获取内存中的记忆 """ cache_key = self._get_memory_key(session_id, user_id) return self.memory_cache.get(cache_key) def get_agent_messages( self, session_id: str, user_id: str ) -> List[BaseMessage]: """ 为Agent获取最近的消息(仅内存缓存) 如果消息Token数量超过模型最大上下文长度的阀值,会自动进行摘要裁剪 """ memory = self.get_memory(session_id, user_id) if not memory: return [] # 获取所有消息 return memory.messages def save_agent_messages( self, session_id: str, user_id: str, messages: List[BaseMessage] ): """ 保存Agent消息(仅内存缓存) 注意:Redis中的记忆通过TTL机制自动过期,这里只更新内存缓存,Redis会在下次访问时自动过期 """ memory = self.get_memory(session_id, user_id) if not memory: memory = ConversationMemory(session_id=session_id, user_id=user_id) memory.messages = messages memory.updated_at = datetime.now() # 更新内存缓存 self.save_memory(memory) def save_memory(self, memory: ConversationMemory): """ 保存记忆到内存缓存 注意:Redis中的记忆通过TTL机制自动过期,这里只更新内存缓存,Redis会在下次访问时自动过期 """ cache_key = self._get_memory_key(memory.session_id, memory.user_id) self.memory_cache[cache_key] = memory def clear_memory(self, session_id: str, user_id: str): """ 清空会话记忆 """ cache_key = self._get_memory_key(session_id, user_id) if cache_key in self.memory_cache: del self.memory_cache[cache_key] logger.info(f"会话记忆已清空: session_id={session_id}, user_id={user_id}") async def _cleanup_expired_memories(self): """ 清理内存中过期记忆的后台任务 注意:Redis中的记忆通过TTL机制自动过期,这里只清理内存缓存 """ while True: try: # 每小时清理一次 await asyncio.sleep(3600) current_time = datetime.now() expired_sessions = [] # 只检查内存缓存中的过期记忆 # Redis中的记忆会通过TTL自动过期,无需手动处理 for cache_key, memory in self.memory_cache.items(): if ( current_time - memory.updated_at ).days > settings.LLM_MEMORY_RETENTION_DAYS: expired_sessions.append(cache_key) # 只清理内存缓存,不删除Redis中的键(Redis会自动过期) for cache_key in expired_sessions: if cache_key in self.memory_cache: del self.memory_cache[cache_key] if expired_sessions: logger.info(f"清理了{len(expired_sessions)}个过期内存会话记忆") except asyncio.CancelledError: break except Exception as e: logger.error(f"清理记忆时发生错误: {e}") memory_manager = MemoryManager()