mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-05-12 02:47:11 +08:00
feat(agent): add voice message support with TTS/STT for Telegram and WeChat
- Integrate voice message handling: detect and extract audio references from Telegram and WeChat messages, route to agent with voice reply preference. - Add voice provider abstraction and OpenAI-based TTS/STT implementation. - Implement agent tool `send_voice_message` for generating and sending voice replies, with fallback to text if voice is unavailable. - Extend agent prompt and context to support voice reply instructions. - Update notification and message schemas to support audio fields. - Add Telegram and WeChat voice sending logic, including audio file conversion and temporary media upload for WeChat. - Add tests for voice helper and agent voice routing.
This commit is contained in:
@@ -2,11 +2,12 @@ import base64
|
||||
import json
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock, patch
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
from telebot import apihelper
|
||||
|
||||
from app.agent.tools.impl.send_message import SendMessageInput
|
||||
from app.agent import MoviePilotAgent, AgentChain
|
||||
from app.chain.message import MessageChain
|
||||
from app.core.config import settings
|
||||
from app.modules.discord import DiscordModule
|
||||
@@ -23,6 +24,19 @@ from app.schemas.types import MessageChannel
|
||||
|
||||
|
||||
class AgentImageSupportTest(unittest.TestCase):
|
||||
def test_telegram_extract_audio_refs_returns_prefixed_file_ids(self):
|
||||
audio_refs = TelegramModule._extract_audio_refs(
|
||||
{
|
||||
"voice": {"file_id": "voice-1"},
|
||||
"audio": {"file_id": "audio-1"},
|
||||
}
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
audio_refs,
|
||||
["tg://voice_file_id/voice-1", "tg://audio_file_id/audio-1"],
|
||||
)
|
||||
|
||||
def test_telegram_extract_images_returns_prefixed_file_ids(self):
|
||||
images = TelegramModule._extract_images(
|
||||
{
|
||||
@@ -126,6 +140,25 @@ class AgentImageSupportTest(unittest.TestCase):
|
||||
self.assertEqual(handle_kwargs["text"], "")
|
||||
self.assertEqual(handle_kwargs["images"], ["tg://file_id/image-1"])
|
||||
|
||||
def test_process_allows_audio_only_message(self):
|
||||
chain = MessageChain()
|
||||
message = CommingMessage(
|
||||
channel=MessageChannel.Telegram,
|
||||
source="telegram-test",
|
||||
userid="10001",
|
||||
username="tester",
|
||||
audio_refs=["tg://voice_file_id/voice-1"],
|
||||
)
|
||||
|
||||
with patch.object(chain, "message_parser", return_value=message), patch.object(
|
||||
chain, "handle_message"
|
||||
) as handle_message:
|
||||
chain.process(body="{}", form={}, args={"source": "telegram-test"})
|
||||
|
||||
handle_kwargs = handle_message.call_args.kwargs
|
||||
self.assertEqual(handle_kwargs["text"], "")
|
||||
self.assertEqual(handle_kwargs["audio_refs"], ["tg://voice_file_id/voice-1"])
|
||||
|
||||
def test_image_message_routes_to_agent_even_when_global_agent_is_disabled(self):
|
||||
chain = MessageChain()
|
||||
|
||||
@@ -149,6 +182,48 @@ class AgentImageSupportTest(unittest.TestCase):
|
||||
|
||||
handle_ai_message.assert_called_once()
|
||||
|
||||
def test_audio_message_routes_to_agent_with_voice_reply_flag(self):
|
||||
chain = MessageChain()
|
||||
|
||||
with patch.object(chain, "load_cache", return_value={}), patch.object(
|
||||
chain, "_transcribe_audio_refs", return_value="帮我推荐一部电影"
|
||||
), patch.object(chain.messagehelper, "put"), patch.object(
|
||||
chain.messageoper, "add"
|
||||
), patch.object(chain, "_handle_ai_message") as handle_ai_message:
|
||||
chain.handle_message(
|
||||
channel=MessageChannel.Telegram,
|
||||
source="telegram-test",
|
||||
userid="10001",
|
||||
username="tester",
|
||||
text="",
|
||||
audio_refs=["tg://voice_file_id/voice-1"],
|
||||
)
|
||||
|
||||
handle_ai_message.assert_called_once()
|
||||
self.assertEqual(handle_ai_message.call_args.kwargs["text"], "帮我推荐一部电影")
|
||||
self.assertTrue(handle_ai_message.call_args.kwargs["reply_with_voice"])
|
||||
|
||||
def test_agent_send_agent_message_does_not_auto_convert_to_voice(self):
|
||||
agent = MoviePilotAgent(
|
||||
session_id="session-1",
|
||||
user_id="user-1",
|
||||
channel=MessageChannel.Telegram.value,
|
||||
source="telegram-test",
|
||||
username="tester",
|
||||
)
|
||||
agent.reply_with_voice = True
|
||||
|
||||
with patch.object(
|
||||
AgentChain, "async_post_message", new_callable=AsyncMock
|
||||
) as async_post_message:
|
||||
import asyncio
|
||||
|
||||
asyncio.run(agent.send_agent_message("这是语音回复"))
|
||||
|
||||
notification = async_post_message.await_args.args[0]
|
||||
self.assertIsNone(notification.voice_path)
|
||||
self.assertEqual(notification.text, "这是语音回复")
|
||||
|
||||
def test_slack_images_use_authenticated_data_url_download(self):
|
||||
chain = MessageChain()
|
||||
|
||||
|
||||
55
tests/test_voice_helper.py
Normal file
55
tests/test_voice_helper.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import unittest
|
||||
import sys
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
sys.modules.setdefault("psutil", Mock())
|
||||
sys.modules.setdefault("pyquery", Mock())
|
||||
|
||||
from app.core.config import settings
|
||||
from app.helper.voice import VoiceHelper, OpenAIVoiceProvider
|
||||
|
||||
|
||||
class VoiceHelperTest(unittest.TestCase):
|
||||
def test_registered_providers_contains_openai(self):
|
||||
self.assertIn("openai", VoiceHelper.get_registered_providers())
|
||||
|
||||
def test_get_provider_falls_back_to_global_provider(self):
|
||||
with patch.object(settings, "AI_VOICE_PROVIDER", "openai"), patch.object(
|
||||
settings, "AI_VOICE_STT_PROVIDER", None
|
||||
):
|
||||
provider = VoiceHelper.get_provider("stt")
|
||||
|
||||
self.assertIsInstance(provider, OpenAIVoiceProvider)
|
||||
|
||||
def test_is_available_checks_stt_and_tts_separately(self):
|
||||
provider = Mock()
|
||||
provider.is_available_for_stt.return_value = True
|
||||
provider.is_available_for_tts.return_value = False
|
||||
|
||||
with patch.object(VoiceHelper, "get_provider", return_value=provider):
|
||||
self.assertTrue(VoiceHelper.is_available("stt"))
|
||||
self.assertFalse(VoiceHelper.is_available("tts"))
|
||||
|
||||
def test_transcribe_bytes_routes_to_stt_provider(self):
|
||||
provider = Mock()
|
||||
provider.transcribe_bytes.return_value = "你好"
|
||||
|
||||
with patch.object(VoiceHelper, "get_provider", return_value=provider):
|
||||
result = VoiceHelper.transcribe_bytes(b"audio")
|
||||
|
||||
self.assertEqual(result, "你好")
|
||||
provider.transcribe_bytes.assert_called_once()
|
||||
|
||||
def test_synthesize_speech_routes_to_tts_provider(self):
|
||||
provider = Mock()
|
||||
provider.synthesize_speech.return_value = "/tmp/reply.opus"
|
||||
|
||||
with patch.object(VoiceHelper, "get_provider", return_value=provider):
|
||||
result = VoiceHelper.synthesize_speech("你好")
|
||||
|
||||
self.assertEqual(result, "/tmp/reply.opus")
|
||||
provider.synthesize_speech.assert_called_once_with(text="你好")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user