diff --git a/app/agent/tools/impl/add_subscribe.py b/app/agent/tools/impl/add_subscribe.py index fa471a10..775b7b79 100644 --- a/app/agent/tools/impl/add_subscribe.py +++ b/app/agent/tools/impl/add_subscribe.py @@ -21,6 +21,8 @@ class AddSubscribeInput(BaseModel): description="Season number for TV shows (optional, if not specified will subscribe to all seasons)") tmdb_id: Optional[int] = Field(None, description="TMDB database ID for precise media identification (optional, can be obtained from search_media tool)") + douban_id: Optional[str] = Field(None, + description="Douban ID for precise media identification (optional, alternative to tmdb_id)") start_episode: Optional[int] = Field(None, description="Starting episode number for TV shows (optional, defaults to 1 if not specified)") total_episode: Optional[int] = Field(None, @@ -61,13 +63,14 @@ class AddSubscribeTool(MoviePilotTool): async def run(self, title: str, year: str, media_type: str, season: Optional[int] = None, tmdb_id: Optional[int] = None, + douban_id: Optional[str] = None, start_episode: Optional[int] = None, total_episode: Optional[int] = None, quality: Optional[str] = None, resolution: Optional[str] = None, effect: Optional[str] = None, filter_groups: Optional[List[str]] = None, sites: Optional[List[int]] = None, **kwargs) -> str: logger.info( f"执行工具: {self.name}, 参数: title={title}, year={year}, media_type={media_type}, " - f"season={season}, tmdb_id={tmdb_id}, start_episode={start_episode}, " + f"season={season}, tmdb_id={tmdb_id}, douban_id={douban_id}, start_episode={start_episode}, " f"total_episode={total_episode}, quality={quality}, resolution={resolution}, " f"effect={effect}, filter_groups={filter_groups}, sites={sites}") @@ -99,6 +102,7 @@ class AddSubscribeTool(MoviePilotTool): title=title, year=year, tmdbid=tmdb_id, + doubanid=douban_id, season=season, username=self._user_id, **subscribe_kwargs diff --git a/app/agent/tools/impl/query_media_detail.py b/app/agent/tools/impl/query_media_detail.py index a318bba4..f1600aa9 100644 --- a/app/agent/tools/impl/query_media_detail.py +++ b/app/agent/tools/impl/query_media_detail.py @@ -14,22 +14,32 @@ from app.schemas.types import MediaType class QueryMediaDetailInput(BaseModel): """查询媒体详情工具的输入参数模型""" explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context") - tmdb_id: int = Field(..., description="TMDB ID of the media (movie or TV series, can be obtained from search_media tool)") + tmdb_id: Optional[int] = Field(None, description="TMDB ID of the media (movie or TV series, can be obtained from search_media tool)") + douban_id: Optional[str] = Field(None, description="Douban ID of the media (alternative to tmdb_id)") media_type: str = Field(..., description="Allowed values: movie, tv") class QueryMediaDetailTool(MoviePilotTool): name: str = "query_media_detail" - description: str = "Query supplementary media details from TMDB by ID and media_type. media_type accepts 'movie' or 'tv'. Returns non-duplicated detail fields such as status, genres, directors, actors, and season info for TV series." + description: str = "Query supplementary media details from TMDB by ID and media_type. Accepts tmdb_id or douban_id (at least one required). media_type accepts 'movie' or 'tv'. Returns non-duplicated detail fields such as status, genres, directors, actors, and season info for TV series." args_schema: Type[BaseModel] = QueryMediaDetailInput def get_tool_message(self, **kwargs) -> Optional[str]: """根据查询参数生成友好的提示消息""" tmdb_id = kwargs.get("tmdb_id") - return f"正在查询媒体详情: TMDB ID {tmdb_id}" + douban_id = kwargs.get("douban_id") + if tmdb_id: + return f"正在查询媒体详情: TMDB ID {tmdb_id}" + return f"正在查询媒体详情: 豆瓣 ID {douban_id}" - async def run(self, tmdb_id: int, media_type: str, **kwargs) -> str: - logger.info(f"执行工具: {self.name}, 参数: tmdb_id={tmdb_id}, media_type={media_type}") + async def run(self, media_type: str, tmdb_id: Optional[int] = None, douban_id: Optional[str] = None, **kwargs) -> str: + logger.info(f"执行工具: {self.name}, 参数: tmdb_id={tmdb_id}, douban_id={douban_id}, media_type={media_type}") + + if tmdb_id is None and douban_id is None: + return json.dumps({ + "success": False, + "message": "必须提供 tmdb_id 或 douban_id 之一" + }, ensure_ascii=False) try: media_chain = MediaChain() @@ -41,12 +51,13 @@ class QueryMediaDetailTool(MoviePilotTool): "message": f"无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv'" }, ensure_ascii=False) - mediainfo = await media_chain.async_recognize_media(tmdbid=tmdb_id, mtype=media_type_enum) - + mediainfo = await media_chain.async_recognize_media(tmdbid=tmdb_id, doubanid=douban_id, mtype=media_type_enum) + if not mediainfo: + id_info = f"TMDB ID {tmdb_id}" if tmdb_id else f"豆瓣 ID {douban_id}" return json.dumps({ "success": False, - "message": f"未找到 TMDB ID {tmdb_id} 的媒体信息" + "message": f"未找到 {id_info} 的媒体信息" }, ensure_ascii=False) # 精简 genres - 只保留名称 @@ -110,5 +121,6 @@ class QueryMediaDetailTool(MoviePilotTool): return json.dumps({ "success": False, "message": error_message, - "tmdb_id": tmdb_id + "tmdb_id": tmdb_id, + "douban_id": douban_id }, ensure_ascii=False) diff --git a/app/agent/tools/impl/query_subscribes.py b/app/agent/tools/impl/query_subscribes.py index 4018dbaf..ec594fe4 100644 --- a/app/agent/tools/impl/query_subscribes.py +++ b/app/agent/tools/impl/query_subscribes.py @@ -19,6 +19,7 @@ class QuerySubscribesInput(BaseModel): media_type: Optional[str] = Field("all", description="Allowed values: movie, tv, all") tmdb_id: Optional[int] = Field(None, description="Filter by TMDB ID to check if a specific media is already subscribed") + douban_id: Optional[str] = Field(None, description="Filter by Douban ID to check if a specific media is already subscribed") class QuerySubscribesTool(MoviePilotTool): @@ -45,8 +46,8 @@ class QuerySubscribesTool(MoviePilotTool): return " | ".join(parts) if len(parts) > 1 else parts[0] async def run(self, status: Optional[str] = "all", media_type: Optional[str] = "all", - tmdb_id: Optional[int] = None, **kwargs) -> str: - logger.info(f"执行工具: {self.name}, 参数: status={status}, media_type={media_type}, tmdb_id={tmdb_id}") + tmdb_id: Optional[int] = None, douban_id: Optional[str] = None, **kwargs) -> str: + logger.info(f"执行工具: {self.name}, 参数: status={status}, media_type={media_type}, tmdb_id={tmdb_id}, douban_id={douban_id}") try: if media_type != "all" and not MediaType.from_agent(media_type): return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv', 'all'" @@ -61,6 +62,8 @@ class QuerySubscribesTool(MoviePilotTool): continue if tmdb_id is not None and sub.tmdbid != tmdb_id: continue + if douban_id is not None and sub.doubanid != douban_id: + continue filtered_subscribes.append(sub) if filtered_subscribes: # 限制最多50条结果