Files
MoviePilot/app/agent/tools/impl/add_download.py

175 lines
7.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""添加下载工具"""
import re
from typing import Optional, Type
from pydantic import BaseModel, Field
from app.agent.tools.base import MoviePilotTool, ToolChain
from app.chain.search import SearchChain
from app.chain.download import DownloadChain
from app.core.context import Context
from app.core.metainfo import MetaInfo
from app.db.site_oper import SiteOper
from app.log import logger
from app.schemas import TorrentInfo
from app.utils.crypto import HashUtils
class AddDownloadInput(BaseModel):
"""添加下载工具的输入参数模型"""
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
site_name: Optional[str] = Field(None, description="Name of the torrent site/source (e.g., 'The Pirate Bay')")
torrent_title: Optional[str] = Field(
None,
description="The display name/title of the torrent (e.g., 'The.Matrix.1999.1080p.BluRay.x264')"
)
torrent_url: str = Field(
...,
description="Torrent download link, magnet URI, or search result reference returned by search_torrents"
)
torrent_description: Optional[str] = Field(None,
description="Brief description of the torrent content (optional)")
downloader: Optional[str] = Field(None,
description="Name of the downloader to use (optional, uses default if not specified)")
save_path: Optional[str] = Field(None,
description="Directory path where the downloaded files should be saved. Using `<storage>:<path>` for remote storage. e.g. rclone:/MP, smb:/server/share/Movies. (optional, uses default path if not specified)")
labels: Optional[str] = Field(None,
description="Comma-separated list of labels/tags to assign to the download (optional, e.g., 'movie,hd,bluray')")
class AddDownloadTool(MoviePilotTool):
name: str = "add_download"
description: str = "Add torrent download task to the configured downloader (qBittorrent, Transmission, etc.). Downloads the torrent file and starts the download process with specified settings."
args_schema: Type[BaseModel] = AddDownloadInput
def get_tool_message(self, **kwargs) -> Optional[str]:
"""根据下载参数生成友好的提示消息"""
torrent_title = kwargs.get("torrent_title", "")
torrent_url = kwargs.get("torrent_url")
site_name = kwargs.get("site_name", "")
downloader = kwargs.get("downloader")
message = f"正在添加下载任务: {torrent_title or f'资源 {torrent_url}'}"
if site_name:
message += f" (来源: {site_name})"
if downloader:
message += f" [下载器: {downloader}]"
return message
@staticmethod
def _build_torrent_ref(context: Context) -> str:
"""生成用于校验缓存项的短引用"""
if not context or not context.torrent_info:
return ""
return HashUtils.sha1(context.torrent_info.enclosure or "")[:7]
@staticmethod
def _is_torrent_ref(torrent_ref: Optional[str]) -> bool:
"""判断是否为内部搜索结果引用"""
if not torrent_ref:
return False
return bool(re.fullmatch(r"[0-9a-f]{7}:\d+", str(torrent_ref).strip()))
@classmethod
def _resolve_cached_context(cls, torrent_ref: str) -> Optional[Context]:
"""从最近一次搜索缓存中解析种子上下文,仅支持 hash:id 格式"""
ref = str(torrent_ref).strip()
if ":" not in ref:
return None
try:
ref_hash, ref_index = ref.split(":", 1)
index = int(ref_index)
except (TypeError, ValueError):
return None
if index < 1:
return None
results = SearchChain().last_search_results() or []
if index > len(results):
return None
context = results[index - 1]
if not ref_hash or cls._build_torrent_ref(context) != ref_hash:
return None
return context
async def run(self, site_name: Optional[str] = None, torrent_title: Optional[str] = None,
torrent_url: Optional[str] = None,
torrent_description: Optional[str] = None,
downloader: Optional[str] = None, save_path: Optional[str] = None,
labels: Optional[str] = None, **kwargs) -> str:
logger.info(
f"执行工具: {self.name}, 参数: site_name={site_name}, torrent_title={torrent_title}, torrent_url={torrent_url}, downloader={downloader}, save_path={save_path}, labels={labels}")
try:
cached_context = None
if torrent_url:
is_torrent_ref = self._is_torrent_ref(torrent_url)
cached_context = self._resolve_cached_context(torrent_url)
if is_torrent_ref and (not cached_context or not cached_context.torrent_info):
return "错误torrent_url 无效,请重新使用 search_torrents 搜索"
if not cached_context or not cached_context.torrent_info:
cached_context = None
if cached_context and cached_context.torrent_info:
cached_torrent = cached_context.torrent_info
site_name = site_name or cached_torrent.site_name
torrent_title = torrent_title or cached_torrent.title
torrent_url = cached_torrent.enclosure
torrent_description = torrent_description or cached_torrent.description
if not torrent_title:
return "错误:必须提供种子标题和下载链接"
if not torrent_url:
return "错误:必须提供种子标题和下载链接"
# 使用DownloadChain添加下载
download_chain = DownloadChain()
# 根据站点名称查询站点cookie
if not site_name:
return "错误:必须提供站点名称,请从搜索资源结果信息中获取"
siteinfo = await SiteOper().async_get_by_name(site_name)
if not siteinfo:
return f"错误:未找到站点信息:{site_name}"
# 创建下载上下文
torrent_info = TorrentInfo(
title=torrent_title,
description=torrent_description,
enclosure=torrent_url,
site_name=site_name,
site_ua=siteinfo.ua,
site_cookie=siteinfo.cookie,
site_proxy=siteinfo.proxy,
site_order=siteinfo.pri,
site_downloader=siteinfo.downloader
)
meta_info = MetaInfo(title=torrent_title, subtitle=torrent_description)
media_info = cached_context.media_info if cached_context and cached_context.media_info else None
if not media_info:
media_info = await ToolChain().async_recognize_media(meta=meta_info)
if not media_info:
return "错误:无法识别媒体信息,无法添加下载任务"
context = Context(
torrent_info=torrent_info,
meta_info=meta_info,
media_info=media_info
)
did = download_chain.download_single(
context=context,
downloader=downloader,
save_path=save_path,
label=labels
)
if did:
return f"成功添加下载任务:{torrent_title}"
else:
return "添加下载任务失败"
except Exception as e:
logger.error(f"添加下载任务失败: {e}", exc_info=True)
return f"添加下载任务时发生错误: {str(e)}"