mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-05-09 23:02:09 +08:00
重构语音能力配置与逻辑,统一音频输入输出开关并优化语音回复判断
This commit is contained in:
@@ -34,6 +34,7 @@ from app.chain import ChainBase
|
||||
from app.core.config import settings
|
||||
from app.db.transferhistory_oper import TransferHistoryOper
|
||||
from app.helper.llm import LLMHelper
|
||||
from app.helper.voice import VoiceHelper
|
||||
from app.log import logger
|
||||
from app.schemas import Notification, NotificationType
|
||||
from app.schemas.message import ChannelCapabilityManager, ChannelCapability
|
||||
@@ -677,6 +678,22 @@ class MoviePilotAgent:
|
||||
"""
|
||||
通过原渠道发送消息给用户
|
||||
"""
|
||||
voice_path = None
|
||||
if (
|
||||
self.reply_with_voice
|
||||
and VoiceHelper.resolve_reply_mode(
|
||||
channel=self.channel,
|
||||
source=self.source,
|
||||
)
|
||||
== VoiceHelper.REPLY_MODE_NATIVE
|
||||
and VoiceHelper.is_available("tts")
|
||||
):
|
||||
# 当用户本轮发来语音且 Agent 未主动调用 send_voice_message 时,
|
||||
# 这里补一层自动语音回复兜底,避免最终仍只返回纯文字。
|
||||
voice_file = await asyncio.to_thread(VoiceHelper.synthesize_speech, message)
|
||||
if voice_file:
|
||||
voice_path = str(voice_file)
|
||||
|
||||
await AgentChain().async_post_message(
|
||||
Notification(
|
||||
channel=self.channel,
|
||||
@@ -686,6 +703,12 @@ class MoviePilotAgent:
|
||||
username=self.username,
|
||||
title=title,
|
||||
text=message,
|
||||
voice_path=voice_path,
|
||||
voice_caption=(
|
||||
message
|
||||
if voice_path and settings.AI_VOICE_REPLY_WITH_TEXT
|
||||
else None
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -335,7 +335,7 @@ class PromptManager:
|
||||
return (
|
||||
"- Current message context: The user sent a voice message.\n"
|
||||
"- Reply preference: Prioritize calling `send_voice_message` for the main user-facing reply.\n"
|
||||
"- Fallback: If voice is unavailable on the current channel, `send_voice_message` will fall back to text.\n"
|
||||
"- Fallback: If native voice is unavailable on the current channel, `send_voice_message` will fall back to text.\n"
|
||||
"- Do not repeat the same full reply again after calling `send_voice_message`."
|
||||
)
|
||||
|
||||
|
||||
@@ -8,10 +8,8 @@ from pydantic import BaseModel, Field
|
||||
from app.agent.tools.base import MoviePilotTool, ToolChain
|
||||
from app.core.config import settings
|
||||
from app.helper.voice import VoiceHelper
|
||||
from app.helper.service import ServiceConfigHelper
|
||||
from app.log import logger
|
||||
from app.schemas import Notification, NotificationType
|
||||
from app.schemas.types import MessageChannel
|
||||
|
||||
|
||||
class SendVoiceMessageInput(BaseModel):
|
||||
@@ -43,18 +41,6 @@ class SendVoiceMessageTool(MoviePilotTool):
|
||||
message = message[:40] + "..."
|
||||
return f"发送语音回复: {message}"
|
||||
|
||||
def _supports_real_voice_reply(self) -> bool:
|
||||
channel = self._channel or ""
|
||||
if channel == MessageChannel.Telegram.value:
|
||||
return True
|
||||
if channel != MessageChannel.Wechat.value:
|
||||
return False
|
||||
for config in ServiceConfigHelper.get_notification_configs():
|
||||
if config.name != self._source:
|
||||
continue
|
||||
return (config.config or {}).get("WECHAT_MODE", "app") != "bot"
|
||||
return False
|
||||
|
||||
async def run(self, message: str, **kwargs) -> str:
|
||||
if not message:
|
||||
return "语音回复内容不能为空"
|
||||
@@ -62,11 +48,23 @@ class SendVoiceMessageTool(MoviePilotTool):
|
||||
voice_path = None
|
||||
used_voice = False
|
||||
channel = self._channel or ""
|
||||
if self._supports_real_voice_reply() and VoiceHelper.is_available("tts"):
|
||||
reply_mode = VoiceHelper.resolve_reply_mode(
|
||||
channel=channel,
|
||||
source=self._source,
|
||||
)
|
||||
fallback_reason = "当前渠道不支持语音回复"
|
||||
if not VoiceHelper.is_enabled():
|
||||
fallback_reason = "当前未启用音频输入输出"
|
||||
if (
|
||||
reply_mode == VoiceHelper.REPLY_MODE_NATIVE
|
||||
and VoiceHelper.is_available("tts")
|
||||
):
|
||||
voice_file = await asyncio.to_thread(VoiceHelper.synthesize_speech, message)
|
||||
if voice_file:
|
||||
voice_path = str(voice_file)
|
||||
used_voice = True
|
||||
elif reply_mode == VoiceHelper.REPLY_MODE_NATIVE:
|
||||
fallback_reason = "当前未配置可用的语音合成能力"
|
||||
|
||||
logger.info(
|
||||
"执行工具: %s, channel=%s, use_voice=%s, text_len=%s",
|
||||
@@ -85,7 +83,11 @@ class SendVoiceMessageTool(MoviePilotTool):
|
||||
username=self._username,
|
||||
text=message,
|
||||
voice_path=voice_path,
|
||||
voice_caption=message if settings.AI_VOICE_REPLY_WITH_TEXT else None,
|
||||
voice_caption=(
|
||||
message
|
||||
if voice_path and settings.AI_VOICE_REPLY_WITH_TEXT
|
||||
else None
|
||||
),
|
||||
)
|
||||
)
|
||||
self._agent_context["user_reply_sent"] = True
|
||||
@@ -93,4 +95,4 @@ class SendVoiceMessageTool(MoviePilotTool):
|
||||
|
||||
if used_voice:
|
||||
return "语音回复已发送"
|
||||
return "当前未使用语音通道,已自动回退为文字回复"
|
||||
return f"{fallback_reason},已自动回退为文字回复"
|
||||
|
||||
@@ -494,6 +494,7 @@ async def get_user_global_setting(_: User = Depends(get_current_active_user_asyn
|
||||
info = settings.model_dump(
|
||||
include={
|
||||
"AI_AGENT_ENABLE",
|
||||
"LLM_SUPPORT_AUDIO_INPUT_OUTPUT",
|
||||
"RECOGNIZE_SOURCE",
|
||||
"SEARCH_SOURCE",
|
||||
"AI_RECOMMEND_ENABLED",
|
||||
@@ -503,6 +504,7 @@ async def get_user_global_setting(_: User = Depends(get_current_active_user_asyn
|
||||
# 智能助手总开关未开启,智能推荐状态强制返回False
|
||||
if not settings.AI_AGENT_ENABLE:
|
||||
info["AI_RECOMMEND_ENABLED"] = False
|
||||
info["LLM_SUPPORT_AUDIO_INPUT_OUTPUT"] = False
|
||||
|
||||
# 追加用户唯一ID和订阅分享管理权限
|
||||
share_admin = SubscribeHelper().is_admin_user()
|
||||
|
||||
@@ -202,7 +202,10 @@ class MessageChain(ChainBase):
|
||||
)
|
||||
return
|
||||
|
||||
if settings.AI_AGENT_ENABLE and (settings.AI_AGENT_GLOBAL or images or files):
|
||||
if (
|
||||
settings.AI_AGENT_ENABLE
|
||||
and (settings.AI_AGENT_GLOBAL or images or files or reply_with_voice)
|
||||
):
|
||||
self._handle_ai_message(
|
||||
text=text,
|
||||
channel=channel,
|
||||
|
||||
@@ -509,6 +509,8 @@ class ConfigModel(BaseModel):
|
||||
LLM_THINKING_LEVEL: Optional[str] = 'off'
|
||||
# LLM是否支持图片输入,开启后消息图片会按多模态输入发送给模型
|
||||
LLM_SUPPORT_IMAGE_INPUT: bool = True
|
||||
# LLM是否支持音频输入输出,开启后才会启用语音转写与语音回复
|
||||
LLM_SUPPORT_AUDIO_INPUT_OUTPUT: bool = False
|
||||
# LLM API密钥
|
||||
LLM_API_KEY: Optional[str] = None
|
||||
# LLM基础URL(用于自定义API端点)
|
||||
@@ -553,24 +555,12 @@ class ConfigModel(BaseModel):
|
||||
# AI智能体自动重试整理失败记录开关
|
||||
AI_AGENT_RETRY_TRANSFER: bool = False
|
||||
|
||||
# 语音能力提供商(当前仅支持 openai)
|
||||
# 语音能力提供商(当前仅支持 openai/openai-compatible)
|
||||
AI_VOICE_PROVIDER: str = "openai"
|
||||
# 语音识别提供商,未设置时回退到 AI_VOICE_PROVIDER
|
||||
AI_VOICE_STT_PROVIDER: Optional[str] = None
|
||||
# 语音合成提供商,未设置时回退到 AI_VOICE_PROVIDER
|
||||
AI_VOICE_TTS_PROVIDER: Optional[str] = None
|
||||
# 语音能力 API 密钥,未设置且 LLM_PROVIDER=openai 时回退使用 LLM_API_KEY
|
||||
# 语音能力共享 API 密钥,未设置且 LLM_PROVIDER=openai 时回退使用 LLM_API_KEY
|
||||
AI_VOICE_API_KEY: Optional[str] = None
|
||||
# 语音识别 API 密钥,未设置时回退到 AI_VOICE_API_KEY
|
||||
AI_VOICE_STT_API_KEY: Optional[str] = None
|
||||
# 语音合成 API 密钥,未设置时回退到 AI_VOICE_API_KEY
|
||||
AI_VOICE_TTS_API_KEY: Optional[str] = None
|
||||
# 语音能力基础URL,未设置且 LLM_PROVIDER=openai 时回退使用 LLM_BASE_URL
|
||||
# 语音能力共享基础URL,未设置且 LLM_PROVIDER=openai 时回退使用 LLM_BASE_URL
|
||||
AI_VOICE_BASE_URL: Optional[str] = None
|
||||
# 语音识别基础URL,未设置时回退到 AI_VOICE_BASE_URL
|
||||
AI_VOICE_STT_BASE_URL: Optional[str] = None
|
||||
# 语音合成基础URL,未设置时回退到 AI_VOICE_BASE_URL
|
||||
AI_VOICE_TTS_BASE_URL: Optional[str] = None
|
||||
# 语音转文字模型
|
||||
AI_VOICE_STT_MODEL: str = "gpt-4o-mini-transcribe"
|
||||
# 文字转语音模型
|
||||
|
||||
@@ -44,26 +44,17 @@ class OpenAIVoiceProvider(VoiceProvider):
|
||||
def name(self) -> str:
|
||||
return "openai"
|
||||
|
||||
@staticmethod
|
||||
def _resolve_provider_name() -> str:
|
||||
provider = settings.AI_VOICE_PROVIDER or "openai"
|
||||
return provider.strip().lower()
|
||||
|
||||
@staticmethod
|
||||
def _resolve_credentials(mode: str) -> tuple[Optional[str], Optional[str]]:
|
||||
mode = mode.lower()
|
||||
provider = (
|
||||
settings.AI_VOICE_STT_PROVIDER
|
||||
if mode == "stt"
|
||||
else settings.AI_VOICE_TTS_PROVIDER
|
||||
) or settings.AI_VOICE_PROVIDER
|
||||
provider = (provider or "").strip().lower()
|
||||
|
||||
api_key = (
|
||||
settings.AI_VOICE_STT_API_KEY
|
||||
if mode == "stt"
|
||||
else settings.AI_VOICE_TTS_API_KEY
|
||||
) or settings.AI_VOICE_API_KEY
|
||||
base_url = (
|
||||
settings.AI_VOICE_STT_BASE_URL
|
||||
if mode == "stt"
|
||||
else settings.AI_VOICE_TTS_BASE_URL
|
||||
) or settings.AI_VOICE_BASE_URL
|
||||
provider = OpenAIVoiceProvider._resolve_provider_name()
|
||||
api_key = settings.AI_VOICE_API_KEY
|
||||
base_url = settings.AI_VOICE_BASE_URL
|
||||
|
||||
if (
|
||||
not api_key
|
||||
@@ -136,25 +127,27 @@ class OpenAIVoiceProvider(VoiceProvider):
|
||||
|
||||
|
||||
class VoiceHelper:
|
||||
"""统一语音入口,负责按 STT/TTS provider 路由。"""
|
||||
"""统一语音入口,负责音频能力判断与 STT/TTS provider 路由。"""
|
||||
|
||||
_providers: Dict[str, VoiceProvider] = {
|
||||
"openai": OpenAIVoiceProvider(),
|
||||
}
|
||||
REPLY_MODE_NATIVE = "native_voice"
|
||||
REPLY_MODE_TEXT = "text"
|
||||
|
||||
@classmethod
|
||||
def register_provider(cls, provider: VoiceProvider) -> None:
|
||||
cls._providers[provider.name.lower()] = provider
|
||||
|
||||
@staticmethod
|
||||
def is_enabled() -> bool:
|
||||
"""音频输入输出总开关,以显式配置为准。"""
|
||||
return bool(settings.LLM_SUPPORT_AUDIO_INPUT_OUTPUT)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_provider_name(mode: str) -> str:
|
||||
mode = mode.lower()
|
||||
provider = (
|
||||
settings.AI_VOICE_STT_PROVIDER
|
||||
if mode == "stt"
|
||||
else settings.AI_VOICE_TTS_PROVIDER
|
||||
) or settings.AI_VOICE_PROVIDER
|
||||
return (provider or "openai").strip().lower()
|
||||
del mode
|
||||
return OpenAIVoiceProvider._resolve_provider_name()
|
||||
|
||||
@classmethod
|
||||
def get_provider(cls, mode: str) -> Optional[VoiceProvider]:
|
||||
@@ -171,6 +164,8 @@ class VoiceHelper:
|
||||
|
||||
@classmethod
|
||||
def is_available(cls, mode: Optional[str] = None) -> bool:
|
||||
if not cls.is_enabled():
|
||||
return False
|
||||
if mode:
|
||||
provider = cls.get_provider(mode)
|
||||
if not provider:
|
||||
@@ -182,8 +177,49 @@ class VoiceHelper:
|
||||
)
|
||||
return cls.is_available("stt") or cls.is_available("tts")
|
||||
|
||||
@classmethod
|
||||
def supports_native_voice_reply(
|
||||
cls, channel: Optional[str], source: Optional[str]
|
||||
) -> bool:
|
||||
"""
|
||||
判断当前渠道是否支持原生语音消息发送。
|
||||
"""
|
||||
if not channel:
|
||||
return False
|
||||
|
||||
from app.helper.service import ServiceConfigHelper
|
||||
from app.schemas.types import MessageChannel
|
||||
|
||||
try:
|
||||
channel_enum = MessageChannel(channel)
|
||||
except (TypeError, ValueError):
|
||||
return False
|
||||
|
||||
if channel_enum == MessageChannel.Telegram:
|
||||
return True
|
||||
if channel_enum != MessageChannel.Wechat:
|
||||
return False
|
||||
|
||||
# 企业微信 bot 模式不支持发送语音,只有应用模式可用。
|
||||
for config in ServiceConfigHelper.get_notification_configs():
|
||||
if config.name != source:
|
||||
continue
|
||||
return (config.config or {}).get("WECHAT_MODE", "app") != "bot"
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def resolve_reply_mode(cls, channel: Optional[str], source: Optional[str]) -> str:
|
||||
"""
|
||||
仅在支持原生语音回复的渠道上发送音频,其余渠道统一回退文字。
|
||||
"""
|
||||
if cls.supports_native_voice_reply(channel=channel, source=source):
|
||||
return cls.REPLY_MODE_NATIVE
|
||||
return cls.REPLY_MODE_TEXT
|
||||
|
||||
@classmethod
|
||||
def transcribe_bytes(cls, content: bytes, filename: str = "input.ogg") -> Optional[str]:
|
||||
if not cls.is_enabled():
|
||||
return None
|
||||
provider = cls.get_provider("stt")
|
||||
if not provider:
|
||||
return None
|
||||
@@ -191,6 +227,8 @@ class VoiceHelper:
|
||||
|
||||
@classmethod
|
||||
def synthesize_speech(cls, text: str) -> Optional[Path]:
|
||||
if not cls.is_enabled():
|
||||
return None
|
||||
provider = cls.get_provider("tts")
|
||||
if not provider:
|
||||
return None
|
||||
|
||||
@@ -226,14 +226,17 @@ class AgentImageSupportTest(unittest.TestCase):
|
||||
), patch.object(chain.messagehelper, "put"), patch.object(
|
||||
chain.messageoper, "add"
|
||||
), patch.object(chain, "_handle_ai_message") as handle_ai_message:
|
||||
chain.handle_message(
|
||||
channel=MessageChannel.Telegram,
|
||||
source="telegram-test",
|
||||
userid="10001",
|
||||
username="tester",
|
||||
text="",
|
||||
audio_refs=["tg://voice_file_id/voice-1"],
|
||||
)
|
||||
with patch.object(settings, "AI_AGENT_ENABLE", True), patch.object(
|
||||
settings, "AI_AGENT_GLOBAL", False
|
||||
):
|
||||
chain.handle_message(
|
||||
channel=MessageChannel.Telegram,
|
||||
source="telegram-test",
|
||||
userid="10001",
|
||||
username="tester",
|
||||
text="",
|
||||
audio_refs=["tg://voice_file_id/voice-1"],
|
||||
)
|
||||
|
||||
handle_ai_message.assert_called_once()
|
||||
self.assertEqual(handle_ai_message.call_args.kwargs["text"], "帮我推荐一部电影")
|
||||
@@ -319,7 +322,7 @@ class AgentImageSupportTest(unittest.TestCase):
|
||||
],
|
||||
)
|
||||
|
||||
def test_agent_send_agent_message_does_not_auto_convert_to_voice(self):
|
||||
def test_agent_send_agent_message_auto_converts_to_voice_when_supported(self):
|
||||
agent = MoviePilotAgent(
|
||||
session_id="session-1",
|
||||
user_id="user-1",
|
||||
@@ -330,6 +333,14 @@ class AgentImageSupportTest(unittest.TestCase):
|
||||
agent.reply_with_voice = True
|
||||
|
||||
with patch.object(
|
||||
VoiceHelper,
|
||||
"resolve_reply_mode",
|
||||
return_value=VoiceHelper.REPLY_MODE_NATIVE,
|
||||
), patch.object(
|
||||
VoiceHelper, "is_available", return_value=True
|
||||
), patch.object(
|
||||
VoiceHelper, "synthesize_speech", return_value=Path("/tmp/reply.opus")
|
||||
), patch.object(
|
||||
AgentChain, "async_post_message", new_callable=AsyncMock
|
||||
) as async_post_message:
|
||||
import asyncio
|
||||
@@ -337,7 +348,7 @@ class AgentImageSupportTest(unittest.TestCase):
|
||||
asyncio.run(agent.send_agent_message("这是语音回复"))
|
||||
|
||||
notification = async_post_message.await_args.args[0]
|
||||
self.assertIsNone(notification.voice_path)
|
||||
self.assertEqual(notification.voice_path, "/tmp/reply.opus")
|
||||
self.assertEqual(notification.text, "这是语音回复")
|
||||
|
||||
def test_agent_process_wraps_request_as_structured_json(self):
|
||||
|
||||
@@ -13,10 +13,8 @@ class VoiceHelperTest(unittest.TestCase):
|
||||
def test_registered_providers_contains_openai(self):
|
||||
self.assertIn("openai", VoiceHelper.get_registered_providers())
|
||||
|
||||
def test_get_provider_falls_back_to_global_provider(self):
|
||||
with patch.object(settings, "AI_VOICE_PROVIDER", "openai"), patch.object(
|
||||
settings, "AI_VOICE_STT_PROVIDER", None
|
||||
):
|
||||
def test_get_provider_uses_single_audio_provider_setting(self):
|
||||
with patch.object(settings, "AI_VOICE_PROVIDER", "openai"):
|
||||
provider = VoiceHelper.get_provider("stt")
|
||||
|
||||
self.assertIsInstance(provider, OpenAIVoiceProvider)
|
||||
@@ -26,15 +24,29 @@ class VoiceHelperTest(unittest.TestCase):
|
||||
provider.is_available_for_stt.return_value = True
|
||||
provider.is_available_for_tts.return_value = False
|
||||
|
||||
with patch.object(VoiceHelper, "get_provider", return_value=provider):
|
||||
with patch.object(
|
||||
settings, "LLM_SUPPORT_AUDIO_INPUT_OUTPUT", True
|
||||
), patch.object(VoiceHelper, "get_provider", return_value=provider):
|
||||
self.assertTrue(VoiceHelper.is_available("stt"))
|
||||
self.assertFalse(VoiceHelper.is_available("tts"))
|
||||
|
||||
def test_is_available_returns_false_when_audio_switch_is_disabled(self):
|
||||
provider = Mock()
|
||||
provider.is_available_for_stt.return_value = True
|
||||
|
||||
with patch.object(
|
||||
settings, "LLM_SUPPORT_AUDIO_INPUT_OUTPUT", False
|
||||
), patch.object(VoiceHelper, "get_provider", return_value=provider):
|
||||
self.assertFalse(VoiceHelper.is_available("stt"))
|
||||
self.assertFalse(VoiceHelper.is_available())
|
||||
|
||||
def test_transcribe_bytes_routes_to_stt_provider(self):
|
||||
provider = Mock()
|
||||
provider.transcribe_bytes.return_value = "你好"
|
||||
|
||||
with patch.object(VoiceHelper, "get_provider", return_value=provider):
|
||||
with patch.object(
|
||||
settings, "LLM_SUPPORT_AUDIO_INPUT_OUTPUT", True
|
||||
), patch.object(VoiceHelper, "get_provider", return_value=provider):
|
||||
result = VoiceHelper.transcribe_bytes(b"audio")
|
||||
|
||||
self.assertEqual(result, "你好")
|
||||
@@ -44,7 +56,9 @@ class VoiceHelperTest(unittest.TestCase):
|
||||
provider = Mock()
|
||||
provider.synthesize_speech.return_value = "/tmp/reply.opus"
|
||||
|
||||
with patch.object(VoiceHelper, "get_provider", return_value=provider):
|
||||
with patch.object(
|
||||
settings, "LLM_SUPPORT_AUDIO_INPUT_OUTPUT", True
|
||||
), patch.object(VoiceHelper, "get_provider", return_value=provider):
|
||||
result = VoiceHelper.synthesize_speech("你好")
|
||||
|
||||
self.assertEqual(result, "/tmp/reply.opus")
|
||||
|
||||
Reference in New Issue
Block a user