feat(agent): add voice message support with TTS/STT for Telegram and WeChat

- Integrate voice message handling: detect and extract audio references from Telegram and WeChat messages, route to agent with voice reply preference.
- Add voice provider abstraction and OpenAI-based TTS/STT implementation.
- Implement agent tool `send_voice_message` for generating and sending voice replies, with fallback to text if voice is unavailable.
- Extend agent prompt and context to support voice reply instructions.
- Update notification and message schemas to support audio fields.
- Add Telegram and WeChat voice sending logic, including audio file conversion and temporary media upload for WeChat.
- Add tests for voice helper and agent voice routing.
This commit is contained in:
jxxghp
2026-04-12 12:30:02 +08:00
parent 9dababbcfd
commit e5f97cd299
17 changed files with 945 additions and 167 deletions

View File

@@ -129,6 +129,8 @@ class MoviePilotAgent:
self.channel = channel
self.source = source
self.username = username
self.reply_with_voice = False
self._tool_context: Dict[str, object] = {}
# 流式token管理
self.stream_handler = StreamingHandler()
@@ -151,6 +153,8 @@ class MoviePilotAgent:
"""
if self.is_background:
return False
if self.reply_with_voice:
return False
# 啰嗦模式下始终需要流式输出来捕获工具调用前的 Agent 文字
if settings.AI_AGENT_VERBOSE:
return True
@@ -214,6 +218,7 @@ class MoviePilotAgent:
source=self.source,
username=self.username,
stream_handler=self.stream_handler,
agent_context=self._tool_context,
)
def _create_agent(self, streaming: bool = False):
@@ -223,7 +228,10 @@ class MoviePilotAgent:
"""
try:
# 系统提示词
system_prompt = prompt_manager.get_agent_prompt(channel=self.channel)
system_prompt = prompt_manager.get_agent_prompt(
channel=self.channel,
prefer_voice_reply=self.reply_with_voice,
)
# LLM 模型(用于 agent 执行)
llm = self._initialize_llm(streaming=streaming)
@@ -281,6 +289,11 @@ class MoviePilotAgent:
logger.info(
f"Agent推理: session_id={self.session_id}, input={message}, images={len(images) if images else 0}"
)
self._tool_context = {
"incoming_voice": self.reply_with_voice,
"user_reply_sent": False,
"reply_mode": None,
}
# 获取历史消息
messages = memory_manager.get_agent_messages(
@@ -417,7 +430,7 @@ class MoviePilotAgent:
# 流式输出未能发送全部内容(发送失败等)
# 通过常规方式发送剩余内容
remaining_text = await self.stream_handler.take()
if remaining_text:
if remaining_text and not self._tool_context.get("user_reply_sent"):
await self.send_agent_message(remaining_text)
elif streamed_text:
# 流式输出已发送全部内容,但未记录到数据库,补充保存消息记录
@@ -447,7 +460,7 @@ class MoviePilotAgent:
final_text = text.strip()
break
if final_text:
if final_text and not self._tool_context.get("user_reply_sent"):
if self.is_background:
# 后台任务仅广播最终回复,带标题
await self.send_agent_message(
@@ -534,6 +547,7 @@ class _MessageTask:
channel: Optional[str] = None
source: Optional[str] = None
username: Optional[str] = None
reply_with_voice: bool = False
class AgentManager:
@@ -599,6 +613,7 @@ class AgentManager:
channel: str = None,
source: str = None,
username: str = None,
reply_with_voice: bool = False,
) -> str:
"""
处理用户消息:将消息放入会话队列,按顺序依次处理。
@@ -612,6 +627,7 @@ class AgentManager:
channel=channel,
source=source,
username=username,
reply_with_voice=reply_with_voice,
)
# 获取或创建会话队列
@@ -709,6 +725,7 @@ class AgentManager:
agent.source = task.source
if task.username:
agent.username = task.username
agent.reply_with_voice = task.reply_with_voice
return await agent.process(task.message, images=task.images)

View File

@@ -21,6 +21,7 @@ Core Capabilities:
- Include key details (year, rating, resolution) but do NOT over-explain.
- Do not stop for approval on read-only operations. Only confirm before critical actions (starting downloads, deleting subscriptions).
- If the current channel supports image sending and an image would materially help, you may use the `send_message` tool with `image_url` to send it.
- Voice replies: {voice_reply_spec}
- NOT a coding assistant. Do not offer code snippets.
- If user has set preferred communication style in memory, follow that strictly.
</communication>

View File

@@ -50,10 +50,13 @@ class PromptManager:
logger.error(f"加载提示词失败: {prompt_name}, 错误: {e}")
raise
def get_agent_prompt(self, channel: str = None) -> str:
def get_agent_prompt(
self, channel: str = None, prefer_voice_reply: bool = False
) -> str:
"""
获取智能体提示词
:param channel: 消息渠道Telegram、微信、Slack等
:param prefer_voice_reply: 是否优先使用语音回复
:return: 提示词内容
"""
# 基础提示词
@@ -87,12 +90,16 @@ class PromptManager:
# MoviePilot系统信息
moviepilot_info = self._get_moviepilot_info()
voice_reply_spec = self._generate_voice_reply_instructions(
prefer_voice_reply=prefer_voice_reply
)
# 始终替换占位符,避免后续 .format() 时因残留花括号报 KeyError
base_prompt = base_prompt.format(
markdown_spec=markdown_spec,
verbose_spec=verbose_spec,
moviepilot_info=moviepilot_info,
voice_reply_spec=voice_reply_spec,
)
return base_prompt
@@ -166,6 +173,20 @@ class PromptManager:
instructions.append("- Links: Paste URLs directly as text.")
return "\n".join(instructions)
@staticmethod
def _generate_voice_reply_instructions(prefer_voice_reply: bool) -> str:
if not prefer_voice_reply:
return (
"- Voice replies: Use normal text replies by default. "
"Only call `send_voice_message` when spoken playback is clearly better than plain text."
)
return (
"- Current message context: The user sent a voice message.\n"
"- Reply preference: Prioritize calling `send_voice_message` for the main user-facing reply.\n"
"- Fallback: If voice is unavailable on the current channel, `send_voice_message` will fall back to text.\n"
"- Do not repeat the same full reply again after calling `send_voice_message`."
)
def clear_cache(self):
"""
清空缓存

View File

@@ -31,6 +31,7 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
_username: Optional[str] = PrivateAttr(default=None)
_stream_handler: Optional[StreamingHandler] = PrivateAttr(default=None)
_require_admin: bool = PrivateAttr(default=False)
_agent_context: dict = PrivateAttr(default_factory=dict)
def __init__(self, session_id: str, user_id: str, **kwargs):
super().__init__(**kwargs)
@@ -142,6 +143,12 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
"""
self._stream_handler = stream_handler
def set_agent_context(self, agent_context: Optional[dict]):
"""
设置与当前 Agent 共享的上下文。
"""
self._agent_context = agent_context or {}
async def _check_permission(self) -> Optional[str]:
"""
检查用户权限:

View File

@@ -30,6 +30,7 @@ from app.agent.tools.impl.search_torrents import SearchTorrentsTool
from app.agent.tools.impl.get_search_results import GetSearchResultsTool
from app.agent.tools.impl.search_web import SearchWebTool
from app.agent.tools.impl.send_message import SendMessageTool
from app.agent.tools.impl.send_voice_message import SendVoiceMessageTool
from app.agent.tools.impl.query_schedulers import QuerySchedulersTool
from app.agent.tools.impl.run_scheduler import RunSchedulerTool
from app.agent.tools.impl.query_workflows import QueryWorkflowsTool
@@ -72,6 +73,7 @@ class MoviePilotToolFactory:
source: str = None,
username: str = None,
stream_handler: Callable = None,
agent_context: dict = None,
) -> List[MoviePilotTool]:
"""
创建MoviePilot工具列表
@@ -117,6 +119,7 @@ class MoviePilotToolFactory:
QueryTransferHistoryTool,
TransferFileTool,
SendMessageTool,
SendVoiceMessageTool,
QuerySchedulersTool,
RunSchedulerTool,
QueryWorkflowsTool,
@@ -138,6 +141,7 @@ class MoviePilotToolFactory:
tool = ToolClass(session_id=session_id, user_id=user_id)
tool.set_message_attr(channel=channel, source=source, username=username)
tool.set_stream_handler(stream_handler=stream_handler)
tool.set_agent_context(agent_context=agent_context)
tools.append(tool)
# 加载插件提供的工具
@@ -161,6 +165,7 @@ class MoviePilotToolFactory:
channel=channel, source=source, username=username
)
tool.set_stream_handler(stream_handler=stream_handler)
tool.set_agent_context(agent_context=agent_context)
tools.append(tool)
plugin_tools_count += 1
logger.debug(

View File

@@ -0,0 +1,96 @@
"""发送语音消息工具。"""
import asyncio
from typing import Optional, Type
from pydantic import BaseModel, Field
from app.agent.tools.base import MoviePilotTool, ToolChain
from app.core.config import settings
from app.helper.voice import VoiceHelper
from app.helper.service import ServiceConfigHelper
from app.log import logger
from app.schemas import Notification, NotificationType
from app.schemas.types import MessageChannel
class SendVoiceMessageInput(BaseModel):
"""发送语音消息工具输入。"""
explanation: str = Field(
...,
description="Clear explanation of why a voice reply is the best fit in the current context",
)
message: str = Field(
...,
description="The spoken content to send back to the user",
)
class SendVoiceMessageTool(MoviePilotTool):
name: str = "send_voice_message"
description: str = (
"Send a voice reply to the current user. Prefer this when the user sent a voice message "
"or when spoken playback is more natural. On channels without voice support or when TTS "
"is unavailable, it automatically falls back to sending the same content as plain text."
)
args_schema: Type[BaseModel] = SendVoiceMessageInput
require_admin: bool = False
def get_tool_message(self, **kwargs) -> Optional[str]:
message = kwargs.get("message") or ""
if len(message) > 40:
message = message[:40] + "..."
return f"正在发送语音回复: {message}"
def _supports_real_voice_reply(self) -> bool:
channel = self._channel or ""
if channel == MessageChannel.Telegram.value:
return True
if channel != MessageChannel.Wechat.value:
return False
for config in ServiceConfigHelper.get_notification_configs():
if config.name != self._source:
continue
return (config.config or {}).get("WECHAT_MODE", "app") != "bot"
return False
async def run(self, message: str, **kwargs) -> str:
if not message:
return "语音回复内容不能为空"
voice_path = None
used_voice = False
channel = self._channel or ""
if self._supports_real_voice_reply() and VoiceHelper.is_available("tts"):
voice_file = await asyncio.to_thread(VoiceHelper.synthesize_speech, message)
if voice_file:
voice_path = str(voice_file)
used_voice = True
logger.info(
"执行工具: %s, channel=%s, use_voice=%s, text_len=%s",
self.name,
channel,
used_voice,
len(message),
)
await ToolChain().async_post_message(
Notification(
channel=self._channel,
source=self._source,
mtype=NotificationType.Agent,
userid=self._user_id,
username=self._username,
text=message,
voice_path=voice_path,
voice_caption=message if settings.AI_VOICE_REPLY_WITH_TEXT else None,
)
)
self._agent_context["user_reply_sent"] = True
self._agent_context["reply_mode"] = "voice" if used_voice else "text_fallback"
if used_voice:
return "语音回复已发送"
return "当前未使用语音通道,已自动回退为文字回复"

View File

@@ -4,6 +4,8 @@ import time
from datetime import datetime, timedelta
from typing import Any, Optional, Dict, Union, List
import base64
from app.agent import agent_manager
from app.chain import ChainBase
from app.chain.download import DownloadChain
@@ -15,6 +17,7 @@ from app.core.context import MediaInfo, Context
from app.core.meta import MetaBase
from app.db.user_oper import UserOper
from app.helper.torrent import TorrentHelper
from app.helper.voice import VoiceHelper
from app.log import logger
from app.schemas import Notification, NotExistMediaInfo, CommingMessage
from app.schemas.message import ChannelCapabilityManager
@@ -44,26 +47,6 @@ class MessageChain(ChainBase):
# 会话超时时间(分钟)
_session_timeout_minutes: int = 24 * 60
@staticmethod
def _summarize_images(images: Optional[List[str]], max_items: int = 3) -> List[str]:
"""
图片引用摘要,避免日志过长或直接输出完整 base64。
"""
if not images:
return []
summary = []
for image in images[:max_items]:
if not image:
continue
image = str(image)
if image.startswith("data:"):
summary.append(f"{image[:32]}...({len(image)} chars)")
elif len(image) > 120:
summary.append(f"{image[:117]}...")
else:
summary.append(image)
return summary
@staticmethod
def __get_noexits_info(
_meta: MetaBase, _mediainfo: MediaInfo
@@ -146,23 +129,15 @@ class MessageChain(ChainBase):
if userid is None or userid == "":
logger.debug(f"未识别到用户ID{body}{form}{args}")
return
# 消息内容
text = str(info.text).strip() if info.text else ""
images = info.images
if not text and not images:
audio_refs = info.audio_refs
if not text and not images and not audio_refs:
logger.debug(f"未识别到消息内容::{body}{form}{args}")
return
logger.info(
"消息链路解析完成: source=%s, channel=%s, userid=%s, text_len=%s, image_count=%s, image_refs=%s",
source,
channel.value if channel else None,
userid,
len(text),
len(images or []),
self._summarize_images(images),
)
# 获取原消息ID信息
original_message_id = info.message_id
original_chat_id = info.chat_id
@@ -177,6 +152,7 @@ class MessageChain(ChainBase):
original_message_id=original_message_id,
original_chat_id=original_chat_id,
images=images,
audio_refs=audio_refs,
)
def handle_message(
@@ -189,25 +165,43 @@ class MessageChain(ChainBase):
original_message_id: Optional[Union[str, int]] = None,
original_chat_id: Optional[str] = None,
images: Optional[List[str]] = None,
audio_refs: Optional[List[str]] = None,
) -> None:
"""
识别消息内容,执行操作
"""
# 申明全局变量
global _current_page, _current_meta, _current_media
# 处理消息
logger.info(
"收到用户消息内容: channel=%s, source=%s, userid=%s, text=%s, image_count=%s, image_refs=%s",
channel.value if channel else None,
source,
userid,
text,
len(images or []),
self._summarize_images(images),
)
# 加载缓存
user_cache: Dict[str, dict] = self.load_cache(self._cache_file) or {}
try:
# 识别语音为文本
reply_with_voice = bool(audio_refs)
if audio_refs:
transcript = self._transcribe_audio_refs(audio_refs, channel, source)
merged_parts = []
seen_parts = set()
for item in [text.strip() if text else "", transcript or ""]:
normalized = item.strip()
if not normalized or normalized in seen_parts:
continue
seen_parts.add(normalized)
merged_parts.append(normalized)
text = "\n".join(merged_parts).strip()
if not text:
self.post_message(
Notification(
channel=channel,
source=source,
userid=userid,
username=username,
title="语音识别失败,请稍后重试",
)
)
return
# 保存消息
if not text.startswith("CALLBACK:"):
self.messagehelper.put(
@@ -251,32 +245,6 @@ class MessageChain(ChainBase):
{"cmd": text, "user": userid, "channel": channel, "source": source},
)
elif text.lower().startswith("/ai"):
# 用户指定AI智能体消息响应
logger.info(
"消息链路分流到AI: reason=explicit_ai, channel=%s, source=%s, userid=%s, image_count=%s",
channel.value if channel else None,
source,
userid,
len(images or []),
)
self._handle_ai_message(
text=text,
channel=channel,
source=source,
userid=userid,
username=username,
images=images,
)
elif settings.AI_AGENT_ENABLE and images:
# 带图消息优先交给智能体处理,避免图片在传统消息链路中丢失
logger.info(
"消息链路分流到AI: reason=image_message, channel=%s, source=%s, userid=%s, image_count=%s, image_refs=%s",
channel.value if channel else None,
source,
userid,
len(images or []),
self._summarize_images(images),
)
self._handle_ai_message(
text=text,
channel=channel,
@@ -284,16 +252,10 @@ class MessageChain(ChainBase):
userid=userid,
username=username,
images=images,
reply_with_voice=reply_with_voice,
)
elif settings.AI_AGENT_ENABLE and settings.AI_AGENT_GLOBAL:
# 普通消息,全局智能体响应
logger.info(
"消息链路分流到AI: reason=global_agent, channel=%s, source=%s, userid=%s, image_count=%s",
channel.value if channel else None,
source,
userid,
len(images or []),
)
self._handle_ai_message(
text=text,
channel=channel,
@@ -301,6 +263,7 @@ class MessageChain(ChainBase):
userid=userid,
username=username,
images=images,
reply_with_voice=reply_with_voice,
)
else:
# 非智能体普通消息响应
@@ -1266,6 +1229,7 @@ class MessageChain(ChainBase):
userid: Union[str, int],
username: str,
images: Optional[List[str]] = None,
reply_with_voice: bool = False,
) -> None:
"""
处理AI智能体消息
@@ -1290,16 +1254,6 @@ class MessageChain(ChainBase):
else:
user_message = text.strip() # 按原消息处理
logger.info(
"AI消息入口: channel=%s, source=%s, userid=%s, text_len=%s, raw_image_count=%s, raw_image_refs=%s",
channel.value if channel else None,
source,
userid,
len(user_message),
len(images or []),
self._summarize_images(images),
)
if not user_message and not images:
self.post_message(
Notification(
@@ -1319,15 +1273,6 @@ class MessageChain(ChainBase):
original_images = images
if images:
images = self._download_images_to_base64(images, channel, source)
logger.info(
"AI图片预处理完成: channel=%s, source=%s, userid=%s, raw_image_count=%s, converted_image_count=%s, converted_image_refs=%s",
channel.value if channel else None,
source,
userid,
len(original_images or []),
len(images or []),
self._summarize_images(images),
)
if original_images and not images and not user_message:
self.post_message(
Notification(
@@ -1350,6 +1295,7 @@ class MessageChain(ChainBase):
channel=channel.value if channel else None,
source=source,
username=username,
reply_with_voice=reply_with_voice,
),
global_vars.loop,
)
@@ -1360,6 +1306,64 @@ class MessageChain(ChainBase):
f"AI智能体处理失败: {str(e)}", role="system", title="MoviePilot助手"
)
def _transcribe_audio_refs(
self, audio_refs: List[str], channel: MessageChannel, source: str
) -> Optional[str]:
"""
下载并识别语音消息,仅处理当前已接入的渠道。
"""
if not audio_refs:
return None
if not VoiceHelper.is_available("stt"):
logger.warning("语音能力未配置,跳过语音识别")
return None
transcripts = []
for audio_ref in audio_refs:
try:
if audio_ref.startswith("tg://voice_file_id/"):
file_id = audio_ref.replace("tg://voice_file_id/", "", 1)
content = self.run_module(
"download_telegram_file_bytes", file_id=file_id, source=source
)
filename = "input.ogg"
elif audio_ref.startswith("tg://audio_file_id/"):
file_id = audio_ref.replace("tg://audio_file_id/", "", 1)
content = self.run_module(
"download_telegram_file_bytes", file_id=file_id, source=source
)
filename = "input.mp3"
elif audio_ref.startswith("wxwork://voice_media_id/"):
content = self.run_module(
"download_wechat_media_bytes", media_ref=audio_ref, source=source
)
filename = "input.amr"
elif audio_ref.startswith("wxbot://voice"):
continue
else:
logger.debug(
"暂不支持的语音引用: channel=%s, source=%s, ref=%s",
channel.value if channel else None,
source,
audio_ref,
)
continue
transcript = VoiceHelper.transcribe_bytes(content=content, filename=filename)
if transcript:
transcripts.append(transcript)
logger.info(
"语音识别成功: channel=%s, source=%s, ref=%s, text_len=%s",
channel.value if channel else None,
source,
audio_ref,
len(transcript),
)
except Exception as err:
logger.error(f"语音识别失败: {err}")
return "\n".join(transcripts).strip() if transcripts else None
def _download_images_to_base64(
self, images: List[str], channel: MessageChannel, source: str
) -> List[str]:
@@ -1373,16 +1377,10 @@ class MessageChain(ChainBase):
try:
if img.startswith("data:"):
base64_images.append(img)
logger.info(
"图片无需下载: channel=%s, source=%s, input=%s",
channel.value if channel else None,
source,
self._summarize_images([img])[0],
)
elif img.startswith("tg://file_id/"):
file_id = img.replace("tg://file_id/", "")
base64_data = self.run_module(
"download_file_to_base64", file_id=file_id, source=source
"download_telegram_file_to_base64", file_id=file_id, source=source
)
if base64_data:
base64_images.append(f"data:image/jpeg;base64,{base64_data}")
@@ -1402,26 +1400,12 @@ class MessageChain(ChainBase):
)
if data_url:
base64_images.append(data_url)
logger.info(
"图片下载成功: channel=%s, source=%s, input=%s, output=%s",
channel.value if channel else None,
source,
img,
self._summarize_images([data_url])[0],
)
elif channel == MessageChannel.Slack:
data_url = self.run_module(
"download_file_to_data_url", file_url=img, source=source
"download_slack_file_to_data_url", file_url=img, source=source
)
if data_url:
base64_images.append(data_url)
logger.info(
"图片下载成功: channel=%s, source=%s, input=%s, output=%s",
channel.value if channel else None,
source,
img,
self._summarize_images([data_url])[0],
)
elif img.startswith("vocechat://file/"):
data_url = self.run_module(
"download_vocechat_image_to_data_url",
@@ -1430,30 +1414,12 @@ class MessageChain(ChainBase):
)
if data_url:
base64_images.append(data_url)
logger.info(
"图片下载成功: channel=%s, source=%s, input=%s, output=%s",
channel.value if channel else None,
source,
img,
self._summarize_images([data_url])[0],
)
elif img.startswith("http"):
resp = RequestUtils(timeout=30).get_res(img)
if resp and resp.content:
import base64
base64_data = base64.b64encode(resp.content).decode()
mime_type = resp.headers.get("Content-Type", "image/jpeg")
base64_images.append(f"data:{mime_type};base64,{base64_data}")
logger.info(
"图片下载成功: channel=%s, source=%s, input=%s, output=%s",
channel.value if channel else None,
source,
img,
self._summarize_images(
[f"data:{mime_type};base64,{base64_data}"]
)[0],
)
except Exception as e:
logger.error(f"下载图片失败: {img}, error: {e}")
return base64_images if base64_images else None

View File

@@ -538,6 +538,35 @@ class ConfigModel(BaseModel):
# AI智能体自动重试整理失败记录开关
AI_AGENT_RETRY_TRANSFER: bool = False
# 语音能力提供商(当前仅支持 openai
AI_VOICE_PROVIDER: str = "openai"
# 语音识别提供商,未设置时回退到 AI_VOICE_PROVIDER
AI_VOICE_STT_PROVIDER: Optional[str] = None
# 语音合成提供商,未设置时回退到 AI_VOICE_PROVIDER
AI_VOICE_TTS_PROVIDER: Optional[str] = None
# 语音能力 API 密钥,未设置且 LLM_PROVIDER=openai 时回退使用 LLM_API_KEY
AI_VOICE_API_KEY: Optional[str] = None
# 语音识别 API 密钥,未设置时回退到 AI_VOICE_API_KEY
AI_VOICE_STT_API_KEY: Optional[str] = None
# 语音合成 API 密钥,未设置时回退到 AI_VOICE_API_KEY
AI_VOICE_TTS_API_KEY: Optional[str] = None
# 语音能力基础URL未设置且 LLM_PROVIDER=openai 时回退使用 LLM_BASE_URL
AI_VOICE_BASE_URL: Optional[str] = None
# 语音识别基础URL未设置时回退到 AI_VOICE_BASE_URL
AI_VOICE_STT_BASE_URL: Optional[str] = None
# 语音合成基础URL未设置时回退到 AI_VOICE_BASE_URL
AI_VOICE_TTS_BASE_URL: Optional[str] = None
# 语音转文字模型
AI_VOICE_STT_MODEL: str = "gpt-4o-mini-transcribe"
# 文字转语音模型
AI_VOICE_TTS_MODEL: str = "gpt-4o-mini-tts"
# TTS 发音人
AI_VOICE_TTS_VOICE: str = "alloy"
# 语音识别语言
AI_VOICE_LANGUAGE: str = "zh"
# 回复语音时是否同时附带文字说明
AI_VOICE_REPLY_WITH_TEXT: bool = False
class Settings(BaseSettings, ConfigModel, LogConfigModel):
"""

197
app/helper/voice.py Normal file
View File

@@ -0,0 +1,197 @@
"""语音能力辅助功能。"""
from abc import ABC, abstractmethod
from io import BytesIO
from pathlib import Path
from typing import Dict, Optional
from uuid import uuid4
from app.core.config import settings
from app.log import logger
class VoiceProvider(ABC):
"""语音 provider 抽象层。"""
MAX_TRANSCRIBE_BYTES = 25 * 1024 * 1024
@property
@abstractmethod
def name(self) -> str:
"""provider 名称。"""
@abstractmethod
def is_available_for_stt(self) -> bool:
"""是否可用于语音识别。"""
@abstractmethod
def is_available_for_tts(self) -> bool:
"""是否可用于语音合成。"""
@abstractmethod
def transcribe_bytes(self, content: bytes, filename: str = "input.ogg") -> Optional[str]:
"""将音频字节转成文字。"""
@abstractmethod
def synthesize_speech(self, text: str) -> Optional[Path]:
"""将文字转成语音文件。"""
class OpenAIVoiceProvider(VoiceProvider):
"""OpenAI / OpenAI-compatible provider。"""
@property
def name(self) -> str:
return "openai"
@staticmethod
def _resolve_credentials(mode: str) -> tuple[Optional[str], Optional[str]]:
mode = mode.lower()
provider = (
settings.AI_VOICE_STT_PROVIDER
if mode == "stt"
else settings.AI_VOICE_TTS_PROVIDER
) or settings.AI_VOICE_PROVIDER
provider = (provider or "").strip().lower()
api_key = (
settings.AI_VOICE_STT_API_KEY
if mode == "stt"
else settings.AI_VOICE_TTS_API_KEY
) or settings.AI_VOICE_API_KEY
base_url = (
settings.AI_VOICE_STT_BASE_URL
if mode == "stt"
else settings.AI_VOICE_TTS_BASE_URL
) or settings.AI_VOICE_BASE_URL
if (
not api_key
and provider == "openai"
and (settings.LLM_PROVIDER or "").strip().lower() == "openai"
):
api_key = settings.LLM_API_KEY
base_url = base_url or settings.LLM_BASE_URL
return api_key, base_url
def _get_client(self, mode: str):
from openai import OpenAI
api_key, base_url = self._resolve_credentials(mode)
if not api_key:
raise ValueError(f"{mode.upper()} provider 未配置 API Key")
return OpenAI(api_key=api_key, base_url=base_url, max_retries=3)
def is_available_for_stt(self) -> bool:
api_key, _ = self._resolve_credentials("stt")
return bool(api_key)
def is_available_for_tts(self) -> bool:
api_key, _ = self._resolve_credentials("tts")
return bool(api_key)
def transcribe_bytes(self, content: bytes, filename: str = "input.ogg") -> Optional[str]:
if not content:
return None
if len(content) > self.MAX_TRANSCRIBE_BYTES:
raise ValueError("语音文件超过 25MB无法识别")
try:
client = self._get_client("stt")
audio_file = BytesIO(content)
audio_file.name = filename
response = client.audio.transcriptions.create(
model=settings.AI_VOICE_STT_MODEL,
file=audio_file,
language=settings.AI_VOICE_LANGUAGE or "zh",
response_format="verbose_json",
)
text = getattr(response, "text", None)
return text.strip() if text else None
except Exception as err:
logger.error(f"语音转文字失败: provider={self.name}, error={err}")
return None
def synthesize_speech(self, text: str) -> Optional[Path]:
if not text:
return None
try:
client = self._get_client("tts")
voice_dir = settings.TEMP_PATH / "voice"
voice_dir.mkdir(parents=True, exist_ok=True)
output_path = voice_dir / f"{uuid4().hex}.opus"
response = client.audio.speech.create(
model=settings.AI_VOICE_TTS_MODEL,
voice=settings.AI_VOICE_TTS_VOICE,
input=text,
response_format="opus",
)
response.write_to_file(output_path)
return output_path
except Exception as err:
logger.error(f"文字转语音失败: provider={self.name}, error={err}")
return None
class VoiceHelper:
"""统一语音入口,负责按 STT/TTS provider 路由。"""
_providers: Dict[str, VoiceProvider] = {
"openai": OpenAIVoiceProvider(),
}
@classmethod
def register_provider(cls, provider: VoiceProvider) -> None:
cls._providers[provider.name.lower()] = provider
@staticmethod
def _resolve_provider_name(mode: str) -> str:
mode = mode.lower()
provider = (
settings.AI_VOICE_STT_PROVIDER
if mode == "stt"
else settings.AI_VOICE_TTS_PROVIDER
) or settings.AI_VOICE_PROVIDER
return (provider or "openai").strip().lower()
@classmethod
def get_provider(cls, mode: str) -> Optional[VoiceProvider]:
provider_name = cls._resolve_provider_name(mode)
provider = cls._providers.get(provider_name)
if provider:
return provider
logger.warning(f"未注册语音 provider: mode={mode}, provider={provider_name}")
return None
@classmethod
def get_registered_providers(cls) -> list[str]:
return sorted(cls._providers.keys())
@classmethod
def is_available(cls, mode: Optional[str] = None) -> bool:
if mode:
provider = cls.get_provider(mode)
if not provider:
return False
return (
provider.is_available_for_stt()
if mode.lower() == "stt"
else provider.is_available_for_tts()
)
return cls.is_available("stt") or cls.is_available("tts")
@classmethod
def transcribe_bytes(cls, content: bytes, filename: str = "input.ogg") -> Optional[str]:
provider = cls.get_provider("stt")
if not provider:
return None
return provider.transcribe_bytes(content=content, filename=filename)
@classmethod
def synthesize_speech(cls, text: str) -> Optional[Path]:
provider = cls.get_provider("tts")
if not provider:
return None
return provider.synthesize_speech(text=text)

View File

@@ -297,7 +297,7 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]):
images.append(url)
return images if images else None
def download_file_to_data_url(self, file_url: str, source: str) -> Optional[str]:
def download_slack_file_to_data_url(self, file_url: str, source: str) -> Optional[str]:
"""
下载Slack文件并转为data URL
:param file_url: Slack私有文件URL

View File

@@ -214,17 +214,19 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
text = self._append_reply_markup_links(text, msg.get("reply_markup"))
images = self._extract_images(msg)
audio_refs = self._extract_audio_refs(msg)
if user_id:
if not text and not images:
if not text and not images and not audio_refs:
logger.debug(
f"收到来自 {client_config.name} 的Telegram消息无文本图片"
f"收到来自 {client_config.name} 的Telegram消息无文本图片和语音"
)
return None
logger.info(
f"收到来自 {client_config.name} 的Telegram消息"
f"userid={user_id}, username={user_name}, chat_id={chat_id}, text={text}, images={len(images) if images else 0}"
f"userid={user_id}, username={user_name}, chat_id={chat_id}, text={text}, "
f"images={len(images) if images else 0}, audios={len(audio_refs) if audio_refs else 0}"
)
cleaned_text = (
@@ -263,6 +265,7 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
text=cleaned_text,
chat_id=str(chat_id) if chat_id else None,
images=images if images else None,
audio_refs=audio_refs if audio_refs else None,
)
return None
@@ -288,6 +291,26 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
return images if images else None
@staticmethod
def _extract_audio_refs(msg: dict) -> Optional[List[str]]:
"""
从Telegram消息中提取语音/音频 file_id。
"""
audio_refs = []
voice = msg.get("voice")
if voice:
file_id = voice.get("file_id")
if file_id:
audio_refs.append(f"tg://voice_file_id/{file_id}")
audio = msg.get("audio")
if audio:
file_id = audio.get("file_id")
if file_id:
audio_refs.append(f"tg://audio_file_id/{file_id}")
return audio_refs if audio_refs else None
@staticmethod
def _embed_entity_links(text: str, entities: Optional[List[dict]]) -> str:
"""
@@ -389,17 +412,25 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
return
client: Telegram = self.get_instance(conf.name)
if client:
client.send_msg(
title=message.title,
text=message.text,
image=message.image,
userid=userid,
link=message.link,
buttons=message.buttons,
original_message_id=message.original_message_id,
original_chat_id=message.original_chat_id,
disable_web_page_preview=message.disable_web_page_preview,
)
if message.voice_path:
client.send_voice(
voice_path=message.voice_path,
userid=userid,
caption=message.voice_caption,
original_chat_id=message.original_chat_id,
)
else:
client.send_msg(
title=message.title,
text=message.text,
image=message.image,
userid=userid,
link=message.link,
buttons=message.buttons,
original_message_id=message.original_message_id,
original_chat_id=message.original_chat_id,
disable_web_page_preview=message.disable_web_page_preview,
)
def post_medias_message(
self, message: Notification, medias: List[MediaInfo]
@@ -531,14 +562,22 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
return None
client: Telegram = self.get_instance(conf.name)
if client:
result = client.send_msg(
title=message.title,
text=message.text,
image=message.image,
userid=userid,
link=message.link,
disable_web_page_preview=message.disable_web_page_preview,
)
if message.voice_path:
result = client.send_voice(
voice_path=message.voice_path,
userid=userid,
caption=message.voice_caption,
original_chat_id=message.original_chat_id,
)
else:
result = client.send_msg(
title=message.title,
text=message.text,
image=message.image,
userid=userid,
link=message.link,
disable_web_page_preview=message.disable_web_page_preview,
)
if result and result.get("success"):
return MessageResponse(
message_id=result.get("message_id"),
@@ -601,7 +640,7 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
)
client.register_commands(filtered_scoped_commands)
def download_file_to_base64(self, file_id: str, source: str) -> Optional[str]:
def download_telegram_file_to_base64(self, file_id: str, source: str) -> Optional[str]:
"""
下载Telegram文件并转为base64
:param file_id: Telegram文件ID
@@ -620,3 +659,15 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
return base64.b64encode(file_content).decode()
return None
def download_telegram_file_bytes(self, file_id: str, source: str) -> Optional[bytes]:
"""
下载Telegram文件并返回原始字节。
"""
config = self.get_config(source)
if not config:
return None
client = self.get_instance(config.name)
if not client:
return None
return client.download_file(file_id)

View File

@@ -3,6 +3,7 @@ import json
import re
import threading
import time
from pathlib import Path
from typing import Any, Optional, List, Dict, Callable, Union
from urllib.parse import urljoin, quote
@@ -461,6 +462,51 @@ class Telegram:
self._stop_typing_task(chat_id)
return {"success": False}
def send_voice(
self,
voice_path: str,
userid: Optional[str] = None,
caption: Optional[str] = None,
original_chat_id: Optional[str] = None,
) -> Optional[dict]:
"""
发送Telegram语音消息。
"""
if not self._bot or not voice_path:
return None
chat_id = self._determine_target_chat_id(userid, original_chat_id)
voice_file = Path(voice_path)
if not voice_file.exists():
logger.error(f"语音文件不存在: {voice_file}")
return {"success": False}
try:
with voice_file.open("rb") as fp:
sent = self._bot.send_voice(
chat_id=chat_id,
voice=fp,
caption=standardize(caption) if caption else None,
parse_mode="MarkdownV2" if caption else None,
)
self._stop_typing_task(chat_id)
if sent and hasattr(sent, "message_id"):
return {
"success": True,
"message_id": sent.message_id,
"chat_id": sent.chat.id if hasattr(sent, "chat") else chat_id,
}
return {"success": bool(sent)}
except Exception as err:
logger.error(f"发送语音消息失败:{err}")
self._stop_typing_task(chat_id)
return {"success": False}
finally:
try:
voice_file.unlink(missing_ok=True)
except Exception as cleanup_err:
logger.debug(f"清理语音临时文件失败: {cleanup_err}")
def _determine_target_chat_id(
self, userid: Optional[str] = None, original_chat_id: Optional[str] = None
) -> str:

View File

@@ -167,6 +167,7 @@ class WechatModule(_ModuleBase, _MessageBase[WeChat]):
# 解析消息内容
content = None
images = None
audio_refs = None
if msg_type == "event" and event == "click":
# 校验用户有权限执行交互命令
if client_config.config.get('WECHAT_ADMINS'):
@@ -192,14 +193,24 @@ class WechatModule(_ModuleBase, _MessageBase[WeChat]):
logger.info(
f"收到来自 {client_config.name} 的微信图片消息userid={user_id}, images={len(images) if images else 0}"
)
elif msg_type == "voice":
media_id = DomUtils.tag_value(root_node, "MediaId")
recognition = DomUtils.tag_value(root_node, "Recognition", default="")
content = (recognition or "").strip()
if media_id:
audio_refs = [f"wxwork://voice_media_id/{media_id}"]
logger.info(
f"收到来自 {client_config.name} 的微信语音消息userid={user_id}, "
f"text={content}, audios={len(audio_refs) if audio_refs else 0}"
)
else:
return None
if content or images:
if content or images or audio_refs:
# 处理消息内容
return CommingMessage(channel=MessageChannel.Wechat, source=client_config.name,
userid=user_id, username=user_id, text=content or "",
images=images)
images=images, audio_refs=audio_refs)
except Exception as err:
logger.error(f"微信消息处理发生错误:{str(err)}")
return None
@@ -230,6 +241,7 @@ class WechatModule(_ModuleBase, _MessageBase[WeChat]):
text = WeChatBot._extract_text_from_body(payload_body)
images = WeChatBot._extract_images_from_body(payload_body)
audio_refs = ["wxbot://voice"] if payload_body.get("msgtype") == "voice" else None
if text:
text = re.sub(r"@\S+", "", text).strip()
@@ -245,7 +257,7 @@ class WechatModule(_ModuleBase, _MessageBase[WeChat]):
client.send_msg(title="只有管理员才有权限执行此命令", userid=sender)
return None
if not text and not images:
if not text and not images and not audio_refs:
return None
logger.info(
@@ -259,6 +271,7 @@ class WechatModule(_ModuleBase, _MessageBase[WeChat]):
username=sender,
text=text or "",
images=images,
audio_refs=audio_refs,
)
def post_message(self, message: Notification, **kwargs) -> None:
@@ -279,8 +292,17 @@ class WechatModule(_ModuleBase, _MessageBase[WeChat]):
return
client: WeChat = self.get_instance(conf.name)
if client:
client.send_msg(title=message.title, text=message.text,
image=message.image, userid=userid, link=message.link)
if message.voice_path and hasattr(client, "send_voice"):
sent = client.send_voice(
voice_path=message.voice_path,
userid=userid,
)
if not sent:
client.send_msg(title=message.title, text=message.text,
image=message.image, userid=userid, link=message.link)
else:
client.send_msg(title=message.title, text=message.text,
image=message.image, userid=userid, link=message.link)
def download_wechat_image_to_data_url(self, image_ref: str, source: str) -> Optional[str]:
"""
@@ -301,6 +323,23 @@ class WechatModule(_ModuleBase, _MessageBase[WeChat]):
return client.download_image_to_data_url(image_ref)
return None
def download_wechat_media_bytes(self, media_ref: str, source: str) -> Optional[bytes]:
"""
下载企业微信语音媒体并返回原始字节。
"""
if not media_ref:
return None
client_config = self.get_config(source)
if not client_config:
return None
client = self.get_instance(client_config.name)
if not client or not hasattr(client, "download_media_bytes"):
return None
if media_ref.startswith("wxwork://voice_media_id/"):
media_id = media_ref.replace("wxwork://voice_media_id/", "", 1)
return client.download_media_bytes(media_id)
return None
def post_medias_message(self, message: Notification, medias: List[MediaInfo]) -> None:
"""
发送媒体信息选择列表

View File

@@ -2,7 +2,9 @@ import json
import re
import threading
import base64
import subprocess
from datetime import datetime
from pathlib import Path
from typing import Optional, List, Dict
from app.core.context import MediaInfo, Context
@@ -46,6 +48,8 @@ class WeChat:
_delete_menu_url = "cgi-bin/menu/delete?access_token={access_token}&agentid={agentid}"
# 企业微信下载媒体URL
_download_media_url = "cgi-bin/media/get?access_token={access_token}&media_id={media_id}"
# 企业微信上传临时素材URL
_upload_media_url = "cgi-bin/media/upload?access_token={access_token}&type={media_type}"
def __init__(self, WECHAT_CORPID: Optional[str] = None, WECHAT_APP_SECRET: Optional[str] = None,
WECHAT_APP_ID: Optional[str] = None, WECHAT_PROXY: Optional[str] = None, **kwargs):
@@ -66,6 +70,7 @@ class WeChat:
self._create_menu_url = UrlUtils.adapt_request_url(self._proxy, self._create_menu_url)
self._delete_menu_url = UrlUtils.adapt_request_url(self._proxy, self._delete_menu_url)
self._download_media_url = UrlUtils.adapt_request_url(self._proxy, self._download_media_url)
self._upload_media_url = UrlUtils.adapt_request_url(self._proxy, self._upload_media_url)
if self._corpid and self._appsecret and self._appid:
self.__get_access_token()
@@ -323,6 +328,168 @@ class WeChat:
mime_type = self._guess_mime_type(res.content, content_type or "image/jpeg")
return f"data:{mime_type};base64,{base64.b64encode(res.content).decode()}"
def download_media_bytes(self, media_id: str) -> Optional[bytes]:
"""
下载企业微信媒体文件并返回原始字节。
"""
if not media_id:
return None
access_token = self.__get_access_token()
if not access_token:
logger.error("下载企业微信媒体失败access_token 获取失败")
return None
req_url = self._download_media_url.format(
access_token=access_token,
media_id=media_id,
)
try:
res = RequestUtils(timeout=30).get_res(req_url)
except Exception as err:
logger.error(f"下载企业微信媒体失败:{err}")
return None
if not res or not res.content:
return None
content_type = (res.headers.get("Content-Type") or "").split(";")[0].strip()
if content_type == "application/json":
try:
logger.error(f"企业微信媒体下载失败:{res.json()}")
except Exception:
logger.error(f"企业微信媒体下载失败:{res.text}")
return None
return res.content
@staticmethod
def _convert_voice_to_amr(voice_path: str) -> Optional[Path]:
"""
将语音文件转换为企业微信要求的 AMR 格式(<=60s
"""
src_path = Path(voice_path)
if not src_path.exists():
logger.error(f"语音文件不存在:{src_path}")
return None
dst_path = src_path.with_suffix(".amr")
cmd = [
"ffmpeg",
"-y",
"-i",
str(src_path),
"-ar",
"8000",
"-ac",
"1",
"-t",
"60",
str(dst_path),
]
try:
result = subprocess.run(
cmd,
capture_output=True,
text=True,
check=False,
)
except Exception as err:
logger.error(f"调用 ffmpeg 转换 AMR 失败:{err}")
return None
if result.returncode != 0 or not dst_path.exists():
logger.error(
"ffmpeg 转换 AMR 失败: returncode=%s, stderr=%s",
result.returncode,
(result.stderr or "").strip()[:500],
)
return None
if dst_path.stat().st_size > 2 * 1024 * 1024:
logger.error("AMR 语音文件超过 2MB无法发送到企业微信")
dst_path.unlink(missing_ok=True)
return None
return dst_path
def _upload_temp_media(self, media_path: Path, media_type: str = "voice") -> Optional[str]:
"""
上传企业微信临时素材,返回 media_id。
"""
access_token = self.__get_access_token()
if not access_token:
return None
req_url = self._upload_media_url.format(
access_token=access_token,
media_type=media_type,
)
try:
with media_path.open("rb") as media_file:
response = RequestUtils(timeout=60).request(
method="post",
url=req_url,
files={
"media": (
media_path.name,
media_file,
"voice/amr" if media_type == "voice" else "application/octet-stream",
)
},
)
except Exception as err:
logger.error(f"上传企业微信临时素材失败:{err}")
return None
if not response:
return None
try:
ret_json = response.json()
except Exception as err:
logger.error(f"解析企业微信临时素材响应失败:{err}")
return None
if ret_json.get("errcode") != 0:
logger.error(f"上传企业微信临时素材失败:{ret_json}")
return None
return ret_json.get("media_id")
def send_voice(self, voice_path: str, userid: Optional[str] = None) -> Optional[bool]:
"""
发送企业微信语音消息。仅自建应用模式支持。
"""
if not voice_path:
return False
if not self.__get_access_token():
logger.error("获取微信access_token失败请检查参数配置")
return None
if not userid:
userid = "@all"
source_path = Path(voice_path)
converted_path = self._convert_voice_to_amr(voice_path)
if not converted_path:
return False
try:
media_id = self._upload_temp_media(converted_path, media_type="voice")
if not media_id:
return False
req_json = {
"touser": userid,
"msgtype": "voice",
"agentid": self._appid,
"voice": {
"media_id": media_id
},
"safe": 0,
"enable_id_trans": 0,
"enable_duplicate_check": 0
}
return self.__post_request(self._send_msg_url, req_json)
except Exception as err:
logger.error(f"发送企业微信语音消息失败:{err}")
return False
finally:
converted_path.unlink(missing_ok=True)
source_path.unlink(missing_ok=True)
def send_medias_msg(self, medias: List[MediaInfo], userid: Optional[str] = None) -> Optional[bool]:
"""
发送列表类消息

View File

@@ -55,6 +55,8 @@ class CommingMessage(BaseModel):
callback_query: Optional[Dict] = None
# 图片列表图片URL或file_id
images: Optional[List[str]] = None
# 语音/音频引用列表
audio_refs: Optional[List[str]] = None
def to_dict(self):
"""
@@ -86,6 +88,10 @@ class Notification(BaseModel):
text: Optional[str] = None
# 图片
image: Optional[str] = None
# 语音文件路径
voice_path: Optional[str] = None
# 语音消息附带说明文字
voice_caption: Optional[str] = None
# 链接
link: Optional[str] = None
# 用户ID

View File

@@ -2,11 +2,12 @@ import base64
import json
import unittest
from types import SimpleNamespace
from unittest.mock import Mock, patch
from unittest.mock import AsyncMock, Mock, patch
from telebot import apihelper
from app.agent.tools.impl.send_message import SendMessageInput
from app.agent import MoviePilotAgent, AgentChain
from app.chain.message import MessageChain
from app.core.config import settings
from app.modules.discord import DiscordModule
@@ -23,6 +24,19 @@ from app.schemas.types import MessageChannel
class AgentImageSupportTest(unittest.TestCase):
def test_telegram_extract_audio_refs_returns_prefixed_file_ids(self):
audio_refs = TelegramModule._extract_audio_refs(
{
"voice": {"file_id": "voice-1"},
"audio": {"file_id": "audio-1"},
}
)
self.assertEqual(
audio_refs,
["tg://voice_file_id/voice-1", "tg://audio_file_id/audio-1"],
)
def test_telegram_extract_images_returns_prefixed_file_ids(self):
images = TelegramModule._extract_images(
{
@@ -126,6 +140,25 @@ class AgentImageSupportTest(unittest.TestCase):
self.assertEqual(handle_kwargs["text"], "")
self.assertEqual(handle_kwargs["images"], ["tg://file_id/image-1"])
def test_process_allows_audio_only_message(self):
chain = MessageChain()
message = CommingMessage(
channel=MessageChannel.Telegram,
source="telegram-test",
userid="10001",
username="tester",
audio_refs=["tg://voice_file_id/voice-1"],
)
with patch.object(chain, "message_parser", return_value=message), patch.object(
chain, "handle_message"
) as handle_message:
chain.process(body="{}", form={}, args={"source": "telegram-test"})
handle_kwargs = handle_message.call_args.kwargs
self.assertEqual(handle_kwargs["text"], "")
self.assertEqual(handle_kwargs["audio_refs"], ["tg://voice_file_id/voice-1"])
def test_image_message_routes_to_agent_even_when_global_agent_is_disabled(self):
chain = MessageChain()
@@ -149,6 +182,48 @@ class AgentImageSupportTest(unittest.TestCase):
handle_ai_message.assert_called_once()
def test_audio_message_routes_to_agent_with_voice_reply_flag(self):
chain = MessageChain()
with patch.object(chain, "load_cache", return_value={}), patch.object(
chain, "_transcribe_audio_refs", return_value="帮我推荐一部电影"
), patch.object(chain.messagehelper, "put"), patch.object(
chain.messageoper, "add"
), patch.object(chain, "_handle_ai_message") as handle_ai_message:
chain.handle_message(
channel=MessageChannel.Telegram,
source="telegram-test",
userid="10001",
username="tester",
text="",
audio_refs=["tg://voice_file_id/voice-1"],
)
handle_ai_message.assert_called_once()
self.assertEqual(handle_ai_message.call_args.kwargs["text"], "帮我推荐一部电影")
self.assertTrue(handle_ai_message.call_args.kwargs["reply_with_voice"])
def test_agent_send_agent_message_does_not_auto_convert_to_voice(self):
agent = MoviePilotAgent(
session_id="session-1",
user_id="user-1",
channel=MessageChannel.Telegram.value,
source="telegram-test",
username="tester",
)
agent.reply_with_voice = True
with patch.object(
AgentChain, "async_post_message", new_callable=AsyncMock
) as async_post_message:
import asyncio
asyncio.run(agent.send_agent_message("这是语音回复"))
notification = async_post_message.await_args.args[0]
self.assertIsNone(notification.voice_path)
self.assertEqual(notification.text, "这是语音回复")
def test_slack_images_use_authenticated_data_url_download(self):
chain = MessageChain()

View File

@@ -0,0 +1,55 @@
import unittest
import sys
from unittest.mock import Mock, patch
sys.modules.setdefault("psutil", Mock())
sys.modules.setdefault("pyquery", Mock())
from app.core.config import settings
from app.helper.voice import VoiceHelper, OpenAIVoiceProvider
class VoiceHelperTest(unittest.TestCase):
def test_registered_providers_contains_openai(self):
self.assertIn("openai", VoiceHelper.get_registered_providers())
def test_get_provider_falls_back_to_global_provider(self):
with patch.object(settings, "AI_VOICE_PROVIDER", "openai"), patch.object(
settings, "AI_VOICE_STT_PROVIDER", None
):
provider = VoiceHelper.get_provider("stt")
self.assertIsInstance(provider, OpenAIVoiceProvider)
def test_is_available_checks_stt_and_tts_separately(self):
provider = Mock()
provider.is_available_for_stt.return_value = True
provider.is_available_for_tts.return_value = False
with patch.object(VoiceHelper, "get_provider", return_value=provider):
self.assertTrue(VoiceHelper.is_available("stt"))
self.assertFalse(VoiceHelper.is_available("tts"))
def test_transcribe_bytes_routes_to_stt_provider(self):
provider = Mock()
provider.transcribe_bytes.return_value = "你好"
with patch.object(VoiceHelper, "get_provider", return_value=provider):
result = VoiceHelper.transcribe_bytes(b"audio")
self.assertEqual(result, "你好")
provider.transcribe_bytes.assert_called_once()
def test_synthesize_speech_routes_to_tts_provider(self):
provider = Mock()
provider.synthesize_speech.return_value = "/tmp/reply.opus"
with patch.object(VoiceHelper, "get_provider", return_value=provider):
result = VoiceHelper.synthesize_speech("你好")
self.assertEqual(result, "/tmp/reply.opus")
provider.synthesize_speech.assert_called_once_with(text="你好")
if __name__ == "__main__":
unittest.main()