mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-05-11 00:39:22 +08:00
- Integrate voice message handling: detect and extract audio references from Telegram and WeChat messages, route to agent with voice reply preference. - Add voice provider abstraction and OpenAI-based TTS/STT implementation. - Implement agent tool `send_voice_message` for generating and sending voice replies, with fallback to text if voice is unavailable. - Extend agent prompt and context to support voice reply instructions. - Update notification and message schemas to support audio fields. - Add Telegram and WeChat voice sending logic, including audio file conversion and temporary media upload for WeChat. - Add tests for voice helper and agent voice routing.
198 lines
6.3 KiB
Python
198 lines
6.3 KiB
Python
"""语音能力辅助功能。"""
|
||
|
||
from abc import ABC, abstractmethod
|
||
from io import BytesIO
|
||
from pathlib import Path
|
||
from typing import Dict, Optional
|
||
from uuid import uuid4
|
||
|
||
from app.core.config import settings
|
||
from app.log import logger
|
||
|
||
|
||
class VoiceProvider(ABC):
|
||
"""语音 provider 抽象层。"""
|
||
|
||
MAX_TRANSCRIBE_BYTES = 25 * 1024 * 1024
|
||
|
||
@property
|
||
@abstractmethod
|
||
def name(self) -> str:
|
||
"""provider 名称。"""
|
||
|
||
@abstractmethod
|
||
def is_available_for_stt(self) -> bool:
|
||
"""是否可用于语音识别。"""
|
||
|
||
@abstractmethod
|
||
def is_available_for_tts(self) -> bool:
|
||
"""是否可用于语音合成。"""
|
||
|
||
@abstractmethod
|
||
def transcribe_bytes(self, content: bytes, filename: str = "input.ogg") -> Optional[str]:
|
||
"""将音频字节转成文字。"""
|
||
|
||
@abstractmethod
|
||
def synthesize_speech(self, text: str) -> Optional[Path]:
|
||
"""将文字转成语音文件。"""
|
||
|
||
|
||
class OpenAIVoiceProvider(VoiceProvider):
|
||
"""OpenAI / OpenAI-compatible provider。"""
|
||
|
||
@property
|
||
def name(self) -> str:
|
||
return "openai"
|
||
|
||
@staticmethod
|
||
def _resolve_credentials(mode: str) -> tuple[Optional[str], Optional[str]]:
|
||
mode = mode.lower()
|
||
provider = (
|
||
settings.AI_VOICE_STT_PROVIDER
|
||
if mode == "stt"
|
||
else settings.AI_VOICE_TTS_PROVIDER
|
||
) or settings.AI_VOICE_PROVIDER
|
||
provider = (provider or "").strip().lower()
|
||
|
||
api_key = (
|
||
settings.AI_VOICE_STT_API_KEY
|
||
if mode == "stt"
|
||
else settings.AI_VOICE_TTS_API_KEY
|
||
) or settings.AI_VOICE_API_KEY
|
||
base_url = (
|
||
settings.AI_VOICE_STT_BASE_URL
|
||
if mode == "stt"
|
||
else settings.AI_VOICE_TTS_BASE_URL
|
||
) or settings.AI_VOICE_BASE_URL
|
||
|
||
if (
|
||
not api_key
|
||
and provider == "openai"
|
||
and (settings.LLM_PROVIDER or "").strip().lower() == "openai"
|
||
):
|
||
api_key = settings.LLM_API_KEY
|
||
base_url = base_url or settings.LLM_BASE_URL
|
||
|
||
return api_key, base_url
|
||
|
||
def _get_client(self, mode: str):
|
||
from openai import OpenAI
|
||
|
||
api_key, base_url = self._resolve_credentials(mode)
|
||
if not api_key:
|
||
raise ValueError(f"{mode.upper()} provider 未配置 API Key")
|
||
return OpenAI(api_key=api_key, base_url=base_url, max_retries=3)
|
||
|
||
def is_available_for_stt(self) -> bool:
|
||
api_key, _ = self._resolve_credentials("stt")
|
||
return bool(api_key)
|
||
|
||
def is_available_for_tts(self) -> bool:
|
||
api_key, _ = self._resolve_credentials("tts")
|
||
return bool(api_key)
|
||
|
||
def transcribe_bytes(self, content: bytes, filename: str = "input.ogg") -> Optional[str]:
|
||
if not content:
|
||
return None
|
||
if len(content) > self.MAX_TRANSCRIBE_BYTES:
|
||
raise ValueError("语音文件超过 25MB,无法识别")
|
||
|
||
try:
|
||
client = self._get_client("stt")
|
||
audio_file = BytesIO(content)
|
||
audio_file.name = filename
|
||
response = client.audio.transcriptions.create(
|
||
model=settings.AI_VOICE_STT_MODEL,
|
||
file=audio_file,
|
||
language=settings.AI_VOICE_LANGUAGE or "zh",
|
||
response_format="verbose_json",
|
||
)
|
||
text = getattr(response, "text", None)
|
||
return text.strip() if text else None
|
||
except Exception as err:
|
||
logger.error(f"语音转文字失败: provider={self.name}, error={err}")
|
||
return None
|
||
|
||
def synthesize_speech(self, text: str) -> Optional[Path]:
|
||
if not text:
|
||
return None
|
||
|
||
try:
|
||
client = self._get_client("tts")
|
||
voice_dir = settings.TEMP_PATH / "voice"
|
||
voice_dir.mkdir(parents=True, exist_ok=True)
|
||
output_path = voice_dir / f"{uuid4().hex}.opus"
|
||
response = client.audio.speech.create(
|
||
model=settings.AI_VOICE_TTS_MODEL,
|
||
voice=settings.AI_VOICE_TTS_VOICE,
|
||
input=text,
|
||
response_format="opus",
|
||
)
|
||
response.write_to_file(output_path)
|
||
return output_path
|
||
except Exception as err:
|
||
logger.error(f"文字转语音失败: provider={self.name}, error={err}")
|
||
return None
|
||
|
||
|
||
class VoiceHelper:
|
||
"""统一语音入口,负责按 STT/TTS provider 路由。"""
|
||
|
||
_providers: Dict[str, VoiceProvider] = {
|
||
"openai": OpenAIVoiceProvider(),
|
||
}
|
||
|
||
@classmethod
|
||
def register_provider(cls, provider: VoiceProvider) -> None:
|
||
cls._providers[provider.name.lower()] = provider
|
||
|
||
@staticmethod
|
||
def _resolve_provider_name(mode: str) -> str:
|
||
mode = mode.lower()
|
||
provider = (
|
||
settings.AI_VOICE_STT_PROVIDER
|
||
if mode == "stt"
|
||
else settings.AI_VOICE_TTS_PROVIDER
|
||
) or settings.AI_VOICE_PROVIDER
|
||
return (provider or "openai").strip().lower()
|
||
|
||
@classmethod
|
||
def get_provider(cls, mode: str) -> Optional[VoiceProvider]:
|
||
provider_name = cls._resolve_provider_name(mode)
|
||
provider = cls._providers.get(provider_name)
|
||
if provider:
|
||
return provider
|
||
logger.warning(f"未注册语音 provider: mode={mode}, provider={provider_name}")
|
||
return None
|
||
|
||
@classmethod
|
||
def get_registered_providers(cls) -> list[str]:
|
||
return sorted(cls._providers.keys())
|
||
|
||
@classmethod
|
||
def is_available(cls, mode: Optional[str] = None) -> bool:
|
||
if mode:
|
||
provider = cls.get_provider(mode)
|
||
if not provider:
|
||
return False
|
||
return (
|
||
provider.is_available_for_stt()
|
||
if mode.lower() == "stt"
|
||
else provider.is_available_for_tts()
|
||
)
|
||
return cls.is_available("stt") or cls.is_available("tts")
|
||
|
||
@classmethod
|
||
def transcribe_bytes(cls, content: bytes, filename: str = "input.ogg") -> Optional[str]:
|
||
provider = cls.get_provider("stt")
|
||
if not provider:
|
||
return None
|
||
return provider.transcribe_bytes(content=content, filename=filename)
|
||
|
||
@classmethod
|
||
def synthesize_speech(cls, text: str) -> Optional[Path]:
|
||
provider = cls.get_provider("tts")
|
||
if not provider:
|
||
return None
|
||
return provider.synthesize_speech(text=text)
|