mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-03-31 09:21:04 +08:00
feat(agent): 为需要管理员权限的工具添加 require_admin 字段
- ExecuteCommandTool: 执行命令行 - DeleteDownloadHistoryTool: 删除下载历史 - EditFileTool: 编辑文件 - WriteFileTool: 写入文件 - TransferFileTool: 传输文件 - UpdateSiteTool: 更新站点 - UpdateSiteCookieTool: 更新站点Cookie - UpdateSubscribeTool: 更新订阅 - DeleteSubscribeTool: 删除订阅 - DeleteDownloadTool: 删除下载 - ModifyDownloadTool: 修改下载 - RunSchedulerTool: 运行定时任务 - RunWorkflowTool: 运行工作流 - RunPluginCommandTool: 运行插件命令 - SendMessageTool: 发送消息
This commit is contained in:
@@ -30,11 +30,13 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
_source: Optional[str] = PrivateAttr(default=None)
|
||||
_username: Optional[str] = PrivateAttr(default=None)
|
||||
_stream_handler: Optional[StreamingHandler] = PrivateAttr(default=None)
|
||||
_require_admin: bool = PrivateAttr(default=False)
|
||||
|
||||
def __init__(self, session_id: str, user_id: str, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._session_id = session_id
|
||||
self._user_id = user_id
|
||||
self._require_admin = getattr(self.__class__, "require_admin", False)
|
||||
|
||||
def _run(self, *args: Any, **kwargs: Any) -> Any:
|
||||
raise NotImplementedError("MoviePilotTool 只支持异步调用,请使用 _arun")
|
||||
@@ -143,11 +145,15 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
async def _check_permission(self) -> Optional[str]:
|
||||
"""
|
||||
检查用户权限:
|
||||
1. 首先检查用户是否是渠道管理员
|
||||
2. 如果渠道没有设置管理员名单,则检查用户是否是系统管理员
|
||||
3. 如果都不是系统管理员,检查用户ID是否等于渠道配置的用户ID
|
||||
4. 如果都不是,返回权限拒绝消息
|
||||
1. 首先检查工具是否需要管理员权限
|
||||
2. 如果需要管理员权限,则检查用户是否是渠道管理员
|
||||
3. 如果渠道没有设置管理员名单,则检查用户是否是系统管理员
|
||||
4. 如果都不是系统管理员,检查用户ID是否等于渠道配置的用户ID
|
||||
5. 如果都不是,返回权限拒绝消息
|
||||
"""
|
||||
if not self._require_admin:
|
||||
return None
|
||||
|
||||
if not self._channel or not self._source:
|
||||
return None
|
||||
|
||||
|
||||
@@ -11,46 +11,68 @@ from app.log import logger
|
||||
|
||||
class DeleteDownloadInput(BaseModel):
|
||||
"""删除下载任务工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
hash: str = Field(..., description="Task hash (can be obtained from query_download_tasks tool)")
|
||||
downloader: Optional[str] = Field(None, description="Name of specific downloader (optional, if not provided will search all downloaders)")
|
||||
delete_files: Optional[bool] = Field(False, description="Whether to delete downloaded files along with the task (default: False, only removes the task from downloader)")
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
hash: str = Field(
|
||||
..., description="Task hash (can be obtained from query_download_tasks tool)"
|
||||
)
|
||||
downloader: Optional[str] = Field(
|
||||
None,
|
||||
description="Name of specific downloader (optional, if not provided will search all downloaders)",
|
||||
)
|
||||
delete_files: Optional[bool] = Field(
|
||||
False,
|
||||
description="Whether to delete downloaded files along with the task (default: False, only removes the task from downloader)",
|
||||
)
|
||||
|
||||
|
||||
class DeleteDownloadTool(MoviePilotTool):
|
||||
name: str = "delete_download"
|
||||
description: str = "Delete a download task from the downloader by task hash only. Optionally specify the downloader name and whether to delete downloaded files."
|
||||
args_schema: Type[BaseModel] = DeleteDownloadInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据删除参数生成友好的提示消息"""
|
||||
hash_value = kwargs.get("hash", "")
|
||||
downloader = kwargs.get("downloader")
|
||||
delete_files = kwargs.get("delete_files", False)
|
||||
|
||||
|
||||
message = f"正在删除下载任务: {hash_value}"
|
||||
if downloader:
|
||||
message += f" [下载器: {downloader}]"
|
||||
if delete_files:
|
||||
message += " (包含文件)"
|
||||
|
||||
|
||||
return message
|
||||
|
||||
async def run(self, hash: str, downloader: Optional[str] = None,
|
||||
delete_files: Optional[bool] = False, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: hash={hash}, downloader={downloader}, delete_files={delete_files}")
|
||||
async def run(
|
||||
self,
|
||||
hash: str,
|
||||
downloader: Optional[str] = None,
|
||||
delete_files: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: hash={hash}, downloader={downloader}, delete_files={delete_files}"
|
||||
)
|
||||
|
||||
try:
|
||||
download_chain = DownloadChain()
|
||||
|
||||
# 仅支持通过hash删除任务
|
||||
if len(hash) != 40 or not all(c in '0123456789abcdefABCDEF' for c in hash):
|
||||
if len(hash) != 40 or not all(c in "0123456789abcdefABCDEF" for c in hash):
|
||||
return "参数错误:hash 格式无效,请先使用 query_download_tasks 工具获取正确的 hash。"
|
||||
|
||||
|
||||
# 删除下载任务
|
||||
# remove_torrents 支持 delete_file 参数,可以控制是否删除文件
|
||||
result = download_chain.remove_torrents(hashs=[hash], downloader=downloader, delete_file=delete_files)
|
||||
|
||||
result = download_chain.remove_torrents(
|
||||
hashs=[hash], downloader=downloader, delete_file=delete_files
|
||||
)
|
||||
|
||||
if result:
|
||||
files_info = "(包含文件)" if delete_files else "(不包含文件)"
|
||||
return f"成功删除下载任务:{hash} {files_info}"
|
||||
@@ -59,4 +81,3 @@ class DeleteDownloadTool(MoviePilotTool):
|
||||
except Exception as e:
|
||||
logger.error(f"删除下载任务失败: {e}", exc_info=True)
|
||||
return f"删除下载任务时发生错误: {str(e)}"
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ class DeleteDownloadHistoryTool(MoviePilotTool):
|
||||
name: str = "delete_download_history"
|
||||
description: str = "Delete a download history record by ID. This only removes the record from the database, does not delete any actual files."
|
||||
args_schema: Type[BaseModel] = DeleteDownloadHistoryInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
history_id = kwargs.get("history_id")
|
||||
|
||||
@@ -14,14 +14,22 @@ from app.schemas.types import EventType
|
||||
|
||||
class DeleteSubscribeInput(BaseModel):
|
||||
"""删除订阅工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
subscribe_id: int = Field(..., description="The ID of the subscription to delete (can be obtained from query_subscribes tool)")
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
subscribe_id: int = Field(
|
||||
...,
|
||||
description="The ID of the subscription to delete (can be obtained from query_subscribes tool)",
|
||||
)
|
||||
|
||||
|
||||
class DeleteSubscribeTool(MoviePilotTool):
|
||||
name: str = "delete_subscribe"
|
||||
description: str = "Delete a media subscription by its ID. This will remove the subscription and stop automatic downloads for that media."
|
||||
args_schema: Type[BaseModel] = DeleteSubscribeInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据删除参数生成友好的提示消息"""
|
||||
@@ -37,27 +45,25 @@ class DeleteSubscribeTool(MoviePilotTool):
|
||||
subscribe = await subscribe_oper.async_get(subscribe_id)
|
||||
if not subscribe:
|
||||
return f"订阅 ID {subscribe_id} 不存在"
|
||||
|
||||
|
||||
# 在删除之前获取订阅信息(用于事件)
|
||||
subscribe_info = subscribe.to_dict()
|
||||
|
||||
|
||||
# 删除订阅
|
||||
subscribe_oper.delete(subscribe_id)
|
||||
|
||||
|
||||
# 发送事件
|
||||
await eventmanager.async_send_event(EventType.SubscribeDeleted, {
|
||||
"subscribe_id": subscribe_id,
|
||||
"subscribe_info": subscribe_info
|
||||
})
|
||||
|
||||
await eventmanager.async_send_event(
|
||||
EventType.SubscribeDeleted,
|
||||
{"subscribe_id": subscribe_id, "subscribe_info": subscribe_info},
|
||||
)
|
||||
|
||||
# 统计订阅
|
||||
SubscribeHelper().sub_done_async({
|
||||
"tmdbid": subscribe.tmdbid,
|
||||
"doubanid": subscribe.doubanid
|
||||
})
|
||||
|
||||
SubscribeHelper().sub_done_async(
|
||||
{"tmdbid": subscribe.tmdbid, "doubanid": subscribe.doubanid}
|
||||
)
|
||||
|
||||
return f"成功删除订阅:{subscribe.name} ({subscribe.year})"
|
||||
except Exception as e:
|
||||
logger.error(f"删除订阅失败: {e}", exc_info=True)
|
||||
return f"删除订阅时发生错误: {str(e)}"
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ from app.log import logger
|
||||
|
||||
class EditFileInput(BaseModel):
|
||||
"""Input parameters for edit file tool"""
|
||||
|
||||
file_path: str = Field(..., description="The absolute path of the file to edit")
|
||||
old_text: str = Field(..., description="The exact old text to be replaced")
|
||||
new_text: str = Field(..., description="The new text to replace with")
|
||||
@@ -21,6 +22,7 @@ class EditFileTool(MoviePilotTool):
|
||||
name: str = "edit_file"
|
||||
description: str = "Edit a file by replacing specific old text with new text. Useful for modifying configuration files, code, or scripts."
|
||||
args_schema: Type[BaseModel] = EditFileInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据参数生成友好的提示消息"""
|
||||
@@ -38,7 +40,7 @@ class EditFileTool(MoviePilotTool):
|
||||
# 如果 old_text 为空,可能用户想直接创建文件,但通常 edit_file 需要匹配旧内容
|
||||
if old_text:
|
||||
return f"错误:文件 {file_path} 不存在,无法进行内容替换。"
|
||||
|
||||
|
||||
if await path.exists() and not await path.is_file():
|
||||
return f"错误:{file_path} 不是一个文件"
|
||||
|
||||
@@ -56,14 +58,13 @@ class EditFileTool(MoviePilotTool):
|
||||
|
||||
# 自动创建父目录
|
||||
await path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
# 写入文件
|
||||
await path.write_text(new_content, encoding="utf-8")
|
||||
|
||||
|
||||
logger.info(f"成功编辑文件 {file_path},替换了 {occurrences} 处内容")
|
||||
return f"成功编辑文件 {file_path} (替换了 {occurrences} 处匹配内容)"
|
||||
|
||||
|
||||
except PermissionError:
|
||||
return f"错误:没有访问/修改 {file_path} 的权限"
|
||||
except UnicodeDecodeError:
|
||||
@@ -71,5 +72,3 @@ class EditFileTool(MoviePilotTool):
|
||||
except Exception as e:
|
||||
logger.error(f"编辑文件 {file_path} 时发生错误: {str(e)}", exc_info=True)
|
||||
return f"操作失败: {str(e)}"
|
||||
|
||||
|
||||
|
||||
@@ -11,15 +11,21 @@ from app.log import logger
|
||||
|
||||
class ExecuteCommandInput(BaseModel):
|
||||
"""执行Shell命令工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this command is being executed")
|
||||
|
||||
explanation: str = Field(
|
||||
..., description="Clear explanation of why this command is being executed"
|
||||
)
|
||||
command: str = Field(..., description="The shell command to execute")
|
||||
timeout: Optional[int] = Field(60, description="Max execution time in seconds (default: 60)")
|
||||
timeout: Optional[int] = Field(
|
||||
60, description="Max execution time in seconds (default: 60)"
|
||||
)
|
||||
|
||||
|
||||
class ExecuteCommandTool(MoviePilotTool):
|
||||
name: str = "execute_command"
|
||||
description: str = "Safely execute shell commands on the server. Useful for system maintenance, checking status, or running custom scripts. Includes timeout and output limits."
|
||||
args_schema: Type[BaseModel] = ExecuteCommandInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据命令生成友好的提示消息"""
|
||||
@@ -27,10 +33,19 @@ class ExecuteCommandTool(MoviePilotTool):
|
||||
return f"正在执行系统命令: {command}"
|
||||
|
||||
async def run(self, command: str, timeout: Optional[int] = 60, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: command={command}, timeout={timeout}")
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: command={command}, timeout={timeout}"
|
||||
)
|
||||
|
||||
# 简单安全过滤
|
||||
forbidden_keywords = ["rm -rf /", ":(){ :|:& };:", "dd if=/dev/zero", "mkfs", "reboot", "shutdown"]
|
||||
forbidden_keywords = [
|
||||
"rm -rf /",
|
||||
":(){ :|:& };:",
|
||||
"dd if=/dev/zero",
|
||||
"mkfs",
|
||||
"reboot",
|
||||
"shutdown",
|
||||
]
|
||||
for keyword in forbidden_keywords:
|
||||
if keyword in command:
|
||||
return f"错误:命令包含禁止使用的关键字 '{keyword}'"
|
||||
@@ -38,18 +53,18 @@ class ExecuteCommandTool(MoviePilotTool):
|
||||
try:
|
||||
# 执行命令
|
||||
process = await asyncio.create_subprocess_shell(
|
||||
command,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE
|
||||
command, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
|
||||
)
|
||||
|
||||
try:
|
||||
# 等待完成,带超时
|
||||
stdout, stderr = await asyncio.wait_for(process.communicate(), timeout=timeout)
|
||||
|
||||
stdout, stderr = await asyncio.wait_for(
|
||||
process.communicate(), timeout=timeout
|
||||
)
|
||||
|
||||
# 处理输出
|
||||
stdout_str = stdout.decode('utf-8', errors='replace').strip()
|
||||
stderr_str = stderr.decode('utf-8', errors='replace').strip()
|
||||
stdout_str = stdout.decode("utf-8", errors="replace").strip()
|
||||
stderr_str = stderr.decode("utf-8", errors="replace").strip()
|
||||
exit_code = process.returncode
|
||||
|
||||
result = f"命令执行完成 (退出码: {exit_code})"
|
||||
@@ -57,15 +72,15 @@ class ExecuteCommandTool(MoviePilotTool):
|
||||
result += f"\n\n标准输出:\n{stdout_str}"
|
||||
if stderr_str:
|
||||
result += f"\n\n错误输出:\n{stderr_str}"
|
||||
|
||||
|
||||
# 如果没有输出
|
||||
if not stdout_str and not stderr_str:
|
||||
result += "\n\n(无输出内容)"
|
||||
|
||||
|
||||
# 限制输出长度,防止上下文过长
|
||||
if len(result) > 3000:
|
||||
result = result[:3000] + "\n\n...(输出内容过长,已截断)"
|
||||
|
||||
|
||||
return result
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
|
||||
@@ -47,6 +47,7 @@ class ModifyDownloadTool(MoviePilotTool):
|
||||
"Multiple operations can be performed in a single call."
|
||||
)
|
||||
args_schema: Type[BaseModel] = ModifyDownloadInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
hash_value = kwargs.get("hash", "")
|
||||
|
||||
@@ -37,6 +37,7 @@ class RunPluginCommandTool(MoviePilotTool):
|
||||
"Note: This tool triggers the command execution but the actual processing happens in the background."
|
||||
)
|
||||
args_schema: Type[BaseModel] = RunPluginCommandInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""生成友好的提示消息"""
|
||||
|
||||
@@ -11,14 +11,22 @@ from app.scheduler import Scheduler
|
||||
|
||||
class RunSchedulerInput(BaseModel):
|
||||
"""运行定时服务工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
job_id: str = Field(..., description="The ID of the scheduled job to run (can be obtained from query_schedulers tool)")
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
job_id: str = Field(
|
||||
...,
|
||||
description="The ID of the scheduled job to run (can be obtained from query_schedulers tool)",
|
||||
)
|
||||
|
||||
|
||||
class RunSchedulerTool(MoviePilotTool):
|
||||
name: str = "run_scheduler"
|
||||
description: str = "Manually trigger a scheduled task to run immediately. This will execute the specified scheduler job by its ID."
|
||||
args_schema: Type[BaseModel] = RunSchedulerInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据运行参数生成友好的提示消息"""
|
||||
@@ -39,15 +47,14 @@ class RunSchedulerTool(MoviePilotTool):
|
||||
job_exists = True
|
||||
job_name = s.name
|
||||
break
|
||||
|
||||
|
||||
if not job_exists:
|
||||
return f"定时服务 ID {job_id} 不存在,请使用 query_schedulers 工具查询可用的定时服务"
|
||||
|
||||
|
||||
# 运行定时服务
|
||||
scheduler.start(job_id)
|
||||
|
||||
|
||||
return f"成功触发定时服务:{job_name} (ID: {job_id})"
|
||||
except Exception as e:
|
||||
logger.error(f"运行定时服务失败: {e}", exc_info=True)
|
||||
return f"运行定时服务时发生错误: {str(e)}"
|
||||
|
||||
|
||||
@@ -13,46 +13,61 @@ from app.log import logger
|
||||
|
||||
class RunWorkflowInput(BaseModel):
|
||||
"""执行工作流工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
workflow_id: int = Field(..., description="Workflow ID (can be obtained from query_workflows tool)")
|
||||
from_begin: Optional[bool] = Field(True, description="Whether to run workflow from the beginning (default: True, if False will continue from last executed action)")
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
workflow_id: int = Field(
|
||||
..., description="Workflow ID (can be obtained from query_workflows tool)"
|
||||
)
|
||||
from_begin: Optional[bool] = Field(
|
||||
True,
|
||||
description="Whether to run workflow from the beginning (default: True, if False will continue from last executed action)",
|
||||
)
|
||||
|
||||
|
||||
class RunWorkflowTool(MoviePilotTool):
|
||||
name: str = "run_workflow"
|
||||
description: str = "Execute a specific workflow manually by workflow ID. Supports running from the beginning or continuing from the last executed action."
|
||||
args_schema: Type[BaseModel] = RunWorkflowInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据工作流参数生成友好的提示消息"""
|
||||
workflow_id = kwargs.get("workflow_id")
|
||||
from_begin = kwargs.get("from_begin", True)
|
||||
|
||||
|
||||
message = f"正在执行工作流: {workflow_id}"
|
||||
if not from_begin:
|
||||
message += " (从上次位置继续)"
|
||||
else:
|
||||
message += " (从头开始)"
|
||||
|
||||
|
||||
return message
|
||||
|
||||
async def run(self, workflow_id: int,
|
||||
from_begin: Optional[bool] = True, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: workflow_id={workflow_id}, from_begin={from_begin}")
|
||||
async def run(
|
||||
self, workflow_id: int, from_begin: Optional[bool] = True, **kwargs
|
||||
) -> str:
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: workflow_id={workflow_id}, from_begin={from_begin}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 获取数据库会话
|
||||
async with AsyncSessionFactory() as db:
|
||||
workflow_oper = WorkflowOper(db)
|
||||
workflow = await workflow_oper.async_get(workflow_id)
|
||||
|
||||
|
||||
if not workflow:
|
||||
return f"未找到工作流:{workflow_id},请使用 query_workflows 工具查询可用的工作流"
|
||||
|
||||
|
||||
# 执行工作流
|
||||
workflow_chain = WorkflowChain()
|
||||
state, errmsg = workflow_chain.process(workflow.id, from_begin=from_begin)
|
||||
|
||||
state, errmsg = workflow_chain.process(
|
||||
workflow.id, from_begin=from_begin
|
||||
)
|
||||
|
||||
if not state:
|
||||
return f"执行工作流失败:{workflow.name} (ID: {workflow.id})\n错误原因:{errmsg}"
|
||||
else:
|
||||
@@ -60,4 +75,3 @@ class RunWorkflowTool(MoviePilotTool):
|
||||
except Exception as e:
|
||||
logger.error(f"执行工作流失败: {e}", exc_info=True)
|
||||
return f"执行工作流时发生错误: {str(e)}"
|
||||
|
||||
|
||||
@@ -10,33 +10,52 @@ from app.log import logger
|
||||
|
||||
class SendMessageInput(BaseModel):
|
||||
"""发送消息工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
message: str = Field(..., description="The message content to send to the user (should be clear and informative)")
|
||||
message_type: Optional[str] = Field("info",
|
||||
description="Type of message: 'info' for general information, 'success' for successful operations, 'warning' for warnings, 'error' for error messages")
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
message: str = Field(
|
||||
...,
|
||||
description="The message content to send to the user (should be clear and informative)",
|
||||
)
|
||||
message_type: Optional[str] = Field(
|
||||
"info",
|
||||
description="Type of message: 'info' for general information, 'success' for successful operations, 'warning' for warnings, 'error' for error messages",
|
||||
)
|
||||
|
||||
|
||||
class SendMessageTool(MoviePilotTool):
|
||||
name: str = "send_message"
|
||||
description: str = "Send notification message to the user through configured notification channels (Telegram, Slack, WeChat, etc.). Used to inform users about operation results, errors, or important updates."
|
||||
args_schema: Type[BaseModel] = SendMessageInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据消息参数生成友好的提示消息"""
|
||||
message = kwargs.get("message", "")
|
||||
message_type = kwargs.get("message_type", "info")
|
||||
|
||||
type_map = {"info": "信息", "success": "成功", "warning": "警告", "error": "错误"}
|
||||
|
||||
type_map = {
|
||||
"info": "信息",
|
||||
"success": "成功",
|
||||
"warning": "警告",
|
||||
"error": "错误",
|
||||
}
|
||||
type_desc = type_map.get(message_type, message_type)
|
||||
|
||||
|
||||
# 截断过长的消息
|
||||
if len(message) > 50:
|
||||
message = message[:50] + "..."
|
||||
|
||||
|
||||
return f"正在发送{type_desc}消息: {message}"
|
||||
|
||||
async def run(self, message: str, message_type: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: message={message}, message_type={message_type}")
|
||||
async def run(
|
||||
self, message: str, message_type: Optional[str] = None, **kwargs
|
||||
) -> str:
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: message={message}, message_type={message_type}"
|
||||
)
|
||||
try:
|
||||
await self.send_tool_message(message, title=message_type)
|
||||
return "消息已发送"
|
||||
|
||||
@@ -13,23 +13,53 @@ from app.schemas import FileItem, MediaType
|
||||
|
||||
class TransferFileInput(BaseModel):
|
||||
"""整理文件或目录工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
file_path: str = Field(..., description="Path to the file or directory to transfer (e.g., '/path/to/file.mkv' or '/path/to/directory')")
|
||||
storage: Optional[str] = Field("local", description="Storage type of the source file (default: 'local', can be 'smb', 'alist', etc.)")
|
||||
target_path: Optional[str] = Field(None, description="Target path for the transferred file/directory (optional, uses default library path if not specified)")
|
||||
target_storage: Optional[str] = Field(None, description="Target storage type (optional, uses default storage if not specified)")
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
file_path: str = Field(
|
||||
...,
|
||||
description="Path to the file or directory to transfer (e.g., '/path/to/file.mkv' or '/path/to/directory')",
|
||||
)
|
||||
storage: Optional[str] = Field(
|
||||
"local",
|
||||
description="Storage type of the source file (default: 'local', can be 'smb', 'alist', etc.)",
|
||||
)
|
||||
target_path: Optional[str] = Field(
|
||||
None,
|
||||
description="Target path for the transferred file/directory (optional, uses default library path if not specified)",
|
||||
)
|
||||
target_storage: Optional[str] = Field(
|
||||
None,
|
||||
description="Target storage type (optional, uses default storage if not specified)",
|
||||
)
|
||||
media_type: Optional[str] = Field(None, description="Allowed values: movie, tv")
|
||||
tmdbid: Optional[int] = Field(None, description="TMDB ID for precise media identification (optional but recommended for accuracy)")
|
||||
doubanid: Optional[str] = Field(None, description="Douban ID for media identification (optional)")
|
||||
season: Optional[int] = Field(None, description="Season number for TV shows (optional)")
|
||||
transfer_type: Optional[str] = Field(None, description="Transfer mode: 'move' to move files, 'copy' to copy files, 'link' for hard link, 'softlink' for symbolic link (optional, uses default mode if not specified)")
|
||||
background: Optional[bool] = Field(False, description="Whether to run transfer in background (default: False, runs synchronously)")
|
||||
tmdbid: Optional[int] = Field(
|
||||
None,
|
||||
description="TMDB ID for precise media identification (optional but recommended for accuracy)",
|
||||
)
|
||||
doubanid: Optional[str] = Field(
|
||||
None, description="Douban ID for media identification (optional)"
|
||||
)
|
||||
season: Optional[int] = Field(
|
||||
None, description="Season number for TV shows (optional)"
|
||||
)
|
||||
transfer_type: Optional[str] = Field(
|
||||
None,
|
||||
description="Transfer mode: 'move' to move files, 'copy' to copy files, 'link' for hard link, 'softlink' for symbolic link (optional, uses default mode if not specified)",
|
||||
)
|
||||
background: Optional[bool] = Field(
|
||||
False,
|
||||
description="Whether to run transfer in background (default: False, runs synchronously)",
|
||||
)
|
||||
|
||||
|
||||
class TransferFileTool(MoviePilotTool):
|
||||
name: str = "transfer_file"
|
||||
description: str = "Transfer/organize a file or directory to the media library. Automatically recognizes media information and organizes files according to configured rules. Supports custom target paths, media identification, and transfer modes."
|
||||
args_schema: Type[BaseModel] = TransferFileInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据整理参数生成友好的提示消息"""
|
||||
@@ -37,66 +67,79 @@ class TransferFileTool(MoviePilotTool):
|
||||
media_type = kwargs.get("media_type")
|
||||
transfer_type = kwargs.get("transfer_type")
|
||||
background = kwargs.get("background", False)
|
||||
|
||||
|
||||
message = f"正在整理文件: {file_path}"
|
||||
if media_type:
|
||||
message += f" [{media_type}]"
|
||||
if transfer_type:
|
||||
transfer_map = {"move": "移动", "copy": "复制", "link": "硬链接", "softlink": "软链接"}
|
||||
transfer_map = {
|
||||
"move": "移动",
|
||||
"copy": "复制",
|
||||
"link": "硬链接",
|
||||
"softlink": "软链接",
|
||||
}
|
||||
message += f" 模式: {transfer_map.get(transfer_type, transfer_type)}"
|
||||
if background:
|
||||
message += " [后台运行]"
|
||||
|
||||
|
||||
return message
|
||||
|
||||
async def run(self, file_path: str, storage: Optional[str] = "local",
|
||||
target_path: Optional[str] = None,
|
||||
target_storage: Optional[str] = None,
|
||||
media_type: Optional[str] = None,
|
||||
tmdbid: Optional[int] = None,
|
||||
doubanid: Optional[str] = None,
|
||||
season: Optional[int] = None,
|
||||
transfer_type: Optional[str] = None,
|
||||
background: Optional[bool] = False, **kwargs) -> str:
|
||||
async def run(
|
||||
self,
|
||||
file_path: str,
|
||||
storage: Optional[str] = "local",
|
||||
target_path: Optional[str] = None,
|
||||
target_storage: Optional[str] = None,
|
||||
media_type: Optional[str] = None,
|
||||
tmdbid: Optional[int] = None,
|
||||
doubanid: Optional[str] = None,
|
||||
season: Optional[int] = None,
|
||||
transfer_type: Optional[str] = None,
|
||||
background: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: file_path={file_path}, storage={storage}, target_path={target_path}, "
|
||||
f"target_storage={target_storage}, media_type={media_type}, tmdbid={tmdbid}, doubanid={doubanid}, "
|
||||
f"season={season}, transfer_type={transfer_type}, background={background}")
|
||||
f"season={season}, transfer_type={transfer_type}, background={background}"
|
||||
)
|
||||
|
||||
try:
|
||||
if not file_path:
|
||||
return "错误:必须提供文件或目录路径"
|
||||
|
||||
|
||||
# 规范化路径
|
||||
if storage == "local":
|
||||
# 本地路径处理
|
||||
if not file_path.startswith("/") and not (len(file_path) > 1 and file_path[1] == ":"):
|
||||
if not file_path.startswith("/") and not (
|
||||
len(file_path) > 1 and file_path[1] == ":"
|
||||
):
|
||||
# 相对路径,尝试转换为绝对路径
|
||||
file_path = str(Path(file_path).resolve())
|
||||
else:
|
||||
# 远程存储路径,确保以/开头
|
||||
if not file_path.startswith("/"):
|
||||
file_path = "/" + file_path
|
||||
|
||||
|
||||
# 创建FileItem
|
||||
fileitem = FileItem(
|
||||
storage=storage or "local",
|
||||
path=file_path,
|
||||
type="dir" if file_path.endswith("/") else "file"
|
||||
type="dir" if file_path.endswith("/") else "file",
|
||||
)
|
||||
|
||||
|
||||
# 处理目标路径
|
||||
target_path_obj = None
|
||||
if target_path:
|
||||
target_path_obj = Path(target_path)
|
||||
|
||||
|
||||
# 处理媒体类型
|
||||
media_type_enum = None
|
||||
if media_type:
|
||||
media_type_enum = MediaType.from_agent(media_type)
|
||||
if not media_type_enum:
|
||||
return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv'"
|
||||
|
||||
|
||||
# 调用整理方法
|
||||
transfer_chain = TransferChain()
|
||||
state, errormsg = transfer_chain.manual_transfer(
|
||||
@@ -108,15 +151,17 @@ class TransferFileTool(MoviePilotTool):
|
||||
mtype=media_type_enum,
|
||||
season=season,
|
||||
transfer_type=transfer_type,
|
||||
background=background
|
||||
background=background,
|
||||
)
|
||||
|
||||
|
||||
if not state:
|
||||
# 处理错误信息
|
||||
if isinstance(errormsg, list):
|
||||
error_text = f"整理完成,{len(errormsg)} 个文件转移失败"
|
||||
if errormsg:
|
||||
error_text += f":\n" + "\n".join(str(e) for e in errormsg[:5]) # 只显示前5个错误
|
||||
error_text += f":\n" + "\n".join(
|
||||
str(e) for e in errormsg[:5]
|
||||
) # 只显示前5个错误
|
||||
if len(errormsg) > 5:
|
||||
error_text += f"\n... 还有 {len(errormsg) - 5} 个错误"
|
||||
else:
|
||||
@@ -130,4 +175,3 @@ class TransferFileTool(MoviePilotTool):
|
||||
except Exception as e:
|
||||
logger.error(f"整理文件失败: {e}", exc_info=True)
|
||||
return f"整理文件时发生错误: {str(e)}"
|
||||
|
||||
|
||||
@@ -16,37 +16,67 @@ from app.utils.string import StringUtils
|
||||
|
||||
class UpdateSiteInput(BaseModel):
|
||||
"""更新站点工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
site_id: int = Field(..., description="The ID of the site to update (can be obtained from query_sites tool)")
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
site_id: int = Field(
|
||||
...,
|
||||
description="The ID of the site to update (can be obtained from query_sites tool)",
|
||||
)
|
||||
name: Optional[str] = Field(None, description="Site name (optional)")
|
||||
url: Optional[str] = Field(None, description="Site URL (optional, will be automatically formatted)")
|
||||
pri: Optional[int] = Field(None, description="Site priority (optional, smaller value = higher priority, e.g., pri=1 has higher priority than pri=10)")
|
||||
url: Optional[str] = Field(
|
||||
None, description="Site URL (optional, will be automatically formatted)"
|
||||
)
|
||||
pri: Optional[int] = Field(
|
||||
None,
|
||||
description="Site priority (optional, smaller value = higher priority, e.g., pri=1 has higher priority than pri=10)",
|
||||
)
|
||||
rss: Optional[str] = Field(None, description="RSS feed URL (optional)")
|
||||
cookie: Optional[str] = Field(None, description="Site cookie (optional)")
|
||||
ua: Optional[str] = Field(None, description="User-Agent string (optional)")
|
||||
apikey: Optional[str] = Field(None, description="API key (optional)")
|
||||
token: Optional[str] = Field(None, description="API token (optional)")
|
||||
proxy: Optional[int] = Field(None, description="Whether to use proxy: 0 for no, 1 for yes (optional)")
|
||||
filter: Optional[str] = Field(None, description="Filter rule as regular expression (optional)")
|
||||
proxy: Optional[int] = Field(
|
||||
None, description="Whether to use proxy: 0 for no, 1 for yes (optional)"
|
||||
)
|
||||
filter: Optional[str] = Field(
|
||||
None, description="Filter rule as regular expression (optional)"
|
||||
)
|
||||
note: Optional[str] = Field(None, description="Site notes/remarks (optional)")
|
||||
timeout: Optional[int] = Field(None, description="Request timeout in seconds (optional, default: 15)")
|
||||
limit_interval: Optional[int] = Field(None, description="Rate limit interval in seconds (optional)")
|
||||
limit_count: Optional[int] = Field(None, description="Rate limit count per interval (optional)")
|
||||
limit_seconds: Optional[int] = Field(None, description="Rate limit seconds between requests (optional)")
|
||||
is_active: Optional[bool] = Field(None, description="Whether site is active: True for enabled, False for disabled (optional)")
|
||||
downloader: Optional[str] = Field(None, description="Downloader name for this site (optional)")
|
||||
timeout: Optional[int] = Field(
|
||||
None, description="Request timeout in seconds (optional, default: 15)"
|
||||
)
|
||||
limit_interval: Optional[int] = Field(
|
||||
None, description="Rate limit interval in seconds (optional)"
|
||||
)
|
||||
limit_count: Optional[int] = Field(
|
||||
None, description="Rate limit count per interval (optional)"
|
||||
)
|
||||
limit_seconds: Optional[int] = Field(
|
||||
None, description="Rate limit seconds between requests (optional)"
|
||||
)
|
||||
is_active: Optional[bool] = Field(
|
||||
None,
|
||||
description="Whether site is active: True for enabled, False for disabled (optional)",
|
||||
)
|
||||
downloader: Optional[str] = Field(
|
||||
None, description="Downloader name for this site (optional)"
|
||||
)
|
||||
|
||||
|
||||
class UpdateSiteTool(MoviePilotTool):
|
||||
name: str = "update_site"
|
||||
description: str = "Update site configuration including URL, priority, authentication credentials (cookie, UA, API key), proxy settings, rate limits, and other site properties. Supports updating multiple site attributes at once. Site priority (pri): smaller values have higher priority (e.g., pri=1 has higher priority than pri=10)."
|
||||
args_schema: Type[BaseModel] = UpdateSiteInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据更新参数生成友好的提示消息"""
|
||||
site_id = kwargs.get("site_id")
|
||||
fields_updated = []
|
||||
|
||||
|
||||
if kwargs.get("name"):
|
||||
fields_updated.append("名称")
|
||||
if kwargs.get("url"):
|
||||
@@ -63,60 +93,63 @@ class UpdateSiteTool(MoviePilotTool):
|
||||
fields_updated.append("启用状态")
|
||||
if kwargs.get("downloader"):
|
||||
fields_updated.append("下载器")
|
||||
|
||||
|
||||
if fields_updated:
|
||||
return f"正在更新站点 #{site_id}: {', '.join(fields_updated)}"
|
||||
return f"正在更新站点 #{site_id}"
|
||||
|
||||
async def run(self, site_id: int,
|
||||
name: Optional[str] = None,
|
||||
url: Optional[str] = None,
|
||||
pri: Optional[int] = None,
|
||||
rss: Optional[str] = None,
|
||||
cookie: Optional[str] = None,
|
||||
ua: Optional[str] = None,
|
||||
apikey: Optional[str] = None,
|
||||
token: Optional[str] = None,
|
||||
proxy: Optional[int] = None,
|
||||
filter: Optional[str] = None,
|
||||
note: Optional[str] = None,
|
||||
timeout: Optional[int] = None,
|
||||
limit_interval: Optional[int] = None,
|
||||
limit_count: Optional[int] = None,
|
||||
limit_seconds: Optional[int] = None,
|
||||
is_active: Optional[bool] = None,
|
||||
downloader: Optional[str] = None,
|
||||
**kwargs) -> str:
|
||||
async def run(
|
||||
self,
|
||||
site_id: int,
|
||||
name: Optional[str] = None,
|
||||
url: Optional[str] = None,
|
||||
pri: Optional[int] = None,
|
||||
rss: Optional[str] = None,
|
||||
cookie: Optional[str] = None,
|
||||
ua: Optional[str] = None,
|
||||
apikey: Optional[str] = None,
|
||||
token: Optional[str] = None,
|
||||
proxy: Optional[int] = None,
|
||||
filter: Optional[str] = None,
|
||||
note: Optional[str] = None,
|
||||
timeout: Optional[int] = None,
|
||||
limit_interval: Optional[int] = None,
|
||||
limit_count: Optional[int] = None,
|
||||
limit_seconds: Optional[int] = None,
|
||||
is_active: Optional[bool] = None,
|
||||
downloader: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: site_id={site_id}")
|
||||
|
||||
|
||||
try:
|
||||
# 获取数据库会话
|
||||
async with AsyncSessionFactory() as db:
|
||||
# 获取站点
|
||||
site = await Site.async_get(db, site_id)
|
||||
if not site:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": f"站点不存在: {site_id}"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
return json.dumps(
|
||||
{"success": False, "message": f"站点不存在: {site_id}"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
# 构建更新字典
|
||||
site_dict = {}
|
||||
|
||||
|
||||
# 基本信息
|
||||
if name is not None:
|
||||
site_dict["name"] = name
|
||||
|
||||
|
||||
# URL处理(需要校正格式)
|
||||
if url is not None:
|
||||
_scheme, _netloc = StringUtils.get_url_netloc(url)
|
||||
site_dict["url"] = f"{_scheme}://{_netloc}/"
|
||||
|
||||
|
||||
if pri is not None:
|
||||
site_dict["pri"] = pri
|
||||
if rss is not None:
|
||||
site_dict["rss"] = rss
|
||||
|
||||
|
||||
# 认证信息
|
||||
if cookie is not None:
|
||||
site_dict["cookie"] = cookie
|
||||
@@ -126,7 +159,7 @@ class UpdateSiteTool(MoviePilotTool):
|
||||
site_dict["apikey"] = apikey
|
||||
if token is not None:
|
||||
site_dict["token"] = token
|
||||
|
||||
|
||||
# 配置选项
|
||||
if proxy is not None:
|
||||
site_dict["proxy"] = proxy
|
||||
@@ -136,7 +169,7 @@ class UpdateSiteTool(MoviePilotTool):
|
||||
site_dict["note"] = note
|
||||
if timeout is not None:
|
||||
site_dict["timeout"] = timeout
|
||||
|
||||
|
||||
# 流控设置
|
||||
if limit_interval is not None:
|
||||
site_dict["limit_interval"] = limit_interval
|
||||
@@ -144,39 +177,40 @@ class UpdateSiteTool(MoviePilotTool):
|
||||
site_dict["limit_count"] = limit_count
|
||||
if limit_seconds is not None:
|
||||
site_dict["limit_seconds"] = limit_seconds
|
||||
|
||||
|
||||
# 状态和下载器
|
||||
if is_active is not None:
|
||||
site_dict["is_active"] = is_active
|
||||
if downloader is not None:
|
||||
site_dict["downloader"] = downloader
|
||||
|
||||
|
||||
# 如果没有要更新的字段
|
||||
if not site_dict:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": "没有提供要更新的字段"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
return json.dumps(
|
||||
{"success": False, "message": "没有提供要更新的字段"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
# 更新站点
|
||||
await site.async_update(db, site_dict)
|
||||
|
||||
|
||||
# 重新获取更新后的站点数据
|
||||
updated_site = await Site.async_get(db, site_id)
|
||||
|
||||
|
||||
# 发送站点更新事件
|
||||
await eventmanager.async_send_event(EventType.SiteUpdated, {
|
||||
"domain": updated_site.domain if updated_site else site.domain
|
||||
})
|
||||
|
||||
await eventmanager.async_send_event(
|
||||
EventType.SiteUpdated,
|
||||
{"domain": updated_site.domain if updated_site else site.domain},
|
||||
)
|
||||
|
||||
# 构建返回结果
|
||||
result = {
|
||||
"success": True,
|
||||
"message": f"站点 #{site_id} 更新成功",
|
||||
"site_id": site_id,
|
||||
"updated_fields": list(site_dict.keys())
|
||||
"updated_fields": list(site_dict.keys()),
|
||||
}
|
||||
|
||||
|
||||
if updated_site:
|
||||
result["site"] = {
|
||||
"id": updated_site.id,
|
||||
@@ -187,17 +221,15 @@ class UpdateSiteTool(MoviePilotTool):
|
||||
"is_active": updated_site.is_active,
|
||||
"downloader": updated_site.downloader,
|
||||
"proxy": updated_site.proxy,
|
||||
"timeout": updated_site.timeout
|
||||
"timeout": updated_site.timeout,
|
||||
}
|
||||
|
||||
|
||||
return json.dumps(result, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"更新站点失败: {str(e)}"
|
||||
logger.error(f"更新站点失败: {e}", exc_info=True)
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": error_message,
|
||||
"site_id": site_id
|
||||
}, ensure_ascii=False)
|
||||
|
||||
return json.dumps(
|
||||
{"success": False, "message": error_message, "site_id": site_id},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
@@ -12,50 +12,69 @@ from app.log import logger
|
||||
|
||||
class UpdateSiteCookieInput(BaseModel):
|
||||
"""更新站点Cookie和UA工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
site_identifier: int = Field(..., description="Site ID to update Cookie and User-Agent for (can be obtained from query_sites tool)")
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
site_identifier: int = Field(
|
||||
...,
|
||||
description="Site ID to update Cookie and User-Agent for (can be obtained from query_sites tool)",
|
||||
)
|
||||
username: str = Field(..., description="Site login username")
|
||||
password: str = Field(..., description="Site login password")
|
||||
two_step_code: Optional[str] = Field(None, description="Two-step verification code or secret key (optional, required for sites with 2FA enabled)")
|
||||
two_step_code: Optional[str] = Field(
|
||||
None,
|
||||
description="Two-step verification code or secret key (optional, required for sites with 2FA enabled)",
|
||||
)
|
||||
|
||||
|
||||
class UpdateSiteCookieTool(MoviePilotTool):
|
||||
name: str = "update_site_cookie"
|
||||
description: str = "Update site Cookie and User-Agent by logging in with username and password. This tool can automatically obtain and update the site's authentication credentials. Supports two-step verification for sites that require it. Accepts site ID only."
|
||||
args_schema: Type[BaseModel] = UpdateSiteCookieInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据更新参数生成友好的提示消息"""
|
||||
site_identifier = kwargs.get("site_identifier")
|
||||
username = kwargs.get("username", "")
|
||||
two_step_code = kwargs.get("two_step_code")
|
||||
|
||||
|
||||
message = f"正在更新站点Cookie: {site_identifier} (用户: {username})"
|
||||
if two_step_code:
|
||||
message += " [需要两步验证]"
|
||||
|
||||
|
||||
return message
|
||||
|
||||
async def run(self, site_identifier: int, username: str, password: str,
|
||||
two_step_code: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: site_identifier={site_identifier}, username={username}")
|
||||
async def run(
|
||||
self,
|
||||
site_identifier: int,
|
||||
username: str,
|
||||
password: str,
|
||||
two_step_code: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: site_identifier={site_identifier}, username={username}"
|
||||
)
|
||||
|
||||
try:
|
||||
site_oper = SiteOper()
|
||||
site_chain = SiteChain()
|
||||
site = await site_oper.async_get(site_identifier)
|
||||
|
||||
|
||||
if not site:
|
||||
return f"未找到站点:{site_identifier},请使用 query_sites 工具查询可用的站点"
|
||||
|
||||
|
||||
# 更新站点Cookie和UA
|
||||
status, message = site_chain.update_cookie(
|
||||
site_info=site,
|
||||
username=username,
|
||||
password=password,
|
||||
two_step_code=two_step_code
|
||||
two_step_code=two_step_code,
|
||||
)
|
||||
|
||||
|
||||
if status:
|
||||
return f"站点【{site.name}】Cookie和UA更新成功\n{message}"
|
||||
else:
|
||||
@@ -63,4 +82,3 @@ class UpdateSiteCookieTool(MoviePilotTool):
|
||||
except Exception as e:
|
||||
logger.error(f"更新站点Cookie和UA失败: {e}", exc_info=True)
|
||||
return f"更新站点Cookie和UA时发生错误: {str(e)}"
|
||||
|
||||
|
||||
@@ -15,40 +15,87 @@ from app.schemas.types import EventType
|
||||
|
||||
class UpdateSubscribeInput(BaseModel):
|
||||
"""更新订阅工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
subscribe_id: int = Field(..., description="The ID of the subscription to update (can be obtained from query_subscribes tool)")
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
subscribe_id: int = Field(
|
||||
...,
|
||||
description="The ID of the subscription to update (can be obtained from query_subscribes tool)",
|
||||
)
|
||||
name: Optional[str] = Field(None, description="Subscription name/title (optional)")
|
||||
year: Optional[str] = Field(None, description="Release year (optional)")
|
||||
season: Optional[int] = Field(None, description="Season number for TV shows (optional)")
|
||||
total_episode: Optional[int] = Field(None, description="Total number of episodes (optional)")
|
||||
lack_episode: Optional[int] = Field(None, description="Number of missing episodes (optional)")
|
||||
start_episode: Optional[int] = Field(None, description="Starting episode number (optional)")
|
||||
quality: Optional[str] = Field(None, description="Quality filter as regular expression (optional, e.g., 'BluRay|WEB-DL|HDTV')")
|
||||
resolution: Optional[str] = Field(None, description="Resolution filter as regular expression (optional, e.g., '1080p|720p|2160p')")
|
||||
effect: Optional[str] = Field(None, description="Effect filter as regular expression (optional, e.g., 'HDR|DV|SDR')")
|
||||
include: Optional[str] = Field(None, description="Include filter as regular expression (optional)")
|
||||
exclude: Optional[str] = Field(None, description="Exclude filter as regular expression (optional)")
|
||||
filter: Optional[str] = Field(None, description="Filter rule as regular expression (optional)")
|
||||
state: Optional[str] = Field(None, description="Subscription state: 'R' for enabled, 'P' for pending, 'S' for paused (optional)")
|
||||
sites: Optional[List[int]] = Field(None, description="List of site IDs to search from (optional)")
|
||||
season: Optional[int] = Field(
|
||||
None, description="Season number for TV shows (optional)"
|
||||
)
|
||||
total_episode: Optional[int] = Field(
|
||||
None, description="Total number of episodes (optional)"
|
||||
)
|
||||
lack_episode: Optional[int] = Field(
|
||||
None, description="Number of missing episodes (optional)"
|
||||
)
|
||||
start_episode: Optional[int] = Field(
|
||||
None, description="Starting episode number (optional)"
|
||||
)
|
||||
quality: Optional[str] = Field(
|
||||
None,
|
||||
description="Quality filter as regular expression (optional, e.g., 'BluRay|WEB-DL|HDTV')",
|
||||
)
|
||||
resolution: Optional[str] = Field(
|
||||
None,
|
||||
description="Resolution filter as regular expression (optional, e.g., '1080p|720p|2160p')",
|
||||
)
|
||||
effect: Optional[str] = Field(
|
||||
None,
|
||||
description="Effect filter as regular expression (optional, e.g., 'HDR|DV|SDR')",
|
||||
)
|
||||
include: Optional[str] = Field(
|
||||
None, description="Include filter as regular expression (optional)"
|
||||
)
|
||||
exclude: Optional[str] = Field(
|
||||
None, description="Exclude filter as regular expression (optional)"
|
||||
)
|
||||
filter: Optional[str] = Field(
|
||||
None, description="Filter rule as regular expression (optional)"
|
||||
)
|
||||
state: Optional[str] = Field(
|
||||
None,
|
||||
description="Subscription state: 'R' for enabled, 'P' for pending, 'S' for paused (optional)",
|
||||
)
|
||||
sites: Optional[List[int]] = Field(
|
||||
None, description="List of site IDs to search from (optional)"
|
||||
)
|
||||
downloader: Optional[str] = Field(None, description="Downloader name (optional)")
|
||||
save_path: Optional[str] = Field(None, description="Save path for downloaded files (optional)")
|
||||
best_version: Optional[int] = Field(None, description="Whether to upgrade to best version: 0 for no, 1 for yes (optional)")
|
||||
custom_words: Optional[str] = Field(None, description="Custom recognition words (optional)")
|
||||
media_category: Optional[str] = Field(None, description="Custom media category (optional)")
|
||||
episode_group: Optional[str] = Field(None, description="Episode group ID (optional)")
|
||||
save_path: Optional[str] = Field(
|
||||
None, description="Save path for downloaded files (optional)"
|
||||
)
|
||||
best_version: Optional[int] = Field(
|
||||
None,
|
||||
description="Whether to upgrade to best version: 0 for no, 1 for yes (optional)",
|
||||
)
|
||||
custom_words: Optional[str] = Field(
|
||||
None, description="Custom recognition words (optional)"
|
||||
)
|
||||
media_category: Optional[str] = Field(
|
||||
None, description="Custom media category (optional)"
|
||||
)
|
||||
episode_group: Optional[str] = Field(
|
||||
None, description="Episode group ID (optional)"
|
||||
)
|
||||
|
||||
|
||||
class UpdateSubscribeTool(MoviePilotTool):
|
||||
name: str = "update_subscribe"
|
||||
description: str = "Update subscription properties including filters, episode counts, state, and other settings. Supports updating quality/resolution filters, episode tracking, subscription state, and download configuration."
|
||||
args_schema: Type[BaseModel] = UpdateSubscribeInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据更新参数生成友好的提示消息"""
|
||||
subscribe_id = kwargs.get("subscribe_id")
|
||||
fields_updated = []
|
||||
|
||||
|
||||
if kwargs.get("name"):
|
||||
fields_updated.append("名称")
|
||||
if kwargs.get("total_episode") is not None:
|
||||
@@ -61,57 +108,62 @@ class UpdateSubscribeTool(MoviePilotTool):
|
||||
fields_updated.append("分辨率过滤")
|
||||
if kwargs.get("state"):
|
||||
state_map = {"R": "启用", "P": "禁用", "S": "暂停"}
|
||||
fields_updated.append(f"状态({state_map.get(kwargs.get('state'), kwargs.get('state'))})")
|
||||
fields_updated.append(
|
||||
f"状态({state_map.get(kwargs.get('state'), kwargs.get('state'))})"
|
||||
)
|
||||
if kwargs.get("sites"):
|
||||
fields_updated.append("站点")
|
||||
if kwargs.get("downloader"):
|
||||
fields_updated.append("下载器")
|
||||
|
||||
|
||||
if fields_updated:
|
||||
return f"正在更新订阅 #{subscribe_id}: {', '.join(fields_updated)}"
|
||||
return f"正在更新订阅 #{subscribe_id}"
|
||||
|
||||
async def run(self, subscribe_id: int,
|
||||
name: Optional[str] = None,
|
||||
year: Optional[str] = None,
|
||||
season: Optional[int] = None,
|
||||
total_episode: Optional[int] = None,
|
||||
lack_episode: Optional[int] = None,
|
||||
start_episode: Optional[int] = None,
|
||||
quality: Optional[str] = None,
|
||||
resolution: Optional[str] = None,
|
||||
effect: Optional[str] = None,
|
||||
include: Optional[str] = None,
|
||||
exclude: Optional[str] = None,
|
||||
filter: Optional[str] = None,
|
||||
state: Optional[str] = None,
|
||||
sites: Optional[List[int]] = None,
|
||||
downloader: Optional[str] = None,
|
||||
save_path: Optional[str] = None,
|
||||
best_version: Optional[int] = None,
|
||||
custom_words: Optional[str] = None,
|
||||
media_category: Optional[str] = None,
|
||||
episode_group: Optional[str] = None,
|
||||
**kwargs) -> str:
|
||||
async def run(
|
||||
self,
|
||||
subscribe_id: int,
|
||||
name: Optional[str] = None,
|
||||
year: Optional[str] = None,
|
||||
season: Optional[int] = None,
|
||||
total_episode: Optional[int] = None,
|
||||
lack_episode: Optional[int] = None,
|
||||
start_episode: Optional[int] = None,
|
||||
quality: Optional[str] = None,
|
||||
resolution: Optional[str] = None,
|
||||
effect: Optional[str] = None,
|
||||
include: Optional[str] = None,
|
||||
exclude: Optional[str] = None,
|
||||
filter: Optional[str] = None,
|
||||
state: Optional[str] = None,
|
||||
sites: Optional[List[int]] = None,
|
||||
downloader: Optional[str] = None,
|
||||
save_path: Optional[str] = None,
|
||||
best_version: Optional[int] = None,
|
||||
custom_words: Optional[str] = None,
|
||||
media_category: Optional[str] = None,
|
||||
episode_group: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: subscribe_id={subscribe_id}")
|
||||
|
||||
|
||||
try:
|
||||
# 获取数据库会话
|
||||
async with AsyncSessionFactory() as db:
|
||||
# 获取订阅
|
||||
subscribe = await Subscribe.async_get(db, subscribe_id)
|
||||
if not subscribe:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": f"订阅不存在: {subscribe_id}"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
return json.dumps(
|
||||
{"success": False, "message": f"订阅不存在: {subscribe_id}"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
# 保存旧数据用于事件
|
||||
old_subscribe_dict = subscribe.to_dict()
|
||||
|
||||
|
||||
# 构建更新字典
|
||||
subscribe_dict = {}
|
||||
|
||||
|
||||
# 基本信息
|
||||
if name is not None:
|
||||
subscribe_dict["name"] = name
|
||||
@@ -119,27 +171,29 @@ class UpdateSubscribeTool(MoviePilotTool):
|
||||
subscribe_dict["year"] = year
|
||||
if season is not None:
|
||||
subscribe_dict["season"] = season
|
||||
|
||||
|
||||
# 集数相关
|
||||
if total_episode is not None:
|
||||
subscribe_dict["total_episode"] = total_episode
|
||||
# 如果总集数增加,缺失集数也要相应增加
|
||||
if total_episode > (subscribe.total_episode or 0):
|
||||
old_lack = subscribe.lack_episode or 0
|
||||
subscribe_dict["lack_episode"] = old_lack + (total_episode - (subscribe.total_episode or 0))
|
||||
subscribe_dict["lack_episode"] = old_lack + (
|
||||
total_episode - (subscribe.total_episode or 0)
|
||||
)
|
||||
# 标记为手动修改过总集数
|
||||
subscribe_dict["manual_total_episode"] = 1
|
||||
|
||||
|
||||
# 缺失集数处理(只有在没有提供总集数时才单独处理)
|
||||
# 注意:如果 lack_episode 为 0,不更新(避免更新为0)
|
||||
if lack_episode is not None and total_episode is None:
|
||||
if lack_episode > 0:
|
||||
subscribe_dict["lack_episode"] = lack_episode
|
||||
# 如果 lack_episode 为 0,不添加到更新字典中(保持原值或由总集数逻辑处理)
|
||||
|
||||
|
||||
if start_episode is not None:
|
||||
subscribe_dict["start_episode"] = start_episode
|
||||
|
||||
|
||||
# 过滤规则
|
||||
if quality is not None:
|
||||
subscribe_dict["quality"] = quality
|
||||
@@ -153,17 +207,20 @@ class UpdateSubscribeTool(MoviePilotTool):
|
||||
subscribe_dict["exclude"] = exclude
|
||||
if filter is not None:
|
||||
subscribe_dict["filter"] = filter
|
||||
|
||||
|
||||
# 状态
|
||||
if state is not None:
|
||||
valid_states = ["R", "P", "S", "N"]
|
||||
if state not in valid_states:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": f"无效的订阅状态: {state},有效状态: {', '.join(valid_states)}"
|
||||
}, ensure_ascii=False)
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"无效的订阅状态: {state},有效状态: {', '.join(valid_states)}",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
subscribe_dict["state"] = state
|
||||
|
||||
|
||||
# 下载配置
|
||||
if sites is not None:
|
||||
subscribe_dict["sites"] = sites
|
||||
@@ -173,7 +230,7 @@ class UpdateSubscribeTool(MoviePilotTool):
|
||||
subscribe_dict["save_path"] = save_path
|
||||
if best_version is not None:
|
||||
subscribe_dict["best_version"] = best_version
|
||||
|
||||
|
||||
# 其他配置
|
||||
if custom_words is not None:
|
||||
subscribe_dict["custom_words"] = custom_words
|
||||
@@ -181,35 +238,40 @@ class UpdateSubscribeTool(MoviePilotTool):
|
||||
subscribe_dict["media_category"] = media_category
|
||||
if episode_group is not None:
|
||||
subscribe_dict["episode_group"] = episode_group
|
||||
|
||||
|
||||
# 如果没有要更新的字段
|
||||
if not subscribe_dict:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": "没有提供要更新的字段"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
return json.dumps(
|
||||
{"success": False, "message": "没有提供要更新的字段"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
# 更新订阅
|
||||
await subscribe.async_update(db, subscribe_dict)
|
||||
|
||||
|
||||
# 重新获取更新后的订阅数据
|
||||
updated_subscribe = await Subscribe.async_get(db, subscribe_id)
|
||||
|
||||
|
||||
# 发送订阅调整事件
|
||||
await eventmanager.async_send_event(EventType.SubscribeModified, {
|
||||
"subscribe_id": subscribe_id,
|
||||
"old_subscribe_info": old_subscribe_dict,
|
||||
"subscribe_info": updated_subscribe.to_dict() if updated_subscribe else {},
|
||||
})
|
||||
|
||||
await eventmanager.async_send_event(
|
||||
EventType.SubscribeModified,
|
||||
{
|
||||
"subscribe_id": subscribe_id,
|
||||
"old_subscribe_info": old_subscribe_dict,
|
||||
"subscribe_info": updated_subscribe.to_dict()
|
||||
if updated_subscribe
|
||||
else {},
|
||||
},
|
||||
)
|
||||
|
||||
# 构建返回结果
|
||||
result = {
|
||||
"success": True,
|
||||
"message": f"订阅 #{subscribe_id} 更新成功",
|
||||
"subscribe_id": subscribe_id,
|
||||
"updated_fields": list(subscribe_dict.keys())
|
||||
"updated_fields": list(subscribe_dict.keys()),
|
||||
}
|
||||
|
||||
|
||||
if updated_subscribe:
|
||||
result["subscribe"] = {
|
||||
"id": updated_subscribe.id,
|
||||
@@ -223,17 +285,19 @@ class UpdateSubscribeTool(MoviePilotTool):
|
||||
"start_episode": updated_subscribe.start_episode,
|
||||
"quality": updated_subscribe.quality,
|
||||
"resolution": updated_subscribe.resolution,
|
||||
"effect": updated_subscribe.effect
|
||||
"effect": updated_subscribe.effect,
|
||||
}
|
||||
|
||||
|
||||
return json.dumps(result, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"更新订阅失败: {str(e)}"
|
||||
logger.error(f"更新订阅失败: {e}", exc_info=True)
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": error_message,
|
||||
"subscribe_id": subscribe_id
|
||||
}, ensure_ascii=False)
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"message": error_message,
|
||||
"subscribe_id": subscribe_id,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
@@ -12,6 +12,7 @@ from app.log import logger
|
||||
|
||||
class WriteFileInput(BaseModel):
|
||||
"""Input parameters for write file tool"""
|
||||
|
||||
file_path: str = Field(..., description="The absolute path of the file to write")
|
||||
content: str = Field(..., description="The content to write into the file")
|
||||
|
||||
@@ -20,6 +21,7 @@ class WriteFileTool(MoviePilotTool):
|
||||
name: str = "write_file"
|
||||
description: str = "Write full content to a file. If the file already exists, it will be overwritten. Automatically creates parent directories if they don't exist."
|
||||
args_schema: Type[BaseModel] = WriteFileInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据参数生成友好的提示消息"""
|
||||
@@ -32,16 +34,16 @@ class WriteFileTool(MoviePilotTool):
|
||||
|
||||
try:
|
||||
path = AsyncPath(file_path)
|
||||
|
||||
|
||||
if await path.exists() and not await path.is_file():
|
||||
return f"错误:{file_path} 路径已存在但不是一个文件"
|
||||
|
||||
# 自动创建父目录
|
||||
await path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
# 写入文件
|
||||
await path.write_text(content, encoding="utf-8")
|
||||
|
||||
|
||||
logger.info(f"成功写入文件 {file_path}")
|
||||
return f"成功写入文件 {file_path}"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user