diff --git a/app/agent/__init__.py b/app/agent/__init__.py index 9e749906..b414d86b 100644 --- a/app/agent/__init__.py +++ b/app/agent/__init__.py @@ -1,17 +1,16 @@ """MoviePilot AI智能体实现""" import asyncio -import threading from typing import Dict, List, Any from langchain.agents import AgentExecutor, create_openai_tools_agent from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_community.callbacks import get_openai_callback -from langchain_core.callbacks import AsyncCallbackHandler from langchain_core.chat_history import InMemoryChatMessageHistory from langchain_core.messages import HumanMessage, AIMessage, ToolCall from langchain_core.runnables.history import RunnableWithMessageHistory +from app.agent.callback import StreamingCallbackHandler from app.agent.memory import ConversationMemoryManager from app.agent.prompt import PromptManager from app.agent.tools import MoviePilotToolFactory @@ -26,34 +25,6 @@ class AgentChain(ChainBase): pass -class StreamingCallbackHandler(AsyncCallbackHandler): - """流式输出回调处理器""" - - def __init__(self, session_id: str): - self._lock = threading.Lock() - self.session_id = session_id - self.current_message = "" - self.message_helper = MessageHelper() - - async def get_message(self): - """获取当前消息内容,获取后清空""" - with self._lock: - if not self.current_message: - return "" - msg = self.current_message - logger.info(f"Agent消息: {msg}") - self.current_message = "" - return msg - - async def on_llm_new_token(self, token: str, **kwargs): - """处理新的token""" - if not token: - return - with self._lock: - # 缓存当前消息 - self.current_message += token - - class MoviePilotAgent: """MoviePilot AI智能体""" @@ -142,7 +113,8 @@ class MoviePilotAgent: user_id=self.user_id, channel=self.channel, source=self.source, - username=self.username + username=self.username, + callback_handler=self.callback_handler ) @staticmethod @@ -249,7 +221,7 @@ class MoviePilotAgent: agent_message = await self.callback_handler.get_message() # 发送Agent回复给用户(通过原渠道) - self.send_agent_message(agent_message) + await self.send_agent_message(agent_message) # 添加Agent回复到记忆 await self.memory_manager.add_memory( @@ -265,7 +237,7 @@ class MoviePilotAgent: error_message = f"处理消息时发生错误: {str(e)}" logger.error(error_message) # 发送错误消息给用户(通过原渠道) - self.send_agent_message(error_message) + await self.send_agent_message(error_message) return error_message async def _execute_agent(self, input_context: Dict[str, Any]) -> Dict[str, Any]: @@ -301,9 +273,9 @@ class MoviePilotAgent: "token_usage": {} } - def send_agent_message(self, message: str, title: str = "MoviePilot助手"): + async def send_agent_message(self, message: str, title: str = "MoviePilot助手"): """通过原渠道发送消息给用户""" - AgentChain().post_message( + await AgentChain().async_post_message( Notification( channel=self.channel, source=self.source, diff --git a/app/agent/callback/__init__.py b/app/agent/callback/__init__.py new file mode 100644 index 00000000..0e511af5 --- /dev/null +++ b/app/agent/callback/__init__.py @@ -0,0 +1,33 @@ +import threading + +from langchain_core.callbacks import AsyncCallbackHandler + +from app.log import logger + + +class StreamingCallbackHandler(AsyncCallbackHandler): + """流式输出回调处理器""" + + def __init__(self, session_id: str): + self._lock = threading.Lock() + self.session_id = session_id + self.current_message = "" + + async def get_message(self): + """获取当前消息内容,获取后清空""" + with self._lock: + if not self.current_message: + return "" + msg = self.current_message + logger.info(f"Agent消息: {msg}") + self.current_message = "" + return msg + + async def on_llm_new_token(self, token: str, **kwargs): + """处理新的token""" + if not token: + return + with self._lock: + # 缓存当前消息 + self.current_message += token + diff --git a/app/agent/tools/base.py b/app/agent/tools/base.py index 8881cc95..ec4c7740 100644 --- a/app/agent/tools/base.py +++ b/app/agent/tools/base.py @@ -1,8 +1,11 @@ """MoviePilot工具基类""" +from abc import ABCMeta, abstractmethod +from typing import Callable, Any from langchain.tools import BaseTool from pydantic import PrivateAttr +from app.agent import StreamingCallbackHandler from app.chain import ChainBase from app.schemas import Notification @@ -11,7 +14,7 @@ class ToolChain(ChainBase): pass -class MoviePilotTool(BaseTool): +class MoviePilotTool(BaseTool, metaclass=ABCMeta): """MoviePilot专用工具基类""" _session_id: str = PrivateAttr() @@ -19,16 +22,26 @@ class MoviePilotTool(BaseTool): _channel: str = PrivateAttr(default=None) _source: str = PrivateAttr(default=None) _username: str = PrivateAttr(default=None) + _callback_handler: StreamingCallbackHandler = PrivateAttr(default=None) def __init__(self, session_id: str, user_id: str, **kwargs): super().__init__(**kwargs) self._session_id = session_id self._user_id = user_id - def _run(self, **kwargs) -> str: - raise NotImplementedError + def _run(self, *args: Any, **kwargs: Any) -> Any: + pass async def _arun(self, **kwargs) -> str: + """异步运行工具""" + # 发送运行工具前的消息 + agent_message = await self._callback_handler.get_message() + if agent_message: + await self.send_tool_message(agent_message) + return await self.run(**kwargs) + + @abstractmethod + async def run(self, **kwargs) -> str: raise NotImplementedError def set_message_attr(self, channel: str, source: str, username: str): @@ -37,9 +50,13 @@ class MoviePilotTool(BaseTool): self._source = source self._username = username - def send_tool_message(self, message: str, title: str = "执行工具"): + def set_callback_handler(self, callback_handler: StreamingCallbackHandler): + """设置回调处理器""" + self._callback_handler = callback_handler + + async def send_tool_message(self, message: str, title: str = "执行工具"): """发送工具消息""" - ToolChain().post_message( + await ToolChain().async_post_message( Notification( channel=self._channel, source=self._source, diff --git a/app/agent/tools/factory.py b/app/agent/tools/factory.py index 63062b5e..974bb29f 100644 --- a/app/agent/tools/factory.py +++ b/app/agent/tools/factory.py @@ -1,6 +1,6 @@ """MoviePilot工具工厂""" -from typing import List +from typing import List, Callable from app.agent.tools.impl.add_download import AddDownloadTool from app.agent.tools.impl.add_subscribe import AddSubscribeTool @@ -21,7 +21,8 @@ class MoviePilotToolFactory: @staticmethod def create_tools(session_id: str, user_id: str, - channel: str = None, source: str = None, username: str = None) -> List[MoviePilotTool]: + channel: str = None, source: str = None, username: str = None, + callback_handler: Callable = None) -> List[MoviePilotTool]: """创建MoviePilot工具列表""" tools = [] tool_definitions = [ @@ -42,6 +43,7 @@ class MoviePilotToolFactory: user_id=user_id ) tool.set_message_attr(channel=channel, source=source, username=username) + tool.set_callback_handler(callback_handler=callback_handler) tools.append(tool) logger.info(f"成功创建 {len(tools)} 个MoviePilot工具") return tools diff --git a/app/agent/tools/impl/add_download.py b/app/agent/tools/impl/add_download.py index b27e472d..6d5d8395 100644 --- a/app/agent/tools/impl/add_download.py +++ b/app/agent/tools/impl/add_download.py @@ -15,11 +15,15 @@ from app.schemas import TorrentInfo class AddDownloadInput(BaseModel): """添加下载工具的输入参数模型""" explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context") - torrent_title: str = Field(..., description="The display name/title of the torrent (e.g., 'The.Matrix.1999.1080p.BluRay.x264')") + torrent_title: str = Field(..., + description="The display name/title of the torrent (e.g., 'The.Matrix.1999.1080p.BluRay.x264')") torrent_url: str = Field(..., description="Direct URL to the torrent file (.torrent) or magnet link") - downloader: Optional[str] = Field(None, description="Name of the downloader to use (optional, uses default if not specified)") - save_path: Optional[str] = Field(None, description="Directory path where the downloaded files should be saved (optional, uses default path if not specified)") - labels: Optional[str] = Field(None, description="Comma-separated list of labels/tags to assign to the download (optional, e.g., 'movie,hd,bluray')") + downloader: Optional[str] = Field(None, + description="Name of the downloader to use (optional, uses default if not specified)") + save_path: Optional[str] = Field(None, + description="Directory path where the downloaded files should be saved (optional, uses default path if not specified)") + labels: Optional[str] = Field(None, + description="Comma-separated list of labels/tags to assign to the download (optional, e.g., 'movie,hd,bluray')") class AddDownloadTool(MoviePilotTool): @@ -27,19 +31,19 @@ class AddDownloadTool(MoviePilotTool): description: str = "Add torrent download task to the configured downloader (qBittorrent, Transmission, etc.). Downloads the torrent file and starts the download process with specified settings." args_schema: Type[BaseModel] = AddDownloadInput - async def _arun(self, torrent_title: str, torrent_url: str, - downloader: Optional[str] = None, save_path: Optional[str] = None, - labels: Optional[str] = None, **kwargs) -> str: + async def run(self, torrent_title: str, torrent_url: str, + downloader: Optional[str] = None, save_path: Optional[str] = None, + labels: Optional[str] = None, **kwargs) -> str: logger.info( f"执行工具: {self.name}, 参数: torrent_title={torrent_title}, torrent_url={torrent_url}, downloader={downloader}, save_path={save_path}, labels={labels}") # 发送工具执行说明 - self.send_tool_message(f"正在添加下载任务: {torrent_title}", title="添加下载") + await self.send_tool_message(f"正在添加下载任务: {torrent_title}", title="添加下载") try: if not torrent_title or not torrent_url: error_message = "错误:必须提供种子标题和下载链接" - self.send_tool_message(error_message, title="下载失败") + await self.send_tool_message(error_message, title="下载失败") return error_message # 使用DownloadChain添加下载 @@ -64,14 +68,14 @@ class AddDownloadTool(MoviePilotTool): ) if did: success_message = f"成功添加下载任务:{torrent_title}" - self.send_tool_message(success_message, title="下载成功") + await self.send_tool_message(success_message, title="下载成功") return success_message else: error_message = "添加下载任务失败" - self.send_tool_message(error_message, title="下载失败") + await self.send_tool_message(error_message, title="下载失败") return error_message except Exception as e: error_message = f"添加下载任务时发生错误: {str(e)}" logger.error(f"添加下载任务失败: {e}", exc_info=True) - self.send_tool_message(error_message, title="下载失败") + await self.send_tool_message(error_message, title="下载失败") return error_message diff --git a/app/agent/tools/impl/add_subscribe.py b/app/agent/tools/impl/add_subscribe.py index 676a9036..f8533f7c 100644 --- a/app/agent/tools/impl/add_subscribe.py +++ b/app/agent/tools/impl/add_subscribe.py @@ -15,9 +15,12 @@ class AddSubscribeInput(BaseModel): explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context") title: str = Field(..., description="The title of the media to subscribe to (e.g., 'The Matrix', 'Breaking Bad')") year: str = Field(..., description="Release year of the media (required for accurate identification)") - media_type: str = Field(..., description="Type of media content: '电影' for films, '电视剧' for television series or anime series") - season: Optional[int] = Field(None, description="Season number for TV shows (optional, if not specified will subscribe to all seasons)") - tmdb_id: Optional[str] = Field(None, description="TMDB database ID for precise media identification (optional but recommended for accuracy)") + media_type: str = Field(..., + description="Type of media content: '电影' for films, '电视剧' for television series or anime series") + season: Optional[int] = Field(None, + description="Season number for TV shows (optional, if not specified will subscribe to all seasons)") + tmdb_id: Optional[str] = Field(None, + description="TMDB database ID for precise media identification (optional but recommended for accuracy)") class AddSubscribeTool(MoviePilotTool): @@ -25,13 +28,13 @@ class AddSubscribeTool(MoviePilotTool): description: str = "Add media subscription to create automated download rules for movies and TV shows. The system will automatically search and download new episodes or releases based on the subscription criteria." args_schema: Type[BaseModel] = AddSubscribeInput - async def _arun(self, title: str, year: str, media_type: str, - season: Optional[int] = None, tmdb_id: Optional[str] = None, **kwargs) -> str: + async def run(self, title: str, year: str, media_type: str, + season: Optional[int] = None, tmdb_id: Optional[str] = None, **kwargs) -> str: logger.info( f"执行工具: {self.name}, 参数: title={title}, year={year}, media_type={media_type}, season={season}, tmdb_id={tmdb_id}") # 发送工具执行说明 - self.send_tool_message(f"正在添加订阅: {title} ({year}) - {media_type}", title="添加订阅") + await self.send_tool_message(f"正在添加订阅: {title} ({year}) - {media_type}", title="添加订阅") try: subscribe_chain = SubscribeChain() @@ -53,14 +56,14 @@ class AddSubscribeTool(MoviePilotTool): ) if sid: success_message = f"成功添加订阅:{title} ({year})" - self.send_tool_message(success_message, title="订阅成功") + await self.send_tool_message(success_message, title="订阅成功") return success_message else: error_message = f"添加订阅失败:{message}" - self.send_tool_message(error_message, title="订阅失败") + await self.send_tool_message(error_message, title="订阅失败") return error_message except Exception as e: error_message = f"添加订阅时发生错误: {str(e)}" logger.error(f"添加订阅失败: {e}", exc_info=True) - self.send_tool_message(error_message, title="订阅失败") + await self.send_tool_message(error_message, title="订阅失败") return error_message diff --git a/app/agent/tools/impl/get_recommendations.py b/app/agent/tools/impl/get_recommendations.py index e14a7048..60e96ccb 100644 --- a/app/agent/tools/impl/get_recommendations.py +++ b/app/agent/tools/impl/get_recommendations.py @@ -5,17 +5,20 @@ from typing import Optional, Type from pydantic import BaseModel, Field +from app.agent.tools.base import MoviePilotTool from app.chain.recommend import RecommendChain from app.log import logger -from app.agent.tools.base import MoviePilotTool class GetRecommendationsInput(BaseModel): """获取推荐工具的输入参数模型""" explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context") - source: Optional[str] = Field("tmdb_trending", description="Recommendation source: 'tmdb_trending' for TMDB trending content, 'douban_hot' for Douban popular content, 'bangumi_calendar' for Bangumi anime calendar") - media_type: Optional[str] = Field("all", description="Type of media content: '电影' for films, '电视剧' for television series or anime series, 'all' for all types") - limit: Optional[int] = Field(20, description="Maximum number of recommendations to return (default: 20, maximum: 100)") + source: Optional[str] = Field("tmdb_trending", + description="Recommendation source: 'tmdb_trending' for TMDB trending content, 'douban_hot' for Douban popular content, 'bangumi_calendar' for Bangumi anime calendar") + media_type: Optional[str] = Field("all", + description="Type of media content: '电影' for films, '电视剧' for television series or anime series, 'all' for all types") + limit: Optional[int] = Field(20, + description="Maximum number of recommendations to return (default: 20, maximum: 100)") class GetRecommendationsTool(MoviePilotTool): @@ -23,8 +26,8 @@ class GetRecommendationsTool(MoviePilotTool): description: str = "Get trending and popular media recommendations from various sources. Returns curated lists of popular movies, TV shows, and anime based on different criteria like trending, ratings, or calendar schedules." args_schema: Type[BaseModel] = GetRecommendationsInput - async def _arun(self, source: Optional[str] = "tmdb_trending", - media_type: Optional[str] = "all", limit: Optional[int] = 20, **kwargs) -> str: + async def run(self, source: Optional[str] = "tmdb_trending", + media_type: Optional[str] = "all", limit: Optional[int] = 20, **kwargs) -> str: logger.info(f"执行工具: {self.name}, 参数: source={source}, media_type={media_type}, limit={limit}") try: recommend_chain = RecommendChain() @@ -36,12 +39,12 @@ class GetRecommendationsTool(MoviePilotTool): results = recommend_chain.douban_movie_hot(limit=limit) elif media_type == "tv": results = recommend_chain.douban_tv_hot(limit=limit) - else: # all + else: # all results.extend(recommend_chain.douban_movie_hot(limit=limit)) results.extend(recommend_chain.douban_tv_hot(limit=limit)) elif source == "bangumi_calendar": results = recommend_chain.bangumi_calendar(limit=limit) - + if results: # 使用 to_dict() 方法 return json.dumps(results) diff --git a/app/agent/tools/impl/query_downloaders.py b/app/agent/tools/impl/query_downloaders.py index e7a36b69..a7a3c25d 100644 --- a/app/agent/tools/impl/query_downloaders.py +++ b/app/agent/tools/impl/query_downloaders.py @@ -5,10 +5,10 @@ from typing import Type from pydantic import BaseModel, Field +from app.agent.tools.base import MoviePilotTool from app.db.systemconfig_oper import SystemConfigOper from app.log import logger from app.schemas.types import SystemConfigKey -from app.agent.tools.base import MoviePilotTool class QueryDownloadersInput(BaseModel): @@ -21,7 +21,7 @@ class QueryDownloadersTool(MoviePilotTool): description: str = "Query downloader configuration and list all available downloaders. Shows downloader status, connection details, and configuration settings." args_schema: Type[BaseModel] = QueryDownloadersInput - async def _arun(self, **kwargs) -> str: + async def run(self, **kwargs) -> str: logger.info(f"执行工具: {self.name}") try: system_config_oper = SystemConfigOper() diff --git a/app/agent/tools/impl/query_downloads.py b/app/agent/tools/impl/query_downloads.py index c24125c1..4cd15d4b 100644 --- a/app/agent/tools/impl/query_downloads.py +++ b/app/agent/tools/impl/query_downloads.py @@ -5,16 +5,18 @@ from typing import Optional, Type from pydantic import BaseModel, Field +from app.agent.tools.base import MoviePilotTool from app.chain.download import DownloadChain from app.log import logger -from app.agent.tools.base import MoviePilotTool class QueryDownloadsInput(BaseModel): """查询下载工具的输入参数模型""" explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context") - downloader: Optional[str] = Field(None, description="Name of specific downloader to query (optional, if not provided queries all configured downloaders)") - status: Optional[str] = Field("all", description="Filter downloads by status: 'downloading' for active downloads, 'completed' for finished downloads, 'paused' for paused downloads, 'all' for all downloads") + downloader: Optional[str] = Field(None, + description="Name of specific downloader to query (optional, if not provided queries all configured downloaders)") + status: Optional[str] = Field("all", + description="Filter downloads by status: 'downloading' for active downloads, 'completed' for finished downloads, 'paused' for paused downloads, 'all' for all downloads") class QueryDownloadsTool(MoviePilotTool): @@ -22,8 +24,8 @@ class QueryDownloadsTool(MoviePilotTool): description: str = "Query download status and list all active download tasks. Shows download progress, completion status, and task details from configured downloaders." args_schema: Type[BaseModel] = QueryDownloadsInput - async def _arun(self, downloader: Optional[str] = None, - status: Optional[str] = "all", **kwargs) -> str: + async def run(self, downloader: Optional[str] = None, + status: Optional[str] = "all", **kwargs) -> str: logger.info(f"执行工具: {self.name}, 参数: downloader={downloader}, status={status}") try: download_chain = DownloadChain() diff --git a/app/agent/tools/impl/query_media_library.py b/app/agent/tools/impl/query_media_library.py index 8c2fbdd3..b6943c1c 100644 --- a/app/agent/tools/impl/query_media_library.py +++ b/app/agent/tools/impl/query_media_library.py @@ -14,9 +14,12 @@ from app.schemas import MediaServerItem class QueryMediaLibraryInput(BaseModel): """查询媒体库工具的输入参数模型""" explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context") - media_type: Optional[str] = Field("all", description="Type of media content: '电影' for films, '电视剧' for television series or anime series, 'all' for all types") - title: Optional[str] = Field(None, description="Specific media title to search for (optional, if provided returns detailed info for that specific media)") - year: Optional[str] = Field(None, description="Release year of the media (optional, helps narrow down search results)") + media_type: Optional[str] = Field("all", + description="Type of media content: '电影' for films, '电视剧' for television series or anime series, 'all' for all types") + title: Optional[str] = Field(None, + description="Specific media title to search for (optional, if provided returns detailed info for that specific media)") + year: Optional[str] = Field(None, + description="Release year of the media (optional, helps narrow down search results)") class QueryMediaLibraryTool(MoviePilotTool): @@ -24,8 +27,8 @@ class QueryMediaLibraryTool(MoviePilotTool): description: str = "Query media library status and list all media files that have been successfully processed and added to the media server (Plex, Emby, Jellyfin). Shows library statistics and file details." args_schema: Type[BaseModel] = QueryMediaLibraryInput - async def _arun(self, media_type: Optional[str] = "all", - title: Optional[str] = None, year: Optional[str] = None, **kwargs) -> str: + async def run(self, media_type: Optional[str] = "all", + title: Optional[str] = None, year: Optional[str] = None, **kwargs) -> str: logger.info(f"执行工具: {self.name}, 参数: media_type={media_type}, title={title}") try: media_server_oper = MediaServerOper() diff --git a/app/agent/tools/impl/query_subscribes.py b/app/agent/tools/impl/query_subscribes.py index 0d853dd5..2e0bf266 100644 --- a/app/agent/tools/impl/query_subscribes.py +++ b/app/agent/tools/impl/query_subscribes.py @@ -5,16 +5,18 @@ from typing import Optional, Type from pydantic import BaseModel, Field +from app.agent.tools.base import MoviePilotTool from app.db.subscribe_oper import SubscribeOper from app.log import logger -from app.agent.tools.base import MoviePilotTool class QuerySubscribesInput(BaseModel): """查询订阅工具的输入参数模型""" explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context") - status: Optional[str] = Field("all", description="Filter subscriptions by status: 'R' for enabled subscriptions, 'P' for disabled ones, 'all' for all subscriptions") - media_type: Optional[str] = Field("all", description="Filter by media type: 'movie' for films, 'tv' for television series, 'all' for all types") + status: Optional[str] = Field("all", + description="Filter subscriptions by status: 'R' for enabled subscriptions, 'P' for disabled ones, 'all' for all subscriptions") + media_type: Optional[str] = Field("all", + description="Filter by media type: 'movie' for films, 'tv' for television series, 'all' for all types") class QuerySubscribesTool(MoviePilotTool): @@ -22,7 +24,7 @@ class QuerySubscribesTool(MoviePilotTool): description: str = "Query subscription status and list all user subscriptions. Shows active subscriptions, their download status, and configuration details." args_schema: Type[BaseModel] = QuerySubscribesInput - async def _arun(self, status: Optional[str] = "all", media_type: Optional[str] = "all", **kwargs) -> str: + async def run(self, status: Optional[str] = "all", media_type: Optional[str] = "all", **kwargs) -> str: logger.info(f"执行工具: {self.name}, 参数: status={status}, media_type={media_type}") try: subscribe_oper = SubscribeOper() diff --git a/app/agent/tools/impl/search_media.py b/app/agent/tools/impl/search_media.py index 0d666a10..0e5f2db6 100644 --- a/app/agent/tools/impl/search_media.py +++ b/app/agent/tools/impl/search_media.py @@ -16,8 +16,10 @@ class SearchMediaInput(BaseModel): explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context") title: str = Field(..., description="The title of the media to search for (e.g., 'The Matrix', 'Breaking Bad')") year: Optional[str] = Field(None, description="Release year of the media (optional, helps narrow down results)") - media_type: Optional[str] = Field(None, description="Type of media content: '电影' for films, '电视剧' for television series or anime series") - season: Optional[int] = Field(None, description="Season number for TV shows and anime (optional, only applicable for series)") + media_type: Optional[str] = Field(None, + description="Type of media content: '电影' for films, '电视剧' for television series or anime series") + season: Optional[int] = Field(None, + description="Season number for TV shows and anime (optional, only applicable for series)") class SearchMediaTool(MoviePilotTool): @@ -25,13 +27,13 @@ class SearchMediaTool(MoviePilotTool): description: str = "Search for media resources including movies, TV shows, anime, etc. Supports searching by title, year, type, and other criteria. Returns detailed media information from TMDB database." args_schema: Type[BaseModel] = SearchMediaInput - async def _arun(self, title: str, year: Optional[str] = None, - media_type: Optional[str] = None, season: Optional[int] = None, **kwargs) -> str: + async def run(self, title: str, year: Optional[str] = None, + media_type: Optional[str] = None, season: Optional[int] = None, **kwargs) -> str: logger.info( f"执行工具: {self.name}, 参数: title={title}, year={year}, media_type={media_type}, season={season}") # 发送工具执行说明 - self.send_tool_message(f"正在搜索媒体资源: {title}" + (f" ({year})" if year else ""), title="搜索中") + await self.send_tool_message(f"正在搜索媒体资源: {title}" + (f" ({year})" if year else ""), title="搜索中") try: media_chain = MediaChain() @@ -62,24 +64,24 @@ class SearchMediaTool(MoviePilotTool): if filtered_results: result_message = f"找到 {len(filtered_results)} 个相关媒体资源" - self.send_tool_message(result_message, title="搜索成功") + await self.send_tool_message(result_message, title="搜索成功") # 发送详细结果 for i, result in enumerate(filtered_results[:5]): # 只显示前5个结果 media_info = f"{i + 1}. {result.title} ({result.year}) - {result.type.value if result.type else '未知'}" - self.send_tool_message(media_info, title="搜索结果") + await self.send_tool_message(media_info, title="搜索结果") return json.dumps([r.to_dict() for r in filtered_results], ensure_ascii=False, indent=2) else: error_message = f"未找到符合条件的媒体资源: {title}" - self.send_tool_message(error_message, title="搜索完成") + await self.send_tool_message(error_message, title="搜索完成") return error_message else: error_message = f"未找到相关媒体资源: {title}" - self.send_tool_message(error_message, title="搜索完成") + await self.send_tool_message(error_message, title="搜索完成") return error_message except Exception as e: error_message = f"搜索媒体失败: {str(e)}" logger.error(f"搜索媒体失败: {e}", exc_info=True) - self.send_tool_message(error_message, title="搜索失败") + await self.send_tool_message(error_message, title="搜索失败") return error_message diff --git a/app/agent/tools/impl/search_torrents.py b/app/agent/tools/impl/search_torrents.py index 2d0d1589..1725a31f 100644 --- a/app/agent/tools/impl/search_torrents.py +++ b/app/agent/tools/impl/search_torrents.py @@ -14,11 +14,15 @@ from app.schemas.types import MediaType class SearchTorrentsInput(BaseModel): """搜索种子工具的输入参数模型""" explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context") - title: str = Field(..., description="The title of the media resource to search for (e.g., 'The Matrix 1999', 'Breaking Bad S01E01')") - year: Optional[str] = Field(None, description="Release year of the media (optional, helps narrow down search results)") - media_type: Optional[str] = Field(None, description="Type of media content: '电影' for films, '电视剧' for television series or anime series") + title: str = Field(..., + description="The title of the media resource to search for (e.g., 'The Matrix 1999', 'Breaking Bad S01E01')") + year: Optional[str] = Field(None, + description="Release year of the media (optional, helps narrow down search results)") + media_type: Optional[str] = Field(None, + description="Type of media content: '电影' for films, '电视剧' for television series or anime series") season: Optional[int] = Field(None, description="Season number for TV shows (optional, only applicable for series)") - sites: Optional[List[int]] = Field(None, description="Array of specific site IDs to search on (optional, if not provided searches all configured sites)") + sites: Optional[List[int]] = Field(None, + description="Array of specific site IDs to search on (optional, if not provided searches all configured sites)") class SearchTorrentsTool(MoviePilotTool): @@ -26,14 +30,14 @@ class SearchTorrentsTool(MoviePilotTool): description: str = "Search for torrent files across configured indexer sites based on media information. Returns available torrent downloads with details like file size, quality, and download links." args_schema: Type[BaseModel] = SearchTorrentsInput - async def _arun(self, title: str, year: Optional[str] = None, - media_type: Optional[str] = None, season: Optional[int] = None, - sites: Optional[List[int]] = None, **kwargs) -> str: + async def run(self, title: str, year: Optional[str] = None, + media_type: Optional[str] = None, season: Optional[int] = None, + sites: Optional[List[int]] = None, **kwargs) -> str: logger.info( f"执行工具: {self.name}, 参数: title={title}, year={year}, media_type={media_type}, season={season}, sites={sites}") # 发送工具执行说明 - self.send_tool_message(f"正在搜索种子资源: {title}" + (f" ({year})" if year else ""), title="搜索种子") + await self.send_tool_message(f"正在搜索种子资源: {title}" + (f" ({year})" if year else ""), title="搜索种子") try: search_chain = SearchChain() @@ -52,22 +56,22 @@ class SearchTorrentsTool(MoviePilotTool): if filtered_torrents: result_message = f"找到 {len(filtered_torrents)} 个相关种子资源" - self.send_tool_message(result_message, title="搜索成功") + await self.send_tool_message(result_message, title="搜索成功") # 发送详细结果 for i, torrent in enumerate(filtered_torrents[:5]): # 只显示前5个结果 torrent_title = torrent.torrent_info.title if torrent.torrent_info else torrent.meta_info.title if torrent.meta_info else "未知" site_name = torrent.torrent_info.site_name if torrent.torrent_info else "未知站点" torrent_info = f"{i + 1}. {torrent_title} - {site_name}" - self.send_tool_message(torrent_info, title="搜索结果") + await self.send_tool_message(torrent_info, title="搜索结果") return json.dumps([t.to_dict() for t in filtered_torrents], ensure_ascii=False, indent=2) else: error_message = f"未找到相关种子资源: {title}" - self.send_tool_message(error_message, title="搜索完成") + await self.send_tool_message(error_message, title="搜索完成") return error_message except Exception as e: error_message = f"搜索种子时发生错误: {str(e)}" logger.error(f"搜索种子失败: {e}", exc_info=True) - self.send_tool_message(error_message, title="搜索失败") + await self.send_tool_message(error_message, title="搜索失败") return error_message diff --git a/app/agent/tools/impl/send_message.py b/app/agent/tools/impl/send_message.py index 17371be7..39ef4cd9 100644 --- a/app/agent/tools/impl/send_message.py +++ b/app/agent/tools/impl/send_message.py @@ -12,7 +12,8 @@ class SendMessageInput(BaseModel): """发送消息工具的输入参数模型""" explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context") message: str = Field(..., description="The message content to send to the user (should be clear and informative)") - message_type: Optional[str] = Field("info", description="Type of message: 'info' for general information, 'success' for successful operations, 'warning' for warnings, 'error' for error messages") + message_type: Optional[str] = Field("info", + description="Type of message: 'info' for general information, 'success' for successful operations, 'warning' for warnings, 'error' for error messages") class SendMessageTool(MoviePilotTool): @@ -20,10 +21,10 @@ class SendMessageTool(MoviePilotTool): description: str = "Send notification message to the user through configured notification channels (Telegram, Slack, WeChat, etc.). Used to inform users about operation results, errors, or important updates." args_schema: Type[BaseModel] = SendMessageInput - async def _arun(self, message: str, message_type: Optional[str] = None, **kwargs) -> str: + async def run(self, message: str, message_type: Optional[str] = None, **kwargs) -> str: logger.info(f"执行工具: {self.name}, 参数: message={message}, message_type={message_type}") try: - self.send_tool_message(message, title=message_type) + await self.send_tool_message(message, title=message_type) return "消息已发送。" except Exception as e: logger.error(f"发送消息失败: {e}")