mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-04-01 01:41:59 +08:00
155 lines
4.9 KiB
Python
155 lines
4.9 KiB
Python
"""对话记忆管理器"""
|
||
|
||
import asyncio
|
||
from datetime import datetime
|
||
from typing import Dict, List, Optional
|
||
|
||
from langchain_core.messages import BaseMessage
|
||
|
||
from app.core.config import settings
|
||
from app.log import logger
|
||
from app.schemas.agent import ConversationMemory
|
||
|
||
|
||
class MemoryManager:
|
||
"""
|
||
对话记忆管理器
|
||
"""
|
||
|
||
def __init__(self):
|
||
# 内存中的会话记忆缓存
|
||
self.memory_cache: Dict[str, ConversationMemory] = {}
|
||
# 内存缓存清理任务
|
||
self.cleanup_task: Optional[asyncio.Task] = None
|
||
|
||
def initialize(self):
|
||
"""
|
||
初始化记忆管理器
|
||
"""
|
||
try:
|
||
# 启动内存缓存清理任务(Redis通过TTL自动过期)
|
||
self.cleanup_task = asyncio.create_task(
|
||
self._cleanup_expired_memories()
|
||
)
|
||
logger.info("对话记忆管理器初始化完成")
|
||
|
||
except Exception as e:
|
||
logger.warning(f"Redis连接失败,将使用内存存储: {e}")
|
||
|
||
async def close(self):
|
||
"""
|
||
关闭记忆管理器
|
||
"""
|
||
if self.cleanup_task:
|
||
self.cleanup_task.cancel()
|
||
try:
|
||
await self.cleanup_task
|
||
except asyncio.CancelledError:
|
||
pass
|
||
|
||
logger.info("对话记忆管理器已关闭")
|
||
|
||
@staticmethod
|
||
def _get_memory_key(session_id: str, user_id: str):
|
||
"""
|
||
计算内存Key
|
||
"""
|
||
return f"{user_id}:{session_id}" if user_id else session_id
|
||
|
||
def get_memory(self, session_id: str, user_id: str) -> Optional[ConversationMemory]:
|
||
"""
|
||
获取内存中的记忆
|
||
"""
|
||
cache_key = self._get_memory_key(session_id, user_id)
|
||
return self.memory_cache.get(cache_key)
|
||
|
||
def get_agent_messages(
|
||
self, session_id: str, user_id: str
|
||
) -> List[BaseMessage]:
|
||
"""
|
||
为Agent获取最近的消息(仅内存缓存)
|
||
|
||
如果消息Token数量超过模型最大上下文长度的阀值,会自动进行摘要裁剪
|
||
"""
|
||
memory = self.get_memory(session_id, user_id)
|
||
if not memory:
|
||
return []
|
||
|
||
# 获取所有消息
|
||
return memory.messages
|
||
|
||
def save_agent_messages(
|
||
self, session_id: str, user_id: str, messages: List[BaseMessage]
|
||
):
|
||
"""
|
||
保存Agent消息(仅内存缓存)
|
||
|
||
注意:Redis中的记忆通过TTL机制自动过期,这里只更新内存缓存,Redis会在下次访问时自动过期
|
||
"""
|
||
memory = self.get_memory(session_id, user_id)
|
||
if not memory:
|
||
memory = ConversationMemory(session_id=session_id, user_id=user_id)
|
||
|
||
memory.messages = messages
|
||
memory.updated_at = datetime.now()
|
||
|
||
# 更新内存缓存
|
||
self.save_memory(memory)
|
||
|
||
def save_memory(self, memory: ConversationMemory):
|
||
"""
|
||
保存记忆到内存缓存
|
||
|
||
注意:Redis中的记忆通过TTL机制自动过期,这里只更新内存缓存,Redis会在下次访问时自动过期
|
||
"""
|
||
cache_key = self._get_memory_key(memory.session_id, memory.user_id)
|
||
self.memory_cache[cache_key] = memory
|
||
|
||
def clear_memory(self, session_id: str, user_id: str):
|
||
"""
|
||
清空会话记忆
|
||
"""
|
||
cache_key = self._get_memory_key(session_id, user_id)
|
||
if cache_key in self.memory_cache:
|
||
del self.memory_cache[cache_key]
|
||
|
||
logger.info(f"会话记忆已清空: session_id={session_id}, user_id={user_id}")
|
||
|
||
async def _cleanup_expired_memories(self):
|
||
"""
|
||
清理内存中过期记忆的后台任务
|
||
|
||
注意:Redis中的记忆通过TTL机制自动过期,这里只清理内存缓存
|
||
"""
|
||
while True:
|
||
try:
|
||
# 每小时清理一次
|
||
await asyncio.sleep(3600)
|
||
|
||
current_time = datetime.now()
|
||
expired_sessions = []
|
||
|
||
# 只检查内存缓存中的过期记忆
|
||
# Redis中的记忆会通过TTL自动过期,无需手动处理
|
||
for cache_key, memory in self.memory_cache.items():
|
||
if (
|
||
current_time - memory.updated_at
|
||
).days > settings.LLM_MEMORY_RETENTION_DAYS:
|
||
expired_sessions.append(cache_key)
|
||
|
||
# 只清理内存缓存,不删除Redis中的键(Redis会自动过期)
|
||
for cache_key in expired_sessions:
|
||
if cache_key in self.memory_cache:
|
||
del self.memory_cache[cache_key]
|
||
|
||
if expired_sessions:
|
||
logger.info(f"清理了{len(expired_sessions)}个过期内存会话记忆")
|
||
|
||
except asyncio.CancelledError:
|
||
break
|
||
except Exception as e:
|
||
logger.error(f"清理记忆时发生错误: {e}")
|
||
|
||
|
||
memory_manager = MemoryManager()
|