mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-03-20 03:57:30 +08:00
feat: add TestSiteTool and enhance AddSubscribeTool with advanced filtering options
- Introduced TestSiteTool to the toolset for site testing functionalities. - Updated __all__ exports in init.py and factory.py to include TestSiteTool. - Enhanced AddSubscribeTool to support additional parameters for episode management and media quality filtering, improving subscription customization.
This commit is contained in:
@@ -10,6 +10,7 @@ from app.agent.tools.impl.delete_subscribe import DeleteSubscribeTool
|
||||
from app.agent.tools.impl.query_downloads import QueryDownloadsTool
|
||||
from app.agent.tools.impl.query_downloaders import QueryDownloadersTool
|
||||
from app.agent.tools.impl.query_sites import QuerySitesTool
|
||||
from app.agent.tools.impl.test_site import TestSiteTool
|
||||
from app.agent.tools.impl.get_recommendations import GetRecommendationsTool
|
||||
from app.agent.tools.impl.query_media_library import QueryMediaLibraryTool
|
||||
from app.agent.tools.impl.send_message import SendMessageTool
|
||||
@@ -28,6 +29,7 @@ __all__ = [
|
||||
"QueryDownloadsTool",
|
||||
"QueryDownloadersTool",
|
||||
"QuerySitesTool",
|
||||
"TestSiteTool",
|
||||
"GetRecommendationsTool",
|
||||
"QueryMediaLibraryTool",
|
||||
"SendMessageTool",
|
||||
|
||||
@@ -9,6 +9,7 @@ from app.agent.tools.impl.query_downloaders import QueryDownloadersTool
|
||||
from app.agent.tools.impl.query_downloads import QueryDownloadsTool
|
||||
from app.agent.tools.impl.query_media_library import QueryMediaLibraryTool
|
||||
from app.agent.tools.impl.query_sites import QuerySitesTool
|
||||
from app.agent.tools.impl.test_site import TestSiteTool
|
||||
from app.agent.tools.impl.query_subscribes import QuerySubscribesTool
|
||||
from app.agent.tools.impl.delete_subscribe import DeleteSubscribeTool
|
||||
from app.agent.tools.impl.search_media import SearchMediaTool
|
||||
@@ -40,6 +41,7 @@ class MoviePilotToolFactory:
|
||||
QueryDownloadsTool,
|
||||
QueryDownloadersTool,
|
||||
QuerySitesTool,
|
||||
TestSiteTool,
|
||||
GetRecommendationsTool,
|
||||
QueryMediaLibraryTool,
|
||||
SendMessageTool,
|
||||
|
||||
@@ -21,17 +21,32 @@ class AddSubscribeInput(BaseModel):
|
||||
description="Season number for TV shows (optional, if not specified will subscribe to all seasons)")
|
||||
tmdb_id: Optional[str] = Field(None,
|
||||
description="TMDB database ID for precise media identification (optional but recommended for accuracy)")
|
||||
start_episode: Optional[int] = Field(None,
|
||||
description="Starting episode number for TV shows (optional, defaults to 1 if not specified)")
|
||||
total_episode: Optional[int] = Field(None,
|
||||
description="Total number of episodes for TV shows (optional, will be auto-detected from TMDB if not specified)")
|
||||
quality: Optional[str] = Field(None,
|
||||
description="Quality filter as regular expression (optional, e.g., 'BluRay|WEB-DL|HDTV')")
|
||||
resolution: Optional[str] = Field(None,
|
||||
description="Resolution filter as regular expression (optional, e.g., '1080p|720p|2160p')")
|
||||
effect: Optional[str] = Field(None,
|
||||
description="Effect filter as regular expression (optional, e.g., 'HDR|DV|SDR')")
|
||||
|
||||
|
||||
class AddSubscribeTool(MoviePilotTool):
|
||||
name: str = "add_subscribe"
|
||||
description: str = "Add media subscription to create automated download rules for movies and TV shows. The system will automatically search and download new episodes or releases based on the subscription criteria."
|
||||
description: str = "Add media subscription to create automated download rules for movies and TV shows. The system will automatically search and download new episodes or releases based on the subscription criteria. Supports advanced filtering options like quality, resolution, and effect filters using regular expressions."
|
||||
args_schema: Type[BaseModel] = AddSubscribeInput
|
||||
|
||||
async def run(self, title: str, year: str, media_type: str,
|
||||
season: Optional[int] = None, tmdb_id: Optional[str] = None, **kwargs) -> str:
|
||||
season: Optional[int] = None, tmdb_id: Optional[str] = None,
|
||||
start_episode: Optional[int] = None, total_episode: Optional[int] = None,
|
||||
quality: Optional[str] = None, resolution: Optional[str] = None,
|
||||
effect: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: title={title}, year={year}, media_type={media_type}, season={season}, tmdb_id={tmdb_id}")
|
||||
f"执行工具: {self.name}, 参数: title={title}, year={year}, media_type={media_type}, "
|
||||
f"season={season}, tmdb_id={tmdb_id}, start_episode={start_episode}, "
|
||||
f"total_episode={total_episode}, quality={quality}, resolution={resolution}, effect={effect}")
|
||||
|
||||
try:
|
||||
subscribe_chain = SubscribeChain()
|
||||
@@ -43,16 +58,45 @@ class AddSubscribeTool(MoviePilotTool):
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(f"无效的 tmdb_id: {tmdb_id},将忽略")
|
||||
|
||||
# 构建额外的订阅参数
|
||||
subscribe_kwargs = {}
|
||||
if start_episode is not None:
|
||||
subscribe_kwargs['start_episode'] = start_episode
|
||||
if total_episode is not None:
|
||||
subscribe_kwargs['total_episode'] = total_episode
|
||||
if quality:
|
||||
subscribe_kwargs['quality'] = quality
|
||||
if resolution:
|
||||
subscribe_kwargs['resolution'] = resolution
|
||||
if effect:
|
||||
subscribe_kwargs['effect'] = effect
|
||||
|
||||
sid, message = await subscribe_chain.async_add(
|
||||
mtype=MediaType(media_type),
|
||||
title=title,
|
||||
year=year,
|
||||
tmdbid=tmdbid_int,
|
||||
season=season,
|
||||
username=self._user_id
|
||||
username=self._user_id,
|
||||
**subscribe_kwargs
|
||||
)
|
||||
if sid:
|
||||
return f"成功添加订阅:{title} ({year})"
|
||||
result_msg = f"成功添加订阅:{title} ({year})"
|
||||
if subscribe_kwargs:
|
||||
params = []
|
||||
if start_episode is not None:
|
||||
params.append(f"开始集数: {start_episode}")
|
||||
if total_episode is not None:
|
||||
params.append(f"总集数: {total_episode}")
|
||||
if quality:
|
||||
params.append(f"质量过滤: {quality}")
|
||||
if resolution:
|
||||
params.append(f"分辨率过滤: {resolution}")
|
||||
if effect:
|
||||
params.append(f"特效过滤: {effect}")
|
||||
if params:
|
||||
result_msg += f"\n配置参数: {', '.join(params)}"
|
||||
return result_msg
|
||||
else:
|
||||
return f"添加订阅失败:{message}"
|
||||
except Exception as e:
|
||||
|
||||
67
app/agent/tools/impl/test_site.py
Normal file
67
app/agent/tools/impl/test_site.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""测试站点连通性工具"""
|
||||
|
||||
from typing import Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.site import SiteChain
|
||||
from app.db.site_oper import SiteOper
|
||||
from app.log import logger
|
||||
from app.utils.string import StringUtils
|
||||
|
||||
|
||||
class TestSiteInput(BaseModel):
|
||||
"""测试站点连通性工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
site_identifier: str = Field(..., description="Site identifier: can be site ID (integer as string), site name, or site domain/URL")
|
||||
|
||||
|
||||
class TestSiteTool(MoviePilotTool):
|
||||
name: str = "test_site"
|
||||
description: str = "Test site connectivity and availability. This will check if a site is accessible and can be logged in. Accepts site ID, site name, or site domain/URL as identifier."
|
||||
args_schema: Type[BaseModel] = TestSiteInput
|
||||
|
||||
async def run(self, site_identifier: str, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: site_identifier={site_identifier}")
|
||||
|
||||
try:
|
||||
site_oper = SiteOper()
|
||||
site_chain = SiteChain()
|
||||
|
||||
# 尝试解析为站点ID
|
||||
site = None
|
||||
if site_identifier.isdigit():
|
||||
# 如果是数字,尝试作为站点ID查询
|
||||
site = await site_oper.async_get(int(site_identifier))
|
||||
|
||||
# 如果不是ID或ID查询失败,尝试按名称或域名查询
|
||||
if not site:
|
||||
# 尝试按名称查询
|
||||
sites = await site_oper.async_list()
|
||||
for s in sites:
|
||||
if (site_identifier.lower() in (s.name or "").lower()) or \
|
||||
(site_identifier.lower() in (s.domain or "").lower()):
|
||||
site = s
|
||||
break
|
||||
|
||||
# 如果还是没找到,尝试从URL提取域名
|
||||
if not site:
|
||||
domain = StringUtils.get_url_domain(site_identifier)
|
||||
if domain:
|
||||
site = await site_oper.async_get_by_domain(domain)
|
||||
|
||||
if not site:
|
||||
return f"未找到站点:{site_identifier},请使用 query_sites 工具查询可用的站点"
|
||||
|
||||
# 测试站点连通性
|
||||
status, message = site_chain.test(site.domain)
|
||||
|
||||
if status:
|
||||
return f"站点连通性测试成功:{site.name} ({site.domain})\n{message}"
|
||||
else:
|
||||
return f"站点连通性测试失败:{site.name} ({site.domain})\n{message}"
|
||||
except Exception as e:
|
||||
logger.error(f"测试站点连通性失败: {e}", exc_info=True)
|
||||
return f"测试站点连通性时发生错误: {str(e)}"
|
||||
|
||||
Reference in New Issue
Block a user