diff --git a/app/agent/__init__.py b/app/agent/__init__.py index 401fb70e..457c2d19 100644 --- a/app/agent/__init__.py +++ b/app/agent/__init__.py @@ -1635,6 +1635,10 @@ class _MessageTask: processing_status: Optional[dict] = None reply_mode: ReplyMode = ReplyMode.DISPATCH allow_message_tools: bool = True + output_callback: Optional[Callable[[str], None]] = None + notification_callback: Optional[Callable[[Any], None]] = None + agent_factory: Optional[Callable[..., MoviePilotAgent]] = None + completion_future: Optional[asyncio.Future] = None class AgentManager: @@ -1780,11 +1784,18 @@ class AgentManager: original_chat_id: Optional[str] = None, reply_mode: ReplyMode = ReplyMode.DISPATCH, allow_message_tools: bool = True, + output_callback: Optional[Callable[[str], None]] = None, + notification_callback: Optional[Callable[[Any], None]] = None, + agent_factory: Optional[Callable[..., MoviePilotAgent]] = None, + wait_for_completion: bool = False, ) -> str: """ 处理用户消息:将消息放入会话队列,按顺序依次处理。 同一会话的消息排队等待,不同会话之间互不影响。 """ + completion_future = ( + asyncio.get_running_loop().create_future() if wait_for_completion else None + ) task = _MessageTask( session_id=session_id, user_id=user_id, @@ -1799,6 +1810,10 @@ class AgentManager: original_chat_id=original_chat_id, reply_mode=reply_mode, allow_message_tools=allow_message_tools, + output_callback=output_callback, + notification_callback=notification_callback, + agent_factory=agent_factory, + completion_future=completion_future, ) self._record_session_activity(session_id, user_id) @@ -1831,6 +1846,8 @@ class AgentManager: self._session_worker(session_id) ) + if completion_future: + return await completion_future return "" async def _session_worker(self, session_id: str): @@ -1854,9 +1871,17 @@ class AgentManager: try: await self._start_task_processing_status(task) - await self._process_message_internal(task) + result = await self._process_message_internal(task) + if task.completion_future and not task.completion_future.done(): + task.completion_future.set_result(result) + except asyncio.CancelledError as err: + if task.completion_future and not task.completion_future.done(): + task.completion_future.set_exception(err) + raise except Exception as e: logger.error(f"处理会话 {session_id} 的消息失败: {e}") + if task.completion_future and not task.completion_future.done(): + task.completion_future.set_exception(e) finally: await self._finish_task_processing_status(task) queue.task_done() @@ -1895,21 +1920,36 @@ class AgentManager: 实际处理单条消息 """ session_id = task.session_id + existing_agent = self.active_agents.get(session_id) + if ( + existing_agent + and task.agent_factory + and isinstance(task.agent_factory, type) + and not isinstance(existing_agent, task.agent_factory) + ): + await existing_agent.cleanup() + self.active_agents.pop(session_id, None) + if session_id not in self.active_agents: logger.info( f"创建新的AI智能体实例,session_id: {session_id}, user_id: {task.user_id}" ) - agent = MoviePilotAgent( - session_id=session_id, - user_id=task.user_id, - channel=task.channel, - source=task.source, - username=task.username, - original_message_id=task.original_message_id, - original_chat_id=task.original_chat_id, - replay_mode=task.reply_mode, - allow_message_tools=task.allow_message_tools, - ) + agent_factory = task.agent_factory or MoviePilotAgent + agent_kwargs = { + "session_id": session_id, + "user_id": task.user_id, + "channel": task.channel, + "source": task.source, + "username": task.username, + "original_message_id": task.original_message_id, + "original_chat_id": task.original_chat_id, + "replay_mode": task.reply_mode, + "allow_message_tools": task.allow_message_tools, + "output_callback": task.output_callback, + } + if task.notification_callback is not None and task.agent_factory: + agent_kwargs["notification_callback"] = task.notification_callback + agent = agent_factory(**agent_kwargs) self.active_agents[session_id] = agent else: agent = self.active_agents[session_id] @@ -1924,6 +1964,12 @@ class AgentManager: agent.original_chat_id = task.original_chat_id agent.reply_mode = task.reply_mode agent.allow_message_tools = task.allow_message_tools + if hasattr(agent, "set_output_callback"): + agent.set_output_callback(task.output_callback) + else: + agent.output_callback = task.output_callback + if task.notification_callback is not None and hasattr(agent, "set_notification_callback"): + agent.set_notification_callback(task.notification_callback) process_kwargs = { "images": task.images, diff --git a/app/agent/tools/impl/ask_user_choice.py b/app/agent/tools/impl/ask_user_choice.py index 8f3f7758..e6564390 100644 --- a/app/agent/tools/impl/ask_user_choice.py +++ b/app/agent/tools/impl/ask_user_choice.py @@ -24,9 +24,14 @@ class UserChoiceOptionInput(BaseModel): ..., description="The exact content that will be sent back to the agent after the user clicks this button", ) + description: Optional[str] = Field( + None, + description="Optional user-facing description stored in chat history after this option is selected", + ) @model_validator(mode="after") def validate_option(self): + """校验按钮选项的文案和值不能为空。""" label = str(self.label) value = str(self.value) if not label.strip(): @@ -56,6 +61,7 @@ class AskUserChoiceInput(BaseModel): @model_validator(mode="after") def validate_payload(self): + """校验按钮选择工具必须提供问题和选项。""" message = str(self.message) if not message.strip(): raise ValueError("message 不能为空") @@ -85,6 +91,7 @@ class AskUserChoiceTool(MoviePilotTool): args_schema: Type[BaseModel] = AskUserChoiceInput def get_tool_message(self, **kwargs) -> Optional[str]: + """生成工具执行提示文案。""" message = kwargs.get("message", "") or "" if len(message) > 40: message = message[:40] + "..." @@ -92,6 +99,7 @@ class AskUserChoiceTool(MoviePilotTool): @staticmethod def _truncate_button_text(text: str, max_length: int) -> str: + """按渠道限制截断按钮文案。""" if max_length <= 0 or len(text) <= max_length: return text if max_length <= 3: @@ -114,6 +122,14 @@ class AskUserChoiceTool(MoviePilotTool): title: Optional[str] = None, **kwargs, ) -> str: + """ + 发送按钮选择消息,并登记待回调的交互上下文。 + + :param message: 展示给用户的问题 + :param options: 可点击的选项列表 + :param title: 可选标题 + :return: 工具执行结果描述 + """ if self._blocked_by_feedback_quality_gate(): logger.warning( "ask_user_choice blocked after feedback issue rejected_quality: " @@ -148,7 +164,9 @@ class AskUserChoiceTool(MoviePilotTool): choice_options = [ AgentInteractionOption( - label=option.label.strip(), value=option.value.strip() + label=option.label.strip(), + value=option.value.strip(), + description=(option.description.strip() if option.description else None), ) for option in options ] @@ -172,6 +190,7 @@ class AskUserChoiceTool(MoviePilotTool): "callback_data": ( f"agent_interaction:choice:{request.request_id}:{index}" ), + "description": option.description or option.label, } ) if len(current_row) >= max_per_row: diff --git a/app/api/endpoints/agent.py b/app/api/endpoints/agent.py index 1a8a5cd1..87285331 100644 --- a/app/api/endpoints/agent.py +++ b/app/api/endpoints/agent.py @@ -1,4 +1,5 @@ import asyncio +import copy import hashlib import json import mimetypes @@ -6,8 +7,10 @@ import shutil import subprocess import time import uuid +from queue import Empty, Queue from pathlib import Path -from typing import Any, AsyncIterator, Callable, Optional +from threading import Lock +from typing import Any, AsyncIterator, Callable, Optional, Union from fastapi import APIRouter, Depends, File, Form, HTTPException, Request, UploadFile, status from fastapi.concurrency import run_in_threadpool @@ -15,17 +18,24 @@ from fastapi.responses import FileResponse, StreamingResponse from sqlalchemy.ext.asyncio import AsyncSession from app import schemas -from app.agent import MoviePilotAgent, ReplyMode, StreamingHandler +from app.agent import MoviePilotAgent, ReplyMode, StreamingHandler, agent_manager from app.agent.llm.capability import AgentCapabilityManager +from app.chain.message import MessageChain +from app.chain.site import site_interaction_manager +from app.chain.skills import skills_interaction_manager +from app.chain.subscribe import subscribe_interaction_manager +from app.command import Command from app.core.config import global_vars, settings +from app.core.event import Event, EventManager from app.db import get_async_db from app.db.agentchat_oper import AgentChatOper from app.db.models import User from app.db.models.agentchat import AgentChat from app.db.user_oper import UserOper, get_current_active_user -from app.helper.interaction import agent_interaction_manager +from app.helper.agent import attach_web_agent_edit_queue, detach_web_agent_edit_queue +from app.helper.interaction import agent_interaction_manager, media_interaction_manager from app.log import logger -from app.schemas.types import MessageChannel +from app.schemas.types import EventType, MessageChannel router = APIRouter() @@ -36,7 +46,12 @@ WEB_AGENT_FILE_MAX_ITEMS = 256 WEB_AGENT_UPLOAD_MAX_BYTES = 32 * 1024 * 1024 WEB_AGENT_UPLOAD_CHUNK_SIZE = 1024 * 1024 WEB_AGENT_BROWSER_AUDIO_SUFFIXES = {".aac", ".m4a", ".mp3", ".mp4", ".wav", ".wave"} +WEB_AGENT_TRADITIONAL_IDLE_TIMEOUT_SECONDS = 2.0 +WEB_AGENT_TRADITIONAL_MAX_WAIT_SECONDS = 60.0 _WEB_AGENT_FILE_REGISTRY: dict[str, dict[str, Any]] = {} +_WEB_AGENT_NOTICE_QUEUES: dict[str, list[Queue[schemas.Notification]]] = {} +_WEB_AGENT_NOTICE_LOCK = Lock() +_WEB_AGENT_NOTICE_LISTENER_REGISTERED = False class _WebAgentStreamingHandler(StreamingHandler): @@ -48,6 +63,14 @@ class _WebAgentStreamingHandler(StreamingHandler): super().__init__() self._on_emit = on_emit + def set_emit_callback(self, on_emit: Callable[[str], None]) -> None: + """ + 更新流式输出回调,复用 WebAgent 实例时指向当前 SSE 请求。 + + :param on_emit: 当前请求的输出回调 + """ + self._on_emit = on_emit + def emit(self, token: str) -> str: """追加 token 并同步通知 SSE 生产者。""" emitted = super().emit(token) @@ -124,6 +147,27 @@ class _WebAgentMoviePilotAgent(MoviePilotAgent): """Web 面板需要实时输出,即使 Web 渠道本身不支持消息编辑。""" return True + def set_notification_callback( + self, + notification_callback: Optional[Callable[[schemas.Notification], None]], + ) -> None: + """ + 更新 Web SSE 通知回调,复用 Agent 实例时指向当前请求队列。 + + :param notification_callback: 当前请求的 Web 通知回调 + """ + self._notification_callback = notification_callback + + def set_output_callback(self, output_callback: Optional[Callable[[str], None]]) -> None: + """ + 更新 Web SSE 输出回调,复用 Agent 实例时指向当前请求队列。 + + :param output_callback: 当前请求的输出回调 + """ + self.output_callback = output_callback + if output_callback and isinstance(self.stream_handler, _WebAgentStreamingHandler): + self.stream_handler.set_emit_callback(self._emit_output) + async def _is_system_admin_context(self) -> bool: """Web Agent 根据当前登录用户 ID 判断工具管理员上下文。""" if not self.user_id: @@ -218,6 +262,14 @@ def _apply_web_agent_display_event(event: dict, assistant_message: dict) -> None assistant_message["attachments"].append(event["attachment"]) elif event_type == "choice" and event.get("choice"): assistant_message["choices"].append({**event["choice"], "status": "pending"}) + elif event_type == "message_update": + target_message = event.get("target_message") or {} + assistant_message["id"] = target_message.get("id") or assistant_message.get("id") + assistant_message["content"] = target_message.get("content") or "" + assistant_message["attachments"] = target_message.get("attachments") or [] + assistant_message["choices"] = target_message.get("choices") or [] + assistant_message["tools"] = target_message.get("tools") or [] + assistant_message["status"] = target_message.get("status") or "done" elif event_type == "error": assistant_message["status"] = "error" assistant_message["content"] = ( @@ -657,27 +709,32 @@ def _parse_web_agent_choice_callback(callback_data: str) -> Optional[tuple[str, return request_id, int(option_index) -def _flatten_web_agent_choice_buttons(buttons: Optional[list[list[dict]]]) -> list[dict]: +def _normalize_web_agent_choice_button_rows(buttons: Optional[list[list[dict]]]) -> list[list[dict]]: """ - 将消息渠道按钮二维结构转换为 Web 前端可渲染的一维选项列表。 + 将消息渠道按钮二维结构转换为 Web 前端可渲染的按钮行。 :param buttons: Notification 中的按钮行 - :return: Web 选择卡片按钮列表 + :return: Web 选择卡片按钮行 """ - flattened = [] + normalized_rows = [] for row in buttons or []: + normalized_row = [] for button in row or []: text = str(button.get("text") or "").strip() callback_data = str(button.get("callback_data") or "").strip() if not text or not callback_data: continue - flattened.append( + description = str(button.get("description") or "").strip() + normalized_row.append( { "label": text, "callback_data": callback_data, + **({"description": description} if description else {}), } ) - return flattened + if normalized_row: + normalized_rows.append(normalized_row) + return normalized_rows def _build_web_agent_choice_event(notification: schemas.Notification) -> Optional[dict]: @@ -687,7 +744,8 @@ def _build_web_agent_choice_event(notification: schemas.Notification) -> Optiona :param notification: Agent 工具发出的按钮通知 :return: 选择卡片事件,按钮为空时返回 None """ - buttons = _flatten_web_agent_choice_buttons(notification.buttons) + button_rows = _normalize_web_agent_choice_button_rows(notification.buttons) + buttons = [button for row in button_rows for button in row] if not buttons: return None @@ -703,10 +761,32 @@ def _build_web_agent_choice_event(notification: schemas.Notification) -> Optiona "title": notification.title, "prompt": notification.text or "", "buttons": buttons, + "button_rows": button_rows, }, } +def _build_web_agent_choice_buttons_from_request( + request, +) -> tuple[list[dict], list[list[dict]]]: + """ + 根据待处理交互请求重建可持久化的按钮列表与按钮行。 + + :param request: 等待用户选择的交互请求 + :return: 平铺按钮列表与按行分组的按钮结构 + """ + buttons = [ + { + "label": option.label, + "callback_data": f"agent_interaction:choice:{request.request_id}:{index}", + "description": option.description or option.label, + } + for index, option in enumerate(request.options, start=1) + ] + button_rows = [[button] for button in buttons] + return buttons, button_rows + + def _resolve_web_agent_choice_payload(callback_data: str, user_id: str) -> Optional[dict]: """ 解析并消费 Web Agent 按钮选择,生成前端反馈与下一条用户消息。 @@ -729,15 +809,31 @@ def _resolve_web_agent_choice_payload(callback_data: str, user_id: str) -> Optio return None request, option = resolved + buttons, button_rows = _build_web_agent_choice_buttons_from_request(request) + selected_description = option.description or option.label return { "message": option.value, "session_id": request.session_id, + "display_message": selected_description, + "choice_selection": { + "choice_id": request.request_id, + "title": request.title, + "prompt": request.prompt, + "buttons": buttons, + "button_rows": button_rows, + "selected_label": option.label, + "selected_value": option.value, + "selected_description": selected_description, + }, "feedback": { "request_id": request.request_id, "title": request.title, "prompt": request.prompt, "selected_label": option.label, "selected_value": option.value, + "selected_description": selected_description, + "buttons": buttons, + "button_rows": button_rows, }, } @@ -802,6 +898,354 @@ def _build_web_agent_notification_events( return events +def _build_web_agent_display_message_from_events( + events: list[dict], +) -> dict: + """ + 将传统消息事件聚合为前端展示消息快照。 + + :param events: 已转换的 WebAgent SSE 事件列表 + :return: 可持久化的助手展示消息 + """ + message = MoviePilotAgent.build_display_message( + role="assistant", + status="streaming", + ) + for event in events: + _apply_web_agent_display_event(copy.deepcopy(event), message) + _apply_web_agent_display_event({"type": "done"}, message) + return message + + +def _is_web_agent_traditional_message(text: str) -> bool: + """ + 判断用户输入是否应走传统消息命令/交互链路。 + + :param text: 前端输入文本 + :return: 需要交给 MessageChain 时返回 True + """ + normalized = str(text or "").strip() + return normalized.startswith("/") or normalized.startswith("CALLBACK:") + + +def _has_web_agent_traditional_interaction(user_id: str) -> bool: + """ + 判断当前用户是否存在待继续的传统交互会话。 + + :param user_id: 当前登录用户 ID + :return: 存在传统交互上下文时返回 True + """ + return any( + manager.get_by_user(user_id) + for manager in ( + site_interaction_manager, + subscribe_interaction_manager, + skills_interaction_manager, + media_interaction_manager, + ) + ) + + +def _extract_web_agent_notification_from_event_data( + data: dict, +) -> Optional[schemas.Notification]: + """ + 从 NoticeMessage 事件数据中提取 WebAgent 通知。 + + :param data: NoticeMessage 事件数据,兼容扁平字段和 message 包装格式 + :return: WebAgent 通知,不属于 WebAgent 或数据无效时返回 None + """ + if not isinstance(data, dict): + return None + + try: + message = data.get("message") + if isinstance(message, schemas.Notification): + notification = message + elif isinstance(message, dict): + notification_data = copy.deepcopy(message) + notification_data.pop("type", None) + notification = schemas.Notification(**notification_data) + else: + notification_data = copy.deepcopy(data) + notification_data.pop("type", None) + notification_data.pop("current_time", None) + notification = schemas.Notification(**notification_data) + except Exception as err: + logger.debug(f"解析WebAgent通知事件失败: {err}") + return None + + channel = notification.channel + channel_value = channel.value if isinstance(channel, MessageChannel) else channel + if channel_value != MessageChannel.WebAgent.value: + return None + return notification + + +def _is_web_agent_notice_for_user( + notification: schemas.Notification, + user_id: str, +) -> bool: + """ + 判断 NoticeMessage 事件是否属于当前 WebAgent 用户。 + + :param notification: NoticeMessage 中的通知消息 + :param user_id: 当前登录用户 ID + :return: 可被本次 WebAgent 请求消费时返回 True + """ + try: + target_user = notification.userid + return target_user is None or str(target_user) == str(user_id) + except Exception: + return False + + +def _get_web_agent_notice_user_id(notification: schemas.Notification) -> Optional[str]: + """ + 从 NoticeMessage 事件中解析 WebAgent 目标用户。 + + :param notification: NoticeMessage 中的通知消息 + :return: 用户 ID 字符串,事件不属于 WebAgent 时返回 None + """ + try: + channel = notification.channel + channel_value = channel.value if isinstance(channel, MessageChannel) else channel + if channel_value != MessageChannel.WebAgent.value: + return None + user_id = notification.userid + return str(user_id) if user_id is not None else None + except Exception: + return None + + +def _dispatch_web_agent_notice_event(event: Event) -> None: + """ + 将 WebAgent NoticeMessage 分发给正在等待的请求队列。 + + :param event: NoticeMessage 广播事件 + """ + data = event.event_data if isinstance(event.event_data, dict) else {} + notification = _extract_web_agent_notification_from_event_data(data) + if not notification: + return + with _WEB_AGENT_NOTICE_LOCK: + user_id = _get_web_agent_notice_user_id(notification) + if user_id is None: + queues = [ + notice_queue + for user_queues in _WEB_AGENT_NOTICE_QUEUES.values() + for notice_queue in user_queues + ] + else: + queues = list(_WEB_AGENT_NOTICE_QUEUES.get(user_id) or []) + for notice_queue in queues: + notice_queue.put(notification) + + +def _ensure_web_agent_notice_listener() -> None: + """ + 确保 WebAgent NoticeMessage 全局监听器已注册。 + """ + global _WEB_AGENT_NOTICE_LISTENER_REGISTERED + if _WEB_AGENT_NOTICE_LISTENER_REGISTERED: + return + with _WEB_AGENT_NOTICE_LOCK: + if _WEB_AGENT_NOTICE_LISTENER_REGISTERED: + return + EventManager().add_event_listener( + EventType.NoticeMessage, + _dispatch_web_agent_notice_event, + ) + _WEB_AGENT_NOTICE_LISTENER_REGISTERED = True + + +def _attach_web_agent_notice_queue(user_id: str, notice_queue: Queue[schemas.Notification]) -> None: + """ + 为当前 WebAgent 请求挂载通知收集队列。 + + :param user_id: 当前用户 ID + :param notice_queue: 用于接收通知事件的队列 + """ + _ensure_web_agent_notice_listener() + with _WEB_AGENT_NOTICE_LOCK: + _WEB_AGENT_NOTICE_QUEUES.setdefault(str(user_id), []).append(notice_queue) + + +def _detach_web_agent_notice_queue(user_id: str, notice_queue: Queue[schemas.Notification]) -> None: + """ + 移除当前 WebAgent 请求的通知收集队列。 + + :param user_id: 当前用户 ID + :param notice_queue: 需要移除的队列 + """ + with _WEB_AGENT_NOTICE_LOCK: + queues = _WEB_AGENT_NOTICE_QUEUES.get(str(user_id)) + if not queues: + return + _WEB_AGENT_NOTICE_QUEUES[str(user_id)] = [ + item for item in queues if item is not notice_queue + ] + if not _WEB_AGENT_NOTICE_QUEUES[str(user_id)]: + _WEB_AGENT_NOTICE_QUEUES.pop(str(user_id), None) + + +def _build_web_agent_command_items() -> list[dict]: + """ + 读取当前可用斜杠命令并转换为前端建议列表。 + + :return: 按分类和命令名排序的命令列表 + """ + commands = Command().get_commands() or {} + items = [] + for command, data in commands.items(): + if not command.startswith("/"): + continue + if data.get("show") is False: + continue + items.append( + { + "command": command, + "description": data.get("description") or "", + "category": data.get("category") or "其他", + "type": data.get("type") or "", + "pid": data.get("pid"), + } + ) + return sorted(items, key=lambda item: (item["category"], item["command"])) + + +def _extract_web_agent_slash_command(text: str) -> Optional[str]: + """ + 从 WebAgent 输入中提取斜杠命令名。 + + :param text: 前端输入文本 + :return: 斜杠命令名,非命令输入返回 None + """ + normalized = str(text or "").strip() + if not normalized.startswith("/") or normalized.startswith("//"): + return None + command = normalized.split(maxsplit=1)[0].strip() + return command or None + + +def _get_web_agent_unknown_command_message(text: str) -> Optional[str]: + """ + 判断 WebAgent 斜杠命令是否不存在。 + + :param text: 前端输入文本 + :return: 命令不存在时返回错误提示,命令存在或非命令时返回 None + """ + command = _extract_web_agent_slash_command(text) + if not command: + return None + if Command().get(command): + return None + return f"命令不存在:{command}" + + +def _ensure_web_agent_command_allowed(current_user: User) -> Optional[str]: + """ + 校验当前 Web 用户是否可以执行传统斜杠命令。 + + :param current_user: 当前登录用户 + :return: 无权限时返回错误提示,允许执行时返回 None + """ + if getattr(current_user, "is_superuser", False): + return None + return "只有管理员才有权限执行此命令" + + +async def _collect_web_agent_traditional_events( + *, + text: str, + current_user: User, + original_message_id: Optional[Union[str, int]] = None, + original_chat_id: Optional[Union[str, int]] = None, +) -> list[dict]: + """ + 执行传统消息链路并收集本次 WebAgent 用户产生的通知事件。 + + :param text: 需要交给传统消息链路处理的文本 + :param current_user: 当前登录用户 + :param original_message_id: WebAgent 原助手消息 ID + :param original_chat_id: WebAgent 原聊天 ID + :return: 可直接发送给前端的 SSE 事件列表 + """ + notice_queue: Queue[schemas.Notification] = Queue() + edit_queue: Queue[dict] = Queue() + user_id = str(current_user.id) + + _attach_web_agent_notice_queue(user_id, notice_queue) + attach_web_agent_edit_queue(user_id, edit_queue) + try: + await run_in_threadpool( + MessageChain().handle_message, + channel=MessageChannel.WebAgent, + source=WEB_AGENT_SOURCE, + userid=user_id, + username=current_user.name or user_id, + text=text, + original_message_id=original_message_id, + original_chat_id=original_chat_id, + ) + + events = [] + deadline = time.monotonic() + WEB_AGENT_TRADITIONAL_MAX_WAIT_SECONDS + idle_deadline: Optional[float] = None + while time.monotonic() < deadline: + now = time.monotonic() + drained_edit_event = False + while True: + try: + events.append(edit_queue.get_nowait()) + drained_edit_event = True + except Empty: + break + if drained_edit_event: + idle_deadline = time.monotonic() + WEB_AGENT_TRADITIONAL_IDLE_TIMEOUT_SECONDS + continue + + wait_until = idle_deadline or deadline + timeout = max(0.05, min(0.25, wait_until - now, deadline - now)) + try: + notification = await asyncio.to_thread(notice_queue.get, True, timeout) + except Empty: + if idle_deadline and time.monotonic() >= idle_deadline: + break + continue + + if not _is_web_agent_notice_for_user(notification, user_id): + continue + events.extend(_build_web_agent_notification_events(notification)) + idle_deadline = time.monotonic() + WEB_AGENT_TRADITIONAL_IDLE_TIMEOUT_SECONDS + return events + finally: + _detach_web_agent_notice_queue(user_id, notice_queue) + detach_web_agent_edit_queue(user_id, edit_queue) + + +def _build_web_agent_traditional_callback_payload( + callback_data: str, + original_message_id: Optional[Union[str, int]] = None, + original_chat_id: Optional[Union[str, int]] = None, +) -> dict: + """ + 构造传统消息链按钮回调的前端执行载荷。 + + :param callback_data: 前端点击的传统按钮回调数据 + :param original_message_id: WebAgent 原助手消息 ID + :param original_chat_id: WebAgent 原聊天 ID + :return: 前端可继续发送到 /stream 的消息载荷 + """ + return { + "message": f"CALLBACK:{callback_data}", + "display_message": callback_data, + "traditional": True, + "original_message_id": original_message_id, + "original_chat_id": original_chat_id, + } + + def _split_web_agent_output(text: str) -> list[dict]: """ 将 Agent 输出拆成普通文本与工具提示事件。 @@ -943,6 +1387,19 @@ async def web_agent_callback( :param current_user: 当前登录用户 :return: 下一条需要发送给 Agent 的用户消息与卡片反馈 """ + if not _parse_web_agent_choice_callback(payload.callback_data): + denied_message = _ensure_web_agent_command_allowed(current_user) + if denied_message: + return schemas.Response(success=False, message=denied_message) + return schemas.Response( + success=True, + data=_build_web_agent_traditional_callback_payload( + payload.callback_data, + original_message_id=payload.original_message_id, + original_chat_id=payload.original_chat_id, + ), + ) + result = _resolve_web_agent_choice_payload( callback_data=payload.callback_data, user_id=str(current_user.id), @@ -952,6 +1409,22 @@ async def web_agent_callback( return schemas.Response(success=True, data=result) +@router.get("/commands", summary="获取 Web 智能助手可用命令", response_model=schemas.Response) +async def list_web_agent_commands( + current_user: User = Depends(get_current_active_user), +) -> schemas.Response: + """ + 获取当前 Web 智能助手可补全的斜杠命令。 + + :param current_user: 当前登录用户 + :return: 可用命令列表 + """ + denied_message = _ensure_web_agent_command_allowed(current_user) + if denied_message: + return schemas.Response(success=False, message=denied_message) + return schemas.Response(success=True, data=_build_web_agent_command_items()) + + @router.get("/sessions", summary="获取 Agent 历史会话", response_model=schemas.Response) async def list_agent_chat_sessions( current_user: User = Depends(get_current_active_user), @@ -1076,6 +1549,92 @@ async def web_agent_stream( :param current_user: 当前登录用户 :return: SSE 流式响应 """ + prompt = payload.text.strip() + display_prompt = (payload.display_text or payload.text).strip() + is_traditional_message = ( + _is_web_agent_traditional_message(prompt) + or _has_web_agent_traditional_interaction(str(current_user.id)) + ) + if is_traditional_message: + denied_message = _ensure_web_agent_command_allowed(current_user) + if denied_message: + return StreamingResponse( + iter([ + _build_web_agent_sse( + "error", + {"message": denied_message}, + ) + ]), + media_type="text/event-stream", + ) + unknown_command_message = _get_web_agent_unknown_command_message(prompt) + if unknown_command_message: + return StreamingResponse( + iter([ + _build_web_agent_sse( + "error", + {"message": unknown_command_message}, + ) + ]), + media_type="text/event-stream", + ) + + session_id = _build_web_agent_session_id(current_user, payload.session_id) + user_attachments = _build_web_agent_input_attachments( + images=payload.images or [], + files=[ + file.model_dump(exclude_none=True) + for file in (payload.files or []) + ], + audio_refs=payload.audio_refs or [], + ) + display_messages = [] + if payload.echo_user: + display_messages.append( + MoviePilotAgent.build_display_message( + role="user", + content=display_prompt or prompt, + attachments=user_attachments, + ) + ) + + async def traditional_event_generator() -> AsyncIterator[str]: + """ + 生成传统消息链路的 WebAgent SSE 事件。 + """ + yield _build_web_agent_sse("start", {"session_id": session_id}) + events = await _collect_web_agent_traditional_events( + text=prompt, + current_user=current_user, + original_message_id=payload.original_message_id, + original_chat_id=payload.original_chat_id, + ) + assistant_message = _build_web_agent_display_message_from_events(events) + display_messages.append(assistant_message) + for event in events: + event_payload = copy.deepcopy(event) + yield _build_web_agent_sse(event_payload.pop("type"), event_payload) + if await request.is_disconnected(): + break + await run_in_threadpool( + _save_web_agent_display_snapshot, + session_id=session_id, + current_user=current_user, + messages=display_messages, + client_session_id=payload.session_id or session_id, + ) + yield _build_web_agent_sse("done", {}) + + return StreamingResponse( + traditional_event_generator(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + }, + ) + if not settings.AI_AGENT_ENABLE: return StreamingResponse( iter([ @@ -1087,9 +1646,9 @@ async def web_agent_stream( media_type="text/event-stream", ) - prompt = payload.text.strip() transcript = _transcribe_web_agent_audio_refs(payload.audio_refs or []) prompt = _merge_web_agent_prompt_with_transcript(prompt, transcript) + display_prompt = _merge_web_agent_prompt_with_transcript(display_prompt, transcript) has_audio_input = bool(transcript) if not prompt and payload.audio_refs and not payload.images and not payload.files: return StreamingResponse( @@ -1113,6 +1672,7 @@ async def web_agent_stream( ) session_id = _build_web_agent_session_id(current_user, payload.session_id) + MessageChain().bind_user_session(str(current_user.id), session_id) event_queue: asyncio.Queue = asyncio.Queue() last_output = "" user_attachments = _build_web_agent_input_attachments( @@ -1125,13 +1685,14 @@ async def web_agent_stream( ) display_messages = [] if payload.echo_user: - display_messages.append( - MoviePilotAgent.build_display_message( - role="user", - content=prompt, - attachments=user_attachments, - ) + user_display_message = MoviePilotAgent.build_display_message( + role="user", + content=display_prompt or prompt, + attachments=user_attachments, ) + if payload.choice_selection: + user_display_message["choice_selection"] = payload.choice_selection + display_messages.append(user_display_message) assistant_display_message = MoviePilotAgent.build_display_message( role="assistant", status="streaming", @@ -1170,26 +1731,25 @@ async def web_agent_stream( for audio_ref in payload.audio_refs or []: files.append({"ref": audio_ref, "mime_type": "audio/*"}) - agent = _WebAgentMoviePilotAgent( - session_id=session_id, - user_id=str(current_user.id), - channel=MessageChannel.WebAgent.value, - source=WEB_AGENT_SOURCE, - username=current_user.name, - replay_mode=ReplyMode.CAPTURE_ONLY, - allow_message_tools=True, - output_callback=output_callback, - notification_callback=notification_callback, - ) - async def run_agent() -> None: """后台执行 Agent,并将结果写入事件队列。""" try: - await agent.process( + await agent_manager.process_message( + session_id=session_id, + user_id=str(current_user.id), message=prompt, images=payload.images or [], files=files or None, has_audio_input=has_audio_input, + channel=MessageChannel.WebAgent.value, + source=WEB_AGENT_SOURCE, + username=current_user.name, + reply_mode=ReplyMode.CAPTURE_ONLY, + allow_message_tools=True, + output_callback=output_callback, + notification_callback=notification_callback, + agent_factory=_WebAgentMoviePilotAgent, + wait_for_completion=True, ) except Exception as err: logger.error(f"Web智能助手执行失败: {str(err)}") @@ -1233,7 +1793,7 @@ async def web_agent_stream( await task except asyncio.CancelledError: pass - await agent.cleanup() + # WebAgent 会话由 AgentManager 统一管理,空闲清理或 /clear_session 时释放。 return StreamingResponse( event_generator(), diff --git a/app/chain/__init__.py b/app/chain/__init__.py index 26cd0d6c..72e16bb9 100644 --- a/app/chain/__init__.py +++ b/app/chain/__init__.py @@ -1769,6 +1769,21 @@ class ChainBase(metaclass=ABCMeta): :param metadata: 其他消息元数据 :return: 编辑是否成功 """ + if channel == MessageChannel.WebAgent: + try: + from app.helper.agent import edit_web_agent_message + + return edit_web_agent_message( + user_id=str((metadata or {}).get("userid") or ""), + message_id=message_id, + title=title, + text=text, + buttons=buttons, + ) + except Exception as err: + logger.debug(f"编辑 WebAgent 消息失败: {err}") + return False + return self.run_module( "edit_message", channel=channel, diff --git a/app/chain/message.py b/app/chain/message.py index a2bb1382..eb7a09e4 100644 --- a/app/chain/message.py +++ b/app/chain/message.py @@ -1000,6 +1000,15 @@ class MessageChain(ChainBase): self._schedule_agent_session_clear(old_session[0], userid) self._user_sessions[userid] = (session_id, datetime.now()) + def bind_user_session(self, userid: Union[str, int], session_id: str) -> None: + """ + 绑定用户与指定智能体会话,供非传统入口复用远程命令状态查询。 + + :param userid: 用户 ID + :param session_id: 智能体会话 ID + """ + self._bind_session_id(userid, session_id) + def _record_user_message( self, channel: MessageChannel, diff --git a/app/chain/skills.py b/app/chain/skills.py index 4eaf0f1c..865719f2 100644 --- a/app/chain/skills.py +++ b/app/chain/skills.py @@ -722,7 +722,7 @@ class SkillsChain(ChainBase): buttons = None if self._supports_interactive_buttons(request.channel): buttons = [] - for index, skill in enumerate(page_items, start=1): + for index, skill in enumerate(items, start=1): if not skill.removable: continue buttons.append( diff --git a/app/core/event.py b/app/core/event.py index c78b5a5a..d0807745 100644 --- a/app/core/event.py +++ b/app/core/event.py @@ -525,6 +525,13 @@ class EventManager(metaclass=Singleton): class_name=class_name, method_name=method_name, e=e) else: # 全局处理器 + if not class_name: + try: + handler(event) + except Exception as e: + self.__handle_event_error(event=event, module_name=self.__get_handler_identifier(handler), + class_name=class_name, method_name=method_name, e=e) + return class_obj = self.__get_class_instance(class_name) if not class_obj or not hasattr(class_obj, method_name): return @@ -556,6 +563,16 @@ class EventManager(metaclass=Singleton): elif class_name in module_manager.get_module_ids(): await self.__invoke_module_method_async(module_manager, class_name, method_name, event) else: + if not class_name: + try: + if inspect.iscoroutinefunction(handler): + await handler(event) + else: + await run_in_threadpool(handler, event) + except Exception as e: + self.__handle_event_error(event=event, module_name=self.__get_handler_identifier(handler), + class_name=class_name, method_name=method_name, e=e) + return await self.__invoke_global_method_async(class_name, method_name, event) @staticmethod @@ -566,6 +583,8 @@ class EventManager(metaclass=Singleton): :return: (class_name, method_name) """ names = handler.__qualname__.split(".") + if len(names) < 2: + return "", names[0] return names[0], names[1] async def __invoke_plugin_method_async(self, handler: Any, class_name: str, method_name: str, event: Event): diff --git a/app/helper/agent.py b/app/helper/agent.py new file mode 100644 index 00000000..a3708c4b --- /dev/null +++ b/app/helper/agent.py @@ -0,0 +1,175 @@ +from queue import Queue +from threading import Lock +from typing import Optional, Union + + +_WEB_AGENT_EDIT_QUEUES: dict[str, list[Queue[dict]]] = {} +_WEB_AGENT_EDIT_LOCK = Lock() + + +def normalize_web_agent_button_rows(buttons: Optional[list[list[dict]]]) -> list[list[dict]]: + """ + 将消息按钮转换为 WebAgent 前端可识别的按钮行。 + + :param buttons: 传统消息模块返回的按钮二维数组 + :return: WebAgent 前端选项按钮二维数组 + """ + button_rows: list[list[dict]] = [] + for row in buttons or []: + normalized_row = [] + for button in row or []: + label = str(button.get("text") or button.get("label") or "").strip() + callback_data = str(button.get("callback_data") or "").strip() + if not label or not callback_data: + continue + normalized_button = { + "label": label, + "callback_data": callback_data, + } + if button.get("description"): + normalized_button["description"] = str(button.get("description")) + normalized_row.append(normalized_button) + if normalized_row: + button_rows.append(normalized_row) + return button_rows + + +def _resolve_web_agent_choice_id( + message_id: Union[str, int], + button_rows: list[list[dict]], +) -> str: + """ + 从按钮回调中提取稳定的 WebAgent 选项 ID。 + + :param message_id: 前端助手消息 ID + :param button_rows: 已规范化的按钮行 + :return: 选项卡片 ID + """ + for row in button_rows: + for button in row: + callback_data = str(button.get("callback_data") or "").strip() + if not callback_data: + continue + parts = callback_data.split(":") + if len(parts) >= 2 and parts[1]: + return parts[1] + return callback_data + return str(message_id) + + +def build_web_agent_message_update_event( + *, + message_id: Union[str, int], + title: Optional[str], + text: str, + buttons: Optional[list[list[dict]]], +) -> dict: + """ + 构造 WebAgent 原消息更新事件。 + + :param message_id: 前端助手消息 ID + :param title: 更新后的标题 + :param text: 更新后的正文 + :param buttons: 更新后的按钮 + :return: 前端可应用到原消息的 SSE 事件 + """ + button_rows = normalize_web_agent_button_rows(buttons) + content_parts = [part for part in (title, text) if part] + target_message = { + "id": str(message_id), + "content": "" if button_rows else "\n\n".join(content_parts), + "choices": [], + "attachments": [], + "tools": [], + "status": "done", + } + if button_rows: + target_message["choices"].append({ + "id": _resolve_web_agent_choice_id(message_id, button_rows), + "title": title, + "prompt": text or "", + "buttons": [button for row in button_rows for button in row], + "button_rows": button_rows, + "status": "pending", + }) + return { + "type": "message_update", + "target_message": target_message, + } + + +def attach_web_agent_edit_queue(user_id: str, edit_queue: Queue[dict]) -> None: + """ + 为当前 WebAgent 请求挂载原消息编辑事件队列。 + + :param user_id: 当前用户 ID + :param edit_queue: 用于接收编辑事件的队列 + """ + with _WEB_AGENT_EDIT_LOCK: + _WEB_AGENT_EDIT_QUEUES.setdefault(str(user_id), []).append(edit_queue) + + +def detach_web_agent_edit_queue(user_id: str, edit_queue: Queue[dict]) -> None: + """ + 移除当前 WebAgent 请求的原消息编辑事件队列。 + + :param user_id: 当前用户 ID + :param edit_queue: 需要移除的队列 + """ + with _WEB_AGENT_EDIT_LOCK: + queues = _WEB_AGENT_EDIT_QUEUES.get(str(user_id)) + if not queues: + return + _WEB_AGENT_EDIT_QUEUES[str(user_id)] = [ + item for item in queues if item is not edit_queue + ] + if not _WEB_AGENT_EDIT_QUEUES[str(user_id)]: + _WEB_AGENT_EDIT_QUEUES.pop(str(user_id), None) + + +def dispatch_web_agent_edit_event( + *, + user_id: str, + event: dict, +) -> bool: + """ + 将 WebAgent 原消息编辑事件分发给正在等待的请求队列。 + + :param user_id: 当前用户 ID + :param event: 前端可应用的 SSE 事件 + :return: 是否存在接收本次编辑事件的请求队列 + """ + with _WEB_AGENT_EDIT_LOCK: + queues = list(_WEB_AGENT_EDIT_QUEUES.get(str(user_id)) or []) + for edit_queue in queues: + edit_queue.put(event) + return bool(queues) + + +def edit_web_agent_message( + *, + user_id: str, + message_id: Union[str, int], + title: Optional[str], + text: str, + buttons: Optional[list[list[dict]]] = None, +) -> bool: + """ + 原地更新 WebAgent 前端消息卡片。 + + :param user_id: 当前用户 ID + :param message_id: 前端助手消息 ID + :param title: 更新后的标题 + :param text: 更新后的正文 + :param buttons: 更新后的按钮 + :return: 是否已投递编辑事件 + """ + if not user_id: + return False + event = build_web_agent_message_update_event( + message_id=message_id, + title=title, + text=text, + buttons=buttons, + ) + return dispatch_web_agent_edit_event(user_id=user_id, event=event) diff --git a/app/helper/interaction.py b/app/helper/interaction.py index 5c75a873..7ebbd0f3 100644 --- a/app/helper/interaction.py +++ b/app/helper/interaction.py @@ -197,6 +197,9 @@ def update_or_post_message( and original_chat_id and ChannelCapabilityManager.supports_editing(channel) ): + edit_kwargs = {} + if channel == MessageChannel.WebAgent: + edit_kwargs["metadata"] = {"userid": userid} edited = chain.edit_message( channel=channel, source=source, @@ -205,6 +208,7 @@ def update_or_post_message( title=title, text=text, buttons=buttons, + **edit_kwargs, ) if edited: return @@ -402,6 +406,7 @@ class AgentInteractionOption: label: str value: str + description: Optional[str] = None @dataclass diff --git a/app/schemas/agent.py b/app/schemas/agent.py index ff94452d..9e3b24dc 100644 --- a/app/schemas/agent.py +++ b/app/schemas/agent.py @@ -88,6 +88,22 @@ class AgentChatChoiceButton(BaseModel): label: str = Field(..., description="按钮文案") callback_data: str = Field(..., description="回调数据") + description: Optional[str] = Field(None, description="选项描述") + + +class AgentChatChoiceSelection(BaseModel): + """ + Agent 会话中用户选择的展示快照。 + """ + + choice_id: str = Field(..., description="选择卡片 ID") + title: Optional[str] = Field(None, description="标题") + prompt: str = Field(default="", description="提示语") + buttons: list[AgentChatChoiceButton] = Field(default_factory=list, description="按钮列表") + button_rows: list[list[AgentChatChoiceButton]] = Field(default_factory=list, description="按钮行") + selected_label: Optional[str] = Field(None, description="已选择文案") + selected_value: Optional[str] = Field(None, description="已选择值") + selected_description: Optional[str] = Field(None, description="已选择描述") class AgentChatChoiceCard(BaseModel): @@ -99,9 +115,11 @@ class AgentChatChoiceCard(BaseModel): title: Optional[str] = Field(None, description="标题") prompt: str = Field(default="", description="提示语") buttons: list[AgentChatChoiceButton] = Field(default_factory=list, description="按钮列表") + button_rows: list[list[AgentChatChoiceButton]] = Field(default_factory=list, description="按钮行") status: str = Field(default="pending", description="选择状态") selected_label: Optional[str] = Field(None, description="已选择文案") selected_value: Optional[str] = Field(None, description="已选择值") + selected_description: Optional[str] = Field(None, description="已选择描述") class AgentChatMessage(BaseModel): @@ -117,6 +135,7 @@ class AgentChatMessage(BaseModel): tools: list[AgentChatToolCall] = Field(default_factory=list, description="工具提示列表") attachments: list[AgentChatAttachment] = Field(default_factory=list, description="附件列表") choices: list[AgentChatChoiceCard] = Field(default_factory=list, description="选择卡片列表") + choice_selection: Optional[AgentChatChoiceSelection] = Field(None, description="用户选择项快照") class AgentChatSession(BaseModel): diff --git a/app/schemas/message.py b/app/schemas/message.py index e02261c3..08345709 100644 --- a/app/schemas/message.py +++ b/app/schemas/message.py @@ -330,6 +330,8 @@ class AgentWebChatRequest(BaseModel): # 用户本轮输入 text: str = Field(default="") + # 展示历史中记录的用户可读文本;为空时使用 text + display_text: Optional[str] = Field(None) # 前端会话标识,相同标识复用同一段 Agent 记忆 session_id: Optional[str] = Field(None) # 图片 URL 或 data URL 列表 @@ -338,6 +340,12 @@ class AgentWebChatRequest(BaseModel): audio_refs: Optional[List[str]] = Field(default_factory=list) # 文件附件列表 files: Optional[List[AgentWebChatFile]] = Field(default_factory=list) + # 用户通过按钮选择时的完整选择快照 + choice_selection: Optional[Dict[str, Any]] = Field(default=None) + # WebAgent 按钮回调关联的原消息 ID,用于传统交互原地编辑卡片 + original_message_id: Optional[Union[str, int]] = Field(default=None) + # WebAgent 按钮回调关联的原聊天 ID,用于传统交互原地编辑卡片 + original_chat_id: Optional[Union[str, int]] = Field(default=None) # 是否在展示历史中记录本轮用户消息 echo_user: bool = Field(default=True) @@ -351,6 +359,10 @@ class AgentWebChoiceRequest(BaseModel): session_id: Optional[str] = Field(None) # Agent 工具生成的按钮回调数据 callback_data: str = Field(..., min_length=1) + # WebAgent 原助手消息 ID,用于传统按钮回调原地编辑 + original_message_id: Optional[Union[str, int]] = Field(default=None) + # WebAgent 原聊天 ID,用于传统按钮回调原地编辑 + original_chat_id: Optional[Union[str, int]] = Field(default=None) class ChannelCapability(Enum): diff --git a/tests/test_skills_command.py b/tests/test_skills_command.py index 343ec9ee..d533459c 100644 --- a/tests/test_skills_command.py +++ b/tests/test_skills_command.py @@ -512,6 +512,48 @@ class TestSkillsCommand(unittest.TestCase): self.assertIn("仓库来源 · acme/custom-skills", text) self.assertIn("3. 管理技能源", text) + def test_skills_chain_installed_view_builds_remove_buttons(self): + chain = SkillsChain() + request = skills_interaction_manager.create_or_replace( + user_id="10001", + channel=MessageChannel.WebAgent, + source="web-agent", + username="tester", + ) + + with patch.object( + chain.skillhelper, + "list_local_skills", + return_value=[ + SkillInfo( + id="builtin", + name="Builtin", + description="Built in skill", + source_type="builtin", + source_label="内置", + removable=False, + ), + SkillInfo( + id="custom", + name="Custom", + description="Custom skill", + source_type="local", + source_label="本地", + removable=True, + ), + ], + ): + title, text, buttons = chain._build_installed_view(request=request) + + self.assertEqual(title, "已安装技能") + self.assertIn("builtin", text) + self.assertIn("custom", text) + self.assertTrue(buttons) + self.assertIn( + {"text": "删除 2", "callback_data": f"skills:{request.request_id}:remove:2"}, + [button for row in buttons for button in row], + ) + def test_skills_chain_callback_enters_search_input_mode(self): chain = SkillsChain() request = skills_interaction_manager.create_or_replace( diff --git a/tests/test_web_agent_stream.py b/tests/test_web_agent_stream.py index 2e5841e9..fe173458 100644 --- a/tests/test_web_agent_stream.py +++ b/tests/test_web_agent_stream.py @@ -1,27 +1,39 @@ import asyncio import time +from queue import Queue from types import SimpleNamespace from unittest.mock import AsyncMock, patch from app import schemas -from app.agent import ReplyMode +from app.agent import ReplyMode, agent_manager from app.api.endpoints.agent import ( _WebAgentMoviePilotAgent, _WEB_AGENT_FILE_REGISTRY, + _WEB_AGENT_NOTICE_QUEUES, _apply_web_agent_display_event, _build_web_agent_input_attachments, _build_web_agent_notification_events, + _build_web_agent_command_items, _build_web_agent_session_id, + _build_web_agent_traditional_callback_payload, + _build_web_agent_display_message_from_events, + _collect_web_agent_traditional_events, + _dispatch_web_agent_notice_event, + _extract_web_agent_notification_from_event_data, + _has_web_agent_traditional_interaction, _prepare_web_agent_audio_attachment_path, _transcribe_web_agent_audio_refs, web_agent_stream, _resolve_web_agent_choice_payload, _split_web_agent_output, ) +from app.core.event import Event from app.db.agentchat_oper import AgentChatOper -from app.helper.interaction import AgentInteractionOption, agent_interaction_manager +from app.helper.agent import build_web_agent_message_update_event +from app.helper.interaction import AgentInteractionOption, agent_interaction_manager, skills_interaction_manager +from app.chain.message import MessageChain from app.schemas.message import ChannelCapability, ChannelCapabilityManager -from app.schemas.types import MessageChannel, NotificationType +from app.schemas.types import EventType, MessageChannel, NotificationType def test_split_web_agent_output_extracts_verbose_tool_message(): @@ -144,6 +156,125 @@ def test_build_web_agent_input_attachments_marks_kinds(): assert attachments[1]["name"] == "report.txt" +def test_build_web_agent_command_items_returns_slash_commands(): + """WebAgent 命令建议应返回可展示的斜杠命令。""" + with patch( + "app.api.endpoints.agent.Command", + return_value=SimpleNamespace( + get_commands=lambda: { + "/sites": {"description": "管理站点", "category": "站点"}, + "hidden": {"description": "忽略", "category": "其他"}, + "/hidden": {"description": "隐藏", "category": "其他", "show": False}, + } + ), + ): + commands = _build_web_agent_command_items() + + assert commands == [ + { + "command": "/sites", + "description": "管理站点", + "category": "站点", + "type": "", + "pid": None, + } + ] + + +def test_build_web_agent_command_items_includes_sites_command(): + """WebAgent 命令建议应包含内建站点管理命令。""" + commands = _build_web_agent_command_items() + + assert any(command["command"] == "/sites" for command in commands) + + +def test_build_web_agent_traditional_callback_payload_wraps_callback(): + """传统按钮回调应包装为可继续提交给 MessageChain 的消息。""" + payload = _build_web_agent_traditional_callback_payload( + "skills:req-1:root", + original_message_id="assistant-1", + original_chat_id="web-session", + ) + + assert payload["message"] == "CALLBACK:skills:req-1:root" + assert payload["traditional"] is True + assert payload["original_message_id"] == "assistant-1" + assert payload["original_chat_id"] == "web-session" + + +def test_web_agent_stream_returns_error_for_unknown_command(): + """不存在的 WebAgent 斜杠命令应立即返回错误,不进入等待队列。""" + payload = schemas.AgentWebChatRequest( + text="/missing_command 参数", + session_id="browser-session", + ) + request = SimpleNamespace() + user = SimpleNamespace(id=1, name="admin", is_superuser=True) + + with patch( + "app.api.endpoints.agent.Command", + return_value=SimpleNamespace(get=lambda _: {}), + ), patch("app.api.endpoints.agent.MessageChain.handle_message") as handle_message: + response = asyncio.run(web_agent_stream(payload, request, user)) + body = "".join(asyncio.run(_collect_streaming_response(response))) + + assert "error" in body + assert "命令不存在:/missing_command" in body + handle_message.assert_not_called() + + +def test_build_web_agent_message_update_event_converts_buttons(): + """WebAgent 编辑消息应转换为可原地更新卡片的事件。""" + event = build_web_agent_message_update_event( + message_id="assistant-1", + title="技能管理", + text="请选择操作", + buttons=[[{"text": "返回", "callback_data": "skills:req-1:root"}]], + ) + + assert event["type"] == "message_update" + assert event["target_message"]["id"] == "assistant-1" + assert event["target_message"]["choices"][0]["title"] == "技能管理" + assert event["target_message"]["choices"][0]["prompt"] == "请选择操作" + assert event["target_message"]["choices"][0]["buttons"][0]["label"] == "返回" + + +def test_build_web_agent_display_message_from_events_marks_done(): + """传统消息事件应聚合为完成态助手展示消息。""" + message = _build_web_agent_display_message_from_events([ + {"type": "delta", "content": "菜单"}, + { + "type": "choice", + "choice": { + "id": "choice-1", + "prompt": "请选择", + "buttons": [{"label": "返回", "callback_data": "back"}], + }, + }, + ]) + + assert message["content"] == "菜单" + assert message["status"] == "done" + assert message["choices"][0]["prompt"] == "请选择" + + +def test_has_web_agent_traditional_interaction_detects_pending_skills(): + """WebAgent 应能识别命令后的传统交互上下文。""" + skills_interaction_manager.clear() + try: + skills_interaction_manager.create_or_replace( + user_id="1", + channel=MessageChannel.WebAgent, + source="web-agent", + username="admin", + ) + + assert _has_web_agent_traditional_interaction("1") is True + assert _has_web_agent_traditional_interaction("2") is False + finally: + skills_interaction_manager.clear() + + def test_web_agent_admin_context_uses_current_user_id(): """Web Agent 工具权限应按当前登录用户 ID 判断管理员身份。""" agent = _WebAgentMoviePilotAgent( @@ -213,6 +344,69 @@ def test_build_web_agent_notification_events_extracts_image(): ] +def test_extract_web_agent_notification_supports_wrapped_message_event(): + """NoticeMessage 包装 Notification 时应仍能解析为 WebAgent 通知。""" + notification = schemas.Notification( + channel=MessageChannel.WebAgent, + source="web-agent", + title="会话状态", + userid="1", + ) + + extracted = _extract_web_agent_notification_from_event_data( + {"message": notification, "current_time": "2026-06-26 09:18:38"} + ) + + assert extracted == notification + + +def test_dispatch_web_agent_notice_event_accepts_wrapped_message_event(): + """WebAgent 等待队列应接收 message 包装格式的 NoticeMessage 事件。""" + notice_queue = Queue() + _WEB_AGENT_NOTICE_QUEUES["1"] = [notice_queue] + notification = schemas.Notification( + channel=MessageChannel.WebAgent, + source="web-agent", + title="会话状态", + userid="1", + ) + + try: + _dispatch_web_agent_notice_event( + Event( + EventType.NoticeMessage, + {"message": notification, "current_time": "2026-06-26 09:18:38"}, + ) + ) + finally: + _WEB_AGENT_NOTICE_QUEUES.pop("1", None) + + assert notice_queue.get_nowait() == notification + + +def test_collect_web_agent_traditional_events_does_not_emit_submit_hint(): + """传统命令未产生通知时不应返回“命令已提交”的兜底提示。""" + user = SimpleNamespace(id=1, name="admin") + + with patch( + "app.api.endpoints.agent.MessageChain.handle_message", + ), patch( + "app.api.endpoints.agent.WEB_AGENT_TRADITIONAL_IDLE_TIMEOUT_SECONDS", + 0.01, + ), patch( + "app.api.endpoints.agent.WEB_AGENT_TRADITIONAL_MAX_WAIT_SECONDS", + 0.05, + ): + events = asyncio.run( + _collect_web_agent_traditional_events( + text="/session_status", + current_user=user, + ) + ) + + assert events == [] + + def test_build_web_agent_notification_events_registers_local_file(tmp_path): """Agent 工具发送本地文件时应生成可下载附件事件。""" file_path = tmp_path / "report.txt" @@ -335,6 +529,70 @@ def test_web_agent_stream_returns_error_when_voice_transcription_fails(): assert "语音识别失败" in body +def test_web_agent_stream_binds_session_to_agent_manager(): + """WebAgent 普通对话应统一进入 AgentManager 并绑定远程命令会话。""" + payload = schemas.AgentWebChatRequest( + text="查看会话", + session_id="browser-session", + ) + request = SimpleNamespace(is_disconnected=AsyncMock(return_value=False)) + user = SimpleNamespace(id=1, name="admin", is_superuser=True) + + class FakeWebAgent: + """测试用 WebAgent,模拟 AgentManager 内部的持久实例。""" + + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + self.processed = [] + + def set_output_callback(self, output_callback): + """更新当前 SSE 输出回调。""" + self.output_callback = output_callback + + def set_notification_callback(self, notification_callback): + """更新当前 SSE 通知回调。""" + self.notification_callback = notification_callback + + async def process(self, message, **kwargs): + """模拟一次 WebAgent 推理输出。""" + self.processed.append((message, kwargs)) + self.output_callback("状态正常") + return "状态正常" + + async def cleanup(self): + """模拟 Agent 资源清理。""" + return None + + session_id = _build_web_agent_session_id(user, payload.session_id) + MessageChain._user_sessions.clear() + agent_manager.active_agents.pop(session_id, None) + agent_manager._session_queues.pop(session_id, None) + worker = agent_manager._session_workers.pop(session_id, None) + if worker: + worker.cancel() + + try: + with patch("app.api.endpoints.agent.settings.AI_AGENT_ENABLE", True), patch( + "app.api.endpoints.agent._WebAgentMoviePilotAgent", + FakeWebAgent, + ): + response = asyncio.run(web_agent_stream(payload, request, user)) + body = "".join(asyncio.run(_collect_streaming_response(response))) + + assert "状态正常" in body + assert MessageChain._user_sessions["1"][0] == session_id + assert isinstance(agent_manager.active_agents[session_id], FakeWebAgent) + finally: + MessageChain._user_sessions.clear() + agent = agent_manager.active_agents.pop(session_id, None) + if agent: + asyncio.run(agent.cleanup()) + agent_manager._session_queues.pop(session_id, None) + worker = agent_manager._session_workers.pop(session_id, None) + if worker: + worker.cancel() + + async def _collect_streaming_response(response): """读取 StreamingResponse,便于断言 SSE 内容。""" chunks = [] @@ -356,6 +614,7 @@ def test_build_web_agent_notification_events_extracts_choice_card(): { "text": "继续下载", "callback_data": "agent_interaction:choice:req-1:1", + "description": "继续当前下载任务", } ], [ @@ -379,12 +638,28 @@ def test_build_web_agent_notification_events_extracts_choice_card(): { "label": "继续下载", "callback_data": "agent_interaction:choice:req-1:1", + "description": "继续当前下载任务", }, { "label": "查看详情", "callback_data": "agent_interaction:choice:req-1:2", }, ], + "button_rows": [ + [ + { + "label": "继续下载", + "callback_data": "agent_interaction:choice:req-1:1", + "description": "继续当前下载任务", + } + ], + [ + { + "label": "查看详情", + "callback_data": "agent_interaction:choice:req-1:2", + } + ], + ], }, } ] @@ -403,7 +678,7 @@ def test_resolve_web_agent_choice_payload_returns_next_message(): prompt="请选择", options=[ AgentInteractionOption(label="电影", value="我选择电影"), - AgentInteractionOption(label="电视剧", value="我选择电视剧"), + AgentInteractionOption(label="电视剧", value="我选择电视剧", description="选择电视剧并继续清理日志"), ], ) @@ -416,6 +691,12 @@ def test_resolve_web_agent_choice_payload_returns_next_message(): agent_interaction_manager.clear() assert result["message"] == "我选择电视剧" + assert result["display_message"] == "选择电视剧并继续清理日志" assert result["session_id"] == "web-agent:session" assert result["feedback"]["prompt"] == "请选择" assert result["feedback"]["selected_label"] == "电视剧" + assert result["feedback"]["selected_value"] == "我选择电视剧" + assert result["feedback"]["selected_description"] == "选择电视剧并继续清理日志" + assert result["choice_selection"]["prompt"] == "请选择" + assert result["choice_selection"]["selected_description"] == "选择电视剧并继续清理日志" + assert result["choice_selection"]["button_rows"][1][0]["description"] == "选择电视剧并继续清理日志"