fix agent

This commit is contained in:
jxxghp
2025-11-01 11:41:22 +08:00
parent 182c46037b
commit 247208b8a9
10 changed files with 114 additions and 121 deletions

View File

@@ -249,7 +249,7 @@ class MoviePilotAgent:
agent_message = await self.callback_handler.get_message()
# 发送Agent回复给用户通过原渠道
self._send_message_to_channel(agent_message)
self.send_agent_message(agent_message)
# 添加Agent回复到记忆
await self.memory_manager.add_memory(
@@ -265,22 +265,9 @@ class MoviePilotAgent:
error_message = f"处理消息时发生错误: {str(e)}"
logger.error(error_message)
# 发送错误消息给用户(通过原渠道)
self._send_message_to_channel(error_message)
self.send_agent_message(error_message)
return error_message
def _send_message_to_channel(self, message: str, title: str = "MoviePilot助手"):
"""通过原渠道发送消息给用户"""
AgentChain().post_message(
Notification(
channel=self.channel,
source=self.source,
userid=self.user_id,
username=self.username,
title=title,
text=message
)
)
async def _execute_agent(self, input_context: Dict[str, Any]) -> Dict[str, Any]:
"""执行LangChain Agent"""
try:
@@ -314,6 +301,19 @@ class MoviePilotAgent:
"token_usage": {}
}
def send_agent_message(self, message: str, title: str = "MoviePilot助手"):
"""通过原渠道发送消息给用户"""
AgentChain().post_message(
Notification(
channel=self.channel,
source=self.source,
userid=self.user_id,
username=self.username,
title=title,
text=message
)
)
async def cleanup(self):
"""清理智能体资源"""
if self.session_id in self.session_store:

View File

@@ -257,6 +257,10 @@
"title": {
"description": "Specific media title to search for (optional, if provided returns detailed info for that specific media)",
"type": "string"
},
"year": {
"description": "Release year of the media (optional, helps narrow down search results)",
"type": "string"
}
},
"required": [

View File

@@ -4,8 +4,6 @@ from langchain.tools import BaseTool
from pydantic import PrivateAttr
from app.chain import ChainBase
from app.helper.message import MessageHelper
from app.log import logger
from app.schemas import Notification
@@ -18,17 +16,14 @@ class MoviePilotTool(BaseTool):
_session_id: str = PrivateAttr()
_user_id: str = PrivateAttr()
_message_helper: MessageHelper = PrivateAttr()
_channel: str = PrivateAttr(default=None)
_source: str = PrivateAttr(default=None)
_username: str = PrivateAttr(default=None)
def __init__(self, session_id: str, user_id: str,
channel: str = None, source: str = None, username: str = None, **kwargs):
def __init__(self, session_id: str, user_id: str, **kwargs):
super().__init__(**kwargs)
self._session_id = session_id
self._user_id = user_id
self.channel = channel
self.source = source
self.username = username
self._message_helper = MessageHelper()
def _run(self, **kwargs) -> str:
raise NotImplementedError
@@ -36,8 +31,14 @@ class MoviePilotTool(BaseTool):
async def _arun(self, **kwargs) -> str:
raise NotImplementedError
def _send_tool_message(self, message: str, title: str = None, **kwargs):
"""发送工具执行消息"""
def set_message_attr(self, channel: str, source: str, username: str):
"""设置消息属性"""
self._channel = channel
self._source = source
self._username = username
def send_tool_message(self, message: str, title: str = "执行工具"):
"""发送工具消息"""
ToolChain().post_message(
Notification(
channel=self.channel,

View File

@@ -37,12 +37,11 @@ class MoviePilotToolFactory:
SendMessageTool
]
for ToolClass in tool_definitions:
tools.append(ToolClass(
tool = ToolClass(
session_id=session_id,
user_id=user_id,
channel=channel,
source=source,
username=username
))
user_id=user_id
)
tool.set_message_attr(channel=channel, source=source, username=username)
tools.append(tool)
logger.info(f"成功创建 {len(tools)} 个MoviePilot工具")
return tools

View File

@@ -2,12 +2,12 @@
from typing import Optional
from app.agent.tools.base import MoviePilotTool
from app.chain.download import DownloadChain
from app.core.context import Context
from app.core.metainfo import MetaInfo
from app.log import logger
from app.schemas import TorrentInfo
from app.agent.tools.base import MoviePilotTool
class AddDownloadTool(MoviePilotTool):
@@ -17,15 +17,16 @@ class AddDownloadTool(MoviePilotTool):
async def _arun(self, torrent_title: str, torrent_url: str, explanation: str,
downloader: Optional[str] = None, save_path: Optional[str] = None,
labels: Optional[str] = None, **kwargs) -> str:
logger.info(f"执行工具: {self.name}, 参数: torrent_title={torrent_title}, torrent_url={torrent_url}, downloader={downloader}, save_path={save_path}, labels={labels}")
logger.info(
f"执行工具: {self.name}, 参数: torrent_title={torrent_title}, torrent_url={torrent_url}, downloader={downloader}, save_path={save_path}, labels={labels}")
# 发送工具执行说明
self._send_tool_message(f"正在添加下载任务: {torrent_title}", title="添加下载")
self.send_tool_message(f"正在添加下载任务: {torrent_title}", title="添加下载")
try:
if not torrent_title or not torrent_url:
error_message = "错误:必须提供种子标题和下载链接"
self._send_tool_message(error_message, title="下载失败")
self.send_tool_message(error_message, title="下载失败")
return error_message
# 使用DownloadChain添加下载
@@ -50,14 +51,14 @@ class AddDownloadTool(MoviePilotTool):
)
if did:
success_message = f"成功添加下载任务:{torrent_title}"
self._send_tool_message(success_message, title="下载成功")
self.send_tool_message(success_message, title="下载成功")
return success_message
else:
error_message = "添加下载任务失败"
self._send_tool_message(error_message, title="下载失败")
self.send_tool_message(error_message, title="下载失败")
return error_message
except Exception as e:
error_message = f"添加下载任务时发生错误: {str(e)}"
logger.error(f"添加下载任务失败: {e}", exc_info=True)
self._send_tool_message(error_message, title="下载失败")
self.send_tool_message(error_message, title="下载失败")
return error_message

View File

@@ -2,10 +2,10 @@
from typing import Optional
from app.agent.tools.base import MoviePilotTool
from app.chain.subscribe import SubscribeChain
from app.log import logger
from app.schemas.types import MediaType
from app.agent.tools.base import MoviePilotTool
class AddSubscribeTool(MoviePilotTool):
@@ -14,10 +14,11 @@ class AddSubscribeTool(MoviePilotTool):
async def _arun(self, title: str, year: str, media_type: str, explanation: str,
season: Optional[int] = None, tmdb_id: Optional[str] = None, **kwargs) -> str:
logger.info(f"执行工具: {self.name}, 参数: title={title}, year={year}, media_type={media_type}, season={season}, tmdb_id={tmdb_id}")
logger.info(
f"执行工具: {self.name}, 参数: title={title}, year={year}, media_type={media_type}, season={season}, tmdb_id={tmdb_id}")
# 发送工具执行说明
self._send_tool_message(f"正在添加订阅: {title} ({year}) - {media_type}", title="添加订阅")
self.send_tool_message(f"正在添加订阅: {title} ({year}) - {media_type}", title="添加订阅")
try:
subscribe_chain = SubscribeChain()
@@ -39,14 +40,14 @@ class AddSubscribeTool(MoviePilotTool):
)
if sid:
success_message = f"成功添加订阅:{title} ({year})"
self._send_tool_message(success_message, title="订阅成功")
self.send_tool_message(success_message, title="订阅成功")
return success_message
else:
error_message = f"添加订阅失败:{message}"
self._send_tool_message(error_message, title="订阅失败")
self.send_tool_message(error_message, title="订阅失败")
return error_message
except Exception as e:
error_message = f"添加订阅时发生错误: {str(e)}"
logger.error(f"添加订阅失败: {e}", exc_info=True)
self._send_tool_message(error_message, title="订阅失败")
self.send_tool_message(error_message, title="订阅失败")
return error_message

View File

@@ -1,11 +1,12 @@
"""查询媒体库工具"""
import json
from typing import Optional
from typing import Optional, List
from app.agent.tools.base import MoviePilotTool
from app.db.mediaserver_oper import MediaServerOper
from app.log import logger
from app.agent.tools.base import MoviePilotTool
from app.schemas import MediaServerItem
class QueryMediaLibraryTool(MoviePilotTool):
@@ -13,18 +14,11 @@ class QueryMediaLibraryTool(MoviePilotTool):
description: str = "查询媒体库状态,查看已入库的媒体文件情况。"
async def _arun(self, explanation: str, media_type: Optional[str] = "all",
title: Optional[str] = None, **kwargs) -> str:
title: Optional[str] = None, year: Optional[str] = None, **kwargs) -> str:
logger.info(f"执行工具: {self.name}, 参数: media_type={media_type}, title={title}")
try:
media_server_oper = MediaServerOper()
medias = media_server_oper.list()
filtered_medias = []
for media in medias:
if media_type != "all" and media.type != media_type:
continue
if title and title.lower() not in media.title.lower():
continue
filtered_medias.append(media)
filtered_medias: List[MediaServerItem] = media_server_oper.exists(title=title, year=year, mtype=media_type)
if filtered_medias:
return json.dumps([m.to_dict() for m in filtered_medias])
return "媒体库中未找到相关媒体。"

View File

@@ -3,11 +3,10 @@
import json
from typing import Optional
from app.agent.tools.base import MoviePilotTool
from app.chain.media import MediaChain
from app.core.metainfo import MetaInfo
from app.log import logger
from app.schemas.types import MediaType
from app.agent.tools.base import MoviePilotTool
class SearchMediaTool(MoviePilotTool):
@@ -16,10 +15,11 @@ class SearchMediaTool(MoviePilotTool):
async def _arun(self, title: str, explanation: str, year: Optional[str] = None,
media_type: Optional[str] = None, season: Optional[int] = None, **kwargs) -> str:
logger.info(f"执行工具: {self.name}, 参数: title={title}, year={year}, media_type={media_type}, season={season}")
logger.info(
f"执行工具: {self.name}, 参数: title={title}, year={year}, media_type={media_type}, season={season}")
# 发送工具执行说明
self._send_tool_message(f"正在搜索媒体资源: {title}" + (f" ({year})" if year else ""), title="搜索中")
self.send_tool_message(f"正在搜索媒体资源: {title}" + (f" ({year})" if year else ""), title="搜索中")
try:
media_chain = MediaChain()
@@ -42,35 +42,32 @@ class SearchMediaTool(MoviePilotTool):
if year and result.year != year:
continue
if media_type:
try:
if result.type != MediaType(media_type):
continue
except:
pass
if season and result.season != season:
continue
filtered_results.append(result)
if filtered_results:
result_message = f"找到 {len(filtered_results)} 个相关媒体资源"
self._send_tool_message(result_message, title="搜索成功")
self.send_tool_message(result_message, title="搜索成功")
# 发送详细结果
for i, result in enumerate(filtered_results[:5]): # 只显示前5个结果
media_info = f"{i+1}. {result.title} ({result.year}) - {result.type.value if result.type else '未知'}"
self._send_tool_message(media_info, title="搜索结果")
media_info = f"{i + 1}. {result.title} ({result.year}) - {result.type.value if result.type else '未知'}"
self.send_tool_message(media_info, title="搜索结果")
return json.dumps([r.to_dict() for r in filtered_results], ensure_ascii=False, indent=2)
else:
error_message = f"未找到符合条件的媒体资源: {title}"
self._send_tool_message(error_message, title="搜索完成")
self.send_tool_message(error_message, title="搜索完成")
return error_message
else:
error_message = f"未找到相关媒体资源: {title}"
self._send_tool_message(error_message, title="搜索完成")
self.send_tool_message(error_message, title="搜索完成")
return error_message
except Exception as e:
error_message = f"搜索媒体失败: {str(e)}"
logger.error(f"搜索媒体失败: {e}", exc_info=True)
self._send_tool_message(error_message, title="搜索失败")
self.send_tool_message(error_message, title="搜索失败")
return error_message

View File

@@ -3,10 +3,10 @@
import json
from typing import List, Optional
from app.agent.tools.base import MoviePilotTool
from app.chain.search import SearchChain
from app.log import logger
from app.schemas.types import MediaType
from app.agent.tools.base import MoviePilotTool
class SearchTorrentsTool(MoviePilotTool):
@@ -16,10 +16,11 @@ class SearchTorrentsTool(MoviePilotTool):
async def _arun(self, title: str, explanation: str, year: Optional[str] = None,
media_type: Optional[str] = None, season: Optional[int] = None,
sites: Optional[List[int]] = None, **kwargs) -> str:
logger.info(f"执行工具: {self.name}, 参数: title={title}, year={year}, media_type={media_type}, season={season}, sites={sites}")
logger.info(
f"执行工具: {self.name}, 参数: title={title}, year={year}, media_type={media_type}, season={season}, sites={sites}")
# 发送工具执行说明
self._send_tool_message(f"正在搜索种子资源: {title}" + (f" ({year})" if year else ""), title="搜索种子")
self.send_tool_message(f"正在搜索种子资源: {title}" + (f" ({year})" if year else ""), title="搜索种子")
try:
search_chain = SearchChain()
@@ -30,33 +31,30 @@ class SearchTorrentsTool(MoviePilotTool):
if year and torrent.meta_info and torrent.meta_info.year != year:
continue
if media_type and torrent.media_info:
try:
if torrent.media_info.type != MediaType(media_type):
continue
except:
pass
if season and torrent.meta_info and torrent.meta_info.begin_season != season:
continue
filtered_torrents.append(torrent)
if filtered_torrents:
result_message = f"找到 {len(filtered_torrents)} 个相关种子资源"
self._send_tool_message(result_message, title="搜索成功")
self.send_tool_message(result_message, title="搜索成功")
# 发送详细结果
for i, torrent in enumerate(filtered_torrents[:5]): # 只显示前5个结果
torrent_title = torrent.torrent_info.title if torrent.torrent_info else torrent.meta_info.title if torrent.meta_info else "未知"
site_name = torrent.torrent_info.site_name if torrent.torrent_info else "未知站点"
torrent_info = f"{i+1}. {torrent_title} - {site_name}"
self._send_tool_message(torrent_info, title="搜索结果")
torrent_info = f"{i + 1}. {torrent_title} - {site_name}"
self.send_tool_message(torrent_info, title="搜索结果")
return json.dumps([t.to_dict() for t in filtered_torrents], ensure_ascii=False, indent=2)
else:
error_message = f"未找到相关种子资源: {title}"
self._send_tool_message(error_message, title="搜索完成")
self.send_tool_message(error_message, title="搜索完成")
return error_message
except Exception as e:
error_message = f"搜索种子时发生错误: {str(e)}"
logger.error(f"搜索种子失败: {e}", exc_info=True)
self._send_tool_message(error_message, title="搜索失败")
self.send_tool_message(error_message, title="搜索失败")
return error_message

View File

@@ -2,9 +2,8 @@
from typing import Optional
from app.helper.message import MessageHelper
from app.log import logger
from app.agent.tools.base import MoviePilotTool
from app.log import logger
class SendMessageTool(MoviePilotTool):
@@ -14,8 +13,7 @@ class SendMessageTool(MoviePilotTool):
async def _arun(self, message: str, explanation: str, message_type: Optional[str] = "info", **kwargs) -> str:
logger.info(f"执行工具: {self.name}, 参数: message={message}, message_type={message_type}")
try:
message_helper = MessageHelper()
message_helper.put(message=message, role="system", title=f"MoviePilot助手通知 ({message_type})")
self.send_tool_message(message, title=message_type)
return "消息已发送。"
except Exception as e:
logger.error(f"发送消息失败: {e}")