diff --git a/app/agent/tools/factory.py b/app/agent/tools/factory.py index f665a416..1c2e4282 100644 --- a/app/agent/tools/factory.py +++ b/app/agent/tools/factory.py @@ -36,6 +36,7 @@ from app.agent.tools.impl.query_workflows import QueryWorkflowsTool from app.agent.tools.impl.run_workflow import RunWorkflowTool from app.agent.tools.impl.update_site_cookie import UpdateSiteCookieTool from app.agent.tools.impl.delete_download import DeleteDownloadTool +from app.agent.tools.impl.modify_download import ModifyDownloadTool from app.agent.tools.impl.query_directory_settings import QueryDirectorySettingsTool from app.agent.tools.impl.list_directory import ListDirectoryTool from app.agent.tools.impl.query_transfer_history import QueryTransferHistoryTool @@ -85,6 +86,7 @@ class MoviePilotToolFactory: DeleteSubscribeTool, QueryDownloadTasksTool, DeleteDownloadTool, + ModifyDownloadTool, QueryDownloadersTool, QuerySitesTool, UpdateSiteTool, diff --git a/app/agent/tools/impl/modify_download.py b/app/agent/tools/impl/modify_download.py new file mode 100644 index 00000000..4c0c0a47 --- /dev/null +++ b/app/agent/tools/impl/modify_download.py @@ -0,0 +1,123 @@ +"""修改下载任务工具""" + +from typing import Optional, Type, List + +from pydantic import BaseModel, Field + +from app.agent.tools.base import MoviePilotTool +from app.chain.download import DownloadChain +from app.log import logger + + +class ModifyDownloadInput(BaseModel): + """修改下载任务工具的输入参数模型""" + + explanation: str = Field( + ..., + description="Clear explanation of why this tool is being used in the current context", + ) + hash: str = Field( + ..., description="Task hash (can be obtained from query_download_tasks tool)" + ) + action: Optional[str] = Field( + None, + description="Action to perform on the task: 'start' to resume downloading, 'stop' to pause downloading. " + "If not provided, no start/stop action will be performed.", + ) + tags: Optional[List[str]] = Field( + None, + description="List of tags to set on the download task. If provided, these tags will be added to the task. " + "Example: ['movie', 'hd']", + ) + downloader: Optional[str] = Field( + None, + description="Name of specific downloader (optional, if not provided will search all downloaders)", + ) + + +class ModifyDownloadTool(MoviePilotTool): + """修改下载任务工具""" + + name: str = "modify_download" + description: str = ( + "Modify a download task in the downloader by task hash. " + "Supports: 1) Setting tags on a download task, " + "2) Starting (resuming) a paused download task, " + "3) Stopping (pausing) a downloading task. " + "Multiple operations can be performed in a single call." + ) + args_schema: Type[BaseModel] = ModifyDownloadInput + + def get_tool_message(self, **kwargs) -> Optional[str]: + hash_value = kwargs.get("hash", "") + action = kwargs.get("action") + tags = kwargs.get("tags") + downloader = kwargs.get("downloader") + + parts = [f"正在修改下载任务: {hash_value}"] + if action == "start": + parts.append("操作: 开始下载") + elif action == "stop": + parts.append("操作: 暂停下载") + if tags: + parts.append(f"标签: {', '.join(tags)}") + if downloader: + parts.append(f"下载器: {downloader}") + return " | ".join(parts) + + async def run( + self, + hash: str, + action: Optional[str] = None, + tags: Optional[List[str]] = None, + downloader: Optional[str] = None, + **kwargs, + ) -> str: + logger.info( + f"执行工具: {self.name}, 参数: hash={hash}, action={action}, tags={tags}, downloader={downloader}" + ) + + try: + # 校验 hash 格式 + if len(hash) != 40 or not all(c in "0123456789abcdefABCDEF" for c in hash): + return "参数错误:hash 格式无效,请先使用 query_download_tasks 工具获取正确的 hash。" + + # 校验参数:至少需要一个操作 + if not action and not tags: + return "参数错误:至少需要指定 action(start/stop)或 tags 中的一个。" + + # 校验 action 参数 + if action and action not in ("start", "stop"): + return f"参数错误:action 只支持 'start'(开始下载)或 'stop'(暂停下载),收到: '{action}'。" + + download_chain = DownloadChain() + results = [] + + # 设置标签 + if tags: + tag_result = download_chain.set_torrents_tag( + hashs=[hash], tags=tags, downloader=downloader + ) + if tag_result: + results.append(f"成功设置标签:{', '.join(tags)}") + else: + results.append(f"设置标签失败,请检查任务是否存在或下载器是否可用") + + # 执行开始/暂停操作 + if action: + action_result = download_chain.set_downloading( + hash_str=hash, oper=action, name=downloader + ) + action_desc = "开始" if action == "start" else "暂停" + if action_result: + results.append(f"成功{action_desc}下载任务") + else: + results.append( + f"{action_desc}下载任务失败,请检查任务是否存在或下载器是否可用" + ) + + return f"下载任务 {hash}:" + ";".join(results) + + except Exception as e: + logger.error(f"修改下载任务失败: {e}", exc_info=True) + return f"修改下载任务时发生错误: {str(e)}" diff --git a/app/agent/tools/impl/query_download_tasks.py b/app/agent/tools/impl/query_download_tasks.py index c0e87544..6f22cd63 100644 --- a/app/agent/tools/impl/query_download_tasks.py +++ b/app/agent/tools/impl/query_download_tasks.py @@ -214,6 +214,7 @@ class QueryDownloadTasksTool(MoviePilotTool): "state": d.state, "upspeed": d.upspeed, "dlspeed": d.dlspeed, + "tags": d.tags, "left_time": d.left_time } # 精简 media 字段 diff --git a/app/chain/__init__.py b/app/chain/__init__.py index 246c2887..de760c4e 100644 --- a/app/chain/__init__.py +++ b/app/chain/__init__.py @@ -1034,6 +1034,18 @@ class ChainBase(metaclass=ABCMeta): """ return self.run_module("stop_torrents", hashs=hashs, downloader=downloader) + def set_torrents_tag( + self, hashs: Union[list, str], tags: list, downloader: Optional[str] = None + ) -> bool: + """ + 设置种子标签 + :param hashs: 种子Hash + :param tags: 标签列表 + :param downloader: 下载器 + :return: bool + """ + return self.run_module("set_torrents_tag", hashs=hashs, tags=tags, downloader=downloader) + def torrent_files( self, tid: str, downloader: Optional[str] = None ) -> Optional[Union[TorrentFilesList, List[File]]]: diff --git a/app/modules/qbittorrent/__init__.py b/app/modules/qbittorrent/__init__.py index e7582d7b..c335a4a5 100644 --- a/app/modules/qbittorrent/__init__.py +++ b/app/modules/qbittorrent/__init__.py @@ -318,6 +318,7 @@ class QbittorrentModule(_ModuleBase, _DownloaderBase[Qbittorrent]): state="paused" if torrent.get('state') in ("paused", "pausedDL") else "downloading", dlspeed=StringUtils.str_filesize(torrent.get('dlspeed')), upspeed=StringUtils.str_filesize(torrent.get('upspeed')), + tags=torrent.get('tags'), left_time=StringUtils.str_secends( (torrent.get('total_size') - torrent.get('completed')) / torrent.get( 'dlspeed')) if torrent.get( @@ -356,6 +357,21 @@ class QbittorrentModule(_ModuleBase, _DownloaderBase[Qbittorrent]): return None return server.delete_torrents(delete_file=delete_file, ids=hashs) + def set_torrents_tag(self, hashs: Union[str, list], tags: list, + downloader: Optional[str] = None) -> Optional[bool]: + """ + 设置种子标签 + :param hashs: 种子Hash + :param tags: 标签列表 + :param downloader: 下载器 + :return: bool + """ + server: Qbittorrent = self.get_instance(downloader) + if not server: + return None + server.set_torrents_tag(ids=hashs, tags=tags) + return True + def start_torrents(self, hashs: Union[list, str], downloader: Optional[str] = None) -> Optional[bool]: """ diff --git a/app/modules/rtorrent/__init__.py b/app/modules/rtorrent/__init__.py index e55f76f4..ebed2313 100644 --- a/app/modules/rtorrent/__init__.py +++ b/app/modules/rtorrent/__init__.py @@ -391,6 +391,7 @@ class RtorrentModule(_ModuleBase, _DownloaderBase[Rtorrent]): else "downloading", dlspeed=StringUtils.str_filesize(dlspeed), upspeed=StringUtils.str_filesize(upspeed), + tags=torrent.get("tags"), left_time=StringUtils.str_secends( (total_size - completed) / dlspeed ) @@ -445,6 +446,22 @@ class RtorrentModule(_ModuleBase, _DownloaderBase[Rtorrent]): return None return server.delete_torrents(delete_file=delete_file, ids=hashs) + def set_torrents_tag( + self, hashs: Union[str, list], tags: list, + downloader: Optional[str] = None, + ) -> Optional[bool]: + """ + 设置种子标签 + :param hashs: 种子Hash + :param tags: 标签列表 + :param downloader: 下载器 + :return: bool + """ + server: Rtorrent = self.get_instance(downloader) + if not server: + return None + return server.set_torrents_tag(ids=hashs, tags=tags) + def start_torrents( self, hashs: Union[list, str], downloader: Optional[str] = None ) -> Optional[bool]: diff --git a/app/modules/transmission/__init__.py b/app/modules/transmission/__init__.py index 15146e0a..0da127e9 100644 --- a/app/modules/transmission/__init__.py +++ b/app/modules/transmission/__init__.py @@ -309,6 +309,7 @@ class TransmissionModule(_ModuleBase, _DownloaderBase[Transmission]): state="paused" if torrent.status == "stopped" else "downloading", dlspeed=StringUtils.str_filesize(dlspeed), upspeed=StringUtils.str_filesize(upspeed), + tags=",".join(torrent.labels or []), left_time=StringUtils.str_secends(torrent.left_until_done / dlspeed) if dlspeed > 0 else '' )) finally: @@ -353,6 +354,23 @@ class TransmissionModule(_ModuleBase, _DownloaderBase[Transmission]): return None return server.delete_torrents(delete_file=delete_file, ids=hashs) + def set_torrents_tag(self, hashs: Union[str, list], tags: list, + downloader: Optional[str] = None) -> Optional[bool]: + """ + 设置种子标签 + :param hashs: 种子Hash + :param tags: 标签列表 + :param downloader: 下载器 + :return: bool + """ + # 获取下载器 + server: Transmission = self.get_instance(downloader) + if not server: + return None + # 获取原标签,TR默认会覆盖,需追加 + org_tags = server.get_torrent_tags(ids=hashs) + return server.set_torrent_tag(ids=hashs, tags=tags, org_tags=org_tags) + def start_torrents(self, hashs: Union[list, str], downloader: Optional[str] = None) -> Optional[bool]: """ diff --git a/app/schemas/transfer.py b/app/schemas/transfer.py index 94cb93c1..642ef213 100644 --- a/app/schemas/transfer.py +++ b/app/schemas/transfer.py @@ -40,6 +40,7 @@ class DownloadingTorrent(BaseModel): state: Optional[str] = 'downloading' upspeed: Optional[str] = None dlspeed: Optional[str] = None + tags: Optional[str] = None media: Optional[dict] = Field(default_factory=dict) userid: Optional[str] = None username: Optional[str] = None