mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-05-11 18:37:39 +08:00
235 lines
7.6 KiB
Python
235 lines
7.6 KiB
Python
"""语音能力辅助功能。"""
|
||
|
||
from abc import ABC, abstractmethod
|
||
from io import BytesIO
|
||
from pathlib import Path
|
||
from typing import Dict, Optional
|
||
from uuid import uuid4
|
||
|
||
from app.core.config import settings
|
||
from app.log import logger
|
||
|
||
|
||
class VoiceProvider(ABC):
|
||
"""语音 provider 抽象层。"""
|
||
|
||
MAX_TRANSCRIBE_BYTES = 10 * 1024 * 1024
|
||
|
||
@property
|
||
@abstractmethod
|
||
def name(self) -> str:
|
||
"""provider 名称。"""
|
||
|
||
@abstractmethod
|
||
def is_available_for_stt(self) -> bool:
|
||
"""是否可用于语音识别。"""
|
||
|
||
@abstractmethod
|
||
def is_available_for_tts(self) -> bool:
|
||
"""是否可用于语音合成。"""
|
||
|
||
@abstractmethod
|
||
def transcribe_bytes(self, content: bytes, filename: str = "input.ogg") -> Optional[str]:
|
||
"""将音频字节转成文字。"""
|
||
|
||
@abstractmethod
|
||
def synthesize_speech(self, text: str) -> Optional[Path]:
|
||
"""将文字转成语音文件。"""
|
||
|
||
|
||
class OpenAIVoiceProvider(VoiceProvider):
|
||
"""OpenAI / OpenAI-compatible provider。"""
|
||
|
||
@property
|
||
def name(self) -> str:
|
||
return "openai"
|
||
|
||
@staticmethod
|
||
def _resolve_provider_name() -> str:
|
||
provider = settings.AI_VOICE_PROVIDER or "openai"
|
||
return provider.strip().lower()
|
||
|
||
def _resolve_credentials(self) -> tuple[Optional[str], Optional[str]]:
|
||
provider = self._resolve_provider_name()
|
||
api_key = settings.AI_VOICE_API_KEY
|
||
base_url = settings.AI_VOICE_BASE_URL
|
||
|
||
if (
|
||
not api_key
|
||
and provider == "openai"
|
||
and (settings.LLM_PROVIDER or "").strip().lower() == "openai"
|
||
):
|
||
api_key = settings.LLM_API_KEY
|
||
base_url = base_url or settings.LLM_BASE_URL
|
||
|
||
return api_key, base_url
|
||
|
||
def _get_client(self, mode: str):
|
||
from openai import OpenAI
|
||
|
||
api_key, base_url = self._resolve_credentials()
|
||
if not api_key:
|
||
raise ValueError(f"{mode.upper()} provider 未配置 API Key")
|
||
return OpenAI(api_key=api_key, base_url=base_url, max_retries=3)
|
||
|
||
def is_available_for_stt(self) -> bool:
|
||
api_key, _ = self._resolve_credentials()
|
||
return bool(api_key)
|
||
|
||
def is_available_for_tts(self) -> bool:
|
||
api_key, _ = self._resolve_credentials()
|
||
return bool(api_key)
|
||
|
||
def transcribe_bytes(self, content: bytes, filename: str = "input.ogg") -> Optional[str]:
|
||
if not content:
|
||
return None
|
||
if len(content) > self.MAX_TRANSCRIBE_BYTES:
|
||
raise ValueError("语音文件超过 10MB,无法识别")
|
||
|
||
try:
|
||
client = self._get_client("stt")
|
||
audio_file = BytesIO(content)
|
||
audio_file.name = filename
|
||
response = client.audio.transcriptions.create(
|
||
model=settings.AI_VOICE_STT_MODEL,
|
||
file=audio_file,
|
||
language=settings.AI_VOICE_LANGUAGE or "zh",
|
||
response_format="verbose_json",
|
||
)
|
||
text = getattr(response, "text", None)
|
||
return text.strip() if text else None
|
||
except Exception as err:
|
||
logger.error(f"语音转文字失败: provider={self.name}, error={err}")
|
||
return None
|
||
|
||
def synthesize_speech(self, text: str) -> Optional[Path]:
|
||
if not text:
|
||
return None
|
||
|
||
try:
|
||
client = self._get_client("tts")
|
||
voice_dir = settings.TEMP_PATH / "voice"
|
||
voice_dir.mkdir(parents=True, exist_ok=True)
|
||
output_path = voice_dir / f"{uuid4().hex}.opus"
|
||
response = client.audio.speech.create(
|
||
model=settings.AI_VOICE_TTS_MODEL,
|
||
voice=settings.AI_VOICE_TTS_VOICE,
|
||
input=text,
|
||
response_format="opus",
|
||
)
|
||
response.write_to_file(output_path)
|
||
return output_path
|
||
except Exception as err:
|
||
logger.error(f"文字转语音失败: provider={self.name}, error={err}")
|
||
return None
|
||
|
||
|
||
class VoiceHelper:
|
||
"""统一语音入口,负责音频能力判断与 STT/TTS provider 路由。"""
|
||
|
||
_providers: Dict[str, VoiceProvider] = {
|
||
"openai": OpenAIVoiceProvider(),
|
||
}
|
||
REPLY_MODE_NATIVE = "native_voice"
|
||
REPLY_MODE_TEXT = "text"
|
||
|
||
@classmethod
|
||
def register_provider(cls, provider: VoiceProvider) -> None:
|
||
cls._providers[provider.name.lower()] = provider
|
||
|
||
@staticmethod
|
||
def is_enabled() -> bool:
|
||
"""音频输入输出总开关,以显式配置为准。"""
|
||
return bool(settings.LLM_SUPPORT_AUDIO_INPUT_OUTPUT)
|
||
|
||
@staticmethod
|
||
def _resolve_provider_name() -> str:
|
||
"""标准化当前配置的语音 provider 名称。"""
|
||
provider = settings.AI_VOICE_PROVIDER or "openai"
|
||
return provider.strip().lower()
|
||
|
||
@classmethod
|
||
def get_provider(cls, mode: str) -> Optional[VoiceProvider]:
|
||
provider_name = cls._resolve_provider_name()
|
||
provider = cls._providers.get(provider_name)
|
||
if provider:
|
||
return provider
|
||
logger.warning(f"未注册语音 provider: mode={mode}, provider={provider_name}")
|
||
return None
|
||
|
||
@classmethod
|
||
def get_registered_providers(cls) -> list[str]:
|
||
return sorted(cls._providers.keys())
|
||
|
||
@classmethod
|
||
def is_available(cls, mode: Optional[str] = None) -> bool:
|
||
if not cls.is_enabled():
|
||
return False
|
||
if mode:
|
||
provider = cls.get_provider(mode)
|
||
if not provider:
|
||
return False
|
||
return (
|
||
provider.is_available_for_stt()
|
||
if mode.lower() == "stt"
|
||
else provider.is_available_for_tts()
|
||
)
|
||
return cls.is_available("stt") or cls.is_available("tts")
|
||
|
||
@classmethod
|
||
def supports_native_voice_reply(
|
||
cls, channel: Optional[str], source: Optional[str]
|
||
) -> bool:
|
||
"""
|
||
判断当前渠道是否支持原生语音消息发送。
|
||
"""
|
||
if not channel:
|
||
return False
|
||
|
||
from app.helper.service import ServiceConfigHelper
|
||
from app.schemas.types import MessageChannel
|
||
|
||
try:
|
||
channel_enum = MessageChannel(channel)
|
||
except (TypeError, ValueError):
|
||
return False
|
||
|
||
if channel_enum == MessageChannel.Telegram:
|
||
return True
|
||
if channel_enum != MessageChannel.Wechat:
|
||
return False
|
||
|
||
# 企业微信 bot 模式不支持发送语音,只有应用模式可用。
|
||
for config in ServiceConfigHelper.get_notification_configs():
|
||
if config.name != source:
|
||
continue
|
||
return (config.config or {}).get("WECHAT_MODE", "app") != "bot"
|
||
return False
|
||
|
||
@classmethod
|
||
def resolve_reply_mode(cls, channel: Optional[str], source: Optional[str]) -> str:
|
||
"""
|
||
仅在支持原生语音回复的渠道上发送音频,其余渠道统一回退文字。
|
||
"""
|
||
if cls.supports_native_voice_reply(channel=channel, source=source):
|
||
return cls.REPLY_MODE_NATIVE
|
||
return cls.REPLY_MODE_TEXT
|
||
|
||
@classmethod
|
||
def transcribe_bytes(cls, content: bytes, filename: str = "input.ogg") -> Optional[str]:
|
||
if not cls.is_enabled():
|
||
return None
|
||
provider = cls.get_provider("stt")
|
||
if not provider:
|
||
return None
|
||
return provider.transcribe_bytes(content=content, filename=filename)
|
||
|
||
@classmethod
|
||
def synthesize_speech(cls, text: str) -> Optional[Path]:
|
||
if not cls.is_enabled():
|
||
return None
|
||
provider = cls.get_provider("tts")
|
||
if not provider:
|
||
return None
|
||
return provider.synthesize_speech(text=text)
|