mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-07-03 02:17:19 +08:00
feat(web-agent): enhance message handling with edit capabilities and button descriptions
This commit is contained in:
@@ -1635,6 +1635,10 @@ class _MessageTask:
|
||||
processing_status: Optional[dict] = None
|
||||
reply_mode: ReplyMode = ReplyMode.DISPATCH
|
||||
allow_message_tools: bool = True
|
||||
output_callback: Optional[Callable[[str], None]] = None
|
||||
notification_callback: Optional[Callable[[Any], None]] = None
|
||||
agent_factory: Optional[Callable[..., MoviePilotAgent]] = None
|
||||
completion_future: Optional[asyncio.Future] = None
|
||||
|
||||
|
||||
class AgentManager:
|
||||
@@ -1780,11 +1784,18 @@ class AgentManager:
|
||||
original_chat_id: Optional[str] = None,
|
||||
reply_mode: ReplyMode = ReplyMode.DISPATCH,
|
||||
allow_message_tools: bool = True,
|
||||
output_callback: Optional[Callable[[str], None]] = None,
|
||||
notification_callback: Optional[Callable[[Any], None]] = None,
|
||||
agent_factory: Optional[Callable[..., MoviePilotAgent]] = None,
|
||||
wait_for_completion: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
处理用户消息:将消息放入会话队列,按顺序依次处理。
|
||||
同一会话的消息排队等待,不同会话之间互不影响。
|
||||
"""
|
||||
completion_future = (
|
||||
asyncio.get_running_loop().create_future() if wait_for_completion else None
|
||||
)
|
||||
task = _MessageTask(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
@@ -1799,6 +1810,10 @@ class AgentManager:
|
||||
original_chat_id=original_chat_id,
|
||||
reply_mode=reply_mode,
|
||||
allow_message_tools=allow_message_tools,
|
||||
output_callback=output_callback,
|
||||
notification_callback=notification_callback,
|
||||
agent_factory=agent_factory,
|
||||
completion_future=completion_future,
|
||||
)
|
||||
self._record_session_activity(session_id, user_id)
|
||||
|
||||
@@ -1831,6 +1846,8 @@ class AgentManager:
|
||||
self._session_worker(session_id)
|
||||
)
|
||||
|
||||
if completion_future:
|
||||
return await completion_future
|
||||
return ""
|
||||
|
||||
async def _session_worker(self, session_id: str):
|
||||
@@ -1854,9 +1871,17 @@ class AgentManager:
|
||||
|
||||
try:
|
||||
await self._start_task_processing_status(task)
|
||||
await self._process_message_internal(task)
|
||||
result = await self._process_message_internal(task)
|
||||
if task.completion_future and not task.completion_future.done():
|
||||
task.completion_future.set_result(result)
|
||||
except asyncio.CancelledError as err:
|
||||
if task.completion_future and not task.completion_future.done():
|
||||
task.completion_future.set_exception(err)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"处理会话 {session_id} 的消息失败: {e}")
|
||||
if task.completion_future and not task.completion_future.done():
|
||||
task.completion_future.set_exception(e)
|
||||
finally:
|
||||
await self._finish_task_processing_status(task)
|
||||
queue.task_done()
|
||||
@@ -1895,21 +1920,36 @@ class AgentManager:
|
||||
实际处理单条消息
|
||||
"""
|
||||
session_id = task.session_id
|
||||
existing_agent = self.active_agents.get(session_id)
|
||||
if (
|
||||
existing_agent
|
||||
and task.agent_factory
|
||||
and isinstance(task.agent_factory, type)
|
||||
and not isinstance(existing_agent, task.agent_factory)
|
||||
):
|
||||
await existing_agent.cleanup()
|
||||
self.active_agents.pop(session_id, None)
|
||||
|
||||
if session_id not in self.active_agents:
|
||||
logger.info(
|
||||
f"创建新的AI智能体实例,session_id: {session_id}, user_id: {task.user_id}"
|
||||
)
|
||||
agent = MoviePilotAgent(
|
||||
session_id=session_id,
|
||||
user_id=task.user_id,
|
||||
channel=task.channel,
|
||||
source=task.source,
|
||||
username=task.username,
|
||||
original_message_id=task.original_message_id,
|
||||
original_chat_id=task.original_chat_id,
|
||||
replay_mode=task.reply_mode,
|
||||
allow_message_tools=task.allow_message_tools,
|
||||
)
|
||||
agent_factory = task.agent_factory or MoviePilotAgent
|
||||
agent_kwargs = {
|
||||
"session_id": session_id,
|
||||
"user_id": task.user_id,
|
||||
"channel": task.channel,
|
||||
"source": task.source,
|
||||
"username": task.username,
|
||||
"original_message_id": task.original_message_id,
|
||||
"original_chat_id": task.original_chat_id,
|
||||
"replay_mode": task.reply_mode,
|
||||
"allow_message_tools": task.allow_message_tools,
|
||||
"output_callback": task.output_callback,
|
||||
}
|
||||
if task.notification_callback is not None and task.agent_factory:
|
||||
agent_kwargs["notification_callback"] = task.notification_callback
|
||||
agent = agent_factory(**agent_kwargs)
|
||||
self.active_agents[session_id] = agent
|
||||
else:
|
||||
agent = self.active_agents[session_id]
|
||||
@@ -1924,6 +1964,12 @@ class AgentManager:
|
||||
agent.original_chat_id = task.original_chat_id
|
||||
agent.reply_mode = task.reply_mode
|
||||
agent.allow_message_tools = task.allow_message_tools
|
||||
if hasattr(agent, "set_output_callback"):
|
||||
agent.set_output_callback(task.output_callback)
|
||||
else:
|
||||
agent.output_callback = task.output_callback
|
||||
if task.notification_callback is not None and hasattr(agent, "set_notification_callback"):
|
||||
agent.set_notification_callback(task.notification_callback)
|
||||
|
||||
process_kwargs = {
|
||||
"images": task.images,
|
||||
|
||||
@@ -24,9 +24,14 @@ class UserChoiceOptionInput(BaseModel):
|
||||
...,
|
||||
description="The exact content that will be sent back to the agent after the user clicks this button",
|
||||
)
|
||||
description: Optional[str] = Field(
|
||||
None,
|
||||
description="Optional user-facing description stored in chat history after this option is selected",
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_option(self):
|
||||
"""校验按钮选项的文案和值不能为空。"""
|
||||
label = str(self.label)
|
||||
value = str(self.value)
|
||||
if not label.strip():
|
||||
@@ -56,6 +61,7 @@ class AskUserChoiceInput(BaseModel):
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_payload(self):
|
||||
"""校验按钮选择工具必须提供问题和选项。"""
|
||||
message = str(self.message)
|
||||
if not message.strip():
|
||||
raise ValueError("message 不能为空")
|
||||
@@ -85,6 +91,7 @@ class AskUserChoiceTool(MoviePilotTool):
|
||||
args_schema: Type[BaseModel] = AskUserChoiceInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""生成工具执行提示文案。"""
|
||||
message = kwargs.get("message", "") or ""
|
||||
if len(message) > 40:
|
||||
message = message[:40] + "..."
|
||||
@@ -92,6 +99,7 @@ class AskUserChoiceTool(MoviePilotTool):
|
||||
|
||||
@staticmethod
|
||||
def _truncate_button_text(text: str, max_length: int) -> str:
|
||||
"""按渠道限制截断按钮文案。"""
|
||||
if max_length <= 0 or len(text) <= max_length:
|
||||
return text
|
||||
if max_length <= 3:
|
||||
@@ -114,6 +122,14 @@ class AskUserChoiceTool(MoviePilotTool):
|
||||
title: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""
|
||||
发送按钮选择消息,并登记待回调的交互上下文。
|
||||
|
||||
:param message: 展示给用户的问题
|
||||
:param options: 可点击的选项列表
|
||||
:param title: 可选标题
|
||||
:return: 工具执行结果描述
|
||||
"""
|
||||
if self._blocked_by_feedback_quality_gate():
|
||||
logger.warning(
|
||||
"ask_user_choice blocked after feedback issue rejected_quality: "
|
||||
@@ -148,7 +164,9 @@ class AskUserChoiceTool(MoviePilotTool):
|
||||
|
||||
choice_options = [
|
||||
AgentInteractionOption(
|
||||
label=option.label.strip(), value=option.value.strip()
|
||||
label=option.label.strip(),
|
||||
value=option.value.strip(),
|
||||
description=(option.description.strip() if option.description else None),
|
||||
)
|
||||
for option in options
|
||||
]
|
||||
@@ -172,6 +190,7 @@ class AskUserChoiceTool(MoviePilotTool):
|
||||
"callback_data": (
|
||||
f"agent_interaction:choice:{request.request_id}:{index}"
|
||||
),
|
||||
"description": option.description or option.label,
|
||||
}
|
||||
)
|
||||
if len(current_row) >= max_per_row:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import copy
|
||||
import hashlib
|
||||
import json
|
||||
import mimetypes
|
||||
@@ -6,8 +7,10 @@ import shutil
|
||||
import subprocess
|
||||
import time
|
||||
import uuid
|
||||
from queue import Empty, Queue
|
||||
from pathlib import Path
|
||||
from typing import Any, AsyncIterator, Callable, Optional
|
||||
from threading import Lock
|
||||
from typing import Any, AsyncIterator, Callable, Optional, Union
|
||||
|
||||
from fastapi import APIRouter, Depends, File, Form, HTTPException, Request, UploadFile, status
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
@@ -15,17 +18,24 @@ from fastapi.responses import FileResponse, StreamingResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app import schemas
|
||||
from app.agent import MoviePilotAgent, ReplyMode, StreamingHandler
|
||||
from app.agent import MoviePilotAgent, ReplyMode, StreamingHandler, agent_manager
|
||||
from app.agent.llm.capability import AgentCapabilityManager
|
||||
from app.chain.message import MessageChain
|
||||
from app.chain.site import site_interaction_manager
|
||||
from app.chain.skills import skills_interaction_manager
|
||||
from app.chain.subscribe import subscribe_interaction_manager
|
||||
from app.command import Command
|
||||
from app.core.config import global_vars, settings
|
||||
from app.core.event import Event, EventManager
|
||||
from app.db import get_async_db
|
||||
from app.db.agentchat_oper import AgentChatOper
|
||||
from app.db.models import User
|
||||
from app.db.models.agentchat import AgentChat
|
||||
from app.db.user_oper import UserOper, get_current_active_user
|
||||
from app.helper.interaction import agent_interaction_manager
|
||||
from app.helper.agent import attach_web_agent_edit_queue, detach_web_agent_edit_queue
|
||||
from app.helper.interaction import agent_interaction_manager, media_interaction_manager
|
||||
from app.log import logger
|
||||
from app.schemas.types import MessageChannel
|
||||
from app.schemas.types import EventType, MessageChannel
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -36,7 +46,12 @@ WEB_AGENT_FILE_MAX_ITEMS = 256
|
||||
WEB_AGENT_UPLOAD_MAX_BYTES = 32 * 1024 * 1024
|
||||
WEB_AGENT_UPLOAD_CHUNK_SIZE = 1024 * 1024
|
||||
WEB_AGENT_BROWSER_AUDIO_SUFFIXES = {".aac", ".m4a", ".mp3", ".mp4", ".wav", ".wave"}
|
||||
WEB_AGENT_TRADITIONAL_IDLE_TIMEOUT_SECONDS = 2.0
|
||||
WEB_AGENT_TRADITIONAL_MAX_WAIT_SECONDS = 60.0
|
||||
_WEB_AGENT_FILE_REGISTRY: dict[str, dict[str, Any]] = {}
|
||||
_WEB_AGENT_NOTICE_QUEUES: dict[str, list[Queue[schemas.Notification]]] = {}
|
||||
_WEB_AGENT_NOTICE_LOCK = Lock()
|
||||
_WEB_AGENT_NOTICE_LISTENER_REGISTERED = False
|
||||
|
||||
|
||||
class _WebAgentStreamingHandler(StreamingHandler):
|
||||
@@ -48,6 +63,14 @@ class _WebAgentStreamingHandler(StreamingHandler):
|
||||
super().__init__()
|
||||
self._on_emit = on_emit
|
||||
|
||||
def set_emit_callback(self, on_emit: Callable[[str], None]) -> None:
|
||||
"""
|
||||
更新流式输出回调,复用 WebAgent 实例时指向当前 SSE 请求。
|
||||
|
||||
:param on_emit: 当前请求的输出回调
|
||||
"""
|
||||
self._on_emit = on_emit
|
||||
|
||||
def emit(self, token: str) -> str:
|
||||
"""追加 token 并同步通知 SSE 生产者。"""
|
||||
emitted = super().emit(token)
|
||||
@@ -124,6 +147,27 @@ class _WebAgentMoviePilotAgent(MoviePilotAgent):
|
||||
"""Web 面板需要实时输出,即使 Web 渠道本身不支持消息编辑。"""
|
||||
return True
|
||||
|
||||
def set_notification_callback(
|
||||
self,
|
||||
notification_callback: Optional[Callable[[schemas.Notification], None]],
|
||||
) -> None:
|
||||
"""
|
||||
更新 Web SSE 通知回调,复用 Agent 实例时指向当前请求队列。
|
||||
|
||||
:param notification_callback: 当前请求的 Web 通知回调
|
||||
"""
|
||||
self._notification_callback = notification_callback
|
||||
|
||||
def set_output_callback(self, output_callback: Optional[Callable[[str], None]]) -> None:
|
||||
"""
|
||||
更新 Web SSE 输出回调,复用 Agent 实例时指向当前请求队列。
|
||||
|
||||
:param output_callback: 当前请求的输出回调
|
||||
"""
|
||||
self.output_callback = output_callback
|
||||
if output_callback and isinstance(self.stream_handler, _WebAgentStreamingHandler):
|
||||
self.stream_handler.set_emit_callback(self._emit_output)
|
||||
|
||||
async def _is_system_admin_context(self) -> bool:
|
||||
"""Web Agent 根据当前登录用户 ID 判断工具管理员上下文。"""
|
||||
if not self.user_id:
|
||||
@@ -218,6 +262,14 @@ def _apply_web_agent_display_event(event: dict, assistant_message: dict) -> None
|
||||
assistant_message["attachments"].append(event["attachment"])
|
||||
elif event_type == "choice" and event.get("choice"):
|
||||
assistant_message["choices"].append({**event["choice"], "status": "pending"})
|
||||
elif event_type == "message_update":
|
||||
target_message = event.get("target_message") or {}
|
||||
assistant_message["id"] = target_message.get("id") or assistant_message.get("id")
|
||||
assistant_message["content"] = target_message.get("content") or ""
|
||||
assistant_message["attachments"] = target_message.get("attachments") or []
|
||||
assistant_message["choices"] = target_message.get("choices") or []
|
||||
assistant_message["tools"] = target_message.get("tools") or []
|
||||
assistant_message["status"] = target_message.get("status") or "done"
|
||||
elif event_type == "error":
|
||||
assistant_message["status"] = "error"
|
||||
assistant_message["content"] = (
|
||||
@@ -657,27 +709,32 @@ def _parse_web_agent_choice_callback(callback_data: str) -> Optional[tuple[str,
|
||||
return request_id, int(option_index)
|
||||
|
||||
|
||||
def _flatten_web_agent_choice_buttons(buttons: Optional[list[list[dict]]]) -> list[dict]:
|
||||
def _normalize_web_agent_choice_button_rows(buttons: Optional[list[list[dict]]]) -> list[list[dict]]:
|
||||
"""
|
||||
将消息渠道按钮二维结构转换为 Web 前端可渲染的一维选项列表。
|
||||
将消息渠道按钮二维结构转换为 Web 前端可渲染的按钮行。
|
||||
|
||||
:param buttons: Notification 中的按钮行
|
||||
:return: Web 选择卡片按钮列表
|
||||
:return: Web 选择卡片按钮行
|
||||
"""
|
||||
flattened = []
|
||||
normalized_rows = []
|
||||
for row in buttons or []:
|
||||
normalized_row = []
|
||||
for button in row or []:
|
||||
text = str(button.get("text") or "").strip()
|
||||
callback_data = str(button.get("callback_data") or "").strip()
|
||||
if not text or not callback_data:
|
||||
continue
|
||||
flattened.append(
|
||||
description = str(button.get("description") or "").strip()
|
||||
normalized_row.append(
|
||||
{
|
||||
"label": text,
|
||||
"callback_data": callback_data,
|
||||
**({"description": description} if description else {}),
|
||||
}
|
||||
)
|
||||
return flattened
|
||||
if normalized_row:
|
||||
normalized_rows.append(normalized_row)
|
||||
return normalized_rows
|
||||
|
||||
|
||||
def _build_web_agent_choice_event(notification: schemas.Notification) -> Optional[dict]:
|
||||
@@ -687,7 +744,8 @@ def _build_web_agent_choice_event(notification: schemas.Notification) -> Optiona
|
||||
:param notification: Agent 工具发出的按钮通知
|
||||
:return: 选择卡片事件,按钮为空时返回 None
|
||||
"""
|
||||
buttons = _flatten_web_agent_choice_buttons(notification.buttons)
|
||||
button_rows = _normalize_web_agent_choice_button_rows(notification.buttons)
|
||||
buttons = [button for row in button_rows for button in row]
|
||||
if not buttons:
|
||||
return None
|
||||
|
||||
@@ -703,10 +761,32 @@ def _build_web_agent_choice_event(notification: schemas.Notification) -> Optiona
|
||||
"title": notification.title,
|
||||
"prompt": notification.text or "",
|
||||
"buttons": buttons,
|
||||
"button_rows": button_rows,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _build_web_agent_choice_buttons_from_request(
|
||||
request,
|
||||
) -> tuple[list[dict], list[list[dict]]]:
|
||||
"""
|
||||
根据待处理交互请求重建可持久化的按钮列表与按钮行。
|
||||
|
||||
:param request: 等待用户选择的交互请求
|
||||
:return: 平铺按钮列表与按行分组的按钮结构
|
||||
"""
|
||||
buttons = [
|
||||
{
|
||||
"label": option.label,
|
||||
"callback_data": f"agent_interaction:choice:{request.request_id}:{index}",
|
||||
"description": option.description or option.label,
|
||||
}
|
||||
for index, option in enumerate(request.options, start=1)
|
||||
]
|
||||
button_rows = [[button] for button in buttons]
|
||||
return buttons, button_rows
|
||||
|
||||
|
||||
def _resolve_web_agent_choice_payload(callback_data: str, user_id: str) -> Optional[dict]:
|
||||
"""
|
||||
解析并消费 Web Agent 按钮选择,生成前端反馈与下一条用户消息。
|
||||
@@ -729,15 +809,31 @@ def _resolve_web_agent_choice_payload(callback_data: str, user_id: str) -> Optio
|
||||
return None
|
||||
|
||||
request, option = resolved
|
||||
buttons, button_rows = _build_web_agent_choice_buttons_from_request(request)
|
||||
selected_description = option.description or option.label
|
||||
return {
|
||||
"message": option.value,
|
||||
"session_id": request.session_id,
|
||||
"display_message": selected_description,
|
||||
"choice_selection": {
|
||||
"choice_id": request.request_id,
|
||||
"title": request.title,
|
||||
"prompt": request.prompt,
|
||||
"buttons": buttons,
|
||||
"button_rows": button_rows,
|
||||
"selected_label": option.label,
|
||||
"selected_value": option.value,
|
||||
"selected_description": selected_description,
|
||||
},
|
||||
"feedback": {
|
||||
"request_id": request.request_id,
|
||||
"title": request.title,
|
||||
"prompt": request.prompt,
|
||||
"selected_label": option.label,
|
||||
"selected_value": option.value,
|
||||
"selected_description": selected_description,
|
||||
"buttons": buttons,
|
||||
"button_rows": button_rows,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -802,6 +898,354 @@ def _build_web_agent_notification_events(
|
||||
return events
|
||||
|
||||
|
||||
def _build_web_agent_display_message_from_events(
|
||||
events: list[dict],
|
||||
) -> dict:
|
||||
"""
|
||||
将传统消息事件聚合为前端展示消息快照。
|
||||
|
||||
:param events: 已转换的 WebAgent SSE 事件列表
|
||||
:return: 可持久化的助手展示消息
|
||||
"""
|
||||
message = MoviePilotAgent.build_display_message(
|
||||
role="assistant",
|
||||
status="streaming",
|
||||
)
|
||||
for event in events:
|
||||
_apply_web_agent_display_event(copy.deepcopy(event), message)
|
||||
_apply_web_agent_display_event({"type": "done"}, message)
|
||||
return message
|
||||
|
||||
|
||||
def _is_web_agent_traditional_message(text: str) -> bool:
|
||||
"""
|
||||
判断用户输入是否应走传统消息命令/交互链路。
|
||||
|
||||
:param text: 前端输入文本
|
||||
:return: 需要交给 MessageChain 时返回 True
|
||||
"""
|
||||
normalized = str(text or "").strip()
|
||||
return normalized.startswith("/") or normalized.startswith("CALLBACK:")
|
||||
|
||||
|
||||
def _has_web_agent_traditional_interaction(user_id: str) -> bool:
|
||||
"""
|
||||
判断当前用户是否存在待继续的传统交互会话。
|
||||
|
||||
:param user_id: 当前登录用户 ID
|
||||
:return: 存在传统交互上下文时返回 True
|
||||
"""
|
||||
return any(
|
||||
manager.get_by_user(user_id)
|
||||
for manager in (
|
||||
site_interaction_manager,
|
||||
subscribe_interaction_manager,
|
||||
skills_interaction_manager,
|
||||
media_interaction_manager,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _extract_web_agent_notification_from_event_data(
|
||||
data: dict,
|
||||
) -> Optional[schemas.Notification]:
|
||||
"""
|
||||
从 NoticeMessage 事件数据中提取 WebAgent 通知。
|
||||
|
||||
:param data: NoticeMessage 事件数据,兼容扁平字段和 message 包装格式
|
||||
:return: WebAgent 通知,不属于 WebAgent 或数据无效时返回 None
|
||||
"""
|
||||
if not isinstance(data, dict):
|
||||
return None
|
||||
|
||||
try:
|
||||
message = data.get("message")
|
||||
if isinstance(message, schemas.Notification):
|
||||
notification = message
|
||||
elif isinstance(message, dict):
|
||||
notification_data = copy.deepcopy(message)
|
||||
notification_data.pop("type", None)
|
||||
notification = schemas.Notification(**notification_data)
|
||||
else:
|
||||
notification_data = copy.deepcopy(data)
|
||||
notification_data.pop("type", None)
|
||||
notification_data.pop("current_time", None)
|
||||
notification = schemas.Notification(**notification_data)
|
||||
except Exception as err:
|
||||
logger.debug(f"解析WebAgent通知事件失败: {err}")
|
||||
return None
|
||||
|
||||
channel = notification.channel
|
||||
channel_value = channel.value if isinstance(channel, MessageChannel) else channel
|
||||
if channel_value != MessageChannel.WebAgent.value:
|
||||
return None
|
||||
return notification
|
||||
|
||||
|
||||
def _is_web_agent_notice_for_user(
|
||||
notification: schemas.Notification,
|
||||
user_id: str,
|
||||
) -> bool:
|
||||
"""
|
||||
判断 NoticeMessage 事件是否属于当前 WebAgent 用户。
|
||||
|
||||
:param notification: NoticeMessage 中的通知消息
|
||||
:param user_id: 当前登录用户 ID
|
||||
:return: 可被本次 WebAgent 请求消费时返回 True
|
||||
"""
|
||||
try:
|
||||
target_user = notification.userid
|
||||
return target_user is None or str(target_user) == str(user_id)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _get_web_agent_notice_user_id(notification: schemas.Notification) -> Optional[str]:
|
||||
"""
|
||||
从 NoticeMessage 事件中解析 WebAgent 目标用户。
|
||||
|
||||
:param notification: NoticeMessage 中的通知消息
|
||||
:return: 用户 ID 字符串,事件不属于 WebAgent 时返回 None
|
||||
"""
|
||||
try:
|
||||
channel = notification.channel
|
||||
channel_value = channel.value if isinstance(channel, MessageChannel) else channel
|
||||
if channel_value != MessageChannel.WebAgent.value:
|
||||
return None
|
||||
user_id = notification.userid
|
||||
return str(user_id) if user_id is not None else None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _dispatch_web_agent_notice_event(event: Event) -> None:
|
||||
"""
|
||||
将 WebAgent NoticeMessage 分发给正在等待的请求队列。
|
||||
|
||||
:param event: NoticeMessage 广播事件
|
||||
"""
|
||||
data = event.event_data if isinstance(event.event_data, dict) else {}
|
||||
notification = _extract_web_agent_notification_from_event_data(data)
|
||||
if not notification:
|
||||
return
|
||||
with _WEB_AGENT_NOTICE_LOCK:
|
||||
user_id = _get_web_agent_notice_user_id(notification)
|
||||
if user_id is None:
|
||||
queues = [
|
||||
notice_queue
|
||||
for user_queues in _WEB_AGENT_NOTICE_QUEUES.values()
|
||||
for notice_queue in user_queues
|
||||
]
|
||||
else:
|
||||
queues = list(_WEB_AGENT_NOTICE_QUEUES.get(user_id) or [])
|
||||
for notice_queue in queues:
|
||||
notice_queue.put(notification)
|
||||
|
||||
|
||||
def _ensure_web_agent_notice_listener() -> None:
|
||||
"""
|
||||
确保 WebAgent NoticeMessage 全局监听器已注册。
|
||||
"""
|
||||
global _WEB_AGENT_NOTICE_LISTENER_REGISTERED
|
||||
if _WEB_AGENT_NOTICE_LISTENER_REGISTERED:
|
||||
return
|
||||
with _WEB_AGENT_NOTICE_LOCK:
|
||||
if _WEB_AGENT_NOTICE_LISTENER_REGISTERED:
|
||||
return
|
||||
EventManager().add_event_listener(
|
||||
EventType.NoticeMessage,
|
||||
_dispatch_web_agent_notice_event,
|
||||
)
|
||||
_WEB_AGENT_NOTICE_LISTENER_REGISTERED = True
|
||||
|
||||
|
||||
def _attach_web_agent_notice_queue(user_id: str, notice_queue: Queue[schemas.Notification]) -> None:
|
||||
"""
|
||||
为当前 WebAgent 请求挂载通知收集队列。
|
||||
|
||||
:param user_id: 当前用户 ID
|
||||
:param notice_queue: 用于接收通知事件的队列
|
||||
"""
|
||||
_ensure_web_agent_notice_listener()
|
||||
with _WEB_AGENT_NOTICE_LOCK:
|
||||
_WEB_AGENT_NOTICE_QUEUES.setdefault(str(user_id), []).append(notice_queue)
|
||||
|
||||
|
||||
def _detach_web_agent_notice_queue(user_id: str, notice_queue: Queue[schemas.Notification]) -> None:
|
||||
"""
|
||||
移除当前 WebAgent 请求的通知收集队列。
|
||||
|
||||
:param user_id: 当前用户 ID
|
||||
:param notice_queue: 需要移除的队列
|
||||
"""
|
||||
with _WEB_AGENT_NOTICE_LOCK:
|
||||
queues = _WEB_AGENT_NOTICE_QUEUES.get(str(user_id))
|
||||
if not queues:
|
||||
return
|
||||
_WEB_AGENT_NOTICE_QUEUES[str(user_id)] = [
|
||||
item for item in queues if item is not notice_queue
|
||||
]
|
||||
if not _WEB_AGENT_NOTICE_QUEUES[str(user_id)]:
|
||||
_WEB_AGENT_NOTICE_QUEUES.pop(str(user_id), None)
|
||||
|
||||
|
||||
def _build_web_agent_command_items() -> list[dict]:
|
||||
"""
|
||||
读取当前可用斜杠命令并转换为前端建议列表。
|
||||
|
||||
:return: 按分类和命令名排序的命令列表
|
||||
"""
|
||||
commands = Command().get_commands() or {}
|
||||
items = []
|
||||
for command, data in commands.items():
|
||||
if not command.startswith("/"):
|
||||
continue
|
||||
if data.get("show") is False:
|
||||
continue
|
||||
items.append(
|
||||
{
|
||||
"command": command,
|
||||
"description": data.get("description") or "",
|
||||
"category": data.get("category") or "其他",
|
||||
"type": data.get("type") or "",
|
||||
"pid": data.get("pid"),
|
||||
}
|
||||
)
|
||||
return sorted(items, key=lambda item: (item["category"], item["command"]))
|
||||
|
||||
|
||||
def _extract_web_agent_slash_command(text: str) -> Optional[str]:
|
||||
"""
|
||||
从 WebAgent 输入中提取斜杠命令名。
|
||||
|
||||
:param text: 前端输入文本
|
||||
:return: 斜杠命令名,非命令输入返回 None
|
||||
"""
|
||||
normalized = str(text or "").strip()
|
||||
if not normalized.startswith("/") or normalized.startswith("//"):
|
||||
return None
|
||||
command = normalized.split(maxsplit=1)[0].strip()
|
||||
return command or None
|
||||
|
||||
|
||||
def _get_web_agent_unknown_command_message(text: str) -> Optional[str]:
|
||||
"""
|
||||
判断 WebAgent 斜杠命令是否不存在。
|
||||
|
||||
:param text: 前端输入文本
|
||||
:return: 命令不存在时返回错误提示,命令存在或非命令时返回 None
|
||||
"""
|
||||
command = _extract_web_agent_slash_command(text)
|
||||
if not command:
|
||||
return None
|
||||
if Command().get(command):
|
||||
return None
|
||||
return f"命令不存在:{command}"
|
||||
|
||||
|
||||
def _ensure_web_agent_command_allowed(current_user: User) -> Optional[str]:
|
||||
"""
|
||||
校验当前 Web 用户是否可以执行传统斜杠命令。
|
||||
|
||||
:param current_user: 当前登录用户
|
||||
:return: 无权限时返回错误提示,允许执行时返回 None
|
||||
"""
|
||||
if getattr(current_user, "is_superuser", False):
|
||||
return None
|
||||
return "只有管理员才有权限执行此命令"
|
||||
|
||||
|
||||
async def _collect_web_agent_traditional_events(
|
||||
*,
|
||||
text: str,
|
||||
current_user: User,
|
||||
original_message_id: Optional[Union[str, int]] = None,
|
||||
original_chat_id: Optional[Union[str, int]] = None,
|
||||
) -> list[dict]:
|
||||
"""
|
||||
执行传统消息链路并收集本次 WebAgent 用户产生的通知事件。
|
||||
|
||||
:param text: 需要交给传统消息链路处理的文本
|
||||
:param current_user: 当前登录用户
|
||||
:param original_message_id: WebAgent 原助手消息 ID
|
||||
:param original_chat_id: WebAgent 原聊天 ID
|
||||
:return: 可直接发送给前端的 SSE 事件列表
|
||||
"""
|
||||
notice_queue: Queue[schemas.Notification] = Queue()
|
||||
edit_queue: Queue[dict] = Queue()
|
||||
user_id = str(current_user.id)
|
||||
|
||||
_attach_web_agent_notice_queue(user_id, notice_queue)
|
||||
attach_web_agent_edit_queue(user_id, edit_queue)
|
||||
try:
|
||||
await run_in_threadpool(
|
||||
MessageChain().handle_message,
|
||||
channel=MessageChannel.WebAgent,
|
||||
source=WEB_AGENT_SOURCE,
|
||||
userid=user_id,
|
||||
username=current_user.name or user_id,
|
||||
text=text,
|
||||
original_message_id=original_message_id,
|
||||
original_chat_id=original_chat_id,
|
||||
)
|
||||
|
||||
events = []
|
||||
deadline = time.monotonic() + WEB_AGENT_TRADITIONAL_MAX_WAIT_SECONDS
|
||||
idle_deadline: Optional[float] = None
|
||||
while time.monotonic() < deadline:
|
||||
now = time.monotonic()
|
||||
drained_edit_event = False
|
||||
while True:
|
||||
try:
|
||||
events.append(edit_queue.get_nowait())
|
||||
drained_edit_event = True
|
||||
except Empty:
|
||||
break
|
||||
if drained_edit_event:
|
||||
idle_deadline = time.monotonic() + WEB_AGENT_TRADITIONAL_IDLE_TIMEOUT_SECONDS
|
||||
continue
|
||||
|
||||
wait_until = idle_deadline or deadline
|
||||
timeout = max(0.05, min(0.25, wait_until - now, deadline - now))
|
||||
try:
|
||||
notification = await asyncio.to_thread(notice_queue.get, True, timeout)
|
||||
except Empty:
|
||||
if idle_deadline and time.monotonic() >= idle_deadline:
|
||||
break
|
||||
continue
|
||||
|
||||
if not _is_web_agent_notice_for_user(notification, user_id):
|
||||
continue
|
||||
events.extend(_build_web_agent_notification_events(notification))
|
||||
idle_deadline = time.monotonic() + WEB_AGENT_TRADITIONAL_IDLE_TIMEOUT_SECONDS
|
||||
return events
|
||||
finally:
|
||||
_detach_web_agent_notice_queue(user_id, notice_queue)
|
||||
detach_web_agent_edit_queue(user_id, edit_queue)
|
||||
|
||||
|
||||
def _build_web_agent_traditional_callback_payload(
|
||||
callback_data: str,
|
||||
original_message_id: Optional[Union[str, int]] = None,
|
||||
original_chat_id: Optional[Union[str, int]] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
构造传统消息链按钮回调的前端执行载荷。
|
||||
|
||||
:param callback_data: 前端点击的传统按钮回调数据
|
||||
:param original_message_id: WebAgent 原助手消息 ID
|
||||
:param original_chat_id: WebAgent 原聊天 ID
|
||||
:return: 前端可继续发送到 /stream 的消息载荷
|
||||
"""
|
||||
return {
|
||||
"message": f"CALLBACK:{callback_data}",
|
||||
"display_message": callback_data,
|
||||
"traditional": True,
|
||||
"original_message_id": original_message_id,
|
||||
"original_chat_id": original_chat_id,
|
||||
}
|
||||
|
||||
|
||||
def _split_web_agent_output(text: str) -> list[dict]:
|
||||
"""
|
||||
将 Agent 输出拆成普通文本与工具提示事件。
|
||||
@@ -943,6 +1387,19 @@ async def web_agent_callback(
|
||||
:param current_user: 当前登录用户
|
||||
:return: 下一条需要发送给 Agent 的用户消息与卡片反馈
|
||||
"""
|
||||
if not _parse_web_agent_choice_callback(payload.callback_data):
|
||||
denied_message = _ensure_web_agent_command_allowed(current_user)
|
||||
if denied_message:
|
||||
return schemas.Response(success=False, message=denied_message)
|
||||
return schemas.Response(
|
||||
success=True,
|
||||
data=_build_web_agent_traditional_callback_payload(
|
||||
payload.callback_data,
|
||||
original_message_id=payload.original_message_id,
|
||||
original_chat_id=payload.original_chat_id,
|
||||
),
|
||||
)
|
||||
|
||||
result = _resolve_web_agent_choice_payload(
|
||||
callback_data=payload.callback_data,
|
||||
user_id=str(current_user.id),
|
||||
@@ -952,6 +1409,22 @@ async def web_agent_callback(
|
||||
return schemas.Response(success=True, data=result)
|
||||
|
||||
|
||||
@router.get("/commands", summary="获取 Web 智能助手可用命令", response_model=schemas.Response)
|
||||
async def list_web_agent_commands(
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
) -> schemas.Response:
|
||||
"""
|
||||
获取当前 Web 智能助手可补全的斜杠命令。
|
||||
|
||||
:param current_user: 当前登录用户
|
||||
:return: 可用命令列表
|
||||
"""
|
||||
denied_message = _ensure_web_agent_command_allowed(current_user)
|
||||
if denied_message:
|
||||
return schemas.Response(success=False, message=denied_message)
|
||||
return schemas.Response(success=True, data=_build_web_agent_command_items())
|
||||
|
||||
|
||||
@router.get("/sessions", summary="获取 Agent 历史会话", response_model=schemas.Response)
|
||||
async def list_agent_chat_sessions(
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
@@ -1076,6 +1549,92 @@ async def web_agent_stream(
|
||||
:param current_user: 当前登录用户
|
||||
:return: SSE 流式响应
|
||||
"""
|
||||
prompt = payload.text.strip()
|
||||
display_prompt = (payload.display_text or payload.text).strip()
|
||||
is_traditional_message = (
|
||||
_is_web_agent_traditional_message(prompt)
|
||||
or _has_web_agent_traditional_interaction(str(current_user.id))
|
||||
)
|
||||
if is_traditional_message:
|
||||
denied_message = _ensure_web_agent_command_allowed(current_user)
|
||||
if denied_message:
|
||||
return StreamingResponse(
|
||||
iter([
|
||||
_build_web_agent_sse(
|
||||
"error",
|
||||
{"message": denied_message},
|
||||
)
|
||||
]),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
unknown_command_message = _get_web_agent_unknown_command_message(prompt)
|
||||
if unknown_command_message:
|
||||
return StreamingResponse(
|
||||
iter([
|
||||
_build_web_agent_sse(
|
||||
"error",
|
||||
{"message": unknown_command_message},
|
||||
)
|
||||
]),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
session_id = _build_web_agent_session_id(current_user, payload.session_id)
|
||||
user_attachments = _build_web_agent_input_attachments(
|
||||
images=payload.images or [],
|
||||
files=[
|
||||
file.model_dump(exclude_none=True)
|
||||
for file in (payload.files or [])
|
||||
],
|
||||
audio_refs=payload.audio_refs or [],
|
||||
)
|
||||
display_messages = []
|
||||
if payload.echo_user:
|
||||
display_messages.append(
|
||||
MoviePilotAgent.build_display_message(
|
||||
role="user",
|
||||
content=display_prompt or prompt,
|
||||
attachments=user_attachments,
|
||||
)
|
||||
)
|
||||
|
||||
async def traditional_event_generator() -> AsyncIterator[str]:
|
||||
"""
|
||||
生成传统消息链路的 WebAgent SSE 事件。
|
||||
"""
|
||||
yield _build_web_agent_sse("start", {"session_id": session_id})
|
||||
events = await _collect_web_agent_traditional_events(
|
||||
text=prompt,
|
||||
current_user=current_user,
|
||||
original_message_id=payload.original_message_id,
|
||||
original_chat_id=payload.original_chat_id,
|
||||
)
|
||||
assistant_message = _build_web_agent_display_message_from_events(events)
|
||||
display_messages.append(assistant_message)
|
||||
for event in events:
|
||||
event_payload = copy.deepcopy(event)
|
||||
yield _build_web_agent_sse(event_payload.pop("type"), event_payload)
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
await run_in_threadpool(
|
||||
_save_web_agent_display_snapshot,
|
||||
session_id=session_id,
|
||||
current_user=current_user,
|
||||
messages=display_messages,
|
||||
client_session_id=payload.session_id or session_id,
|
||||
)
|
||||
yield _build_web_agent_sse("done", {})
|
||||
|
||||
return StreamingResponse(
|
||||
traditional_event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
if not settings.AI_AGENT_ENABLE:
|
||||
return StreamingResponse(
|
||||
iter([
|
||||
@@ -1087,9 +1646,9 @@ async def web_agent_stream(
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
prompt = payload.text.strip()
|
||||
transcript = _transcribe_web_agent_audio_refs(payload.audio_refs or [])
|
||||
prompt = _merge_web_agent_prompt_with_transcript(prompt, transcript)
|
||||
display_prompt = _merge_web_agent_prompt_with_transcript(display_prompt, transcript)
|
||||
has_audio_input = bool(transcript)
|
||||
if not prompt and payload.audio_refs and not payload.images and not payload.files:
|
||||
return StreamingResponse(
|
||||
@@ -1113,6 +1672,7 @@ async def web_agent_stream(
|
||||
)
|
||||
|
||||
session_id = _build_web_agent_session_id(current_user, payload.session_id)
|
||||
MessageChain().bind_user_session(str(current_user.id), session_id)
|
||||
event_queue: asyncio.Queue = asyncio.Queue()
|
||||
last_output = ""
|
||||
user_attachments = _build_web_agent_input_attachments(
|
||||
@@ -1125,13 +1685,14 @@ async def web_agent_stream(
|
||||
)
|
||||
display_messages = []
|
||||
if payload.echo_user:
|
||||
display_messages.append(
|
||||
MoviePilotAgent.build_display_message(
|
||||
role="user",
|
||||
content=prompt,
|
||||
attachments=user_attachments,
|
||||
)
|
||||
user_display_message = MoviePilotAgent.build_display_message(
|
||||
role="user",
|
||||
content=display_prompt or prompt,
|
||||
attachments=user_attachments,
|
||||
)
|
||||
if payload.choice_selection:
|
||||
user_display_message["choice_selection"] = payload.choice_selection
|
||||
display_messages.append(user_display_message)
|
||||
assistant_display_message = MoviePilotAgent.build_display_message(
|
||||
role="assistant",
|
||||
status="streaming",
|
||||
@@ -1170,26 +1731,25 @@ async def web_agent_stream(
|
||||
for audio_ref in payload.audio_refs or []:
|
||||
files.append({"ref": audio_ref, "mime_type": "audio/*"})
|
||||
|
||||
agent = _WebAgentMoviePilotAgent(
|
||||
session_id=session_id,
|
||||
user_id=str(current_user.id),
|
||||
channel=MessageChannel.WebAgent.value,
|
||||
source=WEB_AGENT_SOURCE,
|
||||
username=current_user.name,
|
||||
replay_mode=ReplyMode.CAPTURE_ONLY,
|
||||
allow_message_tools=True,
|
||||
output_callback=output_callback,
|
||||
notification_callback=notification_callback,
|
||||
)
|
||||
|
||||
async def run_agent() -> None:
|
||||
"""后台执行 Agent,并将结果写入事件队列。"""
|
||||
try:
|
||||
await agent.process(
|
||||
await agent_manager.process_message(
|
||||
session_id=session_id,
|
||||
user_id=str(current_user.id),
|
||||
message=prompt,
|
||||
images=payload.images or [],
|
||||
files=files or None,
|
||||
has_audio_input=has_audio_input,
|
||||
channel=MessageChannel.WebAgent.value,
|
||||
source=WEB_AGENT_SOURCE,
|
||||
username=current_user.name,
|
||||
reply_mode=ReplyMode.CAPTURE_ONLY,
|
||||
allow_message_tools=True,
|
||||
output_callback=output_callback,
|
||||
notification_callback=notification_callback,
|
||||
agent_factory=_WebAgentMoviePilotAgent,
|
||||
wait_for_completion=True,
|
||||
)
|
||||
except Exception as err:
|
||||
logger.error(f"Web智能助手执行失败: {str(err)}")
|
||||
@@ -1233,7 +1793,7 @@ async def web_agent_stream(
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
await agent.cleanup()
|
||||
# WebAgent 会话由 AgentManager 统一管理,空闲清理或 /clear_session 时释放。
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
|
||||
@@ -1769,6 +1769,21 @@ class ChainBase(metaclass=ABCMeta):
|
||||
:param metadata: 其他消息元数据
|
||||
:return: 编辑是否成功
|
||||
"""
|
||||
if channel == MessageChannel.WebAgent:
|
||||
try:
|
||||
from app.helper.agent import edit_web_agent_message
|
||||
|
||||
return edit_web_agent_message(
|
||||
user_id=str((metadata or {}).get("userid") or ""),
|
||||
message_id=message_id,
|
||||
title=title,
|
||||
text=text,
|
||||
buttons=buttons,
|
||||
)
|
||||
except Exception as err:
|
||||
logger.debug(f"编辑 WebAgent 消息失败: {err}")
|
||||
return False
|
||||
|
||||
return self.run_module(
|
||||
"edit_message",
|
||||
channel=channel,
|
||||
|
||||
@@ -1000,6 +1000,15 @@ class MessageChain(ChainBase):
|
||||
self._schedule_agent_session_clear(old_session[0], userid)
|
||||
self._user_sessions[userid] = (session_id, datetime.now())
|
||||
|
||||
def bind_user_session(self, userid: Union[str, int], session_id: str) -> None:
|
||||
"""
|
||||
绑定用户与指定智能体会话,供非传统入口复用远程命令状态查询。
|
||||
|
||||
:param userid: 用户 ID
|
||||
:param session_id: 智能体会话 ID
|
||||
"""
|
||||
self._bind_session_id(userid, session_id)
|
||||
|
||||
def _record_user_message(
|
||||
self,
|
||||
channel: MessageChannel,
|
||||
|
||||
@@ -722,7 +722,7 @@ class SkillsChain(ChainBase):
|
||||
buttons = None
|
||||
if self._supports_interactive_buttons(request.channel):
|
||||
buttons = []
|
||||
for index, skill in enumerate(page_items, start=1):
|
||||
for index, skill in enumerate(items, start=1):
|
||||
if not skill.removable:
|
||||
continue
|
||||
buttons.append(
|
||||
|
||||
@@ -525,6 +525,13 @@ class EventManager(metaclass=Singleton):
|
||||
class_name=class_name, method_name=method_name, e=e)
|
||||
else:
|
||||
# 全局处理器
|
||||
if not class_name:
|
||||
try:
|
||||
handler(event)
|
||||
except Exception as e:
|
||||
self.__handle_event_error(event=event, module_name=self.__get_handler_identifier(handler),
|
||||
class_name=class_name, method_name=method_name, e=e)
|
||||
return
|
||||
class_obj = self.__get_class_instance(class_name)
|
||||
if not class_obj or not hasattr(class_obj, method_name):
|
||||
return
|
||||
@@ -556,6 +563,16 @@ class EventManager(metaclass=Singleton):
|
||||
elif class_name in module_manager.get_module_ids():
|
||||
await self.__invoke_module_method_async(module_manager, class_name, method_name, event)
|
||||
else:
|
||||
if not class_name:
|
||||
try:
|
||||
if inspect.iscoroutinefunction(handler):
|
||||
await handler(event)
|
||||
else:
|
||||
await run_in_threadpool(handler, event)
|
||||
except Exception as e:
|
||||
self.__handle_event_error(event=event, module_name=self.__get_handler_identifier(handler),
|
||||
class_name=class_name, method_name=method_name, e=e)
|
||||
return
|
||||
await self.__invoke_global_method_async(class_name, method_name, event)
|
||||
|
||||
@staticmethod
|
||||
@@ -566,6 +583,8 @@ class EventManager(metaclass=Singleton):
|
||||
:return: (class_name, method_name)
|
||||
"""
|
||||
names = handler.__qualname__.split(".")
|
||||
if len(names) < 2:
|
||||
return "", names[0]
|
||||
return names[0], names[1]
|
||||
|
||||
async def __invoke_plugin_method_async(self, handler: Any, class_name: str, method_name: str, event: Event):
|
||||
|
||||
175
app/helper/agent.py
Normal file
175
app/helper/agent.py
Normal file
@@ -0,0 +1,175 @@
|
||||
from queue import Queue
|
||||
from threading import Lock
|
||||
from typing import Optional, Union
|
||||
|
||||
|
||||
_WEB_AGENT_EDIT_QUEUES: dict[str, list[Queue[dict]]] = {}
|
||||
_WEB_AGENT_EDIT_LOCK = Lock()
|
||||
|
||||
|
||||
def normalize_web_agent_button_rows(buttons: Optional[list[list[dict]]]) -> list[list[dict]]:
|
||||
"""
|
||||
将消息按钮转换为 WebAgent 前端可识别的按钮行。
|
||||
|
||||
:param buttons: 传统消息模块返回的按钮二维数组
|
||||
:return: WebAgent 前端选项按钮二维数组
|
||||
"""
|
||||
button_rows: list[list[dict]] = []
|
||||
for row in buttons or []:
|
||||
normalized_row = []
|
||||
for button in row or []:
|
||||
label = str(button.get("text") or button.get("label") or "").strip()
|
||||
callback_data = str(button.get("callback_data") or "").strip()
|
||||
if not label or not callback_data:
|
||||
continue
|
||||
normalized_button = {
|
||||
"label": label,
|
||||
"callback_data": callback_data,
|
||||
}
|
||||
if button.get("description"):
|
||||
normalized_button["description"] = str(button.get("description"))
|
||||
normalized_row.append(normalized_button)
|
||||
if normalized_row:
|
||||
button_rows.append(normalized_row)
|
||||
return button_rows
|
||||
|
||||
|
||||
def _resolve_web_agent_choice_id(
|
||||
message_id: Union[str, int],
|
||||
button_rows: list[list[dict]],
|
||||
) -> str:
|
||||
"""
|
||||
从按钮回调中提取稳定的 WebAgent 选项 ID。
|
||||
|
||||
:param message_id: 前端助手消息 ID
|
||||
:param button_rows: 已规范化的按钮行
|
||||
:return: 选项卡片 ID
|
||||
"""
|
||||
for row in button_rows:
|
||||
for button in row:
|
||||
callback_data = str(button.get("callback_data") or "").strip()
|
||||
if not callback_data:
|
||||
continue
|
||||
parts = callback_data.split(":")
|
||||
if len(parts) >= 2 and parts[1]:
|
||||
return parts[1]
|
||||
return callback_data
|
||||
return str(message_id)
|
||||
|
||||
|
||||
def build_web_agent_message_update_event(
|
||||
*,
|
||||
message_id: Union[str, int],
|
||||
title: Optional[str],
|
||||
text: str,
|
||||
buttons: Optional[list[list[dict]]],
|
||||
) -> dict:
|
||||
"""
|
||||
构造 WebAgent 原消息更新事件。
|
||||
|
||||
:param message_id: 前端助手消息 ID
|
||||
:param title: 更新后的标题
|
||||
:param text: 更新后的正文
|
||||
:param buttons: 更新后的按钮
|
||||
:return: 前端可应用到原消息的 SSE 事件
|
||||
"""
|
||||
button_rows = normalize_web_agent_button_rows(buttons)
|
||||
content_parts = [part for part in (title, text) if part]
|
||||
target_message = {
|
||||
"id": str(message_id),
|
||||
"content": "" if button_rows else "\n\n".join(content_parts),
|
||||
"choices": [],
|
||||
"attachments": [],
|
||||
"tools": [],
|
||||
"status": "done",
|
||||
}
|
||||
if button_rows:
|
||||
target_message["choices"].append({
|
||||
"id": _resolve_web_agent_choice_id(message_id, button_rows),
|
||||
"title": title,
|
||||
"prompt": text or "",
|
||||
"buttons": [button for row in button_rows for button in row],
|
||||
"button_rows": button_rows,
|
||||
"status": "pending",
|
||||
})
|
||||
return {
|
||||
"type": "message_update",
|
||||
"target_message": target_message,
|
||||
}
|
||||
|
||||
|
||||
def attach_web_agent_edit_queue(user_id: str, edit_queue: Queue[dict]) -> None:
|
||||
"""
|
||||
为当前 WebAgent 请求挂载原消息编辑事件队列。
|
||||
|
||||
:param user_id: 当前用户 ID
|
||||
:param edit_queue: 用于接收编辑事件的队列
|
||||
"""
|
||||
with _WEB_AGENT_EDIT_LOCK:
|
||||
_WEB_AGENT_EDIT_QUEUES.setdefault(str(user_id), []).append(edit_queue)
|
||||
|
||||
|
||||
def detach_web_agent_edit_queue(user_id: str, edit_queue: Queue[dict]) -> None:
|
||||
"""
|
||||
移除当前 WebAgent 请求的原消息编辑事件队列。
|
||||
|
||||
:param user_id: 当前用户 ID
|
||||
:param edit_queue: 需要移除的队列
|
||||
"""
|
||||
with _WEB_AGENT_EDIT_LOCK:
|
||||
queues = _WEB_AGENT_EDIT_QUEUES.get(str(user_id))
|
||||
if not queues:
|
||||
return
|
||||
_WEB_AGENT_EDIT_QUEUES[str(user_id)] = [
|
||||
item for item in queues if item is not edit_queue
|
||||
]
|
||||
if not _WEB_AGENT_EDIT_QUEUES[str(user_id)]:
|
||||
_WEB_AGENT_EDIT_QUEUES.pop(str(user_id), None)
|
||||
|
||||
|
||||
def dispatch_web_agent_edit_event(
|
||||
*,
|
||||
user_id: str,
|
||||
event: dict,
|
||||
) -> bool:
|
||||
"""
|
||||
将 WebAgent 原消息编辑事件分发给正在等待的请求队列。
|
||||
|
||||
:param user_id: 当前用户 ID
|
||||
:param event: 前端可应用的 SSE 事件
|
||||
:return: 是否存在接收本次编辑事件的请求队列
|
||||
"""
|
||||
with _WEB_AGENT_EDIT_LOCK:
|
||||
queues = list(_WEB_AGENT_EDIT_QUEUES.get(str(user_id)) or [])
|
||||
for edit_queue in queues:
|
||||
edit_queue.put(event)
|
||||
return bool(queues)
|
||||
|
||||
|
||||
def edit_web_agent_message(
|
||||
*,
|
||||
user_id: str,
|
||||
message_id: Union[str, int],
|
||||
title: Optional[str],
|
||||
text: str,
|
||||
buttons: Optional[list[list[dict]]] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
原地更新 WebAgent 前端消息卡片。
|
||||
|
||||
:param user_id: 当前用户 ID
|
||||
:param message_id: 前端助手消息 ID
|
||||
:param title: 更新后的标题
|
||||
:param text: 更新后的正文
|
||||
:param buttons: 更新后的按钮
|
||||
:return: 是否已投递编辑事件
|
||||
"""
|
||||
if not user_id:
|
||||
return False
|
||||
event = build_web_agent_message_update_event(
|
||||
message_id=message_id,
|
||||
title=title,
|
||||
text=text,
|
||||
buttons=buttons,
|
||||
)
|
||||
return dispatch_web_agent_edit_event(user_id=user_id, event=event)
|
||||
@@ -197,6 +197,9 @@ def update_or_post_message(
|
||||
and original_chat_id
|
||||
and ChannelCapabilityManager.supports_editing(channel)
|
||||
):
|
||||
edit_kwargs = {}
|
||||
if channel == MessageChannel.WebAgent:
|
||||
edit_kwargs["metadata"] = {"userid": userid}
|
||||
edited = chain.edit_message(
|
||||
channel=channel,
|
||||
source=source,
|
||||
@@ -205,6 +208,7 @@ def update_or_post_message(
|
||||
title=title,
|
||||
text=text,
|
||||
buttons=buttons,
|
||||
**edit_kwargs,
|
||||
)
|
||||
if edited:
|
||||
return
|
||||
@@ -402,6 +406,7 @@ class AgentInteractionOption:
|
||||
|
||||
label: str
|
||||
value: str
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -88,6 +88,22 @@ class AgentChatChoiceButton(BaseModel):
|
||||
|
||||
label: str = Field(..., description="按钮文案")
|
||||
callback_data: str = Field(..., description="回调数据")
|
||||
description: Optional[str] = Field(None, description="选项描述")
|
||||
|
||||
|
||||
class AgentChatChoiceSelection(BaseModel):
|
||||
"""
|
||||
Agent 会话中用户选择的展示快照。
|
||||
"""
|
||||
|
||||
choice_id: str = Field(..., description="选择卡片 ID")
|
||||
title: Optional[str] = Field(None, description="标题")
|
||||
prompt: str = Field(default="", description="提示语")
|
||||
buttons: list[AgentChatChoiceButton] = Field(default_factory=list, description="按钮列表")
|
||||
button_rows: list[list[AgentChatChoiceButton]] = Field(default_factory=list, description="按钮行")
|
||||
selected_label: Optional[str] = Field(None, description="已选择文案")
|
||||
selected_value: Optional[str] = Field(None, description="已选择值")
|
||||
selected_description: Optional[str] = Field(None, description="已选择描述")
|
||||
|
||||
|
||||
class AgentChatChoiceCard(BaseModel):
|
||||
@@ -99,9 +115,11 @@ class AgentChatChoiceCard(BaseModel):
|
||||
title: Optional[str] = Field(None, description="标题")
|
||||
prompt: str = Field(default="", description="提示语")
|
||||
buttons: list[AgentChatChoiceButton] = Field(default_factory=list, description="按钮列表")
|
||||
button_rows: list[list[AgentChatChoiceButton]] = Field(default_factory=list, description="按钮行")
|
||||
status: str = Field(default="pending", description="选择状态")
|
||||
selected_label: Optional[str] = Field(None, description="已选择文案")
|
||||
selected_value: Optional[str] = Field(None, description="已选择值")
|
||||
selected_description: Optional[str] = Field(None, description="已选择描述")
|
||||
|
||||
|
||||
class AgentChatMessage(BaseModel):
|
||||
@@ -117,6 +135,7 @@ class AgentChatMessage(BaseModel):
|
||||
tools: list[AgentChatToolCall] = Field(default_factory=list, description="工具提示列表")
|
||||
attachments: list[AgentChatAttachment] = Field(default_factory=list, description="附件列表")
|
||||
choices: list[AgentChatChoiceCard] = Field(default_factory=list, description="选择卡片列表")
|
||||
choice_selection: Optional[AgentChatChoiceSelection] = Field(None, description="用户选择项快照")
|
||||
|
||||
|
||||
class AgentChatSession(BaseModel):
|
||||
|
||||
@@ -330,6 +330,8 @@ class AgentWebChatRequest(BaseModel):
|
||||
|
||||
# 用户本轮输入
|
||||
text: str = Field(default="")
|
||||
# 展示历史中记录的用户可读文本;为空时使用 text
|
||||
display_text: Optional[str] = Field(None)
|
||||
# 前端会话标识,相同标识复用同一段 Agent 记忆
|
||||
session_id: Optional[str] = Field(None)
|
||||
# 图片 URL 或 data URL 列表
|
||||
@@ -338,6 +340,12 @@ class AgentWebChatRequest(BaseModel):
|
||||
audio_refs: Optional[List[str]] = Field(default_factory=list)
|
||||
# 文件附件列表
|
||||
files: Optional[List[AgentWebChatFile]] = Field(default_factory=list)
|
||||
# 用户通过按钮选择时的完整选择快照
|
||||
choice_selection: Optional[Dict[str, Any]] = Field(default=None)
|
||||
# WebAgent 按钮回调关联的原消息 ID,用于传统交互原地编辑卡片
|
||||
original_message_id: Optional[Union[str, int]] = Field(default=None)
|
||||
# WebAgent 按钮回调关联的原聊天 ID,用于传统交互原地编辑卡片
|
||||
original_chat_id: Optional[Union[str, int]] = Field(default=None)
|
||||
# 是否在展示历史中记录本轮用户消息
|
||||
echo_user: bool = Field(default=True)
|
||||
|
||||
@@ -351,6 +359,10 @@ class AgentWebChoiceRequest(BaseModel):
|
||||
session_id: Optional[str] = Field(None)
|
||||
# Agent 工具生成的按钮回调数据
|
||||
callback_data: str = Field(..., min_length=1)
|
||||
# WebAgent 原助手消息 ID,用于传统按钮回调原地编辑
|
||||
original_message_id: Optional[Union[str, int]] = Field(default=None)
|
||||
# WebAgent 原聊天 ID,用于传统按钮回调原地编辑
|
||||
original_chat_id: Optional[Union[str, int]] = Field(default=None)
|
||||
|
||||
|
||||
class ChannelCapability(Enum):
|
||||
|
||||
@@ -512,6 +512,48 @@ class TestSkillsCommand(unittest.TestCase):
|
||||
self.assertIn("仓库来源 · acme/custom-skills", text)
|
||||
self.assertIn("3. 管理技能源", text)
|
||||
|
||||
def test_skills_chain_installed_view_builds_remove_buttons(self):
|
||||
chain = SkillsChain()
|
||||
request = skills_interaction_manager.create_or_replace(
|
||||
user_id="10001",
|
||||
channel=MessageChannel.WebAgent,
|
||||
source="web-agent",
|
||||
username="tester",
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
chain.skillhelper,
|
||||
"list_local_skills",
|
||||
return_value=[
|
||||
SkillInfo(
|
||||
id="builtin",
|
||||
name="Builtin",
|
||||
description="Built in skill",
|
||||
source_type="builtin",
|
||||
source_label="内置",
|
||||
removable=False,
|
||||
),
|
||||
SkillInfo(
|
||||
id="custom",
|
||||
name="Custom",
|
||||
description="Custom skill",
|
||||
source_type="local",
|
||||
source_label="本地",
|
||||
removable=True,
|
||||
),
|
||||
],
|
||||
):
|
||||
title, text, buttons = chain._build_installed_view(request=request)
|
||||
|
||||
self.assertEqual(title, "已安装技能")
|
||||
self.assertIn("builtin", text)
|
||||
self.assertIn("custom", text)
|
||||
self.assertTrue(buttons)
|
||||
self.assertIn(
|
||||
{"text": "删除 2", "callback_data": f"skills:{request.request_id}:remove:2"},
|
||||
[button for row in buttons for button in row],
|
||||
)
|
||||
|
||||
def test_skills_chain_callback_enters_search_input_mode(self):
|
||||
chain = SkillsChain()
|
||||
request = skills_interaction_manager.create_or_replace(
|
||||
|
||||
@@ -1,27 +1,39 @@
|
||||
import asyncio
|
||||
import time
|
||||
from queue import Queue
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from app import schemas
|
||||
from app.agent import ReplyMode
|
||||
from app.agent import ReplyMode, agent_manager
|
||||
from app.api.endpoints.agent import (
|
||||
_WebAgentMoviePilotAgent,
|
||||
_WEB_AGENT_FILE_REGISTRY,
|
||||
_WEB_AGENT_NOTICE_QUEUES,
|
||||
_apply_web_agent_display_event,
|
||||
_build_web_agent_input_attachments,
|
||||
_build_web_agent_notification_events,
|
||||
_build_web_agent_command_items,
|
||||
_build_web_agent_session_id,
|
||||
_build_web_agent_traditional_callback_payload,
|
||||
_build_web_agent_display_message_from_events,
|
||||
_collect_web_agent_traditional_events,
|
||||
_dispatch_web_agent_notice_event,
|
||||
_extract_web_agent_notification_from_event_data,
|
||||
_has_web_agent_traditional_interaction,
|
||||
_prepare_web_agent_audio_attachment_path,
|
||||
_transcribe_web_agent_audio_refs,
|
||||
web_agent_stream,
|
||||
_resolve_web_agent_choice_payload,
|
||||
_split_web_agent_output,
|
||||
)
|
||||
from app.core.event import Event
|
||||
from app.db.agentchat_oper import AgentChatOper
|
||||
from app.helper.interaction import AgentInteractionOption, agent_interaction_manager
|
||||
from app.helper.agent import build_web_agent_message_update_event
|
||||
from app.helper.interaction import AgentInteractionOption, agent_interaction_manager, skills_interaction_manager
|
||||
from app.chain.message import MessageChain
|
||||
from app.schemas.message import ChannelCapability, ChannelCapabilityManager
|
||||
from app.schemas.types import MessageChannel, NotificationType
|
||||
from app.schemas.types import EventType, MessageChannel, NotificationType
|
||||
|
||||
|
||||
def test_split_web_agent_output_extracts_verbose_tool_message():
|
||||
@@ -144,6 +156,125 @@ def test_build_web_agent_input_attachments_marks_kinds():
|
||||
assert attachments[1]["name"] == "report.txt"
|
||||
|
||||
|
||||
def test_build_web_agent_command_items_returns_slash_commands():
|
||||
"""WebAgent 命令建议应返回可展示的斜杠命令。"""
|
||||
with patch(
|
||||
"app.api.endpoints.agent.Command",
|
||||
return_value=SimpleNamespace(
|
||||
get_commands=lambda: {
|
||||
"/sites": {"description": "管理站点", "category": "站点"},
|
||||
"hidden": {"description": "忽略", "category": "其他"},
|
||||
"/hidden": {"description": "隐藏", "category": "其他", "show": False},
|
||||
}
|
||||
),
|
||||
):
|
||||
commands = _build_web_agent_command_items()
|
||||
|
||||
assert commands == [
|
||||
{
|
||||
"command": "/sites",
|
||||
"description": "管理站点",
|
||||
"category": "站点",
|
||||
"type": "",
|
||||
"pid": None,
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def test_build_web_agent_command_items_includes_sites_command():
|
||||
"""WebAgent 命令建议应包含内建站点管理命令。"""
|
||||
commands = _build_web_agent_command_items()
|
||||
|
||||
assert any(command["command"] == "/sites" for command in commands)
|
||||
|
||||
|
||||
def test_build_web_agent_traditional_callback_payload_wraps_callback():
|
||||
"""传统按钮回调应包装为可继续提交给 MessageChain 的消息。"""
|
||||
payload = _build_web_agent_traditional_callback_payload(
|
||||
"skills:req-1:root",
|
||||
original_message_id="assistant-1",
|
||||
original_chat_id="web-session",
|
||||
)
|
||||
|
||||
assert payload["message"] == "CALLBACK:skills:req-1:root"
|
||||
assert payload["traditional"] is True
|
||||
assert payload["original_message_id"] == "assistant-1"
|
||||
assert payload["original_chat_id"] == "web-session"
|
||||
|
||||
|
||||
def test_web_agent_stream_returns_error_for_unknown_command():
|
||||
"""不存在的 WebAgent 斜杠命令应立即返回错误,不进入等待队列。"""
|
||||
payload = schemas.AgentWebChatRequest(
|
||||
text="/missing_command 参数",
|
||||
session_id="browser-session",
|
||||
)
|
||||
request = SimpleNamespace()
|
||||
user = SimpleNamespace(id=1, name="admin", is_superuser=True)
|
||||
|
||||
with patch(
|
||||
"app.api.endpoints.agent.Command",
|
||||
return_value=SimpleNamespace(get=lambda _: {}),
|
||||
), patch("app.api.endpoints.agent.MessageChain.handle_message") as handle_message:
|
||||
response = asyncio.run(web_agent_stream(payload, request, user))
|
||||
body = "".join(asyncio.run(_collect_streaming_response(response)))
|
||||
|
||||
assert "error" in body
|
||||
assert "命令不存在:/missing_command" in body
|
||||
handle_message.assert_not_called()
|
||||
|
||||
|
||||
def test_build_web_agent_message_update_event_converts_buttons():
|
||||
"""WebAgent 编辑消息应转换为可原地更新卡片的事件。"""
|
||||
event = build_web_agent_message_update_event(
|
||||
message_id="assistant-1",
|
||||
title="技能管理",
|
||||
text="请选择操作",
|
||||
buttons=[[{"text": "返回", "callback_data": "skills:req-1:root"}]],
|
||||
)
|
||||
|
||||
assert event["type"] == "message_update"
|
||||
assert event["target_message"]["id"] == "assistant-1"
|
||||
assert event["target_message"]["choices"][0]["title"] == "技能管理"
|
||||
assert event["target_message"]["choices"][0]["prompt"] == "请选择操作"
|
||||
assert event["target_message"]["choices"][0]["buttons"][0]["label"] == "返回"
|
||||
|
||||
|
||||
def test_build_web_agent_display_message_from_events_marks_done():
|
||||
"""传统消息事件应聚合为完成态助手展示消息。"""
|
||||
message = _build_web_agent_display_message_from_events([
|
||||
{"type": "delta", "content": "菜单"},
|
||||
{
|
||||
"type": "choice",
|
||||
"choice": {
|
||||
"id": "choice-1",
|
||||
"prompt": "请选择",
|
||||
"buttons": [{"label": "返回", "callback_data": "back"}],
|
||||
},
|
||||
},
|
||||
])
|
||||
|
||||
assert message["content"] == "菜单"
|
||||
assert message["status"] == "done"
|
||||
assert message["choices"][0]["prompt"] == "请选择"
|
||||
|
||||
|
||||
def test_has_web_agent_traditional_interaction_detects_pending_skills():
|
||||
"""WebAgent 应能识别命令后的传统交互上下文。"""
|
||||
skills_interaction_manager.clear()
|
||||
try:
|
||||
skills_interaction_manager.create_or_replace(
|
||||
user_id="1",
|
||||
channel=MessageChannel.WebAgent,
|
||||
source="web-agent",
|
||||
username="admin",
|
||||
)
|
||||
|
||||
assert _has_web_agent_traditional_interaction("1") is True
|
||||
assert _has_web_agent_traditional_interaction("2") is False
|
||||
finally:
|
||||
skills_interaction_manager.clear()
|
||||
|
||||
|
||||
def test_web_agent_admin_context_uses_current_user_id():
|
||||
"""Web Agent 工具权限应按当前登录用户 ID 判断管理员身份。"""
|
||||
agent = _WebAgentMoviePilotAgent(
|
||||
@@ -213,6 +344,69 @@ def test_build_web_agent_notification_events_extracts_image():
|
||||
]
|
||||
|
||||
|
||||
def test_extract_web_agent_notification_supports_wrapped_message_event():
|
||||
"""NoticeMessage 包装 Notification 时应仍能解析为 WebAgent 通知。"""
|
||||
notification = schemas.Notification(
|
||||
channel=MessageChannel.WebAgent,
|
||||
source="web-agent",
|
||||
title="会话状态",
|
||||
userid="1",
|
||||
)
|
||||
|
||||
extracted = _extract_web_agent_notification_from_event_data(
|
||||
{"message": notification, "current_time": "2026-06-26 09:18:38"}
|
||||
)
|
||||
|
||||
assert extracted == notification
|
||||
|
||||
|
||||
def test_dispatch_web_agent_notice_event_accepts_wrapped_message_event():
|
||||
"""WebAgent 等待队列应接收 message 包装格式的 NoticeMessage 事件。"""
|
||||
notice_queue = Queue()
|
||||
_WEB_AGENT_NOTICE_QUEUES["1"] = [notice_queue]
|
||||
notification = schemas.Notification(
|
||||
channel=MessageChannel.WebAgent,
|
||||
source="web-agent",
|
||||
title="会话状态",
|
||||
userid="1",
|
||||
)
|
||||
|
||||
try:
|
||||
_dispatch_web_agent_notice_event(
|
||||
Event(
|
||||
EventType.NoticeMessage,
|
||||
{"message": notification, "current_time": "2026-06-26 09:18:38"},
|
||||
)
|
||||
)
|
||||
finally:
|
||||
_WEB_AGENT_NOTICE_QUEUES.pop("1", None)
|
||||
|
||||
assert notice_queue.get_nowait() == notification
|
||||
|
||||
|
||||
def test_collect_web_agent_traditional_events_does_not_emit_submit_hint():
|
||||
"""传统命令未产生通知时不应返回“命令已提交”的兜底提示。"""
|
||||
user = SimpleNamespace(id=1, name="admin")
|
||||
|
||||
with patch(
|
||||
"app.api.endpoints.agent.MessageChain.handle_message",
|
||||
), patch(
|
||||
"app.api.endpoints.agent.WEB_AGENT_TRADITIONAL_IDLE_TIMEOUT_SECONDS",
|
||||
0.01,
|
||||
), patch(
|
||||
"app.api.endpoints.agent.WEB_AGENT_TRADITIONAL_MAX_WAIT_SECONDS",
|
||||
0.05,
|
||||
):
|
||||
events = asyncio.run(
|
||||
_collect_web_agent_traditional_events(
|
||||
text="/session_status",
|
||||
current_user=user,
|
||||
)
|
||||
)
|
||||
|
||||
assert events == []
|
||||
|
||||
|
||||
def test_build_web_agent_notification_events_registers_local_file(tmp_path):
|
||||
"""Agent 工具发送本地文件时应生成可下载附件事件。"""
|
||||
file_path = tmp_path / "report.txt"
|
||||
@@ -335,6 +529,70 @@ def test_web_agent_stream_returns_error_when_voice_transcription_fails():
|
||||
assert "语音识别失败" in body
|
||||
|
||||
|
||||
def test_web_agent_stream_binds_session_to_agent_manager():
|
||||
"""WebAgent 普通对话应统一进入 AgentManager 并绑定远程命令会话。"""
|
||||
payload = schemas.AgentWebChatRequest(
|
||||
text="查看会话",
|
||||
session_id="browser-session",
|
||||
)
|
||||
request = SimpleNamespace(is_disconnected=AsyncMock(return_value=False))
|
||||
user = SimpleNamespace(id=1, name="admin", is_superuser=True)
|
||||
|
||||
class FakeWebAgent:
|
||||
"""测试用 WebAgent,模拟 AgentManager 内部的持久实例。"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.__dict__.update(kwargs)
|
||||
self.processed = []
|
||||
|
||||
def set_output_callback(self, output_callback):
|
||||
"""更新当前 SSE 输出回调。"""
|
||||
self.output_callback = output_callback
|
||||
|
||||
def set_notification_callback(self, notification_callback):
|
||||
"""更新当前 SSE 通知回调。"""
|
||||
self.notification_callback = notification_callback
|
||||
|
||||
async def process(self, message, **kwargs):
|
||||
"""模拟一次 WebAgent 推理输出。"""
|
||||
self.processed.append((message, kwargs))
|
||||
self.output_callback("状态正常")
|
||||
return "状态正常"
|
||||
|
||||
async def cleanup(self):
|
||||
"""模拟 Agent 资源清理。"""
|
||||
return None
|
||||
|
||||
session_id = _build_web_agent_session_id(user, payload.session_id)
|
||||
MessageChain._user_sessions.clear()
|
||||
agent_manager.active_agents.pop(session_id, None)
|
||||
agent_manager._session_queues.pop(session_id, None)
|
||||
worker = agent_manager._session_workers.pop(session_id, None)
|
||||
if worker:
|
||||
worker.cancel()
|
||||
|
||||
try:
|
||||
with patch("app.api.endpoints.agent.settings.AI_AGENT_ENABLE", True), patch(
|
||||
"app.api.endpoints.agent._WebAgentMoviePilotAgent",
|
||||
FakeWebAgent,
|
||||
):
|
||||
response = asyncio.run(web_agent_stream(payload, request, user))
|
||||
body = "".join(asyncio.run(_collect_streaming_response(response)))
|
||||
|
||||
assert "状态正常" in body
|
||||
assert MessageChain._user_sessions["1"][0] == session_id
|
||||
assert isinstance(agent_manager.active_agents[session_id], FakeWebAgent)
|
||||
finally:
|
||||
MessageChain._user_sessions.clear()
|
||||
agent = agent_manager.active_agents.pop(session_id, None)
|
||||
if agent:
|
||||
asyncio.run(agent.cleanup())
|
||||
agent_manager._session_queues.pop(session_id, None)
|
||||
worker = agent_manager._session_workers.pop(session_id, None)
|
||||
if worker:
|
||||
worker.cancel()
|
||||
|
||||
|
||||
async def _collect_streaming_response(response):
|
||||
"""读取 StreamingResponse,便于断言 SSE 内容。"""
|
||||
chunks = []
|
||||
@@ -356,6 +614,7 @@ def test_build_web_agent_notification_events_extracts_choice_card():
|
||||
{
|
||||
"text": "继续下载",
|
||||
"callback_data": "agent_interaction:choice:req-1:1",
|
||||
"description": "继续当前下载任务",
|
||||
}
|
||||
],
|
||||
[
|
||||
@@ -379,12 +638,28 @@ def test_build_web_agent_notification_events_extracts_choice_card():
|
||||
{
|
||||
"label": "继续下载",
|
||||
"callback_data": "agent_interaction:choice:req-1:1",
|
||||
"description": "继续当前下载任务",
|
||||
},
|
||||
{
|
||||
"label": "查看详情",
|
||||
"callback_data": "agent_interaction:choice:req-1:2",
|
||||
},
|
||||
],
|
||||
"button_rows": [
|
||||
[
|
||||
{
|
||||
"label": "继续下载",
|
||||
"callback_data": "agent_interaction:choice:req-1:1",
|
||||
"description": "继续当前下载任务",
|
||||
}
|
||||
],
|
||||
[
|
||||
{
|
||||
"label": "查看详情",
|
||||
"callback_data": "agent_interaction:choice:req-1:2",
|
||||
}
|
||||
],
|
||||
],
|
||||
},
|
||||
}
|
||||
]
|
||||
@@ -403,7 +678,7 @@ def test_resolve_web_agent_choice_payload_returns_next_message():
|
||||
prompt="请选择",
|
||||
options=[
|
||||
AgentInteractionOption(label="电影", value="我选择电影"),
|
||||
AgentInteractionOption(label="电视剧", value="我选择电视剧"),
|
||||
AgentInteractionOption(label="电视剧", value="我选择电视剧", description="选择电视剧并继续清理日志"),
|
||||
],
|
||||
)
|
||||
|
||||
@@ -416,6 +691,12 @@ def test_resolve_web_agent_choice_payload_returns_next_message():
|
||||
agent_interaction_manager.clear()
|
||||
|
||||
assert result["message"] == "我选择电视剧"
|
||||
assert result["display_message"] == "选择电视剧并继续清理日志"
|
||||
assert result["session_id"] == "web-agent:session"
|
||||
assert result["feedback"]["prompt"] == "请选择"
|
||||
assert result["feedback"]["selected_label"] == "电视剧"
|
||||
assert result["feedback"]["selected_value"] == "我选择电视剧"
|
||||
assert result["feedback"]["selected_description"] == "选择电视剧并继续清理日志"
|
||||
assert result["choice_selection"]["prompt"] == "请选择"
|
||||
assert result["choice_selection"]["selected_description"] == "选择电视剧并继续清理日志"
|
||||
assert result["choice_selection"]["button_rows"][1][0]["description"] == "选择电视剧并继续清理日志"
|
||||
|
||||
Reference in New Issue
Block a user