diff --git a/app/modules/discord/__init__.py b/app/modules/discord/__init__.py index 15ed851a..c2e9f2b4 100644 --- a/app/modules/discord/__init__.py +++ b/app/modules/discord/__init__.py @@ -4,7 +4,7 @@ from typing import Optional, Union, List, Tuple, Any from app.core.context import MediaInfo, Context from app.log import logger from app.modules import _ModuleBase, _MessageBase -from app.schemas import MessageChannel, CommingMessage, Notification +from app.schemas import MessageChannel, CommingMessage, Notification, MessageResponse from app.schemas.types import ModuleType try: @@ -15,7 +15,6 @@ except Exception as err: # ImportError or other load issues class DiscordModule(_ModuleBase, _MessageBase[Discord]): - def init_module(self) -> None: """ 初始化模块 @@ -24,8 +23,9 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]): logger.error("Discord 依赖未就绪(需要安装 discord.py==2.6.4),模块未启动") return self.stop() - super().init_service(service_name=Discord.__name__.lower(), - service_type=Discord) + super().init_service( + service_name=Discord.__name__.lower(), service_type=Discord + ) self._channel = MessageChannel.Discord @staticmethod @@ -75,7 +75,9 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]): def init_setting(self) -> Tuple[str, Union[str, bool]]: pass - def message_parser(self, source: str, body: Any, form: Any, args: Any) -> Optional[CommingMessage]: + def message_parser( + self, source: str, body: Any, form: Any, args: Any + ) -> Optional[CommingMessage]: """ 解析消息内容,返回字典,注意以下约定值: userid: 用户ID @@ -108,8 +110,10 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]): message_id = msg_json.get("message_id") chat_id = msg_json.get("chat_id") if callback_data and userid: - logger.info(f"收到来自 {client_config.name} 的 Discord 按钮回调:" - f"userid={userid}, username={username}, callback_data={callback_data}") + logger.info( + f"收到来自 {client_config.name} 的 Discord 按钮回调:" + f"userid={userid}, username={username}, callback_data={callback_data}" + ) return CommingMessage( channel=MessageChannel.Discord, source=client_config.name, @@ -119,7 +123,7 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]): is_callback=True, callback_data=callback_data, message_id=message_id, - chat_id=str(chat_id) if chat_id else None + chat_id=str(chat_id) if chat_id else None, ) return None @@ -127,11 +131,18 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]): text = msg_json.get("text") chat_id = msg_json.get("chat_id") if text and userid: - logger.info(f"收到来自 {client_config.name} 的 Discord 消息:" - f"userid={userid}, username={username}, text={text}") - return CommingMessage(channel=MessageChannel.Discord, source=client_config.name, - userid=userid, username=username, text=text, - chat_id=str(chat_id) if chat_id else None) + logger.info( + f"收到来自 {client_config.name} 的 Discord 消息:" + f"userid={userid}, username={username}, text={text}" + ) + return CommingMessage( + channel=MessageChannel.Discord, + source=client_config.name, + userid=userid, + username=username, + text=text, + chat_id=str(chat_id) if chat_id else None, + ) return None def post_message(self, message: Notification, **kwargs) -> None: @@ -141,43 +152,66 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]): """ # DEBUG: Log entry and configs configs = self.get_configs() - logger.debug(f"[Discord] post_message 被调用,message.source={message.source}, " - f"message.userid={message.userid}, message.channel={message.channel}") - logger.debug(f"[Discord] 当前配置数量: {len(configs)}, 配置名称: {list(configs.keys())}") - logger.debug(f"[Discord] 当前实例数量: {len(self.get_instances())}, 实例名称: {list(self.get_instances().keys())}") + logger.debug( + f"[Discord] post_message 被调用,message.source={message.source}, " + f"message.userid={message.userid}, message.channel={message.channel}" + ) + logger.debug( + f"[Discord] 当前配置数量: {len(configs)}, 配置名称: {list(configs.keys())}" + ) + logger.debug( + f"[Discord] 当前实例数量: {len(self.get_instances())}, 实例名称: {list(self.get_instances().keys())}" + ) if not configs: logger.warning("[Discord] get_configs() 返回空,没有可用的 Discord 配置") return for conf in configs.values(): - logger.debug(f"[Discord] 检查配置: name={conf.name}, type={conf.type}, enabled={conf.enabled}") + logger.debug( + f"[Discord] 检查配置: name={conf.name}, type={conf.type}, enabled={conf.enabled}" + ) if not self.check_message(message, conf.name): - logger.debug(f"[Discord] check_message 返回 False,跳过配置: {conf.name}") + logger.debug( + f"[Discord] check_message 返回 False,跳过配置: {conf.name}" + ) continue logger.debug(f"[Discord] check_message 通过,准备发送到: {conf.name}") targets = message.targets userid = message.userid if not userid and targets is not None: - userid = targets.get('discord_userid') + userid = targets.get("discord_userid") if not userid: logger.warn("用户没有指定 Discord 用户ID,消息无法发送") return client: Discord = self.get_instance(conf.name) - logger.debug(f"[Discord] get_instance('{conf.name}') 返回: {client is not None}") + logger.debug( + f"[Discord] get_instance('{conf.name}') 返回: {client is not None}" + ) if client: - logger.debug(f"[Discord] 调用 client.send_msg, userid={userid}, title={message.title[:50] if message.title else None}...") - result = client.send_msg(title=message.title, text=message.text, - image=message.image, userid=userid, link=message.link, - buttons=message.buttons, - original_message_id=message.original_message_id, - original_chat_id=message.original_chat_id, - mtype=message.mtype) + logger.debug( + f"[Discord] 调用 client.send_msg, userid={userid}, title={message.title[:50] if message.title else None}..." + ) + result = client.send_msg( + title=message.title, + text=message.text, + image=message.image, + userid=userid, + link=message.link, + buttons=message.buttons, + original_message_id=message.original_message_id, + original_chat_id=message.original_chat_id, + mtype=message.mtype, + ) logger.debug(f"[Discord] send_msg 返回结果: {result}") else: - logger.warning(f"[Discord] 未找到配置 '{conf.name}' 对应的 Discord 客户端实例") + logger.warning( + f"[Discord] 未找到配置 '{conf.name}' 对应的 Discord 客户端实例" + ) - def post_medias_message(self, message: Notification, medias: List[MediaInfo]) -> None: + def post_medias_message( + self, message: Notification, medias: List[MediaInfo] + ) -> None: """ 发送媒体信息选择列表 :param message: 消息体 @@ -189,12 +223,18 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]): continue client: Discord = self.get_instance(conf.name) if client: - client.send_medias_msg(title=message.title, medias=medias, userid=message.userid, - buttons=message.buttons, - original_message_id=message.original_message_id, - original_chat_id=message.original_chat_id) + client.send_medias_msg( + title=message.title, + medias=medias, + userid=message.userid, + buttons=message.buttons, + original_message_id=message.original_message_id, + original_chat_id=message.original_chat_id, + ) - def post_torrents_message(self, message: Notification, torrents: List[Context]) -> None: + def post_torrents_message( + self, message: Notification, torrents: List[Context] + ) -> None: """ 发送种子信息选择列表 :param message: 消息体 @@ -206,13 +246,22 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]): continue client: Discord = self.get_instance(conf.name) if client: - client.send_torrents_msg(title=message.title, torrents=torrents, - userid=message.userid, buttons=message.buttons, - original_message_id=message.original_message_id, - original_chat_id=message.original_chat_id) + client.send_torrents_msg( + title=message.title, + torrents=torrents, + userid=message.userid, + buttons=message.buttons, + original_message_id=message.original_message_id, + original_chat_id=message.original_chat_id, + ) - def delete_message(self, channel: MessageChannel, source: str, - message_id: str, chat_id: Optional[str] = None) -> bool: + def delete_message( + self, + channel: MessageChannel, + source: str, + message_id: str, + chat_id: Optional[str] = None, + ) -> bool: """ 删除消息 :param channel: 消息渠道 @@ -233,3 +282,80 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]): if result: success = True return success + + def edit_message( + self, + channel: MessageChannel, + source: str, + message_id: Union[str, int], + chat_id: Union[str, int], + text: str, + title: Optional[str] = None, + ) -> bool: + """ + 编辑消息 + :param channel: 消息渠道 + :param source: 指定的消息源 + :param message_id: 消息ID + :param chat_id: 聊天ID + :param text: 新的消息内容 + :param title: 消息标题 + :return: 编辑是否成功 + """ + if channel != self._channel: + return False + for conf in self.get_configs().values(): + if source != conf.name: + continue + client: Discord = self.get_instance(conf.name) + if client: + result = client.send_msg( + title=title or "", + text=text, + original_message_id=message_id, + original_chat_id=str(chat_id), + ) + if result and isinstance(result, tuple) and result[0]: + return True + elif result: + return True + return False + + def send_direct_message(self, message: Notification) -> Optional[MessageResponse]: + """ + 直接发送消息并返回消息ID等信息 + :param message: 消息体 + :return: 消息响应(包含message_id, chat_id等) + """ + for conf in self.get_configs().values(): + if not self.check_message(message, conf.name): + continue + targets = message.targets + userid = message.userid + if not userid and targets is not None: + userid = targets.get("discord_userid") + if not userid: + logger.warn("用户没有指定 Discord 用户ID,消息无法发送") + return None + client: Discord = self.get_instance(conf.name) + if client: + result = client.send_msg( + title=message.title or "", + text=message.text, + userid=userid, + ) + if result: + success, message_id = ( + (result[0], result[1]) + if isinstance(result, tuple) + else (result, None) + ) + if success: + return MessageResponse( + message_id=str(message_id) if message_id else None, + chat_id=None, + channel=MessageChannel.Discord, + source=conf.name, + success=True, + ) + return None diff --git a/app/modules/discord/discord.py b/app/modules/discord/discord.py index f1997c5c..c672613f 100644 --- a/app/modules/discord/discord.py +++ b/app/modules/discord/discord.py @@ -18,10 +18,10 @@ from app.utils.string import StringUtils # Discord embed 字段解析白名单 # 只有这些消息类型会使用复杂的字段解析逻辑 PARSE_FIELD_TYPES = { - NotificationType.Download, # 资源下载 - NotificationType.Organize, # 整理入库 - NotificationType.Subscribe, # 订阅 - NotificationType.Manual, # 手动处理 + NotificationType.Download, # 资源下载 + NotificationType.Organize, # 整理入库 + NotificationType.Subscribe, # 订阅 + NotificationType.Manual, # 手动处理 } @@ -30,13 +30,18 @@ class Discord: Discord Bot 通知与交互实现(基于 discord.py 2.6.4) """ - def __init__(self, DISCORD_BOT_TOKEN: Optional[str] = None, - DISCORD_GUILD_ID: Optional[Union[str, int]] = None, - DISCORD_CHANNEL_ID: Optional[Union[str, int]] = None, - **kwargs): - logger.debug(f"[Discord] 初始化 Discord 实例: name={kwargs.get('name')}, " - f"GUILD_ID={DISCORD_GUILD_ID}, CHANNEL_ID={DISCORD_CHANNEL_ID}, " - f"TOKEN={'已配置' if DISCORD_BOT_TOKEN else '未配置'}") + def __init__( + self, + DISCORD_BOT_TOKEN: Optional[str] = None, + DISCORD_GUILD_ID: Optional[Union[str, int]] = None, + DISCORD_CHANNEL_ID: Optional[Union[str, int]] = None, + **kwargs, + ): + logger.debug( + f"[Discord] 初始化 Discord 实例: name={kwargs.get('name')}, " + f"GUILD_ID={DISCORD_GUILD_ID}, CHANNEL_ID={DISCORD_CHANNEL_ID}, " + f"TOKEN={'已配置' if DISCORD_BOT_TOKEN else '未配置'}" + ) if not DISCORD_BOT_TOKEN: logger.error("Discord Bot Token 未配置!") return @@ -44,12 +49,14 @@ class Discord: self._token = DISCORD_BOT_TOKEN self._guild_id = self._to_int(DISCORD_GUILD_ID) self._channel_id = self._to_int(DISCORD_CHANNEL_ID) - logger.debug(f"[Discord] 解析后的 ID: _guild_id={self._guild_id}, _channel_id={self._channel_id}") + logger.debug( + f"[Discord] 解析后的 ID: _guild_id={self._guild_id}, _channel_id={self._channel_id}" + ) base_ds_url = f"http://127.0.0.1:{settings.PORT}/api/v1/message/" self._ds_url = f"{base_ds_url}?token={settings.API_TOKEN}" if kwargs.get("name"): # URL encode the source name to handle special characters in config names - encoded_name = quote(kwargs.get('name'), safe='') + encoded_name = quote(kwargs.get("name"), safe="") self._ds_url = f"{self._ds_url}&source={encoded_name}" logger.debug(f"[Discord] 消息回调 URL: {self._ds_url}") @@ -59,15 +66,16 @@ class Discord: intents.guilds = True self._client: Optional[discord.Client] = discord.Client( - intents=intents, - proxy=settings.PROXY_HOST + intents=intents, proxy=settings.PROXY_HOST ) self._tree: Optional[app_commands.CommandTree] = None self._loop: asyncio.AbstractEventLoop = asyncio.new_event_loop() self._thread: Optional[threading.Thread] = None self._ready_event = threading.Event() self._user_dm_cache: Dict[str, discord.DMChannel] = {} - self._user_chat_mapping: Dict[str, str] = {} # userid -> chat_id mapping for reply targeting + self._user_chat_mapping: Dict[ + str, str + ] = {} # userid -> chat_id mapping for reply targeting self._broadcast_channel = None self._bot_user_id: Optional[int] = None @@ -96,10 +104,16 @@ class Discord: return # Update user-chat mapping for reply targeting - self._update_user_chat_mapping(str(message.author.id), str(message.channel.id)) + self._update_user_chat_mapping( + str(message.author.id), str(message.channel.id) + ) cleaned_text = self._clean_bot_mention(message.content or "") - username = message.author.display_name or message.author.global_name or message.author.name + username = ( + message.author.display_name + or message.author.global_name + or message.author.name + ) payload = { "type": "message", "userid": str(message.author.id), @@ -108,7 +122,9 @@ class Discord: "text": cleaned_text, "message_id": str(message.id), "chat_id": str(message.channel.id), - "channel_type": "dm" if isinstance(message.channel, discord.DMChannel) else "guild" + "channel_type": "dm" + if isinstance(message.channel, discord.DMChannel) + else "guild", } await self._post_to_ds(payload) @@ -126,18 +142,31 @@ class Discord: # Update user-chat mapping for reply targeting if interaction.user and interaction.channel: - self._update_user_chat_mapping(str(interaction.user.id), str(interaction.channel.id)) + self._update_user_chat_mapping( + str(interaction.user.id), str(interaction.channel.id) + ) - username = (interaction.user.display_name or interaction.user.global_name or interaction.user.name) \ - if interaction.user else None + username = ( + ( + interaction.user.display_name + or interaction.user.global_name + or interaction.user.name + ) + if interaction.user + else None + ) payload = { "type": "interaction", "userid": str(interaction.user.id) if interaction.user else None, "username": username, "user_tag": str(interaction.user) if interaction.user else None, "callback_data": callback_data, - "message_id": str(interaction.message.id) if interaction.message else None, - "chat_id": str(interaction.channel.id) if interaction.channel else None + "message_id": str(interaction.message.id) + if interaction.message + else None, + "chat_id": str(interaction.channel.id) + if interaction.channel + else None, } await self._post_to_ds(payload) @@ -165,7 +194,9 @@ class Discord: if not self._client or not self._loop or not self._thread: return try: - asyncio.run_coroutine_threadsafe(self._client.close(), self._loop).result(timeout=10) + asyncio.run_coroutine_threadsafe(self._client.close(), self._loop).result( + timeout=10 + ) except Exception as err: logger.error(f"关闭 Discord Bot 失败:{err}") finally: @@ -178,16 +209,26 @@ class Discord: def get_state(self) -> bool: return self._ready_event.is_set() and self._client is not None - def send_msg(self, title: str, text: Optional[str] = None, image: Optional[str] = None, - userid: Optional[str] = None, link: Optional[str] = None, - buttons: Optional[List[List[dict]]] = None, - original_message_id: Optional[Union[int, str]] = None, - original_chat_id: Optional[str] = None, - mtype: Optional['NotificationType'] = None) -> Optional[bool]: - logger.debug(f"[Discord] send_msg 被调用: userid={userid}, title={title[:50] if title else None}...") - logger.debug(f"[Discord] get_state() = {self.get_state()}, " - f"_ready_event.is_set() = {self._ready_event.is_set()}, " - f"_client = {self._client is not None}") + def send_msg( + self, + title: str, + text: Optional[str] = None, + image: Optional[str] = None, + userid: Optional[str] = None, + link: Optional[str] = None, + buttons: Optional[List[List[dict]]] = None, + original_message_id: Optional[Union[int, str]] = None, + original_chat_id: Optional[str] = None, + mtype: Optional["NotificationType"] = None, + ) -> Optional[bool]: + logger.debug( + f"[Discord] send_msg 被调用: userid={userid}, title={title[:50] if title else None}..." + ) + logger.debug( + f"[Discord] get_state() = {self.get_state()}, " + f"_ready_event.is_set() = {self._ready_event.is_set()}, " + f"_client = {self._client is not None}" + ) if not self.get_state(): logger.warning("[Discord] get_state() 返回 False,Bot 未就绪,无法发送消息") return False @@ -198,12 +239,19 @@ class Discord: try: logger.debug(f"[Discord] 准备异步发送消息...") future = asyncio.run_coroutine_threadsafe( - self._send_message(title=title, text=text, image=image, userid=userid, - link=link, buttons=buttons, - original_message_id=original_message_id, - original_chat_id=original_chat_id, - mtype=mtype), - self._loop) + self._send_message( + title=title, + text=text, + image=image, + userid=userid, + link=link, + buttons=buttons, + original_message_id=original_message_id, + original_chat_id=original_chat_id, + mtype=mtype, + ), + self._loop, + ) result = future.result(timeout=30) logger.debug(f"[Discord] 异步发送完成,结果: {result}") return result @@ -211,10 +259,15 @@ class Discord: logger.error(f"发送 Discord 消息失败:{err}") return False - def send_medias_msg(self, medias: List[MediaInfo], userid: Optional[str] = None, title: Optional[str] = None, - buttons: Optional[List[List[dict]]] = None, - original_message_id: Optional[Union[int, str]] = None, - original_chat_id: Optional[str] = None) -> Optional[bool]: + def send_medias_msg( + self, + medias: List[MediaInfo], + userid: Optional[str] = None, + title: Optional[str] = None, + buttons: Optional[List[List[dict]]] = None, + original_message_id: Optional[Union[int, str]] = None, + original_chat_id: Optional[str] = None, + ) -> Optional[bool]: if not self.get_state() or not medias: return False title = title or "媒体列表" @@ -223,22 +276,29 @@ class Discord: self._send_list_message( embeds=self._build_media_embeds(medias, title), userid=userid, - buttons=self._build_default_buttons(len(medias)) if not buttons else buttons, + buttons=self._build_default_buttons(len(medias)) + if not buttons + else buttons, fallback_buttons=buttons, original_message_id=original_message_id, - original_chat_id=original_chat_id + original_chat_id=original_chat_id, ), - self._loop + self._loop, ) return future.result(timeout=30) except Exception as err: logger.error(f"发送 Discord 媒体列表失败:{err}") return False - def send_torrents_msg(self, torrents: List[Context], userid: Optional[str] = None, title: Optional[str] = None, - buttons: Optional[List[List[dict]]] = None, - original_message_id: Optional[Union[int, str]] = None, - original_chat_id: Optional[str] = None) -> Optional[bool]: + def send_torrents_msg( + self, + torrents: List[Context], + userid: Optional[str] = None, + title: Optional[str] = None, + buttons: Optional[List[List[dict]]] = None, + original_message_id: Optional[Union[int, str]] = None, + original_chat_id: Optional[str] = None, + ) -> Optional[bool]: if not self.get_state() or not torrents: return False title = title or "种子列表" @@ -247,68 +307,92 @@ class Discord: self._send_list_message( embeds=self._build_torrent_embeds(torrents, title), userid=userid, - buttons=self._build_default_buttons(len(torrents)) if not buttons else buttons, + buttons=self._build_default_buttons(len(torrents)) + if not buttons + else buttons, fallback_buttons=buttons, original_message_id=original_message_id, - original_chat_id=original_chat_id + original_chat_id=original_chat_id, ), - self._loop + self._loop, ) return future.result(timeout=30) except Exception as err: logger.error(f"发送 Discord 种子列表失败:{err}") return False - def delete_msg(self, message_id: Union[str, int], chat_id: Optional[str] = None) -> Optional[bool]: + def delete_msg( + self, message_id: Union[str, int], chat_id: Optional[str] = None + ) -> Optional[bool]: if not self.get_state(): return False try: future = asyncio.run_coroutine_threadsafe( - self._delete_message(message_id=message_id, chat_id=chat_id), - self._loop + self._delete_message(message_id=message_id, chat_id=chat_id), self._loop ) return future.result(timeout=15) except Exception as err: logger.error(f"删除 Discord 消息失败:{err}") return False - async def _send_message(self, title: str, text: Optional[str], image: Optional[str], - userid: Optional[str], link: Optional[str], - buttons: Optional[List[List[dict]]], - original_message_id: Optional[Union[int, str]], - original_chat_id: Optional[str], - mtype: Optional['NotificationType'] = None) -> bool: - logger.debug(f"[Discord] _send_message: userid={userid}, original_chat_id={original_chat_id}") + async def _send_message( + self, + title: str, + text: Optional[str], + image: Optional[str], + userid: Optional[str], + link: Optional[str], + buttons: Optional[List[List[dict]]], + original_message_id: Optional[Union[int, str]], + original_chat_id: Optional[str], + mtype: Optional["NotificationType"] = None, + ) -> Tuple[bool, Optional[int]]: + logger.debug( + f"[Discord] _send_message: userid={userid}, original_chat_id={original_chat_id}" + ) channel = await self._resolve_channel(userid=userid, chat_id=original_chat_id) - logger.debug(f"[Discord] _resolve_channel 返回: {channel}, type={type(channel)}") + logger.debug( + f"[Discord] _resolve_channel 返回: {channel}, type={type(channel)}" + ) if not channel: logger.error("未找到可用的 Discord 频道或私聊") - return False + return False, None - embed = self._build_embed(title=title, text=text, image=image, link=link, mtype=mtype) + embed = self._build_embed( + title=title, text=text, image=image, link=link, mtype=mtype + ) view = self._build_view(buttons=buttons, link=link) content = None if original_message_id and original_chat_id: logger.debug(f"[Discord] 编辑现有消息: message_id={original_message_id}") - return await self._edit_message(chat_id=original_chat_id, message_id=original_message_id, - content=content, embed=embed, view=view) + success = await self._edit_message( + chat_id=original_chat_id, + message_id=original_message_id, + content=content, + embed=embed, + view=view, + ) + return success, int(original_message_id) if original_message_id else None logger.debug(f"[Discord] 发送新消息到频道: {channel}") try: - await channel.send(content=content, embed=embed, view=view) + sent_message = await channel.send(content=content, embed=embed, view=view) logger.debug("[Discord] 消息发送成功") - return True + return True, sent_message.id if sent_message else None except Exception as e: logger.error(f"[Discord] 发送消息到频道失败: {e}") - return False + return False, None - async def _send_list_message(self, embeds: List[discord.Embed], - userid: Optional[str], - buttons: Optional[List[List[dict]]], - fallback_buttons: Optional[List[List[dict]]], - original_message_id: Optional[Union[int, str]], - original_chat_id: Optional[str]) -> bool: + async def _send_list_message( + self, + embeds: List[discord.Embed], + userid: Optional[str], + buttons: Optional[List[List[dict]]], + fallback_buttons: Optional[List[List[dict]]], + original_message_id: Optional[Union[int, str]], + original_chat_id: Optional[str], + ) -> bool: channel = await self._resolve_channel(userid=userid, chat_id=original_chat_id) if not channel: logger.error("未找到可用的 Discord 频道或私聊") @@ -318,17 +402,31 @@ class Discord: embeds = embeds[:10] if embeds else [] # Discord 单条消息最多 10 个 embed if original_message_id and original_chat_id: - return await self._edit_message(chat_id=original_chat_id, message_id=original_message_id, - content=None, embed=None, view=view, embeds=embeds) + return await self._edit_message( + chat_id=original_chat_id, + message_id=original_message_id, + content=None, + embed=None, + view=view, + embeds=embeds, + ) - await channel.send(embed=embeds[0] if len(embeds) == 1 else None, - embeds=embeds if len(embeds) > 1 else None, - view=view) + await channel.send( + embed=embeds[0] if len(embeds) == 1 else None, + embeds=embeds if len(embeds) > 1 else None, + view=view, + ) return True - async def _edit_message(self, chat_id: Union[str, int], message_id: Union[str, int], - content: Optional[str], embed: Optional[discord.Embed], - view: Optional[discord.ui.View], embeds: Optional[List[discord.Embed]] = None) -> bool: + async def _edit_message( + self, + chat_id: Union[str, int], + message_id: Union[str, int], + content: Optional[str], + embed: Optional[discord.Embed], + view: Optional[discord.ui.View], + embeds: Optional[List[discord.Embed]] = None, + ) -> bool: channel = await self._resolve_channel(chat_id=str(chat_id)) if not channel: logger.error(f"未找到要编辑的 Discord 频道:{chat_id}") @@ -349,7 +447,9 @@ class Discord: logger.error(f"编辑 Discord 消息失败:{err}") return False - async def _delete_message(self, message_id: Union[str, int], chat_id: Optional[str]) -> bool: + async def _delete_message( + self, message_id: Union[str, int], chat_id: Optional[str] + ) -> bool: channel = await self._resolve_channel(chat_id=chat_id) if not channel: logger.error("删除 Discord 消息时未找到频道") @@ -363,11 +463,17 @@ class Discord: return False @staticmethod - def _build_embed(title: str, text: Optional[str], image: Optional[str], - link: Optional[str], mtype: Optional['NotificationType'] = None) -> discord.Embed: + def _build_embed( + title: str, + text: Optional[str], + image: Optional[str], + link: Optional[str], + mtype: Optional["NotificationType"] = None, + ) -> discord.Embed: fields: List[Dict[str, str]] = [] desc_lines: List[str] = [] should_parse_fields = mtype in PARSE_FIELD_TYPES if mtype else False + def _collect_spans(s: str, left: str, right: str) -> List[Tuple[int, int]]: spans: List[Tuple[int, int]] = [] start = 0 @@ -383,7 +489,7 @@ class Discord: return spans def _find_colon_index(s: str, m: re.Match) -> Optional[int]: - segment = s[m.start():m.end()] + segment = s[m.start() : m.end()] for i, ch in enumerate(segment): if ch in (":", ":"): return m.start() + i @@ -392,7 +498,11 @@ class Discord: if text: # 处理上游未反序列化的 "\n" 等转义换行,避免被当成普通字符 if "\\n" in text or "\\r" in text: - text = text.replace("\\r\\n", "\n").replace("\\n", "\n").replace("\\r", "\n") + text = ( + text.replace("\\r\\n", "\n") + .replace("\\n", "\n") + .replace("\\r", "\n") + ) if not should_parse_fields: desc_lines.append(text.strip()) else: @@ -410,12 +520,16 @@ class Discord: continue matches = list(pair_pattern.finditer(line)) if matches: - book_spans = _collect_spans(line, "《", "》") + _collect_spans(line, "【", "】") + book_spans = _collect_spans(line, "《", "》") + _collect_spans( + line, "【", "】" + ) if book_spans: has_book_colon = False for m in matches: colon_idx = _find_colon_index(line, m) - if colon_idx is not None and any(l < colon_idx < r for l, r in book_spans): + if colon_idx is not None and any( + l < colon_idx < r for l, r in book_spans + ): has_book_colon = True break if has_book_colon: @@ -423,20 +537,25 @@ class Discord: continue # 若整行只是 URL/时间等自然包含":"的内容,则不当作字段 url_like_names = {"http", "https", "ftp", "ftps", "magnet"} - if all(m.group(1).lower() in url_like_names or m.group(1).isdigit() for m in matches): + if all( + m.group(1).lower() in url_like_names or m.group(1).isdigit() + for m in matches + ): desc_lines.append(line) continue last_end = 0 for m in matches: # 追加匹配前的非空文本到描述 - prefix = line[last_end:m.start()].strip(" ,,;;。、") + prefix = line[last_end : m.start()].strip(" ,,;;。、") # 仅当前缀不全是分隔符/空白时才记录 if prefix and prefix.strip(" ,,;;。、"): desc_lines.append(prefix) name = m.group(1).strip() value = m.group(2).strip(" ,,;;。、\t") or "-" if name: - fields.append({"name": name, "value": value, "inline": False}) + fields.append( + {"name": name, "value": value, "inline": False} + ) last_end = m.end() # 匹配末尾后的文本 suffix = line[last_end:].strip(" ,,;;。、") @@ -451,7 +570,7 @@ class Discord: title=title, url=link or "https://github.com/jxxghp/MoviePilot", description=description if description else None, - color=0xE67E22 + color=0xE67E22, ) for field in fields: embed.add_field(name=field["name"], value=field["value"], inline=False) @@ -465,14 +584,16 @@ class Discord: for index, media in enumerate(medias[:10], start=1): overview = media.get_overview_string(80) desc_parts = [ - f"{media.type.value} | {media.vote_star}" if media.vote_star else media.type.value, - overview + f"{media.type.value} | {media.vote_star}" + if media.vote_star + else media.type.value, + overview, ] embed = discord.Embed( title=f"{index}. {media.title_year}", url=media.detail_link or discord.Embed.Empty, description="\n".join([p for p in desc_parts if p]), - color=0x5865F2 + color=0x5865F2, ) if media.get_poster_image(): embed.set_thumbnail(url=media.get_poster_image()) @@ -482,7 +603,9 @@ class Discord: return embeds @staticmethod - def _build_torrent_embeds(torrents: List[Context], title: str) -> List[discord.Embed]: + def _build_torrent_embeds( + torrents: List[Context], title: str + ) -> List[discord.Embed]: embeds: List[discord.Embed] = [] for index, context in enumerate(torrents[:10], start=1): torrent = context.torrent_info @@ -492,13 +615,13 @@ class Discord: detail = [ f"{torrent.site_name} | {StringUtils.str_filesize(torrent.size)} | {torrent.volume_factor} | {torrent.seeders}↑", meta.resource_term, - meta.video_term + meta.video_term, ] embed = discord.Embed( title=f"{index}. {title_text or torrent.title}", url=torrent.page_url or discord.Embed.Empty, description="\n".join([d for d in detail if d]), - color=0x00A86B + color=0x00A86B, ) poster = getattr(torrent, "poster", None) if poster: @@ -524,7 +647,9 @@ class Discord: return buttons @staticmethod - def _build_view(buttons: Optional[List[List[dict]]], link: Optional[str] = None) -> Optional[discord.ui.View]: + def _build_view( + buttons: Optional[List[List[dict]]], link: Optional[str] = None + ) -> Optional[discord.ui.View]: has_buttons = buttons and any(buttons) if not has_buttons and not link: return None @@ -534,20 +659,34 @@ class Discord: for row_index, button_row in enumerate(buttons[:5]): for button in button_row[:5]: if "url" in button: - btn = discord.ui.Button(label=button.get("text", "链接"), - url=button["url"], - style=discord.ButtonStyle.link) + btn = discord.ui.Button( + label=button.get("text", "链接"), + url=button["url"], + style=discord.ButtonStyle.link, + ) else: - custom_id = (button.get("callback_data") or button.get("text") or f"btn-{row_index}")[:99] - btn = discord.ui.Button(label=button.get("text", "选择")[:80], - custom_id=custom_id, - style=discord.ButtonStyle.primary) + custom_id = ( + button.get("callback_data") + or button.get("text") + or f"btn-{row_index}" + )[:99] + btn = discord.ui.Button( + label=button.get("text", "选择")[:80], + custom_id=custom_id, + style=discord.ButtonStyle.primary, + ) view.add_item(btn) elif link: - view.add_item(discord.ui.Button(label="查看详情", url=link, style=discord.ButtonStyle.link)) + view.add_item( + discord.ui.Button( + label="查看详情", url=link, style=discord.ButtonStyle.link + ) + ) return view - async def _resolve_channel(self, userid: Optional[str] = None, chat_id: Optional[str] = None): + async def _resolve_channel( + self, userid: Optional[str] = None, chat_id: Optional[str] = None + ): """ Resolve the channel to send messages to. Priority order: @@ -557,8 +696,10 @@ class Discord: 4. Any available text channel in configured guild - fallback 5. `userid` (DM) - for private conversations as a final fallback """ - logger.debug(f"[Discord] _resolve_channel: userid={userid}, chat_id={chat_id}, " - f"_channel_id={self._channel_id}, _guild_id={self._guild_id}") + logger.debug( + f"[Discord] _resolve_channel: userid={userid}, chat_id={chat_id}, " + f"_channel_id={self._channel_id}, _guild_id={self._guild_id}" + ) # Priority 1: Use explicit chat_id (reply to the same channel where user sent message) if chat_id: @@ -585,7 +726,9 @@ class Discord: return channel try: channel = await self._client.fetch_channel(int(mapped_chat_id)) - logger.debug(f"[Discord] 通过 fetch_channel 找到映射频道: {channel}") + logger.debug( + f"[Discord] 通过 fetch_channel 找到映射频道: {channel}" + ) return channel except Exception as err: logger.warn(f"通过映射的 chat_id 获取 Discord 频道失败:{err}") @@ -595,7 +738,9 @@ class Discord: logger.debug(f"[Discord] 使用缓存的广播频道: {self._broadcast_channel}") return self._broadcast_channel if self._channel_id: - logger.debug(f"[Discord] 尝试通过配置的 _channel_id={self._channel_id} 获取频道") + logger.debug( + f"[Discord] 尝试通过配置的 _channel_id={self._channel_id} 获取频道" + ) channel = self._client.get_channel(self._channel_id) if not channel: try: @@ -641,7 +786,9 @@ class Discord: async def _get_dm_channel(self, userid: str) -> Optional[discord.DMChannel]: logger.debug(f"[Discord] _get_dm_channel: userid={userid}") if userid in self._user_dm_cache: - logger.debug(f"[Discord] 从缓存获取私聊频道: {self._user_dm_cache.get(userid)}") + logger.debug( + f"[Discord] 从缓存获取私聊频道: {self._user_dm_cache.get(userid)}" + ) return self._user_dm_cache.get(userid) try: logger.debug(f"[Discord] 尝试获取/创建用户 {userid} 的私聊频道") @@ -674,7 +821,9 @@ class Discord: """ if userid and chat_id: self._user_chat_mapping[userid] = chat_id - logger.debug(f"[Discord] 更新用户频道映射: userid={userid} -> chat_id={chat_id}") + logger.debug( + f"[Discord] 更新用户频道映射: userid={userid} -> chat_id={chat_id}" + ) def _get_user_chat_id(self, userid: str) -> Optional[str]: """ @@ -708,7 +857,9 @@ class Discord: proxy = None if settings.PROXY: proxy = settings.PROXY.get("https") or settings.PROXY.get("http") - async with httpx.AsyncClient(timeout=10, verify=False, proxy=proxy) as client: + async with httpx.AsyncClient( + timeout=10, verify=False, proxy=proxy + ) as client: await client.post(self._ds_url, json=payload) except Exception as err: logger.error(f"转发 Discord 消息失败:{err}") diff --git a/app/modules/slack/__init__.py b/app/modules/slack/__init__.py index 0ae6b671..61889ad9 100644 --- a/app/modules/slack/__init__.py +++ b/app/modules/slack/__init__.py @@ -6,18 +6,16 @@ from app.core.context import MediaInfo, Context from app.log import logger from app.modules import _ModuleBase, _MessageBase from app.modules.slack.slack import Slack -from app.schemas import MessageChannel, CommingMessage, Notification +from app.schemas import MessageChannel, CommingMessage, Notification, MessageResponse from app.schemas.types import ModuleType class SlackModule(_ModuleBase, _MessageBase[Slack]): - def init_module(self) -> None: """ 初始化模块 """ - super().init_service(service_name=Slack.__name__.lower(), - service_type=Slack) + super().init_service(service_name=Slack.__name__.lower(), service_type=Slack) self._channel = MessageChannel.Slack @staticmethod @@ -67,7 +65,9 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]): def init_setting(self) -> Tuple[str, Union[str, bool]]: pass - def message_parser(self, source: str, body: Any, form: Any, args: Any) -> Optional[CommingMessage]: + def message_parser( + self, source: str, body: Any, form: Any, args: Any + ) -> Optional[CommingMessage]: """ 解析消息内容,返回字典,注意以下约定值: userid: 用户ID @@ -213,10 +213,14 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]): message_info = msg_json.get("message", {}) # Slack消息的时间戳作为消息ID message_ts = message_info.get("ts") - channel_id = msg_json.get("channel", {}).get("id") or msg_json.get("container", {}).get("channel_id") + channel_id = msg_json.get("channel", {}).get("id") or msg_json.get( + "container", {} + ).get("channel_id") - logger.info(f"收到来自 {client_config.name} 的Slack按钮回调:" - f"userid={userid}, username={username}, callback_data={callback_data}") + logger.info( + f"收到来自 {client_config.name} 的Slack按钮回调:" + f"userid={userid}, username={username}, callback_data={callback_data}" + ) # 创建包含回调信息的CommingMessage return CommingMessage( @@ -228,11 +232,16 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]): is_callback=True, callback_data=callback_data, message_id=message_ts, - chat_id=channel_id + chat_id=channel_id, ) elif msg_json.get("type") == "event_callback": - userid = msg_json.get('event', {}).get('user') - text = re.sub(r"<@[0-9A-Z]+>", "", msg_json.get("event", {}).get("text"), flags=re.IGNORECASE).strip() + userid = msg_json.get("event", {}).get("user") + text = re.sub( + r"<@[0-9A-Z]+>", + "", + msg_json.get("event", {}).get("text"), + flags=re.IGNORECASE, + ).strip() username = "" elif msg_json.get("type") == "shortcut": userid = msg_json.get("user", {}).get("id") @@ -244,9 +253,16 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]): username = msg_json.get("user_name") else: return None - logger.info(f"收到来自 {client_config.name} 的Slack消息:userid={userid}, username={username}, text={text}") - return CommingMessage(channel=MessageChannel.Slack, source=client_config.name, - userid=userid, username=username, text=text) + logger.info( + f"收到来自 {client_config.name} 的Slack消息:userid={userid}, username={username}, text={text}" + ) + return CommingMessage( + channel=MessageChannel.Slack, + source=client_config.name, + userid=userid, + username=username, + text=text, + ) return None def post_message(self, message: Notification, **kwargs) -> None: @@ -261,19 +277,26 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]): targets = message.targets userid = message.userid if not userid and targets is not None: - userid = targets.get('slack_userid') + userid = targets.get("slack_userid") if not userid: logger.warn(f"用户没有指定 Slack用户ID,消息无法发送") return client: Slack = self.get_instance(conf.name) if client: - client.send_msg(title=message.title, text=message.text, - image=message.image, userid=userid, link=message.link, - buttons=message.buttons, - original_message_id=message.original_message_id, - original_chat_id=message.original_chat_id) + client.send_msg( + title=message.title, + text=message.text, + image=message.image, + userid=userid, + link=message.link, + buttons=message.buttons, + original_message_id=message.original_message_id, + original_chat_id=message.original_chat_id, + ) - def post_medias_message(self, message: Notification, medias: List[MediaInfo]) -> None: + def post_medias_message( + self, message: Notification, medias: List[MediaInfo] + ) -> None: """ 发送媒体信息选择列表 :param message: 消息体 @@ -285,12 +308,18 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]): continue client: Slack = self.get_instance(conf.name) if client: - client.send_medias_msg(title=message.title, medias=medias, userid=message.userid, - buttons=message.buttons, - original_message_id=message.original_message_id, - original_chat_id=message.original_chat_id) + client.send_medias_msg( + title=message.title, + medias=medias, + userid=message.userid, + buttons=message.buttons, + original_message_id=message.original_message_id, + original_chat_id=message.original_chat_id, + ) - def post_torrents_message(self, message: Notification, torrents: List[Context]) -> None: + def post_torrents_message( + self, message: Notification, torrents: List[Context] + ) -> None: """ 发送种子信息选择列表 :param message: 消息体 @@ -302,13 +331,22 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]): continue client: Slack = self.get_instance(conf.name) if client: - client.send_torrents_msg(title=message.title, torrents=torrents, - userid=message.userid, buttons=message.buttons, - original_message_id=message.original_message_id, - original_chat_id=message.original_chat_id) + client.send_torrents_msg( + title=message.title, + torrents=torrents, + userid=message.userid, + buttons=message.buttons, + original_message_id=message.original_message_id, + original_chat_id=message.original_chat_id, + ) - def delete_message(self, channel: MessageChannel, source: str, - message_id: str, chat_id: Optional[str] = None) -> bool: + def delete_message( + self, + channel: MessageChannel, + source: str, + message_id: str, + chat_id: Optional[str] = None, + ) -> bool: """ 删除消息 :param channel: 消息渠道 @@ -329,3 +367,86 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]): if result: success = True return success + + def edit_message( + self, + channel: MessageChannel, + source: str, + message_id: Union[str, int], + chat_id: Union[str, int], + text: str, + title: Optional[str] = None, + ) -> bool: + """ + 编辑消息 + :param channel: 消息渠道 + :param source: 指定的消息源 + :param message_id: 消息ID + :param chat_id: 聊天ID + :param text: 新的消息内容 + :param title: 消息标题 + :return: 编辑是否成功 + """ + if channel != self._channel: + return False + for conf in self.get_configs().values(): + if source != conf.name: + continue + client: Slack = self.get_instance(conf.name) + if client: + result = client.send_msg( + title=title or "", + text=text, + original_message_id=str(message_id), + original_chat_id=str(chat_id), + ) + if result and result[0]: + return True + return False + + def send_direct_message(self, message: Notification) -> Optional[MessageResponse]: + """ + 直接发送消息并返回消息ID等信息 + :param message: 消息体 + :return: 消息响应(包含message_id, chat_id等) + """ + for conf in self.get_configs().values(): + if not self.check_message(message, conf.name): + continue + targets = message.targets + userid = message.userid + if not userid and targets is not None: + userid = targets.get("slack_userid") + if not userid: + logger.warn("用户没有指定 Slack 用户ID,消息无法发送") + return None + client: Slack = self.get_instance(conf.name) + if client: + result = client.send_msg( + title=message.title or "", + text=message.text, + userid=userid, + ) + if result and result[0]: + # Slack 使用时间戳作为 message_id,chat_id 是频道ID + # 注意:这里返回的是发送后的结果,需要获取实际的 message_id + # 由于 Slack API 返回的是 result[1],包含完整响应,我们需要从中提取 + response_data = result[1] + message_id = ( + response_data.get("ts") + if isinstance(response_data, dict) + else None + ) + channel_id = ( + response_data.get("channel") + if isinstance(response_data, dict) + else None + ) + return MessageResponse( + message_id=message_id, + chat_id=channel_id, + channel=MessageChannel.Slack, + source=conf.name, + success=True, + ) + return None