fix agent tools

This commit is contained in:
jxxghp
2025-11-17 12:34:20 +08:00
parent 076fae696c
commit c6806ee648
10 changed files with 20 additions and 21 deletions

View File

@@ -37,13 +37,11 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
# 发送运行工具前的消息
agent_message = await self._callback_handler.get_message()
if agent_message:
await self.send_tool_message(agent_message)
await self.send_tool_message(agent_message, title="MoviePilot助手")
# 发送执行工具说明
explanation = kwargs.get("explanation")
if explanation:
if not explanation.startswith("正在"):
explanation = "正在" + explanation
await self.send_tool_message(f"{explanation} ...")
await self.send_tool_message(f"⏯️[{explanation}]")
return await self.run(**kwargs)
@abstractmethod

View File

@@ -58,7 +58,8 @@ class AddDownloadTool(MoviePilotTool):
# 创建下载上下文
torrent_info = TorrentInfo(
title=torrent_title,
download_url=torrent_url,
description=torrent_description,
enclosure=torrent_url,
site_name=site_name,
site_ua=siteinfo.ua,
site_cookie=siteinfo.cookie,
@@ -67,7 +68,7 @@ class AddDownloadTool(MoviePilotTool):
site_downloader=siteinfo.downloader
)
meta_info = MetaInfo(title=torrent_title, subtitle=torrent_description)
media_info = ToolChain().recognize_media(meta=meta_info)
media_info = await ToolChain().async_recognize_media(meta=meta_info)
if not media_info:
return "错误:无法识别媒体信息,无法添加下载任务"
context = Context(

View File

@@ -43,7 +43,7 @@ class AddSubscribeTool(MoviePilotTool):
except (ValueError, TypeError):
logger.warning(f"无效的 tmdb_id: {tmdb_id},将忽略")
sid, message = subscribe_chain.add(
sid, message = await subscribe_chain.async_add(
mtype=MediaType(media_type),
title=title,
year=year,

View File

@@ -38,17 +38,17 @@ class GetRecommendationsTool(MoviePilotTool):
recommend_chain = RecommendChain()
results = []
if source == "tmdb_trending":
results = recommend_chain.tmdb_trending(limit=limit)
results = await recommend_chain.async_tmdb_trending(limit=limit)
elif source == "douban_hot":
if media_type == "movie":
results = recommend_chain.douban_movie_hot(limit=limit)
results = await recommend_chain.async_douban_movie_hot(limit=limit)
elif media_type == "tv":
results = recommend_chain.douban_tv_hot(limit=limit)
results = await recommend_chain.async_douban_tv_hot(limit=limit)
else: # all
results.extend(recommend_chain.douban_movie_hot(limit=limit))
results.extend(recommend_chain.douban_tv_hot(limit=limit))
results.extend(await recommend_chain.async_douban_movie_hot(limit=limit))
results.extend(await recommend_chain.async_douban_tv_hot(limit=limit))
elif source == "bangumi_calendar":
results = recommend_chain.bangumi_calendar(limit=limit)
results = await recommend_chain.async_bangumi_calendar(limit=limit)
if results:
# 限制最多20条结果

View File

@@ -74,7 +74,7 @@ class QueryDownloadsTool(MoviePilotTool):
if total_count > 20:
return f"注意:查询结果共找到 {total_count} 条,为节省上下文空间,仅显示前 20 条结果。\n\n{result_json}"
return result_json
return "未找到相关下载任务"
return "未找到相关下载任务"
except Exception as e:
logger.error(f"查询下载失败: {e}", exc_info=True)
return f"查询下载时发生错误: {str(e)}"

View File

@@ -32,10 +32,10 @@ class QueryMediaLibraryTool(MoviePilotTool):
logger.info(f"执行工具: {self.name}, 参数: media_type={media_type}, title={title}")
try:
media_server_oper = MediaServerOper()
filtered_medias: List[MediaServerItem] = media_server_oper.exists(title=title, year=year, mtype=media_type)
filtered_medias: List[MediaServerItem] = await media_server_oper.async_exists(title=title, year=year, mtype=media_type)
if filtered_medias:
return json.dumps([m.to_dict() for m in filtered_medias])
return "媒体库中未找到相关媒体"
return "媒体库中未找到相关媒体"
except Exception as e:
logger.error(f"查询媒体库失败: {e}", exc_info=True)
return f"查询媒体库时发生错误: {str(e)}"

View File

@@ -29,7 +29,7 @@ class QuerySitesTool(MoviePilotTool):
try:
site_oper = SiteOper()
# 获取所有站点(按优先级排序)
sites = site_oper.list_order_by_pri()
sites = await site_oper.async_list()
filtered_sites = []
for site in sites:
# 按状态过滤
@@ -59,7 +59,7 @@ class QuerySitesTool(MoviePilotTool):
simplified_sites.append(simplified)
result_json = json.dumps(simplified_sites, ensure_ascii=False, indent=2)
return result_json
return "未找到相关站点"
return "未找到相关站点"
except Exception as e:
logger.error(f"查询站点失败: {e}", exc_info=True)
return f"查询站点时发生错误: {str(e)}"

View File

@@ -28,7 +28,7 @@ class QuerySubscribesTool(MoviePilotTool):
logger.info(f"执行工具: {self.name}, 参数: status={status}, media_type={media_type}")
try:
subscribe_oper = SubscribeOper()
subscribes = subscribe_oper.list()
subscribes = await subscribe_oper.async_list()
filtered_subscribes = []
for sub in subscribes:
if status != "all" and sub.state != status:

View File

@@ -44,7 +44,7 @@ class SearchMediaTool(MoviePilotTool):
search_title = f"{search_title} S{season:02d}"
# 使用 MediaChain.search 方法
meta, results = media_chain.search(title=search_title)
meta, results = await media_chain.async_search(title=search_title)
# 过滤结果
if results:

View File

@@ -41,7 +41,7 @@ class SearchTorrentsTool(MoviePilotTool):
try:
search_chain = SearchChain()
torrents = search_chain.search_by_title(title=title, sites=sites)
torrents = await search_chain.async_search_by_title(title=title, sites=sites)
filtered_torrents = []
# 编译正则表达式(如果提供)
regex_pattern = None