diff --git a/app/agent/tools/factory.py b/app/agent/tools/factory.py index 0c7bc1de..3b777add 100644 --- a/app/agent/tools/factory.py +++ b/app/agent/tools/factory.py @@ -13,6 +13,7 @@ from app.agent.tools.impl.query_subscribes import QuerySubscribesTool from app.agent.tools.impl.search_media import SearchMediaTool from app.agent.tools.impl.search_torrents import SearchTorrentsTool from app.agent.tools.impl.send_message import SendMessageTool +from app.core.plugin import PluginManager from app.log import logger from .base import MoviePilotTool @@ -39,6 +40,7 @@ class MoviePilotToolFactory: QueryMediaLibraryTool, SendMessageTool ] + # 创建内置工具 for ToolClass in tool_definitions: tool = ToolClass( session_id=session_id, @@ -47,5 +49,36 @@ class MoviePilotToolFactory: tool.set_message_attr(channel=channel, source=source, username=username) tool.set_callback_handler(callback_handler=callback_handler) tools.append(tool) - logger.info(f"成功创建 {len(tools)} 个MoviePilot工具") + + # 加载插件提供的工具 + plugin_tools_count = 0 + plugin_tools_info = PluginManager().get_plugin_agent_tools() + for plugin_info in plugin_tools_info: + plugin_id = plugin_info.get("plugin_id") + plugin_name = plugin_info.get("plugin_name") + tool_classes = plugin_info.get("tools", []) + for ToolClass in tool_classes: + try: + # 验证工具类是否继承自 MoviePilotTool + if not issubclass(ToolClass, MoviePilotTool): + logger.warning(f"插件 {plugin_name}({plugin_id}) 提供的工具类 {ToolClass.__name__} 未继承自 MoviePilotTool,已跳过") + continue + # 创建工具实例 + tool = ToolClass( + session_id=session_id, + user_id=user_id + ) + tool.set_message_attr(channel=channel, source=source, username=username) + tool.set_callback_handler(callback_handler=callback_handler) + tools.append(tool) + plugin_tools_count += 1 + logger.debug(f"成功加载插件 {plugin_name}({plugin_id}) 的工具: {ToolClass.__name__}") + except Exception as e: + logger.error(f"加载插件 {plugin_name}({plugin_id}) 的工具 {ToolClass.__name__} 失败: {str(e)}") + + builtin_tools_count = len(tool_definitions) + if plugin_tools_count > 0: + logger.info(f"成功创建 {len(tools)} 个MoviePilot工具(内置工具: {builtin_tools_count} 个,插件工具: {plugin_tools_count} 个)") + else: + logger.info(f"成功创建 {len(tools)} 个MoviePilot工具") return tools diff --git a/app/core/plugin.py b/app/core/plugin.py index a3a019f7..24222eca 100644 --- a/app/core/plugin.py +++ b/app/core/plugin.py @@ -745,6 +745,36 @@ class PluginManager(metaclass=Singleton): logger.error(f"获取插件 {plugin_id} 动作出错:{str(e)}") return ret_actions + def get_plugin_agent_tools(self, pid: Optional[str] = None) -> List[Dict[str, Any]]: + """ + 获取插件智能体工具 + [{ + "plugin_id": "插件ID", + "plugin_name": "插件名称", + "tools": [ToolClass1, ToolClass2, ...] + }] + """ + ret_tools = [] + # 创建字典快照避免并发修改 + running_plugins_snapshot = dict(self._running_plugins) + for plugin_id, plugin in running_plugins_snapshot.items(): + if pid and pid != plugin_id: + continue + if hasattr(plugin, "get_agent_tools") and ObjectUtils.check_method(plugin.get_agent_tools): + try: + if not plugin.get_state(): + continue + tools = plugin.get_agent_tools() + if tools: + ret_tools.append({ + "plugin_id": plugin_id, + "plugin_name": plugin.plugin_name, + "tools": tools + }) + except Exception as e: + logger.error(f"获取插件 {plugin_id} 智能体工具出错:{str(e)}") + return ret_tools + @staticmethod def get_plugin_remote_entry(plugin_id: str, dist_path: str) -> str: """ diff --git a/app/plugins/__init__.py b/app/plugins/__init__.py index 7619ba58..2550b95a 100644 --- a/app/plugins/__init__.py +++ b/app/plugins/__init__.py @@ -1,6 +1,6 @@ from abc import ABCMeta, abstractmethod from pathlib import Path -from typing import Any, List, Dict, Tuple, Optional +from typing import Any, List, Dict, Tuple, Optional, Type from app.chain import ChainBase from app.core.config import settings @@ -200,6 +200,20 @@ class _PluginBase(metaclass=ABCMeta): """ pass + def get_agent_tools(self) -> List[Type]: + """ + 获取插件智能体工具 + 返回工具类列表,每个工具类必须继承自 MoviePilotTool + [ToolClass1, ToolClass2, ...] + + 对工具类的要求: + 1、工具类必须继承自 app.agent.tools.base.MoviePilotTool + 2、工具类需要实现 run 方法(异步方法) + 3、工具类需要定义 name 和 description 属性 + 4、工具类可以定义 args_schema 来指定输入参数模型 + """ + pass + @abstractmethod def stop_service(self): """