mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-02-03 02:25:32 +08:00
348 lines
12 KiB
Python
348 lines
12 KiB
Python
"""对话记忆管理器"""
|
||
|
||
import asyncio
|
||
import json
|
||
from datetime import datetime, timedelta
|
||
from typing import Dict, List, Optional, Any
|
||
|
||
from app.core.config import settings
|
||
from app.helper.redis import AsyncRedisHelper
|
||
from app.log import logger
|
||
from app.schemas.agent import ConversationMemory
|
||
|
||
|
||
class ConversationMemoryManager:
|
||
"""
|
||
对话记忆管理器
|
||
"""
|
||
|
||
def __init__(self):
|
||
# 内存中的会话记忆缓存
|
||
self.memory_cache: Dict[str, ConversationMemory] = {}
|
||
# 使用现有的Redis助手
|
||
self.redis_helper = AsyncRedisHelper()
|
||
# 内存缓存清理任务(Redis通过TTL自动过期)
|
||
self.cleanup_task: Optional[asyncio.Task] = None
|
||
|
||
async def initialize(self):
|
||
"""
|
||
初始化记忆管理器
|
||
"""
|
||
try:
|
||
# 启动内存缓存清理任务(Redis通过TTL自动过期)
|
||
self.cleanup_task = asyncio.create_task(self._cleanup_expired_memories())
|
||
logger.info("对话记忆管理器初始化完成")
|
||
|
||
except Exception as e:
|
||
logger.warning(f"Redis连接失败,将使用内存存储: {e}")
|
||
|
||
async def close(self):
|
||
"""
|
||
关闭记忆管理器
|
||
"""
|
||
if self.cleanup_task:
|
||
self.cleanup_task.cancel()
|
||
try:
|
||
await self.cleanup_task
|
||
except asyncio.CancelledError:
|
||
pass
|
||
|
||
await self.redis_helper.close()
|
||
|
||
logger.info("对话记忆管理器已关闭")
|
||
|
||
@staticmethod
|
||
def _get_memory_key(session_id: str, user_id: str):
|
||
"""
|
||
计算内存Key
|
||
"""
|
||
return f"{user_id}:{session_id}" if user_id else session_id
|
||
|
||
@staticmethod
|
||
def _get_redis_key(session_id: str, user_id: str):
|
||
"""
|
||
计算Redis Key
|
||
"""
|
||
return f"agent_memory:{user_id}:{session_id}" if user_id else f"agent_memory:{session_id}"
|
||
|
||
def _get_memory(self, session_id: str, user_id: str):
|
||
"""
|
||
获取内存中的记忆
|
||
"""
|
||
cache_key = self._get_memory_key(session_id, user_id)
|
||
return self.memory_cache.get(cache_key)
|
||
|
||
async def _get_redis(self, session_id: str, user_id: str) -> Optional[ConversationMemory]:
|
||
"""
|
||
从Redis获取记忆
|
||
"""
|
||
if settings.CACHE_BACKEND_TYPE == "redis":
|
||
try:
|
||
redis_key = self._get_redis_key(session_id, user_id)
|
||
memory_data = await self.redis_helper.get(redis_key, region="AI_AGENT")
|
||
if memory_data:
|
||
memory_dict = json.loads(memory_data) if isinstance(memory_data, str) else memory_data
|
||
memory = ConversationMemory(**memory_dict)
|
||
return memory
|
||
except Exception as e:
|
||
logger.warning(f"从Redis加载记忆失败: {e}")
|
||
return None
|
||
|
||
async def get_conversation(self, session_id: str, user_id: str) -> ConversationMemory:
|
||
"""
|
||
获取会话记忆
|
||
"""
|
||
# 首先检查缓存
|
||
conversion = self._get_memory(session_id, user_id)
|
||
if conversion:
|
||
return conversion
|
||
|
||
# 尝试从Redis加载
|
||
memory = await self._get_redis(session_id, user_id)
|
||
if memory:
|
||
# 加载到内存缓存
|
||
self._save_memory(memory)
|
||
return memory
|
||
|
||
# 创建新的记忆
|
||
memory = ConversationMemory(session_id=session_id, user_id=user_id)
|
||
await self._save_conversation(memory)
|
||
|
||
return memory
|
||
|
||
async def set_title(self, session_id: str, user_id: str, title: str):
|
||
"""
|
||
设置会话标题
|
||
"""
|
||
memory = await self.get_conversation(session_id=session_id, user_id=user_id)
|
||
memory.title = title
|
||
memory.updated_at = datetime.now()
|
||
await self._save_conversation(memory)
|
||
|
||
async def get_title(self, session_id: str, user_id: str) -> Optional[str]:
|
||
"""
|
||
获取会话标题
|
||
"""
|
||
memory = await self.get_conversation(session_id=session_id, user_id=user_id)
|
||
return memory.title
|
||
|
||
async def list_sessions(self, user_id: str, limit: int = 100) -> List[Dict[str, Any]]:
|
||
"""
|
||
列出历史会话摘要(按更新时间倒序)
|
||
|
||
- 当启用Redis时:遍历 `agent_memory:*` 键并读取摘要
|
||
- 当未启用Redis时:基于内存缓存返回
|
||
"""
|
||
sessions: List[ConversationMemory] = []
|
||
# 从Redis遍历
|
||
if settings.CACHE_BACKEND_TYPE == "redis":
|
||
try:
|
||
# 使用Redis助手的items方法遍历所有键
|
||
async for key, value in self.redis_helper.items(region="AI_AGENT"):
|
||
if key.startswith("agent_memory:"):
|
||
try:
|
||
# 解析键名获取user_id和session_id
|
||
key_parts = key.split(":")
|
||
if len(key_parts) >= 3:
|
||
key_user_id = key_parts[2] if len(key_parts) > 3 else None
|
||
if not user_id or key_user_id == user_id:
|
||
data = value if isinstance(value, dict) else json.loads(value)
|
||
memory = ConversationMemory(**data)
|
||
sessions.append(memory)
|
||
except Exception as err:
|
||
logger.warning(f"解析Redis记忆数据失败: {err}")
|
||
continue
|
||
except Exception as e:
|
||
logger.warning(f"遍历Redis会话失败: {e}")
|
||
|
||
# 合并内存缓存(确保包含近期的会话)
|
||
for cache_key, memory in self.memory_cache.items():
|
||
# 如果指定了user_id,只返回该用户的会话
|
||
if not user_id or memory.user_id == user_id:
|
||
sessions.append(memory)
|
||
|
||
# 去重(以 session_id 为键,取最近updated)
|
||
uniq: Dict[str, ConversationMemory] = {}
|
||
for mem in sessions:
|
||
existed = uniq.get(mem.session_id)
|
||
if (not existed) or (mem.updated_at > existed.updated_at):
|
||
uniq[mem.session_id] = mem
|
||
|
||
# 排序并裁剪
|
||
sorted_list = sorted(uniq.values(), key=lambda m: m.updated_at, reverse=True)[:limit]
|
||
return [
|
||
{
|
||
"session_id": m.session_id,
|
||
"title": m.title or "新会话",
|
||
"message_count": len(m.messages),
|
||
"created_at": m.created_at.isoformat(),
|
||
"updated_at": m.updated_at.isoformat(),
|
||
}
|
||
for m in sorted_list
|
||
]
|
||
|
||
async def add_conversation(
|
||
self,
|
||
session_id: str,
|
||
user_id: str,
|
||
role: str,
|
||
content: str,
|
||
metadata: Optional[Dict[str, Any]] = None
|
||
):
|
||
"""
|
||
添加消息到记忆
|
||
"""
|
||
memory = await self.get_conversation(session_id=session_id, user_id=user_id)
|
||
|
||
message = {
|
||
"role": role,
|
||
"content": content,
|
||
"timestamp": datetime.now().isoformat(),
|
||
"metadata": metadata or {}
|
||
}
|
||
|
||
memory.messages.append(message)
|
||
memory.updated_at = datetime.now()
|
||
|
||
# 限制消息数量,避免记忆过大
|
||
max_messages = settings.LLM_MAX_MEMORY_MESSAGES
|
||
if len(memory.messages) > max_messages:
|
||
# 保留最近的消息,但保留第一条系统消息
|
||
system_messages = [msg for msg in memory.messages if msg["role"] == "system"]
|
||
recent_messages = memory.messages[-(max_messages - len(system_messages)):]
|
||
memory.messages = system_messages + recent_messages
|
||
|
||
await self._save_conversation(memory)
|
||
|
||
logger.debug(f"消息已添加到记忆: session_id={session_id}, user_id={user_id}, role={role}")
|
||
|
||
def get_recent_messages_for_agent(
|
||
self,
|
||
session_id: str,
|
||
user_id: str
|
||
) -> List[Dict[str, Any]]:
|
||
"""
|
||
为Agent获取最近的消息(仅内存缓存)
|
||
|
||
如果消息Token数量超过模型最大上下文长度的阀值,会自动进行摘要裁剪
|
||
"""
|
||
cache_key = self._get_memory_key(session_id, user_id)
|
||
memory = self.memory_cache.get(cache_key)
|
||
if not memory:
|
||
return []
|
||
|
||
# 获取所有消息
|
||
return memory.messages[:-1]
|
||
|
||
async def get_recent_messages(
|
||
self,
|
||
session_id: str,
|
||
user_id: str,
|
||
limit: int = 10,
|
||
role_filter: Optional[list] = None
|
||
) -> List[Dict[str, Any]]:
|
||
"""
|
||
获取最近的消息
|
||
"""
|
||
memory = await self.get_conversation(session_id=session_id, user_id=user_id)
|
||
|
||
messages = memory.messages
|
||
if role_filter:
|
||
messages = [msg for msg in messages if msg["role"] in role_filter]
|
||
|
||
return messages[-limit:] if messages else []
|
||
|
||
async def get_context(self, session_id: str, user_id: str) -> Dict[str, Any]:
|
||
"""
|
||
获取会话上下文
|
||
"""
|
||
memory = await self.get_conversation(session_id=session_id, user_id=user_id)
|
||
return memory.context
|
||
|
||
async def clear_memory(self, session_id: str, user_id: str):
|
||
"""
|
||
清空会话记忆
|
||
"""
|
||
cache_key = f"{user_id}:{session_id}" if user_id else session_id
|
||
if cache_key in self.memory_cache:
|
||
del self.memory_cache[cache_key]
|
||
|
||
if settings.CACHE_BACKEND_TYPE == "redis":
|
||
redis_key = self._get_redis_key(session_id, user_id)
|
||
await self.redis_helper.delete(redis_key, region="AI_AGENT")
|
||
|
||
logger.info(f"会话记忆已清空: session_id={session_id}, user_id={user_id}")
|
||
|
||
def _save_memory(self, memory: ConversationMemory):
|
||
"""
|
||
保存记忆到内存
|
||
"""
|
||
cache_key = self._get_memory_key(memory.session_id, memory.user_id)
|
||
self.memory_cache[cache_key] = memory
|
||
|
||
async def _save_redis(self, memory: ConversationMemory):
|
||
"""
|
||
保存记忆到Redis
|
||
"""
|
||
if settings.CACHE_BACKEND_TYPE == "redis":
|
||
try:
|
||
memory_dict = memory.model_dump()
|
||
redis_key = self._get_redis_key(memory.session_id, memory.user_id)
|
||
ttl = int(timedelta(days=settings.LLM_REDIS_MEMORY_RETENTION_DAYS).total_seconds())
|
||
await self.redis_helper.set(
|
||
redis_key,
|
||
memory_dict,
|
||
ttl=ttl,
|
||
region="AI_AGENT"
|
||
)
|
||
except Exception as e:
|
||
logger.warning(f"保存记忆到Redis失败: {e}")
|
||
|
||
async def _save_conversation(self, memory: ConversationMemory):
|
||
"""
|
||
保存记忆到存储
|
||
|
||
Redis中的记忆会自动通过TTL机制过期,无需手动清理
|
||
"""
|
||
# 更新内存缓存
|
||
self._save_memory(memory)
|
||
|
||
# 保存到Redis,设置TTL自动过期
|
||
await self._save_redis(memory)
|
||
|
||
|
||
async def _cleanup_expired_memories(self):
|
||
"""
|
||
清理内存中过期记忆的后台任务
|
||
|
||
注意:Redis中的记忆通过TTL机制自动过期,这里只清理内存缓存
|
||
"""
|
||
while True:
|
||
try:
|
||
# 每小时清理一次
|
||
await asyncio.sleep(3600)
|
||
|
||
current_time = datetime.now()
|
||
expired_sessions = []
|
||
|
||
# 只检查内存缓存中的过期记忆
|
||
# Redis中的记忆会通过TTL自动过期,无需手动处理
|
||
for cache_key, memory in self.memory_cache.items():
|
||
if (current_time - memory.updated_at).days > settings.LLM_MEMORY_RETENTION_DAYS:
|
||
expired_sessions.append(cache_key)
|
||
|
||
# 只清理内存缓存,不删除Redis中的键(Redis会自动过期)
|
||
for cache_key in expired_sessions:
|
||
if cache_key in self.memory_cache:
|
||
del self.memory_cache[cache_key]
|
||
|
||
if expired_sessions:
|
||
logger.info(f"清理了{len(expired_sessions)}个过期内存会话记忆")
|
||
|
||
except asyncio.CancelledError:
|
||
break
|
||
except Exception as e:
|
||
logger.error(f"清理记忆时发生错误: {e}")
|
||
|
||
conversation_manager = ConversationMemoryManager()
|