feat: unify slash command interactions

This commit is contained in:
jxxghp
2026-05-01 08:53:52 +08:00
parent 4bb4f5aeb5
commit db6dc926cf
10 changed files with 1800 additions and 162 deletions

View File

@@ -16,7 +16,9 @@ from app.chain.interaction import (
agent_interaction_manager,
media_interaction_manager,
)
from app.chain.site import SiteChain, site_interaction_manager
from app.chain.skills import SkillsChain, skills_interaction_manager
from app.chain.subscribe import SubscribeChain, subscribe_interaction_manager
from app.chain.transfer import TransferChain
from app.core.config import settings, global_vars
from app.db.models import TransferHistory
@@ -170,13 +172,34 @@ class MessageChain(ChainBase):
)
return
if skills_interaction_manager.get_by_user(userid):
latest_slash_interaction = self._get_latest_slash_interaction(userid)
if latest_slash_interaction == "sites":
if SiteChain().handle_text_interaction(
channel=channel,
source=source,
userid=userid,
username=username,
text=text,
):
return
if latest_slash_interaction == "subscribes":
if SubscribeChain().handle_text_interaction(
channel=channel,
source=source,
userid=userid,
username=username,
text=text,
):
return
if latest_slash_interaction == "skills":
if SkillsChain().handle_text_interaction(
channel=channel,
source=source,
userid=userid,
username=username,
text=text,
channel=channel,
source=source,
userid=userid,
username=username,
text=text,
):
return
@@ -274,6 +297,28 @@ class MessageChain(ChainBase):
):
return
if SiteChain().handle_callback_interaction(
callback_data=callback_data,
channel=channel,
source=source,
userid=userid,
username=username,
original_message_id=original_message_id,
original_chat_id=original_chat_id,
):
return
if SubscribeChain().handle_callback_interaction(
callback_data=callback_data,
channel=channel,
source=source,
userid=userid,
username=username,
original_message_id=original_message_id,
original_chat_id=original_chat_id,
):
return
if MediaInteractionChain().handle_callback_interaction(
callback_data=callback_data,
channel=channel,
@@ -326,6 +371,24 @@ class MessageChain(ChainBase):
)
)
@staticmethod
def _get_latest_slash_interaction(userid: Union[str, int]) -> Optional[str]:
"""
返回当前用户最近一次激活的 slash 交互类型。
"""
candidates = []
for name, manager in (
("sites", site_interaction_manager),
("subscribes", subscribe_interaction_manager),
("skills", skills_interaction_manager),
):
request = manager.get_by_user(userid)
if request:
candidates.append((request.created_at, name))
if not candidates:
return None
return max(candidates, key=lambda item: item[0])[1]
@staticmethod
def _parse_transfer_callback(
callback_data: str,

View File

@@ -1,12 +1,21 @@
import base64
import re
from datetime import datetime
from typing import Optional, Tuple, Union, Dict
from typing import List, Optional, Tuple, Union, Dict
from urllib.parse import urljoin
from lxml import etree
from app.chain import ChainBase
from app.helper.slash import (
SlashInteractionManager,
build_navigation_buttons,
format_markdown_table,
page_items,
supports_interaction_buttons,
supports_markdown,
update_or_post_message,
)
from app.core.config import global_vars, settings
from app.core.event import Event, eventmanager
from app.db.models.site import Site
@@ -26,11 +35,17 @@ from app.utils.site import SiteUtils
from app.utils.string import StringUtils
site_interaction_manager = SlashInteractionManager()
class SiteChain(ChainBase):
"""
站点管理处理链
"""
_button_page_size = 6
_text_page_size = 10
def __init__(self):
super().__init__()
@@ -626,39 +641,548 @@ class SiteChain(ChainBase):
return False, f"无法打开网站!"
return True, "连接成功"
def remote_list(self, channel: MessageChannel,
userid: Union[str, int] = None, source: Optional[str] = None):
def remote_list(
self,
arg_str: str = "",
channel: MessageChannel = None,
userid: Union[str, int] = None,
source: Optional[str] = None,
):
"""
查询所有站点,发送消息
/sites 统一入口。
"""
site_list = SiteOper().list()
if not site_list:
self.post_message(Notification(
channel=channel,
title="没有维护任何站点信息!",
userid=userid,
link=settings.MP_DOMAIN('#/site')))
title = f"共有 {len(site_list)} 个站点,回复对应指令操作:" \
f"\n- 禁用站点:/site_disable [id]" \
f"\n- 启用站点:/site_enable [id]" \
f"\n- 更新站点Cookie/site_cookie [id] [username] [password] [2fa_code/secret]"
messages = []
for site in site_list:
if site.render:
render_str = "🧭"
else:
render_str = ""
if site.is_active:
messages.append(f"{site.id}. {site.name} {render_str}")
else:
messages.append(f"{site.id}. {site.name} ⚠️")
# 发送列表
self.post_message(Notification(
request = site_interaction_manager.create_or_replace(
user_id=userid,
command="/sites",
channel=channel,
source=source,
title=title, text="\n".join(messages), userid=userid,
link=settings.MP_DOMAIN('#/site'))
username=None,
)
normalized_arg = (arg_str or "").strip()
if normalized_arg and self.handle_text_interaction(
channel=channel,
source=source,
userid=userid,
username="",
text=normalized_arg,
):
return
self._render_site_interaction(
request=request,
channel=channel,
source=source,
userid=userid,
username="",
)
@staticmethod
def parse_callback(callback_data: str) -> Optional[Tuple[str, str]]:
"""
解析 /sites 按钮回调。
"""
if not callback_data.startswith("sites:"):
return None
parts = callback_data.split(":")
if len(parts) < 3:
return None
return parts[1], parts[2]
def handle_callback_interaction(
self,
callback_data: str,
channel: MessageChannel,
source: str,
userid: Union[str, int],
username: str,
original_message_id: Optional[Union[str, int]] = None,
original_chat_id: Optional[str] = None,
) -> bool:
"""
处理 /sites 按钮交互。
"""
parsed = self.parse_callback(callback_data)
if not parsed:
return False
request_id, action = parsed
request = site_interaction_manager.get_by_id(request_id, userid)
if not request:
self.post_message(
Notification(
channel=channel,
source=source,
userid=userid,
username=username,
title="站点交互已失效,请重新发送 /sites",
)
)
return True
request.channel = channel
request.source = source
request.username = username
if action == "close":
site_interaction_manager.remove(request.request_id)
update_or_post_message(
chain=self,
channel=channel,
source=source,
userid=userid,
username=username,
title="站点管理",
text="站点交互已结束",
original_message_id=original_message_id,
original_chat_id=original_chat_id,
)
return True
if action == "page-prev":
request.page = max(0, request.page - 1)
request.awaiting_input = None
elif action == "page-next":
request.page += 1
request.awaiting_input = None
elif action in {"cookie", "enable", "disable"}:
request.awaiting_input = action
elif action == "refresh":
request.awaiting_input = None
self._render_site_interaction(
request=request,
channel=channel,
source=source,
userid=userid,
username=username,
original_message_id=original_message_id,
original_chat_id=original_chat_id,
)
return True
def handle_text_interaction(
self,
channel: MessageChannel,
source: str,
userid: Union[str, int],
username: str,
text: str,
) -> bool:
"""
处理 /sites 文本补充输入。
"""
request = site_interaction_manager.get_by_user(userid)
if not request:
return False
request.channel = channel
request.source = source
request.username = username
normalized = (text or "").strip()
lowered = normalized.lower()
if lowered in {"退出", "关闭", "q", "quit", "exit"}:
site_interaction_manager.remove(request.request_id)
self.post_message(
Notification(
channel=channel,
source=source,
userid=userid,
username=username,
title="站点交互已结束",
)
)
return True
if lowered in {"取消", "cancel", "返回", "back"}:
request.awaiting_input = None
self._render_site_interaction(
request=request,
channel=channel,
source=source,
userid=userid,
username=username,
)
return True
if lowered in {"刷新", "refresh", "列表", "list"}:
request.awaiting_input = None
self._render_site_interaction(
request=request,
channel=channel,
source=source,
userid=userid,
username=username,
)
return True
if lowered in {"p", "prev", "上一页"}:
request.awaiting_input = None
request.page = max(0, request.page - 1)
self._render_site_interaction(
request=request,
channel=channel,
source=source,
userid=userid,
username=username,
)
return True
if lowered in {"n", "next", "下一页"}:
request.awaiting_input = None
request.page += 1
self._render_site_interaction(
request=request,
channel=channel,
source=source,
userid=userid,
username=username,
)
return True
cookie_match = re.match(
r"^(?:cookie|更新cookie|更新\s*cookie)\s+(.+)$",
normalized,
re.IGNORECASE,
)
enable_match = re.match(r"^(?:启用|enable)\s+(.+)$", normalized, re.IGNORECASE)
disable_match = re.match(
r"^(?:禁用|disable)\s+(.+)$", normalized, re.IGNORECASE
)
if request.awaiting_input == "cookie":
success, message = self._update_site_cookie_from_input(normalized)
request.awaiting_input = None
self.post_message(
Notification(
channel=channel,
source=source,
userid=userid,
username=username,
title=message,
)
)
self._render_site_interaction(
request=request,
channel=channel,
source=source,
userid=userid,
username=username,
)
return True
if request.awaiting_input == "enable":
success, message = self._set_sites_enabled(normalized, enabled=True)
request.awaiting_input = None
self.post_message(
Notification(
channel=channel,
source=source,
userid=userid,
username=username,
title=message,
)
)
self._render_site_interaction(
request=request,
channel=channel,
source=source,
userid=userid,
username=username,
)
return True
if request.awaiting_input == "disable":
success, message = self._set_sites_enabled(normalized, enabled=False)
request.awaiting_input = None
self.post_message(
Notification(
channel=channel,
source=source,
userid=userid,
username=username,
title=message,
)
)
self._render_site_interaction(
request=request,
channel=channel,
source=source,
userid=userid,
username=username,
)
return True
if cookie_match:
success, message = self._update_site_cookie_from_input(cookie_match.group(1))
self.post_message(
Notification(
channel=channel,
source=source,
userid=userid,
username=username,
title=message,
)
)
self._render_site_interaction(
request=request,
channel=channel,
source=source,
userid=userid,
username=username,
)
return True
if enable_match:
success, message = self._set_sites_enabled(enable_match.group(1), enabled=True)
self.post_message(
Notification(
channel=channel,
source=source,
userid=userid,
username=username,
title=message,
)
)
self._render_site_interaction(
request=request,
channel=channel,
source=source,
userid=userid,
username=username,
)
return True
if disable_match:
success, message = self._set_sites_enabled(
disable_match.group(1), enabled=False
)
self.post_message(
Notification(
channel=channel,
source=source,
userid=userid,
username=username,
title=message,
)
)
self._render_site_interaction(
request=request,
channel=channel,
source=source,
userid=userid,
username=username,
)
return True
self.post_message(
Notification(
channel=channel,
source=source,
userid=userid,
username=username,
title=self._site_usage_hint(request.awaiting_input),
)
)
return True
def _render_site_interaction(
self,
request,
channel: MessageChannel,
source: Optional[str],
userid: Union[str, int],
username: Optional[str],
original_message_id: Optional[Union[str, int]] = None,
original_chat_id: Optional[str] = None,
) -> None:
"""
渲染 /sites 当前页面。
"""
site_list = SiteOper().list()
page_size = self._button_page_size if supports_interaction_buttons(channel) else self._text_page_size
page_sites, page, total_pages = page_items(site_list, request.page, page_size)
request.page = page
if site_list:
body = self._format_site_list(page_sites, channel=channel)
footer = [
f"{page + 1}/{total_pages} 页,共 {len(site_list)} 个站点",
self._site_prompt(request.awaiting_input),
self._site_usage_hint(request.awaiting_input),
]
text = "\n\n".join([body, *[line for line in footer if line]])
else:
text = "当前没有任何站点。\n\n输入 `退出` 结束交互。"
buttons = None
if supports_interaction_buttons(channel):
buttons = build_navigation_buttons("sites", request, page, total_pages)
buttons.extend(
[
[
{
"text": "更新 Cookie",
"callback_data": f"sites:{request.request_id}:cookie",
},
{
"text": "禁用站点",
"callback_data": f"sites:{request.request_id}:disable",
},
{
"text": "启用站点",
"callback_data": f"sites:{request.request_id}:enable",
},
],
[
{
"text": "刷新列表",
"callback_data": f"sites:{request.request_id}:refresh",
},
{
"text": "关闭",
"callback_data": f"sites:{request.request_id}:close",
},
],
]
)
update_or_post_message(
chain=self,
channel=channel,
source=source,
userid=userid,
username=username,
title="站点管理",
text=text,
buttons=buttons,
original_message_id=original_message_id,
original_chat_id=original_chat_id,
)
def _format_site_list(
self, site_list: List[Site], channel: Optional[MessageChannel]
) -> str:
"""
根据渠道能力格式化站点列表。
"""
if supports_markdown(channel):
rows = [
[
site.id,
site.name,
"启用" if site.is_active else "禁用",
"已配置" if site.cookie else "未配置",
"" if site.render else "",
site.domain or StringUtils.get_url_domain(site.url or ""),
]
for site in site_list
]
return format_markdown_table(
headers=["ID", "站点", "状态", "Cookie", "渲染", "域名"],
rows=rows,
)
lines = []
for site in site_list:
lines.append(
f"{site.id}. {site.name} | 状态:{'启用' if site.is_active else '禁用'}"
f" | Cookie{'已配置' if site.cookie else '未配置'}"
f" | 渲染:{'' if site.render else ''}"
f" | 域名:{site.domain or StringUtils.get_url_domain(site.url or '')}"
)
return "\n".join(lines)
@staticmethod
def _site_prompt(awaiting_input: Optional[str]) -> str:
"""
返回当前输入模式提示。
"""
if awaiting_input == "cookie":
return "当前操作:更新站点 Cookie请输入<id> <username> <password> [2fa_code/secret]"
if awaiting_input == "enable":
return "当前操作:启用站点,请输入站点 ID多个 ID 用空格分隔。"
if awaiting_input == "disable":
return "当前操作:禁用站点,请输入站点 ID多个 ID 用空格分隔。"
return ""
@staticmethod
def _site_usage_hint(awaiting_input: Optional[str]) -> str:
"""
返回 /sites 的文本操作提示。
"""
if awaiting_input == "cookie":
return "输入站点 ID、用户名、密码和可选 2FA输入 `取消` 返回列表,输入 `退出` 结束交互。"
if awaiting_input in {"enable", "disable"}:
return "输入一个或多个站点 ID输入 `取消` 返回列表,输入 `退出` 结束交互。"
return (
"可输入:`cookie <id> <username> <password> [2fa]`、`启用 <id...>`、`禁用 <id...>`、"
"`n`、`p`、`刷新`、`退出`。"
)
@staticmethod
def _parse_site_ids(arg_str: str) -> List[int]:
"""
从输入中提取站点 ID。
"""
return [int(item) for item in re.findall(r"\d+", arg_str or "")]
def _set_sites_enabled(self, arg_str: str, enabled: bool) -> Tuple[bool, str]:
"""
批量启用或禁用站点。
"""
site_ids = self._parse_site_ids(arg_str)
if not site_ids:
return False, "请输入至少一个有效的站点 ID"
siteoper = SiteOper()
changed = []
missing = []
for site_id in site_ids:
site = siteoper.get(site_id)
if not site:
missing.append(str(site_id))
continue
siteoper.update(site_id, {"is_active": enabled})
changed.append(site.name)
action = "启用" if enabled else "禁用"
if not changed and missing:
return False, f"未找到站点:{', '.join(missing)}"
message = f"{action} {len(changed)} 个站点"
if changed:
message += f"{', '.join(changed)}"
if missing:
message += f";未找到:{', '.join(missing)}"
return True, message
def _update_site_cookie_from_input(self, arg_str: str) -> Tuple[bool, str]:
"""
根据输入更新单个站点 Cookie。
"""
args = str(arg_str or "").split()
if len(args) not in {3, 4} or not args[0].isdigit():
return (
False,
"格式错误请输入cookie <id> <username> <password> [2fa_code/secret]",
)
site_id = int(args[0])
site_info = SiteOper().get(site_id)
if not site_info:
return False, f"站点编号 {site_id} 不存在"
status, msg = self.update_cookie(
site_info=site_info,
username=args[1],
password=args[2],
two_step_code=args[3] if len(args) == 4 else None,
)
if not status:
logger.error(msg)
return False, f"{site_info.name}】Cookie&UA 更新失败:{msg}"
return True, f"{site_info.name}】Cookie&UA 更新成功"
def remote_disable(self, arg_str: str, channel: MessageChannel,
userid: Union[str, int] = None, source: Optional[str] = None):

View File

@@ -1,4 +1,3 @@
import math
import re
from dataclasses import dataclass, field
from datetime import datetime, timedelta
@@ -7,9 +6,14 @@ from typing import Dict, List, Optional, Tuple, Union
import uuid
from app.chain import ChainBase
from app.helper.slash import (
build_navigation_buttons,
page_items,
supports_interaction_buttons,
update_or_post_message,
)
from app.helper.skill import SkillHelper, SkillInfo
from app.schemas import Notification
from app.schemas.message import ChannelCapabilityManager
from app.schemas.types import MessageChannel
@@ -1055,11 +1059,7 @@ class SkillsChain(ChainBase):
"""
返回当前页的数据,并把页码钳制到有效范围内。
"""
total_pages = max(1, math.ceil(len(items) / page_size)) if page_size else 1
page = min(max(0, page), total_pages - 1)
start = page * page_size
end = start + page_size
return items[start:end], page, total_pages
return page_items(items=items, page=page, page_size=page_size)
def _page_size(self, channel: Optional[MessageChannel]) -> int:
"""
@@ -1076,11 +1076,7 @@ class SkillsChain(ChainBase):
"""
判断当前渠道是否同时支持按钮展示和回调。
"""
return bool(
channel
and ChannelCapabilityManager.supports_buttons(channel)
and ChannelCapabilityManager.supports_callbacks(channel)
)
return supports_interaction_buttons(channel)
@staticmethod
def _navigation_buttons(
@@ -1091,25 +1087,12 @@ class SkillsChain(ChainBase):
"""
为分页视图生成上一页和下一页按钮。
"""
buttons = []
nav_row = []
if page > 0:
nav_row.append(
{
"text": "⬅️ 上一页",
"callback_data": f"skills:{request.request_id}:page-prev",
}
)
if page < total_pages - 1:
nav_row.append(
{
"text": "下一页 ➡️",
"callback_data": f"skills:{request.request_id}:page-next",
}
)
if nav_row:
buttons.append(nav_row)
return buttons
return build_navigation_buttons(
prefix="skills",
request=request,
page=page,
total_pages=total_pages,
)
def _update_or_post_message(
self,
@@ -1126,33 +1109,17 @@ class SkillsChain(ChainBase):
"""
优先编辑原消息,编辑失败时再回退为发送新消息。
"""
if (
original_message_id
and original_chat_id
and ChannelCapabilityManager.supports_editing(channel)
):
edited = self.edit_message(
channel=channel,
source=source,
message_id=original_message_id,
chat_id=original_chat_id,
title=title,
text=text,
buttons=buttons,
)
if edited:
return
self.post_message(
Notification(
channel=channel,
source=source,
userid=userid,
username=username,
title=title,
text=text,
buttons=buttons,
)
update_or_post_message(
chain=self,
channel=channel,
source=source,
userid=userid,
username=username,
title=title,
text=text,
buttons=buttons,
original_message_id=original_message_id,
original_chat_id=original_chat_id,
)
@staticmethod

View File

@@ -11,6 +11,15 @@ from app.chain import ChainBase
from app.chain.download import DownloadChain
from app.chain.media import MediaChain
from app.chain.search import SearchChain
from app.helper.slash import (
SlashInteractionManager,
build_navigation_buttons,
format_markdown_table,
page_items,
supports_interaction_buttons,
supports_markdown,
update_or_post_message,
)
from app.chain.tmdb import TmdbChain
from app.chain.torrents import TorrentsChain
from app.core.config import settings, global_vars
@@ -32,6 +41,9 @@ from app.schemas.types import MediaType, SystemConfigKey, MessageChannel, Notifi
ContentType
subscribe_interaction_manager = SlashInteractionManager()
class SubscribeChain(ChainBase):
"""
订阅管理处理链
@@ -40,6 +52,8 @@ class SubscribeChain(ChainBase):
_rlock = threading.RLock()
# 避免莫名原因导致长时间持有锁
_LOCK_TIMOUT = 3600 * 2
_button_page_size = 6
_text_page_size = 10
@staticmethod
def __get_event_media(_mediaid: str, _meta: MetaBase) -> Optional[MediaInfo]:
@@ -1385,33 +1399,670 @@ class SubscribeChain(ChainBase):
"doubanid": mediainfo.douban_id
})
def remote_list(self, channel: MessageChannel,
userid: Union[str, int] = None, source: Optional[str] = None):
def remote_list(
self,
arg_str: str = "",
channel: MessageChannel = None,
userid: Union[str, int] = None,
source: Optional[str] = None,
):
"""
查询订阅并发送消息
/subscribes 统一入口。
"""
request = subscribe_interaction_manager.create_or_replace(
user_id=userid,
command="/subscribes",
channel=channel,
source=source,
username=None,
)
normalized_arg = (arg_str or "").strip()
if normalized_arg and self.handle_text_interaction(
channel=channel,
source=source,
userid=userid,
username="",
text=normalized_arg,
):
return
self._render_subscribe_interaction(
request=request,
channel=channel,
source=source,
userid=userid,
username="",
)
@staticmethod
def parse_callback(callback_data: str) -> Optional[Tuple[str, str]]:
"""
解析 /subscribes 按钮回调。
"""
if not callback_data.startswith("subscribes:"):
return None
parts = callback_data.split(":")
if len(parts) < 3:
return None
return parts[1], parts[2]
def handle_callback_interaction(
self,
callback_data: str,
channel: MessageChannel,
source: str,
userid: Union[str, int],
username: str,
original_message_id: Optional[Union[str, int]] = None,
original_chat_id: Optional[str] = None,
) -> bool:
"""
处理 /subscribes 按钮交互。
"""
parsed = self.parse_callback(callback_data)
if not parsed:
return False
request_id, action = parsed
request = subscribe_interaction_manager.get_by_id(request_id, userid)
if not request:
self.post_message(
schemas.Notification(
channel=channel,
source=source,
userid=userid,
username=username,
title="订阅交互已失效,请重新发送 /subscribes",
)
)
return True
request.channel = channel
request.source = source
request.username = username
if action == "close":
subscribe_interaction_manager.remove(request.request_id)
update_or_post_message(
chain=self,
channel=channel,
source=source,
userid=userid,
username=username,
title="订阅管理",
text="订阅交互已结束",
original_message_id=original_message_id,
original_chat_id=original_chat_id,
)
return True
if action == "page-prev":
request.page = max(0, request.page - 1)
request.awaiting_input = None
elif action == "page-next":
request.page += 1
request.awaiting_input = None
elif action in {"search", "delete"}:
request.awaiting_input = action
elif action == "refresh":
request.awaiting_input = None
self._run_refresh_action(channel, source, userid, username)
elif action == "refresh-list":
request.awaiting_input = None
elif action == "metadata":
request.awaiting_input = None
self._run_metadata_refresh_action(channel, source, userid, username)
self._render_subscribe_interaction(
request=request,
channel=channel,
source=source,
userid=userid,
username=username,
original_message_id=original_message_id,
original_chat_id=original_chat_id,
)
return True
def handle_text_interaction(
self,
channel: MessageChannel,
source: str,
userid: Union[str, int],
username: str,
text: str,
) -> bool:
"""
处理 /subscribes 文本补充输入。
"""
request = subscribe_interaction_manager.get_by_user(userid)
if not request:
return False
request.channel = channel
request.source = source
request.username = username
normalized = (text or "").strip()
lowered = normalized.lower()
if lowered in {"退出", "关闭", "q", "quit", "exit"}:
subscribe_interaction_manager.remove(request.request_id)
self.post_message(
schemas.Notification(
channel=channel,
source=source,
userid=userid,
username=username,
title="订阅交互已结束",
)
)
return True
if lowered in {"取消", "cancel", "返回", "back"}:
request.awaiting_input = None
self._render_subscribe_interaction(
request=request,
channel=channel,
source=source,
userid=userid,
username=username,
)
return True
if lowered in {"刷新列表", "列表", "list"}:
request.awaiting_input = None
self._render_subscribe_interaction(
request=request,
channel=channel,
source=source,
userid=userid,
username=username,
)
return True
if lowered in {"刷新", "refresh"}:
request.awaiting_input = None
self._run_refresh_action(channel, source, userid, username)
self._render_subscribe_interaction(
request=request,
channel=channel,
source=source,
userid=userid,
username=username,
)
return True
if lowered in {"元数据", "刷新元数据", "metadata"}:
request.awaiting_input = None
self._run_metadata_refresh_action(channel, source, userid, username)
self._render_subscribe_interaction(
request=request,
channel=channel,
source=source,
userid=userid,
username=username,
)
return True
if lowered in {"p", "prev", "上一页"}:
request.awaiting_input = None
request.page = max(0, request.page - 1)
self._render_subscribe_interaction(
request=request,
channel=channel,
source=source,
userid=userid,
username=username,
)
return True
if lowered in {"n", "next", "下一页"}:
request.awaiting_input = None
request.page += 1
self._render_subscribe_interaction(
request=request,
channel=channel,
source=source,
userid=userid,
username=username,
)
return True
search_match = re.match(r"^(?:搜索|search)\s+(.+)$", normalized, re.IGNORECASE)
delete_match = re.match(r"^(?:删除|delete)\s+(.+)$", normalized, re.IGNORECASE)
if request.awaiting_input == "search":
success, message = self._run_search_action(
normalized, channel, source, userid, username
)
request.awaiting_input = None
self.post_message(
schemas.Notification(
channel=channel,
source=source,
userid=userid,
username=username,
title=message,
)
)
self._render_subscribe_interaction(
request=request,
channel=channel,
source=source,
userid=userid,
username=username,
)
return True
if request.awaiting_input == "delete":
success, message = self._delete_subscribes(normalized)
request.awaiting_input = None
self.post_message(
schemas.Notification(
channel=channel,
source=source,
userid=userid,
username=username,
title=message,
)
)
self._render_subscribe_interaction(
request=request,
channel=channel,
source=source,
userid=userid,
username=username,
)
return True
if search_match:
success, message = self._run_search_action(
search_match.group(1), channel, source, userid, username
)
self.post_message(
schemas.Notification(
channel=channel,
source=source,
userid=userid,
username=username,
title=message,
)
)
self._render_subscribe_interaction(
request=request,
channel=channel,
source=source,
userid=userid,
username=username,
)
return True
if delete_match:
success, message = self._delete_subscribes(delete_match.group(1))
self.post_message(
schemas.Notification(
channel=channel,
source=source,
userid=userid,
username=username,
title=message,
)
)
self._render_subscribe_interaction(
request=request,
channel=channel,
source=source,
userid=userid,
username=username,
)
return True
self.post_message(
schemas.Notification(
channel=channel,
source=source,
userid=userid,
username=username,
title=self._subscribe_usage_hint(request.awaiting_input),
)
)
return True
def _render_subscribe_interaction(
self,
request,
channel: MessageChannel,
source: Optional[str],
userid: Union[str, int],
username: Optional[str],
original_message_id: Optional[Union[str, int]] = None,
original_chat_id: Optional[str] = None,
) -> None:
"""
渲染 /subscribes 当前页面。
"""
subscribes = SubscribeOper().list()
if not subscribes:
self.post_message(schemas.Notification(channel=channel,
source=source,
title='没有任何订阅!', userid=userid))
return
title = f"共有 {len(subscribes)} 个订阅,回复对应指令操作: " \
f"\n- 删除订阅:/subscribe_delete [id]" \
f"\n- 搜索订阅:/subscribe_search [id]" \
f"\n- 刷新订阅:/subscribe_refresh"
messages = []
page_size = (
self._button_page_size
if supports_interaction_buttons(channel)
else self._text_page_size
)
page_subscribes, page, total_pages = page_items(
subscribes, request.page, page_size
)
request.page = page
if subscribes:
body = self._format_subscribe_list(page_subscribes, channel=channel)
footer = [
f"{page + 1}/{total_pages} 页,共 {len(subscribes)} 个订阅",
self._subscribe_prompt(request.awaiting_input),
self._subscribe_usage_hint(request.awaiting_input),
]
text = "\n\n".join([body, *[line for line in footer if line]])
else:
text = "当前没有任何订阅。\n\n输入 `退出` 结束交互。"
buttons = None
if supports_interaction_buttons(channel):
buttons = build_navigation_buttons(
"subscribes", request, page, total_pages
)
buttons.extend(
[
[
{
"text": "搜索订阅",
"callback_data": f"subscribes:{request.request_id}:search",
},
{
"text": "删除订阅",
"callback_data": f"subscribes:{request.request_id}:delete",
},
{
"text": "刷新订阅",
"callback_data": f"subscribes:{request.request_id}:refresh",
},
],
[
{
"text": "刷新元数据",
"callback_data": f"subscribes:{request.request_id}:metadata",
},
{
"text": "刷新列表",
"callback_data": f"subscribes:{request.request_id}:refresh-list",
},
{
"text": "关闭",
"callback_data": f"subscribes:{request.request_id}:close",
},
],
]
)
update_or_post_message(
chain=self,
channel=channel,
source=source,
userid=userid,
username=username,
title="订阅管理",
text=text,
buttons=buttons,
original_message_id=original_message_id,
original_chat_id=original_chat_id,
)
def _format_subscribe_list(
self, subscribes: List[Subscribe], channel: Optional[MessageChannel]
) -> str:
"""
根据渠道能力格式化订阅列表。
"""
if supports_markdown(channel):
rows = [
[
subscribe.id,
subscribe.name,
subscribe.type,
subscribe.year or "-",
self._format_subscribe_progress(subscribe),
self._format_subscribe_state(subscribe.state),
]
for subscribe in subscribes
]
return format_markdown_table(
headers=["ID", "名称", "类型", "年份", "季/进度", "状态"],
rows=rows,
)
lines = []
for subscribe in subscribes:
if subscribe.type == MediaType.MOVIE.value:
messages.append(f"{subscribe.id}. {subscribe.name}{subscribe.year}")
else:
messages.append(f"{subscribe.id}. {subscribe.name}{subscribe.year}"
f"{subscribe.season}"
f"[{subscribe.total_episode - (subscribe.lack_episode or subscribe.total_episode)}"
f"/{subscribe.total_episode}]")
# 发送列表
self.post_message(schemas.Notification(channel=channel, source=source,
title=title, text='\n'.join(messages), userid=userid))
lines.append(
f"{subscribe.id}. {subscribe.name}{subscribe.year or '-'}"
f" | {subscribe.type}"
f" | {self._format_subscribe_progress(subscribe)}"
f" | 状态:{self._format_subscribe_state(subscribe.state)}"
)
return "\n".join(lines)
@staticmethod
def _format_subscribe_state(state: Optional[str]) -> str:
"""
订阅状态显示文本。
"""
mapping = {
"N": "新建",
"R": "订阅中",
"P": "待定",
"S": "暂停",
}
return mapping.get(state or "", state or "-")
@staticmethod
def _format_subscribe_progress(subscribe: Subscribe) -> str:
"""
构造订阅的季和进度说明。
"""
if subscribe.type == MediaType.MOVIE.value:
return "电影"
season = subscribe.season or 1
if subscribe.total_episode:
lack_episode = (
subscribe.lack_episode
if subscribe.lack_episode is not None
else subscribe.total_episode
)
downloaded = max(subscribe.total_episode - lack_episode, 0)
return f"{season}季 [{downloaded}/{subscribe.total_episode}]"
return f"{season}"
@staticmethod
def _subscribe_prompt(awaiting_input: Optional[str]) -> str:
"""
返回当前输入模式提示。
"""
if awaiting_input == "search":
return "当前操作:搜索订阅,请输入订阅 ID多个 ID 用空格分隔,或输入 all 搜索全部。"
if awaiting_input == "delete":
return "当前操作:删除订阅,请输入订阅 ID多个 ID 用空格分隔。"
return ""
@staticmethod
def _subscribe_usage_hint(awaiting_input: Optional[str]) -> str:
"""
返回 /subscribes 的文本操作提示。
"""
if awaiting_input == "search":
return "输入订阅 ID 或 all输入 `取消` 返回列表,输入 `退出` 结束交互。"
if awaiting_input == "delete":
return "输入一个或多个订阅 ID输入 `取消` 返回列表,输入 `退出` 结束交互。"
return (
"可输入:`搜索 <id...|all>`、`删除 <id...>`、`刷新`、`刷新元数据`、`n`、`p`、`退出`。"
)
def _run_refresh_action(
self,
channel: MessageChannel,
source: str,
userid: Union[str, int],
username: str,
) -> None:
"""
执行订阅刷新。
"""
self.post_message(
schemas.Notification(
channel=channel,
source=source,
userid=userid,
username=username,
title="开始刷新订阅...",
)
)
self.refresh()
self.post_message(
schemas.Notification(
channel=channel,
source=source,
userid=userid,
username=username,
title="订阅刷新执行完成",
)
)
def _run_metadata_refresh_action(
self,
channel: MessageChannel,
source: str,
userid: Union[str, int],
username: str,
) -> None:
"""
执行订阅元数据刷新。
"""
self.post_message(
schemas.Notification(
channel=channel,
source=source,
userid=userid,
username=username,
title="开始刷新订阅元数据...",
)
)
self.check()
self.post_message(
schemas.Notification(
channel=channel,
source=source,
userid=userid,
username=username,
title="订阅元数据刷新完成",
)
)
@staticmethod
def _parse_subscribe_ids(arg_str: str) -> List[int]:
"""
从输入中提取订阅 ID。
"""
return [int(item) for item in re.findall(r"\d+", arg_str or "")]
def _run_search_action(
self,
arg_str: str,
channel: MessageChannel,
source: str,
userid: Union[str, int],
username: str,
) -> Tuple[bool, str]:
"""
手动执行订阅搜索。
"""
normalized = (arg_str or "").strip()
if not normalized or normalized.lower() in {"all", "全部", "所有"}:
self.post_message(
schemas.Notification(
channel=channel,
source=source,
userid=userid,
username=username,
title="开始搜索所有订阅...",
)
)
self.search(state="N,R,P", manual=True)
return True, "所有订阅搜索完成"
subscribe_ids = self._parse_subscribe_ids(normalized)
if not subscribe_ids:
return False, "请输入订阅 ID多个 ID 用空格分隔,或输入 all"
subscribeoper = SubscribeOper()
missing = []
searched = []
for subscribe_id in subscribe_ids:
subscribe = subscribeoper.get(subscribe_id)
if not subscribe:
missing.append(str(subscribe_id))
continue
self.post_message(
schemas.Notification(
channel=channel,
source=source,
userid=userid,
username=username,
title=f"开始搜索订阅【{subscribe.name}】...",
)
)
self.search(sid=subscribe_id, manual=True)
searched.append(subscribe.name)
if not searched and missing:
return False, f"未找到订阅:{', '.join(missing)}"
message = f"已完成 {len(searched)} 个订阅搜索"
if searched:
message += f"{', '.join(searched)}"
if missing:
message += f";未找到:{', '.join(missing)}"
return True, message
def _delete_subscribes(self, arg_str: str) -> Tuple[bool, str]:
"""
批量删除订阅。
"""
subscribe_ids = self._parse_subscribe_ids(arg_str)
if not subscribe_ids:
return False, "请输入至少一个有效的订阅 ID"
subscribeoper = SubscribeOper()
subscribehelper = SubscribeHelper()
deleted = []
missing = []
for subscribe_id in subscribe_ids:
subscribe = subscribeoper.get(subscribe_id)
if not subscribe:
missing.append(str(subscribe_id))
continue
deleted.append(subscribe.name)
subscribeoper.delete(subscribe_id)
subscribehelper.sub_done_async(
{
"tmdbid": subscribe.tmdbid,
"doubanid": subscribe.doubanid,
}
)
if not deleted and missing:
return False, f"未找到订阅:{', '.join(missing)}"
message = f"已删除 {len(deleted)} 个订阅"
if deleted:
message += f"{', '.join(deleted)}"
if missing:
message += f";未找到:{', '.join(missing)}"
return True, message
def remote_delete(self, arg_str: str, channel: MessageChannel,
userid: Union[str, int] = None, source: Optional[str] = None):

View File

@@ -50,30 +50,10 @@ class Command(metaclass=Singleton):
},
"/sites": {
"func": SiteChain().remote_list,
"description": "查询站点",
"description": "管理站点",
"category": "站点",
"data": {},
},
"/site_cookie": {
"func": SiteChain().remote_cookie,
"description": "更新站点Cookie",
"data": {},
},
"/site_statistic": {
"func": SiteChain().remote_refresh_userdatas,
"description": "站点数据统计",
"data": {},
},
"/site_enable": {
"func": SiteChain().remote_enable,
"description": "启用站点",
"data": {},
},
"/site_disable": {
"func": SiteChain().remote_disable,
"description": "禁用站点",
"data": {},
},
"/mediaserver_sync": {
"id": "mediaserver_sync",
"type": "scheduler",
@@ -82,32 +62,10 @@ class Command(metaclass=Singleton):
},
"/subscribes": {
"func": SubscribeChain().remote_list,
"description": "查询订阅",
"description": "管理订阅",
"category": "订阅",
"data": {},
},
"/subscribe_refresh": {
"id": "subscribe_refresh",
"type": "scheduler",
"description": "刷新订阅",
"category": "订阅",
},
"/subscribe_search": {
"id": "subscribe_search",
"type": "scheduler",
"description": "搜索订阅",
"category": "订阅",
},
"/subscribe_delete": {
"func": SubscribeChain().remote_delete,
"description": "删除订阅",
"data": {},
},
"/subscribe_tmdb": {
"id": "subscribe_tmdb",
"type": "scheduler",
"description": "订阅元数据更新",
},
"/downloading": {
"func": DownloadChain().remote_downloading,
"description": "正在下载",

244
app/helper/slash.py Normal file
View File

@@ -0,0 +1,244 @@
import math
import uuid
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from threading import Lock
from typing import Dict, List, Optional, Sequence, Tuple, Union
from app.schemas import Notification
from app.schemas.message import ChannelCapabilityManager
from app.schemas.types import MessageChannel
@dataclass
class PendingSlashInteraction:
"""
通用 slash 命令交互上下文。
"""
request_id: str
user_id: str
channel: Optional[MessageChannel]
source: Optional[str]
username: Optional[str]
command: str
page: int = 0
awaiting_input: Optional[str] = None
created_at: datetime = field(default_factory=datetime.now)
class SlashInteractionManager:
"""
管理单个 slash 命令的交互会话。
"""
_ttl = timedelta(hours=24)
def __init__(self):
self._by_id: Dict[str, PendingSlashInteraction] = {}
self._by_user: Dict[str, str] = {}
self._lock = Lock()
def _cleanup_locked(self) -> None:
expire_before = datetime.now() - self._ttl
expired = [
request_id
for request_id, request in self._by_id.items()
if request.created_at < expire_before
]
for request_id in expired:
request = self._by_id.pop(request_id, None)
if request:
self._by_user.pop(str(request.user_id), None)
def create_or_replace(
self,
user_id: Union[str, int],
command: str,
channel: Optional[MessageChannel],
source: Optional[str],
username: Optional[str],
) -> PendingSlashInteraction:
with self._lock:
self._cleanup_locked()
user_key = str(user_id)
old_request_id = self._by_user.get(user_key)
if old_request_id:
self._by_id.pop(old_request_id, None)
request = PendingSlashInteraction(
request_id=uuid.uuid4().hex[:12],
user_id=user_key,
command=command,
channel=channel,
source=source,
username=username,
)
self._by_id[request.request_id] = request
self._by_user[user_key] = request.request_id
return request
def get_by_user(
self, user_id: Union[str, int]
) -> Optional[PendingSlashInteraction]:
with self._lock:
self._cleanup_locked()
request_id = self._by_user.get(str(user_id))
if not request_id:
return None
return self._by_id.get(request_id)
def get_by_id(
self, request_id: str, user_id: Union[str, int]
) -> Optional[PendingSlashInteraction]:
with self._lock:
self._cleanup_locked()
request = self._by_id.get(request_id)
if not request or str(request.user_id) != str(user_id):
return None
return request
def remove(self, request_id: str) -> None:
with self._lock:
request = self._by_id.pop(request_id, None)
if request:
self._by_user.pop(str(request.user_id), None)
def clear(self) -> None:
with self._lock:
self._by_id.clear()
self._by_user.clear()
def supports_interaction_buttons(channel: Optional[MessageChannel]) -> bool:
"""
渠道同时支持按钮和回调时,优先使用按钮交互。
"""
return bool(
channel
and ChannelCapabilityManager.supports_buttons(channel)
and ChannelCapabilityManager.supports_callbacks(channel)
)
def supports_markdown(channel: Optional[MessageChannel]) -> bool:
"""
仅在支持 Markdown 的渠道上输出 Markdown 内容。
"""
return bool(channel and ChannelCapabilityManager.supports_markdown(channel))
def page_items(
items: Sequence,
page: int,
page_size: int,
) -> Tuple[List, int, int]:
"""
对列表做分页并规范化页码。
"""
total = len(items)
if total == 0:
return [], 0, 1
total_pages = max(1, math.ceil(total / max(1, page_size)))
page = min(max(0, page), total_pages - 1)
start = page * page_size
end = start + page_size
return list(items[start:end]), page, total_pages
def build_navigation_buttons(
prefix: str,
request: PendingSlashInteraction,
page: int,
total_pages: int,
) -> List[List[dict]]:
"""
构造标准上一页/下一页按钮。
"""
buttons = []
nav_row = []
if page > 0:
nav_row.append(
{
"text": "⬅️ 上一页",
"callback_data": f"{prefix}:{request.request_id}:page-prev",
}
)
if page < total_pages - 1:
nav_row.append(
{
"text": "下一页 ➡️",
"callback_data": f"{prefix}:{request.request_id}:page-next",
}
)
if nav_row:
buttons.append(nav_row)
return buttons
def update_or_post_message(
chain,
channel: MessageChannel,
source: Optional[str],
userid: Union[str, int],
username: Optional[str],
title: str,
text: str,
buttons: Optional[List[List[dict]]] = None,
original_message_id: Optional[Union[str, int]] = None,
original_chat_id: Optional[str] = None,
) -> None:
"""
优先编辑原消息,失败时回退为发送新消息。
"""
if (
original_message_id
and original_chat_id
and ChannelCapabilityManager.supports_editing(channel)
):
edited = chain.edit_message(
channel=channel,
source=source,
message_id=original_message_id,
chat_id=original_chat_id,
title=title,
text=text,
buttons=buttons,
)
if edited:
return
chain.post_message(
Notification(
channel=channel,
source=source,
userid=userid,
username=username,
title=title,
text=text,
buttons=buttons,
)
)
def escape_markdown_table_cell(value: object) -> str:
"""
最小化转义 Markdown 表格中的特殊字符。
"""
text = str(value or "").replace("\n", "<br>")
text = text.replace("|", "\\|")
return text
def format_markdown_table(headers: Sequence[str], rows: Sequence[Sequence[object]]) -> str:
"""
生成 Markdown 表格文本。
"""
header_line = "| " + " | ".join(escape_markdown_table_cell(item) for item in headers) + " |"
separator_line = "| " + " | ".join("---" for _ in headers) + " |"
data_lines = [
"| "
+ " | ".join(escape_markdown_table_cell(item) for item in row)
+ " |"
for row in rows
]
return "\n".join([header_line, separator_line, *data_lines])

View File

@@ -263,6 +263,8 @@ class ChannelCapability(Enum):
CALLBACK_QUERIES = "callback_queries"
# 支持富文本
RICH_TEXT = "rich_text"
# 支持 Markdown
MARKDOWN = "markdown"
# 支持图片
IMAGES = "images"
# 支持链接
@@ -301,6 +303,7 @@ class ChannelCapabilityManager:
ChannelCapability.MESSAGE_EDITING,
ChannelCapability.MESSAGE_DELETION,
ChannelCapability.CALLBACK_QUERIES,
ChannelCapability.MARKDOWN,
ChannelCapability.RICH_TEXT,
ChannelCapability.IMAGES,
ChannelCapability.LINKS,
@@ -315,6 +318,7 @@ class ChannelCapabilityManager:
MessageChannel.Wechat: ChannelCapabilities(
channel=MessageChannel.Wechat,
capabilities={
ChannelCapability.MARKDOWN,
ChannelCapability.IMAGES,
ChannelCapability.LINKS,
ChannelCapability.MENU_COMMANDS,
@@ -328,6 +332,7 @@ class ChannelCapabilityManager:
ChannelCapability.MESSAGE_EDITING,
ChannelCapability.MESSAGE_DELETION,
ChannelCapability.CALLBACK_QUERIES,
ChannelCapability.MARKDOWN,
ChannelCapability.RICH_TEXT,
ChannelCapability.IMAGES,
ChannelCapability.LINKS,
@@ -348,6 +353,7 @@ class ChannelCapabilityManager:
ChannelCapability.MESSAGE_EDITING,
ChannelCapability.MESSAGE_DELETION,
ChannelCapability.CALLBACK_QUERIES,
ChannelCapability.MARKDOWN,
ChannelCapability.RICH_TEXT,
ChannelCapability.IMAGES,
ChannelCapability.LINKS,
@@ -363,6 +369,7 @@ class ChannelCapabilityManager:
MessageChannel.SynologyChat: ChannelCapabilities(
channel=MessageChannel.SynologyChat,
capabilities={
ChannelCapability.MARKDOWN,
ChannelCapability.RICH_TEXT,
ChannelCapability.IMAGES,
ChannelCapability.LINKS,
@@ -372,6 +379,7 @@ class ChannelCapabilityManager:
MessageChannel.VoceChat: ChannelCapabilities(
channel=MessageChannel.VoceChat,
capabilities={
ChannelCapability.MARKDOWN,
ChannelCapability.RICH_TEXT,
ChannelCapability.IMAGES,
ChannelCapability.LINKS,
@@ -386,6 +394,7 @@ class ChannelCapabilityManager:
MessageChannel.Web: ChannelCapabilities(
channel=MessageChannel.Web,
capabilities={
ChannelCapability.MARKDOWN,
ChannelCapability.RICH_TEXT,
ChannelCapability.IMAGES,
ChannelCapability.LINKS,
@@ -395,6 +404,7 @@ class ChannelCapabilityManager:
MessageChannel.QQ: ChannelCapabilities(
channel=MessageChannel.QQ,
capabilities={
ChannelCapability.MARKDOWN,
ChannelCapability.RICH_TEXT,
ChannelCapability.IMAGES,
ChannelCapability.LINKS,
@@ -443,6 +453,13 @@ class ChannelCapabilityManager:
"""
return cls.supports_capability(channel, ChannelCapability.MESSAGE_EDITING)
@classmethod
def supports_markdown(cls, channel: MessageChannel) -> bool:
"""
检查渠道是否支持 Markdown。
"""
return cls.supports_capability(channel, ChannelCapability.MARKDOWN)
@classmethod
def supports_deletion(cls, channel: MessageChannel) -> bool:
"""

View File

@@ -21,7 +21,7 @@ Use this skill to identify user intent and dispatch the corresponding system or
- The user describes an action in natural language, for example:
- "Sync sites" → `/cookiecloud`
- "Show my subscriptions" → `/subscribes`
- "Refresh subscriptions" → `/subscribe_refresh`
- "Refresh subscriptions" → `/subscribes refresh`
- "What's downloading?" → `/downloading`
- "Organize downloaded files" → `/transfer`
- "Clear cache" → `/clear_cache`
@@ -58,7 +58,8 @@ If the user's description involves a specific plugin's functionality, additional
Some commands support additional arguments (space-separated after the command), for example:
- `/redo <history_id>` — Manually re-organize a specific record
- `/subscribe_delete <name>` — Delete a specific subscription
- `/sites disable <site_id>` — Disable one or more sites
- `/subscribes delete <subscribe_id>` — Delete one or more subscriptions
Use `run_slash_command` to execute the command in the format `/command_name arg1 arg2`.

View File

@@ -13,6 +13,18 @@ sys.modules.setdefault("transmission_rpc", ModuleType("transmission_rpc"))
setattr(sys.modules["transmission_rpc"], "File", object)
sys.modules.setdefault("psutil", ModuleType("psutil"))
sys.modules.setdefault("aioshutil", ModuleType("aioshutil"))
sys.modules.setdefault("pyquery", ModuleType("pyquery"))
setattr(sys.modules["pyquery"], "PyQuery", object)
sys.modules.setdefault("cn2an", ModuleType("cn2an"))
setattr(sys.modules["cn2an"], "cn2an", lambda value, mode=None: value)
setattr(sys.modules["cn2an"], "an2cn", lambda value, mode=None: str(value))
sys.modules.setdefault("dateparser", ModuleType("dateparser"))
setattr(sys.modules["dateparser"], "parse", lambda *args, **kwargs: None)
sys.modules.setdefault("dateutil", ModuleType("dateutil"))
dateutil_parser = ModuleType("dateutil.parser")
setattr(dateutil_parser, "parse", lambda *args, **kwargs: None)
sys.modules.setdefault("dateutil.parser", dateutil_parser)
setattr(sys.modules["dateutil"], "parser", dateutil_parser)
from app.chain.message import MessageChain
from app.chain.skills import SkillsChain, skills_interaction_manager

View File

@@ -0,0 +1,201 @@
import sys
import unittest
from types import ModuleType, SimpleNamespace
from unittest.mock import patch
sys.modules.setdefault("qbittorrentapi", ModuleType("qbittorrentapi"))
setattr(sys.modules["qbittorrentapi"], "TorrentFilesList", list)
sys.modules.setdefault("transmission_rpc", ModuleType("transmission_rpc"))
setattr(sys.modules["transmission_rpc"], "File", object)
sys.modules.setdefault("psutil", ModuleType("psutil"))
sys.modules.setdefault("aioshutil", ModuleType("aioshutil"))
sys.modules.setdefault("pyquery", ModuleType("pyquery"))
setattr(sys.modules["pyquery"], "PyQuery", object)
sys.modules.setdefault("cn2an", ModuleType("cn2an"))
setattr(sys.modules["cn2an"], "cn2an", lambda value, mode=None: value)
setattr(sys.modules["cn2an"], "an2cn", lambda value, mode=None: str(value))
sys.modules.setdefault("dateparser", ModuleType("dateparser"))
setattr(sys.modules["dateparser"], "parse", lambda *args, **kwargs: None)
sys.modules.setdefault("dateutil", ModuleType("dateutil"))
dateutil_parser = ModuleType("dateutil.parser")
setattr(dateutil_parser, "parse", lambda *args, **kwargs: None)
sys.modules.setdefault("dateutil.parser", dateutil_parser)
setattr(sys.modules["dateutil"], "parser", dateutil_parser)
from app.chain.message import MessageChain
from app.chain.site import SiteChain, site_interaction_manager
from app.chain.skills import skills_interaction_manager
from app.chain.subscribe import SubscribeChain, subscribe_interaction_manager
from app.schemas.types import MessageChannel
class TestSlashCommandInteractions(unittest.TestCase):
def tearDown(self):
skills_interaction_manager.clear()
site_interaction_manager.clear()
subscribe_interaction_manager.clear()
def test_message_routes_text_reply_to_latest_sites_interaction(self):
chain = MessageChain()
skills_interaction_manager.create_or_replace(
user_id="10001",
channel=MessageChannel.Wechat,
source="wechat-test",
username="tester",
)
site_interaction_manager.create_or_replace(
user_id="10001",
command="/sites",
channel=MessageChannel.Wechat,
source="wechat-test",
username="tester",
)
with patch.object(chain, "_record_user_message"), patch(
"app.chain.message.SiteChain.handle_text_interaction",
return_value=True,
) as handle_site, patch(
"app.chain.message.SkillsChain.handle_text_interaction"
) as handle_skills:
chain.handle_message(
channel=MessageChannel.Wechat,
source="wechat-test",
userid="10001",
username="tester",
text="禁用 1",
)
handle_site.assert_called_once()
handle_skills.assert_not_called()
def test_message_routes_text_reply_to_latest_subscribes_interaction(self):
chain = MessageChain()
site_interaction_manager.create_or_replace(
user_id="10001",
command="/sites",
channel=MessageChannel.Wechat,
source="wechat-test",
username="tester",
)
subscribe_interaction_manager.create_or_replace(
user_id="10001",
command="/subscribes",
channel=MessageChannel.Wechat,
source="wechat-test",
username="tester",
)
with patch.object(chain, "_record_user_message"), patch(
"app.chain.message.SubscribeChain.handle_text_interaction",
return_value=True,
) as handle_subscribes, patch(
"app.chain.message.SiteChain.handle_text_interaction"
) as handle_sites:
chain.handle_message(
channel=MessageChannel.Wechat,
source="wechat-test",
userid="10001",
username="tester",
text="搜索 all",
)
handle_subscribes.assert_called_once()
handle_sites.assert_not_called()
def test_callback_routes_to_sites_chain(self):
chain = MessageChain()
request = site_interaction_manager.create_or_replace(
user_id="10001",
command="/sites",
channel=MessageChannel.Telegram,
source="telegram-test",
username="tester",
)
with patch(
"app.chain.message.SiteChain.handle_callback_interaction",
return_value=True,
) as handle_callback:
chain._handle_callback(
text=f"CALLBACK:sites:{request.request_id}:refresh",
channel=MessageChannel.Telegram,
source="telegram-test",
userid="10001",
username="tester",
)
handle_callback.assert_called_once()
def test_callback_routes_to_subscribes_chain(self):
chain = MessageChain()
request = subscribe_interaction_manager.create_or_replace(
user_id="10001",
command="/subscribes",
channel=MessageChannel.Telegram,
source="telegram-test",
username="tester",
)
with patch(
"app.chain.message.SubscribeChain.handle_callback_interaction",
return_value=True,
) as handle_callback:
chain._handle_callback(
text=f"CALLBACK:subscribes:{request.request_id}:refresh",
channel=MessageChannel.Telegram,
source="telegram-test",
userid="10001",
username="tester",
)
handle_callback.assert_called_once()
def test_sites_renders_markdown_table_when_channel_supports_markdown(self):
chain = SiteChain()
fake_sites = [
SimpleNamespace(
id=1,
name="M-Team",
is_active=True,
cookie="cookie=value",
render=1,
domain="m-team.io",
url="https://m-team.io/",
)
]
with patch("app.chain.site.SiteOper.list", return_value=fake_sites), patch.object(
chain, "post_message"
) as post_message:
chain.remote_list(channel=MessageChannel.Web, userid="u1", source="web")
notification = post_message.call_args[0][0]
self.assertIn("| ID | 站点 | 状态 | Cookie | 渲染 | 域名 |", notification.text)
self.assertIn("| 1 | M-Team | 启用 | 已配置 | 是 | m-team.io |", notification.text)
def test_subscribes_renders_markdown_table_when_channel_supports_markdown(self):
chain = SubscribeChain()
fake_subscribes = [
SimpleNamespace(
id=12,
name="Example Show",
type="电视剧",
year="2024",
season=1,
total_episode=10,
lack_episode=3,
state="R",
)
]
with patch(
"app.chain.subscribe.SubscribeOper.list", return_value=fake_subscribes
), patch.object(chain, "post_message") as post_message:
chain.remote_list(channel=MessageChannel.Web, userid="u1", source="web")
notification = post_message.call_args[0][0]
self.assertIn("| ID | 名称 | 类型 | 年份 | 季/进度 | 状态 |", notification.text)
self.assertIn(
"| 12 | Example Show | 电视剧 | 2024 | 第1季 [7/10] | 订阅中 |",
notification.text,
)