mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-04-24 02:32:01 +08:00
feat: 优化query_library_exists和query_subscribes工具输出,优化SKILL.md
fix(add_download): 更新torrent_url和description字段的描述,移除错误的添加直链的功能
This commit is contained in:
@@ -24,7 +24,7 @@ class AddDownloadInput(BaseModel):
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
torrent_url: List[str] = Field(
|
||||
...,
|
||||
description="One or more torrent_url values. Values matching the hash:id pattern from get_search_results are treated as internal references; other values must be direct torrent URLs or magnet links."
|
||||
description="One or more torrent_url values. Supports refs from get_search_results (`hash:id`) and magnet links."
|
||||
)
|
||||
downloader: Optional[str] = Field(None,
|
||||
description="Name of the downloader to use (optional, uses default if not specified)")
|
||||
@@ -36,7 +36,7 @@ class AddDownloadInput(BaseModel):
|
||||
|
||||
class AddDownloadTool(MoviePilotTool):
|
||||
name: str = "add_download"
|
||||
description: str = "Add torrent download task to the configured downloader (qBittorrent, Transmission, etc.) using hash:id references from get_search_results or direct torrent URLs / magnet links."
|
||||
description: str = "Add torrent download tasks using refs from get_search_results or magnet links."
|
||||
args_schema: Type[BaseModel] = AddDownloadInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
@@ -49,7 +49,7 @@ class AddDownloadTool(MoviePilotTool):
|
||||
if self._is_torrent_ref(torrent_urls[0]):
|
||||
message = f"正在添加下载任务: 资源 {torrent_urls[0]}"
|
||||
else:
|
||||
message = "正在添加下载任务: 直链或磁力链接"
|
||||
message = "正在添加下载任务: 磁力链接"
|
||||
else:
|
||||
message = f"正在批量添加下载任务: 共 {len(torrent_urls)} 个资源"
|
||||
else:
|
||||
@@ -74,12 +74,12 @@ class AddDownloadTool(MoviePilotTool):
|
||||
return bool(re.fullmatch(r"[0-9a-f]{7}:\d+", str(torrent_ref).strip()))
|
||||
|
||||
@staticmethod
|
||||
def _is_direct_download_url(torrent_url: Optional[str]) -> bool:
|
||||
"""判断是否为允许直传下载器的下载内容"""
|
||||
def _is_magnet_link_input(torrent_url: Optional[str]) -> bool:
|
||||
"""判断输入是否为允许直接添加的磁力链接"""
|
||||
if not torrent_url:
|
||||
return False
|
||||
value = str(torrent_url).strip()
|
||||
return value.startswith("http://") or value.startswith("https://") or value.startswith("magnet:")
|
||||
return value.startswith("magnet:")
|
||||
|
||||
@classmethod
|
||||
def _resolve_cached_context(cls, torrent_ref: str) -> Optional[Context]:
|
||||
@@ -127,7 +127,7 @@ class AddDownloadTool(MoviePilotTool):
|
||||
prefix = "添加种子任务失败:"
|
||||
if normalized_error.startswith(prefix):
|
||||
normalized_error = normalized_error[len(prefix):].lstrip()
|
||||
if AddDownloadTool._is_direct_download_url(normalized_error):
|
||||
if AddDownloadTool._is_magnet_link_input(normalized_error):
|
||||
normalized_error = ""
|
||||
if normalized_error:
|
||||
return f"{torrent_ref} {normalized_error}"
|
||||
@@ -227,9 +227,9 @@ class AddDownloadTool(MoviePilotTool):
|
||||
media_info=media_info
|
||||
)
|
||||
else:
|
||||
if not self._is_direct_download_url(torrent_input):
|
||||
if not self._is_magnet_link_input(torrent_input):
|
||||
failed_messages.append(
|
||||
f"{torrent_input} 不是有效的下载内容,非 hash:id 时仅支持 http://、https:// 或 magnet: 开头"
|
||||
f"{torrent_input} 不是有效的下载内容,非 hash:id 时仅支持 magnet: 开头"
|
||||
)
|
||||
continue
|
||||
download_dir = self._resolve_direct_download_dir(save_path)
|
||||
|
||||
@@ -1,16 +1,78 @@
|
||||
"""查询媒体库工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
from collections import OrderedDict
|
||||
from typing import Optional, Type, Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.mediaserver import MediaServerChain
|
||||
from app.helper.mediaserver import MediaServerHelper
|
||||
from app.log import logger
|
||||
from app.schemas.types import MediaType, media_type_to_agent
|
||||
|
||||
|
||||
def _sort_seasons(seasons: Optional[dict]) -> dict:
|
||||
"""按季号、集号升序整理季集信息,保证输出稳定。"""
|
||||
if not seasons:
|
||||
return {}
|
||||
|
||||
def _sort_key(value):
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return str(value)
|
||||
|
||||
return OrderedDict(
|
||||
(season, sorted(episodes, key=_sort_key))
|
||||
for season, episodes in sorted(seasons.items(), key=lambda item: _sort_key(item[0]))
|
||||
)
|
||||
|
||||
|
||||
def _filter_regular_seasons(seasons: Optional[dict]) -> OrderedDict:
|
||||
"""仅保留正片季,忽略 season 0 等特殊季。"""
|
||||
sorted_seasons = _sort_seasons(seasons)
|
||||
regular_seasons = OrderedDict()
|
||||
for season, episodes in sorted_seasons.items():
|
||||
try:
|
||||
season_number = int(season)
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
if season_number > 0:
|
||||
regular_seasons[season_number] = episodes
|
||||
return regular_seasons
|
||||
|
||||
|
||||
def _build_tv_server_result(existing_seasons: OrderedDict, total_seasons: OrderedDict) -> dict[str, Any]:
|
||||
"""构建单个服务器的电视剧存在性结果。"""
|
||||
seasons_result = OrderedDict()
|
||||
missing_seasons = []
|
||||
all_seasons = sorted(set(total_seasons.keys()) | set(existing_seasons.keys()))
|
||||
|
||||
for season in all_seasons:
|
||||
existing_episodes = existing_seasons.get(season, [])
|
||||
total_episodes = total_seasons.get(season)
|
||||
if total_episodes is not None:
|
||||
missing_episodes = [episode for episode in total_episodes if episode not in existing_episodes]
|
||||
total_episode_count = len(total_episodes)
|
||||
else:
|
||||
missing_episodes = None
|
||||
total_episode_count = None
|
||||
seasons_result[str(season)] = {
|
||||
"existing_episodes": existing_episodes,
|
||||
"total_episodes": total_episode_count,
|
||||
"missing_episodes": missing_episodes
|
||||
}
|
||||
if total_episodes is not None and not existing_episodes:
|
||||
missing_seasons.append(season)
|
||||
|
||||
return {
|
||||
"seasons": seasons_result,
|
||||
"missing_seasons": missing_seasons
|
||||
}
|
||||
|
||||
|
||||
class QueryLibraryExistsInput(BaseModel):
|
||||
"""查询媒体库工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
@@ -21,7 +83,7 @@ class QueryLibraryExistsInput(BaseModel):
|
||||
|
||||
class QueryLibraryExistsTool(MoviePilotTool):
|
||||
name: str = "query_library_exists"
|
||||
description: str = "Check whether a specific media resource already exists in the media library (Plex, Emby, Jellyfin) by media ID. Requires tmdb_id or douban_id (can be obtained from search_media tool) for accurate matching."
|
||||
description: str = "Check whether media already exists in Plex, Emby, or Jellyfin by media ID. Results are grouped by media server; TV results include existing episodes, total episodes, and missing episodes/seasons. Requires tmdb_id or douban_id from search_media."
|
||||
args_schema: Type[BaseModel] = QueryLibraryExistsInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
@@ -63,62 +125,53 @@ class QueryLibraryExistsTool(MoviePilotTool):
|
||||
media_id = f"TMDB={tmdb_id}" if tmdb_id else f"豆瓣={douban_id}"
|
||||
return f"未识别到媒体信息: {media_id}"
|
||||
|
||||
# 2. 调用媒体服务器接口实时查询存在信息
|
||||
existsinfo = media_chain.media_exists(mediainfo=mediainfo)
|
||||
# 2. 遍历所有媒体服务器,分别查询存在性信息
|
||||
server_results = OrderedDict()
|
||||
media_server_helper = MediaServerHelper()
|
||||
total_seasons = _filter_regular_seasons(mediainfo.seasons)
|
||||
global_existsinfo = media_chain.media_exists(mediainfo=mediainfo)
|
||||
|
||||
if not existsinfo:
|
||||
for service_name in sorted(media_server_helper.get_services().keys()):
|
||||
existsinfo = media_chain.media_exists(mediainfo=mediainfo, server=service_name)
|
||||
if not existsinfo:
|
||||
continue
|
||||
|
||||
if existsinfo.type == MediaType.TV:
|
||||
existing_seasons = _filter_regular_seasons(existsinfo.seasons)
|
||||
server_results[service_name] = _build_tv_server_result(
|
||||
existing_seasons=existing_seasons,
|
||||
total_seasons=total_seasons
|
||||
)
|
||||
else:
|
||||
server_results[service_name] = {
|
||||
"exists": True
|
||||
}
|
||||
|
||||
if global_existsinfo:
|
||||
fallback_server_name = global_existsinfo.server or "local"
|
||||
if fallback_server_name not in server_results:
|
||||
if global_existsinfo.type == MediaType.TV:
|
||||
server_results[fallback_server_name] = _build_tv_server_result(
|
||||
existing_seasons=_filter_regular_seasons(global_existsinfo.seasons),
|
||||
total_seasons=total_seasons
|
||||
)
|
||||
else:
|
||||
server_results[fallback_server_name] = {
|
||||
"exists": True
|
||||
}
|
||||
|
||||
if not server_results:
|
||||
return "媒体库中未找到相关媒体"
|
||||
|
||||
# 3. 如果找到了,获取详细信息并组装结果
|
||||
result_items = []
|
||||
if existsinfo.itemid and existsinfo.server:
|
||||
iteminfo = media_chain.iteminfo(server=existsinfo.server, item_id=existsinfo.itemid)
|
||||
if iteminfo:
|
||||
# 使用 model_dump() 转换为字典格式
|
||||
item_dict = iteminfo.model_dump(exclude_none=True)
|
||||
|
||||
# 对于电视剧,补充已存在的季集详情及进度统计
|
||||
if existsinfo.type == MediaType.TV:
|
||||
# 注入已存在集信息 (Dict[int, list])
|
||||
item_dict["seasoninfo"] = existsinfo.seasons
|
||||
|
||||
# 统计库中已存在的季集总数
|
||||
if existsinfo.seasons:
|
||||
item_dict["existing_episodes_count"] = sum(len(e) for e in existsinfo.seasons.values())
|
||||
item_dict["seasons_existing_count"] = {str(s): len(e) for s, e in existsinfo.seasons.items()}
|
||||
|
||||
# 如果识别到了元数据,补充总计对比和进度概览
|
||||
if mediainfo.seasons:
|
||||
item_dict["seasons_total_count"] = {str(s): len(e) for s, e in mediainfo.seasons.items()}
|
||||
# 进度概览,例如 "Season 1": "3/12"
|
||||
item_dict["seasons_progress"] = {
|
||||
f"第{s}季": f"{len(existsinfo.seasons.get(s, []))}/{len(mediainfo.seasons.get(s, []))} 集"
|
||||
for s in mediainfo.seasons.keys() if (s in existsinfo.seasons or s > 0)
|
||||
}
|
||||
|
||||
result_items.append(item_dict)
|
||||
|
||||
if result_items:
|
||||
return json.dumps(result_items, ensure_ascii=False)
|
||||
|
||||
# 如果找到了但没有获取到 iteminfo,返回基本信息
|
||||
# 3. 组装统一的存在性结果,不查询媒体服务器详情
|
||||
result_dict = {
|
||||
"title": mediainfo.title,
|
||||
"year": mediainfo.year,
|
||||
"type": media_type_to_agent(existsinfo.type),
|
||||
"server": existsinfo.server,
|
||||
"server_type": existsinfo.server_type,
|
||||
"itemid": existsinfo.itemid,
|
||||
"seasons": existsinfo.seasons if existsinfo.seasons else {}
|
||||
"type": media_type_to_agent(mediainfo.type),
|
||||
"servers": server_results
|
||||
}
|
||||
if existsinfo.type == MediaType.TV and existsinfo.seasons:
|
||||
result_dict["existing_episodes_count"] = sum(len(e) for e in existsinfo.seasons.values())
|
||||
result_dict["seasons_existing_count"] = {str(s): len(e) for s, e in existsinfo.seasons.items()}
|
||||
if mediainfo.seasons:
|
||||
result_dict["seasons_total_count"] = {str(s): len(e) for s, e in mediainfo.seasons.items()}
|
||||
|
||||
return json.dumps([result_dict], ensure_ascii=False)
|
||||
except Exception as e:
|
||||
logger.error(f"查询媒体库失败: {e}", exc_info=True)
|
||||
return f"查询媒体库时发生错误: {str(e)}"
|
||||
|
||||
|
||||
@@ -8,7 +8,35 @@ from pydantic import BaseModel, Field
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.db.subscribe_oper import SubscribeOper
|
||||
from app.log import logger
|
||||
from app.schemas.types import MediaType, media_type_to_agent
|
||||
from app.schemas.subscribe import Subscribe as SubscribeSchema
|
||||
from app.schemas.types import MediaType
|
||||
|
||||
QUERY_SUBSCRIBE_OUTPUT_FIELDS = [
|
||||
"id",
|
||||
"name",
|
||||
"year",
|
||||
"type",
|
||||
"season",
|
||||
"total_episode",
|
||||
"start_episode",
|
||||
"lack_episode",
|
||||
"filter",
|
||||
"include",
|
||||
"exclude",
|
||||
"quality",
|
||||
"resolution",
|
||||
"effect",
|
||||
"state",
|
||||
"last_update",
|
||||
"sites",
|
||||
"downloader",
|
||||
"best_version",
|
||||
"save_path",
|
||||
"custom_words",
|
||||
"media_category",
|
||||
"filter_groups",
|
||||
"episode_group"
|
||||
]
|
||||
|
||||
|
||||
class QuerySubscribesInput(BaseModel):
|
||||
@@ -24,7 +52,7 @@ class QuerySubscribesInput(BaseModel):
|
||||
|
||||
class QuerySubscribesTool(MoviePilotTool):
|
||||
name: str = "query_subscribes"
|
||||
description: str = "Query subscription status and list all user subscriptions. Shows active subscriptions, their download status, and configuration details."
|
||||
description: str = "Query subscription status and list user subscriptions. Returns full subscription parameters for each matched subscription."
|
||||
args_schema: Type[BaseModel] = QuerySubscribesInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
@@ -69,28 +97,14 @@ class QuerySubscribesTool(MoviePilotTool):
|
||||
# 限制最多50条结果
|
||||
total_count = len(filtered_subscribes)
|
||||
limited_subscribes = filtered_subscribes[:50]
|
||||
# 精简字段,只保留关键信息
|
||||
simplified_subscribes = []
|
||||
for s in limited_subscribes:
|
||||
simplified = {
|
||||
"id": s.id,
|
||||
"name": s.name,
|
||||
"year": s.year,
|
||||
"type": media_type_to_agent(s.type),
|
||||
"season": s.season,
|
||||
"tmdbid": s.tmdbid,
|
||||
"doubanid": s.doubanid,
|
||||
"bangumiid": s.bangumiid,
|
||||
"poster": s.poster,
|
||||
"vote": s.vote,
|
||||
"state": s.state,
|
||||
"total_episode": s.total_episode,
|
||||
"lack_episode": s.lack_episode,
|
||||
"last_update": s.last_update,
|
||||
"username": s.username
|
||||
}
|
||||
simplified_subscribes.append(simplified)
|
||||
result_json = json.dumps(simplified_subscribes, ensure_ascii=False, indent=2)
|
||||
full_subscribes = [
|
||||
SubscribeSchema.model_validate(s, from_attributes=True).model_dump(
|
||||
include=set(QUERY_SUBSCRIBE_OUTPUT_FIELDS),
|
||||
exclude_none=True
|
||||
)
|
||||
for s in limited_subscribes
|
||||
]
|
||||
result_json = json.dumps(full_subscribes, ensure_ascii=False, indent=2)
|
||||
# 如果结果被裁剪,添加提示信息
|
||||
if total_count > 50:
|
||||
return f"注意:查询结果共找到 {total_count} 条,为节省上下文空间,仅显示前 50 条结果。\n\n{result_json}"
|
||||
|
||||
Reference in New Issue
Block a user