fix agent

This commit is contained in:
jxxghp
2025-11-01 11:41:22 +08:00
parent 182c46037b
commit 247208b8a9
10 changed files with 114 additions and 121 deletions

View File

@@ -249,7 +249,7 @@ class MoviePilotAgent:
agent_message = await self.callback_handler.get_message() agent_message = await self.callback_handler.get_message()
# 发送Agent回复给用户通过原渠道 # 发送Agent回复给用户通过原渠道
self._send_message_to_channel(agent_message) self.send_agent_message(agent_message)
# 添加Agent回复到记忆 # 添加Agent回复到记忆
await self.memory_manager.add_memory( await self.memory_manager.add_memory(
@@ -265,22 +265,9 @@ class MoviePilotAgent:
error_message = f"处理消息时发生错误: {str(e)}" error_message = f"处理消息时发生错误: {str(e)}"
logger.error(error_message) logger.error(error_message)
# 发送错误消息给用户(通过原渠道) # 发送错误消息给用户(通过原渠道)
self._send_message_to_channel(error_message) self.send_agent_message(error_message)
return error_message return error_message
def _send_message_to_channel(self, message: str, title: str = "MoviePilot助手"):
"""通过原渠道发送消息给用户"""
AgentChain().post_message(
Notification(
channel=self.channel,
source=self.source,
userid=self.user_id,
username=self.username,
title=title,
text=message
)
)
async def _execute_agent(self, input_context: Dict[str, Any]) -> Dict[str, Any]: async def _execute_agent(self, input_context: Dict[str, Any]) -> Dict[str, Any]:
"""执行LangChain Agent""" """执行LangChain Agent"""
try: try:
@@ -314,6 +301,19 @@ class MoviePilotAgent:
"token_usage": {} "token_usage": {}
} }
def send_agent_message(self, message: str, title: str = "MoviePilot助手"):
"""通过原渠道发送消息给用户"""
AgentChain().post_message(
Notification(
channel=self.channel,
source=self.source,
userid=self.user_id,
username=self.username,
title=title,
text=message
)
)
async def cleanup(self): async def cleanup(self):
"""清理智能体资源""" """清理智能体资源"""
if self.session_id in self.session_store: if self.session_id in self.session_store:

View File

@@ -257,6 +257,10 @@
"title": { "title": {
"description": "Specific media title to search for (optional, if provided returns detailed info for that specific media)", "description": "Specific media title to search for (optional, if provided returns detailed info for that specific media)",
"type": "string" "type": "string"
},
"year": {
"description": "Release year of the media (optional, helps narrow down search results)",
"type": "string"
} }
}, },
"required": [ "required": [

View File

@@ -4,8 +4,6 @@ from langchain.tools import BaseTool
from pydantic import PrivateAttr from pydantic import PrivateAttr
from app.chain import ChainBase from app.chain import ChainBase
from app.helper.message import MessageHelper
from app.log import logger
from app.schemas import Notification from app.schemas import Notification
@@ -18,17 +16,14 @@ class MoviePilotTool(BaseTool):
_session_id: str = PrivateAttr() _session_id: str = PrivateAttr()
_user_id: str = PrivateAttr() _user_id: str = PrivateAttr()
_message_helper: MessageHelper = PrivateAttr() _channel: str = PrivateAttr(default=None)
_source: str = PrivateAttr(default=None)
_username: str = PrivateAttr(default=None)
def __init__(self, session_id: str, user_id: str, def __init__(self, session_id: str, user_id: str, **kwargs):
channel: str = None, source: str = None, username: str = None, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self._session_id = session_id self._session_id = session_id
self._user_id = user_id self._user_id = user_id
self.channel = channel
self.source = source
self.username = username
self._message_helper = MessageHelper()
def _run(self, **kwargs) -> str: def _run(self, **kwargs) -> str:
raise NotImplementedError raise NotImplementedError
@@ -36,8 +31,14 @@ class MoviePilotTool(BaseTool):
async def _arun(self, **kwargs) -> str: async def _arun(self, **kwargs) -> str:
raise NotImplementedError raise NotImplementedError
def _send_tool_message(self, message: str, title: str = None, **kwargs): def set_message_attr(self, channel: str, source: str, username: str):
"""发送工具执行消息""" """设置消息属性"""
self._channel = channel
self._source = source
self._username = username
def send_tool_message(self, message: str, title: str = "执行工具"):
"""发送工具消息"""
ToolChain().post_message( ToolChain().post_message(
Notification( Notification(
channel=self.channel, channel=self.channel,

View File

@@ -37,12 +37,11 @@ class MoviePilotToolFactory:
SendMessageTool SendMessageTool
] ]
for ToolClass in tool_definitions: for ToolClass in tool_definitions:
tools.append(ToolClass( tool = ToolClass(
session_id=session_id, session_id=session_id,
user_id=user_id, user_id=user_id
channel=channel, )
source=source, tool.set_message_attr(channel=channel, source=source, username=username)
username=username tools.append(tool)
))
logger.info(f"成功创建 {len(tools)} 个MoviePilot工具") logger.info(f"成功创建 {len(tools)} 个MoviePilot工具")
return tools return tools

View File

@@ -2,12 +2,12 @@
from typing import Optional from typing import Optional
from app.agent.tools.base import MoviePilotTool
from app.chain.download import DownloadChain from app.chain.download import DownloadChain
from app.core.context import Context from app.core.context import Context
from app.core.metainfo import MetaInfo from app.core.metainfo import MetaInfo
from app.log import logger from app.log import logger
from app.schemas import TorrentInfo from app.schemas import TorrentInfo
from app.agent.tools.base import MoviePilotTool
class AddDownloadTool(MoviePilotTool): class AddDownloadTool(MoviePilotTool):
@@ -17,15 +17,16 @@ class AddDownloadTool(MoviePilotTool):
async def _arun(self, torrent_title: str, torrent_url: str, explanation: str, async def _arun(self, torrent_title: str, torrent_url: str, explanation: str,
downloader: Optional[str] = None, save_path: Optional[str] = None, downloader: Optional[str] = None, save_path: Optional[str] = None,
labels: Optional[str] = None, **kwargs) -> str: labels: Optional[str] = None, **kwargs) -> str:
logger.info(f"执行工具: {self.name}, 参数: torrent_title={torrent_title}, torrent_url={torrent_url}, downloader={downloader}, save_path={save_path}, labels={labels}") logger.info(
f"执行工具: {self.name}, 参数: torrent_title={torrent_title}, torrent_url={torrent_url}, downloader={downloader}, save_path={save_path}, labels={labels}")
# 发送工具执行说明 # 发送工具执行说明
self._send_tool_message(f"正在添加下载任务: {torrent_title}", title="添加下载") self.send_tool_message(f"正在添加下载任务: {torrent_title}", title="添加下载")
try: try:
if not torrent_title or not torrent_url: if not torrent_title or not torrent_url:
error_message = "错误:必须提供种子标题和下载链接" error_message = "错误:必须提供种子标题和下载链接"
self._send_tool_message(error_message, title="下载失败") self.send_tool_message(error_message, title="下载失败")
return error_message return error_message
# 使用DownloadChain添加下载 # 使用DownloadChain添加下载
@@ -50,14 +51,14 @@ class AddDownloadTool(MoviePilotTool):
) )
if did: if did:
success_message = f"成功添加下载任务:{torrent_title}" success_message = f"成功添加下载任务:{torrent_title}"
self._send_tool_message(success_message, title="下载成功") self.send_tool_message(success_message, title="下载成功")
return success_message return success_message
else: else:
error_message = "添加下载任务失败" error_message = "添加下载任务失败"
self._send_tool_message(error_message, title="下载失败") self.send_tool_message(error_message, title="下载失败")
return error_message return error_message
except Exception as e: except Exception as e:
error_message = f"添加下载任务时发生错误: {str(e)}" error_message = f"添加下载任务时发生错误: {str(e)}"
logger.error(f"添加下载任务失败: {e}", exc_info=True) logger.error(f"添加下载任务失败: {e}", exc_info=True)
self._send_tool_message(error_message, title="下载失败") self.send_tool_message(error_message, title="下载失败")
return error_message return error_message

View File

@@ -2,10 +2,10 @@
from typing import Optional from typing import Optional
from app.agent.tools.base import MoviePilotTool
from app.chain.subscribe import SubscribeChain from app.chain.subscribe import SubscribeChain
from app.log import logger from app.log import logger
from app.schemas.types import MediaType from app.schemas.types import MediaType
from app.agent.tools.base import MoviePilotTool
class AddSubscribeTool(MoviePilotTool): class AddSubscribeTool(MoviePilotTool):
@@ -14,10 +14,11 @@ class AddSubscribeTool(MoviePilotTool):
async def _arun(self, title: str, year: str, media_type: str, explanation: str, async def _arun(self, title: str, year: str, media_type: str, explanation: str,
season: Optional[int] = None, tmdb_id: Optional[str] = None, **kwargs) -> str: season: Optional[int] = None, tmdb_id: Optional[str] = None, **kwargs) -> str:
logger.info(f"执行工具: {self.name}, 参数: title={title}, year={year}, media_type={media_type}, season={season}, tmdb_id={tmdb_id}") logger.info(
f"执行工具: {self.name}, 参数: title={title}, year={year}, media_type={media_type}, season={season}, tmdb_id={tmdb_id}")
# 发送工具执行说明 # 发送工具执行说明
self._send_tool_message(f"正在添加订阅: {title} ({year}) - {media_type}", title="添加订阅") self.send_tool_message(f"正在添加订阅: {title} ({year}) - {media_type}", title="添加订阅")
try: try:
subscribe_chain = SubscribeChain() subscribe_chain = SubscribeChain()
@@ -39,14 +40,14 @@ class AddSubscribeTool(MoviePilotTool):
) )
if sid: if sid:
success_message = f"成功添加订阅:{title} ({year})" success_message = f"成功添加订阅:{title} ({year})"
self._send_tool_message(success_message, title="订阅成功") self.send_tool_message(success_message, title="订阅成功")
return success_message return success_message
else: else:
error_message = f"添加订阅失败:{message}" error_message = f"添加订阅失败:{message}"
self._send_tool_message(error_message, title="订阅失败") self.send_tool_message(error_message, title="订阅失败")
return error_message return error_message
except Exception as e: except Exception as e:
error_message = f"添加订阅时发生错误: {str(e)}" error_message = f"添加订阅时发生错误: {str(e)}"
logger.error(f"添加订阅失败: {e}", exc_info=True) logger.error(f"添加订阅失败: {e}", exc_info=True)
self._send_tool_message(error_message, title="订阅失败") self.send_tool_message(error_message, title="订阅失败")
return error_message return error_message

View File

@@ -1,11 +1,12 @@
"""查询媒体库工具""" """查询媒体库工具"""
import json import json
from typing import Optional from typing import Optional, List
from app.agent.tools.base import MoviePilotTool
from app.db.mediaserver_oper import MediaServerOper from app.db.mediaserver_oper import MediaServerOper
from app.log import logger from app.log import logger
from app.agent.tools.base import MoviePilotTool from app.schemas import MediaServerItem
class QueryMediaLibraryTool(MoviePilotTool): class QueryMediaLibraryTool(MoviePilotTool):
@@ -13,18 +14,11 @@ class QueryMediaLibraryTool(MoviePilotTool):
description: str = "查询媒体库状态,查看已入库的媒体文件情况。" description: str = "查询媒体库状态,查看已入库的媒体文件情况。"
async def _arun(self, explanation: str, media_type: Optional[str] = "all", async def _arun(self, explanation: str, media_type: Optional[str] = "all",
title: Optional[str] = None, **kwargs) -> str: title: Optional[str] = None, year: Optional[str] = None, **kwargs) -> str:
logger.info(f"执行工具: {self.name}, 参数: media_type={media_type}, title={title}") logger.info(f"执行工具: {self.name}, 参数: media_type={media_type}, title={title}")
try: try:
media_server_oper = MediaServerOper() media_server_oper = MediaServerOper()
medias = media_server_oper.list() filtered_medias: List[MediaServerItem] = media_server_oper.exists(title=title, year=year, mtype=media_type)
filtered_medias = []
for media in medias:
if media_type != "all" and media.type != media_type:
continue
if title and title.lower() not in media.title.lower():
continue
filtered_medias.append(media)
if filtered_medias: if filtered_medias:
return json.dumps([m.to_dict() for m in filtered_medias]) return json.dumps([m.to_dict() for m in filtered_medias])
return "媒体库中未找到相关媒体。" return "媒体库中未找到相关媒体。"

View File

@@ -3,11 +3,10 @@
import json import json
from typing import Optional from typing import Optional
from app.agent.tools.base import MoviePilotTool
from app.chain.media import MediaChain from app.chain.media import MediaChain
from app.core.metainfo import MetaInfo
from app.log import logger from app.log import logger
from app.schemas.types import MediaType from app.schemas.types import MediaType
from app.agent.tools.base import MoviePilotTool
class SearchMediaTool(MoviePilotTool): class SearchMediaTool(MoviePilotTool):
@@ -16,10 +15,11 @@ class SearchMediaTool(MoviePilotTool):
async def _arun(self, title: str, explanation: str, year: Optional[str] = None, async def _arun(self, title: str, explanation: str, year: Optional[str] = None,
media_type: Optional[str] = None, season: Optional[int] = None, **kwargs) -> str: media_type: Optional[str] = None, season: Optional[int] = None, **kwargs) -> str:
logger.info(f"执行工具: {self.name}, 参数: title={title}, year={year}, media_type={media_type}, season={season}") logger.info(
f"执行工具: {self.name}, 参数: title={title}, year={year}, media_type={media_type}, season={season}")
# 发送工具执行说明 # 发送工具执行说明
self._send_tool_message(f"正在搜索媒体资源: {title}" + (f" ({year})" if year else ""), title="搜索中") self.send_tool_message(f"正在搜索媒体资源: {title}" + (f" ({year})" if year else ""), title="搜索中")
try: try:
media_chain = MediaChain() media_chain = MediaChain()
@@ -42,35 +42,32 @@ class SearchMediaTool(MoviePilotTool):
if year and result.year != year: if year and result.year != year:
continue continue
if media_type: if media_type:
try: if result.type != MediaType(media_type):
if result.type != MediaType(media_type): continue
continue
except:
pass
if season and result.season != season: if season and result.season != season:
continue continue
filtered_results.append(result) filtered_results.append(result)
if filtered_results: if filtered_results:
result_message = f"找到 {len(filtered_results)} 个相关媒体资源" result_message = f"找到 {len(filtered_results)} 个相关媒体资源"
self._send_tool_message(result_message, title="搜索成功") self.send_tool_message(result_message, title="搜索成功")
# 发送详细结果 # 发送详细结果
for i, result in enumerate(filtered_results[:5]): # 只显示前5个结果 for i, result in enumerate(filtered_results[:5]): # 只显示前5个结果
media_info = f"{i+1}. {result.title} ({result.year}) - {result.type.value if result.type else '未知'}" media_info = f"{i + 1}. {result.title} ({result.year}) - {result.type.value if result.type else '未知'}"
self._send_tool_message(media_info, title="搜索结果") self.send_tool_message(media_info, title="搜索结果")
return json.dumps([r.to_dict() for r in filtered_results], ensure_ascii=False, indent=2) return json.dumps([r.to_dict() for r in filtered_results], ensure_ascii=False, indent=2)
else: else:
error_message = f"未找到符合条件的媒体资源: {title}" error_message = f"未找到符合条件的媒体资源: {title}"
self._send_tool_message(error_message, title="搜索完成") self.send_tool_message(error_message, title="搜索完成")
return error_message return error_message
else: else:
error_message = f"未找到相关媒体资源: {title}" error_message = f"未找到相关媒体资源: {title}"
self._send_tool_message(error_message, title="搜索完成") self.send_tool_message(error_message, title="搜索完成")
return error_message return error_message
except Exception as e: except Exception as e:
error_message = f"搜索媒体失败: {str(e)}" error_message = f"搜索媒体失败: {str(e)}"
logger.error(f"搜索媒体失败: {e}", exc_info=True) logger.error(f"搜索媒体失败: {e}", exc_info=True)
self._send_tool_message(error_message, title="搜索失败") self.send_tool_message(error_message, title="搜索失败")
return error_message return error_message

View File

@@ -3,10 +3,10 @@
import json import json
from typing import List, Optional from typing import List, Optional
from app.agent.tools.base import MoviePilotTool
from app.chain.search import SearchChain from app.chain.search import SearchChain
from app.log import logger from app.log import logger
from app.schemas.types import MediaType from app.schemas.types import MediaType
from app.agent.tools.base import MoviePilotTool
class SearchTorrentsTool(MoviePilotTool): class SearchTorrentsTool(MoviePilotTool):
@@ -16,10 +16,11 @@ class SearchTorrentsTool(MoviePilotTool):
async def _arun(self, title: str, explanation: str, year: Optional[str] = None, async def _arun(self, title: str, explanation: str, year: Optional[str] = None,
media_type: Optional[str] = None, season: Optional[int] = None, media_type: Optional[str] = None, season: Optional[int] = None,
sites: Optional[List[int]] = None, **kwargs) -> str: sites: Optional[List[int]] = None, **kwargs) -> str:
logger.info(f"执行工具: {self.name}, 参数: title={title}, year={year}, media_type={media_type}, season={season}, sites={sites}") logger.info(
f"执行工具: {self.name}, 参数: title={title}, year={year}, media_type={media_type}, season={season}, sites={sites}")
# 发送工具执行说明 # 发送工具执行说明
self._send_tool_message(f"正在搜索种子资源: {title}" + (f" ({year})" if year else ""), title="搜索种子") self.send_tool_message(f"正在搜索种子资源: {title}" + (f" ({year})" if year else ""), title="搜索种子")
try: try:
search_chain = SearchChain() search_chain = SearchChain()
@@ -30,33 +31,30 @@ class SearchTorrentsTool(MoviePilotTool):
if year and torrent.meta_info and torrent.meta_info.year != year: if year and torrent.meta_info and torrent.meta_info.year != year:
continue continue
if media_type and torrent.media_info: if media_type and torrent.media_info:
try: if torrent.media_info.type != MediaType(media_type):
if torrent.media_info.type != MediaType(media_type): continue
continue
except:
pass
if season and torrent.meta_info and torrent.meta_info.begin_season != season: if season and torrent.meta_info and torrent.meta_info.begin_season != season:
continue continue
filtered_torrents.append(torrent) filtered_torrents.append(torrent)
if filtered_torrents: if filtered_torrents:
result_message = f"找到 {len(filtered_torrents)} 个相关种子资源" result_message = f"找到 {len(filtered_torrents)} 个相关种子资源"
self._send_tool_message(result_message, title="搜索成功") self.send_tool_message(result_message, title="搜索成功")
# 发送详细结果 # 发送详细结果
for i, torrent in enumerate(filtered_torrents[:5]): # 只显示前5个结果 for i, torrent in enumerate(filtered_torrents[:5]): # 只显示前5个结果
torrent_title = torrent.torrent_info.title if torrent.torrent_info else torrent.meta_info.title if torrent.meta_info else "未知" torrent_title = torrent.torrent_info.title if torrent.torrent_info else torrent.meta_info.title if torrent.meta_info else "未知"
site_name = torrent.torrent_info.site_name if torrent.torrent_info else "未知站点" site_name = torrent.torrent_info.site_name if torrent.torrent_info else "未知站点"
torrent_info = f"{i+1}. {torrent_title} - {site_name}" torrent_info = f"{i + 1}. {torrent_title} - {site_name}"
self._send_tool_message(torrent_info, title="搜索结果") self.send_tool_message(torrent_info, title="搜索结果")
return json.dumps([t.to_dict() for t in filtered_torrents], ensure_ascii=False, indent=2) return json.dumps([t.to_dict() for t in filtered_torrents], ensure_ascii=False, indent=2)
else: else:
error_message = f"未找到相关种子资源: {title}" error_message = f"未找到相关种子资源: {title}"
self._send_tool_message(error_message, title="搜索完成") self.send_tool_message(error_message, title="搜索完成")
return error_message return error_message
except Exception as e: except Exception as e:
error_message = f"搜索种子时发生错误: {str(e)}" error_message = f"搜索种子时发生错误: {str(e)}"
logger.error(f"搜索种子失败: {e}", exc_info=True) logger.error(f"搜索种子失败: {e}", exc_info=True)
self._send_tool_message(error_message, title="搜索失败") self.send_tool_message(error_message, title="搜索失败")
return error_message return error_message

View File

@@ -2,9 +2,8 @@
from typing import Optional from typing import Optional
from app.helper.message import MessageHelper
from app.log import logger
from app.agent.tools.base import MoviePilotTool from app.agent.tools.base import MoviePilotTool
from app.log import logger
class SendMessageTool(MoviePilotTool): class SendMessageTool(MoviePilotTool):
@@ -14,8 +13,7 @@ class SendMessageTool(MoviePilotTool):
async def _arun(self, message: str, explanation: str, message_type: Optional[str] = "info", **kwargs) -> str: async def _arun(self, message: str, explanation: str, message_type: Optional[str] = "info", **kwargs) -> str:
logger.info(f"执行工具: {self.name}, 参数: message={message}, message_type={message_type}") logger.info(f"执行工具: {self.name}, 参数: message={message}, message_type={message_type}")
try: try:
message_helper = MessageHelper() self.send_tool_message(message, title=message_type)
message_helper.put(message=message, role="system", title=f"MoviePilot助手通知 ({message_type})")
return "消息已发送。" return "消息已发送。"
except Exception as e: except Exception as e:
logger.error(f"发送消息失败: {e}") logger.error(f"发送消息失败: {e}")