diff --git a/app/agent/__init__.py b/app/agent/__init__.py index 59a28875..9730ccbb 100644 --- a/app/agent/__init__.py +++ b/app/agent/__init__.py @@ -5,7 +5,8 @@ from typing import Dict, List from langchain.agents import create_agent from langchain.agents.middleware import ( - SummarizationMiddleware, LLMToolSelectorMiddleware, + SummarizationMiddleware, + LLMToolSelectorMiddleware, ) from langchain_core.messages import ( HumanMessage, @@ -36,12 +37,12 @@ class MoviePilotAgent: """ def __init__( - self, - session_id: str, - user_id: str = None, - channel: str = None, - source: str = None, - username: str = None, + self, + session_id: str, + user_id: str = None, + channel: str = None, + source: str = None, + username: str = None, ): self.session_id = session_id self.user_id = user_id @@ -80,9 +81,7 @@ class MoviePilotAgent: # 系统提示词 system_prompt = prompt_manager.get_agent_prompt( channel=self.channel - ).format( - current_date=strftime('%Y-%m-%d') - ) + ).format(current_date=strftime("%Y-%m-%d")) # LLM 模型(用于 agent 执行) llm = self._initialize_llm() @@ -93,21 +92,15 @@ class MoviePilotAgent: # 中间件 middlewares = [ # 工具选择 - LLMToolSelectorMiddleware( - model=llm, - max_tools=20 - ), + LLMToolSelectorMiddleware(model=llm, max_tools=20), # 记忆管理 MemoryMiddleware( sources=[str(settings.CONFIG_PATH / "agent" / "MEMORY.md")] ), # 上下文压缩 - SummarizationMiddleware( - model=llm, - trigger=("fraction", 0.85) - ), + SummarizationMiddleware(model=llm, trigger=("fraction", 0.85)), # 错误工具调用修复 - PatchToolCallsMiddleware() + PatchToolCallsMiddleware(), ] return create_agent( @@ -130,8 +123,7 @@ class MoviePilotAgent: # 获取历史消息 messages = memory_manager.get_agent_messages( - session_id=self.session_id, - user_id=self.user_id + session_id=self.session_id, user_id=self.user_id ) # 增加用户消息 @@ -150,6 +142,7 @@ class MoviePilotAgent: """ 调用 LangGraph Agent,通过 astream_events 流式获取 token, 同时用 UsageMetadataCallbackHandler 统计 token 用量。 + 支持流式输出:在支持消息编辑的渠道上实时推送 token。 """ try: # Agent运行配置 @@ -162,37 +155,57 @@ class MoviePilotAgent: # 创建智能体 agent = self._create_agent() + # 启动流式输出(内部会检查渠道是否支持消息编辑) + await self.stream_handler.start_streaming( + channel=self.channel, + source=self.source, + user_id=self.user_id, + username=self.username, + ) + # 流式运行智能体 async for chunk in agent.astream( - {"messages": messages}, - stream_mode="messages", - config=agent_config, - version="v2" + {"messages": messages}, + stream_mode="messages", + config=agent_config, + version="v2", ): # 处理流式token(过滤工具调用token,只保留模型生成的内容) if chunk["type"] == "messages": token, metadata = chunk["data"] - if (token and hasattr(token, "tool_call_chunks") - and not token.tool_call_chunks): + if ( + token + and hasattr(token, "tool_call_chunks") + and not token.tool_call_chunks + ): if token.content: self.stream_handler.emit(token.content) - # 发送最终消息给用户 - await self.send_agent_message( - self.stream_handler.take() - ) + # 停止流式输出,返回是否已通过流式编辑发送了所有内容 + all_sent_via_stream = await self.stream_handler.stop_streaming() + + if not all_sent_via_stream: + # 流式输出未能发送全部内容(渠道不支持编辑,或发送失败) + # 通过常规方式发送剩余内容 + remaining_text = await self.stream_handler.take() + if remaining_text: + await self.send_agent_message(remaining_text) # 保存消息 memory_manager.save_agent_messages( session_id=self.session_id, user_id=self.user_id, - messages=agent.get_state(agent_config).values.get("messages", []) + messages=agent.get_state(agent_config).values.get("messages", []), ) except asyncio.CancelledError: + # 确保取消时也停止流式输出 + await self.stream_handler.stop_streaming() logger.info(f"Agent执行被取消: session_id={self.session_id}") return "任务已取消", {} except Exception as e: + # 确保异常时也停止流式输出 + await self.stream_handler.stop_streaming() logger.error(f"Agent执行失败: {e} - {traceback.format_exc()}") return str(e), {} @@ -243,13 +256,13 @@ class AgentManager: self.active_agents.clear() async def process_message( - self, - session_id: str, - user_id: str, - message: str, - channel: str = None, - source: str = None, - username: str = None, + self, + session_id: str, + user_id: str, + message: str, + channel: str = None, + source: str = None, + username: str = None, ) -> str: """ 处理用户消息 diff --git a/app/agent/callback/__init__.py b/app/agent/callback/__init__.py index 963834c4..e4b6095a 100644 --- a/app/agent/callback/__init__.py +++ b/app/agent/callback/__init__.py @@ -1,6 +1,20 @@ +import asyncio import threading +from typing import Optional +from app.chain import ChainBase from app.log import logger +from app.schemas import Notification +from app.schemas.message import ( + MessageResponse, + ChannelCapabilityManager, + ChannelCapability, +) +from app.schemas.types import MessageChannel + + +class _StreamChain(ChainBase): + pass class StreamingHandler: @@ -8,11 +22,39 @@ class StreamingHandler: 流式Token缓冲管理器 负责从 LLM 流式 token 中积累文本,供 Agent 在工具调用之间穿插发送中间消息。 + 当启用流式输出时,通过定时编辑消息将新产生的 tokens 实时推送给用户。 + + 工作流程: + 1. Agent开始处理时调用 start_streaming(),检查渠道能力并启动定时刷新 + 2. LLM 产生 token 时调用 emit() 积累到缓冲区 + 3. 定时器周期性调用 _flush(): + - 第一次有内容时发送新消息(通过 send_direct_message 获取 message_id) + - 后续有新内容时编辑同一条消息(通过 edit_message) + 4. 工具调用时 take() 被调用:取走缓冲区内容(如果已流式发送则返回空), + 重置消息状态以便工具调用后的新内容开启新的流式消息 + 5. Agent最终完成时调用 stop_streaming():执行最后一次刷新, + 返回是否已通过流式发送完所有内容(调用方据此决定是否还需额外发送) """ + # 流式输出的刷新间隔(秒) + FLUSH_INTERVAL = 3.0 + def __init__(self): self._lock = threading.Lock() self._buffer = "" + # 流式输出相关状态 + self._streaming_enabled = False + self._flush_task: Optional[asyncio.Task] = None + # 当前消息的发送信息(用于编辑消息) + self._message_response: Optional[MessageResponse] = None + # 已发送给用户的文本(用于追踪增量) + self._sent_text = "" + # 消息发送所需的上下文信息 + self._channel: Optional[str] = None + self._source: Optional[str] = None + self._user_id: Optional[str] = None + self._username: Optional[str] = None + self._title: str = "MoviePilot助手" def emit(self, token: str): """ @@ -21,17 +63,51 @@ class StreamingHandler: with self._lock: self._buffer += token - def take(self) -> str: + async def take(self) -> str: """ 获取当前已积累的消息内容,获取后清空缓冲区。 + + 当流式输出启用时: + 1. 先暂停 flush loop(避免与后续发送产生竞争) + 2. 执行最终一次 flush(确保已有内容完整推送到流式消息) + 3. 如果内容已全部通过流式编辑发送给用户,返回空字符串(避免重复发送) + 4. 重置消息状态,以便工具执行后 LLM 产出的新内容开启新的流式消息 + 5. 重新启动 flush loop(恢复后续流式输出能力) """ + if self._streaming_enabled: + # 暂停 flush loop + await self._cancel_flush_task() + # 执行最终一次 flush,确保当前流式消息是完整的 + await self._flush() + with self._lock: if not self._buffer: - return "" - message = self._buffer - logger.info(f"Agent消息: {message}") - self._buffer = "" - return message + message = "" + already_sent = False + else: + message = self._buffer + logger.info(f"Agent消息: {message}") + + # 如果流式输出已经把内容发给用户了,工具不需要再发 + already_sent = ( + self._streaming_enabled + and self._message_response is not None + and self._sent_text == self._buffer + ) + + self._buffer = "" + + # 重置流式消息状态,下次有新内容时会开启新消息 + self._sent_text = "" + self._message_response = None + + # 恢复 flush loop(工具执行完成后 LLM 继续产出 token 时需要) + if self._streaming_enabled: + await self._restart_flush_loop() + + if already_sent or not message: + return "" + return message def clear(self): """ @@ -39,3 +115,196 @@ class StreamingHandler: """ with self._lock: self._buffer = "" + self._sent_text = "" + self._message_response = None + + async def start_streaming( + self, + channel: Optional[str] = None, + source: Optional[str] = None, + user_id: Optional[str] = None, + username: Optional[str] = None, + title: str = "MoviePilot助手", + ): + """ + 启动流式输出。检查渠道是否支持消息编辑,如果支持则启动定时刷新任务。 + :param channel: 消息渠道 + :param source: 消息来源 + :param user_id: 用户ID + :param username: 用户名 + :param title: 消息标题 + """ + self._channel = channel + self._source = source + self._user_id = user_id + self._username = username + self._title = title + + # 检查渠道是否支持消息编辑 + if not self._can_stream(): + logger.debug(f"渠道 {channel} 不支持消息编辑,不启用流式输出") + return + + self._streaming_enabled = True + self._sent_text = "" + self._message_response = None + + # 启动异步定时刷新任务 + self._flush_task = asyncio.create_task(self._flush_loop()) + logger.debug("流式输出已启动") + + async def stop_streaming(self) -> bool: + """ + 停止流式输出。执行最后一次刷新确保所有内容都已发送。 + :return: 是否已经通过流式编辑将最终完整内容发送给了用户 + (True 表示调用方无需再额外发送消息) + """ + if not self._streaming_enabled: + return False + + self._streaming_enabled = False + + # 取消定时任务 + await self._cancel_flush_task() + + # 执行最后一次刷新 + await self._flush() + + # 检查是否所有缓冲内容都已发送 + with self._lock: + all_sent = ( + self._message_response is not None + and self._sent_text + and self._buffer == self._sent_text + ) + # 重置状态 + self._sent_text = "" + self._message_response = None + if all_sent: + # 所有内容已通过流式发送,清空缓冲区 + self._buffer = "" + return all_sent + + def _can_stream(self) -> bool: + """ + 检查当前渠道是否支持流式输出(消息编辑) + """ + if not self._channel: + return False + try: + channel_enum = MessageChannel(self._channel) + return ChannelCapabilityManager.supports_capability( + channel_enum, ChannelCapability.MESSAGE_EDITING + ) + except (ValueError, KeyError): + return False + + async def _flush_loop(self): + """ + 定时刷新循环,定期将缓冲区内容发送/编辑到用户 + """ + try: + while self._streaming_enabled: + await asyncio.sleep(self.FLUSH_INTERVAL) + if self._streaming_enabled: + await self._flush() + except asyncio.CancelledError: + pass + except Exception as e: + logger.error(f"流式刷新异常: {e}") + + async def _cancel_flush_task(self): + """ + 取消当前的定时刷新任务 + """ + if self._flush_task and not self._flush_task.done(): + self._flush_task.cancel() + try: + await self._flush_task + except asyncio.CancelledError: + pass + self._flush_task = None + + async def _restart_flush_loop(self): + """ + 重新启动定时刷新任务(用于 take() 后恢复流式输出) + """ + if not self._streaming_enabled: + return + self._flush_task = asyncio.create_task(self._flush_loop()) + + async def _flush(self): + """ + 将当前缓冲区内容刷新到用户消息 + - 如果还没有发送过消息,先发送一条新消息并记录message_id + - 如果已经发送过消息,编辑该消息为最新的完整内容 + """ + with self._lock: + current_text = self._buffer + if not current_text or current_text == self._sent_text: + # 没有新内容需要刷新 + return + + chain = _StreamChain() + + try: + if self._message_response is None: + # 第一次发送:发送新消息并获取 message_id + response = chain.send_direct_message( + Notification( + channel=self._channel, + source=self._source, + userid=self._user_id, + username=self._username, + title=self._title, + text=current_text, + ) + ) + if response and response.success and response.message_id: + self._message_response = response + with self._lock: + self._sent_text = current_text + logger.debug( + f"流式输出初始消息已发送: message_id={response.message_id}" + ) + else: + logger.debug( + "流式输出初始消息发送失败或未返回message_id,降级为非流式输出" + ) + self._streaming_enabled = False + else: + # 后续更新:编辑已有消息 + try: + channel_enum = MessageChannel(self._channel) + except (ValueError, KeyError): + return + + success = chain.edit_message( + channel=channel_enum, + source=self._message_response.source, + message_id=self._message_response.message_id, + chat_id=self._message_response.chat_id, + text=current_text, + title=self._title, + ) + if success: + with self._lock: + self._sent_text = current_text + else: + logger.debug("流式输出消息编辑失败") + except Exception as e: + logger.error(f"流式输出刷新失败: {e}") + + @property + def is_streaming(self) -> bool: + """ + 是否正在流式输出 + """ + return self._streaming_enabled + + @property + def has_sent_message(self) -> bool: + """ + 是否已经通过流式输出发送过消息(当前轮次) + """ + return self._message_response is not None diff --git a/app/agent/tools/base.py b/app/agent/tools/base.py index bb1b12d5..6f0fe712 100644 --- a/app/agent/tools/base.py +++ b/app/agent/tools/base.py @@ -45,7 +45,7 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta): """ # 获取工具调用前 Agent 已积累的流式文本 agent_message = ( - self._stream_handler.take() if self._stream_handler else "" + await self._stream_handler.take() if self._stream_handler else "" ) # 获取工具执行提示消息 diff --git a/app/chain/__init__.py b/app/chain/__init__.py index f8a2aebd..246c2887 100644 --- a/app/chain/__init__.py +++ b/app/chain/__init__.py @@ -24,10 +24,28 @@ from app.db.user_oper import UserOper from app.helper.message import MessageHelper, MessageQueueManager, MessageTemplateHelper from app.helper.service import ServiceConfigHelper from app.log import logger -from app.schemas import TransferInfo, TransferTorrent, ExistMediaInfo, DownloadingTorrent, CommingMessage, Notification, \ - WebhookEventInfo, TmdbEpisode, MediaPerson, FileItem, TransferDirectoryConf +from app.schemas import ( + TransferInfo, + TransferTorrent, + ExistMediaInfo, + DownloadingTorrent, + CommingMessage, + Notification, + WebhookEventInfo, + TmdbEpisode, + MediaPerson, + FileItem, + TransferDirectoryConf, + MessageResponse, +) from app.schemas.category import CategoryConfig -from app.schemas.types import TorrentStatus, MediaType, MediaImageType, EventType, MessageChannel +from app.schemas.types import ( + TorrentStatus, + MediaType, + MediaImageType, + EventType, + MessageChannel, +) from app.utils.object import ObjectUtils @@ -44,9 +62,7 @@ class ChainBase(metaclass=ABCMeta): self.eventmanager = EventManager() self.messageoper = MessageOper() self.messagehelper = MessageHelper() - self.messagequeue = MessageQueueManager( - send_callback=self.run_module - ) + self.messagequeue = MessageQueueManager(send_callback=self.run_module) self.pluginmanager = PluginManager() self.filecache = FileCache() self.async_filecache = AsyncFileCache() @@ -119,17 +135,20 @@ class ChainBase(metaclass=ABCMeta): else: return ret is None - def __handle_plugin_error(self, err: Exception, plugin_id: str, plugin_name: str, method: str, **kwargs): + def __handle_plugin_error( + self, err: Exception, plugin_id: str, plugin_name: str, method: str, **kwargs + ): """ 处理插件模块执行错误 """ if kwargs.get("raise_exception"): raise logger.error( - f"运行插件 {plugin_id} 模块 {method} 出错:{str(err)}\n{traceback.format_exc()}") - self.messagehelper.put(title=f"{plugin_name} 发生了错误", - message=str(err), - role="plugin") + f"运行插件 {plugin_id} 模块 {method} 出错:{str(err)}\n{traceback.format_exc()}" + ) + self.messagehelper.put( + title=f"{plugin_name} 发生了错误", message=str(err), role="plugin" + ) self.eventmanager.send_event( EventType.SystemError, { @@ -138,21 +157,24 @@ class ChainBase(metaclass=ABCMeta): "plugin_name": plugin_name, "plugin_method": method, "error": str(err), - "traceback": traceback.format_exc() - } + "traceback": traceback.format_exc(), + }, ) - def __handle_system_error(self, err: Exception, module_id: str, module_name: str, method: str, **kwargs): + def __handle_system_error( + self, err: Exception, module_id: str, module_name: str, method: str, **kwargs + ): """ 处理系统模块执行错误 """ if kwargs.get("raise_exception"): raise logger.error( - f"运行模块 {module_id}.{method} 出错:{str(err)}\n{traceback.format_exc()}") - self.messagehelper.put(title=f"{module_name}发生了错误", - message=str(err), - role="system") + f"运行模块 {module_id}.{method} 出错:{str(err)}\n{traceback.format_exc()}" + ) + self.messagehelper.put( + title=f"{module_name}发生了错误", message=str(err), role="system" + ) self.eventmanager.send_event( EventType.SystemError, { @@ -161,11 +183,13 @@ class ChainBase(metaclass=ABCMeta): "module_name": module_name, "module_method": method, "error": str(err), - "traceback": traceback.format_exc() - } + "traceback": traceback.format_exc(), + }, ) - def __execute_plugin_modules(self, method: str, result: Any, *args, **kwargs) -> Any: + def __execute_plugin_modules( + self, method: str, result: Any, *args, **kwargs + ) -> Any: """ 执行插件模块 """ @@ -187,10 +211,14 @@ class ChainBase(metaclass=ABCMeta): else: break except Exception as err: - self.__handle_plugin_error(err, plugin_id, plugin_name, method, **kwargs) + self.__handle_plugin_error( + err, plugin_id, plugin_name, method, **kwargs + ) return result - async def __async_execute_plugin_modules(self, method: str, result: Any, *args, **kwargs) -> Any: + async def __async_execute_plugin_modules( + self, method: str, result: Any, *args, **kwargs + ) -> Any: """ 异步执行插件模块 """ @@ -220,15 +248,22 @@ class ChainBase(metaclass=ABCMeta): else: break except Exception as err: - self.__handle_plugin_error(err, plugin_id, plugin_name, method, **kwargs) + self.__handle_plugin_error( + err, plugin_id, plugin_name, method, **kwargs + ) return result - def __execute_system_modules(self, method: str, result: Any, *args, **kwargs) -> Any: + def __execute_system_modules( + self, method: str, result: Any, *args, **kwargs + ) -> Any: """ 执行系统模块 """ logger.debug(f"请求系统模块执行:{method} ...") - for module in sorted(self.modulemanager.get_running_modules(method), key=lambda x: x.get_priority()): + for module in sorted( + self.modulemanager.get_running_modules(method), + key=lambda x: x.get_priority(), + ): module_id = module.__class__.__name__ try: module_name = module.get_name() @@ -253,15 +288,22 @@ class ChainBase(metaclass=ABCMeta): break except Exception as err: logger.error(traceback.format_exc()) - self.__handle_system_error(err, module_id, module_name, method, **kwargs) + self.__handle_system_error( + err, module_id, module_name, method, **kwargs + ) return result - async def __async_execute_system_modules(self, method: str, result: Any, *args, **kwargs) -> Any: + async def __async_execute_system_modules( + self, method: str, result: Any, *args, **kwargs + ) -> Any: """ 异步执行系统模块 """ logger.debug(f"请求系统模块执行:{method} ...") - for module in sorted(self.modulemanager.get_running_modules(method), key=lambda x: x.get_priority()): + for module in sorted( + self.modulemanager.get_running_modules(method), + key=lambda x: x.get_priority(), + ): module_id = module.__class__.__name__ try: module_name = module.get_name() @@ -295,7 +337,9 @@ class ChainBase(metaclass=ABCMeta): break except Exception as err: logger.error(traceback.format_exc()) - self.__handle_system_error(err, module_id, module_name, method, **kwargs) + self.__handle_system_error( + err, module_id, module_name, method, **kwargs + ) return result def run_module(self, method: str, *args, **kwargs) -> Any: @@ -324,22 +368,29 @@ class ChainBase(metaclass=ABCMeta): result = None # 执行插件模块 - result = await self.__async_execute_plugin_modules(method, result, *args, **kwargs) + result = await self.__async_execute_plugin_modules( + method, result, *args, **kwargs + ) if not self.__is_valid_empty(result) and not isinstance(result, list): # 插件模块返回结果不为空且不是列表,直接返回 return result # 执行系统模块 - return await self.__async_execute_system_modules(method, result, *args, **kwargs) + return await self.__async_execute_system_modules( + method, result, *args, **kwargs + ) - def recognize_media(self, meta: MetaBase = None, - mtype: Optional[MediaType] = None, - tmdbid: Optional[int] = None, - doubanid: Optional[str] = None, - bangumiid: Optional[int] = None, - episode_group: Optional[str] = None, - cache: bool = True) -> Optional[MediaInfo]: + def recognize_media( + self, + meta: MetaBase = None, + mtype: Optional[MediaType] = None, + tmdbid: Optional[int] = None, + doubanid: Optional[str] = None, + bangumiid: Optional[int] = None, + episode_group: Optional[str] = None, + cache: bool = True, + ) -> Optional[MediaInfo]: """ 识别媒体信息,不含Fanart图片 :param meta: 识别的元数据 @@ -363,17 +414,27 @@ class ChainBase(metaclass=ABCMeta): doubanid = None bangumiid = None with fresh(not cache): - return self.run_module("recognize_media", meta=meta, mtype=mtype, - tmdbid=tmdbid, doubanid=doubanid, bangumiid=bangumiid, - episode_group=episode_group, cache=cache) + return self.run_module( + "recognize_media", + meta=meta, + mtype=mtype, + tmdbid=tmdbid, + doubanid=doubanid, + bangumiid=bangumiid, + episode_group=episode_group, + cache=cache, + ) - async def async_recognize_media(self, meta: MetaBase = None, - mtype: Optional[MediaType] = None, - tmdbid: Optional[int] = None, - doubanid: Optional[str] = None, - bangumiid: Optional[int] = None, - episode_group: Optional[str] = None, - cache: bool = True) -> Optional[MediaInfo]: + async def async_recognize_media( + self, + meta: MetaBase = None, + mtype: Optional[MediaType] = None, + tmdbid: Optional[int] = None, + doubanid: Optional[str] = None, + bangumiid: Optional[int] = None, + episode_group: Optional[str] = None, + cache: bool = True, + ) -> Optional[MediaInfo]: """ 识别媒体信息,不含Fanart图片(异步版本) :param meta: 识别的元数据 @@ -397,13 +458,26 @@ class ChainBase(metaclass=ABCMeta): doubanid = None bangumiid = None async with async_fresh(not cache): - return await self.async_run_module("async_recognize_media", meta=meta, mtype=mtype, - tmdbid=tmdbid, doubanid=doubanid, bangumiid=bangumiid, - episode_group=episode_group, cache=cache) + return await self.async_run_module( + "async_recognize_media", + meta=meta, + mtype=mtype, + tmdbid=tmdbid, + doubanid=doubanid, + bangumiid=bangumiid, + episode_group=episode_group, + cache=cache, + ) - def match_doubaninfo(self, name: str, imdbid: Optional[str] = None, - mtype: Optional[MediaType] = None, year: Optional[str] = None, season: Optional[int] = None, - raise_exception: bool = False) -> Optional[dict]: + def match_doubaninfo( + self, + name: str, + imdbid: Optional[str] = None, + mtype: Optional[MediaType] = None, + year: Optional[str] = None, + season: Optional[int] = None, + raise_exception: bool = False, + ) -> Optional[dict]: """ 搜索和匹配豆瓣信息 :param name: 标题 @@ -413,13 +487,25 @@ class ChainBase(metaclass=ABCMeta): :param season: 季 :param raise_exception: 触发速率限制时是否抛出异常 """ - return self.run_module("match_doubaninfo", name=name, imdbid=imdbid, - mtype=mtype, year=year, season=season, raise_exception=raise_exception) + return self.run_module( + "match_doubaninfo", + name=name, + imdbid=imdbid, + mtype=mtype, + year=year, + season=season, + raise_exception=raise_exception, + ) - async def async_match_doubaninfo(self, name: str, imdbid: Optional[str] = None, - mtype: Optional[MediaType] = None, year: Optional[str] = None, - season: Optional[int] = None, - raise_exception: bool = False) -> Optional[dict]: + async def async_match_doubaninfo( + self, + name: str, + imdbid: Optional[str] = None, + mtype: Optional[MediaType] = None, + year: Optional[str] = None, + season: Optional[int] = None, + raise_exception: bool = False, + ) -> Optional[dict]: """ 搜索和匹配豆瓣信息(异步版本) :param name: 标题 @@ -429,11 +515,23 @@ class ChainBase(metaclass=ABCMeta): :param season: 季 :param raise_exception: 触发速率限制时是否抛出异常 """ - return await self.async_run_module("async_match_doubaninfo", name=name, imdbid=imdbid, - mtype=mtype, year=year, season=season, raise_exception=raise_exception) + return await self.async_run_module( + "async_match_doubaninfo", + name=name, + imdbid=imdbid, + mtype=mtype, + year=year, + season=season, + raise_exception=raise_exception, + ) - def match_tmdbinfo(self, name: str, mtype: Optional[MediaType] = None, - year: Optional[str] = None, season: Optional[int] = None) -> Optional[dict]: + def match_tmdbinfo( + self, + name: str, + mtype: Optional[MediaType] = None, + year: Optional[str] = None, + season: Optional[int] = None, + ) -> Optional[dict]: """ 搜索和匹配TMDB信息 :param name: 标题 @@ -441,11 +539,17 @@ class ChainBase(metaclass=ABCMeta): :param year: 年份 :param season: 季 """ - return self.run_module("match_tmdbinfo", name=name, - mtype=mtype, year=year, season=season) + return self.run_module( + "match_tmdbinfo", name=name, mtype=mtype, year=year, season=season + ) - async def async_match_tmdbinfo(self, name: str, mtype: Optional[MediaType] = None, - year: Optional[str] = None, season: Optional[int] = None) -> Optional[dict]: + async def async_match_tmdbinfo( + self, + name: str, + mtype: Optional[MediaType] = None, + year: Optional[str] = None, + season: Optional[int] = None, + ) -> Optional[dict]: """ 搜索和匹配TMDB信息(异步版本) :param name: 标题 @@ -453,8 +557,9 @@ class ChainBase(metaclass=ABCMeta): :param year: 年份 :param season: 季 """ - return await self.async_run_module("async_match_tmdbinfo", name=name, - mtype=mtype, year=year, season=season) + return await self.async_run_module( + "async_match_tmdbinfo", name=name, mtype=mtype, year=year, season=season + ) def obtain_images(self, mediainfo: MediaInfo) -> Optional[MediaInfo]: """ @@ -472,9 +577,15 @@ class ChainBase(metaclass=ABCMeta): """ return await self.async_run_module("async_obtain_images", mediainfo=mediainfo) - def obtain_specific_image(self, mediaid: Union[str, int], mtype: MediaType, - image_type: MediaImageType, image_prefix: Optional[str] = None, - season: Optional[int] = None, episode: Optional[int] = None) -> Optional[str]: + def obtain_specific_image( + self, + mediaid: Union[str, int], + mtype: MediaType, + image_type: MediaImageType, + image_prefix: Optional[str] = None, + season: Optional[int] = None, + episode: Optional[int] = None, + ) -> Optional[str]: """ 获取指定媒体信息图片,返回图片地址 :param mediaid: 媒体ID @@ -484,12 +595,22 @@ class ChainBase(metaclass=ABCMeta): :param season: 季 :param episode: 集 """ - return self.run_module("obtain_specific_image", mediaid=mediaid, mtype=mtype, - image_prefix=image_prefix, image_type=image_type, - season=season, episode=episode) + return self.run_module( + "obtain_specific_image", + mediaid=mediaid, + mtype=mtype, + image_prefix=image_prefix, + image_type=image_type, + season=season, + episode=episode, + ) - def douban_info(self, doubanid: str, mtype: Optional[MediaType] = None, - raise_exception: bool = False) -> Optional[dict]: + def douban_info( + self, + doubanid: str, + mtype: Optional[MediaType] = None, + raise_exception: bool = False, + ) -> Optional[dict]: """ 获取豆瓣信息 :param doubanid: 豆瓣ID @@ -497,10 +618,19 @@ class ChainBase(metaclass=ABCMeta): :return: 豆瓣信息 :param raise_exception: 触发速率限制时是否抛出异常 """ - return self.run_module("douban_info", doubanid=doubanid, mtype=mtype, raise_exception=raise_exception) + return self.run_module( + "douban_info", + doubanid=doubanid, + mtype=mtype, + raise_exception=raise_exception, + ) - async def async_douban_info(self, doubanid: str, mtype: Optional[MediaType] = None, - raise_exception: bool = False) -> Optional[dict]: + async def async_douban_info( + self, + doubanid: str, + mtype: Optional[MediaType] = None, + raise_exception: bool = False, + ) -> Optional[dict]: """ 获取豆瓣信息(异步版本) :param doubanid: 豆瓣ID @@ -508,8 +638,12 @@ class ChainBase(metaclass=ABCMeta): :return: 豆瓣信息 :param raise_exception: 触发速率限制时是否抛出异常 """ - return await self.async_run_module("async_douban_info", doubanid=doubanid, mtype=mtype, - raise_exception=raise_exception) + return await self.async_run_module( + "async_douban_info", + doubanid=doubanid, + mtype=mtype, + raise_exception=raise_exception, + ) def tvdb_info(self, tvdbid: int) -> Optional[dict]: """ @@ -519,7 +653,9 @@ class ChainBase(metaclass=ABCMeta): """ return self.run_module("tvdb_info", tvdbid=tvdbid) - def tmdb_info(self, tmdbid: int, mtype: MediaType, season: Optional[int] = None) -> Optional[dict]: + def tmdb_info( + self, tmdbid: int, mtype: MediaType, season: Optional[int] = None + ) -> Optional[dict]: """ 获取TMDB信息 :param tmdbid: int @@ -529,7 +665,9 @@ class ChainBase(metaclass=ABCMeta): """ return self.run_module("tmdb_info", tmdbid=tmdbid, mtype=mtype, season=season) - async def async_tmdb_info(self, tmdbid: int, mtype: MediaType, season: Optional[int] = None) -> Optional[dict]: + async def async_tmdb_info( + self, tmdbid: int, mtype: MediaType, season: Optional[int] = None + ) -> Optional[dict]: """ 获取TMDB信息(异步版本) :param tmdbid: int @@ -537,7 +675,9 @@ class ChainBase(metaclass=ABCMeta): :param season: 季 :return: TVDB信息 """ - return await self.async_run_module("async_tmdb_info", tmdbid=tmdbid, mtype=mtype, season=season) + return await self.async_run_module( + "async_tmdb_info", tmdbid=tmdbid, mtype=mtype, season=season + ) def bangumi_info(self, bangumiid: int) -> Optional[dict]: """ @@ -555,8 +695,9 @@ class ChainBase(metaclass=ABCMeta): """ return await self.async_run_module("async_bangumi_info", bangumiid=bangumiid) - def message_parser(self, source: str, body: Any, form: Any, - args: Any) -> Optional[CommingMessage]: + def message_parser( + self, source: str, body: Any, form: Any, args: Any + ) -> Optional[CommingMessage]: """ 解析消息内容,返回字典,注意以下约定值: userid: 用户ID @@ -568,9 +709,13 @@ class ChainBase(metaclass=ABCMeta): :param args: 参数 :return: 消息渠道、消息内容 """ - return self.run_module("message_parser", source=source, body=body, form=form, args=args) + return self.run_module( + "message_parser", source=source, body=body, form=form, args=args + ) - def webhook_parser(self, body: Any, form: Any, args: Any) -> Optional[WebhookEventInfo]: + def webhook_parser( + self, body: Any, form: Any, args: Any + ) -> Optional[WebhookEventInfo]: """ 解析Webhook报文体 :param body: 请求体 @@ -624,10 +769,13 @@ class ChainBase(metaclass=ABCMeta): """ return await self.async_run_module("async_search_collections", name=name) - def search_torrents(self, site: dict, - keyword: str, - mtype: Optional[MediaType] = None, - page: Optional[int] = 0) -> List[TorrentInfo]: + def search_torrents( + self, + site: dict, + keyword: str, + mtype: Optional[MediaType] = None, + page: Optional[int] = 0, + ) -> List[TorrentInfo]: """ 搜索一个站点的种子资源 :param site: 站点 @@ -636,13 +784,17 @@ class ChainBase(metaclass=ABCMeta): :param page: 页码 :reutrn: 资源列表 """ - return self.run_module("search_torrents", site=site, keyword=keyword, - mtype=mtype, page=page) + return self.run_module( + "search_torrents", site=site, keyword=keyword, mtype=mtype, page=page + ) - async def async_search_torrents(self, site: dict, - keyword: str, - mtype: Optional[MediaType] = None, - page: Optional[int] = 0) -> List[TorrentInfo]: + async def async_search_torrents( + self, + site: dict, + keyword: str, + mtype: Optional[MediaType] = None, + page: Optional[int] = 0, + ) -> List[TorrentInfo]: """ 异步搜索一个站点的种子资源 :param site: 站点 @@ -651,11 +803,17 @@ class ChainBase(metaclass=ABCMeta): :param page: 页码 :reutrn: 资源列表 """ - return await self.async_run_module("async_search_torrents", site=site, keyword=keyword, - mtype=mtype, page=page) + return await self.async_run_module( + "async_search_torrents", site=site, keyword=keyword, mtype=mtype, page=page + ) - def refresh_torrents(self, site: dict, keyword: Optional[str] = None, - cat: Optional[str] = None, page: Optional[int] = 0) -> List[TorrentInfo]: + def refresh_torrents( + self, + site: dict, + keyword: Optional[str] = None, + cat: Optional[str] = None, + page: Optional[int] = 0, + ) -> List[TorrentInfo]: """ 获取站点最新一页的种子,多个站点需要多线程处理 :param site: 站点 @@ -664,10 +822,17 @@ class ChainBase(metaclass=ABCMeta): :param page: 页码 :reutrn: 种子资源列表 """ - return self.run_module("refresh_torrents", site=site, keyword=keyword, cat=cat, page=page) + return self.run_module( + "refresh_torrents", site=site, keyword=keyword, cat=cat, page=page + ) - async def async_refresh_torrents(self, site: dict, keyword: Optional[str] = None, - cat: Optional[str] = None, page: Optional[int] = 0) -> List[TorrentInfo]: + async def async_refresh_torrents( + self, + site: dict, + keyword: Optional[str] = None, + cat: Optional[str] = None, + page: Optional[int] = 0, + ) -> List[TorrentInfo]: """ 异步获取站点最新一页的种子,多个站点需要多线程处理 :param site: 站点 @@ -676,12 +841,16 @@ class ChainBase(metaclass=ABCMeta): :param page: 页码 :reutrn: 种子资源列表 """ - return await self.async_run_module("async_refresh_torrents", - site=site, keyword=keyword, cat=cat, page=page) + return await self.async_run_module( + "async_refresh_torrents", site=site, keyword=keyword, cat=cat, page=page + ) - def filter_torrents(self, rule_groups: List[str], - torrent_list: List[TorrentInfo], - mediainfo: MediaInfo = None) -> List[TorrentInfo]: + def filter_torrents( + self, + rule_groups: List[str], + torrent_list: List[TorrentInfo], + mediainfo: MediaInfo = None, + ) -> List[TorrentInfo]: """ 过滤种子资源 :param rule_groups: 过滤规则组名称列表 @@ -689,13 +858,23 @@ class ChainBase(metaclass=ABCMeta): :param mediainfo: 识别的媒体信息 :return: 过滤后的资源列表,添加资源优先级 """ - return self.run_module("filter_torrents", rule_groups=rule_groups, - torrent_list=torrent_list, mediainfo=mediainfo) + return self.run_module( + "filter_torrents", + rule_groups=rule_groups, + torrent_list=torrent_list, + mediainfo=mediainfo, + ) - def download(self, content: Union[Path, str, bytes], download_dir: Path, cookie: str, - episodes: Set[int] = None, category: Optional[str] = None, label: Optional[str] = None, - downloader: Optional[str] = None - ) -> Optional[Tuple[Optional[str], Optional[str], Optional[str], str]]: + def download( + self, + content: Union[Path, str, bytes], + download_dir: Path, + cookie: str, + episodes: Set[int] = None, + category: Optional[str] = None, + label: Optional[str] = None, + downloader: Optional[str] = None, + ) -> Optional[Tuple[Optional[str], Optional[str], Optional[str], str]]: """ 根据种子文件,选择并添加下载任务 :param content: 种子文件地址或者磁力链接或者种子内容 @@ -707,11 +886,23 @@ class ChainBase(metaclass=ABCMeta): :param downloader: 下载器 :return: 下载器名称、种子Hash、种子文件布局、错误原因 """ - return self.run_module("download", content=content, download_dir=download_dir, - cookie=cookie, episodes=episodes, category=category, label=label, - downloader=downloader) + return self.run_module( + "download", + content=content, + download_dir=download_dir, + cookie=cookie, + episodes=episodes, + category=category, + label=label, + downloader=downloader, + ) - def download_added(self, context: Context, download_dir: Path, torrent_content: Union[str, bytes] = None) -> None: + def download_added( + self, + context: Context, + download_dir: Path, + torrent_content: Union[str, bytes] = None, + ) -> None: """ 添加下载任务成功后,从站点下载字幕,保存到下载目录 :param context: 上下文,包括识别信息、媒体信息、种子信息 @@ -719,14 +910,19 @@ class ChainBase(metaclass=ABCMeta): :param torrent_content: 种子内容,如果有则直接使用该内容,否则从context中获取种子文件路径 :return: None,该方法可被多个模块同时处理 """ - return self.run_module("download_added", context=context, - torrent_content=torrent_content, - download_dir=download_dir) + return self.run_module( + "download_added", + context=context, + torrent_content=torrent_content, + download_dir=download_dir, + ) - def list_torrents(self, status: TorrentStatus = None, - hashs: Union[list, str] = None, - downloader: Optional[str] = None - ) -> Optional[List[Union[TransferTorrent, DownloadingTorrent]]]: + def list_torrents( + self, + status: TorrentStatus = None, + hashs: Union[list, str] = None, + downloader: Optional[str] = None, + ) -> Optional[List[Union[TransferTorrent, DownloadingTorrent]]]: """ 获取下载器种子列表 :param status: 种子状态 @@ -734,15 +930,26 @@ class ChainBase(metaclass=ABCMeta): :param downloader: 下载器 :return: 下载器中符合状态的种子列表 """ - return self.run_module("list_torrents", status=status, hashs=hashs, downloader=downloader) + return self.run_module( + "list_torrents", status=status, hashs=hashs, downloader=downloader + ) - def transfer(self, fileitem: FileItem, meta: MetaBase, mediainfo: MediaInfo, - target_directory: TransferDirectoryConf = None, - target_storage: Optional[str] = None, target_path: Path = None, - transfer_type: Optional[str] = None, scrape: bool = None, - library_type_folder: bool = None, library_category_folder: bool = None, - episodes_info: List[TmdbEpisode] = None, - source_oper: Callable = None, target_oper: Callable = None) -> Optional[TransferInfo]: + def transfer( + self, + fileitem: FileItem, + meta: MetaBase, + mediainfo: MediaInfo, + target_directory: TransferDirectoryConf = None, + target_storage: Optional[str] = None, + target_path: Path = None, + transfer_type: Optional[str] = None, + scrape: bool = None, + library_type_folder: bool = None, + library_category_folder: bool = None, + episodes_info: List[TmdbEpisode] = None, + source_oper: Callable = None, + target_oper: Callable = None, + ) -> Optional[TransferInfo]: """ 文件转移 :param fileitem: 文件信息 @@ -760,15 +967,22 @@ class ChainBase(metaclass=ABCMeta): :param target_oper: 目标存储操作类 :return: {path, target_path, message} """ - return self.run_module("transfer", - fileitem=fileitem, meta=meta, mediainfo=mediainfo, - target_directory=target_directory, - target_path=target_path, target_storage=target_storage, - transfer_type=transfer_type, scrape=scrape, - library_type_folder=library_type_folder, - library_category_folder=library_category_folder, - episodes_info=episodes_info, - source_oper=source_oper, target_oper=target_oper) + return self.run_module( + "transfer", + fileitem=fileitem, + meta=meta, + mediainfo=mediainfo, + target_directory=target_directory, + target_path=target_path, + target_storage=target_storage, + transfer_type=transfer_type, + scrape=scrape, + library_type_folder=library_type_folder, + library_category_folder=library_category_folder, + episodes_info=episodes_info, + source_oper=source_oper, + target_oper=target_oper, + ) def transfer_completed(self, hashs: str, downloader: Optional[str] = None) -> None: """ @@ -778,8 +992,12 @@ class ChainBase(metaclass=ABCMeta): """ return self.run_module("transfer_completed", hashs=hashs, downloader=downloader) - def remove_torrents(self, hashs: Union[str, list], delete_file: bool = True, - downloader: Optional[str] = None) -> bool: + def remove_torrents( + self, + hashs: Union[str, list], + delete_file: bool = True, + downloader: Optional[str] = None, + ) -> bool: """ 删除下载器种子 :param hashs: 种子Hash @@ -787,9 +1005,16 @@ class ChainBase(metaclass=ABCMeta): :param downloader: 下载器 :return: bool """ - return self.run_module("remove_torrents", hashs=hashs, delete_file=delete_file, downloader=downloader) + return self.run_module( + "remove_torrents", + hashs=hashs, + delete_file=delete_file, + downloader=downloader, + ) - def start_torrents(self, hashs: Union[list, str], downloader: Optional[str] = None) -> bool: + def start_torrents( + self, hashs: Union[list, str], downloader: Optional[str] = None + ) -> bool: """ 开始下载 :param hashs: 种子Hash @@ -798,7 +1023,9 @@ class ChainBase(metaclass=ABCMeta): """ return self.run_module("start_torrents", hashs=hashs, downloader=downloader) - def stop_torrents(self, hashs: Union[list, str], downloader: Optional[str] = None) -> bool: + def stop_torrents( + self, hashs: Union[list, str], downloader: Optional[str] = None + ) -> bool: """ 停止下载 :param hashs: 种子Hash @@ -807,8 +1034,9 @@ class ChainBase(metaclass=ABCMeta): """ return self.run_module("stop_torrents", hashs=hashs, downloader=downloader) - def torrent_files(self, tid: str, - downloader: Optional[str] = None) -> Optional[Union[TorrentFilesList, List[File]]]: + def torrent_files( + self, tid: str, downloader: Optional[str] = None + ) -> Optional[Union[TorrentFilesList, List[File]]]: """ 获取种子文件 :param tid: 种子Hash @@ -817,8 +1045,12 @@ class ChainBase(metaclass=ABCMeta): """ return self.run_module("torrent_files", tid=tid, downloader=downloader) - def media_exists(self, mediainfo: MediaInfo, itemid: Optional[str] = None, - server: Optional[str] = None) -> Optional[ExistMediaInfo]: + def media_exists( + self, + mediainfo: MediaInfo, + itemid: Optional[str] = None, + server: Optional[str] = None, + ) -> Optional[ExistMediaInfo]: """ 判断媒体文件是否存在 :param mediainfo: 识别的媒体信息 @@ -826,7 +1058,9 @@ class ChainBase(metaclass=ABCMeta): :param server: 媒体服务器 :return: 如不存在返回None,存在时返回信息,包括每季已存在所有集{type: movie/tv, seasons: {season: [episodes]}} """ - return self.run_module("media_exists", mediainfo=mediainfo, itemid=itemid, server=server) + return self.run_module( + "media_exists", mediainfo=mediainfo, itemid=itemid, server=server + ) def media_files(self, mediainfo: MediaInfo) -> Optional[List[FileItem]]: """ @@ -836,13 +1070,15 @@ class ChainBase(metaclass=ABCMeta): """ return self.run_module("media_files", mediainfo=mediainfo) - def post_message(self, - message: Optional[Notification] = None, - meta: Optional[MetaBase] = None, - mediainfo: Optional[MediaInfo] = None, - torrentinfo: Optional[TorrentInfo] = None, - transferinfo: Optional[TransferInfo] = None, - **kwargs) -> None: + def post_message( + self, + message: Optional[Notification] = None, + meta: Optional[MetaBase] = None, + mediainfo: Optional[MediaInfo] = None, + torrentinfo: Optional[TorrentInfo] = None, + transferinfo: Optional[TransferInfo] = None, + **kwargs, + ) -> None: """ 发送消息 :param message: Notification实例 @@ -854,10 +1090,16 @@ class ChainBase(metaclass=ABCMeta): :return: 成功或失败 """ # 添加格式化的时间参数 - kwargs.setdefault('current_time', datetime.now().strftime('%Y-%m-%d %H:%M:%S')) + kwargs.setdefault("current_time", datetime.now().strftime("%Y-%m-%d %H:%M:%S")) # 渲染消息 - message = MessageTemplateHelper.render(message=message, meta=meta, mediainfo=mediainfo, - torrentinfo=torrentinfo, transferinfo=transferinfo, **kwargs) + message = MessageTemplateHelper.render( + message=message, + meta=meta, + mediainfo=mediainfo, + torrentinfo=torrentinfo, + transferinfo=transferinfo, + **kwargs, + ) # 检查消息是否有效 if not message: logger.warning("消息为空,跳过发送") @@ -886,20 +1128,30 @@ class ChainBase(metaclass=ABCMeta): admin_sended = True elif action == "user" and send_message.username: # 发送对应用户 - logger.info(f"{send_message.mtype} 的消息已设置发送给用户 {send_message.username}") + logger.info( + f"{send_message.mtype} 的消息已设置发送给用户 {send_message.username}" + ) # 读取用户消息IDS - send_message.targets = useroper.get_settings(send_message.username) + send_message.targets = useroper.get_settings( + send_message.username + ) if send_message.targets is None: # 没有找到用户 if not admin_sended: # 回滚发送管理员 - logger.info(f"用户 {send_message.username} 不存在,消息将发送给管理员") + logger.info( + f"用户 {send_message.username} 不存在,消息将发送给管理员" + ) # 读取管理员消息IDS - send_message.targets = useroper.get_settings(settings.SUPERUSER) + send_message.targets = useroper.get_settings( + settings.SUPERUSER + ) admin_sended = True else: # 管理员发过了,此消息不发了 - logger.info(f"用户 {send_message.username} 不存在,消息无法发送到对应用户") + logger.info( + f"用户 {send_message.username} 不存在,消息无法发送到对应用户" + ) continue elif send_message.username == settings.SUPERUSER: # 管理员同名已发送 @@ -910,24 +1162,37 @@ class ChainBase(metaclass=ABCMeta): send_orignal = True break # 按设定发送 - self.eventmanager.send_event(etype=EventType.NoticeMessage, - data={**send_message.model_dump(), "type": send_message.mtype}) - self.messagequeue.send_message("post_message", message=send_message, **kwargs) + self.eventmanager.send_event( + etype=EventType.NoticeMessage, + data={**send_message.model_dump(), "type": send_message.mtype}, + ) + self.messagequeue.send_message( + "post_message", message=send_message, **kwargs + ) if not send_orignal: return # 发送消息事件 - self.eventmanager.send_event(etype=EventType.NoticeMessage, data={**message.model_dump(), "type": message.mtype}) + self.eventmanager.send_event( + etype=EventType.NoticeMessage, + data={**message.model_dump(), "type": message.mtype}, + ) # 按原消息发送 - self.messagequeue.send_message("post_message", message=message, - immediately=True if message.userid else False, **kwargs) + self.messagequeue.send_message( + "post_message", + message=message, + immediately=True if message.userid else False, + **kwargs, + ) - async def async_post_message(self, - message: Optional[Notification] = None, - meta: Optional[MetaBase] = None, - mediainfo: Optional[MediaInfo] = None, - torrentinfo: Optional[TorrentInfo] = None, - transferinfo: Optional[TransferInfo] = None, - **kwargs) -> None: + async def async_post_message( + self, + message: Optional[Notification] = None, + meta: Optional[MetaBase] = None, + mediainfo: Optional[MediaInfo] = None, + torrentinfo: Optional[TorrentInfo] = None, + transferinfo: Optional[TransferInfo] = None, + **kwargs, + ) -> None: """ 异步发送消息 :param message: Notification实例 @@ -939,10 +1204,16 @@ class ChainBase(metaclass=ABCMeta): :return: 成功或失败 """ # 添加格式化的时间参数 - kwargs.setdefault('current_time', datetime.now().strftime('%Y-%m-%d %H:%M:%S')) + kwargs.setdefault("current_time", datetime.now().strftime("%Y-%m-%d %H:%M:%S")) # 渲染消息 - message = MessageTemplateHelper.render(message=message, meta=meta, mediainfo=mediainfo, - torrentinfo=torrentinfo, transferinfo=transferinfo, **kwargs) + message = MessageTemplateHelper.render( + message=message, + meta=meta, + mediainfo=mediainfo, + torrentinfo=torrentinfo, + transferinfo=transferinfo, + **kwargs, + ) # 检查消息是否有效 if not message: logger.warning("消息为空,跳过发送") @@ -971,20 +1242,30 @@ class ChainBase(metaclass=ABCMeta): admin_sended = True elif action == "user" and send_message.username: # 发送对应用户 - logger.info(f"{send_message.mtype} 的消息已设置发送给用户 {send_message.username}") + logger.info( + f"{send_message.mtype} 的消息已设置发送给用户 {send_message.username}" + ) # 读取用户消息IDS - send_message.targets = useroper.get_settings(send_message.username) + send_message.targets = useroper.get_settings( + send_message.username + ) if send_message.targets is None: # 没有找到用户 if not admin_sended: # 回滚发送管理员 - logger.info(f"用户 {send_message.username} 不存在,消息将发送给管理员") + logger.info( + f"用户 {send_message.username} 不存在,消息将发送给管理员" + ) # 读取管理员消息IDS - send_message.targets = useroper.get_settings(settings.SUPERUSER) + send_message.targets = useroper.get_settings( + settings.SUPERUSER + ) admin_sended = True else: # 管理员发过了,此消息不发了 - logger.info(f"用户 {send_message.username} 不存在,消息无法发送到对应用户") + logger.info( + f"用户 {send_message.username} 不存在,消息无法发送到对应用户" + ) continue elif send_message.username == settings.SUPERUSER: # 管理员同名已发送 @@ -995,19 +1276,31 @@ class ChainBase(metaclass=ABCMeta): send_orignal = True break # 按设定发送 - await self.eventmanager.async_send_event(etype=EventType.NoticeMessage, - data={**send_message.model_dump(), "type": send_message.mtype}) - await self.messagequeue.async_send_message("post_message", message=send_message, **kwargs) + await self.eventmanager.async_send_event( + etype=EventType.NoticeMessage, + data={**send_message.model_dump(), "type": send_message.mtype}, + ) + await self.messagequeue.async_send_message( + "post_message", message=send_message, **kwargs + ) if not send_orignal: return # 发送消息事件 - await self.eventmanager.async_send_event(etype=EventType.NoticeMessage, - data={**message.model_dump(), "type": message.mtype}) + await self.eventmanager.async_send_event( + etype=EventType.NoticeMessage, + data={**message.model_dump(), "type": message.mtype}, + ) # 按原消息发送 - await self.messagequeue.async_send_message("post_message", message=message, - immediately=True if message.userid else False, **kwargs) + await self.messagequeue.async_send_message( + "post_message", + message=message, + immediately=True if message.userid else False, + **kwargs, + ) - def post_medias_message(self, message: Notification, medias: List[MediaInfo]) -> None: + def post_medias_message( + self, message: Notification, medias: List[MediaInfo] + ) -> None: """ 发送媒体信息选择列表 :param message: 消息体 @@ -1015,12 +1308,20 @@ class ChainBase(metaclass=ABCMeta): :return: 成功或失败 """ note_list = [media.to_dict() for media in medias] - self.messagehelper.put(message, role="user", note=note_list, title=message.title) + self.messagehelper.put( + message, role="user", note=note_list, title=message.title + ) self.messageoper.add(**message.model_dump(), note=note_list) - return self.messagequeue.send_message("post_medias_message", message=message, medias=medias, - immediately=True if message.userid else False) + return self.messagequeue.send_message( + "post_medias_message", + message=message, + medias=medias, + immediately=True if message.userid else False, + ) - def post_torrents_message(self, message: Notification, torrents: List[Context]) -> None: + def post_torrents_message( + self, message: Notification, torrents: List[Context] + ) -> None: """ 发送种子信息选择列表 :param message: 消息体 @@ -1028,13 +1329,24 @@ class ChainBase(metaclass=ABCMeta): :return: 成功或失败 """ note_list = [torrent.torrent_info.to_dict() for torrent in torrents] - self.messagehelper.put(message, role="user", note=note_list, title=message.title) + self.messagehelper.put( + message, role="user", note=note_list, title=message.title + ) self.messageoper.add(**message.model_dump(), note=note_list) - return self.messagequeue.send_message("post_torrents_message", message=message, torrents=torrents, - immediately=True if message.userid else False) + return self.messagequeue.send_message( + "post_torrents_message", + message=message, + torrents=torrents, + immediately=True if message.userid else False, + ) - def delete_message(self, channel: MessageChannel, source: str, - message_id: Union[str, int], chat_id: Optional[Union[str, int]] = None) -> bool: + def delete_message( + self, + channel: MessageChannel, + source: str, + message_id: Union[str, int], + chat_id: Optional[Union[str, int]] = None, + ) -> bool: """ 删除消息 :param channel: 消息渠道 @@ -1043,18 +1355,67 @@ class ChainBase(metaclass=ABCMeta): :param chat_id: 聊天ID(如群组ID) :return: 删除是否成功 """ - return self.run_module("delete_message", channel=channel, source=source, - message_id=message_id, chat_id=chat_id) + return self.run_module( + "delete_message", + channel=channel, + source=source, + message_id=message_id, + chat_id=chat_id, + ) - def metadata_img(self, mediainfo: MediaInfo, - season: Optional[int] = None, episode: Optional[int] = None) -> Optional[dict]: + def edit_message( + self, + channel: MessageChannel, + source: str, + message_id: Union[str, int], + chat_id: Union[str, int], + text: str, + title: Optional[str] = None, + ) -> bool: + """ + 编辑已发送的消息 + :param channel: 消息渠道 + :param source: 消息源(指定特定的消息模块) + :param message_id: 消息ID + :param chat_id: 聊天ID + :param text: 新的消息内容 + :param title: 消息标题 + :return: 编辑是否成功 + """ + return self.run_module( + "edit_message", + channel=channel, + source=source, + message_id=message_id, + chat_id=chat_id, + text=text, + title=title, + ) + + def send_direct_message(self, message: Notification) -> Optional[MessageResponse]: + """ + 直接发送消息并返回消息ID等信息(用于后续编辑消息的场景) + 不经过消息队列、不保存消息历史 + :param message: 消息体 + :return: 消息响应(包含message_id, chat_id等) + """ + return self.run_module("send_direct_message", message=message) + + def metadata_img( + self, + mediainfo: MediaInfo, + season: Optional[int] = None, + episode: Optional[int] = None, + ) -> Optional[dict]: """ 获取图片名称和url :param mediainfo: 媒体信息 :param season: 季号 :param episode: 集号 """ - return self.run_module("metadata_img", mediainfo=mediainfo, season=season, episode=episode) + return self.run_module( + "metadata_img", mediainfo=mediainfo, season=season, episode=episode + ) def media_category(self) -> Optional[Dict[str, list]]: """ diff --git a/app/modules/telegram/__init__.py b/app/modules/telegram/__init__.py index b47ea77c..9a81ac7a 100644 --- a/app/modules/telegram/__init__.py +++ b/app/modules/telegram/__init__.py @@ -8,20 +8,26 @@ from app.core.event import eventmanager from app.log import logger from app.modules import _ModuleBase, _MessageBase from app.modules.telegram.telegram import Telegram -from app.schemas import MessageChannel, CommingMessage, Notification, CommandRegisterEventData, \ - NotificationConf +from app.schemas import ( + MessageChannel, + CommingMessage, + Notification, + CommandRegisterEventData, + NotificationConf, + MessageResponse, +) from app.schemas.types import ModuleType, ChainEventType from app.utils.structures import DictUtils class TelegramModule(_ModuleBase, _MessageBase[Telegram]): - def init_module(self) -> None: """ 初始化模块 """ - super().init_service(service_name=Telegram.__name__.lower(), - service_type=Telegram) + super().init_service( + service_name=Telegram.__name__.lower(), service_type=Telegram + ) self._channel = MessageChannel.Telegram @staticmethod @@ -71,8 +77,9 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]): def init_setting(self) -> Tuple[str, Union[str, bool]]: pass - def message_parser(self, source: str, body: Any, form: Any, - args: Any) -> Optional[CommingMessage]: + def message_parser( + self, source: str, body: Any, form: Any, args: Any + ) -> Optional[CommingMessage]: """ 解析消息内容,返回字典,注意以下约定值: userid: 用户ID @@ -140,7 +147,9 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]): return None @staticmethod - def _handle_callback_query(message: dict, client_config: NotificationConf) -> Optional[CommingMessage]: + def _handle_callback_query( + message: dict, client_config: NotificationConf + ) -> Optional[CommingMessage]: """ 处理按钮回调查询 """ @@ -151,8 +160,10 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]): user_name = user_info.get("username") if callback_data and user_id: - logger.info(f"收到来自 {client_config.name} 的Telegram按钮回调:" - f"userid={user_id}, username={user_name}, callback_data={callback_data}") + logger.info( + f"收到来自 {client_config.name} 的Telegram按钮回调:" + f"userid={user_id}, username={user_name}, callback_data={callback_data}" + ) # 将callback_data作为特殊格式的text返回,以便主程序识别这是按钮回调 callback_text = f"CALLBACK:{callback_data}" @@ -167,13 +178,16 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]): is_callback=True, callback_data=callback_data, message_id=callback_query.get("message", {}).get("message_id"), - chat_id=str(callback_query.get("message", {}).get("chat", {}).get("id", "")), - callback_query=callback_query + chat_id=str( + callback_query.get("message", {}).get("chat", {}).get("id", "") + ), + callback_query=callback_query, ) return None - def _handle_text_message(self, msg: dict, - client_config: NotificationConf, client: Telegram) -> Optional[CommingMessage]: + def _handle_text_message( + self, msg: dict, client_config: NotificationConf, client: Telegram + ) -> Optional[CommingMessage]: """ 处理普通文本消息 """ @@ -184,11 +198,15 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]): chat_id = msg.get("chat", {}).get("id") if text and user_id: - logger.info(f"收到来自 {client_config.name} 的Telegram消息:" - f"userid={user_id}, username={user_name}, chat_id={chat_id}, text={text}") + logger.info( + f"收到来自 {client_config.name} 的Telegram消息:" + f"userid={user_id}, username={user_name}, chat_id={chat_id}, text={text}" + ) # Clean bot mentions from text to ensure consistent processing - cleaned_text = self._clean_bot_mention(text, client.bot_username if client else None) + cleaned_text = self._clean_bot_mention( + text, client.bot_username if client else None + ) # 检查权限 admin_users = client_config.config.get("TELEGRAM_ADMINS") @@ -196,16 +214,21 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]): config_chat_id = client_config.config.get("TELEGRAM_CHAT_ID") if cleaned_text.startswith("/"): - if admin_users \ - and str(user_id) not in admin_users.split(',') \ - and str(user_id) != config_chat_id: - client.send_msg(title="只有管理员才有权限执行此命令", userid=user_id) + if ( + admin_users + and str(user_id) not in admin_users.split(",") + and str(user_id) != config_chat_id + ): + client.send_msg( + title="只有管理员才有权限执行此命令", userid=user_id + ) return None else: - if user_list \ - and str(user_id) not in user_list.split(','): + if user_list and str(user_id) not in user_list.split(","): logger.info(f"用户{user_id}不在用户白名单中,无法使用此机器人") - client.send_msg(title="你不在用户白名单中,无法使用此机器人", userid=user_id) + client.send_msg( + title="你不在用户白名单中,无法使用此机器人", userid=user_id + ) return None return CommingMessage( @@ -214,7 +237,7 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]): userid=user_id, username=user_name, text=cleaned_text, # Use cleaned text - chat_id=str(chat_id) if chat_id else None + chat_id=str(chat_id) if chat_id else None, ) return None @@ -235,13 +258,13 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]): # Remove mention at the beginning with optional following space if cleaned.startswith(mention_pattern): - cleaned = cleaned[len(mention_pattern):].lstrip() + cleaned = cleaned[len(mention_pattern) :].lstrip() # Remove mention at any other position cleaned = cleaned.replace(mention_pattern, "").strip() # Clean up multiple spaces - cleaned = re.sub(r'\s+', ' ', cleaned).strip() + cleaned = re.sub(r"\s+", " ", cleaned).strip() return cleaned @@ -257,19 +280,26 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]): targets = message.targets userid = message.userid if not userid and targets is not None: - userid = targets.get('telegram_userid') + userid = targets.get("telegram_userid") if not userid: logger.warn(f"用户没有指定 Telegram用户ID,消息无法发送") return client: Telegram = self.get_instance(conf.name) if client: - client.send_msg(title=message.title, text=message.text, - image=message.image, userid=userid, link=message.link, - buttons=message.buttons, - original_message_id=message.original_message_id, - original_chat_id=message.original_chat_id) + client.send_msg( + title=message.title, + text=message.text, + image=message.image, + userid=userid, + link=message.link, + buttons=message.buttons, + original_message_id=message.original_message_id, + original_chat_id=message.original_chat_id, + ) - def post_medias_message(self, message: Notification, medias: List[MediaInfo]) -> None: + def post_medias_message( + self, message: Notification, medias: List[MediaInfo] + ) -> None: """ 发送媒体信息选择列表 :param message: 消息体 @@ -281,13 +311,19 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]): continue client: Telegram = self.get_instance(conf.name) if client: - client.send_medias_msg(title=message.title, medias=medias, - userid=message.userid, link=message.link, - buttons=message.buttons, - original_message_id=message.original_message_id, - original_chat_id=message.original_chat_id) + client.send_medias_msg( + title=message.title, + medias=medias, + userid=message.userid, + link=message.link, + buttons=message.buttons, + original_message_id=message.original_message_id, + original_chat_id=message.original_chat_id, + ) - def post_torrents_message(self, message: Notification, torrents: List[Context]) -> None: + def post_torrents_message( + self, message: Notification, torrents: List[Context] + ) -> None: """ 发送种子信息选择列表 :param message: 消息体 @@ -299,14 +335,23 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]): continue client: Telegram = self.get_instance(conf.name) if client: - client.send_torrents_msg(title=message.title, torrents=torrents, - userid=message.userid, link=message.link, - buttons=message.buttons, - original_message_id=message.original_message_id, - original_chat_id=message.original_chat_id) + client.send_torrents_msg( + title=message.title, + torrents=torrents, + userid=message.userid, + link=message.link, + buttons=message.buttons, + original_message_id=message.original_message_id, + original_chat_id=message.original_chat_id, + ) - def delete_message(self, channel: MessageChannel, source: str, - message_id: int, chat_id: Optional[int] = None) -> bool: + def delete_message( + self, + channel: MessageChannel, + source: str, + message_id: int, + chat_id: Optional[int] = None, + ) -> bool: """ 删除消息 :param channel: 消息渠道 @@ -328,6 +373,77 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]): success = True return success + def edit_message( + self, + channel: MessageChannel, + source: str, + message_id: Union[str, int], + chat_id: Union[str, int], + text: str, + title: Optional[str] = None, + ) -> bool: + """ + 编辑消息 + :param channel: 消息渠道 + :param source: 指定的消息源 + :param message_id: 消息ID + :param chat_id: 聊天ID + :param text: 新的消息内容 + :param title: 消息标题 + :return: 编辑是否成功 + """ + if channel != self._channel: + return False + for conf in self.get_configs().values(): + if source != conf.name: + continue + client: Telegram = self.get_instance(conf.name) + if client: + result = client.edit_msg( + chat_id=chat_id, + message_id=message_id, + text=text, + title=title, + ) + if result: + return True + return False + + def send_direct_message(self, message: Notification) -> Optional[MessageResponse]: + """ + 直接发送消息并返回消息ID等信息 + :param message: 消息体 + :return: 消息响应(包含message_id, chat_id等) + """ + for conf in self.get_configs().values(): + if not self.check_message(message, conf.name): + continue + targets = message.targets + userid = message.userid + if not userid and targets is not None: + userid = targets.get("telegram_userid") + if not userid: + logger.warn("用户没有指定 Telegram用户ID,消息无法发送") + return None + client: Telegram = self.get_instance(conf.name) + if client: + result = client.send_msg( + title=message.title, + text=message.text, + image=message.image, + userid=userid, + link=message.link, + ) + if result and result.get("success"): + return MessageResponse( + message_id=result.get("message_id"), + chat_id=result.get("chat_id"), + channel=MessageChannel.Telegram, + source=conf.name, + success=True, + ) + return None + def register_commands(self, commands: Dict[str, dict]): """ 注册命令,实现这个函数接收系统可用的命令菜单 @@ -342,7 +458,11 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]): scoped_commands = copy.deepcopy(commands) event = eventmanager.send_event( ChainEventType.CommandRegister, - CommandRegisterEventData(commands=scoped_commands, origin="Telegram", service=client_config.name) + CommandRegisterEventData( + commands=scoped_commands, + origin="Telegram", + service=client_config.name, + ), ) # 如果事件返回有效的 event_data,使用事件中调整后的命令 @@ -361,7 +481,9 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]): client.delete_commands() # scoped_commands 必须是 commands 的子集 - filtered_scoped_commands = DictUtils.filter_keys_to_subset(scoped_commands, commands) + filtered_scoped_commands = DictUtils.filter_keys_to_subset( + scoped_commands, commands + ) # 如果 filtered_scoped_commands 为空,则跳过注册 if not filtered_scoped_commands: logger.debug("Filtered commands are empty, skipping registration.") @@ -369,5 +491,7 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]): continue # 对比调整后的命令与当前命令 if filtered_scoped_commands != commands: - logger.debug(f"Command set has changed, Updating new commands: {filtered_scoped_commands}") + logger.debug( + f"Command set has changed, Updating new commands: {filtered_scoped_commands}" + ) client.register_commands(filtered_scoped_commands) diff --git a/app/modules/telegram/telegram.py b/app/modules/telegram/telegram.py index 45c1b2b2..c5036918 100644 --- a/app/modules/telegram/telegram.py +++ b/app/modules/telegram/telegram.py @@ -1,11 +1,16 @@ import asyncio import re import threading -from typing import Optional, List, Dict, Callable +from typing import Optional, List, Dict, Callable, Union from urllib.parse import urljoin, quote from telebot import TeleBot, apihelper -from telebot.types import BotCommand, InlineKeyboardMarkup, InlineKeyboardButton, InputMediaPhoto +from telebot.types import ( + BotCommand, + InlineKeyboardMarkup, + InlineKeyboardButton, + InputMediaPhoto, +) from telegramify_markdown import standardize, telegramify from telegramify_markdown.type import ContentTypes, SentType @@ -25,13 +30,22 @@ class RetryException(Exception): class Telegram: - _ds_url = f"http://127.0.0.1:{settings.PORT}/api/v1/message?token={settings.API_TOKEN}" + _ds_url = ( + f"http://127.0.0.1:{settings.PORT}/api/v1/message?token={settings.API_TOKEN}" + ) _bot: TeleBot = None _callback_handlers: Dict[str, Callable] = {} # 存储回调处理器 - _user_chat_mapping: Dict[str, str] = {} # userid -> chat_id mapping for reply targeting + _user_chat_mapping: Dict[ + str, str + ] = {} # userid -> chat_id mapping for reply targeting _bot_username: Optional[str] = None # Bot username for mention detection - def __init__(self, TELEGRAM_TOKEN: Optional[str] = None, TELEGRAM_CHAT_ID: Optional[str] = None, **kwargs): + def __init__( + self, + TELEGRAM_TOKEN: Optional[str] = None, + TELEGRAM_CHAT_ID: Optional[str] = None, + **kwargs, + ): """ 初始化参数 """ @@ -46,8 +60,8 @@ class Telegram: if self._telegram_token and self._telegram_chat_id: # telegram bot api 地址,格式:https://api.telegram.org if kwargs.get("API_URL"): - apihelper.API_URL = urljoin(kwargs["API_URL"], '/bot{0}/{1}') - apihelper.FILE_URL = urljoin(kwargs["API_URL"], '/file/bot{0}/{1}') + apihelper.API_URL = urljoin(kwargs["API_URL"], "/bot{0}/{1}") + apihelper.FILE_URL = urljoin(kwargs["API_URL"], "/file/bot{0}/{1}") else: apihelper.proxy = settings.PROXY # bot @@ -66,12 +80,15 @@ class Telegram: # 标记渠道来源 if kwargs.get("name"): # URL encode the source name to handle special characters - encoded_name = quote(kwargs.get('name'), safe='') + encoded_name = quote(kwargs.get("name"), safe="") self._ds_url = f"{self._ds_url}&source={encoded_name}" - @_bot.message_handler(commands=['start', 'help']) + @_bot.message_handler(commands=["start", "help"]) def send_welcome(message): - _bot.reply_to(message, "温馨提示:直接发送名称或`订阅`+名称,搜索或订阅电影、电视剧") + _bot.reply_to( + message, + "温馨提示:直接发送名称或`订阅`+名称,搜索或订阅电影、电视剧", + ) @_bot.message_handler(func=lambda message: True) def echo_all(message): @@ -82,7 +99,7 @@ class Telegram: if self._should_process_message(message): # 发送正在输入状态 try: - _bot.send_chat_action(message.chat.id, 'typing') + _bot.send_chat_action(message.chat.id, "typing") except Exception as err: logger.error(f"发送Telegram正在输入状态失败:{err}") RequestUtils(timeout=15).post_res(self._ds_url, json=message.json) @@ -94,7 +111,9 @@ class Telegram: """ try: # Update user-chat mapping for callbacks too - self._update_user_chat_mapping(call.from_user.id, call.message.chat.id) + self._update_user_chat_mapping( + call.from_user.id, call.message.chat.id + ) # 解析回调数据 callback_data = call.data @@ -111,9 +130,9 @@ class Telegram: "message_id": call.message.message_id, "chat": { "id": call.message.chat.id, - } + }, }, - "data": callback_data + "data": callback_data, } } @@ -122,7 +141,7 @@ class Telegram: # 发送正在输入状态 try: - _bot.send_chat_action(call.message.chat.id, 'typing') + _bot.send_chat_action(call.message.chat.id, "typing") except Exception as e: logger.error(f"发送Telegram正在输入状态失败:{e}") @@ -179,17 +198,17 @@ class Telegram: :return: 是否处理 """ # 私聊消息总是处理 - if message.chat.type == 'private': + if message.chat.type == "private": logger.debug(f"处理私聊消息:用户 {message.from_user.id}") return True # 群聊中的命令消息总是处理(以/开头) - if message.text and message.text.startswith('/'): + if message.text and message.text.startswith("/"): logger.debug(f"处理群聊命令消息:{message.text[:20]}...") return True # 群聊中检查是否@了机器人 - if message.chat.type in ['group', 'supergroup']: + if message.chat.type in ["group", "supergroup"]: if not self._bot_username: # 如果没有获取到bot用户名,为了安全起见处理所有消息 logger.debug("未获取到bot用户名,处理所有群聊消息") @@ -203,14 +222,20 @@ class Telegram: # 检查消息实体中是否有提及bot if message.entities: for entity in message.entities: - if entity.type == 'mention': - mention_text = message.text[entity.offset:entity.offset + entity.length] + if entity.type == "mention": + mention_text = message.text[ + entity.offset : entity.offset + entity.length + ] if mention_text == f"@{self._bot_username}": - logger.debug(f"通过实体检测到@{self._bot_username},处理群聊消息") + logger.debug( + f"通过实体检测到@{self._bot_username},处理群聊消息" + ) return True # 群聊中没有@机器人,不处理 - logger.debug(f"群聊消息未@机器人,跳过处理:{message.text[:30] if message.text else 'No text'}...") + logger.debug( + f"群聊消息未@机器人,跳过处理:{message.text[:30] if message.text else 'No text'}..." + ) return False # 其他类型的聊天默认处理 @@ -223,11 +248,17 @@ class Telegram: """ return self._bot is not None - def send_msg(self, title: str, text: Optional[str] = None, image: Optional[str] = None, - userid: Optional[str] = None, link: Optional[str] = None, - buttons: Optional[List[List[dict]]] = None, - original_message_id: Optional[int] = None, - original_chat_id: Optional[str] = None) -> Optional[bool]: + def send_msg( + self, + title: str, + text: Optional[str] = None, + image: Optional[str] = None, + userid: Optional[str] = None, + link: Optional[str] = None, + buttons: Optional[List[List[dict]]] = None, + original_message_id: Optional[int] = None, + original_chat_id: Optional[str] = None, + ) -> Optional[dict]: """ 发送Telegram消息 :param title: 消息标题 @@ -238,14 +269,14 @@ class Telegram: :param buttons: 按钮列表,格式:[[{"text": "按钮文本", "callback_data": "回调数据"}]] :param original_message_id: 原消息ID,如果提供则编辑原消息 :param original_chat_id: 原消息的聊天ID,编辑消息时需要 - + :return: 包含 message_id, chat_id, success 的字典 """ if not self._telegram_token or not self._telegram_chat_id: return None if not title and not text: logger.warn("标题和内容不能同时为空") - return False + return {"success": False} try: # 标准化标题后再加粗,避免**符号被显示为文本 @@ -275,17 +306,39 @@ class Telegram: # 判断是编辑消息还是发送新消息 if original_message_id and original_chat_id: # 编辑消息 - return self.__edit_message(original_chat_id, original_message_id, caption, buttons, image) + result = self.__edit_message( + original_chat_id, original_message_id, caption, buttons, image + ) + return { + "success": bool(result), + "message_id": original_message_id, + "chat_id": original_chat_id, + } else: # 发送新消息 - return self.__send_request(userid=chat_id, image=image, caption=caption, reply_markup=reply_markup) + sent = self.__send_request( + userid=chat_id, + image=image, + caption=caption, + reply_markup=reply_markup, + ) + if sent and hasattr(sent, "message_id"): + return { + "success": True, + "message_id": sent.message_id, + "chat_id": sent.chat.id if hasattr(sent, "chat") else chat_id, + } + elif sent: + return {"success": True} + return {"success": False} except Exception as msg_e: logger.error(f"发送消息失败:{msg_e}") - return False + return {"success": False} - def _determine_target_chat_id(self, userid: Optional[str] = None, - original_chat_id: Optional[str] = None) -> str: + def _determine_target_chat_id( + self, userid: Optional[str] = None, original_chat_id: Optional[str] = None + ) -> str: """ 确定目标聊天ID,使用用户映射确保回复到正确的聊天 :param userid: 用户ID @@ -307,11 +360,16 @@ class Telegram: # 3. 最后使用默认聊天ID return self._telegram_chat_id - def send_medias_msg(self, medias: List[MediaInfo], userid: Optional[str] = None, - title: Optional[str] = None, link: Optional[str] = None, - buttons: Optional[List[List[Dict]]] = None, - original_message_id: Optional[int] = None, - original_chat_id: Optional[str] = None) -> Optional[bool]: + def send_medias_msg( + self, + medias: List[MediaInfo], + userid: Optional[str] = None, + title: Optional[str] = None, + link: Optional[str] = None, + buttons: Optional[List[List[Dict]]] = None, + original_message_id: Optional[int] = None, + original_chat_id: Optional[str] = None, + ) -> Optional[bool]: """ 发送媒体列表消息 :param medias: 媒体信息列表 @@ -331,18 +389,22 @@ class Telegram: if not image: image = media.get_message_image() if media.vote_average: - caption = "%s\n%s. [%s](%s)\n_%s,%s_" % (caption, - index, - media.title_year, - media.detail_link, - f"类型:{media.type.value}", - f"评分:{media.vote_average}") + caption = "%s\n%s. [%s](%s)\n_%s,%s_" % ( + caption, + index, + media.title_year, + media.detail_link, + f"类型:{media.type.value}", + f"评分:{media.vote_average}", + ) else: - caption = "%s\n%s. [%s](%s)\n_%s_" % (caption, - index, - media.title_year, - media.detail_link, - f"类型:{media.type.value}") + caption = "%s\n%s. [%s](%s)\n_%s_" % ( + caption, + index, + media.title_year, + media.detail_link, + f"类型:{media.type.value}", + ) index += 1 if link: @@ -359,20 +421,32 @@ class Telegram: # 判断是编辑消息还是发送新消息 if original_message_id and original_chat_id: # 编辑消息 - return self.__edit_message(original_chat_id, original_message_id, caption, buttons, image) + return self.__edit_message( + original_chat_id, original_message_id, caption, buttons, image + ) else: # 发送新消息 - return self.__send_request(userid=chat_id, image=image, caption=caption, reply_markup=reply_markup) + return self.__send_request( + userid=chat_id, + image=image, + caption=caption, + reply_markup=reply_markup, + ) except Exception as msg_e: logger.error(f"发送消息失败:{msg_e}") return False - def send_torrents_msg(self, torrents: List[Context], - userid: Optional[str] = None, title: Optional[str] = None, - link: Optional[str] = None, buttons: Optional[List[List[Dict]]] = None, - original_message_id: Optional[int] = None, - original_chat_id: Optional[str] = None) -> Optional[bool]: + def send_torrents_msg( + self, + torrents: List[Context], + userid: Optional[str] = None, + title: Optional[str] = None, + link: Optional[str] = None, + buttons: Optional[List[List[Dict]]] = None, + original_message_id: Optional[int] = None, + original_chat_id: Optional[str] = None, + ) -> Optional[bool]: """ 发送种子列表消息 :param torrents: 种子信息列表 @@ -394,15 +468,19 @@ class Telegram: site_name = torrent.site_name meta = MetaInfo(torrent.title, torrent.description) link = torrent.page_url - title = f"{meta.season_episode} " \ - f"{meta.resource_term} " \ - f"{meta.video_term} " \ - f"{meta.release_group}" + title = ( + f"{meta.season_episode} " + f"{meta.resource_term} " + f"{meta.video_term} " + f"{meta.release_group}" + ) title = re.sub(r"\s+", " ", title).strip() free = torrent.volume_factor seeder = f"{torrent.seeders}↑" - caption = f"{caption}\n{index}.【{site_name}】[{title}]({link}) " \ - f"{StringUtils.str_filesize(torrent.size)} {free} {seeder}" + caption = ( + f"{caption}\n{index}.【{site_name}】[{title}]({link}) " + f"{StringUtils.str_filesize(torrent.size)} {free} {seeder}" + ) index += 1 if link: @@ -419,10 +497,17 @@ class Telegram: # 判断是编辑消息还是发送新消息 if original_message_id and original_chat_id: # 编辑消息(种子消息通常没有图片) - return self.__edit_message(original_chat_id, original_message_id, caption, buttons, image) + return self.__edit_message( + original_chat_id, original_message_id, caption, buttons, image + ) else: # 发送新消息 - return self.__send_request(userid=chat_id, image=image, caption=caption, reply_markup=reply_markup) + return self.__send_request( + userid=chat_id, + image=image, + caption=caption, + reply_markup=reply_markup, + ) except Exception as msg_e: logger.error(f"发送消息失败:{msg_e}") @@ -444,13 +529,19 @@ class Telegram: btn = InlineKeyboardButton(text=button["text"], url=button["url"]) else: # 回调按钮 - btn = InlineKeyboardButton(text=button["text"], callback_data=button["callback_data"]) + btn = InlineKeyboardButton( + text=button["text"], callback_data=button["callback_data"] + ) button_row.append(btn) keyboard.append(button_row) return InlineKeyboardMarkup(keyboard) - def answer_callback_query(self, callback_query_id: int, text: Optional[str] = None, - show_alert: bool = False) -> Optional[bool]: + def answer_callback_query( + self, + callback_query_id: int, + text: Optional[str] = None, + show_alert: bool = False, + ) -> Optional[bool]: """ 回应回调查询 """ @@ -458,13 +549,17 @@ class Telegram: return None try: - self._bot.answer_callback_query(callback_query_id, text=text, show_alert=show_alert) + self._bot.answer_callback_query( + callback_query_id, text=text, show_alert=show_alert + ) return True except Exception as e: logger.error(f"回应回调查询失败:{str(e)}") return False - def delete_msg(self, message_id: int, chat_id: Optional[int] = None) -> Optional[bool]: + def delete_msg( + self, message_id: int, chat_id: Optional[int] = None + ) -> Optional[bool]: """ 删除Telegram消息 :param message_id: 消息ID @@ -482,20 +577,68 @@ class Telegram: target_chat_id = self._telegram_chat_id # 删除消息 - result = self._bot.delete_message(chat_id=target_chat_id, message_id=int(message_id)) + result = self._bot.delete_message( + chat_id=target_chat_id, message_id=int(message_id) + ) if result: - logger.info(f"成功删除Telegram消息: chat_id={target_chat_id}, message_id={message_id}") + logger.info( + f"成功删除Telegram消息: chat_id={target_chat_id}, message_id={message_id}" + ) return True else: - logger.error(f"删除Telegram消息失败: chat_id={target_chat_id}, message_id={message_id}") + logger.error( + f"删除Telegram消息失败: chat_id={target_chat_id}, message_id={message_id}" + ) return False except Exception as e: logger.error(f"删除Telegram消息异常: {str(e)}") return False - def __edit_message(self, chat_id: str, message_id: int, text: str, - buttons: Optional[List[List[dict]]] = None, - image: Optional[str] = None) -> Optional[bool]: + def edit_msg( + self, + chat_id: Union[str, int], + message_id: Union[str, int], + text: str, + title: Optional[str] = None, + ) -> Optional[bool]: + """ + 编辑Telegram消息(公开方法) + :param chat_id: 聊天ID + :param message_id: 消息ID + :param text: 新的消息内容 + :param title: 消息标题 + :return: 编辑是否成功 + """ + if not self._bot: + return None + + try: + # 组合标题和文本 + if title: + bold_title = f"**{standardize(title).removesuffix(chr(10))}**" + caption = f"{bold_title}\n{text}" if text else bold_title + elif text: + caption = text + else: + return False + + return self.__edit_message( + chat_id=str(chat_id), + message_id=int(message_id), + text=caption, + ) + except Exception as e: + logger.error(f"编辑Telegram消息异常: {str(e)}") + return False + + def __edit_message( + self, + chat_id: str, + message_id: int, + text: str, + buttons: Optional[List[List[dict]]] = None, + image: Optional[str] = None, + ) -> Optional[bool]: """ 编辑已发送的消息 :param chat_id: 聊天ID @@ -509,7 +652,6 @@ class Telegram: return None try: - # 创建按钮键盘 reply_markup = None if buttons: @@ -517,12 +659,14 @@ class Telegram: if image: # 如果有图片,使用edit_message_media - media = InputMediaPhoto(media=image, caption=standardize(text), parse_mode="MarkdownV2") + media = InputMediaPhoto( + media=image, caption=standardize(text), parse_mode="MarkdownV2" + ) self._bot.edit_message_media( chat_id=chat_id, message_id=message_id, media=media, - reply_markup=reply_markup + reply_markup=reply_markup, ) else: # 如果没有图片,使用edit_message_text @@ -531,23 +675,29 @@ class Telegram: message_id=message_id, text=standardize(text), parse_mode="MarkdownV2", - reply_markup=reply_markup + reply_markup=reply_markup, ) return True except Exception as e: logger.error(f"编辑消息失败:{str(e)}") return False - def __send_request(self, userid: Optional[str] = None, image="", caption="", - reply_markup: Optional[InlineKeyboardMarkup] = None) -> bool: + def __send_request( + self, + userid: Optional[str] = None, + image="", + caption="", + reply_markup: Optional[InlineKeyboardMarkup] = None, + ): """ - 向Telegram发送报文 + 向Telegram发送报文,返回发送的消息对象 :param reply_markup: 内联键盘 + :return: 发送成功返回消息对象,失败返回None """ kwargs = { - 'chat_id': userid or self._telegram_chat_id, - 'parse_mode': "MarkdownV2", - 'reply_markup': reply_markup + "chat_id": userid or self._telegram_chat_id, + "parse_mode": "MarkdownV2", + "reply_markup": reply_markup, } # 处理图片 @@ -562,10 +712,10 @@ class Telegram: sent_idx = set() ret = self.__send_long_message(image, caption, sent_idx, **kwargs) - return ret is not None + return ret except Exception as e: logger.error(f"发送Telegram消息失败: {e}") - return False + return None @staticmethod def __process_image(image_url: Optional[str]) -> Optional[bytes]: @@ -587,27 +737,28 @@ class Telegram: try: if image: return self._bot.send_photo( - photo=image, - caption=standardize(caption), - **kwargs + photo=image, caption=standardize(caption), **kwargs ) else: - return self._bot.send_message( - text=standardize(caption), - **kwargs - ) + return self._bot.send_message(text=standardize(caption), **kwargs) except Exception: raise RetryException(f"发送{'图片' if image else '文本'}消息失败") @retry(RetryException, logger=logger) - def __send_long_message(self, image: Optional[bytes], caption: str, sent_idx: set, **kwargs): + def __send_long_message( + self, image: Optional[bytes], caption: str, sent_idx: set, **kwargs + ): """ 发送长消息 """ try: reply_markup = kwargs.pop("reply_markup", None) - boxs: SentType = ThreadHelper().submit(lambda x: asyncio.run(telegramify(x)), caption).result() + boxs: SentType = ( + ThreadHelper() + .submit(lambda x: asyncio.run(telegramify(x)), caption) + .result() + ) ret = None for i, item in enumerate(boxs): @@ -618,24 +769,27 @@ class Telegram: current_reply_markup = reply_markup if i == 0 else None if item.content_type == ContentTypes.TEXT and (i != 0 or not image): - ret = self._bot.send_message(**kwargs, - text=item.content, - reply_markup=current_reply_markup + ret = self._bot.send_message( + **kwargs, text=item.content, reply_markup=current_reply_markup ) elif item.content_type == ContentTypes.PHOTO or (image and i == 0): - ret = self._bot.send_photo(**kwargs, - photo=(getattr(item, "file_name", ""), - getattr(item, "file_data", image)), + ret = self._bot.send_photo( + **kwargs, + photo=( + getattr(item, "file_name", ""), + getattr(item, "file_data", image), + ), caption=getattr(item, "caption", item.content), - reply_markup=current_reply_markup + reply_markup=current_reply_markup, ) elif item.content_type == ContentTypes.FILE: - ret = self._bot.send_document(**kwargs, + ret = self._bot.send_document( + **kwargs, document=(item.file_name, item.file_data), caption=item.caption, - reply_markup=current_reply_markup + reply_markup=current_reply_markup, ) sent_idx.add(i) @@ -658,8 +812,8 @@ class Telegram: self._bot.delete_my_commands() self._bot.set_my_commands( commands=[ - BotCommand(cmd[1:], str(desc.get("description"))) for cmd, desc in - commands.items() + BotCommand(cmd[1:], str(desc.get("description"))) + for cmd, desc in commands.items() ] ) diff --git a/app/schemas/message.py b/app/schemas/message.py index 86c265b1..18c3557e 100644 --- a/app/schemas/message.py +++ b/app/schemas/message.py @@ -7,10 +7,28 @@ from pydantic import BaseModel, Field from app.schemas.types import ContentType, NotificationType, MessageChannel +class MessageResponse(BaseModel): + """ + 消息发送响应,包含消息ID等信息用于后续编辑 + """ + + # 消息ID + message_id: Optional[Union[str, int]] = None + # 聊天ID + chat_id: Optional[Union[str, int]] = None + # 消息渠道 + channel: Optional[MessageChannel] = None + # 消息来源 + source: Optional[str] = None + # 是否发送成功 + success: bool = False + + class CommingMessage(BaseModel): """ 外来消息 """ + # 用户ID userid: Optional[Union[str, int]] = None # 用户名称 @@ -51,6 +69,7 @@ class Notification(BaseModel): """ 消息 """ + # 消息渠道 channel: Optional[MessageChannel] = None # 消息来源 @@ -90,8 +109,7 @@ class Notification(BaseModel): """ items = self.model_dump() for k, v in items.items(): - if isinstance(v, MessageChannel) \ - or isinstance(v, NotificationType): + if isinstance(v, MessageChannel) or isinstance(v, NotificationType): items[k] = v.value return items @@ -100,6 +118,7 @@ class NotificationSwitch(BaseModel): """ 消息开关 """ + # 消息类型 mtype: Optional[str] = None # 微信开关 @@ -122,6 +141,7 @@ class Subscription(BaseModel): """ 客户端消息订阅 """ + endpoint: Optional[str] = None keys: Optional[dict] = Field(default_factory=dict) @@ -130,6 +150,7 @@ class SubscriptionMessage(BaseModel): """ 客户端订阅消息体 """ + title: Optional[str] = None body: Optional[str] = None icon: Optional[str] = None @@ -141,6 +162,7 @@ class ChannelCapability(Enum): """ 渠道能力枚举 """ + # 支持内联按钮 INLINE_BUTTONS = "inline_buttons" # 支持菜单命令 @@ -166,6 +188,7 @@ class ChannelCapabilities: """ 渠道能力配置 """ + channel: MessageChannel capabilities: Set[ChannelCapability] max_buttons_per_row: int = 5 @@ -191,20 +214,20 @@ class ChannelCapabilityManager: ChannelCapability.RICH_TEXT, ChannelCapability.IMAGES, ChannelCapability.LINKS, - ChannelCapability.FILE_SENDING + ChannelCapability.FILE_SENDING, }, max_buttons_per_row=4, max_button_rows=10, - max_button_text_length=30 + max_button_text_length=30, ), MessageChannel.Wechat: ChannelCapabilities( channel=MessageChannel.Wechat, capabilities={ ChannelCapability.IMAGES, ChannelCapability.LINKS, - ChannelCapability.MENU_COMMANDS + ChannelCapability.MENU_COMMANDS, }, - fallback_enabled=True + fallback_enabled=True, ), MessageChannel.Slack: ChannelCapabilities( channel=MessageChannel.Slack, @@ -216,12 +239,12 @@ class ChannelCapabilityManager: ChannelCapability.RICH_TEXT, ChannelCapability.IMAGES, ChannelCapability.LINKS, - ChannelCapability.MENU_COMMANDS + ChannelCapability.MENU_COMMANDS, }, max_buttons_per_row=3, max_button_rows=8, max_button_text_length=25, - fallback_enabled=True + fallback_enabled=True, ), MessageChannel.Discord: ChannelCapabilities( channel=MessageChannel.Discord, @@ -232,56 +255,54 @@ class ChannelCapabilityManager: ChannelCapability.CALLBACK_QUERIES, ChannelCapability.RICH_TEXT, ChannelCapability.IMAGES, - ChannelCapability.LINKS + ChannelCapability.LINKS, }, max_buttons_per_row=5, max_button_rows=5, max_button_text_length=80, - fallback_enabled=True + fallback_enabled=True, ), MessageChannel.SynologyChat: ChannelCapabilities( channel=MessageChannel.SynologyChat, capabilities={ ChannelCapability.RICH_TEXT, ChannelCapability.IMAGES, - ChannelCapability.LINKS + ChannelCapability.LINKS, }, - fallback_enabled=True + fallback_enabled=True, ), MessageChannel.VoceChat: ChannelCapabilities( channel=MessageChannel.VoceChat, capabilities={ ChannelCapability.RICH_TEXT, ChannelCapability.IMAGES, - ChannelCapability.LINKS + ChannelCapability.LINKS, }, - fallback_enabled=True + fallback_enabled=True, ), MessageChannel.WebPush: ChannelCapabilities( channel=MessageChannel.WebPush, - capabilities={ - ChannelCapability.LINKS - }, - fallback_enabled=True + capabilities={ChannelCapability.LINKS}, + fallback_enabled=True, ), MessageChannel.Web: ChannelCapabilities( channel=MessageChannel.Web, capabilities={ ChannelCapability.RICH_TEXT, ChannelCapability.IMAGES, - ChannelCapability.LINKS + ChannelCapability.LINKS, }, - fallback_enabled=True + fallback_enabled=True, ), MessageChannel.QQ: ChannelCapabilities( channel=MessageChannel.QQ, capabilities={ ChannelCapability.RICH_TEXT, ChannelCapability.IMAGES, - ChannelCapability.LINKS + ChannelCapability.LINKS, }, - fallback_enabled=True - ) + fallback_enabled=True, + ), } @classmethod @@ -292,7 +313,9 @@ class ChannelCapabilityManager: return cls._capabilities.get(channel) @classmethod - def supports_capability(cls, channel: MessageChannel, capability: ChannelCapability) -> bool: + def supports_capability( + cls, channel: MessageChannel, capability: ChannelCapability + ) -> bool: """ 检查渠道是否支持某项能力 """