diff --git a/app/agent/__init__.py b/app/agent/__init__.py index 9eac704c..dae14d29 100644 --- a/app/agent/__init__.py +++ b/app/agent/__init__.py @@ -129,6 +129,8 @@ class MoviePilotAgent: self.channel = channel self.source = source self.username = username + self.reply_with_voice = False + self._tool_context: Dict[str, object] = {} # 流式token管理 self.stream_handler = StreamingHandler() @@ -151,6 +153,8 @@ class MoviePilotAgent: """ if self.is_background: return False + if self.reply_with_voice: + return False # 啰嗦模式下始终需要流式输出来捕获工具调用前的 Agent 文字 if settings.AI_AGENT_VERBOSE: return True @@ -214,6 +218,7 @@ class MoviePilotAgent: source=self.source, username=self.username, stream_handler=self.stream_handler, + agent_context=self._tool_context, ) def _create_agent(self, streaming: bool = False): @@ -223,7 +228,10 @@ class MoviePilotAgent: """ try: # 系统提示词 - system_prompt = prompt_manager.get_agent_prompt(channel=self.channel) + system_prompt = prompt_manager.get_agent_prompt( + channel=self.channel, + prefer_voice_reply=self.reply_with_voice, + ) # LLM 模型(用于 agent 执行) llm = self._initialize_llm(streaming=streaming) @@ -281,6 +289,11 @@ class MoviePilotAgent: logger.info( f"Agent推理: session_id={self.session_id}, input={message}, images={len(images) if images else 0}" ) + self._tool_context = { + "incoming_voice": self.reply_with_voice, + "user_reply_sent": False, + "reply_mode": None, + } # 获取历史消息 messages = memory_manager.get_agent_messages( @@ -417,7 +430,7 @@ class MoviePilotAgent: # 流式输出未能发送全部内容(发送失败等) # 通过常规方式发送剩余内容 remaining_text = await self.stream_handler.take() - if remaining_text: + if remaining_text and not self._tool_context.get("user_reply_sent"): await self.send_agent_message(remaining_text) elif streamed_text: # 流式输出已发送全部内容,但未记录到数据库,补充保存消息记录 @@ -447,7 +460,7 @@ class MoviePilotAgent: final_text = text.strip() break - if final_text: + if final_text and not self._tool_context.get("user_reply_sent"): if self.is_background: # 后台任务仅广播最终回复,带标题 await self.send_agent_message( @@ -534,6 +547,7 @@ class _MessageTask: channel: Optional[str] = None source: Optional[str] = None username: Optional[str] = None + reply_with_voice: bool = False class AgentManager: @@ -599,6 +613,7 @@ class AgentManager: channel: str = None, source: str = None, username: str = None, + reply_with_voice: bool = False, ) -> str: """ 处理用户消息:将消息放入会话队列,按顺序依次处理。 @@ -612,6 +627,7 @@ class AgentManager: channel=channel, source=source, username=username, + reply_with_voice=reply_with_voice, ) # 获取或创建会话队列 @@ -709,6 +725,7 @@ class AgentManager: agent.source = task.source if task.username: agent.username = task.username + agent.reply_with_voice = task.reply_with_voice return await agent.process(task.message, images=task.images) diff --git a/app/agent/prompt/Agent Prompt.txt b/app/agent/prompt/Agent Prompt.txt index e33fbbf3..2d4e9c4b 100644 --- a/app/agent/prompt/Agent Prompt.txt +++ b/app/agent/prompt/Agent Prompt.txt @@ -21,6 +21,7 @@ Core Capabilities: - Include key details (year, rating, resolution) but do NOT over-explain. - Do not stop for approval on read-only operations. Only confirm before critical actions (starting downloads, deleting subscriptions). - If the current channel supports image sending and an image would materially help, you may use the `send_message` tool with `image_url` to send it. +- Voice replies: {voice_reply_spec} - NOT a coding assistant. Do not offer code snippets. - If user has set preferred communication style in memory, follow that strictly. diff --git a/app/agent/prompt/__init__.py b/app/agent/prompt/__init__.py index 3680c5af..4c34ceaa 100644 --- a/app/agent/prompt/__init__.py +++ b/app/agent/prompt/__init__.py @@ -50,10 +50,13 @@ class PromptManager: logger.error(f"加载提示词失败: {prompt_name}, 错误: {e}") raise - def get_agent_prompt(self, channel: str = None) -> str: + def get_agent_prompt( + self, channel: str = None, prefer_voice_reply: bool = False + ) -> str: """ 获取智能体提示词 :param channel: 消息渠道(Telegram、微信、Slack等) + :param prefer_voice_reply: 是否优先使用语音回复 :return: 提示词内容 """ # 基础提示词 @@ -87,12 +90,16 @@ class PromptManager: # MoviePilot系统信息 moviepilot_info = self._get_moviepilot_info() + voice_reply_spec = self._generate_voice_reply_instructions( + prefer_voice_reply=prefer_voice_reply + ) # 始终替换占位符,避免后续 .format() 时因残留花括号报 KeyError base_prompt = base_prompt.format( markdown_spec=markdown_spec, verbose_spec=verbose_spec, moviepilot_info=moviepilot_info, + voice_reply_spec=voice_reply_spec, ) return base_prompt @@ -166,6 +173,20 @@ class PromptManager: instructions.append("- Links: Paste URLs directly as text.") return "\n".join(instructions) + @staticmethod + def _generate_voice_reply_instructions(prefer_voice_reply: bool) -> str: + if not prefer_voice_reply: + return ( + "- Voice replies: Use normal text replies by default. " + "Only call `send_voice_message` when spoken playback is clearly better than plain text." + ) + return ( + "- Current message context: The user sent a voice message.\n" + "- Reply preference: Prioritize calling `send_voice_message` for the main user-facing reply.\n" + "- Fallback: If voice is unavailable on the current channel, `send_voice_message` will fall back to text.\n" + "- Do not repeat the same full reply again after calling `send_voice_message`." + ) + def clear_cache(self): """ 清空缓存 diff --git a/app/agent/tools/base.py b/app/agent/tools/base.py index 40f2bc3a..76e3e287 100644 --- a/app/agent/tools/base.py +++ b/app/agent/tools/base.py @@ -31,6 +31,7 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta): _username: Optional[str] = PrivateAttr(default=None) _stream_handler: Optional[StreamingHandler] = PrivateAttr(default=None) _require_admin: bool = PrivateAttr(default=False) + _agent_context: dict = PrivateAttr(default_factory=dict) def __init__(self, session_id: str, user_id: str, **kwargs): super().__init__(**kwargs) @@ -142,6 +143,12 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta): """ self._stream_handler = stream_handler + def set_agent_context(self, agent_context: Optional[dict]): + """ + 设置与当前 Agent 共享的上下文。 + """ + self._agent_context = agent_context or {} + async def _check_permission(self) -> Optional[str]: """ 检查用户权限: diff --git a/app/agent/tools/factory.py b/app/agent/tools/factory.py index 4c741d37..f54251af 100644 --- a/app/agent/tools/factory.py +++ b/app/agent/tools/factory.py @@ -30,6 +30,7 @@ from app.agent.tools.impl.search_torrents import SearchTorrentsTool from app.agent.tools.impl.get_search_results import GetSearchResultsTool from app.agent.tools.impl.search_web import SearchWebTool from app.agent.tools.impl.send_message import SendMessageTool +from app.agent.tools.impl.send_voice_message import SendVoiceMessageTool from app.agent.tools.impl.query_schedulers import QuerySchedulersTool from app.agent.tools.impl.run_scheduler import RunSchedulerTool from app.agent.tools.impl.query_workflows import QueryWorkflowsTool @@ -72,6 +73,7 @@ class MoviePilotToolFactory: source: str = None, username: str = None, stream_handler: Callable = None, + agent_context: dict = None, ) -> List[MoviePilotTool]: """ 创建MoviePilot工具列表 @@ -117,6 +119,7 @@ class MoviePilotToolFactory: QueryTransferHistoryTool, TransferFileTool, SendMessageTool, + SendVoiceMessageTool, QuerySchedulersTool, RunSchedulerTool, QueryWorkflowsTool, @@ -138,6 +141,7 @@ class MoviePilotToolFactory: tool = ToolClass(session_id=session_id, user_id=user_id) tool.set_message_attr(channel=channel, source=source, username=username) tool.set_stream_handler(stream_handler=stream_handler) + tool.set_agent_context(agent_context=agent_context) tools.append(tool) # 加载插件提供的工具 @@ -161,6 +165,7 @@ class MoviePilotToolFactory: channel=channel, source=source, username=username ) tool.set_stream_handler(stream_handler=stream_handler) + tool.set_agent_context(agent_context=agent_context) tools.append(tool) plugin_tools_count += 1 logger.debug( diff --git a/app/agent/tools/impl/send_voice_message.py b/app/agent/tools/impl/send_voice_message.py new file mode 100644 index 00000000..6f2b199a --- /dev/null +++ b/app/agent/tools/impl/send_voice_message.py @@ -0,0 +1,96 @@ +"""发送语音消息工具。""" + +import asyncio +from typing import Optional, Type + +from pydantic import BaseModel, Field + +from app.agent.tools.base import MoviePilotTool, ToolChain +from app.core.config import settings +from app.helper.voice import VoiceHelper +from app.helper.service import ServiceConfigHelper +from app.log import logger +from app.schemas import Notification, NotificationType +from app.schemas.types import MessageChannel + + +class SendVoiceMessageInput(BaseModel): + """发送语音消息工具输入。""" + + explanation: str = Field( + ..., + description="Clear explanation of why a voice reply is the best fit in the current context", + ) + message: str = Field( + ..., + description="The spoken content to send back to the user", + ) + + +class SendVoiceMessageTool(MoviePilotTool): + name: str = "send_voice_message" + description: str = ( + "Send a voice reply to the current user. Prefer this when the user sent a voice message " + "or when spoken playback is more natural. On channels without voice support or when TTS " + "is unavailable, it automatically falls back to sending the same content as plain text." + ) + args_schema: Type[BaseModel] = SendVoiceMessageInput + require_admin: bool = False + + def get_tool_message(self, **kwargs) -> Optional[str]: + message = kwargs.get("message") or "" + if len(message) > 40: + message = message[:40] + "..." + return f"正在发送语音回复: {message}" + + def _supports_real_voice_reply(self) -> bool: + channel = self._channel or "" + if channel == MessageChannel.Telegram.value: + return True + if channel != MessageChannel.Wechat.value: + return False + for config in ServiceConfigHelper.get_notification_configs(): + if config.name != self._source: + continue + return (config.config or {}).get("WECHAT_MODE", "app") != "bot" + return False + + async def run(self, message: str, **kwargs) -> str: + if not message: + return "语音回复内容不能为空" + + voice_path = None + used_voice = False + channel = self._channel or "" + if self._supports_real_voice_reply() and VoiceHelper.is_available("tts"): + voice_file = await asyncio.to_thread(VoiceHelper.synthesize_speech, message) + if voice_file: + voice_path = str(voice_file) + used_voice = True + + logger.info( + "执行工具: %s, channel=%s, use_voice=%s, text_len=%s", + self.name, + channel, + used_voice, + len(message), + ) + + await ToolChain().async_post_message( + Notification( + channel=self._channel, + source=self._source, + mtype=NotificationType.Agent, + userid=self._user_id, + username=self._username, + text=message, + voice_path=voice_path, + voice_caption=message if settings.AI_VOICE_REPLY_WITH_TEXT else None, + ) + ) + self._agent_context["user_reply_sent"] = True + self._agent_context["reply_mode"] = "voice" if used_voice else "text_fallback" + + if used_voice: + return "语音回复已发送" + return "当前未使用语音通道,已自动回退为文字回复" diff --git a/app/chain/message.py b/app/chain/message.py index cf2e63fb..e6648e75 100644 --- a/app/chain/message.py +++ b/app/chain/message.py @@ -4,6 +4,8 @@ import time from datetime import datetime, timedelta from typing import Any, Optional, Dict, Union, List +import base64 + from app.agent import agent_manager from app.chain import ChainBase from app.chain.download import DownloadChain @@ -15,6 +17,7 @@ from app.core.context import MediaInfo, Context from app.core.meta import MetaBase from app.db.user_oper import UserOper from app.helper.torrent import TorrentHelper +from app.helper.voice import VoiceHelper from app.log import logger from app.schemas import Notification, NotExistMediaInfo, CommingMessage from app.schemas.message import ChannelCapabilityManager @@ -44,26 +47,6 @@ class MessageChain(ChainBase): # 会话超时时间(分钟) _session_timeout_minutes: int = 24 * 60 - @staticmethod - def _summarize_images(images: Optional[List[str]], max_items: int = 3) -> List[str]: - """ - 图片引用摘要,避免日志过长或直接输出完整 base64。 - """ - if not images: - return [] - summary = [] - for image in images[:max_items]: - if not image: - continue - image = str(image) - if image.startswith("data:"): - summary.append(f"{image[:32]}...({len(image)} chars)") - elif len(image) > 120: - summary.append(f"{image[:117]}...") - else: - summary.append(image) - return summary - @staticmethod def __get_noexits_info( _meta: MetaBase, _mediainfo: MediaInfo @@ -146,23 +129,15 @@ class MessageChain(ChainBase): if userid is None or userid == "": logger.debug(f"未识别到用户ID:{body}{form}{args}") return + # 消息内容 text = str(info.text).strip() if info.text else "" images = info.images - if not text and not images: + audio_refs = info.audio_refs + if not text and not images and not audio_refs: logger.debug(f"未识别到消息内容::{body}{form}{args}") return - logger.info( - "消息链路解析完成: source=%s, channel=%s, userid=%s, text_len=%s, image_count=%s, image_refs=%s", - source, - channel.value if channel else None, - userid, - len(text), - len(images or []), - self._summarize_images(images), - ) - # 获取原消息ID信息 original_message_id = info.message_id original_chat_id = info.chat_id @@ -177,6 +152,7 @@ class MessageChain(ChainBase): original_message_id=original_message_id, original_chat_id=original_chat_id, images=images, + audio_refs=audio_refs, ) def handle_message( @@ -189,25 +165,43 @@ class MessageChain(ChainBase): original_message_id: Optional[Union[str, int]] = None, original_chat_id: Optional[str] = None, images: Optional[List[str]] = None, + audio_refs: Optional[List[str]] = None, ) -> None: """ 识别消息内容,执行操作 """ # 申明全局变量 global _current_page, _current_meta, _current_media - # 处理消息 - logger.info( - "收到用户消息内容: channel=%s, source=%s, userid=%s, text=%s, image_count=%s, image_refs=%s", - channel.value if channel else None, - source, - userid, - text, - len(images or []), - self._summarize_images(images), - ) + # 加载缓存 user_cache: Dict[str, dict] = self.load_cache(self._cache_file) or {} + try: + # 识别语音为文本 + reply_with_voice = bool(audio_refs) + if audio_refs: + transcript = self._transcribe_audio_refs(audio_refs, channel, source) + merged_parts = [] + seen_parts = set() + for item in [text.strip() if text else "", transcript or ""]: + normalized = item.strip() + if not normalized or normalized in seen_parts: + continue + seen_parts.add(normalized) + merged_parts.append(normalized) + text = "\n".join(merged_parts).strip() + if not text: + self.post_message( + Notification( + channel=channel, + source=source, + userid=userid, + username=username, + title="语音识别失败,请稍后重试", + ) + ) + return + # 保存消息 if not text.startswith("CALLBACK:"): self.messagehelper.put( @@ -251,32 +245,6 @@ class MessageChain(ChainBase): {"cmd": text, "user": userid, "channel": channel, "source": source}, ) elif text.lower().startswith("/ai"): - # 用户指定AI智能体消息响应 - logger.info( - "消息链路分流到AI: reason=explicit_ai, channel=%s, source=%s, userid=%s, image_count=%s", - channel.value if channel else None, - source, - userid, - len(images or []), - ) - self._handle_ai_message( - text=text, - channel=channel, - source=source, - userid=userid, - username=username, - images=images, - ) - elif settings.AI_AGENT_ENABLE and images: - # 带图消息优先交给智能体处理,避免图片在传统消息链路中丢失 - logger.info( - "消息链路分流到AI: reason=image_message, channel=%s, source=%s, userid=%s, image_count=%s, image_refs=%s", - channel.value if channel else None, - source, - userid, - len(images or []), - self._summarize_images(images), - ) self._handle_ai_message( text=text, channel=channel, @@ -284,16 +252,10 @@ class MessageChain(ChainBase): userid=userid, username=username, images=images, + reply_with_voice=reply_with_voice, ) elif settings.AI_AGENT_ENABLE and settings.AI_AGENT_GLOBAL: # 普通消息,全局智能体响应 - logger.info( - "消息链路分流到AI: reason=global_agent, channel=%s, source=%s, userid=%s, image_count=%s", - channel.value if channel else None, - source, - userid, - len(images or []), - ) self._handle_ai_message( text=text, channel=channel, @@ -301,6 +263,7 @@ class MessageChain(ChainBase): userid=userid, username=username, images=images, + reply_with_voice=reply_with_voice, ) else: # 非智能体普通消息响应 @@ -1266,6 +1229,7 @@ class MessageChain(ChainBase): userid: Union[str, int], username: str, images: Optional[List[str]] = None, + reply_with_voice: bool = False, ) -> None: """ 处理AI智能体消息 @@ -1290,16 +1254,6 @@ class MessageChain(ChainBase): else: user_message = text.strip() # 按原消息处理 - logger.info( - "AI消息入口: channel=%s, source=%s, userid=%s, text_len=%s, raw_image_count=%s, raw_image_refs=%s", - channel.value if channel else None, - source, - userid, - len(user_message), - len(images or []), - self._summarize_images(images), - ) - if not user_message and not images: self.post_message( Notification( @@ -1319,15 +1273,6 @@ class MessageChain(ChainBase): original_images = images if images: images = self._download_images_to_base64(images, channel, source) - logger.info( - "AI图片预处理完成: channel=%s, source=%s, userid=%s, raw_image_count=%s, converted_image_count=%s, converted_image_refs=%s", - channel.value if channel else None, - source, - userid, - len(original_images or []), - len(images or []), - self._summarize_images(images), - ) if original_images and not images and not user_message: self.post_message( Notification( @@ -1350,6 +1295,7 @@ class MessageChain(ChainBase): channel=channel.value if channel else None, source=source, username=username, + reply_with_voice=reply_with_voice, ), global_vars.loop, ) @@ -1360,6 +1306,64 @@ class MessageChain(ChainBase): f"AI智能体处理失败: {str(e)}", role="system", title="MoviePilot助手" ) + def _transcribe_audio_refs( + self, audio_refs: List[str], channel: MessageChannel, source: str + ) -> Optional[str]: + """ + 下载并识别语音消息,仅处理当前已接入的渠道。 + """ + if not audio_refs: + return None + if not VoiceHelper.is_available("stt"): + logger.warning("语音能力未配置,跳过语音识别") + return None + + transcripts = [] + for audio_ref in audio_refs: + try: + if audio_ref.startswith("tg://voice_file_id/"): + file_id = audio_ref.replace("tg://voice_file_id/", "", 1) + content = self.run_module( + "download_telegram_file_bytes", file_id=file_id, source=source + ) + filename = "input.ogg" + elif audio_ref.startswith("tg://audio_file_id/"): + file_id = audio_ref.replace("tg://audio_file_id/", "", 1) + content = self.run_module( + "download_telegram_file_bytes", file_id=file_id, source=source + ) + filename = "input.mp3" + elif audio_ref.startswith("wxwork://voice_media_id/"): + content = self.run_module( + "download_wechat_media_bytes", media_ref=audio_ref, source=source + ) + filename = "input.amr" + elif audio_ref.startswith("wxbot://voice"): + continue + else: + logger.debug( + "暂不支持的语音引用: channel=%s, source=%s, ref=%s", + channel.value if channel else None, + source, + audio_ref, + ) + continue + + transcript = VoiceHelper.transcribe_bytes(content=content, filename=filename) + if transcript: + transcripts.append(transcript) + logger.info( + "语音识别成功: channel=%s, source=%s, ref=%s, text_len=%s", + channel.value if channel else None, + source, + audio_ref, + len(transcript), + ) + except Exception as err: + logger.error(f"语音识别失败: {err}") + + return "\n".join(transcripts).strip() if transcripts else None + def _download_images_to_base64( self, images: List[str], channel: MessageChannel, source: str ) -> List[str]: @@ -1373,16 +1377,10 @@ class MessageChain(ChainBase): try: if img.startswith("data:"): base64_images.append(img) - logger.info( - "图片无需下载: channel=%s, source=%s, input=%s", - channel.value if channel else None, - source, - self._summarize_images([img])[0], - ) elif img.startswith("tg://file_id/"): file_id = img.replace("tg://file_id/", "") base64_data = self.run_module( - "download_file_to_base64", file_id=file_id, source=source + "download_telegram_file_to_base64", file_id=file_id, source=source ) if base64_data: base64_images.append(f"data:image/jpeg;base64,{base64_data}") @@ -1402,26 +1400,12 @@ class MessageChain(ChainBase): ) if data_url: base64_images.append(data_url) - logger.info( - "图片下载成功: channel=%s, source=%s, input=%s, output=%s", - channel.value if channel else None, - source, - img, - self._summarize_images([data_url])[0], - ) elif channel == MessageChannel.Slack: data_url = self.run_module( - "download_file_to_data_url", file_url=img, source=source + "download_slack_file_to_data_url", file_url=img, source=source ) if data_url: base64_images.append(data_url) - logger.info( - "图片下载成功: channel=%s, source=%s, input=%s, output=%s", - channel.value if channel else None, - source, - img, - self._summarize_images([data_url])[0], - ) elif img.startswith("vocechat://file/"): data_url = self.run_module( "download_vocechat_image_to_data_url", @@ -1430,30 +1414,12 @@ class MessageChain(ChainBase): ) if data_url: base64_images.append(data_url) - logger.info( - "图片下载成功: channel=%s, source=%s, input=%s, output=%s", - channel.value if channel else None, - source, - img, - self._summarize_images([data_url])[0], - ) elif img.startswith("http"): resp = RequestUtils(timeout=30).get_res(img) if resp and resp.content: - import base64 - base64_data = base64.b64encode(resp.content).decode() mime_type = resp.headers.get("Content-Type", "image/jpeg") base64_images.append(f"data:{mime_type};base64,{base64_data}") - logger.info( - "图片下载成功: channel=%s, source=%s, input=%s, output=%s", - channel.value if channel else None, - source, - img, - self._summarize_images( - [f"data:{mime_type};base64,{base64_data}"] - )[0], - ) except Exception as e: logger.error(f"下载图片失败: {img}, error: {e}") return base64_images if base64_images else None diff --git a/app/core/config.py b/app/core/config.py index fcfa79e0..e1eb6fd7 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -538,6 +538,35 @@ class ConfigModel(BaseModel): # AI智能体自动重试整理失败记录开关 AI_AGENT_RETRY_TRANSFER: bool = False + # 语音能力提供商(当前仅支持 openai) + AI_VOICE_PROVIDER: str = "openai" + # 语音识别提供商,未设置时回退到 AI_VOICE_PROVIDER + AI_VOICE_STT_PROVIDER: Optional[str] = None + # 语音合成提供商,未设置时回退到 AI_VOICE_PROVIDER + AI_VOICE_TTS_PROVIDER: Optional[str] = None + # 语音能力 API 密钥,未设置且 LLM_PROVIDER=openai 时回退使用 LLM_API_KEY + AI_VOICE_API_KEY: Optional[str] = None + # 语音识别 API 密钥,未设置时回退到 AI_VOICE_API_KEY + AI_VOICE_STT_API_KEY: Optional[str] = None + # 语音合成 API 密钥,未设置时回退到 AI_VOICE_API_KEY + AI_VOICE_TTS_API_KEY: Optional[str] = None + # 语音能力基础URL,未设置且 LLM_PROVIDER=openai 时回退使用 LLM_BASE_URL + AI_VOICE_BASE_URL: Optional[str] = None + # 语音识别基础URL,未设置时回退到 AI_VOICE_BASE_URL + AI_VOICE_STT_BASE_URL: Optional[str] = None + # 语音合成基础URL,未设置时回退到 AI_VOICE_BASE_URL + AI_VOICE_TTS_BASE_URL: Optional[str] = None + # 语音转文字模型 + AI_VOICE_STT_MODEL: str = "gpt-4o-mini-transcribe" + # 文字转语音模型 + AI_VOICE_TTS_MODEL: str = "gpt-4o-mini-tts" + # TTS 发音人 + AI_VOICE_TTS_VOICE: str = "alloy" + # 语音识别语言 + AI_VOICE_LANGUAGE: str = "zh" + # 回复语音时是否同时附带文字说明 + AI_VOICE_REPLY_WITH_TEXT: bool = False + class Settings(BaseSettings, ConfigModel, LogConfigModel): """ diff --git a/app/helper/voice.py b/app/helper/voice.py new file mode 100644 index 00000000..a05f6236 --- /dev/null +++ b/app/helper/voice.py @@ -0,0 +1,197 @@ +"""语音能力辅助功能。""" + +from abc import ABC, abstractmethod +from io import BytesIO +from pathlib import Path +from typing import Dict, Optional +from uuid import uuid4 + +from app.core.config import settings +from app.log import logger + + +class VoiceProvider(ABC): + """语音 provider 抽象层。""" + + MAX_TRANSCRIBE_BYTES = 25 * 1024 * 1024 + + @property + @abstractmethod + def name(self) -> str: + """provider 名称。""" + + @abstractmethod + def is_available_for_stt(self) -> bool: + """是否可用于语音识别。""" + + @abstractmethod + def is_available_for_tts(self) -> bool: + """是否可用于语音合成。""" + + @abstractmethod + def transcribe_bytes(self, content: bytes, filename: str = "input.ogg") -> Optional[str]: + """将音频字节转成文字。""" + + @abstractmethod + def synthesize_speech(self, text: str) -> Optional[Path]: + """将文字转成语音文件。""" + + +class OpenAIVoiceProvider(VoiceProvider): + """OpenAI / OpenAI-compatible provider。""" + + @property + def name(self) -> str: + return "openai" + + @staticmethod + def _resolve_credentials(mode: str) -> tuple[Optional[str], Optional[str]]: + mode = mode.lower() + provider = ( + settings.AI_VOICE_STT_PROVIDER + if mode == "stt" + else settings.AI_VOICE_TTS_PROVIDER + ) or settings.AI_VOICE_PROVIDER + provider = (provider or "").strip().lower() + + api_key = ( + settings.AI_VOICE_STT_API_KEY + if mode == "stt" + else settings.AI_VOICE_TTS_API_KEY + ) or settings.AI_VOICE_API_KEY + base_url = ( + settings.AI_VOICE_STT_BASE_URL + if mode == "stt" + else settings.AI_VOICE_TTS_BASE_URL + ) or settings.AI_VOICE_BASE_URL + + if ( + not api_key + and provider == "openai" + and (settings.LLM_PROVIDER or "").strip().lower() == "openai" + ): + api_key = settings.LLM_API_KEY + base_url = base_url or settings.LLM_BASE_URL + + return api_key, base_url + + def _get_client(self, mode: str): + from openai import OpenAI + + api_key, base_url = self._resolve_credentials(mode) + if not api_key: + raise ValueError(f"{mode.upper()} provider 未配置 API Key") + return OpenAI(api_key=api_key, base_url=base_url, max_retries=3) + + def is_available_for_stt(self) -> bool: + api_key, _ = self._resolve_credentials("stt") + return bool(api_key) + + def is_available_for_tts(self) -> bool: + api_key, _ = self._resolve_credentials("tts") + return bool(api_key) + + def transcribe_bytes(self, content: bytes, filename: str = "input.ogg") -> Optional[str]: + if not content: + return None + if len(content) > self.MAX_TRANSCRIBE_BYTES: + raise ValueError("语音文件超过 25MB,无法识别") + + try: + client = self._get_client("stt") + audio_file = BytesIO(content) + audio_file.name = filename + response = client.audio.transcriptions.create( + model=settings.AI_VOICE_STT_MODEL, + file=audio_file, + language=settings.AI_VOICE_LANGUAGE or "zh", + response_format="verbose_json", + ) + text = getattr(response, "text", None) + return text.strip() if text else None + except Exception as err: + logger.error(f"语音转文字失败: provider={self.name}, error={err}") + return None + + def synthesize_speech(self, text: str) -> Optional[Path]: + if not text: + return None + + try: + client = self._get_client("tts") + voice_dir = settings.TEMP_PATH / "voice" + voice_dir.mkdir(parents=True, exist_ok=True) + output_path = voice_dir / f"{uuid4().hex}.opus" + response = client.audio.speech.create( + model=settings.AI_VOICE_TTS_MODEL, + voice=settings.AI_VOICE_TTS_VOICE, + input=text, + response_format="opus", + ) + response.write_to_file(output_path) + return output_path + except Exception as err: + logger.error(f"文字转语音失败: provider={self.name}, error={err}") + return None + + +class VoiceHelper: + """统一语音入口,负责按 STT/TTS provider 路由。""" + + _providers: Dict[str, VoiceProvider] = { + "openai": OpenAIVoiceProvider(), + } + + @classmethod + def register_provider(cls, provider: VoiceProvider) -> None: + cls._providers[provider.name.lower()] = provider + + @staticmethod + def _resolve_provider_name(mode: str) -> str: + mode = mode.lower() + provider = ( + settings.AI_VOICE_STT_PROVIDER + if mode == "stt" + else settings.AI_VOICE_TTS_PROVIDER + ) or settings.AI_VOICE_PROVIDER + return (provider or "openai").strip().lower() + + @classmethod + def get_provider(cls, mode: str) -> Optional[VoiceProvider]: + provider_name = cls._resolve_provider_name(mode) + provider = cls._providers.get(provider_name) + if provider: + return provider + logger.warning(f"未注册语音 provider: mode={mode}, provider={provider_name}") + return None + + @classmethod + def get_registered_providers(cls) -> list[str]: + return sorted(cls._providers.keys()) + + @classmethod + def is_available(cls, mode: Optional[str] = None) -> bool: + if mode: + provider = cls.get_provider(mode) + if not provider: + return False + return ( + provider.is_available_for_stt() + if mode.lower() == "stt" + else provider.is_available_for_tts() + ) + return cls.is_available("stt") or cls.is_available("tts") + + @classmethod + def transcribe_bytes(cls, content: bytes, filename: str = "input.ogg") -> Optional[str]: + provider = cls.get_provider("stt") + if not provider: + return None + return provider.transcribe_bytes(content=content, filename=filename) + + @classmethod + def synthesize_speech(cls, text: str) -> Optional[Path]: + provider = cls.get_provider("tts") + if not provider: + return None + return provider.synthesize_speech(text=text) diff --git a/app/modules/slack/__init__.py b/app/modules/slack/__init__.py index 69a5602c..657adc18 100644 --- a/app/modules/slack/__init__.py +++ b/app/modules/slack/__init__.py @@ -297,7 +297,7 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]): images.append(url) return images if images else None - def download_file_to_data_url(self, file_url: str, source: str) -> Optional[str]: + def download_slack_file_to_data_url(self, file_url: str, source: str) -> Optional[str]: """ 下载Slack文件并转为data URL :param file_url: Slack私有文件URL diff --git a/app/modules/telegram/__init__.py b/app/modules/telegram/__init__.py index 050260bf..79adf911 100644 --- a/app/modules/telegram/__init__.py +++ b/app/modules/telegram/__init__.py @@ -214,17 +214,19 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]): text = self._append_reply_markup_links(text, msg.get("reply_markup")) images = self._extract_images(msg) + audio_refs = self._extract_audio_refs(msg) if user_id: - if not text and not images: + if not text and not images and not audio_refs: logger.debug( - f"收到来自 {client_config.name} 的Telegram消息无文本和图片" + f"收到来自 {client_config.name} 的Telegram消息无文本、图片和语音" ) return None logger.info( f"收到来自 {client_config.name} 的Telegram消息:" - f"userid={user_id}, username={user_name}, chat_id={chat_id}, text={text}, images={len(images) if images else 0}" + f"userid={user_id}, username={user_name}, chat_id={chat_id}, text={text}, " + f"images={len(images) if images else 0}, audios={len(audio_refs) if audio_refs else 0}" ) cleaned_text = ( @@ -263,6 +265,7 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]): text=cleaned_text, chat_id=str(chat_id) if chat_id else None, images=images if images else None, + audio_refs=audio_refs if audio_refs else None, ) return None @@ -288,6 +291,26 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]): return images if images else None + @staticmethod + def _extract_audio_refs(msg: dict) -> Optional[List[str]]: + """ + 从Telegram消息中提取语音/音频 file_id。 + """ + audio_refs = [] + voice = msg.get("voice") + if voice: + file_id = voice.get("file_id") + if file_id: + audio_refs.append(f"tg://voice_file_id/{file_id}") + + audio = msg.get("audio") + if audio: + file_id = audio.get("file_id") + if file_id: + audio_refs.append(f"tg://audio_file_id/{file_id}") + + return audio_refs if audio_refs else None + @staticmethod def _embed_entity_links(text: str, entities: Optional[List[dict]]) -> str: """ @@ -389,17 +412,25 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]): return client: Telegram = self.get_instance(conf.name) if client: - client.send_msg( - title=message.title, - text=message.text, - image=message.image, - userid=userid, - link=message.link, - buttons=message.buttons, - original_message_id=message.original_message_id, - original_chat_id=message.original_chat_id, - disable_web_page_preview=message.disable_web_page_preview, - ) + if message.voice_path: + client.send_voice( + voice_path=message.voice_path, + userid=userid, + caption=message.voice_caption, + original_chat_id=message.original_chat_id, + ) + else: + client.send_msg( + title=message.title, + text=message.text, + image=message.image, + userid=userid, + link=message.link, + buttons=message.buttons, + original_message_id=message.original_message_id, + original_chat_id=message.original_chat_id, + disable_web_page_preview=message.disable_web_page_preview, + ) def post_medias_message( self, message: Notification, medias: List[MediaInfo] @@ -531,14 +562,22 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]): return None client: Telegram = self.get_instance(conf.name) if client: - result = client.send_msg( - title=message.title, - text=message.text, - image=message.image, - userid=userid, - link=message.link, - disable_web_page_preview=message.disable_web_page_preview, - ) + if message.voice_path: + result = client.send_voice( + voice_path=message.voice_path, + userid=userid, + caption=message.voice_caption, + original_chat_id=message.original_chat_id, + ) + else: + result = client.send_msg( + title=message.title, + text=message.text, + image=message.image, + userid=userid, + link=message.link, + disable_web_page_preview=message.disable_web_page_preview, + ) if result and result.get("success"): return MessageResponse( message_id=result.get("message_id"), @@ -601,7 +640,7 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]): ) client.register_commands(filtered_scoped_commands) - def download_file_to_base64(self, file_id: str, source: str) -> Optional[str]: + def download_telegram_file_to_base64(self, file_id: str, source: str) -> Optional[str]: """ 下载Telegram文件并转为base64 :param file_id: Telegram文件ID @@ -620,3 +659,15 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]): return base64.b64encode(file_content).decode() return None + + def download_telegram_file_bytes(self, file_id: str, source: str) -> Optional[bytes]: + """ + 下载Telegram文件并返回原始字节。 + """ + config = self.get_config(source) + if not config: + return None + client = self.get_instance(config.name) + if not client: + return None + return client.download_file(file_id) diff --git a/app/modules/telegram/telegram.py b/app/modules/telegram/telegram.py index b21ee4c9..538a7069 100644 --- a/app/modules/telegram/telegram.py +++ b/app/modules/telegram/telegram.py @@ -3,6 +3,7 @@ import json import re import threading import time +from pathlib import Path from typing import Any, Optional, List, Dict, Callable, Union from urllib.parse import urljoin, quote @@ -461,6 +462,51 @@ class Telegram: self._stop_typing_task(chat_id) return {"success": False} + def send_voice( + self, + voice_path: str, + userid: Optional[str] = None, + caption: Optional[str] = None, + original_chat_id: Optional[str] = None, + ) -> Optional[dict]: + """ + 发送Telegram语音消息。 + """ + if not self._bot or not voice_path: + return None + + chat_id = self._determine_target_chat_id(userid, original_chat_id) + voice_file = Path(voice_path) + if not voice_file.exists(): + logger.error(f"语音文件不存在: {voice_file}") + return {"success": False} + + try: + with voice_file.open("rb") as fp: + sent = self._bot.send_voice( + chat_id=chat_id, + voice=fp, + caption=standardize(caption) if caption else None, + parse_mode="MarkdownV2" if caption else None, + ) + self._stop_typing_task(chat_id) + if sent and hasattr(sent, "message_id"): + return { + "success": True, + "message_id": sent.message_id, + "chat_id": sent.chat.id if hasattr(sent, "chat") else chat_id, + } + return {"success": bool(sent)} + except Exception as err: + logger.error(f"发送语音消息失败:{err}") + self._stop_typing_task(chat_id) + return {"success": False} + finally: + try: + voice_file.unlink(missing_ok=True) + except Exception as cleanup_err: + logger.debug(f"清理语音临时文件失败: {cleanup_err}") + def _determine_target_chat_id( self, userid: Optional[str] = None, original_chat_id: Optional[str] = None ) -> str: diff --git a/app/modules/wechat/__init__.py b/app/modules/wechat/__init__.py index 04d7c2cc..cee2c55b 100644 --- a/app/modules/wechat/__init__.py +++ b/app/modules/wechat/__init__.py @@ -167,6 +167,7 @@ class WechatModule(_ModuleBase, _MessageBase[WeChat]): # 解析消息内容 content = None images = None + audio_refs = None if msg_type == "event" and event == "click": # 校验用户有权限执行交互命令 if client_config.config.get('WECHAT_ADMINS'): @@ -192,14 +193,24 @@ class WechatModule(_ModuleBase, _MessageBase[WeChat]): logger.info( f"收到来自 {client_config.name} 的微信图片消息:userid={user_id}, images={len(images) if images else 0}" ) + elif msg_type == "voice": + media_id = DomUtils.tag_value(root_node, "MediaId") + recognition = DomUtils.tag_value(root_node, "Recognition", default="") + content = (recognition or "").strip() + if media_id: + audio_refs = [f"wxwork://voice_media_id/{media_id}"] + logger.info( + f"收到来自 {client_config.name} 的微信语音消息:userid={user_id}, " + f"text={content}, audios={len(audio_refs) if audio_refs else 0}" + ) else: return None - if content or images: + if content or images or audio_refs: # 处理消息内容 return CommingMessage(channel=MessageChannel.Wechat, source=client_config.name, userid=user_id, username=user_id, text=content or "", - images=images) + images=images, audio_refs=audio_refs) except Exception as err: logger.error(f"微信消息处理发生错误:{str(err)}") return None @@ -230,6 +241,7 @@ class WechatModule(_ModuleBase, _MessageBase[WeChat]): text = WeChatBot._extract_text_from_body(payload_body) images = WeChatBot._extract_images_from_body(payload_body) + audio_refs = ["wxbot://voice"] if payload_body.get("msgtype") == "voice" else None if text: text = re.sub(r"@\S+", "", text).strip() @@ -245,7 +257,7 @@ class WechatModule(_ModuleBase, _MessageBase[WeChat]): client.send_msg(title="只有管理员才有权限执行此命令", userid=sender) return None - if not text and not images: + if not text and not images and not audio_refs: return None logger.info( @@ -259,6 +271,7 @@ class WechatModule(_ModuleBase, _MessageBase[WeChat]): username=sender, text=text or "", images=images, + audio_refs=audio_refs, ) def post_message(self, message: Notification, **kwargs) -> None: @@ -279,8 +292,17 @@ class WechatModule(_ModuleBase, _MessageBase[WeChat]): return client: WeChat = self.get_instance(conf.name) if client: - client.send_msg(title=message.title, text=message.text, - image=message.image, userid=userid, link=message.link) + if message.voice_path and hasattr(client, "send_voice"): + sent = client.send_voice( + voice_path=message.voice_path, + userid=userid, + ) + if not sent: + client.send_msg(title=message.title, text=message.text, + image=message.image, userid=userid, link=message.link) + else: + client.send_msg(title=message.title, text=message.text, + image=message.image, userid=userid, link=message.link) def download_wechat_image_to_data_url(self, image_ref: str, source: str) -> Optional[str]: """ @@ -301,6 +323,23 @@ class WechatModule(_ModuleBase, _MessageBase[WeChat]): return client.download_image_to_data_url(image_ref) return None + def download_wechat_media_bytes(self, media_ref: str, source: str) -> Optional[bytes]: + """ + 下载企业微信语音媒体并返回原始字节。 + """ + if not media_ref: + return None + client_config = self.get_config(source) + if not client_config: + return None + client = self.get_instance(client_config.name) + if not client or not hasattr(client, "download_media_bytes"): + return None + if media_ref.startswith("wxwork://voice_media_id/"): + media_id = media_ref.replace("wxwork://voice_media_id/", "", 1) + return client.download_media_bytes(media_id) + return None + def post_medias_message(self, message: Notification, medias: List[MediaInfo]) -> None: """ 发送媒体信息选择列表 diff --git a/app/modules/wechat/wechat.py b/app/modules/wechat/wechat.py index c42e7858..368396ec 100644 --- a/app/modules/wechat/wechat.py +++ b/app/modules/wechat/wechat.py @@ -2,7 +2,9 @@ import json import re import threading import base64 +import subprocess from datetime import datetime +from pathlib import Path from typing import Optional, List, Dict from app.core.context import MediaInfo, Context @@ -46,6 +48,8 @@ class WeChat: _delete_menu_url = "cgi-bin/menu/delete?access_token={access_token}&agentid={agentid}" # 企业微信下载媒体URL _download_media_url = "cgi-bin/media/get?access_token={access_token}&media_id={media_id}" + # 企业微信上传临时素材URL + _upload_media_url = "cgi-bin/media/upload?access_token={access_token}&type={media_type}" def __init__(self, WECHAT_CORPID: Optional[str] = None, WECHAT_APP_SECRET: Optional[str] = None, WECHAT_APP_ID: Optional[str] = None, WECHAT_PROXY: Optional[str] = None, **kwargs): @@ -66,6 +70,7 @@ class WeChat: self._create_menu_url = UrlUtils.adapt_request_url(self._proxy, self._create_menu_url) self._delete_menu_url = UrlUtils.adapt_request_url(self._proxy, self._delete_menu_url) self._download_media_url = UrlUtils.adapt_request_url(self._proxy, self._download_media_url) + self._upload_media_url = UrlUtils.adapt_request_url(self._proxy, self._upload_media_url) if self._corpid and self._appsecret and self._appid: self.__get_access_token() @@ -323,6 +328,168 @@ class WeChat: mime_type = self._guess_mime_type(res.content, content_type or "image/jpeg") return f"data:{mime_type};base64,{base64.b64encode(res.content).decode()}" + def download_media_bytes(self, media_id: str) -> Optional[bytes]: + """ + 下载企业微信媒体文件并返回原始字节。 + """ + if not media_id: + return None + access_token = self.__get_access_token() + if not access_token: + logger.error("下载企业微信媒体失败:access_token 获取失败") + return None + req_url = self._download_media_url.format( + access_token=access_token, + media_id=media_id, + ) + try: + res = RequestUtils(timeout=30).get_res(req_url) + except Exception as err: + logger.error(f"下载企业微信媒体失败:{err}") + return None + if not res or not res.content: + return None + content_type = (res.headers.get("Content-Type") or "").split(";")[0].strip() + if content_type == "application/json": + try: + logger.error(f"企业微信媒体下载失败:{res.json()}") + except Exception: + logger.error(f"企业微信媒体下载失败:{res.text}") + return None + return res.content + + @staticmethod + def _convert_voice_to_amr(voice_path: str) -> Optional[Path]: + """ + 将语音文件转换为企业微信要求的 AMR 格式(<=60s)。 + """ + src_path = Path(voice_path) + if not src_path.exists(): + logger.error(f"语音文件不存在:{src_path}") + return None + + dst_path = src_path.with_suffix(".amr") + cmd = [ + "ffmpeg", + "-y", + "-i", + str(src_path), + "-ar", + "8000", + "-ac", + "1", + "-t", + "60", + str(dst_path), + ] + try: + result = subprocess.run( + cmd, + capture_output=True, + text=True, + check=False, + ) + except Exception as err: + logger.error(f"调用 ffmpeg 转换 AMR 失败:{err}") + return None + + if result.returncode != 0 or not dst_path.exists(): + logger.error( + "ffmpeg 转换 AMR 失败: returncode=%s, stderr=%s", + result.returncode, + (result.stderr or "").strip()[:500], + ) + return None + + if dst_path.stat().st_size > 2 * 1024 * 1024: + logger.error("AMR 语音文件超过 2MB,无法发送到企业微信") + dst_path.unlink(missing_ok=True) + return None + return dst_path + + def _upload_temp_media(self, media_path: Path, media_type: str = "voice") -> Optional[str]: + """ + 上传企业微信临时素材,返回 media_id。 + """ + access_token = self.__get_access_token() + if not access_token: + return None + req_url = self._upload_media_url.format( + access_token=access_token, + media_type=media_type, + ) + try: + with media_path.open("rb") as media_file: + response = RequestUtils(timeout=60).request( + method="post", + url=req_url, + files={ + "media": ( + media_path.name, + media_file, + "voice/amr" if media_type == "voice" else "application/octet-stream", + ) + }, + ) + except Exception as err: + logger.error(f"上传企业微信临时素材失败:{err}") + return None + + if not response: + return None + + try: + ret_json = response.json() + except Exception as err: + logger.error(f"解析企业微信临时素材响应失败:{err}") + return None + + if ret_json.get("errcode") != 0: + logger.error(f"上传企业微信临时素材失败:{ret_json}") + return None + return ret_json.get("media_id") + + def send_voice(self, voice_path: str, userid: Optional[str] = None) -> Optional[bool]: + """ + 发送企业微信语音消息。仅自建应用模式支持。 + """ + if not voice_path: + return False + if not self.__get_access_token(): + logger.error("获取微信access_token失败,请检查参数配置") + return None + if not userid: + userid = "@all" + + source_path = Path(voice_path) + converted_path = self._convert_voice_to_amr(voice_path) + if not converted_path: + return False + + try: + media_id = self._upload_temp_media(converted_path, media_type="voice") + if not media_id: + return False + + req_json = { + "touser": userid, + "msgtype": "voice", + "agentid": self._appid, + "voice": { + "media_id": media_id + }, + "safe": 0, + "enable_id_trans": 0, + "enable_duplicate_check": 0 + } + return self.__post_request(self._send_msg_url, req_json) + except Exception as err: + logger.error(f"发送企业微信语音消息失败:{err}") + return False + finally: + converted_path.unlink(missing_ok=True) + source_path.unlink(missing_ok=True) + def send_medias_msg(self, medias: List[MediaInfo], userid: Optional[str] = None) -> Optional[bool]: """ 发送列表类消息 diff --git a/app/schemas/message.py b/app/schemas/message.py index 5b072a00..89512cc4 100644 --- a/app/schemas/message.py +++ b/app/schemas/message.py @@ -55,6 +55,8 @@ class CommingMessage(BaseModel): callback_query: Optional[Dict] = None # 图片列表(图片URL或file_id) images: Optional[List[str]] = None + # 语音/音频引用列表 + audio_refs: Optional[List[str]] = None def to_dict(self): """ @@ -86,6 +88,10 @@ class Notification(BaseModel): text: Optional[str] = None # 图片 image: Optional[str] = None + # 语音文件路径 + voice_path: Optional[str] = None + # 语音消息附带说明文字 + voice_caption: Optional[str] = None # 链接 link: Optional[str] = None # 用户ID diff --git a/tests/test_agent_image_support.py b/tests/test_agent_image_support.py index cd875786..4db7532b 100644 --- a/tests/test_agent_image_support.py +++ b/tests/test_agent_image_support.py @@ -2,11 +2,12 @@ import base64 import json import unittest from types import SimpleNamespace -from unittest.mock import Mock, patch +from unittest.mock import AsyncMock, Mock, patch from telebot import apihelper from app.agent.tools.impl.send_message import SendMessageInput +from app.agent import MoviePilotAgent, AgentChain from app.chain.message import MessageChain from app.core.config import settings from app.modules.discord import DiscordModule @@ -23,6 +24,19 @@ from app.schemas.types import MessageChannel class AgentImageSupportTest(unittest.TestCase): + def test_telegram_extract_audio_refs_returns_prefixed_file_ids(self): + audio_refs = TelegramModule._extract_audio_refs( + { + "voice": {"file_id": "voice-1"}, + "audio": {"file_id": "audio-1"}, + } + ) + + self.assertEqual( + audio_refs, + ["tg://voice_file_id/voice-1", "tg://audio_file_id/audio-1"], + ) + def test_telegram_extract_images_returns_prefixed_file_ids(self): images = TelegramModule._extract_images( { @@ -126,6 +140,25 @@ class AgentImageSupportTest(unittest.TestCase): self.assertEqual(handle_kwargs["text"], "") self.assertEqual(handle_kwargs["images"], ["tg://file_id/image-1"]) + def test_process_allows_audio_only_message(self): + chain = MessageChain() + message = CommingMessage( + channel=MessageChannel.Telegram, + source="telegram-test", + userid="10001", + username="tester", + audio_refs=["tg://voice_file_id/voice-1"], + ) + + with patch.object(chain, "message_parser", return_value=message), patch.object( + chain, "handle_message" + ) as handle_message: + chain.process(body="{}", form={}, args={"source": "telegram-test"}) + + handle_kwargs = handle_message.call_args.kwargs + self.assertEqual(handle_kwargs["text"], "") + self.assertEqual(handle_kwargs["audio_refs"], ["tg://voice_file_id/voice-1"]) + def test_image_message_routes_to_agent_even_when_global_agent_is_disabled(self): chain = MessageChain() @@ -149,6 +182,48 @@ class AgentImageSupportTest(unittest.TestCase): handle_ai_message.assert_called_once() + def test_audio_message_routes_to_agent_with_voice_reply_flag(self): + chain = MessageChain() + + with patch.object(chain, "load_cache", return_value={}), patch.object( + chain, "_transcribe_audio_refs", return_value="帮我推荐一部电影" + ), patch.object(chain.messagehelper, "put"), patch.object( + chain.messageoper, "add" + ), patch.object(chain, "_handle_ai_message") as handle_ai_message: + chain.handle_message( + channel=MessageChannel.Telegram, + source="telegram-test", + userid="10001", + username="tester", + text="", + audio_refs=["tg://voice_file_id/voice-1"], + ) + + handle_ai_message.assert_called_once() + self.assertEqual(handle_ai_message.call_args.kwargs["text"], "帮我推荐一部电影") + self.assertTrue(handle_ai_message.call_args.kwargs["reply_with_voice"]) + + def test_agent_send_agent_message_does_not_auto_convert_to_voice(self): + agent = MoviePilotAgent( + session_id="session-1", + user_id="user-1", + channel=MessageChannel.Telegram.value, + source="telegram-test", + username="tester", + ) + agent.reply_with_voice = True + + with patch.object( + AgentChain, "async_post_message", new_callable=AsyncMock + ) as async_post_message: + import asyncio + + asyncio.run(agent.send_agent_message("这是语音回复")) + + notification = async_post_message.await_args.args[0] + self.assertIsNone(notification.voice_path) + self.assertEqual(notification.text, "这是语音回复") + def test_slack_images_use_authenticated_data_url_download(self): chain = MessageChain() diff --git a/tests/test_voice_helper.py b/tests/test_voice_helper.py new file mode 100644 index 00000000..5159bdf8 --- /dev/null +++ b/tests/test_voice_helper.py @@ -0,0 +1,55 @@ +import unittest +import sys +from unittest.mock import Mock, patch + +sys.modules.setdefault("psutil", Mock()) +sys.modules.setdefault("pyquery", Mock()) + +from app.core.config import settings +from app.helper.voice import VoiceHelper, OpenAIVoiceProvider + + +class VoiceHelperTest(unittest.TestCase): + def test_registered_providers_contains_openai(self): + self.assertIn("openai", VoiceHelper.get_registered_providers()) + + def test_get_provider_falls_back_to_global_provider(self): + with patch.object(settings, "AI_VOICE_PROVIDER", "openai"), patch.object( + settings, "AI_VOICE_STT_PROVIDER", None + ): + provider = VoiceHelper.get_provider("stt") + + self.assertIsInstance(provider, OpenAIVoiceProvider) + + def test_is_available_checks_stt_and_tts_separately(self): + provider = Mock() + provider.is_available_for_stt.return_value = True + provider.is_available_for_tts.return_value = False + + with patch.object(VoiceHelper, "get_provider", return_value=provider): + self.assertTrue(VoiceHelper.is_available("stt")) + self.assertFalse(VoiceHelper.is_available("tts")) + + def test_transcribe_bytes_routes_to_stt_provider(self): + provider = Mock() + provider.transcribe_bytes.return_value = "你好" + + with patch.object(VoiceHelper, "get_provider", return_value=provider): + result = VoiceHelper.transcribe_bytes(b"audio") + + self.assertEqual(result, "你好") + provider.transcribe_bytes.assert_called_once() + + def test_synthesize_speech_routes_to_tts_provider(self): + provider = Mock() + provider.synthesize_speech.return_value = "/tmp/reply.opus" + + with patch.object(VoiceHelper, "get_provider", return_value=provider): + result = VoiceHelper.synthesize_speech("你好") + + self.assertEqual(result, "/tmp/reply.opus") + provider.synthesize_speech.assert_called_once_with(text="你好") + + +if __name__ == "__main__": + unittest.main()