mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-02-03 02:25:32 +08:00
feat: 为工具管理器添加参数类型规范化处理,并基于渠道能力动态生成提示词中的格式要求
This commit is contained in:
@@ -1,5 +1,3 @@
|
||||
"""MoviePilot AI智能体实现"""
|
||||
|
||||
import asyncio
|
||||
from typing import Dict, List, Any
|
||||
|
||||
@@ -11,8 +9,8 @@ from langchain_core.messages import HumanMessage, AIMessage, ToolCall, ToolMessa
|
||||
from langchain_core.runnables.history import RunnableWithMessageHistory
|
||||
|
||||
from app.agent.callback import StreamingCallbackHandler
|
||||
from app.agent.memory import ConversationMemoryManager
|
||||
from app.agent.prompt import PromptManager
|
||||
from app.agent.memory import conversation_manager
|
||||
from app.agent.prompt import prompt_manager
|
||||
from app.agent.tools.factory import MoviePilotToolFactory
|
||||
from app.chain import ChainBase
|
||||
from app.core.config import settings
|
||||
@@ -27,7 +25,9 @@ class AgentChain(ChainBase):
|
||||
|
||||
|
||||
class MoviePilotAgent:
|
||||
"""MoviePilot AI智能体"""
|
||||
"""
|
||||
MoviePilot AI智能体
|
||||
"""
|
||||
|
||||
def __init__(self, session_id: str, user_id: str = None,
|
||||
channel: str = None, source: str = None, username: str = None):
|
||||
@@ -40,12 +40,6 @@ class MoviePilotAgent:
|
||||
# 消息助手
|
||||
self.message_helper = MessageHelper()
|
||||
|
||||
# 记忆管理器
|
||||
self.memory_manager = ConversationMemoryManager()
|
||||
|
||||
# 提示词管理器
|
||||
self.prompt_manager = PromptManager()
|
||||
|
||||
# 回调处理器
|
||||
self.callback_handler = StreamingCallbackHandler(
|
||||
session_id=session_id
|
||||
@@ -64,30 +58,37 @@ class MoviePilotAgent:
|
||||
self.agent_executor = self._create_agent_executor()
|
||||
|
||||
def _initialize_llm(self):
|
||||
"""初始化LLM模型"""
|
||||
"""
|
||||
初始化LLM模型
|
||||
"""
|
||||
return LLMHelper.get_llm(streaming=True, callbacks=[self.callback_handler])
|
||||
|
||||
def _initialize_tools(self) -> List:
|
||||
"""初始化工具列表"""
|
||||
"""
|
||||
初始化工具列表
|
||||
"""
|
||||
return MoviePilotToolFactory.create_tools(
|
||||
session_id=self.session_id,
|
||||
user_id=self.user_id,
|
||||
channel=self.channel,
|
||||
source=self.source,
|
||||
username=self.username,
|
||||
callback_handler=self.callback_handler,
|
||||
memory_mananger=self.memory_manager
|
||||
callback_handler=self.callback_handler
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _initialize_session_store() -> Dict[str, InMemoryChatMessageHistory]:
|
||||
"""初始化内存存储"""
|
||||
"""
|
||||
初始化内存存储
|
||||
"""
|
||||
return {}
|
||||
|
||||
def get_session_history(self, session_id: str) -> InMemoryChatMessageHistory:
|
||||
"""获取会话历史"""
|
||||
"""
|
||||
获取会话历史
|
||||
"""
|
||||
chat_history = InMemoryChatMessageHistory()
|
||||
messages: List[dict] = self.memory_manager.get_recent_messages_for_agent(
|
||||
messages: List[dict] = conversation_manager.get_recent_messages_for_agent(
|
||||
session_id=session_id,
|
||||
user_id=self.user_id
|
||||
)
|
||||
@@ -119,7 +120,9 @@ class MoviePilotAgent:
|
||||
|
||||
@staticmethod
|
||||
def _initialize_prompt() -> ChatPromptTemplate:
|
||||
"""初始化提示词模板"""
|
||||
"""
|
||||
初始化提示词模板
|
||||
"""
|
||||
try:
|
||||
prompt_template = ChatPromptTemplate.from_messages([
|
||||
("system", "{system_prompt}"),
|
||||
@@ -134,7 +137,9 @@ class MoviePilotAgent:
|
||||
raise e
|
||||
|
||||
def _create_agent_executor(self) -> RunnableWithMessageHistory:
|
||||
"""创建Agent执行器"""
|
||||
"""
|
||||
创建Agent执行器
|
||||
"""
|
||||
try:
|
||||
agent = create_openai_tools_agent(
|
||||
llm=self.llm,
|
||||
@@ -161,10 +166,12 @@ class MoviePilotAgent:
|
||||
raise e
|
||||
|
||||
async def process_message(self, message: str) -> str:
|
||||
"""处理用户消息"""
|
||||
"""
|
||||
处理用户消息
|
||||
"""
|
||||
try:
|
||||
# 添加用户消息到记忆
|
||||
await self.memory_manager.add_memory(
|
||||
await conversation_manager.add_conversation(
|
||||
self.session_id,
|
||||
user_id=self.user_id,
|
||||
role="user",
|
||||
@@ -173,7 +180,7 @@ class MoviePilotAgent:
|
||||
|
||||
# 构建输入上下文
|
||||
input_context = {
|
||||
"system_prompt": self.prompt_manager.get_agent_prompt(channel=self.channel),
|
||||
"system_prompt": prompt_manager.get_agent_prompt(channel=self.channel),
|
||||
"input": message
|
||||
}
|
||||
|
||||
@@ -190,7 +197,7 @@ class MoviePilotAgent:
|
||||
await self.send_agent_message(agent_message)
|
||||
|
||||
# 添加Agent回复到记忆
|
||||
await self.memory_manager.add_memory(
|
||||
await conversation_manager.add_conversation(
|
||||
session_id=self.session_id,
|
||||
user_id=self.user_id,
|
||||
role="agent",
|
||||
@@ -210,7 +217,9 @@ class MoviePilotAgent:
|
||||
return error_message
|
||||
|
||||
async def _execute_agent(self, input_context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""执行LangChain Agent"""
|
||||
"""
|
||||
执行LangChain Agent
|
||||
"""
|
||||
try:
|
||||
with get_openai_callback() as cb:
|
||||
result = await self.agent_executor.ainvoke(
|
||||
@@ -243,7 +252,9 @@ class MoviePilotAgent:
|
||||
}
|
||||
|
||||
async def send_agent_message(self, message: str, title: str = "MoviePilot助手"):
|
||||
"""通过原渠道发送消息给用户"""
|
||||
"""
|
||||
通过原渠道发送消息给用户
|
||||
"""
|
||||
await AgentChain().async_post_message(
|
||||
Notification(
|
||||
channel=self.channel,
|
||||
@@ -256,24 +267,32 @@ class MoviePilotAgent:
|
||||
)
|
||||
|
||||
async def cleanup(self):
|
||||
"""清理智能体资源"""
|
||||
"""
|
||||
清理智能体资源
|
||||
"""
|
||||
logger.info(f"MoviePilot智能体已清理: session_id={self.session_id}")
|
||||
|
||||
|
||||
class AgentManager:
|
||||
"""AI智能体管理器"""
|
||||
"""
|
||||
AI智能体管理器
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.active_agents: Dict[str, MoviePilotAgent] = {}
|
||||
self.memory_manager = ConversationMemoryManager()
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化管理器"""
|
||||
await self.memory_manager.initialize()
|
||||
@staticmethod
|
||||
async def initialize():
|
||||
"""
|
||||
初始化管理器
|
||||
"""
|
||||
await conversation_manager.initialize()
|
||||
|
||||
async def close(self):
|
||||
"""关闭管理器"""
|
||||
await self.memory_manager.close()
|
||||
"""
|
||||
关闭管理器
|
||||
"""
|
||||
await conversation_manager.close()
|
||||
# 清理所有活跃的智能体
|
||||
for agent in self.active_agents.values():
|
||||
await agent.cleanup()
|
||||
@@ -281,7 +300,9 @@ class AgentManager:
|
||||
|
||||
async def process_message(self, session_id: str, user_id: str, message: str,
|
||||
channel: str = None, source: str = None, username: str = None) -> str:
|
||||
"""处理用户消息"""
|
||||
"""
|
||||
处理用户消息
|
||||
"""
|
||||
# 获取或创建Agent实例
|
||||
if session_id not in self.active_agents:
|
||||
logger.info(f"创建新的AI智能体实例,session_id: {session_id}, user_id: {user_id}")
|
||||
@@ -292,7 +313,6 @@ class AgentManager:
|
||||
source=source,
|
||||
username=username
|
||||
)
|
||||
agent.memory_manager = self.memory_manager
|
||||
self.active_agents[session_id] = agent
|
||||
else:
|
||||
agent = self.active_agents[session_id]
|
||||
@@ -309,12 +329,14 @@ class AgentManager:
|
||||
return await agent.process_message(message)
|
||||
|
||||
async def clear_session(self, session_id: str, user_id: str):
|
||||
"""清空会话"""
|
||||
"""
|
||||
清空会话
|
||||
"""
|
||||
if session_id in self.active_agents:
|
||||
agent = self.active_agents[session_id]
|
||||
await agent.cleanup()
|
||||
del self.active_agents[session_id]
|
||||
await self.memory_manager.clear_memory(session_id, user_id)
|
||||
await conversation_manager.clear_memory(session_id, user_id)
|
||||
logger.info(f"会话 {session_id} 的记忆已清空")
|
||||
|
||||
|
||||
|
||||
@@ -6,7 +6,9 @@ from app.log import logger
|
||||
|
||||
|
||||
class StreamingCallbackHandler(AsyncCallbackHandler):
|
||||
"""流式输出回调处理器"""
|
||||
"""
|
||||
流式输出回调处理器
|
||||
"""
|
||||
|
||||
def __init__(self, session_id: str):
|
||||
self._lock = threading.Lock()
|
||||
@@ -14,7 +16,9 @@ class StreamingCallbackHandler(AsyncCallbackHandler):
|
||||
self.current_message = ""
|
||||
|
||||
async def get_message(self):
|
||||
"""获取当前消息内容,获取后清空"""
|
||||
"""
|
||||
获取当前消息内容,获取后清空
|
||||
"""
|
||||
with self._lock:
|
||||
if not self.current_message:
|
||||
return ""
|
||||
@@ -24,7 +28,9 @@ class StreamingCallbackHandler(AsyncCallbackHandler):
|
||||
return msg
|
||||
|
||||
async def on_llm_new_token(self, token: str, **kwargs):
|
||||
"""处理新的token"""
|
||||
"""
|
||||
处理新的token
|
||||
"""
|
||||
if not token:
|
||||
return
|
||||
with self._lock:
|
||||
|
||||
@@ -12,7 +12,9 @@ from app.schemas.agent import ConversationMemory
|
||||
|
||||
|
||||
class ConversationMemoryManager:
|
||||
"""对话记忆管理器"""
|
||||
"""
|
||||
对话记忆管理器
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# 内存中的会话记忆缓存
|
||||
@@ -23,7 +25,9 @@ class ConversationMemoryManager:
|
||||
self.cleanup_task: Optional[asyncio.Task] = None
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化记忆管理器"""
|
||||
"""
|
||||
初始化记忆管理器
|
||||
"""
|
||||
try:
|
||||
# 启动内存缓存清理任务(Redis通过TTL自动过期)
|
||||
self.cleanup_task = asyncio.create_task(self._cleanup_expired_memories())
|
||||
@@ -33,7 +37,9 @@ class ConversationMemoryManager:
|
||||
logger.warning(f"Redis连接失败,将使用内存存储: {e}")
|
||||
|
||||
async def close(self):
|
||||
"""关闭记忆管理器"""
|
||||
"""
|
||||
关闭记忆管理器
|
||||
"""
|
||||
if self.cleanup_task:
|
||||
self.cleanup_task.cancel()
|
||||
try:
|
||||
@@ -46,56 +52,83 @@ class ConversationMemoryManager:
|
||||
logger.info("对话记忆管理器已关闭")
|
||||
|
||||
@staticmethod
|
||||
def get_memory_key(session_id: str, user_id: str):
|
||||
"""计算内存Key"""
|
||||
def _get_memory_key(session_id: str, user_id: str):
|
||||
"""
|
||||
计算内存Key
|
||||
"""
|
||||
return f"{user_id}:{session_id}" if user_id else session_id
|
||||
|
||||
@staticmethod
|
||||
def get_redis_key(session_id: str, user_id: str):
|
||||
"""计算Redis Key"""
|
||||
def _get_redis_key(session_id: str, user_id: str):
|
||||
"""
|
||||
计算Redis Key
|
||||
"""
|
||||
return f"agent_memory:{user_id}:{session_id}" if user_id else f"agent_memory:{session_id}"
|
||||
|
||||
async def get_memory(self, session_id: str, user_id: str) -> ConversationMemory:
|
||||
"""获取会话记忆"""
|
||||
# 首先检查缓存
|
||||
cache_key = self.get_memory_key(session_id, user_id)
|
||||
if cache_key in self.memory_cache:
|
||||
return self.memory_cache[cache_key]
|
||||
|
||||
# 尝试从Redis加载
|
||||
def _get_memory(self, session_id: str, user_id: str):
|
||||
"""
|
||||
获取内存中的记忆
|
||||
"""
|
||||
cache_key = self._get_memory_key(session_id, user_id)
|
||||
return self.memory_cache.get(cache_key)
|
||||
|
||||
async def _get_redis(self, session_id: str, user_id: str) -> Optional[ConversationMemory]:
|
||||
"""
|
||||
从Redis获取记忆
|
||||
"""
|
||||
if settings.CACHE_BACKEND_TYPE == "redis":
|
||||
try:
|
||||
redis_key = self.get_redis_key(session_id, user_id)
|
||||
redis_key = self._get_redis_key(session_id, user_id)
|
||||
memory_data = await self.redis_helper.get(redis_key, region="AI_AGENT")
|
||||
if memory_data:
|
||||
memory_dict = json.loads(memory_data) if isinstance(memory_data, str) else memory_data
|
||||
memory = ConversationMemory(**memory_dict)
|
||||
self.memory_cache[cache_key] = memory
|
||||
return memory
|
||||
except Exception as e:
|
||||
logger.warning(f"从Redis加载记忆失败: {e}")
|
||||
return None
|
||||
|
||||
async def get_conversation(self, session_id: str, user_id: str) -> ConversationMemory:
|
||||
"""
|
||||
获取会话记忆
|
||||
"""
|
||||
# 首先检查缓存
|
||||
conversion = self._get_memory(session_id, user_id)
|
||||
if conversion:
|
||||
return conversion
|
||||
|
||||
# 尝试从Redis加载
|
||||
memory = await self._get_redis(session_id, user_id)
|
||||
if memory:
|
||||
# 加载到内存缓存
|
||||
self._save_memory(memory)
|
||||
return memory
|
||||
|
||||
# 创建新的记忆
|
||||
memory = ConversationMemory(session_id=session_id, user_id=user_id)
|
||||
self.memory_cache[cache_key] = memory
|
||||
await self._save_memory(memory)
|
||||
await self._save_conversation(memory)
|
||||
|
||||
return memory
|
||||
|
||||
async def set_title(self, session_id: str, user_id: str, title: str):
|
||||
"""设置会话标题"""
|
||||
memory = await self.get_memory(session_id=session_id, user_id=user_id)
|
||||
"""
|
||||
设置会话标题
|
||||
"""
|
||||
memory = await self.get_conversation(session_id=session_id, user_id=user_id)
|
||||
memory.title = title
|
||||
memory.updated_at = datetime.now()
|
||||
await self._save_memory(memory)
|
||||
await self._save_conversation(memory)
|
||||
|
||||
async def get_title(self, session_id: str, user_id: str) -> Optional[str]:
|
||||
"""获取会话标题"""
|
||||
memory = await self.get_memory(session_id=session_id, user_id=user_id)
|
||||
"""
|
||||
获取会话标题
|
||||
"""
|
||||
memory = await self.get_conversation(session_id=session_id, user_id=user_id)
|
||||
return memory.title
|
||||
|
||||
async def list_sessions(self, user_id: str, limit: int = 100) -> List[Dict[str, Any]]:
|
||||
"""列出历史会话摘要(按更新时间倒序)
|
||||
"""
|
||||
列出历史会话摘要(按更新时间倒序)
|
||||
|
||||
- 当启用Redis时:遍历 `agent_memory:*` 键并读取摘要
|
||||
- 当未启用Redis时:基于内存缓存返回
|
||||
@@ -148,7 +181,7 @@ class ConversationMemoryManager:
|
||||
for m in sorted_list
|
||||
]
|
||||
|
||||
async def add_memory(
|
||||
async def add_conversation(
|
||||
self,
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
@@ -156,8 +189,10 @@ class ConversationMemoryManager:
|
||||
content: str,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
):
|
||||
"""添加消息到记忆"""
|
||||
memory = await self.get_memory(session_id=session_id, user_id=user_id)
|
||||
"""
|
||||
添加消息到记忆
|
||||
"""
|
||||
memory = await self.get_conversation(session_id=session_id, user_id=user_id)
|
||||
|
||||
message = {
|
||||
"role": role,
|
||||
@@ -177,7 +212,7 @@ class ConversationMemoryManager:
|
||||
recent_messages = memory.messages[-(max_messages - len(system_messages)):]
|
||||
memory.messages = system_messages + recent_messages
|
||||
|
||||
await self._save_memory(memory)
|
||||
await self._save_conversation(memory)
|
||||
|
||||
logger.debug(f"消息已添加到记忆: session_id={session_id}, user_id={user_id}, role={role}")
|
||||
|
||||
@@ -186,11 +221,12 @@ class ConversationMemoryManager:
|
||||
session_id: str,
|
||||
user_id: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""为Agent获取最近的消息(仅内存缓存)
|
||||
"""
|
||||
为Agent获取最近的消息(仅内存缓存)
|
||||
|
||||
如果消息Token数量超过模型最大上下文长度的阀值,会自动进行摘要裁剪
|
||||
"""
|
||||
cache_key = self.get_memory_key(session_id, user_id)
|
||||
cache_key = self._get_memory_key(session_id, user_id)
|
||||
memory = self.memory_cache.get(cache_key)
|
||||
if not memory:
|
||||
return []
|
||||
@@ -205,8 +241,10 @@ class ConversationMemoryManager:
|
||||
limit: int = 10,
|
||||
role_filter: Optional[list] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""获取最近的消息"""
|
||||
memory = await self.get_memory(session_id=session_id, user_id=user_id)
|
||||
"""
|
||||
获取最近的消息
|
||||
"""
|
||||
memory = await self.get_conversation(session_id=session_id, user_id=user_id)
|
||||
|
||||
messages = memory.messages
|
||||
if role_filter:
|
||||
@@ -215,36 +253,41 @@ class ConversationMemoryManager:
|
||||
return messages[-limit:] if messages else []
|
||||
|
||||
async def get_context(self, session_id: str, user_id: str) -> Dict[str, Any]:
|
||||
"""获取会话上下文"""
|
||||
memory = await self.get_memory(session_id=session_id, user_id=user_id)
|
||||
"""
|
||||
获取会话上下文
|
||||
"""
|
||||
memory = await self.get_conversation(session_id=session_id, user_id=user_id)
|
||||
return memory.context
|
||||
|
||||
async def clear_memory(self, session_id: str, user_id: str):
|
||||
"""清空会话记忆"""
|
||||
"""
|
||||
清空会话记忆
|
||||
"""
|
||||
cache_key = f"{user_id}:{session_id}" if user_id else session_id
|
||||
if cache_key in self.memory_cache:
|
||||
del self.memory_cache[cache_key]
|
||||
|
||||
if settings.CACHE_BACKEND_TYPE == "redis":
|
||||
redis_key = self.get_redis_key(session_id, user_id)
|
||||
redis_key = self._get_redis_key(session_id, user_id)
|
||||
await self.redis_helper.delete(redis_key, region="AI_AGENT")
|
||||
|
||||
logger.info(f"会话记忆已清空: session_id={session_id}, user_id={user_id}")
|
||||
|
||||
async def _save_memory(self, memory: ConversationMemory):
|
||||
"""保存记忆到存储
|
||||
|
||||
Redis中的记忆会自动通过TTL机制过期,无需手动清理
|
||||
def _save_memory(self, memory: ConversationMemory):
|
||||
"""
|
||||
# 更新内存缓存
|
||||
cache_key = self.get_memory_key(memory.session_id, memory.user_id)
|
||||
保存记忆到内存
|
||||
"""
|
||||
cache_key = self._get_memory_key(memory.session_id, memory.user_id)
|
||||
self.memory_cache[cache_key] = memory
|
||||
|
||||
# 保存到Redis,设置TTL自动过期
|
||||
async def _save_redis(self, memory: ConversationMemory):
|
||||
"""
|
||||
保存记忆到Redis
|
||||
"""
|
||||
if settings.CACHE_BACKEND_TYPE == "redis":
|
||||
try:
|
||||
memory_dict = memory.model_dump()
|
||||
redis_key = self.get_redis_key(memory.session_id, memory.user_id)
|
||||
redis_key = self._get_redis_key(memory.session_id, memory.user_id)
|
||||
ttl = int(timedelta(days=settings.LLM_REDIS_MEMORY_RETENTION_DAYS).total_seconds())
|
||||
await self.redis_helper.set(
|
||||
redis_key,
|
||||
@@ -255,8 +298,22 @@ class ConversationMemoryManager:
|
||||
except Exception as e:
|
||||
logger.warning(f"保存记忆到Redis失败: {e}")
|
||||
|
||||
async def _save_conversation(self, memory: ConversationMemory):
|
||||
"""
|
||||
保存记忆到存储
|
||||
|
||||
Redis中的记忆会自动通过TTL机制过期,无需手动清理
|
||||
"""
|
||||
# 更新内存缓存
|
||||
self._save_memory(memory)
|
||||
|
||||
# 保存到Redis,设置TTL自动过期
|
||||
await self._save_redis(memory)
|
||||
|
||||
|
||||
async def _cleanup_expired_memories(self):
|
||||
"""清理内存中过期记忆的后台任务
|
||||
"""
|
||||
清理内存中过期记忆的后台任务
|
||||
|
||||
注意:Redis中的记忆通过TTL机制自动过期,这里只清理内存缓存
|
||||
"""
|
||||
@@ -286,3 +343,5 @@ class ConversationMemoryManager:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"清理记忆时发生错误: {e}")
|
||||
|
||||
conversation_manager = ConversationMemoryManager()
|
||||
|
||||
@@ -4,6 +4,20 @@ All your responses must be in **Chinese (中文)**.
|
||||
|
||||
You act as a proactive agent. Your goal is to fully resolve the user's media-related requests autonomously. Do not end your turn until the task is complete or you are blocked and require user feedback.
|
||||
|
||||
Core Capabilities:
|
||||
1. Media Search & Recognition
|
||||
- Identify movies, TV shows, and anime across various metadata providers.
|
||||
- Recognize media info from fuzzy filenames or incomplete titles.
|
||||
2. Subscription Management
|
||||
- Create complex rules for automated downloading of new episodes.
|
||||
- Monitor trending movies/shows for automated suggestions.
|
||||
3. Download Control
|
||||
- Intelligent torrent searching across private/public trackers.
|
||||
- Filter resources by quality (4K/1080p), codec (H265/H264), and release groups.
|
||||
4. System Status & Organization
|
||||
- Monitor download progress and server health.
|
||||
- Manage file transfers, renaming, and library cleanup.
|
||||
|
||||
<communication>
|
||||
- Use Markdown for structured data like movie lists, download statuses, or technical details.
|
||||
- Avoid wrapping the entire response in a single code block. Use `inline code` for titles or parameters and ```code blocks``` for structured logs or data only when necessary.
|
||||
@@ -11,6 +25,11 @@ You act as a proactive agent. Your goal is to fully resolve the user's media-rel
|
||||
- Optimize your writing for clarity and readability, using bold text for key information.
|
||||
- Provide comprehensive details for media (year, rating, resolution) to help users make informed decisions.
|
||||
- Do not stop for approval for read-only operations. Only stop for critical actions like starting a download or deleting a subscription.
|
||||
|
||||
Important Notes:
|
||||
- User-Centric: Your tone should be helpful, professional, and media-savvy.
|
||||
- No Coding Hallucinations: You are NOT a coding assistant. Do not offer code snippets, IDE tips, or programming help. Focus entirely on the MoviePilot media ecosystem.
|
||||
- Contextual Memory: Remember if the user preferred a specific version previously and prioritize similar results in future searches.
|
||||
</communication>
|
||||
|
||||
<status_update_spec>
|
||||
@@ -28,44 +47,26 @@ At the end of your session/turn, provide a concise summary of your actions.
|
||||
</summary_spec>
|
||||
|
||||
<flow>
|
||||
1. **Media Discovery**: Start by identifying the exact media metadata (TMDB ID, Season/Episode) using search tools.
|
||||
2. **Context Checking**: Verify current status (Is it already in the library? Is it already subscribed?).
|
||||
3. **Action Execution**: Perform the requested task (Subscribe, Search Torrents, etc.) with a brief status update.
|
||||
4. **Final Confirmation**: Summarize the final state and wait for the next user command.
|
||||
1. Media Discovery: Start by identifying the exact media metadata (TMDB ID, Season/Episode) using search tools.
|
||||
2. Context Checking: Verify current status (Is it already in the library? Is it already subscribed?).
|
||||
3. Action Execution: Perform the requested task (Subscribe, Search Torrents, etc.) with a brief status update.
|
||||
4. Final Confirmation: Summarize the final state and wait for the next user command.
|
||||
</flow>
|
||||
|
||||
<tool_calling_strategy>
|
||||
- **Parallel Execution**: You MUST call independent tools in parallel. For example, search for torrents on multiple sites or check both subscription and download status at once.
|
||||
- **Information Depth**: If a search returns ambiguous results, use `query_media_detail` or `recognize_media` to resolve the ambiguity before proceeding.
|
||||
- **Proactive Fallback**: If `search_media` fails, try `search_web` or fuzzy search with `recognize_media`. Do not ask the user for help unless all automated search methods are exhausted.
|
||||
- Parallel Execution: You MUST call independent tools in parallel. For example, search for torrents on multiple sites or check both subscription and download status at once.
|
||||
- Information Depth: If a search returns ambiguous results, use `query_media_detail` or `recognize_media` to resolve the ambiguity before proceeding.
|
||||
- Proactive Fallback: If `search_media` fails, try `search_web` or fuzzy search with `recognize_media`. Do not ask the user for help unless all automated search methods are exhausted.
|
||||
</tool_calling_strategy>
|
||||
|
||||
<media_management_rules>
|
||||
1. **Download Safety**: You MUST present a list of found torrents (including size, seeds, and quality) and obtain the user's explicit consent before initiating any download.
|
||||
2. **Subscription Logic**: When adding a subscription, always check for the best matching quality profile based on user history or the default settings.
|
||||
3. **Library Awareness**: Always check if the user already has the content in their library to avoid duplicate downloads.
|
||||
4. **Error Handling**: If a site is down or a tool returns an error, explain the situation in plain Chinese (e.g., "站点响应超时") and suggest an alternative (e.g., "尝试从其他站点进行搜索").
|
||||
1. Download Safety: You MUST present a list of found torrents (including size, seeds, and quality) and obtain the user's explicit consent before initiating any download.
|
||||
2. Subscription Logic: When adding a subscription, always check for the best matching quality profile based on user history or the default settings.
|
||||
3. Library Awareness: Always check if the user already has the content in their library to avoid duplicate downloads.
|
||||
4. Error Handling: If a site is down or a tool returns an error, explain the situation in plain Chinese (e.g., "站点响应超时") and suggest an alternative (e.g., "尝试从其他站点进行搜索").
|
||||
</media_management_rules>
|
||||
|
||||
## Core Capabilities
|
||||
|
||||
### 1. Media Search & Recognition
|
||||
- Identify movies, TV shows, and anime across various metadata providers.
|
||||
- Recognize media info from fuzzy filenames or incomplete titles.
|
||||
|
||||
### 2. Subscription Management
|
||||
- Create complex rules for automated downloading of new episodes.
|
||||
- Monitor trending movies/shows for automated suggestions.
|
||||
|
||||
### 3. Download Control
|
||||
- Intelligent torrent searching across private/public trackers.
|
||||
- Filter resources by quality (4K/1080p), codec (H265/H264), and release groups.
|
||||
|
||||
### 4. System Status & Organization
|
||||
- Monitor download progress and server health.
|
||||
- Manage file transfers, renaming, and library cleanup.
|
||||
|
||||
## Important Notes
|
||||
- **User-Centric**: Your tone should be helpful, professional, and media-savvy.
|
||||
- **No Coding Hallucinations**: You are NOT a coding assistant. Do not offer code snippets, IDE tips, or programming help. Focus entirely on the MoviePilot media ecosystem.
|
||||
- **Contextual Memory**: Remember if the user preferred a specific version previously and prioritize similar results in future searches.
|
||||
<markdown_spec>
|
||||
Specific markdown rules:
|
||||
{markdown_spec}
|
||||
</markdown_spec>
|
||||
@@ -1,13 +1,15 @@
|
||||
"""提示词管理器"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
from app.log import logger
|
||||
from app.schemas import ChannelCapability, ChannelCapabilities, MessageChannel, ChannelCapabilityManager
|
||||
|
||||
|
||||
class PromptManager:
|
||||
"""提示词管理器"""
|
||||
"""
|
||||
提示词管理器
|
||||
"""
|
||||
|
||||
def __init__(self, prompts_dir: str = None):
|
||||
if prompts_dir is None:
|
||||
@@ -17,22 +19,20 @@ class PromptManager:
|
||||
self.prompts_cache: Dict[str, str] = {}
|
||||
|
||||
def load_prompt(self, prompt_name: str) -> str:
|
||||
"""加载指定的提示词"""
|
||||
"""
|
||||
加载指定的提示词
|
||||
"""
|
||||
if prompt_name in self.prompts_cache:
|
||||
return self.prompts_cache[prompt_name]
|
||||
|
||||
prompt_file = self.prompts_dir / prompt_name
|
||||
|
||||
try:
|
||||
with open(prompt_file, 'r', encoding='utf-8') as f:
|
||||
content = f.read().strip()
|
||||
|
||||
# 缓存提示词
|
||||
self.prompts_cache[prompt_name] = content
|
||||
|
||||
logger.info(f"提示词加载成功: {prompt_name},长度:{len(content)} 字符")
|
||||
return content
|
||||
|
||||
except FileNotFoundError:
|
||||
logger.error(f"提示词文件不存在: {prompt_file}")
|
||||
raise
|
||||
@@ -46,73 +46,43 @@ class PromptManager:
|
||||
:param channel: 消息渠道(Telegram、微信、Slack等)
|
||||
:return: 提示词内容
|
||||
"""
|
||||
# 基础提示词
|
||||
base_prompt = self.load_prompt("Agent Prompt.txt")
|
||||
|
||||
# 根据渠道添加特定的格式说明
|
||||
if channel:
|
||||
channel_format_info = self._get_channel_format_info(channel)
|
||||
if channel_format_info:
|
||||
base_prompt += f"\n\n## Current Message Channel Format Requirements\n\n{channel_format_info}"
|
||||
|
||||
|
||||
# 识别渠道
|
||||
msg_channel = next((c for c in MessageChannel if c.value.lower() == channel.lower()), None) if channel else None
|
||||
if msg_channel:
|
||||
# 获取渠道能力说明
|
||||
caps = ChannelCapabilityManager.get_capabilities(msg_channel)
|
||||
if caps:
|
||||
base_prompt = base_prompt.replace(
|
||||
"{markdown_spec}",
|
||||
self._generate_formatting_instructions(caps)
|
||||
)
|
||||
|
||||
return base_prompt
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _get_channel_format_info(channel: str) -> str:
|
||||
def _generate_formatting_instructions(caps: ChannelCapabilities) -> str:
|
||||
"""
|
||||
获取渠道特定的格式说明
|
||||
:param channel: 消息渠道
|
||||
:return: 格式说明文本
|
||||
根据渠道能力动态生成格式指令
|
||||
"""
|
||||
channel_lower = channel.lower() if channel else ""
|
||||
|
||||
if "telegram" in channel_lower:
|
||||
return """Messages are being sent through the **Telegram** channel. You must follow these format requirements:
|
||||
|
||||
**Supported Formatting:**
|
||||
- **Bold text**: Use `*text*` (single asterisk, not double asterisks)
|
||||
- **Italic text**: Use `_text_` (underscore)
|
||||
- **Code**: Use `` `text` `` (backtick)
|
||||
- **Links**: Use `[text](url)` format
|
||||
- **Strikethrough**: Use `~text~` (tilde)
|
||||
|
||||
**IMPORTANT - Headings and Lists:**
|
||||
- **DO NOT use heading syntax** (`#`, `##`, `###`) - Telegram MarkdownV2 does NOT support it
|
||||
- **Instead, use bold text for headings**: `*Heading Text*` followed by a blank line
|
||||
- **DO NOT use list syntax** (`-`, `*`, `+` at line start) - these will be escaped and won't display as lists
|
||||
- **For lists**, use plain text with line breaks, or use bold for list item labels: `*Item 1:* description`
|
||||
|
||||
**Examples:**
|
||||
- ❌ Wrong heading: `# Main Title` or `## Subtitle`
|
||||
- ✅ Correct heading: `*Main Title*` (followed by blank line) or `*Subtitle*` (followed by blank line)
|
||||
- ❌ Wrong list: `- Item 1` or `* Item 2`
|
||||
- ✅ Correct list format: `*Item 1:* description` or use plain text with line breaks
|
||||
|
||||
**Special Characters:**
|
||||
- Avoid using special characters that need escaping in MarkdownV2: `_*[]()~`>#+-=|{}.!` unless they are part of the formatting syntax
|
||||
- Keep formatting simple, avoid nested formatting to ensure proper rendering in Telegram"""
|
||||
|
||||
elif "wechat" in channel_lower or "微信" in channel:
|
||||
return """Messages are being sent through the **WeChat** channel. Please follow these format requirements:
|
||||
|
||||
- WeChat does NOT support Markdown formatting. Use plain text format only.
|
||||
- Do NOT use any Markdown syntax (such as `**bold**`, `*italic*`, `` `code` `` etc.)
|
||||
- Use plain text descriptions. You can organize content using line breaks and punctuation
|
||||
- Links can be provided directly as URLs, no Markdown link format needed
|
||||
- Keep messages concise and clear, use natural Chinese expressions"""
|
||||
|
||||
elif "slack" in channel_lower:
|
||||
return """Messages are being sent through the **Slack** channel. Please follow these format requirements:
|
||||
|
||||
- Slack supports Markdown formatting
|
||||
- Use `*text*` for bold
|
||||
- Use `_text_` for italic
|
||||
- Use `` `text` `` for code
|
||||
- Link format: `<url|text>` or `[text](url)`"""
|
||||
|
||||
# 其他渠道使用标准Markdown
|
||||
return None
|
||||
instructions = []
|
||||
if ChannelCapability.RICH_TEXT not in caps.capabilities:
|
||||
instructions.append("- Formatting: Use **Plain Text ONLY**. The channel does NOT support Markdown.")
|
||||
instructions.append(
|
||||
"- No Markdown Symbols: NEVER use `**`, `*`, `__`, or `[` blocks. Use natural text to emphasize (e.g., using ALL CAPS or separators).")
|
||||
instructions.append(
|
||||
"- Lists: Use plain text symbols like `>` or `*` at the start of lines, followed by manual line breaks.")
|
||||
instructions.append("- Links: Paste URLs directly as text.")
|
||||
return "\n".join(instructions)
|
||||
|
||||
def clear_cache(self):
|
||||
"""清空缓存"""
|
||||
"""
|
||||
清空缓存
|
||||
"""
|
||||
self.prompts_cache.clear()
|
||||
logger.info("提示词缓存已清空")
|
||||
|
||||
|
||||
prompt_manager = PromptManager()
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
"""MoviePilot工具基类"""
|
||||
import json
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import Any, Optional
|
||||
@@ -6,7 +5,7 @@ from typing import Any, Optional
|
||||
from langchain.tools import BaseTool
|
||||
from pydantic import PrivateAttr
|
||||
|
||||
from app.agent import StreamingCallbackHandler, ConversationMemoryManager
|
||||
from app.agent import StreamingCallbackHandler, conversation_manager
|
||||
from app.chain import ChainBase
|
||||
from app.log import logger
|
||||
from app.schemas import Notification
|
||||
@@ -17,7 +16,9 @@ class ToolChain(ChainBase):
|
||||
|
||||
|
||||
class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
"""MoviePilot专用工具基类"""
|
||||
"""
|
||||
MoviePilot专用工具基类
|
||||
"""
|
||||
|
||||
_session_id: str = PrivateAttr()
|
||||
_user_id: str = PrivateAttr()
|
||||
@@ -25,7 +26,6 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
_source: str = PrivateAttr(default=None)
|
||||
_username: str = PrivateAttr(default=None)
|
||||
_callback_handler: StreamingCallbackHandler = PrivateAttr(default=None)
|
||||
_memory_manager: ConversationMemoryManager = PrivateAttr(default=None)
|
||||
|
||||
def __init__(self, session_id: str, user_id: str, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
@@ -36,7 +36,9 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
pass
|
||||
|
||||
async def _arun(self, **kwargs) -> str:
|
||||
"""异步运行工具"""
|
||||
"""
|
||||
异步运行工具
|
||||
"""
|
||||
# 发送和记忆工具调用前的信息
|
||||
agent_message = await self._callback_handler.get_message()
|
||||
if agent_message:
|
||||
@@ -44,7 +46,7 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
await self.send_tool_message(agent_message, title="MoviePilot助手")
|
||||
|
||||
# 记忆工具调用
|
||||
await self._memory_manager.add_memory(
|
||||
await conversation_manager.add_conversation(
|
||||
session_id=self._session_id,
|
||||
user_id=self._user_id,
|
||||
role="tool_call",
|
||||
@@ -77,7 +79,7 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
formated_result = str(result)
|
||||
else:
|
||||
formated_result = json.dumps(result, ensure_ascii=False, indent=2)
|
||||
await self._memory_manager.add_memory(
|
||||
await conversation_manager.add_conversation(
|
||||
session_id=self._session_id,
|
||||
user_id=self._user_id,
|
||||
role="tool_result",
|
||||
@@ -106,21 +108,23 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
raise NotImplementedError
|
||||
|
||||
def set_message_attr(self, channel: str, source: str, username: str):
|
||||
"""设置消息属性"""
|
||||
"""
|
||||
设置消息属性
|
||||
"""
|
||||
self._channel = channel
|
||||
self._source = source
|
||||
self._username = username
|
||||
|
||||
def set_callback_handler(self, callback_handler: StreamingCallbackHandler):
|
||||
"""设置回调处理器"""
|
||||
"""
|
||||
设置回调处理器
|
||||
"""
|
||||
self._callback_handler = callback_handler
|
||||
|
||||
def set_memory_manager(self, memory_manager: ConversationMemoryManager):
|
||||
"""设置记忆客理器"""
|
||||
self._memory_manager = memory_manager
|
||||
|
||||
async def send_tool_message(self, message: str, title: str = ""):
|
||||
"""发送工具消息"""
|
||||
"""
|
||||
发送工具消息
|
||||
"""
|
||||
await ToolChain().async_post_message(
|
||||
Notification(
|
||||
channel=self._channel,
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
"""MoviePilot工具工厂"""
|
||||
|
||||
from typing import List, Callable
|
||||
|
||||
from app.agent.tools.impl.add_download import AddDownloadTool
|
||||
@@ -47,13 +45,17 @@ from .base import MoviePilotTool
|
||||
|
||||
|
||||
class MoviePilotToolFactory:
|
||||
"""MoviePilot工具工厂"""
|
||||
"""
|
||||
MoviePilot工具工厂
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def create_tools(session_id: str, user_id: str,
|
||||
channel: str = None, source: str = None, username: str = None,
|
||||
callback_handler: Callable = None, memory_mananger: Callable = None) -> List[MoviePilotTool]:
|
||||
"""创建MoviePilot工具列表"""
|
||||
callback_handler: Callable = None) -> List[MoviePilotTool]:
|
||||
"""
|
||||
创建MoviePilot工具列表
|
||||
"""
|
||||
tools = []
|
||||
tool_definitions = [
|
||||
SearchMediaTool,
|
||||
@@ -104,7 +106,6 @@ class MoviePilotToolFactory:
|
||||
)
|
||||
tool.set_message_attr(channel=channel, source=source, username=username)
|
||||
tool.set_callback_handler(callback_handler=callback_handler)
|
||||
tool.set_memory_manager(memory_manager=memory_mananger)
|
||||
tools.append(tool)
|
||||
|
||||
# 加载插件提供的工具
|
||||
@@ -127,7 +128,6 @@ class MoviePilotToolFactory:
|
||||
)
|
||||
tool.set_message_attr(channel=channel, source=source, username=username)
|
||||
tool.set_callback_handler(callback_handler=callback_handler)
|
||||
tool.set_memory_manager(memory_manager=memory_mananger)
|
||||
tools.append(tool)
|
||||
plugin_tools_count += 1
|
||||
logger.debug(f"成功加载插件 {plugin_name}({plugin_id}) 的工具: {ToolClass.__name__}")
|
||||
|
||||
@@ -1,18 +1,15 @@
|
||||
"""MoviePilot工具管理器
|
||||
用于HTTP API调用工具
|
||||
"""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.agent import ConversationMemoryManager
|
||||
from app.agent.tools.factory import MoviePilotToolFactory
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class ToolDefinition:
|
||||
"""工具定义"""
|
||||
"""
|
||||
工具定义
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, description: str, input_schema: Dict[str, Any]):
|
||||
self.name = name
|
||||
@@ -21,7 +18,9 @@ class ToolDefinition:
|
||||
|
||||
|
||||
class MoviePilotToolsManager:
|
||||
"""MoviePilot工具管理器(用于HTTP API)"""
|
||||
"""
|
||||
MoviePilot工具管理器(用于HTTP API)
|
||||
"""
|
||||
|
||||
def __init__(self, user_id: str = "api_user", session_id: str = uuid.uuid4()):
|
||||
"""
|
||||
@@ -34,11 +33,12 @@ class MoviePilotToolsManager:
|
||||
self.user_id = user_id
|
||||
self.session_id = session_id
|
||||
self.tools: List[Any] = []
|
||||
self.memory_manager = ConversationMemoryManager()
|
||||
self._load_tools()
|
||||
|
||||
def _load_tools(self):
|
||||
"""加载所有MoviePilot工具"""
|
||||
"""
|
||||
加载所有MoviePilot工具
|
||||
"""
|
||||
try:
|
||||
# 创建工具实例
|
||||
self.tools = MoviePilotToolFactory.create_tools(
|
||||
@@ -48,7 +48,6 @@ class MoviePilotToolsManager:
|
||||
source="api",
|
||||
username="API Client",
|
||||
callback_handler=None,
|
||||
memory_mananger=None,
|
||||
)
|
||||
logger.info(f"成功加载 {len(self.tools)} 个工具")
|
||||
except Exception as e:
|
||||
@@ -116,7 +115,7 @@ class MoviePilotToolsManager:
|
||||
args_schema = getattr(tool_instance, 'args_schema', None)
|
||||
if not args_schema:
|
||||
return arguments
|
||||
|
||||
|
||||
# 获取schema中的字段定义
|
||||
try:
|
||||
schema = args_schema.model_json_schema()
|
||||
@@ -124,7 +123,7 @@ class MoviePilotToolsManager:
|
||||
except Exception as e:
|
||||
logger.warning(f"获取工具schema失败: {e}")
|
||||
return arguments
|
||||
|
||||
|
||||
# 规范化参数
|
||||
normalized = {}
|
||||
for key, value in arguments.items():
|
||||
@@ -132,10 +131,10 @@ class MoviePilotToolsManager:
|
||||
# 参数不在schema中,保持原样
|
||||
normalized[key] = value
|
||||
continue
|
||||
|
||||
|
||||
field_info = properties[key]
|
||||
field_type = field_info.get("type")
|
||||
|
||||
|
||||
# 处理 anyOf 类型(例如 Optional[int] 会生成 anyOf)
|
||||
any_of = field_info.get("anyOf")
|
||||
if any_of and not field_type:
|
||||
@@ -144,7 +143,7 @@ class MoviePilotToolsManager:
|
||||
if "type" in type_option and type_option["type"] != "null":
|
||||
field_type = type_option["type"]
|
||||
break
|
||||
|
||||
|
||||
# 根据类型进行转换
|
||||
if field_type == "integer" and isinstance(value, str):
|
||||
try:
|
||||
@@ -167,7 +166,7 @@ class MoviePilotToolsManager:
|
||||
normalized[key] = True
|
||||
else:
|
||||
normalized[key] = value
|
||||
|
||||
|
||||
return normalized
|
||||
|
||||
async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> str:
|
||||
@@ -192,7 +191,7 @@ class MoviePilotToolsManager:
|
||||
try:
|
||||
# 规范化参数类型
|
||||
normalized_arguments = self._normalize_arguments(tool_instance, arguments)
|
||||
|
||||
|
||||
# 调用工具的run方法
|
||||
result = await tool_instance.run(**normalized_arguments)
|
||||
|
||||
@@ -270,3 +269,6 @@ class MoviePilotToolsManager:
|
||||
"properties": properties,
|
||||
"required": required
|
||||
}
|
||||
|
||||
|
||||
moviepilot_tool_manager = MoviePilotToolsManager()
|
||||
|
||||
@@ -1,14 +1,10 @@
|
||||
"""工具API端点
|
||||
通过HTTP API暴露MoviePilot的智能体工具功能
|
||||
"""
|
||||
|
||||
from typing import List, Any, Dict, Annotated, Union
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse, Response
|
||||
|
||||
from app import schemas
|
||||
from app.agent.tools.manager import MoviePilotToolsManager
|
||||
from app.agent.tools.manager import moviepilot_tool_manager
|
||||
from app.core.security import verify_apikey
|
||||
from app.log import logger
|
||||
|
||||
@@ -25,18 +21,10 @@ MCP_PROTOCOL_VERSIONS = ["2025-11-25", "2025-06-18", "2024-11-05"]
|
||||
MCP_PROTOCOL_VERSION = MCP_PROTOCOL_VERSIONS[0] # 默认使用最新版本
|
||||
|
||||
|
||||
def get_tools_manager() -> MoviePilotToolsManager:
|
||||
"""
|
||||
获取工具管理器实例
|
||||
|
||||
Returns:
|
||||
MoviePilotToolsManager实例
|
||||
"""
|
||||
return MoviePilotToolsManager()
|
||||
|
||||
|
||||
def create_jsonrpc_response(request_id: Union[str, int, None], result: Any) -> Dict[str, Any]:
|
||||
"""创建 JSON-RPC 成功响应"""
|
||||
"""
|
||||
创建 JSON-RPC 成功响应
|
||||
"""
|
||||
response = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
@@ -45,8 +33,11 @@ def create_jsonrpc_response(request_id: Union[str, int, None], result: Any) -> D
|
||||
return response
|
||||
|
||||
|
||||
def create_jsonrpc_error(request_id: Union[str, int, None], code: int, message: str, data: Any = None) -> Dict[str, Any]:
|
||||
"""创建 JSON-RPC 错误响应"""
|
||||
def create_jsonrpc_error(request_id: Union[str, int, None], code: int, message: str, data: Any = None) -> Dict[
|
||||
str, Any]:
|
||||
"""
|
||||
创建 JSON-RPC 错误响应
|
||||
"""
|
||||
error = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
@@ -60,8 +51,6 @@ def create_jsonrpc_error(request_id: Union[str, int, None], code: int, message:
|
||||
return error
|
||||
|
||||
|
||||
# ==================== MCP JSON-RPC 端点 ====================
|
||||
|
||||
@router.post("", summary="MCP JSON-RPC 端点", response_model=None)
|
||||
async def mcp_jsonrpc(
|
||||
request: Request,
|
||||
@@ -146,7 +135,9 @@ async def mcp_jsonrpc(
|
||||
|
||||
|
||||
async def handle_initialize(params: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""处理初始化请求"""
|
||||
"""
|
||||
处理初始化请求
|
||||
"""
|
||||
protocol_version = params.get("protocolVersion")
|
||||
client_info = params.get("clientInfo", {})
|
||||
|
||||
@@ -161,7 +152,7 @@ async def handle_initialize(params: Dict[str, Any]) -> Dict[str, Any]:
|
||||
else:
|
||||
# 客户端版本不支持,使用服务器默认版本
|
||||
logger.warning(f"协议版本不匹配: 客户端={protocol_version}, 使用服务器版本={negotiated_version}")
|
||||
|
||||
|
||||
return {
|
||||
"protocolVersion": negotiated_version,
|
||||
"capabilities": {
|
||||
@@ -180,9 +171,10 @@ async def handle_initialize(params: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
||||
|
||||
async def handle_tools_list() -> Dict[str, Any]:
|
||||
"""处理工具列表请求"""
|
||||
manager = get_tools_manager()
|
||||
tools = manager.list_tools()
|
||||
"""
|
||||
处理工具列表请求
|
||||
"""
|
||||
tools = moviepilot_tool_manager.list_tools()
|
||||
|
||||
# 转换为 MCP 工具格式
|
||||
mcp_tools = []
|
||||
@@ -200,18 +192,18 @@ async def handle_tools_list() -> Dict[str, Any]:
|
||||
|
||||
|
||||
async def handle_tools_call(params: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""处理工具调用请求"""
|
||||
"""
|
||||
处理工具调用请求
|
||||
"""
|
||||
tool_name = params.get("name")
|
||||
arguments = params.get("arguments", {})
|
||||
|
||||
if not tool_name:
|
||||
raise ValueError("Missing tool name")
|
||||
|
||||
manager = get_tools_manager()
|
||||
|
||||
try:
|
||||
result_text = await manager.call_tool(tool_name, arguments)
|
||||
|
||||
result_text = await moviepilot_tool_manager.call_tool(tool_name, arguments)
|
||||
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
@@ -243,8 +235,6 @@ async def delete_mcp_session(
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
|
||||
|
||||
# ==================== 兼容的 RESTful API 端点 ====================
|
||||
|
||||
@router.get("/tools", summary="列出所有可用工具", response_model=List[Dict[str, Any]])
|
||||
@@ -257,9 +247,8 @@ async def list_tools(
|
||||
返回每个工具的名称、描述和参数定义
|
||||
"""
|
||||
try:
|
||||
manager = get_tools_manager()
|
||||
# 获取所有工具定义
|
||||
tools = manager.list_tools()
|
||||
tools = moviepilot_tool_manager.list_tools()
|
||||
|
||||
# 转换为字典格式
|
||||
tools_list = []
|
||||
@@ -289,11 +278,8 @@ async def call_tool(
|
||||
工具执行结果
|
||||
"""
|
||||
try:
|
||||
# 使用当前用户ID创建管理器实例
|
||||
manager = get_tools_manager()
|
||||
|
||||
# 调用工具
|
||||
result_text = await manager.call_tool(request.tool_name, request.arguments)
|
||||
result_text = await moviepilot_tool_manager.call_tool(request.tool_name, request.arguments)
|
||||
|
||||
return schemas.ToolCallResponse(
|
||||
success=True,
|
||||
@@ -319,9 +305,8 @@ async def get_tool_info(
|
||||
工具的详细信息,包括名称、描述和参数定义
|
||||
"""
|
||||
try:
|
||||
manager = get_tools_manager()
|
||||
# 获取所有工具
|
||||
tools = manager.list_tools()
|
||||
tools = moviepilot_tool_manager.list_tools()
|
||||
|
||||
# 查找指定工具
|
||||
for tool in tools:
|
||||
@@ -352,9 +337,8 @@ async def get_tool_schema(
|
||||
工具的JSON Schema定义
|
||||
"""
|
||||
try:
|
||||
manager = get_tools_manager()
|
||||
# 获取所有工具
|
||||
tools = manager.list_tools()
|
||||
tools = moviepilot_tool_manager.list_tools()
|
||||
|
||||
# 查找指定工具
|
||||
for tool in tools:
|
||||
|
||||
@@ -1,20 +1,17 @@
|
||||
"""
|
||||
AI智能体初始化器
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
from typing import Optional
|
||||
|
||||
from app.agent import agent_manager
|
||||
from app.core.config import settings
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class AgentInitializer:
|
||||
"""AI智能体初始化器"""
|
||||
"""
|
||||
AI智能体初始化器
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.agent_manager = None
|
||||
self._initialized = False
|
||||
|
||||
async def initialize(self) -> bool:
|
||||
@@ -26,9 +23,6 @@ class AgentInitializer:
|
||||
logger.info("AI智能体功能未启用")
|
||||
return True
|
||||
|
||||
from app.agent import agent_manager
|
||||
self.agent_manager = agent_manager
|
||||
|
||||
await agent_manager.initialize()
|
||||
self._initialized = True
|
||||
logger.info("AI智能体管理器初始化成功")
|
||||
@@ -43,10 +37,10 @@ class AgentInitializer:
|
||||
清理AI智能体管理器
|
||||
"""
|
||||
try:
|
||||
if not self._initialized or not self.agent_manager:
|
||||
if not self._initialized:
|
||||
return
|
||||
|
||||
await self.agent_manager.close()
|
||||
await agent_manager.close()
|
||||
self._initialized = False
|
||||
logger.info("AI智能体管理器已关闭")
|
||||
|
||||
@@ -78,8 +72,8 @@ def init_agent():
|
||||
else:
|
||||
logger.error("AI智能体管理器初始化失败")
|
||||
return success
|
||||
except Exception as e:
|
||||
logger.error(f"AI智能体管理器初始化失败: {e}")
|
||||
except Exception as err:
|
||||
logger.error(f"AI智能体管理器初始化失败: {err}")
|
||||
return False
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
Reference in New Issue
Block a user