diff --git a/app/agent/__init__.py b/app/agent/__init__.py index f3701f79..fb74ed08 100644 --- a/app/agent/__init__.py +++ b/app/agent/__init__.py @@ -34,6 +34,7 @@ from app.chain import ChainBase from app.core.config import settings from app.db.transferhistory_oper import TransferHistoryOper from app.helper.llm import LLMHelper +from app.helper.voice import VoiceHelper from app.log import logger from app.schemas import Notification, NotificationType from app.schemas.message import ChannelCapabilityManager, ChannelCapability @@ -677,6 +678,22 @@ class MoviePilotAgent: """ 通过原渠道发送消息给用户 """ + voice_path = None + if ( + self.reply_with_voice + and VoiceHelper.resolve_reply_mode( + channel=self.channel, + source=self.source, + ) + == VoiceHelper.REPLY_MODE_NATIVE + and VoiceHelper.is_available("tts") + ): + # 当用户本轮发来语音且 Agent 未主动调用 send_voice_message 时, + # 这里补一层自动语音回复兜底,避免最终仍只返回纯文字。 + voice_file = await asyncio.to_thread(VoiceHelper.synthesize_speech, message) + if voice_file: + voice_path = str(voice_file) + await AgentChain().async_post_message( Notification( channel=self.channel, @@ -686,6 +703,12 @@ class MoviePilotAgent: username=self.username, title=title, text=message, + voice_path=voice_path, + voice_caption=( + message + if voice_path and settings.AI_VOICE_REPLY_WITH_TEXT + else None + ), ) ) diff --git a/app/agent/prompt/__init__.py b/app/agent/prompt/__init__.py index 814d4e0e..eb92e13e 100644 --- a/app/agent/prompt/__init__.py +++ b/app/agent/prompt/__init__.py @@ -335,7 +335,7 @@ class PromptManager: return ( "- Current message context: The user sent a voice message.\n" "- Reply preference: Prioritize calling `send_voice_message` for the main user-facing reply.\n" - "- Fallback: If voice is unavailable on the current channel, `send_voice_message` will fall back to text.\n" + "- Fallback: If native voice is unavailable on the current channel, `send_voice_message` will fall back to text.\n" "- Do not repeat the same full reply again after calling `send_voice_message`." ) diff --git a/app/agent/tools/impl/send_voice_message.py b/app/agent/tools/impl/send_voice_message.py index 6f353962..08f18b4a 100644 --- a/app/agent/tools/impl/send_voice_message.py +++ b/app/agent/tools/impl/send_voice_message.py @@ -8,10 +8,8 @@ from pydantic import BaseModel, Field from app.agent.tools.base import MoviePilotTool, ToolChain from app.core.config import settings from app.helper.voice import VoiceHelper -from app.helper.service import ServiceConfigHelper from app.log import logger from app.schemas import Notification, NotificationType -from app.schemas.types import MessageChannel class SendVoiceMessageInput(BaseModel): @@ -43,18 +41,6 @@ class SendVoiceMessageTool(MoviePilotTool): message = message[:40] + "..." return f"发送语音回复: {message}" - def _supports_real_voice_reply(self) -> bool: - channel = self._channel or "" - if channel == MessageChannel.Telegram.value: - return True - if channel != MessageChannel.Wechat.value: - return False - for config in ServiceConfigHelper.get_notification_configs(): - if config.name != self._source: - continue - return (config.config or {}).get("WECHAT_MODE", "app") != "bot" - return False - async def run(self, message: str, **kwargs) -> str: if not message: return "语音回复内容不能为空" @@ -62,11 +48,23 @@ class SendVoiceMessageTool(MoviePilotTool): voice_path = None used_voice = False channel = self._channel or "" - if self._supports_real_voice_reply() and VoiceHelper.is_available("tts"): + reply_mode = VoiceHelper.resolve_reply_mode( + channel=channel, + source=self._source, + ) + fallback_reason = "当前渠道不支持语音回复" + if not VoiceHelper.is_enabled(): + fallback_reason = "当前未启用音频输入输出" + if ( + reply_mode == VoiceHelper.REPLY_MODE_NATIVE + and VoiceHelper.is_available("tts") + ): voice_file = await asyncio.to_thread(VoiceHelper.synthesize_speech, message) if voice_file: voice_path = str(voice_file) used_voice = True + elif reply_mode == VoiceHelper.REPLY_MODE_NATIVE: + fallback_reason = "当前未配置可用的语音合成能力" logger.info( "执行工具: %s, channel=%s, use_voice=%s, text_len=%s", @@ -85,7 +83,11 @@ class SendVoiceMessageTool(MoviePilotTool): username=self._username, text=message, voice_path=voice_path, - voice_caption=message if settings.AI_VOICE_REPLY_WITH_TEXT else None, + voice_caption=( + message + if voice_path and settings.AI_VOICE_REPLY_WITH_TEXT + else None + ), ) ) self._agent_context["user_reply_sent"] = True @@ -93,4 +95,4 @@ class SendVoiceMessageTool(MoviePilotTool): if used_voice: return "语音回复已发送" - return "当前未使用语音通道,已自动回退为文字回复" + return f"{fallback_reason},已自动回退为文字回复" diff --git a/app/api/endpoints/system.py b/app/api/endpoints/system.py index 22d7bccf..fba26068 100644 --- a/app/api/endpoints/system.py +++ b/app/api/endpoints/system.py @@ -494,6 +494,7 @@ async def get_user_global_setting(_: User = Depends(get_current_active_user_asyn info = settings.model_dump( include={ "AI_AGENT_ENABLE", + "LLM_SUPPORT_AUDIO_INPUT_OUTPUT", "RECOGNIZE_SOURCE", "SEARCH_SOURCE", "AI_RECOMMEND_ENABLED", @@ -503,6 +504,7 @@ async def get_user_global_setting(_: User = Depends(get_current_active_user_asyn # 智能助手总开关未开启,智能推荐状态强制返回False if not settings.AI_AGENT_ENABLE: info["AI_RECOMMEND_ENABLED"] = False + info["LLM_SUPPORT_AUDIO_INPUT_OUTPUT"] = False # 追加用户唯一ID和订阅分享管理权限 share_admin = SubscribeHelper().is_admin_user() diff --git a/app/chain/message.py b/app/chain/message.py index 31ec9342..7e697586 100644 --- a/app/chain/message.py +++ b/app/chain/message.py @@ -202,7 +202,10 @@ class MessageChain(ChainBase): ) return - if settings.AI_AGENT_ENABLE and (settings.AI_AGENT_GLOBAL or images or files): + if ( + settings.AI_AGENT_ENABLE + and (settings.AI_AGENT_GLOBAL or images or files or reply_with_voice) + ): self._handle_ai_message( text=text, channel=channel, diff --git a/app/core/config.py b/app/core/config.py index b8e58636..e62247ca 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -509,6 +509,8 @@ class ConfigModel(BaseModel): LLM_THINKING_LEVEL: Optional[str] = 'off' # LLM是否支持图片输入,开启后消息图片会按多模态输入发送给模型 LLM_SUPPORT_IMAGE_INPUT: bool = True + # LLM是否支持音频输入输出,开启后才会启用语音转写与语音回复 + LLM_SUPPORT_AUDIO_INPUT_OUTPUT: bool = False # LLM API密钥 LLM_API_KEY: Optional[str] = None # LLM基础URL(用于自定义API端点) @@ -553,24 +555,12 @@ class ConfigModel(BaseModel): # AI智能体自动重试整理失败记录开关 AI_AGENT_RETRY_TRANSFER: bool = False - # 语音能力提供商(当前仅支持 openai) + # 语音能力提供商(当前仅支持 openai/openai-compatible) AI_VOICE_PROVIDER: str = "openai" - # 语音识别提供商,未设置时回退到 AI_VOICE_PROVIDER - AI_VOICE_STT_PROVIDER: Optional[str] = None - # 语音合成提供商,未设置时回退到 AI_VOICE_PROVIDER - AI_VOICE_TTS_PROVIDER: Optional[str] = None - # 语音能力 API 密钥,未设置且 LLM_PROVIDER=openai 时回退使用 LLM_API_KEY + # 语音能力共享 API 密钥,未设置且 LLM_PROVIDER=openai 时回退使用 LLM_API_KEY AI_VOICE_API_KEY: Optional[str] = None - # 语音识别 API 密钥,未设置时回退到 AI_VOICE_API_KEY - AI_VOICE_STT_API_KEY: Optional[str] = None - # 语音合成 API 密钥,未设置时回退到 AI_VOICE_API_KEY - AI_VOICE_TTS_API_KEY: Optional[str] = None - # 语音能力基础URL,未设置且 LLM_PROVIDER=openai 时回退使用 LLM_BASE_URL + # 语音能力共享基础URL,未设置且 LLM_PROVIDER=openai 时回退使用 LLM_BASE_URL AI_VOICE_BASE_URL: Optional[str] = None - # 语音识别基础URL,未设置时回退到 AI_VOICE_BASE_URL - AI_VOICE_STT_BASE_URL: Optional[str] = None - # 语音合成基础URL,未设置时回退到 AI_VOICE_BASE_URL - AI_VOICE_TTS_BASE_URL: Optional[str] = None # 语音转文字模型 AI_VOICE_STT_MODEL: str = "gpt-4o-mini-transcribe" # 文字转语音模型 diff --git a/app/helper/voice.py b/app/helper/voice.py index a05f6236..59b132e6 100644 --- a/app/helper/voice.py +++ b/app/helper/voice.py @@ -44,26 +44,17 @@ class OpenAIVoiceProvider(VoiceProvider): def name(self) -> str: return "openai" + @staticmethod + def _resolve_provider_name() -> str: + provider = settings.AI_VOICE_PROVIDER or "openai" + return provider.strip().lower() + @staticmethod def _resolve_credentials(mode: str) -> tuple[Optional[str], Optional[str]]: mode = mode.lower() - provider = ( - settings.AI_VOICE_STT_PROVIDER - if mode == "stt" - else settings.AI_VOICE_TTS_PROVIDER - ) or settings.AI_VOICE_PROVIDER - provider = (provider or "").strip().lower() - - api_key = ( - settings.AI_VOICE_STT_API_KEY - if mode == "stt" - else settings.AI_VOICE_TTS_API_KEY - ) or settings.AI_VOICE_API_KEY - base_url = ( - settings.AI_VOICE_STT_BASE_URL - if mode == "stt" - else settings.AI_VOICE_TTS_BASE_URL - ) or settings.AI_VOICE_BASE_URL + provider = OpenAIVoiceProvider._resolve_provider_name() + api_key = settings.AI_VOICE_API_KEY + base_url = settings.AI_VOICE_BASE_URL if ( not api_key @@ -136,25 +127,27 @@ class OpenAIVoiceProvider(VoiceProvider): class VoiceHelper: - """统一语音入口,负责按 STT/TTS provider 路由。""" + """统一语音入口,负责音频能力判断与 STT/TTS provider 路由。""" _providers: Dict[str, VoiceProvider] = { "openai": OpenAIVoiceProvider(), } + REPLY_MODE_NATIVE = "native_voice" + REPLY_MODE_TEXT = "text" @classmethod def register_provider(cls, provider: VoiceProvider) -> None: cls._providers[provider.name.lower()] = provider + @staticmethod + def is_enabled() -> bool: + """音频输入输出总开关,以显式配置为准。""" + return bool(settings.LLM_SUPPORT_AUDIO_INPUT_OUTPUT) + @staticmethod def _resolve_provider_name(mode: str) -> str: - mode = mode.lower() - provider = ( - settings.AI_VOICE_STT_PROVIDER - if mode == "stt" - else settings.AI_VOICE_TTS_PROVIDER - ) or settings.AI_VOICE_PROVIDER - return (provider or "openai").strip().lower() + del mode + return OpenAIVoiceProvider._resolve_provider_name() @classmethod def get_provider(cls, mode: str) -> Optional[VoiceProvider]: @@ -171,6 +164,8 @@ class VoiceHelper: @classmethod def is_available(cls, mode: Optional[str] = None) -> bool: + if not cls.is_enabled(): + return False if mode: provider = cls.get_provider(mode) if not provider: @@ -182,8 +177,49 @@ class VoiceHelper: ) return cls.is_available("stt") or cls.is_available("tts") + @classmethod + def supports_native_voice_reply( + cls, channel: Optional[str], source: Optional[str] + ) -> bool: + """ + 判断当前渠道是否支持原生语音消息发送。 + """ + if not channel: + return False + + from app.helper.service import ServiceConfigHelper + from app.schemas.types import MessageChannel + + try: + channel_enum = MessageChannel(channel) + except (TypeError, ValueError): + return False + + if channel_enum == MessageChannel.Telegram: + return True + if channel_enum != MessageChannel.Wechat: + return False + + # 企业微信 bot 模式不支持发送语音,只有应用模式可用。 + for config in ServiceConfigHelper.get_notification_configs(): + if config.name != source: + continue + return (config.config or {}).get("WECHAT_MODE", "app") != "bot" + return False + + @classmethod + def resolve_reply_mode(cls, channel: Optional[str], source: Optional[str]) -> str: + """ + 仅在支持原生语音回复的渠道上发送音频,其余渠道统一回退文字。 + """ + if cls.supports_native_voice_reply(channel=channel, source=source): + return cls.REPLY_MODE_NATIVE + return cls.REPLY_MODE_TEXT + @classmethod def transcribe_bytes(cls, content: bytes, filename: str = "input.ogg") -> Optional[str]: + if not cls.is_enabled(): + return None provider = cls.get_provider("stt") if not provider: return None @@ -191,6 +227,8 @@ class VoiceHelper: @classmethod def synthesize_speech(cls, text: str) -> Optional[Path]: + if not cls.is_enabled(): + return None provider = cls.get_provider("tts") if not provider: return None diff --git a/tests/test_agent_image_support.py b/tests/test_agent_image_support.py index 6e15efa8..f9b43d57 100644 --- a/tests/test_agent_image_support.py +++ b/tests/test_agent_image_support.py @@ -226,14 +226,17 @@ class AgentImageSupportTest(unittest.TestCase): ), patch.object(chain.messagehelper, "put"), patch.object( chain.messageoper, "add" ), patch.object(chain, "_handle_ai_message") as handle_ai_message: - chain.handle_message( - channel=MessageChannel.Telegram, - source="telegram-test", - userid="10001", - username="tester", - text="", - audio_refs=["tg://voice_file_id/voice-1"], - ) + with patch.object(settings, "AI_AGENT_ENABLE", True), patch.object( + settings, "AI_AGENT_GLOBAL", False + ): + chain.handle_message( + channel=MessageChannel.Telegram, + source="telegram-test", + userid="10001", + username="tester", + text="", + audio_refs=["tg://voice_file_id/voice-1"], + ) handle_ai_message.assert_called_once() self.assertEqual(handle_ai_message.call_args.kwargs["text"], "帮我推荐一部电影") @@ -319,7 +322,7 @@ class AgentImageSupportTest(unittest.TestCase): ], ) - def test_agent_send_agent_message_does_not_auto_convert_to_voice(self): + def test_agent_send_agent_message_auto_converts_to_voice_when_supported(self): agent = MoviePilotAgent( session_id="session-1", user_id="user-1", @@ -330,6 +333,14 @@ class AgentImageSupportTest(unittest.TestCase): agent.reply_with_voice = True with patch.object( + VoiceHelper, + "resolve_reply_mode", + return_value=VoiceHelper.REPLY_MODE_NATIVE, + ), patch.object( + VoiceHelper, "is_available", return_value=True + ), patch.object( + VoiceHelper, "synthesize_speech", return_value=Path("/tmp/reply.opus") + ), patch.object( AgentChain, "async_post_message", new_callable=AsyncMock ) as async_post_message: import asyncio @@ -337,7 +348,7 @@ class AgentImageSupportTest(unittest.TestCase): asyncio.run(agent.send_agent_message("这是语音回复")) notification = async_post_message.await_args.args[0] - self.assertIsNone(notification.voice_path) + self.assertEqual(notification.voice_path, "/tmp/reply.opus") self.assertEqual(notification.text, "这是语音回复") def test_agent_process_wraps_request_as_structured_json(self): diff --git a/tests/test_voice_helper.py b/tests/test_voice_helper.py index 5159bdf8..832eb5a9 100644 --- a/tests/test_voice_helper.py +++ b/tests/test_voice_helper.py @@ -13,10 +13,8 @@ class VoiceHelperTest(unittest.TestCase): def test_registered_providers_contains_openai(self): self.assertIn("openai", VoiceHelper.get_registered_providers()) - def test_get_provider_falls_back_to_global_provider(self): - with patch.object(settings, "AI_VOICE_PROVIDER", "openai"), patch.object( - settings, "AI_VOICE_STT_PROVIDER", None - ): + def test_get_provider_uses_single_audio_provider_setting(self): + with patch.object(settings, "AI_VOICE_PROVIDER", "openai"): provider = VoiceHelper.get_provider("stt") self.assertIsInstance(provider, OpenAIVoiceProvider) @@ -26,15 +24,29 @@ class VoiceHelperTest(unittest.TestCase): provider.is_available_for_stt.return_value = True provider.is_available_for_tts.return_value = False - with patch.object(VoiceHelper, "get_provider", return_value=provider): + with patch.object( + settings, "LLM_SUPPORT_AUDIO_INPUT_OUTPUT", True + ), patch.object(VoiceHelper, "get_provider", return_value=provider): self.assertTrue(VoiceHelper.is_available("stt")) self.assertFalse(VoiceHelper.is_available("tts")) + def test_is_available_returns_false_when_audio_switch_is_disabled(self): + provider = Mock() + provider.is_available_for_stt.return_value = True + + with patch.object( + settings, "LLM_SUPPORT_AUDIO_INPUT_OUTPUT", False + ), patch.object(VoiceHelper, "get_provider", return_value=provider): + self.assertFalse(VoiceHelper.is_available("stt")) + self.assertFalse(VoiceHelper.is_available()) + def test_transcribe_bytes_routes_to_stt_provider(self): provider = Mock() provider.transcribe_bytes.return_value = "你好" - with patch.object(VoiceHelper, "get_provider", return_value=provider): + with patch.object( + settings, "LLM_SUPPORT_AUDIO_INPUT_OUTPUT", True + ), patch.object(VoiceHelper, "get_provider", return_value=provider): result = VoiceHelper.transcribe_bytes(b"audio") self.assertEqual(result, "你好") @@ -44,7 +56,9 @@ class VoiceHelperTest(unittest.TestCase): provider = Mock() provider.synthesize_speech.return_value = "/tmp/reply.opus" - with patch.object(VoiceHelper, "get_provider", return_value=provider): + with patch.object( + settings, "LLM_SUPPORT_AUDIO_INPUT_OUTPUT", True + ), patch.object(VoiceHelper, "get_provider", return_value=provider): result = VoiceHelper.synthesize_speech("你好") self.assertEqual(result, "/tmp/reply.opus")