diff --git a/app/helper/voice.py b/app/helper/voice.py index 59b132e6..af633eb1 100644 --- a/app/helper/voice.py +++ b/app/helper/voice.py @@ -13,7 +13,7 @@ from app.log import logger class VoiceProvider(ABC): """语音 provider 抽象层。""" - MAX_TRANSCRIBE_BYTES = 25 * 1024 * 1024 + MAX_TRANSCRIBE_BYTES = 10 * 1024 * 1024 @property @abstractmethod @@ -49,10 +49,8 @@ class OpenAIVoiceProvider(VoiceProvider): provider = settings.AI_VOICE_PROVIDER or "openai" return provider.strip().lower() - @staticmethod - def _resolve_credentials(mode: str) -> tuple[Optional[str], Optional[str]]: - mode = mode.lower() - provider = OpenAIVoiceProvider._resolve_provider_name() + def _resolve_credentials(self) -> tuple[Optional[str], Optional[str]]: + provider = self._resolve_provider_name() api_key = settings.AI_VOICE_API_KEY base_url = settings.AI_VOICE_BASE_URL @@ -69,17 +67,17 @@ class OpenAIVoiceProvider(VoiceProvider): def _get_client(self, mode: str): from openai import OpenAI - api_key, base_url = self._resolve_credentials(mode) + api_key, base_url = self._resolve_credentials() if not api_key: raise ValueError(f"{mode.upper()} provider 未配置 API Key") return OpenAI(api_key=api_key, base_url=base_url, max_retries=3) def is_available_for_stt(self) -> bool: - api_key, _ = self._resolve_credentials("stt") + api_key, _ = self._resolve_credentials() return bool(api_key) def is_available_for_tts(self) -> bool: - api_key, _ = self._resolve_credentials("tts") + api_key, _ = self._resolve_credentials() return bool(api_key) def transcribe_bytes(self, content: bytes, filename: str = "input.ogg") -> Optional[str]: @@ -144,14 +142,12 @@ class VoiceHelper: """音频输入输出总开关,以显式配置为准。""" return bool(settings.LLM_SUPPORT_AUDIO_INPUT_OUTPUT) - @staticmethod - def _resolve_provider_name(mode: str) -> str: - del mode - return OpenAIVoiceProvider._resolve_provider_name() + def _resolve_provider_name(self) -> str: + return self._resolve_provider_name() @classmethod def get_provider(cls, mode: str) -> Optional[VoiceProvider]: - provider_name = cls._resolve_provider_name(mode) + provider_name = cls._resolve_provider_name() provider = cls._providers.get(provider_name) if provider: return provider