Files
MoviePilot/app/helper/voice.py
jxxghp e5f97cd299 feat(agent): add voice message support with TTS/STT for Telegram and WeChat
- Integrate voice message handling: detect and extract audio references from Telegram and WeChat messages, route to agent with voice reply preference.
- Add voice provider abstraction and OpenAI-based TTS/STT implementation.
- Implement agent tool `send_voice_message` for generating and sending voice replies, with fallback to text if voice is unavailable.
- Extend agent prompt and context to support voice reply instructions.
- Update notification and message schemas to support audio fields.
- Add Telegram and WeChat voice sending logic, including audio file conversion and temporary media upload for WeChat.
- Add tests for voice helper and agent voice routing.
2026-04-12 12:30:02 +08:00

198 lines
6.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""语音能力辅助功能。"""
from abc import ABC, abstractmethod
from io import BytesIO
from pathlib import Path
from typing import Dict, Optional
from uuid import uuid4
from app.core.config import settings
from app.log import logger
class VoiceProvider(ABC):
"""语音 provider 抽象层。"""
MAX_TRANSCRIBE_BYTES = 25 * 1024 * 1024
@property
@abstractmethod
def name(self) -> str:
"""provider 名称。"""
@abstractmethod
def is_available_for_stt(self) -> bool:
"""是否可用于语音识别。"""
@abstractmethod
def is_available_for_tts(self) -> bool:
"""是否可用于语音合成。"""
@abstractmethod
def transcribe_bytes(self, content: bytes, filename: str = "input.ogg") -> Optional[str]:
"""将音频字节转成文字。"""
@abstractmethod
def synthesize_speech(self, text: str) -> Optional[Path]:
"""将文字转成语音文件。"""
class OpenAIVoiceProvider(VoiceProvider):
"""OpenAI / OpenAI-compatible provider。"""
@property
def name(self) -> str:
return "openai"
@staticmethod
def _resolve_credentials(mode: str) -> tuple[Optional[str], Optional[str]]:
mode = mode.lower()
provider = (
settings.AI_VOICE_STT_PROVIDER
if mode == "stt"
else settings.AI_VOICE_TTS_PROVIDER
) or settings.AI_VOICE_PROVIDER
provider = (provider or "").strip().lower()
api_key = (
settings.AI_VOICE_STT_API_KEY
if mode == "stt"
else settings.AI_VOICE_TTS_API_KEY
) or settings.AI_VOICE_API_KEY
base_url = (
settings.AI_VOICE_STT_BASE_URL
if mode == "stt"
else settings.AI_VOICE_TTS_BASE_URL
) or settings.AI_VOICE_BASE_URL
if (
not api_key
and provider == "openai"
and (settings.LLM_PROVIDER or "").strip().lower() == "openai"
):
api_key = settings.LLM_API_KEY
base_url = base_url or settings.LLM_BASE_URL
return api_key, base_url
def _get_client(self, mode: str):
from openai import OpenAI
api_key, base_url = self._resolve_credentials(mode)
if not api_key:
raise ValueError(f"{mode.upper()} provider 未配置 API Key")
return OpenAI(api_key=api_key, base_url=base_url, max_retries=3)
def is_available_for_stt(self) -> bool:
api_key, _ = self._resolve_credentials("stt")
return bool(api_key)
def is_available_for_tts(self) -> bool:
api_key, _ = self._resolve_credentials("tts")
return bool(api_key)
def transcribe_bytes(self, content: bytes, filename: str = "input.ogg") -> Optional[str]:
if not content:
return None
if len(content) > self.MAX_TRANSCRIBE_BYTES:
raise ValueError("语音文件超过 25MB无法识别")
try:
client = self._get_client("stt")
audio_file = BytesIO(content)
audio_file.name = filename
response = client.audio.transcriptions.create(
model=settings.AI_VOICE_STT_MODEL,
file=audio_file,
language=settings.AI_VOICE_LANGUAGE or "zh",
response_format="verbose_json",
)
text = getattr(response, "text", None)
return text.strip() if text else None
except Exception as err:
logger.error(f"语音转文字失败: provider={self.name}, error={err}")
return None
def synthesize_speech(self, text: str) -> Optional[Path]:
if not text:
return None
try:
client = self._get_client("tts")
voice_dir = settings.TEMP_PATH / "voice"
voice_dir.mkdir(parents=True, exist_ok=True)
output_path = voice_dir / f"{uuid4().hex}.opus"
response = client.audio.speech.create(
model=settings.AI_VOICE_TTS_MODEL,
voice=settings.AI_VOICE_TTS_VOICE,
input=text,
response_format="opus",
)
response.write_to_file(output_path)
return output_path
except Exception as err:
logger.error(f"文字转语音失败: provider={self.name}, error={err}")
return None
class VoiceHelper:
"""统一语音入口,负责按 STT/TTS provider 路由。"""
_providers: Dict[str, VoiceProvider] = {
"openai": OpenAIVoiceProvider(),
}
@classmethod
def register_provider(cls, provider: VoiceProvider) -> None:
cls._providers[provider.name.lower()] = provider
@staticmethod
def _resolve_provider_name(mode: str) -> str:
mode = mode.lower()
provider = (
settings.AI_VOICE_STT_PROVIDER
if mode == "stt"
else settings.AI_VOICE_TTS_PROVIDER
) or settings.AI_VOICE_PROVIDER
return (provider or "openai").strip().lower()
@classmethod
def get_provider(cls, mode: str) -> Optional[VoiceProvider]:
provider_name = cls._resolve_provider_name(mode)
provider = cls._providers.get(provider_name)
if provider:
return provider
logger.warning(f"未注册语音 provider: mode={mode}, provider={provider_name}")
return None
@classmethod
def get_registered_providers(cls) -> list[str]:
return sorted(cls._providers.keys())
@classmethod
def is_available(cls, mode: Optional[str] = None) -> bool:
if mode:
provider = cls.get_provider(mode)
if not provider:
return False
return (
provider.is_available_for_stt()
if mode.lower() == "stt"
else provider.is_available_for_tts()
)
return cls.is_available("stt") or cls.is_available("tts")
@classmethod
def transcribe_bytes(cls, content: bytes, filename: str = "input.ogg") -> Optional[str]:
provider = cls.get_provider("stt")
if not provider:
return None
return provider.transcribe_bytes(content=content, filename=filename)
@classmethod
def synthesize_speech(cls, text: str) -> Optional[Path]:
provider = cls.get_provider("tts")
if not provider:
return None
return provider.synthesize_speech(text=text)