From ab32d3347d9a9d4711b9e097397d68bfbbaa873c Mon Sep 17 00:00:00 2001 From: InfinityPacer <160988576+InfinityPacer@users.noreply.github.com> Date: Wed, 23 Oct 2024 02:26:11 +0800 Subject: [PATCH] feat(command): optimize command registration event handling --- app/api/endpoints/plugin.py | 7 + app/chain/command.py | 350 ++++++++++++++++++----------- app/core/event.py | 2 +- app/schemas/types.py | 2 + app/startup/plugins_initializer.py | 4 + 5 files changed, 234 insertions(+), 131 deletions(-) diff --git a/app/api/endpoints/plugin.py b/app/api/endpoints/plugin.py index c41dbbec..1b4d58e5 100644 --- a/app/api/endpoints/plugin.py +++ b/app/api/endpoints/plugin.py @@ -3,6 +3,7 @@ from typing import Annotated, Any, List, Optional from fastapi import APIRouter, Depends, Header from app import schemas +from app.chain.command import CommandChain from app.core.config import settings from app.core.plugin import PluginManager from app.core.security import verify_apikey, verify_token @@ -210,6 +211,8 @@ def install(plugin_id: str, PluginManager().reload_plugin(plugin_id) # 注册插件服务 Scheduler().update_plugin_job(plugin_id) + # 注册菜单命令 + CommandChain().init_commands(plugin_id) # 注册插件API register_plugin_api(plugin_id) return schemas.Response(success=True) @@ -276,6 +279,8 @@ def reset_plugin(plugin_id: str, PluginManager().reload_plugin(plugin_id) # 注册插件服务 Scheduler().update_plugin_job(plugin_id) + # 注册菜单命令 + CommandChain().init_commands(plugin_id) # 注册插件API register_plugin_api(plugin_id) return schemas.Response(success=True) @@ -302,6 +307,8 @@ def set_plugin_config(plugin_id: str, conf: dict, PluginManager().init_plugin(plugin_id, conf) # 注册插件服务 Scheduler().update_plugin_job(plugin_id) + # 注册菜单命令 + CommandChain().init_commands(plugin_id) # 注册插件API register_plugin_api(plugin_id) return schemas.Response(success=True) diff --git a/app/chain/command.py b/app/chain/command.py index a6a2ac4b..fdd75864 100644 --- a/app/chain/command.py +++ b/app/chain/command.py @@ -1,5 +1,6 @@ +import threading import traceback -from typing import Any, Union, Dict +from typing import Any, Union, Dict, Optional from app.chain import ChainBase from app.chain.download import DownloadChain @@ -8,13 +9,13 @@ from app.chain.subscribe import SubscribeChain from app.chain.system import SystemChain from app.chain.transfer import TransferChain from app.core.config import settings -from app.core.event import Event as ManagerEvent, eventmanager +from app.core.event import Event as ManagerEvent, eventmanager, Event from app.core.plugin import PluginManager from app.helper.message import MessageHelper from app.log import logger from app.scheduler import Scheduler from app.schemas import Notification -from app.schemas.types import EventType, MessageChannel +from app.schemas.types import EventType, MessageChannel, ChainEventType from app.utils.object import ObjectUtils from app.utils.singleton import Singleton @@ -23,133 +24,213 @@ class CommandChain(ChainBase, metaclass=Singleton): """ 全局命令管理,消费事件 """ - # 内建命令 + # 注册的命令集合 + _registered_commands = {} + # 所有命令集合 _commands = {} + # 内建命令集合 + _preset_commands = { + "/cookiecloud": { + "id": "cookiecloud", + "type": "scheduler", + "description": "同步站点", + "category": "站点" + }, + "/sites": { + "func": SiteChain().remote_list, + "description": "查询站点", + "category": "站点", + "data": {} + }, + "/site_cookie": { + "func": SiteChain().remote_cookie, + "description": "更新站点Cookie", + "data": {} + }, + "/site_enable": { + "func": SiteChain().remote_enable, + "description": "启用站点", + "data": {} + }, + "/site_disable": { + "func": SiteChain().remote_disable, + "description": "禁用站点", + "data": {} + }, + "/mediaserver_sync": { + "id": "mediaserver_sync", + "type": "scheduler", + "description": "同步媒体服务器", + "category": "管理" + }, + "/subscribes": { + "func": SubscribeChain().remote_list, + "description": "查询订阅", + "category": "订阅", + "data": {} + }, + "/subscribe_refresh": { + "id": "subscribe_refresh", + "type": "scheduler", + "description": "刷新订阅", + "category": "订阅" + }, + "/subscribe_search": { + "id": "subscribe_search", + "type": "scheduler", + "description": "搜索订阅", + "category": "订阅" + }, + "/subscribe_delete": { + "func": SubscribeChain().remote_delete, + "description": "删除订阅", + "data": {} + }, + "/subscribe_tmdb": { + "id": "subscribe_tmdb", + "type": "scheduler", + "description": "订阅元数据更新" + }, + "/downloading": { + "func": DownloadChain().remote_downloading, + "description": "正在下载", + "category": "管理", + "data": {} + }, + "/transfer": { + "id": "transfer", + "type": "scheduler", + "description": "下载文件整理", + "category": "管理" + }, + "/redo": { + "func": TransferChain().remote_transfer, + "description": "手动整理", + "data": {} + }, + "/clear_cache": { + "func": SystemChain().remote_clear_cache, + "description": "清理缓存", + "category": "管理", + "data": {} + }, + "/restart": { + "func": SystemChain().restart, + "description": "重启系统", + "category": "管理", + "data": {} + }, + "/version": { + "func": SystemChain().version, + "description": "当前版本", + "category": "管理", + "data": {} + } + } + # 插件命令集合 + _plugin_commands = {} + # 其他命令集合 + _other_commands = {} def __init__(self): # 插件管理器 super().__init__() + # 初始化锁 + self._rlock = threading.RLock() + # 插件管理 self.pluginmanager = PluginManager() # 定时服务管理 self.scheduler = Scheduler() # 消息管理器 self.messagehelper = MessageHelper() - # 内置命令:标准参数 arg_str: str, channel: MessageChannel, userid: Union[str, int] = None, source: str = None - # 其中 arg_str 为用户输入的参数,channel 为消息渠道,userid 为用户ID,source 为消息来源,arg_str 可选 - self._commands = { - "/cookiecloud": { - "id": "cookiecloud", - "type": "scheduler", - "description": "同步站点", - "category": "站点" - }, - "/sites": { - "func": SiteChain().remote_list, - "description": "查询站点", - "category": "站点", - "data": {} - }, - "/site_cookie": { - "func": SiteChain().remote_cookie, - "description": "更新站点Cookie", - "data": {} - }, - "/site_enable": { - "func": SiteChain().remote_enable, - "description": "启用站点", - "data": {} - }, - "/site_disable": { - "func": SiteChain().remote_disable, - "description": "禁用站点", - "data": {} - }, - "/mediaserver_sync": { - "id": "mediaserver_sync", - "type": "scheduler", - "description": "同步媒体服务器", - "category": "管理" - }, - "/subscribes": { - "func": SubscribeChain().remote_list, - "description": "查询订阅", - "category": "订阅", - "data": {} - }, - "/subscribe_refresh": { - "id": "subscribe_refresh", - "type": "scheduler", - "description": "刷新订阅", - "category": "订阅" - }, - "/subscribe_search": { - "id": "subscribe_search", - "type": "scheduler", - "description": "搜索订阅", - "category": "订阅" - }, - "/subscribe_delete": { - "func": SubscribeChain().remote_delete, - "description": "删除订阅", - "data": {} - }, - "/subscribe_tmdb": { - "id": "subscribe_tmdb", - "type": "scheduler", - "description": "订阅元数据更新" - }, - "/downloading": { - "func": DownloadChain().remote_downloading, - "description": "正在下载", - "category": "管理", - "data": {} - }, - "/transfer": { - "id": "transfer", - "type": "scheduler", - "description": "下载文件整理", - "category": "管理" - }, - "/redo": { - "func": TransferChain().remote_transfer, - "description": "手动整理", - "data": {} - }, - "/clear_cache": { - "func": SystemChain().remote_clear_cache, - "description": "清理缓存", - "category": "管理", - "data": {} - }, - "/restart": { - "func": SystemChain().restart, - "description": "重启系统", - "category": "管理", - "data": {} - }, - "/version": { - "func": SystemChain().version, - "description": "当前版本", - "category": "管理", - "data": {} + # 初始化命令 + self.init_commands() + + def init_commands(self, pid: Optional[str] = None) -> None: + """ + 初始化菜单命令 + """ + if settings.DEV: + logger.debug("Development mode active. Skipping command initialization.") + return + + with self._rlock: + logger.debug("Acquired lock for initializing commands.") + self._plugin_commands = self.__build_plugin_commands() + self._commands = { + **self._preset_commands, + **self._plugin_commands, + **self._other_commands } - } - # 汇总插件命令 - plugin_commands = self.pluginmanager.get_plugin_commands() - for command in plugin_commands: - self.register( - cmd=command.get('cmd'), - func=self.send_plugin_event, - desc=command.get('desc'), - category=command.get('category'), - data={ - 'etype': command.get('event'), - 'data': command.get('data') + + # 触发事件允许可以拦截和调整命令 + event, initial_commands = self.__trigger_register_commands_event() + + # 如果事件返回有效的 event_data,使用事件中调整后的命令 + if event and event.event_data: + initial_commands = event.event_data.get("commands") or {} + logger.debug(f"Registering command count from event: {len(initial_commands)}") + else: + logger.debug(f"Registering initial command count: {len(initial_commands)}") + + # 对比调整后的命令与当前命令 + if initial_commands == self._registered_commands: + logger.debug("Command set unchanged, skipping broadcast registration.") + else: + logger.debug("Command set has changed, Updating and broadcasting new commands.") + self._registered_commands = initial_commands + super().register_commands(commands=initial_commands) + + def __trigger_register_commands_event(self) -> (Optional[Event], dict): + """ + 触发事件,允许调整命令数据 + """ + + def add_commands(source, command_type): + """ + 添加命令集合 + """ + for cmd, command in source.items(): + command_data = { + "type": command_type, + "description": command.get("description"), + "category": command.get("category") } - ) - # 广播注册命令菜单 - if not settings.DEV: - self.register_commands(commands=self.get_commands()) + # 如果有 pid,则添加到命令数据中 + plugin_id = command.get("pid") + if plugin_id: + command_data["pid"] = plugin_id + commands[cmd] = command_data + + # 触发事件允许可以拦截和调整命令 + commands = {} + add_commands(self._preset_commands, "preset") + add_commands(self._plugin_commands, "plugin") + add_commands(self._other_commands, "other") + event_data = { + "commands": commands + } + return eventmanager.send_event(ChainEventType.CommandRegister, event_data), commands + + def __build_plugin_commands(self) -> Dict[str, dict]: + """ + 构建插件命令 + """ + plugin_commands = {} + for command in self.pluginmanager.get_plugin_commands(): + cmd = command.get("cmd") + if cmd: + plugin_commands[cmd] = { + "pid": command.get("pid"), + "func": self.send_plugin_event, + "description": command.get("desc"), + "category": command.get("category"), + "data": { + "etype": command.get("event"), + "data": command.get("data") + } + } + return plugin_commands def __run_command(self, command: Dict[str, any], data_str: str = "", channel: MessageChannel = None, source: str = None, userid: Union[str, int] = None): @@ -211,24 +292,25 @@ class CommandChain(ChainBase, metaclass=Singleton): """ return self._commands - def register(self, cmd: str, func: Any, data: dict = None, - desc: str = None, category: str = None) -> None: - """ - 注册命令 - """ - self._commands[cmd] = { - "func": func, - "description": desc, - "category": category, - "data": data or {} - } - def get(self, cmd: str) -> Any: """ 获取命令 """ return self._commands.get(cmd, {}) + def register(self, cmd: str, func: Any, data: dict = None, + desc: str = None, category: str = None) -> None: + """ + 注册单个命令 + """ + # 单独调用的,统一注册到其他 + self._other_commands[cmd] = { + "func": func, + "description": desc, + "category": category, + "data": data or {} + } + def execute(self, cmd: str, data_str: str = "", channel: MessageChannel = None, source: str = None, userid: Union[str, int] = None) -> None: @@ -286,3 +368,11 @@ class CommandChain(ChainBase, metaclass=Singleton): if self.get(cmd): self.execute(cmd=cmd, data_str=args, channel=event_channel, source=event_source, userid=event_user) + + @eventmanager.register(EventType.ModuleReload) + def module_reload_event(self, event: ManagerEvent) -> None: + """ + 注册模块重载事件 + """ + # 发生模块重载时,重新注册命令 + self.init_commands() diff --git a/app/core/event.py b/app/core/event.py index e74066a2..2b85914f 100644 --- a/app/core/event.py +++ b/app/core/event.py @@ -129,7 +129,7 @@ class EventManager(metaclass=Singleton): for handler in handlers.values() ) - def send_event(self, etype: Union[EventType, ChainEventType], data: Optional[Dict] = None, + def send_event(self, etype: Union[EventType, ChainEventType], data: Optional[Union[Dict, ChainEventData]] = None, priority: int = DEFAULT_EVENT_PRIORITY) -> Optional[Event]: """ 发送事件,根据事件类型决定是广播事件还是链式事件 diff --git a/app/schemas/types.py b/app/schemas/types.py index c2815a85..692e47a4 100644 --- a/app/schemas/types.py +++ b/app/schemas/types.py @@ -66,6 +66,8 @@ class ChainEventType(Enum): AuthVerification = "auth.verification" # 认证拦截请求 AuthIntercept = "auth.intercept" + # 命令注册请求 + CommandRegister = "command.register" # 系统配置Key字典 diff --git a/app/startup/plugins_initializer.py b/app/startup/plugins_initializer.py index c359d324..23494732 100644 --- a/app/startup/plugins_initializer.py +++ b/app/startup/plugins_initializer.py @@ -1,5 +1,6 @@ import asyncio +from app.chain.command import CommandChain from app.core.plugin import PluginManager from app.log import logger from app.scheduler import Scheduler @@ -13,6 +14,7 @@ async def init_plugins_async(): loop = asyncio.get_event_loop() plugin_manager = PluginManager() scheduler = Scheduler() + command = CommandChain() sync_plugins = await loop.run_in_executor(None, plugin_manager.sync) if not sync_plugins: return @@ -22,6 +24,8 @@ async def init_plugins_async(): plugin_manager.init_config() # 插件启动后注册后台任务 scheduler.init_plugin_jobs() + # 插件启动后注册菜单命令 + command.init_commands() # 插件启动后注册插件API register_plugin_api() logger.info("所有插件初始化完成")