diff --git a/app/agent/tools/base.py b/app/agent/tools/base.py index 525095a5..8cc1b28c 100644 --- a/app/agent/tools/base.py +++ b/app/agent/tools/base.py @@ -30,11 +30,13 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta): _source: Optional[str] = PrivateAttr(default=None) _username: Optional[str] = PrivateAttr(default=None) _stream_handler: Optional[StreamingHandler] = PrivateAttr(default=None) + _require_admin: bool = PrivateAttr(default=False) def __init__(self, session_id: str, user_id: str, **kwargs): super().__init__(**kwargs) self._session_id = session_id self._user_id = user_id + self._require_admin = getattr(self.__class__, "require_admin", False) def _run(self, *args: Any, **kwargs: Any) -> Any: raise NotImplementedError("MoviePilotTool 只支持异步调用,请使用 _arun") @@ -143,11 +145,15 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta): async def _check_permission(self) -> Optional[str]: """ 检查用户权限: - 1. 首先检查用户是否是渠道管理员 - 2. 如果渠道没有设置管理员名单,则检查用户是否是系统管理员 - 3. 如果都不是系统管理员,检查用户ID是否等于渠道配置的用户ID - 4. 如果都不是,返回权限拒绝消息 + 1. 首先检查工具是否需要管理员权限 + 2. 如果需要管理员权限,则检查用户是否是渠道管理员 + 3. 如果渠道没有设置管理员名单,则检查用户是否是系统管理员 + 4. 如果都不是系统管理员,检查用户ID是否等于渠道配置的用户ID + 5. 如果都不是,返回权限拒绝消息 """ + if not self._require_admin: + return None + if not self._channel or not self._source: return None diff --git a/app/agent/tools/impl/delete_download.py b/app/agent/tools/impl/delete_download.py index 9433d765..ebd54bac 100644 --- a/app/agent/tools/impl/delete_download.py +++ b/app/agent/tools/impl/delete_download.py @@ -11,46 +11,68 @@ from app.log import logger class DeleteDownloadInput(BaseModel): """删除下载任务工具的输入参数模型""" - explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context") - hash: str = Field(..., description="Task hash (can be obtained from query_download_tasks tool)") - downloader: Optional[str] = Field(None, description="Name of specific downloader (optional, if not provided will search all downloaders)") - delete_files: Optional[bool] = Field(False, description="Whether to delete downloaded files along with the task (default: False, only removes the task from downloader)") + + explanation: str = Field( + ..., + description="Clear explanation of why this tool is being used in the current context", + ) + hash: str = Field( + ..., description="Task hash (can be obtained from query_download_tasks tool)" + ) + downloader: Optional[str] = Field( + None, + description="Name of specific downloader (optional, if not provided will search all downloaders)", + ) + delete_files: Optional[bool] = Field( + False, + description="Whether to delete downloaded files along with the task (default: False, only removes the task from downloader)", + ) class DeleteDownloadTool(MoviePilotTool): name: str = "delete_download" description: str = "Delete a download task from the downloader by task hash only. Optionally specify the downloader name and whether to delete downloaded files." args_schema: Type[BaseModel] = DeleteDownloadInput + require_admin: bool = True def get_tool_message(self, **kwargs) -> Optional[str]: """根据删除参数生成友好的提示消息""" hash_value = kwargs.get("hash", "") downloader = kwargs.get("downloader") delete_files = kwargs.get("delete_files", False) - + message = f"正在删除下载任务: {hash_value}" if downloader: message += f" [下载器: {downloader}]" if delete_files: message += " (包含文件)" - + return message - async def run(self, hash: str, downloader: Optional[str] = None, - delete_files: Optional[bool] = False, **kwargs) -> str: - logger.info(f"执行工具: {self.name}, 参数: hash={hash}, downloader={downloader}, delete_files={delete_files}") + async def run( + self, + hash: str, + downloader: Optional[str] = None, + delete_files: Optional[bool] = False, + **kwargs, + ) -> str: + logger.info( + f"执行工具: {self.name}, 参数: hash={hash}, downloader={downloader}, delete_files={delete_files}" + ) try: download_chain = DownloadChain() # 仅支持通过hash删除任务 - if len(hash) != 40 or not all(c in '0123456789abcdefABCDEF' for c in hash): + if len(hash) != 40 or not all(c in "0123456789abcdefABCDEF" for c in hash): return "参数错误:hash 格式无效,请先使用 query_download_tasks 工具获取正确的 hash。" - + # 删除下载任务 # remove_torrents 支持 delete_file 参数,可以控制是否删除文件 - result = download_chain.remove_torrents(hashs=[hash], downloader=downloader, delete_file=delete_files) - + result = download_chain.remove_torrents( + hashs=[hash], downloader=downloader, delete_file=delete_files + ) + if result: files_info = "(包含文件)" if delete_files else "(不包含文件)" return f"成功删除下载任务:{hash} {files_info}" @@ -59,4 +81,3 @@ class DeleteDownloadTool(MoviePilotTool): except Exception as e: logger.error(f"删除下载任务失败: {e}", exc_info=True) return f"删除下载任务时发生错误: {str(e)}" - diff --git a/app/agent/tools/impl/delete_download_history.py b/app/agent/tools/impl/delete_download_history.py index bfe0c88f..0fb3be94 100644 --- a/app/agent/tools/impl/delete_download_history.py +++ b/app/agent/tools/impl/delete_download_history.py @@ -26,6 +26,7 @@ class DeleteDownloadHistoryTool(MoviePilotTool): name: str = "delete_download_history" description: str = "Delete a download history record by ID. This only removes the record from the database, does not delete any actual files." args_schema: Type[BaseModel] = DeleteDownloadHistoryInput + require_admin: bool = True def get_tool_message(self, **kwargs) -> Optional[str]: history_id = kwargs.get("history_id") diff --git a/app/agent/tools/impl/delete_subscribe.py b/app/agent/tools/impl/delete_subscribe.py index 9657cdd1..045cfffc 100644 --- a/app/agent/tools/impl/delete_subscribe.py +++ b/app/agent/tools/impl/delete_subscribe.py @@ -14,14 +14,22 @@ from app.schemas.types import EventType class DeleteSubscribeInput(BaseModel): """删除订阅工具的输入参数模型""" - explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context") - subscribe_id: int = Field(..., description="The ID of the subscription to delete (can be obtained from query_subscribes tool)") + + explanation: str = Field( + ..., + description="Clear explanation of why this tool is being used in the current context", + ) + subscribe_id: int = Field( + ..., + description="The ID of the subscription to delete (can be obtained from query_subscribes tool)", + ) class DeleteSubscribeTool(MoviePilotTool): name: str = "delete_subscribe" description: str = "Delete a media subscription by its ID. This will remove the subscription and stop automatic downloads for that media." args_schema: Type[BaseModel] = DeleteSubscribeInput + require_admin: bool = True def get_tool_message(self, **kwargs) -> Optional[str]: """根据删除参数生成友好的提示消息""" @@ -37,27 +45,25 @@ class DeleteSubscribeTool(MoviePilotTool): subscribe = await subscribe_oper.async_get(subscribe_id) if not subscribe: return f"订阅 ID {subscribe_id} 不存在" - + # 在删除之前获取订阅信息(用于事件) subscribe_info = subscribe.to_dict() - + # 删除订阅 subscribe_oper.delete(subscribe_id) - + # 发送事件 - await eventmanager.async_send_event(EventType.SubscribeDeleted, { - "subscribe_id": subscribe_id, - "subscribe_info": subscribe_info - }) - + await eventmanager.async_send_event( + EventType.SubscribeDeleted, + {"subscribe_id": subscribe_id, "subscribe_info": subscribe_info}, + ) + # 统计订阅 - SubscribeHelper().sub_done_async({ - "tmdbid": subscribe.tmdbid, - "doubanid": subscribe.doubanid - }) - + SubscribeHelper().sub_done_async( + {"tmdbid": subscribe.tmdbid, "doubanid": subscribe.doubanid} + ) + return f"成功删除订阅:{subscribe.name} ({subscribe.year})" except Exception as e: logger.error(f"删除订阅失败: {e}", exc_info=True) return f"删除订阅时发生错误: {str(e)}" - diff --git a/app/agent/tools/impl/edit_file.py b/app/agent/tools/impl/edit_file.py index ddf97770..b26dce66 100644 --- a/app/agent/tools/impl/edit_file.py +++ b/app/agent/tools/impl/edit_file.py @@ -12,6 +12,7 @@ from app.log import logger class EditFileInput(BaseModel): """Input parameters for edit file tool""" + file_path: str = Field(..., description="The absolute path of the file to edit") old_text: str = Field(..., description="The exact old text to be replaced") new_text: str = Field(..., description="The new text to replace with") @@ -21,6 +22,7 @@ class EditFileTool(MoviePilotTool): name: str = "edit_file" description: str = "Edit a file by replacing specific old text with new text. Useful for modifying configuration files, code, or scripts." args_schema: Type[BaseModel] = EditFileInput + require_admin: bool = True def get_tool_message(self, **kwargs) -> Optional[str]: """根据参数生成友好的提示消息""" @@ -38,7 +40,7 @@ class EditFileTool(MoviePilotTool): # 如果 old_text 为空,可能用户想直接创建文件,但通常 edit_file 需要匹配旧内容 if old_text: return f"错误:文件 {file_path} 不存在,无法进行内容替换。" - + if await path.exists() and not await path.is_file(): return f"错误:{file_path} 不是一个文件" @@ -56,14 +58,13 @@ class EditFileTool(MoviePilotTool): # 自动创建父目录 await path.parent.mkdir(parents=True, exist_ok=True) - + # 写入文件 await path.write_text(new_content, encoding="utf-8") - + logger.info(f"成功编辑文件 {file_path},替换了 {occurrences} 处内容") return f"成功编辑文件 {file_path} (替换了 {occurrences} 处匹配内容)" - except PermissionError: return f"错误:没有访问/修改 {file_path} 的权限" except UnicodeDecodeError: @@ -71,5 +72,3 @@ class EditFileTool(MoviePilotTool): except Exception as e: logger.error(f"编辑文件 {file_path} 时发生错误: {str(e)}", exc_info=True) return f"操作失败: {str(e)}" - - diff --git a/app/agent/tools/impl/execute_command.py b/app/agent/tools/impl/execute_command.py index 81e8a57d..52b40336 100644 --- a/app/agent/tools/impl/execute_command.py +++ b/app/agent/tools/impl/execute_command.py @@ -11,15 +11,21 @@ from app.log import logger class ExecuteCommandInput(BaseModel): """执行Shell命令工具的输入参数模型""" - explanation: str = Field(..., description="Clear explanation of why this command is being executed") + + explanation: str = Field( + ..., description="Clear explanation of why this command is being executed" + ) command: str = Field(..., description="The shell command to execute") - timeout: Optional[int] = Field(60, description="Max execution time in seconds (default: 60)") + timeout: Optional[int] = Field( + 60, description="Max execution time in seconds (default: 60)" + ) class ExecuteCommandTool(MoviePilotTool): name: str = "execute_command" description: str = "Safely execute shell commands on the server. Useful for system maintenance, checking status, or running custom scripts. Includes timeout and output limits." args_schema: Type[BaseModel] = ExecuteCommandInput + require_admin: bool = True def get_tool_message(self, **kwargs) -> Optional[str]: """根据命令生成友好的提示消息""" @@ -27,10 +33,19 @@ class ExecuteCommandTool(MoviePilotTool): return f"正在执行系统命令: {command}" async def run(self, command: str, timeout: Optional[int] = 60, **kwargs) -> str: - logger.info(f"执行工具: {self.name}, 参数: command={command}, timeout={timeout}") + logger.info( + f"执行工具: {self.name}, 参数: command={command}, timeout={timeout}" + ) # 简单安全过滤 - forbidden_keywords = ["rm -rf /", ":(){ :|:& };:", "dd if=/dev/zero", "mkfs", "reboot", "shutdown"] + forbidden_keywords = [ + "rm -rf /", + ":(){ :|:& };:", + "dd if=/dev/zero", + "mkfs", + "reboot", + "shutdown", + ] for keyword in forbidden_keywords: if keyword in command: return f"错误:命令包含禁止使用的关键字 '{keyword}'" @@ -38,18 +53,18 @@ class ExecuteCommandTool(MoviePilotTool): try: # 执行命令 process = await asyncio.create_subprocess_shell( - command, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE + command, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE ) try: # 等待完成,带超时 - stdout, stderr = await asyncio.wait_for(process.communicate(), timeout=timeout) - + stdout, stderr = await asyncio.wait_for( + process.communicate(), timeout=timeout + ) + # 处理输出 - stdout_str = stdout.decode('utf-8', errors='replace').strip() - stderr_str = stderr.decode('utf-8', errors='replace').strip() + stdout_str = stdout.decode("utf-8", errors="replace").strip() + stderr_str = stderr.decode("utf-8", errors="replace").strip() exit_code = process.returncode result = f"命令执行完成 (退出码: {exit_code})" @@ -57,15 +72,15 @@ class ExecuteCommandTool(MoviePilotTool): result += f"\n\n标准输出:\n{stdout_str}" if stderr_str: result += f"\n\n错误输出:\n{stderr_str}" - + # 如果没有输出 if not stdout_str and not stderr_str: result += "\n\n(无输出内容)" - + # 限制输出长度,防止上下文过长 if len(result) > 3000: result = result[:3000] + "\n\n...(输出内容过长,已截断)" - + return result except asyncio.TimeoutError: diff --git a/app/agent/tools/impl/modify_download.py b/app/agent/tools/impl/modify_download.py index 4c0c0a47..65b98900 100644 --- a/app/agent/tools/impl/modify_download.py +++ b/app/agent/tools/impl/modify_download.py @@ -47,6 +47,7 @@ class ModifyDownloadTool(MoviePilotTool): "Multiple operations can be performed in a single call." ) args_schema: Type[BaseModel] = ModifyDownloadInput + require_admin: bool = True def get_tool_message(self, **kwargs) -> Optional[str]: hash_value = kwargs.get("hash", "") diff --git a/app/agent/tools/impl/run_plugin_command.py b/app/agent/tools/impl/run_plugin_command.py index 1aaa67af..1d66817f 100644 --- a/app/agent/tools/impl/run_plugin_command.py +++ b/app/agent/tools/impl/run_plugin_command.py @@ -37,6 +37,7 @@ class RunPluginCommandTool(MoviePilotTool): "Note: This tool triggers the command execution but the actual processing happens in the background." ) args_schema: Type[BaseModel] = RunPluginCommandInput + require_admin: bool = True def get_tool_message(self, **kwargs) -> Optional[str]: """生成友好的提示消息""" diff --git a/app/agent/tools/impl/run_scheduler.py b/app/agent/tools/impl/run_scheduler.py index e8b9b2fb..802c2f8b 100644 --- a/app/agent/tools/impl/run_scheduler.py +++ b/app/agent/tools/impl/run_scheduler.py @@ -11,14 +11,22 @@ from app.scheduler import Scheduler class RunSchedulerInput(BaseModel): """运行定时服务工具的输入参数模型""" - explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context") - job_id: str = Field(..., description="The ID of the scheduled job to run (can be obtained from query_schedulers tool)") + + explanation: str = Field( + ..., + description="Clear explanation of why this tool is being used in the current context", + ) + job_id: str = Field( + ..., + description="The ID of the scheduled job to run (can be obtained from query_schedulers tool)", + ) class RunSchedulerTool(MoviePilotTool): name: str = "run_scheduler" description: str = "Manually trigger a scheduled task to run immediately. This will execute the specified scheduler job by its ID." args_schema: Type[BaseModel] = RunSchedulerInput + require_admin: bool = True def get_tool_message(self, **kwargs) -> Optional[str]: """根据运行参数生成友好的提示消息""" @@ -39,15 +47,14 @@ class RunSchedulerTool(MoviePilotTool): job_exists = True job_name = s.name break - + if not job_exists: return f"定时服务 ID {job_id} 不存在,请使用 query_schedulers 工具查询可用的定时服务" - + # 运行定时服务 scheduler.start(job_id) - + return f"成功触发定时服务:{job_name} (ID: {job_id})" except Exception as e: logger.error(f"运行定时服务失败: {e}", exc_info=True) return f"运行定时服务时发生错误: {str(e)}" - diff --git a/app/agent/tools/impl/run_workflow.py b/app/agent/tools/impl/run_workflow.py index 8e20f2bf..9a8ed78e 100644 --- a/app/agent/tools/impl/run_workflow.py +++ b/app/agent/tools/impl/run_workflow.py @@ -13,46 +13,61 @@ from app.log import logger class RunWorkflowInput(BaseModel): """执行工作流工具的输入参数模型""" - explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context") - workflow_id: int = Field(..., description="Workflow ID (can be obtained from query_workflows tool)") - from_begin: Optional[bool] = Field(True, description="Whether to run workflow from the beginning (default: True, if False will continue from last executed action)") + + explanation: str = Field( + ..., + description="Clear explanation of why this tool is being used in the current context", + ) + workflow_id: int = Field( + ..., description="Workflow ID (can be obtained from query_workflows tool)" + ) + from_begin: Optional[bool] = Field( + True, + description="Whether to run workflow from the beginning (default: True, if False will continue from last executed action)", + ) class RunWorkflowTool(MoviePilotTool): name: str = "run_workflow" description: str = "Execute a specific workflow manually by workflow ID. Supports running from the beginning or continuing from the last executed action." args_schema: Type[BaseModel] = RunWorkflowInput + require_admin: bool = True def get_tool_message(self, **kwargs) -> Optional[str]: """根据工作流参数生成友好的提示消息""" workflow_id = kwargs.get("workflow_id") from_begin = kwargs.get("from_begin", True) - + message = f"正在执行工作流: {workflow_id}" if not from_begin: message += " (从上次位置继续)" else: message += " (从头开始)" - + return message - async def run(self, workflow_id: int, - from_begin: Optional[bool] = True, **kwargs) -> str: - logger.info(f"执行工具: {self.name}, 参数: workflow_id={workflow_id}, from_begin={from_begin}") + async def run( + self, workflow_id: int, from_begin: Optional[bool] = True, **kwargs + ) -> str: + logger.info( + f"执行工具: {self.name}, 参数: workflow_id={workflow_id}, from_begin={from_begin}" + ) try: # 获取数据库会话 async with AsyncSessionFactory() as db: workflow_oper = WorkflowOper(db) workflow = await workflow_oper.async_get(workflow_id) - + if not workflow: return f"未找到工作流:{workflow_id},请使用 query_workflows 工具查询可用的工作流" - + # 执行工作流 workflow_chain = WorkflowChain() - state, errmsg = workflow_chain.process(workflow.id, from_begin=from_begin) - + state, errmsg = workflow_chain.process( + workflow.id, from_begin=from_begin + ) + if not state: return f"执行工作流失败:{workflow.name} (ID: {workflow.id})\n错误原因:{errmsg}" else: @@ -60,4 +75,3 @@ class RunWorkflowTool(MoviePilotTool): except Exception as e: logger.error(f"执行工作流失败: {e}", exc_info=True) return f"执行工作流时发生错误: {str(e)}" - diff --git a/app/agent/tools/impl/send_message.py b/app/agent/tools/impl/send_message.py index 090729c0..0bcd1b7c 100644 --- a/app/agent/tools/impl/send_message.py +++ b/app/agent/tools/impl/send_message.py @@ -10,33 +10,52 @@ from app.log import logger class SendMessageInput(BaseModel): """发送消息工具的输入参数模型""" - explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context") - message: str = Field(..., description="The message content to send to the user (should be clear and informative)") - message_type: Optional[str] = Field("info", - description="Type of message: 'info' for general information, 'success' for successful operations, 'warning' for warnings, 'error' for error messages") + + explanation: str = Field( + ..., + description="Clear explanation of why this tool is being used in the current context", + ) + message: str = Field( + ..., + description="The message content to send to the user (should be clear and informative)", + ) + message_type: Optional[str] = Field( + "info", + description="Type of message: 'info' for general information, 'success' for successful operations, 'warning' for warnings, 'error' for error messages", + ) class SendMessageTool(MoviePilotTool): name: str = "send_message" description: str = "Send notification message to the user through configured notification channels (Telegram, Slack, WeChat, etc.). Used to inform users about operation results, errors, or important updates." args_schema: Type[BaseModel] = SendMessageInput + require_admin: bool = True def get_tool_message(self, **kwargs) -> Optional[str]: """根据消息参数生成友好的提示消息""" message = kwargs.get("message", "") message_type = kwargs.get("message_type", "info") - - type_map = {"info": "信息", "success": "成功", "warning": "警告", "error": "错误"} + + type_map = { + "info": "信息", + "success": "成功", + "warning": "警告", + "error": "错误", + } type_desc = type_map.get(message_type, message_type) - + # 截断过长的消息 if len(message) > 50: message = message[:50] + "..." - + return f"正在发送{type_desc}消息: {message}" - async def run(self, message: str, message_type: Optional[str] = None, **kwargs) -> str: - logger.info(f"执行工具: {self.name}, 参数: message={message}, message_type={message_type}") + async def run( + self, message: str, message_type: Optional[str] = None, **kwargs + ) -> str: + logger.info( + f"执行工具: {self.name}, 参数: message={message}, message_type={message_type}" + ) try: await self.send_tool_message(message, title=message_type) return "消息已发送" diff --git a/app/agent/tools/impl/transfer_file.py b/app/agent/tools/impl/transfer_file.py index ff17911c..cc1df642 100644 --- a/app/agent/tools/impl/transfer_file.py +++ b/app/agent/tools/impl/transfer_file.py @@ -13,23 +13,53 @@ from app.schemas import FileItem, MediaType class TransferFileInput(BaseModel): """整理文件或目录工具的输入参数模型""" - explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context") - file_path: str = Field(..., description="Path to the file or directory to transfer (e.g., '/path/to/file.mkv' or '/path/to/directory')") - storage: Optional[str] = Field("local", description="Storage type of the source file (default: 'local', can be 'smb', 'alist', etc.)") - target_path: Optional[str] = Field(None, description="Target path for the transferred file/directory (optional, uses default library path if not specified)") - target_storage: Optional[str] = Field(None, description="Target storage type (optional, uses default storage if not specified)") + + explanation: str = Field( + ..., + description="Clear explanation of why this tool is being used in the current context", + ) + file_path: str = Field( + ..., + description="Path to the file or directory to transfer (e.g., '/path/to/file.mkv' or '/path/to/directory')", + ) + storage: Optional[str] = Field( + "local", + description="Storage type of the source file (default: 'local', can be 'smb', 'alist', etc.)", + ) + target_path: Optional[str] = Field( + None, + description="Target path for the transferred file/directory (optional, uses default library path if not specified)", + ) + target_storage: Optional[str] = Field( + None, + description="Target storage type (optional, uses default storage if not specified)", + ) media_type: Optional[str] = Field(None, description="Allowed values: movie, tv") - tmdbid: Optional[int] = Field(None, description="TMDB ID for precise media identification (optional but recommended for accuracy)") - doubanid: Optional[str] = Field(None, description="Douban ID for media identification (optional)") - season: Optional[int] = Field(None, description="Season number for TV shows (optional)") - transfer_type: Optional[str] = Field(None, description="Transfer mode: 'move' to move files, 'copy' to copy files, 'link' for hard link, 'softlink' for symbolic link (optional, uses default mode if not specified)") - background: Optional[bool] = Field(False, description="Whether to run transfer in background (default: False, runs synchronously)") + tmdbid: Optional[int] = Field( + None, + description="TMDB ID for precise media identification (optional but recommended for accuracy)", + ) + doubanid: Optional[str] = Field( + None, description="Douban ID for media identification (optional)" + ) + season: Optional[int] = Field( + None, description="Season number for TV shows (optional)" + ) + transfer_type: Optional[str] = Field( + None, + description="Transfer mode: 'move' to move files, 'copy' to copy files, 'link' for hard link, 'softlink' for symbolic link (optional, uses default mode if not specified)", + ) + background: Optional[bool] = Field( + False, + description="Whether to run transfer in background (default: False, runs synchronously)", + ) class TransferFileTool(MoviePilotTool): name: str = "transfer_file" description: str = "Transfer/organize a file or directory to the media library. Automatically recognizes media information and organizes files according to configured rules. Supports custom target paths, media identification, and transfer modes." args_schema: Type[BaseModel] = TransferFileInput + require_admin: bool = True def get_tool_message(self, **kwargs) -> Optional[str]: """根据整理参数生成友好的提示消息""" @@ -37,66 +67,79 @@ class TransferFileTool(MoviePilotTool): media_type = kwargs.get("media_type") transfer_type = kwargs.get("transfer_type") background = kwargs.get("background", False) - + message = f"正在整理文件: {file_path}" if media_type: message += f" [{media_type}]" if transfer_type: - transfer_map = {"move": "移动", "copy": "复制", "link": "硬链接", "softlink": "软链接"} + transfer_map = { + "move": "移动", + "copy": "复制", + "link": "硬链接", + "softlink": "软链接", + } message += f" 模式: {transfer_map.get(transfer_type, transfer_type)}" if background: message += " [后台运行]" - + return message - async def run(self, file_path: str, storage: Optional[str] = "local", - target_path: Optional[str] = None, - target_storage: Optional[str] = None, - media_type: Optional[str] = None, - tmdbid: Optional[int] = None, - doubanid: Optional[str] = None, - season: Optional[int] = None, - transfer_type: Optional[str] = None, - background: Optional[bool] = False, **kwargs) -> str: + async def run( + self, + file_path: str, + storage: Optional[str] = "local", + target_path: Optional[str] = None, + target_storage: Optional[str] = None, + media_type: Optional[str] = None, + tmdbid: Optional[int] = None, + doubanid: Optional[str] = None, + season: Optional[int] = None, + transfer_type: Optional[str] = None, + background: Optional[bool] = False, + **kwargs, + ) -> str: logger.info( f"执行工具: {self.name}, 参数: file_path={file_path}, storage={storage}, target_path={target_path}, " f"target_storage={target_storage}, media_type={media_type}, tmdbid={tmdbid}, doubanid={doubanid}, " - f"season={season}, transfer_type={transfer_type}, background={background}") + f"season={season}, transfer_type={transfer_type}, background={background}" + ) try: if not file_path: return "错误:必须提供文件或目录路径" - + # 规范化路径 if storage == "local": # 本地路径处理 - if not file_path.startswith("/") and not (len(file_path) > 1 and file_path[1] == ":"): + if not file_path.startswith("/") and not ( + len(file_path) > 1 and file_path[1] == ":" + ): # 相对路径,尝试转换为绝对路径 file_path = str(Path(file_path).resolve()) else: # 远程存储路径,确保以/开头 if not file_path.startswith("/"): file_path = "/" + file_path - + # 创建FileItem fileitem = FileItem( storage=storage or "local", path=file_path, - type="dir" if file_path.endswith("/") else "file" + type="dir" if file_path.endswith("/") else "file", ) - + # 处理目标路径 target_path_obj = None if target_path: target_path_obj = Path(target_path) - + # 处理媒体类型 media_type_enum = None if media_type: media_type_enum = MediaType.from_agent(media_type) if not media_type_enum: return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv'" - + # 调用整理方法 transfer_chain = TransferChain() state, errormsg = transfer_chain.manual_transfer( @@ -108,15 +151,17 @@ class TransferFileTool(MoviePilotTool): mtype=media_type_enum, season=season, transfer_type=transfer_type, - background=background + background=background, ) - + if not state: # 处理错误信息 if isinstance(errormsg, list): error_text = f"整理完成,{len(errormsg)} 个文件转移失败" if errormsg: - error_text += f":\n" + "\n".join(str(e) for e in errormsg[:5]) # 只显示前5个错误 + error_text += f":\n" + "\n".join( + str(e) for e in errormsg[:5] + ) # 只显示前5个错误 if len(errormsg) > 5: error_text += f"\n... 还有 {len(errormsg) - 5} 个错误" else: @@ -130,4 +175,3 @@ class TransferFileTool(MoviePilotTool): except Exception as e: logger.error(f"整理文件失败: {e}", exc_info=True) return f"整理文件时发生错误: {str(e)}" - diff --git a/app/agent/tools/impl/update_site.py b/app/agent/tools/impl/update_site.py index a9c80643..1976a737 100644 --- a/app/agent/tools/impl/update_site.py +++ b/app/agent/tools/impl/update_site.py @@ -16,37 +16,67 @@ from app.utils.string import StringUtils class UpdateSiteInput(BaseModel): """更新站点工具的输入参数模型""" - explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context") - site_id: int = Field(..., description="The ID of the site to update (can be obtained from query_sites tool)") + + explanation: str = Field( + ..., + description="Clear explanation of why this tool is being used in the current context", + ) + site_id: int = Field( + ..., + description="The ID of the site to update (can be obtained from query_sites tool)", + ) name: Optional[str] = Field(None, description="Site name (optional)") - url: Optional[str] = Field(None, description="Site URL (optional, will be automatically formatted)") - pri: Optional[int] = Field(None, description="Site priority (optional, smaller value = higher priority, e.g., pri=1 has higher priority than pri=10)") + url: Optional[str] = Field( + None, description="Site URL (optional, will be automatically formatted)" + ) + pri: Optional[int] = Field( + None, + description="Site priority (optional, smaller value = higher priority, e.g., pri=1 has higher priority than pri=10)", + ) rss: Optional[str] = Field(None, description="RSS feed URL (optional)") cookie: Optional[str] = Field(None, description="Site cookie (optional)") ua: Optional[str] = Field(None, description="User-Agent string (optional)") apikey: Optional[str] = Field(None, description="API key (optional)") token: Optional[str] = Field(None, description="API token (optional)") - proxy: Optional[int] = Field(None, description="Whether to use proxy: 0 for no, 1 for yes (optional)") - filter: Optional[str] = Field(None, description="Filter rule as regular expression (optional)") + proxy: Optional[int] = Field( + None, description="Whether to use proxy: 0 for no, 1 for yes (optional)" + ) + filter: Optional[str] = Field( + None, description="Filter rule as regular expression (optional)" + ) note: Optional[str] = Field(None, description="Site notes/remarks (optional)") - timeout: Optional[int] = Field(None, description="Request timeout in seconds (optional, default: 15)") - limit_interval: Optional[int] = Field(None, description="Rate limit interval in seconds (optional)") - limit_count: Optional[int] = Field(None, description="Rate limit count per interval (optional)") - limit_seconds: Optional[int] = Field(None, description="Rate limit seconds between requests (optional)") - is_active: Optional[bool] = Field(None, description="Whether site is active: True for enabled, False for disabled (optional)") - downloader: Optional[str] = Field(None, description="Downloader name for this site (optional)") + timeout: Optional[int] = Field( + None, description="Request timeout in seconds (optional, default: 15)" + ) + limit_interval: Optional[int] = Field( + None, description="Rate limit interval in seconds (optional)" + ) + limit_count: Optional[int] = Field( + None, description="Rate limit count per interval (optional)" + ) + limit_seconds: Optional[int] = Field( + None, description="Rate limit seconds between requests (optional)" + ) + is_active: Optional[bool] = Field( + None, + description="Whether site is active: True for enabled, False for disabled (optional)", + ) + downloader: Optional[str] = Field( + None, description="Downloader name for this site (optional)" + ) class UpdateSiteTool(MoviePilotTool): name: str = "update_site" description: str = "Update site configuration including URL, priority, authentication credentials (cookie, UA, API key), proxy settings, rate limits, and other site properties. Supports updating multiple site attributes at once. Site priority (pri): smaller values have higher priority (e.g., pri=1 has higher priority than pri=10)." args_schema: Type[BaseModel] = UpdateSiteInput + require_admin: bool = True def get_tool_message(self, **kwargs) -> Optional[str]: """根据更新参数生成友好的提示消息""" site_id = kwargs.get("site_id") fields_updated = [] - + if kwargs.get("name"): fields_updated.append("名称") if kwargs.get("url"): @@ -63,60 +93,63 @@ class UpdateSiteTool(MoviePilotTool): fields_updated.append("启用状态") if kwargs.get("downloader"): fields_updated.append("下载器") - + if fields_updated: return f"正在更新站点 #{site_id}: {', '.join(fields_updated)}" return f"正在更新站点 #{site_id}" - async def run(self, site_id: int, - name: Optional[str] = None, - url: Optional[str] = None, - pri: Optional[int] = None, - rss: Optional[str] = None, - cookie: Optional[str] = None, - ua: Optional[str] = None, - apikey: Optional[str] = None, - token: Optional[str] = None, - proxy: Optional[int] = None, - filter: Optional[str] = None, - note: Optional[str] = None, - timeout: Optional[int] = None, - limit_interval: Optional[int] = None, - limit_count: Optional[int] = None, - limit_seconds: Optional[int] = None, - is_active: Optional[bool] = None, - downloader: Optional[str] = None, - **kwargs) -> str: + async def run( + self, + site_id: int, + name: Optional[str] = None, + url: Optional[str] = None, + pri: Optional[int] = None, + rss: Optional[str] = None, + cookie: Optional[str] = None, + ua: Optional[str] = None, + apikey: Optional[str] = None, + token: Optional[str] = None, + proxy: Optional[int] = None, + filter: Optional[str] = None, + note: Optional[str] = None, + timeout: Optional[int] = None, + limit_interval: Optional[int] = None, + limit_count: Optional[int] = None, + limit_seconds: Optional[int] = None, + is_active: Optional[bool] = None, + downloader: Optional[str] = None, + **kwargs, + ) -> str: logger.info(f"执行工具: {self.name}, 参数: site_id={site_id}") - + try: # 获取数据库会话 async with AsyncSessionFactory() as db: # 获取站点 site = await Site.async_get(db, site_id) if not site: - return json.dumps({ - "success": False, - "message": f"站点不存在: {site_id}" - }, ensure_ascii=False) - + return json.dumps( + {"success": False, "message": f"站点不存在: {site_id}"}, + ensure_ascii=False, + ) + # 构建更新字典 site_dict = {} - + # 基本信息 if name is not None: site_dict["name"] = name - + # URL处理(需要校正格式) if url is not None: _scheme, _netloc = StringUtils.get_url_netloc(url) site_dict["url"] = f"{_scheme}://{_netloc}/" - + if pri is not None: site_dict["pri"] = pri if rss is not None: site_dict["rss"] = rss - + # 认证信息 if cookie is not None: site_dict["cookie"] = cookie @@ -126,7 +159,7 @@ class UpdateSiteTool(MoviePilotTool): site_dict["apikey"] = apikey if token is not None: site_dict["token"] = token - + # 配置选项 if proxy is not None: site_dict["proxy"] = proxy @@ -136,7 +169,7 @@ class UpdateSiteTool(MoviePilotTool): site_dict["note"] = note if timeout is not None: site_dict["timeout"] = timeout - + # 流控设置 if limit_interval is not None: site_dict["limit_interval"] = limit_interval @@ -144,39 +177,40 @@ class UpdateSiteTool(MoviePilotTool): site_dict["limit_count"] = limit_count if limit_seconds is not None: site_dict["limit_seconds"] = limit_seconds - + # 状态和下载器 if is_active is not None: site_dict["is_active"] = is_active if downloader is not None: site_dict["downloader"] = downloader - + # 如果没有要更新的字段 if not site_dict: - return json.dumps({ - "success": False, - "message": "没有提供要更新的字段" - }, ensure_ascii=False) - + return json.dumps( + {"success": False, "message": "没有提供要更新的字段"}, + ensure_ascii=False, + ) + # 更新站点 await site.async_update(db, site_dict) - + # 重新获取更新后的站点数据 updated_site = await Site.async_get(db, site_id) - + # 发送站点更新事件 - await eventmanager.async_send_event(EventType.SiteUpdated, { - "domain": updated_site.domain if updated_site else site.domain - }) - + await eventmanager.async_send_event( + EventType.SiteUpdated, + {"domain": updated_site.domain if updated_site else site.domain}, + ) + # 构建返回结果 result = { "success": True, "message": f"站点 #{site_id} 更新成功", "site_id": site_id, - "updated_fields": list(site_dict.keys()) + "updated_fields": list(site_dict.keys()), } - + if updated_site: result["site"] = { "id": updated_site.id, @@ -187,17 +221,15 @@ class UpdateSiteTool(MoviePilotTool): "is_active": updated_site.is_active, "downloader": updated_site.downloader, "proxy": updated_site.proxy, - "timeout": updated_site.timeout + "timeout": updated_site.timeout, } - + return json.dumps(result, ensure_ascii=False, indent=2) - + except Exception as e: error_message = f"更新站点失败: {str(e)}" logger.error(f"更新站点失败: {e}", exc_info=True) - return json.dumps({ - "success": False, - "message": error_message, - "site_id": site_id - }, ensure_ascii=False) - + return json.dumps( + {"success": False, "message": error_message, "site_id": site_id}, + ensure_ascii=False, + ) diff --git a/app/agent/tools/impl/update_site_cookie.py b/app/agent/tools/impl/update_site_cookie.py index f91b706f..a9b208fa 100644 --- a/app/agent/tools/impl/update_site_cookie.py +++ b/app/agent/tools/impl/update_site_cookie.py @@ -12,50 +12,69 @@ from app.log import logger class UpdateSiteCookieInput(BaseModel): """更新站点Cookie和UA工具的输入参数模型""" - explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context") - site_identifier: int = Field(..., description="Site ID to update Cookie and User-Agent for (can be obtained from query_sites tool)") + + explanation: str = Field( + ..., + description="Clear explanation of why this tool is being used in the current context", + ) + site_identifier: int = Field( + ..., + description="Site ID to update Cookie and User-Agent for (can be obtained from query_sites tool)", + ) username: str = Field(..., description="Site login username") password: str = Field(..., description="Site login password") - two_step_code: Optional[str] = Field(None, description="Two-step verification code or secret key (optional, required for sites with 2FA enabled)") + two_step_code: Optional[str] = Field( + None, + description="Two-step verification code or secret key (optional, required for sites with 2FA enabled)", + ) class UpdateSiteCookieTool(MoviePilotTool): name: str = "update_site_cookie" description: str = "Update site Cookie and User-Agent by logging in with username and password. This tool can automatically obtain and update the site's authentication credentials. Supports two-step verification for sites that require it. Accepts site ID only." args_schema: Type[BaseModel] = UpdateSiteCookieInput + require_admin: bool = True def get_tool_message(self, **kwargs) -> Optional[str]: """根据更新参数生成友好的提示消息""" site_identifier = kwargs.get("site_identifier") username = kwargs.get("username", "") two_step_code = kwargs.get("two_step_code") - + message = f"正在更新站点Cookie: {site_identifier} (用户: {username})" if two_step_code: message += " [需要两步验证]" - + return message - async def run(self, site_identifier: int, username: str, password: str, - two_step_code: Optional[str] = None, **kwargs) -> str: - logger.info(f"执行工具: {self.name}, 参数: site_identifier={site_identifier}, username={username}") + async def run( + self, + site_identifier: int, + username: str, + password: str, + two_step_code: Optional[str] = None, + **kwargs, + ) -> str: + logger.info( + f"执行工具: {self.name}, 参数: site_identifier={site_identifier}, username={username}" + ) try: site_oper = SiteOper() site_chain = SiteChain() site = await site_oper.async_get(site_identifier) - + if not site: return f"未找到站点:{site_identifier},请使用 query_sites 工具查询可用的站点" - + # 更新站点Cookie和UA status, message = site_chain.update_cookie( site_info=site, username=username, password=password, - two_step_code=two_step_code + two_step_code=two_step_code, ) - + if status: return f"站点【{site.name}】Cookie和UA更新成功\n{message}" else: @@ -63,4 +82,3 @@ class UpdateSiteCookieTool(MoviePilotTool): except Exception as e: logger.error(f"更新站点Cookie和UA失败: {e}", exc_info=True) return f"更新站点Cookie和UA时发生错误: {str(e)}" - diff --git a/app/agent/tools/impl/update_subscribe.py b/app/agent/tools/impl/update_subscribe.py index 9e635598..43c4b39b 100644 --- a/app/agent/tools/impl/update_subscribe.py +++ b/app/agent/tools/impl/update_subscribe.py @@ -15,40 +15,87 @@ from app.schemas.types import EventType class UpdateSubscribeInput(BaseModel): """更新订阅工具的输入参数模型""" - explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context") - subscribe_id: int = Field(..., description="The ID of the subscription to update (can be obtained from query_subscribes tool)") + + explanation: str = Field( + ..., + description="Clear explanation of why this tool is being used in the current context", + ) + subscribe_id: int = Field( + ..., + description="The ID of the subscription to update (can be obtained from query_subscribes tool)", + ) name: Optional[str] = Field(None, description="Subscription name/title (optional)") year: Optional[str] = Field(None, description="Release year (optional)") - season: Optional[int] = Field(None, description="Season number for TV shows (optional)") - total_episode: Optional[int] = Field(None, description="Total number of episodes (optional)") - lack_episode: Optional[int] = Field(None, description="Number of missing episodes (optional)") - start_episode: Optional[int] = Field(None, description="Starting episode number (optional)") - quality: Optional[str] = Field(None, description="Quality filter as regular expression (optional, e.g., 'BluRay|WEB-DL|HDTV')") - resolution: Optional[str] = Field(None, description="Resolution filter as regular expression (optional, e.g., '1080p|720p|2160p')") - effect: Optional[str] = Field(None, description="Effect filter as regular expression (optional, e.g., 'HDR|DV|SDR')") - include: Optional[str] = Field(None, description="Include filter as regular expression (optional)") - exclude: Optional[str] = Field(None, description="Exclude filter as regular expression (optional)") - filter: Optional[str] = Field(None, description="Filter rule as regular expression (optional)") - state: Optional[str] = Field(None, description="Subscription state: 'R' for enabled, 'P' for pending, 'S' for paused (optional)") - sites: Optional[List[int]] = Field(None, description="List of site IDs to search from (optional)") + season: Optional[int] = Field( + None, description="Season number for TV shows (optional)" + ) + total_episode: Optional[int] = Field( + None, description="Total number of episodes (optional)" + ) + lack_episode: Optional[int] = Field( + None, description="Number of missing episodes (optional)" + ) + start_episode: Optional[int] = Field( + None, description="Starting episode number (optional)" + ) + quality: Optional[str] = Field( + None, + description="Quality filter as regular expression (optional, e.g., 'BluRay|WEB-DL|HDTV')", + ) + resolution: Optional[str] = Field( + None, + description="Resolution filter as regular expression (optional, e.g., '1080p|720p|2160p')", + ) + effect: Optional[str] = Field( + None, + description="Effect filter as regular expression (optional, e.g., 'HDR|DV|SDR')", + ) + include: Optional[str] = Field( + None, description="Include filter as regular expression (optional)" + ) + exclude: Optional[str] = Field( + None, description="Exclude filter as regular expression (optional)" + ) + filter: Optional[str] = Field( + None, description="Filter rule as regular expression (optional)" + ) + state: Optional[str] = Field( + None, + description="Subscription state: 'R' for enabled, 'P' for pending, 'S' for paused (optional)", + ) + sites: Optional[List[int]] = Field( + None, description="List of site IDs to search from (optional)" + ) downloader: Optional[str] = Field(None, description="Downloader name (optional)") - save_path: Optional[str] = Field(None, description="Save path for downloaded files (optional)") - best_version: Optional[int] = Field(None, description="Whether to upgrade to best version: 0 for no, 1 for yes (optional)") - custom_words: Optional[str] = Field(None, description="Custom recognition words (optional)") - media_category: Optional[str] = Field(None, description="Custom media category (optional)") - episode_group: Optional[str] = Field(None, description="Episode group ID (optional)") + save_path: Optional[str] = Field( + None, description="Save path for downloaded files (optional)" + ) + best_version: Optional[int] = Field( + None, + description="Whether to upgrade to best version: 0 for no, 1 for yes (optional)", + ) + custom_words: Optional[str] = Field( + None, description="Custom recognition words (optional)" + ) + media_category: Optional[str] = Field( + None, description="Custom media category (optional)" + ) + episode_group: Optional[str] = Field( + None, description="Episode group ID (optional)" + ) class UpdateSubscribeTool(MoviePilotTool): name: str = "update_subscribe" description: str = "Update subscription properties including filters, episode counts, state, and other settings. Supports updating quality/resolution filters, episode tracking, subscription state, and download configuration." args_schema: Type[BaseModel] = UpdateSubscribeInput + require_admin: bool = True def get_tool_message(self, **kwargs) -> Optional[str]: """根据更新参数生成友好的提示消息""" subscribe_id = kwargs.get("subscribe_id") fields_updated = [] - + if kwargs.get("name"): fields_updated.append("名称") if kwargs.get("total_episode") is not None: @@ -61,57 +108,62 @@ class UpdateSubscribeTool(MoviePilotTool): fields_updated.append("分辨率过滤") if kwargs.get("state"): state_map = {"R": "启用", "P": "禁用", "S": "暂停"} - fields_updated.append(f"状态({state_map.get(kwargs.get('state'), kwargs.get('state'))})") + fields_updated.append( + f"状态({state_map.get(kwargs.get('state'), kwargs.get('state'))})" + ) if kwargs.get("sites"): fields_updated.append("站点") if kwargs.get("downloader"): fields_updated.append("下载器") - + if fields_updated: return f"正在更新订阅 #{subscribe_id}: {', '.join(fields_updated)}" return f"正在更新订阅 #{subscribe_id}" - async def run(self, subscribe_id: int, - name: Optional[str] = None, - year: Optional[str] = None, - season: Optional[int] = None, - total_episode: Optional[int] = None, - lack_episode: Optional[int] = None, - start_episode: Optional[int] = None, - quality: Optional[str] = None, - resolution: Optional[str] = None, - effect: Optional[str] = None, - include: Optional[str] = None, - exclude: Optional[str] = None, - filter: Optional[str] = None, - state: Optional[str] = None, - sites: Optional[List[int]] = None, - downloader: Optional[str] = None, - save_path: Optional[str] = None, - best_version: Optional[int] = None, - custom_words: Optional[str] = None, - media_category: Optional[str] = None, - episode_group: Optional[str] = None, - **kwargs) -> str: + async def run( + self, + subscribe_id: int, + name: Optional[str] = None, + year: Optional[str] = None, + season: Optional[int] = None, + total_episode: Optional[int] = None, + lack_episode: Optional[int] = None, + start_episode: Optional[int] = None, + quality: Optional[str] = None, + resolution: Optional[str] = None, + effect: Optional[str] = None, + include: Optional[str] = None, + exclude: Optional[str] = None, + filter: Optional[str] = None, + state: Optional[str] = None, + sites: Optional[List[int]] = None, + downloader: Optional[str] = None, + save_path: Optional[str] = None, + best_version: Optional[int] = None, + custom_words: Optional[str] = None, + media_category: Optional[str] = None, + episode_group: Optional[str] = None, + **kwargs, + ) -> str: logger.info(f"执行工具: {self.name}, 参数: subscribe_id={subscribe_id}") - + try: # 获取数据库会话 async with AsyncSessionFactory() as db: # 获取订阅 subscribe = await Subscribe.async_get(db, subscribe_id) if not subscribe: - return json.dumps({ - "success": False, - "message": f"订阅不存在: {subscribe_id}" - }, ensure_ascii=False) - + return json.dumps( + {"success": False, "message": f"订阅不存在: {subscribe_id}"}, + ensure_ascii=False, + ) + # 保存旧数据用于事件 old_subscribe_dict = subscribe.to_dict() - + # 构建更新字典 subscribe_dict = {} - + # 基本信息 if name is not None: subscribe_dict["name"] = name @@ -119,27 +171,29 @@ class UpdateSubscribeTool(MoviePilotTool): subscribe_dict["year"] = year if season is not None: subscribe_dict["season"] = season - + # 集数相关 if total_episode is not None: subscribe_dict["total_episode"] = total_episode # 如果总集数增加,缺失集数也要相应增加 if total_episode > (subscribe.total_episode or 0): old_lack = subscribe.lack_episode or 0 - subscribe_dict["lack_episode"] = old_lack + (total_episode - (subscribe.total_episode or 0)) + subscribe_dict["lack_episode"] = old_lack + ( + total_episode - (subscribe.total_episode or 0) + ) # 标记为手动修改过总集数 subscribe_dict["manual_total_episode"] = 1 - + # 缺失集数处理(只有在没有提供总集数时才单独处理) # 注意:如果 lack_episode 为 0,不更新(避免更新为0) if lack_episode is not None and total_episode is None: if lack_episode > 0: subscribe_dict["lack_episode"] = lack_episode # 如果 lack_episode 为 0,不添加到更新字典中(保持原值或由总集数逻辑处理) - + if start_episode is not None: subscribe_dict["start_episode"] = start_episode - + # 过滤规则 if quality is not None: subscribe_dict["quality"] = quality @@ -153,17 +207,20 @@ class UpdateSubscribeTool(MoviePilotTool): subscribe_dict["exclude"] = exclude if filter is not None: subscribe_dict["filter"] = filter - + # 状态 if state is not None: valid_states = ["R", "P", "S", "N"] if state not in valid_states: - return json.dumps({ - "success": False, - "message": f"无效的订阅状态: {state},有效状态: {', '.join(valid_states)}" - }, ensure_ascii=False) + return json.dumps( + { + "success": False, + "message": f"无效的订阅状态: {state},有效状态: {', '.join(valid_states)}", + }, + ensure_ascii=False, + ) subscribe_dict["state"] = state - + # 下载配置 if sites is not None: subscribe_dict["sites"] = sites @@ -173,7 +230,7 @@ class UpdateSubscribeTool(MoviePilotTool): subscribe_dict["save_path"] = save_path if best_version is not None: subscribe_dict["best_version"] = best_version - + # 其他配置 if custom_words is not None: subscribe_dict["custom_words"] = custom_words @@ -181,35 +238,40 @@ class UpdateSubscribeTool(MoviePilotTool): subscribe_dict["media_category"] = media_category if episode_group is not None: subscribe_dict["episode_group"] = episode_group - + # 如果没有要更新的字段 if not subscribe_dict: - return json.dumps({ - "success": False, - "message": "没有提供要更新的字段" - }, ensure_ascii=False) - + return json.dumps( + {"success": False, "message": "没有提供要更新的字段"}, + ensure_ascii=False, + ) + # 更新订阅 await subscribe.async_update(db, subscribe_dict) - + # 重新获取更新后的订阅数据 updated_subscribe = await Subscribe.async_get(db, subscribe_id) - + # 发送订阅调整事件 - await eventmanager.async_send_event(EventType.SubscribeModified, { - "subscribe_id": subscribe_id, - "old_subscribe_info": old_subscribe_dict, - "subscribe_info": updated_subscribe.to_dict() if updated_subscribe else {}, - }) - + await eventmanager.async_send_event( + EventType.SubscribeModified, + { + "subscribe_id": subscribe_id, + "old_subscribe_info": old_subscribe_dict, + "subscribe_info": updated_subscribe.to_dict() + if updated_subscribe + else {}, + }, + ) + # 构建返回结果 result = { "success": True, "message": f"订阅 #{subscribe_id} 更新成功", "subscribe_id": subscribe_id, - "updated_fields": list(subscribe_dict.keys()) + "updated_fields": list(subscribe_dict.keys()), } - + if updated_subscribe: result["subscribe"] = { "id": updated_subscribe.id, @@ -223,17 +285,19 @@ class UpdateSubscribeTool(MoviePilotTool): "start_episode": updated_subscribe.start_episode, "quality": updated_subscribe.quality, "resolution": updated_subscribe.resolution, - "effect": updated_subscribe.effect + "effect": updated_subscribe.effect, } - + return json.dumps(result, ensure_ascii=False, indent=2) - + except Exception as e: error_message = f"更新订阅失败: {str(e)}" logger.error(f"更新订阅失败: {e}", exc_info=True) - return json.dumps({ - "success": False, - "message": error_message, - "subscribe_id": subscribe_id - }, ensure_ascii=False) - + return json.dumps( + { + "success": False, + "message": error_message, + "subscribe_id": subscribe_id, + }, + ensure_ascii=False, + ) diff --git a/app/agent/tools/impl/write_file.py b/app/agent/tools/impl/write_file.py index 41be6e3b..565c5156 100644 --- a/app/agent/tools/impl/write_file.py +++ b/app/agent/tools/impl/write_file.py @@ -12,6 +12,7 @@ from app.log import logger class WriteFileInput(BaseModel): """Input parameters for write file tool""" + file_path: str = Field(..., description="The absolute path of the file to write") content: str = Field(..., description="The content to write into the file") @@ -20,6 +21,7 @@ class WriteFileTool(MoviePilotTool): name: str = "write_file" description: str = "Write full content to a file. If the file already exists, it will be overwritten. Automatically creates parent directories if they don't exist." args_schema: Type[BaseModel] = WriteFileInput + require_admin: bool = True def get_tool_message(self, **kwargs) -> Optional[str]: """根据参数生成友好的提示消息""" @@ -32,16 +34,16 @@ class WriteFileTool(MoviePilotTool): try: path = AsyncPath(file_path) - + if await path.exists() and not await path.is_file(): return f"错误:{file_path} 路径已存在但不是一个文件" # 自动创建父目录 await path.parent.mkdir(parents=True, exist_ok=True) - + # 写入文件 await path.write_text(content, encoding="utf-8") - + logger.info(f"成功写入文件 {file_path}") return f"成功写入文件 {file_path}"