diff --git a/app/agent/tools/factory.py b/app/agent/tools/factory.py index fef53a34..cbf850eb 100644 --- a/app/agent/tools/factory.py +++ b/app/agent/tools/factory.py @@ -4,11 +4,14 @@ from typing import List, Callable from app.agent.tools.impl.add_download import AddDownloadTool from app.agent.tools.impl.add_subscribe import AddSubscribeTool +from app.agent.tools.impl.update_subscribe import UpdateSubscribeTool from app.agent.tools.impl.get_recommendations import GetRecommendationsTool from app.agent.tools.impl.query_downloaders import QueryDownloadersTool from app.agent.tools.impl.query_downloads import QueryDownloadsTool from app.agent.tools.impl.query_media_library import QueryMediaLibraryTool from app.agent.tools.impl.query_sites import QuerySitesTool +from app.agent.tools.impl.update_site import UpdateSiteTool +from app.agent.tools.impl.query_site_userdata import QuerySiteUserdataTool from app.agent.tools.impl.test_site import TestSiteTool from app.agent.tools.impl.query_subscribes import QuerySubscribesTool from app.agent.tools.impl.query_subscribe_shares import QuerySubscribeSharesTool @@ -49,6 +52,7 @@ class MoviePilotToolFactory: RecognizeMediaTool, ScrapeMetadataTool, AddSubscribeTool, + UpdateSubscribeTool, SearchTorrentsTool, AddDownloadTool, QuerySubscribesTool, @@ -60,6 +64,8 @@ class MoviePilotToolFactory: DeleteDownloadTool, QueryDownloadersTool, QuerySitesTool, + UpdateSiteTool, + QuerySiteUserdataTool, TestSiteTool, UpdateSiteCookieTool, GetRecommendationsTool, diff --git a/app/agent/tools/impl/query_site_userdata.py b/app/agent/tools/impl/query_site_userdata.py new file mode 100644 index 00000000..4fd0e395 --- /dev/null +++ b/app/agent/tools/impl/query_site_userdata.py @@ -0,0 +1,136 @@ +"""查询站点用户数据工具""" + +import json +from typing import Optional, Type + +from pydantic import BaseModel, Field + +from app.agent.tools.base import MoviePilotTool +from app.db import AsyncSessionFactory +from app.db.models.site import Site +from app.db.models.siteuserdata import SiteUserData +from app.log import logger + + +class QuerySiteUserdataInput(BaseModel): + """查询站点用户数据工具的输入参数模型""" + explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context") + site_id: int = Field(..., description="The ID of the site to query user data for") + workdate: Optional[str] = Field(None, description="Work date to query (optional, format: 'YYYY-MM-DD', if not specified returns latest data)") + + +class QuerySiteUserdataTool(MoviePilotTool): + name: str = "query_site_userdata" + description: str = "Query user data for a specific site including username, user level, upload/download statistics, seeding information, bonus points, and other account details. Supports querying data for a specific date or latest data." + args_schema: Type[BaseModel] = QuerySiteUserdataInput + + def get_tool_message(self, **kwargs) -> Optional[str]: + """根据查询参数生成友好的提示消息""" + site_id = kwargs.get("site_id") + workdate = kwargs.get("workdate") + + message = f"正在查询站点 #{site_id} 的用户数据" + if workdate: + message += f" (日期: {workdate})" + else: + message += " (最新数据)" + + return message + + async def run(self, site_id: int, workdate: Optional[str] = None, **kwargs) -> str: + logger.info(f"执行工具: {self.name}, 参数: site_id={site_id}, workdate={workdate}") + + try: + # 获取数据库会话 + async with AsyncSessionFactory() as db: + # 获取站点 + site = await Site.async_get(db, site_id) + if not site: + return json.dumps({ + "success": False, + "message": f"站点不存在: {site_id}" + }, ensure_ascii=False) + + # 获取站点用户数据 + user_data_list = await SiteUserData.async_get_by_domain( + db, + domain=site.domain, + workdate=workdate + ) + + if not user_data_list: + return json.dumps({ + "success": False, + "message": f"站点 {site.name} ({site.domain}) 暂无用户数据", + "site_id": site_id, + "site_name": site.name, + "site_domain": site.domain, + "workdate": workdate + }, ensure_ascii=False) + + # 格式化用户数据 + result = { + "success": True, + "site_id": site_id, + "site_name": site.name, + "site_domain": site.domain, + "workdate": workdate, + "data_count": len(user_data_list), + "user_data": [] + } + + for user_data in user_data_list: + # 格式化上传/下载量(转换为可读格式) + upload_gb = user_data.upload / (1024 ** 3) if user_data.upload else 0 + download_gb = user_data.download / (1024 ** 3) if user_data.download else 0 + seeding_size_gb = user_data.seeding_size / (1024 ** 3) if user_data.seeding_size else 0 + leeching_size_gb = user_data.leeching_size / (1024 ** 3) if user_data.leeching_size else 0 + + user_data_dict = { + "domain": user_data.domain, + "name": user_data.name, + "username": user_data.username, + "userid": user_data.userid, + "user_level": user_data.user_level, + "join_at": user_data.join_at, + "bonus": user_data.bonus, + "upload": user_data.upload, + "upload_gb": round(upload_gb, 2), + "download": user_data.download, + "download_gb": round(download_gb, 2), + "ratio": round(user_data.ratio, 2) if user_data.ratio else 0, + "seeding": int(user_data.seeding) if user_data.seeding else 0, + "leeching": int(user_data.leeching) if user_data.leeching else 0, + "seeding_size": user_data.seeding_size, + "seeding_size_gb": round(seeding_size_gb, 2), + "leeching_size": user_data.leeching_size, + "leeching_size_gb": round(leeching_size_gb, 2), + "seeding_info": user_data.seeding_info if user_data.seeding_info else [], + "message_unread": user_data.message_unread, + "message_unread_contents": user_data.message_unread_contents if user_data.message_unread_contents else [], + "err_msg": user_data.err_msg, + "updated_day": user_data.updated_day, + "updated_time": user_data.updated_time + } + result["user_data"].append(user_data_dict) + + # 如果有多条数据,只返回最新的(按更新时间排序) + if len(result["user_data"]) > 1: + result["user_data"].sort( + key=lambda x: (x.get("updated_day", ""), x.get("updated_time", "")), + reverse=True + ) + result["message"] = f"找到 {len(result['user_data'])} 条数据,显示最新的一条" + result["user_data"] = [result["user_data"][0]] + + return json.dumps(result, ensure_ascii=False, indent=2) + + except Exception as e: + error_message = f"查询站点用户数据失败: {str(e)}" + logger.error(f"查询站点用户数据失败: {e}", exc_info=True) + return json.dumps({ + "success": False, + "message": error_message, + "site_id": site_id + }, ensure_ascii=False) + diff --git a/app/agent/tools/impl/update_site.py b/app/agent/tools/impl/update_site.py new file mode 100644 index 00000000..a3b18ead --- /dev/null +++ b/app/agent/tools/impl/update_site.py @@ -0,0 +1,203 @@ +"""更新站点工具""" + +import json +from typing import Optional, Type + +from pydantic import BaseModel, Field + +from app.agent.tools.base import MoviePilotTool +from app.core.event import eventmanager +from app.db import AsyncSessionFactory +from app.db.models.site import Site +from app.log import logger +from app.schemas.types import EventType +from app.utils.string import StringUtils + + +class UpdateSiteInput(BaseModel): + """更新站点工具的输入参数模型""" + explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context") + site_id: int = Field(..., description="The ID of the site to update") + name: Optional[str] = Field(None, description="Site name (optional)") + url: Optional[str] = Field(None, description="Site URL (optional, will be automatically formatted)") + pri: Optional[int] = Field(None, description="Site priority (optional, higher number = higher priority)") + rss: Optional[str] = Field(None, description="RSS feed URL (optional)") + cookie: Optional[str] = Field(None, description="Site cookie (optional)") + ua: Optional[str] = Field(None, description="User-Agent string (optional)") + apikey: Optional[str] = Field(None, description="API key (optional)") + token: Optional[str] = Field(None, description="API token (optional)") + proxy: Optional[int] = Field(None, description="Whether to use proxy: 0 for no, 1 for yes (optional)") + filter: Optional[str] = Field(None, description="Filter rule as regular expression (optional)") + note: Optional[str] = Field(None, description="Site notes/remarks (optional)") + timeout: Optional[int] = Field(None, description="Request timeout in seconds (optional, default: 15)") + limit_interval: Optional[int] = Field(None, description="Rate limit interval in seconds (optional)") + limit_count: Optional[int] = Field(None, description="Rate limit count per interval (optional)") + limit_seconds: Optional[int] = Field(None, description="Rate limit seconds between requests (optional)") + is_active: Optional[bool] = Field(None, description="Whether site is active: True for enabled, False for disabled (optional)") + downloader: Optional[str] = Field(None, description="Downloader name for this site (optional)") + + +class UpdateSiteTool(MoviePilotTool): + name: str = "update_site" + description: str = "Update site configuration including URL, priority, authentication credentials (cookie, UA, API key), proxy settings, rate limits, and other site properties. Supports updating multiple site attributes at once." + args_schema: Type[BaseModel] = UpdateSiteInput + + def get_tool_message(self, **kwargs) -> Optional[str]: + """根据更新参数生成友好的提示消息""" + site_id = kwargs.get("site_id") + fields_updated = [] + + if kwargs.get("name"): + fields_updated.append("名称") + if kwargs.get("url"): + fields_updated.append("URL") + if kwargs.get("pri") is not None: + fields_updated.append("优先级") + if kwargs.get("cookie"): + fields_updated.append("Cookie") + if kwargs.get("ua"): + fields_updated.append("User-Agent") + if kwargs.get("proxy") is not None: + fields_updated.append("代理设置") + if kwargs.get("is_active") is not None: + fields_updated.append("启用状态") + if kwargs.get("downloader"): + fields_updated.append("下载器") + + if fields_updated: + return f"正在更新站点 #{site_id}: {', '.join(fields_updated)}" + return f"正在更新站点 #{site_id}" + + async def run(self, site_id: int, + name: Optional[str] = None, + url: Optional[str] = None, + pri: Optional[int] = None, + rss: Optional[str] = None, + cookie: Optional[str] = None, + ua: Optional[str] = None, + apikey: Optional[str] = None, + token: Optional[str] = None, + proxy: Optional[int] = None, + filter: Optional[str] = None, + note: Optional[str] = None, + timeout: Optional[int] = None, + limit_interval: Optional[int] = None, + limit_count: Optional[int] = None, + limit_seconds: Optional[int] = None, + is_active: Optional[bool] = None, + downloader: Optional[str] = None, + **kwargs) -> str: + logger.info(f"执行工具: {self.name}, 参数: site_id={site_id}") + + try: + # 获取数据库会话 + async with AsyncSessionFactory() as db: + # 获取站点 + site = await Site.async_get(db, site_id) + if not site: + return json.dumps({ + "success": False, + "message": f"站点不存在: {site_id}" + }, ensure_ascii=False) + + # 构建更新字典 + site_dict = {} + + # 基本信息 + if name is not None: + site_dict["name"] = name + + # URL处理(需要校正格式) + if url is not None: + _scheme, _netloc = StringUtils.get_url_netloc(url) + site_dict["url"] = f"{_scheme}://{_netloc}/" + + if pri is not None: + site_dict["pri"] = pri + if rss is not None: + site_dict["rss"] = rss + + # 认证信息 + if cookie is not None: + site_dict["cookie"] = cookie + if ua is not None: + site_dict["ua"] = ua + if apikey is not None: + site_dict["apikey"] = apikey + if token is not None: + site_dict["token"] = token + + # 配置选项 + if proxy is not None: + site_dict["proxy"] = proxy + if filter is not None: + site_dict["filter"] = filter + if note is not None: + site_dict["note"] = note + if timeout is not None: + site_dict["timeout"] = timeout + + # 流控设置 + if limit_interval is not None: + site_dict["limit_interval"] = limit_interval + if limit_count is not None: + site_dict["limit_count"] = limit_count + if limit_seconds is not None: + site_dict["limit_seconds"] = limit_seconds + + # 状态和下载器 + if is_active is not None: + site_dict["is_active"] = is_active + if downloader is not None: + site_dict["downloader"] = downloader + + # 如果没有要更新的字段 + if not site_dict: + return json.dumps({ + "success": False, + "message": "没有提供要更新的字段" + }, ensure_ascii=False) + + # 更新站点 + await site.async_update(db, site_dict) + + # 重新获取更新后的站点数据 + updated_site = await Site.async_get(db, site_id) + + # 发送站点更新事件 + await eventmanager.async_send_event(EventType.SiteUpdated, { + "domain": updated_site.domain if updated_site else site.domain + }) + + # 构建返回结果 + result = { + "success": True, + "message": f"站点 #{site_id} 更新成功", + "site_id": site_id, + "updated_fields": list(site_dict.keys()) + } + + if updated_site: + result["site"] = { + "id": updated_site.id, + "name": updated_site.name, + "domain": updated_site.domain, + "url": updated_site.url, + "pri": updated_site.pri, + "is_active": updated_site.is_active, + "downloader": updated_site.downloader, + "proxy": updated_site.proxy, + "timeout": updated_site.timeout + } + + return json.dumps(result, ensure_ascii=False, indent=2) + + except Exception as e: + error_message = f"更新站点失败: {str(e)}" + logger.error(f"更新站点失败: {e}", exc_info=True) + return json.dumps({ + "success": False, + "message": error_message, + "site_id": site_id + }, ensure_ascii=False) + diff --git a/app/agent/tools/impl/update_subscribe.py b/app/agent/tools/impl/update_subscribe.py new file mode 100644 index 00000000..dbbe1f07 --- /dev/null +++ b/app/agent/tools/impl/update_subscribe.py @@ -0,0 +1,239 @@ +"""更新订阅工具""" + +import json +from typing import Optional, Type, List + +from pydantic import BaseModel, Field + +from app.agent.tools.base import MoviePilotTool +from app.core.event import eventmanager +from app.db import AsyncSessionFactory +from app.db.models.subscribe import Subscribe +from app.log import logger +from app.schemas.types import EventType + + +class UpdateSubscribeInput(BaseModel): + """更新订阅工具的输入参数模型""" + explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context") + subscribe_id: int = Field(..., description="The ID of the subscription to update") + name: Optional[str] = Field(None, description="Subscription name/title (optional)") + year: Optional[str] = Field(None, description="Release year (optional)") + season: Optional[int] = Field(None, description="Season number for TV shows (optional)") + total_episode: Optional[int] = Field(None, description="Total number of episodes (optional)") + lack_episode: Optional[int] = Field(None, description="Number of missing episodes (optional)") + start_episode: Optional[int] = Field(None, description="Starting episode number (optional)") + quality: Optional[str] = Field(None, description="Quality filter as regular expression (optional, e.g., 'BluRay|WEB-DL|HDTV')") + resolution: Optional[str] = Field(None, description="Resolution filter as regular expression (optional, e.g., '1080p|720p|2160p')") + effect: Optional[str] = Field(None, description="Effect filter as regular expression (optional, e.g., 'HDR|DV|SDR')") + include: Optional[str] = Field(None, description="Include filter as regular expression (optional)") + exclude: Optional[str] = Field(None, description="Exclude filter as regular expression (optional)") + filter: Optional[str] = Field(None, description="Filter rule as regular expression (optional)") + state: Optional[str] = Field(None, description="Subscription state: 'R' for enabled, 'P' for disabled, 'S' for paused (optional)") + sites: Optional[List[int]] = Field(None, description="List of site IDs to search from (optional)") + downloader: Optional[str] = Field(None, description="Downloader name (optional)") + save_path: Optional[str] = Field(None, description="Save path for downloaded files (optional)") + best_version: Optional[int] = Field(None, description="Whether to upgrade to best version: 0 for no, 1 for yes (optional)") + custom_words: Optional[str] = Field(None, description="Custom recognition words (optional)") + media_category: Optional[str] = Field(None, description="Custom media category (optional)") + episode_group: Optional[str] = Field(None, description="Episode group ID (optional)") + + +class UpdateSubscribeTool(MoviePilotTool): + name: str = "update_subscribe" + description: str = "Update subscription properties including filters, episode counts, state, and other settings. Supports updating quality/resolution filters, episode tracking, subscription state, and download configuration." + args_schema: Type[BaseModel] = UpdateSubscribeInput + + def get_tool_message(self, **kwargs) -> Optional[str]: + """根据更新参数生成友好的提示消息""" + subscribe_id = kwargs.get("subscribe_id") + fields_updated = [] + + if kwargs.get("name"): + fields_updated.append("名称") + if kwargs.get("total_episode") is not None: + fields_updated.append("总集数") + if kwargs.get("lack_episode") is not None: + fields_updated.append("缺失集数") + if kwargs.get("quality"): + fields_updated.append("质量过滤") + if kwargs.get("resolution"): + fields_updated.append("分辨率过滤") + if kwargs.get("state"): + state_map = {"R": "启用", "P": "禁用", "S": "暂停"} + fields_updated.append(f"状态({state_map.get(kwargs.get('state'), kwargs.get('state'))})") + if kwargs.get("sites"): + fields_updated.append("站点") + if kwargs.get("downloader"): + fields_updated.append("下载器") + + if fields_updated: + return f"正在更新订阅 #{subscribe_id}: {', '.join(fields_updated)}" + return f"正在更新订阅 #{subscribe_id}" + + async def run(self, subscribe_id: int, + name: Optional[str] = None, + year: Optional[str] = None, + season: Optional[int] = None, + total_episode: Optional[int] = None, + lack_episode: Optional[int] = None, + start_episode: Optional[int] = None, + quality: Optional[str] = None, + resolution: Optional[str] = None, + effect: Optional[str] = None, + include: Optional[str] = None, + exclude: Optional[str] = None, + filter: Optional[str] = None, + state: Optional[str] = None, + sites: Optional[List[int]] = None, + downloader: Optional[str] = None, + save_path: Optional[str] = None, + best_version: Optional[int] = None, + custom_words: Optional[str] = None, + media_category: Optional[str] = None, + episode_group: Optional[str] = None, + **kwargs) -> str: + logger.info(f"执行工具: {self.name}, 参数: subscribe_id={subscribe_id}") + + try: + # 获取数据库会话 + async with AsyncSessionFactory() as db: + # 获取订阅 + subscribe = await Subscribe.async_get(db, subscribe_id) + if not subscribe: + return json.dumps({ + "success": False, + "message": f"订阅不存在: {subscribe_id}" + }, ensure_ascii=False) + + # 保存旧数据用于事件 + old_subscribe_dict = subscribe.to_dict() + + # 构建更新字典 + subscribe_dict = {} + + # 基本信息 + if name is not None: + subscribe_dict["name"] = name + if year is not None: + subscribe_dict["year"] = year + if season is not None: + subscribe_dict["season"] = season + + # 集数相关 + if total_episode is not None: + subscribe_dict["total_episode"] = total_episode + # 如果总集数增加,缺失集数也要相应增加 + if total_episode > (subscribe.total_episode or 0): + old_lack = subscribe.lack_episode or 0 + subscribe_dict["lack_episode"] = old_lack + (total_episode - (subscribe.total_episode or 0)) + # 标记为手动修改过总集数 + subscribe_dict["manual_total_episode"] = 1 + + # 缺失集数处理(只有在没有提供总集数时才单独处理) + # 注意:如果 lack_episode 为 0,不更新(避免更新为0) + if lack_episode is not None and total_episode is None: + if lack_episode > 0: + subscribe_dict["lack_episode"] = lack_episode + # 如果 lack_episode 为 0,不添加到更新字典中(保持原值或由总集数逻辑处理) + + if start_episode is not None: + subscribe_dict["start_episode"] = start_episode + + # 过滤规则 + if quality is not None: + subscribe_dict["quality"] = quality + if resolution is not None: + subscribe_dict["resolution"] = resolution + if effect is not None: + subscribe_dict["effect"] = effect + if include is not None: + subscribe_dict["include"] = include + if exclude is not None: + subscribe_dict["exclude"] = exclude + if filter is not None: + subscribe_dict["filter"] = filter + + # 状态 + if state is not None: + valid_states = ["R", "P", "S", "N"] + if state not in valid_states: + return json.dumps({ + "success": False, + "message": f"无效的订阅状态: {state},有效状态: {', '.join(valid_states)}" + }, ensure_ascii=False) + subscribe_dict["state"] = state + + # 下载配置 + if sites is not None: + subscribe_dict["sites"] = sites + if downloader is not None: + subscribe_dict["downloader"] = downloader + if save_path is not None: + subscribe_dict["save_path"] = save_path + if best_version is not None: + subscribe_dict["best_version"] = best_version + + # 其他配置 + if custom_words is not None: + subscribe_dict["custom_words"] = custom_words + if media_category is not None: + subscribe_dict["media_category"] = media_category + if episode_group is not None: + subscribe_dict["episode_group"] = episode_group + + # 如果没有要更新的字段 + if not subscribe_dict: + return json.dumps({ + "success": False, + "message": "没有提供要更新的字段" + }, ensure_ascii=False) + + # 更新订阅 + await subscribe.async_update(db, subscribe_dict) + + # 重新获取更新后的订阅数据 + updated_subscribe = await Subscribe.async_get(db, subscribe_id) + + # 发送订阅调整事件 + await eventmanager.async_send_event(EventType.SubscribeModified, { + "subscribe_id": subscribe_id, + "old_subscribe_info": old_subscribe_dict, + "subscribe_info": updated_subscribe.to_dict() if updated_subscribe else {}, + }) + + # 构建返回结果 + result = { + "success": True, + "message": f"订阅 #{subscribe_id} 更新成功", + "subscribe_id": subscribe_id, + "updated_fields": list(subscribe_dict.keys()) + } + + if updated_subscribe: + result["subscribe"] = { + "id": updated_subscribe.id, + "name": updated_subscribe.name, + "year": updated_subscribe.year, + "type": updated_subscribe.type, + "season": updated_subscribe.season, + "state": updated_subscribe.state, + "total_episode": updated_subscribe.total_episode, + "lack_episode": updated_subscribe.lack_episode, + "start_episode": updated_subscribe.start_episode, + "quality": updated_subscribe.quality, + "resolution": updated_subscribe.resolution, + "effect": updated_subscribe.effect + } + + return json.dumps(result, ensure_ascii=False, indent=2) + + except Exception as e: + error_message = f"更新订阅失败: {str(e)}" + logger.error(f"更新订阅失败: {e}", exc_info=True) + return json.dumps({ + "success": False, + "message": error_message, + "subscribe_id": subscribe_id + }, ensure_ascii=False) +