Merge pull request #2780 from InfinityPacer/feature/module

This commit is contained in:
jxxghp
2024-09-27 10:19:28 +08:00
committed by GitHub
19 changed files with 378 additions and 384 deletions

View File

@@ -1,8 +1,7 @@
import json
from typing import Union, Any, List
from fastapi import APIRouter, BackgroundTasks, Depends
from fastapi import Request
from fastapi import APIRouter, BackgroundTasks, Depends, Request
from pywebpush import WebPushException, webpush
from sqlalchemy.orm import Session
from starlette.responses import PlainTextResponse
@@ -15,7 +14,7 @@ from app.db import get_db
from app.db.models import User
from app.db.models.message import Message
from app.db.user_oper import get_current_active_superuser
from app.helper.notification import NotificationHelper
from app.helper.serviceconfig import ServiceConfigHelper
from app.log import logger
from app.modules.wechat.WXBizMsgCrypt3 import WXBizMsgCrypt
from app.schemas.types import MessageChannel
@@ -81,7 +80,7 @@ def wechat_verify(echostr: str, msg_signature: str, timestamp: Union[str, int],
"""
微信验证响应
"""
clients = NotificationHelper().get_clients()
clients = ServiceConfigHelper.get_notification_configs()
if not clients:
return
for client in clients:

View File

@@ -10,15 +10,14 @@ from ruamel.yaml import CommentedMap
from transmission_rpc import File
from app.core.config import settings
from app.core.context import Context
from app.core.context import MediaInfo, TorrentInfo
from app.core.context import Context, MediaInfo, TorrentInfo
from app.core.event import EventManager
from app.core.meta import MetaBase
from app.core.module import ModuleManager
from app.db.message_oper import MessageOper
from app.db.user_oper import UserOper
from app.helper.message import MessageHelper
from app.helper.notification import NotificationHelper
from app.helper.serviceconfig import ServiceConfigHelper
from app.log import logger
from app.schemas import TransferInfo, TransferTorrent, ExistMediaInfo, DownloadingTorrent, CommingMessage, Notification, \
WebhookEventInfo, TmdbEpisode, MediaPerson, FileItem
@@ -39,7 +38,6 @@ class ChainBase(metaclass=ABCMeta):
self.eventmanager = EventManager()
self.messageoper = MessageOper()
self.messagehelper = MessageHelper()
self.notificationhelper = NotificationHelper()
self.useroper = UserOper()
@staticmethod
@@ -484,7 +482,7 @@ class ChainBase(metaclass=ABCMeta):
# 没有指定用户ID时按规则确定发送对象
# 默认发送全体
to_targets = None
notify_action = self.notificationhelper.get_switch(message.mtype)
notify_action = ServiceConfigHelper.get_notification_switch(message.mtype)
if notify_action == "admin":
# 仅发送管理员
logger.info(f"已设置 {message.mtype} 的消息只发送给管理员")

View File

@@ -5,7 +5,7 @@ from typing import List, Union, Optional
from app import schemas
from app.chain import ChainBase
from app.db.mediaserver_oper import MediaServerOper
from app.helper.mediaserver import MediaServerHelper
from app.helper.serviceconfig import ServiceConfigHelper
from app.log import logger
lock = threading.Lock()
@@ -19,7 +19,6 @@ class MediaServerChain(ChainBase):
def __init__(self):
super().__init__()
self.dboper = MediaServerOper()
self.mediaserverhelper = MediaServerHelper()
def librarys(self, server: str, username: str = None, hidden: bool = False) -> List[schemas.MediaServerLibrary]:
"""
@@ -27,12 +26,14 @@ class MediaServerChain(ChainBase):
"""
return self.run_module("mediaserver_librarys", server=server, username=username, hidden=hidden)
def items(self, server: str, library_id: Union[str, int], start_index: int = 0, limit: int = 100) -> List[schemas.MediaServerItem]:
def items(self, server: str, library_id: Union[str, int], start_index: int = 0, limit: int = 100) \
-> List[schemas.MediaServerItem]:
"""
获取媒体服务器所有项目
"""
data = []
data_generator = self.run_module("mediaserver_items", server=server, library_id=library_id, start_index=start_index, limit=limit)
data_generator = self.run_module("mediaserver_items", server=server, library_id=library_id,
start_index=start_index, limit=limit)
if data_generator:
for item in data_generator:
if item:
@@ -74,7 +75,7 @@ class MediaServerChain(ChainBase):
同步媒体库所有数据到本地数据库
"""
# 设置的媒体服务器
mediaservers = self.mediaserverhelper.get_mediaservers()
mediaservers = ServiceConfigHelper.get_mediaserver_configs()
if not mediaservers:
return
with lock:

View File

@@ -1,23 +1,5 @@
from typing import List
from app.db.systemconfig_oper import SystemConfigOper
from app.schemas.system import DownloaderConf
from app.schemas.types import SystemConfigKey
class DownloaderHelper:
"""
下载器帮助类
"""
def __init__(self):
self.systemconfig = SystemConfigOper()
def get_downloaders(self) -> List[DownloaderConf]:
"""
获取下载器
"""
downloader_confs: List[dict] = self.systemconfig.get(SystemConfigKey.Downloaders)
if not downloader_confs:
return []
return [DownloaderConf(**conf) for conf in downloader_confs]
pass

View File

@@ -1,23 +1,5 @@
from typing import List
from app.db.systemconfig_oper import SystemConfigOper
from app.schemas import MediaServerConf
from app.schemas.types import SystemConfigKey
class MediaServerHelper:
"""
媒体服务器帮助类
"""
def __init__(self):
self.systemconfig = SystemConfigOper()
def get_mediaservers(self) -> List[MediaServerConf]:
"""
获取媒体服务器
"""
mediaserver_confs: List[dict] = self.systemconfig.get(SystemConfigKey.MediaServers)
if not mediaserver_confs:
return []
return [MediaServerConf(**conf) for conf in mediaserver_confs]
pass

View File

@@ -1,42 +1,5 @@
from typing import List, Optional
from app.db.systemconfig_oper import SystemConfigOper
from app.schemas import NotificationConf, NotificationSwitchConf
from app.schemas.types import SystemConfigKey, NotificationType
class NotificationHelper:
"""
消息通知渠道帮助类
"""
def __init__(self):
self.systemconfig = SystemConfigOper()
def get_clients(self) -> List[NotificationConf]:
"""
获取消息通知渠道
"""
client_confs: List[dict] = self.systemconfig.get(SystemConfigKey.Notifications)
if not client_confs:
return []
return [NotificationConf(**conf) for conf in client_confs]
def get_switchs(self) -> List[NotificationSwitchConf]:
"""
获取消息通知场景开关
"""
switchs: List[dict] = self.systemconfig.get(SystemConfigKey.NotificationSwitchs)
if not switchs:
return []
return [NotificationSwitchConf(**switch) for switch in switchs]
def get_switch(self, mtype: NotificationType) -> Optional[str]:
"""
获取消息通知场景开关
"""
switchs = self.get_switchs()
for switch in switchs:
if switch.type == mtype.value:
return switch.action
return None
pass

View File

@@ -0,0 +1,65 @@
from typing import List, Type, Optional
from app.db.systemconfig_oper import SystemConfigOper
from app.schemas import DownloaderConf, MediaServerConf, NotificationConf, NotificationSwitchConf
from app.schemas.types import SystemConfigKey, NotificationType
class ServiceConfigHelper:
"""
配置帮助类,获取不同类型的服务配置
"""
@staticmethod
def get_configs(config_key: SystemConfigKey, conf_type: Type) -> List:
"""
通用获取配置的方法,根据 config_key 获取相应的配置并返回指定类型的配置列表
:param config_key: 系统配置的 key
:param conf_type: 用于实例化配置对象的类类型
:return: 配置对象列表
"""
config_data = SystemConfigOper().get(config_key)
if not config_data:
return []
# 直接使用 conf_type 来实例化配置对象
return [conf_type(**conf) for conf in config_data]
@staticmethod
def get_downloader_configs() -> List[DownloaderConf]:
"""
获取下载器的配置
"""
return ServiceConfigHelper.get_configs(SystemConfigKey.Downloaders, DownloaderConf)
@staticmethod
def get_mediaserver_configs() -> List[MediaServerConf]:
"""
获取媒体服务器的配置
"""
return ServiceConfigHelper.get_configs(SystemConfigKey.MediaServers, MediaServerConf)
@staticmethod
def get_notification_configs() -> List[NotificationConf]:
"""
获取消息通知渠道的配置
"""
return ServiceConfigHelper.get_configs(SystemConfigKey.Notifications, NotificationConf)
@staticmethod
def get_notification_switches() -> List[NotificationSwitchConf]:
"""
获取消息通知场景的开关
"""
return ServiceConfigHelper.get_configs(SystemConfigKey.NotificationSwitchs, NotificationSwitchConf)
@staticmethod
def get_notification_switch(mtype: NotificationType) -> Optional[str]:
"""
获取消息通知场景开关
"""
switchs = ServiceConfigHelper.get_notification_switches()
for switch in switchs:
if switch.type == mtype.value:
return switch.action
return None

View File

@@ -1,7 +1,8 @@
from abc import abstractmethod, ABCMeta
from typing import Tuple, Union, Dict, Any, Optional
from typing import Generic, Tuple, Union, TypeVar, Type, Dict, Optional, Callable, Any, List
from app.schemas import Notification, MessageChannel, NotificationConf, MediaServerConf
from app.helper.serviceconfig import ServiceConfigHelper
from app.schemas import Notification, MessageChannel, NotificationConf, MediaServerConf, DownloaderConf
class _ModuleBase(metaclass=ABCMeta):
@@ -49,37 +50,109 @@ class _ModuleBase(metaclass=ABCMeta):
pass
class _MessageBase:
# 定义泛型,用于表示具体的服务类型和配置类型
TService = TypeVar("TService", bound=object)
TConf = TypeVar("TConf")
class ServiceBase(Generic[TService, TConf], metaclass=ABCMeta):
"""
消息基类
抽象服务基类,负责服务的初始化、获取实例和配置管理
"""
_channel: MessageChannel = None
_configs: Dict[str, NotificationConf] = {}
_clients: Dict[str, Any] = {}
def get_client(self, name: str) -> Optional[Any]:
def __init__(self):
"""
获取客户端
初始化 ServiceBase 类的实例
"""
self._configs: Dict[str, TConf] = {}
self._instances: Dict[str, TService] = {}
def init_service(self, service_name: str,
service_type: Optional[Union[Type[TService], Callable[..., TService]]] = None):
"""
初始化服务,获取配置并实例化对应服务
:param service_name: 服务名称,作为配置匹配的依据
:param service_type: 服务的类型可以是类类型Type[TService]、工厂函数Callable或 None 来跳过实例化
"""
configs = self.get_configs()
if not configs:
return
for conf in configs:
if conf.enabled and conf.type == service_name:
self._configs[conf.name] = conf
if service_type:
# 通过服务类型或工厂函数来创建实例
if isinstance(service_type, type):
# 如果传入的是类类型,调用构造函数实例化
self._instances[conf.name] = service_type(**conf.config)
else:
# 如果传入的是工厂函数,直接调用工厂函数
self._instances[conf.name] = service_type(conf)
def get_instance(self, name: str) -> Optional[TService]:
"""
获取服务实例
:param name: 实例名称
:return: 返回对应名称的服务实例,若不存在则返回 None
"""
if not name:
return None
return self._clients.get(name)
return self._instances.get(name)
def get_config(self, name: str, ctype: str = None) -> Optional[NotificationConf]:
@abstractmethod
def get_configs(self) -> List[TConf]:
"""
获取配置
获取服务的配置列表
:return: 返回配置列表
"""
pass
def get_config(self, name: str, ctype: str = None) -> Optional[TConf]:
"""
获取配置,支持类型过滤
:param name: 配置名称
:param ctype: 配置类型,可选,默认不进行类型过滤
:return: 返回符合条件的配置,若不存在则返回 None
"""
if not name:
return None
conf = self._configs.get(name)
if not ctype:
return conf
return conf if conf.type == ctype else None
return conf if getattr(conf, "type", None) == ctype else None
def checkMessage(self, message: Notification, source: str = None) -> bool:
class _MessageBase(ServiceBase[TService, NotificationConf]):
"""
消息基类
"""
def __init__(self):
"""
检查消息渠道及消息类型,如不符合则不处理
初始化消息基类,并设置消息通道
"""
super().__init__()
self._channel: Optional[MessageChannel] = None
def get_configs(self) -> List[NotificationConf]:
"""
获取消息通知渠道的配置
:return: 返回消息通知的配置列表
"""
return ServiceConfigHelper.get_notification_configs()
def check_message(self, message: Notification, source: str = None) -> bool:
"""
检查消息渠道及消息类型,判断是否处理消息
:param message: 要检查的通知消息
:param source: 消息来源,可选
:return: 返回布尔值,表示是否处理该消息
"""
# 检查消息渠道
if message.channel and message.channel != self._channel:
@@ -97,45 +170,63 @@ class _MessageBase:
return True
class _DownloaderBase:
class _DownloaderBase(ServiceBase[TService, DownloaderConf]):
"""
下载器基类
"""
_servers: Dict[str, Any] = {}
_default_server: Any = None
_default_server_name: str = None
def get_server(self, name: str = None) -> Optional[Any]:
def __init__(self):
"""
获取服务器name为空则返回默认服务器
初始化下载器基类,并设置默认服务器
"""
super().__init__()
self._default_server: Any = None
self._default_server_name: Optional[str] = None
def init_service(self, service_name: str,
service_type: Optional[Union[Type[TService], Callable[..., TService]]] = None):
"""
初始化服务,获取配置并实例化对应服务
:param service_name: 服务名称,作为配置匹配的依据
:param service_type: 服务的类型可以是类类型Type[TService]或工厂函数Callable用于创建服务实例
"""
super().init_service(service_name=service_name, service_type=service_type)
if self._configs:
for conf in self._configs.values():
if conf.default:
self._default_server_name = conf.name
self._default_server = self.get_instance(conf.name)
def get_instance(self, name: str = None) -> Optional[Any]:
"""
获取实例name为空时返回默认实例
:param name: 实例名称,可选,默认为 None
:return: 返回指定名称的实例,若 name 为 None 则返回默认实例
"""
if name:
return self._servers.get(name)
return self._instances.get(name)
return self._default_server
def get_configs(self) -> List[DownloaderConf]:
"""
获取下载器的配置
class _MediaServerBase:
:return: 返回下载器配置列表
"""
return ServiceConfigHelper.get_downloader_configs()
class _MediaServerBase(ServiceBase[TService, MediaServerConf]):
"""
媒体服务器基类
"""
_servers: Dict[str, Any] = {}
_configs: Dict[str, MediaServerConf] = {}
def get_configs(self) -> List[MediaServerConf]:
"""
获取媒体服务器的配置
def get_server(self, name: str) -> Optional[Any]:
:return: 返回媒体服务器配置列表
"""
获取Plex服务器
"""
return self._servers.get(name)
def get_config(self, name: str, mtype: str = None) -> Optional[MediaServerConf]:
"""
获取配置
"""
if not name:
return None
conf = self._configs.get(name)
if not mtype:
return conf
return conf if conf.type == mtype else None
return ServiceConfigHelper.get_mediaserver_configs()

View File

@@ -1,8 +1,7 @@
from typing import Optional, Tuple, Union, Any, List, Generator, Dict
from typing import Optional, Tuple, Union, Any, List, Generator
from app import schemas
from app.core.context import MediaInfo
from app.helper.mediaserver import MediaServerHelper
from app.log import logger
from app.modules import _ModuleBase, _MediaServerBase
from app.modules.emby.emby import Emby
@@ -10,22 +9,14 @@ from app.schemas import MediaServerConf
from app.schemas.types import MediaType
class EmbyModule(_ModuleBase, _MediaServerBase):
class EmbyModule(_ModuleBase, _MediaServerBase[Emby]):
def init_module(self) -> None:
"""
初始化模块
"""
# 读取媒体服务器配置
self._servers: Dict[str, Emby] = {}
self._configs: Dict[str, MediaServerConf] = {}
mediaservers = MediaServerHelper().get_mediaservers()
if not mediaservers:
return
for server in mediaservers:
if server.type == "emby" and server.enabled:
self._configs[server.name] = server
self._servers[server.name] = Emby(**server.config, sync_libraries=server.sync_libraries)
super().init_service(service_name=Emby.__name__.lower(),
service_type=lambda conf: Emby(**conf.config, sync_libraries=conf.sync_libraries))
@staticmethod
def get_name() -> str:
@@ -38,9 +29,9 @@ class EmbyModule(_ModuleBase, _MediaServerBase):
"""
测试模块连接性
"""
if not self._servers:
if not self._instances:
return None
for name, server in self._servers.items():
for name, server in self._instances.items():
if server.is_inactive():
server.reconnect()
if not server.get_user():
@@ -55,7 +46,7 @@ class EmbyModule(_ModuleBase, _MediaServerBase):
定时任务每10分钟调用一次
"""
# 定时重连
for name, server in self._servers.items():
for name, server in self._instances.items():
if server.is_inactive():
logger.info(f"Emby服务器 {name} 连接断开,尝试重连 ...")
server.reconnect()
@@ -68,7 +59,7 @@ class EmbyModule(_ModuleBase, _MediaServerBase):
:return: token or None
"""
# Emby认证
for server in self._servers.values():
for server in self._instances.values():
result = server.authenticate(name, password)
if result:
return result
@@ -87,7 +78,7 @@ class EmbyModule(_ModuleBase, _MediaServerBase):
server_config: MediaServerConf = self.get_config(source, 'emby')
if not server_config:
return None
server: Emby = self.get_server(source)
server: Emby = self.get_instance(source)
if not server:
return None
return server.get_webhook_message(form, args)
@@ -95,7 +86,7 @@ class EmbyModule(_ModuleBase, _MediaServerBase):
for conf in self._configs.values():
if conf.type != "emby":
continue
server = self.get_server(conf.name)
server = self.get_instance(conf.name)
if server:
result = server.get_webhook_message(form, args)
if result:
@@ -109,7 +100,7 @@ class EmbyModule(_ModuleBase, _MediaServerBase):
:param itemid: 媒体服务器ItemID
:return: 如不存在返回None存在时返回信息包括每季已存在所有集{type: movie/tv, seasons: {season: [episodes]}}
"""
for name, server in self._servers.items():
for name, server in self._instances.items():
if mediainfo.type == MediaType.MOVIE:
if itemid:
movie = server.get_iteminfo(itemid)
@@ -156,12 +147,12 @@ class EmbyModule(_ModuleBase, _MediaServerBase):
媒体数量统计
"""
if server:
server: Emby = self.get_server(server)
server: Emby = self.get_instance(server)
if not server:
return None
servers = [server]
else:
servers = self._servers.values()
servers = self._instances.values()
media_statistics = []
for server in servers:
media_statistic = server.get_medias_count()
@@ -177,16 +168,17 @@ class EmbyModule(_ModuleBase, _MediaServerBase):
"""
媒体库列表
"""
server: Emby = self.get_server(server)
server: Emby = self.get_instance(server)
if server:
return server.get_librarys(username=username, hidden=hidden)
return None
def mediaserver_items(self, server: str, library_id: str, start_index: int = 0, limit: int = 100) -> Optional[Generator]:
def mediaserver_items(self, server: str, library_id: str, start_index: int = 0, limit: int = 100) \
-> Optional[Generator]:
"""
媒体库项目列表
"""
server: Emby = self.get_server(server)
server: Emby = self.get_instance(server)
if server:
return server.get_items(library_id, start_index, limit)
return None
@@ -195,7 +187,7 @@ class EmbyModule(_ModuleBase, _MediaServerBase):
"""
媒体库项目详情
"""
server: Emby = self.get_server(server)
server: Emby = self.get_instance(server)
if server:
return server.get_iteminfo(item_id)
return None
@@ -205,7 +197,7 @@ class EmbyModule(_ModuleBase, _MediaServerBase):
"""
获取剧集信息
"""
server: Emby = self.get_server(server)
server: Emby = self.get_instance(server)
if not server:
return None
_, seasoninfo = server.get_tv_episodes(item_id=item_id)
@@ -221,7 +213,7 @@ class EmbyModule(_ModuleBase, _MediaServerBase):
"""
获取媒体服务器正在播放信息
"""
server: Emby = self.get_server(server)
server: Emby = self.get_instance(server)
if not server:
return []
return server.get_resume(num=count, username=username)
@@ -230,7 +222,7 @@ class EmbyModule(_ModuleBase, _MediaServerBase):
"""
获取媒体库播放地址
"""
server: Emby = self.get_server(server)
server: Emby = self.get_instance(server)
if not server:
return None
return server.get_play_url(item_id)
@@ -240,7 +232,7 @@ class EmbyModule(_ModuleBase, _MediaServerBase):
"""
获取媒体服务器最新入库条目
"""
server: Emby = self.get_server(server)
server: Emby = self.get_instance(server)
if not server:
return []
return server.get_latest(num=count, username=username)

View File

@@ -1,8 +1,7 @@
from typing import Optional, Tuple, Union, Any, List, Generator, Dict
from typing import Optional, Tuple, Union, Any, List, Generator
from app import schemas
from app.core.context import MediaInfo
from app.helper.mediaserver import MediaServerHelper
from app.log import logger
from app.modules import _ModuleBase, _MediaServerBase
from app.modules.jellyfin.jellyfin import Jellyfin
@@ -10,22 +9,14 @@ from app.schemas import MediaServerConf
from app.schemas.types import MediaType
class JellyfinModule(_ModuleBase, _MediaServerBase):
class JellyfinModule(_ModuleBase, _MediaServerBase[Jellyfin]):
def init_module(self) -> None:
"""
初始化模块
"""
# 读取媒体服务器配置
self._servers: Dict[str, Jellyfin] = {}
self._configs: Dict[str, MediaServerConf] = {}
mediaservers = MediaServerHelper().get_mediaservers()
if not mediaservers:
return
for server in mediaservers:
if server.type == "jellyfin" and server.enabled:
self._configs[server.name] = server
self._servers[server.name] = Jellyfin(**server.config, sync_libraries=server.sync_libraries)
super().init_service(service_name=Jellyfin.__name__.lower(),
service_type=lambda conf: Jellyfin(**conf.config, sync_libraries=conf.sync_libraries))
@staticmethod
def get_name() -> str:
@@ -39,7 +30,7 @@ class JellyfinModule(_ModuleBase, _MediaServerBase):
定时任务每10分钟调用一次
"""
# 定时重连
for name, server in self._servers.items():
for name, server in self._instances.items():
if server.is_inactive():
logger.info(f"Jellyfin {name} 服务器连接断开,尝试重连 ...")
server.reconnect()
@@ -51,9 +42,9 @@ class JellyfinModule(_ModuleBase, _MediaServerBase):
"""
测试模块连接性
"""
if not self._servers:
if not self._instances:
return None
for name, server in self._servers.items():
for name, server in self._instances.items():
if server.is_inactive():
server.reconnect()
if not server.get_user():
@@ -68,7 +59,7 @@ class JellyfinModule(_ModuleBase, _MediaServerBase):
:return: Token or None
"""
# Jellyfin认证
for server in self._servers.values():
for server in self._instances.values():
result = server.authenticate(name, password)
if result:
return result
@@ -87,7 +78,7 @@ class JellyfinModule(_ModuleBase, _MediaServerBase):
server_config: MediaServerConf = self.get_config(source, 'jellyfin')
if not server_config:
return None
server: Jellyfin = self.get_server(source)
server: Jellyfin = self.get_instance(source)
if not server:
return None
return server.get_webhook_message(body)
@@ -95,7 +86,7 @@ class JellyfinModule(_ModuleBase, _MediaServerBase):
for conf in self._configs.values():
if conf.type != "jellyfin":
continue
server = self.get_server(conf.name)
server = self.get_instance(conf.name)
if server:
result = server.get_webhook_message(body)
if result:
@@ -109,7 +100,7 @@ class JellyfinModule(_ModuleBase, _MediaServerBase):
:param itemid: 媒体服务器ItemID
:return: 如不存在返回None存在时返回信息包括每季已存在所有集{type: movie/tv, seasons: {season: [episodes]}}
"""
for name, server in self._servers.items():
for name, server in self._instances.items():
if mediainfo.type == MediaType.MOVIE:
if itemid:
movie = server.get_iteminfo(itemid)
@@ -154,12 +145,12 @@ class JellyfinModule(_ModuleBase, _MediaServerBase):
媒体数量统计
"""
if server:
server: Jellyfin = self.get_server(server)
server: Jellyfin = self.get_instance(server)
if not server:
return None
servers = [server]
else:
servers = self._servers.values()
servers = self._instances.values()
media_statistics = []
for server in servers:
media_statistic = server.get_medias_count()
@@ -175,16 +166,17 @@ class JellyfinModule(_ModuleBase, _MediaServerBase):
"""
媒体库列表
"""
server: Jellyfin = self.get_server(server)
server: Jellyfin = self.get_instance(server)
if server:
return server.get_librarys(username=username, hidden=hidden)
return None
def mediaserver_items(self, server: str, library_id: str, start_index: int = 0, limit: int = 100) -> Optional[Generator]:
def mediaserver_items(self, server: str, library_id: str, start_index: int = 0, limit: int = 100) \
-> Optional[Generator]:
"""
媒体库项目列表
"""
server: Jellyfin = self.get_server(server)
server: Jellyfin = self.get_instance(server)
if server:
return server.get_items(library_id, start_index, limit)
return None
@@ -193,7 +185,7 @@ class JellyfinModule(_ModuleBase, _MediaServerBase):
"""
媒体库项目详情
"""
server: Jellyfin = self.get_server(server)
server: Jellyfin = self.get_instance(server)
if server:
return server.get_iteminfo(item_id)
return None
@@ -203,7 +195,7 @@ class JellyfinModule(_ModuleBase, _MediaServerBase):
"""
获取剧集信息
"""
server: Jellyfin = self.get_server(server)
server: Jellyfin = self.get_instance(server)
if not server:
return None
_, seasoninfo = server.get_tv_episodes(item_id=item_id)
@@ -219,7 +211,7 @@ class JellyfinModule(_ModuleBase, _MediaServerBase):
"""
获取媒体服务器正在播放信息
"""
server: Jellyfin = self.get_server(server)
server: Jellyfin = self.get_instance(server)
if not server:
return []
return server.get_resume(num=count, username=username)
@@ -228,7 +220,7 @@ class JellyfinModule(_ModuleBase, _MediaServerBase):
"""
获取媒体库播放地址
"""
server: Jellyfin = self.get_server(server)
server: Jellyfin = self.get_instance(server)
if not server:
return None
return server.get_play_url(item_id)
@@ -238,7 +230,7 @@ class JellyfinModule(_ModuleBase, _MediaServerBase):
"""
获取媒体服务器最新入库条目
"""
server: Jellyfin = self.get_server(server)
server: Jellyfin = self.get_instance(server)
if not server:
return []
return server.get_latest(num=count, username=username)

View File

@@ -1,8 +1,7 @@
from typing import Optional, Tuple, Union, Any, List, Generator, Dict
from typing import Optional, Tuple, Union, Any, List, Generator
from app import schemas
from app.core.context import MediaInfo
from app.helper.mediaserver import MediaServerHelper
from app.log import logger
from app.modules import _ModuleBase, _MediaServerBase
from app.modules.plex.plex import Plex
@@ -10,22 +9,14 @@ from app.schemas import MediaServerConf
from app.schemas.types import MediaType
class PlexModule(_ModuleBase, _MediaServerBase):
class PlexModule(_ModuleBase, _MediaServerBase[Plex]):
def init_module(self) -> None:
"""
初始化模块
"""
# 读取媒体服务器配置
self._servers: Dict[str, Plex] = {}
self._configs: Dict[str, MediaServerConf] = {}
mediaservers = MediaServerHelper().get_mediaservers()
if not mediaservers:
return
for server in mediaservers:
if server.type == "plex" and server.enabled:
self._configs[server.name] = server
self._servers[server.name] = Plex(**server.config, sync_libraries=server.sync_libraries)
super().init_service(service_name=Plex.__name__.lower(),
service_type=lambda conf: Plex(**conf.config, sync_libraries=conf.sync_libraries))
@staticmethod
def get_name() -> str:
@@ -38,9 +29,9 @@ class PlexModule(_ModuleBase, _MediaServerBase):
"""
测试模块连接性
"""
if not self._servers:
if not self._instances:
return None
for name, server in self._servers.items():
for name, server in self._instances.items():
if server.is_inactive():
server.reconnect()
if not server.get_librarys():
@@ -55,7 +46,7 @@ class PlexModule(_ModuleBase, _MediaServerBase):
定时任务每10分钟调用一次
"""
# 定时重连
for name, server in self._servers.items():
for name, server in self._instances.items():
if server.is_inactive():
logger.info(f"Plex {name} 服务器连接断开,尝试重连 ...")
server.reconnect()
@@ -73,7 +64,7 @@ class PlexModule(_ModuleBase, _MediaServerBase):
server_config: MediaServerConf = self.get_config(source, 'plex')
if not server_config:
return None
server: Plex = self.get_server(source)
server: Plex = self.get_instance(source)
if not server:
return None
return server.get_webhook_message(body)
@@ -81,7 +72,7 @@ class PlexModule(_ModuleBase, _MediaServerBase):
for conf in self._configs.values():
if conf.type != "plex":
continue
server = self.get_server(conf.name)
server = self.get_instance(conf.name)
if server:
result = server.get_webhook_message(body)
if result:
@@ -95,7 +86,7 @@ class PlexModule(_ModuleBase, _MediaServerBase):
:param itemid: 媒体服务器ItemID
:return: 如不存在返回None存在时返回信息包括每季已存在所有集{type: movie/tv, seasons: {season: [episodes]}}
"""
for name, server in self._servers.items():
for name, server in self._instances.items():
if mediainfo.type == MediaType.MOVIE:
if itemid:
movie = server.get_iteminfo(itemid)
@@ -144,12 +135,12 @@ class PlexModule(_ModuleBase, _MediaServerBase):
媒体数量统计
"""
if server:
server: Plex = self.get_server(server)
server: Plex = self.get_instance(server)
if not server:
return None
servers = [server]
else:
servers = self._servers.values()
servers = self._instances.values()
media_statistics = []
for server in servers:
media_statistic = server.get_medias_count()
@@ -163,16 +154,17 @@ class PlexModule(_ModuleBase, _MediaServerBase):
"""
媒体库列表
"""
server: Plex = self.get_server(server)
server: Plex = self.get_instance(server)
if server:
return server.get_librarys(hidden)
return None
def mediaserver_items(self, server: str, library_id: str, start_index: int = 0, limit: int = 100) -> Optional[Generator]:
def mediaserver_items(self, server: str, library_id: str, start_index: int = 0, limit: int = 100) \
-> Optional[Generator]:
"""
媒体库项目列表
"""
server: Plex = self.get_server(server)
server: Plex = self.get_instance(server)
if server:
return server.get_items(library_id, start_index, limit)
return None
@@ -181,7 +173,7 @@ class PlexModule(_ModuleBase, _MediaServerBase):
"""
媒体库项目详情
"""
server: Plex = self.get_server(server)
server: Plex = self.get_instance(server)
if server:
return server.get_iteminfo(item_id)
return None
@@ -191,7 +183,7 @@ class PlexModule(_ModuleBase, _MediaServerBase):
"""
获取剧集信息
"""
server: Plex = self.get_server(server)
server: Plex = self.get_instance(server)
if not server:
return None
_, seasoninfo = server.get_tv_episodes(item_id=item_id)
@@ -206,7 +198,7 @@ class PlexModule(_ModuleBase, _MediaServerBase):
"""
获取媒体服务器正在播放信息
"""
server: Plex = self.get_server(server)
server: Plex = self.get_instance(server)
if not server:
return []
return server.get_resume(num=count)
@@ -215,7 +207,7 @@ class PlexModule(_ModuleBase, _MediaServerBase):
"""
获取媒体服务器最新入库条目
"""
server: Plex = self.get_server(server)
server: Plex = self.get_instance(server)
if not server:
return []
return server.get_latest(num=count)
@@ -224,7 +216,7 @@ class PlexModule(_ModuleBase, _MediaServerBase):
"""
获取媒体库播放地址
"""
server: Plex = self.get_server(server)
server: Plex = self.get_instance(server)
if not server:
return None
return server.get_play_url(item_id)

View File

@@ -1,6 +1,6 @@
import shutil
from pathlib import Path
from typing import Set, Tuple, Optional, Union, List, Dict
from typing import Set, Tuple, Optional, Union, List
from qbittorrentapi import TorrentFilesList
from torrentool.torrent import Torrent
@@ -8,7 +8,6 @@ from torrentool.torrent import Torrent
from app import schemas
from app.core.config import settings
from app.core.metainfo import MetaInfo
from app.helper.downloader import DownloaderHelper
from app.log import logger
from app.modules import _ModuleBase, _DownloaderBase
from app.modules.qbittorrent.qbittorrent import Qbittorrent
@@ -18,23 +17,14 @@ from app.utils.string import StringUtils
from app.utils.system import SystemUtils
class QbittorrentModule(_ModuleBase, _DownloaderBase):
class QbittorrentModule(_ModuleBase, _DownloaderBase[Qbittorrent]):
def init_module(self) -> None:
"""
初始化模块
"""
# 读取下载器配置
self._servers: Dict[str, Qbittorrent] = {}
downloaders = DownloaderHelper().get_downloaders()
if not downloaders:
return
for server in downloaders:
if server.type == "qbittorrent" and server.enabled:
self._servers[server.name] = Qbittorrent(**server.config)
if server.default:
self._default_server_name = server.name
self._default_server = self._servers[server.name]
super().init_service(service_name=Qbittorrent.__name__.lower(),
service_type=Qbittorrent)
@staticmethod
def get_name() -> str:
@@ -47,9 +37,9 @@ class QbittorrentModule(_ModuleBase, _DownloaderBase):
"""
测试模块连接性
"""
if not self._servers:
if not self._instances:
return None
for name, server in self._servers.items():
for name, server in self._instances.items():
if server.is_inactive():
server.reconnect()
if not server.transfer_info():
@@ -63,7 +53,7 @@ class QbittorrentModule(_ModuleBase, _DownloaderBase):
"""
定时任务每10分钟调用一次
"""
for name, server in self._servers.items():
for name, server in self._instances.items():
if server.is_inactive():
logger.info(f"Qbittorrent下载器 {name} 连接断开,尝试重连 ...")
server.reconnect()
@@ -103,7 +93,7 @@ class QbittorrentModule(_ModuleBase, _DownloaderBase):
return None, None, f"种子文件不存在:{content}"
# 获取下载器
server: Qbittorrent = self.get_server(downloader)
server: Qbittorrent = self.get_instance(downloader)
if not server:
return None
@@ -201,7 +191,7 @@ class QbittorrentModule(_ModuleBase, _DownloaderBase):
:return: 下载器中符合状态的种子列表
"""
# 获取下载器
server: Qbittorrent = self.get_server(downloader)
server: Qbittorrent = self.get_instance(downloader)
if not server:
return None
@@ -274,7 +264,7 @@ class QbittorrentModule(_ModuleBase, _DownloaderBase):
:param downloader: 下载器
:param transfer_type: 整理方式
"""
server: Qbittorrent = self.get_server(downloader)
server: Qbittorrent = self.get_instance(downloader)
if not server:
return None
server.set_torrents_tag(ids=hashs, tags=['已整理'])
@@ -298,7 +288,7 @@ class QbittorrentModule(_ModuleBase, _DownloaderBase):
:param downloader: 下载器
:return: bool
"""
server: Qbittorrent = self.get_server(downloader)
server: Qbittorrent = self.get_instance(downloader)
if not server:
return None
return server.delete_torrents(delete_file=delete_file, ids=hashs)
@@ -311,7 +301,7 @@ class QbittorrentModule(_ModuleBase, _DownloaderBase):
:param downloader: 下载器
:return: bool
"""
server: Qbittorrent = self.get_server(downloader)
server: Qbittorrent = self.get_instance(downloader)
if not server:
return None
return server.start_torrents(ids=hashs)
@@ -323,7 +313,7 @@ class QbittorrentModule(_ModuleBase, _DownloaderBase):
:param downloader: 下载器
:return: bool
"""
server: Qbittorrent = self.get_server(downloader)
server: Qbittorrent = self.get_instance(downloader)
if not server:
return None
return server.stop_torrents(ids=hashs)
@@ -332,7 +322,7 @@ class QbittorrentModule(_ModuleBase, _DownloaderBase):
"""
获取种子文件列表
"""
server: Qbittorrent = self.get_server(downloader)
server: Qbittorrent = self.get_instance(downloader)
if not server:
return None
return server.get_files(tid=tid)
@@ -342,12 +332,12 @@ class QbittorrentModule(_ModuleBase, _DownloaderBase):
下载器信息
"""
if downloader:
server: Qbittorrent = self.get_server(downloader)
server: Qbittorrent = self.get_instance(downloader)
if not server:
return None
servers = [server]
else:
servers = self._servers.values()
servers = self._instances.values()
# 调用Qbittorrent API查询实时信息
ret_info = []
for server in servers:

View File

@@ -4,28 +4,20 @@ from typing import Optional, Union, List, Tuple, Any
from app.core.config import settings
from app.core.context import MediaInfo, Context
from app.helper.notification import NotificationHelper
from app.log import logger
from app.modules import _ModuleBase, _MessageBase
from app.modules.slack.slack import Slack
from app.schemas import MessageChannel, CommingMessage, Notification
class SlackModule(_ModuleBase, _MessageBase):
class SlackModule(_ModuleBase, _MessageBase[Slack]):
def init_module(self) -> None:
"""
初始化模块
"""
clients = NotificationHelper().get_clients()
if not clients:
return
self._configs = {}
self._clients = {}
for client in clients:
if client.type == "slack" and client.enabled:
self._configs[client.name] = client
self._clients[client.name] = Slack(**client.config, name=client.name)
super().init_service(service_name=Slack.__name__.lower(),
service_type=Slack)
@staticmethod
def get_name() -> str:
@@ -35,16 +27,16 @@ class SlackModule(_ModuleBase, _MessageBase):
"""
停止模块
"""
for client in self._clients.values():
for client in self._instances.values():
client.stop()
def test(self) -> Optional[Tuple[bool, str]]:
"""
测试模块连接性
"""
if not self._clients:
if not self._instances:
return None
for name, client in self._clients.items():
for name, client in self._instances.items():
state = client.get_state()
if not state:
return False, f"Slack {name} 未就续"
@@ -223,7 +215,7 @@ class SlackModule(_ModuleBase, _MessageBase):
:return: 成功或失败
"""
for conf in self._configs.values():
if not self.checkMessage(message, conf.name):
if not self.check_message(message, conf.name):
continue
targets = message.targets
userid = message.userid
@@ -232,7 +224,7 @@ class SlackModule(_ModuleBase, _MessageBase):
if not userid:
logger.warn(f"用户没有指定 Slack用户ID消息无法发送")
return
client: Slack = self.get_client(conf.name)
client: Slack = self.get_instance(conf.name)
if client:
client.send_msg(title=message.title, text=message.text,
image=message.image, userid=userid, link=message.link)
@@ -245,9 +237,9 @@ class SlackModule(_ModuleBase, _MessageBase):
:return: 成功或失败
"""
for conf in self._configs.values():
if not self.checkMessage(message, conf.name):
if not self.check_message(message, conf.name):
continue
client: Slack = self.get_client(conf.name)
client: Slack = self.get_instance(conf.name)
if client:
client.send_medias_msg(title=message.title, medias=medias, userid=message.userid)
@@ -259,9 +251,9 @@ class SlackModule(_ModuleBase, _MessageBase):
:return: 成功或失败
"""
for conf in self._configs.values():
if not self.checkMessage(message, conf.name):
if not self.check_message(message, conf.name):
continue
client: Slack = self.get_client(conf.name)
client: Slack = self.get_instance(conf.name)
if client:
client.send_torrents_msg(title=message.title, torrents=torrents,
userid=message.userid)

View File

@@ -1,28 +1,20 @@
from typing import Optional, Union, List, Tuple, Any
from app.core.context import MediaInfo, Context
from app.helper.notification import NotificationHelper
from app.log import logger
from app.modules import _ModuleBase, _MessageBase
from app.modules.synologychat.synologychat import SynologyChat
from app.schemas import MessageChannel, CommingMessage, Notification
class SynologyChatModule(_ModuleBase, _MessageBase):
class SynologyChatModule(_ModuleBase, _MessageBase[SynologyChat]):
def init_module(self) -> None:
"""
初始化模块
"""
clients = NotificationHelper().get_clients()
if not clients:
return
self._configs = {}
self._clients = {}
for client in clients:
if client.type == "synologychat" and client.enabled:
self._configs[client.name] = client
self._clients[client.name] = SynologyChat(**client.config)
super().init_service(service_name=SynologyChat.__name__.lower(),
service_type=SynologyChat)
@staticmethod
def get_name() -> str:
@@ -35,9 +27,9 @@ class SynologyChatModule(_ModuleBase, _MessageBase):
"""
测试模块连接性
"""
if not self._clients:
if not self._instances:
return None
for name, client in self._clients.items():
for name, client in self._instances.items():
state = client.get_state()
if not state:
return False, f"Synology Chat {name} 未就续"
@@ -64,7 +56,7 @@ class SynologyChatModule(_ModuleBase, _MessageBase):
client_config = self.get_config(source, 'synologychat')
if not client_config:
return None
client: SynologyChat = self.get_client(source)
client: SynologyChat = self.get_instance(source)
# 解析消息
message: dict = form
if not message:
@@ -94,7 +86,7 @@ class SynologyChatModule(_ModuleBase, _MessageBase):
:return: 成功或失败
"""
for conf in self._configs.values():
if not self.checkMessage(message, conf.name):
if not self.check_message(message, conf.name):
continue
targets = message.targets
userid = message.userid
@@ -103,7 +95,7 @@ class SynologyChatModule(_ModuleBase, _MessageBase):
if not userid:
logger.warn(f"用户没有指定 SynologyChat用户ID消息无法发送")
return
client: SynologyChat = self.get_client(conf.name)
client: SynologyChat = self.get_instance(conf.name)
if client:
client.send_msg(title=message.title, text=message.text,
image=message.image, userid=userid, link=message.link)
@@ -116,9 +108,9 @@ class SynologyChatModule(_ModuleBase, _MessageBase):
:return: 成功或失败
"""
for conf in self._configs.values():
if not self.checkMessage(message, conf.name):
if not self.check_message(message, conf.name):
continue
client: SynologyChat = self.get_client(conf.name)
client: SynologyChat = self.get_instance(conf.name)
if client:
client.send_medias_msg(title=message.title, medias=medias,
userid=message.userid)
@@ -131,9 +123,9 @@ class SynologyChatModule(_ModuleBase, _MessageBase):
:return: 成功或失败
"""
for conf in self._configs.values():
if not self.checkMessage(message, conf.name):
if not self.check_message(message, conf.name):
continue
client: SynologyChat = self.get_client(conf.name)
client: SynologyChat = self.get_instance(conf.name)
if client:
client.send_torrents_msg(title=message.title, torrents=torrents,
userid=message.userid, link=message.link)

View File

@@ -3,28 +3,20 @@ from typing import Optional, Union, List, Tuple, Any, Dict
from app.core.config import settings
from app.core.context import MediaInfo, Context
from app.helper.notification import NotificationHelper
from app.log import logger
from app.modules import _ModuleBase, _MessageBase
from app.modules.telegram.telegram import Telegram
from app.schemas import MessageChannel, CommingMessage, Notification
class TelegramModule(_ModuleBase, _MessageBase):
class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
def init_module(self) -> None:
"""
初始化模块
"""
clients = NotificationHelper().get_clients()
if not clients:
return
self._configs = {}
self._clients = {}
for client in clients:
if client.type == "telegram" and client.enabled:
self._configs[client.name] = client
self._clients[client.name] = Telegram(**client.config, name=client.name)
super().init_service(service_name=Telegram.__name__.lower(),
service_type=Telegram)
@staticmethod
def get_name() -> str:
@@ -34,16 +26,16 @@ class TelegramModule(_ModuleBase, _MessageBase):
"""
停止模块
"""
for client in self._clients.values():
for client in self._instances.values():
client.stop()
def test(self) -> Optional[Tuple[bool, str]]:
"""
测试模块连接性
"""
if not self._clients:
if not self._instances:
return None
for name, client in self._clients.items():
for name, client in self._instances.items():
state = client.get_state()
if not state:
return False, f"Telegram {name} 未就续"
@@ -92,7 +84,7 @@ class TelegramModule(_ModuleBase, _MessageBase):
client_config = self.get_config(source, 'telegram')
if not client_config:
return None
client: Telegram = self.get_client(source)
client: Telegram = self.get_instance(source)
# 校验token
token = args.get("token")
if not token or token != settings.API_TOKEN:
@@ -136,7 +128,7 @@ class TelegramModule(_ModuleBase, _MessageBase):
:return: 成功或失败
"""
for conf in self._configs.values():
if not self.checkMessage(message, conf.name):
if not self.check_message(message, conf.name):
continue
targets = message.targets
userid = message.userid
@@ -145,7 +137,7 @@ class TelegramModule(_ModuleBase, _MessageBase):
if not userid:
logger.warn(f"用户没有指定 Telegram用户ID消息无法发送")
return
client: Telegram = self.get_client(conf.name)
client: Telegram = self.get_instance(conf.name)
if client:
client.send_msg(title=message.title, text=message.text,
image=message.image, userid=userid, link=message.link)
@@ -158,9 +150,9 @@ class TelegramModule(_ModuleBase, _MessageBase):
:return: 成功或失败
"""
for conf in self._configs.values():
if not self.checkMessage(message, conf.name):
if not self.check_message(message, conf.name):
continue
client: Telegram = self.get_client(conf.name)
client: Telegram = self.get_instance(conf.name)
if client:
client.send_medias_msg(title=message.title, medias=medias,
userid=message.userid, link=message.link)
@@ -173,9 +165,9 @@ class TelegramModule(_ModuleBase, _MessageBase):
:return: 成功或失败
"""
for conf in self._configs.values():
if not self.checkMessage(message, conf.name):
if not self.check_message(message, conf.name):
continue
client: Telegram = self.get_client(conf.name)
client: Telegram = self.get_instance(conf.name)
if client:
client.send_torrents_msg(title=message.title, torrents=torrents,
userid=message.userid, link=message.link)
@@ -185,5 +177,5 @@ class TelegramModule(_ModuleBase, _MessageBase):
注册命令,实现这个函数接收系统可用的命令菜单
:param commands: 命令字典
"""
for client in self._clients.values():
for client in self._instances.values():
client.register_commands(commands)

View File

@@ -1,6 +1,6 @@
import shutil
from pathlib import Path
from typing import Set, Tuple, Optional, Union, List, Dict
from typing import Set, Tuple, Optional, Union, List
from torrentool.torrent import Torrent
from transmission_rpc import File
@@ -8,7 +8,6 @@ from transmission_rpc import File
from app import schemas
from app.core.config import settings
from app.core.metainfo import MetaInfo
from app.helper.downloader import DownloaderHelper
from app.log import logger
from app.modules import _ModuleBase, _DownloaderBase
from app.modules.transmission.transmission import Transmission
@@ -18,20 +17,14 @@ from app.utils.string import StringUtils
from app.utils.system import SystemUtils
class TransmissionModule(_ModuleBase, _DownloaderBase):
class TransmissionModule(_ModuleBase, _DownloaderBase[Transmission]):
def init_module(self) -> None:
# 读取下载器配置
self._servers: Dict[str, Transmission] = {}
downloaders = DownloaderHelper().get_downloaders()
if not downloaders:
return
for server in downloaders:
if server.type == "transmission" and server.enabled:
self._servers[server.name] = Transmission(**server.config)
if server.default:
self._default_server_name = server.name
self._default_server = self._servers[server.name]
"""
初始化模块
"""
super().init_service(service_name=Transmission.__name__.lower(),
service_type=Transmission)
@staticmethod
def get_name() -> str:
@@ -44,9 +37,9 @@ class TransmissionModule(_ModuleBase, _DownloaderBase):
"""
测试模块连接性
"""
if not self._servers:
if not self._instances:
return None
for name, server in self._servers.items():
for name, server in self._instances.items():
if server.is_inactive():
server.reconnect()
if not server.transfer_info():
@@ -61,7 +54,7 @@ class TransmissionModule(_ModuleBase, _DownloaderBase):
定时任务每10分钟调用一次
"""
# 定时重连
for name, server in self._servers.items():
for name, server in self._instances.items():
if server.is_inactive():
logger.info(f"Transmission下载器 {name} 连接断开,尝试重连 ...")
server.reconnect()
@@ -100,7 +93,7 @@ class TransmissionModule(_ModuleBase, _DownloaderBase):
return None, None, f"种子文件不存在:{content}"
# 获取下载器
server: Transmission = self.get_server(downloader)
server: Transmission = self.get_instance(downloader)
if not server:
return None
@@ -191,7 +184,7 @@ class TransmissionModule(_ModuleBase, _DownloaderBase):
:return: 下载器中符合状态的种子列表
"""
# 获取下载器
server: Transmission = self.get_server(downloader)
server: Transmission = self.get_instance(downloader)
if not server:
return None
ret_torrents = []
@@ -259,7 +252,7 @@ class TransmissionModule(_ModuleBase, _DownloaderBase):
:param transfer_type: 整理方式
"""
# 获取下载器
server: Transmission = self.get_server(downloader)
server: Transmission = self.get_instance(downloader)
if not server:
return None
# 获取原标签
@@ -291,7 +284,7 @@ class TransmissionModule(_ModuleBase, _DownloaderBase):
:return: bool
"""
# 获取下载器
server: Transmission = self.get_server(downloader)
server: Transmission = self.get_instance(downloader)
if not server:
return None
return server.delete_torrents(delete_file=delete_file, ids=hashs)
@@ -305,7 +298,7 @@ class TransmissionModule(_ModuleBase, _DownloaderBase):
:return: bool
"""
# 获取下载器
server: Transmission = self.get_server(downloader)
server: Transmission = self.get_instance(downloader)
if not server:
return None
return server.start_torrents(ids=hashs)
@@ -319,7 +312,7 @@ class TransmissionModule(_ModuleBase, _DownloaderBase):
:return: bool
"""
# 获取下载器
server: Transmission = self.get_server(downloader)
server: Transmission = self.get_instance(downloader)
if not server:
return None
return server.start_torrents(ids=hashs)
@@ -329,7 +322,7 @@ class TransmissionModule(_ModuleBase, _DownloaderBase):
获取种子文件列表
"""
# 获取下载器
server: Transmission = self.get_server(downloader)
server: Transmission = self.get_instance(downloader)
if not server:
return None
return server.get_files(tid=tid)
@@ -339,12 +332,12 @@ class TransmissionModule(_ModuleBase, _DownloaderBase):
下载器信息
"""
if downloader:
server: Transmission = self.get_server(downloader)
server: Transmission = self.get_instance(downloader)
if not server:
return None
servers = [server]
else:
servers = self._servers.values()
servers = self._instances.values()
# 调用Qbittorrent API查询实时信息
ret_info = []
for server in servers:

View File

@@ -3,28 +3,20 @@ from typing import Optional, Union, List, Tuple, Any, Dict
from app.core.config import settings
from app.core.context import Context, MediaInfo
from app.helper.notification import NotificationHelper
from app.log import logger
from app.modules import _ModuleBase, _MessageBase
from app.modules.vocechat.vocechat import VoceChat
from app.schemas import MessageChannel, CommingMessage, Notification
class VoceChatModule(_ModuleBase, _MessageBase):
class VoceChatModule(_ModuleBase, _MessageBase[VoceChat]):
def init_module(self) -> None:
"""
初始化模块
"""
clients = NotificationHelper().get_clients()
if not clients:
return
self._configs = {}
self._clients = {}
for client in clients:
if client.type == "vocechat" and client.enabled:
self._configs[client.name] = client
self._clients[client.name] = VoceChat(**client.config)
super().init_service(service_name=VoceChat.__name__.lower(),
service_type=VoceChat)
@staticmethod
def get_name() -> str:
@@ -37,9 +29,9 @@ class VoceChatModule(_ModuleBase, _MessageBase):
"""
测试模块连接性
"""
if not self._clients:
if not self._instances:
return None
for name, client in self._clients.items():
for name, client in self._instances.items():
state = client.get_state()
if not state:
return False, f"VoceChat {name} 未就续"
@@ -122,13 +114,13 @@ class VoceChatModule(_ModuleBase, _MessageBase):
:return: 成功或失败
"""
for conf in self._configs.values():
if not self.checkMessage(message, conf.name):
if not self.check_message(message, conf.name):
continue
targets = message.targets
userid = message.userid
if not message.userid and targets:
userid = targets.get('telegram_userid')
client: VoceChat = self.get_client(conf.name)
client: VoceChat = self.get_instance(conf.name)
if client:
client.send_msg(title=message.title, text=message.text,
userid=userid, link=message.link)
@@ -141,9 +133,9 @@ class VoceChatModule(_ModuleBase, _MessageBase):
:return: 成功或失败
"""
for conf in self._configs.values():
if not self.checkMessage(message, conf.name):
if not self.check_message(message, conf.name):
continue
client: VoceChat = self.get_client(conf.name)
client: VoceChat = self.get_instance(conf.name)
if client:
client.send_msg(title=message.title, userid=message.userid)
client.send_medias_msg(title=message.title, medias=medias,
@@ -157,7 +149,7 @@ class VoceChatModule(_ModuleBase, _MessageBase):
:return: 成功或失败
"""
for conf in self._configs.values():
if not self.checkMessage(message, conf.name):
if not self.check_message(message, conf.name):
continue
targets = message.targets
userid = message.userid
@@ -166,7 +158,7 @@ class VoceChatModule(_ModuleBase, _MessageBase):
if not userid:
logger.warn(f"用户没有指定 VoceChat用户ID消息无法发送")
return
client: VoceChat = self.get_client(conf.name)
client: VoceChat = self.get_instance(conf.name)
if client:
client.send_torrents_msg(title=message.title, torrents=torrents,
userid=userid, link=message.link)

View File

@@ -4,24 +4,18 @@ from typing import Union, Tuple
from pywebpush import webpush, WebPushException
from app.core.config import global_vars, settings
from app.helper.notification import NotificationHelper
from app.log import logger
from app.modules import _ModuleBase, _MessageBase
from app.schemas import Notification
class WebPushModule(_ModuleBase, _MessageBase):
def init_module(self) -> None:
"""
初始化模块
"""
clients = NotificationHelper().get_clients()
if not clients:
return
self._configs = {}
for client in clients:
if client.type == "webpush" and client.enabled:
self._configs[client.name] = client
super().init_service(service_name=self.get_name().lower())
@staticmethod
def get_name() -> str:
@@ -46,7 +40,7 @@ class WebPushModule(_ModuleBase, _MessageBase):
:return: 成功或失败
"""
for conf in self._configs.values():
if not self.checkMessage(message, conf.name):
if not self.check_message(message, conf.name):
continue
webpush_users = conf.config.get("WEBPUSH_USERNAME") or ""
if webpush_users:

View File

@@ -2,7 +2,6 @@ import xml.dom.minidom
from typing import Optional, Union, List, Tuple, Any, Dict
from app.core.context import Context, MediaInfo
from app.helper.notification import NotificationHelper
from app.log import logger
from app.modules import _ModuleBase, _MessageBase
from app.modules.wechat.WXBizMsgCrypt3 import WXBizMsgCrypt
@@ -11,21 +10,14 @@ from app.schemas import MessageChannel, CommingMessage, Notification
from app.utils.dom import DomUtils
class WechatModule(_ModuleBase, _MessageBase):
class WechatModule(_ModuleBase, _MessageBase[WeChat]):
def init_module(self) -> None:
"""
初始化模块
"""
clients = NotificationHelper().get_clients()
if not clients:
return
self._configs = {}
self._clients = {}
for client in clients:
if client.type == "wechat" and client.enabled:
self._configs[client.name] = client
self._clients[client.name] = WeChat(**client.config)
super().init_service(service_name=WeChat.__name__.lower(),
service_type=WeChat)
@staticmethod
def get_name() -> str:
@@ -38,9 +30,9 @@ class WechatModule(_ModuleBase, _MessageBase):
"""
测试模块连接性
"""
if not self._clients:
if not self._instances:
return None
for name, client in self._clients.items():
for name, client in self._instances.items():
state = client.get_state()
if not state:
return False, f"企业微信 {name} 未就续"
@@ -67,7 +59,7 @@ class WechatModule(_ModuleBase, _MessageBase):
client_config = self.get_config(source, 'wechat')
if not client_config:
return None
client: WeChat = self.get_client(source)
client: WeChat = self.get_instance(source)
# URL参数
sVerifyMsgSig = args.get("msg_signature")
sVerifyTimeStamp = args.get("timestamp")
@@ -159,7 +151,7 @@ class WechatModule(_ModuleBase, _MessageBase):
:return: 成功或失败
"""
for conf in self._configs.values():
if not self.checkMessage(message, conf.name):
if not self.check_message(message, conf.name):
continue
targets = message.targets
userid = message.userid
@@ -168,7 +160,7 @@ class WechatModule(_ModuleBase, _MessageBase):
if not userid:
logger.warn(f"用户没有指定 微信用户ID消息无法发送")
return
client: WeChat = self.get_client(conf.name)
client: WeChat = self.get_instance(conf.name)
if client:
client.send_msg(title=message.title, text=message.text,
image=message.image, userid=userid, link=message.link)
@@ -181,9 +173,9 @@ class WechatModule(_ModuleBase, _MessageBase):
:return: 成功或失败
"""
for conf in self._configs.values():
if not self.checkMessage(message, conf.name):
if not self.check_message(message, conf.name):
continue
client: WeChat = self.get_client(conf.name)
client: WeChat = self.get_instance(conf.name)
if client:
# 先发送标题
client.send_msg(title=message.title, userid=message.userid, link=message.link)
@@ -198,9 +190,9 @@ class WechatModule(_ModuleBase, _MessageBase):
:return: 成功或失败
"""
for conf in self._configs.values():
if not self.checkMessage(message, conf.name):
if not self.check_message(message, conf.name):
continue
client: WeChat = self.get_client(conf.name)
client: WeChat = self.get_instance(conf.name)
if client:
client.send_torrents_msg(title=message.title, torrents=torrents,
userid=message.userid, link=message.link)
@@ -210,5 +202,5 @@ class WechatModule(_ModuleBase, _MessageBase):
注册命令,实现这个函数接收系统可用的命令菜单
:param commands: 命令字典
"""
for client in self._clients.values():
for client in self._instances.values():
client.create_menus(commands)