diff --git a/app/agent/tools/impl/add_subscribe.py b/app/agent/tools/impl/add_subscribe.py index afc1b10e..fa471a10 100644 --- a/app/agent/tools/impl/add_subscribe.py +++ b/app/agent/tools/impl/add_subscribe.py @@ -73,12 +73,8 @@ class AddSubscribeTool(MoviePilotTool): try: subscribe_chain = SubscribeChain() - media_type_key = media_type.strip().lower() - if media_type_key == "movie": - media_type_enum = MediaType.MOVIE - elif media_type_key == "tv": - media_type_enum = MediaType.TV - else: + media_type_enum = MediaType.from_agent(media_type) + if not media_type_enum: return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv'" # 构建额外的订阅参数 diff --git a/app/agent/tools/impl/get_recommendations.py b/app/agent/tools/impl/get_recommendations.py index 249a2dea..e0617b18 100644 --- a/app/agent/tools/impl/get_recommendations.py +++ b/app/agent/tools/impl/get_recommendations.py @@ -8,6 +8,7 @@ from pydantic import BaseModel, Field from app.agent.tools.base import MoviePilotTool from app.chain.recommend import RecommendChain from app.log import logger +from app.schemas.types import MediaType, media_type_to_agent class GetRecommendationsInput(BaseModel): @@ -75,8 +76,11 @@ class GetRecommendationsTool(MoviePilotTool): media_type: Optional[str] = "all", limit: Optional[int] = 20, **kwargs) -> str: logger.info(f"执行工具: {self.name}, 参数: source={source}, media_type={media_type}, limit={limit}") try: - if media_type not in ["all", "movie", "tv"]: - return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv', 'all'" + if media_type != "all": + media_type_enum = MediaType.from_agent(media_type) + if not media_type_enum: + return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv', 'all'" + media_type = media_type_enum.to_agent() # 归一化为 "movie"/"tv" recommend_chain = RecommendChain() results = [] @@ -152,7 +156,7 @@ class GetRecommendationsTool(MoviePilotTool): "title": r.get("title"), "en_title": r.get("en_title"), "year": r.get("year"), - "type": r.get("type"), + "type": media_type_to_agent(r.get("type")), "season": r.get("season"), "tmdb_id": r.get("tmdb_id"), "imdb_id": r.get("imdb_id"), diff --git a/app/agent/tools/impl/query_download_tasks.py b/app/agent/tools/impl/query_download_tasks.py index c872b9eb..4d7c85e5 100644 --- a/app/agent/tools/impl/query_download_tasks.py +++ b/app/agent/tools/impl/query_download_tasks.py @@ -10,7 +10,7 @@ from app.chain.download import DownloadChain from app.db.downloadhistory_oper import DownloadHistoryOper from app.log import logger from app.schemas import TransferTorrent, DownloadingTorrent -from app.schemas.types import TorrentStatus +from app.schemas.types import TorrentStatus, media_type_to_agent class QueryDownloadTasksInput(BaseModel): @@ -208,7 +208,7 @@ class QueryDownloadTasksTool(MoviePilotTool): if d.media: simplified["media"] = { "tmdbid": d.media.get("tmdbid"), - "type": d.media.get("type"), + "type": media_type_to_agent(d.media.get("type")), "title": d.media.get("title"), "season": d.media.get("season"), "episode": d.media.get("episode") diff --git a/app/agent/tools/impl/query_library_exists.py b/app/agent/tools/impl/query_library_exists.py index b7fdfdc5..d14ade56 100644 --- a/app/agent/tools/impl/query_library_exists.py +++ b/app/agent/tools/impl/query_library_exists.py @@ -8,7 +8,7 @@ from pydantic import BaseModel, Field from app.agent.tools.base import MoviePilotTool from app.chain.mediaserver import MediaServerChain from app.log import logger -from app.schemas.types import MediaType +from app.schemas.types import MediaType, media_type_to_agent class QueryLibraryExistsInput(BaseModel): @@ -49,12 +49,8 @@ class QueryLibraryExistsTool(MoviePilotTool): media_type_enum = None if media_type: - media_type_key = media_type.strip().lower() - if media_type_key == "movie": - media_type_enum = MediaType.MOVIE - elif media_type_key == "tv": - media_type_enum = MediaType.TV - else: + media_type_enum = MediaType.from_agent(media_type) + if not media_type_enum: return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv'" media_chain = MediaServerChain() @@ -109,7 +105,7 @@ class QueryLibraryExistsTool(MoviePilotTool): result_dict = { "title": mediainfo.title, "year": mediainfo.year, - "type": existsinfo.type.value if existsinfo.type else None, + "type": media_type_to_agent(existsinfo.type), "server": existsinfo.server, "server_type": existsinfo.server_type, "itemid": existsinfo.itemid, diff --git a/app/agent/tools/impl/query_media_detail.py b/app/agent/tools/impl/query_media_detail.py index 1d5f55a9..a318bba4 100644 --- a/app/agent/tools/impl/query_media_detail.py +++ b/app/agent/tools/impl/query_media_detail.py @@ -8,7 +8,7 @@ from pydantic import BaseModel, Field from app.agent.tools.base import MoviePilotTool from app.chain.media import MediaChain from app.log import logger -from app.schemas import MediaType +from app.schemas.types import MediaType class QueryMediaDetailInput(BaseModel): @@ -34,15 +34,13 @@ class QueryMediaDetailTool(MoviePilotTool): try: media_chain = MediaChain() - media_type_key = (media_type or "").strip().lower() - if media_type_key not in ["movie", "tv"]: + media_type_enum = MediaType.from_agent(media_type) + if not media_type_enum: return json.dumps({ "success": False, "message": f"无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv'" }, ensure_ascii=False) - media_type_enum = MediaType.MOVIE if media_type_key == "movie" else MediaType.TV - mediainfo = await media_chain.async_recognize_media(tmdbid=tmdb_id, mtype=media_type_enum) if not mediainfo: diff --git a/app/agent/tools/impl/query_popular_subscribes.py b/app/agent/tools/impl/query_popular_subscribes.py index e519c148..5243aabc 100644 --- a/app/agent/tools/impl/query_popular_subscribes.py +++ b/app/agent/tools/impl/query_popular_subscribes.py @@ -10,7 +10,7 @@ from app.agent.tools.base import MoviePilotTool from app.core.context import MediaInfo from app.helper.subscribe import SubscribeHelper from app.log import logger -from app.schemas.types import MediaType +from app.schemas.types import MediaType, media_type_to_agent class QueryPopularSubscribesInput(BaseModel): @@ -69,12 +69,13 @@ class QueryPopularSubscribesTool(MoviePilotTool): page = 1 if count is None or count < 1: count = 30 - if media_type not in ["movie", "tv"]: + media_type_enum = MediaType.from_agent(media_type) + if not media_type_enum: return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv'" subscribe_helper = SubscribeHelper() subscribes = await subscribe_helper.async_get_statistic( - stype=media_type, + stype=media_type_enum.to_agent(), page=page, count=count, genre_id=genre_id, @@ -134,7 +135,7 @@ class QueryPopularSubscribesTool(MoviePilotTool): for media in ret_medias: media_dict = media.to_dict() simplified = { - "type": media_dict.get("type"), + "type": media_type_to_agent(media_dict.get("type")), "title": media_dict.get("title"), "year": media_dict.get("year"), "tmdb_id": media_dict.get("tmdb_id"), diff --git a/app/agent/tools/impl/query_subscribe_history.py b/app/agent/tools/impl/query_subscribe_history.py index e3b7f660..f0cc51f1 100644 --- a/app/agent/tools/impl/query_subscribe_history.py +++ b/app/agent/tools/impl/query_subscribe_history.py @@ -9,6 +9,7 @@ from app.agent.tools.base import MoviePilotTool from app.db import AsyncSessionFactory from app.db.models.subscribehistory import SubscribeHistory from app.log import logger +from app.schemas.types import media_type_to_agent class QuerySubscribeHistoryInput(BaseModel): @@ -83,7 +84,7 @@ class QuerySubscribeHistoryTool(MoviePilotTool): "id": record.id, "name": record.name, "year": record.year, - "type": record.type, + "type": media_type_to_agent(record.type), "season": record.season, "tmdbid": record.tmdbid, "doubanid": record.doubanid, diff --git a/app/agent/tools/impl/query_subscribes.py b/app/agent/tools/impl/query_subscribes.py index 2153155c..4018dbaf 100644 --- a/app/agent/tools/impl/query_subscribes.py +++ b/app/agent/tools/impl/query_subscribes.py @@ -8,6 +8,7 @@ from pydantic import BaseModel, Field from app.agent.tools.base import MoviePilotTool from app.db.subscribe_oper import SubscribeOper from app.log import logger +from app.schemas.types import MediaType, media_type_to_agent class QuerySubscribesInput(BaseModel): @@ -47,17 +48,16 @@ class QuerySubscribesTool(MoviePilotTool): tmdb_id: Optional[int] = None, **kwargs) -> str: logger.info(f"执行工具: {self.name}, 参数: status={status}, media_type={media_type}, tmdb_id={tmdb_id}") try: - if media_type not in ["all", "movie", "tv"]: + if media_type != "all" and not MediaType.from_agent(media_type): return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv', 'all'" - media_type_map = {"movie": "电影", "tv": "电视剧"} subscribe_oper = SubscribeOper() subscribes = await subscribe_oper.async_list() filtered_subscribes = [] for sub in subscribes: if status != "all" and sub.state != status: continue - if media_type != "all" and sub.type != media_type_map.get(media_type): + if media_type != "all" and sub.type != MediaType.from_agent(media_type).value: continue if tmdb_id is not None and sub.tmdbid != tmdb_id: continue @@ -73,7 +73,7 @@ class QuerySubscribesTool(MoviePilotTool): "id": s.id, "name": s.name, "year": s.year, - "type": s.type, + "type": media_type_to_agent(s.type), "season": s.season, "tmdbid": s.tmdbid, "doubanid": s.doubanid, diff --git a/app/agent/tools/impl/query_transfer_history.py b/app/agent/tools/impl/query_transfer_history.py index 158031fb..b228d01b 100644 --- a/app/agent/tools/impl/query_transfer_history.py +++ b/app/agent/tools/impl/query_transfer_history.py @@ -10,6 +10,7 @@ from app.agent.tools.base import MoviePilotTool from app.db import AsyncSessionFactory from app.db.models.transferhistory import TransferHistory from app.log import logger +from app.schemas.types import media_type_to_agent class QueryTransferHistoryInput(BaseModel): @@ -95,7 +96,7 @@ class QueryTransferHistoryTool(MoviePilotTool): "id": record.id, "title": record.title, "year": record.year, - "type": record.type, + "type": media_type_to_agent(record.type), "category": record.category, "seasons": record.seasons, "episodes": record.episodes, diff --git a/app/agent/tools/impl/recognize_media.py b/app/agent/tools/impl/recognize_media.py index 82dea564..2f852770 100644 --- a/app/agent/tools/impl/recognize_media.py +++ b/app/agent/tools/impl/recognize_media.py @@ -10,6 +10,7 @@ from app.chain.media import MediaChain from app.core.context import Context from app.core.metainfo import MetaInfo from app.log import logger +from app.schemas.types import media_type_to_agent class RecognizeMediaInput(BaseModel): @@ -124,7 +125,7 @@ class RecognizeMediaTool(MoviePilotTool): "title": media_info.get("title"), "en_title": media_info.get("en_title"), "year": media_info.get("year"), - "type": media_info.get("type"), + "type": media_type_to_agent(media_info.get("type")), "season": media_info.get("season"), "tmdb_id": media_info.get("tmdb_id"), "imdb_id": media_info.get("imdb_id"), @@ -145,7 +146,7 @@ class RecognizeMediaTool(MoviePilotTool): "name": meta_info.get("name"), "title": meta_info.get("title"), "year": meta_info.get("year"), - "type": meta_info.get("type"), + "type": media_type_to_agent(meta_info.get("type")), "begin_season": meta_info.get("begin_season"), "end_season": meta_info.get("end_season"), "begin_episode": meta_info.get("begin_episode"), diff --git a/app/agent/tools/impl/search_media.py b/app/agent/tools/impl/search_media.py index acfdbe38..4b15c5bb 100644 --- a/app/agent/tools/impl/search_media.py +++ b/app/agent/tools/impl/search_media.py @@ -8,7 +8,7 @@ from pydantic import BaseModel, Field from app.agent.tools.base import MoviePilotTool from app.chain.media import MediaChain from app.log import logger -from app.schemas.types import MediaType +from app.schemas.types import MediaType, media_type_to_agent class SearchMediaInput(BaseModel): @@ -58,12 +58,8 @@ class SearchMediaTool(MoviePilotTool): if results: media_type_enum = None if media_type: - media_type_key = media_type.strip().lower() - if media_type_key == "movie": - media_type_enum = MediaType.MOVIE - elif media_type_key == "tv": - media_type_enum = MediaType.TV - else: + media_type_enum = MediaType.from_agent(media_type) + if not media_type_enum: return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv'" filtered_results = [] @@ -87,7 +83,7 @@ class SearchMediaTool(MoviePilotTool): "title": r.title, "en_title": r.en_title, "year": r.year, - "type": r.type.value if r.type else None, + "type": media_type_to_agent(r.type), "season": r.season, "tmdb_id": r.tmdb_id, "imdb_id": r.imdb_id, diff --git a/app/agent/tools/impl/search_subscribe.py b/app/agent/tools/impl/search_subscribe.py index 1dc0ef03..b53783c4 100644 --- a/app/agent/tools/impl/search_subscribe.py +++ b/app/agent/tools/impl/search_subscribe.py @@ -10,6 +10,7 @@ from app.chain.subscribe import SubscribeChain from app.core.config import global_vars from app.db.subscribe_oper import SubscribeOper from app.log import logger +from app.schemas.types import media_type_to_agent class SearchSubscribeInput(BaseModel): @@ -58,7 +59,7 @@ class SearchSubscribeTool(MoviePilotTool): "id": subscribe.id, "name": subscribe.name, "year": subscribe.year, - "type": subscribe.type, + "type": media_type_to_agent(subscribe.type), "season": subscribe.season, "state": subscribe.state, "total_episode": subscribe.total_episode, diff --git a/app/agent/tools/impl/search_torrents.py b/app/agent/tools/impl/search_torrents.py index 63ee91aa..fb71fca7 100644 --- a/app/agent/tools/impl/search_torrents.py +++ b/app/agent/tools/impl/search_torrents.py @@ -63,12 +63,8 @@ class SearchTorrentsTool(MoviePilotTool): search_chain = SearchChain() media_type_enum = None if media_type: - media_type_key = media_type.strip().lower() - if media_type_key == "movie": - media_type_enum = MediaType.MOVIE - elif media_type_key == "tv": - media_type_enum = MediaType.TV - else: + media_type_enum = MediaType.from_agent(media_type) + if not media_type_enum: return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv'" filtered_torrents = await search_chain.async_search_by_id( diff --git a/app/agent/tools/impl/transfer_file.py b/app/agent/tools/impl/transfer_file.py index 8ebebaf4..ff17911c 100644 --- a/app/agent/tools/impl/transfer_file.py +++ b/app/agent/tools/impl/transfer_file.py @@ -93,12 +93,8 @@ class TransferFileTool(MoviePilotTool): # 处理媒体类型 media_type_enum = None if media_type: - media_type_key = media_type.strip().lower() - if media_type_key == "movie": - media_type_enum = MediaType.MOVIE - elif media_type_key == "tv": - media_type_enum = MediaType.TV - else: + media_type_enum = MediaType.from_agent(media_type) + if not media_type_enum: return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv'" # 调用整理方法 diff --git a/app/schemas/types.py b/app/schemas/types.py index ad0a714f..e8e4ddc3 100644 --- a/app/schemas/types.py +++ b/app/schemas/types.py @@ -1,4 +1,5 @@ from enum import Enum +from typing import Optional # 媒体类型 @@ -8,6 +9,26 @@ class MediaType(Enum): COLLECTION = '系列' UNKNOWN = '未知' + @staticmethod + def from_agent(key: str) -> Optional["MediaType"]: + """'movie' -> MediaType.MOVIE, 'tv' -> MediaType.TV, 否则 None""" + _map = {"movie": MediaType.MOVIE, "tv": MediaType.TV} + return _map.get(key.strip().lower() if key else "") + + def to_agent(self) -> str: + """MediaType.MOVIE -> 'movie', MediaType.TV -> 'tv', 其他返回 .value""" + return {MediaType.MOVIE: "movie", MediaType.TV: "tv"}.get(self, self.value) + + +def media_type_to_agent(value) -> Optional[str]: + """将 MediaType 枚举或中文字符串统一转为 'movie'/'tv'""" + if isinstance(value, MediaType): + return value.to_agent() + if isinstance(value, str): + mt = MediaType.from_agent(value) + return mt.to_agent() if mt else value + return None + # 排序类型枚举 class SortType(Enum):