mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-02-13 23:45:05 +08:00
204 lines
9.1 KiB
Python
204 lines
9.1 KiB
Python
"""更新站点工具"""
|
||
|
||
import json
|
||
from typing import Optional, Type
|
||
|
||
from pydantic import BaseModel, Field
|
||
|
||
from app.agent.tools.base import MoviePilotTool
|
||
from app.core.event import eventmanager
|
||
from app.db import AsyncSessionFactory
|
||
from app.db.models.site import Site
|
||
from app.log import logger
|
||
from app.schemas.types import EventType
|
||
from app.utils.string import StringUtils
|
||
|
||
|
||
class UpdateSiteInput(BaseModel):
|
||
"""更新站点工具的输入参数模型"""
|
||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||
site_id: int = Field(..., description="The ID of the site to update")
|
||
name: Optional[str] = Field(None, description="Site name (optional)")
|
||
url: Optional[str] = Field(None, description="Site URL (optional, will be automatically formatted)")
|
||
pri: Optional[int] = Field(None, description="Site priority (optional, smaller value = higher priority, e.g., pri=1 has higher priority than pri=10)")
|
||
rss: Optional[str] = Field(None, description="RSS feed URL (optional)")
|
||
cookie: Optional[str] = Field(None, description="Site cookie (optional)")
|
||
ua: Optional[str] = Field(None, description="User-Agent string (optional)")
|
||
apikey: Optional[str] = Field(None, description="API key (optional)")
|
||
token: Optional[str] = Field(None, description="API token (optional)")
|
||
proxy: Optional[int] = Field(None, description="Whether to use proxy: 0 for no, 1 for yes (optional)")
|
||
filter: Optional[str] = Field(None, description="Filter rule as regular expression (optional)")
|
||
note: Optional[str] = Field(None, description="Site notes/remarks (optional)")
|
||
timeout: Optional[int] = Field(None, description="Request timeout in seconds (optional, default: 15)")
|
||
limit_interval: Optional[int] = Field(None, description="Rate limit interval in seconds (optional)")
|
||
limit_count: Optional[int] = Field(None, description="Rate limit count per interval (optional)")
|
||
limit_seconds: Optional[int] = Field(None, description="Rate limit seconds between requests (optional)")
|
||
is_active: Optional[bool] = Field(None, description="Whether site is active: True for enabled, False for disabled (optional)")
|
||
downloader: Optional[str] = Field(None, description="Downloader name for this site (optional)")
|
||
|
||
|
||
class UpdateSiteTool(MoviePilotTool):
|
||
name: str = "update_site"
|
||
description: str = "Update site configuration including URL, priority, authentication credentials (cookie, UA, API key), proxy settings, rate limits, and other site properties. Supports updating multiple site attributes at once. Site priority (pri): smaller values have higher priority (e.g., pri=1 has higher priority than pri=10)."
|
||
args_schema: Type[BaseModel] = UpdateSiteInput
|
||
|
||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||
"""根据更新参数生成友好的提示消息"""
|
||
site_id = kwargs.get("site_id")
|
||
fields_updated = []
|
||
|
||
if kwargs.get("name"):
|
||
fields_updated.append("名称")
|
||
if kwargs.get("url"):
|
||
fields_updated.append("URL")
|
||
if kwargs.get("pri") is not None:
|
||
fields_updated.append("优先级")
|
||
if kwargs.get("cookie"):
|
||
fields_updated.append("Cookie")
|
||
if kwargs.get("ua"):
|
||
fields_updated.append("User-Agent")
|
||
if kwargs.get("proxy") is not None:
|
||
fields_updated.append("代理设置")
|
||
if kwargs.get("is_active") is not None:
|
||
fields_updated.append("启用状态")
|
||
if kwargs.get("downloader"):
|
||
fields_updated.append("下载器")
|
||
|
||
if fields_updated:
|
||
return f"正在更新站点 #{site_id}: {', '.join(fields_updated)}"
|
||
return f"正在更新站点 #{site_id}"
|
||
|
||
async def run(self, site_id: int,
|
||
name: Optional[str] = None,
|
||
url: Optional[str] = None,
|
||
pri: Optional[int] = None,
|
||
rss: Optional[str] = None,
|
||
cookie: Optional[str] = None,
|
||
ua: Optional[str] = None,
|
||
apikey: Optional[str] = None,
|
||
token: Optional[str] = None,
|
||
proxy: Optional[int] = None,
|
||
filter: Optional[str] = None,
|
||
note: Optional[str] = None,
|
||
timeout: Optional[int] = None,
|
||
limit_interval: Optional[int] = None,
|
||
limit_count: Optional[int] = None,
|
||
limit_seconds: Optional[int] = None,
|
||
is_active: Optional[bool] = None,
|
||
downloader: Optional[str] = None,
|
||
**kwargs) -> str:
|
||
logger.info(f"执行工具: {self.name}, 参数: site_id={site_id}")
|
||
|
||
try:
|
||
# 获取数据库会话
|
||
async with AsyncSessionFactory() as db:
|
||
# 获取站点
|
||
site = await Site.async_get(db, site_id)
|
||
if not site:
|
||
return json.dumps({
|
||
"success": False,
|
||
"message": f"站点不存在: {site_id}"
|
||
}, ensure_ascii=False)
|
||
|
||
# 构建更新字典
|
||
site_dict = {}
|
||
|
||
# 基本信息
|
||
if name is not None:
|
||
site_dict["name"] = name
|
||
|
||
# URL处理(需要校正格式)
|
||
if url is not None:
|
||
_scheme, _netloc = StringUtils.get_url_netloc(url)
|
||
site_dict["url"] = f"{_scheme}://{_netloc}/"
|
||
|
||
if pri is not None:
|
||
site_dict["pri"] = pri
|
||
if rss is not None:
|
||
site_dict["rss"] = rss
|
||
|
||
# 认证信息
|
||
if cookie is not None:
|
||
site_dict["cookie"] = cookie
|
||
if ua is not None:
|
||
site_dict["ua"] = ua
|
||
if apikey is not None:
|
||
site_dict["apikey"] = apikey
|
||
if token is not None:
|
||
site_dict["token"] = token
|
||
|
||
# 配置选项
|
||
if proxy is not None:
|
||
site_dict["proxy"] = proxy
|
||
if filter is not None:
|
||
site_dict["filter"] = filter
|
||
if note is not None:
|
||
site_dict["note"] = note
|
||
if timeout is not None:
|
||
site_dict["timeout"] = timeout
|
||
|
||
# 流控设置
|
||
if limit_interval is not None:
|
||
site_dict["limit_interval"] = limit_interval
|
||
if limit_count is not None:
|
||
site_dict["limit_count"] = limit_count
|
||
if limit_seconds is not None:
|
||
site_dict["limit_seconds"] = limit_seconds
|
||
|
||
# 状态和下载器
|
||
if is_active is not None:
|
||
site_dict["is_active"] = is_active
|
||
if downloader is not None:
|
||
site_dict["downloader"] = downloader
|
||
|
||
# 如果没有要更新的字段
|
||
if not site_dict:
|
||
return json.dumps({
|
||
"success": False,
|
||
"message": "没有提供要更新的字段"
|
||
}, ensure_ascii=False)
|
||
|
||
# 更新站点
|
||
await site.async_update(db, site_dict)
|
||
|
||
# 重新获取更新后的站点数据
|
||
updated_site = await Site.async_get(db, site_id)
|
||
|
||
# 发送站点更新事件
|
||
await eventmanager.async_send_event(EventType.SiteUpdated, {
|
||
"domain": updated_site.domain if updated_site else site.domain
|
||
})
|
||
|
||
# 构建返回结果
|
||
result = {
|
||
"success": True,
|
||
"message": f"站点 #{site_id} 更新成功",
|
||
"site_id": site_id,
|
||
"updated_fields": list(site_dict.keys())
|
||
}
|
||
|
||
if updated_site:
|
||
result["site"] = {
|
||
"id": updated_site.id,
|
||
"name": updated_site.name,
|
||
"domain": updated_site.domain,
|
||
"url": updated_site.url,
|
||
"pri": updated_site.pri,
|
||
"is_active": updated_site.is_active,
|
||
"downloader": updated_site.downloader,
|
||
"proxy": updated_site.proxy,
|
||
"timeout": updated_site.timeout
|
||
}
|
||
|
||
return json.dumps(result, ensure_ascii=False, indent=2)
|
||
|
||
except Exception as e:
|
||
error_message = f"更新站点失败: {str(e)}"
|
||
logger.error(f"更新站点失败: {e}", exc_info=True)
|
||
return json.dumps({
|
||
"success": False,
|
||
"message": error_message,
|
||
"site_id": site_id
|
||
}, ensure_ascii=False)
|
||
|