mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-03-20 03:57:30 +08:00
feat: add UpdateSubscribeTool and additional query tools to MoviePilotToolFactory
- Included UpdateSubscribeTool and QuerySiteUserdataTool in the tool definitions of MoviePilotToolFactory to enhance subscription management and user data querying capabilities. - Updated the factory.py file to reflect the addition of these new tools, improving overall functionality.
This commit is contained in:
@@ -4,11 +4,14 @@ from typing import List, Callable
|
|||||||
|
|
||||||
from app.agent.tools.impl.add_download import AddDownloadTool
|
from app.agent.tools.impl.add_download import AddDownloadTool
|
||||||
from app.agent.tools.impl.add_subscribe import AddSubscribeTool
|
from app.agent.tools.impl.add_subscribe import AddSubscribeTool
|
||||||
|
from app.agent.tools.impl.update_subscribe import UpdateSubscribeTool
|
||||||
from app.agent.tools.impl.get_recommendations import GetRecommendationsTool
|
from app.agent.tools.impl.get_recommendations import GetRecommendationsTool
|
||||||
from app.agent.tools.impl.query_downloaders import QueryDownloadersTool
|
from app.agent.tools.impl.query_downloaders import QueryDownloadersTool
|
||||||
from app.agent.tools.impl.query_downloads import QueryDownloadsTool
|
from app.agent.tools.impl.query_downloads import QueryDownloadsTool
|
||||||
from app.agent.tools.impl.query_media_library import QueryMediaLibraryTool
|
from app.agent.tools.impl.query_media_library import QueryMediaLibraryTool
|
||||||
from app.agent.tools.impl.query_sites import QuerySitesTool
|
from app.agent.tools.impl.query_sites import QuerySitesTool
|
||||||
|
from app.agent.tools.impl.update_site import UpdateSiteTool
|
||||||
|
from app.agent.tools.impl.query_site_userdata import QuerySiteUserdataTool
|
||||||
from app.agent.tools.impl.test_site import TestSiteTool
|
from app.agent.tools.impl.test_site import TestSiteTool
|
||||||
from app.agent.tools.impl.query_subscribes import QuerySubscribesTool
|
from app.agent.tools.impl.query_subscribes import QuerySubscribesTool
|
||||||
from app.agent.tools.impl.query_subscribe_shares import QuerySubscribeSharesTool
|
from app.agent.tools.impl.query_subscribe_shares import QuerySubscribeSharesTool
|
||||||
@@ -49,6 +52,7 @@ class MoviePilotToolFactory:
|
|||||||
RecognizeMediaTool,
|
RecognizeMediaTool,
|
||||||
ScrapeMetadataTool,
|
ScrapeMetadataTool,
|
||||||
AddSubscribeTool,
|
AddSubscribeTool,
|
||||||
|
UpdateSubscribeTool,
|
||||||
SearchTorrentsTool,
|
SearchTorrentsTool,
|
||||||
AddDownloadTool,
|
AddDownloadTool,
|
||||||
QuerySubscribesTool,
|
QuerySubscribesTool,
|
||||||
@@ -60,6 +64,8 @@ class MoviePilotToolFactory:
|
|||||||
DeleteDownloadTool,
|
DeleteDownloadTool,
|
||||||
QueryDownloadersTool,
|
QueryDownloadersTool,
|
||||||
QuerySitesTool,
|
QuerySitesTool,
|
||||||
|
UpdateSiteTool,
|
||||||
|
QuerySiteUserdataTool,
|
||||||
TestSiteTool,
|
TestSiteTool,
|
||||||
UpdateSiteCookieTool,
|
UpdateSiteCookieTool,
|
||||||
GetRecommendationsTool,
|
GetRecommendationsTool,
|
||||||
|
|||||||
136
app/agent/tools/impl/query_site_userdata.py
Normal file
136
app/agent/tools/impl/query_site_userdata.py
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
"""查询站点用户数据工具"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Optional, Type
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from app.agent.tools.base import MoviePilotTool
|
||||||
|
from app.db import AsyncSessionFactory
|
||||||
|
from app.db.models.site import Site
|
||||||
|
from app.db.models.siteuserdata import SiteUserData
|
||||||
|
from app.log import logger
|
||||||
|
|
||||||
|
|
||||||
|
class QuerySiteUserdataInput(BaseModel):
|
||||||
|
"""查询站点用户数据工具的输入参数模型"""
|
||||||
|
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||||
|
site_id: int = Field(..., description="The ID of the site to query user data for")
|
||||||
|
workdate: Optional[str] = Field(None, description="Work date to query (optional, format: 'YYYY-MM-DD', if not specified returns latest data)")
|
||||||
|
|
||||||
|
|
||||||
|
class QuerySiteUserdataTool(MoviePilotTool):
|
||||||
|
name: str = "query_site_userdata"
|
||||||
|
description: str = "Query user data for a specific site including username, user level, upload/download statistics, seeding information, bonus points, and other account details. Supports querying data for a specific date or latest data."
|
||||||
|
args_schema: Type[BaseModel] = QuerySiteUserdataInput
|
||||||
|
|
||||||
|
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||||
|
"""根据查询参数生成友好的提示消息"""
|
||||||
|
site_id = kwargs.get("site_id")
|
||||||
|
workdate = kwargs.get("workdate")
|
||||||
|
|
||||||
|
message = f"正在查询站点 #{site_id} 的用户数据"
|
||||||
|
if workdate:
|
||||||
|
message += f" (日期: {workdate})"
|
||||||
|
else:
|
||||||
|
message += " (最新数据)"
|
||||||
|
|
||||||
|
return message
|
||||||
|
|
||||||
|
async def run(self, site_id: int, workdate: Optional[str] = None, **kwargs) -> str:
|
||||||
|
logger.info(f"执行工具: {self.name}, 参数: site_id={site_id}, workdate={workdate}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 获取数据库会话
|
||||||
|
async with AsyncSessionFactory() as db:
|
||||||
|
# 获取站点
|
||||||
|
site = await Site.async_get(db, site_id)
|
||||||
|
if not site:
|
||||||
|
return json.dumps({
|
||||||
|
"success": False,
|
||||||
|
"message": f"站点不存在: {site_id}"
|
||||||
|
}, ensure_ascii=False)
|
||||||
|
|
||||||
|
# 获取站点用户数据
|
||||||
|
user_data_list = await SiteUserData.async_get_by_domain(
|
||||||
|
db,
|
||||||
|
domain=site.domain,
|
||||||
|
workdate=workdate
|
||||||
|
)
|
||||||
|
|
||||||
|
if not user_data_list:
|
||||||
|
return json.dumps({
|
||||||
|
"success": False,
|
||||||
|
"message": f"站点 {site.name} ({site.domain}) 暂无用户数据",
|
||||||
|
"site_id": site_id,
|
||||||
|
"site_name": site.name,
|
||||||
|
"site_domain": site.domain,
|
||||||
|
"workdate": workdate
|
||||||
|
}, ensure_ascii=False)
|
||||||
|
|
||||||
|
# 格式化用户数据
|
||||||
|
result = {
|
||||||
|
"success": True,
|
||||||
|
"site_id": site_id,
|
||||||
|
"site_name": site.name,
|
||||||
|
"site_domain": site.domain,
|
||||||
|
"workdate": workdate,
|
||||||
|
"data_count": len(user_data_list),
|
||||||
|
"user_data": []
|
||||||
|
}
|
||||||
|
|
||||||
|
for user_data in user_data_list:
|
||||||
|
# 格式化上传/下载量(转换为可读格式)
|
||||||
|
upload_gb = user_data.upload / (1024 ** 3) if user_data.upload else 0
|
||||||
|
download_gb = user_data.download / (1024 ** 3) if user_data.download else 0
|
||||||
|
seeding_size_gb = user_data.seeding_size / (1024 ** 3) if user_data.seeding_size else 0
|
||||||
|
leeching_size_gb = user_data.leeching_size / (1024 ** 3) if user_data.leeching_size else 0
|
||||||
|
|
||||||
|
user_data_dict = {
|
||||||
|
"domain": user_data.domain,
|
||||||
|
"name": user_data.name,
|
||||||
|
"username": user_data.username,
|
||||||
|
"userid": user_data.userid,
|
||||||
|
"user_level": user_data.user_level,
|
||||||
|
"join_at": user_data.join_at,
|
||||||
|
"bonus": user_data.bonus,
|
||||||
|
"upload": user_data.upload,
|
||||||
|
"upload_gb": round(upload_gb, 2),
|
||||||
|
"download": user_data.download,
|
||||||
|
"download_gb": round(download_gb, 2),
|
||||||
|
"ratio": round(user_data.ratio, 2) if user_data.ratio else 0,
|
||||||
|
"seeding": int(user_data.seeding) if user_data.seeding else 0,
|
||||||
|
"leeching": int(user_data.leeching) if user_data.leeching else 0,
|
||||||
|
"seeding_size": user_data.seeding_size,
|
||||||
|
"seeding_size_gb": round(seeding_size_gb, 2),
|
||||||
|
"leeching_size": user_data.leeching_size,
|
||||||
|
"leeching_size_gb": round(leeching_size_gb, 2),
|
||||||
|
"seeding_info": user_data.seeding_info if user_data.seeding_info else [],
|
||||||
|
"message_unread": user_data.message_unread,
|
||||||
|
"message_unread_contents": user_data.message_unread_contents if user_data.message_unread_contents else [],
|
||||||
|
"err_msg": user_data.err_msg,
|
||||||
|
"updated_day": user_data.updated_day,
|
||||||
|
"updated_time": user_data.updated_time
|
||||||
|
}
|
||||||
|
result["user_data"].append(user_data_dict)
|
||||||
|
|
||||||
|
# 如果有多条数据,只返回最新的(按更新时间排序)
|
||||||
|
if len(result["user_data"]) > 1:
|
||||||
|
result["user_data"].sort(
|
||||||
|
key=lambda x: (x.get("updated_day", ""), x.get("updated_time", "")),
|
||||||
|
reverse=True
|
||||||
|
)
|
||||||
|
result["message"] = f"找到 {len(result['user_data'])} 条数据,显示最新的一条"
|
||||||
|
result["user_data"] = [result["user_data"][0]]
|
||||||
|
|
||||||
|
return json.dumps(result, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_message = f"查询站点用户数据失败: {str(e)}"
|
||||||
|
logger.error(f"查询站点用户数据失败: {e}", exc_info=True)
|
||||||
|
return json.dumps({
|
||||||
|
"success": False,
|
||||||
|
"message": error_message,
|
||||||
|
"site_id": site_id
|
||||||
|
}, ensure_ascii=False)
|
||||||
|
|
||||||
203
app/agent/tools/impl/update_site.py
Normal file
203
app/agent/tools/impl/update_site.py
Normal file
@@ -0,0 +1,203 @@
|
|||||||
|
"""更新站点工具"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Optional, Type
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from app.agent.tools.base import MoviePilotTool
|
||||||
|
from app.core.event import eventmanager
|
||||||
|
from app.db import AsyncSessionFactory
|
||||||
|
from app.db.models.site import Site
|
||||||
|
from app.log import logger
|
||||||
|
from app.schemas.types import EventType
|
||||||
|
from app.utils.string import StringUtils
|
||||||
|
|
||||||
|
|
||||||
|
class UpdateSiteInput(BaseModel):
|
||||||
|
"""更新站点工具的输入参数模型"""
|
||||||
|
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||||
|
site_id: int = Field(..., description="The ID of the site to update")
|
||||||
|
name: Optional[str] = Field(None, description="Site name (optional)")
|
||||||
|
url: Optional[str] = Field(None, description="Site URL (optional, will be automatically formatted)")
|
||||||
|
pri: Optional[int] = Field(None, description="Site priority (optional, higher number = higher priority)")
|
||||||
|
rss: Optional[str] = Field(None, description="RSS feed URL (optional)")
|
||||||
|
cookie: Optional[str] = Field(None, description="Site cookie (optional)")
|
||||||
|
ua: Optional[str] = Field(None, description="User-Agent string (optional)")
|
||||||
|
apikey: Optional[str] = Field(None, description="API key (optional)")
|
||||||
|
token: Optional[str] = Field(None, description="API token (optional)")
|
||||||
|
proxy: Optional[int] = Field(None, description="Whether to use proxy: 0 for no, 1 for yes (optional)")
|
||||||
|
filter: Optional[str] = Field(None, description="Filter rule as regular expression (optional)")
|
||||||
|
note: Optional[str] = Field(None, description="Site notes/remarks (optional)")
|
||||||
|
timeout: Optional[int] = Field(None, description="Request timeout in seconds (optional, default: 15)")
|
||||||
|
limit_interval: Optional[int] = Field(None, description="Rate limit interval in seconds (optional)")
|
||||||
|
limit_count: Optional[int] = Field(None, description="Rate limit count per interval (optional)")
|
||||||
|
limit_seconds: Optional[int] = Field(None, description="Rate limit seconds between requests (optional)")
|
||||||
|
is_active: Optional[bool] = Field(None, description="Whether site is active: True for enabled, False for disabled (optional)")
|
||||||
|
downloader: Optional[str] = Field(None, description="Downloader name for this site (optional)")
|
||||||
|
|
||||||
|
|
||||||
|
class UpdateSiteTool(MoviePilotTool):
|
||||||
|
name: str = "update_site"
|
||||||
|
description: str = "Update site configuration including URL, priority, authentication credentials (cookie, UA, API key), proxy settings, rate limits, and other site properties. Supports updating multiple site attributes at once."
|
||||||
|
args_schema: Type[BaseModel] = UpdateSiteInput
|
||||||
|
|
||||||
|
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||||
|
"""根据更新参数生成友好的提示消息"""
|
||||||
|
site_id = kwargs.get("site_id")
|
||||||
|
fields_updated = []
|
||||||
|
|
||||||
|
if kwargs.get("name"):
|
||||||
|
fields_updated.append("名称")
|
||||||
|
if kwargs.get("url"):
|
||||||
|
fields_updated.append("URL")
|
||||||
|
if kwargs.get("pri") is not None:
|
||||||
|
fields_updated.append("优先级")
|
||||||
|
if kwargs.get("cookie"):
|
||||||
|
fields_updated.append("Cookie")
|
||||||
|
if kwargs.get("ua"):
|
||||||
|
fields_updated.append("User-Agent")
|
||||||
|
if kwargs.get("proxy") is not None:
|
||||||
|
fields_updated.append("代理设置")
|
||||||
|
if kwargs.get("is_active") is not None:
|
||||||
|
fields_updated.append("启用状态")
|
||||||
|
if kwargs.get("downloader"):
|
||||||
|
fields_updated.append("下载器")
|
||||||
|
|
||||||
|
if fields_updated:
|
||||||
|
return f"正在更新站点 #{site_id}: {', '.join(fields_updated)}"
|
||||||
|
return f"正在更新站点 #{site_id}"
|
||||||
|
|
||||||
|
async def run(self, site_id: int,
|
||||||
|
name: Optional[str] = None,
|
||||||
|
url: Optional[str] = None,
|
||||||
|
pri: Optional[int] = None,
|
||||||
|
rss: Optional[str] = None,
|
||||||
|
cookie: Optional[str] = None,
|
||||||
|
ua: Optional[str] = None,
|
||||||
|
apikey: Optional[str] = None,
|
||||||
|
token: Optional[str] = None,
|
||||||
|
proxy: Optional[int] = None,
|
||||||
|
filter: Optional[str] = None,
|
||||||
|
note: Optional[str] = None,
|
||||||
|
timeout: Optional[int] = None,
|
||||||
|
limit_interval: Optional[int] = None,
|
||||||
|
limit_count: Optional[int] = None,
|
||||||
|
limit_seconds: Optional[int] = None,
|
||||||
|
is_active: Optional[bool] = None,
|
||||||
|
downloader: Optional[str] = None,
|
||||||
|
**kwargs) -> str:
|
||||||
|
logger.info(f"执行工具: {self.name}, 参数: site_id={site_id}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 获取数据库会话
|
||||||
|
async with AsyncSessionFactory() as db:
|
||||||
|
# 获取站点
|
||||||
|
site = await Site.async_get(db, site_id)
|
||||||
|
if not site:
|
||||||
|
return json.dumps({
|
||||||
|
"success": False,
|
||||||
|
"message": f"站点不存在: {site_id}"
|
||||||
|
}, ensure_ascii=False)
|
||||||
|
|
||||||
|
# 构建更新字典
|
||||||
|
site_dict = {}
|
||||||
|
|
||||||
|
# 基本信息
|
||||||
|
if name is not None:
|
||||||
|
site_dict["name"] = name
|
||||||
|
|
||||||
|
# URL处理(需要校正格式)
|
||||||
|
if url is not None:
|
||||||
|
_scheme, _netloc = StringUtils.get_url_netloc(url)
|
||||||
|
site_dict["url"] = f"{_scheme}://{_netloc}/"
|
||||||
|
|
||||||
|
if pri is not None:
|
||||||
|
site_dict["pri"] = pri
|
||||||
|
if rss is not None:
|
||||||
|
site_dict["rss"] = rss
|
||||||
|
|
||||||
|
# 认证信息
|
||||||
|
if cookie is not None:
|
||||||
|
site_dict["cookie"] = cookie
|
||||||
|
if ua is not None:
|
||||||
|
site_dict["ua"] = ua
|
||||||
|
if apikey is not None:
|
||||||
|
site_dict["apikey"] = apikey
|
||||||
|
if token is not None:
|
||||||
|
site_dict["token"] = token
|
||||||
|
|
||||||
|
# 配置选项
|
||||||
|
if proxy is not None:
|
||||||
|
site_dict["proxy"] = proxy
|
||||||
|
if filter is not None:
|
||||||
|
site_dict["filter"] = filter
|
||||||
|
if note is not None:
|
||||||
|
site_dict["note"] = note
|
||||||
|
if timeout is not None:
|
||||||
|
site_dict["timeout"] = timeout
|
||||||
|
|
||||||
|
# 流控设置
|
||||||
|
if limit_interval is not None:
|
||||||
|
site_dict["limit_interval"] = limit_interval
|
||||||
|
if limit_count is not None:
|
||||||
|
site_dict["limit_count"] = limit_count
|
||||||
|
if limit_seconds is not None:
|
||||||
|
site_dict["limit_seconds"] = limit_seconds
|
||||||
|
|
||||||
|
# 状态和下载器
|
||||||
|
if is_active is not None:
|
||||||
|
site_dict["is_active"] = is_active
|
||||||
|
if downloader is not None:
|
||||||
|
site_dict["downloader"] = downloader
|
||||||
|
|
||||||
|
# 如果没有要更新的字段
|
||||||
|
if not site_dict:
|
||||||
|
return json.dumps({
|
||||||
|
"success": False,
|
||||||
|
"message": "没有提供要更新的字段"
|
||||||
|
}, ensure_ascii=False)
|
||||||
|
|
||||||
|
# 更新站点
|
||||||
|
await site.async_update(db, site_dict)
|
||||||
|
|
||||||
|
# 重新获取更新后的站点数据
|
||||||
|
updated_site = await Site.async_get(db, site_id)
|
||||||
|
|
||||||
|
# 发送站点更新事件
|
||||||
|
await eventmanager.async_send_event(EventType.SiteUpdated, {
|
||||||
|
"domain": updated_site.domain if updated_site else site.domain
|
||||||
|
})
|
||||||
|
|
||||||
|
# 构建返回结果
|
||||||
|
result = {
|
||||||
|
"success": True,
|
||||||
|
"message": f"站点 #{site_id} 更新成功",
|
||||||
|
"site_id": site_id,
|
||||||
|
"updated_fields": list(site_dict.keys())
|
||||||
|
}
|
||||||
|
|
||||||
|
if updated_site:
|
||||||
|
result["site"] = {
|
||||||
|
"id": updated_site.id,
|
||||||
|
"name": updated_site.name,
|
||||||
|
"domain": updated_site.domain,
|
||||||
|
"url": updated_site.url,
|
||||||
|
"pri": updated_site.pri,
|
||||||
|
"is_active": updated_site.is_active,
|
||||||
|
"downloader": updated_site.downloader,
|
||||||
|
"proxy": updated_site.proxy,
|
||||||
|
"timeout": updated_site.timeout
|
||||||
|
}
|
||||||
|
|
||||||
|
return json.dumps(result, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_message = f"更新站点失败: {str(e)}"
|
||||||
|
logger.error(f"更新站点失败: {e}", exc_info=True)
|
||||||
|
return json.dumps({
|
||||||
|
"success": False,
|
||||||
|
"message": error_message,
|
||||||
|
"site_id": site_id
|
||||||
|
}, ensure_ascii=False)
|
||||||
|
|
||||||
239
app/agent/tools/impl/update_subscribe.py
Normal file
239
app/agent/tools/impl/update_subscribe.py
Normal file
@@ -0,0 +1,239 @@
|
|||||||
|
"""更新订阅工具"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Optional, Type, List
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from app.agent.tools.base import MoviePilotTool
|
||||||
|
from app.core.event import eventmanager
|
||||||
|
from app.db import AsyncSessionFactory
|
||||||
|
from app.db.models.subscribe import Subscribe
|
||||||
|
from app.log import logger
|
||||||
|
from app.schemas.types import EventType
|
||||||
|
|
||||||
|
|
||||||
|
class UpdateSubscribeInput(BaseModel):
|
||||||
|
"""更新订阅工具的输入参数模型"""
|
||||||
|
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||||
|
subscribe_id: int = Field(..., description="The ID of the subscription to update")
|
||||||
|
name: Optional[str] = Field(None, description="Subscription name/title (optional)")
|
||||||
|
year: Optional[str] = Field(None, description="Release year (optional)")
|
||||||
|
season: Optional[int] = Field(None, description="Season number for TV shows (optional)")
|
||||||
|
total_episode: Optional[int] = Field(None, description="Total number of episodes (optional)")
|
||||||
|
lack_episode: Optional[int] = Field(None, description="Number of missing episodes (optional)")
|
||||||
|
start_episode: Optional[int] = Field(None, description="Starting episode number (optional)")
|
||||||
|
quality: Optional[str] = Field(None, description="Quality filter as regular expression (optional, e.g., 'BluRay|WEB-DL|HDTV')")
|
||||||
|
resolution: Optional[str] = Field(None, description="Resolution filter as regular expression (optional, e.g., '1080p|720p|2160p')")
|
||||||
|
effect: Optional[str] = Field(None, description="Effect filter as regular expression (optional, e.g., 'HDR|DV|SDR')")
|
||||||
|
include: Optional[str] = Field(None, description="Include filter as regular expression (optional)")
|
||||||
|
exclude: Optional[str] = Field(None, description="Exclude filter as regular expression (optional)")
|
||||||
|
filter: Optional[str] = Field(None, description="Filter rule as regular expression (optional)")
|
||||||
|
state: Optional[str] = Field(None, description="Subscription state: 'R' for enabled, 'P' for disabled, 'S' for paused (optional)")
|
||||||
|
sites: Optional[List[int]] = Field(None, description="List of site IDs to search from (optional)")
|
||||||
|
downloader: Optional[str] = Field(None, description="Downloader name (optional)")
|
||||||
|
save_path: Optional[str] = Field(None, description="Save path for downloaded files (optional)")
|
||||||
|
best_version: Optional[int] = Field(None, description="Whether to upgrade to best version: 0 for no, 1 for yes (optional)")
|
||||||
|
custom_words: Optional[str] = Field(None, description="Custom recognition words (optional)")
|
||||||
|
media_category: Optional[str] = Field(None, description="Custom media category (optional)")
|
||||||
|
episode_group: Optional[str] = Field(None, description="Episode group ID (optional)")
|
||||||
|
|
||||||
|
|
||||||
|
class UpdateSubscribeTool(MoviePilotTool):
|
||||||
|
name: str = "update_subscribe"
|
||||||
|
description: str = "Update subscription properties including filters, episode counts, state, and other settings. Supports updating quality/resolution filters, episode tracking, subscription state, and download configuration."
|
||||||
|
args_schema: Type[BaseModel] = UpdateSubscribeInput
|
||||||
|
|
||||||
|
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||||
|
"""根据更新参数生成友好的提示消息"""
|
||||||
|
subscribe_id = kwargs.get("subscribe_id")
|
||||||
|
fields_updated = []
|
||||||
|
|
||||||
|
if kwargs.get("name"):
|
||||||
|
fields_updated.append("名称")
|
||||||
|
if kwargs.get("total_episode") is not None:
|
||||||
|
fields_updated.append("总集数")
|
||||||
|
if kwargs.get("lack_episode") is not None:
|
||||||
|
fields_updated.append("缺失集数")
|
||||||
|
if kwargs.get("quality"):
|
||||||
|
fields_updated.append("质量过滤")
|
||||||
|
if kwargs.get("resolution"):
|
||||||
|
fields_updated.append("分辨率过滤")
|
||||||
|
if kwargs.get("state"):
|
||||||
|
state_map = {"R": "启用", "P": "禁用", "S": "暂停"}
|
||||||
|
fields_updated.append(f"状态({state_map.get(kwargs.get('state'), kwargs.get('state'))})")
|
||||||
|
if kwargs.get("sites"):
|
||||||
|
fields_updated.append("站点")
|
||||||
|
if kwargs.get("downloader"):
|
||||||
|
fields_updated.append("下载器")
|
||||||
|
|
||||||
|
if fields_updated:
|
||||||
|
return f"正在更新订阅 #{subscribe_id}: {', '.join(fields_updated)}"
|
||||||
|
return f"正在更新订阅 #{subscribe_id}"
|
||||||
|
|
||||||
|
async def run(self, subscribe_id: int,
|
||||||
|
name: Optional[str] = None,
|
||||||
|
year: Optional[str] = None,
|
||||||
|
season: Optional[int] = None,
|
||||||
|
total_episode: Optional[int] = None,
|
||||||
|
lack_episode: Optional[int] = None,
|
||||||
|
start_episode: Optional[int] = None,
|
||||||
|
quality: Optional[str] = None,
|
||||||
|
resolution: Optional[str] = None,
|
||||||
|
effect: Optional[str] = None,
|
||||||
|
include: Optional[str] = None,
|
||||||
|
exclude: Optional[str] = None,
|
||||||
|
filter: Optional[str] = None,
|
||||||
|
state: Optional[str] = None,
|
||||||
|
sites: Optional[List[int]] = None,
|
||||||
|
downloader: Optional[str] = None,
|
||||||
|
save_path: Optional[str] = None,
|
||||||
|
best_version: Optional[int] = None,
|
||||||
|
custom_words: Optional[str] = None,
|
||||||
|
media_category: Optional[str] = None,
|
||||||
|
episode_group: Optional[str] = None,
|
||||||
|
**kwargs) -> str:
|
||||||
|
logger.info(f"执行工具: {self.name}, 参数: subscribe_id={subscribe_id}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 获取数据库会话
|
||||||
|
async with AsyncSessionFactory() as db:
|
||||||
|
# 获取订阅
|
||||||
|
subscribe = await Subscribe.async_get(db, subscribe_id)
|
||||||
|
if not subscribe:
|
||||||
|
return json.dumps({
|
||||||
|
"success": False,
|
||||||
|
"message": f"订阅不存在: {subscribe_id}"
|
||||||
|
}, ensure_ascii=False)
|
||||||
|
|
||||||
|
# 保存旧数据用于事件
|
||||||
|
old_subscribe_dict = subscribe.to_dict()
|
||||||
|
|
||||||
|
# 构建更新字典
|
||||||
|
subscribe_dict = {}
|
||||||
|
|
||||||
|
# 基本信息
|
||||||
|
if name is not None:
|
||||||
|
subscribe_dict["name"] = name
|
||||||
|
if year is not None:
|
||||||
|
subscribe_dict["year"] = year
|
||||||
|
if season is not None:
|
||||||
|
subscribe_dict["season"] = season
|
||||||
|
|
||||||
|
# 集数相关
|
||||||
|
if total_episode is not None:
|
||||||
|
subscribe_dict["total_episode"] = total_episode
|
||||||
|
# 如果总集数增加,缺失集数也要相应增加
|
||||||
|
if total_episode > (subscribe.total_episode or 0):
|
||||||
|
old_lack = subscribe.lack_episode or 0
|
||||||
|
subscribe_dict["lack_episode"] = old_lack + (total_episode - (subscribe.total_episode or 0))
|
||||||
|
# 标记为手动修改过总集数
|
||||||
|
subscribe_dict["manual_total_episode"] = 1
|
||||||
|
|
||||||
|
# 缺失集数处理(只有在没有提供总集数时才单独处理)
|
||||||
|
# 注意:如果 lack_episode 为 0,不更新(避免更新为0)
|
||||||
|
if lack_episode is not None and total_episode is None:
|
||||||
|
if lack_episode > 0:
|
||||||
|
subscribe_dict["lack_episode"] = lack_episode
|
||||||
|
# 如果 lack_episode 为 0,不添加到更新字典中(保持原值或由总集数逻辑处理)
|
||||||
|
|
||||||
|
if start_episode is not None:
|
||||||
|
subscribe_dict["start_episode"] = start_episode
|
||||||
|
|
||||||
|
# 过滤规则
|
||||||
|
if quality is not None:
|
||||||
|
subscribe_dict["quality"] = quality
|
||||||
|
if resolution is not None:
|
||||||
|
subscribe_dict["resolution"] = resolution
|
||||||
|
if effect is not None:
|
||||||
|
subscribe_dict["effect"] = effect
|
||||||
|
if include is not None:
|
||||||
|
subscribe_dict["include"] = include
|
||||||
|
if exclude is not None:
|
||||||
|
subscribe_dict["exclude"] = exclude
|
||||||
|
if filter is not None:
|
||||||
|
subscribe_dict["filter"] = filter
|
||||||
|
|
||||||
|
# 状态
|
||||||
|
if state is not None:
|
||||||
|
valid_states = ["R", "P", "S", "N"]
|
||||||
|
if state not in valid_states:
|
||||||
|
return json.dumps({
|
||||||
|
"success": False,
|
||||||
|
"message": f"无效的订阅状态: {state},有效状态: {', '.join(valid_states)}"
|
||||||
|
}, ensure_ascii=False)
|
||||||
|
subscribe_dict["state"] = state
|
||||||
|
|
||||||
|
# 下载配置
|
||||||
|
if sites is not None:
|
||||||
|
subscribe_dict["sites"] = sites
|
||||||
|
if downloader is not None:
|
||||||
|
subscribe_dict["downloader"] = downloader
|
||||||
|
if save_path is not None:
|
||||||
|
subscribe_dict["save_path"] = save_path
|
||||||
|
if best_version is not None:
|
||||||
|
subscribe_dict["best_version"] = best_version
|
||||||
|
|
||||||
|
# 其他配置
|
||||||
|
if custom_words is not None:
|
||||||
|
subscribe_dict["custom_words"] = custom_words
|
||||||
|
if media_category is not None:
|
||||||
|
subscribe_dict["media_category"] = media_category
|
||||||
|
if episode_group is not None:
|
||||||
|
subscribe_dict["episode_group"] = episode_group
|
||||||
|
|
||||||
|
# 如果没有要更新的字段
|
||||||
|
if not subscribe_dict:
|
||||||
|
return json.dumps({
|
||||||
|
"success": False,
|
||||||
|
"message": "没有提供要更新的字段"
|
||||||
|
}, ensure_ascii=False)
|
||||||
|
|
||||||
|
# 更新订阅
|
||||||
|
await subscribe.async_update(db, subscribe_dict)
|
||||||
|
|
||||||
|
# 重新获取更新后的订阅数据
|
||||||
|
updated_subscribe = await Subscribe.async_get(db, subscribe_id)
|
||||||
|
|
||||||
|
# 发送订阅调整事件
|
||||||
|
await eventmanager.async_send_event(EventType.SubscribeModified, {
|
||||||
|
"subscribe_id": subscribe_id,
|
||||||
|
"old_subscribe_info": old_subscribe_dict,
|
||||||
|
"subscribe_info": updated_subscribe.to_dict() if updated_subscribe else {},
|
||||||
|
})
|
||||||
|
|
||||||
|
# 构建返回结果
|
||||||
|
result = {
|
||||||
|
"success": True,
|
||||||
|
"message": f"订阅 #{subscribe_id} 更新成功",
|
||||||
|
"subscribe_id": subscribe_id,
|
||||||
|
"updated_fields": list(subscribe_dict.keys())
|
||||||
|
}
|
||||||
|
|
||||||
|
if updated_subscribe:
|
||||||
|
result["subscribe"] = {
|
||||||
|
"id": updated_subscribe.id,
|
||||||
|
"name": updated_subscribe.name,
|
||||||
|
"year": updated_subscribe.year,
|
||||||
|
"type": updated_subscribe.type,
|
||||||
|
"season": updated_subscribe.season,
|
||||||
|
"state": updated_subscribe.state,
|
||||||
|
"total_episode": updated_subscribe.total_episode,
|
||||||
|
"lack_episode": updated_subscribe.lack_episode,
|
||||||
|
"start_episode": updated_subscribe.start_episode,
|
||||||
|
"quality": updated_subscribe.quality,
|
||||||
|
"resolution": updated_subscribe.resolution,
|
||||||
|
"effect": updated_subscribe.effect
|
||||||
|
}
|
||||||
|
|
||||||
|
return json.dumps(result, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_message = f"更新订阅失败: {str(e)}"
|
||||||
|
logger.error(f"更新订阅失败: {e}", exc_info=True)
|
||||||
|
return json.dumps({
|
||||||
|
"success": False,
|
||||||
|
"message": error_message,
|
||||||
|
"subscribe_id": subscribe_id
|
||||||
|
}, ensure_ascii=False)
|
||||||
|
|
||||||
Reference in New Issue
Block a user