From 15a7297099fb3dc9f286261d35f38c79bcf9cd26 Mon Sep 17 00:00:00 2001 From: jxxghp Date: Tue, 2 Jul 2024 13:50:41 +0800 Subject: [PATCH] fix messages --- app/api/endpoints/message.py | 3 +- app/chain/__init__.py | 37 ++----- app/chain/download.py | 25 +++-- app/chain/message.py | 63 ++++++++---- app/chain/subscribe.py | 2 + app/command.py | 27 +++--- app/db/message_oper.py | 3 + app/db/models/message.py | 2 + app/helper/notification.py | 13 ++- app/modules/__init__.py | 39 -------- app/modules/emby/__init__.py | 8 +- app/modules/jellyfin/__init__.py | 8 +- app/modules/plex/__init__.py | 8 +- app/modules/qbittorrent/__init__.py | 2 +- app/modules/slack/__init__.py | 119 ++++++++++++++++++----- app/modules/slack/slack.py | 4 + app/modules/synologychat/__init__.py | 110 ++++++++++++++++----- app/modules/telegram/__init__.py | 139 +++++++++++++++++++++------ app/modules/telegram/telegram.py | 3 + app/modules/transmission/__init__.py | 2 +- app/modules/vocechat/__init__.py | 16 ++- app/modules/wechat/__init__.py | 117 +++++++++++++++++----- app/schemas/__init__.py | 2 +- app/schemas/mediaserver.py | 12 --- app/schemas/message.py | 4 + app/schemas/system.py | 20 ++++ 26 files changed, 555 insertions(+), 233 deletions(-) diff --git a/app/api/endpoints/message.py b/app/api/endpoints/message.py index cbf2b618..984ac79e 100644 --- a/app/api/endpoints/message.py +++ b/app/api/endpoints/message.py @@ -50,6 +50,7 @@ def web_message(text: str, current_user: User = Depends(get_current_active_super """ MessageChain().handle_message( channel=MessageChannel.Web, + source=current_user.name, userid=current_user.name, username=current_user.name, text=text @@ -104,7 +105,7 @@ def vocechat_verify(token: str) -> Any: """ if token == settings.API_TOKEN: return {"status": "OK"} - return {"status": "ERROR"} + return {"status": "API_TOKEN ERROR"} @router.get("/", summary="回调请求验证") diff --git a/app/chain/__init__.py b/app/chain/__init__.py index d2f53397..661766f7 100644 --- a/app/chain/__init__.py +++ b/app/chain/__init__.py @@ -249,19 +249,20 @@ class ChainBase(metaclass=ABCMeta): """ return self.run_module("bangumi_info", bangumiid=bangumiid) - def message_parser(self, body: Any, form: Any, + def message_parser(self, source: str, body: Any, form: Any, args: Any) -> Optional[CommingMessage]: """ 解析消息内容,返回字典,注意以下约定值: userid: 用户ID username: 用户名 text: 内容 + :param source: 消息来源(渠道配置名称) :param body: 请求体 :param form: 表单 :param args: 参数 :return: 消息渠道、消息内容 """ - return self.run_module("message_parser", body=body, form=form, args=args) + return self.run_module("message_parser", source=source, body=body, form=form, args=args) def webhook_parser(self, body: Any, form: Any, args: Any) -> Optional[WebhookEventInfo]: """ @@ -454,29 +455,19 @@ class ChainBase(metaclass=ABCMeta): :return: 成功或失败 """ logger.info(f"发送消息:channel={message.channel}," + f"source={message.source}," f"title={message.title}, " f"text={message.text}," f"userid={message.userid}") # 发送事件 - self.eventmanager.send_event(etype=EventType.NoticeMessage, - data={ - "channel": message.channel, - "type": message.mtype, - "title": message.title, - "text": message.text, - "image": message.image, - "userid": message.userid, - }) + self.eventmanager.send_event(etype=EventType.NoticeMessage, data=message.dict()) # 保存消息 self.messagehelper.put(message, role="user") - self.messageoper.add(channel=message.channel, mtype=message.mtype, - title=message.title, text=message.text, - image=message.image, link=message.link, - userid=message.userid, action=1) + self.messageoper.add(**message.dict(), action=1) # 发送 self.run_module("post_message", message=message) - def post_medias_message(self, message: Notification, medias: List[MediaInfo]) -> Optional[bool]: + def post_medias_message(self, message: Notification, medias: List[MediaInfo]) -> None: """ 发送媒体信息选择列表 :param message: 消息体 @@ -485,14 +476,10 @@ class ChainBase(metaclass=ABCMeta): """ note_list = [media.to_dict() for media in medias] self.messagehelper.put(message, role="user", note=note_list) - self.messageoper.add(channel=message.channel, mtype=message.mtype, - title=message.title, text=message.text, - image=message.image, link=message.link, - userid=message.userid, action=1, - note=note_list) + self.messageoper.add(**message.dict(), action=1, note=note_list) return self.run_module("post_medias_message", message=message, medias=medias) - def post_torrents_message(self, message: Notification, torrents: List[Context]) -> Optional[bool]: + def post_torrents_message(self, message: Notification, torrents: List[Context]) -> None: """ 发送种子信息选择列表 :param message: 消息体 @@ -501,11 +488,7 @@ class ChainBase(metaclass=ABCMeta): """ note_list = [torrent.torrent_info.to_dict() for torrent in torrents] self.messagehelper.put(message, role="user", note=note_list) - self.messageoper.add(channel=message.channel, mtype=message.mtype, - title=message.title, text=message.text, - image=message.image, link=message.link, - userid=message.userid, action=1, - note=note_list) + self.messageoper.add(**message.dict(), action=1, note=note_list) return self.run_module("post_torrents_message", message=message, torrents=torrents) def metadata_img(self, mediainfo: MediaInfo, season: int = None) -> Optional[dict]: diff --git a/app/chain/download.py b/app/chain/download.py index 8a37f64c..47e2b09b 100644 --- a/app/chain/download.py +++ b/app/chain/download.py @@ -94,6 +94,7 @@ class DownloadChain(ChainBase): def download_torrent(self, torrent: TorrentInfo, channel: MessageChannel = None, + source: str = None, userid: Union[str, int] = None ) -> Tuple[Optional[Union[Path, str]], str, list]: """ @@ -187,6 +188,7 @@ class DownloadChain(ChainBase): logger.error(f"下载种子文件失败:{torrent.title} - {torrent_url}") self.post_message(Notification( channel=channel, + source=source, mtype=NotificationType.Manual, title=f"{torrent.title} 种子下载失败!", text=f"错误信息:{error_msg}\n站点:{torrent.site_name}", @@ -198,7 +200,7 @@ class DownloadChain(ChainBase): def download_single(self, context: Context, torrent_file: Path = None, episodes: Set[int] = None, - channel: MessageChannel = None, + channel: MessageChannel = None, source: str = None, save_path: str = None, userid: Union[str, int] = None, username: str = None) -> Optional[str]: @@ -208,6 +210,7 @@ class DownloadChain(ChainBase): :param torrent_file: 种子文件路径 :param episodes: 需要下载的集数 :param channel: 通知渠道 + :param source: 通知来源 :param save_path: 保存路径 :param userid: 用户ID :param username: 调用下载的用户名/插件名 @@ -230,6 +233,7 @@ class DownloadChain(ChainBase): # 下载种子文件,得到的可能是文件也可能是磁力链 content, _folder_name, _file_list = self.download_torrent(_torrent, channel=channel, + source=source, userid=userid) if not content: return None @@ -352,6 +356,7 @@ class DownloadChain(ChainBase): # 只发送给对应渠道和用户 self.post_message(Notification( channel=channel, + source=source, mtype=NotificationType.Manual, title="添加下载任务失败:%s %s" % (_media.title_year, _meta.season_episode), @@ -367,6 +372,7 @@ class DownloadChain(ChainBase): no_exists: Dict[Union[int, str], Dict[int, NotExistMediaInfo]] = None, save_path: str = None, channel: MessageChannel = None, + source: str = None, userid: str = None, username: str = None ) -> Tuple[List[Context], Dict[Union[int, str], Dict[int, NotExistMediaInfo]]]: @@ -376,6 +382,7 @@ class DownloadChain(ChainBase): :param no_exists: 缺失的剧集信息 :param save_path: 保存路径 :param channel: 通知渠道 + :param source: 通知来源 :param userid: 用户ID :param username: 调用下载的用户名/插件名 :return: 已经下载的资源列表、剩余未下载到的剧集 no_exists[tmdb_id/douban_id] = {season: NotExistMediaInfo} @@ -446,7 +453,7 @@ class DownloadChain(ChainBase): if context.media_info.type == MediaType.MOVIE: logger.info(f"开始下载电影 {context.torrent_info.title} ...") if self.download_single(context, save_path=save_path, channel=channel, - userid=userid, username=username): + source=source, userid=userid, username=username): # 下载成功 logger.info(f"{context.torrent_info.title} 添加下载成功") downloaded_list.append(context) @@ -526,14 +533,15 @@ class DownloadChain(ChainBase): torrent_file=content if isinstance(content, Path) else None, save_path=save_path, channel=channel, + source=source, userid=userid, username=username ) else: # 下载 logger.info(f"开始下载 {torrent.title} ...") - download_id = self.download_single(context, - save_path=save_path, channel=channel, + download_id = self.download_single(context, save_path=save_path, + channel=channel, source=source, userid=userid, username=username) if download_id: @@ -600,8 +608,8 @@ class DownloadChain(ChainBase): if torrent_episodes.issubset(set(need_episodes)): # 下载 logger.info(f"开始下载 {meta.title} ...") - download_id = self.download_single(context, - save_path=save_path, channel=channel, + download_id = self.download_single(context, save_path=save_path, + channel=channel, source=source, userid=userid, username=username) if download_id: # 下载成功 @@ -686,6 +694,7 @@ class DownloadChain(ChainBase): episodes=selected_episodes, save_path=save_path, channel=channel, + source=source, userid=userid, username=username ) @@ -839,7 +848,7 @@ class DownloadChain(ChainBase): # 全部存在 return True, no_exists - def remote_downloading(self, channel: MessageChannel, userid: Union[str, int] = None): + def remote_downloading(self, channel: MessageChannel, source: str, userid: Union[str, int] = None): """ 查询正在下载的任务,并发送消息 """ @@ -847,6 +856,7 @@ class DownloadChain(ChainBase): if not torrents: self.post_message(Notification( channel=channel, + source=source, mtype=NotificationType.Download, title="没有正在下载的任务!", userid=userid, @@ -864,6 +874,7 @@ class DownloadChain(ChainBase): index += 1 self.post_message(Notification( channel=channel, + source=source, mtype=NotificationType.Download, title=title, text="\n".join(messages), diff --git a/app/chain/message.py b/app/chain/message.py index 4c2644dc..3973d2f0 100644 --- a/app/chain/message.py +++ b/app/chain/message.py @@ -106,8 +106,10 @@ class MessageChain(ChainBase): """ 调用模块识别消息内容 """ + # 消息来源 + source = args.get("source") # 获取消息内容 - info = self.message_parser(body=body, form=form, args=args) + info = self.message_parser(source=source, body=body, form=form, args=args) if not info: return # 渠道 @@ -125,9 +127,10 @@ class MessageChain(ChainBase): logger.debug(f'未识别到消息内容::{body}{form}{args}') return # 处理消息 - self.handle_message(channel=channel, userid=userid, username=username, text=text) + self.handle_message(channel=channel, source=source, userid=userid, username=username, text=text) - def handle_message(self, channel: MessageChannel, userid: Union[str, int], username: str, text: str) -> None: + def handle_message(self, channel: MessageChannel, source: str, + userid: Union[str, int], username: str, text: str) -> None: """ 识别消息内容,执行操作 """ @@ -143,10 +146,12 @@ class MessageChain(ChainBase): userid=userid, username=username, channel=channel, + source=source, text=text ), role="user") self.messageoper.add( channel=channel, + source=source, userid=username or userid, text=text, action=0 @@ -159,7 +164,8 @@ class MessageChain(ChainBase): { "cmd": text, "user": userid, - "channel": channel + "channel": channel, + "source": source } ) @@ -172,7 +178,7 @@ class MessageChain(ChainBase): or not cache_data.get('items') \ or len(cache_data.get('items')) < int(text): # 发送消息 - self.post_message(Notification(channel=channel, title="输入有误!", userid=userid)) + self.post_message(Notification(channel=channel, source=source, title="输入有误!", userid=userid)) return # 选择的序号 _choice = int(text) + _current_page * self._page_size - 1 @@ -192,6 +198,7 @@ class MessageChain(ChainBase): # 媒体库中已存在 self.post_message( Notification(channel=channel, + source=source, title=f"【{_current_media.title_year}" f"{_current_meta.sea} 媒体库中已存在,如需重新下载请发送:搜索 名称 或 下载 名称】", userid=userid)) @@ -215,12 +222,14 @@ class MessageChain(ChainBase): for sea, no_exist in no_exists.get(mediakey).items()] if messages: self.post_message(Notification(channel=channel, + source=source, title=f"{mediainfo.title_year}:\n" + "\n".join(messages), userid=userid)) # 搜索种子,过滤掉不需要的剧集,以便选择 logger.info(f"开始搜索 {mediainfo.title_year} ...") self.post_message( Notification(channel=channel, + source=source, title=f"开始搜索 {mediainfo.type.value} {mediainfo.title_year} ...", userid=userid)) # 开始搜索 @@ -229,8 +238,10 @@ class MessageChain(ChainBase): if not contexts: # 没有数据 self.post_message(Notification( - channel=channel, title=f"{mediainfo.title}" - f"{_current_meta.sea} 未搜索到需要的资源!", + channel=channel, + source=source, + title=f"{mediainfo.title}" + f"{_current_meta.sea} 未搜索到需要的资源!", userid=userid)) return # 搜索结果排序 @@ -244,6 +255,7 @@ class MessageChain(ChainBase): logger.info(f"用户 {userid} 在自动下载用户中,开始自动择优下载 ...") # 自动选择下载 self.__auto_download(channel=channel, + source=source, cache_list=contexts, userid=userid, username=username, @@ -257,6 +269,7 @@ class MessageChain(ChainBase): # 发送种子数据 logger.info(f"搜索到 {len(contexts)} 条数据,开始发送选择消息 ...") self.__post_torrents_message(channel=channel, + source=source, title=mediainfo.title, items=contexts[:self._page_size], userid=userid, @@ -274,6 +287,7 @@ class MessageChain(ChainBase): if exist_flag: self.post_message(Notification( channel=channel, + source=source, title=f"【{mediainfo.title_year}" f"{_current_meta.sea} 媒体库中已存在,如需洗版请发送:洗版 XXX】", userid=userid)) @@ -287,6 +301,7 @@ class MessageChain(ChainBase): tmdbid=mediainfo.tmdb_id, season=_current_meta.begin_season, channel=channel, + source=source, userid=userid, username=username, best_version=best_version) @@ -294,6 +309,7 @@ class MessageChain(ChainBase): if int(text) == 0: # 自动选择下载,强制下载模式 self.__auto_download(channel=channel, + source=source, cache_list=cache_list, userid=userid, username=username) @@ -301,7 +317,7 @@ class MessageChain(ChainBase): # 下载种子 context: Context = cache_list[_choice] # 下载 - self.downloadchain.download_single(context, channel=channel, + self.downloadchain.download_single(context, channel=channel, source=source, userid=userid, username=username) elif text.lower() == "p": @@ -310,13 +326,13 @@ class MessageChain(ChainBase): if not cache_data: # 没有缓存 self.post_message(Notification( - channel=channel, title="输入有误!", userid=userid)) + channel=channel, source=source, title="输入有误!", userid=userid)) return if _current_page == 0: # 第一页 self.post_message(Notification( - channel=channel, title="已经是第一页了!", userid=userid)) + channel=channel, source=source, title="已经是第一页了!", userid=userid)) return # 减一页 _current_page -= 1 @@ -332,6 +348,7 @@ class MessageChain(ChainBase): if cache_type == "Torrent": # 发送种子数据 self.__post_torrents_message(channel=channel, + source=source, title=_current_media.title, items=cache_list[start:end], userid=userid, @@ -339,6 +356,7 @@ class MessageChain(ChainBase): else: # 发送媒体数据 self.__post_medias_message(channel=channel, + source=source, title=_current_meta.name, items=cache_list[start:end], userid=userid, @@ -350,7 +368,7 @@ class MessageChain(ChainBase): if not cache_data: # 没有缓存 self.post_message(Notification( - channel=channel, title="输入有误!", userid=userid)) + channel=channel, source=source, title="输入有误!", userid=userid)) return cache_type: str = cache_data.get('type') # 产生副本,避免修改原值 @@ -362,7 +380,7 @@ class MessageChain(ChainBase): if not cache_list: # 没有数据 self.post_message(Notification( - channel=channel, title="已经是最后一页了!", userid=userid)) + channel=channel, source=source, title="已经是最后一页了!", userid=userid)) return else: # 加一页 @@ -370,11 +388,13 @@ class MessageChain(ChainBase): if cache_type == "Torrent": # 发送种子数据 self.__post_torrents_message(channel=channel, + source=source, title=_current_media.title, items=cache_list, userid=userid, total=total) else: # 发送媒体数据 self.__post_medias_message(channel=channel, + source=source, title=_current_meta.name, items=cache_list, userid=userid, total=total) @@ -411,12 +431,12 @@ class MessageChain(ChainBase): # 识别 if not meta.name: self.post_message(Notification( - channel=channel, title="无法识别输入内容!", userid=userid)) + channel=channel, source=source, title="无法识别输入内容!", userid=userid)) return # 开始搜索 if not medias: self.post_message(Notification( - channel=channel, title=f"{meta.name} 没有找到对应的媒体信息!", userid=userid)) + channel=channel, source=source, title=f"{meta.name} 没有找到对应的媒体信息!", userid=userid)) return logger.info(f"搜索到 {len(medias)} 条相关媒体信息") # 记录当前状态 @@ -429,6 +449,7 @@ class MessageChain(ChainBase): _current_media = None # 发送媒体列表 self.__post_medias_message(channel=channel, + source=source, title=meta.name, items=medias[:self._page_size], userid=userid, total=len(medias)) @@ -439,14 +460,15 @@ class MessageChain(ChainBase): { "text": content, "userid": userid, - "channel": channel + "channel": channel, + "source": source } ) # 保存缓存 self.save_cache(user_cache, self._cache_file) - def __auto_download(self, channel: MessageChannel, cache_list: list[Context], + def __auto_download(self, channel: MessageChannel, source: str, cache_list: list[Context], userid: Union[str, int], username: str, no_exists: Optional[Dict[Union[int, str], Dict[int, NotExistMediaInfo]]] = None): """ @@ -466,6 +488,7 @@ class MessageChain(ChainBase): downloads, lefts = self.downloadchain.batch_download(contexts=cache_list, no_exists=no_exists, channel=channel, + source=source, userid=userid, username=username) if downloads and not lefts: @@ -488,12 +511,13 @@ class MessageChain(ChainBase): tmdbid=_current_media.tmdb_id, season=_current_meta.begin_season, channel=channel, + source=source, userid=userid, username=username, state="R", note=note) - def __post_medias_message(self, channel: MessageChannel, + def __post_medias_message(self, channel: MessageChannel, source: str, title: str, items: list, userid: str, total: int): """ 发送媒体列表消息 @@ -504,11 +528,13 @@ class MessageChain(ChainBase): title = f"【{title}】共找到{total}条相关信息,请回复对应数字选择" self.post_medias_message(Notification( channel=channel, + source=source, title=title, userid=userid ), medias=items) - def __post_torrents_message(self, channel: MessageChannel, title: str, items: list, + def __post_torrents_message(self, channel: MessageChannel, source: str, + title: str, items: list, userid: str, total: int): """ 发送种子列表消息 @@ -519,6 +545,7 @@ class MessageChain(ChainBase): title = f"【{title}】共找到{total}条相关资源,请回复对应数字下载(0: 自动选择)" self.post_torrents_message(Notification( channel=channel, + source=source, title=title, userid=userid, link=settings.MP_DOMAIN('#/resource') diff --git a/app/chain/subscribe.py b/app/chain/subscribe.py index 5c80e5c6..83359ea5 100644 --- a/app/chain/subscribe.py +++ b/app/chain/subscribe.py @@ -54,6 +54,7 @@ class SubscribeChain(ChainBase): bangumiid: int = None, season: int = None, channel: MessageChannel = None, + source: str = None, userid: str = None, username: str = None, message: bool = True, @@ -164,6 +165,7 @@ class SubscribeChain(ChainBase): if not exist_ok and message: # 发回原用户 self.post_message(Notification(channel=channel, + source=source, mtype=NotificationType.Subscribe, title=f"{mediainfo.title_year} {metainfo.season} " f"添加订阅失败!", diff --git a/app/command.py b/app/command.py index 66cfec55..2dc704ff 100644 --- a/app/command.py +++ b/app/command.py @@ -235,9 +235,8 @@ class Command(metaclass=Singleton): } ) - def __run_command(self, command: Dict[str, any], - data_str: str = "", - channel: MessageChannel = None, userid: Union[str, int] = None): + def __run_command(self, command: Dict[str, any], data_str: str = "", + channel: MessageChannel = None, source: str = None, userid: Union[str, int] = None): """ 运行定时服务 """ @@ -247,6 +246,7 @@ class Command(metaclass=Singleton): self.chain.post_message( Notification( channel=channel, + source=source, title=f"开始执行 {command.get('description')} ...", userid=userid ) @@ -259,6 +259,7 @@ class Command(metaclass=Singleton): self.chain.post_message( Notification( channel=channel, + source=source, title=f"{command.get('description')} 执行完成", userid=userid ) @@ -272,17 +273,18 @@ class Command(metaclass=Singleton): # 有内置参数直接使用内置参数 data = cmd_data.get("data") or {} data['channel'] = channel + data['source'] = source data['user'] = userid if data_str: data['args'] = data_str cmd_data['data'] = data command['func'](**cmd_data) - elif args_num == 2: - # 没有输入参数,只输入渠道和用户ID - command['func'](channel, userid) - elif args_num > 2: + elif args_num == 3: + # 没有输入参数,只输入渠道来源、和用户ID + command['func'](channel, source, userid) + elif args_num > 3: # 多个输入参数:用户输入、用户ID - command['func'](data_str, channel, userid) + command['func'](data_str, channel, source, userid) else: # 没有参数 command['func']() @@ -324,7 +326,8 @@ class Command(metaclass=Singleton): return self._commands.get(cmd, {}) def execute(self, cmd: str, data_str: str = "", - channel: MessageChannel = None, userid: Union[str, int] = None) -> None: + channel: MessageChannel = None, source: str = None, + userid: Union[str, int] = None) -> None: """ 执行命令 """ @@ -338,7 +341,7 @@ class Command(metaclass=Singleton): # 执行命令 self.__run_command(command, data_str=data_str, - channel=channel, userid=userid) + channel=channel, source=source, userid=userid) if userid: logger.info(f"用户 {userid} {command.get('description')} 执行完成") @@ -369,10 +372,12 @@ class Command(metaclass=Singleton): event_str = event.event_data.get('cmd') # 消息渠道 event_channel = event.event_data.get('channel') + # 消息来源 + event_source = event.event_data.get('source') # 消息用户 event_user = event.event_data.get('user') if event_str: cmd = event_str.split()[0] args = " ".join(event_str.split()[1:]) if self.get(cmd): - self.execute(cmd, args, event_channel, event_user) + self.execute(cmd, args, event_channel, event_source, event_user) diff --git a/app/db/message_oper.py b/app/db/message_oper.py index 2a46b8df..a1a02058 100644 --- a/app/db/message_oper.py +++ b/app/db/message_oper.py @@ -19,6 +19,7 @@ class MessageOper(DbOper): def add(self, channel: MessageChannel = None, + source: str = None, mtype: NotificationType = None, title: str = None, text: str = None, @@ -31,6 +32,7 @@ class MessageOper(DbOper): """ 新增媒体服务器数据 :param channel: 消息渠道 + :param source: 来源 :param mtype: 消息类型 :param title: 标题 :param text: 文本内容 @@ -42,6 +44,7 @@ class MessageOper(DbOper): """ kwargs.update({ "channel": channel.value if channel else '', + "source": source, "mtype": mtype.value if mtype else '', "title": title, "text": text, diff --git a/app/db/models/message.py b/app/db/models/message.py index a1f8029d..b128f108 100644 --- a/app/db/models/message.py +++ b/app/db/models/message.py @@ -11,6 +11,8 @@ class Message(Base): id = Column(Integer, Sequence('id'), primary_key=True, index=True) # 消息渠道 channel = Column(String) + # 消息来源 + source = Column(String) # 消息类型 mtype = Column(String) # 标题 diff --git a/app/helper/notification.py b/app/helper/notification.py index a898c762..ed5b3819 100644 --- a/app/helper/notification.py +++ b/app/helper/notification.py @@ -1,4 +1,7 @@ +from typing import List + from app.db.systemconfig_oper import SystemConfigOper +from app.schemas import NotificationConf from app.schemas.types import SystemConfigKey @@ -10,11 +13,11 @@ class NotificationHelper: def __init__(self): self.systemconfig = SystemConfigOper() - def get_notifications(self) -> dict: + def get_notifications(self) -> List[NotificationConf]: """ 获取消息通知渠道 """ - notification_conf: dict = self.systemconfig.get(SystemConfigKey.Notifications) - if not notification_conf: - return {} - return notification_conf + notification_confs: List[dict] = self.systemconfig.get(SystemConfigKey.Notifications) + if not notification_confs: + return [] + return [NotificationConf(**conf) for conf in notification_confs] diff --git a/app/modules/__init__.py b/app/modules/__init__.py index 0f4fb807..734bb9eb 100644 --- a/app/modules/__init__.py +++ b/app/modules/__init__.py @@ -1,10 +1,6 @@ from abc import abstractmethod, ABCMeta from typing import Tuple, Union -from app.db.systemconfig_oper import SystemConfigOper -from app.schemas import Notification -from app.schemas.types import SystemConfigKey, MessageChannel - class _ModuleBase(metaclass=ABCMeta): """ @@ -49,38 +45,3 @@ class _ModuleBase(metaclass=ABCMeta): 模块测试, 返回测试结果和错误信息 """ pass - - -def checkMessage(channel_type: MessageChannel): - """ - 检查消息渠道及消息类型,如不符合则不处理 - """ - - def decorator(func): - def wrapper(self, message: Notification, *args, **kwargs): - # 检查消息渠道 - if message.channel and message.channel != channel_type: - return None - else: - # 检查消息类型开关 - if message.mtype: - switchs = SystemConfigOper().get(SystemConfigKey.NotificationChannels) or [] - for switch in switchs: - if switch.get("mtype") == message.mtype.value: - if channel_type == MessageChannel.Wechat and not switch.get("wechat"): - return None - if channel_type == MessageChannel.Telegram and not switch.get("telegram"): - return None - if channel_type == MessageChannel.Slack and not switch.get("slack"): - return None - if channel_type == MessageChannel.SynologyChat and not switch.get("synologychat"): - return None - if channel_type == MessageChannel.VoceChat and not switch.get("vocechat"): - return None - if channel_type == MessageChannel.WebPush and not switch.get("webpush"): - return None - return func(self, message, *args, **kwargs) - - return wrapper - - return decorator diff --git a/app/modules/emby/__init__.py b/app/modules/emby/__init__.py index 43b00ae1..1dd55879 100644 --- a/app/modules/emby/__init__.py +++ b/app/modules/emby/__init__.py @@ -22,7 +22,7 @@ class EmbyModule(_ModuleBase): if not mediaservers: return for server in mediaservers: - if server.type == "emby": + if server.type == "emby" and server.enabled: self._servers[server.name] = Emby(**server.config) def get_server(self, name: str) -> Optional[Emby]: @@ -86,6 +86,12 @@ class EmbyModule(_ModuleBase): :param args: 请求参数 :return: 字典,解析为消息时需要包含:title、text、image """ + source = args.get("source") + if source: + server = self.get_server(source) + if not server: + return None + return server.get_webhook_message(form, args) for server in self._servers.values(): result = server.get_webhook_message(form, args) if result: diff --git a/app/modules/jellyfin/__init__.py b/app/modules/jellyfin/__init__.py index bc94f09b..dd09b279 100644 --- a/app/modules/jellyfin/__init__.py +++ b/app/modules/jellyfin/__init__.py @@ -22,7 +22,7 @@ class JellyfinModule(_ModuleBase): if not mediaservers: return for server in mediaservers: - if server.type == "jellyfin": + if server.type == "jellyfin" and server.enabled: self._servers[server.name] = Jellyfin(**server.config) def get_server(self, name: str) -> Optional[Jellyfin]: @@ -86,6 +86,12 @@ class JellyfinModule(_ModuleBase): :param args: 请求参数 :return: 字典,解析为消息时需要包含:title、text、image """ + source = args.get("source") + if source: + server = self.get_server(source) + if not server: + return None + return server.get_webhook_message(body) for server in self._servers.values(): result = server.get_webhook_message(body) if result: diff --git a/app/modules/plex/__init__.py b/app/modules/plex/__init__.py index bbdb480f..b3f3f593 100644 --- a/app/modules/plex/__init__.py +++ b/app/modules/plex/__init__.py @@ -22,7 +22,7 @@ class PlexModule(_ModuleBase): if not mediaservers: return for server in mediaservers: - if server.type == "plex": + if server.type == "plex" and server.enabled: self._servers[server.name] = Plex(**server.config) @staticmethod @@ -72,6 +72,12 @@ class PlexModule(_ModuleBase): :param args: 请求参数 :return: 字典,解析为消息时需要包含:title、text、image """ + source = args.get("source") + if source: + server = self.get_server(source) + if not server: + return None + return server.get_webhook_message(body) for server in self._servers.values(): result = server.get_webhook_message(body) if result: diff --git a/app/modules/qbittorrent/__init__.py b/app/modules/qbittorrent/__init__.py index 17a93f58..c50e84b3 100644 --- a/app/modules/qbittorrent/__init__.py +++ b/app/modules/qbittorrent/__init__.py @@ -33,7 +33,7 @@ class QbittorrentModule(_ModuleBase): if not downloaders: return for server in downloaders: - if server.type == "qbittorrent": + if server.type == "qbittorrent" and server.enabled: self._servers[server.name] = Qbittorrent(**server.config) if server.default: self._default_server_name = server.name diff --git a/app/modules/slack/__init__.py b/app/modules/slack/__init__.py index 64b12625..ab9f147e 100644 --- a/app/modules/slack/__init__.py +++ b/app/modules/slack/__init__.py @@ -1,42 +1,91 @@ import json import re -from typing import Optional, Union, List, Tuple, Any +from typing import Optional, Union, List, Tuple, Any, Dict from app.core.context import MediaInfo, Context from app.core.config import settings +from app.helper.notification import NotificationHelper from app.log import logger -from app.modules import _ModuleBase, checkMessage +from app.modules import _ModuleBase from app.modules.slack.slack import Slack -from app.schemas import MessageChannel, CommingMessage, Notification +from app.schemas import MessageChannel, CommingMessage, Notification, NotificationConf class SlackModule(_ModuleBase): - slack: Slack = None + _channel = MessageChannel.Telegram + _configs: Dict[str, NotificationConf] = {} + _clients: Dict[str, Slack] = {} def init_module(self) -> None: - self.slack = Slack() + """ + 初始化模块 + """ + clients = NotificationHelper().get_notifications() + if not clients: + return + self._configs = {} + self._clients = {} + for client in clients: + if client.type == "telegram" and client.enabled: + self._configs[client.name] = client + self._clients[client.name] = Slack(**client.config, name=client.name) @staticmethod def get_name() -> str: return "Slack" + def get_client(self, name: str) -> Optional[Slack]: + """ + 获取Telegram客户端 + """ + return self._clients.get(name) + + def get_config(self, name: str) -> Optional[NotificationConf]: + """ + 获取Telegram配置 + """ + return self._configs.get(name) + def stop(self): - self.slack.stop() + """ + 停止模块 + """ + for client in self._clients.values(): + client.stop() def test(self) -> Tuple[bool, str]: """ 测试模块连接性 """ - state = self.slack.get_state() - if state: - return True, "" - return False, "Slack未就续,请检查参数设置和网络连接" + for name, client in self._clients.items(): + state = client.get_state() + if not state: + return False, f"Slack {name} 未就续" + return True, "" def init_setting(self) -> Tuple[str, Union[str, bool]]: - return "MESSAGER", "slack" + pass - @staticmethod - def message_parser(body: Any, form: Any, + def checkMessage(self, message: Notification, source: str) -> bool: + """ + 检查消息渠道及消息类型,如不符合则不处理 + """ + # 检查消息渠道 + if message.channel and message.channel != self._channel: + return False + # 检查消息来源 + if message.source and message.source != source: + return False + # 检查消息类型开关 + if message.mtype: + conf = self.get_config(source) + if conf: + switchs = conf.switchs or [] + if message.mtype.value not in switchs: + return False + return True + + def message_parser(self, body: Any, form: Any, args: Any) -> Optional[CommingMessage]: """ 解析消息内容,返回字典,注意以下约定值: @@ -157,6 +206,14 @@ class SlackModule(_ModuleBase): ] } """ + # 来源 + source = args.get("source") + if not source: + return None + # 获取客户端 + client = self.get_client(source) + if not client: + return None # 校验token token = args.get("token") if not token or token != settings.API_TOKEN: @@ -189,38 +246,50 @@ class SlackModule(_ModuleBase): username = msg_json.get("user_name") else: return None - logger.info(f"收到Slack消息:userid={userid}, username={username}, text={text}") - return CommingMessage(channel=MessageChannel.Slack, + logger.info(f"收到来自 {source} 的Slack消息:userid={userid}, username={username}, text={text}") + return CommingMessage(channel=MessageChannel.Slack, source=source, userid=userid, username=username, text=text) return None - @checkMessage(MessageChannel.Slack) def post_message(self, message: Notification) -> None: """ 发送消息 :param message: 消息 :return: 成功或失败 """ - self.slack.send_msg(title=message.title, text=message.text, - image=message.image, userid=message.userid, link=message.link) + for conf in self._configs.values(): + if not self.checkMessage(message, conf.name): + continue + client = self.get_client(conf.name) + if client: + client.send_msg(title=message.title, text=message.text, + image=message.image, userid=message.userid, link=message.link) - @checkMessage(MessageChannel.Slack) - def post_medias_message(self, message: Notification, medias: List[MediaInfo]) -> Optional[bool]: + def post_medias_message(self, message: Notification, medias: List[MediaInfo]) -> None: """ 发送媒体信息选择列表 :param message: 消息体 :param medias: 媒体信息 :return: 成功或失败 """ - return self.slack.send_meidas_msg(title=message.title, medias=medias, userid=message.userid) + for conf in self._configs.values(): + if not self.checkMessage(message, conf.name): + continue + client = self.get_client(conf.name) + if client: + client.send_meidas_msg(title=message.title, medias=medias, userid=message.userid) - @checkMessage(MessageChannel.Slack) - def post_torrents_message(self, message: Notification, torrents: List[Context]) -> Optional[bool]: + def post_torrents_message(self, message: Notification, torrents: List[Context]) -> None: """ 发送种子信息选择列表 :param message: 消息体 :param torrents: 种子信息 :return: 成功或失败 """ - return self.slack.send_torrents_msg(title=message.title, torrents=torrents, - userid=message.userid) + for conf in self._configs.values(): + if not self.checkMessage(message, conf.name): + continue + client = self.get_client(conf.name) + if client: + client.send_torrents_msg(title=message.title, torrents=torrents, + userid=message.userid) diff --git a/app/modules/slack/slack.py b/app/modules/slack/slack.py index 055fd6c8..cafa631e 100644 --- a/app/modules/slack/slack.py +++ b/app/modules/slack/slack.py @@ -41,6 +41,10 @@ class Slack: self._client = slack_app.client self._channel = channel + # 标记消息来源 + if kwargs.get("name"): + self._ds_url = f"{self._ds_url}&source={kwargs.get('name')}" + # 注册消息响应 @slack_app.event("message") def slack_message(message): diff --git a/app/modules/synologychat/__init__.py b/app/modules/synologychat/__init__.py index caf99c52..fc146e71 100644 --- a/app/modules/synologychat/__init__.py +++ b/app/modules/synologychat/__init__.py @@ -1,22 +1,48 @@ -from typing import Optional, Union, List, Tuple, Any +from typing import Optional, Union, List, Tuple, Any, Dict from app.core.context import MediaInfo, Context +from app.helper.notification import NotificationHelper from app.log import logger -from app.modules import _ModuleBase, checkMessage +from app.modules import _ModuleBase from app.modules.synologychat.synologychat import SynologyChat -from app.schemas import MessageChannel, CommingMessage, Notification +from app.schemas import MessageChannel, CommingMessage, Notification, NotificationConf class SynologyChatModule(_ModuleBase): - synologychat: SynologyChat = None + _channel = MessageChannel.Telegram + _configs: Dict[str, NotificationConf] = {} + _clients: Dict[str, SynologyChat] = {} def init_module(self) -> None: - self.synologychat = SynologyChat() + """ + 初始化模块 + """ + clients = NotificationHelper().get_notifications() + if not clients: + return + self._configs = {} + self._clients = {} + for client in clients: + if client.type == "telegram" and client.enabled: + self._configs[client.name] = client + self._clients[client.name] = SynologyChat(**client.config) @staticmethod def get_name() -> str: return "Synology Chat" + def get_client(self, name: str) -> Optional[SynologyChat]: + """ + 获取Telegram客户端 + """ + return self._clients.get(name) + + def get_config(self, name: str) -> Optional[NotificationConf]: + """ + 获取Telegram配置 + """ + return self._configs.get(name) + def stop(self): pass @@ -24,13 +50,33 @@ class SynologyChatModule(_ModuleBase): """ 测试模块连接性 """ - state = self.synologychat.get_state() - if state: - return True, "" - return False, "SynologyChat未就续,请检查参数设置、网络连接以及机器人是否可见" + for name, client in self._clients.items(): + state = client.get_state() + if not state: + return False, f"Synology Chat {name} 未就续" + return True, "" def init_setting(self) -> Tuple[str, Union[str, bool]]: - return "MESSAGER", "synologychat" + pass + + def checkMessage(self, message: Notification, source: str) -> bool: + """ + 检查消息渠道及消息类型,如不符合则不处理 + """ + # 检查消息渠道 + if message.channel and message.channel != self._channel: + return False + # 检查消息来源 + if message.source and message.source != source: + return False + # 检查消息类型开关 + if message.mtype: + conf = self.get_config(source) + if conf: + switchs = conf.switchs or [] + if message.mtype.value not in switchs: + return False + return True def message_parser(self, body: Any, form: Any, args: Any) -> Optional[CommingMessage]: @@ -45,12 +91,20 @@ class SynologyChatModule(_ModuleBase): :return: 渠道、消息体 """ try: + # 来源 + source = args.get("source") + if not source: + return None + client = self.get_client(source) + if not client: + return None + # 解析消息 message: dict = form if not message: return None # 校验token token = message.get("token") - if not token or not self.synologychat.check_token(token): + if not token or not client.check_token(token): return None # 文本 text = message.get("text") @@ -66,34 +120,46 @@ class SynologyChatModule(_ModuleBase): logger.debug(f"解析SynologyChat消息失败:{str(err)}") return None - @checkMessage(MessageChannel.SynologyChat) def post_message(self, message: Notification) -> None: """ 发送消息 :param message: 消息体 :return: 成功或失败 """ - self.synologychat.send_msg(title=message.title, text=message.text, - image=message.image, userid=message.userid, link=message.link) + for conf in self._configs.values(): + if not self.checkMessage(message, conf.name): + continue + client = self.get_client(conf.name) + if client: + client.send_msg(title=message.title, text=message.text, + image=message.image, userid=message.userid, link=message.link) - @checkMessage(MessageChannel.SynologyChat) - def post_medias_message(self, message: Notification, medias: List[MediaInfo]) -> Optional[bool]: + def post_medias_message(self, message: Notification, medias: List[MediaInfo]) -> None: """ 发送媒体信息选择列表 :param message: 消息体 :param medias: 媒体列表 :return: 成功或失败 """ - return self.synologychat.send_meidas_msg(title=message.title, medias=medias, - userid=message.userid) + for conf in self._configs.values(): + if not self.checkMessage(message, conf.name): + continue + client = self.get_client(conf.name) + if client: + client.send_meidas_msg(title=message.title, medias=medias, + userid=message.userid) - @checkMessage(MessageChannel.SynologyChat) - def post_torrents_message(self, message: Notification, torrents: List[Context]) -> Optional[bool]: + def post_torrents_message(self, message: Notification, torrents: List[Context]) -> None: """ 发送种子信息选择列表 :param message: 消息体 :param torrents: 种子列表 :return: 成功或失败 """ - return self.synologychat.send_torrents_msg(title=message.title, torrents=torrents, - userid=message.userid, link=message.link) + for conf in self._configs.values(): + if not self.checkMessage(message, conf.name): + continue + client = self.get_client(conf.name) + if client: + client.send_torrents_msg(title=message.title, torrents=torrents, + userid=message.userid, link=message.link) diff --git a/app/modules/telegram/__init__.py b/app/modules/telegram/__init__.py index c2bc34db..07a11690 100644 --- a/app/modules/telegram/__init__.py +++ b/app/modules/telegram/__init__.py @@ -3,44 +3,95 @@ from typing import Optional, Union, List, Tuple, Any, Dict from app.core.context import MediaInfo, Context from app.core.config import settings +from app.helper.notification import NotificationHelper from app.log import logger -from app.modules import _ModuleBase, checkMessage +from app.modules import _ModuleBase from app.modules.telegram.telegram import Telegram -from app.schemas import MessageChannel, CommingMessage, Notification +from app.schemas import MessageChannel, CommingMessage, Notification, NotificationConf class TelegramModule(_ModuleBase): - telegram: Telegram = None + _channel = MessageChannel.Telegram + _configs: Dict[str, NotificationConf] = {} + _clients: Dict[str, Telegram] = {} def init_module(self) -> None: - self.telegram = Telegram() + """ + 初始化模块 + """ + clients = NotificationHelper().get_notifications() + if not clients: + return + self._configs = {} + self._clients = {} + for client in clients: + if client.type == "telegram" and client.enabled: + self._configs[client.name] = client + self._clients[client.name] = Telegram(**client.config, name=client.name) @staticmethod def get_name() -> str: return "Telegram" + def get_client(self, name: str) -> Optional[Telegram]: + """ + 获取Telegram客户端 + """ + return self._clients.get(name) + + def get_config(self, name: str) -> Optional[NotificationConf]: + """ + 获取Telegram配置 + """ + return self._configs.get(name) + def stop(self): - self.telegram.stop() + """ + 停止模块 + """ + for client in self._clients.values(): + client.stop() def test(self) -> Tuple[bool, str]: """ 测试模块连接性 """ - state = self.telegram.get_state() - if state: - return True, "" - return False, "Telegram未就续,请检查参数设置和网络连接" + for name, client in self._clients.items(): + state = client.get_state() + if not state: + return False, f"Telegram {name} 未就续" + return True, "" def init_setting(self) -> Tuple[str, Union[str, bool]]: - return "MESSAGER", "telegram" + pass - def message_parser(self, body: Any, form: Any, + def checkMessage(self, message: Notification, source: str) -> bool: + """ + 检查消息渠道及消息类型,如不符合则不处理 + """ + # 检查消息渠道 + if message.channel and message.channel != self._channel: + return False + # 检查消息来源 + if message.source and message.source != source: + return False + # 检查消息类型开关 + if message.mtype: + conf = self.get_config(source) + if conf: + switchs = conf.switchs or [] + if message.mtype.value not in switchs: + return False + return True + + def message_parser(self, source: str, body: Any, form: Any, args: Any) -> Optional[CommingMessage]: """ 解析消息内容,返回字典,注意以下约定值: userid: 用户ID username: 用户名 text: 内容 + :param source: 消息来源(渠道配置名称) :param body: 请求体 :param form: 表单 :param args: 参数 @@ -69,6 +120,14 @@ class TelegramModule(_ModuleBase): } } """ + # 获取渠道 + client = self.get_client(source) + if not client: + return None + # 获取配置 + config = self.get_config(source) + if not config: + return None # 校验token token = args.get("token") if not token or token != settings.API_TOKEN: @@ -84,59 +143,75 @@ class TelegramModule(_ModuleBase): # 获取用户名 user_name = message.get("from", {}).get("username") if text: - logger.info(f"收到Telegram消息:userid={user_id}, username={user_name}, text={text}") + logger.info(f"收到来自 {source} 的Telegram消息:userid={user_id}, username={user_name}, text={text}") # 检查权限 + admin_users = config.config.get("admins") + user_list = config.config.get("users") + chat_id = config.config.get("chat_id") if text.startswith("/"): - if settings.TELEGRAM_ADMINS \ - and str(user_id) not in settings.TELEGRAM_ADMINS.split(',') \ - and str(user_id) != settings.TELEGRAM_CHAT_ID: - self.telegram.send_msg(title="只有管理员才有权限执行此命令", userid=user_id) + if admin_users \ + and str(user_id) not in admin_users.split(',') \ + and str(user_id) != chat_id: + client.send_msg(title="只有管理员才有权限执行此命令", userid=user_id) return None else: - if settings.TELEGRAM_USERS \ - and not str(user_id) in settings.TELEGRAM_USERS.split(','): + if user_list \ + and not str(user_id) in user_list.split(','): logger.info(f"用户{user_id}不在用户白名单中,无法使用此机器人") - self.telegram.send_msg(title="你不在用户白名单中,无法使用此机器人", userid=user_id) + client.send_msg(title="你不在用户白名单中,无法使用此机器人", userid=user_id) return None - return CommingMessage(channel=MessageChannel.Telegram, + return CommingMessage(channel=MessageChannel.Telegram, source=source, userid=user_id, username=user_name, text=text) return None - @checkMessage(MessageChannel.Telegram) def post_message(self, message: Notification) -> None: """ 发送消息 :param message: 消息体 :return: 成功或失败 """ - self.telegram.send_msg(title=message.title, text=message.text, - image=message.image, userid=message.userid, link=message.link) + for conf in self._configs.values(): + if not self.checkMessage(message, conf.name): + continue + client = self.get_client(conf.name) + if client: + client.send_msg(title=message.title, text=message.text, + image=message.image, userid=message.userid, link=message.link) - @checkMessage(MessageChannel.Telegram) - def post_medias_message(self, message: Notification, medias: List[MediaInfo]) -> Optional[bool]: + def post_medias_message(self, message: Notification, medias: List[MediaInfo]) -> None: """ 发送媒体信息选择列表 :param message: 消息体 :param medias: 媒体列表 :return: 成功或失败 """ - return self.telegram.send_meidas_msg(title=message.title, medias=medias, - userid=message.userid, link=message.link) + for conf in self._configs.values(): + if not self.checkMessage(message, conf.name): + continue + client = self.get_client(conf.name) + if client: + client.send_meidas_msg(title=message.title, medias=medias, + userid=message.userid, link=message.link) - @checkMessage(MessageChannel.Telegram) - def post_torrents_message(self, message: Notification, torrents: List[Context]) -> Optional[bool]: + def post_torrents_message(self, message: Notification, torrents: List[Context]) -> None: """ 发送种子信息选择列表 :param message: 消息体 :param torrents: 种子列表 :return: 成功或失败 """ - return self.telegram.send_torrents_msg(title=message.title, torrents=torrents, - userid=message.userid, link=message.link) + for conf in self._configs.values(): + if not self.checkMessage(message, conf.name): + continue + client = self.get_client(conf.name) + if client: + client.send_torrents_msg(title=message.title, torrents=torrents, + userid=message.userid, link=message.link) def register_commands(self, commands: Dict[str, dict]): """ 注册命令,实现这个函数接收系统可用的命令菜单 :param commands: 命令字典 """ - self.telegram.register_commands(commands) + for client in self._clients.values(): + client.register_commands(commands) diff --git a/app/modules/telegram/telegram.py b/app/modules/telegram/telegram.py index 405904c8..c796bf4d 100644 --- a/app/modules/telegram/telegram.py +++ b/app/modules/telegram/telegram.py @@ -41,6 +41,9 @@ class Telegram: _bot = telebot.TeleBot(self._telegram_token, parse_mode="Markdown") # 记录句柄 self._bot = _bot + # 标记渠道来源 + if kwargs.get("name"): + self._ds_url = f"{self._ds_url}&source={kwargs.get('name')}" @_bot.message_handler(commands=['start', 'help']) def send_welcome(message): diff --git a/app/modules/transmission/__init__.py b/app/modules/transmission/__init__.py index 2af4ec76..ccb57555 100644 --- a/app/modules/transmission/__init__.py +++ b/app/modules/transmission/__init__.py @@ -30,7 +30,7 @@ class TransmissionModule(_ModuleBase): if not downloaders: return for server in downloaders: - if server.type == "transmission": + if server.type == "transmission" and server.enabled: self._servers[server.name] = Transmission(**server.config) if server.default: self._default_server_name = server.name diff --git a/app/modules/vocechat/__init__.py b/app/modules/vocechat/__init__.py index ccd4d774..48f0bb1f 100644 --- a/app/modules/vocechat/__init__.py +++ b/app/modules/vocechat/__init__.py @@ -3,6 +3,7 @@ from typing import Optional, Union, List, Tuple, Any, Dict from app.core.config import settings from app.core.context import Context, MediaInfo +from app.helper.notification import NotificationHelper from app.log import logger from app.modules import _ModuleBase, checkMessage from app.modules.vocechat.vocechat import VoceChat @@ -10,10 +11,19 @@ from app.schemas import MessageChannel, CommingMessage, Notification class VoceChatModule(_ModuleBase): - vocechat: VoceChat = None + _clients: Dict[str, VoceChat] = {} def init_module(self) -> None: - self.vocechat = VoceChat() + """ + 初始化模块 + """ + self._clients = {} + clients = NotificationHelper().get_notifications() + if not clients: + return + for client in clients: + if client.type == "vocechat" and client.enabled: + self._clients[client.name] = VoceChat(**client.config) @staticmethod def get_name() -> str: @@ -32,7 +42,7 @@ class VoceChatModule(_ModuleBase): return False, "获取VoceChat频道失败" def init_setting(self) -> Tuple[str, Union[str, bool]]: - return "MESSAGER", "vocechat" + pass @staticmethod def message_parser(body: Any, form: Any, diff --git a/app/modules/wechat/__init__.py b/app/modules/wechat/__init__.py index d363751c..160c9ff8 100644 --- a/app/modules/wechat/__init__.py +++ b/app/modules/wechat/__init__.py @@ -3,24 +3,50 @@ from typing import Optional, Union, List, Tuple, Any, Dict from app.core.config import settings from app.core.context import Context, MediaInfo +from app.helper.notification import NotificationHelper from app.log import logger -from app.modules import _ModuleBase, checkMessage +from app.modules import _ModuleBase from app.modules.wechat.WXBizMsgCrypt3 import WXBizMsgCrypt from app.modules.wechat.wechat import WeChat -from app.schemas import MessageChannel, CommingMessage, Notification +from app.schemas import MessageChannel, CommingMessage, Notification, NotificationConf from app.utils.dom import DomUtils class WechatModule(_ModuleBase): - wechat: WeChat = None + _channel = MessageChannel.Wechat + _configs: Dict[str, NotificationConf] = {} + _clients: Dict[str, WeChat] = {} def init_module(self) -> None: - self.wechat = WeChat() + """ + 初始化模块 + """ + clients = NotificationHelper().get_notifications() + if not clients: + return + self._configs = {} + self._clients = {} + for client in clients: + if client.type == "wechat" and client.enabled: + self._configs[client.name] = client + self._clients[client.name] = WeChat(**client.config) @staticmethod def get_name() -> str: return "微信" + def get_client(self, name: str) -> Optional[WeChat]: + """ + 获取Telegram客户端 + """ + return self._clients.get(name) + + def get_config(self, name: str) -> Optional[NotificationConf]: + """ + 获取Telegram配置 + """ + return self._configs.get(name) + def stop(self): pass @@ -28,13 +54,14 @@ class WechatModule(_ModuleBase): """ 测试模块连接性 """ - state = self.wechat.get_state() - if state: - return True, "" - return False, "获取微信token失败" + for name, client in self._clients.items(): + state = client.get_state() + if not state: + return False, f"企业微信 {name} 未就续" + return True, "" def init_setting(self) -> Tuple[str, Union[str, bool]]: - return "MESSAGER", "wechat" + pass def message_parser(self, body: Any, form: Any, args: Any) -> Optional[CommingMessage]: @@ -49,6 +76,14 @@ class WechatModule(_ModuleBase): :return: 渠道、消息体 """ try: + # 消息来源 + source = args.get("source") + if not source: + return None + # 获取客户端 + client = self.get_client(source) + if not client: + return None # URL参数 sVerifyMsgSig = args.get("msg_signature") sVerifyTimeStamp = args.get("timestamp") @@ -113,7 +148,7 @@ class WechatModule(_ModuleBase): wechat_admins = settings.WECHAT_ADMINS.split(',') if wechat_admins and not any( user_id == admin_user for admin_user in wechat_admins): - self.wechat.send_msg(title="用户无权限执行菜单命令", userid=user_id) + client.send_msg(title="用户无权限执行菜单命令", userid=user_id) return None # 根据EventKey执行命令 content = DomUtils.tag_value(root_node, "EventKey") @@ -127,49 +162,81 @@ class WechatModule(_ModuleBase): if content: # 处理消息内容 - return CommingMessage(channel=MessageChannel.Wechat, + return CommingMessage(channel=MessageChannel.Wechat, source=source, userid=user_id, username=user_id, text=content) except Exception as err: logger.error(f"微信消息处理发生错误:{str(err)}") return None - @checkMessage(MessageChannel.Wechat) + def checkMessage(self, message: Notification, source: str) -> bool: + """ + 检查消息渠道及消息类型,如不符合则不处理 + """ + # 检查消息渠道 + if message.channel and message.channel != self._channel: + return False + # 检查消息来源 + if message.source and message.source != source: + return False + # 检查消息类型开关 + if message.mtype: + conf = self.get_config(source) + if conf: + switchs = conf.switchs or [] + if message.mtype.value not in switchs: + return False + return True + def post_message(self, message: Notification) -> None: """ 发送消息 :param message: 消息内容 :return: 成功或失败 """ - self.wechat.send_msg(title=message.title, text=message.text, - image=message.image, userid=message.userid, link=message.link) + for conf in self._configs.values(): + if not self.checkMessage(message, conf.name): + continue + client = self.get_client(conf.name) + if client: + client.send_msg(title=message.title, text=message.text, + image=message.image, userid=message.userid, link=message.link) - @checkMessage(MessageChannel.Wechat) - def post_medias_message(self, message: Notification, medias: List[MediaInfo]) -> Optional[bool]: + def post_medias_message(self, message: Notification, medias: List[MediaInfo]) -> None: """ 发送媒体信息选择列表 :param message: 消息内容 :param medias: 媒体列表 :return: 成功或失败 """ - # 先发送标题 - self.wechat.send_msg(title=message.title, userid=message.userid, link=message.link) - # 再发送内容 - return self.wechat.send_medias_msg(medias=medias, userid=message.userid) + for conf in self._configs.values(): + if not self.checkMessage(message, conf.name): + continue + client = self.get_client(conf.name) + if client: + # 先发送标题 + client.send_msg(title=message.title, userid=message.userid, link=message.link) + # 再发送内容 + client.send_medias_msg(medias=medias, userid=message.userid) - @checkMessage(MessageChannel.Wechat) - def post_torrents_message(self, message: Notification, torrents: List[Context]) -> Optional[bool]: + def post_torrents_message(self, message: Notification, torrents: List[Context]) -> None: """ 发送种子信息选择列表 :param message: 消息内容 :param torrents: 种子列表 :return: 成功或失败 """ - return self.wechat.send_torrents_msg(title=message.title, torrents=torrents, - userid=message.userid, link=message.link) + for conf in self._configs.values(): + if not self.checkMessage(message, conf.name): + continue + client = self.get_client(conf.name) + if client: + client.send_torrents_msg(title=message.title, torrents=torrents, + userid=message.userid, link=message.link) def register_commands(self, commands: Dict[str, dict]): """ 注册命令,实现这个函数接收系统可用的命令菜单 :param commands: 命令字典 """ - self.wechat.create_menus(commands) + for client in self._clients.values(): + client.create_menus(commands) diff --git a/app/schemas/__init__.py b/app/schemas/__init__.py index 603001aa..0ea1f602 100644 --- a/app/schemas/__init__.py +++ b/app/schemas/__init__.py @@ -14,4 +14,4 @@ from .message import * from .tmdb import * from .transfer import * from .filetransfer import * -from .mediaserver import * +from .system import * diff --git a/app/schemas/mediaserver.py b/app/schemas/mediaserver.py index eadad97f..582ad2d7 100644 --- a/app/schemas/mediaserver.py +++ b/app/schemas/mediaserver.py @@ -156,15 +156,3 @@ class MediaServerPlayItem(BaseModel): image: Optional[str] = None link: Optional[str] = None percent: Optional[float] = None - - -class MediaServerConf(BaseModel): - """ - 媒体服务器配置 - """ - # 名称 - name: Optional[str] = None - # 类型 emby/jellyfin/plex - type: Optional[str] = None - # 配置 - config: Optional[dict] = {} diff --git a/app/schemas/message.py b/app/schemas/message.py index 1d8c2ab4..088bc491 100644 --- a/app/schemas/message.py +++ b/app/schemas/message.py @@ -15,6 +15,8 @@ class CommingMessage(BaseModel): username: Optional[str] = None # 消息渠道 channel: Optional[MessageChannel] = None + # 来源(渠道名称) + source: Optional[str] = None # 消息体 text: Optional[str] = None # 时间 @@ -39,6 +41,8 @@ class Notification(BaseModel): """ # 消息渠道 channel: Optional[MessageChannel] = None + # 消息来源 + source: Optional[str] = None # 消息类型 mtype: Optional[NotificationType] = None # 标题 diff --git a/app/schemas/system.py b/app/schemas/system.py index 1eb6f3a9..b0a963fa 100644 --- a/app/schemas/system.py +++ b/app/schemas/system.py @@ -13,6 +13,8 @@ class MediaServerConf(BaseModel): type: Optional[str] = None # 配置 config: Optional[dict] = {} + # 是否启用 + enabled: Optional[bool] = False class DownloaderConf(BaseModel): @@ -27,3 +29,21 @@ class DownloaderConf(BaseModel): default: Optional[bool] = False # 配置 config: Optional[dict] = {} + # 是否启用 + enabled: Optional[bool] = False + + +class NotificationConf(BaseModel): + """ + 通知配置 + """ + # 名称 + name: Optional[str] = None + # 类型 telegram/wechat/vocechat/synologychat + type: Optional[str] = None + # 配置 + config: Optional[dict] = {} + # 场景开关 + switchs: Optional[list] = [] + # 是否启用 + enabled: Optional[bool] = False