diff --git a/README.md b/README.md index 104aa5bd..94a60b25 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,11 @@ 官方Wiki:https://wiki.movie-pilot.org +### 为 AI Agent 添加 Skills +```shell +npx skills add https://github.com/jxxghp/MoviePilot +``` + ## 参与开发 API文档:https://api.movie-pilot.org diff --git a/app/agent/tools/factory.py b/app/agent/tools/factory.py index fa5d1685..5e84891a 100644 --- a/app/agent/tools/factory.py +++ b/app/agent/tools/factory.py @@ -27,6 +27,7 @@ from app.agent.tools.impl.scrape_metadata import ScrapeMetadataTool from app.agent.tools.impl.query_episode_schedule import QueryEpisodeScheduleTool from app.agent.tools.impl.query_media_detail import QueryMediaDetailTool from app.agent.tools.impl.search_torrents import SearchTorrentsTool +from app.agent.tools.impl.get_search_results import GetSearchResultsTool from app.agent.tools.impl.search_web import SearchWebTool from app.agent.tools.impl.send_message import SendMessageTool from app.agent.tools.impl.query_schedulers import QuerySchedulersTool @@ -70,6 +71,7 @@ class MoviePilotToolFactory: UpdateSubscribeTool, SearchSubscribeTool, SearchTorrentsTool, + GetSearchResultsTool, SearchWebTool, AddDownloadTool, QuerySubscribesTool, diff --git a/app/agent/tools/impl/_torrent_search_utils.py b/app/agent/tools/impl/_torrent_search_utils.py new file mode 100644 index 00000000..b5a8b3bd --- /dev/null +++ b/app/agent/tools/impl/_torrent_search_utils.py @@ -0,0 +1,176 @@ +"""种子搜索工具辅助函数""" + +import re +from typing import List, Optional + +from app.core.context import Context +from app.utils.crypto import HashUtils +from app.utils.string import StringUtils + +SEARCH_RESULT_CACHE_FILE = "__search_result__" +TORRENT_RESULT_LIMIT = 50 + + +def build_torrent_ref(context: Optional[Context]) -> str: + """生成用于下载校验的短引用""" + if not context or not context.torrent_info: + return "" + return HashUtils.sha1(context.torrent_info.enclosure or "")[:7] + + +def sort_season_options(options: List[str]) -> List[str]: + """按前端逻辑排序季集选项""" + if len(options) <= 1: + return options + + parsed_options = [] + for index, option in enumerate(options): + match = re.match(r"^S(\d+)(?:-S(\d+))?\s*(?:E(\d+)(?:-E(\d+))?)?$", option or "") + if not match: + parsed_options.append({ + "original": option, + "season_num": 0, + "episode_num": 0, + "max_episode_num": 0, + "is_whole_season": False, + "index": index, + }) + continue + + episode_num = int(match.group(3)) if match.group(3) else 0 + max_episode_num = int(match.group(4)) if match.group(4) else episode_num + parsed_options.append({ + "original": option, + "season_num": int(match.group(1)), + "episode_num": episode_num, + "max_episode_num": max_episode_num, + "is_whole_season": not match.group(3), + "index": index, + }) + + whole_seasons = [item for item in parsed_options if item["is_whole_season"]] + episodes = [item for item in parsed_options if not item["is_whole_season"]] + + whole_seasons.sort(key=lambda item: (-item["season_num"], item["index"])) + episodes.sort( + key=lambda item: ( + -item["season_num"], + -(item["max_episode_num"] or item["episode_num"]), + -item["episode_num"], + item["index"], + ) + ) + return [item["original"] for item in whole_seasons + episodes] + + +def append_option(options: List[str], value: Optional[str]) -> None: + """按前端逻辑收集去重后的筛选项""" + if value and value not in options: + options.append(value) + + +def build_filter_options(items: List[Context]) -> dict: + """从搜索结果中构建筛选项汇总""" + filter_options = { + "site": [], + "season": [], + "freeState": [], + "edition": [], + "resolution": [], + "videoCode": [], + "releaseGroup": [], + } + + for item in items: + torrent_info = item.torrent_info + meta_info = item.meta_info + append_option(filter_options["site"], getattr(torrent_info, "site_name", None)) + append_option(filter_options["season"], getattr(meta_info, "season_episode", None)) + append_option(filter_options["freeState"], getattr(torrent_info, "volume_factor", None)) + append_option(filter_options["edition"], getattr(meta_info, "edition", None)) + append_option(filter_options["resolution"], getattr(meta_info, "resource_pix", None)) + append_option(filter_options["videoCode"], getattr(meta_info, "video_encode", None)) + append_option(filter_options["releaseGroup"], getattr(meta_info, "resource_team", None)) + + filter_options["season"] = sort_season_options(filter_options["season"]) + return filter_options + + +def match_filter(filter_values: Optional[List[str]], value: Optional[str]) -> bool: + """匹配前端同款多选筛选规则""" + return not filter_values or bool(value and value in filter_values) + + +def filter_contexts(items: List[Context], + site: Optional[List[str]] = None, + season: Optional[List[str]] = None, + free_state: Optional[List[str]] = None, + video_code: Optional[List[str]] = None, + edition: Optional[List[str]] = None, + resolution: Optional[List[str]] = None, + release_group: Optional[List[str]] = None) -> List[Context]: + """按前端同款维度筛选结果""" + filtered_items = [] + for item in items: + torrent_info = item.torrent_info + meta_info = item.meta_info + if ( + match_filter(site, getattr(torrent_info, "site_name", None)) + and match_filter(free_state, getattr(torrent_info, "volume_factor", None)) + and match_filter(season, getattr(meta_info, "season_episode", None)) + and match_filter(release_group, getattr(meta_info, "resource_team", None)) + and match_filter(video_code, getattr(meta_info, "video_encode", None)) + and match_filter(resolution, getattr(meta_info, "resource_pix", None)) + and match_filter(edition, getattr(meta_info, "edition", None)) + ): + filtered_items.append(item) + return filtered_items + + +def simplify_search_result(context: Context, index: int) -> dict: + """精简单条搜索结果""" + simplified = {} + torrent_info = context.torrent_info + meta_info = context.meta_info + media_info = context.media_info + + if torrent_info: + simplified["torrent_info"] = { + "title": torrent_info.title, + "size": StringUtils.format_size(torrent_info.size), + "seeders": torrent_info.seeders, + "peers": torrent_info.peers, + "site_name": torrent_info.site_name, + "torrent_url": f"{build_torrent_ref(context)}:{index}", + "page_url": torrent_info.page_url, + "volume_factor": torrent_info.volume_factor, + "freedate_diff": torrent_info.freedate_diff, + "pubdate": torrent_info.pubdate, + } + + if media_info: + simplified["media_info"] = { + "title": media_info.title, + "en_title": media_info.en_title, + "year": media_info.year, + "type": media_info.type.value if media_info.type else None, + "season": media_info.season, + "tmdb_id": media_info.tmdb_id, + } + + if meta_info: + simplified["meta_info"] = { + "name": meta_info.name, + "cn_name": meta_info.cn_name, + "en_name": meta_info.en_name, + "year": meta_info.year, + "type": meta_info.type.value if meta_info.type else None, + "begin_season": meta_info.begin_season, + "season_episode": meta_info.season_episode, + "resource_team": meta_info.resource_team, + "video_encode": meta_info.video_encode, + "edition": meta_info.edition, + "resource_pix": meta_info.resource_pix, + } + + return simplified diff --git a/app/agent/tools/impl/add_download.py b/app/agent/tools/impl/add_download.py index baaf3bfa..8347c30d 100644 --- a/app/agent/tools/impl/add_download.py +++ b/app/agent/tools/impl/add_download.py @@ -1,27 +1,29 @@ """添加下载工具""" +import re from typing import Optional, Type from pydantic import BaseModel, Field from app.agent.tools.base import MoviePilotTool, ToolChain +from app.chain.search import SearchChain from app.chain.download import DownloadChain +from app.core.config import settings from app.core.context import Context from app.core.metainfo import MetaInfo from app.db.site_oper import SiteOper from app.log import logger from app.schemas import TorrentInfo +from app.utils.crypto import HashUtils class AddDownloadInput(BaseModel): """添加下载工具的输入参数模型""" explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context") - site_name: str = Field(..., description="Name of the torrent site/source (e.g., 'The Pirate Bay')") - torrent_title: str = Field(..., - description="The display name/title of the torrent (e.g., 'The.Matrix.1999.1080p.BluRay.x264')") - torrent_url: str = Field(..., description="Direct URL to the torrent file (.torrent) or magnet link") - torrent_description: Optional[str] = Field(None, - description="Brief description of the torrent content (optional)") + torrent_url: str = Field( + ..., + description="torrent_url in hash:id format (obtainable from get_search_results tool per-item results)" + ) downloader: Optional[str] = Field(None, description="Name of the downloader to use (optional, uses default if not specified)") save_path: Optional[str] = Field(None, @@ -32,32 +34,87 @@ class AddDownloadInput(BaseModel): class AddDownloadTool(MoviePilotTool): name: str = "add_download" - description: str = "Add torrent download task to the configured downloader (qBittorrent, Transmission, etc.). Downloads the torrent file and starts the download process with specified settings." + description: str = "Add torrent download task to the configured downloader (qBittorrent, Transmission, etc.) using torrent_url reference from get_search_results results." args_schema: Type[BaseModel] = AddDownloadInput def get_tool_message(self, **kwargs) -> Optional[str]: """根据下载参数生成友好的提示消息""" - torrent_title = kwargs.get("torrent_title", "") - site_name = kwargs.get("site_name", "") + torrent_url = kwargs.get("torrent_url") downloader = kwargs.get("downloader") - message = f"正在添加下载任务: {torrent_title}" - if site_name: - message += f" (来源: {site_name})" + message = f"正在添加下载任务: 资源 {torrent_url}" if downloader: message += f" [下载器: {downloader}]" return message - async def run(self, site_name: str, torrent_title: str, torrent_url: str, torrent_description: Optional[str] = None, + @staticmethod + def _build_torrent_ref(context: Context) -> str: + """生成用于校验缓存项的短引用""" + if not context or not context.torrent_info: + return "" + return HashUtils.sha1(context.torrent_info.enclosure or "")[:7] + + @staticmethod + def _is_torrent_ref(torrent_ref: Optional[str]) -> bool: + """判断是否为内部搜索结果引用""" + if not torrent_ref: + return False + return bool(re.fullmatch(r"[0-9a-f]{7}:\d+", str(torrent_ref).strip())) + + @classmethod + def _resolve_cached_context(cls, torrent_ref: str) -> Optional[Context]: + """从最近一次搜索缓存中解析种子上下文,仅支持 hash:id 格式""" + ref = str(torrent_ref).strip() + if ":" not in ref: + return None + try: + ref_hash, ref_index = ref.split(":", 1) + index = int(ref_index) + except (TypeError, ValueError): + return None + + if index < 1: + return None + + results = SearchChain().last_search_results() or [] + if index > len(results): + return None + context = results[index - 1] + if not ref_hash or cls._build_torrent_ref(context) != ref_hash: + return None + return context + + @staticmethod + def _merge_labels_with_system_tag(labels: Optional[str]) -> Optional[str]: + """合并用户标签与系统默认标签,确保任务可被系统管理""" + system_tag = (settings.TORRENT_TAG or "").strip() + user_labels = [item.strip() for item in (labels or "").split(",") if item.strip()] + + if system_tag and system_tag not in user_labels: + user_labels.append(system_tag) + + return ",".join(user_labels) if user_labels else None + + async def run(self, torrent_url: Optional[str] = None, downloader: Optional[str] = None, save_path: Optional[str] = None, labels: Optional[str] = None, **kwargs) -> str: logger.info( - f"执行工具: {self.name}, 参数: site_name={site_name}, torrent_title={torrent_title}, torrent_url={torrent_url}, downloader={downloader}, save_path={save_path}, labels={labels}") + f"执行工具: {self.name}, 参数: torrent_url={torrent_url}, downloader={downloader}, save_path={save_path}, labels={labels}") try: - if not torrent_title or not torrent_url: - return "错误:必须提供种子标题和下载链接" + if not torrent_url or not self._is_torrent_ref(torrent_url): + return "错误:torrent_url 必须是 get_search_results 返回的 hash:id 引用,请先使用 search_torrents 搜索,再通过 get_search_results 筛选后选择。" + + cached_context = self._resolve_cached_context(torrent_url) + if not cached_context or not cached_context.torrent_info: + return "错误:torrent_url 无效,请重新使用 search_torrents 搜索" + + cached_torrent = cached_context.torrent_info + site_name = cached_torrent.site_name + torrent_title = cached_torrent.title + torrent_description = cached_torrent.description + torrent_url = cached_torrent.enclosure # 使用DownloadChain添加下载 download_chain = DownloadChain() @@ -82,7 +139,9 @@ class AddDownloadTool(MoviePilotTool): site_downloader=siteinfo.downloader ) meta_info = MetaInfo(title=torrent_title, subtitle=torrent_description) - media_info = await ToolChain().async_recognize_media(meta=meta_info) + media_info = cached_context.media_info if cached_context and cached_context.media_info else None + if not media_info: + media_info = await ToolChain().async_recognize_media(meta=meta_info) if not media_info: return "错误:无法识别媒体信息,无法添加下载任务" context = Context( @@ -91,11 +150,13 @@ class AddDownloadTool(MoviePilotTool): media_info=media_info ) + merged_labels = self._merge_labels_with_system_tag(labels) + did = download_chain.download_single( context=context, downloader=downloader, save_path=save_path, - label=labels + label=merged_labels ) if did: return f"成功添加下载任务:{torrent_title}" diff --git a/app/agent/tools/impl/add_subscribe.py b/app/agent/tools/impl/add_subscribe.py index a2b26cc1..775b7b79 100644 --- a/app/agent/tools/impl/add_subscribe.py +++ b/app/agent/tools/impl/add_subscribe.py @@ -16,11 +16,13 @@ class AddSubscribeInput(BaseModel): title: str = Field(..., description="The title of the media to subscribe to (e.g., 'The Matrix', 'Breaking Bad')") year: str = Field(..., description="Release year of the media (required for accurate identification)") media_type: str = Field(..., - description="Type of media content: '电影' for films, '电视剧' for television series or anime series") + description="Allowed values: movie, tv") season: Optional[int] = Field(None, description="Season number for TV shows (optional, if not specified will subscribe to all seasons)") - tmdb_id: Optional[str] = Field(None, - description="TMDB database ID for precise media identification (optional but recommended for accuracy)") + tmdb_id: Optional[int] = Field(None, + description="TMDB database ID for precise media identification (optional, can be obtained from search_media tool)") + douban_id: Optional[str] = Field(None, + description="Douban ID for precise media identification (optional, alternative to tmdb_id)") start_episode: Optional[int] = Field(None, description="Starting episode number for TV shows (optional, defaults to 1 if not specified)") total_episode: Optional[int] = Field(None, @@ -32,9 +34,9 @@ class AddSubscribeInput(BaseModel): effect: Optional[str] = Field(None, description="Effect filter as regular expression (optional, e.g., 'HDR|DV|SDR')") filter_groups: Optional[List[str]] = Field(None, - description="List of filter rule group names to apply (optional, use query_rule_groups tool to get available rule groups)") + description="List of filter rule group names to apply (optional, can be obtained from query_rule_groups tool)") sites: Optional[List[int]] = Field(None, - description="List of site IDs to search from (optional, use query_sites tool to get available site IDs)") + description="List of site IDs to search from (optional, can be obtained from query_sites tool)") class AddSubscribeTool(MoviePilotTool): @@ -60,26 +62,23 @@ class AddSubscribeTool(MoviePilotTool): return message async def run(self, title: str, year: str, media_type: str, - season: Optional[int] = None, tmdb_id: Optional[str] = None, + season: Optional[int] = None, tmdb_id: Optional[int] = None, + douban_id: Optional[str] = None, start_episode: Optional[int] = None, total_episode: Optional[int] = None, quality: Optional[str] = None, resolution: Optional[str] = None, effect: Optional[str] = None, filter_groups: Optional[List[str]] = None, sites: Optional[List[int]] = None, **kwargs) -> str: logger.info( f"执行工具: {self.name}, 参数: title={title}, year={year}, media_type={media_type}, " - f"season={season}, tmdb_id={tmdb_id}, start_episode={start_episode}, " + f"season={season}, tmdb_id={tmdb_id}, douban_id={douban_id}, start_episode={start_episode}, " f"total_episode={total_episode}, quality={quality}, resolution={resolution}, " f"effect={effect}, filter_groups={filter_groups}, sites={sites}") try: subscribe_chain = SubscribeChain() - # 转换 tmdb_id 为整数 - tmdbid_int = None - if tmdb_id: - try: - tmdbid_int = int(tmdb_id) - except (ValueError, TypeError): - logger.warning(f"无效的 tmdb_id: {tmdb_id},将忽略") + media_type_enum = MediaType.from_agent(media_type) + if not media_type_enum: + return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv'" # 构建额外的订阅参数 subscribe_kwargs = {} @@ -99,10 +98,11 @@ class AddSubscribeTool(MoviePilotTool): subscribe_kwargs['sites'] = sites sid, message = await subscribe_chain.async_add( - mtype=MediaType(media_type), + mtype=media_type_enum, title=title, year=year, - tmdbid=tmdbid_int, + tmdbid=tmdb_id, + doubanid=douban_id, season=season, username=self._user_id, **subscribe_kwargs diff --git a/app/agent/tools/impl/delete_download.py b/app/agent/tools/impl/delete_download.py index 12952b16..9433d765 100644 --- a/app/agent/tools/impl/delete_download.py +++ b/app/agent/tools/impl/delete_download.py @@ -12,23 +12,23 @@ from app.log import logger class DeleteDownloadInput(BaseModel): """删除下载任务工具的输入参数模型""" explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context") - task_identifier: str = Field(..., description="Task identifier: can be task hash (unique identifier) or task title/name") + hash: str = Field(..., description="Task hash (can be obtained from query_download_tasks tool)") downloader: Optional[str] = Field(None, description="Name of specific downloader (optional, if not provided will search all downloaders)") delete_files: Optional[bool] = Field(False, description="Whether to delete downloaded files along with the task (default: False, only removes the task from downloader)") class DeleteDownloadTool(MoviePilotTool): name: str = "delete_download" - description: str = "Delete a download task from the downloader. Can delete by task hash (unique identifier) or task title/name. Optionally specify the downloader name and whether to delete downloaded files." + description: str = "Delete a download task from the downloader by task hash only. Optionally specify the downloader name and whether to delete downloaded files." args_schema: Type[BaseModel] = DeleteDownloadInput def get_tool_message(self, **kwargs) -> Optional[str]: """根据删除参数生成友好的提示消息""" - task_identifier = kwargs.get("task_identifier", "") + hash_value = kwargs.get("hash", "") downloader = kwargs.get("downloader") delete_files = kwargs.get("delete_files", False) - message = f"正在删除下载任务: {task_identifier}" + message = f"正在删除下载任务: {hash_value}" if downloader: message += f" [下载器: {downloader}]" if delete_files: @@ -36,40 +36,26 @@ class DeleteDownloadTool(MoviePilotTool): return message - async def run(self, task_identifier: str, downloader: Optional[str] = None, + async def run(self, hash: str, downloader: Optional[str] = None, delete_files: Optional[bool] = False, **kwargs) -> str: - logger.info(f"执行工具: {self.name}, 参数: task_identifier={task_identifier}, downloader={downloader}, delete_files={delete_files}") + logger.info(f"执行工具: {self.name}, 参数: hash={hash}, downloader={downloader}, delete_files={delete_files}") try: download_chain = DownloadChain() - - # 如果task_identifier看起来像hash(通常是40个字符的十六进制字符串) - task_hash = None - if len(task_identifier) == 40 and all(c in '0123456789abcdefABCDEF' for c in task_identifier): - # 直接使用hash - task_hash = task_identifier - else: - # 通过标题查找任务 - downloads = download_chain.downloading(name=downloader) - for dl in downloads: - # 检查标题或名称是否匹配 - if (task_identifier.lower() in (dl.title or "").lower()) or \ - (task_identifier.lower() in (dl.name or "").lower()): - task_hash = dl.hash - break - - if not task_hash: - return f"未找到匹配的下载任务:{task_identifier},请使用 query_downloads 工具查询可用的下载任务" + + # 仅支持通过hash删除任务 + if len(hash) != 40 or not all(c in '0123456789abcdefABCDEF' for c in hash): + return "参数错误:hash 格式无效,请先使用 query_download_tasks 工具获取正确的 hash。" # 删除下载任务 # remove_torrents 支持 delete_file 参数,可以控制是否删除文件 - result = download_chain.remove_torrents(hashs=[task_hash], downloader=downloader, delete_file=delete_files) + result = download_chain.remove_torrents(hashs=[hash], downloader=downloader, delete_file=delete_files) if result: files_info = "(包含文件)" if delete_files else "(不包含文件)" - return f"成功删除下载任务:{task_identifier} {files_info}" + return f"成功删除下载任务:{hash} {files_info}" else: - return f"删除下载任务失败:{task_identifier},请检查任务是否存在或下载器是否可用" + return f"删除下载任务失败:{hash},请检查任务是否存在或下载器是否可用" except Exception as e: logger.error(f"删除下载任务失败: {e}", exc_info=True) return f"删除下载任务时发生错误: {str(e)}" diff --git a/app/agent/tools/impl/get_recommendations.py b/app/agent/tools/impl/get_recommendations.py index 4d7cfa06..e0617b18 100644 --- a/app/agent/tools/impl/get_recommendations.py +++ b/app/agent/tools/impl/get_recommendations.py @@ -8,6 +8,7 @@ from pydantic import BaseModel, Field from app.agent.tools.base import MoviePilotTool from app.chain.recommend import RecommendChain from app.log import logger +from app.schemas.types import MediaType, media_type_to_agent class GetRecommendationsInput(BaseModel): @@ -30,7 +31,7 @@ class GetRecommendationsInput(BaseModel): "'douban_tv_animation' for Douban popular animation, " "'bangumi_calendar' for Bangumi anime calendar") media_type: Optional[str] = Field("all", - description="Type of media content: '电影' for films, '电视剧' for television series or anime series, 'all' for all types") + description="Allowed values: movie, tv, all") limit: Optional[int] = Field(20, description="Maximum number of recommendations to return (default: 20, maximum: 100)") @@ -75,6 +76,12 @@ class GetRecommendationsTool(MoviePilotTool): media_type: Optional[str] = "all", limit: Optional[int] = 20, **kwargs) -> str: logger.info(f"执行工具: {self.name}, 参数: source={source}, media_type={media_type}, limit={limit}") try: + if media_type != "all": + media_type_enum = MediaType.from_agent(media_type) + if not media_type_enum: + return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv', 'all'" + media_type = media_type_enum.to_agent() # 归一化为 "movie"/"tv" + recommend_chain = RecommendChain() results = [] if source == "tmdb_trending": @@ -149,7 +156,7 @@ class GetRecommendationsTool(MoviePilotTool): "title": r.get("title"), "en_title": r.get("en_title"), "year": r.get("year"), - "type": r.get("type"), + "type": media_type_to_agent(r.get("type")), "season": r.get("season"), "tmdb_id": r.get("tmdb_id"), "imdb_id": r.get("imdb_id"), diff --git a/app/agent/tools/impl/get_search_results.py b/app/agent/tools/impl/get_search_results.py new file mode 100644 index 00000000..ab740289 --- /dev/null +++ b/app/agent/tools/impl/get_search_results.py @@ -0,0 +1,108 @@ +"""获取搜索结果工具""" + +import json +import re +from typing import List, Optional, Type + +from pydantic import BaseModel, Field + +from app.agent.tools.base import MoviePilotTool +from app.chain.search import SearchChain +from app.log import logger +from ._torrent_search_utils import ( + TORRENT_RESULT_LIMIT, + build_filter_options, + filter_contexts, + simplify_search_result, +) + + +class GetSearchResultsInput(BaseModel): + """获取搜索结果工具的输入参数模型""" + explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context") + site: Optional[List[str]] = Field(None, description="Site name filters") + season: Optional[List[str]] = Field(None, description="Season or episode filters") + free_state: Optional[List[str]] = Field(None, description="Promotion state filters") + video_code: Optional[List[str]] = Field(None, description="Video codec filters") + edition: Optional[List[str]] = Field(None, description="Edition filters") + resolution: Optional[List[str]] = Field(None, description="Resolution filters") + release_group: Optional[List[str]] = Field(None, description="Release group filters") + title_pattern: Optional[str] = Field(None, description="Regular expression pattern to filter torrent titles (e.g., '4K|2160p|UHD', '1080p.*BluRay')") + show_filter_options: Optional[bool] = Field(False, description="Whether to return only optional filter options for re-checking available conditions") + +class GetSearchResultsTool(MoviePilotTool): + name: str = "get_search_results" + description: str = "Get cached torrent search results from search_torrents with optional filters. Returns at most the first 50 matches." + args_schema: Type[BaseModel] = GetSearchResultsInput + + def get_tool_message(self, **kwargs) -> Optional[str]: + return "正在获取搜索结果" + + async def run(self, site: Optional[List[str]] = None, season: Optional[List[str]] = None, + free_state: Optional[List[str]] = None, video_code: Optional[List[str]] = None, + edition: Optional[List[str]] = None, resolution: Optional[List[str]] = None, + release_group: Optional[List[str]] = None, title_pattern: Optional[str] = None, + show_filter_options: bool = False, + **kwargs) -> str: + logger.info( + f"执行工具: {self.name}, 参数: site={site}, season={season}, free_state={free_state}, video_code={video_code}, edition={edition}, resolution={resolution}, release_group={release_group}, title_pattern={title_pattern}, show_filter_options={show_filter_options}") + + try: + items = await SearchChain().async_last_search_results() or [] + if not items: + return "没有可用的搜索结果,请先使用 search_torrents 搜索" + + if show_filter_options: + payload = { + "total_count": len(items), + "filter_options": build_filter_options(items), + } + return json.dumps(payload, ensure_ascii=False, indent=2) + + regex_pattern = None + if title_pattern: + try: + regex_pattern = re.compile(title_pattern, re.IGNORECASE) + except re.error as e: + logger.warning(f"正则表达式编译失败: {title_pattern}, 错误: {e}") + return f"正则表达式格式错误: {str(e)}" + + filtered_items = filter_contexts( + items=items, + site=site, + season=season, + free_state=free_state, + video_code=video_code, + edition=edition, + resolution=resolution, + release_group=release_group, + ) + if regex_pattern: + filtered_items = [ + item for item in filtered_items + if item.torrent_info and item.torrent_info.title + and regex_pattern.search(item.torrent_info.title) + ] + if not filtered_items: + return "没有符合筛选条件的搜索结果,请调整筛选条件" + + total_count = len(filtered_items) + filtered_ids = {id(item) for item in filtered_items} + matched_indices = [index for index, item in enumerate(items, start=1) if id(item) in filtered_ids] + limited_items = filtered_items[:TORRENT_RESULT_LIMIT] + limited_indices = matched_indices[:TORRENT_RESULT_LIMIT] + results = [ + simplify_search_result(item, index) + for item, index in zip(limited_items, limited_indices) + ] + payload = { + "total_count": total_count, + "results": results, + } + if total_count > TORRENT_RESULT_LIMIT: + payload["message"] = f"搜索结果共找到 {total_count} 条,仅显示前 {TORRENT_RESULT_LIMIT} 条结果。" + return json.dumps(payload, ensure_ascii=False, indent=2) + except Exception as e: + error_message = f"获取搜索结果失败: {str(e)}" + logger.error(f"获取搜索结果失败: {e}", exc_info=True) + return error_message diff --git a/app/agent/tools/impl/list_directory.py b/app/agent/tools/impl/list_directory.py index 85f1a6fa..283315c7 100644 --- a/app/agent/tools/impl/list_directory.py +++ b/app/agent/tools/impl/list_directory.py @@ -24,7 +24,7 @@ class ListDirectoryInput(BaseModel): class ListDirectoryTool(MoviePilotTool): name: str = "list_directory" - description: str = "List actual files and folders in a file system directory (NOT configuration). Shows files and subdirectories with their names, types, sizes, and modification times. Returns up to 20 items and the total count if there are more items. Use 'query_directories' to query directory configuration settings." + description: str = "List actual files and folders in a file system directory (NOT configuration). Shows files and subdirectories with their names, types, sizes, and modification times. Returns up to 20 items and the total count if there are more items. Use 'query_directory_settings' to query directory configuration settings." args_schema: Type[BaseModel] = ListDirectoryInput def get_tool_message(self, **kwargs) -> Optional[str]: diff --git a/app/agent/tools/impl/query_download_tasks.py b/app/agent/tools/impl/query_download_tasks.py index c872b9eb..4d7c85e5 100644 --- a/app/agent/tools/impl/query_download_tasks.py +++ b/app/agent/tools/impl/query_download_tasks.py @@ -10,7 +10,7 @@ from app.chain.download import DownloadChain from app.db.downloadhistory_oper import DownloadHistoryOper from app.log import logger from app.schemas import TransferTorrent, DownloadingTorrent -from app.schemas.types import TorrentStatus +from app.schemas.types import TorrentStatus, media_type_to_agent class QueryDownloadTasksInput(BaseModel): @@ -208,7 +208,7 @@ class QueryDownloadTasksTool(MoviePilotTool): if d.media: simplified["media"] = { "tmdbid": d.media.get("tmdbid"), - "type": d.media.get("type"), + "type": media_type_to_agent(d.media.get("type")), "title": d.media.get("title"), "season": d.media.get("season"), "episode": d.media.get("episode") diff --git a/app/agent/tools/impl/query_episode_schedule.py b/app/agent/tools/impl/query_episode_schedule.py index 9c32ba7f..cb4ce4cf 100644 --- a/app/agent/tools/impl/query_episode_schedule.py +++ b/app/agent/tools/impl/query_episode_schedule.py @@ -6,23 +6,21 @@ from typing import Optional, Type from pydantic import BaseModel, Field from app.agent.tools.base import MoviePilotTool -from app.chain.media import MediaChain from app.chain.tmdb import TmdbChain from app.log import logger -from app.schemas import MediaType class QueryEpisodeScheduleInput(BaseModel): """查询剧集上映时间工具的输入参数模型""" explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context") - tmdb_id: int = Field(..., description="TMDB ID of the TV series") + tmdb_id: int = Field(..., description="TMDB ID of the TV series (can be obtained from search_media tool)") season: int = Field(..., description="Season number to query") episode_group: Optional[str] = Field(None, description="Episode group ID (optional)") class QueryEpisodeScheduleTool(MoviePilotTool): name: str = "query_episode_schedule" - description: str = "Query TV series episode air dates and schedule. Returns detailed information for each episode including air date, episode number, title, overview, and other metadata. Filters out episodes without air dates." + description: str = "Query TV series episode air dates and schedule. Returns non-duplicated schedule fields, including episode list, air-date statistics, and per-episode metadata. Filters out episodes without air dates." args_schema: Type[BaseModel] = QueryEpisodeScheduleInput def get_tool_message(self, **kwargs) -> Optional[str]: @@ -41,12 +39,6 @@ class QueryEpisodeScheduleTool(MoviePilotTool): logger.info(f"执行工具: {self.name}, 参数: tmdb_id={tmdb_id}, season={season}, episode_group={episode_group}") try: - # 获取媒体信息(用于获取标题和海报) - media_chain = MediaChain() - mediainfo = await media_chain.async_recognize_media(tmdbid=tmdb_id, mtype=MediaType.TV) - if not mediainfo: - return f"未找到 TMDB ID {tmdb_id} 的媒体信息" - # 获取集列表 tmdb_chain = TmdbChain() episodes = await tmdb_chain.async_tmdb_episodes( @@ -92,12 +84,7 @@ class QueryEpisodeScheduleTool(MoviePilotTool): episode_list.sort(key=lambda x: (x["air_date"] or "", x["episode_number"] or 0)) result = { - "success": True, - "tmdb_id": tmdb_id, "season": season, - "episode_group": episode_group, - "series_title": mediainfo.title if mediainfo else None, - "series_poster": mediainfo.poster_path if mediainfo else None, "total_episodes": len(episodes), "episodes_with_air_date": len(episode_list), "episodes": episode_list diff --git a/app/agent/tools/impl/query_library_exists.py b/app/agent/tools/impl/query_library_exists.py index 19a009d6..d14ade56 100644 --- a/app/agent/tools/impl/query_library_exists.py +++ b/app/agent/tools/impl/query_library_exists.py @@ -7,76 +7,61 @@ from pydantic import BaseModel, Field from app.agent.tools.base import MoviePilotTool from app.chain.mediaserver import MediaServerChain -from app.core.context import MediaInfo -from app.core.meta import MetaBase from app.log import logger -from app.schemas.types import MediaType +from app.schemas.types import MediaType, media_type_to_agent class QueryLibraryExistsInput(BaseModel): """查询媒体库工具的输入参数模型""" explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context") - media_type: Optional[str] = Field("all", - description="Type of media content: '电影' for films, '电视剧' for television series or anime series, 'all' for all types") - title: Optional[str] = Field(None, - description="Specific media title to check if it exists in the media library (optional, if provided checks for that specific media)") - year: Optional[str] = Field(None, - description="Release year of the media (optional, helps narrow down search results)") + tmdb_id: Optional[int] = Field(None, description="TMDB ID (can be obtained from search_media tool). Either tmdb_id or douban_id must be provided.") + douban_id: Optional[str] = Field(None, description="Douban ID (can be obtained from search_media tool). Either tmdb_id or douban_id must be provided.") + media_type: Optional[str] = Field(None, description="Allowed values: movie, tv") class QueryLibraryExistsTool(MoviePilotTool): name: str = "query_library_exists" - description: str = "Check if a specific media resource already exists in the media library (Plex, Emby, Jellyfin). Use this tool to verify whether a movie or TV series has been successfully processed and added to the media server before performing operations like downloading or subscribing." + description: str = "Check whether a specific media resource already exists in the media library (Plex, Emby, Jellyfin) by media ID. Requires tmdb_id or douban_id (can be obtained from search_media tool) for accurate matching." args_schema: Type[BaseModel] = QueryLibraryExistsInput def get_tool_message(self, **kwargs) -> Optional[str]: """根据查询参数生成友好的提示消息""" - media_type = kwargs.get("media_type", "all") - title = kwargs.get("title") - year = kwargs.get("year") - - parts = ["正在查询媒体库"] - - if title: - parts.append(f"标题: {title}") - if year: - parts.append(f"年份: {year}") - if media_type != "all": - parts.append(f"类型: {media_type}") - - return " | ".join(parts) if len(parts) > 1 else parts[0] + tmdb_id = kwargs.get("tmdb_id") + douban_id = kwargs.get("douban_id") + media_type = kwargs.get("media_type") - async def run(self, media_type: Optional[str] = "all", - title: Optional[str] = None, year: Optional[str] = None, **kwargs) -> str: - logger.info(f"执行工具: {self.name}, 参数: media_type={media_type}, title={title}") + if tmdb_id: + message = f"正在查询媒体库: TMDB={tmdb_id}" + elif douban_id: + message = f"正在查询媒体库: 豆瓣={douban_id}" + else: + message = "正在查询媒体库" + if media_type: + message += f" [{media_type}]" + return message + + async def run(self, tmdb_id: Optional[int] = None, douban_id: Optional[str] = None, + media_type: Optional[str] = None, **kwargs) -> str: + logger.info(f"执行工具: {self.name}, 参数: tmdb_id={tmdb_id}, douban_id={douban_id}, media_type={media_type}") try: - if not title: - return "请提供媒体标题进行查询" + if not tmdb_id and not douban_id: + return "参数错误:tmdb_id 和 douban_id 至少需要提供一个,请先使用 search_media 工具获取媒体 ID。" + + media_type_enum = None + if media_type: + media_type_enum = MediaType.from_agent(media_type) + if not media_type_enum: + return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv'" media_chain = MediaServerChain() - - # 1. 识别媒体信息(获取 TMDB ID 和各季的总集数等元数据) - meta = MetaBase(title=title) - if year: - meta.year = str(year) - if media_type == "电影": - meta.type = MediaType.MOVIE - elif media_type == "电视剧": - meta.type = MediaType.TV - - # 使用识别方法补充信息 - recognize_info = media_chain.recognize_media(meta=meta) - if recognize_info: - mediainfo = recognize_info - else: - # 识别失败,创建基本信息的 MediaInfo - mediainfo = MediaInfo() - mediainfo.title = title - mediainfo.year = year - if media_type == "电影": - mediainfo.type = MediaType.MOVIE - elif media_type == "电视剧": - mediainfo.type = MediaType.TV + mediainfo = media_chain.recognize_media( + tmdbid=tmdb_id, + doubanid=douban_id, + mtype=media_type_enum, + ) + if not mediainfo: + media_id = f"TMDB={tmdb_id}" if tmdb_id else f"豆瓣={douban_id}" + return f"未识别到媒体信息: {media_id}" # 2. 调用媒体服务器接口实时查询存在信息 existsinfo = media_chain.media_exists(mediainfo=mediainfo) @@ -120,7 +105,7 @@ class QueryLibraryExistsTool(MoviePilotTool): result_dict = { "title": mediainfo.title, "year": mediainfo.year, - "type": existsinfo.type.value if existsinfo.type else None, + "type": media_type_to_agent(existsinfo.type), "server": existsinfo.server, "server_type": existsinfo.server_type, "itemid": existsinfo.itemid, diff --git a/app/agent/tools/impl/query_media_detail.py b/app/agent/tools/impl/query_media_detail.py index 84c638b5..f1600aa9 100644 --- a/app/agent/tools/impl/query_media_detail.py +++ b/app/agent/tools/impl/query_media_detail.py @@ -8,45 +8,56 @@ from pydantic import BaseModel, Field from app.agent.tools.base import MoviePilotTool from app.chain.media import MediaChain from app.log import logger -from app.schemas import MediaType +from app.schemas.types import MediaType class QueryMediaDetailInput(BaseModel): """查询媒体详情工具的输入参数模型""" explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context") - tmdb_id: int = Field(..., description="TMDB ID of the media (movie or TV series)") - media_type: str = Field(..., description="Media type: 'movie' or 'tv'") + tmdb_id: Optional[int] = Field(None, description="TMDB ID of the media (movie or TV series, can be obtained from search_media tool)") + douban_id: Optional[str] = Field(None, description="Douban ID of the media (alternative to tmdb_id)") + media_type: str = Field(..., description="Allowed values: movie, tv") class QueryMediaDetailTool(MoviePilotTool): name: str = "query_media_detail" - description: str = "Query detailed media information from TMDB by ID and media_type. IMPORTANT: Convert search results type: '电影'→'movie', '电视剧'→'tv'. Returns core metadata including title, year, overview, status, genres, directors, actors, and season count for TV series." + description: str = "Query supplementary media details from TMDB by ID and media_type. Accepts tmdb_id or douban_id (at least one required). media_type accepts 'movie' or 'tv'. Returns non-duplicated detail fields such as status, genres, directors, actors, and season info for TV series." args_schema: Type[BaseModel] = QueryMediaDetailInput def get_tool_message(self, **kwargs) -> Optional[str]: """根据查询参数生成友好的提示消息""" tmdb_id = kwargs.get("tmdb_id") - return f"正在查询媒体详情: TMDB ID {tmdb_id}" + douban_id = kwargs.get("douban_id") + if tmdb_id: + return f"正在查询媒体详情: TMDB ID {tmdb_id}" + return f"正在查询媒体详情: 豆瓣 ID {douban_id}" - async def run(self, tmdb_id: int, media_type: str, **kwargs) -> str: - logger.info(f"执行工具: {self.name}, 参数: tmdb_id={tmdb_id}, media_type={media_type}") + async def run(self, media_type: str, tmdb_id: Optional[int] = None, douban_id: Optional[str] = None, **kwargs) -> str: + logger.info(f"执行工具: {self.name}, 参数: tmdb_id={tmdb_id}, douban_id={douban_id}, media_type={media_type}") + + if tmdb_id is None and douban_id is None: + return json.dumps({ + "success": False, + "message": "必须提供 tmdb_id 或 douban_id 之一" + }, ensure_ascii=False) try: media_chain = MediaChain() - mtype = None - if media_type: - if media_type.lower() == 'movie': - mtype = MediaType.MOVIE - elif media_type.lower() == 'tv': - mtype = MediaType.TV - - mediainfo = await media_chain.async_recognize_media(tmdbid=tmdb_id, mtype=mtype) - - if not mediainfo: + media_type_enum = MediaType.from_agent(media_type) + if not media_type_enum: return json.dumps({ "success": False, - "message": f"未找到 TMDB ID {tmdb_id} 的媒体信息" + "message": f"无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv'" + }, ensure_ascii=False) + + mediainfo = await media_chain.async_recognize_media(tmdbid=tmdb_id, doubanid=douban_id, mtype=media_type_enum) + + if not mediainfo: + id_info = f"TMDB ID {tmdb_id}" if tmdb_id else f"豆瓣 ID {douban_id}" + return json.dumps({ + "success": False, + "message": f"未找到 {id_info} 的媒体信息" }, ensure_ascii=False) # 精简 genres - 只保留名称 @@ -74,12 +85,6 @@ class QueryMediaDetailTool(MoviePilotTool): # 构建基础媒体详情信息 result = { - "success": True, - "tmdb_id": tmdb_id, - "type": mediainfo.type.value if mediainfo.type else None, - "title": mediainfo.title, - "year": mediainfo.year, - "overview": mediainfo.overview, "status": mediainfo.status, "genres": genres, "directors": directors, @@ -116,5 +121,6 @@ class QueryMediaDetailTool(MoviePilotTool): return json.dumps({ "success": False, "message": error_message, - "tmdb_id": tmdb_id + "tmdb_id": tmdb_id, + "douban_id": douban_id }, ensure_ascii=False) diff --git a/app/agent/tools/impl/query_popular_subscribes.py b/app/agent/tools/impl/query_popular_subscribes.py index a7b139f5..5243aabc 100644 --- a/app/agent/tools/impl/query_popular_subscribes.py +++ b/app/agent/tools/impl/query_popular_subscribes.py @@ -10,13 +10,13 @@ from app.agent.tools.base import MoviePilotTool from app.core.context import MediaInfo from app.helper.subscribe import SubscribeHelper from app.log import logger -from app.schemas.types import MediaType +from app.schemas.types import MediaType, media_type_to_agent class QueryPopularSubscribesInput(BaseModel): """查询热门订阅工具的输入参数模型""" explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context") - stype: str = Field(..., description="Media type: '电影' for films, '电视剧' for television series") + media_type: str = Field(..., description="Allowed values: movie, tv") page: Optional[int] = Field(1, description="Page number for pagination (default: 1)") count: Optional[int] = Field(30, description="Number of items per page (default: 30)") min_sub: Optional[int] = Field(None, description="Minimum number of subscribers filter (optional, e.g., 5)") @@ -33,13 +33,13 @@ class QueryPopularSubscribesTool(MoviePilotTool): def get_tool_message(self, **kwargs) -> Optional[str]: """根据查询参数生成友好的提示消息""" - stype = kwargs.get("stype", "") + media_type = kwargs.get("media_type", "") page = kwargs.get("page", 1) min_sub = kwargs.get("min_sub") min_rating = kwargs.get("min_rating") max_rating = kwargs.get("max_rating") - parts = [f"正在查询热门订阅 [{stype}]"] + parts = [f"正在查询热门订阅 [{media_type}]"] if min_sub: parts.append(f"最少订阅: {min_sub}") @@ -52,7 +52,7 @@ class QueryPopularSubscribesTool(MoviePilotTool): return " | ".join(parts) if len(parts) > 1 else parts[0] - async def run(self, stype: str, + async def run(self, media_type: str, page: Optional[int] = 1, count: Optional[int] = 30, min_sub: Optional[int] = None, @@ -61,7 +61,7 @@ class QueryPopularSubscribesTool(MoviePilotTool): max_rating: Optional[float] = None, sort_type: Optional[str] = None, **kwargs) -> str: logger.info( - f"执行工具: {self.name}, 参数: stype={stype}, page={page}, count={count}, min_sub={min_sub}, " + f"执行工具: {self.name}, 参数: media_type={media_type}, page={page}, count={count}, min_sub={min_sub}, " f"genre_id={genre_id}, min_rating={min_rating}, max_rating={max_rating}, sort_type={sort_type}") try: @@ -69,10 +69,13 @@ class QueryPopularSubscribesTool(MoviePilotTool): page = 1 if count is None or count < 1: count = 30 + media_type_enum = MediaType.from_agent(media_type) + if not media_type_enum: + return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv'" subscribe_helper = SubscribeHelper() subscribes = await subscribe_helper.async_get_statistic( - stype=stype, + stype=media_type_enum.to_agent(), page=page, count=count, genre_id=genre_id, @@ -94,7 +97,15 @@ class QueryPopularSubscribesTool(MoviePilotTool): continue media = MediaInfo() - media.type = MediaType(sub.get("type")) + raw_type = str(sub.get("type") or "").strip().lower() + if raw_type in ["movie", "电影"]: + media.type = MediaType.MOVIE + elif raw_type in ["tv", "电视剧"]: + media.type = MediaType.TV + else: + # 跳过无法识别类型的数据,避免单条脏数据导致整批失败 + logger.warning(f"跳过未知媒体类型: {sub.get('type')}") + continue media.tmdb_id = sub.get("tmdbid") # 处理标题 title = sub.get("name") @@ -124,7 +135,7 @@ class QueryPopularSubscribesTool(MoviePilotTool): for media in ret_medias: media_dict = media.to_dict() simplified = { - "type": media_dict.get("type"), + "type": media_type_to_agent(media_dict.get("type")), "title": media_dict.get("title"), "year": media_dict.get("year"), "tmdb_id": media_dict.get("tmdb_id"), diff --git a/app/agent/tools/impl/query_site_userdata.py b/app/agent/tools/impl/query_site_userdata.py index 4fd0e395..e4ea4aa5 100644 --- a/app/agent/tools/impl/query_site_userdata.py +++ b/app/agent/tools/impl/query_site_userdata.py @@ -15,7 +15,7 @@ from app.log import logger class QuerySiteUserdataInput(BaseModel): """查询站点用户数据工具的输入参数模型""" explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context") - site_id: int = Field(..., description="The ID of the site to query user data for") + site_id: int = Field(..., description="The ID of the site to query user data for (can be obtained from query_sites tool)") workdate: Optional[str] = Field(None, description="Work date to query (optional, format: 'YYYY-MM-DD', if not specified returns latest data)") diff --git a/app/agent/tools/impl/query_subscribe_history.py b/app/agent/tools/impl/query_subscribe_history.py index b4fa1a00..f0cc51f1 100644 --- a/app/agent/tools/impl/query_subscribe_history.py +++ b/app/agent/tools/impl/query_subscribe_history.py @@ -9,12 +9,13 @@ from app.agent.tools.base import MoviePilotTool from app.db import AsyncSessionFactory from app.db.models.subscribehistory import SubscribeHistory from app.log import logger +from app.schemas.types import media_type_to_agent class QuerySubscribeHistoryInput(BaseModel): """查询订阅历史工具的输入参数模型""" explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context") - media_type: Optional[str] = Field("all", description="Filter by media type: '电影' for films, '电视剧' for television series, 'all' for all types (default: 'all')") + media_type: Optional[str] = Field("all", description="Allowed values: movie, tv, all") name: Optional[str] = Field(None, description="Filter by media name (partial match, optional)") @@ -42,6 +43,9 @@ class QuerySubscribeHistoryTool(MoviePilotTool): logger.info(f"执行工具: {self.name}, 参数: media_type={media_type}, name={name}") try: + if media_type not in ["all", "movie", "tv"]: + return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv', 'all'" + # 获取数据库会话 async with AsyncSessionFactory() as db: # 根据类型查询 @@ -80,7 +84,7 @@ class QuerySubscribeHistoryTool(MoviePilotTool): "id": record.id, "name": record.name, "year": record.year, - "type": record.type, + "type": media_type_to_agent(record.type), "season": record.season, "tmdbid": record.tmdbid, "doubanid": record.doubanid, diff --git a/app/agent/tools/impl/query_subscribes.py b/app/agent/tools/impl/query_subscribes.py index ee383a8a..ec594fe4 100644 --- a/app/agent/tools/impl/query_subscribes.py +++ b/app/agent/tools/impl/query_subscribes.py @@ -8,6 +8,7 @@ from pydantic import BaseModel, Field from app.agent.tools.base import MoviePilotTool from app.db.subscribe_oper import SubscribeOper from app.log import logger +from app.schemas.types import MediaType, media_type_to_agent class QuerySubscribesInput(BaseModel): @@ -16,7 +17,9 @@ class QuerySubscribesInput(BaseModel): status: Optional[str] = Field("all", description="Filter subscriptions by status: 'R' for enabled subscriptions, 'S' for paused ones, 'all' for all subscriptions") media_type: Optional[str] = Field("all", - description="Filter by media type: '电影' for films, '电视剧' for television series, 'all' for all types") + description="Allowed values: movie, tv, all") + tmdb_id: Optional[int] = Field(None, description="Filter by TMDB ID to check if a specific media is already subscribed") + douban_id: Optional[str] = Field(None, description="Filter by Douban ID to check if a specific media is already subscribed") class QuerySubscribesTool(MoviePilotTool): @@ -42,16 +45,24 @@ class QuerySubscribesTool(MoviePilotTool): return " | ".join(parts) if len(parts) > 1 else parts[0] - async def run(self, status: Optional[str] = "all", media_type: Optional[str] = "all", **kwargs) -> str: - logger.info(f"执行工具: {self.name}, 参数: status={status}, media_type={media_type}") + async def run(self, status: Optional[str] = "all", media_type: Optional[str] = "all", + tmdb_id: Optional[int] = None, douban_id: Optional[str] = None, **kwargs) -> str: + logger.info(f"执行工具: {self.name}, 参数: status={status}, media_type={media_type}, tmdb_id={tmdb_id}, douban_id={douban_id}") try: + if media_type != "all" and not MediaType.from_agent(media_type): + return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv', 'all'" + subscribe_oper = SubscribeOper() subscribes = await subscribe_oper.async_list() filtered_subscribes = [] for sub in subscribes: if status != "all" and sub.state != status: continue - if media_type != "all" and sub.type != media_type: + if media_type != "all" and sub.type != MediaType.from_agent(media_type).value: + continue + if tmdb_id is not None and sub.tmdbid != tmdb_id: + continue + if douban_id is not None and sub.doubanid != douban_id: continue filtered_subscribes.append(sub) if filtered_subscribes: @@ -65,7 +76,7 @@ class QuerySubscribesTool(MoviePilotTool): "id": s.id, "name": s.name, "year": s.year, - "type": s.type, + "type": media_type_to_agent(s.type), "season": s.season, "tmdbid": s.tmdbid, "doubanid": s.doubanid, diff --git a/app/agent/tools/impl/query_transfer_history.py b/app/agent/tools/impl/query_transfer_history.py index 158031fb..b228d01b 100644 --- a/app/agent/tools/impl/query_transfer_history.py +++ b/app/agent/tools/impl/query_transfer_history.py @@ -10,6 +10,7 @@ from app.agent.tools.base import MoviePilotTool from app.db import AsyncSessionFactory from app.db.models.transferhistory import TransferHistory from app.log import logger +from app.schemas.types import media_type_to_agent class QueryTransferHistoryInput(BaseModel): @@ -95,7 +96,7 @@ class QueryTransferHistoryTool(MoviePilotTool): "id": record.id, "title": record.title, "year": record.year, - "type": record.type, + "type": media_type_to_agent(record.type), "category": record.category, "seasons": record.seasons, "episodes": record.episodes, diff --git a/app/agent/tools/impl/recognize_media.py b/app/agent/tools/impl/recognize_media.py index 82dea564..2f852770 100644 --- a/app/agent/tools/impl/recognize_media.py +++ b/app/agent/tools/impl/recognize_media.py @@ -10,6 +10,7 @@ from app.chain.media import MediaChain from app.core.context import Context from app.core.metainfo import MetaInfo from app.log import logger +from app.schemas.types import media_type_to_agent class RecognizeMediaInput(BaseModel): @@ -124,7 +125,7 @@ class RecognizeMediaTool(MoviePilotTool): "title": media_info.get("title"), "en_title": media_info.get("en_title"), "year": media_info.get("year"), - "type": media_info.get("type"), + "type": media_type_to_agent(media_info.get("type")), "season": media_info.get("season"), "tmdb_id": media_info.get("tmdb_id"), "imdb_id": media_info.get("imdb_id"), @@ -145,7 +146,7 @@ class RecognizeMediaTool(MoviePilotTool): "name": meta_info.get("name"), "title": meta_info.get("title"), "year": meta_info.get("year"), - "type": meta_info.get("type"), + "type": media_type_to_agent(meta_info.get("type")), "begin_season": meta_info.get("begin_season"), "end_season": meta_info.get("end_season"), "begin_episode": meta_info.get("begin_episode"), diff --git a/app/agent/tools/impl/run_workflow.py b/app/agent/tools/impl/run_workflow.py index 98692237..8e20f2bf 100644 --- a/app/agent/tools/impl/run_workflow.py +++ b/app/agent/tools/impl/run_workflow.py @@ -14,21 +14,21 @@ from app.log import logger class RunWorkflowInput(BaseModel): """执行工作流工具的输入参数模型""" explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context") - workflow_identifier: str = Field(..., description="Workflow identifier: can be workflow ID (integer as string) or workflow name") + workflow_id: int = Field(..., description="Workflow ID (can be obtained from query_workflows tool)") from_begin: Optional[bool] = Field(True, description="Whether to run workflow from the beginning (default: True, if False will continue from last executed action)") class RunWorkflowTool(MoviePilotTool): name: str = "run_workflow" - description: str = "Execute a specific workflow manually. Can run workflow by ID or name. Supports running from the beginning or continuing from the last executed action." + description: str = "Execute a specific workflow manually by workflow ID. Supports running from the beginning or continuing from the last executed action." args_schema: Type[BaseModel] = RunWorkflowInput def get_tool_message(self, **kwargs) -> Optional[str]: """根据工作流参数生成友好的提示消息""" - workflow_identifier = kwargs.get("workflow_identifier", "") + workflow_id = kwargs.get("workflow_id") from_begin = kwargs.get("from_begin", True) - message = f"正在执行工作流: {workflow_identifier}" + message = f"正在执行工作流: {workflow_id}" if not from_begin: message += " (从上次位置继续)" else: @@ -36,27 +36,18 @@ class RunWorkflowTool(MoviePilotTool): return message - async def run(self, workflow_identifier: str, + async def run(self, workflow_id: int, from_begin: Optional[bool] = True, **kwargs) -> str: - logger.info(f"执行工具: {self.name}, 参数: workflow_identifier={workflow_identifier}, from_begin={from_begin}") + logger.info(f"执行工具: {self.name}, 参数: workflow_id={workflow_id}, from_begin={from_begin}") try: # 获取数据库会话 async with AsyncSessionFactory() as db: workflow_oper = WorkflowOper(db) - - # 尝试解析为工作流ID - workflow = None - if workflow_identifier.isdigit(): - # 如果是数字,尝试作为工作流ID查询 - workflow = await workflow_oper.async_get(int(workflow_identifier)) - - # 如果不是ID或ID查询失败,尝试按名称查询 - if not workflow: - workflow = await workflow_oper.async_get_by_name(workflow_identifier) + workflow = await workflow_oper.async_get(workflow_id) if not workflow: - return f"未找到工作流:{workflow_identifier},请使用 query_workflows 工具查询可用的工作流" + return f"未找到工作流:{workflow_id},请使用 query_workflows 工具查询可用的工作流" # 执行工作流 workflow_chain = WorkflowChain() diff --git a/app/agent/tools/impl/search_media.py b/app/agent/tools/impl/search_media.py index b1abf57a..4b15c5bb 100644 --- a/app/agent/tools/impl/search_media.py +++ b/app/agent/tools/impl/search_media.py @@ -8,7 +8,7 @@ from pydantic import BaseModel, Field from app.agent.tools.base import MoviePilotTool from app.chain.media import MediaChain from app.log import logger -from app.schemas.types import MediaType +from app.schemas.types import MediaType, media_type_to_agent class SearchMediaInput(BaseModel): @@ -17,7 +17,7 @@ class SearchMediaInput(BaseModel): title: str = Field(..., description="The title of the media to search for (e.g., 'The Matrix', 'Breaking Bad')") year: Optional[str] = Field(None, description="Release year of the media (optional, helps narrow down results)") media_type: Optional[str] = Field(None, - description="Type of media content: '电影' for films, '电视剧' for television series or anime series") + description="Allowed values: movie, tv") season: Optional[int] = Field(None, description="Season number for TV shows and anime (optional, only applicable for series)") @@ -56,13 +56,18 @@ class SearchMediaTool(MoviePilotTool): # 过滤结果 if results: + media_type_enum = None + if media_type: + media_type_enum = MediaType.from_agent(media_type) + if not media_type_enum: + return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv'" + filtered_results = [] for result in results: if year and result.year != year: continue - if media_type: - if result.type != MediaType(media_type): - continue + if media_type_enum and result.type != media_type_enum: + continue if season is not None and result.season != season: continue filtered_results.append(result) @@ -78,7 +83,7 @@ class SearchMediaTool(MoviePilotTool): "title": r.title, "en_title": r.en_title, "year": r.year, - "type": r.type.value if r.type else None, + "type": media_type_to_agent(r.type), "season": r.season, "tmdb_id": r.tmdb_id, "imdb_id": r.imdb_id, diff --git a/app/agent/tools/impl/search_subscribe.py b/app/agent/tools/impl/search_subscribe.py index ebde39cb..b53783c4 100644 --- a/app/agent/tools/impl/search_subscribe.py +++ b/app/agent/tools/impl/search_subscribe.py @@ -10,15 +10,16 @@ from app.chain.subscribe import SubscribeChain from app.core.config import global_vars from app.db.subscribe_oper import SubscribeOper from app.log import logger +from app.schemas.types import media_type_to_agent class SearchSubscribeInput(BaseModel): """搜索订阅缺失剧集工具的输入参数模型""" explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context") - subscribe_id: int = Field(..., description="The ID of the subscription to search for missing episodes") + subscribe_id: int = Field(..., description="The ID of the subscription to search for missing episodes (can be obtained from query_subscribes tool)") manual: Optional[bool] = Field(False, description="Whether this is a manual search (default: False)") filter_groups: Optional[List[str]] = Field(None, - description="List of filter rule group names to apply for this search (optional, use query_rule_groups tool to get available rule groups. If provided, will temporarily update the subscription's filter groups before searching)") + description="List of filter rule group names to apply for this search (optional, can be obtained from query_rule_groups tool. If provided, will temporarily update the subscription's filter groups before searching)") class SearchSubscribeTool(MoviePilotTool): @@ -58,7 +59,7 @@ class SearchSubscribeTool(MoviePilotTool): "id": subscribe.id, "name": subscribe.name, "year": subscribe.year, - "type": subscribe.type, + "type": media_type_to_agent(subscribe.type), "season": subscribe.season, "state": subscribe.state, "total_episode": subscribe.total_episode, diff --git a/app/agent/tools/impl/search_torrents.py b/app/agent/tools/impl/search_torrents.py index efe68f01..fb71fca7 100644 --- a/app/agent/tools/impl/search_torrents.py +++ b/app/agent/tools/impl/search_torrents.py @@ -1,164 +1,109 @@ """搜索种子工具""" import json -import re from typing import List, Optional, Type -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, Field from app.agent.tools.base import MoviePilotTool from app.chain.search import SearchChain +from app.db.systemconfig_oper import SystemConfigOper +from app.helper.sites import SitesHelper from app.log import logger -from app.schemas.types import MediaType -from app.utils.string import StringUtils +from app.schemas.types import MediaType, SystemConfigKey +from ._torrent_search_utils import ( + SEARCH_RESULT_CACHE_FILE, + build_filter_options, +) class SearchTorrentsInput(BaseModel): """搜索种子工具的输入参数模型""" explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context") - title: str = Field(..., - description="The title of the media resource to search for (e.g., 'The Matrix 1999', 'Breaking Bad S01E01')") - year: Optional[str] = Field(None, - description="Release year of the media (optional, helps narrow down search results)") - media_type: Optional[str] = Field(None, - description="Type of media content: '电影' for films, '电视剧' for television series or anime series") - season: Optional[int] = Field(None, description="Season number for TV shows (optional, only applicable for series)") + tmdb_id: Optional[int] = Field(None, description="TMDB ID (can be obtained from search_media tool). Either tmdb_id or douban_id must be provided.") + douban_id: Optional[str] = Field(None, description="Douban ID (can be obtained from search_media tool). Either tmdb_id or douban_id must be provided.") + media_type: Optional[str] = Field(None, description="Allowed values: movie, tv") + area: Optional[str] = Field(None, description="Search scope: 'title' (default) or 'imdbid'") sites: Optional[List[int]] = Field(None, description="Array of specific site IDs to search on (optional, if not provided searches all configured sites)") - filter_pattern: Optional[str] = Field(None, - description="Regular expression pattern to filter torrent titles by resolution, quality, or other keywords (e.g., '4K|2160p|UHD' for 4K content, '1080p|BluRay' for 1080p BluRay)") - - @field_validator("sites", mode="before") - @classmethod - def normalize_sites(cls, value): - """兼容字符串格式的站点列表(如 "[28]"、"28,30")""" - if value is None: - return value - if isinstance(value, str): - value = value.strip() - if not value: - return None - try: - parsed = json.loads(value) - if isinstance(parsed, list): - return parsed - except Exception: - pass - if "," in value: - return [v.strip() for v in value.split(",") if v.strip()] - if value.isdigit(): - return [value] - return value - class SearchTorrentsTool(MoviePilotTool): name: str = "search_torrents" - description: str = "Search for torrent files across configured indexer sites based on media information. Returns available torrent downloads with details like file size, quality, and download links." + description: str = ("Search for torrent files by media ID across configured indexer sites, cache the matched results, " + "and return available filter options for follow-up selection. " + "Requires tmdb_id or douban_id (can be obtained from search_media tool) for accurate matching.") args_schema: Type[BaseModel] = SearchTorrentsInput def get_tool_message(self, **kwargs) -> Optional[str]: """根据搜索参数生成友好的提示消息""" - title = kwargs.get("title", "") - year = kwargs.get("year") + tmdb_id = kwargs.get("tmdb_id") + douban_id = kwargs.get("douban_id") media_type = kwargs.get("media_type") - season = kwargs.get("season") - filter_pattern = kwargs.get("filter_pattern") - - message = f"正在搜索种子: {title}" - if year: - message += f" ({year})" + + if tmdb_id: + message = f"正在搜索种子: TMDB={tmdb_id}" + elif douban_id: + message = f"正在搜索种子: 豆瓣={douban_id}" + else: + message = "正在搜索种子" if media_type: message += f" [{media_type}]" - if season: - message += f" 第{season}季" - if filter_pattern: - message += f" 过滤: {filter_pattern}" - return message - async def run(self, title: str, year: Optional[str] = None, - media_type: Optional[str] = None, season: Optional[int] = None, - sites: Optional[List[int]] = None, filter_pattern: Optional[str] = None, **kwargs) -> str: + async def run(self, tmdb_id: Optional[int] = None, douban_id: Optional[str] = None, + media_type: Optional[str] = None, area: Optional[str] = None, + sites: Optional[List[int]] = None, **kwargs) -> str: logger.info( - f"执行工具: {self.name}, 参数: title={title}, year={year}, media_type={media_type}, season={season}, sites={sites}, filter_pattern={filter_pattern}") + f"执行工具: {self.name}, 参数: tmdb_id={tmdb_id}, douban_id={douban_id}, media_type={media_type}, area={area}, sites={sites}") + + if not tmdb_id and not douban_id: + return "参数错误:tmdb_id 和 douban_id 至少需要提供一个,请先使用 search_media 工具获取媒体 ID。" try: search_chain = SearchChain() - torrents = await search_chain.async_search_by_title(title=title, sites=sites) - filtered_torrents = [] - # 编译正则表达式(如果提供) - regex_pattern = None - if filter_pattern: - try: - regex_pattern = re.compile(filter_pattern, re.IGNORECASE) - except re.error as e: - logger.warning(f"正则表达式编译失败: {filter_pattern}, 错误: {e}") - return f"正则表达式格式错误: {str(e)}" - - for torrent in torrents: - # torrent 是 Context 对象,需要通过 meta_info 和 media_info 访问属性 - if year and torrent.meta_info and torrent.meta_info.year != year: - continue - if media_type and torrent.media_info: - if torrent.media_info.type != MediaType(media_type): - continue - if season is not None and torrent.meta_info and torrent.meta_info.begin_season != season: - continue - # 使用正则表达式过滤标题(分辨率、质量等关键字) - if regex_pattern and torrent.torrent_info and torrent.torrent_info.title: - if not regex_pattern.search(torrent.torrent_info.title): - continue - filtered_torrents.append(torrent) + media_type_enum = None + if media_type: + media_type_enum = MediaType.from_agent(media_type) + if not media_type_enum: + return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv'" + + filtered_torrents = await search_chain.async_search_by_id( + tmdbid=tmdb_id, + doubanid=douban_id, + mtype=media_type_enum, + area=area or "title", + sites=sites, + cache_local=False, + ) + + # 获取站点信息 + all_indexers = await SitesHelper().async_get_indexers() + all_sites = [{"id": indexer.get("id"), "name": indexer.get("name")} for indexer in (all_indexers or [])] + + if sites: + search_site_ids = sites + else: + configured_sites = SystemConfigOper().get(SystemConfigKey.IndexerSites) + search_site_ids = configured_sites if configured_sites else [] if filtered_torrents: - # 限制最多50条结果 - total_count = len(filtered_torrents) - limited_torrents = filtered_torrents[:50] - # 精简字段,只保留关键信息 - simplified_torrents = [] - for t in limited_torrents: - simplified = {} - # 精简 torrent_info - if t.torrent_info: - simplified["torrent_info"] = { - "title": t.torrent_info.title, - "size": StringUtils.format_size(t.torrent_info.size), - "seeders": t.torrent_info.seeders, - "peers": t.torrent_info.peers, - "site_name": t.torrent_info.site_name, - "enclosure": t.torrent_info.enclosure, - "page_url": t.torrent_info.page_url, - "volume_factor": t.torrent_info.volume_factor, - "pubdate": t.torrent_info.pubdate - } - # 精简 media_info - if t.media_info: - simplified["media_info"] = { - "title": t.media_info.title, - "en_title": t.media_info.en_title, - "year": t.media_info.year, - "type": t.media_info.type.value if t.media_info.type else None, - "season": t.media_info.season, - "tmdb_id": t.media_info.tmdb_id - } - # 精简 meta_info - if t.meta_info: - simplified["meta_info"] = { - "name": t.meta_info.name, - "cn_name": t.meta_info.cn_name, - "en_name": t.meta_info.en_name, - "year": t.meta_info.year, - "type": t.meta_info.type.value if t.meta_info.type else None, - "begin_season": t.meta_info.begin_season - } - simplified_torrents.append(simplified) - result_json = json.dumps(simplified_torrents, ensure_ascii=False, indent=2) - # 如果结果被裁剪,添加提示信息 - if total_count > 50: - return f"注意:搜索结果共找到 {total_count} 条,为节省上下文空间,仅显示前 50 条结果。\n\n{result_json}" + await search_chain.async_save_cache(filtered_torrents, SEARCH_RESULT_CACHE_FILE) + result_json = json.dumps({ + "total_count": len(filtered_torrents), + "message": "搜索完成。请使用 get_search_results 工具获取搜索结果。", + "all_sites": all_sites, + "search_site_ids": search_site_ids, + "filter_options": build_filter_options(filtered_torrents), + }, ensure_ascii=False, indent=2) return result_json else: - return f"未找到相关种子资源: {title}" + media_id = f"TMDB={tmdb_id}" if tmdb_id else f"豆瓣={douban_id}" + result_json = json.dumps({ + "message": f"未找到相关种子资源: {media_id}", + "all_sites": all_sites, + "search_site_ids": search_site_ids, + }, ensure_ascii=False, indent=2) + return result_json except Exception as e: error_message = f"搜索种子时发生错误: {str(e)}" logger.error(f"搜索种子失败: {e}", exc_info=True) diff --git a/app/agent/tools/impl/test_site.py b/app/agent/tools/impl/test_site.py index ee61d34d..4ed8343e 100644 --- a/app/agent/tools/impl/test_site.py +++ b/app/agent/tools/impl/test_site.py @@ -8,53 +8,31 @@ from app.agent.tools.base import MoviePilotTool from app.chain.site import SiteChain from app.db.site_oper import SiteOper from app.log import logger -from app.utils.string import StringUtils class TestSiteInput(BaseModel): """测试站点连通性工具的输入参数模型""" explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context") - site_identifier: str = Field(..., description="Site identifier: can be site ID (integer as string), site name, or site domain/URL") + site_identifier: int = Field(..., description="Site ID to test (can be obtained from query_sites tool)") class TestSiteTool(MoviePilotTool): name: str = "test_site" - description: str = "Test site connectivity and availability. This will check if a site is accessible and can be logged in. Accepts site ID, site name, or site domain/URL as identifier." + description: str = "Test site connectivity and availability. This will check if a site is accessible and can be logged in. Accepts site ID only." args_schema: Type[BaseModel] = TestSiteInput def get_tool_message(self, **kwargs) -> Optional[str]: """根据测试参数生成友好的提示消息""" - site_identifier = kwargs.get("site_identifier", "") + site_identifier = kwargs.get("site_identifier") return f"正在测试站点连通性: {site_identifier}" - async def run(self, site_identifier: str, **kwargs) -> str: + async def run(self, site_identifier: int, **kwargs) -> str: logger.info(f"执行工具: {self.name}, 参数: site_identifier={site_identifier}") try: site_oper = SiteOper() site_chain = SiteChain() - - # 尝试解析为站点ID - site = None - if site_identifier.isdigit(): - # 如果是数字,尝试作为站点ID查询 - site = await site_oper.async_get(int(site_identifier)) - - # 如果不是ID或ID查询失败,尝试按名称或域名查询 - if not site: - # 尝试按名称查询 - sites = await site_oper.async_list() - for s in sites: - if (site_identifier.lower() in (s.name or "").lower()) or \ - (site_identifier.lower() in (s.domain or "").lower()): - site = s - break - - # 如果还是没找到,尝试从URL提取域名 - if not site: - domain = StringUtils.get_url_domain(site_identifier) - if domain: - site = await site_oper.async_get_by_domain(domain) + site = await site_oper.async_get(site_identifier) if not site: return f"未找到站点:{site_identifier},请使用 query_sites 工具查询可用的站点" diff --git a/app/agent/tools/impl/transfer_file.py b/app/agent/tools/impl/transfer_file.py index 4684ce2c..ff17911c 100644 --- a/app/agent/tools/impl/transfer_file.py +++ b/app/agent/tools/impl/transfer_file.py @@ -18,7 +18,7 @@ class TransferFileInput(BaseModel): storage: Optional[str] = Field("local", description="Storage type of the source file (default: 'local', can be 'smb', 'alist', etc.)") target_path: Optional[str] = Field(None, description="Target path for the transferred file/directory (optional, uses default library path if not specified)") target_storage: Optional[str] = Field(None, description="Target storage type (optional, uses default storage if not specified)") - media_type: Optional[str] = Field(None, description="Media type: '电影' for films, '电视剧' for television series (optional, will be auto-detected if not specified)") + media_type: Optional[str] = Field(None, description="Allowed values: movie, tv") tmdbid: Optional[int] = Field(None, description="TMDB ID for precise media identification (optional but recommended for accuracy)") doubanid: Optional[str] = Field(None, description="Douban ID for media identification (optional)") season: Optional[int] = Field(None, description="Season number for TV shows (optional)") @@ -91,11 +91,10 @@ class TransferFileTool(MoviePilotTool): target_path_obj = Path(target_path) # 处理媒体类型 - mtype = None + media_type_enum = None if media_type: - try: - mtype = MediaType(media_type) - except ValueError: + media_type_enum = MediaType.from_agent(media_type) + if not media_type_enum: return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv'" # 调用整理方法 @@ -106,7 +105,7 @@ class TransferFileTool(MoviePilotTool): target_path=target_path_obj, tmdbid=tmdbid, doubanid=doubanid, - mtype=mtype, + mtype=media_type_enum, season=season, transfer_type=transfer_type, background=background diff --git a/app/agent/tools/impl/update_site.py b/app/agent/tools/impl/update_site.py index 59d5349b..a9c80643 100644 --- a/app/agent/tools/impl/update_site.py +++ b/app/agent/tools/impl/update_site.py @@ -17,7 +17,7 @@ from app.utils.string import StringUtils class UpdateSiteInput(BaseModel): """更新站点工具的输入参数模型""" explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context") - site_id: int = Field(..., description="The ID of the site to update") + site_id: int = Field(..., description="The ID of the site to update (can be obtained from query_sites tool)") name: Optional[str] = Field(None, description="Site name (optional)") url: Optional[str] = Field(None, description="Site URL (optional, will be automatically formatted)") pri: Optional[int] = Field(None, description="Site priority (optional, smaller value = higher priority, e.g., pri=1 has higher priority than pri=10)") diff --git a/app/agent/tools/impl/update_site_cookie.py b/app/agent/tools/impl/update_site_cookie.py index c93b5a25..f91b706f 100644 --- a/app/agent/tools/impl/update_site_cookie.py +++ b/app/agent/tools/impl/update_site_cookie.py @@ -8,13 +8,12 @@ from app.agent.tools.base import MoviePilotTool from app.chain.site import SiteChain from app.db.site_oper import SiteOper from app.log import logger -from app.utils.string import StringUtils class UpdateSiteCookieInput(BaseModel): """更新站点Cookie和UA工具的输入参数模型""" explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context") - site_identifier: str = Field(..., description="Site identifier: can be site ID (integer as string), site name, or site domain/URL") + site_identifier: int = Field(..., description="Site ID to update Cookie and User-Agent for (can be obtained from query_sites tool)") username: str = Field(..., description="Site login username") password: str = Field(..., description="Site login password") two_step_code: Optional[str] = Field(None, description="Two-step verification code or secret key (optional, required for sites with 2FA enabled)") @@ -22,12 +21,12 @@ class UpdateSiteCookieInput(BaseModel): class UpdateSiteCookieTool(MoviePilotTool): name: str = "update_site_cookie" - description: str = "Update site Cookie and User-Agent by logging in with username and password. This tool can automatically obtain and update the site's authentication credentials. Supports two-step verification for sites that require it. Accepts site ID, site name, or site domain/URL as identifier." + description: str = "Update site Cookie and User-Agent by logging in with username and password. This tool can automatically obtain and update the site's authentication credentials. Supports two-step verification for sites that require it. Accepts site ID only." args_schema: Type[BaseModel] = UpdateSiteCookieInput def get_tool_message(self, **kwargs) -> Optional[str]: """根据更新参数生成友好的提示消息""" - site_identifier = kwargs.get("site_identifier", "") + site_identifier = kwargs.get("site_identifier") username = kwargs.get("username", "") two_step_code = kwargs.get("two_step_code") @@ -37,35 +36,14 @@ class UpdateSiteCookieTool(MoviePilotTool): return message - async def run(self, site_identifier: str, username: str, password: str, + async def run(self, site_identifier: int, username: str, password: str, two_step_code: Optional[str] = None, **kwargs) -> str: logger.info(f"执行工具: {self.name}, 参数: site_identifier={site_identifier}, username={username}") try: site_oper = SiteOper() site_chain = SiteChain() - - # 尝试解析为站点ID - site = None - if site_identifier.isdigit(): - # 如果是数字,尝试作为站点ID查询 - site = await site_oper.async_get(int(site_identifier)) - - # 如果不是ID或ID查询失败,尝试按名称或域名查询 - if not site: - # 尝试按名称查询 - sites = await site_oper.async_list() - for s in sites: - if (site_identifier.lower() in (s.name or "").lower()) or \ - (site_identifier.lower() in (s.domain or "").lower()): - site = s - break - - # 如果还是没找到,尝试从URL提取域名 - if not site: - domain = StringUtils.get_url_domain(site_identifier) - if domain: - site = await site_oper.async_get_by_domain(domain) + site = await site_oper.async_get(site_identifier) if not site: return f"未找到站点:{site_identifier},请使用 query_sites 工具查询可用的站点" diff --git a/app/agent/tools/impl/update_subscribe.py b/app/agent/tools/impl/update_subscribe.py index 362a8b2a..9e635598 100644 --- a/app/agent/tools/impl/update_subscribe.py +++ b/app/agent/tools/impl/update_subscribe.py @@ -16,7 +16,7 @@ from app.schemas.types import EventType class UpdateSubscribeInput(BaseModel): """更新订阅工具的输入参数模型""" explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context") - subscribe_id: int = Field(..., description="The ID of the subscription to update") + subscribe_id: int = Field(..., description="The ID of the subscription to update (can be obtained from query_subscribes tool)") name: Optional[str] = Field(None, description="Subscription name/title (optional)") year: Optional[str] = Field(None, description="Release year (optional)") season: Optional[int] = Field(None, description="Season number for TV shows (optional)") diff --git a/app/agent/tools/manager.py b/app/agent/tools/manager.py index d86e3e23..5ea24a19 100644 --- a/app/agent/tools/manager.py +++ b/app/agent/tools/manager.py @@ -99,6 +99,69 @@ class MoviePilotToolsManager: return tool return None + @staticmethod + def _resolve_field_schema(field_info: Dict[str, Any]) -> Dict[str, Any]: + """ + 解析字段schema,兼容 Optional[T] 生成的 anyOf 结构 + """ + if field_info.get("type"): + return field_info + + any_of = field_info.get("anyOf") + if not any_of: + return field_info + + for type_option in any_of: + if type_option.get("type") and type_option["type"] != "null": + merged = dict(type_option) + if "description" not in merged and field_info.get("description"): + merged["description"] = field_info["description"] + if "default" not in merged and "default" in field_info: + merged["default"] = field_info["default"] + return merged + + return field_info + + @staticmethod + def _normalize_scalar_value(field_type: Optional[str], value: Any, key: str) -> Any: + """ + 根据字段类型规范化单个值 + """ + if field_type == "integer" and isinstance(value, str): + try: + return int(value) + except (ValueError, TypeError): + logger.warning(f"无法将参数 {key}='{value}' 转换为整数,返回 None") + return None + if field_type == "number" and isinstance(value, str): + try: + return float(value) + except (ValueError, TypeError): + logger.warning(f"无法将参数 {key}='{value}' 转换为浮点数,返回 None") + return None + if field_type == "boolean": + if isinstance(value, str): + return value.lower() in ("true", "1", "yes", "on") + if isinstance(value, (int, float)): + return value != 0 + if isinstance(value, bool): + return value + return True + return value + + @staticmethod + def _parse_array_string(value: str, key: str, item_type: str = "string") -> list: + """ + 将逗号分隔的字符串解析为列表,并根据 item_type 转换元素类型 + """ + trimmed = value.strip() + if not trimmed: + return [] + return [ + MoviePilotToolsManager._normalize_scalar_value(item_type, item.strip(), key) + for item in trimmed.split(",") if item.strip() + ] + @staticmethod def _normalize_arguments(tool_instance: Any, arguments: Dict[str, Any]) -> Dict[str, Any]: """ @@ -132,40 +195,17 @@ class MoviePilotToolsManager: normalized[key] = value continue - field_info = properties[key] + field_info = MoviePilotToolsManager._resolve_field_schema(properties[key]) field_type = field_info.get("type") - # 处理 anyOf 类型(例如 Optional[int] 会生成 anyOf) - any_of = field_info.get("anyOf") - if any_of and not field_type: - # 从 anyOf 中提取实际类型 - for type_option in any_of: - if "type" in type_option and type_option["type"] != "null": - field_type = type_option["type"] - break + # 数组类型:将字符串解析为列表 + if field_type == "array" and isinstance(value, str): + item_type = field_info.get("items", {}).get("type", "string") + normalized[key] = MoviePilotToolsManager._parse_array_string(value, key, item_type) + continue # 根据类型进行转换 - if field_type == "integer" and isinstance(value, str): - try: - normalized[key] = int(value) - except (ValueError, TypeError): - logger.warning(f"无法将参数 {key}='{value}' 转换为整数,保持原值") - normalized[key] = None - elif field_type == "number" and isinstance(value, str): - try: - normalized[key] = float(value) - except (ValueError, TypeError): - logger.warning(f"无法将参数 {key}='{value}' 转换为浮点数,保持原值") - normalized[key] = None - elif field_type == "boolean": - if isinstance(value, str): - normalized[key] = value.lower() in ("true", "1", "yes", "on") - elif isinstance(value, (int, float)): - normalized[key] = value != 0 - else: - normalized[key] = True - else: - normalized[key] = value + normalized[key] = MoviePilotToolsManager._normalize_scalar_value(field_type, value, key) return normalized @@ -235,14 +275,15 @@ class MoviePilotToolsManager: if "properties" in schema: for field_name, field_info in schema["properties"].items(): + resolved_field_info = MoviePilotToolsManager._resolve_field_schema(field_info) # 转换字段类型 - field_type = field_info.get("type", "string") - field_description = field_info.get("description", "") + field_type = resolved_field_info.get("type", "string") + field_description = resolved_field_info.get("description", "") # 处理可选字段 if field_name not in schema.get("required", []): # 可选字段 - default_value = field_info.get("default") + default_value = resolved_field_info.get("default") properties[field_name] = { "type": field_type, "description": field_description @@ -257,12 +298,12 @@ class MoviePilotToolsManager: required.append(field_name) # 处理枚举类型 - if "enum" in field_info: - properties[field_name]["enum"] = field_info["enum"] + if "enum" in resolved_field_info: + properties[field_name]["enum"] = resolved_field_info["enum"] # 处理数组类型 - if field_type == "array" and "items" in field_info: - properties[field_name]["items"] = field_info["items"] + if field_type == "array" and "items" in resolved_field_info: + properties[field_name]["items"] = resolved_field_info["items"] return { "type": "object", diff --git a/app/api/endpoints/mcp.py b/app/api/endpoints/mcp.py index 7da7634d..d949b346 100644 --- a/app/api/endpoints/mcp.py +++ b/app/api/endpoints/mcp.py @@ -19,6 +19,17 @@ router = APIRouter() # MCP 协议版本 MCP_PROTOCOL_VERSIONS = ["2025-11-25", "2025-06-18", "2024-11-05"] MCP_PROTOCOL_VERSION = MCP_PROTOCOL_VERSIONS[0] # 默认使用最新版本 +MCP_HIDDEN_TOOLS = {"execute_command", "search_web"} + + +def list_exposed_tools(): + """ + 获取 MCP 可见工具列表 + """ + return [ + tool for tool in moviepilot_tool_manager.list_tools() + if tool.name not in MCP_HIDDEN_TOOLS + ] def create_jsonrpc_response(request_id: Union[str, int, None], result: Any) -> Dict[str, Any]: @@ -174,7 +185,7 @@ async def handle_tools_list() -> Dict[str, Any]: """ 处理工具列表请求 """ - tools = moviepilot_tool_manager.list_tools() + tools = list_exposed_tools() # 转换为 MCP 工具格式 mcp_tools = [] @@ -202,6 +213,9 @@ async def handle_tools_call(params: Dict[str, Any]) -> Dict[str, Any]: raise ValueError("Missing tool name") try: + if tool_name in MCP_HIDDEN_TOOLS: + raise ValueError(f"工具 '{tool_name}' 未找到") + result_text = await moviepilot_tool_manager.call_tool(tool_name, arguments) return { @@ -248,7 +262,7 @@ async def list_tools( """ try: # 获取所有工具定义 - tools = moviepilot_tool_manager.list_tools() + tools = list_exposed_tools() # 转换为字典格式 tools_list = [] @@ -278,7 +292,9 @@ async def call_tool( 工具执行结果 """ try: - # 调用工具 + if request.tool_name in MCP_HIDDEN_TOOLS: + raise ValueError(f"工具 '{request.tool_name}' 未找到") + result_text = await moviepilot_tool_manager.call_tool(request.tool_name, request.arguments) return schemas.ToolCallResponse( @@ -306,7 +322,7 @@ async def get_tool_info( """ try: # 获取所有工具 - tools = moviepilot_tool_manager.list_tools() + tools = list_exposed_tools() # 查找指定工具 for tool in tools: @@ -338,7 +354,7 @@ async def get_tool_schema( """ try: # 获取所有工具 - tools = moviepilot_tool_manager.list_tools() + tools = list_exposed_tools() # 查找指定工具 for tool in tools: diff --git a/app/schemas/types.py b/app/schemas/types.py index ad0a714f..e8e4ddc3 100644 --- a/app/schemas/types.py +++ b/app/schemas/types.py @@ -1,4 +1,5 @@ from enum import Enum +from typing import Optional # 媒体类型 @@ -8,6 +9,26 @@ class MediaType(Enum): COLLECTION = '系列' UNKNOWN = '未知' + @staticmethod + def from_agent(key: str) -> Optional["MediaType"]: + """'movie' -> MediaType.MOVIE, 'tv' -> MediaType.TV, 否则 None""" + _map = {"movie": MediaType.MOVIE, "tv": MediaType.TV} + return _map.get(key.strip().lower() if key else "") + + def to_agent(self) -> str: + """MediaType.MOVIE -> 'movie', MediaType.TV -> 'tv', 其他返回 .value""" + return {MediaType.MOVIE: "movie", MediaType.TV: "tv"}.get(self, self.value) + + +def media_type_to_agent(value) -> Optional[str]: + """将 MediaType 枚举或中文字符串统一转为 'movie'/'tv'""" + if isinstance(value, MediaType): + return value.to_agent() + if isinstance(value, str): + mt = MediaType.from_agent(value) + return mt.to_agent() if mt else value + return None + # 排序类型枚举 class SortType(Enum): diff --git a/app/utils/crypto.py b/app/utils/crypto.py index 5ca7ba9c..39c2ee51 100644 --- a/app/utils/crypto.py +++ b/app/utils/crypto.py @@ -109,6 +109,19 @@ class HashUtils: data = data.encode(encoding) return hashlib.md5(data).hexdigest() + @staticmethod + def sha1(data: Union[str, bytes], encoding: str = "utf-8") -> str: + """ + 生成数据的SHA-1哈希值,并以字符串形式返回 + + :param data: 输入的数据,类型为字符串或字节 + :param encoding: 字符串编码类型,默认使用UTF-8 + :return: 生成的SHA-1哈希字符串 + """ + if isinstance(data, str): + data = data.encode(encoding) + return hashlib.sha1(data).hexdigest() + @staticmethod def md5_bytes(data: Union[str, bytes], encoding: str = "utf-8") -> bytes: """ diff --git a/docs/mcp-api.md b/docs/mcp-api.md index b5f27ba6..44cf3ead 100644 --- a/docs/mcp-api.md +++ b/docs/mcp-api.md @@ -123,7 +123,7 @@ MoviePilot 实现了标准的 **Model Context Protocol (MCP)**,允许 AI 智 "arguments": { "title": "流浪地球", "year": "2019", - "media_type": "电影" + "media_type": "movie" } } ``` diff --git a/skills/moviepilot-cli/SKILL.md b/skills/moviepilot-cli/SKILL.md new file mode 100644 index 00000000..b098c338 --- /dev/null +++ b/skills/moviepilot-cli/SKILL.md @@ -0,0 +1,132 @@ +--- +name: moviepilot-cli +description: Use this skill when the user wants to find, download, or subscribe to a movie or TV show (including anime); asks about download or subscription status; needs to check or organize the media library; or mentions MoviePilot directly. Covers the full media acquisition workflow via MoviePilot — searching TMDB, filtering and downloading torrents from PT indexer sites, managing subscriptions for automatic episode tracking, and handling library organization, site accounts, filter rules, and schedulers. +--- + +# MoviePilot CLI + +Use `scripts/mp-cli.js` to interact with the MoviePilot backend. + +## Discover Commands + +```bash +node scripts/mp-cli.js list # list all available commands +node scripts/mp-cli.js show # show parameters, required fields, and usage +``` + +Always run `show ` before calling a command. Do not guess parameter names or argument formats. + +## Command Groups + +| Category | Commands | +| ------------ | ---------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| Media Search | search_media, recognize_media, query_media_detail, get_recommendations, search_person, search_person_credits | +| Torrent | search_torrents, get_search_results | +| Download | add_download, query_download_tasks, delete_download, query_downloaders | +| Subscription | add_subscribe, query_subscribes, update_subscribe, delete_subscribe, search_subscribe, query_subscribe_history, query_popular_subscribes, query_subscribe_shares | +| Library | query_library_exists, query_library_latest, transfer_file, scrape_metadata, query_transfer_history | +| Files | list_directory, query_directory_settings | +| Sites | query_sites, query_site_userdata, test_site, update_site, update_site_cookie | +| System | query_schedulers, run_scheduler, query_workflows, run_workflow, query_rule_groups, query_episode_schedule, send_message | + +## Gotchas + +- **Don't guess command parameters.** Parameter names vary per command and are not inferrable. Always run `show ` first. +- **`search_torrents` results are cached server-side.** `get_search_results` reads from that cache — always run `search_torrents` first in the same session before filtering. +- **Omitting `sites` uses the user's configured default sites**, not all available sites. Only call `query_sites` and pass `sites=` when the user explicitly asks for a specific site. +- **TMDB season numbers don't always match fan-labeled seasons.** Anime and long-running shows often split one TMDB season into parts. Always validate with `query_media_detail` when the user mentions a specific season. +- **`add_download` is irreversible without manual cleanup.** Always present torrent details and wait for explicit confirmation before calling it. +- **`get_search_results` filter params are ANDed.** Combining multiple fields can silently exclude valid results. If results come back empty, drop the most restrictive filter and retry before reporting failure. +- **`volume_factor` and `freedate_diff` indicate promotional status.** `volume_factor` describes the discount type (e.g. `免费` = free download, `2X` = double upload only, `2X免费` = free download + double upload, `普通` = no discount). `freedate_diff` is the remaining free window (e.g. `2天3小时`); empty means no active promotion. Always include both fields when presenting results — they are critical for the user to pick the best-value torrent. + +## Common Workflows + +### Search and Download + +```bash +# 1. Search TMDB to get tmdb_id +node scripts/mp-cli.js search_media title="流浪地球2" media_type="movie" + +# [TV only, only if user specified a season] Validate season — see "Season Validation" section below +node scripts/mp-cli.js query_media_detail tmdb_id=... media_type="tv" + +# 2. Search torrents using tmdb_id — results are cached server-side +# Response includes available filter options (resolution, release group, etc.) +# [Optional] If the user specifies sites, first run query_sites to get IDs, then pass them via sites param +node scripts/mp-cli.js query_sites # get site IDs +node scripts/mp-cli.js search_torrents tmdb_id=791373 media_type="movie" # use user's default sites +node scripts/mp-cli.js search_torrents tmdb_id=791373 media_type="movie" sites='1,3' # override with specific sites + +# 3. Present ALL available filter_options to the user and ask which ones to apply +# Show every field and its values — do not pre-select or omit any +# e.g. "分辨率: 1080p, 2160p;字幕组: CMCT, PTer;请问需要筛选哪些条件?" + +# 4. Filter cached results based on user preferences and your own judgment +# Filter params are ANDed — if results come back empty, drop the most restrictive field and retry +node scripts/mp-cli.js get_search_results resolution='1080p' + +# [Optional] Re-check available filter options from cached results (same shape as search_torrents; returns filter options only) +node scripts/mp-cli.js get_search_results show_filter_options=true + +# 5. Present ALL filtered results as a numbered list — do not pre-select or discard any +# Show for each: index, title, size, seeders, resolution, release group, volume_factor, freedate_diff +# Let the user pick by number; only then proceed to step 6 + +# 6. After user confirms selection, check library and subscriptions before downloading +node scripts/mp-cli.js query_library_exists tmdb_id=123456 media_type="movie" +node scripts/mp-cli.js query_subscribes tmdb_id=123456 +# If already in library or subscribed, warn the user and ask for confirmation to proceed + +# 7. Add download +node scripts/mp-cli.js add_download torrent_url="..." +``` + +### Add Subscription + +```bash +# 1. Search to get tmdb_id (required for accurate identification) +node scripts/mp-cli.js search_media title="黑镜" media_type="tv" + +# 2. Subscribe — the system will auto-download new episodes +node scripts/mp-cli.js add_subscribe title="黑镜" year="2011" media_type="tv" tmdb_id=42009 +``` + +### Manage Subscriptions + +```bash +node scripts/mp-cli.js query_subscribes status=R # list active +node scripts/mp-cli.js update_subscribe subscribe_id=123 resolution="1080p" # update filters +node scripts/mp-cli.js search_subscribe subscribe_id=123 # search missing episodes +node scripts/mp-cli.js delete_subscribe subscribe_id=123 # remove +``` + +## Season Validation (only when user specifies a season) + +Skip this section if the user did not mention a specific season. + +**Step 1 — Verify the season exists:** + +```bash +node scripts/mp-cli.js query_media_detail tmdb_id= media_type="tv" +``` + +Check `season_info` against the season the user requested: + +- **Season exists:** use that season number directly, then proceed to torrent search. +- **Season does not exist:** anime and long-running shows often split one TMDB season into multiple parts that fans call separate seasons. Use the latest available season number and continue to Step 2. + +**Step 2 — Identify the correct episode range:** + +```bash +node scripts/mp-cli.js query_episode_schedule tmdb_id= season= +``` + +Use `air_date` to find a block of recently-aired episodes that likely corresponds to what the user calls the missing season. If no such block exists, tell the user the content is unavailable. Otherwise, confirm the episode range with the user before proceeding to torrent search. + +## Error Handling + +| Error | Resolution | +| --------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| No search results | Retry with an alternative title (e.g. English title). If still empty, ask the user to confirm the title or provide the TMDB ID directly. | +| Download failure | Run `query_downloaders` to check downloader health, then `query_download_tasks` to check if the task already exists (duplicate tasks are rejected). If both are normal, report findings to the user, suggest checking storage space, and mention it may be a network error — suggest retrying later. | +| Missing configuration | Ask the user for the backend host and API key. Once provided, run `node scripts/mp-cli.js -h -k ` (no command) to save the config persistently — subsequent commands will use it automatically. | diff --git a/skills/moviepilot-cli/scripts/mp-cli.js b/skills/moviepilot-cli/scripts/mp-cli.js new file mode 100755 index 00000000..024b0213 --- /dev/null +++ b/skills/moviepilot-cli/scripts/mp-cli.js @@ -0,0 +1,543 @@ +#!/usr/bin/env node + +'use strict'; + +const fs = require('fs'); +const os = require('os'); +const path = require('path'); +const http = require('http'); +const https = require('https'); + +const SCRIPT_NAME = process.env.MP_SCRIPT_NAME || path.basename(process.argv[1] || 'mp-cli.js'); +const CONFIG_DIR = path.join(os.homedir(), '.config', 'moviepilot_cli'); +const CONFIG_FILE = path.join(CONFIG_DIR, 'config'); + +let commandsJson = []; +let commandsLoaded = false; + +let optHost = ''; +let optKey = ''; + +const envHost = process.env.MP_HOST || ''; +const envKey = process.env.MP_API_KEY || ''; + +let mpHost = ''; +let mpApiKey = ''; + +function fail(message) { + console.error(message); + process.exit(1); +} + +function spacePad(text = '', targetCol = 0) { + const spaces = text.length < targetCol ? targetCol - text.length + 2 : 2; + return ' '.repeat(spaces); +} + +function printBox(title, lines) { + const rightPadding = 0; + const contentWidth = + lines.reduce((max, line) => Math.max(max, line.length), title.length) + rightPadding; + const innerWidth = contentWidth + 2; + const topLabel = `─ ${title}`; + + console.error(`┌${topLabel}${'─'.repeat(Math.max(innerWidth - topLabel.length, 0))}┐`); + for (const line of lines) { + console.error(`│ ${line}${' '.repeat(contentWidth - line.length)} │`); + } + console.error(`└${'─'.repeat(innerWidth)}┘`); +} + +function readConfig() { + let cfgHost = ''; + let cfgKey = ''; + + if (!fs.existsSync(CONFIG_FILE)) { + return { cfgHost, cfgKey }; + } + + const content = fs.readFileSync(CONFIG_FILE, 'utf8'); + for (const line of content.split(/\r?\n/)) { + if (!line.trim() || /^\s*#/.test(line)) { + continue; + } + + const index = line.indexOf('='); + if (index === -1) { + continue; + } + + const key = line.slice(0, index).replace(/\s+/g, ''); + const value = line.slice(index + 1); + + if (key === 'MP_HOST') { + cfgHost = value; + } else if (key === 'MP_API_KEY') { + cfgKey = value; + } + } + + return { cfgHost, cfgKey }; +} + +function saveConfig(host, key) { + fs.mkdirSync(CONFIG_DIR, { recursive: true }); + fs.writeFileSync(CONFIG_FILE, `MP_HOST=${host}\nMP_API_KEY=${key}\n`, 'utf8'); + fs.chmodSync(CONFIG_FILE, 0o600); +} + +function loadConfig() { + const { cfgHost: initialHost, cfgKey: initialKey } = readConfig(); + let cfgHost = initialHost; + let cfgKey = initialKey; + + if (optHost || optKey) { + const nextHost = optHost || cfgHost; + const nextKey = optKey || cfgKey; + saveConfig(nextHost, nextKey); + cfgHost = nextHost; + cfgKey = nextKey; + } + + mpHost = optHost || mpHost || envHost || cfgHost; + mpApiKey = optKey || mpApiKey || envKey || cfgKey; +} + +function normalizeType(schema = {}) { + if (schema.type) { + return schema.type; + } + if (Array.isArray(schema.anyOf)) { + const candidate = schema.anyOf.find((item) => item && item.type && item.type !== 'null'); + return candidate?.type || 'string'; + } + return 'string'; +} + +function normalizeItemType(schema = {}) { + const items = schema.items; + if (!items) { + return null; + } + if (items.type) { + return items.type; + } + if (Array.isArray(items.anyOf)) { + const candidate = items.anyOf.find((item) => item && item.type && item.type !== 'null'); + return candidate?.type || null; + } + return null; +} + +function normalizeCommand(tool = {}) { + const properties = tool?.inputSchema?.properties || {}; + const required = Array.isArray(tool?.inputSchema?.required) ? tool.inputSchema.required : []; + const fields = Object.entries(properties) + .filter(([fieldName]) => fieldName !== 'explanation') + .map(([fieldName, schema]) => ({ + name: fieldName, + type: normalizeType(schema), + description: schema?.description || '', + required: required.includes(fieldName), + item_type: normalizeItemType(schema), + })); + + return { + name: tool?.name, + description: tool?.description || '', + fields, + }; +} + +function request(method, targetUrl, headers = {}, body, timeout = 120000) { + return new Promise((resolve, reject) => { + let url; + try { + url = new URL(targetUrl); + } catch (error) { + reject(new Error(`Invalid URL: ${targetUrl}`)); + return; + } + + const transport = url.protocol === 'https:' ? https : http; + const req = transport.request( + { + method, + hostname: url.hostname, + port: url.port || undefined, + path: `${url.pathname}${url.search}`, + headers, + }, + (res) => { + const chunks = []; + res.on('data', (chunk) => chunks.push(chunk)); + res.on('end', () => { + resolve({ + statusCode: res.statusCode ? String(res.statusCode) : '', + body: Buffer.concat(chunks).toString('utf8'), + }); + }); + } + ); + + req.setTimeout(timeout, () => { + req.destroy(new Error(`Request timed out after ${timeout}ms`)); + }); + + req.on('error', reject); + + if (body !== undefined) { + req.write(body); + } + + req.end(); + }); +} + +async function loadCommandsJson() { + if (commandsLoaded) { + return; + } + + const { statusCode, body } = await request('GET', `${mpHost}/api/v1/mcp/tools`, { + 'X-API-KEY': mpApiKey, + }); + + if (statusCode !== '200') { + console.error(`Error: failed to load command definitions (HTTP ${statusCode || 'unknown'})`); + process.exit(1); + } + + let response; + try { + response = JSON.parse(body); + } catch { + fail('Error: backend returned invalid JSON for command definitions'); + } + + commandsJson = Array.isArray(response) + ? response.map((tool) => normalizeCommand(tool)) + : []; + + commandsLoaded = true; +} + +async function loadCommandJson(commandName) { + const { statusCode, body } = await request('GET', `${mpHost}/api/v1/mcp/tools/${commandName}`, { + 'X-API-KEY': mpApiKey, + }); + + if (statusCode === '404') { + console.error(`Error: command '${commandName}' not found`); + console.error(`Run 'node ${SCRIPT_NAME} list' to see available commands`); + process.exit(1); + } + + if (statusCode !== '200') { + console.error(`Error: failed to load command definition (HTTP ${statusCode || 'unknown'})`); + process.exit(1); + } + + let response; + try { + response = JSON.parse(body); + } catch { + fail(`Error: backend returned invalid JSON for command '${commandName}'`); + } + + return normalizeCommand(response); +} + +function ensureConfig() { + loadConfig(); + let ok = true; + + if (!mpHost) { + console.error('Error: backend host is not configured.'); + console.error(' Use: -h HOST to set it'); + console.error(' Or set environment variable: MP_HOST=http://localhost:3001'); + ok = false; + } + + if (!mpApiKey) { + console.error('Error: API key is not configured.'); + console.error(' Use: -k KEY to set it'); + console.error(' Or set environment variable: MP_API_KEY=your_key'); + ok = false; + } + + if (!ok) { + process.exit(1); + } +} + +function printValue(value) { + if (typeof value === 'string') { + process.stdout.write(`${value}\n`); + return; + } + + process.stdout.write(`${JSON.stringify(value)}\n`); +} + +function formatUsageValue(field) { + if (field?.type === 'array') { + return "','"; + } + return ''; +} + +async function cmdList() { + await loadCommandsJson(); + const sortedCommands = [...commandsJson].sort((left, right) => left.name.localeCompare(right.name)); + for (const command of sortedCommands) { + process.stdout.write(`${command.name}\n`); + } +} + +async function cmdShow(commandName) { + if (!commandName) { + fail(`Usage: ${SCRIPT_NAME} show `); + } + + const command = await loadCommandJson(commandName); + + const commandLabel = 'Command:'; + const descriptionLabel = 'Description:'; + const paramsLabel = 'Parameters:'; + const usageLabel = 'Usage:'; + const detailLabelWidth = Math.max( + commandLabel.length, + descriptionLabel.length, + paramsLabel.length, + usageLabel.length + ); + + process.stdout.write(`${commandLabel} ${command.name}\n`); + process.stdout.write(`${descriptionLabel} ${command.description || '(none)'}\n\n`); + + if (command.fields.length === 0) { + process.stdout.write(`${paramsLabel}${spacePad(paramsLabel, detailLabelWidth)}(none)\n`); + } else { + const fieldLines = command.fields.map((field) => [ + field.required ? `${field.name}*` : field.name, + field.type, + field.description, + ]); + + const nameWidth = Math.max(...fieldLines.map(([name]) => name.length), 0); + const typeWidth = Math.max(...fieldLines.map(([, type]) => type.length), 0); + + process.stdout.write(`${paramsLabel}\n`); + for (const [fieldName, fieldType, fieldDesc] of fieldLines) { + process.stdout.write( + ` ${fieldName}${spacePad(fieldName, nameWidth)}${fieldType}${spacePad(fieldType, typeWidth)}${fieldDesc}\n` + ); + } + } + + const usageLine = `${command.name}`; + const reqPart = command.fields + .filter((field) => field.required) + .map((field) => ` ${field.name}=${formatUsageValue(field)}`) + .join(''); + const optPart = command.fields + .filter((field) => !field.required) + .map((field) => ` [${field.name}=${formatUsageValue(field)}]`) + .join(''); + + process.stdout.write(`\n${usageLabel} ${usageLine}${reqPart}${optPart}\n`); +} + +function buildArguments(pairs) { + const args = { explanation: 'CLI invocation' }; + + for (const kv of pairs) { + if (!kv.includes('=')) { + fail(`Error: argument must be in key=value format, got: '${kv}'`); + } + + const index = kv.indexOf('='); + args[kv.slice(0, index)] = kv.slice(index + 1); + } + + return args; +} + +async function cmdRun(commandName, pairs) { + if (!commandName) { + fail(`Usage: ${SCRIPT_NAME} [key=value ...]`); + } + + const requestBody = JSON.stringify({ + tool_name: commandName, + arguments: buildArguments(pairs), + }); + + const { statusCode, body } = await request( + 'POST', + `${mpHost}/api/v1/mcp/tools/call`, + { + 'Content-Type': 'application/json', + 'Content-Length': Buffer.byteLength(requestBody), + 'X-API-KEY': mpApiKey, + }, + requestBody + ); + + if (statusCode && statusCode !== '200' && statusCode !== '201') { + console.error(`Warning: HTTP status ${statusCode}`); + } + + try { + const parsed = JSON.parse(body); + if (Object.prototype.hasOwnProperty.call(parsed, 'error') && parsed.error) { + printValue(parsed); + return; + } + + if (Object.prototype.hasOwnProperty.call(parsed, 'result')) { + if (typeof parsed.result === 'string') { + try { + printValue(JSON.parse(parsed.result)); + } catch { + printValue(parsed.result); + } + } else { + printValue(parsed.result); + } + return; + } + + printValue(parsed); + } catch { + process.stdout.write(`${body}\n`); + } +} + +function printUsage() { + const { cfgHost, cfgKey } = readConfig(); + let effectiveHost = mpHost || envHost || cfgHost; + let effectiveKey = mpApiKey || envKey || cfgKey; + + if (optHost) { + effectiveHost = optHost; + } + if (optKey) { + effectiveKey = optKey; + } + + if (!effectiveHost || !effectiveKey) { + const warningLines = []; + if (!effectiveHost) { + const opt = '-h HOST'; + const desc = 'set backend host'; + warningLines.push(`${opt}${spacePad(opt)}${desc}`); + } + if (!effectiveKey) { + const opt = '-k KEY'; + const desc = 'set API key'; + warningLines.push(`${opt}${spacePad(opt)}${desc}`); + } + printBox('Warning: not configured', warningLines); + console.error(''); + } + + process.stdout.write(`Usage: ${SCRIPT_NAME} [-h HOST] [-k KEY] [COMMAND] [ARGS...]\n\n`); + const optionWidth = Math.max('-h HOST'.length, '-k KEY'.length); + process.stdout.write('Options:\n'); + process.stdout.write(` -h HOST${spacePad('-h HOST', optionWidth)}backend host\n`); + process.stdout.write(` -k KEY${spacePad('-k KEY', optionWidth)}API key\n\n`); + const commandWidth = Math.max( + '(no command)'.length, + 'list'.length, + 'show '.length, + ' [k=v...]'.length + ); + process.stdout.write('Commands:\n'); + process.stdout.write( + ` (no command)${spacePad('(no command)', commandWidth)}save config when -h and -k are provided\n` + ); + process.stdout.write(` list${spacePad('list', commandWidth)}list all commands\n`); + process.stdout.write( + ` show ${spacePad('show ', commandWidth)}show command details and usage example\n` + ); + process.stdout.write( + ` [k=v...]${spacePad(' [k=v...]', commandWidth)}run a command\n` + ); +} + +async function main() { + const args = []; + const argv = process.argv.slice(2); + + for (let index = 0; index < argv.length; index += 1) { + const arg = argv[index]; + + if (arg === '--help' || arg === '-?') { + printUsage(); + process.exit(0); + } + + if (arg === '-h') { + index += 1; + optHost = argv[index] || ''; + continue; + } + + if (arg === '-k') { + index += 1; + optKey = argv[index] || ''; + continue; + } + + if (arg === '--') { + args.push(...argv.slice(index + 1)); + break; + } + + if (arg.startsWith('-')) { + console.error(`Unknown option: ${arg}`); + printUsage(); + process.exit(1); + } + + args.push(arg); + } + + if ((optHost && !optKey) || (!optHost && optKey)) { + fail('Error: -h and -k must be provided together'); + } + + const command = args[0] || ''; + + if (command === 'list') { + ensureConfig(); + await cmdList(); + return; + } + + if (command === 'show') { + ensureConfig(); + await cmdShow(args[1] || ''); + return; + } + + if (!command) { + if (optHost || optKey) { + loadConfig(); + process.stdout.write('Configuration saved.\n'); + return; + } + + printUsage(); + return; + } + + ensureConfig(); + await cmdRun(command, args.slice(1)); +} + +main().catch((error) => { + fail(`Error: ${error.message}`); +});