diff --git a/app/chain/__init__.py b/app/chain/__init__.py index d1b97465..21a81aee 100644 --- a/app/chain/__init__.py +++ b/app/chain/__init__.py @@ -16,11 +16,13 @@ from app.core.event import EventManager from app.core.meta import MetaBase from app.core.module import ModuleManager from app.db.message_oper import MessageOper +from app.db.user_oper import UserOper from app.helper.message import MessageHelper +from app.helper.notification import NotificationHelper from app.log import logger from app.schemas import TransferInfo, TransferTorrent, ExistMediaInfo, DownloadingTorrent, CommingMessage, Notification, \ WebhookEventInfo, TmdbEpisode, MediaPerson, FileItem -from app.schemas.types import TorrentStatus, MediaType, MediaImageType, EventType +from app.schemas.types import TorrentStatus, MediaType, MediaImageType, EventType, NotificationType from app.utils.object import ObjectUtils @@ -37,6 +39,8 @@ class ChainBase(metaclass=ABCMeta): self.eventmanager = EventManager() self.messageoper = MessageOper() self.messagehelper = MessageHelper() + self.notificationhelper = NotificationHelper() + self.useroper = UserOper() @staticmethod def load_cache(filename: str) -> Any: @@ -454,12 +458,31 @@ class ChainBase(metaclass=ABCMeta): :param message: 消息体 :return: 成功或失败 """ - # TODO 根据消息场景开关决定发给谁 logger.info(f"发送消息:channel={message.channel}," f"source={message.source}," f"title={message.title}, " f"text={message.text}," f"userid={message.userid}") + if not message.userid and message.mtype: + # 没有指定用户ID时,按规则确定发送对象 + # 默认发送全体 + to_targets = {} + notify_action = self.notificationhelper.get_switch(message.mtype) + if notify_action == "admin": + # 仅发送管理员 + logger.info(f"已设置 {message.mtype} 的消息只发送给管理员") + to_targets = self.useroper.get_settings(settings.SUPERUSER) + elif notify_action == "user": + # 发送对应用户 + if message.username: + logger.info(f"已设置 {message.mtype} 的消息只发送给用户 {message.username}") + to_targets = self.useroper.get_settings(message.username) + if not message.username or to_targets is None: + if message.username: + logger.info(f"没有 {message.username} 这个用户,该消息将发送给管理员") + # 回滚发送管理员 + to_targets = self.useroper.get_settings(settings.SUPERUSER) + message.targets = to_targets # 发送事件 self.eventmanager.send_event(etype=EventType.NoticeMessage, data=message.dict()) # 保存消息 diff --git a/app/chain/download.py b/app/chain/download.py index a78c8ab9..dd3de97d 100644 --- a/app/chain/download.py +++ b/app/chain/download.py @@ -39,18 +39,18 @@ class DownloadChain(ChainBase): self.messagehelper = MessageHelper() def post_download_message(self, meta: MetaBase, mediainfo: MediaInfo, torrent: TorrentInfo, - channel: MessageChannel = None, userid: str = None, username: str = None, + channel: MessageChannel = None, username: str = None, download_episodes: str = None): """ - 发送添加下载的消息 + 发送添加下载的消息,根据消息场景开关决定发给谁 :param meta: 元数据 :param mediainfo: 媒体信息 :param torrent: 种子信息 :param channel: 通知渠道 - :param userid: 用户ID,指定时精确发送对应用户 :param username: 通知显示的下载用户信息 :param download_episodes: 下载的集数 """ + # 拼装消息内容 msg_text = "" if username: msg_text = f"用户:{username}" @@ -82,15 +82,16 @@ class DownloadChain(ChainBase): torrent.description = re.sub(r'<[^>]+>', '', description) msg_text = f"{msg_text}\n描述:{torrent.description}" + # 下载成功按规则发送消息 self.post_message(Notification( channel=channel, mtype=NotificationType.Download, - userid=userid, title=f"{mediainfo.title_year} " f"{'%s %s' % (meta.season, download_episodes) if download_episodes else meta.season_episode} 开始下载", text=msg_text, image=mediainfo.get_message_image(), - link=settings.MP_DOMAIN('/#/downloading'))) + link=settings.MP_DOMAIN('/#/downloading'), + username=username)) def download_torrent(self, torrent: TorrentInfo, channel: MessageChannel = None, @@ -339,7 +340,7 @@ class DownloadChain(ChainBase): if files_to_add: self.downloadhis.add_files(files_to_add) - # 发送消息 TODO 根据消息场景开关决定发给谁 + # 下载成功发送消息 self.post_download_message(meta=_meta, mediainfo=_media, torrent=_torrent, username=username, download_episodes=download_episodes) # 下载成功后处理 diff --git a/app/chain/subscribe.py b/app/chain/subscribe.py index 3bdab2b1..4f151a15 100644 --- a/app/chain/subscribe.py +++ b/app/chain/subscribe.py @@ -163,7 +163,7 @@ class SubscribeChain(ChainBase): if not sid: logger.error(f'{mediainfo.title_year} {err_msg}') if not exist_ok and message: - # 发回原用户 + # 失败发回原用户 self.post_message(Notification(channel=channel, source=source, mtype=NotificationType.Subscribe, @@ -183,11 +183,13 @@ class SubscribeChain(ChainBase): link = settings.MP_DOMAIN('#/subscribe-tv?tab=mysub') else: link = settings.MP_DOMAIN('#/subscribe-movie?tab=mysub') + # 订阅成功按规则发送消息 self.post_message(Notification(mtype=NotificationType.Subscribe, title=f"{mediainfo.title_year} {metainfo.season} 已添加订阅", text=text, image=mediainfo.get_message_image(), - link=link)) + link=link, + username=username)) # 发送事件 EventManager().send_event(EventType.SubscribeAdded, { "subscribe_id": sid, @@ -926,10 +928,12 @@ class SubscribeChain(ChainBase): link = settings.MP_DOMAIN('#/subscribe-tv?tab=mysub') else: link = settings.MP_DOMAIN('#/subscribe-movie?tab=mysub') + # 完成订阅按规则发送消息 self.post_message(Notification(mtype=NotificationType.Subscribe, title=f'{mediainfo.title_year} {meta.season} 已完成{msgstr}', image=mediainfo.get_message_image(), - link=link)) + link=link, + username=subscribe.username)) # 发送事件 EventManager().send_event(EventType.SubscribeComplete, { "subscribe_id": subscribe.id, diff --git a/app/db/models/user.py b/app/db/models/user.py index 84a5ea5d..39332f1c 100644 --- a/app/db/models/user.py +++ b/app/db/models/user.py @@ -1,9 +1,8 @@ -from typing import Tuple, Optional +from typing import Tuple, Any from sqlalchemy import Boolean, Column, Integer, String, Sequence from sqlalchemy.orm import Session -from app import schemas from app.core.security import verify_password from app.db import db_query, db_update, Base from app.utils.otp import OtpUtils @@ -39,14 +38,14 @@ class User(Base): @staticmethod @db_query def authenticate(db: Session, name: str, password: str, - otp_password: str) -> Tuple[bool, Optional[schemas.User]]: + otp_password: str) -> Tuple[bool, Any]: user = db.query(User).filter(User.name == name).first() if not user: return False, None if not verify_password(password, str(user.hashed_password)): return False, user if user.is_otp: - if not otp_password or not OtpUtils.check(user.otp_secret, otp_password): + if not otp_password or not OtpUtils.check(str(user.otp_secret), otp_password): return False, user return True, user diff --git a/app/db/user_oper.py b/app/db/user_oper.py new file mode 100644 index 00000000..2c2abce2 --- /dev/null +++ b/app/db/user_oper.py @@ -0,0 +1,46 @@ +import json +from typing import Optional + +from app.db import DbOper +from app.db.models.user import User + + +class UserOper(DbOper): + """ + 用户管理 + """ + + def get_permissions(self, name: str) -> dict: + """ + 获取用户权限 + """ + user = User.get_by_name(self._db, name) + if user: + try: + return json.loads(user.permissions) + except json.JSONDecodeError: + return {} + return {} + + def get_settings(self, name: str) -> Optional[dict]: + """ + 获取用户个性化设置,返回None表示用户不存在 + """ + user = User.get_by_name(self._db, name) + if user: + try: + if user.settings: + return json.loads(user.settings) + return {} + except json.JSONDecodeError: + return {} + return None + + def get_setting(self, name: str, key: str) -> Optional[str]: + """ + 获取用户个性化设置 + """ + settings = self.get_settings(name) + if settings: + return settings.get(key) + return None diff --git a/app/helper/notification.py b/app/helper/notification.py index dc58e0b5..61b20c17 100644 --- a/app/helper/notification.py +++ b/app/helper/notification.py @@ -1,8 +1,8 @@ -from typing import List +from typing import List, Optional from app.db.systemconfig_oper import SystemConfigOper from app.schemas import NotificationConf, NotificationSwitchConf -from app.schemas.types import SystemConfigKey +from app.schemas.types import SystemConfigKey, NotificationType class NotificationHelper: @@ -22,7 +22,7 @@ class NotificationHelper: return [] return [NotificationConf(**conf) for conf in client_confs] - def get_switchs(self) -> List[dict]: + def get_switchs(self) -> List[NotificationSwitchConf]: """ 获取消息通知场景开关 """ @@ -30,3 +30,13 @@ class NotificationHelper: if not switchs: return [] return [NotificationSwitchConf(**switch) for switch in switchs] + + def get_switch(self, mtype: NotificationType) -> Optional[str]: + """ + 获取消息通知场景开关 + """ + switchs = self.get_switchs() + for switch in switchs: + if switch.type == mtype.value: + return switch.action + return None diff --git a/app/modules/__init__.py b/app/modules/__init__.py index 734bb9eb..9c19d7e2 100644 --- a/app/modules/__init__.py +++ b/app/modules/__init__.py @@ -1,5 +1,7 @@ from abc import abstractmethod, ABCMeta -from typing import Tuple, Union +from typing import Tuple, Union, Dict, Any, Optional + +from app.schemas import Notification, MessageChannel, NotificationConf class _ModuleBase(metaclass=ABCMeta): @@ -45,3 +47,48 @@ class _ModuleBase(metaclass=ABCMeta): 模块测试, 返回测试结果和错误信息 """ pass + + +class _MessageBase: + """ + 消息基类 + """ + + _channel: MessageChannel = None + _configs: Dict[str, NotificationConf] = {} + _clients: Dict[str, Any] = {} + + def get_client(self, name: str) -> Optional[Any]: + """ + 获取客户端 + """ + if not name: + return None + return self._clients.get(name) + + def get_config(self, name: str) -> Optional[NotificationConf]: + """ + 获取配置 + """ + if not name: + return None + return self._configs.get(name) + + def checkMessage(self, message: Notification, source: str) -> bool: + """ + 检查消息渠道及消息类型,如不符合则不处理 + """ + # 检查消息渠道 + if message.channel and message.channel != self._channel: + return False + # 检查消息来源 + if message.source and message.source != source: + return False + # 检查消息类型开关 + if message.mtype: + conf = self.get_config(source) + if conf: + switchs = conf.switchs or [] + if message.mtype.value not in switchs: + return False + return True diff --git a/app/modules/slack/__init__.py b/app/modules/slack/__init__.py index 3802bbc4..c44d1e36 100644 --- a/app/modules/slack/__init__.py +++ b/app/modules/slack/__init__.py @@ -1,20 +1,17 @@ import json import re -from typing import Optional, Union, List, Tuple, Any, Dict +from typing import Optional, Union, List, Tuple, Any -from app.core.context import MediaInfo, Context from app.core.config import settings +from app.core.context import MediaInfo, Context from app.helper.notification import NotificationHelper from app.log import logger -from app.modules import _ModuleBase +from app.modules import _ModuleBase, _MessageBase from app.modules.slack.slack import Slack -from app.schemas import MessageChannel, CommingMessage, Notification, NotificationConf +from app.schemas import MessageChannel, CommingMessage, Notification -class SlackModule(_ModuleBase): - _channel = MessageChannel.Telegram - _configs: Dict[str, NotificationConf] = {} - _clients: Dict[str, Slack] = {} +class SlackModule(_ModuleBase, _MessageBase): def init_module(self) -> None: """ @@ -34,18 +31,6 @@ class SlackModule(_ModuleBase): def get_name() -> str: return "Slack" - def get_client(self, name: str) -> Optional[Slack]: - """ - 获取Telegram客户端 - """ - return self._clients.get(name) - - def get_config(self, name: str) -> Optional[NotificationConf]: - """ - 获取Telegram配置 - """ - return self._configs.get(name) - def stop(self): """ 停止模块 @@ -66,32 +51,14 @@ class SlackModule(_ModuleBase): def init_setting(self) -> Tuple[str, Union[str, bool]]: pass - def checkMessage(self, message: Notification, source: str) -> bool: - """ - 检查消息渠道及消息类型,如不符合则不处理 - """ - # 检查消息渠道 - if message.channel and message.channel != self._channel: - return False - # 检查消息来源 - if message.source and message.source != source: - return False - # 检查消息类型开关 - if message.mtype: - conf = self.get_config(source) - if conf: - switchs = conf.switchs or [] - if message.mtype.value not in switchs: - return False - return True - - def message_parser(self, body: Any, form: Any, + def message_parser(self, source: str, body: Any, form: Any, args: Any) -> Optional[CommingMessage]: """ 解析消息内容,返回字典,注意以下约定值: userid: 用户ID username: 用户名 text: 内容 + :param source: 消息来源 :param body: 请求体 :param form: 表单 :param args: 参数 @@ -207,11 +174,8 @@ class SlackModule(_ModuleBase): } """ # 来源 - source = args.get("source") - if not source: - return None # 获取客户端 - client = self.get_client(source) + client: Slack = self.get_client(source) if not client: return None # 校验token @@ -260,10 +224,17 @@ class SlackModule(_ModuleBase): for conf in self._configs.values(): if not self.checkMessage(message, conf.name): continue - client = self.get_client(conf.name) + targets = message.targets + userid = message.userid + if not userid and targets is not None: + userid = targets.get('slack_userid') + if not userid: + logger.warn(f"用户没有指定 Slack用户ID,消息无法发送") + return + client: Slack = self.get_client(conf.name) if client: client.send_msg(title=message.title, text=message.text, - image=message.image, userid=message.userid, link=message.link) + image=message.image, userid=userid, link=message.link) def post_medias_message(self, message: Notification, medias: List[MediaInfo]) -> None: """ @@ -275,7 +246,7 @@ class SlackModule(_ModuleBase): for conf in self._configs.values(): if not self.checkMessage(message, conf.name): continue - client = self.get_client(conf.name) + client: Slack = self.get_client(conf.name) if client: client.send_meidas_msg(title=message.title, medias=medias, userid=message.userid) @@ -289,7 +260,7 @@ class SlackModule(_ModuleBase): for conf in self._configs.values(): if not self.checkMessage(message, conf.name): continue - client = self.get_client(conf.name) + client: Slack = self.get_client(conf.name) if client: client.send_torrents_msg(title=message.title, torrents=torrents, userid=message.userid) diff --git a/app/modules/synologychat/__init__.py b/app/modules/synologychat/__init__.py index 3293ab48..e6e70adc 100644 --- a/app/modules/synologychat/__init__.py +++ b/app/modules/synologychat/__init__.py @@ -1,17 +1,14 @@ -from typing import Optional, Union, List, Tuple, Any, Dict +from typing import Optional, Union, List, Tuple, Any from app.core.context import MediaInfo, Context from app.helper.notification import NotificationHelper from app.log import logger -from app.modules import _ModuleBase +from app.modules import _ModuleBase, _MessageBase from app.modules.synologychat.synologychat import SynologyChat -from app.schemas import MessageChannel, CommingMessage, Notification, NotificationConf +from app.schemas import MessageChannel, CommingMessage, Notification -class SynologyChatModule(_ModuleBase): - _channel = MessageChannel.Telegram - _configs: Dict[str, NotificationConf] = {} - _clients: Dict[str, SynologyChat] = {} +class SynologyChatModule(_ModuleBase, _MessageBase): def init_module(self) -> None: """ @@ -31,18 +28,6 @@ class SynologyChatModule(_ModuleBase): def get_name() -> str: return "Synology Chat" - def get_client(self, name: str) -> Optional[SynologyChat]: - """ - 获取Telegram客户端 - """ - return self._clients.get(name) - - def get_config(self, name: str) -> Optional[NotificationConf]: - """ - 获取Telegram配置 - """ - return self._configs.get(name) - def stop(self): pass @@ -59,32 +44,14 @@ class SynologyChatModule(_ModuleBase): def init_setting(self) -> Tuple[str, Union[str, bool]]: pass - def checkMessage(self, message: Notification, source: str) -> bool: - """ - 检查消息渠道及消息类型,如不符合则不处理 - """ - # 检查消息渠道 - if message.channel and message.channel != self._channel: - return False - # 检查消息来源 - if message.source and message.source != source: - return False - # 检查消息类型开关 - if message.mtype: - conf = self.get_config(source) - if conf: - switchs = conf.switchs or [] - if message.mtype.value not in switchs: - return False - return True - - def message_parser(self, body: Any, form: Any, + def message_parser(self, source: str, body: Any, form: Any, args: Any) -> Optional[CommingMessage]: """ 解析消息内容,返回字典,注意以下约定值: userid: 用户ID username: 用户名 text: 内容 + :param source: 消息来源 :param body: 请求体 :param form: 表单 :param args: 参数 @@ -92,10 +59,7 @@ class SynologyChatModule(_ModuleBase): """ try: # 来源 - source = args.get("source") - if not source: - return None - client = self.get_client(source) + client: SynologyChat = self.get_client(source) if not client: return None # 解析消息 @@ -129,10 +93,17 @@ class SynologyChatModule(_ModuleBase): for conf in self._configs.values(): if not self.checkMessage(message, conf.name): continue - client = self.get_client(conf.name) + targets = message.targets + userid = message.userid + if not userid and targets is not None: + userid = targets.get('synologychat_userid') + if not userid: + logger.warn(f"用户没有指定 SynologyChat用户ID,消息无法发送") + return + client: SynologyChat = self.get_client(conf.name) if client: client.send_msg(title=message.title, text=message.text, - image=message.image, userid=message.userid, link=message.link) + image=message.image, userid=userid, link=message.link) def post_medias_message(self, message: Notification, medias: List[MediaInfo]) -> None: """ @@ -144,7 +115,7 @@ class SynologyChatModule(_ModuleBase): for conf in self._configs.values(): if not self.checkMessage(message, conf.name): continue - client = self.get_client(conf.name) + client: SynologyChat = self.get_client(conf.name) if client: client.send_meidas_msg(title=message.title, medias=medias, userid=message.userid) @@ -159,7 +130,7 @@ class SynologyChatModule(_ModuleBase): for conf in self._configs.values(): if not self.checkMessage(message, conf.name): continue - client = self.get_client(conf.name) + client: SynologyChat = self.get_client(conf.name) if client: client.send_torrents_msg(title=message.title, torrents=torrents, userid=message.userid, link=message.link) diff --git a/app/modules/telegram/__init__.py b/app/modules/telegram/__init__.py index 876f38c9..cf8fcf14 100644 --- a/app/modules/telegram/__init__.py +++ b/app/modules/telegram/__init__.py @@ -5,15 +5,12 @@ from app.core.context import MediaInfo, Context from app.core.config import settings from app.helper.notification import NotificationHelper from app.log import logger -from app.modules import _ModuleBase +from app.modules import _ModuleBase, _MessageBase from app.modules.telegram.telegram import Telegram from app.schemas import MessageChannel, CommingMessage, Notification, NotificationConf -class TelegramModule(_ModuleBase): - _channel = MessageChannel.Telegram - _configs: Dict[str, NotificationConf] = {} - _clients: Dict[str, Telegram] = {} +class TelegramModule(_ModuleBase, _MessageBase): def init_module(self) -> None: """ @@ -33,18 +30,6 @@ class TelegramModule(_ModuleBase): def get_name() -> str: return "Telegram" - def get_client(self, name: str) -> Optional[Telegram]: - """ - 获取Telegram客户端 - """ - return self._clients.get(name) - - def get_config(self, name: str) -> Optional[NotificationConf]: - """ - 获取Telegram配置 - """ - return self._configs.get(name) - def stop(self): """ 停止模块 @@ -65,25 +50,6 @@ class TelegramModule(_ModuleBase): def init_setting(self) -> Tuple[str, Union[str, bool]]: pass - def checkMessage(self, message: Notification, source: str) -> bool: - """ - 检查消息渠道及消息类型,如不符合则不处理 - """ - # 检查消息渠道 - if message.channel and message.channel != self._channel: - return False - # 检查消息来源 - if message.source and message.source != source: - return False - # 检查消息类型开关 - if message.mtype: - conf = self.get_config(source) - if conf: - switchs = conf.switchs or [] - if message.mtype.value not in switchs: - return False - return True - def message_parser(self, source: str, body: Any, form: Any, args: Any) -> Optional[CommingMessage]: """ @@ -91,7 +57,7 @@ class TelegramModule(_ModuleBase): userid: 用户ID username: 用户名 text: 内容 - :param source: 消息来源(渠道配置名称) + :param source: 消息来源 :param body: 请求体 :param form: 表单 :param args: 参数 @@ -121,7 +87,7 @@ class TelegramModule(_ModuleBase): } """ # 获取渠道 - client = self.get_client(source) + client: Telegram = self.get_client(source) if not client: return None # 获取配置 @@ -173,10 +139,17 @@ class TelegramModule(_ModuleBase): for conf in self._configs.values(): if not self.checkMessage(message, conf.name): continue - client = self.get_client(conf.name) + targets = message.targets + userid = message.userid + if not userid and targets is not None: + userid = targets.get('telegram_userid') + if not userid: + logger.warn(f"用户没有指定 Telegram用户ID,消息无法发送") + return + client: Telegram = self.get_client(conf.name) if client: client.send_msg(title=message.title, text=message.text, - image=message.image, userid=message.userid, link=message.link) + image=message.image, userid=userid, link=message.link) def post_medias_message(self, message: Notification, medias: List[MediaInfo]) -> None: """ @@ -188,7 +161,7 @@ class TelegramModule(_ModuleBase): for conf in self._configs.values(): if not self.checkMessage(message, conf.name): continue - client = self.get_client(conf.name) + client: Telegram = self.get_client(conf.name) if client: client.send_meidas_msg(title=message.title, medias=medias, userid=message.userid, link=message.link) @@ -203,7 +176,7 @@ class TelegramModule(_ModuleBase): for conf in self._configs.values(): if not self.checkMessage(message, conf.name): continue - client = self.get_client(conf.name) + client: Telegram = self.get_client(conf.name) if client: client.send_torrents_msg(title=message.title, torrents=torrents, userid=message.userid, link=message.link) diff --git a/app/modules/vocechat/__init__.py b/app/modules/vocechat/__init__.py index b7dd3881..3e0e2496 100644 --- a/app/modules/vocechat/__init__.py +++ b/app/modules/vocechat/__init__.py @@ -5,24 +5,25 @@ from app.core.config import settings from app.core.context import Context, MediaInfo from app.helper.notification import NotificationHelper from app.log import logger -from app.modules import _ModuleBase, checkMessage +from app.modules import _ModuleBase, _MessageBase from app.modules.vocechat.vocechat import VoceChat from app.schemas import MessageChannel, CommingMessage, Notification -class VoceChatModule(_ModuleBase): - _clients: Dict[str, VoceChat] = {} +class VoceChatModule(_ModuleBase, _MessageBase): def init_module(self) -> None: """ 初始化模块 """ - self._clients = {} clients = NotificationHelper().get_clients() if not clients: return + self._configs = {} + self._clients = {} for client in clients: if client.type == "vocechat" and client.enabled: + self._configs[client.name] = client self._clients[client.name] = VoceChat(**client.config) @staticmethod @@ -36,22 +37,23 @@ class VoceChatModule(_ModuleBase): """ 测试模块连接性 """ - state = self.vocechat.get_state() - if state: - return True, "" - return False, "获取VoceChat频道失败" + for name, client in self._clients.items(): + state = client.get_state() + if not state: + return False, f"VoceChat {name} 未就续" + return True, "" def init_setting(self) -> Tuple[str, Union[str, bool]]: pass - @staticmethod - def message_parser(body: Any, form: Any, + def message_parser(self, source: str, body: Any, form: Any, args: Any) -> Optional[CommingMessage]: """ 解析消息内容,返回字典,注意以下约定值: userid: 用户ID username: 用户名 text: 内容 + :param source: 消息来源 :param body: 请求体 :param form: 表单 :param args: 参数 @@ -73,6 +75,14 @@ class VoceChatModule(_ModuleBase): "target": { "gid": 2 } //发送给谁,gid代表是发送给频道,uid代表是发送给个人,此时的数据结构举例:{"uid":1} } """ + # 获取渠道 + client: VoceChat = self.get_client(source) + if not client: + return None + # 获取配置 + config = self.get_config(source) + if not config: + return None # 报文体 msg_body = json.loads(body) # 类型 @@ -90,7 +100,8 @@ class VoceChatModule(_ModuleBase): content = msg_body.get("detail", {}).get("content") # 用户ID gid = msg_body.get("target", {}).get("gid") - if gid and str(gid) == str(settings.VOCECHAT_CHANNEL_ID): + channel_id = config.config.get("channel_id") + if gid and str(gid) == str(channel_id): # 来自监听频道的消息 userid = f"GID#{gid}" else: @@ -106,40 +117,61 @@ class VoceChatModule(_ModuleBase): logger.error(f"VoceChat消息处理发生错误:{str(err)}") return None - @checkMessage(MessageChannel.VoceChat) def post_message(self, message: Notification) -> None: """ 发送消息 :param message: 消息内容 :return: 成功或失败 """ - self.vocechat.send_msg(title=message.title, text=message.text, - userid=message.userid, link=message.link) + for conf in self._configs.values(): + if not self.checkMessage(message, conf.name): + continue + targets = message.targets + userid = message.userid + if not message.userid and targets: + userid = targets.get('telegram_userid') + client: VoceChat = self.get_client(conf.name) + if client: + client.send_msg(title=message.title, text=message.text, + userid=userid, link=message.link) - @checkMessage(MessageChannel.VoceChat) - def post_medias_message(self, message: Notification, medias: List[MediaInfo]) -> Optional[bool]: + def post_medias_message(self, message: Notification, medias: List[MediaInfo]) -> None: """ 发送媒体信息选择列表 :param message: 消息内容 :param medias: 媒体列表 :return: 成功或失败 """ - # 先发送标题 - self.vocechat.send_msg(title=message.title, userid=message.userid) - # 再发送内容 - return self.vocechat.send_medias_msg(title=message.title, medias=medias, - userid=message.userid, link=message.link) + for conf in self._configs.values(): + if not self.checkMessage(message, conf.name): + continue + client: VoceChat = self.get_client(conf.name) + if client: + client.send_msg(title=message.title, userid=message.userid) + client.send_medias_msg(title=message.title, medias=medias, + userid=message.userid, link=message.link) - @checkMessage(MessageChannel.VoceChat) - def post_torrents_message(self, message: Notification, torrents: List[Context]) -> Optional[bool]: + def post_torrents_message(self, message: Notification, torrents: List[Context]) -> None: """ 发送种子信息选择列表 :param message: 消息内容 :param torrents: 种子列表 :return: 成功或失败 """ - return self.vocechat.send_torrents_msg(title=message.title, torrents=torrents, - userid=message.userid, link=message.link) + for conf in self._configs.values(): + if not self.checkMessage(message, conf.name): + continue + targets = message.targets + userid = message.userid + if not userid and targets is not None: + userid = targets.get('vocechat_userid') + if not userid: + logger.warn(f"用户没有指定 VoceChat用户ID,消息无法发送") + return + client: VoceChat = self.get_client(conf.name) + if client: + client.send_torrents_msg(title=message.title, torrents=torrents, + userid=userid, link=message.link) def register_commands(self, commands: Dict[str, dict]): pass diff --git a/app/modules/wechat/__init__.py b/app/modules/wechat/__init__.py index 16ee2ec0..0464e9e7 100644 --- a/app/modules/wechat/__init__.py +++ b/app/modules/wechat/__init__.py @@ -5,17 +5,14 @@ from app.core.config import settings from app.core.context import Context, MediaInfo from app.helper.notification import NotificationHelper from app.log import logger -from app.modules import _ModuleBase +from app.modules import _ModuleBase, _MessageBase from app.modules.wechat.WXBizMsgCrypt3 import WXBizMsgCrypt from app.modules.wechat.wechat import WeChat -from app.schemas import MessageChannel, CommingMessage, Notification, NotificationConf +from app.schemas import MessageChannel, CommingMessage, Notification from app.utils.dom import DomUtils -class WechatModule(_ModuleBase): - _channel = MessageChannel.Wechat - _configs: Dict[str, NotificationConf] = {} - _clients: Dict[str, WeChat] = {} +class WechatModule(_ModuleBase, _MessageBase): def init_module(self) -> None: """ @@ -35,18 +32,6 @@ class WechatModule(_ModuleBase): def get_name() -> str: return "微信" - def get_client(self, name: str) -> Optional[WeChat]: - """ - 获取Telegram客户端 - """ - return self._clients.get(name) - - def get_config(self, name: str) -> Optional[NotificationConf]: - """ - 获取Telegram配置 - """ - return self._configs.get(name) - def stop(self): pass @@ -63,25 +48,22 @@ class WechatModule(_ModuleBase): def init_setting(self) -> Tuple[str, Union[str, bool]]: pass - def message_parser(self, body: Any, form: Any, + def message_parser(self, source: str, body: Any, form: Any, args: Any) -> Optional[CommingMessage]: """ 解析消息内容,返回字典,注意以下约定值: userid: 用户ID username: 用户名 text: 内容 + :param source: 消息来源 :param body: 请求体 :param form: 表单 :param args: 参数 :return: 渠道、消息体 """ try: - # 消息来源 - source = args.get("source") - if not source: - return None # 获取客户端 - client = self.get_client(source) + client: WeChat = self.get_client(source) if not client: return None # URL参数 @@ -168,25 +150,6 @@ class WechatModule(_ModuleBase): logger.error(f"微信消息处理发生错误:{str(err)}") return None - def checkMessage(self, message: Notification, source: str) -> bool: - """ - 检查消息渠道及消息类型,如不符合则不处理 - """ - # 检查消息渠道 - if message.channel and message.channel != self._channel: - return False - # 检查消息来源 - if message.source and message.source != source: - return False - # 检查消息类型开关 - if message.mtype: - conf = self.get_config(source) - if conf: - switchs = conf.switchs or [] - if message.mtype.value not in switchs: - return False - return True - def post_message(self, message: Notification) -> None: """ 发送消息 @@ -196,10 +159,17 @@ class WechatModule(_ModuleBase): for conf in self._configs.values(): if not self.checkMessage(message, conf.name): continue - client = self.get_client(conf.name) + targets = message.targets + userid = message.userid + if not userid and targets is not None: + userid = targets.get('wechat_userid') + if not userid: + logger.warn(f"用户没有指定 微信用户ID,消息无法发送") + return + client: WeChat = self.get_client(conf.name) if client: client.send_msg(title=message.title, text=message.text, - image=message.image, userid=message.userid, link=message.link) + image=message.image, userid=userid, link=message.link) def post_medias_message(self, message: Notification, medias: List[MediaInfo]) -> None: """ @@ -211,7 +181,7 @@ class WechatModule(_ModuleBase): for conf in self._configs.values(): if not self.checkMessage(message, conf.name): continue - client = self.get_client(conf.name) + client: WeChat = self.get_client(conf.name) if client: # 先发送标题 client.send_msg(title=message.title, userid=message.userid, link=message.link) @@ -228,7 +198,7 @@ class WechatModule(_ModuleBase): for conf in self._configs.values(): if not self.checkMessage(message, conf.name): continue - client = self.get_client(conf.name) + client: WeChat = self.get_client(conf.name) if client: client.send_torrents_msg(title=message.title, torrents=torrents, userid=message.userid, link=message.link) diff --git a/app/schemas/message.py b/app/schemas/message.py index 088bc491..6065aa8f 100644 --- a/app/schemas/message.py +++ b/app/schemas/message.py @@ -55,10 +55,14 @@ class Notification(BaseModel): link: Optional[str] = None # 用户ID userid: Optional[Union[str, int]] = None + # 用户名称 + username: Optional[str] = None # 时间 date: Optional[str] = None # 消息方向 action: Optional[int] = 1 + # 消息目标用户ID字典,未指定用户ID时使用 + targets: Optional[dict] = None def to_dict(self): """ diff --git a/app/schemas/system.py b/app/schemas/system.py index b7c2c8c2..09fbc8f8 100644 --- a/app/schemas/system.py +++ b/app/schemas/system.py @@ -2,8 +2,6 @@ from typing import Optional from pydantic import BaseModel -from app.schemas import NotificationType - class MediaServerConf(BaseModel): """ @@ -58,8 +56,8 @@ class NotificationSwitchConf(BaseModel): 通知场景开关配置 """ # 场景名称 - type: NotificationType = None - # 通知范围 all/user/admin/userandadmin + type: str = None + # 通知范围 all/user/admin action: Optional[str] = 'all'