mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-04-03 02:38:39 +08:00
refactor:Command提到上层
This commit is contained in:
@@ -3,7 +3,7 @@ from typing import Annotated, Any, List, Optional
|
||||
from fastapi import APIRouter, Depends, Header
|
||||
|
||||
from app import schemas
|
||||
from app.chain.command import CommandChain
|
||||
from app.command import Command
|
||||
from app.core.config import settings
|
||||
from app.core.plugin import PluginManager
|
||||
from app.core.security import verify_apikey, verify_token
|
||||
@@ -212,7 +212,7 @@ def install(plugin_id: str,
|
||||
# 注册插件服务
|
||||
Scheduler().update_plugin_job(plugin_id)
|
||||
# 注册菜单命令
|
||||
CommandChain().init_commands(plugin_id)
|
||||
Command().init_commands(plugin_id)
|
||||
# 注册插件API
|
||||
register_plugin_api(plugin_id)
|
||||
return schemas.Response(success=True)
|
||||
@@ -280,7 +280,7 @@ def reset_plugin(plugin_id: str,
|
||||
# 注册插件服务
|
||||
Scheduler().update_plugin_job(plugin_id)
|
||||
# 注册菜单命令
|
||||
CommandChain().init_commands(plugin_id)
|
||||
Command().init_commands(plugin_id)
|
||||
# 注册插件API
|
||||
register_plugin_api(plugin_id)
|
||||
return schemas.Response(success=True)
|
||||
@@ -308,7 +308,7 @@ def set_plugin_config(plugin_id: str, conf: dict,
|
||||
# 注册插件服务
|
||||
Scheduler().update_plugin_job(plugin_id)
|
||||
# 注册菜单命令
|
||||
CommandChain().init_commands(plugin_id)
|
||||
Command().init_commands(plugin_id)
|
||||
# 注册插件API
|
||||
register_plugin_api(plugin_id)
|
||||
return schemas.Response(success=True)
|
||||
|
||||
@@ -23,7 +23,11 @@ from app.utils.singleton import Singleton
|
||||
from app.utils.structures import DictUtils
|
||||
|
||||
|
||||
class CommandChain(ChainBase, metaclass=Singleton):
|
||||
class CommandChain(ChainBase):
|
||||
pass
|
||||
|
||||
|
||||
class Command(metaclass=Singleton):
|
||||
"""
|
||||
全局命令管理,消费事件
|
||||
"""
|
||||
@@ -210,7 +214,7 @@ class CommandChain(ChainBase, metaclass=Singleton):
|
||||
if filtered_initial_commands != self._registered_commands or force_register:
|
||||
logger.debug("Command set has changed or force registration is enabled.")
|
||||
self._registered_commands = filtered_initial_commands
|
||||
super().register_commands(commands=filtered_initial_commands)
|
||||
CommandChain().register_commands(commands=filtered_initial_commands)
|
||||
else:
|
||||
logger.debug("Command set unchanged, skipping broadcast registration.")
|
||||
except Exception as e:
|
||||
@@ -248,7 +252,7 @@ class CommandChain(ChainBase, metaclass=Singleton):
|
||||
event = eventmanager.send_event(ChainEventType.CommandRegister, event_data)
|
||||
return event, commands
|
||||
|
||||
def __build_plugin_commands(self, pid: Optional[str] = None) -> Dict[str, dict]:
|
||||
def __build_plugin_commands(self, _: Optional[str] = None) -> Dict[str, dict]:
|
||||
"""
|
||||
构建插件命令
|
||||
"""
|
||||
@@ -277,7 +281,7 @@ class CommandChain(ChainBase, metaclass=Singleton):
|
||||
if command.get("type") == "scheduler":
|
||||
# 定时服务
|
||||
if userid:
|
||||
self.post_message(
|
||||
CommandChain().post_message(
|
||||
Notification(
|
||||
channel=channel,
|
||||
source=source,
|
||||
@@ -290,7 +294,7 @@ class CommandChain(ChainBase, metaclass=Singleton):
|
||||
self.scheduler.start(job_id=command.get("id"))
|
||||
|
||||
if userid:
|
||||
self.post_message(
|
||||
CommandChain().post_message(
|
||||
Notification(
|
||||
channel=channel,
|
||||
source=source,
|
||||
@@ -438,12 +438,15 @@ class EventManager(metaclass=Singleton):
|
||||
|
||||
# 如果类不在全局变量中,尝试动态导入模块并创建实例
|
||||
try:
|
||||
# 导入模块,除了插件,只有chain能响应事件
|
||||
if not class_name.endswith("Chain"):
|
||||
if class_name == "Command":
|
||||
module_name = "app.command"
|
||||
module = importlib.import_module(module_name)
|
||||
elif class_name.endswith("Chain"):
|
||||
module_name = f"app.chain.{class_name[:-5].lower()}"
|
||||
module = importlib.import_module(module_name)
|
||||
else:
|
||||
logger.debug(f"事件处理出错:无效的 Chain 类名: {class_name},类名必须以 'Chain' 结尾")
|
||||
return None
|
||||
module_name = f"app.chain.{class_name[:-5].lower()}"
|
||||
module = importlib.import_module(module_name)
|
||||
if hasattr(module, class_name):
|
||||
class_obj = getattr(module, class_name)()
|
||||
return class_obj
|
||||
|
||||
@@ -210,7 +210,7 @@ class LoggerManager:
|
||||
"""
|
||||
输出警告级别日志(兼容)
|
||||
"""
|
||||
self.logger("warning", msg, *args, **kwargs)
|
||||
self.warning(msg, *args, **kwargs)
|
||||
|
||||
def error(self, msg: str, *args, **kwargs):
|
||||
"""
|
||||
|
||||
@@ -46,9 +46,9 @@ class AuthCredentials(ChainEventData):
|
||||
|
||||
# 输出参数
|
||||
# grant_type 为 authorization_code 时,输出参数包括 username、token、channel、service
|
||||
token: Optional[str] = Field(None, description="认证令牌")
|
||||
channel: Optional[str] = Field(None, description="认证渠道")
|
||||
service: Optional[str] = Field(None, description="服务名称")
|
||||
token: Optional[str] = Field(default=None, description="认证令牌")
|
||||
channel: Optional[str] = Field(default=None, description="认证渠道")
|
||||
service: Optional[str] = Field(default=None, description="服务名称")
|
||||
|
||||
@root_validator(pre=True)
|
||||
def check_fields_based_on_grant_type(cls, values):
|
||||
@@ -92,8 +92,8 @@ class AuthInterceptCredentials(ChainEventData):
|
||||
token: Optional[str] = Field(None, description="认证令牌")
|
||||
|
||||
# 输出参数
|
||||
source: str = Field("未知拦截源", description="拦截源")
|
||||
cancel: bool = Field(False, description="是否取消认证")
|
||||
source: str = Field(default="未知拦截源", description="拦截源")
|
||||
cancel: bool = Field(default=False, description="是否取消认证")
|
||||
|
||||
|
||||
class CommandRegisterEventData(ChainEventData):
|
||||
@@ -116,8 +116,8 @@ class CommandRegisterEventData(ChainEventData):
|
||||
service: Optional[str] = Field(..., description="服务名称")
|
||||
|
||||
# 输出参数
|
||||
cancel: bool = Field(False, description="是否取消注册")
|
||||
source: str = Field("未知拦截源", description="拦截源")
|
||||
cancel: bool = Field(default=False, description="是否取消注册")
|
||||
source: str = Field(default="未知拦截源", description="拦截源")
|
||||
|
||||
|
||||
class TransferRenameEventData(ChainEventData):
|
||||
@@ -143,9 +143,9 @@ class TransferRenameEventData(ChainEventData):
|
||||
render_str: str = Field(..., description="渲染生成的字符串")
|
||||
|
||||
# 输出参数
|
||||
updated: bool = Field(False, description="是否已更新")
|
||||
updated_str: Optional[str] = Field(None, description="更新后的字符串")
|
||||
source: Optional[str] = Field("未知拦截源", description="拦截源")
|
||||
updated: bool = Field(default=False, description="是否已更新")
|
||||
updated_str: Optional[str] = Field(default=None, description="更新后的字符串")
|
||||
source: Optional[str] = Field(default="未知拦截源", description="拦截源")
|
||||
|
||||
|
||||
class ResourceSelectionEventData(BaseModel):
|
||||
@@ -168,9 +168,9 @@ class ResourceSelectionEventData(BaseModel):
|
||||
origin: Optional[str] = Field(None, description="来源")
|
||||
|
||||
# 输出参数
|
||||
updated: bool = Field(False, description="是否已更新")
|
||||
updated_contexts: Optional[List[Context]] = Field(None, description="已更新的资源上下文列表")
|
||||
source: Optional[str] = Field("未知拦截源", description="拦截源")
|
||||
updated: bool = Field(default=False, description="是否已更新")
|
||||
updated_contexts: Optional[List[Context]] = Field(default=None, description="已更新的资源上下文列表")
|
||||
source: Optional[str] = Field(default="未知拦截源", description="拦截源")
|
||||
|
||||
|
||||
class ResourceDownloadEventData(ChainEventData):
|
||||
@@ -200,6 +200,6 @@ class ResourceDownloadEventData(ChainEventData):
|
||||
options: Optional[dict] = Field(None, description="其他参数")
|
||||
|
||||
# 输出参数
|
||||
cancel: bool = Field(False, description="是否取消下载")
|
||||
source: str = Field("未知拦截源", description="拦截源")
|
||||
reason: str = Field("", description="拦截原因")
|
||||
cancel: bool = Field(default=False, description="是否取消下载")
|
||||
source: str = Field(default="未知拦截源", description="拦截源")
|
||||
reason: str = Field(default="", description="拦截原因")
|
||||
|
||||
@@ -27,7 +27,7 @@ from app.schemas import Notification, NotificationType
|
||||
from app.schemas.types import SystemConfigKey
|
||||
from app.db import close_database
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.chain.command import CommandChain
|
||||
from app.command import Command, CommandChain
|
||||
|
||||
|
||||
def start_frontend():
|
||||
@@ -159,7 +159,7 @@ def start_modules(_: FastAPI):
|
||||
# 启动定时服务
|
||||
Scheduler()
|
||||
# 加载命令
|
||||
CommandChain()
|
||||
Command()
|
||||
# 启动前端服务
|
||||
start_frontend()
|
||||
# 检查认证状态
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import asyncio
|
||||
|
||||
from app.chain.command import CommandChain
|
||||
from app.command import Command
|
||||
from app.core.plugin import PluginManager
|
||||
from app.log import logger
|
||||
from app.scheduler import Scheduler
|
||||
@@ -14,7 +14,7 @@ async def init_plugins_async():
|
||||
loop = asyncio.get_event_loop()
|
||||
plugin_manager = PluginManager()
|
||||
scheduler = Scheduler()
|
||||
command = CommandChain()
|
||||
command = Command()
|
||||
|
||||
sync_result = await execute_task(loop, plugin_manager.sync, "插件同步到本地")
|
||||
resolved_dependencies = await execute_task(loop, plugin_manager.install_plugin_missing_dependencies,
|
||||
|
||||
Reference in New Issue
Block a user