diff --git a/app/api/endpoints/message.py b/app/api/endpoints/message.py index a1af1f6c..239b052f 100644 --- a/app/api/endpoints/message.py +++ b/app/api/endpoints/message.py @@ -1,8 +1,7 @@ import json from typing import Union, Any, List -from fastapi import APIRouter, BackgroundTasks, Depends -from fastapi import Request +from fastapi import APIRouter, BackgroundTasks, Depends, Request from pywebpush import WebPushException, webpush from sqlalchemy.orm import Session from starlette.responses import PlainTextResponse @@ -15,7 +14,7 @@ from app.db import get_db from app.db.models import User from app.db.models.message import Message from app.db.user_oper import get_current_active_superuser -from app.helper.notification import NotificationHelper +from app.helper.serviceconfig import ServiceConfigHelper from app.log import logger from app.modules.wechat.WXBizMsgCrypt3 import WXBizMsgCrypt from app.schemas.types import MessageChannel @@ -81,7 +80,7 @@ def wechat_verify(echostr: str, msg_signature: str, timestamp: Union[str, int], """ 微信验证响应 """ - clients = NotificationHelper().get_clients() + clients = ServiceConfigHelper.get_notification_configs() if not clients: return for client in clients: diff --git a/app/chain/__init__.py b/app/chain/__init__.py index 8dce6ae3..038d10b0 100644 --- a/app/chain/__init__.py +++ b/app/chain/__init__.py @@ -10,15 +10,14 @@ from ruamel.yaml import CommentedMap from transmission_rpc import File from app.core.config import settings -from app.core.context import Context -from app.core.context import MediaInfo, TorrentInfo +from app.core.context import Context, MediaInfo, TorrentInfo from app.core.event import EventManager from app.core.meta import MetaBase from app.core.module import ModuleManager from app.db.message_oper import MessageOper from app.db.user_oper import UserOper from app.helper.message import MessageHelper -from app.helper.notification import NotificationHelper +from app.helper.serviceconfig import ServiceConfigHelper from app.log import logger from app.schemas import TransferInfo, TransferTorrent, ExistMediaInfo, DownloadingTorrent, CommingMessage, Notification, \ WebhookEventInfo, TmdbEpisode, MediaPerson, FileItem @@ -39,7 +38,6 @@ class ChainBase(metaclass=ABCMeta): self.eventmanager = EventManager() self.messageoper = MessageOper() self.messagehelper = MessageHelper() - self.notificationhelper = NotificationHelper() self.useroper = UserOper() @staticmethod @@ -484,7 +482,7 @@ class ChainBase(metaclass=ABCMeta): # 没有指定用户ID时,按规则确定发送对象 # 默认发送全体 to_targets = None - notify_action = self.notificationhelper.get_switch(message.mtype) + notify_action = ServiceConfigHelper.get_notification_switch(message.mtype) if notify_action == "admin": # 仅发送管理员 logger.info(f"已设置 {message.mtype} 的消息只发送给管理员") diff --git a/app/chain/mediaserver.py b/app/chain/mediaserver.py index a6766a88..b00d1fe3 100644 --- a/app/chain/mediaserver.py +++ b/app/chain/mediaserver.py @@ -5,7 +5,7 @@ from typing import List, Union, Optional from app import schemas from app.chain import ChainBase from app.db.mediaserver_oper import MediaServerOper -from app.helper.mediaserver import MediaServerHelper +from app.helper.serviceconfig import ServiceConfigHelper from app.log import logger lock = threading.Lock() @@ -19,7 +19,6 @@ class MediaServerChain(ChainBase): def __init__(self): super().__init__() self.dboper = MediaServerOper() - self.mediaserverhelper = MediaServerHelper() def librarys(self, server: str, username: str = None, hidden: bool = False) -> List[schemas.MediaServerLibrary]: """ @@ -27,12 +26,14 @@ class MediaServerChain(ChainBase): """ return self.run_module("mediaserver_librarys", server=server, username=username, hidden=hidden) - def items(self, server: str, library_id: Union[str, int], start_index: int = 0, limit: int = 100) -> List[schemas.MediaServerItem]: + def items(self, server: str, library_id: Union[str, int], start_index: int = 0, limit: int = 100) \ + -> List[schemas.MediaServerItem]: """ 获取媒体服务器所有项目 """ data = [] - data_generator = self.run_module("mediaserver_items", server=server, library_id=library_id, start_index=start_index, limit=limit) + data_generator = self.run_module("mediaserver_items", server=server, library_id=library_id, + start_index=start_index, limit=limit) if data_generator: for item in data_generator: if item: @@ -74,7 +75,7 @@ class MediaServerChain(ChainBase): 同步媒体库所有数据到本地数据库 """ # 设置的媒体服务器 - mediaservers = self.mediaserverhelper.get_mediaservers() + mediaservers = ServiceConfigHelper.get_mediaserver_configs() if not mediaservers: return with lock: diff --git a/app/helper/downloader.py b/app/helper/downloader.py index e7bf14fc..86d3d557 100644 --- a/app/helper/downloader.py +++ b/app/helper/downloader.py @@ -1,23 +1,5 @@ -from typing import List - -from app.db.systemconfig_oper import SystemConfigOper -from app.schemas.system import DownloaderConf -from app.schemas.types import SystemConfigKey - - class DownloaderHelper: """ 下载器帮助类 """ - - def __init__(self): - self.systemconfig = SystemConfigOper() - - def get_downloaders(self) -> List[DownloaderConf]: - """ - 获取下载器 - """ - downloader_confs: List[dict] = self.systemconfig.get(SystemConfigKey.Downloaders) - if not downloader_confs: - return [] - return [DownloaderConf(**conf) for conf in downloader_confs] + pass diff --git a/app/helper/mediaserver.py b/app/helper/mediaserver.py index 0b881d96..a21a1155 100644 --- a/app/helper/mediaserver.py +++ b/app/helper/mediaserver.py @@ -1,23 +1,5 @@ -from typing import List - -from app.db.systemconfig_oper import SystemConfigOper -from app.schemas import MediaServerConf -from app.schemas.types import SystemConfigKey - - class MediaServerHelper: """ 媒体服务器帮助类 """ - - def __init__(self): - self.systemconfig = SystemConfigOper() - - def get_mediaservers(self) -> List[MediaServerConf]: - """ - 获取媒体服务器 - """ - mediaserver_confs: List[dict] = self.systemconfig.get(SystemConfigKey.MediaServers) - if not mediaserver_confs: - return [] - return [MediaServerConf(**conf) for conf in mediaserver_confs] + pass diff --git a/app/helper/notification.py b/app/helper/notification.py index 61b20c17..8d32d452 100644 --- a/app/helper/notification.py +++ b/app/helper/notification.py @@ -1,42 +1,5 @@ -from typing import List, Optional - -from app.db.systemconfig_oper import SystemConfigOper -from app.schemas import NotificationConf, NotificationSwitchConf -from app.schemas.types import SystemConfigKey, NotificationType - - class NotificationHelper: """ 消息通知渠道帮助类 """ - - def __init__(self): - self.systemconfig = SystemConfigOper() - - def get_clients(self) -> List[NotificationConf]: - """ - 获取消息通知渠道 - """ - client_confs: List[dict] = self.systemconfig.get(SystemConfigKey.Notifications) - if not client_confs: - return [] - return [NotificationConf(**conf) for conf in client_confs] - - def get_switchs(self) -> List[NotificationSwitchConf]: - """ - 获取消息通知场景开关 - """ - switchs: List[dict] = self.systemconfig.get(SystemConfigKey.NotificationSwitchs) - if not switchs: - return [] - return [NotificationSwitchConf(**switch) for switch in switchs] - - def get_switch(self, mtype: NotificationType) -> Optional[str]: - """ - 获取消息通知场景开关 - """ - switchs = self.get_switchs() - for switch in switchs: - if switch.type == mtype.value: - return switch.action - return None + pass diff --git a/app/helper/serviceconfig.py b/app/helper/serviceconfig.py new file mode 100644 index 00000000..c65cdc12 --- /dev/null +++ b/app/helper/serviceconfig.py @@ -0,0 +1,65 @@ +from typing import List, Type, Optional + +from app.db.systemconfig_oper import SystemConfigOper +from app.schemas import DownloaderConf, MediaServerConf, NotificationConf, NotificationSwitchConf +from app.schemas.types import SystemConfigKey, NotificationType + + +class ServiceConfigHelper: + """ + 配置帮助类,获取不同类型的服务配置 + """ + + @staticmethod + def get_configs(config_key: SystemConfigKey, conf_type: Type) -> List: + """ + 通用获取配置的方法,根据 config_key 获取相应的配置并返回指定类型的配置列表 + + :param config_key: 系统配置的 key + :param conf_type: 用于实例化配置对象的类类型 + :return: 配置对象列表 + """ + config_data = SystemConfigOper().get(config_key) + if not config_data: + return [] + # 直接使用 conf_type 来实例化配置对象 + return [conf_type(**conf) for conf in config_data] + + @staticmethod + def get_downloader_configs() -> List[DownloaderConf]: + """ + 获取下载器的配置 + """ + return ServiceConfigHelper.get_configs(SystemConfigKey.Downloaders, DownloaderConf) + + @staticmethod + def get_mediaserver_configs() -> List[MediaServerConf]: + """ + 获取媒体服务器的配置 + """ + return ServiceConfigHelper.get_configs(SystemConfigKey.MediaServers, MediaServerConf) + + @staticmethod + def get_notification_configs() -> List[NotificationConf]: + """ + 获取消息通知渠道的配置 + """ + return ServiceConfigHelper.get_configs(SystemConfigKey.Notifications, NotificationConf) + + @staticmethod + def get_notification_switches() -> List[NotificationSwitchConf]: + """ + 获取消息通知场景的开关 + """ + return ServiceConfigHelper.get_configs(SystemConfigKey.NotificationSwitchs, NotificationSwitchConf) + + @staticmethod + def get_notification_switch(mtype: NotificationType) -> Optional[str]: + """ + 获取消息通知场景开关 + """ + switchs = ServiceConfigHelper.get_notification_switches() + for switch in switchs: + if switch.type == mtype.value: + return switch.action + return None diff --git a/app/modules/__init__.py b/app/modules/__init__.py index 6b723e46..3a54d817 100644 --- a/app/modules/__init__.py +++ b/app/modules/__init__.py @@ -1,7 +1,8 @@ from abc import abstractmethod, ABCMeta -from typing import Tuple, Union, Dict, Any, Optional +from typing import Generic, Tuple, Union, TypeVar, Type, Dict, Optional, Callable, Any, List -from app.schemas import Notification, MessageChannel, NotificationConf, MediaServerConf +from app.helper.serviceconfig import ServiceConfigHelper +from app.schemas import Notification, MessageChannel, NotificationConf, MediaServerConf, DownloaderConf class _ModuleBase(metaclass=ABCMeta): @@ -49,37 +50,109 @@ class _ModuleBase(metaclass=ABCMeta): pass -class _MessageBase: +# 定义泛型,用于表示具体的服务类型和配置类型 +TService = TypeVar("TService", bound=object) +TConf = TypeVar("TConf") + + +class ServiceBase(Generic[TService, TConf], metaclass=ABCMeta): """ - 消息基类 + 抽象服务基类,负责服务的初始化、获取实例和配置管理 """ - _channel: MessageChannel = None - _configs: Dict[str, NotificationConf] = {} - _clients: Dict[str, Any] = {} - - def get_client(self, name: str) -> Optional[Any]: + def __init__(self): """ - 获取客户端 + 初始化 ServiceBase 类的实例 + """ + self._configs: Dict[str, TConf] = {} + self._instances: Dict[str, TService] = {} + + def init_service(self, service_name: str, + service_type: Optional[Union[Type[TService], Callable[..., TService]]] = None): + """ + 初始化服务,获取配置并实例化对应服务 + + :param service_name: 服务名称,作为配置匹配的依据 + :param service_type: 服务的类型,可以是类类型(Type[TService])、工厂函数(Callable)或 None 来跳过实例化 + """ + configs = self.get_configs() + if not configs: + return + for conf in configs: + if conf.enabled and conf.type == service_name: + self._configs[conf.name] = conf + if service_type: + # 通过服务类型或工厂函数来创建实例 + if isinstance(service_type, type): + # 如果传入的是类类型,调用构造函数实例化 + self._instances[conf.name] = service_type(**conf.config) + else: + # 如果传入的是工厂函数,直接调用工厂函数 + self._instances[conf.name] = service_type(conf) + + def get_instance(self, name: str) -> Optional[TService]: + """ + 获取服务实例 + + :param name: 实例名称 + :return: 返回对应名称的服务实例,若不存在则返回 None """ if not name: return None - return self._clients.get(name) + return self._instances.get(name) - def get_config(self, name: str, ctype: str = None) -> Optional[NotificationConf]: + @abstractmethod + def get_configs(self) -> List[TConf]: """ - 获取配置 + 获取服务的配置列表 + + :return: 返回配置列表 + """ + pass + + def get_config(self, name: str, ctype: str = None) -> Optional[TConf]: + """ + 获取配置,支持类型过滤 + + :param name: 配置名称 + :param ctype: 配置类型,可选,默认不进行类型过滤 + :return: 返回符合条件的配置,若不存在则返回 None """ if not name: return None conf = self._configs.get(name) if not ctype: return conf - return conf if conf.type == ctype else None + return conf if getattr(conf, "type", None) == ctype else None - def checkMessage(self, message: Notification, source: str = None) -> bool: + +class _MessageBase(ServiceBase[TService, NotificationConf]): + """ + 消息基类 + """ + + def __init__(self): """ - 检查消息渠道及消息类型,如不符合则不处理 + 初始化消息基类,并设置消息通道 + """ + super().__init__() + self._channel: Optional[MessageChannel] = None + + def get_configs(self) -> List[NotificationConf]: + """ + 获取消息通知渠道的配置 + + :return: 返回消息通知的配置列表 + """ + return ServiceConfigHelper.get_notification_configs() + + def check_message(self, message: Notification, source: str = None) -> bool: + """ + 检查消息渠道及消息类型,判断是否处理消息 + + :param message: 要检查的通知消息 + :param source: 消息来源,可选 + :return: 返回布尔值,表示是否处理该消息 """ # 检查消息渠道 if message.channel and message.channel != self._channel: @@ -97,45 +170,63 @@ class _MessageBase: return True -class _DownloaderBase: +class _DownloaderBase(ServiceBase[TService, DownloaderConf]): """ 下载器基类 """ - _servers: Dict[str, Any] = {} - _default_server: Any = None - _default_server_name: str = None - - def get_server(self, name: str = None) -> Optional[Any]: + def __init__(self): """ - 获取服务器,name为空则返回默认服务器 + 初始化下载器基类,并设置默认服务器 + """ + super().__init__() + self._default_server: Any = None + self._default_server_name: Optional[str] = None + + def init_service(self, service_name: str, + service_type: Optional[Union[Type[TService], Callable[..., TService]]] = None): + """ + 初始化服务,获取配置并实例化对应服务 + + :param service_name: 服务名称,作为配置匹配的依据 + :param service_type: 服务的类型,可以是类类型(Type[TService])或工厂函数(Callable),用于创建服务实例 + """ + super().init_service(service_name=service_name, service_type=service_type) + if self._configs: + for conf in self._configs.values(): + if conf.default: + self._default_server_name = conf.name + self._default_server = self.get_instance(conf.name) + + def get_instance(self, name: str = None) -> Optional[Any]: + """ + 获取实例,name为空时,返回默认实例 + + :param name: 实例名称,可选,默认为 None + :return: 返回指定名称的实例,若 name 为 None 则返回默认实例 """ if name: - return self._servers.get(name) + return self._instances.get(name) return self._default_server + def get_configs(self) -> List[DownloaderConf]: + """ + 获取下载器的配置 -class _MediaServerBase: + :return: 返回下载器配置列表 + """ + return ServiceConfigHelper.get_downloader_configs() + + +class _MediaServerBase(ServiceBase[TService, MediaServerConf]): """ 媒体服务器基类 """ - _servers: Dict[str, Any] = {} - _configs: Dict[str, MediaServerConf] = {} + def get_configs(self) -> List[MediaServerConf]: + """ + 获取媒体服务器的配置 - def get_server(self, name: str) -> Optional[Any]: + :return: 返回媒体服务器配置列表 """ - 获取Plex服务器 - """ - return self._servers.get(name) - - def get_config(self, name: str, mtype: str = None) -> Optional[MediaServerConf]: - """ - 获取配置 - """ - if not name: - return None - conf = self._configs.get(name) - if not mtype: - return conf - return conf if conf.type == mtype else None + return ServiceConfigHelper.get_mediaserver_configs() diff --git a/app/modules/emby/__init__.py b/app/modules/emby/__init__.py index d0b345fa..20432cf9 100644 --- a/app/modules/emby/__init__.py +++ b/app/modules/emby/__init__.py @@ -1,8 +1,7 @@ -from typing import Optional, Tuple, Union, Any, List, Generator, Dict +from typing import Optional, Tuple, Union, Any, List, Generator from app import schemas from app.core.context import MediaInfo -from app.helper.mediaserver import MediaServerHelper from app.log import logger from app.modules import _ModuleBase, _MediaServerBase from app.modules.emby.emby import Emby @@ -10,22 +9,14 @@ from app.schemas import MediaServerConf from app.schemas.types import MediaType -class EmbyModule(_ModuleBase, _MediaServerBase): +class EmbyModule(_ModuleBase, _MediaServerBase[Emby]): def init_module(self) -> None: """ 初始化模块 """ - # 读取媒体服务器配置 - self._servers: Dict[str, Emby] = {} - self._configs: Dict[str, MediaServerConf] = {} - mediaservers = MediaServerHelper().get_mediaservers() - if not mediaservers: - return - for server in mediaservers: - if server.type == "emby" and server.enabled: - self._configs[server.name] = server - self._servers[server.name] = Emby(**server.config, sync_libraries=server.sync_libraries) + super().init_service(service_name=Emby.__name__.lower(), + service_type=lambda conf: Emby(**conf.config, sync_libraries=conf.sync_libraries)) @staticmethod def get_name() -> str: @@ -38,9 +29,9 @@ class EmbyModule(_ModuleBase, _MediaServerBase): """ 测试模块连接性 """ - if not self._servers: + if not self._instances: return None - for name, server in self._servers.items(): + for name, server in self._instances.items(): if server.is_inactive(): server.reconnect() if not server.get_user(): @@ -55,7 +46,7 @@ class EmbyModule(_ModuleBase, _MediaServerBase): 定时任务,每10分钟调用一次 """ # 定时重连 - for name, server in self._servers.items(): + for name, server in self._instances.items(): if server.is_inactive(): logger.info(f"Emby服务器 {name} 连接断开,尝试重连 ...") server.reconnect() @@ -68,7 +59,7 @@ class EmbyModule(_ModuleBase, _MediaServerBase): :return: token or None """ # Emby认证 - for server in self._servers.values(): + for server in self._instances.values(): result = server.authenticate(name, password) if result: return result @@ -87,7 +78,7 @@ class EmbyModule(_ModuleBase, _MediaServerBase): server_config: MediaServerConf = self.get_config(source, 'emby') if not server_config: return None - server: Emby = self.get_server(source) + server: Emby = self.get_instance(source) if not server: return None return server.get_webhook_message(form, args) @@ -95,7 +86,7 @@ class EmbyModule(_ModuleBase, _MediaServerBase): for conf in self._configs.values(): if conf.type != "emby": continue - server = self.get_server(conf.name) + server = self.get_instance(conf.name) if server: result = server.get_webhook_message(form, args) if result: @@ -109,7 +100,7 @@ class EmbyModule(_ModuleBase, _MediaServerBase): :param itemid: 媒体服务器ItemID :return: 如不存在返回None,存在时返回信息,包括每季已存在所有集{type: movie/tv, seasons: {season: [episodes]}} """ - for name, server in self._servers.items(): + for name, server in self._instances.items(): if mediainfo.type == MediaType.MOVIE: if itemid: movie = server.get_iteminfo(itemid) @@ -156,12 +147,12 @@ class EmbyModule(_ModuleBase, _MediaServerBase): 媒体数量统计 """ if server: - server: Emby = self.get_server(server) + server: Emby = self.get_instance(server) if not server: return None servers = [server] else: - servers = self._servers.values() + servers = self._instances.values() media_statistics = [] for server in servers: media_statistic = server.get_medias_count() @@ -177,16 +168,17 @@ class EmbyModule(_ModuleBase, _MediaServerBase): """ 媒体库列表 """ - server: Emby = self.get_server(server) + server: Emby = self.get_instance(server) if server: return server.get_librarys(username=username, hidden=hidden) return None - def mediaserver_items(self, server: str, library_id: str, start_index: int = 0, limit: int = 100) -> Optional[Generator]: + def mediaserver_items(self, server: str, library_id: str, start_index: int = 0, limit: int = 100) \ + -> Optional[Generator]: """ 媒体库项目列表 """ - server: Emby = self.get_server(server) + server: Emby = self.get_instance(server) if server: return server.get_items(library_id, start_index, limit) return None @@ -195,7 +187,7 @@ class EmbyModule(_ModuleBase, _MediaServerBase): """ 媒体库项目详情 """ - server: Emby = self.get_server(server) + server: Emby = self.get_instance(server) if server: return server.get_iteminfo(item_id) return None @@ -205,7 +197,7 @@ class EmbyModule(_ModuleBase, _MediaServerBase): """ 获取剧集信息 """ - server: Emby = self.get_server(server) + server: Emby = self.get_instance(server) if not server: return None _, seasoninfo = server.get_tv_episodes(item_id=item_id) @@ -221,7 +213,7 @@ class EmbyModule(_ModuleBase, _MediaServerBase): """ 获取媒体服务器正在播放信息 """ - server: Emby = self.get_server(server) + server: Emby = self.get_instance(server) if not server: return [] return server.get_resume(num=count, username=username) @@ -230,7 +222,7 @@ class EmbyModule(_ModuleBase, _MediaServerBase): """ 获取媒体库播放地址 """ - server: Emby = self.get_server(server) + server: Emby = self.get_instance(server) if not server: return None return server.get_play_url(item_id) @@ -240,7 +232,7 @@ class EmbyModule(_ModuleBase, _MediaServerBase): """ 获取媒体服务器最新入库条目 """ - server: Emby = self.get_server(server) + server: Emby = self.get_instance(server) if not server: return [] return server.get_latest(num=count, username=username) diff --git a/app/modules/jellyfin/__init__.py b/app/modules/jellyfin/__init__.py index b3ccac7b..c03bb8d3 100644 --- a/app/modules/jellyfin/__init__.py +++ b/app/modules/jellyfin/__init__.py @@ -1,8 +1,7 @@ -from typing import Optional, Tuple, Union, Any, List, Generator, Dict +from typing import Optional, Tuple, Union, Any, List, Generator from app import schemas from app.core.context import MediaInfo -from app.helper.mediaserver import MediaServerHelper from app.log import logger from app.modules import _ModuleBase, _MediaServerBase from app.modules.jellyfin.jellyfin import Jellyfin @@ -10,22 +9,14 @@ from app.schemas import MediaServerConf from app.schemas.types import MediaType -class JellyfinModule(_ModuleBase, _MediaServerBase): +class JellyfinModule(_ModuleBase, _MediaServerBase[Jellyfin]): def init_module(self) -> None: """ 初始化模块 """ - # 读取媒体服务器配置 - self._servers: Dict[str, Jellyfin] = {} - self._configs: Dict[str, MediaServerConf] = {} - mediaservers = MediaServerHelper().get_mediaservers() - if not mediaservers: - return - for server in mediaservers: - if server.type == "jellyfin" and server.enabled: - self._configs[server.name] = server - self._servers[server.name] = Jellyfin(**server.config, sync_libraries=server.sync_libraries) + super().init_service(service_name=Jellyfin.__name__.lower(), + service_type=lambda conf: Jellyfin(**conf.config, sync_libraries=conf.sync_libraries)) @staticmethod def get_name() -> str: @@ -39,7 +30,7 @@ class JellyfinModule(_ModuleBase, _MediaServerBase): 定时任务,每10分钟调用一次 """ # 定时重连 - for name, server in self._servers.items(): + for name, server in self._instances.items(): if server.is_inactive(): logger.info(f"Jellyfin {name} 服务器连接断开,尝试重连 ...") server.reconnect() @@ -51,9 +42,9 @@ class JellyfinModule(_ModuleBase, _MediaServerBase): """ 测试模块连接性 """ - if not self._servers: + if not self._instances: return None - for name, server in self._servers.items(): + for name, server in self._instances.items(): if server.is_inactive(): server.reconnect() if not server.get_user(): @@ -68,7 +59,7 @@ class JellyfinModule(_ModuleBase, _MediaServerBase): :return: Token or None """ # Jellyfin认证 - for server in self._servers.values(): + for server in self._instances.values(): result = server.authenticate(name, password) if result: return result @@ -87,7 +78,7 @@ class JellyfinModule(_ModuleBase, _MediaServerBase): server_config: MediaServerConf = self.get_config(source, 'jellyfin') if not server_config: return None - server: Jellyfin = self.get_server(source) + server: Jellyfin = self.get_instance(source) if not server: return None return server.get_webhook_message(body) @@ -95,7 +86,7 @@ class JellyfinModule(_ModuleBase, _MediaServerBase): for conf in self._configs.values(): if conf.type != "jellyfin": continue - server = self.get_server(conf.name) + server = self.get_instance(conf.name) if server: result = server.get_webhook_message(body) if result: @@ -109,7 +100,7 @@ class JellyfinModule(_ModuleBase, _MediaServerBase): :param itemid: 媒体服务器ItemID :return: 如不存在返回None,存在时返回信息,包括每季已存在所有集{type: movie/tv, seasons: {season: [episodes]}} """ - for name, server in self._servers.items(): + for name, server in self._instances.items(): if mediainfo.type == MediaType.MOVIE: if itemid: movie = server.get_iteminfo(itemid) @@ -154,12 +145,12 @@ class JellyfinModule(_ModuleBase, _MediaServerBase): 媒体数量统计 """ if server: - server: Jellyfin = self.get_server(server) + server: Jellyfin = self.get_instance(server) if not server: return None servers = [server] else: - servers = self._servers.values() + servers = self._instances.values() media_statistics = [] for server in servers: media_statistic = server.get_medias_count() @@ -175,16 +166,17 @@ class JellyfinModule(_ModuleBase, _MediaServerBase): """ 媒体库列表 """ - server: Jellyfin = self.get_server(server) + server: Jellyfin = self.get_instance(server) if server: return server.get_librarys(username=username, hidden=hidden) return None - def mediaserver_items(self, server: str, library_id: str, start_index: int = 0, limit: int = 100) -> Optional[Generator]: + def mediaserver_items(self, server: str, library_id: str, start_index: int = 0, limit: int = 100) \ + -> Optional[Generator]: """ 媒体库项目列表 """ - server: Jellyfin = self.get_server(server) + server: Jellyfin = self.get_instance(server) if server: return server.get_items(library_id, start_index, limit) return None @@ -193,7 +185,7 @@ class JellyfinModule(_ModuleBase, _MediaServerBase): """ 媒体库项目详情 """ - server: Jellyfin = self.get_server(server) + server: Jellyfin = self.get_instance(server) if server: return server.get_iteminfo(item_id) return None @@ -203,7 +195,7 @@ class JellyfinModule(_ModuleBase, _MediaServerBase): """ 获取剧集信息 """ - server: Jellyfin = self.get_server(server) + server: Jellyfin = self.get_instance(server) if not server: return None _, seasoninfo = server.get_tv_episodes(item_id=item_id) @@ -219,7 +211,7 @@ class JellyfinModule(_ModuleBase, _MediaServerBase): """ 获取媒体服务器正在播放信息 """ - server: Jellyfin = self.get_server(server) + server: Jellyfin = self.get_instance(server) if not server: return [] return server.get_resume(num=count, username=username) @@ -228,7 +220,7 @@ class JellyfinModule(_ModuleBase, _MediaServerBase): """ 获取媒体库播放地址 """ - server: Jellyfin = self.get_server(server) + server: Jellyfin = self.get_instance(server) if not server: return None return server.get_play_url(item_id) @@ -238,7 +230,7 @@ class JellyfinModule(_ModuleBase, _MediaServerBase): """ 获取媒体服务器最新入库条目 """ - server: Jellyfin = self.get_server(server) + server: Jellyfin = self.get_instance(server) if not server: return [] return server.get_latest(num=count, username=username) diff --git a/app/modules/plex/__init__.py b/app/modules/plex/__init__.py index 013e9273..afe108d7 100644 --- a/app/modules/plex/__init__.py +++ b/app/modules/plex/__init__.py @@ -1,8 +1,7 @@ -from typing import Optional, Tuple, Union, Any, List, Generator, Dict +from typing import Optional, Tuple, Union, Any, List, Generator from app import schemas from app.core.context import MediaInfo -from app.helper.mediaserver import MediaServerHelper from app.log import logger from app.modules import _ModuleBase, _MediaServerBase from app.modules.plex.plex import Plex @@ -10,22 +9,14 @@ from app.schemas import MediaServerConf from app.schemas.types import MediaType -class PlexModule(_ModuleBase, _MediaServerBase): +class PlexModule(_ModuleBase, _MediaServerBase[Plex]): def init_module(self) -> None: """ 初始化模块 """ - # 读取媒体服务器配置 - self._servers: Dict[str, Plex] = {} - self._configs: Dict[str, MediaServerConf] = {} - mediaservers = MediaServerHelper().get_mediaservers() - if not mediaservers: - return - for server in mediaservers: - if server.type == "plex" and server.enabled: - self._configs[server.name] = server - self._servers[server.name] = Plex(**server.config, sync_libraries=server.sync_libraries) + super().init_service(service_name=Plex.__name__.lower(), + service_type=lambda conf: Plex(**conf.config, sync_libraries=conf.sync_libraries)) @staticmethod def get_name() -> str: @@ -38,9 +29,9 @@ class PlexModule(_ModuleBase, _MediaServerBase): """ 测试模块连接性 """ - if not self._servers: + if not self._instances: return None - for name, server in self._servers.items(): + for name, server in self._instances.items(): if server.is_inactive(): server.reconnect() if not server.get_librarys(): @@ -55,7 +46,7 @@ class PlexModule(_ModuleBase, _MediaServerBase): 定时任务,每10分钟调用一次 """ # 定时重连 - for name, server in self._servers.items(): + for name, server in self._instances.items(): if server.is_inactive(): logger.info(f"Plex {name} 服务器连接断开,尝试重连 ...") server.reconnect() @@ -73,7 +64,7 @@ class PlexModule(_ModuleBase, _MediaServerBase): server_config: MediaServerConf = self.get_config(source, 'plex') if not server_config: return None - server: Plex = self.get_server(source) + server: Plex = self.get_instance(source) if not server: return None return server.get_webhook_message(body) @@ -81,7 +72,7 @@ class PlexModule(_ModuleBase, _MediaServerBase): for conf in self._configs.values(): if conf.type != "plex": continue - server = self.get_server(conf.name) + server = self.get_instance(conf.name) if server: result = server.get_webhook_message(body) if result: @@ -95,7 +86,7 @@ class PlexModule(_ModuleBase, _MediaServerBase): :param itemid: 媒体服务器ItemID :return: 如不存在返回None,存在时返回信息,包括每季已存在所有集{type: movie/tv, seasons: {season: [episodes]}} """ - for name, server in self._servers.items(): + for name, server in self._instances.items(): if mediainfo.type == MediaType.MOVIE: if itemid: movie = server.get_iteminfo(itemid) @@ -144,12 +135,12 @@ class PlexModule(_ModuleBase, _MediaServerBase): 媒体数量统计 """ if server: - server: Plex = self.get_server(server) + server: Plex = self.get_instance(server) if not server: return None servers = [server] else: - servers = self._servers.values() + servers = self._instances.values() media_statistics = [] for server in servers: media_statistic = server.get_medias_count() @@ -163,16 +154,17 @@ class PlexModule(_ModuleBase, _MediaServerBase): """ 媒体库列表 """ - server: Plex = self.get_server(server) + server: Plex = self.get_instance(server) if server: return server.get_librarys(hidden) return None - def mediaserver_items(self, server: str, library_id: str, start_index: int = 0, limit: int = 100) -> Optional[Generator]: + def mediaserver_items(self, server: str, library_id: str, start_index: int = 0, limit: int = 100) \ + -> Optional[Generator]: """ 媒体库项目列表 """ - server: Plex = self.get_server(server) + server: Plex = self.get_instance(server) if server: return server.get_items(library_id, start_index, limit) return None @@ -181,7 +173,7 @@ class PlexModule(_ModuleBase, _MediaServerBase): """ 媒体库项目详情 """ - server: Plex = self.get_server(server) + server: Plex = self.get_instance(server) if server: return server.get_iteminfo(item_id) return None @@ -191,7 +183,7 @@ class PlexModule(_ModuleBase, _MediaServerBase): """ 获取剧集信息 """ - server: Plex = self.get_server(server) + server: Plex = self.get_instance(server) if not server: return None _, seasoninfo = server.get_tv_episodes(item_id=item_id) @@ -206,7 +198,7 @@ class PlexModule(_ModuleBase, _MediaServerBase): """ 获取媒体服务器正在播放信息 """ - server: Plex = self.get_server(server) + server: Plex = self.get_instance(server) if not server: return [] return server.get_resume(num=count) @@ -215,7 +207,7 @@ class PlexModule(_ModuleBase, _MediaServerBase): """ 获取媒体服务器最新入库条目 """ - server: Plex = self.get_server(server) + server: Plex = self.get_instance(server) if not server: return [] return server.get_latest(num=count) @@ -224,7 +216,7 @@ class PlexModule(_ModuleBase, _MediaServerBase): """ 获取媒体库播放地址 """ - server: Plex = self.get_server(server) + server: Plex = self.get_instance(server) if not server: return None return server.get_play_url(item_id) diff --git a/app/modules/qbittorrent/__init__.py b/app/modules/qbittorrent/__init__.py index 4a2ce5be..7724cec1 100644 --- a/app/modules/qbittorrent/__init__.py +++ b/app/modules/qbittorrent/__init__.py @@ -1,6 +1,6 @@ import shutil from pathlib import Path -from typing import Set, Tuple, Optional, Union, List, Dict +from typing import Set, Tuple, Optional, Union, List from qbittorrentapi import TorrentFilesList from torrentool.torrent import Torrent @@ -8,7 +8,6 @@ from torrentool.torrent import Torrent from app import schemas from app.core.config import settings from app.core.metainfo import MetaInfo -from app.helper.downloader import DownloaderHelper from app.log import logger from app.modules import _ModuleBase, _DownloaderBase from app.modules.qbittorrent.qbittorrent import Qbittorrent @@ -18,23 +17,14 @@ from app.utils.string import StringUtils from app.utils.system import SystemUtils -class QbittorrentModule(_ModuleBase, _DownloaderBase): +class QbittorrentModule(_ModuleBase, _DownloaderBase[Qbittorrent]): def init_module(self) -> None: """ 初始化模块 """ - # 读取下载器配置 - self._servers: Dict[str, Qbittorrent] = {} - downloaders = DownloaderHelper().get_downloaders() - if not downloaders: - return - for server in downloaders: - if server.type == "qbittorrent" and server.enabled: - self._servers[server.name] = Qbittorrent(**server.config) - if server.default: - self._default_server_name = server.name - self._default_server = self._servers[server.name] + super().init_service(service_name=Qbittorrent.__name__.lower(), + service_type=Qbittorrent) @staticmethod def get_name() -> str: @@ -47,9 +37,9 @@ class QbittorrentModule(_ModuleBase, _DownloaderBase): """ 测试模块连接性 """ - if not self._servers: + if not self._instances: return None - for name, server in self._servers.items(): + for name, server in self._instances.items(): if server.is_inactive(): server.reconnect() if not server.transfer_info(): @@ -63,7 +53,7 @@ class QbittorrentModule(_ModuleBase, _DownloaderBase): """ 定时任务,每10分钟调用一次 """ - for name, server in self._servers.items(): + for name, server in self._instances.items(): if server.is_inactive(): logger.info(f"Qbittorrent下载器 {name} 连接断开,尝试重连 ...") server.reconnect() @@ -103,7 +93,7 @@ class QbittorrentModule(_ModuleBase, _DownloaderBase): return None, None, f"种子文件不存在:{content}" # 获取下载器 - server: Qbittorrent = self.get_server(downloader) + server: Qbittorrent = self.get_instance(downloader) if not server: return None @@ -201,7 +191,7 @@ class QbittorrentModule(_ModuleBase, _DownloaderBase): :return: 下载器中符合状态的种子列表 """ # 获取下载器 - server: Qbittorrent = self.get_server(downloader) + server: Qbittorrent = self.get_instance(downloader) if not server: return None @@ -274,7 +264,7 @@ class QbittorrentModule(_ModuleBase, _DownloaderBase): :param downloader: 下载器 :param transfer_type: 整理方式 """ - server: Qbittorrent = self.get_server(downloader) + server: Qbittorrent = self.get_instance(downloader) if not server: return None server.set_torrents_tag(ids=hashs, tags=['已整理']) @@ -298,7 +288,7 @@ class QbittorrentModule(_ModuleBase, _DownloaderBase): :param downloader: 下载器 :return: bool """ - server: Qbittorrent = self.get_server(downloader) + server: Qbittorrent = self.get_instance(downloader) if not server: return None return server.delete_torrents(delete_file=delete_file, ids=hashs) @@ -311,7 +301,7 @@ class QbittorrentModule(_ModuleBase, _DownloaderBase): :param downloader: 下载器 :return: bool """ - server: Qbittorrent = self.get_server(downloader) + server: Qbittorrent = self.get_instance(downloader) if not server: return None return server.start_torrents(ids=hashs) @@ -323,7 +313,7 @@ class QbittorrentModule(_ModuleBase, _DownloaderBase): :param downloader: 下载器 :return: bool """ - server: Qbittorrent = self.get_server(downloader) + server: Qbittorrent = self.get_instance(downloader) if not server: return None return server.stop_torrents(ids=hashs) @@ -332,7 +322,7 @@ class QbittorrentModule(_ModuleBase, _DownloaderBase): """ 获取种子文件列表 """ - server: Qbittorrent = self.get_server(downloader) + server: Qbittorrent = self.get_instance(downloader) if not server: return None return server.get_files(tid=tid) @@ -342,12 +332,12 @@ class QbittorrentModule(_ModuleBase, _DownloaderBase): 下载器信息 """ if downloader: - server: Qbittorrent = self.get_server(downloader) + server: Qbittorrent = self.get_instance(downloader) if not server: return None servers = [server] else: - servers = self._servers.values() + servers = self._instances.values() # 调用Qbittorrent API查询实时信息 ret_info = [] for server in servers: diff --git a/app/modules/slack/__init__.py b/app/modules/slack/__init__.py index 7d129aad..0503273b 100644 --- a/app/modules/slack/__init__.py +++ b/app/modules/slack/__init__.py @@ -4,28 +4,20 @@ from typing import Optional, Union, List, Tuple, Any from app.core.config import settings from app.core.context import MediaInfo, Context -from app.helper.notification import NotificationHelper from app.log import logger from app.modules import _ModuleBase, _MessageBase from app.modules.slack.slack import Slack from app.schemas import MessageChannel, CommingMessage, Notification -class SlackModule(_ModuleBase, _MessageBase): +class SlackModule(_ModuleBase, _MessageBase[Slack]): def init_module(self) -> None: """ 初始化模块 """ - clients = NotificationHelper().get_clients() - if not clients: - return - self._configs = {} - self._clients = {} - for client in clients: - if client.type == "slack" and client.enabled: - self._configs[client.name] = client - self._clients[client.name] = Slack(**client.config, name=client.name) + super().init_service(service_name=Slack.__name__.lower(), + service_type=Slack) @staticmethod def get_name() -> str: @@ -35,16 +27,16 @@ class SlackModule(_ModuleBase, _MessageBase): """ 停止模块 """ - for client in self._clients.values(): + for client in self._instances.values(): client.stop() def test(self) -> Optional[Tuple[bool, str]]: """ 测试模块连接性 """ - if not self._clients: + if not self._instances: return None - for name, client in self._clients.items(): + for name, client in self._instances.items(): state = client.get_state() if not state: return False, f"Slack {name} 未就续" @@ -223,7 +215,7 @@ class SlackModule(_ModuleBase, _MessageBase): :return: 成功或失败 """ for conf in self._configs.values(): - if not self.checkMessage(message, conf.name): + if not self.check_message(message, conf.name): continue targets = message.targets userid = message.userid @@ -232,7 +224,7 @@ class SlackModule(_ModuleBase, _MessageBase): if not userid: logger.warn(f"用户没有指定 Slack用户ID,消息无法发送") return - client: Slack = self.get_client(conf.name) + client: Slack = self.get_instance(conf.name) if client: client.send_msg(title=message.title, text=message.text, image=message.image, userid=userid, link=message.link) @@ -245,9 +237,9 @@ class SlackModule(_ModuleBase, _MessageBase): :return: 成功或失败 """ for conf in self._configs.values(): - if not self.checkMessage(message, conf.name): + if not self.check_message(message, conf.name): continue - client: Slack = self.get_client(conf.name) + client: Slack = self.get_instance(conf.name) if client: client.send_medias_msg(title=message.title, medias=medias, userid=message.userid) @@ -259,9 +251,9 @@ class SlackModule(_ModuleBase, _MessageBase): :return: 成功或失败 """ for conf in self._configs.values(): - if not self.checkMessage(message, conf.name): + if not self.check_message(message, conf.name): continue - client: Slack = self.get_client(conf.name) + client: Slack = self.get_instance(conf.name) if client: client.send_torrents_msg(title=message.title, torrents=torrents, userid=message.userid) diff --git a/app/modules/synologychat/__init__.py b/app/modules/synologychat/__init__.py index e0247bcb..7180c430 100644 --- a/app/modules/synologychat/__init__.py +++ b/app/modules/synologychat/__init__.py @@ -1,28 +1,20 @@ from typing import Optional, Union, List, Tuple, Any from app.core.context import MediaInfo, Context -from app.helper.notification import NotificationHelper from app.log import logger from app.modules import _ModuleBase, _MessageBase from app.modules.synologychat.synologychat import SynologyChat from app.schemas import MessageChannel, CommingMessage, Notification -class SynologyChatModule(_ModuleBase, _MessageBase): +class SynologyChatModule(_ModuleBase, _MessageBase[SynologyChat]): def init_module(self) -> None: """ 初始化模块 """ - clients = NotificationHelper().get_clients() - if not clients: - return - self._configs = {} - self._clients = {} - for client in clients: - if client.type == "synologychat" and client.enabled: - self._configs[client.name] = client - self._clients[client.name] = SynologyChat(**client.config) + super().init_service(service_name=SynologyChat.__name__.lower(), + service_type=SynologyChat) @staticmethod def get_name() -> str: @@ -35,9 +27,9 @@ class SynologyChatModule(_ModuleBase, _MessageBase): """ 测试模块连接性 """ - if not self._clients: + if not self._instances: return None - for name, client in self._clients.items(): + for name, client in self._instances.items(): state = client.get_state() if not state: return False, f"Synology Chat {name} 未就续" @@ -64,7 +56,7 @@ class SynologyChatModule(_ModuleBase, _MessageBase): client_config = self.get_config(source, 'synologychat') if not client_config: return None - client: SynologyChat = self.get_client(source) + client: SynologyChat = self.get_instance(source) # 解析消息 message: dict = form if not message: @@ -94,7 +86,7 @@ class SynologyChatModule(_ModuleBase, _MessageBase): :return: 成功或失败 """ for conf in self._configs.values(): - if not self.checkMessage(message, conf.name): + if not self.check_message(message, conf.name): continue targets = message.targets userid = message.userid @@ -103,7 +95,7 @@ class SynologyChatModule(_ModuleBase, _MessageBase): if not userid: logger.warn(f"用户没有指定 SynologyChat用户ID,消息无法发送") return - client: SynologyChat = self.get_client(conf.name) + client: SynologyChat = self.get_instance(conf.name) if client: client.send_msg(title=message.title, text=message.text, image=message.image, userid=userid, link=message.link) @@ -116,9 +108,9 @@ class SynologyChatModule(_ModuleBase, _MessageBase): :return: 成功或失败 """ for conf in self._configs.values(): - if not self.checkMessage(message, conf.name): + if not self.check_message(message, conf.name): continue - client: SynologyChat = self.get_client(conf.name) + client: SynologyChat = self.get_instance(conf.name) if client: client.send_medias_msg(title=message.title, medias=medias, userid=message.userid) @@ -131,9 +123,9 @@ class SynologyChatModule(_ModuleBase, _MessageBase): :return: 成功或失败 """ for conf in self._configs.values(): - if not self.checkMessage(message, conf.name): + if not self.check_message(message, conf.name): continue - client: SynologyChat = self.get_client(conf.name) + client: SynologyChat = self.get_instance(conf.name) if client: client.send_torrents_msg(title=message.title, torrents=torrents, userid=message.userid, link=message.link) diff --git a/app/modules/telegram/__init__.py b/app/modules/telegram/__init__.py index b4421f2f..97c5351d 100644 --- a/app/modules/telegram/__init__.py +++ b/app/modules/telegram/__init__.py @@ -3,28 +3,20 @@ from typing import Optional, Union, List, Tuple, Any, Dict from app.core.config import settings from app.core.context import MediaInfo, Context -from app.helper.notification import NotificationHelper from app.log import logger from app.modules import _ModuleBase, _MessageBase from app.modules.telegram.telegram import Telegram from app.schemas import MessageChannel, CommingMessage, Notification -class TelegramModule(_ModuleBase, _MessageBase): +class TelegramModule(_ModuleBase, _MessageBase[Telegram]): def init_module(self) -> None: """ 初始化模块 """ - clients = NotificationHelper().get_clients() - if not clients: - return - self._configs = {} - self._clients = {} - for client in clients: - if client.type == "telegram" and client.enabled: - self._configs[client.name] = client - self._clients[client.name] = Telegram(**client.config, name=client.name) + super().init_service(service_name=Telegram.__name__.lower(), + service_type=Telegram) @staticmethod def get_name() -> str: @@ -34,16 +26,16 @@ class TelegramModule(_ModuleBase, _MessageBase): """ 停止模块 """ - for client in self._clients.values(): + for client in self._instances.values(): client.stop() def test(self) -> Optional[Tuple[bool, str]]: """ 测试模块连接性 """ - if not self._clients: + if not self._instances: return None - for name, client in self._clients.items(): + for name, client in self._instances.items(): state = client.get_state() if not state: return False, f"Telegram {name} 未就续" @@ -92,7 +84,7 @@ class TelegramModule(_ModuleBase, _MessageBase): client_config = self.get_config(source, 'telegram') if not client_config: return None - client: Telegram = self.get_client(source) + client: Telegram = self.get_instance(source) # 校验token token = args.get("token") if not token or token != settings.API_TOKEN: @@ -136,7 +128,7 @@ class TelegramModule(_ModuleBase, _MessageBase): :return: 成功或失败 """ for conf in self._configs.values(): - if not self.checkMessage(message, conf.name): + if not self.check_message(message, conf.name): continue targets = message.targets userid = message.userid @@ -145,7 +137,7 @@ class TelegramModule(_ModuleBase, _MessageBase): if not userid: logger.warn(f"用户没有指定 Telegram用户ID,消息无法发送") return - client: Telegram = self.get_client(conf.name) + client: Telegram = self.get_instance(conf.name) if client: client.send_msg(title=message.title, text=message.text, image=message.image, userid=userid, link=message.link) @@ -158,9 +150,9 @@ class TelegramModule(_ModuleBase, _MessageBase): :return: 成功或失败 """ for conf in self._configs.values(): - if not self.checkMessage(message, conf.name): + if not self.check_message(message, conf.name): continue - client: Telegram = self.get_client(conf.name) + client: Telegram = self.get_instance(conf.name) if client: client.send_medias_msg(title=message.title, medias=medias, userid=message.userid, link=message.link) @@ -173,9 +165,9 @@ class TelegramModule(_ModuleBase, _MessageBase): :return: 成功或失败 """ for conf in self._configs.values(): - if not self.checkMessage(message, conf.name): + if not self.check_message(message, conf.name): continue - client: Telegram = self.get_client(conf.name) + client: Telegram = self.get_instance(conf.name) if client: client.send_torrents_msg(title=message.title, torrents=torrents, userid=message.userid, link=message.link) @@ -185,5 +177,5 @@ class TelegramModule(_ModuleBase, _MessageBase): 注册命令,实现这个函数接收系统可用的命令菜单 :param commands: 命令字典 """ - for client in self._clients.values(): + for client in self._instances.values(): client.register_commands(commands) diff --git a/app/modules/transmission/__init__.py b/app/modules/transmission/__init__.py index 0d501bd0..dfd865bd 100644 --- a/app/modules/transmission/__init__.py +++ b/app/modules/transmission/__init__.py @@ -1,6 +1,6 @@ import shutil from pathlib import Path -from typing import Set, Tuple, Optional, Union, List, Dict +from typing import Set, Tuple, Optional, Union, List from torrentool.torrent import Torrent from transmission_rpc import File @@ -8,7 +8,6 @@ from transmission_rpc import File from app import schemas from app.core.config import settings from app.core.metainfo import MetaInfo -from app.helper.downloader import DownloaderHelper from app.log import logger from app.modules import _ModuleBase, _DownloaderBase from app.modules.transmission.transmission import Transmission @@ -18,20 +17,14 @@ from app.utils.string import StringUtils from app.utils.system import SystemUtils -class TransmissionModule(_ModuleBase, _DownloaderBase): +class TransmissionModule(_ModuleBase, _DownloaderBase[Transmission]): def init_module(self) -> None: - # 读取下载器配置 - self._servers: Dict[str, Transmission] = {} - downloaders = DownloaderHelper().get_downloaders() - if not downloaders: - return - for server in downloaders: - if server.type == "transmission" and server.enabled: - self._servers[server.name] = Transmission(**server.config) - if server.default: - self._default_server_name = server.name - self._default_server = self._servers[server.name] + """ + 初始化模块 + """ + super().init_service(service_name=Transmission.__name__.lower(), + service_type=Transmission) @staticmethod def get_name() -> str: @@ -44,9 +37,9 @@ class TransmissionModule(_ModuleBase, _DownloaderBase): """ 测试模块连接性 """ - if not self._servers: + if not self._instances: return None - for name, server in self._servers.items(): + for name, server in self._instances.items(): if server.is_inactive(): server.reconnect() if not server.transfer_info(): @@ -61,7 +54,7 @@ class TransmissionModule(_ModuleBase, _DownloaderBase): 定时任务,每10分钟调用一次 """ # 定时重连 - for name, server in self._servers.items(): + for name, server in self._instances.items(): if server.is_inactive(): logger.info(f"Transmission下载器 {name} 连接断开,尝试重连 ...") server.reconnect() @@ -100,7 +93,7 @@ class TransmissionModule(_ModuleBase, _DownloaderBase): return None, None, f"种子文件不存在:{content}" # 获取下载器 - server: Transmission = self.get_server(downloader) + server: Transmission = self.get_instance(downloader) if not server: return None @@ -191,7 +184,7 @@ class TransmissionModule(_ModuleBase, _DownloaderBase): :return: 下载器中符合状态的种子列表 """ # 获取下载器 - server: Transmission = self.get_server(downloader) + server: Transmission = self.get_instance(downloader) if not server: return None ret_torrents = [] @@ -259,7 +252,7 @@ class TransmissionModule(_ModuleBase, _DownloaderBase): :param transfer_type: 整理方式 """ # 获取下载器 - server: Transmission = self.get_server(downloader) + server: Transmission = self.get_instance(downloader) if not server: return None # 获取原标签 @@ -291,7 +284,7 @@ class TransmissionModule(_ModuleBase, _DownloaderBase): :return: bool """ # 获取下载器 - server: Transmission = self.get_server(downloader) + server: Transmission = self.get_instance(downloader) if not server: return None return server.delete_torrents(delete_file=delete_file, ids=hashs) @@ -305,7 +298,7 @@ class TransmissionModule(_ModuleBase, _DownloaderBase): :return: bool """ # 获取下载器 - server: Transmission = self.get_server(downloader) + server: Transmission = self.get_instance(downloader) if not server: return None return server.start_torrents(ids=hashs) @@ -319,7 +312,7 @@ class TransmissionModule(_ModuleBase, _DownloaderBase): :return: bool """ # 获取下载器 - server: Transmission = self.get_server(downloader) + server: Transmission = self.get_instance(downloader) if not server: return None return server.start_torrents(ids=hashs) @@ -329,7 +322,7 @@ class TransmissionModule(_ModuleBase, _DownloaderBase): 获取种子文件列表 """ # 获取下载器 - server: Transmission = self.get_server(downloader) + server: Transmission = self.get_instance(downloader) if not server: return None return server.get_files(tid=tid) @@ -339,12 +332,12 @@ class TransmissionModule(_ModuleBase, _DownloaderBase): 下载器信息 """ if downloader: - server: Transmission = self.get_server(downloader) + server: Transmission = self.get_instance(downloader) if not server: return None servers = [server] else: - servers = self._servers.values() + servers = self._instances.values() # 调用Qbittorrent API查询实时信息 ret_info = [] for server in servers: diff --git a/app/modules/vocechat/__init__.py b/app/modules/vocechat/__init__.py index 031da1b6..bc802f41 100644 --- a/app/modules/vocechat/__init__.py +++ b/app/modules/vocechat/__init__.py @@ -3,28 +3,20 @@ from typing import Optional, Union, List, Tuple, Any, Dict from app.core.config import settings from app.core.context import Context, MediaInfo -from app.helper.notification import NotificationHelper from app.log import logger from app.modules import _ModuleBase, _MessageBase from app.modules.vocechat.vocechat import VoceChat from app.schemas import MessageChannel, CommingMessage, Notification -class VoceChatModule(_ModuleBase, _MessageBase): +class VoceChatModule(_ModuleBase, _MessageBase[VoceChat]): def init_module(self) -> None: """ 初始化模块 """ - clients = NotificationHelper().get_clients() - if not clients: - return - self._configs = {} - self._clients = {} - for client in clients: - if client.type == "vocechat" and client.enabled: - self._configs[client.name] = client - self._clients[client.name] = VoceChat(**client.config) + super().init_service(service_name=VoceChat.__name__.lower(), + service_type=VoceChat) @staticmethod def get_name() -> str: @@ -37,9 +29,9 @@ class VoceChatModule(_ModuleBase, _MessageBase): """ 测试模块连接性 """ - if not self._clients: + if not self._instances: return None - for name, client in self._clients.items(): + for name, client in self._instances.items(): state = client.get_state() if not state: return False, f"VoceChat {name} 未就续" @@ -122,13 +114,13 @@ class VoceChatModule(_ModuleBase, _MessageBase): :return: 成功或失败 """ for conf in self._configs.values(): - if not self.checkMessage(message, conf.name): + if not self.check_message(message, conf.name): continue targets = message.targets userid = message.userid if not message.userid and targets: userid = targets.get('telegram_userid') - client: VoceChat = self.get_client(conf.name) + client: VoceChat = self.get_instance(conf.name) if client: client.send_msg(title=message.title, text=message.text, userid=userid, link=message.link) @@ -141,9 +133,9 @@ class VoceChatModule(_ModuleBase, _MessageBase): :return: 成功或失败 """ for conf in self._configs.values(): - if not self.checkMessage(message, conf.name): + if not self.check_message(message, conf.name): continue - client: VoceChat = self.get_client(conf.name) + client: VoceChat = self.get_instance(conf.name) if client: client.send_msg(title=message.title, userid=message.userid) client.send_medias_msg(title=message.title, medias=medias, @@ -157,7 +149,7 @@ class VoceChatModule(_ModuleBase, _MessageBase): :return: 成功或失败 """ for conf in self._configs.values(): - if not self.checkMessage(message, conf.name): + if not self.check_message(message, conf.name): continue targets = message.targets userid = message.userid @@ -166,7 +158,7 @@ class VoceChatModule(_ModuleBase, _MessageBase): if not userid: logger.warn(f"用户没有指定 VoceChat用户ID,消息无法发送") return - client: VoceChat = self.get_client(conf.name) + client: VoceChat = self.get_instance(conf.name) if client: client.send_torrents_msg(title=message.title, torrents=torrents, userid=userid, link=message.link) diff --git a/app/modules/webpush/__init__.py b/app/modules/webpush/__init__.py index f158666b..89f45f3c 100644 --- a/app/modules/webpush/__init__.py +++ b/app/modules/webpush/__init__.py @@ -4,24 +4,18 @@ from typing import Union, Tuple from pywebpush import webpush, WebPushException from app.core.config import global_vars, settings -from app.helper.notification import NotificationHelper from app.log import logger from app.modules import _ModuleBase, _MessageBase from app.schemas import Notification class WebPushModule(_ModuleBase, _MessageBase): + def init_module(self) -> None: """ 初始化模块 """ - clients = NotificationHelper().get_clients() - if not clients: - return - self._configs = {} - for client in clients: - if client.type == "webpush" and client.enabled: - self._configs[client.name] = client + super().init_service(service_name=self.get_name().lower()) @staticmethod def get_name() -> str: @@ -46,7 +40,7 @@ class WebPushModule(_ModuleBase, _MessageBase): :return: 成功或失败 """ for conf in self._configs.values(): - if not self.checkMessage(message, conf.name): + if not self.check_message(message, conf.name): continue webpush_users = conf.config.get("WEBPUSH_USERNAME") or "" if webpush_users: diff --git a/app/modules/wechat/__init__.py b/app/modules/wechat/__init__.py index 690a77e6..5994e46a 100644 --- a/app/modules/wechat/__init__.py +++ b/app/modules/wechat/__init__.py @@ -2,7 +2,6 @@ import xml.dom.minidom from typing import Optional, Union, List, Tuple, Any, Dict from app.core.context import Context, MediaInfo -from app.helper.notification import NotificationHelper from app.log import logger from app.modules import _ModuleBase, _MessageBase from app.modules.wechat.WXBizMsgCrypt3 import WXBizMsgCrypt @@ -11,21 +10,14 @@ from app.schemas import MessageChannel, CommingMessage, Notification from app.utils.dom import DomUtils -class WechatModule(_ModuleBase, _MessageBase): +class WechatModule(_ModuleBase, _MessageBase[WeChat]): def init_module(self) -> None: """ 初始化模块 """ - clients = NotificationHelper().get_clients() - if not clients: - return - self._configs = {} - self._clients = {} - for client in clients: - if client.type == "wechat" and client.enabled: - self._configs[client.name] = client - self._clients[client.name] = WeChat(**client.config) + super().init_service(service_name=WeChat.__name__.lower(), + service_type=WeChat) @staticmethod def get_name() -> str: @@ -38,9 +30,9 @@ class WechatModule(_ModuleBase, _MessageBase): """ 测试模块连接性 """ - if not self._clients: + if not self._instances: return None - for name, client in self._clients.items(): + for name, client in self._instances.items(): state = client.get_state() if not state: return False, f"企业微信 {name} 未就续" @@ -67,7 +59,7 @@ class WechatModule(_ModuleBase, _MessageBase): client_config = self.get_config(source, 'wechat') if not client_config: return None - client: WeChat = self.get_client(source) + client: WeChat = self.get_instance(source) # URL参数 sVerifyMsgSig = args.get("msg_signature") sVerifyTimeStamp = args.get("timestamp") @@ -159,7 +151,7 @@ class WechatModule(_ModuleBase, _MessageBase): :return: 成功或失败 """ for conf in self._configs.values(): - if not self.checkMessage(message, conf.name): + if not self.check_message(message, conf.name): continue targets = message.targets userid = message.userid @@ -168,7 +160,7 @@ class WechatModule(_ModuleBase, _MessageBase): if not userid: logger.warn(f"用户没有指定 微信用户ID,消息无法发送") return - client: WeChat = self.get_client(conf.name) + client: WeChat = self.get_instance(conf.name) if client: client.send_msg(title=message.title, text=message.text, image=message.image, userid=userid, link=message.link) @@ -181,9 +173,9 @@ class WechatModule(_ModuleBase, _MessageBase): :return: 成功或失败 """ for conf in self._configs.values(): - if not self.checkMessage(message, conf.name): + if not self.check_message(message, conf.name): continue - client: WeChat = self.get_client(conf.name) + client: WeChat = self.get_instance(conf.name) if client: # 先发送标题 client.send_msg(title=message.title, userid=message.userid, link=message.link) @@ -198,9 +190,9 @@ class WechatModule(_ModuleBase, _MessageBase): :return: 成功或失败 """ for conf in self._configs.values(): - if not self.checkMessage(message, conf.name): + if not self.check_message(message, conf.name): continue - client: WeChat = self.get_client(conf.name) + client: WeChat = self.get_instance(conf.name) if client: client.send_torrents_msg(title=message.title, torrents=torrents, userid=message.userid, link=message.link) @@ -210,5 +202,5 @@ class WechatModule(_ModuleBase, _MessageBase): 注册命令,实现这个函数接收系统可用的命令菜单 :param commands: 命令字典 """ - for client in self._clients.values(): + for client in self._instances.values(): client.create_menus(commands)