fix agent

This commit is contained in:
jxxghp
2025-11-01 19:08:05 +08:00
parent e885fb15a0
commit 8016a9539a
14 changed files with 160 additions and 112 deletions

View File

@@ -1,17 +1,16 @@
"""MoviePilot AI智能体实现"""
import asyncio
import threading
from typing import Dict, List, Any
from langchain.agents import AgentExecutor, create_openai_tools_agent
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_community.callbacks import get_openai_callback
from langchain_core.callbacks import AsyncCallbackHandler
from langchain_core.chat_history import InMemoryChatMessageHistory
from langchain_core.messages import HumanMessage, AIMessage, ToolCall
from langchain_core.runnables.history import RunnableWithMessageHistory
from app.agent.callback import StreamingCallbackHandler
from app.agent.memory import ConversationMemoryManager
from app.agent.prompt import PromptManager
from app.agent.tools import MoviePilotToolFactory
@@ -26,34 +25,6 @@ class AgentChain(ChainBase):
pass
class StreamingCallbackHandler(AsyncCallbackHandler):
"""流式输出回调处理器"""
def __init__(self, session_id: str):
self._lock = threading.Lock()
self.session_id = session_id
self.current_message = ""
self.message_helper = MessageHelper()
async def get_message(self):
"""获取当前消息内容,获取后清空"""
with self._lock:
if not self.current_message:
return ""
msg = self.current_message
logger.info(f"Agent消息: {msg}")
self.current_message = ""
return msg
async def on_llm_new_token(self, token: str, **kwargs):
"""处理新的token"""
if not token:
return
with self._lock:
# 缓存当前消息
self.current_message += token
class MoviePilotAgent:
"""MoviePilot AI智能体"""
@@ -142,7 +113,8 @@ class MoviePilotAgent:
user_id=self.user_id,
channel=self.channel,
source=self.source,
username=self.username
username=self.username,
callback_handler=self.callback_handler
)
@staticmethod
@@ -249,7 +221,7 @@ class MoviePilotAgent:
agent_message = await self.callback_handler.get_message()
# 发送Agent回复给用户通过原渠道
self.send_agent_message(agent_message)
await self.send_agent_message(agent_message)
# 添加Agent回复到记忆
await self.memory_manager.add_memory(
@@ -265,7 +237,7 @@ class MoviePilotAgent:
error_message = f"处理消息时发生错误: {str(e)}"
logger.error(error_message)
# 发送错误消息给用户(通过原渠道)
self.send_agent_message(error_message)
await self.send_agent_message(error_message)
return error_message
async def _execute_agent(self, input_context: Dict[str, Any]) -> Dict[str, Any]:
@@ -301,9 +273,9 @@ class MoviePilotAgent:
"token_usage": {}
}
def send_agent_message(self, message: str, title: str = "MoviePilot助手"):
async def send_agent_message(self, message: str, title: str = "MoviePilot助手"):
"""通过原渠道发送消息给用户"""
AgentChain().post_message(
await AgentChain().async_post_message(
Notification(
channel=self.channel,
source=self.source,

View File

@@ -0,0 +1,33 @@
import threading
from langchain_core.callbacks import AsyncCallbackHandler
from app.log import logger
class StreamingCallbackHandler(AsyncCallbackHandler):
"""流式输出回调处理器"""
def __init__(self, session_id: str):
self._lock = threading.Lock()
self.session_id = session_id
self.current_message = ""
async def get_message(self):
"""获取当前消息内容,获取后清空"""
with self._lock:
if not self.current_message:
return ""
msg = self.current_message
logger.info(f"Agent消息: {msg}")
self.current_message = ""
return msg
async def on_llm_new_token(self, token: str, **kwargs):
"""处理新的token"""
if not token:
return
with self._lock:
# 缓存当前消息
self.current_message += token

View File

@@ -1,8 +1,11 @@
"""MoviePilot工具基类"""
from abc import ABCMeta, abstractmethod
from typing import Callable, Any
from langchain.tools import BaseTool
from pydantic import PrivateAttr
from app.agent import StreamingCallbackHandler
from app.chain import ChainBase
from app.schemas import Notification
@@ -11,7 +14,7 @@ class ToolChain(ChainBase):
pass
class MoviePilotTool(BaseTool):
class MoviePilotTool(BaseTool, metaclass=ABCMeta):
"""MoviePilot专用工具基类"""
_session_id: str = PrivateAttr()
@@ -19,16 +22,26 @@ class MoviePilotTool(BaseTool):
_channel: str = PrivateAttr(default=None)
_source: str = PrivateAttr(default=None)
_username: str = PrivateAttr(default=None)
_callback_handler: StreamingCallbackHandler = PrivateAttr(default=None)
def __init__(self, session_id: str, user_id: str, **kwargs):
super().__init__(**kwargs)
self._session_id = session_id
self._user_id = user_id
def _run(self, **kwargs) -> str:
raise NotImplementedError
def _run(self, *args: Any, **kwargs: Any) -> Any:
pass
async def _arun(self, **kwargs) -> str:
"""异步运行工具"""
# 发送运行工具前的消息
agent_message = await self._callback_handler.get_message()
if agent_message:
await self.send_tool_message(agent_message)
return await self.run(**kwargs)
@abstractmethod
async def run(self, **kwargs) -> str:
raise NotImplementedError
def set_message_attr(self, channel: str, source: str, username: str):
@@ -37,9 +50,13 @@ class MoviePilotTool(BaseTool):
self._source = source
self._username = username
def send_tool_message(self, message: str, title: str = "执行工具"):
def set_callback_handler(self, callback_handler: StreamingCallbackHandler):
"""设置回调处理器"""
self._callback_handler = callback_handler
async def send_tool_message(self, message: str, title: str = "执行工具"):
"""发送工具消息"""
ToolChain().post_message(
await ToolChain().async_post_message(
Notification(
channel=self._channel,
source=self._source,

View File

@@ -1,6 +1,6 @@
"""MoviePilot工具工厂"""
from typing import List
from typing import List, Callable
from app.agent.tools.impl.add_download import AddDownloadTool
from app.agent.tools.impl.add_subscribe import AddSubscribeTool
@@ -21,7 +21,8 @@ class MoviePilotToolFactory:
@staticmethod
def create_tools(session_id: str, user_id: str,
channel: str = None, source: str = None, username: str = None) -> List[MoviePilotTool]:
channel: str = None, source: str = None, username: str = None,
callback_handler: Callable = None) -> List[MoviePilotTool]:
"""创建MoviePilot工具列表"""
tools = []
tool_definitions = [
@@ -42,6 +43,7 @@ class MoviePilotToolFactory:
user_id=user_id
)
tool.set_message_attr(channel=channel, source=source, username=username)
tool.set_callback_handler(callback_handler=callback_handler)
tools.append(tool)
logger.info(f"成功创建 {len(tools)} 个MoviePilot工具")
return tools

View File

@@ -15,11 +15,15 @@ from app.schemas import TorrentInfo
class AddDownloadInput(BaseModel):
"""添加下载工具的输入参数模型"""
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
torrent_title: str = Field(..., description="The display name/title of the torrent (e.g., 'The.Matrix.1999.1080p.BluRay.x264')")
torrent_title: str = Field(...,
description="The display name/title of the torrent (e.g., 'The.Matrix.1999.1080p.BluRay.x264')")
torrent_url: str = Field(..., description="Direct URL to the torrent file (.torrent) or magnet link")
downloader: Optional[str] = Field(None, description="Name of the downloader to use (optional, uses default if not specified)")
save_path: Optional[str] = Field(None, description="Directory path where the downloaded files should be saved (optional, uses default path if not specified)")
labels: Optional[str] = Field(None, description="Comma-separated list of labels/tags to assign to the download (optional, e.g., 'movie,hd,bluray')")
downloader: Optional[str] = Field(None,
description="Name of the downloader to use (optional, uses default if not specified)")
save_path: Optional[str] = Field(None,
description="Directory path where the downloaded files should be saved (optional, uses default path if not specified)")
labels: Optional[str] = Field(None,
description="Comma-separated list of labels/tags to assign to the download (optional, e.g., 'movie,hd,bluray')")
class AddDownloadTool(MoviePilotTool):
@@ -27,19 +31,19 @@ class AddDownloadTool(MoviePilotTool):
description: str = "Add torrent download task to the configured downloader (qBittorrent, Transmission, etc.). Downloads the torrent file and starts the download process with specified settings."
args_schema: Type[BaseModel] = AddDownloadInput
async def _arun(self, torrent_title: str, torrent_url: str,
downloader: Optional[str] = None, save_path: Optional[str] = None,
labels: Optional[str] = None, **kwargs) -> str:
async def run(self, torrent_title: str, torrent_url: str,
downloader: Optional[str] = None, save_path: Optional[str] = None,
labels: Optional[str] = None, **kwargs) -> str:
logger.info(
f"执行工具: {self.name}, 参数: torrent_title={torrent_title}, torrent_url={torrent_url}, downloader={downloader}, save_path={save_path}, labels={labels}")
# 发送工具执行说明
self.send_tool_message(f"正在添加下载任务: {torrent_title}", title="添加下载")
await self.send_tool_message(f"正在添加下载任务: {torrent_title}", title="添加下载")
try:
if not torrent_title or not torrent_url:
error_message = "错误:必须提供种子标题和下载链接"
self.send_tool_message(error_message, title="下载失败")
await self.send_tool_message(error_message, title="下载失败")
return error_message
# 使用DownloadChain添加下载
@@ -64,14 +68,14 @@ class AddDownloadTool(MoviePilotTool):
)
if did:
success_message = f"成功添加下载任务:{torrent_title}"
self.send_tool_message(success_message, title="下载成功")
await self.send_tool_message(success_message, title="下载成功")
return success_message
else:
error_message = "添加下载任务失败"
self.send_tool_message(error_message, title="下载失败")
await self.send_tool_message(error_message, title="下载失败")
return error_message
except Exception as e:
error_message = f"添加下载任务时发生错误: {str(e)}"
logger.error(f"添加下载任务失败: {e}", exc_info=True)
self.send_tool_message(error_message, title="下载失败")
await self.send_tool_message(error_message, title="下载失败")
return error_message

View File

@@ -15,9 +15,12 @@ class AddSubscribeInput(BaseModel):
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
title: str = Field(..., description="The title of the media to subscribe to (e.g., 'The Matrix', 'Breaking Bad')")
year: str = Field(..., description="Release year of the media (required for accurate identification)")
media_type: str = Field(..., description="Type of media content: '电影' for films, '电视剧' for television series or anime series")
season: Optional[int] = Field(None, description="Season number for TV shows (optional, if not specified will subscribe to all seasons)")
tmdb_id: Optional[str] = Field(None, description="TMDB database ID for precise media identification (optional but recommended for accuracy)")
media_type: str = Field(...,
description="Type of media content: '电影' for films, '电视剧' for television series or anime series")
season: Optional[int] = Field(None,
description="Season number for TV shows (optional, if not specified will subscribe to all seasons)")
tmdb_id: Optional[str] = Field(None,
description="TMDB database ID for precise media identification (optional but recommended for accuracy)")
class AddSubscribeTool(MoviePilotTool):
@@ -25,13 +28,13 @@ class AddSubscribeTool(MoviePilotTool):
description: str = "Add media subscription to create automated download rules for movies and TV shows. The system will automatically search and download new episodes or releases based on the subscription criteria."
args_schema: Type[BaseModel] = AddSubscribeInput
async def _arun(self, title: str, year: str, media_type: str,
season: Optional[int] = None, tmdb_id: Optional[str] = None, **kwargs) -> str:
async def run(self, title: str, year: str, media_type: str,
season: Optional[int] = None, tmdb_id: Optional[str] = None, **kwargs) -> str:
logger.info(
f"执行工具: {self.name}, 参数: title={title}, year={year}, media_type={media_type}, season={season}, tmdb_id={tmdb_id}")
# 发送工具执行说明
self.send_tool_message(f"正在添加订阅: {title} ({year}) - {media_type}", title="添加订阅")
await self.send_tool_message(f"正在添加订阅: {title} ({year}) - {media_type}", title="添加订阅")
try:
subscribe_chain = SubscribeChain()
@@ -53,14 +56,14 @@ class AddSubscribeTool(MoviePilotTool):
)
if sid:
success_message = f"成功添加订阅:{title} ({year})"
self.send_tool_message(success_message, title="订阅成功")
await self.send_tool_message(success_message, title="订阅成功")
return success_message
else:
error_message = f"添加订阅失败:{message}"
self.send_tool_message(error_message, title="订阅失败")
await self.send_tool_message(error_message, title="订阅失败")
return error_message
except Exception as e:
error_message = f"添加订阅时发生错误: {str(e)}"
logger.error(f"添加订阅失败: {e}", exc_info=True)
self.send_tool_message(error_message, title="订阅失败")
await self.send_tool_message(error_message, title="订阅失败")
return error_message

View File

@@ -5,17 +5,20 @@ from typing import Optional, Type
from pydantic import BaseModel, Field
from app.agent.tools.base import MoviePilotTool
from app.chain.recommend import RecommendChain
from app.log import logger
from app.agent.tools.base import MoviePilotTool
class GetRecommendationsInput(BaseModel):
"""获取推荐工具的输入参数模型"""
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
source: Optional[str] = Field("tmdb_trending", description="Recommendation source: 'tmdb_trending' for TMDB trending content, 'douban_hot' for Douban popular content, 'bangumi_calendar' for Bangumi anime calendar")
media_type: Optional[str] = Field("all", description="Type of media content: '电影' for films, '电视剧' for television series or anime series, 'all' for all types")
limit: Optional[int] = Field(20, description="Maximum number of recommendations to return (default: 20, maximum: 100)")
source: Optional[str] = Field("tmdb_trending",
description="Recommendation source: 'tmdb_trending' for TMDB trending content, 'douban_hot' for Douban popular content, 'bangumi_calendar' for Bangumi anime calendar")
media_type: Optional[str] = Field("all",
description="Type of media content: '电影' for films, '电视剧' for television series or anime series, 'all' for all types")
limit: Optional[int] = Field(20,
description="Maximum number of recommendations to return (default: 20, maximum: 100)")
class GetRecommendationsTool(MoviePilotTool):
@@ -23,8 +26,8 @@ class GetRecommendationsTool(MoviePilotTool):
description: str = "Get trending and popular media recommendations from various sources. Returns curated lists of popular movies, TV shows, and anime based on different criteria like trending, ratings, or calendar schedules."
args_schema: Type[BaseModel] = GetRecommendationsInput
async def _arun(self, source: Optional[str] = "tmdb_trending",
media_type: Optional[str] = "all", limit: Optional[int] = 20, **kwargs) -> str:
async def run(self, source: Optional[str] = "tmdb_trending",
media_type: Optional[str] = "all", limit: Optional[int] = 20, **kwargs) -> str:
logger.info(f"执行工具: {self.name}, 参数: source={source}, media_type={media_type}, limit={limit}")
try:
recommend_chain = RecommendChain()
@@ -36,7 +39,7 @@ class GetRecommendationsTool(MoviePilotTool):
results = recommend_chain.douban_movie_hot(limit=limit)
elif media_type == "tv":
results = recommend_chain.douban_tv_hot(limit=limit)
else: # all
else: # all
results.extend(recommend_chain.douban_movie_hot(limit=limit))
results.extend(recommend_chain.douban_tv_hot(limit=limit))
elif source == "bangumi_calendar":

View File

@@ -5,10 +5,10 @@ from typing import Type
from pydantic import BaseModel, Field
from app.agent.tools.base import MoviePilotTool
from app.db.systemconfig_oper import SystemConfigOper
from app.log import logger
from app.schemas.types import SystemConfigKey
from app.agent.tools.base import MoviePilotTool
class QueryDownloadersInput(BaseModel):
@@ -21,7 +21,7 @@ class QueryDownloadersTool(MoviePilotTool):
description: str = "Query downloader configuration and list all available downloaders. Shows downloader status, connection details, and configuration settings."
args_schema: Type[BaseModel] = QueryDownloadersInput
async def _arun(self, **kwargs) -> str:
async def run(self, **kwargs) -> str:
logger.info(f"执行工具: {self.name}")
try:
system_config_oper = SystemConfigOper()

View File

@@ -5,16 +5,18 @@ from typing import Optional, Type
from pydantic import BaseModel, Field
from app.agent.tools.base import MoviePilotTool
from app.chain.download import DownloadChain
from app.log import logger
from app.agent.tools.base import MoviePilotTool
class QueryDownloadsInput(BaseModel):
"""查询下载工具的输入参数模型"""
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
downloader: Optional[str] = Field(None, description="Name of specific downloader to query (optional, if not provided queries all configured downloaders)")
status: Optional[str] = Field("all", description="Filter downloads by status: 'downloading' for active downloads, 'completed' for finished downloads, 'paused' for paused downloads, 'all' for all downloads")
downloader: Optional[str] = Field(None,
description="Name of specific downloader to query (optional, if not provided queries all configured downloaders)")
status: Optional[str] = Field("all",
description="Filter downloads by status: 'downloading' for active downloads, 'completed' for finished downloads, 'paused' for paused downloads, 'all' for all downloads")
class QueryDownloadsTool(MoviePilotTool):
@@ -22,8 +24,8 @@ class QueryDownloadsTool(MoviePilotTool):
description: str = "Query download status and list all active download tasks. Shows download progress, completion status, and task details from configured downloaders."
args_schema: Type[BaseModel] = QueryDownloadsInput
async def _arun(self, downloader: Optional[str] = None,
status: Optional[str] = "all", **kwargs) -> str:
async def run(self, downloader: Optional[str] = None,
status: Optional[str] = "all", **kwargs) -> str:
logger.info(f"执行工具: {self.name}, 参数: downloader={downloader}, status={status}")
try:
download_chain = DownloadChain()

View File

@@ -14,9 +14,12 @@ from app.schemas import MediaServerItem
class QueryMediaLibraryInput(BaseModel):
"""查询媒体库工具的输入参数模型"""
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
media_type: Optional[str] = Field("all", description="Type of media content: '电影' for films, '电视剧' for television series or anime series, 'all' for all types")
title: Optional[str] = Field(None, description="Specific media title to search for (optional, if provided returns detailed info for that specific media)")
year: Optional[str] = Field(None, description="Release year of the media (optional, helps narrow down search results)")
media_type: Optional[str] = Field("all",
description="Type of media content: '电影' for films, '电视剧' for television series or anime series, 'all' for all types")
title: Optional[str] = Field(None,
description="Specific media title to search for (optional, if provided returns detailed info for that specific media)")
year: Optional[str] = Field(None,
description="Release year of the media (optional, helps narrow down search results)")
class QueryMediaLibraryTool(MoviePilotTool):
@@ -24,8 +27,8 @@ class QueryMediaLibraryTool(MoviePilotTool):
description: str = "Query media library status and list all media files that have been successfully processed and added to the media server (Plex, Emby, Jellyfin). Shows library statistics and file details."
args_schema: Type[BaseModel] = QueryMediaLibraryInput
async def _arun(self, media_type: Optional[str] = "all",
title: Optional[str] = None, year: Optional[str] = None, **kwargs) -> str:
async def run(self, media_type: Optional[str] = "all",
title: Optional[str] = None, year: Optional[str] = None, **kwargs) -> str:
logger.info(f"执行工具: {self.name}, 参数: media_type={media_type}, title={title}")
try:
media_server_oper = MediaServerOper()

View File

@@ -5,16 +5,18 @@ from typing import Optional, Type
from pydantic import BaseModel, Field
from app.agent.tools.base import MoviePilotTool
from app.db.subscribe_oper import SubscribeOper
from app.log import logger
from app.agent.tools.base import MoviePilotTool
class QuerySubscribesInput(BaseModel):
"""查询订阅工具的输入参数模型"""
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
status: Optional[str] = Field("all", description="Filter subscriptions by status: 'R' for enabled subscriptions, 'P' for disabled ones, 'all' for all subscriptions")
media_type: Optional[str] = Field("all", description="Filter by media type: 'movie' for films, 'tv' for television series, 'all' for all types")
status: Optional[str] = Field("all",
description="Filter subscriptions by status: 'R' for enabled subscriptions, 'P' for disabled ones, 'all' for all subscriptions")
media_type: Optional[str] = Field("all",
description="Filter by media type: 'movie' for films, 'tv' for television series, 'all' for all types")
class QuerySubscribesTool(MoviePilotTool):
@@ -22,7 +24,7 @@ class QuerySubscribesTool(MoviePilotTool):
description: str = "Query subscription status and list all user subscriptions. Shows active subscriptions, their download status, and configuration details."
args_schema: Type[BaseModel] = QuerySubscribesInput
async def _arun(self, status: Optional[str] = "all", media_type: Optional[str] = "all", **kwargs) -> str:
async def run(self, status: Optional[str] = "all", media_type: Optional[str] = "all", **kwargs) -> str:
logger.info(f"执行工具: {self.name}, 参数: status={status}, media_type={media_type}")
try:
subscribe_oper = SubscribeOper()

View File

@@ -16,8 +16,10 @@ class SearchMediaInput(BaseModel):
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
title: str = Field(..., description="The title of the media to search for (e.g., 'The Matrix', 'Breaking Bad')")
year: Optional[str] = Field(None, description="Release year of the media (optional, helps narrow down results)")
media_type: Optional[str] = Field(None, description="Type of media content: '电影' for films, '电视剧' for television series or anime series")
season: Optional[int] = Field(None, description="Season number for TV shows and anime (optional, only applicable for series)")
media_type: Optional[str] = Field(None,
description="Type of media content: '电影' for films, '电视剧' for television series or anime series")
season: Optional[int] = Field(None,
description="Season number for TV shows and anime (optional, only applicable for series)")
class SearchMediaTool(MoviePilotTool):
@@ -25,13 +27,13 @@ class SearchMediaTool(MoviePilotTool):
description: str = "Search for media resources including movies, TV shows, anime, etc. Supports searching by title, year, type, and other criteria. Returns detailed media information from TMDB database."
args_schema: Type[BaseModel] = SearchMediaInput
async def _arun(self, title: str, year: Optional[str] = None,
media_type: Optional[str] = None, season: Optional[int] = None, **kwargs) -> str:
async def run(self, title: str, year: Optional[str] = None,
media_type: Optional[str] = None, season: Optional[int] = None, **kwargs) -> str:
logger.info(
f"执行工具: {self.name}, 参数: title={title}, year={year}, media_type={media_type}, season={season}")
# 发送工具执行说明
self.send_tool_message(f"正在搜索媒体资源: {title}" + (f" ({year})" if year else ""), title="搜索中")
await self.send_tool_message(f"正在搜索媒体资源: {title}" + (f" ({year})" if year else ""), title="搜索中")
try:
media_chain = MediaChain()
@@ -62,24 +64,24 @@ class SearchMediaTool(MoviePilotTool):
if filtered_results:
result_message = f"找到 {len(filtered_results)} 个相关媒体资源"
self.send_tool_message(result_message, title="搜索成功")
await self.send_tool_message(result_message, title="搜索成功")
# 发送详细结果
for i, result in enumerate(filtered_results[:5]): # 只显示前5个结果
media_info = f"{i + 1}. {result.title} ({result.year}) - {result.type.value if result.type else '未知'}"
self.send_tool_message(media_info, title="搜索结果")
await self.send_tool_message(media_info, title="搜索结果")
return json.dumps([r.to_dict() for r in filtered_results], ensure_ascii=False, indent=2)
else:
error_message = f"未找到符合条件的媒体资源: {title}"
self.send_tool_message(error_message, title="搜索完成")
await self.send_tool_message(error_message, title="搜索完成")
return error_message
else:
error_message = f"未找到相关媒体资源: {title}"
self.send_tool_message(error_message, title="搜索完成")
await self.send_tool_message(error_message, title="搜索完成")
return error_message
except Exception as e:
error_message = f"搜索媒体失败: {str(e)}"
logger.error(f"搜索媒体失败: {e}", exc_info=True)
self.send_tool_message(error_message, title="搜索失败")
await self.send_tool_message(error_message, title="搜索失败")
return error_message

View File

@@ -14,11 +14,15 @@ from app.schemas.types import MediaType
class SearchTorrentsInput(BaseModel):
"""搜索种子工具的输入参数模型"""
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
title: str = Field(..., description="The title of the media resource to search for (e.g., 'The Matrix 1999', 'Breaking Bad S01E01')")
year: Optional[str] = Field(None, description="Release year of the media (optional, helps narrow down search results)")
media_type: Optional[str] = Field(None, description="Type of media content: '电影' for films, '电视剧' for television series or anime series")
title: str = Field(...,
description="The title of the media resource to search for (e.g., 'The Matrix 1999', 'Breaking Bad S01E01')")
year: Optional[str] = Field(None,
description="Release year of the media (optional, helps narrow down search results)")
media_type: Optional[str] = Field(None,
description="Type of media content: '电影' for films, '电视剧' for television series or anime series")
season: Optional[int] = Field(None, description="Season number for TV shows (optional, only applicable for series)")
sites: Optional[List[int]] = Field(None, description="Array of specific site IDs to search on (optional, if not provided searches all configured sites)")
sites: Optional[List[int]] = Field(None,
description="Array of specific site IDs to search on (optional, if not provided searches all configured sites)")
class SearchTorrentsTool(MoviePilotTool):
@@ -26,14 +30,14 @@ class SearchTorrentsTool(MoviePilotTool):
description: str = "Search for torrent files across configured indexer sites based on media information. Returns available torrent downloads with details like file size, quality, and download links."
args_schema: Type[BaseModel] = SearchTorrentsInput
async def _arun(self, title: str, year: Optional[str] = None,
media_type: Optional[str] = None, season: Optional[int] = None,
sites: Optional[List[int]] = None, **kwargs) -> str:
async def run(self, title: str, year: Optional[str] = None,
media_type: Optional[str] = None, season: Optional[int] = None,
sites: Optional[List[int]] = None, **kwargs) -> str:
logger.info(
f"执行工具: {self.name}, 参数: title={title}, year={year}, media_type={media_type}, season={season}, sites={sites}")
# 发送工具执行说明
self.send_tool_message(f"正在搜索种子资源: {title}" + (f" ({year})" if year else ""), title="搜索种子")
await self.send_tool_message(f"正在搜索种子资源: {title}" + (f" ({year})" if year else ""), title="搜索种子")
try:
search_chain = SearchChain()
@@ -52,22 +56,22 @@ class SearchTorrentsTool(MoviePilotTool):
if filtered_torrents:
result_message = f"找到 {len(filtered_torrents)} 个相关种子资源"
self.send_tool_message(result_message, title="搜索成功")
await self.send_tool_message(result_message, title="搜索成功")
# 发送详细结果
for i, torrent in enumerate(filtered_torrents[:5]): # 只显示前5个结果
torrent_title = torrent.torrent_info.title if torrent.torrent_info else torrent.meta_info.title if torrent.meta_info else "未知"
site_name = torrent.torrent_info.site_name if torrent.torrent_info else "未知站点"
torrent_info = f"{i + 1}. {torrent_title} - {site_name}"
self.send_tool_message(torrent_info, title="搜索结果")
await self.send_tool_message(torrent_info, title="搜索结果")
return json.dumps([t.to_dict() for t in filtered_torrents], ensure_ascii=False, indent=2)
else:
error_message = f"未找到相关种子资源: {title}"
self.send_tool_message(error_message, title="搜索完成")
await self.send_tool_message(error_message, title="搜索完成")
return error_message
except Exception as e:
error_message = f"搜索种子时发生错误: {str(e)}"
logger.error(f"搜索种子失败: {e}", exc_info=True)
self.send_tool_message(error_message, title="搜索失败")
await self.send_tool_message(error_message, title="搜索失败")
return error_message

View File

@@ -12,7 +12,8 @@ class SendMessageInput(BaseModel):
"""发送消息工具的输入参数模型"""
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
message: str = Field(..., description="The message content to send to the user (should be clear and informative)")
message_type: Optional[str] = Field("info", description="Type of message: 'info' for general information, 'success' for successful operations, 'warning' for warnings, 'error' for error messages")
message_type: Optional[str] = Field("info",
description="Type of message: 'info' for general information, 'success' for successful operations, 'warning' for warnings, 'error' for error messages")
class SendMessageTool(MoviePilotTool):
@@ -20,10 +21,10 @@ class SendMessageTool(MoviePilotTool):
description: str = "Send notification message to the user through configured notification channels (Telegram, Slack, WeChat, etc.). Used to inform users about operation results, errors, or important updates."
args_schema: Type[BaseModel] = SendMessageInput
async def _arun(self, message: str, message_type: Optional[str] = None, **kwargs) -> str:
async def run(self, message: str, message_type: Optional[str] = None, **kwargs) -> str:
logger.info(f"执行工具: {self.name}, 参数: message={message}, message_type={message_type}")
try:
self.send_tool_message(message, title=message_type)
await self.send_tool_message(message, title=message_type)
return "消息已发送。"
except Exception as e:
logger.error(f"发送消息失败: {e}")