重构语音能力配置与逻辑,统一音频输入输出开关并优化语音回复判断

This commit is contained in:
jxxghp
2026-04-29 18:15:34 +08:00
parent e4a7333b79
commit b7749c44fd
9 changed files with 159 additions and 76 deletions

View File

@@ -34,6 +34,7 @@ from app.chain import ChainBase
from app.core.config import settings
from app.db.transferhistory_oper import TransferHistoryOper
from app.helper.llm import LLMHelper
from app.helper.voice import VoiceHelper
from app.log import logger
from app.schemas import Notification, NotificationType
from app.schemas.message import ChannelCapabilityManager, ChannelCapability
@@ -677,6 +678,22 @@ class MoviePilotAgent:
"""
通过原渠道发送消息给用户
"""
voice_path = None
if (
self.reply_with_voice
and VoiceHelper.resolve_reply_mode(
channel=self.channel,
source=self.source,
)
== VoiceHelper.REPLY_MODE_NATIVE
and VoiceHelper.is_available("tts")
):
# 当用户本轮发来语音且 Agent 未主动调用 send_voice_message 时,
# 这里补一层自动语音回复兜底,避免最终仍只返回纯文字。
voice_file = await asyncio.to_thread(VoiceHelper.synthesize_speech, message)
if voice_file:
voice_path = str(voice_file)
await AgentChain().async_post_message(
Notification(
channel=self.channel,
@@ -686,6 +703,12 @@ class MoviePilotAgent:
username=self.username,
title=title,
text=message,
voice_path=voice_path,
voice_caption=(
message
if voice_path and settings.AI_VOICE_REPLY_WITH_TEXT
else None
),
)
)

View File

@@ -335,7 +335,7 @@ class PromptManager:
return (
"- Current message context: The user sent a voice message.\n"
"- Reply preference: Prioritize calling `send_voice_message` for the main user-facing reply.\n"
"- Fallback: If voice is unavailable on the current channel, `send_voice_message` will fall back to text.\n"
"- Fallback: If native voice is unavailable on the current channel, `send_voice_message` will fall back to text.\n"
"- Do not repeat the same full reply again after calling `send_voice_message`."
)

View File

@@ -8,10 +8,8 @@ from pydantic import BaseModel, Field
from app.agent.tools.base import MoviePilotTool, ToolChain
from app.core.config import settings
from app.helper.voice import VoiceHelper
from app.helper.service import ServiceConfigHelper
from app.log import logger
from app.schemas import Notification, NotificationType
from app.schemas.types import MessageChannel
class SendVoiceMessageInput(BaseModel):
@@ -43,18 +41,6 @@ class SendVoiceMessageTool(MoviePilotTool):
message = message[:40] + "..."
return f"发送语音回复: {message}"
def _supports_real_voice_reply(self) -> bool:
channel = self._channel or ""
if channel == MessageChannel.Telegram.value:
return True
if channel != MessageChannel.Wechat.value:
return False
for config in ServiceConfigHelper.get_notification_configs():
if config.name != self._source:
continue
return (config.config or {}).get("WECHAT_MODE", "app") != "bot"
return False
async def run(self, message: str, **kwargs) -> str:
if not message:
return "语音回复内容不能为空"
@@ -62,11 +48,23 @@ class SendVoiceMessageTool(MoviePilotTool):
voice_path = None
used_voice = False
channel = self._channel or ""
if self._supports_real_voice_reply() and VoiceHelper.is_available("tts"):
reply_mode = VoiceHelper.resolve_reply_mode(
channel=channel,
source=self._source,
)
fallback_reason = "当前渠道不支持语音回复"
if not VoiceHelper.is_enabled():
fallback_reason = "当前未启用音频输入输出"
if (
reply_mode == VoiceHelper.REPLY_MODE_NATIVE
and VoiceHelper.is_available("tts")
):
voice_file = await asyncio.to_thread(VoiceHelper.synthesize_speech, message)
if voice_file:
voice_path = str(voice_file)
used_voice = True
elif reply_mode == VoiceHelper.REPLY_MODE_NATIVE:
fallback_reason = "当前未配置可用的语音合成能力"
logger.info(
"执行工具: %s, channel=%s, use_voice=%s, text_len=%s",
@@ -85,7 +83,11 @@ class SendVoiceMessageTool(MoviePilotTool):
username=self._username,
text=message,
voice_path=voice_path,
voice_caption=message if settings.AI_VOICE_REPLY_WITH_TEXT else None,
voice_caption=(
message
if voice_path and settings.AI_VOICE_REPLY_WITH_TEXT
else None
),
)
)
self._agent_context["user_reply_sent"] = True
@@ -93,4 +95,4 @@ class SendVoiceMessageTool(MoviePilotTool):
if used_voice:
return "语音回复已发送"
return "当前未使用语音通道,已自动回退为文字回复"
return f"{fallback_reason},已自动回退为文字回复"

View File

@@ -494,6 +494,7 @@ async def get_user_global_setting(_: User = Depends(get_current_active_user_asyn
info = settings.model_dump(
include={
"AI_AGENT_ENABLE",
"LLM_SUPPORT_AUDIO_INPUT_OUTPUT",
"RECOGNIZE_SOURCE",
"SEARCH_SOURCE",
"AI_RECOMMEND_ENABLED",
@@ -503,6 +504,7 @@ async def get_user_global_setting(_: User = Depends(get_current_active_user_asyn
# 智能助手总开关未开启智能推荐状态强制返回False
if not settings.AI_AGENT_ENABLE:
info["AI_RECOMMEND_ENABLED"] = False
info["LLM_SUPPORT_AUDIO_INPUT_OUTPUT"] = False
# 追加用户唯一ID和订阅分享管理权限
share_admin = SubscribeHelper().is_admin_user()

View File

@@ -202,7 +202,10 @@ class MessageChain(ChainBase):
)
return
if settings.AI_AGENT_ENABLE and (settings.AI_AGENT_GLOBAL or images or files):
if (
settings.AI_AGENT_ENABLE
and (settings.AI_AGENT_GLOBAL or images or files or reply_with_voice)
):
self._handle_ai_message(
text=text,
channel=channel,

View File

@@ -509,6 +509,8 @@ class ConfigModel(BaseModel):
LLM_THINKING_LEVEL: Optional[str] = 'off'
# LLM是否支持图片输入开启后消息图片会按多模态输入发送给模型
LLM_SUPPORT_IMAGE_INPUT: bool = True
# LLM是否支持音频输入输出开启后才会启用语音转写与语音回复
LLM_SUPPORT_AUDIO_INPUT_OUTPUT: bool = False
# LLM API密钥
LLM_API_KEY: Optional[str] = None
# LLM基础URL用于自定义API端点
@@ -553,24 +555,12 @@ class ConfigModel(BaseModel):
# AI智能体自动重试整理失败记录开关
AI_AGENT_RETRY_TRANSFER: bool = False
# 语音能力提供商(当前仅支持 openai
# 语音能力提供商(当前仅支持 openai/openai-compatible
AI_VOICE_PROVIDER: str = "openai"
# 语音识别提供商,未设置时回退到 AI_VOICE_PROVIDER
AI_VOICE_STT_PROVIDER: Optional[str] = None
# 语音合成提供商,未设置时回退到 AI_VOICE_PROVIDER
AI_VOICE_TTS_PROVIDER: Optional[str] = None
# 语音能力 API 密钥,未设置且 LLM_PROVIDER=openai 时回退使用 LLM_API_KEY
# 语音能力共享 API 密钥,未设置且 LLM_PROVIDER=openai 时回退使用 LLM_API_KEY
AI_VOICE_API_KEY: Optional[str] = None
# 语音识别 API 密钥,未设置时回退到 AI_VOICE_API_KEY
AI_VOICE_STT_API_KEY: Optional[str] = None
# 语音合成 API 密钥,未设置时回退到 AI_VOICE_API_KEY
AI_VOICE_TTS_API_KEY: Optional[str] = None
# 语音能力基础URL未设置且 LLM_PROVIDER=openai 时回退使用 LLM_BASE_URL
# 语音能力共享基础URL未设置且 LLM_PROVIDER=openai 时回退使用 LLM_BASE_URL
AI_VOICE_BASE_URL: Optional[str] = None
# 语音识别基础URL未设置时回退到 AI_VOICE_BASE_URL
AI_VOICE_STT_BASE_URL: Optional[str] = None
# 语音合成基础URL未设置时回退到 AI_VOICE_BASE_URL
AI_VOICE_TTS_BASE_URL: Optional[str] = None
# 语音转文字模型
AI_VOICE_STT_MODEL: str = "gpt-4o-mini-transcribe"
# 文字转语音模型

View File

@@ -44,26 +44,17 @@ class OpenAIVoiceProvider(VoiceProvider):
def name(self) -> str:
return "openai"
@staticmethod
def _resolve_provider_name() -> str:
provider = settings.AI_VOICE_PROVIDER or "openai"
return provider.strip().lower()
@staticmethod
def _resolve_credentials(mode: str) -> tuple[Optional[str], Optional[str]]:
mode = mode.lower()
provider = (
settings.AI_VOICE_STT_PROVIDER
if mode == "stt"
else settings.AI_VOICE_TTS_PROVIDER
) or settings.AI_VOICE_PROVIDER
provider = (provider or "").strip().lower()
api_key = (
settings.AI_VOICE_STT_API_KEY
if mode == "stt"
else settings.AI_VOICE_TTS_API_KEY
) or settings.AI_VOICE_API_KEY
base_url = (
settings.AI_VOICE_STT_BASE_URL
if mode == "stt"
else settings.AI_VOICE_TTS_BASE_URL
) or settings.AI_VOICE_BASE_URL
provider = OpenAIVoiceProvider._resolve_provider_name()
api_key = settings.AI_VOICE_API_KEY
base_url = settings.AI_VOICE_BASE_URL
if (
not api_key
@@ -136,25 +127,27 @@ class OpenAIVoiceProvider(VoiceProvider):
class VoiceHelper:
"""统一语音入口,负责 STT/TTS provider 路由。"""
"""统一语音入口,负责音频能力判断与 STT/TTS provider 路由。"""
_providers: Dict[str, VoiceProvider] = {
"openai": OpenAIVoiceProvider(),
}
REPLY_MODE_NATIVE = "native_voice"
REPLY_MODE_TEXT = "text"
@classmethod
def register_provider(cls, provider: VoiceProvider) -> None:
cls._providers[provider.name.lower()] = provider
@staticmethod
def is_enabled() -> bool:
"""音频输入输出总开关,以显式配置为准。"""
return bool(settings.LLM_SUPPORT_AUDIO_INPUT_OUTPUT)
@staticmethod
def _resolve_provider_name(mode: str) -> str:
mode = mode.lower()
provider = (
settings.AI_VOICE_STT_PROVIDER
if mode == "stt"
else settings.AI_VOICE_TTS_PROVIDER
) or settings.AI_VOICE_PROVIDER
return (provider or "openai").strip().lower()
del mode
return OpenAIVoiceProvider._resolve_provider_name()
@classmethod
def get_provider(cls, mode: str) -> Optional[VoiceProvider]:
@@ -171,6 +164,8 @@ class VoiceHelper:
@classmethod
def is_available(cls, mode: Optional[str] = None) -> bool:
if not cls.is_enabled():
return False
if mode:
provider = cls.get_provider(mode)
if not provider:
@@ -182,8 +177,49 @@ class VoiceHelper:
)
return cls.is_available("stt") or cls.is_available("tts")
@classmethod
def supports_native_voice_reply(
cls, channel: Optional[str], source: Optional[str]
) -> bool:
"""
判断当前渠道是否支持原生语音消息发送。
"""
if not channel:
return False
from app.helper.service import ServiceConfigHelper
from app.schemas.types import MessageChannel
try:
channel_enum = MessageChannel(channel)
except (TypeError, ValueError):
return False
if channel_enum == MessageChannel.Telegram:
return True
if channel_enum != MessageChannel.Wechat:
return False
# 企业微信 bot 模式不支持发送语音,只有应用模式可用。
for config in ServiceConfigHelper.get_notification_configs():
if config.name != source:
continue
return (config.config or {}).get("WECHAT_MODE", "app") != "bot"
return False
@classmethod
def resolve_reply_mode(cls, channel: Optional[str], source: Optional[str]) -> str:
"""
仅在支持原生语音回复的渠道上发送音频,其余渠道统一回退文字。
"""
if cls.supports_native_voice_reply(channel=channel, source=source):
return cls.REPLY_MODE_NATIVE
return cls.REPLY_MODE_TEXT
@classmethod
def transcribe_bytes(cls, content: bytes, filename: str = "input.ogg") -> Optional[str]:
if not cls.is_enabled():
return None
provider = cls.get_provider("stt")
if not provider:
return None
@@ -191,6 +227,8 @@ class VoiceHelper:
@classmethod
def synthesize_speech(cls, text: str) -> Optional[Path]:
if not cls.is_enabled():
return None
provider = cls.get_provider("tts")
if not provider:
return None

View File

@@ -226,14 +226,17 @@ class AgentImageSupportTest(unittest.TestCase):
), patch.object(chain.messagehelper, "put"), patch.object(
chain.messageoper, "add"
), patch.object(chain, "_handle_ai_message") as handle_ai_message:
chain.handle_message(
channel=MessageChannel.Telegram,
source="telegram-test",
userid="10001",
username="tester",
text="",
audio_refs=["tg://voice_file_id/voice-1"],
)
with patch.object(settings, "AI_AGENT_ENABLE", True), patch.object(
settings, "AI_AGENT_GLOBAL", False
):
chain.handle_message(
channel=MessageChannel.Telegram,
source="telegram-test",
userid="10001",
username="tester",
text="",
audio_refs=["tg://voice_file_id/voice-1"],
)
handle_ai_message.assert_called_once()
self.assertEqual(handle_ai_message.call_args.kwargs["text"], "帮我推荐一部电影")
@@ -319,7 +322,7 @@ class AgentImageSupportTest(unittest.TestCase):
],
)
def test_agent_send_agent_message_does_not_auto_convert_to_voice(self):
def test_agent_send_agent_message_auto_converts_to_voice_when_supported(self):
agent = MoviePilotAgent(
session_id="session-1",
user_id="user-1",
@@ -330,6 +333,14 @@ class AgentImageSupportTest(unittest.TestCase):
agent.reply_with_voice = True
with patch.object(
VoiceHelper,
"resolve_reply_mode",
return_value=VoiceHelper.REPLY_MODE_NATIVE,
), patch.object(
VoiceHelper, "is_available", return_value=True
), patch.object(
VoiceHelper, "synthesize_speech", return_value=Path("/tmp/reply.opus")
), patch.object(
AgentChain, "async_post_message", new_callable=AsyncMock
) as async_post_message:
import asyncio
@@ -337,7 +348,7 @@ class AgentImageSupportTest(unittest.TestCase):
asyncio.run(agent.send_agent_message("这是语音回复"))
notification = async_post_message.await_args.args[0]
self.assertIsNone(notification.voice_path)
self.assertEqual(notification.voice_path, "/tmp/reply.opus")
self.assertEqual(notification.text, "这是语音回复")
def test_agent_process_wraps_request_as_structured_json(self):

View File

@@ -13,10 +13,8 @@ class VoiceHelperTest(unittest.TestCase):
def test_registered_providers_contains_openai(self):
self.assertIn("openai", VoiceHelper.get_registered_providers())
def test_get_provider_falls_back_to_global_provider(self):
with patch.object(settings, "AI_VOICE_PROVIDER", "openai"), patch.object(
settings, "AI_VOICE_STT_PROVIDER", None
):
def test_get_provider_uses_single_audio_provider_setting(self):
with patch.object(settings, "AI_VOICE_PROVIDER", "openai"):
provider = VoiceHelper.get_provider("stt")
self.assertIsInstance(provider, OpenAIVoiceProvider)
@@ -26,15 +24,29 @@ class VoiceHelperTest(unittest.TestCase):
provider.is_available_for_stt.return_value = True
provider.is_available_for_tts.return_value = False
with patch.object(VoiceHelper, "get_provider", return_value=provider):
with patch.object(
settings, "LLM_SUPPORT_AUDIO_INPUT_OUTPUT", True
), patch.object(VoiceHelper, "get_provider", return_value=provider):
self.assertTrue(VoiceHelper.is_available("stt"))
self.assertFalse(VoiceHelper.is_available("tts"))
def test_is_available_returns_false_when_audio_switch_is_disabled(self):
provider = Mock()
provider.is_available_for_stt.return_value = True
with patch.object(
settings, "LLM_SUPPORT_AUDIO_INPUT_OUTPUT", False
), patch.object(VoiceHelper, "get_provider", return_value=provider):
self.assertFalse(VoiceHelper.is_available("stt"))
self.assertFalse(VoiceHelper.is_available())
def test_transcribe_bytes_routes_to_stt_provider(self):
provider = Mock()
provider.transcribe_bytes.return_value = "你好"
with patch.object(VoiceHelper, "get_provider", return_value=provider):
with patch.object(
settings, "LLM_SUPPORT_AUDIO_INPUT_OUTPUT", True
), patch.object(VoiceHelper, "get_provider", return_value=provider):
result = VoiceHelper.transcribe_bytes(b"audio")
self.assertEqual(result, "你好")
@@ -44,7 +56,9 @@ class VoiceHelperTest(unittest.TestCase):
provider = Mock()
provider.synthesize_speech.return_value = "/tmp/reply.opus"
with patch.object(VoiceHelper, "get_provider", return_value=provider):
with patch.object(
settings, "LLM_SUPPORT_AUDIO_INPUT_OUTPUT", True
), patch.object(VoiceHelper, "get_provider", return_value=provider):
result = VoiceHelper.synthesize_speech("你好")
self.assertEqual(result, "/tmp/reply.opus")