mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-03-20 03:57:30 +08:00
Add filter_groups parameter to AddSubscribeTool and include SearchSubscribeTool and QueryRuleGroupsTool in MoviePilotToolFactory
This commit is contained in:
@@ -5,6 +5,7 @@ from typing import List, Callable
|
||||
from app.agent.tools.impl.add_download import AddDownloadTool
|
||||
from app.agent.tools.impl.add_subscribe import AddSubscribeTool
|
||||
from app.agent.tools.impl.update_subscribe import UpdateSubscribeTool
|
||||
from app.agent.tools.impl.search_subscribe import SearchSubscribeTool
|
||||
from app.agent.tools.impl.get_recommendations import GetRecommendationsTool
|
||||
from app.agent.tools.impl.query_downloaders import QueryDownloadersTool
|
||||
from app.agent.tools.impl.query_downloads import QueryDownloadsTool
|
||||
@@ -15,6 +16,7 @@ from app.agent.tools.impl.query_site_userdata import QuerySiteUserdataTool
|
||||
from app.agent.tools.impl.test_site import TestSiteTool
|
||||
from app.agent.tools.impl.query_subscribes import QuerySubscribesTool
|
||||
from app.agent.tools.impl.query_subscribe_shares import QuerySubscribeSharesTool
|
||||
from app.agent.tools.impl.query_rule_groups import QueryRuleGroupsTool
|
||||
from app.agent.tools.impl.query_popular_subscribes import QueryPopularSubscribesTool
|
||||
from app.agent.tools.impl.query_subscribe_history import QuerySubscribeHistoryTool
|
||||
from app.agent.tools.impl.delete_subscribe import DeleteSubscribeTool
|
||||
@@ -55,11 +57,13 @@ class MoviePilotToolFactory:
|
||||
QueryEpisodeScheduleTool,
|
||||
AddSubscribeTool,
|
||||
UpdateSubscribeTool,
|
||||
SearchSubscribeTool,
|
||||
SearchTorrentsTool,
|
||||
AddDownloadTool,
|
||||
QuerySubscribesTool,
|
||||
QuerySubscribeSharesTool,
|
||||
QueryPopularSubscribesTool,
|
||||
QueryRuleGroupsTool,
|
||||
QuerySubscribeHistoryTool,
|
||||
DeleteSubscribeTool,
|
||||
QueryDownloadsTool,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""添加订阅工具"""
|
||||
|
||||
from typing import Optional, Type
|
||||
from typing import Optional, Type, List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -31,6 +31,8 @@ class AddSubscribeInput(BaseModel):
|
||||
description="Resolution filter as regular expression (optional, e.g., '1080p|720p|2160p')")
|
||||
effect: Optional[str] = Field(None,
|
||||
description="Effect filter as regular expression (optional, e.g., 'HDR|DV|SDR')")
|
||||
filter_groups: Optional[List[str]] = Field(None,
|
||||
description="List of filter rule group names to apply (optional, use query_rule_groups tool to get available rule groups)")
|
||||
|
||||
|
||||
class AddSubscribeTool(MoviePilotTool):
|
||||
@@ -59,11 +61,12 @@ class AddSubscribeTool(MoviePilotTool):
|
||||
season: Optional[int] = None, tmdb_id: Optional[str] = None,
|
||||
start_episode: Optional[int] = None, total_episode: Optional[int] = None,
|
||||
quality: Optional[str] = None, resolution: Optional[str] = None,
|
||||
effect: Optional[str] = None, **kwargs) -> str:
|
||||
effect: Optional[str] = None, filter_groups: Optional[List[str]] = None, **kwargs) -> str:
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: title={title}, year={year}, media_type={media_type}, "
|
||||
f"season={season}, tmdb_id={tmdb_id}, start_episode={start_episode}, "
|
||||
f"total_episode={total_episode}, quality={quality}, resolution={resolution}, effect={effect}")
|
||||
f"total_episode={total_episode}, quality={quality}, resolution={resolution}, "
|
||||
f"effect={effect}, filter_groups={filter_groups}")
|
||||
|
||||
try:
|
||||
subscribe_chain = SubscribeChain()
|
||||
@@ -87,6 +90,8 @@ class AddSubscribeTool(MoviePilotTool):
|
||||
subscribe_kwargs['resolution'] = resolution
|
||||
if effect:
|
||||
subscribe_kwargs['effect'] = effect
|
||||
if filter_groups:
|
||||
subscribe_kwargs['filter_groups'] = filter_groups
|
||||
|
||||
sid, message = await subscribe_chain.async_add(
|
||||
mtype=MediaType(media_type),
|
||||
@@ -111,6 +116,8 @@ class AddSubscribeTool(MoviePilotTool):
|
||||
params.append(f"分辨率过滤: {resolution}")
|
||||
if effect:
|
||||
params.append(f"特效过滤: {effect}")
|
||||
if filter_groups:
|
||||
params.append(f"规则组: {', '.join(filter_groups)}")
|
||||
if params:
|
||||
result_msg += f"\n配置参数: {', '.join(params)}"
|
||||
return result_msg
|
||||
|
||||
65
app/agent/tools/impl/query_rule_groups.py
Normal file
65
app/agent/tools/impl/query_rule_groups.py
Normal file
@@ -0,0 +1,65 @@
|
||||
"""查询规则组工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.helper.rule import RuleHelper
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class QueryRuleGroupsInput(BaseModel):
|
||||
"""查询规则组工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
|
||||
|
||||
class QueryRuleGroupsTool(MoviePilotTool):
|
||||
name: str = "query_rule_groups"
|
||||
description: str = "Query all filter rule groups available in the system. Rule groups are used to filter torrents when searching or subscribing. Returns rule group names, media types, and categories, but excludes rule_string to keep results concise."
|
||||
args_schema: Type[BaseModel] = QueryRuleGroupsInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据查询参数生成友好的提示消息"""
|
||||
return "正在查询所有规则组"
|
||||
|
||||
async def run(self, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}")
|
||||
|
||||
try:
|
||||
rule_helper = RuleHelper()
|
||||
rule_groups = rule_helper.get_rule_groups()
|
||||
|
||||
if not rule_groups:
|
||||
return json.dumps({
|
||||
"message": "未找到任何规则组",
|
||||
"rule_groups": []
|
||||
}, ensure_ascii=False, indent=2)
|
||||
|
||||
# 精简字段,过滤掉 rule_string 避免结果过大
|
||||
simplified_groups = []
|
||||
for group in rule_groups:
|
||||
simplified = {
|
||||
"name": group.name,
|
||||
"media_type": group.media_type,
|
||||
"category": group.category
|
||||
}
|
||||
simplified_groups.append(simplified)
|
||||
|
||||
result = {
|
||||
"message": f"找到 {len(simplified_groups)} 个规则组",
|
||||
"rule_groups": simplified_groups
|
||||
}
|
||||
|
||||
return json.dumps(result, ensure_ascii=False, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"查询规则组失败: {str(e)}"
|
||||
logger.error(f"查询规则组失败: {e}", exc_info=True)
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": error_message,
|
||||
"rule_groups": []
|
||||
}, ensure_ascii=False)
|
||||
|
||||
128
app/agent/tools/impl/search_subscribe.py
Normal file
128
app/agent/tools/impl/search_subscribe.py
Normal file
@@ -0,0 +1,128 @@
|
||||
"""搜索订阅缺失剧集工具"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Optional, Type, List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.subscribe import SubscribeChain
|
||||
from app.db.subscribe_oper import SubscribeOper
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class SearchSubscribeInput(BaseModel):
|
||||
"""搜索订阅缺失剧集工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
subscribe_id: int = Field(..., description="The ID of the subscription to search for missing episodes")
|
||||
manual: Optional[bool] = Field(False, description="Whether this is a manual search (default: False)")
|
||||
filter_groups: Optional[List[str]] = Field(None,
|
||||
description="List of filter rule group names to apply for this search (optional, use query_rule_groups tool to get available rule groups. If provided, will temporarily update the subscription's filter groups before searching)")
|
||||
|
||||
|
||||
class SearchSubscribeTool(MoviePilotTool):
|
||||
name: str = "search_subscribe"
|
||||
description: str = "Search for missing episodes/resources for a specific subscription. This tool will search torrent sites for the missing episodes of the subscription and automatically download matching resources. Use this when a user wants to search for missing episodes of a specific subscription."
|
||||
args_schema: Type[BaseModel] = SearchSubscribeInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据搜索参数生成友好的提示消息"""
|
||||
subscribe_id = kwargs.get("subscribe_id")
|
||||
manual = kwargs.get("manual", False)
|
||||
|
||||
message = f"正在搜索订阅 #{subscribe_id} 的缺失剧集"
|
||||
if manual:
|
||||
message += "(手动搜索)"
|
||||
|
||||
return message
|
||||
|
||||
async def run(self, subscribe_id: int, manual: Optional[bool] = False,
|
||||
filter_groups: Optional[List[str]] = None, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: subscribe_id={subscribe_id}, manual={manual}, filter_groups={filter_groups}")
|
||||
|
||||
try:
|
||||
# 先验证订阅是否存在
|
||||
subscribe_oper = SubscribeOper()
|
||||
subscribe = subscribe_oper.get(subscribe_id)
|
||||
|
||||
if not subscribe:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": f"订阅不存在: {subscribe_id}"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
# 获取订阅信息用于返回
|
||||
subscribe_info = {
|
||||
"id": subscribe.id,
|
||||
"name": subscribe.name,
|
||||
"year": subscribe.year,
|
||||
"type": subscribe.type,
|
||||
"season": subscribe.season,
|
||||
"state": subscribe.state,
|
||||
"total_episode": subscribe.total_episode,
|
||||
"lack_episode": subscribe.lack_episode,
|
||||
"tmdbid": subscribe.tmdbid,
|
||||
"doubanid": subscribe.doubanid
|
||||
}
|
||||
|
||||
# 检查订阅状态
|
||||
if subscribe.state == "S":
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": f"订阅 #{subscribe_id} ({subscribe.name}) 已暂停,无法搜索",
|
||||
"subscribe": subscribe_info
|
||||
}, ensure_ascii=False)
|
||||
|
||||
# 如果提供了 filter_groups 参数,先更新订阅的规则组
|
||||
if filter_groups is not None:
|
||||
subscribe_oper.update(subscribe_id, {"filter_groups": filter_groups})
|
||||
logger.info(f"更新订阅 #{subscribe_id} 的规则组为: {filter_groups}")
|
||||
|
||||
# 调用 SubscribeChain 的 search 方法
|
||||
# search 方法是同步的,需要在异步环境中运行
|
||||
subscribe_chain = SubscribeChain()
|
||||
|
||||
# 在线程池中执行同步的搜索操作
|
||||
# 当 sid 有值时,state 参数会被忽略,直接处理该订阅
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
lambda: subscribe_chain.search(
|
||||
sid=subscribe_id,
|
||||
state='R', # 默认状态,当 sid 有值时此参数会被忽略
|
||||
manual=manual
|
||||
)
|
||||
)
|
||||
|
||||
# 重新获取订阅信息以获取更新后的状态
|
||||
updated_subscribe = subscribe_oper.get(subscribe_id)
|
||||
if updated_subscribe:
|
||||
subscribe_info.update({
|
||||
"state": updated_subscribe.state,
|
||||
"lack_episode": updated_subscribe.lack_episode,
|
||||
"last_update": updated_subscribe.last_update,
|
||||
"filter_groups": updated_subscribe.filter_groups
|
||||
})
|
||||
|
||||
# 如果提供了规则组,会在返回信息中显示
|
||||
result = {
|
||||
"success": True,
|
||||
"message": f"订阅 #{subscribe_id} ({subscribe.name}) 搜索完成",
|
||||
"subscribe": subscribe_info
|
||||
}
|
||||
|
||||
if filter_groups is not None:
|
||||
result["message"] += f"(已应用规则组: {', '.join(filter_groups)})"
|
||||
|
||||
return json.dumps(result, ensure_ascii=False, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"搜索订阅缺失剧集失败: {str(e)}"
|
||||
logger.error(f"搜索订阅缺失剧集失败: {e}", exc_info=True)
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": error_message,
|
||||
"subscribe_id": subscribe_id
|
||||
}, ensure_ascii=False)
|
||||
|
||||
Reference in New Issue
Block a user