Files
MoviePilot/app/helper/voice.py

235 lines
7.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""语音能力辅助功能。"""
from abc import ABC, abstractmethod
from io import BytesIO
from pathlib import Path
from typing import Dict, Optional
from uuid import uuid4
from app.core.config import settings
from app.log import logger
class VoiceProvider(ABC):
"""语音 provider 抽象层。"""
MAX_TRANSCRIBE_BYTES = 10 * 1024 * 1024
@property
@abstractmethod
def name(self) -> str:
"""provider 名称。"""
@abstractmethod
def is_available_for_stt(self) -> bool:
"""是否可用于语音识别。"""
@abstractmethod
def is_available_for_tts(self) -> bool:
"""是否可用于语音合成。"""
@abstractmethod
def transcribe_bytes(self, content: bytes, filename: str = "input.ogg") -> Optional[str]:
"""将音频字节转成文字。"""
@abstractmethod
def synthesize_speech(self, text: str) -> Optional[Path]:
"""将文字转成语音文件。"""
class OpenAIVoiceProvider(VoiceProvider):
"""OpenAI / OpenAI-compatible provider。"""
@property
def name(self) -> str:
return "openai"
@staticmethod
def _resolve_provider_name() -> str:
provider = settings.AI_VOICE_PROVIDER or "openai"
return provider.strip().lower()
def _resolve_credentials(self) -> tuple[Optional[str], Optional[str]]:
provider = self._resolve_provider_name()
api_key = settings.AI_VOICE_API_KEY
base_url = settings.AI_VOICE_BASE_URL
if (
not api_key
and provider == "openai"
and (settings.LLM_PROVIDER or "").strip().lower() == "openai"
):
api_key = settings.LLM_API_KEY
base_url = base_url or settings.LLM_BASE_URL
return api_key, base_url
def _get_client(self, mode: str):
from openai import OpenAI
api_key, base_url = self._resolve_credentials()
if not api_key:
raise ValueError(f"{mode.upper()} provider 未配置 API Key")
return OpenAI(api_key=api_key, base_url=base_url, max_retries=3)
def is_available_for_stt(self) -> bool:
api_key, _ = self._resolve_credentials()
return bool(api_key)
def is_available_for_tts(self) -> bool:
api_key, _ = self._resolve_credentials()
return bool(api_key)
def transcribe_bytes(self, content: bytes, filename: str = "input.ogg") -> Optional[str]:
if not content:
return None
if len(content) > self.MAX_TRANSCRIBE_BYTES:
raise ValueError("语音文件超过 10MB无法识别")
try:
client = self._get_client("stt")
audio_file = BytesIO(content)
audio_file.name = filename
response = client.audio.transcriptions.create(
model=settings.AI_VOICE_STT_MODEL,
file=audio_file,
language=settings.AI_VOICE_LANGUAGE or "zh",
response_format="verbose_json",
)
text = getattr(response, "text", None)
return text.strip() if text else None
except Exception as err:
logger.error(f"语音转文字失败: provider={self.name}, error={err}")
return None
def synthesize_speech(self, text: str) -> Optional[Path]:
if not text:
return None
try:
client = self._get_client("tts")
voice_dir = settings.TEMP_PATH / "voice"
voice_dir.mkdir(parents=True, exist_ok=True)
output_path = voice_dir / f"{uuid4().hex}.opus"
response = client.audio.speech.create(
model=settings.AI_VOICE_TTS_MODEL,
voice=settings.AI_VOICE_TTS_VOICE,
input=text,
response_format="opus",
)
response.write_to_file(output_path)
return output_path
except Exception as err:
logger.error(f"文字转语音失败: provider={self.name}, error={err}")
return None
class VoiceHelper:
"""统一语音入口,负责音频能力判断与 STT/TTS provider 路由。"""
_providers: Dict[str, VoiceProvider] = {
"openai": OpenAIVoiceProvider(),
}
REPLY_MODE_NATIVE = "native_voice"
REPLY_MODE_TEXT = "text"
@classmethod
def register_provider(cls, provider: VoiceProvider) -> None:
cls._providers[provider.name.lower()] = provider
@staticmethod
def is_enabled() -> bool:
"""音频输入输出总开关,以显式配置为准。"""
return bool(settings.LLM_SUPPORT_AUDIO_INPUT_OUTPUT)
@staticmethod
def _resolve_provider_name() -> str:
"""标准化当前配置的语音 provider 名称。"""
provider = settings.AI_VOICE_PROVIDER or "openai"
return provider.strip().lower()
@classmethod
def get_provider(cls, mode: str) -> Optional[VoiceProvider]:
provider_name = cls._resolve_provider_name()
provider = cls._providers.get(provider_name)
if provider:
return provider
logger.warning(f"未注册语音 provider: mode={mode}, provider={provider_name}")
return None
@classmethod
def get_registered_providers(cls) -> list[str]:
return sorted(cls._providers.keys())
@classmethod
def is_available(cls, mode: Optional[str] = None) -> bool:
if not cls.is_enabled():
return False
if mode:
provider = cls.get_provider(mode)
if not provider:
return False
return (
provider.is_available_for_stt()
if mode.lower() == "stt"
else provider.is_available_for_tts()
)
return cls.is_available("stt") or cls.is_available("tts")
@classmethod
def supports_native_voice_reply(
cls, channel: Optional[str], source: Optional[str]
) -> bool:
"""
判断当前渠道是否支持原生语音消息发送。
"""
if not channel:
return False
from app.helper.service import ServiceConfigHelper
from app.schemas.types import MessageChannel
try:
channel_enum = MessageChannel(channel)
except (TypeError, ValueError):
return False
if channel_enum == MessageChannel.Telegram:
return True
if channel_enum != MessageChannel.Wechat:
return False
# 企业微信 bot 模式不支持发送语音,只有应用模式可用。
for config in ServiceConfigHelper.get_notification_configs():
if config.name != source:
continue
return (config.config or {}).get("WECHAT_MODE", "app") != "bot"
return False
@classmethod
def resolve_reply_mode(cls, channel: Optional[str], source: Optional[str]) -> str:
"""
仅在支持原生语音回复的渠道上发送音频,其余渠道统一回退文字。
"""
if cls.supports_native_voice_reply(channel=channel, source=source):
return cls.REPLY_MODE_NATIVE
return cls.REPLY_MODE_TEXT
@classmethod
def transcribe_bytes(cls, content: bytes, filename: str = "input.ogg") -> Optional[str]:
if not cls.is_enabled():
return None
provider = cls.get_provider("stt")
if not provider:
return None
return provider.transcribe_bytes(content=content, filename=filename)
@classmethod
def synthesize_speech(cls, text: str) -> Optional[Path]:
if not cls.is_enabled():
return None
provider = cls.get_provider("tts")
if not provider:
return None
return provider.synthesize_speech(text=text)