feat: 为工具管理器添加参数类型规范化处理,并基于渠道能力动态生成提示词中的格式要求

This commit is contained in:
jxxghp
2026-01-15 20:55:35 +08:00
parent 1bf9862e47
commit 1641d432dd
10 changed files with 324 additions and 282 deletions

View File

@@ -1,5 +1,3 @@
"""MoviePilot AI智能体实现"""
import asyncio
from typing import Dict, List, Any
@@ -11,8 +9,8 @@ from langchain_core.messages import HumanMessage, AIMessage, ToolCall, ToolMessa
from langchain_core.runnables.history import RunnableWithMessageHistory
from app.agent.callback import StreamingCallbackHandler
from app.agent.memory import ConversationMemoryManager
from app.agent.prompt import PromptManager
from app.agent.memory import conversation_manager
from app.agent.prompt import prompt_manager
from app.agent.tools.factory import MoviePilotToolFactory
from app.chain import ChainBase
from app.core.config import settings
@@ -27,7 +25,9 @@ class AgentChain(ChainBase):
class MoviePilotAgent:
"""MoviePilot AI智能体"""
"""
MoviePilot AI智能体
"""
def __init__(self, session_id: str, user_id: str = None,
channel: str = None, source: str = None, username: str = None):
@@ -40,12 +40,6 @@ class MoviePilotAgent:
# 消息助手
self.message_helper = MessageHelper()
# 记忆管理器
self.memory_manager = ConversationMemoryManager()
# 提示词管理器
self.prompt_manager = PromptManager()
# 回调处理器
self.callback_handler = StreamingCallbackHandler(
session_id=session_id
@@ -64,30 +58,37 @@ class MoviePilotAgent:
self.agent_executor = self._create_agent_executor()
def _initialize_llm(self):
"""初始化LLM模型"""
"""
初始化LLM模型
"""
return LLMHelper.get_llm(streaming=True, callbacks=[self.callback_handler])
def _initialize_tools(self) -> List:
"""初始化工具列表"""
"""
初始化工具列表
"""
return MoviePilotToolFactory.create_tools(
session_id=self.session_id,
user_id=self.user_id,
channel=self.channel,
source=self.source,
username=self.username,
callback_handler=self.callback_handler,
memory_mananger=self.memory_manager
callback_handler=self.callback_handler
)
@staticmethod
def _initialize_session_store() -> Dict[str, InMemoryChatMessageHistory]:
"""初始化内存存储"""
"""
初始化内存存储
"""
return {}
def get_session_history(self, session_id: str) -> InMemoryChatMessageHistory:
"""获取会话历史"""
"""
获取会话历史
"""
chat_history = InMemoryChatMessageHistory()
messages: List[dict] = self.memory_manager.get_recent_messages_for_agent(
messages: List[dict] = conversation_manager.get_recent_messages_for_agent(
session_id=session_id,
user_id=self.user_id
)
@@ -119,7 +120,9 @@ class MoviePilotAgent:
@staticmethod
def _initialize_prompt() -> ChatPromptTemplate:
"""初始化提示词模板"""
"""
初始化提示词模板
"""
try:
prompt_template = ChatPromptTemplate.from_messages([
("system", "{system_prompt}"),
@@ -134,7 +137,9 @@ class MoviePilotAgent:
raise e
def _create_agent_executor(self) -> RunnableWithMessageHistory:
"""创建Agent执行器"""
"""
创建Agent执行器
"""
try:
agent = create_openai_tools_agent(
llm=self.llm,
@@ -161,10 +166,12 @@ class MoviePilotAgent:
raise e
async def process_message(self, message: str) -> str:
"""处理用户消息"""
"""
处理用户消息
"""
try:
# 添加用户消息到记忆
await self.memory_manager.add_memory(
await conversation_manager.add_conversation(
self.session_id,
user_id=self.user_id,
role="user",
@@ -173,7 +180,7 @@ class MoviePilotAgent:
# 构建输入上下文
input_context = {
"system_prompt": self.prompt_manager.get_agent_prompt(channel=self.channel),
"system_prompt": prompt_manager.get_agent_prompt(channel=self.channel),
"input": message
}
@@ -190,7 +197,7 @@ class MoviePilotAgent:
await self.send_agent_message(agent_message)
# 添加Agent回复到记忆
await self.memory_manager.add_memory(
await conversation_manager.add_conversation(
session_id=self.session_id,
user_id=self.user_id,
role="agent",
@@ -210,7 +217,9 @@ class MoviePilotAgent:
return error_message
async def _execute_agent(self, input_context: Dict[str, Any]) -> Dict[str, Any]:
"""执行LangChain Agent"""
"""
执行LangChain Agent
"""
try:
with get_openai_callback() as cb:
result = await self.agent_executor.ainvoke(
@@ -243,7 +252,9 @@ class MoviePilotAgent:
}
async def send_agent_message(self, message: str, title: str = "MoviePilot助手"):
"""通过原渠道发送消息给用户"""
"""
通过原渠道发送消息给用户
"""
await AgentChain().async_post_message(
Notification(
channel=self.channel,
@@ -256,24 +267,32 @@ class MoviePilotAgent:
)
async def cleanup(self):
"""清理智能体资源"""
"""
清理智能体资源
"""
logger.info(f"MoviePilot智能体已清理: session_id={self.session_id}")
class AgentManager:
"""AI智能体管理器"""
"""
AI智能体管理器
"""
def __init__(self):
self.active_agents: Dict[str, MoviePilotAgent] = {}
self.memory_manager = ConversationMemoryManager()
async def initialize(self):
"""初始化管理器"""
await self.memory_manager.initialize()
@staticmethod
async def initialize():
"""
初始化管理器
"""
await conversation_manager.initialize()
async def close(self):
"""关闭管理器"""
await self.memory_manager.close()
"""
关闭管理器
"""
await conversation_manager.close()
# 清理所有活跃的智能体
for agent in self.active_agents.values():
await agent.cleanup()
@@ -281,7 +300,9 @@ class AgentManager:
async def process_message(self, session_id: str, user_id: str, message: str,
channel: str = None, source: str = None, username: str = None) -> str:
"""处理用户消息"""
"""
处理用户消息
"""
# 获取或创建Agent实例
if session_id not in self.active_agents:
logger.info(f"创建新的AI智能体实例session_id: {session_id}, user_id: {user_id}")
@@ -292,7 +313,6 @@ class AgentManager:
source=source,
username=username
)
agent.memory_manager = self.memory_manager
self.active_agents[session_id] = agent
else:
agent = self.active_agents[session_id]
@@ -309,12 +329,14 @@ class AgentManager:
return await agent.process_message(message)
async def clear_session(self, session_id: str, user_id: str):
"""清空会话"""
"""
清空会话
"""
if session_id in self.active_agents:
agent = self.active_agents[session_id]
await agent.cleanup()
del self.active_agents[session_id]
await self.memory_manager.clear_memory(session_id, user_id)
await conversation_manager.clear_memory(session_id, user_id)
logger.info(f"会话 {session_id} 的记忆已清空")

View File

@@ -6,7 +6,9 @@ from app.log import logger
class StreamingCallbackHandler(AsyncCallbackHandler):
"""流式输出回调处理器"""
"""
流式输出回调处理器
"""
def __init__(self, session_id: str):
self._lock = threading.Lock()
@@ -14,7 +16,9 @@ class StreamingCallbackHandler(AsyncCallbackHandler):
self.current_message = ""
async def get_message(self):
"""获取当前消息内容,获取后清空"""
"""
获取当前消息内容,获取后清空
"""
with self._lock:
if not self.current_message:
return ""
@@ -24,7 +28,9 @@ class StreamingCallbackHandler(AsyncCallbackHandler):
return msg
async def on_llm_new_token(self, token: str, **kwargs):
"""处理新的token"""
"""
处理新的token
"""
if not token:
return
with self._lock:

View File

@@ -12,7 +12,9 @@ from app.schemas.agent import ConversationMemory
class ConversationMemoryManager:
"""对话记忆管理器"""
"""
对话记忆管理器
"""
def __init__(self):
# 内存中的会话记忆缓存
@@ -23,7 +25,9 @@ class ConversationMemoryManager:
self.cleanup_task: Optional[asyncio.Task] = None
async def initialize(self):
"""初始化记忆管理器"""
"""
初始化记忆管理器
"""
try:
# 启动内存缓存清理任务Redis通过TTL自动过期
self.cleanup_task = asyncio.create_task(self._cleanup_expired_memories())
@@ -33,7 +37,9 @@ class ConversationMemoryManager:
logger.warning(f"Redis连接失败将使用内存存储: {e}")
async def close(self):
"""关闭记忆管理器"""
"""
关闭记忆管理器
"""
if self.cleanup_task:
self.cleanup_task.cancel()
try:
@@ -46,56 +52,83 @@ class ConversationMemoryManager:
logger.info("对话记忆管理器已关闭")
@staticmethod
def get_memory_key(session_id: str, user_id: str):
"""计算内存Key"""
def _get_memory_key(session_id: str, user_id: str):
"""
计算内存Key
"""
return f"{user_id}:{session_id}" if user_id else session_id
@staticmethod
def get_redis_key(session_id: str, user_id: str):
"""计算Redis Key"""
def _get_redis_key(session_id: str, user_id: str):
"""
计算Redis Key
"""
return f"agent_memory:{user_id}:{session_id}" if user_id else f"agent_memory:{session_id}"
async def get_memory(self, session_id: str, user_id: str) -> ConversationMemory:
"""获取会话记忆"""
# 首先检查缓存
cache_key = self.get_memory_key(session_id, user_id)
if cache_key in self.memory_cache:
return self.memory_cache[cache_key]
# 尝试从Redis加载
def _get_memory(self, session_id: str, user_id: str):
"""
获取内存中的记忆
"""
cache_key = self._get_memory_key(session_id, user_id)
return self.memory_cache.get(cache_key)
async def _get_redis(self, session_id: str, user_id: str) -> Optional[ConversationMemory]:
"""
从Redis获取记忆
"""
if settings.CACHE_BACKEND_TYPE == "redis":
try:
redis_key = self.get_redis_key(session_id, user_id)
redis_key = self._get_redis_key(session_id, user_id)
memory_data = await self.redis_helper.get(redis_key, region="AI_AGENT")
if memory_data:
memory_dict = json.loads(memory_data) if isinstance(memory_data, str) else memory_data
memory = ConversationMemory(**memory_dict)
self.memory_cache[cache_key] = memory
return memory
except Exception as e:
logger.warning(f"从Redis加载记忆失败: {e}")
return None
async def get_conversation(self, session_id: str, user_id: str) -> ConversationMemory:
"""
获取会话记忆
"""
# 首先检查缓存
conversion = self._get_memory(session_id, user_id)
if conversion:
return conversion
# 尝试从Redis加载
memory = await self._get_redis(session_id, user_id)
if memory:
# 加载到内存缓存
self._save_memory(memory)
return memory
# 创建新的记忆
memory = ConversationMemory(session_id=session_id, user_id=user_id)
self.memory_cache[cache_key] = memory
await self._save_memory(memory)
await self._save_conversation(memory)
return memory
async def set_title(self, session_id: str, user_id: str, title: str):
"""设置会话标题"""
memory = await self.get_memory(session_id=session_id, user_id=user_id)
"""
设置会话标题
"""
memory = await self.get_conversation(session_id=session_id, user_id=user_id)
memory.title = title
memory.updated_at = datetime.now()
await self._save_memory(memory)
await self._save_conversation(memory)
async def get_title(self, session_id: str, user_id: str) -> Optional[str]:
"""获取会话标题"""
memory = await self.get_memory(session_id=session_id, user_id=user_id)
"""
获取会话标题
"""
memory = await self.get_conversation(session_id=session_id, user_id=user_id)
return memory.title
async def list_sessions(self, user_id: str, limit: int = 100) -> List[Dict[str, Any]]:
"""列出历史会话摘要(按更新时间倒序)
"""
列出历史会话摘要(按更新时间倒序)
- 当启用Redis时遍历 `agent_memory:*` 键并读取摘要
- 当未启用Redis时基于内存缓存返回
@@ -148,7 +181,7 @@ class ConversationMemoryManager:
for m in sorted_list
]
async def add_memory(
async def add_conversation(
self,
session_id: str,
user_id: str,
@@ -156,8 +189,10 @@ class ConversationMemoryManager:
content: str,
metadata: Optional[Dict[str, Any]] = None
):
"""添加消息到记忆"""
memory = await self.get_memory(session_id=session_id, user_id=user_id)
"""
添加消息到记忆
"""
memory = await self.get_conversation(session_id=session_id, user_id=user_id)
message = {
"role": role,
@@ -177,7 +212,7 @@ class ConversationMemoryManager:
recent_messages = memory.messages[-(max_messages - len(system_messages)):]
memory.messages = system_messages + recent_messages
await self._save_memory(memory)
await self._save_conversation(memory)
logger.debug(f"消息已添加到记忆: session_id={session_id}, user_id={user_id}, role={role}")
@@ -186,11 +221,12 @@ class ConversationMemoryManager:
session_id: str,
user_id: str
) -> List[Dict[str, Any]]:
"""为Agent获取最近的消息仅内存缓存
"""
为Agent获取最近的消息仅内存缓存
如果消息Token数量超过模型最大上下文长度的阀值会自动进行摘要裁剪
"""
cache_key = self.get_memory_key(session_id, user_id)
cache_key = self._get_memory_key(session_id, user_id)
memory = self.memory_cache.get(cache_key)
if not memory:
return []
@@ -205,8 +241,10 @@ class ConversationMemoryManager:
limit: int = 10,
role_filter: Optional[list] = None
) -> List[Dict[str, Any]]:
"""获取最近的消息"""
memory = await self.get_memory(session_id=session_id, user_id=user_id)
"""
获取最近的消息
"""
memory = await self.get_conversation(session_id=session_id, user_id=user_id)
messages = memory.messages
if role_filter:
@@ -215,36 +253,41 @@ class ConversationMemoryManager:
return messages[-limit:] if messages else []
async def get_context(self, session_id: str, user_id: str) -> Dict[str, Any]:
"""获取会话上下文"""
memory = await self.get_memory(session_id=session_id, user_id=user_id)
"""
获取会话上下文
"""
memory = await self.get_conversation(session_id=session_id, user_id=user_id)
return memory.context
async def clear_memory(self, session_id: str, user_id: str):
"""清空会话记忆"""
"""
清空会话记忆
"""
cache_key = f"{user_id}:{session_id}" if user_id else session_id
if cache_key in self.memory_cache:
del self.memory_cache[cache_key]
if settings.CACHE_BACKEND_TYPE == "redis":
redis_key = self.get_redis_key(session_id, user_id)
redis_key = self._get_redis_key(session_id, user_id)
await self.redis_helper.delete(redis_key, region="AI_AGENT")
logger.info(f"会话记忆已清空: session_id={session_id}, user_id={user_id}")
async def _save_memory(self, memory: ConversationMemory):
"""保存记忆到存储
Redis中的记忆会自动通过TTL机制过期无需手动清理
def _save_memory(self, memory: ConversationMemory):
"""
# 更新内存缓存
cache_key = self.get_memory_key(memory.session_id, memory.user_id)
保存记忆到内存
"""
cache_key = self._get_memory_key(memory.session_id, memory.user_id)
self.memory_cache[cache_key] = memory
# 保存到Redis设置TTL自动过期
async def _save_redis(self, memory: ConversationMemory):
"""
保存记忆到Redis
"""
if settings.CACHE_BACKEND_TYPE == "redis":
try:
memory_dict = memory.model_dump()
redis_key = self.get_redis_key(memory.session_id, memory.user_id)
redis_key = self._get_redis_key(memory.session_id, memory.user_id)
ttl = int(timedelta(days=settings.LLM_REDIS_MEMORY_RETENTION_DAYS).total_seconds())
await self.redis_helper.set(
redis_key,
@@ -255,8 +298,22 @@ class ConversationMemoryManager:
except Exception as e:
logger.warning(f"保存记忆到Redis失败: {e}")
async def _save_conversation(self, memory: ConversationMemory):
"""
保存记忆到存储
Redis中的记忆会自动通过TTL机制过期无需手动清理
"""
# 更新内存缓存
self._save_memory(memory)
# 保存到Redis设置TTL自动过期
await self._save_redis(memory)
async def _cleanup_expired_memories(self):
"""清理内存中过期记忆的后台任务
"""
清理内存中过期记忆的后台任务
注意Redis中的记忆通过TTL机制自动过期这里只清理内存缓存
"""
@@ -286,3 +343,5 @@ class ConversationMemoryManager:
break
except Exception as e:
logger.error(f"清理记忆时发生错误: {e}")
conversation_manager = ConversationMemoryManager()

View File

@@ -4,6 +4,20 @@ All your responses must be in **Chinese (中文)**.
You act as a proactive agent. Your goal is to fully resolve the user's media-related requests autonomously. Do not end your turn until the task is complete or you are blocked and require user feedback.
Core Capabilities:
1. Media Search & Recognition
- Identify movies, TV shows, and anime across various metadata providers.
- Recognize media info from fuzzy filenames or incomplete titles.
2. Subscription Management
- Create complex rules for automated downloading of new episodes.
- Monitor trending movies/shows for automated suggestions.
3. Download Control
- Intelligent torrent searching across private/public trackers.
- Filter resources by quality (4K/1080p), codec (H265/H264), and release groups.
4. System Status & Organization
- Monitor download progress and server health.
- Manage file transfers, renaming, and library cleanup.
<communication>
- Use Markdown for structured data like movie lists, download statuses, or technical details.
- Avoid wrapping the entire response in a single code block. Use `inline code` for titles or parameters and ```code blocks``` for structured logs or data only when necessary.
@@ -11,6 +25,11 @@ You act as a proactive agent. Your goal is to fully resolve the user's media-rel
- Optimize your writing for clarity and readability, using bold text for key information.
- Provide comprehensive details for media (year, rating, resolution) to help users make informed decisions.
- Do not stop for approval for read-only operations. Only stop for critical actions like starting a download or deleting a subscription.
Important Notes:
- User-Centric: Your tone should be helpful, professional, and media-savvy.
- No Coding Hallucinations: You are NOT a coding assistant. Do not offer code snippets, IDE tips, or programming help. Focus entirely on the MoviePilot media ecosystem.
- Contextual Memory: Remember if the user preferred a specific version previously and prioritize similar results in future searches.
</communication>
<status_update_spec>
@@ -28,44 +47,26 @@ At the end of your session/turn, provide a concise summary of your actions.
</summary_spec>
<flow>
1. **Media Discovery**: Start by identifying the exact media metadata (TMDB ID, Season/Episode) using search tools.
2. **Context Checking**: Verify current status (Is it already in the library? Is it already subscribed?).
3. **Action Execution**: Perform the requested task (Subscribe, Search Torrents, etc.) with a brief status update.
4. **Final Confirmation**: Summarize the final state and wait for the next user command.
1. Media Discovery: Start by identifying the exact media metadata (TMDB ID, Season/Episode) using search tools.
2. Context Checking: Verify current status (Is it already in the library? Is it already subscribed?).
3. Action Execution: Perform the requested task (Subscribe, Search Torrents, etc.) with a brief status update.
4. Final Confirmation: Summarize the final state and wait for the next user command.
</flow>
<tool_calling_strategy>
- **Parallel Execution**: You MUST call independent tools in parallel. For example, search for torrents on multiple sites or check both subscription and download status at once.
- **Information Depth**: If a search returns ambiguous results, use `query_media_detail` or `recognize_media` to resolve the ambiguity before proceeding.
- **Proactive Fallback**: If `search_media` fails, try `search_web` or fuzzy search with `recognize_media`. Do not ask the user for help unless all automated search methods are exhausted.
- Parallel Execution: You MUST call independent tools in parallel. For example, search for torrents on multiple sites or check both subscription and download status at once.
- Information Depth: If a search returns ambiguous results, use `query_media_detail` or `recognize_media` to resolve the ambiguity before proceeding.
- Proactive Fallback: If `search_media` fails, try `search_web` or fuzzy search with `recognize_media`. Do not ask the user for help unless all automated search methods are exhausted.
</tool_calling_strategy>
<media_management_rules>
1. **Download Safety**: You MUST present a list of found torrents (including size, seeds, and quality) and obtain the user's explicit consent before initiating any download.
2. **Subscription Logic**: When adding a subscription, always check for the best matching quality profile based on user history or the default settings.
3. **Library Awareness**: Always check if the user already has the content in their library to avoid duplicate downloads.
4. **Error Handling**: If a site is down or a tool returns an error, explain the situation in plain Chinese (e.g., "站点响应超时") and suggest an alternative (e.g., "尝试从其他站点进行搜索").
1. Download Safety: You MUST present a list of found torrents (including size, seeds, and quality) and obtain the user's explicit consent before initiating any download.
2. Subscription Logic: When adding a subscription, always check for the best matching quality profile based on user history or the default settings.
3. Library Awareness: Always check if the user already has the content in their library to avoid duplicate downloads.
4. Error Handling: If a site is down or a tool returns an error, explain the situation in plain Chinese (e.g., "站点响应超时") and suggest an alternative (e.g., "尝试从其他站点进行搜索").
</media_management_rules>
## Core Capabilities
### 1. Media Search & Recognition
- Identify movies, TV shows, and anime across various metadata providers.
- Recognize media info from fuzzy filenames or incomplete titles.
### 2. Subscription Management
- Create complex rules for automated downloading of new episodes.
- Monitor trending movies/shows for automated suggestions.
### 3. Download Control
- Intelligent torrent searching across private/public trackers.
- Filter resources by quality (4K/1080p), codec (H265/H264), and release groups.
### 4. System Status & Organization
- Monitor download progress and server health.
- Manage file transfers, renaming, and library cleanup.
## Important Notes
- **User-Centric**: Your tone should be helpful, professional, and media-savvy.
- **No Coding Hallucinations**: You are NOT a coding assistant. Do not offer code snippets, IDE tips, or programming help. Focus entirely on the MoviePilot media ecosystem.
- **Contextual Memory**: Remember if the user preferred a specific version previously and prioritize similar results in future searches.
<markdown_spec>
Specific markdown rules:
{markdown_spec}
</markdown_spec>

View File

@@ -1,13 +1,15 @@
"""提示词管理器"""
from pathlib import Path
from typing import Dict
from app.log import logger
from app.schemas import ChannelCapability, ChannelCapabilities, MessageChannel, ChannelCapabilityManager
class PromptManager:
"""提示词管理器"""
"""
提示词管理器
"""
def __init__(self, prompts_dir: str = None):
if prompts_dir is None:
@@ -17,22 +19,20 @@ class PromptManager:
self.prompts_cache: Dict[str, str] = {}
def load_prompt(self, prompt_name: str) -> str:
"""加载指定的提示词"""
"""
加载指定的提示词
"""
if prompt_name in self.prompts_cache:
return self.prompts_cache[prompt_name]
prompt_file = self.prompts_dir / prompt_name
try:
with open(prompt_file, 'r', encoding='utf-8') as f:
content = f.read().strip()
# 缓存提示词
self.prompts_cache[prompt_name] = content
logger.info(f"提示词加载成功: {prompt_name},长度:{len(content)} 字符")
return content
except FileNotFoundError:
logger.error(f"提示词文件不存在: {prompt_file}")
raise
@@ -46,73 +46,43 @@ class PromptManager:
:param channel: 消息渠道Telegram、微信、Slack等
:return: 提示词内容
"""
# 基础提示词
base_prompt = self.load_prompt("Agent Prompt.txt")
# 根据渠道添加特定的格式说明
if channel:
channel_format_info = self._get_channel_format_info(channel)
if channel_format_info:
base_prompt += f"\n\n## Current Message Channel Format Requirements\n\n{channel_format_info}"
# 识别渠道
msg_channel = next((c for c in MessageChannel if c.value.lower() == channel.lower()), None) if channel else None
if msg_channel:
# 获取渠道能力说明
caps = ChannelCapabilityManager.get_capabilities(msg_channel)
if caps:
base_prompt = base_prompt.replace(
"{markdown_spec}",
self._generate_formatting_instructions(caps)
)
return base_prompt
@staticmethod
def _get_channel_format_info(channel: str) -> str:
def _generate_formatting_instructions(caps: ChannelCapabilities) -> str:
"""
获取渠道特定的格式说明
:param channel: 消息渠道
:return: 格式说明文本
根据渠道能力动态生成格式指令
"""
channel_lower = channel.lower() if channel else ""
if "telegram" in channel_lower:
return """Messages are being sent through the **Telegram** channel. You must follow these format requirements:
**Supported Formatting:**
- **Bold text**: Use `*text*` (single asterisk, not double asterisks)
- **Italic text**: Use `_text_` (underscore)
- **Code**: Use `` `text` `` (backtick)
- **Links**: Use `[text](url)` format
- **Strikethrough**: Use `~text~` (tilde)
**IMPORTANT - Headings and Lists:**
- **DO NOT use heading syntax** (`#`, `##`, `###`) - Telegram MarkdownV2 does NOT support it
- **Instead, use bold text for headings**: `*Heading Text*` followed by a blank line
- **DO NOT use list syntax** (`-`, `*`, `+` at line start) - these will be escaped and won't display as lists
- **For lists**, use plain text with line breaks, or use bold for list item labels: `*Item 1:* description`
**Examples:**
- ❌ Wrong heading: `# Main Title` or `## Subtitle`
- ✅ Correct heading: `*Main Title*` (followed by blank line) or `*Subtitle*` (followed by blank line)
- ❌ Wrong list: `- Item 1` or `* Item 2`
- ✅ Correct list format: `*Item 1:* description` or use plain text with line breaks
**Special Characters:**
- Avoid using special characters that need escaping in MarkdownV2: `_*[]()~`>#+-=|{}.!` unless they are part of the formatting syntax
- Keep formatting simple, avoid nested formatting to ensure proper rendering in Telegram"""
elif "wechat" in channel_lower or "微信" in channel:
return """Messages are being sent through the **WeChat** channel. Please follow these format requirements:
- WeChat does NOT support Markdown formatting. Use plain text format only.
- Do NOT use any Markdown syntax (such as `**bold**`, `*italic*`, `` `code` `` etc.)
- Use plain text descriptions. You can organize content using line breaks and punctuation
- Links can be provided directly as URLs, no Markdown link format needed
- Keep messages concise and clear, use natural Chinese expressions"""
elif "slack" in channel_lower:
return """Messages are being sent through the **Slack** channel. Please follow these format requirements:
- Slack supports Markdown formatting
- Use `*text*` for bold
- Use `_text_` for italic
- Use `` `text` `` for code
- Link format: `<url|text>` or `[text](url)`"""
# 其他渠道使用标准Markdown
return None
instructions = []
if ChannelCapability.RICH_TEXT not in caps.capabilities:
instructions.append("- Formatting: Use **Plain Text ONLY**. The channel does NOT support Markdown.")
instructions.append(
"- No Markdown Symbols: NEVER use `**`, `*`, `__`, or `[` blocks. Use natural text to emphasize (e.g., using ALL CAPS or separators).")
instructions.append(
"- Lists: Use plain text symbols like `>` or `*` at the start of lines, followed by manual line breaks.")
instructions.append("- Links: Paste URLs directly as text.")
return "\n".join(instructions)
def clear_cache(self):
"""清空缓存"""
"""
清空缓存
"""
self.prompts_cache.clear()
logger.info("提示词缓存已清空")
prompt_manager = PromptManager()

View File

@@ -1,4 +1,3 @@
"""MoviePilot工具基类"""
import json
from abc import ABCMeta, abstractmethod
from typing import Any, Optional
@@ -6,7 +5,7 @@ from typing import Any, Optional
from langchain.tools import BaseTool
from pydantic import PrivateAttr
from app.agent import StreamingCallbackHandler, ConversationMemoryManager
from app.agent import StreamingCallbackHandler, conversation_manager
from app.chain import ChainBase
from app.log import logger
from app.schemas import Notification
@@ -17,7 +16,9 @@ class ToolChain(ChainBase):
class MoviePilotTool(BaseTool, metaclass=ABCMeta):
"""MoviePilot专用工具基类"""
"""
MoviePilot专用工具基类
"""
_session_id: str = PrivateAttr()
_user_id: str = PrivateAttr()
@@ -25,7 +26,6 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
_source: str = PrivateAttr(default=None)
_username: str = PrivateAttr(default=None)
_callback_handler: StreamingCallbackHandler = PrivateAttr(default=None)
_memory_manager: ConversationMemoryManager = PrivateAttr(default=None)
def __init__(self, session_id: str, user_id: str, **kwargs):
super().__init__(**kwargs)
@@ -36,7 +36,9 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
pass
async def _arun(self, **kwargs) -> str:
"""异步运行工具"""
"""
异步运行工具
"""
# 发送和记忆工具调用前的信息
agent_message = await self._callback_handler.get_message()
if agent_message:
@@ -44,7 +46,7 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
await self.send_tool_message(agent_message, title="MoviePilot助手")
# 记忆工具调用
await self._memory_manager.add_memory(
await conversation_manager.add_conversation(
session_id=self._session_id,
user_id=self._user_id,
role="tool_call",
@@ -77,7 +79,7 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
formated_result = str(result)
else:
formated_result = json.dumps(result, ensure_ascii=False, indent=2)
await self._memory_manager.add_memory(
await conversation_manager.add_conversation(
session_id=self._session_id,
user_id=self._user_id,
role="tool_result",
@@ -106,21 +108,23 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
raise NotImplementedError
def set_message_attr(self, channel: str, source: str, username: str):
"""设置消息属性"""
"""
设置消息属性
"""
self._channel = channel
self._source = source
self._username = username
def set_callback_handler(self, callback_handler: StreamingCallbackHandler):
"""设置回调处理器"""
"""
设置回调处理器
"""
self._callback_handler = callback_handler
def set_memory_manager(self, memory_manager: ConversationMemoryManager):
"""设置记忆客理器"""
self._memory_manager = memory_manager
async def send_tool_message(self, message: str, title: str = ""):
"""发送工具消息"""
"""
发送工具消息
"""
await ToolChain().async_post_message(
Notification(
channel=self._channel,

View File

@@ -1,5 +1,3 @@
"""MoviePilot工具工厂"""
from typing import List, Callable
from app.agent.tools.impl.add_download import AddDownloadTool
@@ -47,13 +45,17 @@ from .base import MoviePilotTool
class MoviePilotToolFactory:
"""MoviePilot工具工厂"""
"""
MoviePilot工具工厂
"""
@staticmethod
def create_tools(session_id: str, user_id: str,
channel: str = None, source: str = None, username: str = None,
callback_handler: Callable = None, memory_mananger: Callable = None) -> List[MoviePilotTool]:
"""创建MoviePilot工具列表"""
callback_handler: Callable = None) -> List[MoviePilotTool]:
"""
创建MoviePilot工具列表
"""
tools = []
tool_definitions = [
SearchMediaTool,
@@ -104,7 +106,6 @@ class MoviePilotToolFactory:
)
tool.set_message_attr(channel=channel, source=source, username=username)
tool.set_callback_handler(callback_handler=callback_handler)
tool.set_memory_manager(memory_manager=memory_mananger)
tools.append(tool)
# 加载插件提供的工具
@@ -127,7 +128,6 @@ class MoviePilotToolFactory:
)
tool.set_message_attr(channel=channel, source=source, username=username)
tool.set_callback_handler(callback_handler=callback_handler)
tool.set_memory_manager(memory_manager=memory_mananger)
tools.append(tool)
plugin_tools_count += 1
logger.debug(f"成功加载插件 {plugin_name}({plugin_id}) 的工具: {ToolClass.__name__}")

View File

@@ -1,18 +1,15 @@
"""MoviePilot工具管理器
用于HTTP API调用工具
"""
import json
import uuid
from typing import Any, Dict, List, Optional
from app.agent import ConversationMemoryManager
from app.agent.tools.factory import MoviePilotToolFactory
from app.log import logger
class ToolDefinition:
"""工具定义"""
"""
工具定义
"""
def __init__(self, name: str, description: str, input_schema: Dict[str, Any]):
self.name = name
@@ -21,7 +18,9 @@ class ToolDefinition:
class MoviePilotToolsManager:
"""MoviePilot工具管理器用于HTTP API"""
"""
MoviePilot工具管理器用于HTTP API
"""
def __init__(self, user_id: str = "api_user", session_id: str = uuid.uuid4()):
"""
@@ -34,11 +33,12 @@ class MoviePilotToolsManager:
self.user_id = user_id
self.session_id = session_id
self.tools: List[Any] = []
self.memory_manager = ConversationMemoryManager()
self._load_tools()
def _load_tools(self):
"""加载所有MoviePilot工具"""
"""
加载所有MoviePilot工具
"""
try:
# 创建工具实例
self.tools = MoviePilotToolFactory.create_tools(
@@ -48,7 +48,6 @@ class MoviePilotToolsManager:
source="api",
username="API Client",
callback_handler=None,
memory_mananger=None,
)
logger.info(f"成功加载 {len(self.tools)} 个工具")
except Exception as e:
@@ -116,7 +115,7 @@ class MoviePilotToolsManager:
args_schema = getattr(tool_instance, 'args_schema', None)
if not args_schema:
return arguments
# 获取schema中的字段定义
try:
schema = args_schema.model_json_schema()
@@ -124,7 +123,7 @@ class MoviePilotToolsManager:
except Exception as e:
logger.warning(f"获取工具schema失败: {e}")
return arguments
# 规范化参数
normalized = {}
for key, value in arguments.items():
@@ -132,10 +131,10 @@ class MoviePilotToolsManager:
# 参数不在schema中保持原样
normalized[key] = value
continue
field_info = properties[key]
field_type = field_info.get("type")
# 处理 anyOf 类型(例如 Optional[int] 会生成 anyOf
any_of = field_info.get("anyOf")
if any_of and not field_type:
@@ -144,7 +143,7 @@ class MoviePilotToolsManager:
if "type" in type_option and type_option["type"] != "null":
field_type = type_option["type"]
break
# 根据类型进行转换
if field_type == "integer" and isinstance(value, str):
try:
@@ -167,7 +166,7 @@ class MoviePilotToolsManager:
normalized[key] = True
else:
normalized[key] = value
return normalized
async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> str:
@@ -192,7 +191,7 @@ class MoviePilotToolsManager:
try:
# 规范化参数类型
normalized_arguments = self._normalize_arguments(tool_instance, arguments)
# 调用工具的run方法
result = await tool_instance.run(**normalized_arguments)
@@ -270,3 +269,6 @@ class MoviePilotToolsManager:
"properties": properties,
"required": required
}
moviepilot_tool_manager = MoviePilotToolsManager()

View File

@@ -1,14 +1,10 @@
"""工具API端点
通过HTTP API暴露MoviePilot的智能体工具功能
"""
from typing import List, Any, Dict, Annotated, Union
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import JSONResponse, Response
from app import schemas
from app.agent.tools.manager import MoviePilotToolsManager
from app.agent.tools.manager import moviepilot_tool_manager
from app.core.security import verify_apikey
from app.log import logger
@@ -25,18 +21,10 @@ MCP_PROTOCOL_VERSIONS = ["2025-11-25", "2025-06-18", "2024-11-05"]
MCP_PROTOCOL_VERSION = MCP_PROTOCOL_VERSIONS[0] # 默认使用最新版本
def get_tools_manager() -> MoviePilotToolsManager:
"""
获取工具管理器实例
Returns:
MoviePilotToolsManager实例
"""
return MoviePilotToolsManager()
def create_jsonrpc_response(request_id: Union[str, int, None], result: Any) -> Dict[str, Any]:
"""创建 JSON-RPC 成功响应"""
"""
创建 JSON-RPC 成功响应
"""
response = {
"jsonrpc": "2.0",
"id": request_id,
@@ -45,8 +33,11 @@ def create_jsonrpc_response(request_id: Union[str, int, None], result: Any) -> D
return response
def create_jsonrpc_error(request_id: Union[str, int, None], code: int, message: str, data: Any = None) -> Dict[str, Any]:
"""创建 JSON-RPC 错误响应"""
def create_jsonrpc_error(request_id: Union[str, int, None], code: int, message: str, data: Any = None) -> Dict[
str, Any]:
"""
创建 JSON-RPC 错误响应
"""
error = {
"jsonrpc": "2.0",
"id": request_id,
@@ -60,8 +51,6 @@ def create_jsonrpc_error(request_id: Union[str, int, None], code: int, message:
return error
# ==================== MCP JSON-RPC 端点 ====================
@router.post("", summary="MCP JSON-RPC 端点", response_model=None)
async def mcp_jsonrpc(
request: Request,
@@ -146,7 +135,9 @@ async def mcp_jsonrpc(
async def handle_initialize(params: Dict[str, Any]) -> Dict[str, Any]:
"""处理初始化请求"""
"""
处理初始化请求
"""
protocol_version = params.get("protocolVersion")
client_info = params.get("clientInfo", {})
@@ -161,7 +152,7 @@ async def handle_initialize(params: Dict[str, Any]) -> Dict[str, Any]:
else:
# 客户端版本不支持,使用服务器默认版本
logger.warning(f"协议版本不匹配: 客户端={protocol_version}, 使用服务器版本={negotiated_version}")
return {
"protocolVersion": negotiated_version,
"capabilities": {
@@ -180,9 +171,10 @@ async def handle_initialize(params: Dict[str, Any]) -> Dict[str, Any]:
async def handle_tools_list() -> Dict[str, Any]:
"""处理工具列表请求"""
manager = get_tools_manager()
tools = manager.list_tools()
"""
处理工具列表请求
"""
tools = moviepilot_tool_manager.list_tools()
# 转换为 MCP 工具格式
mcp_tools = []
@@ -200,18 +192,18 @@ async def handle_tools_list() -> Dict[str, Any]:
async def handle_tools_call(params: Dict[str, Any]) -> Dict[str, Any]:
"""处理工具调用请求"""
"""
处理工具调用请求
"""
tool_name = params.get("name")
arguments = params.get("arguments", {})
if not tool_name:
raise ValueError("Missing tool name")
manager = get_tools_manager()
try:
result_text = await manager.call_tool(tool_name, arguments)
result_text = await moviepilot_tool_manager.call_tool(tool_name, arguments)
return {
"content": [
{
@@ -243,8 +235,6 @@ async def delete_mcp_session(
return Response(status_code=204)
# ==================== 兼容的 RESTful API 端点 ====================
@router.get("/tools", summary="列出所有可用工具", response_model=List[Dict[str, Any]])
@@ -257,9 +247,8 @@ async def list_tools(
返回每个工具的名称、描述和参数定义
"""
try:
manager = get_tools_manager()
# 获取所有工具定义
tools = manager.list_tools()
tools = moviepilot_tool_manager.list_tools()
# 转换为字典格式
tools_list = []
@@ -289,11 +278,8 @@ async def call_tool(
工具执行结果
"""
try:
# 使用当前用户ID创建管理器实例
manager = get_tools_manager()
# 调用工具
result_text = await manager.call_tool(request.tool_name, request.arguments)
result_text = await moviepilot_tool_manager.call_tool(request.tool_name, request.arguments)
return schemas.ToolCallResponse(
success=True,
@@ -319,9 +305,8 @@ async def get_tool_info(
工具的详细信息,包括名称、描述和参数定义
"""
try:
manager = get_tools_manager()
# 获取所有工具
tools = manager.list_tools()
tools = moviepilot_tool_manager.list_tools()
# 查找指定工具
for tool in tools:
@@ -352,9 +337,8 @@ async def get_tool_schema(
工具的JSON Schema定义
"""
try:
manager = get_tools_manager()
# 获取所有工具
tools = manager.list_tools()
tools = moviepilot_tool_manager.list_tools()
# 查找指定工具
for tool in tools:

View File

@@ -1,20 +1,17 @@
"""
AI智能体初始化器
"""
import asyncio
import threading
from typing import Optional
from app.agent import agent_manager
from app.core.config import settings
from app.log import logger
class AgentInitializer:
"""AI智能体初始化器"""
"""
AI智能体初始化器
"""
def __init__(self):
self.agent_manager = None
self._initialized = False
async def initialize(self) -> bool:
@@ -26,9 +23,6 @@ class AgentInitializer:
logger.info("AI智能体功能未启用")
return True
from app.agent import agent_manager
self.agent_manager = agent_manager
await agent_manager.initialize()
self._initialized = True
logger.info("AI智能体管理器初始化成功")
@@ -43,10 +37,10 @@ class AgentInitializer:
清理AI智能体管理器
"""
try:
if not self._initialized or not self.agent_manager:
if not self._initialized:
return
await self.agent_manager.close()
await agent_manager.close()
self._initialized = False
logger.info("AI智能体管理器已关闭")
@@ -78,8 +72,8 @@ def init_agent():
else:
logger.error("AI智能体管理器初始化失败")
return success
except Exception as e:
logger.error(f"AI智能体管理器初始化失败: {e}")
except Exception as err:
logger.error(f"AI智能体管理器初始化失败: {err}")
return False
finally:
loop.close()