fix agent

This commit is contained in:
jxxghp
2025-11-01 19:08:05 +08:00
parent e885fb15a0
commit 8016a9539a
14 changed files with 160 additions and 112 deletions

View File

@@ -1,17 +1,16 @@
"""MoviePilot AI智能体实现""" """MoviePilot AI智能体实现"""
import asyncio import asyncio
import threading
from typing import Dict, List, Any from typing import Dict, List, Any
from langchain.agents import AgentExecutor, create_openai_tools_agent from langchain.agents import AgentExecutor, create_openai_tools_agent
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_community.callbacks import get_openai_callback from langchain_community.callbacks import get_openai_callback
from langchain_core.callbacks import AsyncCallbackHandler
from langchain_core.chat_history import InMemoryChatMessageHistory from langchain_core.chat_history import InMemoryChatMessageHistory
from langchain_core.messages import HumanMessage, AIMessage, ToolCall from langchain_core.messages import HumanMessage, AIMessage, ToolCall
from langchain_core.runnables.history import RunnableWithMessageHistory from langchain_core.runnables.history import RunnableWithMessageHistory
from app.agent.callback import StreamingCallbackHandler
from app.agent.memory import ConversationMemoryManager from app.agent.memory import ConversationMemoryManager
from app.agent.prompt import PromptManager from app.agent.prompt import PromptManager
from app.agent.tools import MoviePilotToolFactory from app.agent.tools import MoviePilotToolFactory
@@ -26,34 +25,6 @@ class AgentChain(ChainBase):
pass pass
class StreamingCallbackHandler(AsyncCallbackHandler):
"""流式输出回调处理器"""
def __init__(self, session_id: str):
self._lock = threading.Lock()
self.session_id = session_id
self.current_message = ""
self.message_helper = MessageHelper()
async def get_message(self):
"""获取当前消息内容,获取后清空"""
with self._lock:
if not self.current_message:
return ""
msg = self.current_message
logger.info(f"Agent消息: {msg}")
self.current_message = ""
return msg
async def on_llm_new_token(self, token: str, **kwargs):
"""处理新的token"""
if not token:
return
with self._lock:
# 缓存当前消息
self.current_message += token
class MoviePilotAgent: class MoviePilotAgent:
"""MoviePilot AI智能体""" """MoviePilot AI智能体"""
@@ -142,7 +113,8 @@ class MoviePilotAgent:
user_id=self.user_id, user_id=self.user_id,
channel=self.channel, channel=self.channel,
source=self.source, source=self.source,
username=self.username username=self.username,
callback_handler=self.callback_handler
) )
@staticmethod @staticmethod
@@ -249,7 +221,7 @@ class MoviePilotAgent:
agent_message = await self.callback_handler.get_message() agent_message = await self.callback_handler.get_message()
# 发送Agent回复给用户通过原渠道 # 发送Agent回复给用户通过原渠道
self.send_agent_message(agent_message) await self.send_agent_message(agent_message)
# 添加Agent回复到记忆 # 添加Agent回复到记忆
await self.memory_manager.add_memory( await self.memory_manager.add_memory(
@@ -265,7 +237,7 @@ class MoviePilotAgent:
error_message = f"处理消息时发生错误: {str(e)}" error_message = f"处理消息时发生错误: {str(e)}"
logger.error(error_message) logger.error(error_message)
# 发送错误消息给用户(通过原渠道) # 发送错误消息给用户(通过原渠道)
self.send_agent_message(error_message) await self.send_agent_message(error_message)
return error_message return error_message
async def _execute_agent(self, input_context: Dict[str, Any]) -> Dict[str, Any]: async def _execute_agent(self, input_context: Dict[str, Any]) -> Dict[str, Any]:
@@ -301,9 +273,9 @@ class MoviePilotAgent:
"token_usage": {} "token_usage": {}
} }
def send_agent_message(self, message: str, title: str = "MoviePilot助手"): async def send_agent_message(self, message: str, title: str = "MoviePilot助手"):
"""通过原渠道发送消息给用户""" """通过原渠道发送消息给用户"""
AgentChain().post_message( await AgentChain().async_post_message(
Notification( Notification(
channel=self.channel, channel=self.channel,
source=self.source, source=self.source,

View File

@@ -0,0 +1,33 @@
import threading
from langchain_core.callbacks import AsyncCallbackHandler
from app.log import logger
class StreamingCallbackHandler(AsyncCallbackHandler):
"""流式输出回调处理器"""
def __init__(self, session_id: str):
self._lock = threading.Lock()
self.session_id = session_id
self.current_message = ""
async def get_message(self):
"""获取当前消息内容,获取后清空"""
with self._lock:
if not self.current_message:
return ""
msg = self.current_message
logger.info(f"Agent消息: {msg}")
self.current_message = ""
return msg
async def on_llm_new_token(self, token: str, **kwargs):
"""处理新的token"""
if not token:
return
with self._lock:
# 缓存当前消息
self.current_message += token

View File

@@ -1,8 +1,11 @@
"""MoviePilot工具基类""" """MoviePilot工具基类"""
from abc import ABCMeta, abstractmethod
from typing import Callable, Any
from langchain.tools import BaseTool from langchain.tools import BaseTool
from pydantic import PrivateAttr from pydantic import PrivateAttr
from app.agent import StreamingCallbackHandler
from app.chain import ChainBase from app.chain import ChainBase
from app.schemas import Notification from app.schemas import Notification
@@ -11,7 +14,7 @@ class ToolChain(ChainBase):
pass pass
class MoviePilotTool(BaseTool): class MoviePilotTool(BaseTool, metaclass=ABCMeta):
"""MoviePilot专用工具基类""" """MoviePilot专用工具基类"""
_session_id: str = PrivateAttr() _session_id: str = PrivateAttr()
@@ -19,16 +22,26 @@ class MoviePilotTool(BaseTool):
_channel: str = PrivateAttr(default=None) _channel: str = PrivateAttr(default=None)
_source: str = PrivateAttr(default=None) _source: str = PrivateAttr(default=None)
_username: str = PrivateAttr(default=None) _username: str = PrivateAttr(default=None)
_callback_handler: StreamingCallbackHandler = PrivateAttr(default=None)
def __init__(self, session_id: str, user_id: str, **kwargs): def __init__(self, session_id: str, user_id: str, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self._session_id = session_id self._session_id = session_id
self._user_id = user_id self._user_id = user_id
def _run(self, **kwargs) -> str: def _run(self, *args: Any, **kwargs: Any) -> Any:
raise NotImplementedError pass
async def _arun(self, **kwargs) -> str: async def _arun(self, **kwargs) -> str:
"""异步运行工具"""
# 发送运行工具前的消息
agent_message = await self._callback_handler.get_message()
if agent_message:
await self.send_tool_message(agent_message)
return await self.run(**kwargs)
@abstractmethod
async def run(self, **kwargs) -> str:
raise NotImplementedError raise NotImplementedError
def set_message_attr(self, channel: str, source: str, username: str): def set_message_attr(self, channel: str, source: str, username: str):
@@ -37,9 +50,13 @@ class MoviePilotTool(BaseTool):
self._source = source self._source = source
self._username = username self._username = username
def send_tool_message(self, message: str, title: str = "执行工具"): def set_callback_handler(self, callback_handler: StreamingCallbackHandler):
"""设置回调处理器"""
self._callback_handler = callback_handler
async def send_tool_message(self, message: str, title: str = "执行工具"):
"""发送工具消息""" """发送工具消息"""
ToolChain().post_message( await ToolChain().async_post_message(
Notification( Notification(
channel=self._channel, channel=self._channel,
source=self._source, source=self._source,

View File

@@ -1,6 +1,6 @@
"""MoviePilot工具工厂""" """MoviePilot工具工厂"""
from typing import List from typing import List, Callable
from app.agent.tools.impl.add_download import AddDownloadTool from app.agent.tools.impl.add_download import AddDownloadTool
from app.agent.tools.impl.add_subscribe import AddSubscribeTool from app.agent.tools.impl.add_subscribe import AddSubscribeTool
@@ -21,7 +21,8 @@ class MoviePilotToolFactory:
@staticmethod @staticmethod
def create_tools(session_id: str, user_id: str, def create_tools(session_id: str, user_id: str,
channel: str = None, source: str = None, username: str = None) -> List[MoviePilotTool]: channel: str = None, source: str = None, username: str = None,
callback_handler: Callable = None) -> List[MoviePilotTool]:
"""创建MoviePilot工具列表""" """创建MoviePilot工具列表"""
tools = [] tools = []
tool_definitions = [ tool_definitions = [
@@ -42,6 +43,7 @@ class MoviePilotToolFactory:
user_id=user_id user_id=user_id
) )
tool.set_message_attr(channel=channel, source=source, username=username) tool.set_message_attr(channel=channel, source=source, username=username)
tool.set_callback_handler(callback_handler=callback_handler)
tools.append(tool) tools.append(tool)
logger.info(f"成功创建 {len(tools)} 个MoviePilot工具") logger.info(f"成功创建 {len(tools)} 个MoviePilot工具")
return tools return tools

View File

@@ -15,11 +15,15 @@ from app.schemas import TorrentInfo
class AddDownloadInput(BaseModel): class AddDownloadInput(BaseModel):
"""添加下载工具的输入参数模型""" """添加下载工具的输入参数模型"""
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context") explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
torrent_title: str = Field(..., description="The display name/title of the torrent (e.g., 'The.Matrix.1999.1080p.BluRay.x264')") torrent_title: str = Field(...,
description="The display name/title of the torrent (e.g., 'The.Matrix.1999.1080p.BluRay.x264')")
torrent_url: str = Field(..., description="Direct URL to the torrent file (.torrent) or magnet link") torrent_url: str = Field(..., description="Direct URL to the torrent file (.torrent) or magnet link")
downloader: Optional[str] = Field(None, description="Name of the downloader to use (optional, uses default if not specified)") downloader: Optional[str] = Field(None,
save_path: Optional[str] = Field(None, description="Directory path where the downloaded files should be saved (optional, uses default path if not specified)") description="Name of the downloader to use (optional, uses default if not specified)")
labels: Optional[str] = Field(None, description="Comma-separated list of labels/tags to assign to the download (optional, e.g., 'movie,hd,bluray')") save_path: Optional[str] = Field(None,
description="Directory path where the downloaded files should be saved (optional, uses default path if not specified)")
labels: Optional[str] = Field(None,
description="Comma-separated list of labels/tags to assign to the download (optional, e.g., 'movie,hd,bluray')")
class AddDownloadTool(MoviePilotTool): class AddDownloadTool(MoviePilotTool):
@@ -27,19 +31,19 @@ class AddDownloadTool(MoviePilotTool):
description: str = "Add torrent download task to the configured downloader (qBittorrent, Transmission, etc.). Downloads the torrent file and starts the download process with specified settings." description: str = "Add torrent download task to the configured downloader (qBittorrent, Transmission, etc.). Downloads the torrent file and starts the download process with specified settings."
args_schema: Type[BaseModel] = AddDownloadInput args_schema: Type[BaseModel] = AddDownloadInput
async def _arun(self, torrent_title: str, torrent_url: str, async def run(self, torrent_title: str, torrent_url: str,
downloader: Optional[str] = None, save_path: Optional[str] = None, downloader: Optional[str] = None, save_path: Optional[str] = None,
labels: Optional[str] = None, **kwargs) -> str: labels: Optional[str] = None, **kwargs) -> str:
logger.info( logger.info(
f"执行工具: {self.name}, 参数: torrent_title={torrent_title}, torrent_url={torrent_url}, downloader={downloader}, save_path={save_path}, labels={labels}") f"执行工具: {self.name}, 参数: torrent_title={torrent_title}, torrent_url={torrent_url}, downloader={downloader}, save_path={save_path}, labels={labels}")
# 发送工具执行说明 # 发送工具执行说明
self.send_tool_message(f"正在添加下载任务: {torrent_title}", title="添加下载") await self.send_tool_message(f"正在添加下载任务: {torrent_title}", title="添加下载")
try: try:
if not torrent_title or not torrent_url: if not torrent_title or not torrent_url:
error_message = "错误:必须提供种子标题和下载链接" error_message = "错误:必须提供种子标题和下载链接"
self.send_tool_message(error_message, title="下载失败") await self.send_tool_message(error_message, title="下载失败")
return error_message return error_message
# 使用DownloadChain添加下载 # 使用DownloadChain添加下载
@@ -64,14 +68,14 @@ class AddDownloadTool(MoviePilotTool):
) )
if did: if did:
success_message = f"成功添加下载任务:{torrent_title}" success_message = f"成功添加下载任务:{torrent_title}"
self.send_tool_message(success_message, title="下载成功") await self.send_tool_message(success_message, title="下载成功")
return success_message return success_message
else: else:
error_message = "添加下载任务失败" error_message = "添加下载任务失败"
self.send_tool_message(error_message, title="下载失败") await self.send_tool_message(error_message, title="下载失败")
return error_message return error_message
except Exception as e: except Exception as e:
error_message = f"添加下载任务时发生错误: {str(e)}" error_message = f"添加下载任务时发生错误: {str(e)}"
logger.error(f"添加下载任务失败: {e}", exc_info=True) logger.error(f"添加下载任务失败: {e}", exc_info=True)
self.send_tool_message(error_message, title="下载失败") await self.send_tool_message(error_message, title="下载失败")
return error_message return error_message

View File

@@ -15,9 +15,12 @@ class AddSubscribeInput(BaseModel):
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context") explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
title: str = Field(..., description="The title of the media to subscribe to (e.g., 'The Matrix', 'Breaking Bad')") title: str = Field(..., description="The title of the media to subscribe to (e.g., 'The Matrix', 'Breaking Bad')")
year: str = Field(..., description="Release year of the media (required for accurate identification)") year: str = Field(..., description="Release year of the media (required for accurate identification)")
media_type: str = Field(..., description="Type of media content: '电影' for films, '电视剧' for television series or anime series") media_type: str = Field(...,
season: Optional[int] = Field(None, description="Season number for TV shows (optional, if not specified will subscribe to all seasons)") description="Type of media content: '电影' for films, '电视剧' for television series or anime series")
tmdb_id: Optional[str] = Field(None, description="TMDB database ID for precise media identification (optional but recommended for accuracy)") season: Optional[int] = Field(None,
description="Season number for TV shows (optional, if not specified will subscribe to all seasons)")
tmdb_id: Optional[str] = Field(None,
description="TMDB database ID for precise media identification (optional but recommended for accuracy)")
class AddSubscribeTool(MoviePilotTool): class AddSubscribeTool(MoviePilotTool):
@@ -25,13 +28,13 @@ class AddSubscribeTool(MoviePilotTool):
description: str = "Add media subscription to create automated download rules for movies and TV shows. The system will automatically search and download new episodes or releases based on the subscription criteria." description: str = "Add media subscription to create automated download rules for movies and TV shows. The system will automatically search and download new episodes or releases based on the subscription criteria."
args_schema: Type[BaseModel] = AddSubscribeInput args_schema: Type[BaseModel] = AddSubscribeInput
async def _arun(self, title: str, year: str, media_type: str, async def run(self, title: str, year: str, media_type: str,
season: Optional[int] = None, tmdb_id: Optional[str] = None, **kwargs) -> str: season: Optional[int] = None, tmdb_id: Optional[str] = None, **kwargs) -> str:
logger.info( logger.info(
f"执行工具: {self.name}, 参数: title={title}, year={year}, media_type={media_type}, season={season}, tmdb_id={tmdb_id}") f"执行工具: {self.name}, 参数: title={title}, year={year}, media_type={media_type}, season={season}, tmdb_id={tmdb_id}")
# 发送工具执行说明 # 发送工具执行说明
self.send_tool_message(f"正在添加订阅: {title} ({year}) - {media_type}", title="添加订阅") await self.send_tool_message(f"正在添加订阅: {title} ({year}) - {media_type}", title="添加订阅")
try: try:
subscribe_chain = SubscribeChain() subscribe_chain = SubscribeChain()
@@ -53,14 +56,14 @@ class AddSubscribeTool(MoviePilotTool):
) )
if sid: if sid:
success_message = f"成功添加订阅:{title} ({year})" success_message = f"成功添加订阅:{title} ({year})"
self.send_tool_message(success_message, title="订阅成功") await self.send_tool_message(success_message, title="订阅成功")
return success_message return success_message
else: else:
error_message = f"添加订阅失败:{message}" error_message = f"添加订阅失败:{message}"
self.send_tool_message(error_message, title="订阅失败") await self.send_tool_message(error_message, title="订阅失败")
return error_message return error_message
except Exception as e: except Exception as e:
error_message = f"添加订阅时发生错误: {str(e)}" error_message = f"添加订阅时发生错误: {str(e)}"
logger.error(f"添加订阅失败: {e}", exc_info=True) logger.error(f"添加订阅失败: {e}", exc_info=True)
self.send_tool_message(error_message, title="订阅失败") await self.send_tool_message(error_message, title="订阅失败")
return error_message return error_message

View File

@@ -5,17 +5,20 @@ from typing import Optional, Type
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from app.agent.tools.base import MoviePilotTool
from app.chain.recommend import RecommendChain from app.chain.recommend import RecommendChain
from app.log import logger from app.log import logger
from app.agent.tools.base import MoviePilotTool
class GetRecommendationsInput(BaseModel): class GetRecommendationsInput(BaseModel):
"""获取推荐工具的输入参数模型""" """获取推荐工具的输入参数模型"""
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context") explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
source: Optional[str] = Field("tmdb_trending", description="Recommendation source: 'tmdb_trending' for TMDB trending content, 'douban_hot' for Douban popular content, 'bangumi_calendar' for Bangumi anime calendar") source: Optional[str] = Field("tmdb_trending",
media_type: Optional[str] = Field("all", description="Type of media content: '电影' for films, '电视剧' for television series or anime series, 'all' for all types") description="Recommendation source: 'tmdb_trending' for TMDB trending content, 'douban_hot' for Douban popular content, 'bangumi_calendar' for Bangumi anime calendar")
limit: Optional[int] = Field(20, description="Maximum number of recommendations to return (default: 20, maximum: 100)") media_type: Optional[str] = Field("all",
description="Type of media content: '电影' for films, '电视剧' for television series or anime series, 'all' for all types")
limit: Optional[int] = Field(20,
description="Maximum number of recommendations to return (default: 20, maximum: 100)")
class GetRecommendationsTool(MoviePilotTool): class GetRecommendationsTool(MoviePilotTool):
@@ -23,8 +26,8 @@ class GetRecommendationsTool(MoviePilotTool):
description: str = "Get trending and popular media recommendations from various sources. Returns curated lists of popular movies, TV shows, and anime based on different criteria like trending, ratings, or calendar schedules." description: str = "Get trending and popular media recommendations from various sources. Returns curated lists of popular movies, TV shows, and anime based on different criteria like trending, ratings, or calendar schedules."
args_schema: Type[BaseModel] = GetRecommendationsInput args_schema: Type[BaseModel] = GetRecommendationsInput
async def _arun(self, source: Optional[str] = "tmdb_trending", async def run(self, source: Optional[str] = "tmdb_trending",
media_type: Optional[str] = "all", limit: Optional[int] = 20, **kwargs) -> str: media_type: Optional[str] = "all", limit: Optional[int] = 20, **kwargs) -> str:
logger.info(f"执行工具: {self.name}, 参数: source={source}, media_type={media_type}, limit={limit}") logger.info(f"执行工具: {self.name}, 参数: source={source}, media_type={media_type}, limit={limit}")
try: try:
recommend_chain = RecommendChain() recommend_chain = RecommendChain()
@@ -36,12 +39,12 @@ class GetRecommendationsTool(MoviePilotTool):
results = recommend_chain.douban_movie_hot(limit=limit) results = recommend_chain.douban_movie_hot(limit=limit)
elif media_type == "tv": elif media_type == "tv":
results = recommend_chain.douban_tv_hot(limit=limit) results = recommend_chain.douban_tv_hot(limit=limit)
else: # all else: # all
results.extend(recommend_chain.douban_movie_hot(limit=limit)) results.extend(recommend_chain.douban_movie_hot(limit=limit))
results.extend(recommend_chain.douban_tv_hot(limit=limit)) results.extend(recommend_chain.douban_tv_hot(limit=limit))
elif source == "bangumi_calendar": elif source == "bangumi_calendar":
results = recommend_chain.bangumi_calendar(limit=limit) results = recommend_chain.bangumi_calendar(limit=limit)
if results: if results:
# 使用 to_dict() 方法 # 使用 to_dict() 方法
return json.dumps(results) return json.dumps(results)

View File

@@ -5,10 +5,10 @@ from typing import Type
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from app.agent.tools.base import MoviePilotTool
from app.db.systemconfig_oper import SystemConfigOper from app.db.systemconfig_oper import SystemConfigOper
from app.log import logger from app.log import logger
from app.schemas.types import SystemConfigKey from app.schemas.types import SystemConfigKey
from app.agent.tools.base import MoviePilotTool
class QueryDownloadersInput(BaseModel): class QueryDownloadersInput(BaseModel):
@@ -21,7 +21,7 @@ class QueryDownloadersTool(MoviePilotTool):
description: str = "Query downloader configuration and list all available downloaders. Shows downloader status, connection details, and configuration settings." description: str = "Query downloader configuration and list all available downloaders. Shows downloader status, connection details, and configuration settings."
args_schema: Type[BaseModel] = QueryDownloadersInput args_schema: Type[BaseModel] = QueryDownloadersInput
async def _arun(self, **kwargs) -> str: async def run(self, **kwargs) -> str:
logger.info(f"执行工具: {self.name}") logger.info(f"执行工具: {self.name}")
try: try:
system_config_oper = SystemConfigOper() system_config_oper = SystemConfigOper()

View File

@@ -5,16 +5,18 @@ from typing import Optional, Type
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from app.agent.tools.base import MoviePilotTool
from app.chain.download import DownloadChain from app.chain.download import DownloadChain
from app.log import logger from app.log import logger
from app.agent.tools.base import MoviePilotTool
class QueryDownloadsInput(BaseModel): class QueryDownloadsInput(BaseModel):
"""查询下载工具的输入参数模型""" """查询下载工具的输入参数模型"""
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context") explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
downloader: Optional[str] = Field(None, description="Name of specific downloader to query (optional, if not provided queries all configured downloaders)") downloader: Optional[str] = Field(None,
status: Optional[str] = Field("all", description="Filter downloads by status: 'downloading' for active downloads, 'completed' for finished downloads, 'paused' for paused downloads, 'all' for all downloads") description="Name of specific downloader to query (optional, if not provided queries all configured downloaders)")
status: Optional[str] = Field("all",
description="Filter downloads by status: 'downloading' for active downloads, 'completed' for finished downloads, 'paused' for paused downloads, 'all' for all downloads")
class QueryDownloadsTool(MoviePilotTool): class QueryDownloadsTool(MoviePilotTool):
@@ -22,8 +24,8 @@ class QueryDownloadsTool(MoviePilotTool):
description: str = "Query download status and list all active download tasks. Shows download progress, completion status, and task details from configured downloaders." description: str = "Query download status and list all active download tasks. Shows download progress, completion status, and task details from configured downloaders."
args_schema: Type[BaseModel] = QueryDownloadsInput args_schema: Type[BaseModel] = QueryDownloadsInput
async def _arun(self, downloader: Optional[str] = None, async def run(self, downloader: Optional[str] = None,
status: Optional[str] = "all", **kwargs) -> str: status: Optional[str] = "all", **kwargs) -> str:
logger.info(f"执行工具: {self.name}, 参数: downloader={downloader}, status={status}") logger.info(f"执行工具: {self.name}, 参数: downloader={downloader}, status={status}")
try: try:
download_chain = DownloadChain() download_chain = DownloadChain()

View File

@@ -14,9 +14,12 @@ from app.schemas import MediaServerItem
class QueryMediaLibraryInput(BaseModel): class QueryMediaLibraryInput(BaseModel):
"""查询媒体库工具的输入参数模型""" """查询媒体库工具的输入参数模型"""
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context") explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
media_type: Optional[str] = Field("all", description="Type of media content: '电影' for films, '电视剧' for television series or anime series, 'all' for all types") media_type: Optional[str] = Field("all",
title: Optional[str] = Field(None, description="Specific media title to search for (optional, if provided returns detailed info for that specific media)") description="Type of media content: '电影' for films, '电视剧' for television series or anime series, 'all' for all types")
year: Optional[str] = Field(None, description="Release year of the media (optional, helps narrow down search results)") title: Optional[str] = Field(None,
description="Specific media title to search for (optional, if provided returns detailed info for that specific media)")
year: Optional[str] = Field(None,
description="Release year of the media (optional, helps narrow down search results)")
class QueryMediaLibraryTool(MoviePilotTool): class QueryMediaLibraryTool(MoviePilotTool):
@@ -24,8 +27,8 @@ class QueryMediaLibraryTool(MoviePilotTool):
description: str = "Query media library status and list all media files that have been successfully processed and added to the media server (Plex, Emby, Jellyfin). Shows library statistics and file details." description: str = "Query media library status and list all media files that have been successfully processed and added to the media server (Plex, Emby, Jellyfin). Shows library statistics and file details."
args_schema: Type[BaseModel] = QueryMediaLibraryInput args_schema: Type[BaseModel] = QueryMediaLibraryInput
async def _arun(self, media_type: Optional[str] = "all", async def run(self, media_type: Optional[str] = "all",
title: Optional[str] = None, year: Optional[str] = None, **kwargs) -> str: title: Optional[str] = None, year: Optional[str] = None, **kwargs) -> str:
logger.info(f"执行工具: {self.name}, 参数: media_type={media_type}, title={title}") logger.info(f"执行工具: {self.name}, 参数: media_type={media_type}, title={title}")
try: try:
media_server_oper = MediaServerOper() media_server_oper = MediaServerOper()

View File

@@ -5,16 +5,18 @@ from typing import Optional, Type
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from app.agent.tools.base import MoviePilotTool
from app.db.subscribe_oper import SubscribeOper from app.db.subscribe_oper import SubscribeOper
from app.log import logger from app.log import logger
from app.agent.tools.base import MoviePilotTool
class QuerySubscribesInput(BaseModel): class QuerySubscribesInput(BaseModel):
"""查询订阅工具的输入参数模型""" """查询订阅工具的输入参数模型"""
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context") explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
status: Optional[str] = Field("all", description="Filter subscriptions by status: 'R' for enabled subscriptions, 'P' for disabled ones, 'all' for all subscriptions") status: Optional[str] = Field("all",
media_type: Optional[str] = Field("all", description="Filter by media type: 'movie' for films, 'tv' for television series, 'all' for all types") description="Filter subscriptions by status: 'R' for enabled subscriptions, 'P' for disabled ones, 'all' for all subscriptions")
media_type: Optional[str] = Field("all",
description="Filter by media type: 'movie' for films, 'tv' for television series, 'all' for all types")
class QuerySubscribesTool(MoviePilotTool): class QuerySubscribesTool(MoviePilotTool):
@@ -22,7 +24,7 @@ class QuerySubscribesTool(MoviePilotTool):
description: str = "Query subscription status and list all user subscriptions. Shows active subscriptions, their download status, and configuration details." description: str = "Query subscription status and list all user subscriptions. Shows active subscriptions, their download status, and configuration details."
args_schema: Type[BaseModel] = QuerySubscribesInput args_schema: Type[BaseModel] = QuerySubscribesInput
async def _arun(self, status: Optional[str] = "all", media_type: Optional[str] = "all", **kwargs) -> str: async def run(self, status: Optional[str] = "all", media_type: Optional[str] = "all", **kwargs) -> str:
logger.info(f"执行工具: {self.name}, 参数: status={status}, media_type={media_type}") logger.info(f"执行工具: {self.name}, 参数: status={status}, media_type={media_type}")
try: try:
subscribe_oper = SubscribeOper() subscribe_oper = SubscribeOper()

View File

@@ -16,8 +16,10 @@ class SearchMediaInput(BaseModel):
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context") explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
title: str = Field(..., description="The title of the media to search for (e.g., 'The Matrix', 'Breaking Bad')") title: str = Field(..., description="The title of the media to search for (e.g., 'The Matrix', 'Breaking Bad')")
year: Optional[str] = Field(None, description="Release year of the media (optional, helps narrow down results)") year: Optional[str] = Field(None, description="Release year of the media (optional, helps narrow down results)")
media_type: Optional[str] = Field(None, description="Type of media content: '电影' for films, '电视剧' for television series or anime series") media_type: Optional[str] = Field(None,
season: Optional[int] = Field(None, description="Season number for TV shows and anime (optional, only applicable for series)") description="Type of media content: '电影' for films, '电视剧' for television series or anime series")
season: Optional[int] = Field(None,
description="Season number for TV shows and anime (optional, only applicable for series)")
class SearchMediaTool(MoviePilotTool): class SearchMediaTool(MoviePilotTool):
@@ -25,13 +27,13 @@ class SearchMediaTool(MoviePilotTool):
description: str = "Search for media resources including movies, TV shows, anime, etc. Supports searching by title, year, type, and other criteria. Returns detailed media information from TMDB database." description: str = "Search for media resources including movies, TV shows, anime, etc. Supports searching by title, year, type, and other criteria. Returns detailed media information from TMDB database."
args_schema: Type[BaseModel] = SearchMediaInput args_schema: Type[BaseModel] = SearchMediaInput
async def _arun(self, title: str, year: Optional[str] = None, async def run(self, title: str, year: Optional[str] = None,
media_type: Optional[str] = None, season: Optional[int] = None, **kwargs) -> str: media_type: Optional[str] = None, season: Optional[int] = None, **kwargs) -> str:
logger.info( logger.info(
f"执行工具: {self.name}, 参数: title={title}, year={year}, media_type={media_type}, season={season}") f"执行工具: {self.name}, 参数: title={title}, year={year}, media_type={media_type}, season={season}")
# 发送工具执行说明 # 发送工具执行说明
self.send_tool_message(f"正在搜索媒体资源: {title}" + (f" ({year})" if year else ""), title="搜索中") await self.send_tool_message(f"正在搜索媒体资源: {title}" + (f" ({year})" if year else ""), title="搜索中")
try: try:
media_chain = MediaChain() media_chain = MediaChain()
@@ -62,24 +64,24 @@ class SearchMediaTool(MoviePilotTool):
if filtered_results: if filtered_results:
result_message = f"找到 {len(filtered_results)} 个相关媒体资源" result_message = f"找到 {len(filtered_results)} 个相关媒体资源"
self.send_tool_message(result_message, title="搜索成功") await self.send_tool_message(result_message, title="搜索成功")
# 发送详细结果 # 发送详细结果
for i, result in enumerate(filtered_results[:5]): # 只显示前5个结果 for i, result in enumerate(filtered_results[:5]): # 只显示前5个结果
media_info = f"{i + 1}. {result.title} ({result.year}) - {result.type.value if result.type else '未知'}" media_info = f"{i + 1}. {result.title} ({result.year}) - {result.type.value if result.type else '未知'}"
self.send_tool_message(media_info, title="搜索结果") await self.send_tool_message(media_info, title="搜索结果")
return json.dumps([r.to_dict() for r in filtered_results], ensure_ascii=False, indent=2) return json.dumps([r.to_dict() for r in filtered_results], ensure_ascii=False, indent=2)
else: else:
error_message = f"未找到符合条件的媒体资源: {title}" error_message = f"未找到符合条件的媒体资源: {title}"
self.send_tool_message(error_message, title="搜索完成") await self.send_tool_message(error_message, title="搜索完成")
return error_message return error_message
else: else:
error_message = f"未找到相关媒体资源: {title}" error_message = f"未找到相关媒体资源: {title}"
self.send_tool_message(error_message, title="搜索完成") await self.send_tool_message(error_message, title="搜索完成")
return error_message return error_message
except Exception as e: except Exception as e:
error_message = f"搜索媒体失败: {str(e)}" error_message = f"搜索媒体失败: {str(e)}"
logger.error(f"搜索媒体失败: {e}", exc_info=True) logger.error(f"搜索媒体失败: {e}", exc_info=True)
self.send_tool_message(error_message, title="搜索失败") await self.send_tool_message(error_message, title="搜索失败")
return error_message return error_message

View File

@@ -14,11 +14,15 @@ from app.schemas.types import MediaType
class SearchTorrentsInput(BaseModel): class SearchTorrentsInput(BaseModel):
"""搜索种子工具的输入参数模型""" """搜索种子工具的输入参数模型"""
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context") explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
title: str = Field(..., description="The title of the media resource to search for (e.g., 'The Matrix 1999', 'Breaking Bad S01E01')") title: str = Field(...,
year: Optional[str] = Field(None, description="Release year of the media (optional, helps narrow down search results)") description="The title of the media resource to search for (e.g., 'The Matrix 1999', 'Breaking Bad S01E01')")
media_type: Optional[str] = Field(None, description="Type of media content: '电影' for films, '电视剧' for television series or anime series") year: Optional[str] = Field(None,
description="Release year of the media (optional, helps narrow down search results)")
media_type: Optional[str] = Field(None,
description="Type of media content: '电影' for films, '电视剧' for television series or anime series")
season: Optional[int] = Field(None, description="Season number for TV shows (optional, only applicable for series)") season: Optional[int] = Field(None, description="Season number for TV shows (optional, only applicable for series)")
sites: Optional[List[int]] = Field(None, description="Array of specific site IDs to search on (optional, if not provided searches all configured sites)") sites: Optional[List[int]] = Field(None,
description="Array of specific site IDs to search on (optional, if not provided searches all configured sites)")
class SearchTorrentsTool(MoviePilotTool): class SearchTorrentsTool(MoviePilotTool):
@@ -26,14 +30,14 @@ class SearchTorrentsTool(MoviePilotTool):
description: str = "Search for torrent files across configured indexer sites based on media information. Returns available torrent downloads with details like file size, quality, and download links." description: str = "Search for torrent files across configured indexer sites based on media information. Returns available torrent downloads with details like file size, quality, and download links."
args_schema: Type[BaseModel] = SearchTorrentsInput args_schema: Type[BaseModel] = SearchTorrentsInput
async def _arun(self, title: str, year: Optional[str] = None, async def run(self, title: str, year: Optional[str] = None,
media_type: Optional[str] = None, season: Optional[int] = None, media_type: Optional[str] = None, season: Optional[int] = None,
sites: Optional[List[int]] = None, **kwargs) -> str: sites: Optional[List[int]] = None, **kwargs) -> str:
logger.info( logger.info(
f"执行工具: {self.name}, 参数: title={title}, year={year}, media_type={media_type}, season={season}, sites={sites}") f"执行工具: {self.name}, 参数: title={title}, year={year}, media_type={media_type}, season={season}, sites={sites}")
# 发送工具执行说明 # 发送工具执行说明
self.send_tool_message(f"正在搜索种子资源: {title}" + (f" ({year})" if year else ""), title="搜索种子") await self.send_tool_message(f"正在搜索种子资源: {title}" + (f" ({year})" if year else ""), title="搜索种子")
try: try:
search_chain = SearchChain() search_chain = SearchChain()
@@ -52,22 +56,22 @@ class SearchTorrentsTool(MoviePilotTool):
if filtered_torrents: if filtered_torrents:
result_message = f"找到 {len(filtered_torrents)} 个相关种子资源" result_message = f"找到 {len(filtered_torrents)} 个相关种子资源"
self.send_tool_message(result_message, title="搜索成功") await self.send_tool_message(result_message, title="搜索成功")
# 发送详细结果 # 发送详细结果
for i, torrent in enumerate(filtered_torrents[:5]): # 只显示前5个结果 for i, torrent in enumerate(filtered_torrents[:5]): # 只显示前5个结果
torrent_title = torrent.torrent_info.title if torrent.torrent_info else torrent.meta_info.title if torrent.meta_info else "未知" torrent_title = torrent.torrent_info.title if torrent.torrent_info else torrent.meta_info.title if torrent.meta_info else "未知"
site_name = torrent.torrent_info.site_name if torrent.torrent_info else "未知站点" site_name = torrent.torrent_info.site_name if torrent.torrent_info else "未知站点"
torrent_info = f"{i + 1}. {torrent_title} - {site_name}" torrent_info = f"{i + 1}. {torrent_title} - {site_name}"
self.send_tool_message(torrent_info, title="搜索结果") await self.send_tool_message(torrent_info, title="搜索结果")
return json.dumps([t.to_dict() for t in filtered_torrents], ensure_ascii=False, indent=2) return json.dumps([t.to_dict() for t in filtered_torrents], ensure_ascii=False, indent=2)
else: else:
error_message = f"未找到相关种子资源: {title}" error_message = f"未找到相关种子资源: {title}"
self.send_tool_message(error_message, title="搜索完成") await self.send_tool_message(error_message, title="搜索完成")
return error_message return error_message
except Exception as e: except Exception as e:
error_message = f"搜索种子时发生错误: {str(e)}" error_message = f"搜索种子时发生错误: {str(e)}"
logger.error(f"搜索种子失败: {e}", exc_info=True) logger.error(f"搜索种子失败: {e}", exc_info=True)
self.send_tool_message(error_message, title="搜索失败") await self.send_tool_message(error_message, title="搜索失败")
return error_message return error_message

View File

@@ -12,7 +12,8 @@ class SendMessageInput(BaseModel):
"""发送消息工具的输入参数模型""" """发送消息工具的输入参数模型"""
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context") explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
message: str = Field(..., description="The message content to send to the user (should be clear and informative)") message: str = Field(..., description="The message content to send to the user (should be clear and informative)")
message_type: Optional[str] = Field("info", description="Type of message: 'info' for general information, 'success' for successful operations, 'warning' for warnings, 'error' for error messages") message_type: Optional[str] = Field("info",
description="Type of message: 'info' for general information, 'success' for successful operations, 'warning' for warnings, 'error' for error messages")
class SendMessageTool(MoviePilotTool): class SendMessageTool(MoviePilotTool):
@@ -20,10 +21,10 @@ class SendMessageTool(MoviePilotTool):
description: str = "Send notification message to the user through configured notification channels (Telegram, Slack, WeChat, etc.). Used to inform users about operation results, errors, or important updates." description: str = "Send notification message to the user through configured notification channels (Telegram, Slack, WeChat, etc.). Used to inform users about operation results, errors, or important updates."
args_schema: Type[BaseModel] = SendMessageInput args_schema: Type[BaseModel] = SendMessageInput
async def _arun(self, message: str, message_type: Optional[str] = None, **kwargs) -> str: async def run(self, message: str, message_type: Optional[str] = None, **kwargs) -> str:
logger.info(f"执行工具: {self.name}, 参数: message={message}, message_type={message_type}") logger.info(f"执行工具: {self.name}, 参数: message={message}, message_type={message_type}")
try: try:
self.send_tool_message(message, title=message_type) await self.send_tool_message(message, title=message_type)
return "消息已发送。" return "消息已发送。"
except Exception as e: except Exception as e:
logger.error(f"发送消息失败: {e}") logger.error(f"发送消息失败: {e}")