Files
MoviePilot/app/agent/tools/impl/get_search_results.py

82 lines
4.1 KiB
Python

"""获取搜索结果工具"""
import json
from typing import List, Optional, Type
from pydantic import BaseModel, Field
from app.agent.tools.base import MoviePilotTool
from app.chain.search import SearchChain
from app.log import logger
from ._torrent_search_utils import (
TORRENT_RESULT_LIMIT,
filter_contexts,
simplify_search_result,
)
class GetSearchResultsInput(BaseModel):
"""获取搜索结果工具的输入参数模型"""
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
site: Optional[List[str]] = Field(None, description="Filter by site name, supports multiple values")
season: Optional[List[str]] = Field(None, description="Filter by season/episode label, supports multiple values")
free_state: Optional[List[str]] = Field(None, description="Filter by promotion state, supports multiple values")
video_code: Optional[List[str]] = Field(None, description="Filter by video codec, supports multiple values")
edition: Optional[List[str]] = Field(None, description="Filter by edition/quality, supports multiple values")
resolution: Optional[List[str]] = Field(None, description="Filter by resolution, supports multiple values")
release_group: Optional[List[str]] = Field(None, description="Filter by release group, supports multiple values")
class GetSearchResultsTool(MoviePilotTool):
name: str = "get_search_results"
description: str = "Get torrent search results from the most recent search_torrents call, with optional frontend-style filters such as site, season, promotion state, codec, quality, resolution, and release group. Returns at most the first 50 matching results."
args_schema: Type[BaseModel] = GetSearchResultsInput
def get_tool_message(self, **kwargs) -> Optional[str]:
return "正在获取搜索结果"
async def run(self, site: Optional[List[str]] = None, season: Optional[List[str]] = None,
free_state: Optional[List[str]] = None, video_code: Optional[List[str]] = None,
edition: Optional[List[str]] = None, resolution: Optional[List[str]] = None,
release_group: Optional[List[str]] = None, **kwargs) -> str:
logger.info(
f"执行工具: {self.name}, 参数: site={site}, season={season}, free_state={free_state}, video_code={video_code}, edition={edition}, resolution={resolution}, release_group={release_group}")
try:
items = await SearchChain().async_last_search_results() or []
if not items:
return "没有可用的搜索结果,请先使用 search_torrents 搜索"
filtered_items = filter_contexts(
items=items,
site=site,
season=season,
free_state=free_state,
video_code=video_code,
edition=edition,
resolution=resolution,
release_group=release_group,
)
if not filtered_items:
return "没有符合筛选条件的搜索结果,请调整筛选条件"
total_count = len(filtered_items)
filtered_ids = {id(item) for item in filtered_items}
matched_indices = [index for index, item in enumerate(items, start=1) if id(item) in filtered_ids]
limited_items = filtered_items[:TORRENT_RESULT_LIMIT]
limited_indices = matched_indices[:TORRENT_RESULT_LIMIT]
results = [
simplify_search_result(item, index)
for item, index in zip(limited_items, limited_indices)
]
payload = {
"total_count": total_count,
"results": results,
}
if total_count > TORRENT_RESULT_LIMIT:
payload["message"] = f"搜索结果共找到 {total_count} 条,仅显示前 {TORRENT_RESULT_LIMIT} 条结果。"
return json.dumps(payload, ensure_ascii=False, indent=2)
except Exception as e:
error_message = f"获取搜索结果失败: {str(e)}"
logger.error(f"获取搜索结果失败: {e}", exc_info=True)
return error_message