mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-05-08 22:33:16 +08:00
188 lines
7.5 KiB
Python
188 lines
7.5 KiB
Python
"""查询订阅历史工具"""
|
|
|
|
import json
|
|
from typing import Optional, Type
|
|
|
|
from pydantic import BaseModel, Field
|
|
|
|
from app.agent.tools.base import MoviePilotTool
|
|
from app.db import AsyncSessionFactory
|
|
from app.db.models.subscribehistory import SubscribeHistory
|
|
from app.log import logger
|
|
from app.schemas.types import media_type_to_agent
|
|
|
|
PAGE_SIZE = 20
|
|
|
|
|
|
class QuerySubscribeHistoryInput(BaseModel):
|
|
"""查询订阅历史工具的输入参数模型"""
|
|
|
|
explanation: str = Field(
|
|
...,
|
|
description="Clear explanation of why this tool is being used in the current context",
|
|
)
|
|
media_type: Optional[str] = Field(
|
|
"all", description="Allowed values: movie, tv, all"
|
|
)
|
|
name: Optional[str] = Field(
|
|
None, description="Filter by media name (partial match, optional)"
|
|
)
|
|
page: Optional[int] = Field(
|
|
1,
|
|
description="Page number for pagination (default: 1, 20 items per page). Ignored when name filter is provided.",
|
|
)
|
|
|
|
|
|
class QuerySubscribeHistoryTool(MoviePilotTool):
|
|
name: str = "query_subscribe_history"
|
|
description: str = "Query subscription history records. Shows completed subscriptions with their details including name, type, rating, completion date, and other subscription information. Supports filtering by media type and name. Supports pagination with 20 records per page."
|
|
args_schema: Type[BaseModel] = QuerySubscribeHistoryInput
|
|
|
|
def get_tool_message(self, **kwargs) -> Optional[str]:
|
|
"""根据查询参数生成友好的提示消息"""
|
|
media_type = kwargs.get("media_type", "all")
|
|
name = kwargs.get("name")
|
|
page = kwargs.get("page", 1)
|
|
|
|
parts = ["正在查询订阅历史"]
|
|
|
|
if media_type != "all":
|
|
parts.append(f"类型: {media_type}")
|
|
if name:
|
|
parts.append(f"名称: {name}")
|
|
else:
|
|
parts.append(f"第{page}页")
|
|
|
|
return " | ".join(parts)
|
|
|
|
async def run(
|
|
self,
|
|
media_type: Optional[str] = "all",
|
|
name: Optional[str] = None,
|
|
page: Optional[int] = 1,
|
|
**kwargs,
|
|
) -> str:
|
|
page = max(1, page or 1)
|
|
logger.info(
|
|
f"执行工具: {self.name}, 参数: media_type={media_type}, name={name}, page={page}"
|
|
)
|
|
|
|
try:
|
|
if media_type not in ["all", "movie", "tv"]:
|
|
return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv', 'all'"
|
|
|
|
# 获取数据库会话
|
|
async with AsyncSessionFactory() as db:
|
|
if name:
|
|
# 有名称过滤时,获取足够多的记录在内存中过滤,不分页
|
|
fetch_count = 500
|
|
if media_type == "all":
|
|
movie_history = await SubscribeHistory.async_list_by_type(
|
|
db, mtype="movie", page=1, count=fetch_count
|
|
)
|
|
tv_history = await SubscribeHistory.async_list_by_type(
|
|
db, mtype="tv", page=1, count=fetch_count
|
|
)
|
|
all_history = list(movie_history) + list(tv_history)
|
|
all_history.sort(key=lambda x: x.date or "", reverse=True)
|
|
else:
|
|
all_history = list(
|
|
await SubscribeHistory.async_list_by_type(
|
|
db, mtype=media_type, page=1, count=fetch_count
|
|
)
|
|
)
|
|
|
|
# 按名称过滤
|
|
name_lower = name.lower()
|
|
filtered_history = [
|
|
record
|
|
for record in all_history
|
|
if record.name and name_lower in record.name.lower()
|
|
]
|
|
|
|
if not filtered_history:
|
|
return "未找到相关订阅历史记录"
|
|
|
|
# 名称过滤时直接返回所有匹配结果,不分页
|
|
simplified_records = self._simplify_records(filtered_history)
|
|
result_json = json.dumps(
|
|
simplified_records, ensure_ascii=False, indent=2
|
|
)
|
|
return result_json
|
|
else:
|
|
# 无名称过滤时,直接利用数据库分页
|
|
if media_type == "all":
|
|
movie_history = await SubscribeHistory.async_list_by_type(
|
|
db, mtype="movie", page=1, count=page * PAGE_SIZE
|
|
)
|
|
tv_history = await SubscribeHistory.async_list_by_type(
|
|
db, mtype="tv", page=1, count=page * PAGE_SIZE
|
|
)
|
|
all_history = list(movie_history) + list(tv_history)
|
|
all_history.sort(key=lambda x: x.date or "", reverse=True)
|
|
filtered_history = all_history
|
|
else:
|
|
filtered_history = list(
|
|
await SubscribeHistory.async_list_by_type(
|
|
db, mtype=media_type, page=1, count=page * PAGE_SIZE
|
|
)
|
|
)
|
|
|
|
if not filtered_history:
|
|
return "未找到相关订阅历史记录"
|
|
|
|
# 分页切片
|
|
total_count = len(filtered_history)
|
|
start = (page - 1) * PAGE_SIZE
|
|
end = start + PAGE_SIZE
|
|
page_records = filtered_history[start:end]
|
|
|
|
if not page_records:
|
|
return f"第 {page} 页没有数据。"
|
|
|
|
simplified_records = self._simplify_records(page_records)
|
|
result_json = json.dumps(
|
|
simplified_records, ensure_ascii=False, indent=2
|
|
)
|
|
|
|
has_more = total_count > end
|
|
payload_msg = f"第 {page} 页,当前页 {len(simplified_records)} 条结果。"
|
|
if has_more:
|
|
payload_msg += (
|
|
f" 可能有更多数据,可使用 page={page + 1} 获取下一页。"
|
|
)
|
|
|
|
return f"{payload_msg}\n\n{result_json}"
|
|
except Exception as e:
|
|
logger.error(f"查询订阅历史失败: {e}", exc_info=True)
|
|
return f"查询订阅历史时发生错误: {str(e)}"
|
|
|
|
@staticmethod
|
|
def _simplify_records(records) -> list:
|
|
"""转换为字典格式,只保留关键信息"""
|
|
simplified_records = []
|
|
for record in records:
|
|
simplified = {
|
|
"id": record.id,
|
|
"name": record.name,
|
|
"year": record.year,
|
|
"type": media_type_to_agent(record.type),
|
|
"season": record.season,
|
|
"tmdbid": record.tmdbid,
|
|
"doubanid": record.doubanid,
|
|
"bangumiid": record.bangumiid,
|
|
"poster": record.poster,
|
|
"vote": record.vote,
|
|
"total_episode": record.total_episode,
|
|
"date": record.date,
|
|
"username": record.username,
|
|
}
|
|
if record.filter:
|
|
simplified["filter"] = record.filter
|
|
if record.quality:
|
|
simplified["quality"] = record.quality
|
|
if record.resolution:
|
|
simplified["resolution"] = record.resolution
|
|
simplified_records.append(simplified)
|
|
return simplified_records
|