From 1641d432ddadf8d19eea3eb92c21b063bfbea919 Mon Sep 17 00:00:00 2001 From: jxxghp Date: Thu, 15 Jan 2026 20:55:35 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E4=B8=BA=E5=B7=A5=E5=85=B7=E7=AE=A1?= =?UTF-8?q?=E7=90=86=E5=99=A8=E6=B7=BB=E5=8A=A0=E5=8F=82=E6=95=B0=E7=B1=BB?= =?UTF-8?q?=E5=9E=8B=E8=A7=84=E8=8C=83=E5=8C=96=E5=A4=84=E7=90=86=EF=BC=8C?= =?UTF-8?q?=E5=B9=B6=E5=9F=BA=E4=BA=8E=E6=B8=A0=E9=81=93=E8=83=BD=E5=8A=9B?= =?UTF-8?q?=E5=8A=A8=E6=80=81=E7=94=9F=E6=88=90=E6=8F=90=E7=A4=BA=E8=AF=8D?= =?UTF-8?q?=E4=B8=AD=E7=9A=84=E6=A0=BC=E5=BC=8F=E8=A6=81=E6=B1=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/agent/__init__.py | 98 +++++++++++-------- app/agent/callback/__init__.py | 12 ++- app/agent/memory/__init__.py | 151 +++++++++++++++++++++--------- app/agent/prompt/Agent Prompt.txt | 67 ++++++------- app/agent/prompt/__init__.py | 106 ++++++++------------- app/agent/tools/base.py | 32 ++++--- app/agent/tools/factory.py | 14 +-- app/agent/tools/manager.py | 36 +++---- app/api/endpoints/mcp.py | 68 +++++--------- app/startup/agent_initializer.py | 22 ++--- 10 files changed, 324 insertions(+), 282 deletions(-) diff --git a/app/agent/__init__.py b/app/agent/__init__.py index 2a510f64..9965ccf9 100644 --- a/app/agent/__init__.py +++ b/app/agent/__init__.py @@ -1,5 +1,3 @@ -"""MoviePilot AI智能体实现""" - import asyncio from typing import Dict, List, Any @@ -11,8 +9,8 @@ from langchain_core.messages import HumanMessage, AIMessage, ToolCall, ToolMessa from langchain_core.runnables.history import RunnableWithMessageHistory from app.agent.callback import StreamingCallbackHandler -from app.agent.memory import ConversationMemoryManager -from app.agent.prompt import PromptManager +from app.agent.memory import conversation_manager +from app.agent.prompt import prompt_manager from app.agent.tools.factory import MoviePilotToolFactory from app.chain import ChainBase from app.core.config import settings @@ -27,7 +25,9 @@ class AgentChain(ChainBase): class MoviePilotAgent: - """MoviePilot AI智能体""" + """ + MoviePilot AI智能体 + """ def __init__(self, session_id: str, user_id: str = None, channel: str = None, source: str = None, username: str = None): @@ -40,12 +40,6 @@ class MoviePilotAgent: # 消息助手 self.message_helper = MessageHelper() - # 记忆管理器 - self.memory_manager = ConversationMemoryManager() - - # 提示词管理器 - self.prompt_manager = PromptManager() - # 回调处理器 self.callback_handler = StreamingCallbackHandler( session_id=session_id @@ -64,30 +58,37 @@ class MoviePilotAgent: self.agent_executor = self._create_agent_executor() def _initialize_llm(self): - """初始化LLM模型""" + """ + 初始化LLM模型 + """ return LLMHelper.get_llm(streaming=True, callbacks=[self.callback_handler]) def _initialize_tools(self) -> List: - """初始化工具列表""" + """ + 初始化工具列表 + """ return MoviePilotToolFactory.create_tools( session_id=self.session_id, user_id=self.user_id, channel=self.channel, source=self.source, username=self.username, - callback_handler=self.callback_handler, - memory_mananger=self.memory_manager + callback_handler=self.callback_handler ) @staticmethod def _initialize_session_store() -> Dict[str, InMemoryChatMessageHistory]: - """初始化内存存储""" + """ + 初始化内存存储 + """ return {} def get_session_history(self, session_id: str) -> InMemoryChatMessageHistory: - """获取会话历史""" + """ + 获取会话历史 + """ chat_history = InMemoryChatMessageHistory() - messages: List[dict] = self.memory_manager.get_recent_messages_for_agent( + messages: List[dict] = conversation_manager.get_recent_messages_for_agent( session_id=session_id, user_id=self.user_id ) @@ -119,7 +120,9 @@ class MoviePilotAgent: @staticmethod def _initialize_prompt() -> ChatPromptTemplate: - """初始化提示词模板""" + """ + 初始化提示词模板 + """ try: prompt_template = ChatPromptTemplate.from_messages([ ("system", "{system_prompt}"), @@ -134,7 +137,9 @@ class MoviePilotAgent: raise e def _create_agent_executor(self) -> RunnableWithMessageHistory: - """创建Agent执行器""" + """ + 创建Agent执行器 + """ try: agent = create_openai_tools_agent( llm=self.llm, @@ -161,10 +166,12 @@ class MoviePilotAgent: raise e async def process_message(self, message: str) -> str: - """处理用户消息""" + """ + 处理用户消息 + """ try: # 添加用户消息到记忆 - await self.memory_manager.add_memory( + await conversation_manager.add_conversation( self.session_id, user_id=self.user_id, role="user", @@ -173,7 +180,7 @@ class MoviePilotAgent: # 构建输入上下文 input_context = { - "system_prompt": self.prompt_manager.get_agent_prompt(channel=self.channel), + "system_prompt": prompt_manager.get_agent_prompt(channel=self.channel), "input": message } @@ -190,7 +197,7 @@ class MoviePilotAgent: await self.send_agent_message(agent_message) # 添加Agent回复到记忆 - await self.memory_manager.add_memory( + await conversation_manager.add_conversation( session_id=self.session_id, user_id=self.user_id, role="agent", @@ -210,7 +217,9 @@ class MoviePilotAgent: return error_message async def _execute_agent(self, input_context: Dict[str, Any]) -> Dict[str, Any]: - """执行LangChain Agent""" + """ + 执行LangChain Agent + """ try: with get_openai_callback() as cb: result = await self.agent_executor.ainvoke( @@ -243,7 +252,9 @@ class MoviePilotAgent: } async def send_agent_message(self, message: str, title: str = "MoviePilot助手"): - """通过原渠道发送消息给用户""" + """ + 通过原渠道发送消息给用户 + """ await AgentChain().async_post_message( Notification( channel=self.channel, @@ -256,24 +267,32 @@ class MoviePilotAgent: ) async def cleanup(self): - """清理智能体资源""" + """ + 清理智能体资源 + """ logger.info(f"MoviePilot智能体已清理: session_id={self.session_id}") class AgentManager: - """AI智能体管理器""" + """ + AI智能体管理器 + """ def __init__(self): self.active_agents: Dict[str, MoviePilotAgent] = {} - self.memory_manager = ConversationMemoryManager() - async def initialize(self): - """初始化管理器""" - await self.memory_manager.initialize() + @staticmethod + async def initialize(): + """ + 初始化管理器 + """ + await conversation_manager.initialize() async def close(self): - """关闭管理器""" - await self.memory_manager.close() + """ + 关闭管理器 + """ + await conversation_manager.close() # 清理所有活跃的智能体 for agent in self.active_agents.values(): await agent.cleanup() @@ -281,7 +300,9 @@ class AgentManager: async def process_message(self, session_id: str, user_id: str, message: str, channel: str = None, source: str = None, username: str = None) -> str: - """处理用户消息""" + """ + 处理用户消息 + """ # 获取或创建Agent实例 if session_id not in self.active_agents: logger.info(f"创建新的AI智能体实例,session_id: {session_id}, user_id: {user_id}") @@ -292,7 +313,6 @@ class AgentManager: source=source, username=username ) - agent.memory_manager = self.memory_manager self.active_agents[session_id] = agent else: agent = self.active_agents[session_id] @@ -309,12 +329,14 @@ class AgentManager: return await agent.process_message(message) async def clear_session(self, session_id: str, user_id: str): - """清空会话""" + """ + 清空会话 + """ if session_id in self.active_agents: agent = self.active_agents[session_id] await agent.cleanup() del self.active_agents[session_id] - await self.memory_manager.clear_memory(session_id, user_id) + await conversation_manager.clear_memory(session_id, user_id) logger.info(f"会话 {session_id} 的记忆已清空") diff --git a/app/agent/callback/__init__.py b/app/agent/callback/__init__.py index 0e511af5..5da3957d 100644 --- a/app/agent/callback/__init__.py +++ b/app/agent/callback/__init__.py @@ -6,7 +6,9 @@ from app.log import logger class StreamingCallbackHandler(AsyncCallbackHandler): - """流式输出回调处理器""" + """ + 流式输出回调处理器 + """ def __init__(self, session_id: str): self._lock = threading.Lock() @@ -14,7 +16,9 @@ class StreamingCallbackHandler(AsyncCallbackHandler): self.current_message = "" async def get_message(self): - """获取当前消息内容,获取后清空""" + """ + 获取当前消息内容,获取后清空 + """ with self._lock: if not self.current_message: return "" @@ -24,7 +28,9 @@ class StreamingCallbackHandler(AsyncCallbackHandler): return msg async def on_llm_new_token(self, token: str, **kwargs): - """处理新的token""" + """ + 处理新的token + """ if not token: return with self._lock: diff --git a/app/agent/memory/__init__.py b/app/agent/memory/__init__.py index aef320bf..7192ebda 100644 --- a/app/agent/memory/__init__.py +++ b/app/agent/memory/__init__.py @@ -12,7 +12,9 @@ from app.schemas.agent import ConversationMemory class ConversationMemoryManager: - """对话记忆管理器""" + """ + 对话记忆管理器 + """ def __init__(self): # 内存中的会话记忆缓存 @@ -23,7 +25,9 @@ class ConversationMemoryManager: self.cleanup_task: Optional[asyncio.Task] = None async def initialize(self): - """初始化记忆管理器""" + """ + 初始化记忆管理器 + """ try: # 启动内存缓存清理任务(Redis通过TTL自动过期) self.cleanup_task = asyncio.create_task(self._cleanup_expired_memories()) @@ -33,7 +37,9 @@ class ConversationMemoryManager: logger.warning(f"Redis连接失败,将使用内存存储: {e}") async def close(self): - """关闭记忆管理器""" + """ + 关闭记忆管理器 + """ if self.cleanup_task: self.cleanup_task.cancel() try: @@ -46,56 +52,83 @@ class ConversationMemoryManager: logger.info("对话记忆管理器已关闭") @staticmethod - def get_memory_key(session_id: str, user_id: str): - """计算内存Key""" + def _get_memory_key(session_id: str, user_id: str): + """ + 计算内存Key + """ return f"{user_id}:{session_id}" if user_id else session_id @staticmethod - def get_redis_key(session_id: str, user_id: str): - """计算Redis Key""" + def _get_redis_key(session_id: str, user_id: str): + """ + 计算Redis Key + """ return f"agent_memory:{user_id}:{session_id}" if user_id else f"agent_memory:{session_id}" - async def get_memory(self, session_id: str, user_id: str) -> ConversationMemory: - """获取会话记忆""" - # 首先检查缓存 - cache_key = self.get_memory_key(session_id, user_id) - if cache_key in self.memory_cache: - return self.memory_cache[cache_key] - - # 尝试从Redis加载 + def _get_memory(self, session_id: str, user_id: str): + """ + 获取内存中的记忆 + """ + cache_key = self._get_memory_key(session_id, user_id) + return self.memory_cache.get(cache_key) + + async def _get_redis(self, session_id: str, user_id: str) -> Optional[ConversationMemory]: + """ + 从Redis获取记忆 + """ if settings.CACHE_BACKEND_TYPE == "redis": try: - redis_key = self.get_redis_key(session_id, user_id) + redis_key = self._get_redis_key(session_id, user_id) memory_data = await self.redis_helper.get(redis_key, region="AI_AGENT") if memory_data: memory_dict = json.loads(memory_data) if isinstance(memory_data, str) else memory_data memory = ConversationMemory(**memory_dict) - self.memory_cache[cache_key] = memory return memory except Exception as e: logger.warning(f"从Redis加载记忆失败: {e}") + return None + + async def get_conversation(self, session_id: str, user_id: str) -> ConversationMemory: + """ + 获取会话记忆 + """ + # 首先检查缓存 + conversion = self._get_memory(session_id, user_id) + if conversion: + return conversion + + # 尝试从Redis加载 + memory = await self._get_redis(session_id, user_id) + if memory: + # 加载到内存缓存 + self._save_memory(memory) + return memory # 创建新的记忆 memory = ConversationMemory(session_id=session_id, user_id=user_id) - self.memory_cache[cache_key] = memory - await self._save_memory(memory) + await self._save_conversation(memory) return memory async def set_title(self, session_id: str, user_id: str, title: str): - """设置会话标题""" - memory = await self.get_memory(session_id=session_id, user_id=user_id) + """ + 设置会话标题 + """ + memory = await self.get_conversation(session_id=session_id, user_id=user_id) memory.title = title memory.updated_at = datetime.now() - await self._save_memory(memory) + await self._save_conversation(memory) async def get_title(self, session_id: str, user_id: str) -> Optional[str]: - """获取会话标题""" - memory = await self.get_memory(session_id=session_id, user_id=user_id) + """ + 获取会话标题 + """ + memory = await self.get_conversation(session_id=session_id, user_id=user_id) return memory.title async def list_sessions(self, user_id: str, limit: int = 100) -> List[Dict[str, Any]]: - """列出历史会话摘要(按更新时间倒序) + """ + 列出历史会话摘要(按更新时间倒序) - 当启用Redis时:遍历 `agent_memory:*` 键并读取摘要 - 当未启用Redis时:基于内存缓存返回 @@ -148,7 +181,7 @@ class ConversationMemoryManager: for m in sorted_list ] - async def add_memory( + async def add_conversation( self, session_id: str, user_id: str, @@ -156,8 +189,10 @@ class ConversationMemoryManager: content: str, metadata: Optional[Dict[str, Any]] = None ): - """添加消息到记忆""" - memory = await self.get_memory(session_id=session_id, user_id=user_id) + """ + 添加消息到记忆 + """ + memory = await self.get_conversation(session_id=session_id, user_id=user_id) message = { "role": role, @@ -177,7 +212,7 @@ class ConversationMemoryManager: recent_messages = memory.messages[-(max_messages - len(system_messages)):] memory.messages = system_messages + recent_messages - await self._save_memory(memory) + await self._save_conversation(memory) logger.debug(f"消息已添加到记忆: session_id={session_id}, user_id={user_id}, role={role}") @@ -186,11 +221,12 @@ class ConversationMemoryManager: session_id: str, user_id: str ) -> List[Dict[str, Any]]: - """为Agent获取最近的消息(仅内存缓存) + """ + 为Agent获取最近的消息(仅内存缓存) 如果消息Token数量超过模型最大上下文长度的阀值,会自动进行摘要裁剪 """ - cache_key = self.get_memory_key(session_id, user_id) + cache_key = self._get_memory_key(session_id, user_id) memory = self.memory_cache.get(cache_key) if not memory: return [] @@ -205,8 +241,10 @@ class ConversationMemoryManager: limit: int = 10, role_filter: Optional[list] = None ) -> List[Dict[str, Any]]: - """获取最近的消息""" - memory = await self.get_memory(session_id=session_id, user_id=user_id) + """ + 获取最近的消息 + """ + memory = await self.get_conversation(session_id=session_id, user_id=user_id) messages = memory.messages if role_filter: @@ -215,36 +253,41 @@ class ConversationMemoryManager: return messages[-limit:] if messages else [] async def get_context(self, session_id: str, user_id: str) -> Dict[str, Any]: - """获取会话上下文""" - memory = await self.get_memory(session_id=session_id, user_id=user_id) + """ + 获取会话上下文 + """ + memory = await self.get_conversation(session_id=session_id, user_id=user_id) return memory.context async def clear_memory(self, session_id: str, user_id: str): - """清空会话记忆""" + """ + 清空会话记忆 + """ cache_key = f"{user_id}:{session_id}" if user_id else session_id if cache_key in self.memory_cache: del self.memory_cache[cache_key] if settings.CACHE_BACKEND_TYPE == "redis": - redis_key = self.get_redis_key(session_id, user_id) + redis_key = self._get_redis_key(session_id, user_id) await self.redis_helper.delete(redis_key, region="AI_AGENT") logger.info(f"会话记忆已清空: session_id={session_id}, user_id={user_id}") - async def _save_memory(self, memory: ConversationMemory): - """保存记忆到存储 - - Redis中的记忆会自动通过TTL机制过期,无需手动清理 + def _save_memory(self, memory: ConversationMemory): """ - # 更新内存缓存 - cache_key = self.get_memory_key(memory.session_id, memory.user_id) + 保存记忆到内存 + """ + cache_key = self._get_memory_key(memory.session_id, memory.user_id) self.memory_cache[cache_key] = memory - # 保存到Redis,设置TTL自动过期 + async def _save_redis(self, memory: ConversationMemory): + """ + 保存记忆到Redis + """ if settings.CACHE_BACKEND_TYPE == "redis": try: memory_dict = memory.model_dump() - redis_key = self.get_redis_key(memory.session_id, memory.user_id) + redis_key = self._get_redis_key(memory.session_id, memory.user_id) ttl = int(timedelta(days=settings.LLM_REDIS_MEMORY_RETENTION_DAYS).total_seconds()) await self.redis_helper.set( redis_key, @@ -255,8 +298,22 @@ class ConversationMemoryManager: except Exception as e: logger.warning(f"保存记忆到Redis失败: {e}") + async def _save_conversation(self, memory: ConversationMemory): + """ + 保存记忆到存储 + + Redis中的记忆会自动通过TTL机制过期,无需手动清理 + """ + # 更新内存缓存 + self._save_memory(memory) + + # 保存到Redis,设置TTL自动过期 + await self._save_redis(memory) + + async def _cleanup_expired_memories(self): - """清理内存中过期记忆的后台任务 + """ + 清理内存中过期记忆的后台任务 注意:Redis中的记忆通过TTL机制自动过期,这里只清理内存缓存 """ @@ -286,3 +343,5 @@ class ConversationMemoryManager: break except Exception as e: logger.error(f"清理记忆时发生错误: {e}") + +conversation_manager = ConversationMemoryManager() diff --git a/app/agent/prompt/Agent Prompt.txt b/app/agent/prompt/Agent Prompt.txt index 12a7ab69..0f46c8c8 100644 --- a/app/agent/prompt/Agent Prompt.txt +++ b/app/agent/prompt/Agent Prompt.txt @@ -4,6 +4,20 @@ All your responses must be in **Chinese (中文)**. You act as a proactive agent. Your goal is to fully resolve the user's media-related requests autonomously. Do not end your turn until the task is complete or you are blocked and require user feedback. +Core Capabilities: +1. Media Search & Recognition +- Identify movies, TV shows, and anime across various metadata providers. +- Recognize media info from fuzzy filenames or incomplete titles. +2. Subscription Management +- Create complex rules for automated downloading of new episodes. +- Monitor trending movies/shows for automated suggestions. +3. Download Control +- Intelligent torrent searching across private/public trackers. +- Filter resources by quality (4K/1080p), codec (H265/H264), and release groups. +4. System Status & Organization +- Monitor download progress and server health. +- Manage file transfers, renaming, and library cleanup. + - Use Markdown for structured data like movie lists, download statuses, or technical details. - Avoid wrapping the entire response in a single code block. Use `inline code` for titles or parameters and ```code blocks``` for structured logs or data only when necessary. @@ -11,6 +25,11 @@ You act as a proactive agent. Your goal is to fully resolve the user's media-rel - Optimize your writing for clarity and readability, using bold text for key information. - Provide comprehensive details for media (year, rating, resolution) to help users make informed decisions. - Do not stop for approval for read-only operations. Only stop for critical actions like starting a download or deleting a subscription. + +Important Notes: +- User-Centric: Your tone should be helpful, professional, and media-savvy. +- No Coding Hallucinations: You are NOT a coding assistant. Do not offer code snippets, IDE tips, or programming help. Focus entirely on the MoviePilot media ecosystem. +- Contextual Memory: Remember if the user preferred a specific version previously and prioritize similar results in future searches. @@ -28,44 +47,26 @@ At the end of your session/turn, provide a concise summary of your actions. -1. **Media Discovery**: Start by identifying the exact media metadata (TMDB ID, Season/Episode) using search tools. -2. **Context Checking**: Verify current status (Is it already in the library? Is it already subscribed?). -3. **Action Execution**: Perform the requested task (Subscribe, Search Torrents, etc.) with a brief status update. -4. **Final Confirmation**: Summarize the final state and wait for the next user command. +1. Media Discovery: Start by identifying the exact media metadata (TMDB ID, Season/Episode) using search tools. +2. Context Checking: Verify current status (Is it already in the library? Is it already subscribed?). +3. Action Execution: Perform the requested task (Subscribe, Search Torrents, etc.) with a brief status update. +4. Final Confirmation: Summarize the final state and wait for the next user command. -- **Parallel Execution**: You MUST call independent tools in parallel. For example, search for torrents on multiple sites or check both subscription and download status at once. -- **Information Depth**: If a search returns ambiguous results, use `query_media_detail` or `recognize_media` to resolve the ambiguity before proceeding. -- **Proactive Fallback**: If `search_media` fails, try `search_web` or fuzzy search with `recognize_media`. Do not ask the user for help unless all automated search methods are exhausted. +- Parallel Execution: You MUST call independent tools in parallel. For example, search for torrents on multiple sites or check both subscription and download status at once. +- Information Depth: If a search returns ambiguous results, use `query_media_detail` or `recognize_media` to resolve the ambiguity before proceeding. +- Proactive Fallback: If `search_media` fails, try `search_web` or fuzzy search with `recognize_media`. Do not ask the user for help unless all automated search methods are exhausted. -1. **Download Safety**: You MUST present a list of found torrents (including size, seeds, and quality) and obtain the user's explicit consent before initiating any download. -2. **Subscription Logic**: When adding a subscription, always check for the best matching quality profile based on user history or the default settings. -3. **Library Awareness**: Always check if the user already has the content in their library to avoid duplicate downloads. -4. **Error Handling**: If a site is down or a tool returns an error, explain the situation in plain Chinese (e.g., "站点响应超时") and suggest an alternative (e.g., "尝试从其他站点进行搜索"). +1. Download Safety: You MUST present a list of found torrents (including size, seeds, and quality) and obtain the user's explicit consent before initiating any download. +2. Subscription Logic: When adding a subscription, always check for the best matching quality profile based on user history or the default settings. +3. Library Awareness: Always check if the user already has the content in their library to avoid duplicate downloads. +4. Error Handling: If a site is down or a tool returns an error, explain the situation in plain Chinese (e.g., "站点响应超时") and suggest an alternative (e.g., "尝试从其他站点进行搜索"). -## Core Capabilities - -### 1. Media Search & Recognition -- Identify movies, TV shows, and anime across various metadata providers. -- Recognize media info from fuzzy filenames or incomplete titles. - -### 2. Subscription Management -- Create complex rules for automated downloading of new episodes. -- Monitor trending movies/shows for automated suggestions. - -### 3. Download Control -- Intelligent torrent searching across private/public trackers. -- Filter resources by quality (4K/1080p), codec (H265/H264), and release groups. - -### 4. System Status & Organization -- Monitor download progress and server health. -- Manage file transfers, renaming, and library cleanup. - -## Important Notes -- **User-Centric**: Your tone should be helpful, professional, and media-savvy. -- **No Coding Hallucinations**: You are NOT a coding assistant. Do not offer code snippets, IDE tips, or programming help. Focus entirely on the MoviePilot media ecosystem. -- **Contextual Memory**: Remember if the user preferred a specific version previously and prioritize similar results in future searches. \ No newline at end of file + +Specific markdown rules: +{markdown_spec} + \ No newline at end of file diff --git a/app/agent/prompt/__init__.py b/app/agent/prompt/__init__.py index 2d3f2066..992550db 100644 --- a/app/agent/prompt/__init__.py +++ b/app/agent/prompt/__init__.py @@ -1,13 +1,15 @@ """提示词管理器""" - from pathlib import Path from typing import Dict from app.log import logger +from app.schemas import ChannelCapability, ChannelCapabilities, MessageChannel, ChannelCapabilityManager class PromptManager: - """提示词管理器""" + """ + 提示词管理器 + """ def __init__(self, prompts_dir: str = None): if prompts_dir is None: @@ -17,22 +19,20 @@ class PromptManager: self.prompts_cache: Dict[str, str] = {} def load_prompt(self, prompt_name: str) -> str: - """加载指定的提示词""" + """ + 加载指定的提示词 + """ if prompt_name in self.prompts_cache: return self.prompts_cache[prompt_name] prompt_file = self.prompts_dir / prompt_name - try: with open(prompt_file, 'r', encoding='utf-8') as f: content = f.read().strip() - # 缓存提示词 self.prompts_cache[prompt_name] = content - logger.info(f"提示词加载成功: {prompt_name},长度:{len(content)} 字符") return content - except FileNotFoundError: logger.error(f"提示词文件不存在: {prompt_file}") raise @@ -46,73 +46,43 @@ class PromptManager: :param channel: 消息渠道(Telegram、微信、Slack等) :return: 提示词内容 """ + # 基础提示词 base_prompt = self.load_prompt("Agent Prompt.txt") - - # 根据渠道添加特定的格式说明 - if channel: - channel_format_info = self._get_channel_format_info(channel) - if channel_format_info: - base_prompt += f"\n\n## Current Message Channel Format Requirements\n\n{channel_format_info}" - + + # 识别渠道 + msg_channel = next((c for c in MessageChannel if c.value.lower() == channel.lower()), None) if channel else None + if msg_channel: + # 获取渠道能力说明 + caps = ChannelCapabilityManager.get_capabilities(msg_channel) + if caps: + base_prompt = base_prompt.replace( + "{markdown_spec}", + self._generate_formatting_instructions(caps) + ) + return base_prompt - + @staticmethod - def _get_channel_format_info(channel: str) -> str: + def _generate_formatting_instructions(caps: ChannelCapabilities) -> str: """ - 获取渠道特定的格式说明 - :param channel: 消息渠道 - :return: 格式说明文本 + 根据渠道能力动态生成格式指令 """ - channel_lower = channel.lower() if channel else "" - - if "telegram" in channel_lower: - return """Messages are being sent through the **Telegram** channel. You must follow these format requirements: - -**Supported Formatting:** -- **Bold text**: Use `*text*` (single asterisk, not double asterisks) -- **Italic text**: Use `_text_` (underscore) -- **Code**: Use `` `text` `` (backtick) -- **Links**: Use `[text](url)` format -- **Strikethrough**: Use `~text~` (tilde) - -**IMPORTANT - Headings and Lists:** -- **DO NOT use heading syntax** (`#`, `##`, `###`) - Telegram MarkdownV2 does NOT support it -- **Instead, use bold text for headings**: `*Heading Text*` followed by a blank line -- **DO NOT use list syntax** (`-`, `*`, `+` at line start) - these will be escaped and won't display as lists -- **For lists**, use plain text with line breaks, or use bold for list item labels: `*Item 1:* description` - -**Examples:** -- ❌ Wrong heading: `# Main Title` or `## Subtitle` -- ✅ Correct heading: `*Main Title*` (followed by blank line) or `*Subtitle*` (followed by blank line) -- ❌ Wrong list: `- Item 1` or `* Item 2` -- ✅ Correct list format: `*Item 1:* description` or use plain text with line breaks - -**Special Characters:** -- Avoid using special characters that need escaping in MarkdownV2: `_*[]()~`>#+-=|{}.!` unless they are part of the formatting syntax -- Keep formatting simple, avoid nested formatting to ensure proper rendering in Telegram""" - - elif "wechat" in channel_lower or "微信" in channel: - return """Messages are being sent through the **WeChat** channel. Please follow these format requirements: - -- WeChat does NOT support Markdown formatting. Use plain text format only. -- Do NOT use any Markdown syntax (such as `**bold**`, `*italic*`, `` `code` `` etc.) -- Use plain text descriptions. You can organize content using line breaks and punctuation -- Links can be provided directly as URLs, no Markdown link format needed -- Keep messages concise and clear, use natural Chinese expressions""" - - elif "slack" in channel_lower: - return """Messages are being sent through the **Slack** channel. Please follow these format requirements: - -- Slack supports Markdown formatting -- Use `*text*` for bold -- Use `_text_` for italic -- Use `` `text` `` for code -- Link format: `` or `[text](url)`""" - - # 其他渠道使用标准Markdown - return None + instructions = [] + if ChannelCapability.RICH_TEXT not in caps.capabilities: + instructions.append("- Formatting: Use **Plain Text ONLY**. The channel does NOT support Markdown.") + instructions.append( + "- No Markdown Symbols: NEVER use `**`, `*`, `__`, or `[` blocks. Use natural text to emphasize (e.g., using ALL CAPS or separators).") + instructions.append( + "- Lists: Use plain text symbols like `>` or `*` at the start of lines, followed by manual line breaks.") + instructions.append("- Links: Paste URLs directly as text.") + return "\n".join(instructions) def clear_cache(self): - """清空缓存""" + """ + 清空缓存 + """ self.prompts_cache.clear() logger.info("提示词缓存已清空") + + +prompt_manager = PromptManager() diff --git a/app/agent/tools/base.py b/app/agent/tools/base.py index 06164f9d..6246ef4f 100644 --- a/app/agent/tools/base.py +++ b/app/agent/tools/base.py @@ -1,4 +1,3 @@ -"""MoviePilot工具基类""" import json from abc import ABCMeta, abstractmethod from typing import Any, Optional @@ -6,7 +5,7 @@ from typing import Any, Optional from langchain.tools import BaseTool from pydantic import PrivateAttr -from app.agent import StreamingCallbackHandler, ConversationMemoryManager +from app.agent import StreamingCallbackHandler, conversation_manager from app.chain import ChainBase from app.log import logger from app.schemas import Notification @@ -17,7 +16,9 @@ class ToolChain(ChainBase): class MoviePilotTool(BaseTool, metaclass=ABCMeta): - """MoviePilot专用工具基类""" + """ + MoviePilot专用工具基类 + """ _session_id: str = PrivateAttr() _user_id: str = PrivateAttr() @@ -25,7 +26,6 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta): _source: str = PrivateAttr(default=None) _username: str = PrivateAttr(default=None) _callback_handler: StreamingCallbackHandler = PrivateAttr(default=None) - _memory_manager: ConversationMemoryManager = PrivateAttr(default=None) def __init__(self, session_id: str, user_id: str, **kwargs): super().__init__(**kwargs) @@ -36,7 +36,9 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta): pass async def _arun(self, **kwargs) -> str: - """异步运行工具""" + """ + 异步运行工具 + """ # 发送和记忆工具调用前的信息 agent_message = await self._callback_handler.get_message() if agent_message: @@ -44,7 +46,7 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta): await self.send_tool_message(agent_message, title="MoviePilot助手") # 记忆工具调用 - await self._memory_manager.add_memory( + await conversation_manager.add_conversation( session_id=self._session_id, user_id=self._user_id, role="tool_call", @@ -77,7 +79,7 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta): formated_result = str(result) else: formated_result = json.dumps(result, ensure_ascii=False, indent=2) - await self._memory_manager.add_memory( + await conversation_manager.add_conversation( session_id=self._session_id, user_id=self._user_id, role="tool_result", @@ -106,21 +108,23 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta): raise NotImplementedError def set_message_attr(self, channel: str, source: str, username: str): - """设置消息属性""" + """ + 设置消息属性 + """ self._channel = channel self._source = source self._username = username def set_callback_handler(self, callback_handler: StreamingCallbackHandler): - """设置回调处理器""" + """ + 设置回调处理器 + """ self._callback_handler = callback_handler - def set_memory_manager(self, memory_manager: ConversationMemoryManager): - """设置记忆客理器""" - self._memory_manager = memory_manager - async def send_tool_message(self, message: str, title: str = ""): - """发送工具消息""" + """ + 发送工具消息 + """ await ToolChain().async_post_message( Notification( channel=self._channel, diff --git a/app/agent/tools/factory.py b/app/agent/tools/factory.py index 7256af53..606a5f05 100644 --- a/app/agent/tools/factory.py +++ b/app/agent/tools/factory.py @@ -1,5 +1,3 @@ -"""MoviePilot工具工厂""" - from typing import List, Callable from app.agent.tools.impl.add_download import AddDownloadTool @@ -47,13 +45,17 @@ from .base import MoviePilotTool class MoviePilotToolFactory: - """MoviePilot工具工厂""" + """ + MoviePilot工具工厂 + """ @staticmethod def create_tools(session_id: str, user_id: str, channel: str = None, source: str = None, username: str = None, - callback_handler: Callable = None, memory_mananger: Callable = None) -> List[MoviePilotTool]: - """创建MoviePilot工具列表""" + callback_handler: Callable = None) -> List[MoviePilotTool]: + """ + 创建MoviePilot工具列表 + """ tools = [] tool_definitions = [ SearchMediaTool, @@ -104,7 +106,6 @@ class MoviePilotToolFactory: ) tool.set_message_attr(channel=channel, source=source, username=username) tool.set_callback_handler(callback_handler=callback_handler) - tool.set_memory_manager(memory_manager=memory_mananger) tools.append(tool) # 加载插件提供的工具 @@ -127,7 +128,6 @@ class MoviePilotToolFactory: ) tool.set_message_attr(channel=channel, source=source, username=username) tool.set_callback_handler(callback_handler=callback_handler) - tool.set_memory_manager(memory_manager=memory_mananger) tools.append(tool) plugin_tools_count += 1 logger.debug(f"成功加载插件 {plugin_name}({plugin_id}) 的工具: {ToolClass.__name__}") diff --git a/app/agent/tools/manager.py b/app/agent/tools/manager.py index 584ba49b..d86e3e23 100644 --- a/app/agent/tools/manager.py +++ b/app/agent/tools/manager.py @@ -1,18 +1,15 @@ -"""MoviePilot工具管理器 -用于HTTP API调用工具 -""" - import json import uuid from typing import Any, Dict, List, Optional -from app.agent import ConversationMemoryManager from app.agent.tools.factory import MoviePilotToolFactory from app.log import logger class ToolDefinition: - """工具定义""" + """ + 工具定义 + """ def __init__(self, name: str, description: str, input_schema: Dict[str, Any]): self.name = name @@ -21,7 +18,9 @@ class ToolDefinition: class MoviePilotToolsManager: - """MoviePilot工具管理器(用于HTTP API)""" + """ + MoviePilot工具管理器(用于HTTP API) + """ def __init__(self, user_id: str = "api_user", session_id: str = uuid.uuid4()): """ @@ -34,11 +33,12 @@ class MoviePilotToolsManager: self.user_id = user_id self.session_id = session_id self.tools: List[Any] = [] - self.memory_manager = ConversationMemoryManager() self._load_tools() def _load_tools(self): - """加载所有MoviePilot工具""" + """ + 加载所有MoviePilot工具 + """ try: # 创建工具实例 self.tools = MoviePilotToolFactory.create_tools( @@ -48,7 +48,6 @@ class MoviePilotToolsManager: source="api", username="API Client", callback_handler=None, - memory_mananger=None, ) logger.info(f"成功加载 {len(self.tools)} 个工具") except Exception as e: @@ -116,7 +115,7 @@ class MoviePilotToolsManager: args_schema = getattr(tool_instance, 'args_schema', None) if not args_schema: return arguments - + # 获取schema中的字段定义 try: schema = args_schema.model_json_schema() @@ -124,7 +123,7 @@ class MoviePilotToolsManager: except Exception as e: logger.warning(f"获取工具schema失败: {e}") return arguments - + # 规范化参数 normalized = {} for key, value in arguments.items(): @@ -132,10 +131,10 @@ class MoviePilotToolsManager: # 参数不在schema中,保持原样 normalized[key] = value continue - + field_info = properties[key] field_type = field_info.get("type") - + # 处理 anyOf 类型(例如 Optional[int] 会生成 anyOf) any_of = field_info.get("anyOf") if any_of and not field_type: @@ -144,7 +143,7 @@ class MoviePilotToolsManager: if "type" in type_option and type_option["type"] != "null": field_type = type_option["type"] break - + # 根据类型进行转换 if field_type == "integer" and isinstance(value, str): try: @@ -167,7 +166,7 @@ class MoviePilotToolsManager: normalized[key] = True else: normalized[key] = value - + return normalized async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> str: @@ -192,7 +191,7 @@ class MoviePilotToolsManager: try: # 规范化参数类型 normalized_arguments = self._normalize_arguments(tool_instance, arguments) - + # 调用工具的run方法 result = await tool_instance.run(**normalized_arguments) @@ -270,3 +269,6 @@ class MoviePilotToolsManager: "properties": properties, "required": required } + + +moviepilot_tool_manager = MoviePilotToolsManager() diff --git a/app/api/endpoints/mcp.py b/app/api/endpoints/mcp.py index 3343a106..7da7634d 100644 --- a/app/api/endpoints/mcp.py +++ b/app/api/endpoints/mcp.py @@ -1,14 +1,10 @@ -"""工具API端点 -通过HTTP API暴露MoviePilot的智能体工具功能 -""" - from typing import List, Any, Dict, Annotated, Union from fastapi import APIRouter, Depends, HTTPException, Request from fastapi.responses import JSONResponse, Response from app import schemas -from app.agent.tools.manager import MoviePilotToolsManager +from app.agent.tools.manager import moviepilot_tool_manager from app.core.security import verify_apikey from app.log import logger @@ -25,18 +21,10 @@ MCP_PROTOCOL_VERSIONS = ["2025-11-25", "2025-06-18", "2024-11-05"] MCP_PROTOCOL_VERSION = MCP_PROTOCOL_VERSIONS[0] # 默认使用最新版本 -def get_tools_manager() -> MoviePilotToolsManager: - """ - 获取工具管理器实例 - - Returns: - MoviePilotToolsManager实例 - """ - return MoviePilotToolsManager() - - def create_jsonrpc_response(request_id: Union[str, int, None], result: Any) -> Dict[str, Any]: - """创建 JSON-RPC 成功响应""" + """ + 创建 JSON-RPC 成功响应 + """ response = { "jsonrpc": "2.0", "id": request_id, @@ -45,8 +33,11 @@ def create_jsonrpc_response(request_id: Union[str, int, None], result: Any) -> D return response -def create_jsonrpc_error(request_id: Union[str, int, None], code: int, message: str, data: Any = None) -> Dict[str, Any]: - """创建 JSON-RPC 错误响应""" +def create_jsonrpc_error(request_id: Union[str, int, None], code: int, message: str, data: Any = None) -> Dict[ + str, Any]: + """ + 创建 JSON-RPC 错误响应 + """ error = { "jsonrpc": "2.0", "id": request_id, @@ -60,8 +51,6 @@ def create_jsonrpc_error(request_id: Union[str, int, None], code: int, message: return error -# ==================== MCP JSON-RPC 端点 ==================== - @router.post("", summary="MCP JSON-RPC 端点", response_model=None) async def mcp_jsonrpc( request: Request, @@ -146,7 +135,9 @@ async def mcp_jsonrpc( async def handle_initialize(params: Dict[str, Any]) -> Dict[str, Any]: - """处理初始化请求""" + """ + 处理初始化请求 + """ protocol_version = params.get("protocolVersion") client_info = params.get("clientInfo", {}) @@ -161,7 +152,7 @@ async def handle_initialize(params: Dict[str, Any]) -> Dict[str, Any]: else: # 客户端版本不支持,使用服务器默认版本 logger.warning(f"协议版本不匹配: 客户端={protocol_version}, 使用服务器版本={negotiated_version}") - + return { "protocolVersion": negotiated_version, "capabilities": { @@ -180,9 +171,10 @@ async def handle_initialize(params: Dict[str, Any]) -> Dict[str, Any]: async def handle_tools_list() -> Dict[str, Any]: - """处理工具列表请求""" - manager = get_tools_manager() - tools = manager.list_tools() + """ + 处理工具列表请求 + """ + tools = moviepilot_tool_manager.list_tools() # 转换为 MCP 工具格式 mcp_tools = [] @@ -200,18 +192,18 @@ async def handle_tools_list() -> Dict[str, Any]: async def handle_tools_call(params: Dict[str, Any]) -> Dict[str, Any]: - """处理工具调用请求""" + """ + 处理工具调用请求 + """ tool_name = params.get("name") arguments = params.get("arguments", {}) if not tool_name: raise ValueError("Missing tool name") - manager = get_tools_manager() - try: - result_text = await manager.call_tool(tool_name, arguments) - + result_text = await moviepilot_tool_manager.call_tool(tool_name, arguments) + return { "content": [ { @@ -243,8 +235,6 @@ async def delete_mcp_session( return Response(status_code=204) - - # ==================== 兼容的 RESTful API 端点 ==================== @router.get("/tools", summary="列出所有可用工具", response_model=List[Dict[str, Any]]) @@ -257,9 +247,8 @@ async def list_tools( 返回每个工具的名称、描述和参数定义 """ try: - manager = get_tools_manager() # 获取所有工具定义 - tools = manager.list_tools() + tools = moviepilot_tool_manager.list_tools() # 转换为字典格式 tools_list = [] @@ -289,11 +278,8 @@ async def call_tool( 工具执行结果 """ try: - # 使用当前用户ID创建管理器实例 - manager = get_tools_manager() - # 调用工具 - result_text = await manager.call_tool(request.tool_name, request.arguments) + result_text = await moviepilot_tool_manager.call_tool(request.tool_name, request.arguments) return schemas.ToolCallResponse( success=True, @@ -319,9 +305,8 @@ async def get_tool_info( 工具的详细信息,包括名称、描述和参数定义 """ try: - manager = get_tools_manager() # 获取所有工具 - tools = manager.list_tools() + tools = moviepilot_tool_manager.list_tools() # 查找指定工具 for tool in tools: @@ -352,9 +337,8 @@ async def get_tool_schema( 工具的JSON Schema定义 """ try: - manager = get_tools_manager() # 获取所有工具 - tools = manager.list_tools() + tools = moviepilot_tool_manager.list_tools() # 查找指定工具 for tool in tools: diff --git a/app/startup/agent_initializer.py b/app/startup/agent_initializer.py index 35b6d34e..81033405 100644 --- a/app/startup/agent_initializer.py +++ b/app/startup/agent_initializer.py @@ -1,20 +1,17 @@ -""" -AI智能体初始化器 -""" - import asyncio import threading -from typing import Optional +from app.agent import agent_manager from app.core.config import settings from app.log import logger class AgentInitializer: - """AI智能体初始化器""" + """ + AI智能体初始化器 + """ def __init__(self): - self.agent_manager = None self._initialized = False async def initialize(self) -> bool: @@ -26,9 +23,6 @@ class AgentInitializer: logger.info("AI智能体功能未启用") return True - from app.agent import agent_manager - self.agent_manager = agent_manager - await agent_manager.initialize() self._initialized = True logger.info("AI智能体管理器初始化成功") @@ -43,10 +37,10 @@ class AgentInitializer: 清理AI智能体管理器 """ try: - if not self._initialized or not self.agent_manager: + if not self._initialized: return - await self.agent_manager.close() + await agent_manager.close() self._initialized = False logger.info("AI智能体管理器已关闭") @@ -78,8 +72,8 @@ def init_agent(): else: logger.error("AI智能体管理器初始化失败") return success - except Exception as e: - logger.error(f"AI智能体管理器初始化失败: {e}") + except Exception as err: + logger.error(f"AI智能体管理器初始化失败: {err}") return False finally: loop.close()