diff --git a/app/chain/message.py b/app/chain/message.py index 0118d03f..2f9c25cf 100644 --- a/app/chain/message.py +++ b/app/chain/message.py @@ -16,7 +16,9 @@ from app.chain.interaction import ( agent_interaction_manager, media_interaction_manager, ) +from app.chain.site import SiteChain, site_interaction_manager from app.chain.skills import SkillsChain, skills_interaction_manager +from app.chain.subscribe import SubscribeChain, subscribe_interaction_manager from app.chain.transfer import TransferChain from app.core.config import settings, global_vars from app.db.models import TransferHistory @@ -170,13 +172,34 @@ class MessageChain(ChainBase): ) return - if skills_interaction_manager.get_by_user(userid): + latest_slash_interaction = self._get_latest_slash_interaction(userid) + if latest_slash_interaction == "sites": + if SiteChain().handle_text_interaction( + channel=channel, + source=source, + userid=userid, + username=username, + text=text, + ): + return + + if latest_slash_interaction == "subscribes": + if SubscribeChain().handle_text_interaction( + channel=channel, + source=source, + userid=userid, + username=username, + text=text, + ): + return + + if latest_slash_interaction == "skills": if SkillsChain().handle_text_interaction( - channel=channel, - source=source, - userid=userid, - username=username, - text=text, + channel=channel, + source=source, + userid=userid, + username=username, + text=text, ): return @@ -274,6 +297,28 @@ class MessageChain(ChainBase): ): return + if SiteChain().handle_callback_interaction( + callback_data=callback_data, + channel=channel, + source=source, + userid=userid, + username=username, + original_message_id=original_message_id, + original_chat_id=original_chat_id, + ): + return + + if SubscribeChain().handle_callback_interaction( + callback_data=callback_data, + channel=channel, + source=source, + userid=userid, + username=username, + original_message_id=original_message_id, + original_chat_id=original_chat_id, + ): + return + if MediaInteractionChain().handle_callback_interaction( callback_data=callback_data, channel=channel, @@ -326,6 +371,24 @@ class MessageChain(ChainBase): ) ) + @staticmethod + def _get_latest_slash_interaction(userid: Union[str, int]) -> Optional[str]: + """ + 返回当前用户最近一次激活的 slash 交互类型。 + """ + candidates = [] + for name, manager in ( + ("sites", site_interaction_manager), + ("subscribes", subscribe_interaction_manager), + ("skills", skills_interaction_manager), + ): + request = manager.get_by_user(userid) + if request: + candidates.append((request.created_at, name)) + if not candidates: + return None + return max(candidates, key=lambda item: item[0])[1] + @staticmethod def _parse_transfer_callback( callback_data: str, diff --git a/app/chain/site.py b/app/chain/site.py index bca0a306..54c6961f 100644 --- a/app/chain/site.py +++ b/app/chain/site.py @@ -1,12 +1,21 @@ import base64 import re from datetime import datetime -from typing import Optional, Tuple, Union, Dict +from typing import List, Optional, Tuple, Union, Dict from urllib.parse import urljoin from lxml import etree from app.chain import ChainBase +from app.helper.slash import ( + SlashInteractionManager, + build_navigation_buttons, + format_markdown_table, + page_items, + supports_interaction_buttons, + supports_markdown, + update_or_post_message, +) from app.core.config import global_vars, settings from app.core.event import Event, eventmanager from app.db.models.site import Site @@ -26,11 +35,17 @@ from app.utils.site import SiteUtils from app.utils.string import StringUtils +site_interaction_manager = SlashInteractionManager() + + class SiteChain(ChainBase): """ 站点管理处理链 """ + _button_page_size = 6 + _text_page_size = 10 + def __init__(self): super().__init__() @@ -626,39 +641,548 @@ class SiteChain(ChainBase): return False, f"无法打开网站!" return True, "连接成功" - def remote_list(self, channel: MessageChannel, - userid: Union[str, int] = None, source: Optional[str] = None): + def remote_list( + self, + arg_str: str = "", + channel: MessageChannel = None, + userid: Union[str, int] = None, + source: Optional[str] = None, + ): """ - 查询所有站点,发送消息 + /sites 统一入口。 """ - site_list = SiteOper().list() - if not site_list: - self.post_message(Notification( - channel=channel, - title="没有维护任何站点信息!", - userid=userid, - link=settings.MP_DOMAIN('#/site'))) - title = f"共有 {len(site_list)} 个站点,回复对应指令操作:" \ - f"\n- 禁用站点:/site_disable [id]" \ - f"\n- 启用站点:/site_enable [id]" \ - f"\n- 更新站点Cookie:/site_cookie [id] [username] [password] [2fa_code/secret]" - messages = [] - for site in site_list: - if site.render: - render_str = "🧭" - else: - render_str = "" - if site.is_active: - messages.append(f"{site.id}. {site.name} {render_str}") - else: - messages.append(f"{site.id}. {site.name} ⚠️") - # 发送列表 - self.post_message(Notification( + request = site_interaction_manager.create_or_replace( + user_id=userid, + command="/sites", channel=channel, source=source, - title=title, text="\n".join(messages), userid=userid, - link=settings.MP_DOMAIN('#/site')) + username=None, ) + normalized_arg = (arg_str or "").strip() + if normalized_arg and self.handle_text_interaction( + channel=channel, + source=source, + userid=userid, + username="", + text=normalized_arg, + ): + return + self._render_site_interaction( + request=request, + channel=channel, + source=source, + userid=userid, + username="", + ) + + @staticmethod + def parse_callback(callback_data: str) -> Optional[Tuple[str, str]]: + """ + 解析 /sites 按钮回调。 + """ + if not callback_data.startswith("sites:"): + return None + parts = callback_data.split(":") + if len(parts) < 3: + return None + return parts[1], parts[2] + + def handle_callback_interaction( + self, + callback_data: str, + channel: MessageChannel, + source: str, + userid: Union[str, int], + username: str, + original_message_id: Optional[Union[str, int]] = None, + original_chat_id: Optional[str] = None, + ) -> bool: + """ + 处理 /sites 按钮交互。 + """ + parsed = self.parse_callback(callback_data) + if not parsed: + return False + + request_id, action = parsed + request = site_interaction_manager.get_by_id(request_id, userid) + if not request: + self.post_message( + Notification( + channel=channel, + source=source, + userid=userid, + username=username, + title="站点交互已失效,请重新发送 /sites", + ) + ) + return True + + request.channel = channel + request.source = source + request.username = username + + if action == "close": + site_interaction_manager.remove(request.request_id) + update_or_post_message( + chain=self, + channel=channel, + source=source, + userid=userid, + username=username, + title="站点管理", + text="站点交互已结束", + original_message_id=original_message_id, + original_chat_id=original_chat_id, + ) + return True + + if action == "page-prev": + request.page = max(0, request.page - 1) + request.awaiting_input = None + elif action == "page-next": + request.page += 1 + request.awaiting_input = None + elif action in {"cookie", "enable", "disable"}: + request.awaiting_input = action + elif action == "refresh": + request.awaiting_input = None + + self._render_site_interaction( + request=request, + channel=channel, + source=source, + userid=userid, + username=username, + original_message_id=original_message_id, + original_chat_id=original_chat_id, + ) + return True + + def handle_text_interaction( + self, + channel: MessageChannel, + source: str, + userid: Union[str, int], + username: str, + text: str, + ) -> bool: + """ + 处理 /sites 文本补充输入。 + """ + request = site_interaction_manager.get_by_user(userid) + if not request: + return False + + request.channel = channel + request.source = source + request.username = username + + normalized = (text or "").strip() + lowered = normalized.lower() + + if lowered in {"退出", "关闭", "q", "quit", "exit"}: + site_interaction_manager.remove(request.request_id) + self.post_message( + Notification( + channel=channel, + source=source, + userid=userid, + username=username, + title="站点交互已结束", + ) + ) + return True + + if lowered in {"取消", "cancel", "返回", "back"}: + request.awaiting_input = None + self._render_site_interaction( + request=request, + channel=channel, + source=source, + userid=userid, + username=username, + ) + return True + + if lowered in {"刷新", "refresh", "列表", "list"}: + request.awaiting_input = None + self._render_site_interaction( + request=request, + channel=channel, + source=source, + userid=userid, + username=username, + ) + return True + + if lowered in {"p", "prev", "上一页"}: + request.awaiting_input = None + request.page = max(0, request.page - 1) + self._render_site_interaction( + request=request, + channel=channel, + source=source, + userid=userid, + username=username, + ) + return True + + if lowered in {"n", "next", "下一页"}: + request.awaiting_input = None + request.page += 1 + self._render_site_interaction( + request=request, + channel=channel, + source=source, + userid=userid, + username=username, + ) + return True + + cookie_match = re.match( + r"^(?:cookie|更新cookie|更新\s*cookie)\s+(.+)$", + normalized, + re.IGNORECASE, + ) + enable_match = re.match(r"^(?:启用|enable)\s+(.+)$", normalized, re.IGNORECASE) + disable_match = re.match( + r"^(?:禁用|disable)\s+(.+)$", normalized, re.IGNORECASE + ) + + if request.awaiting_input == "cookie": + success, message = self._update_site_cookie_from_input(normalized) + request.awaiting_input = None + self.post_message( + Notification( + channel=channel, + source=source, + userid=userid, + username=username, + title=message, + ) + ) + self._render_site_interaction( + request=request, + channel=channel, + source=source, + userid=userid, + username=username, + ) + return True + + if request.awaiting_input == "enable": + success, message = self._set_sites_enabled(normalized, enabled=True) + request.awaiting_input = None + self.post_message( + Notification( + channel=channel, + source=source, + userid=userid, + username=username, + title=message, + ) + ) + self._render_site_interaction( + request=request, + channel=channel, + source=source, + userid=userid, + username=username, + ) + return True + + if request.awaiting_input == "disable": + success, message = self._set_sites_enabled(normalized, enabled=False) + request.awaiting_input = None + self.post_message( + Notification( + channel=channel, + source=source, + userid=userid, + username=username, + title=message, + ) + ) + self._render_site_interaction( + request=request, + channel=channel, + source=source, + userid=userid, + username=username, + ) + return True + + if cookie_match: + success, message = self._update_site_cookie_from_input(cookie_match.group(1)) + self.post_message( + Notification( + channel=channel, + source=source, + userid=userid, + username=username, + title=message, + ) + ) + self._render_site_interaction( + request=request, + channel=channel, + source=source, + userid=userid, + username=username, + ) + return True + + if enable_match: + success, message = self._set_sites_enabled(enable_match.group(1), enabled=True) + self.post_message( + Notification( + channel=channel, + source=source, + userid=userid, + username=username, + title=message, + ) + ) + self._render_site_interaction( + request=request, + channel=channel, + source=source, + userid=userid, + username=username, + ) + return True + + if disable_match: + success, message = self._set_sites_enabled( + disable_match.group(1), enabled=False + ) + self.post_message( + Notification( + channel=channel, + source=source, + userid=userid, + username=username, + title=message, + ) + ) + self._render_site_interaction( + request=request, + channel=channel, + source=source, + userid=userid, + username=username, + ) + return True + + self.post_message( + Notification( + channel=channel, + source=source, + userid=userid, + username=username, + title=self._site_usage_hint(request.awaiting_input), + ) + ) + return True + + def _render_site_interaction( + self, + request, + channel: MessageChannel, + source: Optional[str], + userid: Union[str, int], + username: Optional[str], + original_message_id: Optional[Union[str, int]] = None, + original_chat_id: Optional[str] = None, + ) -> None: + """ + 渲染 /sites 当前页面。 + """ + site_list = SiteOper().list() + page_size = self._button_page_size if supports_interaction_buttons(channel) else self._text_page_size + page_sites, page, total_pages = page_items(site_list, request.page, page_size) + request.page = page + + if site_list: + body = self._format_site_list(page_sites, channel=channel) + footer = [ + f"第 {page + 1}/{total_pages} 页,共 {len(site_list)} 个站点", + self._site_prompt(request.awaiting_input), + self._site_usage_hint(request.awaiting_input), + ] + text = "\n\n".join([body, *[line for line in footer if line]]) + else: + text = "当前没有任何站点。\n\n输入 `退出` 结束交互。" + + buttons = None + if supports_interaction_buttons(channel): + buttons = build_navigation_buttons("sites", request, page, total_pages) + buttons.extend( + [ + [ + { + "text": "更新 Cookie", + "callback_data": f"sites:{request.request_id}:cookie", + }, + { + "text": "禁用站点", + "callback_data": f"sites:{request.request_id}:disable", + }, + { + "text": "启用站点", + "callback_data": f"sites:{request.request_id}:enable", + }, + ], + [ + { + "text": "刷新列表", + "callback_data": f"sites:{request.request_id}:refresh", + }, + { + "text": "关闭", + "callback_data": f"sites:{request.request_id}:close", + }, + ], + ] + ) + + update_or_post_message( + chain=self, + channel=channel, + source=source, + userid=userid, + username=username, + title="站点管理", + text=text, + buttons=buttons, + original_message_id=original_message_id, + original_chat_id=original_chat_id, + ) + + def _format_site_list( + self, site_list: List[Site], channel: Optional[MessageChannel] + ) -> str: + """ + 根据渠道能力格式化站点列表。 + """ + if supports_markdown(channel): + rows = [ + [ + site.id, + site.name, + "启用" if site.is_active else "禁用", + "已配置" if site.cookie else "未配置", + "是" if site.render else "否", + site.domain or StringUtils.get_url_domain(site.url or ""), + ] + for site in site_list + ] + return format_markdown_table( + headers=["ID", "站点", "状态", "Cookie", "渲染", "域名"], + rows=rows, + ) + + lines = [] + for site in site_list: + lines.append( + f"{site.id}. {site.name} | 状态:{'启用' if site.is_active else '禁用'}" + f" | Cookie:{'已配置' if site.cookie else '未配置'}" + f" | 渲染:{'是' if site.render else '否'}" + f" | 域名:{site.domain or StringUtils.get_url_domain(site.url or '')}" + ) + return "\n".join(lines) + + @staticmethod + def _site_prompt(awaiting_input: Optional[str]) -> str: + """ + 返回当前输入模式提示。 + """ + if awaiting_input == "cookie": + return "当前操作:更新站点 Cookie,请输入: [2fa_code/secret]" + if awaiting_input == "enable": + return "当前操作:启用站点,请输入站点 ID,多个 ID 用空格分隔。" + if awaiting_input == "disable": + return "当前操作:禁用站点,请输入站点 ID,多个 ID 用空格分隔。" + return "" + + @staticmethod + def _site_usage_hint(awaiting_input: Optional[str]) -> str: + """ + 返回 /sites 的文本操作提示。 + """ + if awaiting_input == "cookie": + return "输入站点 ID、用户名、密码和可选 2FA;输入 `取消` 返回列表,输入 `退出` 结束交互。" + if awaiting_input in {"enable", "disable"}: + return "输入一个或多个站点 ID;输入 `取消` 返回列表,输入 `退出` 结束交互。" + return ( + "可输入:`cookie [2fa]`、`启用 `、`禁用 `、" + "`n`、`p`、`刷新`、`退出`。" + ) + + @staticmethod + def _parse_site_ids(arg_str: str) -> List[int]: + """ + 从输入中提取站点 ID。 + """ + return [int(item) for item in re.findall(r"\d+", arg_str or "")] + + def _set_sites_enabled(self, arg_str: str, enabled: bool) -> Tuple[bool, str]: + """ + 批量启用或禁用站点。 + """ + site_ids = self._parse_site_ids(arg_str) + if not site_ids: + return False, "请输入至少一个有效的站点 ID" + + siteoper = SiteOper() + changed = [] + missing = [] + for site_id in site_ids: + site = siteoper.get(site_id) + if not site: + missing.append(str(site_id)) + continue + siteoper.update(site_id, {"is_active": enabled}) + changed.append(site.name) + + action = "启用" if enabled else "禁用" + if not changed and missing: + return False, f"未找到站点:{', '.join(missing)}" + + message = f"已{action} {len(changed)} 个站点" + if changed: + message += f":{', '.join(changed)}" + if missing: + message += f";未找到:{', '.join(missing)}" + return True, message + + def _update_site_cookie_from_input(self, arg_str: str) -> Tuple[bool, str]: + """ + 根据输入更新单个站点 Cookie。 + """ + args = str(arg_str or "").split() + if len(args) not in {3, 4} or not args[0].isdigit(): + return ( + False, + "格式错误,请输入:cookie [2fa_code/secret]", + ) + + site_id = int(args[0]) + site_info = SiteOper().get(site_id) + if not site_info: + return False, f"站点编号 {site_id} 不存在" + + status, msg = self.update_cookie( + site_info=site_info, + username=args[1], + password=args[2], + two_step_code=args[3] if len(args) == 4 else None, + ) + if not status: + logger.error(msg) + return False, f"【{site_info.name}】Cookie&UA 更新失败:{msg}" + return True, f"【{site_info.name}】Cookie&UA 更新成功" def remote_disable(self, arg_str: str, channel: MessageChannel, userid: Union[str, int] = None, source: Optional[str] = None): diff --git a/app/chain/skills.py b/app/chain/skills.py index 430c6108..648e324b 100644 --- a/app/chain/skills.py +++ b/app/chain/skills.py @@ -1,4 +1,3 @@ -import math import re from dataclasses import dataclass, field from datetime import datetime, timedelta @@ -7,9 +6,14 @@ from typing import Dict, List, Optional, Tuple, Union import uuid from app.chain import ChainBase +from app.helper.slash import ( + build_navigation_buttons, + page_items, + supports_interaction_buttons, + update_or_post_message, +) from app.helper.skill import SkillHelper, SkillInfo from app.schemas import Notification -from app.schemas.message import ChannelCapabilityManager from app.schemas.types import MessageChannel @@ -1055,11 +1059,7 @@ class SkillsChain(ChainBase): """ 返回当前页的数据,并把页码钳制到有效范围内。 """ - total_pages = max(1, math.ceil(len(items) / page_size)) if page_size else 1 - page = min(max(0, page), total_pages - 1) - start = page * page_size - end = start + page_size - return items[start:end], page, total_pages + return page_items(items=items, page=page, page_size=page_size) def _page_size(self, channel: Optional[MessageChannel]) -> int: """ @@ -1076,11 +1076,7 @@ class SkillsChain(ChainBase): """ 判断当前渠道是否同时支持按钮展示和回调。 """ - return bool( - channel - and ChannelCapabilityManager.supports_buttons(channel) - and ChannelCapabilityManager.supports_callbacks(channel) - ) + return supports_interaction_buttons(channel) @staticmethod def _navigation_buttons( @@ -1091,25 +1087,12 @@ class SkillsChain(ChainBase): """ 为分页视图生成上一页和下一页按钮。 """ - buttons = [] - nav_row = [] - if page > 0: - nav_row.append( - { - "text": "⬅️ 上一页", - "callback_data": f"skills:{request.request_id}:page-prev", - } - ) - if page < total_pages - 1: - nav_row.append( - { - "text": "下一页 ➡️", - "callback_data": f"skills:{request.request_id}:page-next", - } - ) - if nav_row: - buttons.append(nav_row) - return buttons + return build_navigation_buttons( + prefix="skills", + request=request, + page=page, + total_pages=total_pages, + ) def _update_or_post_message( self, @@ -1126,33 +1109,17 @@ class SkillsChain(ChainBase): """ 优先编辑原消息,编辑失败时再回退为发送新消息。 """ - if ( - original_message_id - and original_chat_id - and ChannelCapabilityManager.supports_editing(channel) - ): - edited = self.edit_message( - channel=channel, - source=source, - message_id=original_message_id, - chat_id=original_chat_id, - title=title, - text=text, - buttons=buttons, - ) - if edited: - return - - self.post_message( - Notification( - channel=channel, - source=source, - userid=userid, - username=username, - title=title, - text=text, - buttons=buttons, - ) + update_or_post_message( + chain=self, + channel=channel, + source=source, + userid=userid, + username=username, + title=title, + text=text, + buttons=buttons, + original_message_id=original_message_id, + original_chat_id=original_chat_id, ) @staticmethod diff --git a/app/chain/subscribe.py b/app/chain/subscribe.py index 8450e3a8..17afae40 100644 --- a/app/chain/subscribe.py +++ b/app/chain/subscribe.py @@ -11,6 +11,15 @@ from app.chain import ChainBase from app.chain.download import DownloadChain from app.chain.media import MediaChain from app.chain.search import SearchChain +from app.helper.slash import ( + SlashInteractionManager, + build_navigation_buttons, + format_markdown_table, + page_items, + supports_interaction_buttons, + supports_markdown, + update_or_post_message, +) from app.chain.tmdb import TmdbChain from app.chain.torrents import TorrentsChain from app.core.config import settings, global_vars @@ -32,6 +41,9 @@ from app.schemas.types import MediaType, SystemConfigKey, MessageChannel, Notifi ContentType +subscribe_interaction_manager = SlashInteractionManager() + + class SubscribeChain(ChainBase): """ 订阅管理处理链 @@ -40,6 +52,8 @@ class SubscribeChain(ChainBase): _rlock = threading.RLock() # 避免莫名原因导致长时间持有锁 _LOCK_TIMOUT = 3600 * 2 + _button_page_size = 6 + _text_page_size = 10 @staticmethod def __get_event_media(_mediaid: str, _meta: MetaBase) -> Optional[MediaInfo]: @@ -1385,33 +1399,670 @@ class SubscribeChain(ChainBase): "doubanid": mediainfo.douban_id }) - def remote_list(self, channel: MessageChannel, - userid: Union[str, int] = None, source: Optional[str] = None): + def remote_list( + self, + arg_str: str = "", + channel: MessageChannel = None, + userid: Union[str, int] = None, + source: Optional[str] = None, + ): """ - 查询订阅并发送消息 + /subscribes 统一入口。 + """ + request = subscribe_interaction_manager.create_or_replace( + user_id=userid, + command="/subscribes", + channel=channel, + source=source, + username=None, + ) + normalized_arg = (arg_str or "").strip() + if normalized_arg and self.handle_text_interaction( + channel=channel, + source=source, + userid=userid, + username="", + text=normalized_arg, + ): + return + self._render_subscribe_interaction( + request=request, + channel=channel, + source=source, + userid=userid, + username="", + ) + + @staticmethod + def parse_callback(callback_data: str) -> Optional[Tuple[str, str]]: + """ + 解析 /subscribes 按钮回调。 + """ + if not callback_data.startswith("subscribes:"): + return None + parts = callback_data.split(":") + if len(parts) < 3: + return None + return parts[1], parts[2] + + def handle_callback_interaction( + self, + callback_data: str, + channel: MessageChannel, + source: str, + userid: Union[str, int], + username: str, + original_message_id: Optional[Union[str, int]] = None, + original_chat_id: Optional[str] = None, + ) -> bool: + """ + 处理 /subscribes 按钮交互。 + """ + parsed = self.parse_callback(callback_data) + if not parsed: + return False + + request_id, action = parsed + request = subscribe_interaction_manager.get_by_id(request_id, userid) + if not request: + self.post_message( + schemas.Notification( + channel=channel, + source=source, + userid=userid, + username=username, + title="订阅交互已失效,请重新发送 /subscribes", + ) + ) + return True + + request.channel = channel + request.source = source + request.username = username + + if action == "close": + subscribe_interaction_manager.remove(request.request_id) + update_or_post_message( + chain=self, + channel=channel, + source=source, + userid=userid, + username=username, + title="订阅管理", + text="订阅交互已结束", + original_message_id=original_message_id, + original_chat_id=original_chat_id, + ) + return True + + if action == "page-prev": + request.page = max(0, request.page - 1) + request.awaiting_input = None + elif action == "page-next": + request.page += 1 + request.awaiting_input = None + elif action in {"search", "delete"}: + request.awaiting_input = action + elif action == "refresh": + request.awaiting_input = None + self._run_refresh_action(channel, source, userid, username) + elif action == "refresh-list": + request.awaiting_input = None + elif action == "metadata": + request.awaiting_input = None + self._run_metadata_refresh_action(channel, source, userid, username) + + self._render_subscribe_interaction( + request=request, + channel=channel, + source=source, + userid=userid, + username=username, + original_message_id=original_message_id, + original_chat_id=original_chat_id, + ) + return True + + def handle_text_interaction( + self, + channel: MessageChannel, + source: str, + userid: Union[str, int], + username: str, + text: str, + ) -> bool: + """ + 处理 /subscribes 文本补充输入。 + """ + request = subscribe_interaction_manager.get_by_user(userid) + if not request: + return False + + request.channel = channel + request.source = source + request.username = username + + normalized = (text or "").strip() + lowered = normalized.lower() + + if lowered in {"退出", "关闭", "q", "quit", "exit"}: + subscribe_interaction_manager.remove(request.request_id) + self.post_message( + schemas.Notification( + channel=channel, + source=source, + userid=userid, + username=username, + title="订阅交互已结束", + ) + ) + return True + + if lowered in {"取消", "cancel", "返回", "back"}: + request.awaiting_input = None + self._render_subscribe_interaction( + request=request, + channel=channel, + source=source, + userid=userid, + username=username, + ) + return True + + if lowered in {"刷新列表", "列表", "list"}: + request.awaiting_input = None + self._render_subscribe_interaction( + request=request, + channel=channel, + source=source, + userid=userid, + username=username, + ) + return True + + if lowered in {"刷新", "refresh"}: + request.awaiting_input = None + self._run_refresh_action(channel, source, userid, username) + self._render_subscribe_interaction( + request=request, + channel=channel, + source=source, + userid=userid, + username=username, + ) + return True + + if lowered in {"元数据", "刷新元数据", "metadata"}: + request.awaiting_input = None + self._run_metadata_refresh_action(channel, source, userid, username) + self._render_subscribe_interaction( + request=request, + channel=channel, + source=source, + userid=userid, + username=username, + ) + return True + + if lowered in {"p", "prev", "上一页"}: + request.awaiting_input = None + request.page = max(0, request.page - 1) + self._render_subscribe_interaction( + request=request, + channel=channel, + source=source, + userid=userid, + username=username, + ) + return True + + if lowered in {"n", "next", "下一页"}: + request.awaiting_input = None + request.page += 1 + self._render_subscribe_interaction( + request=request, + channel=channel, + source=source, + userid=userid, + username=username, + ) + return True + + search_match = re.match(r"^(?:搜索|search)\s+(.+)$", normalized, re.IGNORECASE) + delete_match = re.match(r"^(?:删除|delete)\s+(.+)$", normalized, re.IGNORECASE) + + if request.awaiting_input == "search": + success, message = self._run_search_action( + normalized, channel, source, userid, username + ) + request.awaiting_input = None + self.post_message( + schemas.Notification( + channel=channel, + source=source, + userid=userid, + username=username, + title=message, + ) + ) + self._render_subscribe_interaction( + request=request, + channel=channel, + source=source, + userid=userid, + username=username, + ) + return True + + if request.awaiting_input == "delete": + success, message = self._delete_subscribes(normalized) + request.awaiting_input = None + self.post_message( + schemas.Notification( + channel=channel, + source=source, + userid=userid, + username=username, + title=message, + ) + ) + self._render_subscribe_interaction( + request=request, + channel=channel, + source=source, + userid=userid, + username=username, + ) + return True + + if search_match: + success, message = self._run_search_action( + search_match.group(1), channel, source, userid, username + ) + self.post_message( + schemas.Notification( + channel=channel, + source=source, + userid=userid, + username=username, + title=message, + ) + ) + self._render_subscribe_interaction( + request=request, + channel=channel, + source=source, + userid=userid, + username=username, + ) + return True + + if delete_match: + success, message = self._delete_subscribes(delete_match.group(1)) + self.post_message( + schemas.Notification( + channel=channel, + source=source, + userid=userid, + username=username, + title=message, + ) + ) + self._render_subscribe_interaction( + request=request, + channel=channel, + source=source, + userid=userid, + username=username, + ) + return True + + self.post_message( + schemas.Notification( + channel=channel, + source=source, + userid=userid, + username=username, + title=self._subscribe_usage_hint(request.awaiting_input), + ) + ) + return True + + def _render_subscribe_interaction( + self, + request, + channel: MessageChannel, + source: Optional[str], + userid: Union[str, int], + username: Optional[str], + original_message_id: Optional[Union[str, int]] = None, + original_chat_id: Optional[str] = None, + ) -> None: + """ + 渲染 /subscribes 当前页面。 """ subscribes = SubscribeOper().list() - if not subscribes: - self.post_message(schemas.Notification(channel=channel, - source=source, - title='没有任何订阅!', userid=userid)) - return - title = f"共有 {len(subscribes)} 个订阅,回复对应指令操作: " \ - f"\n- 删除订阅:/subscribe_delete [id]" \ - f"\n- 搜索订阅:/subscribe_search [id]" \ - f"\n- 刷新订阅:/subscribe_refresh" - messages = [] + page_size = ( + self._button_page_size + if supports_interaction_buttons(channel) + else self._text_page_size + ) + page_subscribes, page, total_pages = page_items( + subscribes, request.page, page_size + ) + request.page = page + + if subscribes: + body = self._format_subscribe_list(page_subscribes, channel=channel) + footer = [ + f"第 {page + 1}/{total_pages} 页,共 {len(subscribes)} 个订阅", + self._subscribe_prompt(request.awaiting_input), + self._subscribe_usage_hint(request.awaiting_input), + ] + text = "\n\n".join([body, *[line for line in footer if line]]) + else: + text = "当前没有任何订阅。\n\n输入 `退出` 结束交互。" + + buttons = None + if supports_interaction_buttons(channel): + buttons = build_navigation_buttons( + "subscribes", request, page, total_pages + ) + buttons.extend( + [ + [ + { + "text": "搜索订阅", + "callback_data": f"subscribes:{request.request_id}:search", + }, + { + "text": "删除订阅", + "callback_data": f"subscribes:{request.request_id}:delete", + }, + { + "text": "刷新订阅", + "callback_data": f"subscribes:{request.request_id}:refresh", + }, + ], + [ + { + "text": "刷新元数据", + "callback_data": f"subscribes:{request.request_id}:metadata", + }, + { + "text": "刷新列表", + "callback_data": f"subscribes:{request.request_id}:refresh-list", + }, + { + "text": "关闭", + "callback_data": f"subscribes:{request.request_id}:close", + }, + ], + ] + ) + + update_or_post_message( + chain=self, + channel=channel, + source=source, + userid=userid, + username=username, + title="订阅管理", + text=text, + buttons=buttons, + original_message_id=original_message_id, + original_chat_id=original_chat_id, + ) + + def _format_subscribe_list( + self, subscribes: List[Subscribe], channel: Optional[MessageChannel] + ) -> str: + """ + 根据渠道能力格式化订阅列表。 + """ + if supports_markdown(channel): + rows = [ + [ + subscribe.id, + subscribe.name, + subscribe.type, + subscribe.year or "-", + self._format_subscribe_progress(subscribe), + self._format_subscribe_state(subscribe.state), + ] + for subscribe in subscribes + ] + return format_markdown_table( + headers=["ID", "名称", "类型", "年份", "季/进度", "状态"], + rows=rows, + ) + + lines = [] for subscribe in subscribes: - if subscribe.type == MediaType.MOVIE.value: - messages.append(f"{subscribe.id}. {subscribe.name}({subscribe.year})") - else: - messages.append(f"{subscribe.id}. {subscribe.name}({subscribe.year})" - f"第{subscribe.season}季 " - f"[{subscribe.total_episode - (subscribe.lack_episode or subscribe.total_episode)}" - f"/{subscribe.total_episode}]") - # 发送列表 - self.post_message(schemas.Notification(channel=channel, source=source, - title=title, text='\n'.join(messages), userid=userid)) + lines.append( + f"{subscribe.id}. {subscribe.name}({subscribe.year or '-'})" + f" | {subscribe.type}" + f" | {self._format_subscribe_progress(subscribe)}" + f" | 状态:{self._format_subscribe_state(subscribe.state)}" + ) + return "\n".join(lines) + + @staticmethod + def _format_subscribe_state(state: Optional[str]) -> str: + """ + 订阅状态显示文本。 + """ + mapping = { + "N": "新建", + "R": "订阅中", + "P": "待定", + "S": "暂停", + } + return mapping.get(state or "", state or "-") + + @staticmethod + def _format_subscribe_progress(subscribe: Subscribe) -> str: + """ + 构造订阅的季和进度说明。 + """ + if subscribe.type == MediaType.MOVIE.value: + return "电影" + season = subscribe.season or 1 + if subscribe.total_episode: + lack_episode = ( + subscribe.lack_episode + if subscribe.lack_episode is not None + else subscribe.total_episode + ) + downloaded = max(subscribe.total_episode - lack_episode, 0) + return f"第{season}季 [{downloaded}/{subscribe.total_episode}]" + return f"第{season}季" + + @staticmethod + def _subscribe_prompt(awaiting_input: Optional[str]) -> str: + """ + 返回当前输入模式提示。 + """ + if awaiting_input == "search": + return "当前操作:搜索订阅,请输入订阅 ID,多个 ID 用空格分隔,或输入 all 搜索全部。" + if awaiting_input == "delete": + return "当前操作:删除订阅,请输入订阅 ID,多个 ID 用空格分隔。" + return "" + + @staticmethod + def _subscribe_usage_hint(awaiting_input: Optional[str]) -> str: + """ + 返回 /subscribes 的文本操作提示。 + """ + if awaiting_input == "search": + return "输入订阅 ID 或 all;输入 `取消` 返回列表,输入 `退出` 结束交互。" + if awaiting_input == "delete": + return "输入一个或多个订阅 ID;输入 `取消` 返回列表,输入 `退出` 结束交互。" + return ( + "可输入:`搜索 `、`删除 `、`刷新`、`刷新元数据`、`n`、`p`、`退出`。" + ) + + def _run_refresh_action( + self, + channel: MessageChannel, + source: str, + userid: Union[str, int], + username: str, + ) -> None: + """ + 执行订阅刷新。 + """ + self.post_message( + schemas.Notification( + channel=channel, + source=source, + userid=userid, + username=username, + title="开始刷新订阅...", + ) + ) + self.refresh() + self.post_message( + schemas.Notification( + channel=channel, + source=source, + userid=userid, + username=username, + title="订阅刷新执行完成", + ) + ) + + def _run_metadata_refresh_action( + self, + channel: MessageChannel, + source: str, + userid: Union[str, int], + username: str, + ) -> None: + """ + 执行订阅元数据刷新。 + """ + self.post_message( + schemas.Notification( + channel=channel, + source=source, + userid=userid, + username=username, + title="开始刷新订阅元数据...", + ) + ) + self.check() + self.post_message( + schemas.Notification( + channel=channel, + source=source, + userid=userid, + username=username, + title="订阅元数据刷新完成", + ) + ) + + @staticmethod + def _parse_subscribe_ids(arg_str: str) -> List[int]: + """ + 从输入中提取订阅 ID。 + """ + return [int(item) for item in re.findall(r"\d+", arg_str or "")] + + def _run_search_action( + self, + arg_str: str, + channel: MessageChannel, + source: str, + userid: Union[str, int], + username: str, + ) -> Tuple[bool, str]: + """ + 手动执行订阅搜索。 + """ + normalized = (arg_str or "").strip() + if not normalized or normalized.lower() in {"all", "全部", "所有"}: + self.post_message( + schemas.Notification( + channel=channel, + source=source, + userid=userid, + username=username, + title="开始搜索所有订阅...", + ) + ) + self.search(state="N,R,P", manual=True) + return True, "所有订阅搜索完成" + + subscribe_ids = self._parse_subscribe_ids(normalized) + if not subscribe_ids: + return False, "请输入订阅 ID,多个 ID 用空格分隔,或输入 all" + + subscribeoper = SubscribeOper() + missing = [] + searched = [] + for subscribe_id in subscribe_ids: + subscribe = subscribeoper.get(subscribe_id) + if not subscribe: + missing.append(str(subscribe_id)) + continue + self.post_message( + schemas.Notification( + channel=channel, + source=source, + userid=userid, + username=username, + title=f"开始搜索订阅【{subscribe.name}】...", + ) + ) + self.search(sid=subscribe_id, manual=True) + searched.append(subscribe.name) + + if not searched and missing: + return False, f"未找到订阅:{', '.join(missing)}" + + message = f"已完成 {len(searched)} 个订阅搜索" + if searched: + message += f":{', '.join(searched)}" + if missing: + message += f";未找到:{', '.join(missing)}" + return True, message + + def _delete_subscribes(self, arg_str: str) -> Tuple[bool, str]: + """ + 批量删除订阅。 + """ + subscribe_ids = self._parse_subscribe_ids(arg_str) + if not subscribe_ids: + return False, "请输入至少一个有效的订阅 ID" + + subscribeoper = SubscribeOper() + subscribehelper = SubscribeHelper() + deleted = [] + missing = [] + for subscribe_id in subscribe_ids: + subscribe = subscribeoper.get(subscribe_id) + if not subscribe: + missing.append(str(subscribe_id)) + continue + deleted.append(subscribe.name) + subscribeoper.delete(subscribe_id) + subscribehelper.sub_done_async( + { + "tmdbid": subscribe.tmdbid, + "doubanid": subscribe.doubanid, + } + ) + + if not deleted and missing: + return False, f"未找到订阅:{', '.join(missing)}" + + message = f"已删除 {len(deleted)} 个订阅" + if deleted: + message += f":{', '.join(deleted)}" + if missing: + message += f";未找到:{', '.join(missing)}" + return True, message def remote_delete(self, arg_str: str, channel: MessageChannel, userid: Union[str, int] = None, source: Optional[str] = None): diff --git a/app/command.py b/app/command.py index 923f3f2b..57e547e7 100644 --- a/app/command.py +++ b/app/command.py @@ -50,30 +50,10 @@ class Command(metaclass=Singleton): }, "/sites": { "func": SiteChain().remote_list, - "description": "查询站点", + "description": "管理站点", "category": "站点", "data": {}, }, - "/site_cookie": { - "func": SiteChain().remote_cookie, - "description": "更新站点Cookie", - "data": {}, - }, - "/site_statistic": { - "func": SiteChain().remote_refresh_userdatas, - "description": "站点数据统计", - "data": {}, - }, - "/site_enable": { - "func": SiteChain().remote_enable, - "description": "启用站点", - "data": {}, - }, - "/site_disable": { - "func": SiteChain().remote_disable, - "description": "禁用站点", - "data": {}, - }, "/mediaserver_sync": { "id": "mediaserver_sync", "type": "scheduler", @@ -82,32 +62,10 @@ class Command(metaclass=Singleton): }, "/subscribes": { "func": SubscribeChain().remote_list, - "description": "查询订阅", + "description": "管理订阅", "category": "订阅", "data": {}, }, - "/subscribe_refresh": { - "id": "subscribe_refresh", - "type": "scheduler", - "description": "刷新订阅", - "category": "订阅", - }, - "/subscribe_search": { - "id": "subscribe_search", - "type": "scheduler", - "description": "搜索订阅", - "category": "订阅", - }, - "/subscribe_delete": { - "func": SubscribeChain().remote_delete, - "description": "删除订阅", - "data": {}, - }, - "/subscribe_tmdb": { - "id": "subscribe_tmdb", - "type": "scheduler", - "description": "订阅元数据更新", - }, "/downloading": { "func": DownloadChain().remote_downloading, "description": "正在下载", diff --git a/app/helper/slash.py b/app/helper/slash.py new file mode 100644 index 00000000..49f7105e --- /dev/null +++ b/app/helper/slash.py @@ -0,0 +1,244 @@ +import math +import uuid +from dataclasses import dataclass, field +from datetime import datetime, timedelta +from threading import Lock +from typing import Dict, List, Optional, Sequence, Tuple, Union + +from app.schemas import Notification +from app.schemas.message import ChannelCapabilityManager +from app.schemas.types import MessageChannel + + +@dataclass +class PendingSlashInteraction: + """ + 通用 slash 命令交互上下文。 + """ + + request_id: str + user_id: str + channel: Optional[MessageChannel] + source: Optional[str] + username: Optional[str] + command: str + page: int = 0 + awaiting_input: Optional[str] = None + created_at: datetime = field(default_factory=datetime.now) + + +class SlashInteractionManager: + """ + 管理单个 slash 命令的交互会话。 + """ + + _ttl = timedelta(hours=24) + + def __init__(self): + self._by_id: Dict[str, PendingSlashInteraction] = {} + self._by_user: Dict[str, str] = {} + self._lock = Lock() + + def _cleanup_locked(self) -> None: + expire_before = datetime.now() - self._ttl + expired = [ + request_id + for request_id, request in self._by_id.items() + if request.created_at < expire_before + ] + for request_id in expired: + request = self._by_id.pop(request_id, None) + if request: + self._by_user.pop(str(request.user_id), None) + + def create_or_replace( + self, + user_id: Union[str, int], + command: str, + channel: Optional[MessageChannel], + source: Optional[str], + username: Optional[str], + ) -> PendingSlashInteraction: + with self._lock: + self._cleanup_locked() + user_key = str(user_id) + old_request_id = self._by_user.get(user_key) + if old_request_id: + self._by_id.pop(old_request_id, None) + request = PendingSlashInteraction( + request_id=uuid.uuid4().hex[:12], + user_id=user_key, + command=command, + channel=channel, + source=source, + username=username, + ) + self._by_id[request.request_id] = request + self._by_user[user_key] = request.request_id + return request + + def get_by_user( + self, user_id: Union[str, int] + ) -> Optional[PendingSlashInteraction]: + with self._lock: + self._cleanup_locked() + request_id = self._by_user.get(str(user_id)) + if not request_id: + return None + return self._by_id.get(request_id) + + def get_by_id( + self, request_id: str, user_id: Union[str, int] + ) -> Optional[PendingSlashInteraction]: + with self._lock: + self._cleanup_locked() + request = self._by_id.get(request_id) + if not request or str(request.user_id) != str(user_id): + return None + return request + + def remove(self, request_id: str) -> None: + with self._lock: + request = self._by_id.pop(request_id, None) + if request: + self._by_user.pop(str(request.user_id), None) + + def clear(self) -> None: + with self._lock: + self._by_id.clear() + self._by_user.clear() + + +def supports_interaction_buttons(channel: Optional[MessageChannel]) -> bool: + """ + 渠道同时支持按钮和回调时,优先使用按钮交互。 + """ + return bool( + channel + and ChannelCapabilityManager.supports_buttons(channel) + and ChannelCapabilityManager.supports_callbacks(channel) + ) + + +def supports_markdown(channel: Optional[MessageChannel]) -> bool: + """ + 仅在支持 Markdown 的渠道上输出 Markdown 内容。 + """ + return bool(channel and ChannelCapabilityManager.supports_markdown(channel)) + + +def page_items( + items: Sequence, + page: int, + page_size: int, +) -> Tuple[List, int, int]: + """ + 对列表做分页并规范化页码。 + """ + total = len(items) + if total == 0: + return [], 0, 1 + total_pages = max(1, math.ceil(total / max(1, page_size))) + page = min(max(0, page), total_pages - 1) + start = page * page_size + end = start + page_size + return list(items[start:end]), page, total_pages + + +def build_navigation_buttons( + prefix: str, + request: PendingSlashInteraction, + page: int, + total_pages: int, +) -> List[List[dict]]: + """ + 构造标准上一页/下一页按钮。 + """ + buttons = [] + nav_row = [] + if page > 0: + nav_row.append( + { + "text": "⬅️ 上一页", + "callback_data": f"{prefix}:{request.request_id}:page-prev", + } + ) + if page < total_pages - 1: + nav_row.append( + { + "text": "下一页 ➡️", + "callback_data": f"{prefix}:{request.request_id}:page-next", + } + ) + if nav_row: + buttons.append(nav_row) + return buttons + + +def update_or_post_message( + chain, + channel: MessageChannel, + source: Optional[str], + userid: Union[str, int], + username: Optional[str], + title: str, + text: str, + buttons: Optional[List[List[dict]]] = None, + original_message_id: Optional[Union[str, int]] = None, + original_chat_id: Optional[str] = None, +) -> None: + """ + 优先编辑原消息,失败时回退为发送新消息。 + """ + if ( + original_message_id + and original_chat_id + and ChannelCapabilityManager.supports_editing(channel) + ): + edited = chain.edit_message( + channel=channel, + source=source, + message_id=original_message_id, + chat_id=original_chat_id, + title=title, + text=text, + buttons=buttons, + ) + if edited: + return + + chain.post_message( + Notification( + channel=channel, + source=source, + userid=userid, + username=username, + title=title, + text=text, + buttons=buttons, + ) + ) + + +def escape_markdown_table_cell(value: object) -> str: + """ + 最小化转义 Markdown 表格中的特殊字符。 + """ + text = str(value or "").replace("\n", "
") + text = text.replace("|", "\\|") + return text + + +def format_markdown_table(headers: Sequence[str], rows: Sequence[Sequence[object]]) -> str: + """ + 生成 Markdown 表格文本。 + """ + header_line = "| " + " | ".join(escape_markdown_table_cell(item) for item in headers) + " |" + separator_line = "| " + " | ".join("---" for _ in headers) + " |" + data_lines = [ + "| " + + " | ".join(escape_markdown_table_cell(item) for item in row) + + " |" + for row in rows + ] + return "\n".join([header_line, separator_line, *data_lines]) diff --git a/app/schemas/message.py b/app/schemas/message.py index b96f9fa3..461b173a 100644 --- a/app/schemas/message.py +++ b/app/schemas/message.py @@ -263,6 +263,8 @@ class ChannelCapability(Enum): CALLBACK_QUERIES = "callback_queries" # 支持富文本 RICH_TEXT = "rich_text" + # 支持 Markdown + MARKDOWN = "markdown" # 支持图片 IMAGES = "images" # 支持链接 @@ -301,6 +303,7 @@ class ChannelCapabilityManager: ChannelCapability.MESSAGE_EDITING, ChannelCapability.MESSAGE_DELETION, ChannelCapability.CALLBACK_QUERIES, + ChannelCapability.MARKDOWN, ChannelCapability.RICH_TEXT, ChannelCapability.IMAGES, ChannelCapability.LINKS, @@ -315,6 +318,7 @@ class ChannelCapabilityManager: MessageChannel.Wechat: ChannelCapabilities( channel=MessageChannel.Wechat, capabilities={ + ChannelCapability.MARKDOWN, ChannelCapability.IMAGES, ChannelCapability.LINKS, ChannelCapability.MENU_COMMANDS, @@ -328,6 +332,7 @@ class ChannelCapabilityManager: ChannelCapability.MESSAGE_EDITING, ChannelCapability.MESSAGE_DELETION, ChannelCapability.CALLBACK_QUERIES, + ChannelCapability.MARKDOWN, ChannelCapability.RICH_TEXT, ChannelCapability.IMAGES, ChannelCapability.LINKS, @@ -348,6 +353,7 @@ class ChannelCapabilityManager: ChannelCapability.MESSAGE_EDITING, ChannelCapability.MESSAGE_DELETION, ChannelCapability.CALLBACK_QUERIES, + ChannelCapability.MARKDOWN, ChannelCapability.RICH_TEXT, ChannelCapability.IMAGES, ChannelCapability.LINKS, @@ -363,6 +369,7 @@ class ChannelCapabilityManager: MessageChannel.SynologyChat: ChannelCapabilities( channel=MessageChannel.SynologyChat, capabilities={ + ChannelCapability.MARKDOWN, ChannelCapability.RICH_TEXT, ChannelCapability.IMAGES, ChannelCapability.LINKS, @@ -372,6 +379,7 @@ class ChannelCapabilityManager: MessageChannel.VoceChat: ChannelCapabilities( channel=MessageChannel.VoceChat, capabilities={ + ChannelCapability.MARKDOWN, ChannelCapability.RICH_TEXT, ChannelCapability.IMAGES, ChannelCapability.LINKS, @@ -386,6 +394,7 @@ class ChannelCapabilityManager: MessageChannel.Web: ChannelCapabilities( channel=MessageChannel.Web, capabilities={ + ChannelCapability.MARKDOWN, ChannelCapability.RICH_TEXT, ChannelCapability.IMAGES, ChannelCapability.LINKS, @@ -395,6 +404,7 @@ class ChannelCapabilityManager: MessageChannel.QQ: ChannelCapabilities( channel=MessageChannel.QQ, capabilities={ + ChannelCapability.MARKDOWN, ChannelCapability.RICH_TEXT, ChannelCapability.IMAGES, ChannelCapability.LINKS, @@ -443,6 +453,13 @@ class ChannelCapabilityManager: """ return cls.supports_capability(channel, ChannelCapability.MESSAGE_EDITING) + @classmethod + def supports_markdown(cls, channel: MessageChannel) -> bool: + """ + 检查渠道是否支持 Markdown。 + """ + return cls.supports_capability(channel, ChannelCapability.MARKDOWN) + @classmethod def supports_deletion(cls, channel: MessageChannel) -> bool: """ diff --git a/skills/command-dispatch/SKILL.md b/skills/command-dispatch/SKILL.md index 1a4539a2..4182c3cb 100644 --- a/skills/command-dispatch/SKILL.md +++ b/skills/command-dispatch/SKILL.md @@ -21,7 +21,7 @@ Use this skill to identify user intent and dispatch the corresponding system or - The user describes an action in natural language, for example: - "Sync sites" → `/cookiecloud` - "Show my subscriptions" → `/subscribes` - - "Refresh subscriptions" → `/subscribe_refresh` + - "Refresh subscriptions" → `/subscribes refresh` - "What's downloading?" → `/downloading` - "Organize downloaded files" → `/transfer` - "Clear cache" → `/clear_cache` @@ -58,7 +58,8 @@ If the user's description involves a specific plugin's functionality, additional Some commands support additional arguments (space-separated after the command), for example: - `/redo ` — Manually re-organize a specific record -- `/subscribe_delete ` — Delete a specific subscription +- `/sites disable ` — Disable one or more sites +- `/subscribes delete ` — Delete one or more subscriptions Use `run_slash_command` to execute the command in the format `/command_name arg1 arg2`. diff --git a/tests/test_skills_command.py b/tests/test_skills_command.py index f3b8f410..35bd7ad8 100644 --- a/tests/test_skills_command.py +++ b/tests/test_skills_command.py @@ -13,6 +13,18 @@ sys.modules.setdefault("transmission_rpc", ModuleType("transmission_rpc")) setattr(sys.modules["transmission_rpc"], "File", object) sys.modules.setdefault("psutil", ModuleType("psutil")) sys.modules.setdefault("aioshutil", ModuleType("aioshutil")) +sys.modules.setdefault("pyquery", ModuleType("pyquery")) +setattr(sys.modules["pyquery"], "PyQuery", object) +sys.modules.setdefault("cn2an", ModuleType("cn2an")) +setattr(sys.modules["cn2an"], "cn2an", lambda value, mode=None: value) +setattr(sys.modules["cn2an"], "an2cn", lambda value, mode=None: str(value)) +sys.modules.setdefault("dateparser", ModuleType("dateparser")) +setattr(sys.modules["dateparser"], "parse", lambda *args, **kwargs: None) +sys.modules.setdefault("dateutil", ModuleType("dateutil")) +dateutil_parser = ModuleType("dateutil.parser") +setattr(dateutil_parser, "parse", lambda *args, **kwargs: None) +sys.modules.setdefault("dateutil.parser", dateutil_parser) +setattr(sys.modules["dateutil"], "parser", dateutil_parser) from app.chain.message import MessageChain from app.chain.skills import SkillsChain, skills_interaction_manager diff --git a/tests/test_slash_command_interactions.py b/tests/test_slash_command_interactions.py new file mode 100644 index 00000000..bca1241a --- /dev/null +++ b/tests/test_slash_command_interactions.py @@ -0,0 +1,201 @@ +import sys +import unittest +from types import ModuleType, SimpleNamespace +from unittest.mock import patch + +sys.modules.setdefault("qbittorrentapi", ModuleType("qbittorrentapi")) +setattr(sys.modules["qbittorrentapi"], "TorrentFilesList", list) +sys.modules.setdefault("transmission_rpc", ModuleType("transmission_rpc")) +setattr(sys.modules["transmission_rpc"], "File", object) +sys.modules.setdefault("psutil", ModuleType("psutil")) +sys.modules.setdefault("aioshutil", ModuleType("aioshutil")) +sys.modules.setdefault("pyquery", ModuleType("pyquery")) +setattr(sys.modules["pyquery"], "PyQuery", object) +sys.modules.setdefault("cn2an", ModuleType("cn2an")) +setattr(sys.modules["cn2an"], "cn2an", lambda value, mode=None: value) +setattr(sys.modules["cn2an"], "an2cn", lambda value, mode=None: str(value)) +sys.modules.setdefault("dateparser", ModuleType("dateparser")) +setattr(sys.modules["dateparser"], "parse", lambda *args, **kwargs: None) +sys.modules.setdefault("dateutil", ModuleType("dateutil")) +dateutil_parser = ModuleType("dateutil.parser") +setattr(dateutil_parser, "parse", lambda *args, **kwargs: None) +sys.modules.setdefault("dateutil.parser", dateutil_parser) +setattr(sys.modules["dateutil"], "parser", dateutil_parser) + +from app.chain.message import MessageChain +from app.chain.site import SiteChain, site_interaction_manager +from app.chain.skills import skills_interaction_manager +from app.chain.subscribe import SubscribeChain, subscribe_interaction_manager +from app.schemas.types import MessageChannel + + +class TestSlashCommandInteractions(unittest.TestCase): + def tearDown(self): + skills_interaction_manager.clear() + site_interaction_manager.clear() + subscribe_interaction_manager.clear() + + def test_message_routes_text_reply_to_latest_sites_interaction(self): + chain = MessageChain() + skills_interaction_manager.create_or_replace( + user_id="10001", + channel=MessageChannel.Wechat, + source="wechat-test", + username="tester", + ) + site_interaction_manager.create_or_replace( + user_id="10001", + command="/sites", + channel=MessageChannel.Wechat, + source="wechat-test", + username="tester", + ) + + with patch.object(chain, "_record_user_message"), patch( + "app.chain.message.SiteChain.handle_text_interaction", + return_value=True, + ) as handle_site, patch( + "app.chain.message.SkillsChain.handle_text_interaction" + ) as handle_skills: + chain.handle_message( + channel=MessageChannel.Wechat, + source="wechat-test", + userid="10001", + username="tester", + text="禁用 1", + ) + + handle_site.assert_called_once() + handle_skills.assert_not_called() + + def test_message_routes_text_reply_to_latest_subscribes_interaction(self): + chain = MessageChain() + site_interaction_manager.create_or_replace( + user_id="10001", + command="/sites", + channel=MessageChannel.Wechat, + source="wechat-test", + username="tester", + ) + subscribe_interaction_manager.create_or_replace( + user_id="10001", + command="/subscribes", + channel=MessageChannel.Wechat, + source="wechat-test", + username="tester", + ) + + with patch.object(chain, "_record_user_message"), patch( + "app.chain.message.SubscribeChain.handle_text_interaction", + return_value=True, + ) as handle_subscribes, patch( + "app.chain.message.SiteChain.handle_text_interaction" + ) as handle_sites: + chain.handle_message( + channel=MessageChannel.Wechat, + source="wechat-test", + userid="10001", + username="tester", + text="搜索 all", + ) + + handle_subscribes.assert_called_once() + handle_sites.assert_not_called() + + def test_callback_routes_to_sites_chain(self): + chain = MessageChain() + request = site_interaction_manager.create_or_replace( + user_id="10001", + command="/sites", + channel=MessageChannel.Telegram, + source="telegram-test", + username="tester", + ) + + with patch( + "app.chain.message.SiteChain.handle_callback_interaction", + return_value=True, + ) as handle_callback: + chain._handle_callback( + text=f"CALLBACK:sites:{request.request_id}:refresh", + channel=MessageChannel.Telegram, + source="telegram-test", + userid="10001", + username="tester", + ) + + handle_callback.assert_called_once() + + def test_callback_routes_to_subscribes_chain(self): + chain = MessageChain() + request = subscribe_interaction_manager.create_or_replace( + user_id="10001", + command="/subscribes", + channel=MessageChannel.Telegram, + source="telegram-test", + username="tester", + ) + + with patch( + "app.chain.message.SubscribeChain.handle_callback_interaction", + return_value=True, + ) as handle_callback: + chain._handle_callback( + text=f"CALLBACK:subscribes:{request.request_id}:refresh", + channel=MessageChannel.Telegram, + source="telegram-test", + userid="10001", + username="tester", + ) + + handle_callback.assert_called_once() + + def test_sites_renders_markdown_table_when_channel_supports_markdown(self): + chain = SiteChain() + fake_sites = [ + SimpleNamespace( + id=1, + name="M-Team", + is_active=True, + cookie="cookie=value", + render=1, + domain="m-team.io", + url="https://m-team.io/", + ) + ] + + with patch("app.chain.site.SiteOper.list", return_value=fake_sites), patch.object( + chain, "post_message" + ) as post_message: + chain.remote_list(channel=MessageChannel.Web, userid="u1", source="web") + + notification = post_message.call_args[0][0] + self.assertIn("| ID | 站点 | 状态 | Cookie | 渲染 | 域名 |", notification.text) + self.assertIn("| 1 | M-Team | 启用 | 已配置 | 是 | m-team.io |", notification.text) + + def test_subscribes_renders_markdown_table_when_channel_supports_markdown(self): + chain = SubscribeChain() + fake_subscribes = [ + SimpleNamespace( + id=12, + name="Example Show", + type="电视剧", + year="2024", + season=1, + total_episode=10, + lack_episode=3, + state="R", + ) + ] + + with patch( + "app.chain.subscribe.SubscribeOper.list", return_value=fake_subscribes + ), patch.object(chain, "post_message") as post_message: + chain.remote_list(channel=MessageChannel.Web, userid="u1", source="web") + + notification = post_message.call_args[0][0] + self.assertIn("| ID | 名称 | 类型 | 年份 | 季/进度 | 状态 |", notification.text) + self.assertIn( + "| 12 | Example Show | 电视剧 | 2024 | 第1季 [7/10] | 订阅中 |", + notification.text, + )