feat: enhance GetRecommendationsTool and update query tools for improved functionality

- Expanded the GetRecommendationsTool to support additional recommendation sources, including TMDB popular movies and TV shows, as well as various Douban categories.
- Updated the limit for results in QuerySubscribesTool, SearchMediaTool, and QueryTransferHistoryTool from 20 to 50 or 30, respectively, to provide more comprehensive results.
- Removed unnecessary description fields from media objects in QueryPopularSubscribesTool, QuerySubscribeHistoryTool, and QuerySubscribeSharesTool for cleaner output.
This commit is contained in:
jxxghp
2025-11-18 16:21:13 +08:00
parent a8c6516b31
commit 21fabf7436
7 changed files with 75 additions and 18 deletions

View File

@@ -14,7 +14,21 @@ class GetRecommendationsInput(BaseModel):
"""获取推荐工具的输入参数模型""" """获取推荐工具的输入参数模型"""
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context") explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
source: Optional[str] = Field("tmdb_trending", source: Optional[str] = Field("tmdb_trending",
description="Recommendation source: 'tmdb_trending' for TMDB trending content, 'douban_hot' for Douban popular content, 'bangumi_calendar' for Bangumi anime calendar") description="Recommendation source: "
"'tmdb_trending' for TMDB trending content, "
"'tmdb_movies' for TMDB popular movies, "
"'tmdb_tvs' for TMDB popular TV shows, "
"'douban_hot' for Douban popular content, "
"'douban_movie_hot' for Douban hot movies, "
"'douban_tv_hot' for Douban hot TV shows, "
"'douban_movie_showing' for Douban movies currently showing, "
"'douban_movies' for Douban latest movies, "
"'douban_tvs' for Douban latest TV shows, "
"'douban_movie_top250' for Douban movie TOP250, "
"'douban_tv_weekly_chinese' for Douban Chinese TV weekly chart, "
"'douban_tv_weekly_global' for Douban global TV weekly chart, "
"'douban_tv_animation' for Douban popular animation, "
"'bangumi_calendar' for Bangumi anime calendar")
media_type: Optional[str] = Field("all", media_type: Optional[str] = Field("all",
description="Type of media content: '电影' for films, '电视剧' for television series or anime series, 'all' for all types") description="Type of media content: '电影' for films, '电视剧' for television series or anime series, 'all' for all types")
limit: Optional[int] = Field(20, limit: Optional[int] = Field(20,
@@ -33,8 +47,19 @@ class GetRecommendationsTool(MoviePilotTool):
limit = kwargs.get("limit", 20) limit = kwargs.get("limit", 20)
source_map = { source_map = {
"tmdb_trending": "TMDB热门", "tmdb_trending": "TMDB流行趋势",
"tmdb_movies": "TMDB热门电影",
"tmdb_tvs": "TMDB热门电视剧",
"douban_hot": "豆瓣热门", "douban_hot": "豆瓣热门",
"douban_movie_hot": "豆瓣热门电影",
"douban_tv_hot": "豆瓣热门电视剧",
"douban_movie_showing": "豆瓣正在热映",
"douban_movies": "豆瓣最新电影",
"douban_tvs": "豆瓣最新电视剧",
"douban_movie_top250": "豆瓣电影TOP250",
"douban_tv_weekly_chinese": "豆瓣国产剧集榜",
"douban_tv_weekly_global": "豆瓣全球剧集榜",
"douban_tv_animation": "豆瓣热门动漫",
"bangumi_calendar": "番组计划" "bangumi_calendar": "番组计划"
} }
source_desc = source_map.get(source, source) source_desc = source_map.get(source, source)
@@ -58,8 +83,17 @@ class GetRecommendationsTool(MoviePilotTool):
results = await recommend_chain.async_tmdb_trending(page=1) results = await recommend_chain.async_tmdb_trending(page=1)
if limit and limit > 0: if limit and limit > 0:
results = results[:limit] results = results[:limit]
elif source == "tmdb_movies":
# async_tmdb_movies 接受 page 参数,返回固定数量的结果
results = await recommend_chain.async_tmdb_movies(page=1)
if limit and limit > 0:
results = results[:limit]
elif source == "tmdb_tvs":
# async_tmdb_tvs 接受 page 参数,返回固定数量的结果
results = await recommend_chain.async_tmdb_tvs(page=1)
if limit and limit > 0:
results = results[:limit]
elif source == "douban_hot": elif source == "douban_hot":
# async_douban_movie_hot 和 async_douban_tv_hot 接受 page 和 count 参数
if media_type == "movie": if media_type == "movie":
results = await recommend_chain.async_douban_movie_hot(page=1, count=limit) results = await recommend_chain.async_douban_movie_hot(page=1, count=limit)
elif media_type == "tv": elif media_type == "tv":
@@ -67,9 +101,37 @@ class GetRecommendationsTool(MoviePilotTool):
else: # all else: # all
results.extend(await recommend_chain.async_douban_movie_hot(page=1, count=limit)) results.extend(await recommend_chain.async_douban_movie_hot(page=1, count=limit))
results.extend(await recommend_chain.async_douban_tv_hot(page=1, count=limit)) results.extend(await recommend_chain.async_douban_tv_hot(page=1, count=limit))
elif source == "douban_movie_hot":
results = await recommend_chain.async_douban_movie_hot(page=1, count=limit)
elif source == "douban_tv_hot":
results = await recommend_chain.async_douban_tv_hot(page=1, count=limit)
elif source == "douban_movie_showing":
results = await recommend_chain.async_douban_movie_showing(page=1, count=limit)
elif source == "douban_movies":
results = await recommend_chain.async_douban_movies(page=1, count=limit)
elif source == "douban_tvs":
results = await recommend_chain.async_douban_tvs(page=1, count=limit)
elif source == "douban_movie_top250":
results = await recommend_chain.async_douban_movie_top250(page=1, count=limit)
elif source == "douban_tv_weekly_chinese":
results = await recommend_chain.async_douban_tv_weekly_chinese(page=1, count=limit)
elif source == "douban_tv_weekly_global":
results = await recommend_chain.async_douban_tv_weekly_global(page=1, count=limit)
elif source == "douban_tv_animation":
results = await recommend_chain.async_douban_tv_animation(page=1, count=limit)
elif source == "bangumi_calendar": elif source == "bangumi_calendar":
# async_bangumi_calendar 接受 page 和 count 参数
results = await recommend_chain.async_bangumi_calendar(page=1, count=limit) results = await recommend_chain.async_bangumi_calendar(page=1, count=limit)
else:
# 不支持的推荐来源
supported_sources = [
"tmdb_trending", "tmdb_movies", "tmdb_tvs",
"douban_hot", "douban_movie_hot", "douban_tv_hot",
"douban_movie_showing", "douban_movies", "douban_tvs",
"douban_movie_top250", "douban_tv_weekly_chinese",
"douban_tv_weekly_global", "douban_tv_animation",
"bangumi_calendar"
]
return f"不支持的推荐来源: {source}。支持的来源包括: {', '.join(supported_sources)}"
if results: if results:
# 限制最多20条结果 # 限制最多20条结果

View File

@@ -110,7 +110,6 @@ class QueryPopularSubscribesTool(MoviePilotTool):
media.tvdb_id = sub.get("tvdbid") media.tvdb_id = sub.get("tvdbid")
media.imdb_id = sub.get("imdbid") media.imdb_id = sub.get("imdbid")
media.season = sub.get("season") media.season = sub.get("season")
media.overview = sub.get("description")
media.vote_average = sub.get("vote") media.vote_average = sub.get("vote")
media.poster_path = sub.get("poster") media.poster_path = sub.get("poster")
media.backdrop_path = sub.get("backdrop") media.backdrop_path = sub.get("backdrop")
@@ -134,7 +133,6 @@ class QueryPopularSubscribesTool(MoviePilotTool):
"tvdb_id": media_dict.get("tvdb_id"), "tvdb_id": media_dict.get("tvdb_id"),
"imdb_id": media_dict.get("imdb_id"), "imdb_id": media_dict.get("imdb_id"),
"season": media_dict.get("season"), "season": media_dict.get("season"),
"overview": media_dict.get("overview"),
"vote_average": media_dict.get("vote_average"), "vote_average": media_dict.get("vote_average"),
"poster_path": media_dict.get("poster_path"), "poster_path": media_dict.get("poster_path"),
"backdrop_path": media_dict.get("backdrop_path"), "backdrop_path": media_dict.get("backdrop_path"),

View File

@@ -87,7 +87,6 @@ class QuerySubscribeHistoryTool(MoviePilotTool):
"bangumiid": record.bangumiid, "bangumiid": record.bangumiid,
"poster": record.poster, "poster": record.poster,
"vote": record.vote, "vote": record.vote,
"description": record.description[:200] + "..." if record.description and len(record.description) > 200 else record.description,
"total_episode": record.total_episode, "total_episode": record.total_episode,
"date": record.date, "date": record.date,
"username": record.username "username": record.username

View File

@@ -92,7 +92,6 @@ class QuerySubscribeSharesTool(MoviePilotTool):
"bangumiid": share.get("bangumiid"), "bangumiid": share.get("bangumiid"),
"poster": share.get("poster"), "poster": share.get("poster"),
"vote": share.get("vote"), "vote": share.get("vote"),
"description": share.get("description"),
"share_title": share.get("share_title"), "share_title": share.get("share_title"),
"share_comment": share.get("share_comment"), "share_comment": share.get("share_comment"),
"share_user": share.get("share_user"), "share_user": share.get("share_user"),

View File

@@ -55,9 +55,9 @@ class QuerySubscribesTool(MoviePilotTool):
continue continue
filtered_subscribes.append(sub) filtered_subscribes.append(sub)
if filtered_subscribes: if filtered_subscribes:
# 限制最多20条结果 # 限制最多50条结果
total_count = len(filtered_subscribes) total_count = len(filtered_subscribes)
limited_subscribes = filtered_subscribes[:20] limited_subscribes = filtered_subscribes[:50]
# 精简字段,只保留关键信息 # 精简字段,只保留关键信息
simplified_subscribes = [] simplified_subscribes = []
for s in limited_subscribes: for s in limited_subscribes:
@@ -72,7 +72,6 @@ class QuerySubscribesTool(MoviePilotTool):
"bangumiid": s.bangumiid, "bangumiid": s.bangumiid,
"poster": s.poster, "poster": s.poster,
"vote": s.vote, "vote": s.vote,
"description": s.description[:200] + "..." if s.description and len(s.description) > 200 else s.description,
"state": s.state, "state": s.state,
"total_episode": s.total_episode, "total_episode": s.total_episode,
"lack_episode": s.lack_episode, "lack_episode": s.lack_episode,
@@ -82,8 +81,8 @@ class QuerySubscribesTool(MoviePilotTool):
simplified_subscribes.append(simplified) simplified_subscribes.append(simplified)
result_json = json.dumps(simplified_subscribes, ensure_ascii=False, indent=2) result_json = json.dumps(simplified_subscribes, ensure_ascii=False, indent=2)
# 如果结果被裁剪,添加提示信息 # 如果结果被裁剪,添加提示信息
if total_count > 20: if total_count > 50:
return f"注意:查询结果共找到 {total_count} 条,为节省上下文空间,仅显示前 20 条结果。\n\n{result_json}" return f"注意:查询结果共找到 {total_count} 条,为节省上下文空间,仅显示前 50 条结果。\n\n{result_json}"
return result_json return result_json
return "未找到相关订阅" return "未找到相关订阅"
except Exception as e: except Exception as e:

View File

@@ -62,7 +62,7 @@ class QueryTransferHistoryTool(MoviePilotTool):
page = 1 page = 1
# 每页记录数 # 每页记录数
count = 30 count = 50
# 获取数据库会话 # 获取数据库会话
async with AsyncSessionFactory() as db: async with AsyncSessionFactory() as db:

View File

@@ -77,9 +77,9 @@ class SearchMediaTool(MoviePilotTool):
filtered_results.append(result) filtered_results.append(result)
if filtered_results: if filtered_results:
# 限制最多20条结果 # 限制最多30条结果
total_count = len(filtered_results) total_count = len(filtered_results)
limited_results = filtered_results[:20] limited_results = filtered_results[:30]
# 精简字段,只保留关键信息 # 精简字段,只保留关键信息
simplified_results = [] simplified_results = []
for r in limited_results: for r in limited_results:
@@ -100,8 +100,8 @@ class SearchMediaTool(MoviePilotTool):
simplified_results.append(simplified) simplified_results.append(simplified)
result_json = json.dumps(simplified_results, ensure_ascii=False, indent=2) result_json = json.dumps(simplified_results, ensure_ascii=False, indent=2)
# 如果结果被裁剪,添加提示信息 # 如果结果被裁剪,添加提示信息
if total_count > 20: if total_count > 30:
return f"注意:搜索结果共找到 {total_count} 条,为节省上下文空间,仅显示前 20 条结果。\n\n{result_json}" return f"注意:搜索结果共找到 {total_count} 条,为节省上下文空间,仅显示前 30 条结果。\n\n{result_json}"
return result_json return result_json
else: else:
return f"未找到符合条件的媒体资源: {title}" return f"未找到符合条件的媒体资源: {title}"