From cefb60ba2c69007791fa30eefc593224934e947c Mon Sep 17 00:00:00 2001 From: jxxghp Date: Wed, 22 Apr 2026 15:18:04 +0800 Subject: [PATCH] refactor: unify message interactions --- app/agent/interaction.py | 107 -- app/agent/tools/impl/ask_user_choice.py | 2 +- app/chain/interaction.py | 1363 +++++++++++++++++++++++ app/chain/message.py | 1074 +++--------------- tests/test_agent_interaction.py | 2 +- tests/test_media_interaction.py | 158 +++ 6 files changed, 1657 insertions(+), 1049 deletions(-) delete mode 100644 app/agent/interaction.py create mode 100644 app/chain/interaction.py create mode 100644 tests/test_media_interaction.py diff --git a/app/agent/interaction.py b/app/agent/interaction.py deleted file mode 100644 index f831cc6e..00000000 --- a/app/agent/interaction.py +++ /dev/null @@ -1,107 +0,0 @@ -"""Agent 客户端交互请求管理。""" - -from dataclasses import dataclass, field -from datetime import datetime, timedelta -from threading import Lock -from typing import Dict, List, Optional -import uuid - - -@dataclass(frozen=True) -class AgentInteractionOption: - """交互选项。""" - - label: str - value: str - - -@dataclass -class PendingAgentInteraction: - """待处理的 Agent 客户端交互请求。""" - - request_id: str - session_id: str - user_id: str - channel: Optional[str] - source: Optional[str] - username: Optional[str] - title: Optional[str] - prompt: str - options: List[AgentInteractionOption] - created_at: datetime = field(default_factory=datetime.now) - - -class AgentInteractionManager: - """管理 Agent 发起的客户端交互请求。""" - - _ttl = timedelta(hours=24) - - def __init__(self): - self._pending_interactions: Dict[str, PendingAgentInteraction] = {} - self._lock = Lock() - - def _cleanup_locked(self): - expire_before = datetime.now() - self._ttl - expired_ids = [ - request_id - for request_id, request in self._pending_interactions.items() - if request.created_at < expire_before - ] - for request_id in expired_ids: - self._pending_interactions.pop(request_id, None) - - def create_request( - self, - session_id: str, - user_id: str, - channel: Optional[str], - source: Optional[str], - username: Optional[str], - title: Optional[str], - prompt: str, - options: List[AgentInteractionOption], - ) -> PendingAgentInteraction: - with self._lock: - self._cleanup_locked() - request_id = uuid.uuid4().hex[:12] - while request_id in self._pending_interactions: - request_id = uuid.uuid4().hex[:12] - request = PendingAgentInteraction( - request_id=request_id, - session_id=session_id, - user_id=str(user_id), - channel=channel, - source=source, - username=username, - title=title, - prompt=prompt, - options=options, - ) - self._pending_interactions[request_id] = request - return request - - def resolve( - self, - request_id: str, - option_index: int, - user_id: Optional[str] = None, - ) -> Optional[tuple[PendingAgentInteraction, AgentInteractionOption]]: - with self._lock: - self._cleanup_locked() - request = self._pending_interactions.get(request_id) - if not request: - return None - if user_id is not None and str(request.user_id) != str(user_id): - return None - if option_index < 1 or option_index > len(request.options): - return None - option = request.options[option_index - 1] - self._pending_interactions.pop(request_id, None) - return request, option - - def clear(self): - with self._lock: - self._pending_interactions.clear() - - -agent_interaction_manager = AgentInteractionManager() diff --git a/app/agent/tools/impl/ask_user_choice.py b/app/agent/tools/impl/ask_user_choice.py index f44e8bac..da1c02b9 100644 --- a/app/agent/tools/impl/ask_user_choice.py +++ b/app/agent/tools/impl/ask_user_choice.py @@ -5,7 +5,7 @@ from typing import List, Optional, Type from pydantic import BaseModel, Field, model_validator from app.agent.tools.base import MoviePilotTool, ToolChain -from app.agent.interaction import ( +from app.chain.interaction import ( AgentInteractionOption, agent_interaction_manager, ) diff --git a/app/chain/interaction.py b/app/chain/interaction.py new file mode 100644 index 00000000..c839e611 --- /dev/null +++ b/app/chain/interaction.py @@ -0,0 +1,1363 @@ +import math +import re +import uuid +from dataclasses import dataclass, field +from datetime import datetime, timedelta +from threading import Lock +from typing import Any, Dict, List, Optional, Tuple, Union + +from app.chain import ChainBase +from app.chain.download import DownloadChain +from app.chain.media import MediaChain +from app.chain.search import SearchChain +from app.chain.subscribe import SubscribeChain +from app.core.config import settings +from app.core.context import Context, MediaInfo +from app.core.meta import MetaBase +from app.db.user_oper import UserOper +from app.helper.torrent import TorrentHelper +from app.log import logger +from app.schemas import Notification, NotExistMediaInfo +from app.schemas.message import ChannelCapabilityManager +from app.schemas.types import MediaType, MessageChannel +from app.utils.string import StringUtils + + +@dataclass(frozen=True) +class AgentInteractionOption: + """ + Agent 交互选项。 + """ + + label: str + value: str + + +@dataclass +class PendingAgentInteraction: + """ + 待处理的 Agent 客户端交互请求。 + """ + + request_id: str + session_id: str + user_id: str + channel: Optional[str] + source: Optional[str] + username: Optional[str] + title: Optional[str] + prompt: str + options: List[AgentInteractionOption] + created_at: datetime = field(default_factory=datetime.now) + + +class AgentInteractionManager: + """ + 管理 Agent 发起的客户端交互请求。 + """ + + _ttl = timedelta(hours=24) + + def __init__(self): + self._pending_interactions: Dict[str, PendingAgentInteraction] = {} + self._lock = Lock() + + def _cleanup_locked(self) -> None: + expire_before = datetime.now() - self._ttl + expired_ids = [ + request_id + for request_id, request in self._pending_interactions.items() + if request.created_at < expire_before + ] + for request_id in expired_ids: + self._pending_interactions.pop(request_id, None) + + def create_request( + self, + session_id: str, + user_id: str, + channel: Optional[str], + source: Optional[str], + username: Optional[str], + title: Optional[str], + prompt: str, + options: List[AgentInteractionOption], + ) -> PendingAgentInteraction: + """ + 创建一条待用户确认的 Agent 交互请求。 + """ + with self._lock: + self._cleanup_locked() + request_id = uuid.uuid4().hex[:12] + while request_id in self._pending_interactions: + request_id = uuid.uuid4().hex[:12] + request = PendingAgentInteraction( + request_id=request_id, + session_id=session_id, + user_id=str(user_id), + channel=channel, + source=source, + username=username, + title=title, + prompt=prompt, + options=options, + ) + self._pending_interactions[request_id] = request + return request + + def resolve( + self, + request_id: str, + option_index: int, + user_id: Optional[str] = None, + ) -> Optional[tuple[PendingAgentInteraction, AgentInteractionOption]]: + """ + 消费一条 Agent 交互请求,并返回选中的选项。 + """ + with self._lock: + self._cleanup_locked() + request = self._pending_interactions.get(request_id) + if not request: + return None + if user_id is not None and str(request.user_id) != str(user_id): + return None + if option_index < 1 or option_index > len(request.options): + return None + option = request.options[option_index - 1] + self._pending_interactions.pop(request_id, None) + return request, option + + def clear(self) -> None: + """ + 清空所有 Agent 交互请求。 + """ + with self._lock: + self._pending_interactions.clear() + + +agent_interaction_manager = AgentInteractionManager() + + +@dataclass +class PendingMediaInteraction: + """ + 记录一次搜索/下载/订阅交互的当前上下文。 + """ + + request_id: str + user_id: str + channel: Optional[MessageChannel] + source: Optional[str] + username: Optional[str] + action: str + keyword: str + phase: str = "media" + page: int = 0 + title: str = "" + meta: Optional[MetaBase] = None + current_media: Optional[MediaInfo] = None + items: List[Any] = field(default_factory=list) + created_at: datetime = field(default_factory=datetime.now) + + +class MediaInteractionManager: + """ + 管理用户当前激活的媒体交互状态。 + + 每个用户只保留一个有效会话,避免旧按钮与新一轮搜索混用。 + """ + + _ttl = timedelta(hours=24) + + def __init__(self): + self._by_id: Dict[str, PendingMediaInteraction] = {} + self._by_user: Dict[str, str] = {} + self._lock = Lock() + + def _cleanup_locked(self) -> None: + """ + 清理超时会话,避免内存中残留旧交互状态。 + """ + expire_before = datetime.now() - self._ttl + expired = [ + request_id + for request_id, request in self._by_id.items() + if request.created_at < expire_before + ] + for request_id in expired: + request = self._by_id.pop(request_id, None) + if request: + self._by_user.pop(str(request.user_id), None) + + def create_or_replace( + self, + user_id: Union[str, int], + channel: Optional[MessageChannel], + source: Optional[str], + username: Optional[str], + action: str, + keyword: str, + title: str = "", + meta: Optional[MetaBase] = None, + items: Optional[List[Any]] = None, + ) -> PendingMediaInteraction: + """ + 为用户创建新的交互状态,并替换旧会话。 + """ + with self._lock: + self._cleanup_locked() + user_key = str(user_id) + old_request_id = self._by_user.get(user_key) + if old_request_id: + self._by_id.pop(old_request_id, None) + + request = PendingMediaInteraction( + request_id=uuid.uuid4().hex[:12], + user_id=user_key, + channel=channel, + source=source, + username=username, + action=action, + keyword=keyword, + title=title, + meta=meta, + items=list(items or []), + ) + self._by_id[request.request_id] = request + self._by_user[user_key] = request.request_id + return request + + def get_by_user( + self, user_id: Union[str, int] + ) -> Optional[PendingMediaInteraction]: + """ + 按用户读取当前会话,供文本回复和旧按钮兼容使用。 + """ + with self._lock: + self._cleanup_locked() + request_id = self._by_user.get(str(user_id)) + if not request_id: + return None + return self._by_id.get(request_id) + + def get_by_id( + self, request_id: str, user_id: Union[str, int] + ) -> Optional[PendingMediaInteraction]: + """ + 按请求 ID 读取会话,并校验用户归属。 + """ + with self._lock: + self._cleanup_locked() + request = self._by_id.get(request_id) + if not request or str(request.user_id) != str(user_id): + return None + return request + + def remove(self, request_id: str) -> None: + """ + 主动结束一条会话。 + """ + with self._lock: + request = self._by_id.pop(request_id, None) + if request: + self._by_user.pop(str(request.user_id), None) + + def clear(self) -> None: + """ + 清空所有交互状态,主要用于测试。 + """ + with self._lock: + self._by_id.clear() + self._by_user.clear() + + +media_interaction_manager = MediaInteractionManager() + + +class MediaInteractionChain(ChainBase): + """ + 处理媒体搜索、订阅、资源选择和翻页等交互流程。 + """ + + _button_page_size = 8 + _text_page_size = 8 + + @staticmethod + def has_pending_interaction(user_id: Union[str, int]) -> bool: + """ + 判断用户当前是否存在未结束的媒体交互。 + """ + return media_interaction_manager.get_by_user(user_id) is not None + + @staticmethod + def _get_noexits_info( + meta: MetaBase, mediainfo: MediaInfo + ) -> Dict[Union[int, str], Dict[int, NotExistMediaInfo]]: + """ + 构造媒体缺失集信息,用于全量重搜或自动下载补全集数。 + """ + if mediainfo.type == MediaType.TV: + if not mediainfo.seasons: + mediainfo = MediaChain().recognize_media( + mtype=mediainfo.type, + tmdbid=mediainfo.tmdb_id, + doubanid=mediainfo.douban_id, + cache=False, + ) + if not mediainfo: + logger.warn("媒体信息识别失败,无法补充季集信息") + return {} + if not mediainfo.seasons: + logger.warn( + "媒体信息中没有季集信息,标题:%s,tmdbid:%s,doubanid:%s", + mediainfo.title, + mediainfo.tmdb_id, + mediainfo.douban_id, + ) + return {} + + mediakey = mediainfo.tmdb_id or mediainfo.douban_id + no_exists = {mediakey: {}} + if meta.begin_season: + episodes = mediainfo.seasons.get(meta.begin_season) + if not episodes: + return {} + no_exists[mediakey][meta.begin_season] = NotExistMediaInfo( + season=meta.begin_season, + episodes=[], + total_episode=len(episodes), + start_episode=episodes[0], + ) + else: + for sea, eps in mediainfo.seasons.items(): + if not eps: + continue + no_exists[mediakey][sea] = NotExistMediaInfo( + season=sea, + episodes=[], + total_episode=len(eps), + start_episode=eps[0], + ) + return no_exists + return {} + + @staticmethod + def parse_callback( + callback_data: str, + ) -> Optional[Tuple[Optional[str], str, Optional[int]]]: + """ + 解析新旧两种媒体交互按钮格式。 + """ + if callback_data.startswith("media:"): + parts = callback_data.split(":") + if len(parts) < 3: + return None + request_id = parts[1] + action = parts[2] + index = None + if len(parts) >= 4 and parts[3].isdigit(): + index = int(parts[3]) + return request_id, action, index + + match = re.match(r"^(select|download)_(\d+)$", callback_data) + if match: + return None, match.group(1), int(match.group(2)) + if callback_data == "page_p": + return None, "page-prev", None + if callback_data == "page_n": + return None, "page-next", None + return None + + def handle_callback_interaction( + self, + callback_data: str, + channel: MessageChannel, + source: str, + userid: Union[str, int], + username: str, + original_message_id: Optional[Union[str, int]] = None, + original_chat_id: Optional[str] = None, + ) -> bool: + """ + 处理按钮回调,并将当前视图刷新到原消息上。 + """ + parsed = self.parse_callback(callback_data) + if not parsed: + return False + + request_id, action, index = parsed + if request_id: + request = media_interaction_manager.get_by_id(request_id, userid) + else: + request = media_interaction_manager.get_by_user(userid) + + if not request: + self.post_message( + Notification( + channel=channel, + source=source, + userid=userid, + username=username, + title="交互已失效,请重新搜索或订阅", + ) + ) + return True + + request.channel = channel + request.source = source + request.username = username + + if action == "page-prev": + if request.page <= 0: + self._post_invalid_input( + channel=channel, + source=source, + userid=userid, + username=username, + title="已经是第一页了!", + ) + return True + request.page -= 1 + self._render_interaction( + request=request, + channel=channel, + source=source, + userid=userid, + original_message_id=original_message_id, + original_chat_id=original_chat_id, + ) + return True + + if action == "page-next": + if not self._has_next_page(request): + self._post_invalid_input( + channel=channel, + source=source, + userid=userid, + username=username, + title="已经是最后一页了!", + ) + return True + request.page += 1 + self._render_interaction( + request=request, + channel=channel, + source=source, + userid=userid, + original_message_id=original_message_id, + original_chat_id=original_chat_id, + ) + return True + + if action == "select": + self._handle_media_selection( + request=request, + page_index=index, + channel=channel, + source=source, + userid=userid, + username=username, + original_message_id=original_message_id, + original_chat_id=original_chat_id, + ) + return True + + if action == "download": + self._handle_torrent_selection( + request=request, + page_index=index, + channel=channel, + source=source, + userid=userid, + username=username, + ) + return True + + return False + + def handle_text_interaction( + self, + channel: MessageChannel, + source: str, + userid: Union[str, int], + username: str, + text: str, + ) -> bool: + """ + 处理文本式交互。 + + 有会话时优先处理数字选择和翻页;无会话时负责识别搜索/订阅类入口。 + """ + request = media_interaction_manager.get_by_user(userid) + normalized = (text or "").strip() + lowered = normalized.lower() + + if request and lowered in {"退出", "关闭", "q", "quit", "exit"}: + media_interaction_manager.remove(request.request_id) + self.post_message( + Notification( + channel=channel, + source=source, + userid=userid, + username=username, + title="媒体交互已结束", + ) + ) + return True + + if normalized.isdigit(): + if not request: + self._post_invalid_input( + channel=channel, + source=source, + userid=userid, + username=username, + ) + return True + request.channel = channel + request.source = source + request.username = username + index = int(normalized) + if request.phase == "torrent": + self._handle_torrent_selection( + request=request, + page_index=index, + channel=channel, + source=source, + userid=userid, + username=username, + ) + else: + self._handle_media_selection( + request=request, + page_index=index, + channel=channel, + source=source, + userid=userid, + username=username, + ) + return True + + if lowered in {"p", "prev", "上一页"}: + if not request: + self._post_invalid_input( + channel=channel, + source=source, + userid=userid, + username=username, + ) + return True + if request.page <= 0: + self._post_invalid_input( + channel=channel, + source=source, + userid=userid, + username=username, + title="已经是第一页了!", + ) + return True + request.page -= 1 + request.channel = channel + request.source = source + request.username = username + self._render_interaction( + request=request, + channel=channel, + source=source, + userid=userid, + ) + return True + + if lowered in {"n", "next", "下一页"}: + if not request: + self._post_invalid_input( + channel=channel, + source=source, + userid=userid, + username=username, + ) + return True + if not self._has_next_page(request): + self._post_invalid_input( + channel=channel, + source=source, + userid=userid, + username=username, + title="已经是最后一页了!", + ) + return True + request.page += 1 + request.channel = channel + request.source = source + request.username = username + self._render_interaction( + request=request, + channel=channel, + source=source, + userid=userid, + ) + return True + + action, content = self._resolve_action(normalized) + if not action: + return False + + self._start_media_interaction( + action=action, + content=content, + channel=channel, + source=source, + userid=userid, + username=username, + ) + return True + + @staticmethod + def _resolve_action(text: str) -> Tuple[Optional[str], str]: + """ + 将用户输入归类为搜索、订阅或普通聊天。 + """ + if text.startswith("订阅"): + return "Subscribe", re.sub(r"订阅[::\s]*", "", text) + if text.startswith("洗版"): + return "ReSubscribe", re.sub(r"洗版[::\s]*", "", text) + if text.startswith("搜索") or text.startswith("下载"): + return "ReSearch", re.sub(r"(搜索|下载)[::\s]*", "", text) + if StringUtils.is_link(text): + return None, text + if not StringUtils.is_media_title_like(text): + return None, text + return "Search", text + + def _start_media_interaction( + self, + action: str, + content: str, + channel: MessageChannel, + source: str, + userid: Union[str, int], + username: str, + ) -> None: + """ + 根据用户输入搜索媒体,并进入媒体选择阶段。 + """ + meta, medias = MediaChain().search(content) + if not meta.name: + self._post_invalid_input( + channel=channel, + source=source, + userid=userid, + username=username, + title="无法识别输入内容!", + ) + return + if not medias: + self.post_message( + Notification( + channel=channel, + source=source, + userid=userid, + username=username, + title=f"{meta.name} 没有找到对应的媒体信息!", + ) + ) + return + + logger.info("搜索到 %s 条相关媒体信息", len(medias)) + request = media_interaction_manager.create_or_replace( + user_id=userid, + channel=channel, + source=source, + username=username, + action=action, + keyword=content, + title=meta.name, + meta=meta, + items=medias, + ) + self._render_interaction( + request=request, + channel=channel, + source=source, + userid=userid, + ) + + def _handle_media_selection( + self, + request: PendingMediaInteraction, + page_index: Optional[int], + channel: MessageChannel, + source: str, + userid: Union[str, int], + username: str, + original_message_id: Optional[Union[str, int]] = None, + original_chat_id: Optional[str] = None, + ) -> None: + """ + 处理媒体选择阶段的序号输入。 + """ + page_items, page, _ = self._page_items( + items=request.items, + page=request.page, + page_size=self._page_size(request.channel), + ) + request.page = page + if not page_index or page_index < 1 or page_index > len(page_items): + self._post_invalid_input( + channel=channel, + source=source, + userid=userid, + username=username, + ) + return + + mediainfo: MediaInfo = page_items[page_index - 1] + request.current_media = mediainfo + + if request.action in {"Search", "ReSearch"}: + self._search_media_resources( + request=request, + mediainfo=mediainfo, + channel=channel, + source=source, + userid=userid, + username=username, + original_message_id=original_message_id, + original_chat_id=original_chat_id, + ) + return + + if request.action in {"Subscribe", "ReSubscribe"}: + self._subscribe_media( + request=request, + mediainfo=mediainfo, + channel=channel, + source=source, + userid=userid, + username=username, + ) + + def _search_media_resources( + self, + request: PendingMediaInteraction, + mediainfo: MediaInfo, + channel: MessageChannel, + source: str, + userid: Union[str, int], + username: str, + original_message_id: Optional[Union[str, int]] = None, + original_chat_id: Optional[str] = None, + ) -> None: + """ + 根据已选媒体搜索资源,并切换到资源选择阶段。 + """ + exist_flag, no_exists = DownloadChain().get_no_exists_info( + meta=request.meta, + mediainfo=mediainfo, + ) + if exist_flag and request.action == "Search": + self.post_message( + Notification( + channel=channel, + source=source, + userid=userid, + username=username, + title=f"【{mediainfo.title_year}{request.meta.sea} 媒体库中已存在,如需重新下载请发送:搜索 名称 或 下载 名称】", + ) + ) + return + if exist_flag: + no_exists = self._get_noexits_info(request.meta, mediainfo) + + messages = self._build_no_exists_messages( + mediainfo=mediainfo, + no_exists=no_exists, + show_missing_only=request.action == "Search", + ) + if messages: + self.post_message( + Notification( + channel=channel, + source=source, + userid=userid, + username=username, + title=f"{mediainfo.title_year}:\n" + "\n".join(messages), + ) + ) + + logger.info("开始搜索 %s ...", mediainfo.title_year) + self.post_message( + Notification( + channel=channel, + source=source, + userid=userid, + username=username, + title=f"开始搜索 {mediainfo.type.value} {mediainfo.title_year} ...", + ) + ) + + contexts = SearchChain().process(mediainfo=mediainfo, no_exists=no_exists) + if not contexts: + self.post_message( + Notification( + channel=channel, + source=source, + userid=userid, + username=username, + title=f"{mediainfo.title}{request.meta.sea} 未搜索到需要的资源!", + ) + ) + return + + contexts = TorrentHelper().sort_torrents(contexts) + if self._should_auto_download(userid): + logger.info("用户 %s 在自动下载用户中,开始自动择优下载 ...", userid) + self._auto_download( + request=request, + cache_list=contexts, + channel=channel, + source=source, + userid=userid, + username=username, + no_exists=no_exists, + ) + return + + request.phase = "torrent" + request.page = 0 + request.title = mediainfo.title + request.items = list(contexts) + self._render_interaction( + request=request, + channel=channel, + source=source, + userid=userid, + original_message_id=original_message_id, + original_chat_id=original_chat_id, + ) + + def _subscribe_media( + self, + request: PendingMediaInteraction, + mediainfo: MediaInfo, + channel: MessageChannel, + source: str, + userid: Union[str, int], + username: str, + ) -> None: + """ + 根据已选媒体创建订阅或洗版订阅。 + """ + best_version = request.action == "ReSubscribe" + if not best_version: + exist_flag, _ = DownloadChain().get_no_exists_info( + meta=request.meta, + mediainfo=mediainfo, + ) + if exist_flag: + self.post_message( + Notification( + channel=channel, + source=source, + userid=userid, + username=username, + title=f"【{mediainfo.title_year}{request.meta.sea} 媒体库中已存在,如需洗版请发送:洗版 XXX】", + ) + ) + return + + mp_name = ( + UserOper().get_name(**{f"{channel.name.lower()}_userid": userid}) + if channel + else None + ) + SubscribeChain().add( + title=mediainfo.title, + year=mediainfo.year, + mtype=mediainfo.type, + tmdbid=mediainfo.tmdb_id, + season=request.meta.begin_season, + channel=channel, + source=source, + userid=userid, + username=mp_name or username, + best_version=best_version, + ) + + def _handle_torrent_selection( + self, + request: PendingMediaInteraction, + page_index: Optional[int], + channel: MessageChannel, + source: str, + userid: Union[str, int], + username: str, + ) -> None: + """ + 处理资源选择阶段的下载操作。 + """ + if request.phase != "torrent": + self._post_invalid_input( + channel=channel, + source=source, + userid=userid, + username=username, + ) + return + + if page_index == 0: + self._auto_download( + request=request, + cache_list=request.items, + channel=channel, + source=source, + userid=userid, + username=username, + ) + return + + page_items, page, _ = self._page_items( + items=request.items, + page=request.page, + page_size=self._page_size(request.channel), + ) + request.page = page + if not page_index or page_index < 1 or page_index > len(page_items): + self._post_invalid_input( + channel=channel, + source=source, + userid=userid, + username=username, + ) + return + + context: Context = page_items[page_index - 1] + DownloadChain().download_single( + context, + channel=channel, + source=source, + userid=userid, + username=username, + ) + + def _auto_download( + self, + request: PendingMediaInteraction, + cache_list: List[Context], + channel: MessageChannel, + source: str, + userid: Union[str, int], + username: str, + no_exists: Optional[Dict[Union[int, str], Dict[int, NotExistMediaInfo]]] = None, + ) -> None: + """ + 自动择优下载当前资源列表,并在未完成时补建订阅。 + """ + downloadchain = DownloadChain() + if no_exists is None: + exist_flag, no_exists = downloadchain.get_no_exists_info( + meta=request.meta, + mediainfo=request.current_media, + ) + if exist_flag: + no_exists = self._get_noexits_info(request.meta, request.current_media) + + downloads, lefts = downloadchain.batch_download( + contexts=cache_list, + no_exists=no_exists, + channel=channel, + source=source, + userid=userid, + username=username, + ) + if downloads and not lefts: + logger.info("%s 下载完成", request.current_media.title_year) + return + + logger.info("%s 未下载未完整,添加订阅 ...", request.current_media.title_year) + if downloads and request.current_media.type == MediaType.TV: + note = [ + download.meta_info.begin_episode + for download in downloads + if download.meta_info.begin_episode + ] + else: + note = None + + mp_name = ( + UserOper().get_name(**{f"{channel.name.lower()}_userid": userid}) + if channel + else None + ) + SubscribeChain().add( + title=request.current_media.title, + year=request.current_media.year, + mtype=request.current_media.type, + tmdbid=request.current_media.tmdb_id, + season=request.meta.begin_season, + channel=channel, + source=source, + userid=userid, + username=mp_name or username, + state="R", + note=note, + ) + + def _render_interaction( + self, + request: PendingMediaInteraction, + channel: MessageChannel, + source: str, + userid: Union[str, int], + original_message_id: Optional[Union[str, int]] = None, + original_chat_id: Optional[str] = None, + ) -> None: + """ + 按当前阶段渲染媒体列表或资源列表。 + """ + if request.phase == "torrent": + self._post_torrents_message( + request=request, + channel=channel, + source=source, + userid=userid, + original_message_id=original_message_id, + original_chat_id=original_chat_id, + ) + else: + self._post_medias_message( + request=request, + channel=channel, + source=source, + userid=userid, + original_message_id=original_message_id, + original_chat_id=original_chat_id, + ) + + def _post_medias_message( + self, + request: PendingMediaInteraction, + channel: MessageChannel, + source: str, + userid: Union[str, int], + original_message_id: Optional[Union[str, int]] = None, + original_chat_id: Optional[str] = None, + ) -> None: + """ + 发送或更新媒体选择列表。 + """ + page_items, page, total_pages = self._page_items( + items=request.items, + page=request.page, + page_size=self._page_size(channel), + ) + request.page = page + total = len(request.items) + if self._supports_interactive_buttons(channel): + title = f"【{request.title}】共找到{total}条相关信息,请选择操作" + buttons = self._create_media_buttons( + channel=channel, + request=request, + items=page_items, + total=total, + total_pages=total_pages, + ) + else: + if total > self._page_size(channel): + title = f"【{request.title}】共找到{total}条相关信息,请回复对应数字选择(p: 上一页 n: 下一页)" + else: + title = f"【{request.title}】共找到{total}条相关信息,请回复对应数字选择" + buttons = None + + self.post_medias_message( + Notification( + channel=channel, + source=source, + title=title, + userid=userid, + buttons=buttons, + original_message_id=original_message_id, + original_chat_id=original_chat_id, + ), + medias=page_items, + ) + + def _post_torrents_message( + self, + request: PendingMediaInteraction, + channel: MessageChannel, + source: str, + userid: Union[str, int], + original_message_id: Optional[Union[str, int]] = None, + original_chat_id: Optional[str] = None, + ) -> None: + """ + 发送或更新资源选择列表。 + """ + page_items, page, total_pages = self._page_items( + items=request.items, + page=request.page, + page_size=self._page_size(channel), + ) + request.page = page + total = len(request.items) + if self._supports_interactive_buttons(channel): + title = f"【{request.title}】共找到{total}条相关资源,请选择下载" + buttons = self._create_torrent_buttons( + channel=channel, + request=request, + items=page_items, + total=total, + total_pages=total_pages, + ) + else: + if total > self._page_size(channel): + title = f"【{request.title}】共找到{total}条相关资源,请回复对应数字下载(0: 自动选择 p: 上一页 n: 下一页)" + else: + title = f"【{request.title}】共找到{total}条相关资源,请回复对应数字下载(0: 自动选择)" + buttons = None + + self.post_torrents_message( + Notification( + channel=channel, + source=source, + title=title, + userid=userid, + link=settings.MP_DOMAIN("#/resource"), + buttons=buttons, + original_message_id=original_message_id, + original_chat_id=original_chat_id, + ), + torrents=page_items, + ) + + def _create_media_buttons( + self, + channel: MessageChannel, + request: PendingMediaInteraction, + items: List[MediaInfo], + total: int, + total_pages: int, + ) -> List[List[Dict[str, str]]]: + """ + 为媒体列表生成选择和翻页按钮。 + """ + buttons: List[List[Dict[str, str]]] = [] + max_text_length = ChannelCapabilityManager.get_max_button_text_length(channel) + max_per_row = ChannelCapabilityManager.get_max_buttons_per_row(channel) + + current_row: List[Dict[str, str]] = [] + for index, media in enumerate(items, start=1): + if max_per_row == 1: + button_text = f"{index}. {media.title_year}" + if len(button_text) > max_text_length: + button_text = button_text[: max_text_length - 3] + "..." + buttons.append( + [ + { + "text": button_text, + "callback_data": f"media:{request.request_id}:select:{index}", + } + ] + ) + continue + + current_row.append( + { + "text": f"{index}", + "callback_data": f"media:{request.request_id}:select:{index}", + } + ) + if len(current_row) == max_per_row or index == len(items): + buttons.append(current_row) + current_row = [] + + if total > self._page_size(channel): + buttons.extend(self._navigation_buttons(request, total_pages)) + return buttons + + def _create_torrent_buttons( + self, + channel: MessageChannel, + request: PendingMediaInteraction, + items: List[Context], + total: int, + total_pages: int, + ) -> List[List[Dict[str, str]]]: + """ + 为资源列表生成下载和翻页按钮。 + """ + buttons: List[List[Dict[str, str]]] = [ + [ + { + "text": "🤖 自动选择下载", + "callback_data": f"media:{request.request_id}:download:0", + } + ] + ] + max_text_length = ChannelCapabilityManager.get_max_button_text_length(channel) + max_per_row = ChannelCapabilityManager.get_max_buttons_per_row(channel) + + current_row: List[Dict[str, str]] = [] + for index, context in enumerate(items, start=1): + torrent = context.torrent_info + if max_per_row == 1: + button_text = f"{index}. {torrent.site_name} - {torrent.seeders}↑" + if len(button_text) > max_text_length: + button_text = button_text[: max_text_length - 3] + "..." + buttons.append( + [ + { + "text": button_text, + "callback_data": f"media:{request.request_id}:download:{index}", + } + ] + ) + continue + + current_row.append( + { + "text": f"{index}", + "callback_data": f"media:{request.request_id}:download:{index}", + } + ) + if len(current_row) == max_per_row or index == len(items): + buttons.append(current_row) + current_row = [] + + if total > self._page_size(channel): + buttons.extend(self._navigation_buttons(request, total_pages)) + return buttons + + def _has_next_page(self, request: PendingMediaInteraction) -> bool: + """ + 判断当前视图是否还有下一页。 + """ + _, page, total_pages = self._page_items( + items=request.items, + page=request.page, + page_size=self._page_size(request.channel), + ) + return page < total_pages - 1 + + @staticmethod + def _navigation_buttons( + request: PendingMediaInteraction, + total_pages: int, + ) -> List[List[Dict[str, str]]]: + """ + 按当前页状态生成上一页和下一页按钮。 + """ + buttons: List[List[Dict[str, str]]] = [] + nav_row: List[Dict[str, str]] = [] + if request.page > 0: + nav_row.append( + { + "text": "⬅️ 上一页", + "callback_data": f"media:{request.request_id}:page-prev", + } + ) + if request.page < total_pages - 1: + nav_row.append( + { + "text": "下一页 ➡️", + "callback_data": f"media:{request.request_id}:page-next", + } + ) + if nav_row: + buttons.append(nav_row) + return buttons + + @staticmethod + def _page_items( + items: List[Any], + page: int, + page_size: int, + ) -> Tuple[List[Any], int, int]: + """ + 返回当前页数据,并把页码限制在有效范围内。 + """ + total_pages = max(1, math.ceil(len(items) / page_size)) if page_size else 1 + page = min(max(0, page), total_pages - 1) + start = page * page_size + end = start + page_size + return items[start:end], page, total_pages + + def _page_size(self, channel: Optional[MessageChannel]) -> int: + """ + 按渠道交互能力选择分页大小。 + """ + return ( + self._button_page_size + if self._supports_interactive_buttons(channel) + else self._text_page_size + ) + + @staticmethod + def _supports_interactive_buttons(channel: Optional[MessageChannel]) -> bool: + """ + 判断渠道是否同时支持按钮展示与按钮回调。 + """ + return bool( + channel + and ChannelCapabilityManager.supports_buttons(channel) + and ChannelCapabilityManager.supports_callbacks(channel) + ) + + @staticmethod + def _build_no_exists_messages( + mediainfo: MediaInfo, + no_exists: Optional[Dict[Union[int, str], Dict[int, NotExistMediaInfo]]], + show_missing_only: bool, + ) -> List[str]: + """ + 将缺失集信息转换为可发送的文案。 + """ + if not no_exists: + return [] + mediakey = mediainfo.tmdb_id or mediainfo.douban_id + season_map = no_exists.get(mediakey) or {} + if show_missing_only: + return [ + f"第 {sea} 季缺失 {StringUtils.str_series(no_exist.episodes) if no_exist.episodes else no_exist.total_episode} 集" + for sea, no_exist in season_map.items() + ] + return [ + f"第 {sea} 季总 {no_exist.total_episode} 集" + for sea, no_exist in season_map.items() + ] + + @staticmethod + def _should_auto_download(userid: Union[str, int]) -> bool: + """ + 判断当前用户是否命中自动下载名单。 + """ + auto_download_user = settings.AUTO_DOWNLOAD_USER + return bool( + auto_download_user + and ( + auto_download_user == "all" + or any(userid == user for user in auto_download_user.split(",")) + ) + ) + + def _post_invalid_input( + self, + channel: MessageChannel, + source: str, + userid: Union[str, int], + username: Optional[str], + title: str = "输入有误!", + ) -> None: + """ + 发送统一的非法输入提示。 + """ + self.post_message( + Notification( + channel=channel, + source=source, + userid=userid, + username=username, + title=title, + ) + ) diff --git a/app/chain/message.py b/app/chain/message.py index d8d7ba55..db749994 100644 --- a/app/chain/message.py +++ b/app/chain/message.py @@ -11,108 +11,34 @@ import uuid import base64 from app.agent import agent_manager -from app.agent.interaction import agent_interaction_manager from app.chain import ChainBase -from app.chain.download import DownloadChain -from app.chain.media import MediaChain -from app.chain.search import SearchChain +from app.chain.interaction import ( + MediaInteractionChain, + agent_interaction_manager, + media_interaction_manager, +) from app.chain.skills import SkillsChain, skills_interaction_manager -from app.chain.subscribe import SubscribeChain from app.chain.transfer import TransferChain from app.core.config import settings, global_vars -from app.core.context import MediaInfo, Context -from app.core.meta import MetaBase -from app.db.user_oper import UserOper -from app.helper.torrent import TorrentHelper from app.helper.llm import LLMHelper from app.helper.voice import VoiceHelper from app.log import logger -from app.schemas import Notification, NotExistMediaInfo, CommingMessage +from app.schemas import Notification, CommingMessage from app.schemas.message import ChannelCapabilityManager -from app.schemas.types import EventType, MessageChannel, MediaType -from app.utils.string import StringUtils +from app.schemas.types import EventType, MessageChannel from app.utils.http import RequestUtils -# 当前页面 -_current_page: int = 0 -# 当前元数据 -_current_meta: Optional[MetaBase] = None -# 当前媒体信息 -_current_media: Optional[MediaInfo] = None - class MessageChain(ChainBase): """ 外来消息处理链 """ - # 缓存的用户数据 {userid: {type: str, items: list}} - _cache_file = "__user_messages__" - # 每页数据量 - _page_size: int = 8 # 用户会话信息 {userid: (session_id, last_time)} _user_sessions: Dict[Union[str, int], tuple] = {} # 会话超时时间(分钟) _session_timeout_minutes: int = 24 * 60 - @staticmethod - def __get_noexits_info( - _meta: MetaBase, _mediainfo: MediaInfo - ) -> Dict[Union[int, str], Dict[int, NotExistMediaInfo]]: - """ - 获取缺失的媒体信息 - """ - if _mediainfo.type == MediaType.TV: - if not _mediainfo.seasons: - # 补充媒体信息 - _mediainfo = MediaChain().recognize_media( - mtype=_mediainfo.type, - tmdbid=_mediainfo.tmdb_id, - doubanid=_mediainfo.douban_id, - cache=False, - ) - if not _mediainfo: - logger.warn( - f"{_mediainfo.tmdb_id or _mediainfo.douban_id} 媒体信息识别失败!" - ) - return {} - if not _mediainfo.seasons: - logger.warn( - f"媒体信息中没有季集信息," - f"标题:{_mediainfo.title}," - f"tmdbid:{_mediainfo.tmdb_id},doubanid:{_mediainfo.douban_id}" - ) - return {} - # KEY - _mediakey = _mediainfo.tmdb_id or _mediainfo.douban_id - _no_exists = {_mediakey: {}} - if _meta.begin_season: - # 指定季 - episodes = _mediainfo.seasons.get(_meta.begin_season) - if not episodes: - return {} - _no_exists[_mediakey][_meta.begin_season] = NotExistMediaInfo( - season=_meta.begin_season, - episodes=[], - total_episode=len(episodes), - start_episode=episodes[0], - ) - else: - # 所有季 - for sea, eps in _mediainfo.seasons.items(): - if not eps: - continue - _no_exists[_mediakey][sea] = NotExistMediaInfo( - season=sea, - episodes=[], - total_episode=len(eps), - start_episode=eps[0], - ) - else: - _no_exists = {} - - return _no_exists - def process(self, body: Any, form: Any, args: Any) -> None: """ 调用模块识别消息内容 @@ -181,590 +107,132 @@ class MessageChain(ChainBase): """ 识别消息内容,执行操作 """ - # 申明全局变量 - global _current_page, _current_meta, _current_media + images = CommingMessage.MessageImage.normalize_list(images) - # 加载缓存 - user_cache: Dict[str, dict] = self.load_cache(self._cache_file) or {} - - try: - images = CommingMessage.MessageImage.normalize_list(images) - - # 识别语音为文本 - reply_with_voice = bool(audio_refs) - if audio_refs: - transcript = self._transcribe_audio_refs(audio_refs, channel, source) - merged_parts = [] - seen_parts = set() - for item in [text.strip() if text else "", transcript or ""]: - normalized = item.strip() - if not normalized or normalized in seen_parts: - continue - seen_parts.add(normalized) - merged_parts.append(normalized) - text = "\n".join(merged_parts).strip() - if not text: - self.post_message( - Notification( - channel=channel, - source=source, - userid=userid, - username=username, - title="语音识别失败,请稍后重试", - ) - ) - return - - # 保存消息 - if not text.startswith("CALLBACK:"): - self._record_user_message( - channel=channel, - source=source, - userid=userid, - username=username, - text=text, - ) - # 处理消息 - if text.startswith("CALLBACK:"): - # 处理按钮回调(适配支持回调的渠),优先级最高 - if ChannelCapabilityManager.supports_callbacks(channel): - self._handle_callback( - text=text, + # 识别语音为文本 + reply_with_voice = bool(audio_refs) + if audio_refs: + transcript = self._transcribe_audio_refs(audio_refs, channel, source) + merged_parts = [] + seen_parts = set() + for item in [text.strip() if text else "", transcript or ""]: + normalized = item.strip() + if not normalized or normalized in seen_parts: + continue + seen_parts.add(normalized) + merged_parts.append(normalized) + text = "\n".join(merged_parts).strip() + if not text: + self.post_message( + Notification( channel=channel, source=source, userid=userid, username=username, - original_message_id=original_message_id, - original_chat_id=original_chat_id, + title="语音识别失败,请稍后重试", ) - else: - logger.warning( - f"渠道 {channel.value} 不支持回调,但收到了回调消息:{text}" - ) - elif text.startswith("/") and not text.lower().startswith("/ai"): - # 执行特定命令命令(但不是/ai) - self.eventmanager.send_event( - EventType.CommandExcute, - {"cmd": text, "user": userid, "channel": channel, "source": source}, ) - elif skills_interaction_manager.get_by_user(userid): - if SkillsChain().handle_text_interaction( - channel=channel, - source=source, - userid=userid, - username=username, - text=text, - ): - return - elif text.lower().startswith("/ai"): - self._handle_ai_message( + return + + if not text.startswith("CALLBACK:"): + self._record_user_message( + channel=channel, + source=source, + userid=userid, + username=username, + text=text, + ) + + if text.startswith("CALLBACK:"): + if ChannelCapabilityManager.supports_callbacks(channel): + self._handle_callback( text=text, channel=channel, source=source, userid=userid, username=username, - images=images, - files=files, - reply_with_voice=reply_with_voice, - ) - elif settings.AI_AGENT_ENABLE and ( - settings.AI_AGENT_GLOBAL or images or files - ): - # 普通消息,全局智能体响应 - self._handle_ai_message( - text=text, - channel=channel, - source=source, - userid=userid, - username=username, - images=images, - files=files, - reply_with_voice=reply_with_voice, + original_message_id=original_message_id, + original_chat_id=original_chat_id, ) else: - # 非智能体普通消息响应 - if text.isdigit(): - # 用户选择了具体的条目 - # 缓存 - cache_data: dict = user_cache.get(userid) - if not cache_data: - # 发送消息 - self.post_message( - Notification( - channel=channel, - source=source, - title="输入有误!", - userid=userid, - ) - ) - return - cache_data = cache_data.copy() - # 选择项目 - if not cache_data.get("items") or len( - cache_data.get("items") - ) < int(text): - # 发送消息 - self.post_message( - Notification( - channel=channel, - source=source, - title="输入有误!", - userid=userid, - ) - ) - return - try: - # 选择的序号 - _choice = int(text) + _current_page * self._page_size - 1 - # 缓存类型 - cache_type: str = cache_data.get("type") - # 缓存列表 - cache_list: list = cache_data.get("items").copy() - # 选择 - try: - if cache_type in ["Search", "ReSearch"]: - # 当前媒体信息 - mediainfo: MediaInfo = cache_list[_choice] - _current_media = mediainfo - # 查询缺失的媒体信息 - exist_flag, no_exists = ( - DownloadChain().get_no_exists_info( - meta=_current_meta, mediainfo=_current_media - ) - ) - if exist_flag and cache_type == "Search": - # 媒体库中已存在 - self.post_message( - Notification( - channel=channel, - source=source, - title=f"【{_current_media.title_year}" - f"{_current_meta.sea} 媒体库中已存在,如需重新下载请发送:搜索 名称 或 下载 名称】", - userid=userid, - ) - ) - return - elif exist_flag: - # 没有缺失,但要全量重新搜索和下载 - no_exists = self.__get_noexits_info( - _current_meta, _current_media - ) - # 发送缺失的媒体信息 - messages = [] - if no_exists and cache_type == "Search": - # 发送缺失消息 - mediakey = mediainfo.tmdb_id or mediainfo.douban_id - messages = [ - f"第 {sea} 季缺失 {StringUtils.str_series(no_exist.episodes) if no_exist.episodes else no_exist.total_episode} 集" - for sea, no_exist in no_exists.get( - mediakey - ).items() - ] - elif no_exists: - # 发送总集数的消息 - mediakey = mediainfo.tmdb_id or mediainfo.douban_id - messages = [ - f"第 {sea} 季总 {no_exist.total_episode} 集" - for sea, no_exist in no_exists.get( - mediakey - ).items() - ] - if messages: - self.post_message( - Notification( - channel=channel, - source=source, - title=f"{mediainfo.title_year}:\n" - + "\n".join(messages), - userid=userid, - ) - ) - # 搜索种子,过滤掉不需要的剧集,以便选择 - logger.info(f"开始搜索 {mediainfo.title_year} ...") - self.post_message( - Notification( - channel=channel, - source=source, - title=f"开始搜索 {mediainfo.type.value} {mediainfo.title_year} ...", - userid=userid, - ) - ) - # 开始搜索 - contexts = SearchChain().process( - mediainfo=mediainfo, no_exists=no_exists - ) - if not contexts: - # 没有数据 - self.post_message( - Notification( - channel=channel, - source=source, - title=f"{mediainfo.title}" - f"{_current_meta.sea} 未搜索到需要的资源!", - userid=userid, - ) - ) - return - # 搜索结果排序 - contexts = TorrentHelper().sort_torrents(contexts) - try: - # 判断是否设置自动下载 - auto_download_user = settings.AUTO_DOWNLOAD_USER - # 匹配到自动下载用户 - if auto_download_user and ( - auto_download_user == "all" - or any( - userid == user - for user in auto_download_user.split(",") - ) - ): - logger.info( - f"用户 {userid} 在自动下载用户中,开始自动择优下载 ..." - ) - # 自动选择下载 - self.__auto_download( - channel=channel, - source=source, - cache_list=contexts, - userid=userid, - username=username, - no_exists=no_exists, - ) - else: - # 更新缓存 - user_cache[userid] = { - "type": "Torrent", - "items": contexts, - } - _current_page = 0 - # 保存缓存 - self.save_cache(user_cache, self._cache_file) - # 删除原消息 - if ( - original_message_id - and original_chat_id - and ChannelCapabilityManager.supports_deletion( - channel - ) - ): - self.delete_message( - channel=channel, - source=source, - message_id=original_message_id, - chat_id=original_chat_id, - ) - # 发送种子数据 - logger.info( - f"搜索到 {len(contexts)} 条数据,开始发送选择消息 ..." - ) - self.__post_torrents_message( - channel=channel, - source=source, - title=mediainfo.title, - items=contexts[: self._page_size], - userid=userid, - total=len(contexts), - ) - finally: - contexts.clear() - del contexts - elif cache_type in ["Subscribe", "ReSubscribe"]: - # 订阅或洗版媒体 - mediainfo: MediaInfo = cache_list[_choice] - # 洗版标识 - best_version = False - # 查询缺失的媒体信息 - if cache_type == "Subscribe": - exist_flag, _ = DownloadChain().get_no_exists_info( - meta=_current_meta, mediainfo=mediainfo - ) - if exist_flag: - self.post_message( - Notification( - channel=channel, - source=source, - title=f"【{mediainfo.title_year}" - f"{_current_meta.sea} 媒体库中已存在,如需洗版请发送:洗版 XXX】", - userid=userid, - ) - ) - return - else: - best_version = True - # 转换用户名 - mp_name = ( - UserOper().get_name( - **{f"{channel.name.lower()}_userid": userid} - ) - if channel - else None - ) - # 添加订阅,状态为N - SubscribeChain().add( - title=mediainfo.title, - year=mediainfo.year, - mtype=mediainfo.type, - tmdbid=mediainfo.tmdb_id, - season=_current_meta.begin_season, - channel=channel, - source=source, - userid=userid, - username=mp_name or username, - best_version=best_version, - ) - elif cache_type == "Torrent": - if int(text) == 0: - # 自动选择下载,强制下载模式 - self.__auto_download( - channel=channel, - source=source, - cache_list=cache_list, - userid=userid, - username=username, - ) - else: - # 下载种子 - context: Context = cache_list[_choice] - # 下载 - DownloadChain().download_single( - context, - channel=channel, - source=source, - userid=userid, - username=username, - ) - finally: - cache_list.clear() - del cache_list - finally: - cache_data.clear() - del cache_data - elif text.lower() == "p": - # 上一页 - cache_data: dict = user_cache.get(userid) - if not cache_data: - # 没有缓存 - self.post_message( - Notification( - channel=channel, - source=source, - title="输入有误!", - userid=userid, - ) - ) - return - cache_data = cache_data.copy() - try: - if _current_page == 0: - # 第一页 - self.post_message( - Notification( - channel=channel, - source=source, - title="已经是第一页了!", - userid=userid, - ) - ) - return - # 减一页 - _current_page -= 1 - cache_type: str = cache_data.get("type") - # 产生副本,避免修改原值 - cache_list: list = cache_data.get("items").copy() - try: - if _current_page == 0: - start = 0 - end = self._page_size - else: - start = _current_page * self._page_size - end = start + self._page_size - if cache_type == "Torrent": - # 发送种子数据 - self.__post_torrents_message( - channel=channel, - source=source, - title=_current_media.title, - items=cache_list[start:end], - userid=userid, - total=len(cache_list), - original_message_id=original_message_id, - original_chat_id=original_chat_id, - ) - else: - # 发送媒体数据 - self.__post_medias_message( - channel=channel, - source=source, - title=_current_meta.name, - items=cache_list[start:end], - userid=userid, - total=len(cache_list), - original_message_id=original_message_id, - original_chat_id=original_chat_id, - ) - finally: - cache_list.clear() - del cache_list - finally: - cache_data.clear() - del cache_data - elif text.lower() == "n": - # 下一页 - cache_data: dict = user_cache.get(userid) - if not cache_data: - # 没有缓存 - self.post_message( - Notification( - channel=channel, - source=source, - title="输入有误!", - userid=userid, - ) - ) - return - cache_data = cache_data.copy() - try: - cache_type: str = cache_data.get("type") - # 产生副本,避免修改原值 - cache_list: list = cache_data.get("items").copy() - total = len(cache_list) - # 加一页 - cache_list = cache_list[ - (_current_page + 1) * self._page_size : (_current_page + 2) - * self._page_size - ] - if not cache_list: - # 没有数据 - self.post_message( - Notification( - channel=channel, - source=source, - title="已经是最后一页了!", - userid=userid, - ) - ) - return - else: - try: - # 加一页 - _current_page += 1 - if cache_type == "Torrent": - # 发送种子数据 - self.__post_torrents_message( - channel=channel, - source=source, - title=_current_media.title, - items=cache_list, - userid=userid, - total=total, - original_message_id=original_message_id, - original_chat_id=original_chat_id, - ) - else: - # 发送媒体数据 - self.__post_medias_message( - channel=channel, - source=source, - title=_current_meta.name, - items=cache_list, - userid=userid, - total=total, - original_message_id=original_message_id, - original_chat_id=original_chat_id, - ) - finally: - cache_list.clear() - del cache_list - finally: - cache_data.clear() - del cache_data - else: - # 搜索或订阅 - if text.startswith("订阅"): - # 订阅 - content = re.sub(r"订阅[::\s]*", "", text) - action = "Subscribe" - elif text.startswith("洗版"): - # 洗版 - content = re.sub(r"洗版[::\s]*", "", text) - action = "ReSubscribe" - elif text.startswith("搜索") or text.startswith("下载"): - # 重新搜索/下载 - content = re.sub(r"(搜索|下载)[::\s]*", "", text) - action = "ReSearch" - elif StringUtils.is_link(text): - # 链接 - content = text - action = "Link" - elif not StringUtils.is_media_title_like(text): - # 聊天 - content = text - action = "Chat" - else: - # 搜索 - content = text - action = "Search" + logger.warning( + "渠道 %s 不支持回调,但收到了回调消息:%s", + channel.value, + text, + ) + return - if action in ["Search", "ReSearch", "Subscribe", "ReSubscribe"]: - # 搜索 - meta, medias = MediaChain().search(content) - # 识别 - if not meta.name: - self.post_message( - Notification( - channel=channel, - source=source, - title="无法识别输入内容!", - userid=userid, - ) - ) - return - # 开始搜索 - if not medias: - self.post_message( - Notification( - channel=channel, - source=source, - title=f"{meta.name} 没有找到对应的媒体信息!", - userid=userid, - ) - ) - return - logger.info(f"搜索到 {len(medias)} 条相关媒体信息") - try: - # 记录当前状态 - _current_meta = meta - # 保存缓存 - user_cache[userid] = {"type": action, "items": medias} - self.save_cache(user_cache, self._cache_file) - _current_page = 0 - _current_media = None - # 发送媒体列表 - self.__post_medias_message( - channel=channel, - source=source, - title=meta.name, - items=medias[: self._page_size], - userid=userid, - total=len(medias), - ) - finally: - medias.clear() - del medias - else: - # 广播事件 - self.eventmanager.send_event( - EventType.UserMessage, - { - "text": content, - "userid": userid, - "channel": channel, - "source": source, - }, - ) - finally: - user_cache.clear() - del user_cache + if text.startswith("/") and not text.lower().startswith("/ai"): + self.eventmanager.send_event( + EventType.CommandExcute, + {"cmd": text, "user": userid, "channel": channel, "source": source}, + ) + return + + if skills_interaction_manager.get_by_user(userid): + if SkillsChain().handle_text_interaction( + channel=channel, + source=source, + userid=userid, + username=username, + text=text, + ): + return + + if media_interaction_manager.get_by_user(userid): + if MediaInteractionChain().handle_text_interaction( + channel=channel, + source=source, + userid=userid, + username=username, + text=text, + ): + return + + if text.lower().startswith("/ai"): + self._handle_ai_message( + text=text, + channel=channel, + source=source, + userid=userid, + username=username, + images=images, + files=files, + reply_with_voice=reply_with_voice, + ) + return + + if settings.AI_AGENT_ENABLE and (settings.AI_AGENT_GLOBAL or images or files): + self._handle_ai_message( + text=text, + channel=channel, + source=source, + userid=userid, + username=username, + images=images, + files=files, + reply_with_voice=reply_with_voice, + ) + return + + if MediaInteractionChain().handle_text_interaction( + channel=channel, + source=source, + userid=userid, + username=username, + text=text, + ): + return + + self.eventmanager.send_event( + EventType.UserMessage, + { + "text": text, + "userid": userid, + "channel": channel, + "source": source, + }, + ) def _handle_callback( self, @@ -780,8 +248,6 @@ class MessageChain(ChainBase): 处理按钮回调 """ - global _current_media - # 提取回调数据 callback_data = text[9:] # 去掉 "CALLBACK:" 前缀 logger.info(f"处理按钮回调:{callback_data}") @@ -806,6 +272,17 @@ class MessageChain(ChainBase): ): return + if MediaInteractionChain().handle_callback_interaction( + callback_data=callback_data, + channel=channel, + source=source, + userid=userid, + username=username, + original_message_id=original_message_id, + original_chat_id=original_chat_id, + ): + return + if self._handle_agent_choice_callback( callback_data=callback_data, channel=channel, @@ -836,29 +313,16 @@ class MessageChain(ChainBase): ) return - # 解析系统回调数据 - try: - page_text = callback_data.split("_", 1)[1] - self.handle_message( + logger.error(f"回调数据格式错误:{callback_data}") + self.post_message( + Notification( channel=channel, source=source, userid=userid, username=username, - text=page_text, - original_message_id=original_message_id, - original_chat_id=original_chat_id, - ) - except IndexError: - logger.error(f"回调数据格式错误:{callback_data}") - self.post_message( - Notification( - channel=channel, - source=source, - userid=userid, - username=username, - title="回调数据格式错误,请检查!", - ) + title="回调数据格式错误,请检查!", ) + ) @staticmethod def _parse_transfer_callback( @@ -1149,276 +613,6 @@ class MessageChain(ChainBase): asyncio.run_coroutine_threadsafe(_run_ai_takeover(), global_vars.loop) - def __auto_download( - self, - channel: MessageChannel, - source: str, - cache_list: list[Context], - userid: Union[str, int], - username: str, - no_exists: Optional[Dict[Union[int, str], Dict[int, NotExistMediaInfo]]] = None, - ): - """ - 自动择优下载 - """ - downloadchain = DownloadChain() - if no_exists is None: - # 查询缺失的媒体信息 - exist_flag, no_exists = downloadchain.get_no_exists_info( - meta=_current_meta, mediainfo=_current_media - ) - if exist_flag: - # 媒体库中已存在,查询全量 - no_exists = self.__get_noexits_info(_current_meta, _current_media) - - # 批量下载 - downloads, lefts = downloadchain.batch_download( - contexts=cache_list, - no_exists=no_exists, - channel=channel, - source=source, - userid=userid, - username=username, - ) - if downloads and not lefts: - # 全部下载完成 - logger.info(f"{_current_media.title_year} 下载完成") - else: - # 未完成下载 - logger.info(f"{_current_media.title_year} 未下载未完整,添加订阅 ...") - if downloads and _current_media.type == MediaType.TV: - # 获取已下载剧集 - downloaded = [ - download.meta_info.begin_episode - for download in downloads - if download.meta_info.begin_episode - ] - note = downloaded - else: - note = None - # 转换用户名 - mp_name = ( - UserOper().get_name(**{f"{channel.name.lower()}_userid": userid}) - if channel - else None - ) - # 添加订阅,状态为R - SubscribeChain().add( - title=_current_media.title, - year=_current_media.year, - mtype=_current_media.type, - tmdbid=_current_media.tmdb_id, - season=_current_meta.begin_season, - channel=channel, - source=source, - userid=userid, - username=mp_name or username, - state="R", - note=note, - ) - - def __post_medias_message( - self, - channel: MessageChannel, - source: str, - title: str, - items: list, - userid: str, - total: int, - original_message_id: Optional[Union[str, int]] = None, - original_chat_id: Optional[str] = None, - ): - """ - 发送媒体列表消息 - """ - # 检查渠道是否支持按钮 - supports_buttons = ChannelCapabilityManager.supports_buttons(channel) - - if supports_buttons: - # 支持按钮的渠道 - if total > self._page_size: - title = f"【{title}】共找到{total}条相关信息,请选择操作" - else: - title = f"【{title}】共找到{total}条相关信息,请选择操作" - - buttons = self._create_media_buttons( - channel=channel, items=items, total=total - ) - else: - # 不支持按钮的渠道,使用文本提示 - if total > self._page_size: - title = f"【{title}】共找到{total}条相关信息,请回复对应数字选择(p: 上一页 n: 下一页)" - else: - title = f"【{title}】共找到{total}条相关信息,请回复对应数字选择" - buttons = None - - notification = Notification( - channel=channel, - source=source, - title=title, - userid=userid, - buttons=buttons, - original_message_id=original_message_id, - original_chat_id=original_chat_id, - ) - - self.post_medias_message(notification, medias=items) - - def _create_media_buttons( - self, channel: MessageChannel, items: list, total: int - ) -> List[List[Dict]]: - """ - 创建媒体选择按钮 - """ - global _current_page - - buttons = [] - max_text_length = ChannelCapabilityManager.get_max_button_text_length(channel) - max_per_row = ChannelCapabilityManager.get_max_buttons_per_row(channel) - - # 为每个媒体项创建选择按钮 - current_row = [] - for i in range(len(items)): - media = items[i] - - if max_per_row == 1: - # 每行一个按钮,使用完整文本 - button_text = f"{i + 1}. {media.title_year}" - if len(button_text) > max_text_length: - button_text = button_text[: max_text_length - 3] + "..." - - buttons.append( - [{"text": button_text, "callback_data": f"select_{i + 1}"}] - ) - else: - # 多按钮一行的情况,使用简化文本 - button_text = f"{i + 1}" - - current_row.append( - {"text": button_text, "callback_data": f"select_{i + 1}"} - ) - - # 如果当前行已满或者是最后一个按钮,添加到按钮列表 - if len(current_row) == max_per_row or i == len(items) - 1: - buttons.append(current_row) - current_row = [] - - # 添加翻页按钮 - if total > self._page_size: - page_buttons = [] - if _current_page > 0: - page_buttons.append({"text": "⬅️ 上一页", "callback_data": "page_p"}) - if (_current_page + 1) * self._page_size < total: - page_buttons.append({"text": "下一页 ➡️", "callback_data": "page_n"}) - if page_buttons: - buttons.append(page_buttons) - - return buttons - - def __post_torrents_message( - self, - channel: MessageChannel, - source: str, - title: str, - items: list, - userid: str, - total: int, - original_message_id: Optional[Union[str, int]] = None, - original_chat_id: Optional[str] = None, - ): - """ - 发送种子列表消息 - """ - # 检查渠道是否支持按钮 - supports_buttons = ChannelCapabilityManager.supports_buttons(channel) - - if supports_buttons: - # 支持按钮的渠道 - if total > self._page_size: - title = f"【{title}】共找到{total}条相关资源,请选择下载" - else: - title = f"【{title}】共找到{total}条相关资源,请选择下载" - - buttons = self._create_torrent_buttons( - channel=channel, items=items, total=total - ) - else: - # 不支持按钮的渠道,使用文本提示 - if total > self._page_size: - title = f"【{title}】共找到{total}条相关资源,请回复对应数字下载(0: 自动选择 p: 上一页 n: 下一页)" - else: - title = f"【{title}】共找到{total}条相关资源,请回复对应数字下载(0: 自动选择)" - buttons = None - - notification = Notification( - channel=channel, - source=source, - title=title, - userid=userid, - link=settings.MP_DOMAIN("#/resource"), - buttons=buttons, - original_message_id=original_message_id, - original_chat_id=original_chat_id, - ) - - self.post_torrents_message(notification, torrents=items) - - def _create_torrent_buttons( - self, channel: MessageChannel, items: list, total: int - ) -> List[List[Dict]]: - """ - 创建种子下载按钮 - """ - - global _current_page - - buttons = [] - max_text_length = ChannelCapabilityManager.get_max_button_text_length(channel) - max_per_row = ChannelCapabilityManager.get_max_buttons_per_row(channel) - - # 自动选择按钮 - buttons.append([{"text": "🤖 自动选择下载", "callback_data": "download_0"}]) - - # 为每个种子项创建下载按钮 - current_row = [] - for i in range(len(items)): - context = items[i] - torrent = context.torrent_info - - if max_per_row == 1: - # 每行一个按钮,使用完整文本 - button_text = f"{i + 1}. {torrent.site_name} - {torrent.seeders}↑" - if len(button_text) > max_text_length: - button_text = button_text[: max_text_length - 3] + "..." - - buttons.append( - [{"text": button_text, "callback_data": f"download_{i + 1}"}] - ) - else: - # 多按钮一行的情况,使用简化文本 - button_text = f"{i + 1}" - - current_row.append( - {"text": button_text, "callback_data": f"download_{i + 1}"} - ) - - # 如果当前行已满或者是最后一个按钮,添加到按钮列表 - if len(current_row) == max_per_row or i == len(items) - 1: - buttons.append(current_row) - current_row = [] - - # 添加翻页按钮 - if total > self._page_size: - page_buttons = [] - if _current_page > 0: - page_buttons.append({"text": "⬅️ 上一页", "callback_data": "page_p"}) - if (_current_page + 1) * self._page_size < total: - page_buttons.append({"text": "下一页 ➡️", "callback_data": "page_n"}) - if page_buttons: - buttons.append(page_buttons) - - return buttons - def _get_or_create_session_id(self, userid: Union[str, int]) -> str: """ 获取或创建会话ID diff --git a/tests/test_agent_interaction.py b/tests/test_agent_interaction.py index ca1992a6..e5b1adc3 100644 --- a/tests/test_agent_interaction.py +++ b/tests/test_agent_interaction.py @@ -8,7 +8,7 @@ from app.agent.tools.impl.ask_user_choice import ( AskUserChoiceTool, UserChoiceOptionInput, ) -from app.agent.interaction import ( +from app.chain.interaction import ( AgentInteractionOption, agent_interaction_manager, ) diff --git a/tests/test_media_interaction.py b/tests/test_media_interaction.py new file mode 100644 index 00000000..862c1d93 --- /dev/null +++ b/tests/test_media_interaction.py @@ -0,0 +1,158 @@ +import sys +import unittest +from types import ModuleType +from unittest.mock import patch + +sys.modules.setdefault("qbittorrentapi", ModuleType("qbittorrentapi")) +setattr(sys.modules["qbittorrentapi"], "TorrentFilesList", list) +sys.modules.setdefault("transmission_rpc", ModuleType("transmission_rpc")) +setattr(sys.modules["transmission_rpc"], "File", object) +sys.modules.setdefault("psutil", ModuleType("psutil")) + +from app.chain.interaction import MediaInteractionChain, media_interaction_manager +from app.chain.message import MessageChain +from app.core.context import MediaInfo +from app.core.meta import MetaBase +from app.schemas.types import MessageChannel + + +class TestMediaInteraction(unittest.TestCase): + def tearDown(self): + media_interaction_manager.clear() + + @staticmethod + def _build_meta(name: str) -> MetaBase: + meta = MetaBase(name) + meta.name = name + meta.begin_season = 1 + return meta + + def test_message_routes_text_reply_to_media_interaction_before_ai(self): + chain = MessageChain() + request = media_interaction_manager.create_or_replace( + user_id="10001", + channel=MessageChannel.Wechat, + source="wechat-test", + username="tester", + action="Search", + keyword="星际穿越", + title="星际穿越", + meta=self._build_meta("星际穿越"), + items=[MediaInfo(title="星际穿越", year="2014")], + ) + self.assertIsNotNone(request) + + with patch.object(chain, "_record_user_message"), patch( + "app.chain.message.MediaInteractionChain.handle_text_interaction", + return_value=True, + ) as handle_text, patch.object(chain, "_handle_ai_message") as handle_ai: + chain.handle_message( + channel=MessageChannel.Wechat, + source="wechat-test", + userid="10001", + username="tester", + text="1", + ) + + handle_text.assert_called_once() + handle_ai.assert_not_called() + + def test_callback_routes_to_media_interaction_chain(self): + chain = MessageChain() + request = media_interaction_manager.create_or_replace( + user_id="10001", + channel=MessageChannel.Telegram, + source="telegram-test", + username="tester", + action="Search", + keyword="星际穿越", + title="星际穿越", + meta=self._build_meta("星际穿越"), + items=[MediaInfo(title="星际穿越", year="2014")], + ) + + with patch( + "app.chain.message.MediaInteractionChain.handle_callback_interaction", + return_value=True, + ) as handle_callback: + chain._handle_callback( + text=f"CALLBACK:media:{request.request_id}:page-next", + channel=MessageChannel.Telegram, + source="telegram-test", + userid="10001", + username="tester", + ) + + handle_callback.assert_called_once() + + def test_media_interaction_starts_search_and_posts_media_list(self): + chain = MediaInteractionChain() + meta = self._build_meta("星际穿越") + medias = [ + MediaInfo(title="星际穿越", year="2014"), + MediaInfo(title="Interstellar", year="2014"), + ] + + with patch( + "app.chain.interaction.MediaChain.search", + return_value=(meta, medias), + ), patch.object(chain, "post_medias_message") as post_medias_message: + handled = chain.handle_text_interaction( + channel=MessageChannel.Telegram, + source="telegram-test", + userid="10001", + username="tester", + text="星际穿越", + ) + + self.assertTrue(handled) + post_medias_message.assert_called_once() + notification = post_medias_message.call_args.args[0] + self.assertTrue(notification.buttons) + self.assertTrue( + notification.buttons[0][0]["callback_data"].startswith("media:") + ) + + request = media_interaction_manager.get_by_user("10001") + self.assertIsNotNone(request) + self.assertEqual(request.action, "Search") + self.assertEqual(len(request.items), 2) + + def test_media_interaction_legacy_page_callback_updates_existing_request(self): + chain = MediaInteractionChain() + request = media_interaction_manager.create_or_replace( + user_id="10001", + channel=MessageChannel.Telegram, + source="telegram-test", + username="tester", + action="Search", + keyword="星际穿越", + title="星际穿越", + meta=self._build_meta("星际穿越"), + items=[ + MediaInfo(title=f"资源 {index}", year="2024") + for index in range(1, 11) + ], + ) + + with patch.object(chain, "post_medias_message") as post_medias_message: + handled = chain.handle_callback_interaction( + callback_data="page_n", + channel=MessageChannel.Telegram, + source="telegram-test", + userid="10001", + username="tester", + original_message_id=123, + original_chat_id="456", + ) + + self.assertTrue(handled) + self.assertEqual(request.page, 1) + post_medias_message.assert_called_once() + notification = post_medias_message.call_args.args[0] + self.assertEqual(notification.original_message_id, 123) + self.assertEqual(notification.original_chat_id, "456") + + +if __name__ == "__main__": + unittest.main()