fix agent

This commit is contained in:
jxxghp
2025-11-01 10:39:08 +08:00
parent d523c7c916
commit 438d3210bc
18 changed files with 145 additions and 71 deletions

View File

@@ -15,9 +15,15 @@ from langchain_core.runnables.history import RunnableWithMessageHistory
from app.agent.memory import ConversationMemoryManager
from app.agent.prompt import PromptManager
from app.agent.tools import MoviePilotToolFactory
from app.chain import ChainBase
from app.core.config import settings
from app.helper.message import MessageHelper
from app.log import logger
from app.schemas import Notification
class AgentChain(ChainBase):
pass
class StreamingCallbackHandler(AsyncCallbackHandler):
@@ -51,9 +57,13 @@ class StreamingCallbackHandler(AsyncCallbackHandler):
class MoviePilotAgent:
"""MoviePilot AI智能体"""
def __init__(self, session_id: str, user_id: str = None):
def __init__(self, session_id: str, user_id: str = None,
channel: str = None, source: str = None, username: str = None):
self.session_id = session_id
self.user_id = user_id
self.channel = channel # 消息渠道
self.source = source # 消息来源
self.username = username # 用户名
# 消息助手
self.message_helper = MessageHelper()
@@ -173,14 +183,14 @@ class MoviePilotAgent:
def _initialize_prompt() -> ChatPromptTemplate:
"""初始化提示词模板"""
try:
prompt = ChatPromptTemplate.from_messages([
prompt_template = ChatPromptTemplate.from_messages([
("system", "{system_prompt}"),
MessagesPlaceholder(variable_name="chat_history"),
("user", "{input}"),
MessagesPlaceholder(variable_name="agent_scratchpad"),
])
logger.info("LangChain提示词模板初始化成功")
return prompt
return prompt_template
except Exception as e:
logger.error(f"初始化提示词失败: {e}")
raise e
@@ -236,11 +246,8 @@ class MoviePilotAgent:
# 获取Agent回复
agent_message = await self.callback_handler.get_message()
# 发送Agent回复给用户
self.message_helper.put(
message=agent_message,
role="system"
)
# 发送Agent回复给用户(通过原渠道)
self._send_message_to_channel(agent_message)
# 添加Agent回复到记忆
await self.memory_manager.add_memory(
@@ -255,14 +262,23 @@ class MoviePilotAgent:
except Exception as e:
error_message = f"处理消息时发生错误: {str(e)}"
logger.error(error_message)
# 发送错误消息给用户
self.message_helper.put(
message=error_message,
role="system",
title="MoviePilot助手错误"
)
# 发送错误消息给用户(通过原渠道)
self._send_message_to_channel(error_message)
return error_message
def _send_message_to_channel(self, message: str, title: str = "MoviePilot助手"):
"""通过原渠道发送消息给用户"""
AgentChain().post_message(
Notification(
channel=self.channel,
source=self.source,
userid=self.user_id,
username=self.username,
title=title,
text=message
)
)
async def _execute_agent(self, input_context: Dict[str, Any]) -> Dict[str, Any]:
"""执行LangChain Agent"""
try:
@@ -322,20 +338,31 @@ class AgentManager:
await agent.cleanup()
self.active_agents.clear()
async def process_message(self, session_id: str, user_id: str, message: str) -> str:
async def process_message(self, session_id: str, user_id: str, message: str,
channel: str = None, source: str = None, username: str = None) -> str:
"""处理用户消息"""
# 获取或创建Agent实例
if session_id not in self.active_agents:
logger.info(f"创建新的AI智能体实例session_id: {session_id}, user_id: {user_id}")
agent = MoviePilotAgent(
session_id=session_id,
user_id=user_id
user_id=user_id,
channel=channel,
source=source,
username=username
)
agent.memory_manager = self.memory_manager
self.active_agents[session_id] = agent
else:
agent = self.active_agents[session_id]
agent.user_id = user_id # 确保user_id是最新的
# 更新渠道信息
if channel:
agent.channel = channel
if source:
agent.source = source
if username:
agent.username = username
# 处理消息
return await agent.process_message(message)

View File

@@ -3,8 +3,14 @@
from langchain.tools import BaseTool
from pydantic import PrivateAttr
from app.chain import ChainBase
from app.helper.message import MessageHelper
from app.log import logger
from app.schemas import Notification
class ToolChain(ChainBase):
pass
class MoviePilotTool(BaseTool):
@@ -14,11 +20,15 @@ class MoviePilotTool(BaseTool):
_user_id: str = PrivateAttr()
_message_helper: MessageHelper = PrivateAttr()
def __init__(self, session_id: str, user_id: str, message_helper: MessageHelper = None, **kwargs):
def __init__(self, session_id: str, user_id: str,
channel: str = None, source: str = None, username: str = None, **kwargs):
super().__init__(**kwargs)
self._session_id = session_id
self._user_id = user_id
self._message_helper = message_helper or MessageHelper()
self.channel = channel
self.source = source
self.username = username
self._message_helper = MessageHelper()
def _run(self, **kwargs) -> str:
raise NotImplementedError
@@ -28,11 +38,13 @@ class MoviePilotTool(BaseTool):
def _send_tool_message(self, message: str, title: str = None, **kwargs):
"""发送工具执行消息"""
try:
self._message_helper.put(
message=message,
role="system",
title=title or "工具执行"
ToolChain().post_message(
Notification(
channel=self.channel,
source=self.source,
userid=self.user_id,
username=self.username,
title=title,
text=message
)
except Exception as e:
logger.error(f"发送工具消息失败: {e}")
)

View File

@@ -2,26 +2,26 @@
from typing import List
from app.helper.message import MessageHelper
from app.agent.tools.impl.add_download import AddDownloadTool
from app.agent.tools.impl.add_subscribe import AddSubscribeTool
from app.agent.tools.impl.get_recommendations import GetRecommendationsTool
from app.agent.tools.impl.query_downloaders import QueryDownloadersTool
from app.agent.tools.impl.query_downloads import QueryDownloadsTool
from app.agent.tools.impl.query_media_library import QueryMediaLibraryTool
from app.agent.tools.impl.query_subscribes import QuerySubscribesTool
from app.agent.tools.impl.search_media import SearchMediaTool
from app.agent.tools.impl.search_torrents import SearchTorrentsTool
from app.agent.tools.impl.send_message import SendMessageTool
from app.log import logger
from .base import MoviePilotTool
from app.agent.tools.impl.search_media import SearchMediaTool
from app.agent.tools.impl.add_subscribe import AddSubscribeTool
from app.agent.tools.impl.search_torrents import SearchTorrentsTool
from app.agent.tools.impl.add_download import AddDownloadTool
from app.agent.tools.impl.query_subscribes import QuerySubscribesTool
from app.agent.tools.impl.query_downloads import QueryDownloadsTool
from app.agent.tools.impl.query_downloaders import QueryDownloadersTool
from app.agent.tools.impl.get_recommendations import GetRecommendationsTool
from app.agent.tools.impl.query_media_library import QueryMediaLibraryTool
from app.agent.tools.impl.send_message import SendMessageTool
class MoviePilotToolFactory:
"""MoviePilot工具工厂"""
@staticmethod
def create_tools(session_id: str, user_id: str, message_helper: MessageHelper = None) -> List[MoviePilotTool]:
def create_tools(session_id: str, user_id: str,
channel: str = None, source: str = None, username: str = None) -> List[MoviePilotTool]:
"""创建MoviePilot工具列表"""
tools = []
tool_definitions = [
@@ -40,7 +40,9 @@ class MoviePilotToolFactory:
tools.append(ToolClass(
session_id=session_id,
user_id=user_id,
message_helper=message_helper
channel=channel,
source=source,
username=username
))
logger.info(f"成功创建 {len(tools)} 个MoviePilot工具")
return tools

View File

@@ -16,7 +16,7 @@ class AddDownloadTool(MoviePilotTool):
async def _arun(self, torrent_title: str, torrent_url: str, explanation: str,
downloader: Optional[str] = None, save_path: Optional[str] = None,
labels: Optional[str] = None) -> str:
labels: Optional[str] = None, **kwargs) -> str:
logger.info(f"执行工具: {self.name}, 参数: torrent_title={torrent_title}, torrent_url={torrent_url}, downloader={downloader}, save_path={save_path}, labels={labels}")
# 发送工具执行说明

View File

@@ -13,7 +13,7 @@ class AddSubscribeTool(MoviePilotTool):
description: str = "添加媒体订阅,为用户感兴趣的媒体内容创建订阅规则。"
async def _arun(self, title: str, year: str, media_type: str, explanation: str,
season: Optional[int] = None, tmdb_id: Optional[str] = None) -> str:
season: Optional[int] = None, tmdb_id: Optional[str] = None, **kwargs) -> str:
logger.info(f"执行工具: {self.name}, 参数: title={title}, year={year}, media_type={media_type}, season={season}, tmdb_id={tmdb_id}")
# 发送工具执行说明

View File

@@ -13,7 +13,7 @@ class GetRecommendationsTool(MoviePilotTool):
description: str = "获取热门媒体推荐,包括电影、电视剧等热门内容。"
async def _arun(self, explanation: str, source: Optional[str] = "tmdb_trending",
media_type: Optional[str] = "all", limit: Optional[int] = 20) -> str:
media_type: Optional[str] = "all", limit: Optional[int] = 20, **kwargs) -> str:
logger.info(f"执行工具: {self.name}, 参数: source={source}, media_type={media_type}, limit={limit}")
try:
recommend_chain = RecommendChain()

View File

@@ -12,7 +12,7 @@ class QueryDownloadersTool(MoviePilotTool):
name: str = "query_downloaders"
description: str = "查询下载器配置,查看可用的下载器列表和配置信息。"
async def _arun(self, explanation: str) -> str:
async def _arun(self, explanation: str, **kwargs) -> str:
logger.info(f"执行工具: {self.name}")
try:
system_config_oper = SystemConfigOper()

View File

@@ -13,7 +13,7 @@ class QueryDownloadsTool(MoviePilotTool):
description: str = "查询下载状态,查看下载器的任务列表和进度。"
async def _arun(self, explanation: str, downloader: Optional[str] = None,
status: Optional[str] = "all") -> str:
status: Optional[str] = "all", **kwargs) -> str:
logger.info(f"执行工具: {self.name}, 参数: downloader={downloader}, status={status}")
try:
download_chain = DownloadChain()

View File

@@ -13,7 +13,7 @@ class QueryMediaLibraryTool(MoviePilotTool):
description: str = "查询媒体库状态,查看已入库的媒体文件情况。"
async def _arun(self, explanation: str, media_type: Optional[str] = "all",
title: Optional[str] = None) -> str:
title: Optional[str] = None, **kwargs) -> str:
logger.info(f"执行工具: {self.name}, 参数: media_type={media_type}, title={title}")
try:
media_server_oper = MediaServerOper()

View File

@@ -13,7 +13,7 @@ class QuerySubscribesTool(MoviePilotTool):
description: str = "查询订阅状态,查看用户的订阅列表和状态。"
async def _arun(self, explanation: str, status: Optional[str] = "all",
media_type: Optional[str] = "all") -> str:
media_type: Optional[str] = "all", **kwargs) -> str:
logger.info(f"执行工具: {self.name}, 参数: status={status}, media_type={media_type}")
try:
subscribe_oper = SubscribeOper()

View File

@@ -15,7 +15,7 @@ class SearchMediaTool(MoviePilotTool):
description: str = "搜索媒体资源,包括电影、电视剧、动漫等。可以根据标题、年份、类型等条件进行搜索。"
async def _arun(self, title: str, explanation: str, year: Optional[str] = None,
media_type: Optional[str] = None, season: Optional[int] = None) -> str:
media_type: Optional[str] = None, season: Optional[int] = None, **kwargs) -> str:
logger.info(f"执行工具: {self.name}, 参数: title={title}, year={year}, media_type={media_type}, season={season}")
# 发送工具执行说明

View File

@@ -15,7 +15,7 @@ class SearchTorrentsTool(MoviePilotTool):
async def _arun(self, title: str, explanation: str, year: Optional[str] = None,
media_type: Optional[str] = None, season: Optional[int] = None,
sites: Optional[List[int]] = None) -> str:
sites: Optional[List[int]] = None, **kwargs) -> str:
logger.info(f"执行工具: {self.name}, 参数: title={title}, year={year}, media_type={media_type}, season={season}, sites={sites}")
# 发送工具执行说明

View File

@@ -11,11 +11,11 @@ class SendMessageTool(MoviePilotTool):
name: str = "send_message"
description: str = "发送消息通知,向用户发送操作结果或重要信息。"
async def _arun(self, message: str, explanation: str, message_type: Optional[str] = "info") -> str:
async def _arun(self, message: str, explanation: str, message_type: Optional[str] = "info", **kwargs) -> str:
logger.info(f"执行工具: {self.name}, 参数: message={message}, message_type={message_type}")
try:
message_helper = MessageHelper()
message_helper.put(message=message, role="system", title=f"AI助手通知 ({message_type})")
message_helper.put(message=message, role="system", title=f"MoviePilot助手通知 ({message_type})")
return "消息已发送。"
except Exception as e:
logger.error(f"发送消息失败: {e}")

View File

@@ -852,6 +852,10 @@ class ChainBase(metaclass=ABCMeta):
# 渲染消息
message = MessageTemplateHelper.render(message=message, meta=meta, mediainfo=mediainfo,
torrentinfo=torrentinfo, transferinfo=transferinfo, **kwargs)
# 检查消息是否有效
if not message:
logger.warning("消息为空,跳过发送")
return
# 保存消息
self.messagehelper.put(message, role="user", title=message.title)
self.messageoper.add(**message.model_dump())
@@ -931,6 +935,10 @@ class ChainBase(metaclass=ABCMeta):
# 渲染消息
message = MessageTemplateHelper.render(message=message, meta=meta, mediainfo=mediainfo,
torrentinfo=torrentinfo, transferinfo=transferinfo, **kwargs)
# 检查消息是否有效
if not message:
logger.warning("消息为空,跳过发送")
return
# 保存消息
self.messagehelper.put(message, role="user", title=message.title)
await self.messageoper.async_add(**message.model_dump())

View File

@@ -1,6 +1,8 @@
import asyncio
import re
from typing import Any, Optional, Dict, Union, List
from app.agent import agent_manager
from app.chain import ChainBase
from app.chain.download import DownloadChain
from app.chain.media import MediaChain
@@ -828,54 +830,77 @@ class MessageChain(ChainBase):
try:
# 检查AI智能体是否启用
if not settings.AI_AGENT_ENABLE:
self.messagehelper.put("AI智能体功能未启用请在系统设置中启用", role="system", title="AI助手")
self.post_message(Notification(
channel=channel,
source=source,
userid=userid,
username=username,
title="MoviePilot智能助手未启用请在系统设置中启用"
))
return
# 检查LLM配置
if not settings.LLM_API_KEY:
self.messagehelper.put("LLM API密钥未配置请检查系统设置", role="system", title="AI助手")
self.post_message(Notification(
channel=channel,
source=source,
userid=userid,
username=username,
title="MoviePilot智能助未配置请在系统设置中配置"
))
return
# 提取用户消息
user_message = text[3:].strip() # 移除 "/ai" 前缀
if not user_message:
self.messagehelper.put("请输入您的问题或需求", role="system", title="AI助手")
self.post_message(Notification(
channel=channel,
source=source,
userid=userid,
username=username,
title="请输入您的问题或需求"
))
return
# 发送处理中消息
self.messagehelper.put("正在处理您的请求,请稍候...", role="system", title="AI助手")
self.post_message(Notification(
channel=channel,
source=source,
userid=userid,
username=username,
title="正在处理您的请求,请稍候..."
))
# 异步处理AI智能体请求
import asyncio
from app.agent import agent_manager
# 生成会话ID
session_id = f"user_{userid}_{hash(user_message) % 10000}"
# 在事件循环中处理
try:
loop = asyncio.get_event_loop()
response = loop.run_until_complete(
loop.run_until_complete(
agent_manager.process_message(
session_id=session_id,
user_id=str(userid),
message=user_message
message=user_message,
channel=channel.value if channel else None,
source=source,
username=username
)
)
except RuntimeError:
# 如果没有事件循环,创建新的
response = asyncio.run(
asyncio.run(
agent_manager.process_message(
session_id=session_id,
user_id=str(userid),
message=user_message
message=user_message,
channel=channel.value if channel else None,
source=source,
username=username
)
)
# 发送AI智能体回复
self.messagehelper.put(response, role="system", title="AI助手")
except Exception as e:
logger.error(f"处理AI智能体消息失败: {e}")
self.messagehelper.put(f"AI智能体处理失败: {str(e)}", role="system", title="AI助手")
self.messagehelper.put(f"AI智能体处理失败: {str(e)}", role="system", title="MoviePilot助手")

View File

@@ -1151,7 +1151,7 @@ class Emby:
link = self.get_play_url(item.get("Id"))
if item_type == MediaType.MOVIE.value:
title = item.get("Name")
subtitle = item.get("ProductionYear")
subtitle = str(item.get("ProductionYear")) if item.get("ProductionYear") else None
else:
title = f'{item.get("SeriesName")}'
subtitle = f'S{item.get("ParentIndexNumber")}:{item.get("IndexNumber")} - {item.get("Name")}'
@@ -1223,7 +1223,7 @@ class Emby:
ret_latest.append(schemas.MediaServerPlayItem(
id=item.get("Id"),
title=item.get("Name"),
subtitle=item.get("ProductionYear"),
subtitle=str(item.get("ProductionYear")) if item.get("ProductionYear") else None,
type=item_type,
image=image,
link=link,

View File

@@ -924,7 +924,7 @@ class Jellyfin:
image = self.generate_image_link(item.get("Id"), "Backdrop", False)
if item_type == MediaType.MOVIE.value:
title = item.get("Name")
subtitle = item.get("ProductionYear")
subtitle = str(item.get("ProductionYear")) if item.get("ProductionYear") else None
else:
title = f'{item.get("SeriesName")}'
subtitle = f'S{item.get("ParentIndexNumber")}:{item.get("IndexNumber")} - {item.get("Name")}'
@@ -984,7 +984,7 @@ class Jellyfin:
ret_latest.append(schemas.MediaServerPlayItem(
id=item.get("Id"),
title=item.get("Name"),
subtitle=item.get("ProductionYear"),
subtitle=str(item.get("ProductionYear")) if item.get("ProductionYear") else None,
type=item_type,
image=image,
link=link,

View File

@@ -746,7 +746,7 @@ class Plex:
item_type = MediaType.MOVIE.value if item.TYPE == "movie" else MediaType.TV.value
if item_type == MediaType.MOVIE.value:
title = item.title
subtitle = item.year
subtitle = str(item.year) if item.year else None
else:
title = item.grandparentTitle
subtitle = f"S{item.parentIndex}:E{item.index} - {item.title}"
@@ -825,7 +825,7 @@ class Plex:
ret_resume.append(schemas.MediaServerPlayItem(
id=item.key,
title=title,
subtitle=item.year,
subtitle=str(item.year) if item.year else None,
type=item_type,
image=image,
link=link,