Files
MoviePilot/app/agent/tools/impl/add_download.py
PKC278 bc2302baeb feat: 优化query_library_exists和query_subscribes工具输出,优化SKILL.md
fix(add_download): 更新torrent_url和description字段的描述,移除错误的添加直链的功能
2026-03-22 01:47:38 +08:00

278 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""添加下载工具"""
import re
from pathlib import Path
from typing import List, Optional, Type
from pydantic import BaseModel, Field
from app.agent.tools.base import MoviePilotTool, ToolChain
from app.chain.search import SearchChain
from app.chain.download import DownloadChain
from app.core.config import settings
from app.core.context import Context
from app.core.metainfo import MetaInfo
from app.db.site_oper import SiteOper
from app.helper.directory import DirectoryHelper
from app.log import logger
from app.schemas import TorrentInfo, FileURI
from app.utils.crypto import HashUtils
class AddDownloadInput(BaseModel):
"""添加下载工具的输入参数模型"""
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
torrent_url: List[str] = Field(
...,
description="One or more torrent_url values. Supports refs from get_search_results (`hash:id`) and magnet links."
)
downloader: Optional[str] = Field(None,
description="Name of the downloader to use (optional, uses default if not specified)")
save_path: Optional[str] = Field(None,
description="Directory path where the downloaded files should be saved. Using `<storage>:<path>` for remote storage. e.g. rclone:/MP, smb:/server/share/Movies. (optional, uses default path if not specified)")
labels: Optional[str] = Field(None,
description="Comma-separated list of labels/tags to assign to the download (optional, e.g., 'movie,hd,bluray')")
class AddDownloadTool(MoviePilotTool):
name: str = "add_download"
description: str = "Add torrent download tasks using refs from get_search_results or magnet links."
args_schema: Type[BaseModel] = AddDownloadInput
def get_tool_message(self, **kwargs) -> Optional[str]:
"""根据下载参数生成友好的提示消息"""
torrent_urls = self._normalize_torrent_urls(kwargs.get("torrent_url"))
downloader = kwargs.get("downloader")
if torrent_urls:
if len(torrent_urls) == 1:
if self._is_torrent_ref(torrent_urls[0]):
message = f"正在添加下载任务: 资源 {torrent_urls[0]}"
else:
message = "正在添加下载任务: 磁力链接"
else:
message = f"正在批量添加下载任务: 共 {len(torrent_urls)} 个资源"
else:
message = "正在添加下载任务"
if downloader:
message += f" [下载器: {downloader}]"
return message
@staticmethod
def _build_torrent_ref(context: Context) -> str:
"""生成用于校验缓存项的短引用"""
if not context or not context.torrent_info:
return ""
return HashUtils.sha1(context.torrent_info.enclosure or "")[:7]
@staticmethod
def _is_torrent_ref(torrent_ref: Optional[str]) -> bool:
"""判断是否为内部搜索结果引用"""
if not torrent_ref:
return False
return bool(re.fullmatch(r"[0-9a-f]{7}:\d+", str(torrent_ref).strip()))
@staticmethod
def _is_magnet_link_input(torrent_url: Optional[str]) -> bool:
"""判断输入是否为允许直接添加的磁力链接"""
if not torrent_url:
return False
value = str(torrent_url).strip()
return value.startswith("magnet:")
@classmethod
def _resolve_cached_context(cls, torrent_ref: str) -> Optional[Context]:
"""从最近一次搜索缓存中解析种子上下文,仅支持 hash:id 格式"""
ref = str(torrent_ref).strip()
if ":" not in ref:
return None
try:
ref_hash, ref_index = ref.split(":", 1)
index = int(ref_index)
except (TypeError, ValueError):
return None
if index < 1:
return None
results = SearchChain().last_search_results() or []
if index > len(results):
return None
context = results[index - 1]
if not ref_hash or cls._build_torrent_ref(context) != ref_hash:
return None
return context
@staticmethod
def _merge_labels_with_system_tag(labels: Optional[str]) -> Optional[str]:
"""合并用户标签与系统默认标签,确保任务可被系统管理"""
system_tag = (settings.TORRENT_TAG or "").strip()
user_labels = [item.strip() for item in (labels or "").split(",") if item.strip()]
if system_tag and system_tag not in user_labels:
user_labels.append(system_tag)
return ",".join(user_labels) if user_labels else None
@staticmethod
def _format_failed_result(failed_messages: List[str]) -> str:
"""统一格式化失败结果"""
return ", ".join([message for message in failed_messages if message])
@staticmethod
def _build_failure_message(torrent_ref: str, error_msg: Optional[str] = None) -> str:
"""构造失败提示"""
normalized_error = (error_msg or "").strip()
prefix = "添加种子任务失败:"
if normalized_error.startswith(prefix):
normalized_error = normalized_error[len(prefix):].lstrip()
if AddDownloadTool._is_magnet_link_input(normalized_error):
normalized_error = ""
if normalized_error:
return f"{torrent_ref} {normalized_error}"
if AddDownloadTool._is_torrent_ref(torrent_ref):
return torrent_ref
return ""
@classmethod
def _normalize_torrent_urls(cls, torrent_url: Optional[List[str] | str]) -> List[str]:
"""统一规范 torrent_url 输入,保留所有非空值"""
if torrent_url is None:
return []
if isinstance(torrent_url, str):
candidates = torrent_url.split(",")
else:
candidates = torrent_url
return [str(item).strip() for item in candidates if item and str(item).strip()]
@staticmethod
def _resolve_direct_download_dir(save_path: Optional[str]) -> Optional[Path]:
"""解析直接下载使用的目录,优先使用 save_path其次使用默认下载目录"""
if save_path:
return Path(save_path)
download_dirs = DirectoryHelper().get_download_dirs()
if not download_dirs:
return None
dir_conf = download_dirs[0]
if not dir_conf.download_path:
return None
return Path(FileURI(storage=dir_conf.storage or "local", path=dir_conf.download_path).uri)
async def run(self, torrent_url: Optional[List[str]] = None,
downloader: Optional[str] = None, save_path: Optional[str] = None,
labels: Optional[str] = None, **kwargs) -> str:
logger.info(
f"执行工具: {self.name}, 参数: torrent_url={torrent_url}, downloader={downloader}, save_path={save_path}, labels={labels}")
try:
torrent_inputs = self._normalize_torrent_urls(torrent_url)
if not torrent_inputs:
return "错误torrent_url 不能为空。"
download_chain = DownloadChain()
merged_labels = self._merge_labels_with_system_tag(labels)
success_count = 0
failed_messages = []
for torrent_input in torrent_inputs:
if self._is_torrent_ref(torrent_input):
cached_context = self._resolve_cached_context(torrent_input)
if not cached_context or not cached_context.torrent_info:
failed_messages.append(f"{torrent_input} 引用无效,请重新使用 get_search_results 查看搜索结果")
continue
cached_torrent = cached_context.torrent_info
site_name = cached_torrent.site_name
torrent_title = cached_torrent.title or torrent_input
torrent_description = cached_torrent.description
enclosure = cached_torrent.enclosure
if not site_name:
failed_messages.append(f"{torrent_input} 缺少站点名称")
continue
siteinfo = await SiteOper().async_get_by_name(site_name)
if not siteinfo:
failed_messages.append(f"{torrent_input} 未找到站点信息 {site_name}")
continue
torrent_info = TorrentInfo(
title=torrent_title,
description=torrent_description,
enclosure=enclosure,
site_name=site_name,
site_ua=siteinfo.ua,
site_cookie=siteinfo.cookie,
site_proxy=siteinfo.proxy,
site_order=siteinfo.pri,
site_downloader=siteinfo.downloader
)
meta_info = MetaInfo(title=torrent_title, subtitle=torrent_description)
media_info = cached_context.media_info if cached_context.media_info else None
if not media_info:
media_info = await ToolChain().async_recognize_media(meta=meta_info)
if not media_info:
failed_messages.append(f"{torrent_input} 无法识别媒体信息")
continue
context = Context(
torrent_info=torrent_info,
meta_info=meta_info,
media_info=media_info
)
else:
if not self._is_magnet_link_input(torrent_input):
failed_messages.append(
f"{torrent_input} 不是有效的下载内容,非 hash:id 时仅支持 magnet: 开头"
)
continue
download_dir = self._resolve_direct_download_dir(save_path)
if not download_dir:
failed_messages.append(f"{torrent_input} 缺少保存路径,且系统未配置可用下载目录")
continue
result = download_chain.download(
content=torrent_input,
download_dir=download_dir,
cookie=None,
label=merged_labels,
downloader=downloader
)
if result:
_, did, _, error_msg = result
else:
did, error_msg = None, "未找到下载器"
if did:
success_count += 1
else:
failed_messages.append(self._build_failure_message(torrent_input, error_msg))
continue
did, error_msg = download_chain.download_single(
context=context,
downloader=downloader,
save_path=save_path,
label=merged_labels,
return_detail=True
)
if did:
success_count += 1
else:
failed_messages.append(self._build_failure_message(torrent_input, error_msg))
if success_count and not failed_messages:
return "任务添加成功"
if success_count:
return f"部分任务添加失败:{self._format_failed_result(failed_messages)}"
return f"任务添加失败:{self._format_failed_result(failed_messages)}"
except Exception as e:
logger.error(f"添加下载任务失败: {e}", exc_info=True)
return f"添加下载任务时发生错误: {str(e)}"