From 2ce3ddb75a31d8601c7f96982629187573ccff88 Mon Sep 17 00:00:00 2001 From: InfinityPacer <160988576+InfinityPacer@users.noreply.github.com> Date: Fri, 27 Sep 2024 04:04:56 +0800 Subject: [PATCH] refactor(module): simplify service instantiation with generics --- app/api/endpoints/message.py | 7 +- app/chain/__init__.py | 8 +- app/chain/mediaserver.py | 11 ++- app/helper/downloader.py | 20 +--- app/helper/mediaserver.py | 20 +--- app/helper/notification.py | 39 +------- app/helper/serviceconfig.py | 65 +++++++++++++ app/modules/__init__.py | 136 +++++++++++++++++++++++---- app/modules/emby/__init__.py | 20 ++-- app/modules/jellyfin/__init__.py | 21 ++--- app/modules/plex/__init__.py | 20 ++-- app/modules/qbittorrent/__init__.py | 18 +--- app/modules/slack/__init__.py | 14 +-- app/modules/synologychat/__init__.py | 14 +-- app/modules/telegram/__init__.py | 14 +-- app/modules/transmission/__init__.py | 21 ++--- app/modules/vocechat/__init__.py | 14 +-- app/modules/webpush/__init__.py | 10 +- app/modules/wechat/__init__.py | 14 +-- 19 files changed, 244 insertions(+), 242 deletions(-) create mode 100644 app/helper/serviceconfig.py diff --git a/app/api/endpoints/message.py b/app/api/endpoints/message.py index a1af1f6c..239b052f 100644 --- a/app/api/endpoints/message.py +++ b/app/api/endpoints/message.py @@ -1,8 +1,7 @@ import json from typing import Union, Any, List -from fastapi import APIRouter, BackgroundTasks, Depends -from fastapi import Request +from fastapi import APIRouter, BackgroundTasks, Depends, Request from pywebpush import WebPushException, webpush from sqlalchemy.orm import Session from starlette.responses import PlainTextResponse @@ -15,7 +14,7 @@ from app.db import get_db from app.db.models import User from app.db.models.message import Message from app.db.user_oper import get_current_active_superuser -from app.helper.notification import NotificationHelper +from app.helper.serviceconfig import ServiceConfigHelper from app.log import logger from app.modules.wechat.WXBizMsgCrypt3 import WXBizMsgCrypt from app.schemas.types import MessageChannel @@ -81,7 +80,7 @@ def wechat_verify(echostr: str, msg_signature: str, timestamp: Union[str, int], """ 微信验证响应 """ - clients = NotificationHelper().get_clients() + clients = ServiceConfigHelper.get_notification_configs() if not clients: return for client in clients: diff --git a/app/chain/__init__.py b/app/chain/__init__.py index 8dce6ae3..038d10b0 100644 --- a/app/chain/__init__.py +++ b/app/chain/__init__.py @@ -10,15 +10,14 @@ from ruamel.yaml import CommentedMap from transmission_rpc import File from app.core.config import settings -from app.core.context import Context -from app.core.context import MediaInfo, TorrentInfo +from app.core.context import Context, MediaInfo, TorrentInfo from app.core.event import EventManager from app.core.meta import MetaBase from app.core.module import ModuleManager from app.db.message_oper import MessageOper from app.db.user_oper import UserOper from app.helper.message import MessageHelper -from app.helper.notification import NotificationHelper +from app.helper.serviceconfig import ServiceConfigHelper from app.log import logger from app.schemas import TransferInfo, TransferTorrent, ExistMediaInfo, DownloadingTorrent, CommingMessage, Notification, \ WebhookEventInfo, TmdbEpisode, MediaPerson, FileItem @@ -39,7 +38,6 @@ class ChainBase(metaclass=ABCMeta): self.eventmanager = EventManager() self.messageoper = MessageOper() self.messagehelper = MessageHelper() - self.notificationhelper = NotificationHelper() self.useroper = UserOper() @staticmethod @@ -484,7 +482,7 @@ class ChainBase(metaclass=ABCMeta): # 没有指定用户ID时,按规则确定发送对象 # 默认发送全体 to_targets = None - notify_action = self.notificationhelper.get_switch(message.mtype) + notify_action = ServiceConfigHelper.get_notification_switch(message.mtype) if notify_action == "admin": # 仅发送管理员 logger.info(f"已设置 {message.mtype} 的消息只发送给管理员") diff --git a/app/chain/mediaserver.py b/app/chain/mediaserver.py index a6766a88..b00d1fe3 100644 --- a/app/chain/mediaserver.py +++ b/app/chain/mediaserver.py @@ -5,7 +5,7 @@ from typing import List, Union, Optional from app import schemas from app.chain import ChainBase from app.db.mediaserver_oper import MediaServerOper -from app.helper.mediaserver import MediaServerHelper +from app.helper.serviceconfig import ServiceConfigHelper from app.log import logger lock = threading.Lock() @@ -19,7 +19,6 @@ class MediaServerChain(ChainBase): def __init__(self): super().__init__() self.dboper = MediaServerOper() - self.mediaserverhelper = MediaServerHelper() def librarys(self, server: str, username: str = None, hidden: bool = False) -> List[schemas.MediaServerLibrary]: """ @@ -27,12 +26,14 @@ class MediaServerChain(ChainBase): """ return self.run_module("mediaserver_librarys", server=server, username=username, hidden=hidden) - def items(self, server: str, library_id: Union[str, int], start_index: int = 0, limit: int = 100) -> List[schemas.MediaServerItem]: + def items(self, server: str, library_id: Union[str, int], start_index: int = 0, limit: int = 100) \ + -> List[schemas.MediaServerItem]: """ 获取媒体服务器所有项目 """ data = [] - data_generator = self.run_module("mediaserver_items", server=server, library_id=library_id, start_index=start_index, limit=limit) + data_generator = self.run_module("mediaserver_items", server=server, library_id=library_id, + start_index=start_index, limit=limit) if data_generator: for item in data_generator: if item: @@ -74,7 +75,7 @@ class MediaServerChain(ChainBase): 同步媒体库所有数据到本地数据库 """ # 设置的媒体服务器 - mediaservers = self.mediaserverhelper.get_mediaservers() + mediaservers = ServiceConfigHelper.get_mediaserver_configs() if not mediaservers: return with lock: diff --git a/app/helper/downloader.py b/app/helper/downloader.py index e7bf14fc..86d3d557 100644 --- a/app/helper/downloader.py +++ b/app/helper/downloader.py @@ -1,23 +1,5 @@ -from typing import List - -from app.db.systemconfig_oper import SystemConfigOper -from app.schemas.system import DownloaderConf -from app.schemas.types import SystemConfigKey - - class DownloaderHelper: """ 下载器帮助类 """ - - def __init__(self): - self.systemconfig = SystemConfigOper() - - def get_downloaders(self) -> List[DownloaderConf]: - """ - 获取下载器 - """ - downloader_confs: List[dict] = self.systemconfig.get(SystemConfigKey.Downloaders) - if not downloader_confs: - return [] - return [DownloaderConf(**conf) for conf in downloader_confs] + pass diff --git a/app/helper/mediaserver.py b/app/helper/mediaserver.py index 0b881d96..a21a1155 100644 --- a/app/helper/mediaserver.py +++ b/app/helper/mediaserver.py @@ -1,23 +1,5 @@ -from typing import List - -from app.db.systemconfig_oper import SystemConfigOper -from app.schemas import MediaServerConf -from app.schemas.types import SystemConfigKey - - class MediaServerHelper: """ 媒体服务器帮助类 """ - - def __init__(self): - self.systemconfig = SystemConfigOper() - - def get_mediaservers(self) -> List[MediaServerConf]: - """ - 获取媒体服务器 - """ - mediaserver_confs: List[dict] = self.systemconfig.get(SystemConfigKey.MediaServers) - if not mediaserver_confs: - return [] - return [MediaServerConf(**conf) for conf in mediaserver_confs] + pass diff --git a/app/helper/notification.py b/app/helper/notification.py index 61b20c17..8d32d452 100644 --- a/app/helper/notification.py +++ b/app/helper/notification.py @@ -1,42 +1,5 @@ -from typing import List, Optional - -from app.db.systemconfig_oper import SystemConfigOper -from app.schemas import NotificationConf, NotificationSwitchConf -from app.schemas.types import SystemConfigKey, NotificationType - - class NotificationHelper: """ 消息通知渠道帮助类 """ - - def __init__(self): - self.systemconfig = SystemConfigOper() - - def get_clients(self) -> List[NotificationConf]: - """ - 获取消息通知渠道 - """ - client_confs: List[dict] = self.systemconfig.get(SystemConfigKey.Notifications) - if not client_confs: - return [] - return [NotificationConf(**conf) for conf in client_confs] - - def get_switchs(self) -> List[NotificationSwitchConf]: - """ - 获取消息通知场景开关 - """ - switchs: List[dict] = self.systemconfig.get(SystemConfigKey.NotificationSwitchs) - if not switchs: - return [] - return [NotificationSwitchConf(**switch) for switch in switchs] - - def get_switch(self, mtype: NotificationType) -> Optional[str]: - """ - 获取消息通知场景开关 - """ - switchs = self.get_switchs() - for switch in switchs: - if switch.type == mtype.value: - return switch.action - return None + pass diff --git a/app/helper/serviceconfig.py b/app/helper/serviceconfig.py new file mode 100644 index 00000000..c65cdc12 --- /dev/null +++ b/app/helper/serviceconfig.py @@ -0,0 +1,65 @@ +from typing import List, Type, Optional + +from app.db.systemconfig_oper import SystemConfigOper +from app.schemas import DownloaderConf, MediaServerConf, NotificationConf, NotificationSwitchConf +from app.schemas.types import SystemConfigKey, NotificationType + + +class ServiceConfigHelper: + """ + 配置帮助类,获取不同类型的服务配置 + """ + + @staticmethod + def get_configs(config_key: SystemConfigKey, conf_type: Type) -> List: + """ + 通用获取配置的方法,根据 config_key 获取相应的配置并返回指定类型的配置列表 + + :param config_key: 系统配置的 key + :param conf_type: 用于实例化配置对象的类类型 + :return: 配置对象列表 + """ + config_data = SystemConfigOper().get(config_key) + if not config_data: + return [] + # 直接使用 conf_type 来实例化配置对象 + return [conf_type(**conf) for conf in config_data] + + @staticmethod + def get_downloader_configs() -> List[DownloaderConf]: + """ + 获取下载器的配置 + """ + return ServiceConfigHelper.get_configs(SystemConfigKey.Downloaders, DownloaderConf) + + @staticmethod + def get_mediaserver_configs() -> List[MediaServerConf]: + """ + 获取媒体服务器的配置 + """ + return ServiceConfigHelper.get_configs(SystemConfigKey.MediaServers, MediaServerConf) + + @staticmethod + def get_notification_configs() -> List[NotificationConf]: + """ + 获取消息通知渠道的配置 + """ + return ServiceConfigHelper.get_configs(SystemConfigKey.Notifications, NotificationConf) + + @staticmethod + def get_notification_switches() -> List[NotificationSwitchConf]: + """ + 获取消息通知场景的开关 + """ + return ServiceConfigHelper.get_configs(SystemConfigKey.NotificationSwitchs, NotificationSwitchConf) + + @staticmethod + def get_notification_switch(mtype: NotificationType) -> Optional[str]: + """ + 获取消息通知场景开关 + """ + switchs = ServiceConfigHelper.get_notification_switches() + for switch in switchs: + if switch.type == mtype.value: + return switch.action + return None diff --git a/app/modules/__init__.py b/app/modules/__init__.py index 8bdc51b6..0dc26334 100644 --- a/app/modules/__init__.py +++ b/app/modules/__init__.py @@ -1,6 +1,7 @@ from abc import abstractmethod, ABCMeta -from typing import Dict, Any, Optional, Generic, Tuple, Union, TypeVar +from typing import Generic, Tuple, Union, TypeVar, Type, Dict, Optional, Callable, Any, List +from app.helper.serviceconfig import ServiceConfigHelper from app.schemas import Notification, MessageChannel, NotificationConf, MediaServerConf, DownloaderConf @@ -49,29 +50,73 @@ class _ModuleBase(metaclass=ABCMeta): pass -# 定义一个泛型 T,用于表示具体的配置类型 +# 定义泛型,用于表示具体的服务类型和配置类型 +TService = TypeVar("TService", bound=object) TConf = TypeVar("TConf") -class ConfManagerBase(Generic[TConf]): +class ServiceBase(Generic[TService, TConf], metaclass=ABCMeta): """ - 通用管理基类,支持配置管理和实例管理 + 抽象服务基类,负责服务的初始化、获取实例和配置管理 """ - _configs: Dict[str, TConf] = {} - _instances: Dict[str, Any] = {} - - def get_instance(self, name: str) -> Optional[Any]: + def __init__(self): """ - 获取实例 (如服务/客户端) + 初始化 ServiceBase 类的实例 + """ + self._configs: Dict[str, TConf] = {} + self._instances: Dict[str, TService] = {} + + def init_service(self, service_name: str, + service_type: Optional[Union[Type[TService], Callable[..., TService]]] = None): + """ + 初始化服务,获取配置并实例化对应服务 + + :param service_name: 服务名称,作为配置匹配的依据 + :param service_type: 服务的类型,可以是类类型(Type[TService])、工厂函数(Callable)或 None 来跳过实例化 + """ + configs = self.get_configs() + if not configs: + return + for conf in configs: + if conf.enabled and conf.type == service_name: + self._configs[conf.name] = conf + if service_type: + # 通过服务类型或工厂函数来创建实例 + if isinstance(service_type, type): + # 如果传入的是类类型,调用构造函数实例化 + self._instances[conf.name] = service_type(**conf.config) + else: + # 如果传入的是工厂函数,直接调用工厂函数 + self._instances[conf.name] = service_type(conf) + + def get_instance(self, name: str) -> Optional[TService]: + """ + 获取服务实例 + + :param name: 实例名称 + :return: 返回对应名称的服务实例,若不存在则返回 None """ if not name: return None return self._instances.get(name) + @abstractmethod + def get_configs(self) -> List[TConf]: + """ + 获取服务的配置列表 + + :return: 返回配置列表 + """ + pass + def get_config(self, name: str, ctype: str = None) -> Optional[TConf]: """ 获取配置,支持类型过滤 + + :param name: 配置名称 + :param ctype: 配置类型,可选,默认不进行类型过滤 + :return: 返回符合条件的配置,若不存在则返回 None """ if not name: return None @@ -81,16 +126,33 @@ class ConfManagerBase(Generic[TConf]): return conf if getattr(conf, "type", None) == ctype else None -class _MessageBase(ConfManagerBase[NotificationConf]): +class _MessageBase(ServiceBase[TService, NotificationConf]): """ - 消息基类,继承了通用的配置和实例管理功能,指定配置类型为 NotificationConf + 消息基类 """ - _channel: MessageChannel = None + def __init__(self): + """ + 初始化消息基类,并设置消息通道 + """ + super().__init__() + self._channel: Optional[MessageChannel] = None + + def get_configs(self) -> List[NotificationConf]: + """ + 获取消息通知渠道的配置 + + :return: 返回消息通知的配置列表 + """ + return ServiceConfigHelper.get_notification_configs() def check_message(self, message: Notification, source: str = None) -> bool: """ - 检查消息渠道及消息类型,如不符合则不处理 + 检查消息渠道及消息类型,判断是否处理消息 + + :param message: 要检查的通知消息 + :param source: 消息来源,可选 + :return: 返回布尔值,表示是否处理该消息 """ # 检查消息渠道 if message.channel and message.channel != self._channel: @@ -108,25 +170,63 @@ class _MessageBase(ConfManagerBase[NotificationConf]): return True -class _DownloaderBase(ConfManagerBase[DownloaderConf]): +class _DownloaderBase(ServiceBase[TService, DownloaderConf]): """ 下载器基类 """ - _default_server: Any = None - _default_server_name: str = None + def __init__(self): + """ + 初始化下载器基类,并设置默认服务器 + """ + super().__init__() + self._default_server: Any = None + self._default_server_name: Optional[str] = None + + def init_service(self, service_name: str, + service_type: Optional[Union[Type[TService], Callable[..., TService]]] = None): + """ + 初始化服务,获取配置并实例化对应服务 + + :param service_name: 服务名称,作为配置匹配的依据 + :param service_type: 服务的类型,可以是类类型(Type[TService])或工厂函数(Callable),用于创建服务实例 + """ + super().init_service(service_name=service_name, service_type=service_type) + if self._configs: + for conf in self._configs.values(): + if conf.default: + self._default_server_name = conf.name + self._default_server = self.get_instance(conf.name) def get_instance(self, name: str = None) -> Optional[Any]: """ 获取实例,name为空时,返回默认实例 + + :param name: 实例名称,可选,默认为 None + :return: 返回指定名称的实例,若 name 为 None 则返回默认实例 """ if name: return self.get_instance(name) return self._default_server + def get_configs(self) -> List[DownloaderConf]: + """ + 获取下载器的配置 -class _MediaServerBase(ConfManagerBase[MediaServerConf]): + :return: 返回下载器配置列表 + """ + return ServiceConfigHelper.get_downloader_configs() + + +class _MediaServerBase(ServiceBase[TService, MediaServerConf]): """ 媒体服务器基类 """ - pass + + def get_configs(self) -> List[MediaServerConf]: + """ + 获取媒体服务器的配置 + + :return: 返回媒体服务器配置列表 + """ + return ServiceConfigHelper.get_mediaserver_configs() diff --git a/app/modules/emby/__init__.py b/app/modules/emby/__init__.py index 734be30f..20432cf9 100644 --- a/app/modules/emby/__init__.py +++ b/app/modules/emby/__init__.py @@ -1,8 +1,7 @@ -from typing import Optional, Tuple, Union, Any, List, Generator, Dict +from typing import Optional, Tuple, Union, Any, List, Generator from app import schemas from app.core.context import MediaInfo -from app.helper.mediaserver import MediaServerHelper from app.log import logger from app.modules import _ModuleBase, _MediaServerBase from app.modules.emby.emby import Emby @@ -10,22 +9,14 @@ from app.schemas import MediaServerConf from app.schemas.types import MediaType -class EmbyModule(_ModuleBase, _MediaServerBase): +class EmbyModule(_ModuleBase, _MediaServerBase[Emby]): def init_module(self) -> None: """ 初始化模块 """ - # 读取媒体服务器配置 - self._instances: Dict[str, Emby] = {} - self._configs: Dict[str, MediaServerConf] = {} - mediaservers = MediaServerHelper().get_mediaservers() - if not mediaservers: - return - for server in mediaservers: - if server.type == "emby" and server.enabled: - self._configs[server.name] = server - self._instances[server.name] = Emby(**server.config, sync_libraries=server.sync_libraries) + super().init_service(service_name=Emby.__name__.lower(), + service_type=lambda conf: Emby(**conf.config, sync_libraries=conf.sync_libraries)) @staticmethod def get_name() -> str: @@ -182,7 +173,8 @@ class EmbyModule(_ModuleBase, _MediaServerBase): return server.get_librarys(username=username, hidden=hidden) return None - def mediaserver_items(self, server: str, library_id: str, start_index: int = 0, limit: int = 100) -> Optional[Generator]: + def mediaserver_items(self, server: str, library_id: str, start_index: int = 0, limit: int = 100) \ + -> Optional[Generator]: """ 媒体库项目列表 """ diff --git a/app/modules/jellyfin/__init__.py b/app/modules/jellyfin/__init__.py index e0ec0e71..c03bb8d3 100644 --- a/app/modules/jellyfin/__init__.py +++ b/app/modules/jellyfin/__init__.py @@ -1,8 +1,7 @@ -from typing import Optional, Tuple, Union, Any, List, Generator, Dict +from typing import Optional, Tuple, Union, Any, List, Generator from app import schemas from app.core.context import MediaInfo -from app.helper.mediaserver import MediaServerHelper from app.log import logger from app.modules import _ModuleBase, _MediaServerBase from app.modules.jellyfin.jellyfin import Jellyfin @@ -10,22 +9,14 @@ from app.schemas import MediaServerConf from app.schemas.types import MediaType -class JellyfinModule(_ModuleBase, _MediaServerBase): +class JellyfinModule(_ModuleBase, _MediaServerBase[Jellyfin]): def init_module(self) -> None: """ 初始化模块 """ - # 读取媒体服务器配置 - self._instances: Dict[str, Jellyfin] = {} - self._configs: Dict[str, MediaServerConf] = {} - mediaservers = MediaServerHelper().get_mediaservers() - if not mediaservers: - return - for server in mediaservers: - if server.type == "jellyfin" and server.enabled: - self._configs[server.name] = server - self._instances[server.name] = Jellyfin(**server.config, sync_libraries=server.sync_libraries) + super().init_service(service_name=Jellyfin.__name__.lower(), + service_type=lambda conf: Jellyfin(**conf.config, sync_libraries=conf.sync_libraries)) @staticmethod def get_name() -> str: @@ -180,8 +171,8 @@ class JellyfinModule(_ModuleBase, _MediaServerBase): return server.get_librarys(username=username, hidden=hidden) return None - def mediaserver_items(self, server: str, library_id: str, start_index: int = 0, limit: int = 100) -> Optional[ - Generator]: + def mediaserver_items(self, server: str, library_id: str, start_index: int = 0, limit: int = 100) \ + -> Optional[Generator]: """ 媒体库项目列表 """ diff --git a/app/modules/plex/__init__.py b/app/modules/plex/__init__.py index bda5dc99..afe108d7 100644 --- a/app/modules/plex/__init__.py +++ b/app/modules/plex/__init__.py @@ -1,8 +1,7 @@ -from typing import Optional, Tuple, Union, Any, List, Generator, Dict +from typing import Optional, Tuple, Union, Any, List, Generator from app import schemas from app.core.context import MediaInfo -from app.helper.mediaserver import MediaServerHelper from app.log import logger from app.modules import _ModuleBase, _MediaServerBase from app.modules.plex.plex import Plex @@ -10,22 +9,14 @@ from app.schemas import MediaServerConf from app.schemas.types import MediaType -class PlexModule(_ModuleBase, _MediaServerBase): +class PlexModule(_ModuleBase, _MediaServerBase[Plex]): def init_module(self) -> None: """ 初始化模块 """ - # 读取媒体服务器配置 - self._instances: Dict[str, Plex] = {} - self._configs: Dict[str, MediaServerConf] = {} - mediaservers = MediaServerHelper().get_mediaservers() - if not mediaservers: - return - for server in mediaservers: - if server.type == "plex" and server.enabled: - self._configs[server.name] = server - self._instances[server.name] = Plex(**server.config, sync_libraries=server.sync_libraries) + super().init_service(service_name=Plex.__name__.lower(), + service_type=lambda conf: Plex(**conf.config, sync_libraries=conf.sync_libraries)) @staticmethod def get_name() -> str: @@ -168,7 +159,8 @@ class PlexModule(_ModuleBase, _MediaServerBase): return server.get_librarys(hidden) return None - def mediaserver_items(self, server: str, library_id: str, start_index: int = 0, limit: int = 100) -> Optional[Generator]: + def mediaserver_items(self, server: str, library_id: str, start_index: int = 0, limit: int = 100) \ + -> Optional[Generator]: """ 媒体库项目列表 """ diff --git a/app/modules/qbittorrent/__init__.py b/app/modules/qbittorrent/__init__.py index ab3707a5..7724cec1 100644 --- a/app/modules/qbittorrent/__init__.py +++ b/app/modules/qbittorrent/__init__.py @@ -1,6 +1,6 @@ import shutil from pathlib import Path -from typing import Set, Tuple, Optional, Union, List, Dict +from typing import Set, Tuple, Optional, Union, List from qbittorrentapi import TorrentFilesList from torrentool.torrent import Torrent @@ -8,7 +8,6 @@ from torrentool.torrent import Torrent from app import schemas from app.core.config import settings from app.core.metainfo import MetaInfo -from app.helper.downloader import DownloaderHelper from app.log import logger from app.modules import _ModuleBase, _DownloaderBase from app.modules.qbittorrent.qbittorrent import Qbittorrent @@ -18,23 +17,14 @@ from app.utils.string import StringUtils from app.utils.system import SystemUtils -class QbittorrentModule(_ModuleBase, _DownloaderBase): +class QbittorrentModule(_ModuleBase, _DownloaderBase[Qbittorrent]): def init_module(self) -> None: """ 初始化模块 """ - # 读取下载器配置 - self._instances: Dict[str, Qbittorrent] = {} - configs = DownloaderHelper().get_downloader_conf() - if not configs: - return - for conf in configs: - if conf.type == "qbittorrent" and conf.enabled: - self._instances[conf.name] = Qbittorrent(**conf.config) - if conf.default: - self._default_server_name = conf.name - self._default_server = self._instances[conf.name] + super().init_service(service_name=Qbittorrent.__name__.lower(), + service_type=Qbittorrent) @staticmethod def get_name() -> str: diff --git a/app/modules/slack/__init__.py b/app/modules/slack/__init__.py index e9f9c544..0503273b 100644 --- a/app/modules/slack/__init__.py +++ b/app/modules/slack/__init__.py @@ -4,28 +4,20 @@ from typing import Optional, Union, List, Tuple, Any from app.core.config import settings from app.core.context import MediaInfo, Context -from app.helper.notification import NotificationHelper from app.log import logger from app.modules import _ModuleBase, _MessageBase from app.modules.slack.slack import Slack from app.schemas import MessageChannel, CommingMessage, Notification -class SlackModule(_ModuleBase, _MessageBase): +class SlackModule(_ModuleBase, _MessageBase[Slack]): def init_module(self) -> None: """ 初始化模块 """ - clients = NotificationHelper().get_clients() - if not clients: - return - self._configs = {} - self._instances = {} - for client in clients: - if client.type == "slack" and client.enabled: - self._configs[client.name] = client - self._instances[client.name] = Slack(**client.config, name=client.name) + super().init_service(service_name=Slack.__name__.lower(), + service_type=Slack) @staticmethod def get_name() -> str: diff --git a/app/modules/synologychat/__init__.py b/app/modules/synologychat/__init__.py index 588eba0b..7180c430 100644 --- a/app/modules/synologychat/__init__.py +++ b/app/modules/synologychat/__init__.py @@ -1,28 +1,20 @@ from typing import Optional, Union, List, Tuple, Any from app.core.context import MediaInfo, Context -from app.helper.notification import NotificationHelper from app.log import logger from app.modules import _ModuleBase, _MessageBase from app.modules.synologychat.synologychat import SynologyChat from app.schemas import MessageChannel, CommingMessage, Notification -class SynologyChatModule(_ModuleBase, _MessageBase): +class SynologyChatModule(_ModuleBase, _MessageBase[SynologyChat]): def init_module(self) -> None: """ 初始化模块 """ - clients = NotificationHelper().get_clients() - if not clients: - return - self._configs = {} - self._instances = {} - for client in clients: - if client.type == "synologychat" and client.enabled: - self._configs[client.name] = client - self._instances[client.name] = SynologyChat(**client.config) + super().init_service(service_name=SynologyChat.__name__.lower(), + service_type=SynologyChat) @staticmethod def get_name() -> str: diff --git a/app/modules/telegram/__init__.py b/app/modules/telegram/__init__.py index 4696302a..97c5351d 100644 --- a/app/modules/telegram/__init__.py +++ b/app/modules/telegram/__init__.py @@ -3,28 +3,20 @@ from typing import Optional, Union, List, Tuple, Any, Dict from app.core.config import settings from app.core.context import MediaInfo, Context -from app.helper.notification import NotificationHelper from app.log import logger from app.modules import _ModuleBase, _MessageBase from app.modules.telegram.telegram import Telegram from app.schemas import MessageChannel, CommingMessage, Notification -class TelegramModule(_ModuleBase, _MessageBase): +class TelegramModule(_ModuleBase, _MessageBase[Telegram]): def init_module(self) -> None: """ 初始化模块 """ - clients = NotificationHelper().get_clients() - if not clients: - return - self._configs = {} - self._instances = {} - for client in clients: - if client.type == "telegram" and client.enabled: - self._configs[client.name] = client - self._instances[client.name] = Telegram(**client.config, name=client.name) + super().init_service(service_name=Telegram.__name__.lower(), + service_type=Telegram) @staticmethod def get_name() -> str: diff --git a/app/modules/transmission/__init__.py b/app/modules/transmission/__init__.py index 53f02fb5..dfd865bd 100644 --- a/app/modules/transmission/__init__.py +++ b/app/modules/transmission/__init__.py @@ -1,6 +1,6 @@ import shutil from pathlib import Path -from typing import Set, Tuple, Optional, Union, List, Dict +from typing import Set, Tuple, Optional, Union, List from torrentool.torrent import Torrent from transmission_rpc import File @@ -8,7 +8,6 @@ from transmission_rpc import File from app import schemas from app.core.config import settings from app.core.metainfo import MetaInfo -from app.helper.downloader import DownloaderHelper from app.log import logger from app.modules import _ModuleBase, _DownloaderBase from app.modules.transmission.transmission import Transmission @@ -18,20 +17,14 @@ from app.utils.string import StringUtils from app.utils.system import SystemUtils -class TransmissionModule(_ModuleBase, _DownloaderBase): +class TransmissionModule(_ModuleBase, _DownloaderBase[Transmission]): def init_module(self) -> None: - # 读取下载器配置 - self._instances: Dict[str, Transmission] = {} - configs = DownloaderHelper().get_downloader_conf() - if not configs: - return - for conf in configs: - if conf.type == "transmission" and conf.enabled: - self._instances[conf.name] = Transmission(**conf.config) - if conf.default: - self._default_server_name = conf.name - self._default_server = self._instances[conf.name] + """ + 初始化模块 + """ + super().init_service(service_name=Transmission.__name__.lower(), + service_type=Transmission) @staticmethod def get_name() -> str: diff --git a/app/modules/vocechat/__init__.py b/app/modules/vocechat/__init__.py index e2d8c5be..bc802f41 100644 --- a/app/modules/vocechat/__init__.py +++ b/app/modules/vocechat/__init__.py @@ -3,28 +3,20 @@ from typing import Optional, Union, List, Tuple, Any, Dict from app.core.config import settings from app.core.context import Context, MediaInfo -from app.helper.notification import NotificationHelper from app.log import logger from app.modules import _ModuleBase, _MessageBase from app.modules.vocechat.vocechat import VoceChat from app.schemas import MessageChannel, CommingMessage, Notification -class VoceChatModule(_ModuleBase, _MessageBase): +class VoceChatModule(_ModuleBase, _MessageBase[VoceChat]): def init_module(self) -> None: """ 初始化模块 """ - clients = NotificationHelper().get_clients() - if not clients: - return - self._configs = {} - self._instances = {} - for client in clients: - if client.type == "vocechat" and client.enabled: - self._configs[client.name] = client - self._instances[client.name] = VoceChat(**client.config) + super().init_service(service_name=VoceChat.__name__.lower(), + service_type=VoceChat) @staticmethod def get_name() -> str: diff --git a/app/modules/webpush/__init__.py b/app/modules/webpush/__init__.py index 72d90eb9..89f45f3c 100644 --- a/app/modules/webpush/__init__.py +++ b/app/modules/webpush/__init__.py @@ -4,24 +4,18 @@ from typing import Union, Tuple from pywebpush import webpush, WebPushException from app.core.config import global_vars, settings -from app.helper.notification import NotificationHelper from app.log import logger from app.modules import _ModuleBase, _MessageBase from app.schemas import Notification class WebPushModule(_ModuleBase, _MessageBase): + def init_module(self) -> None: """ 初始化模块 """ - clients = NotificationHelper().get_clients() - if not clients: - return - self._configs = {} - for client in clients: - if client.type == "webpush" and client.enabled: - self._configs[client.name] = client + super().init_service(service_name=self.get_name().lower()) @staticmethod def get_name() -> str: diff --git a/app/modules/wechat/__init__.py b/app/modules/wechat/__init__.py index bbd15203..5994e46a 100644 --- a/app/modules/wechat/__init__.py +++ b/app/modules/wechat/__init__.py @@ -2,7 +2,6 @@ import xml.dom.minidom from typing import Optional, Union, List, Tuple, Any, Dict from app.core.context import Context, MediaInfo -from app.helper.notification import NotificationHelper from app.log import logger from app.modules import _ModuleBase, _MessageBase from app.modules.wechat.WXBizMsgCrypt3 import WXBizMsgCrypt @@ -11,21 +10,14 @@ from app.schemas import MessageChannel, CommingMessage, Notification from app.utils.dom import DomUtils -class WechatModule(_ModuleBase, _MessageBase): +class WechatModule(_ModuleBase, _MessageBase[WeChat]): def init_module(self) -> None: """ 初始化模块 """ - clients = NotificationHelper().get_clients() - if not clients: - return - self._configs = {} - self._instances = {} - for client in clients: - if client.type == "wechat" and client.enabled: - self._configs[client.name] = client - self._instances[client.name] = WeChat(**client.config) + super().init_service(service_name=WeChat.__name__.lower(), + service_type=WeChat) @staticmethod def get_name() -> str: