diff --git a/app/api/endpoints/message.py b/app/api/endpoints/message.py index 7af814d7..f438e9e7 100644 --- a/app/api/endpoints/message.py +++ b/app/api/endpoints/message.py @@ -81,26 +81,29 @@ def wechat_verify(echostr: str, msg_signature: str, timestamp: Union[str, int], """ 微信验证响应 """ - clients = ServiceConfigHelper.get_notification_configs() - if not clients: - return - for client in clients: - if client.type == "wechat" and client.enabled and (not source or client.name == source): - try: - wxcpt = WXBizMsgCrypt(sToken=client.config.get('WECHAT_TOKEN'), - sEncodingAESKey=client.config.get('WECHAT_ENCODING_AESKEY'), - sReceiveId=client.config.get('WECHAT_CORPID')) - ret, sEchoStr = wxcpt.VerifyURL(sMsgSignature=msg_signature, - sTimeStamp=timestamp, - sNonce=nonce, - sEchoStr=echostr) - if ret == 0: - # 验证URL成功,将sEchoStr返回给企业号 - return PlainTextResponse(sEchoStr) - except Exception as err: - logger.error(f"微信请求验证失败: {str(err)}") - return str(err) - return "未找到对应的消息配置" + # 获取服务配置 + client_configs = ServiceConfigHelper.get_notification_configs() + if not client_configs: + return "未找到对应的消息配置" + client_config = next((config for config in client_configs if + config.type == "wechat" and config.enabled and (not source or config.name == source)), None) + if not client_config: + return "未找到对应的消息配置" + try: + wxcpt = WXBizMsgCrypt(sToken=client_config.config.get('WECHAT_TOKEN'), + sEncodingAESKey=client_config.config.get('WECHAT_ENCODING_AESKEY'), + sReceiveId=client_config.config.get('WECHAT_CORPID')) + ret, sEchoStr = wxcpt.VerifyURL(sMsgSignature=msg_signature, + sTimeStamp=timestamp, + sNonce=nonce, + sEchoStr=echostr) + if ret == 0: + # 验证URL成功,将sEchoStr返回给企业号 + return PlainTextResponse(sEchoStr) + return "微信验证失败" + except Exception as err: + logger.error(f"微信请求验证失败: {str(err)}") + return str(err) def vocechat_verify() -> Any: diff --git a/app/modules/__init__.py b/app/modules/__init__.py index d6db3474..b759d832 100644 --- a/app/modules/__init__.py +++ b/app/modules/__init__.py @@ -1,5 +1,5 @@ from abc import abstractmethod, ABCMeta -from typing import Generic, Tuple, Union, TypeVar, Type, Dict, Optional, Callable, Any +from typing import Generic, Tuple, Union, TypeVar, Type, Dict, Optional, Callable from app.helper.service import ServiceConfigHelper from app.schemas import Notification, MessageChannel, NotificationConf, MediaServerConf, DownloaderConf @@ -42,7 +42,7 @@ class _ModuleBase(metaclass=ABCMeta): 获取模块类型 """ pass - + @staticmethod @abstractmethod def get_priority() -> int: @@ -120,16 +120,19 @@ class ServiceBase(Generic[TService, TConf], metaclass=ABCMeta): """ return self._instances or {} - def get_instance(self, name: str) -> Optional[TService]: + def get_instance(self, name: Optional[str] = None) -> Optional[TService]: """ - 获取服务实例 + 获取指定名称的服务实例 - :param name: 实例名称 - :return: 返回对应名称的服务实例,若不存在则返回 None + :param name: 实例名称,可选。如果为 None,则返回默认实例 + :return: 返回符合条件的服务实例,若不存在则返回 None """ - if not name or not self._instances: + if not self._instances: return None - return self._instances.get(name) + if name: + return self._instances.get(name) + name = self.get_default_config_name() + return self._instances.get(name) if name else None @abstractmethod def get_configs(self) -> Dict[str, TConf]: @@ -140,16 +143,33 @@ class ServiceBase(Generic[TService, TConf], metaclass=ABCMeta): """ pass - def get_config(self, name: str) -> Optional[TConf]: + def get_config(self, name: Optional[str] = None) -> Optional[TConf]: """ - 获取配置,支持类型过滤 + 获取指定名称的服务配置 - :param name: 配置名称 + :param name: 配置名称,可选。如果为 None,则返回默认服务配置 :return: 返回符合条件的配置,若不存在则返回 None """ - if not name or not self._configs: + if not self._configs: return None - return self._configs.get(name) + if name: + return self._configs.get(name) + name = self.get_default_config_name() + return self._configs.get(name) if name else None + + def get_default_config_name(self) -> Optional[str]: + """ + 获取默认服务配置的名称 + + :return: 返回第一个设置为默认的配置名称;如果没有默认配置,则返回第一个配置的名称;如果没有配置,返回 None + """ + # 优先查找默认配置 + for conf in self._configs.values(): + if getattr(conf, "default", False): + return conf.name + # 如果没有默认配置,返回第一个配置的名称 + first_conf = next(iter(self._configs.values()), None) + return first_conf.name if first_conf else None class _MessageBase(ServiceBase[TService, NotificationConf]): @@ -206,40 +226,6 @@ class _DownloaderBase(ServiceBase[TService, DownloaderConf]): 下载器基类 """ - def __init__(self): - """ - 初始化下载器基类,并设置默认服务器 - """ - super().__init__() - self._default_server: Any = None - self._default_server_name: Optional[str] = None - - def init_service(self, service_name: str, - service_type: Optional[Union[Type[TService], Callable[..., TService]]] = None): - """ - 初始化服务,获取配置并实例化对应服务 - - :param service_name: 服务名称,作为配置匹配的依据 - :param service_type: 服务的类型,可以是类类型(Type[TService])或工厂函数(Callable),用于创建服务实例 - """ - super().init_service(service_name=service_name, service_type=service_type) - if self._configs: - for conf in self._configs.values(): - if conf.default: - self._default_server_name = conf.name - self._default_server = self.get_instance(conf.name) - - def get_instance(self, name: str = None) -> Optional[Any]: - """ - 获取实例,name为空时,返回默认实例 - - :param name: 实例名称,可选,默认为 None - :return: 返回指定名称的实例,若 name 为 None 则返回默认实例 - """ - if name: - return self._instances.get(name) - return self._default_server - def get_configs(self) -> Dict[str, DownloaderConf]: """ 获取已启用的下载器的配置字典 diff --git a/app/modules/qbittorrent/__init__.py b/app/modules/qbittorrent/__init__.py index 2af31e5a..303356ff 100644 --- a/app/modules/qbittorrent/__init__.py +++ b/app/modules/qbittorrent/__init__.py @@ -150,7 +150,7 @@ class QbittorrentModule(_ModuleBase, _DownloaderBase[Qbittorrent]): if settings.TORRENT_TAG and settings.TORRENT_TAG not in torrent_tags: logger.info(f"给种子 {torrent_hash} 打上标签:{settings.TORRENT_TAG}") server.set_torrents_tag(ids=torrent_hash, tags=[settings.TORRENT_TAG]) - return downloader or self._default_server_name, torrent_hash, f"下载任务已存在" + return downloader or self.get_default_config_name(), torrent_hash, f"下载任务已存在" return None, None, f"添加种子任务失败:{content}" else: # 获取种子Hash @@ -162,7 +162,7 @@ class QbittorrentModule(_ModuleBase, _DownloaderBase[Qbittorrent]): # 种子文件 torrent_files = server.get_files(torrent_hash) if not torrent_files: - return downloader or self._default_server_name, torrent_hash, "获取种子文件失败,下载任务可能在暂停状态" + return downloader or self.get_default_config_name(), torrent_hash, "获取种子文件失败,下载任务可能在暂停状态" # 不需要的文件ID file_ids = [] @@ -187,11 +187,11 @@ class QbittorrentModule(_ModuleBase, _DownloaderBase[Qbittorrent]): server.torrents_set_force_start(torrent_hash) else: server.start_torrents(torrent_hash) - return downloader or self._default_server_name, torrent_hash, f"添加下载成功,已选择集数:{sucess_epidised}" + return downloader or self.get_default_config_name(), torrent_hash, f"添加下载成功,已选择集数:{sucess_epidised}" else: if server.is_force_resume(): server.torrents_set_force_start(torrent_hash) - return downloader or self._default_server_name, torrent_hash, "添加下载成功" + return downloader or self.get_default_config_name(), torrent_hash, "添加下载成功" def list_torrents(self, status: TorrentStatus = None, hashs: Union[list, str] = None, diff --git a/app/modules/slack/__init__.py b/app/modules/slack/__init__.py index 34c8fc85..7266ee82 100644 --- a/app/modules/slack/__init__.py +++ b/app/modules/slack/__init__.py @@ -182,14 +182,8 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]): ] } """ - # 获取客户端 - client_config = None - if source: - client_config = self.get_config(source) - else: - client_configs = self.get_configs() - if client_configs: - client_config = list(client_configs.values())[0] + # 获取服务配置 + client_config = self.get_config(source) if not client_config: return None try: diff --git a/app/modules/synologychat/__init__.py b/app/modules/synologychat/__init__.py index 05a0faf2..131146d8 100644 --- a/app/modules/synologychat/__init__.py +++ b/app/modules/synologychat/__init__.py @@ -68,14 +68,8 @@ class SynologyChatModule(_ModuleBase, _MessageBase[SynologyChat]): :return: 渠道、消息体 """ try: - # 来源 - client_config = None - if source: - client_config = self.get_config(source) - else: - client_configs = self.get_configs() - if client_configs: - client_config = list(client_configs.values())[0] + # 获取服务配置 + client_config = self.get_config(source) if not client_config: return None client: SynologyChat = self.get_instance(client_config.name) diff --git a/app/modules/telegram/__init__.py b/app/modules/telegram/__init__.py index 23106afd..2e8c6b9f 100644 --- a/app/modules/telegram/__init__.py +++ b/app/modules/telegram/__init__.py @@ -95,14 +95,8 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]): } } """ - # 获取渠道 - client_config = None - if source: - client_config = self.get_config(source) - else: - client_configs = self.get_configs() - if client_configs: - client_config = list(client_configs.values())[0] + # 获取服务配置 + client_config = self.get_config(source) if not client_config: return None client: Telegram = self.get_instance(client_config.name) diff --git a/app/modules/transmission/__init__.py b/app/modules/transmission/__init__.py index 46bec806..c4db841a 100644 --- a/app/modules/transmission/__init__.py +++ b/app/modules/transmission/__init__.py @@ -153,7 +153,7 @@ class TransmissionModule(_ModuleBase, _DownloaderBase[Transmission]): if settings.TORRENT_TAG and settings.TORRENT_TAG not in labels: labels.append(settings.TORRENT_TAG) server.set_torrent_tag(ids=torrent_hash, tags=labels) - return downloader or self._default_server_name, torrent_hash, f"下载任务已存在" + return downloader or self.get_default_config_name(), torrent_hash, f"下载任务已存在" return None, None, f"添加种子任务失败:{content}" else: torrent_hash = torrent.hashString @@ -161,7 +161,7 @@ class TransmissionModule(_ModuleBase, _DownloaderBase[Transmission]): # 选择文件 torrent_files = server.get_files(torrent_hash) if not torrent_files: - return downloader or self._default_server_name, torrent_hash, "获取种子文件失败,下载任务可能在暂停状态" + return downloader or self.get_default_config_name(), torrent_hash, "获取种子文件失败,下载任务可能在暂停状态" # 需要的文件信息 file_ids = [] unwanted_file_ids = [] @@ -182,9 +182,9 @@ class TransmissionModule(_ModuleBase, _DownloaderBase[Transmission]): server.set_unwanted_files(torrent_hash, unwanted_file_ids) # 开始任务 server.start_torrents(torrent_hash) - return downloader or self._default_server_name, torrent_hash, "添加下载任务成功" + return downloader or self.get_default_config_name(), torrent_hash, "添加下载任务成功" else: - return downloader or self._default_server_name, torrent_hash, "添加下载任务成功" + return downloader or self.get_default_config_name(), torrent_hash, "添加下载任务成功" def list_torrents(self, status: TorrentStatus = None, hashs: Union[list, str] = None, diff --git a/app/modules/vocechat/__init__.py b/app/modules/vocechat/__init__.py index 3e5b6954..9017bd19 100644 --- a/app/modules/vocechat/__init__.py +++ b/app/modules/vocechat/__init__.py @@ -84,14 +84,8 @@ class VoceChatModule(_ModuleBase, _MessageBase[VoceChat]): "target": { "gid": 2 } //发送给谁,gid代表是发送给频道,uid代表是发送给个人,此时的数据结构举例:{"uid":1} } """ - # 获取渠道 - client_config = None - if source: - client_config = self.get_config(source) - else: - client_configs = self.get_configs() - if client_configs: - client_config = list(client_configs.values())[0] + # 获取服务配置 + client_config = self.get_config(source) if not client_config: return None # 报文体 diff --git a/app/modules/wechat/__init__.py b/app/modules/wechat/__init__.py index db02ac80..40051aa2 100644 --- a/app/modules/wechat/__init__.py +++ b/app/modules/wechat/__init__.py @@ -71,14 +71,8 @@ class WechatModule(_ModuleBase, _MessageBase[WeChat]): :return: 渠道、消息体 """ try: - # 获取客户端 - client_config = None - if source: - client_config = self.get_config(source) - else: - client_configs = self.get_configs() - if client_configs: - client_config = list(client_configs.values())[0] + # 获取服务配置 + client_config = self.get_config(source) if not client_config: return None client: WeChat = self.get_instance(client_config.name)