feat: 工具输入输出统一为movie或tv

This commit is contained in:
PKC278
2026-03-18 17:07:03 +08:00
parent 4f3eaa12d5
commit b50a3b9aae
15 changed files with 65 additions and 57 deletions

View File

@@ -73,12 +73,8 @@ class AddSubscribeTool(MoviePilotTool):
try: try:
subscribe_chain = SubscribeChain() subscribe_chain = SubscribeChain()
media_type_key = media_type.strip().lower() media_type_enum = MediaType.from_agent(media_type)
if media_type_key == "movie": if not media_type_enum:
media_type_enum = MediaType.MOVIE
elif media_type_key == "tv":
media_type_enum = MediaType.TV
else:
return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv'" return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv'"
# 构建额外的订阅参数 # 构建额外的订阅参数

View File

@@ -8,6 +8,7 @@ from pydantic import BaseModel, Field
from app.agent.tools.base import MoviePilotTool from app.agent.tools.base import MoviePilotTool
from app.chain.recommend import RecommendChain from app.chain.recommend import RecommendChain
from app.log import logger from app.log import logger
from app.schemas.types import MediaType, media_type_to_agent
class GetRecommendationsInput(BaseModel): class GetRecommendationsInput(BaseModel):
@@ -75,8 +76,11 @@ class GetRecommendationsTool(MoviePilotTool):
media_type: Optional[str] = "all", limit: Optional[int] = 20, **kwargs) -> str: media_type: Optional[str] = "all", limit: Optional[int] = 20, **kwargs) -> str:
logger.info(f"执行工具: {self.name}, 参数: source={source}, media_type={media_type}, limit={limit}") logger.info(f"执行工具: {self.name}, 参数: source={source}, media_type={media_type}, limit={limit}")
try: try:
if media_type not in ["all", "movie", "tv"]: if media_type != "all":
media_type_enum = MediaType.from_agent(media_type)
if not media_type_enum:
return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv', 'all'" return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv', 'all'"
media_type = media_type_enum.to_agent() # 归一化为 "movie"/"tv"
recommend_chain = RecommendChain() recommend_chain = RecommendChain()
results = [] results = []
@@ -152,7 +156,7 @@ class GetRecommendationsTool(MoviePilotTool):
"title": r.get("title"), "title": r.get("title"),
"en_title": r.get("en_title"), "en_title": r.get("en_title"),
"year": r.get("year"), "year": r.get("year"),
"type": r.get("type"), "type": media_type_to_agent(r.get("type")),
"season": r.get("season"), "season": r.get("season"),
"tmdb_id": r.get("tmdb_id"), "tmdb_id": r.get("tmdb_id"),
"imdb_id": r.get("imdb_id"), "imdb_id": r.get("imdb_id"),

View File

@@ -10,7 +10,7 @@ from app.chain.download import DownloadChain
from app.db.downloadhistory_oper import DownloadHistoryOper from app.db.downloadhistory_oper import DownloadHistoryOper
from app.log import logger from app.log import logger
from app.schemas import TransferTorrent, DownloadingTorrent from app.schemas import TransferTorrent, DownloadingTorrent
from app.schemas.types import TorrentStatus from app.schemas.types import TorrentStatus, media_type_to_agent
class QueryDownloadTasksInput(BaseModel): class QueryDownloadTasksInput(BaseModel):
@@ -208,7 +208,7 @@ class QueryDownloadTasksTool(MoviePilotTool):
if d.media: if d.media:
simplified["media"] = { simplified["media"] = {
"tmdbid": d.media.get("tmdbid"), "tmdbid": d.media.get("tmdbid"),
"type": d.media.get("type"), "type": media_type_to_agent(d.media.get("type")),
"title": d.media.get("title"), "title": d.media.get("title"),
"season": d.media.get("season"), "season": d.media.get("season"),
"episode": d.media.get("episode") "episode": d.media.get("episode")

View File

@@ -8,7 +8,7 @@ from pydantic import BaseModel, Field
from app.agent.tools.base import MoviePilotTool from app.agent.tools.base import MoviePilotTool
from app.chain.mediaserver import MediaServerChain from app.chain.mediaserver import MediaServerChain
from app.log import logger from app.log import logger
from app.schemas.types import MediaType from app.schemas.types import MediaType, media_type_to_agent
class QueryLibraryExistsInput(BaseModel): class QueryLibraryExistsInput(BaseModel):
@@ -49,12 +49,8 @@ class QueryLibraryExistsTool(MoviePilotTool):
media_type_enum = None media_type_enum = None
if media_type: if media_type:
media_type_key = media_type.strip().lower() media_type_enum = MediaType.from_agent(media_type)
if media_type_key == "movie": if not media_type_enum:
media_type_enum = MediaType.MOVIE
elif media_type_key == "tv":
media_type_enum = MediaType.TV
else:
return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv'" return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv'"
media_chain = MediaServerChain() media_chain = MediaServerChain()
@@ -109,7 +105,7 @@ class QueryLibraryExistsTool(MoviePilotTool):
result_dict = { result_dict = {
"title": mediainfo.title, "title": mediainfo.title,
"year": mediainfo.year, "year": mediainfo.year,
"type": existsinfo.type.value if existsinfo.type else None, "type": media_type_to_agent(existsinfo.type),
"server": existsinfo.server, "server": existsinfo.server,
"server_type": existsinfo.server_type, "server_type": existsinfo.server_type,
"itemid": existsinfo.itemid, "itemid": existsinfo.itemid,

View File

@@ -8,7 +8,7 @@ from pydantic import BaseModel, Field
from app.agent.tools.base import MoviePilotTool from app.agent.tools.base import MoviePilotTool
from app.chain.media import MediaChain from app.chain.media import MediaChain
from app.log import logger from app.log import logger
from app.schemas import MediaType from app.schemas.types import MediaType
class QueryMediaDetailInput(BaseModel): class QueryMediaDetailInput(BaseModel):
@@ -34,15 +34,13 @@ class QueryMediaDetailTool(MoviePilotTool):
try: try:
media_chain = MediaChain() media_chain = MediaChain()
media_type_key = (media_type or "").strip().lower() media_type_enum = MediaType.from_agent(media_type)
if media_type_key not in ["movie", "tv"]: if not media_type_enum:
return json.dumps({ return json.dumps({
"success": False, "success": False,
"message": f"无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv'" "message": f"无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv'"
}, ensure_ascii=False) }, ensure_ascii=False)
media_type_enum = MediaType.MOVIE if media_type_key == "movie" else MediaType.TV
mediainfo = await media_chain.async_recognize_media(tmdbid=tmdb_id, mtype=media_type_enum) mediainfo = await media_chain.async_recognize_media(tmdbid=tmdb_id, mtype=media_type_enum)
if not mediainfo: if not mediainfo:

View File

@@ -10,7 +10,7 @@ from app.agent.tools.base import MoviePilotTool
from app.core.context import MediaInfo from app.core.context import MediaInfo
from app.helper.subscribe import SubscribeHelper from app.helper.subscribe import SubscribeHelper
from app.log import logger from app.log import logger
from app.schemas.types import MediaType from app.schemas.types import MediaType, media_type_to_agent
class QueryPopularSubscribesInput(BaseModel): class QueryPopularSubscribesInput(BaseModel):
@@ -69,12 +69,13 @@ class QueryPopularSubscribesTool(MoviePilotTool):
page = 1 page = 1
if count is None or count < 1: if count is None or count < 1:
count = 30 count = 30
if media_type not in ["movie", "tv"]: media_type_enum = MediaType.from_agent(media_type)
if not media_type_enum:
return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv'" return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv'"
subscribe_helper = SubscribeHelper() subscribe_helper = SubscribeHelper()
subscribes = await subscribe_helper.async_get_statistic( subscribes = await subscribe_helper.async_get_statistic(
stype=media_type, stype=media_type_enum.to_agent(),
page=page, page=page,
count=count, count=count,
genre_id=genre_id, genre_id=genre_id,
@@ -134,7 +135,7 @@ class QueryPopularSubscribesTool(MoviePilotTool):
for media in ret_medias: for media in ret_medias:
media_dict = media.to_dict() media_dict = media.to_dict()
simplified = { simplified = {
"type": media_dict.get("type"), "type": media_type_to_agent(media_dict.get("type")),
"title": media_dict.get("title"), "title": media_dict.get("title"),
"year": media_dict.get("year"), "year": media_dict.get("year"),
"tmdb_id": media_dict.get("tmdb_id"), "tmdb_id": media_dict.get("tmdb_id"),

View File

@@ -9,6 +9,7 @@ from app.agent.tools.base import MoviePilotTool
from app.db import AsyncSessionFactory from app.db import AsyncSessionFactory
from app.db.models.subscribehistory import SubscribeHistory from app.db.models.subscribehistory import SubscribeHistory
from app.log import logger from app.log import logger
from app.schemas.types import media_type_to_agent
class QuerySubscribeHistoryInput(BaseModel): class QuerySubscribeHistoryInput(BaseModel):
@@ -83,7 +84,7 @@ class QuerySubscribeHistoryTool(MoviePilotTool):
"id": record.id, "id": record.id,
"name": record.name, "name": record.name,
"year": record.year, "year": record.year,
"type": record.type, "type": media_type_to_agent(record.type),
"season": record.season, "season": record.season,
"tmdbid": record.tmdbid, "tmdbid": record.tmdbid,
"doubanid": record.doubanid, "doubanid": record.doubanid,

View File

@@ -8,6 +8,7 @@ from pydantic import BaseModel, Field
from app.agent.tools.base import MoviePilotTool from app.agent.tools.base import MoviePilotTool
from app.db.subscribe_oper import SubscribeOper from app.db.subscribe_oper import SubscribeOper
from app.log import logger from app.log import logger
from app.schemas.types import MediaType, media_type_to_agent
class QuerySubscribesInput(BaseModel): class QuerySubscribesInput(BaseModel):
@@ -47,17 +48,16 @@ class QuerySubscribesTool(MoviePilotTool):
tmdb_id: Optional[int] = None, **kwargs) -> str: tmdb_id: Optional[int] = None, **kwargs) -> str:
logger.info(f"执行工具: {self.name}, 参数: status={status}, media_type={media_type}, tmdb_id={tmdb_id}") logger.info(f"执行工具: {self.name}, 参数: status={status}, media_type={media_type}, tmdb_id={tmdb_id}")
try: try:
if media_type not in ["all", "movie", "tv"]: if media_type != "all" and not MediaType.from_agent(media_type):
return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv', 'all'" return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv', 'all'"
media_type_map = {"movie": "电影", "tv": "电视剧"}
subscribe_oper = SubscribeOper() subscribe_oper = SubscribeOper()
subscribes = await subscribe_oper.async_list() subscribes = await subscribe_oper.async_list()
filtered_subscribes = [] filtered_subscribes = []
for sub in subscribes: for sub in subscribes:
if status != "all" and sub.state != status: if status != "all" and sub.state != status:
continue continue
if media_type != "all" and sub.type != media_type_map.get(media_type): if media_type != "all" and sub.type != MediaType.from_agent(media_type).value:
continue continue
if tmdb_id is not None and sub.tmdbid != tmdb_id: if tmdb_id is not None and sub.tmdbid != tmdb_id:
continue continue
@@ -73,7 +73,7 @@ class QuerySubscribesTool(MoviePilotTool):
"id": s.id, "id": s.id,
"name": s.name, "name": s.name,
"year": s.year, "year": s.year,
"type": s.type, "type": media_type_to_agent(s.type),
"season": s.season, "season": s.season,
"tmdbid": s.tmdbid, "tmdbid": s.tmdbid,
"doubanid": s.doubanid, "doubanid": s.doubanid,

View File

@@ -10,6 +10,7 @@ from app.agent.tools.base import MoviePilotTool
from app.db import AsyncSessionFactory from app.db import AsyncSessionFactory
from app.db.models.transferhistory import TransferHistory from app.db.models.transferhistory import TransferHistory
from app.log import logger from app.log import logger
from app.schemas.types import media_type_to_agent
class QueryTransferHistoryInput(BaseModel): class QueryTransferHistoryInput(BaseModel):
@@ -95,7 +96,7 @@ class QueryTransferHistoryTool(MoviePilotTool):
"id": record.id, "id": record.id,
"title": record.title, "title": record.title,
"year": record.year, "year": record.year,
"type": record.type, "type": media_type_to_agent(record.type),
"category": record.category, "category": record.category,
"seasons": record.seasons, "seasons": record.seasons,
"episodes": record.episodes, "episodes": record.episodes,

View File

@@ -10,6 +10,7 @@ from app.chain.media import MediaChain
from app.core.context import Context from app.core.context import Context
from app.core.metainfo import MetaInfo from app.core.metainfo import MetaInfo
from app.log import logger from app.log import logger
from app.schemas.types import media_type_to_agent
class RecognizeMediaInput(BaseModel): class RecognizeMediaInput(BaseModel):
@@ -124,7 +125,7 @@ class RecognizeMediaTool(MoviePilotTool):
"title": media_info.get("title"), "title": media_info.get("title"),
"en_title": media_info.get("en_title"), "en_title": media_info.get("en_title"),
"year": media_info.get("year"), "year": media_info.get("year"),
"type": media_info.get("type"), "type": media_type_to_agent(media_info.get("type")),
"season": media_info.get("season"), "season": media_info.get("season"),
"tmdb_id": media_info.get("tmdb_id"), "tmdb_id": media_info.get("tmdb_id"),
"imdb_id": media_info.get("imdb_id"), "imdb_id": media_info.get("imdb_id"),
@@ -145,7 +146,7 @@ class RecognizeMediaTool(MoviePilotTool):
"name": meta_info.get("name"), "name": meta_info.get("name"),
"title": meta_info.get("title"), "title": meta_info.get("title"),
"year": meta_info.get("year"), "year": meta_info.get("year"),
"type": meta_info.get("type"), "type": media_type_to_agent(meta_info.get("type")),
"begin_season": meta_info.get("begin_season"), "begin_season": meta_info.get("begin_season"),
"end_season": meta_info.get("end_season"), "end_season": meta_info.get("end_season"),
"begin_episode": meta_info.get("begin_episode"), "begin_episode": meta_info.get("begin_episode"),

View File

@@ -8,7 +8,7 @@ from pydantic import BaseModel, Field
from app.agent.tools.base import MoviePilotTool from app.agent.tools.base import MoviePilotTool
from app.chain.media import MediaChain from app.chain.media import MediaChain
from app.log import logger from app.log import logger
from app.schemas.types import MediaType from app.schemas.types import MediaType, media_type_to_agent
class SearchMediaInput(BaseModel): class SearchMediaInput(BaseModel):
@@ -58,12 +58,8 @@ class SearchMediaTool(MoviePilotTool):
if results: if results:
media_type_enum = None media_type_enum = None
if media_type: if media_type:
media_type_key = media_type.strip().lower() media_type_enum = MediaType.from_agent(media_type)
if media_type_key == "movie": if not media_type_enum:
media_type_enum = MediaType.MOVIE
elif media_type_key == "tv":
media_type_enum = MediaType.TV
else:
return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv'" return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv'"
filtered_results = [] filtered_results = []
@@ -87,7 +83,7 @@ class SearchMediaTool(MoviePilotTool):
"title": r.title, "title": r.title,
"en_title": r.en_title, "en_title": r.en_title,
"year": r.year, "year": r.year,
"type": r.type.value if r.type else None, "type": media_type_to_agent(r.type),
"season": r.season, "season": r.season,
"tmdb_id": r.tmdb_id, "tmdb_id": r.tmdb_id,
"imdb_id": r.imdb_id, "imdb_id": r.imdb_id,

View File

@@ -10,6 +10,7 @@ from app.chain.subscribe import SubscribeChain
from app.core.config import global_vars from app.core.config import global_vars
from app.db.subscribe_oper import SubscribeOper from app.db.subscribe_oper import SubscribeOper
from app.log import logger from app.log import logger
from app.schemas.types import media_type_to_agent
class SearchSubscribeInput(BaseModel): class SearchSubscribeInput(BaseModel):
@@ -58,7 +59,7 @@ class SearchSubscribeTool(MoviePilotTool):
"id": subscribe.id, "id": subscribe.id,
"name": subscribe.name, "name": subscribe.name,
"year": subscribe.year, "year": subscribe.year,
"type": subscribe.type, "type": media_type_to_agent(subscribe.type),
"season": subscribe.season, "season": subscribe.season,
"state": subscribe.state, "state": subscribe.state,
"total_episode": subscribe.total_episode, "total_episode": subscribe.total_episode,

View File

@@ -63,12 +63,8 @@ class SearchTorrentsTool(MoviePilotTool):
search_chain = SearchChain() search_chain = SearchChain()
media_type_enum = None media_type_enum = None
if media_type: if media_type:
media_type_key = media_type.strip().lower() media_type_enum = MediaType.from_agent(media_type)
if media_type_key == "movie": if not media_type_enum:
media_type_enum = MediaType.MOVIE
elif media_type_key == "tv":
media_type_enum = MediaType.TV
else:
return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv'" return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv'"
filtered_torrents = await search_chain.async_search_by_id( filtered_torrents = await search_chain.async_search_by_id(

View File

@@ -93,12 +93,8 @@ class TransferFileTool(MoviePilotTool):
# 处理媒体类型 # 处理媒体类型
media_type_enum = None media_type_enum = None
if media_type: if media_type:
media_type_key = media_type.strip().lower() media_type_enum = MediaType.from_agent(media_type)
if media_type_key == "movie": if not media_type_enum:
media_type_enum = MediaType.MOVIE
elif media_type_key == "tv":
media_type_enum = MediaType.TV
else:
return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv'" return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv'"
# 调用整理方法 # 调用整理方法

View File

@@ -1,4 +1,5 @@
from enum import Enum from enum import Enum
from typing import Optional
# 媒体类型 # 媒体类型
@@ -8,6 +9,26 @@ class MediaType(Enum):
COLLECTION = '系列' COLLECTION = '系列'
UNKNOWN = '未知' UNKNOWN = '未知'
@staticmethod
def from_agent(key: str) -> Optional["MediaType"]:
"""'movie' -> MediaType.MOVIE, 'tv' -> MediaType.TV, 否则 None"""
_map = {"movie": MediaType.MOVIE, "tv": MediaType.TV}
return _map.get(key.strip().lower() if key else "")
def to_agent(self) -> str:
"""MediaType.MOVIE -> 'movie', MediaType.TV -> 'tv', 其他返回 .value"""
return {MediaType.MOVIE: "movie", MediaType.TV: "tv"}.get(self, self.value)
def media_type_to_agent(value) -> Optional[str]:
"""将 MediaType 枚举或中文字符串统一转为 'movie'/'tv'"""
if isinstance(value, MediaType):
return value.to_agent()
if isinstance(value, str):
mt = MediaType.from_agent(value)
return mt.to_agent() if mt else value
return None
# 排序类型枚举 # 排序类型枚举
class SortType(Enum): class SortType(Enum):