mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-05-09 06:42:38 +08:00
feat(agent): add voice message support with TTS/STT for Telegram and WeChat
- Integrate voice message handling: detect and extract audio references from Telegram and WeChat messages, route to agent with voice reply preference. - Add voice provider abstraction and OpenAI-based TTS/STT implementation. - Implement agent tool `send_voice_message` for generating and sending voice replies, with fallback to text if voice is unavailable. - Extend agent prompt and context to support voice reply instructions. - Update notification and message schemas to support audio fields. - Add Telegram and WeChat voice sending logic, including audio file conversion and temporary media upload for WeChat. - Add tests for voice helper and agent voice routing.
This commit is contained in:
@@ -129,6 +129,8 @@ class MoviePilotAgent:
|
||||
self.channel = channel
|
||||
self.source = source
|
||||
self.username = username
|
||||
self.reply_with_voice = False
|
||||
self._tool_context: Dict[str, object] = {}
|
||||
|
||||
# 流式token管理
|
||||
self.stream_handler = StreamingHandler()
|
||||
@@ -151,6 +153,8 @@ class MoviePilotAgent:
|
||||
"""
|
||||
if self.is_background:
|
||||
return False
|
||||
if self.reply_with_voice:
|
||||
return False
|
||||
# 啰嗦模式下始终需要流式输出来捕获工具调用前的 Agent 文字
|
||||
if settings.AI_AGENT_VERBOSE:
|
||||
return True
|
||||
@@ -214,6 +218,7 @@ class MoviePilotAgent:
|
||||
source=self.source,
|
||||
username=self.username,
|
||||
stream_handler=self.stream_handler,
|
||||
agent_context=self._tool_context,
|
||||
)
|
||||
|
||||
def _create_agent(self, streaming: bool = False):
|
||||
@@ -223,7 +228,10 @@ class MoviePilotAgent:
|
||||
"""
|
||||
try:
|
||||
# 系统提示词
|
||||
system_prompt = prompt_manager.get_agent_prompt(channel=self.channel)
|
||||
system_prompt = prompt_manager.get_agent_prompt(
|
||||
channel=self.channel,
|
||||
prefer_voice_reply=self.reply_with_voice,
|
||||
)
|
||||
|
||||
# LLM 模型(用于 agent 执行)
|
||||
llm = self._initialize_llm(streaming=streaming)
|
||||
@@ -281,6 +289,11 @@ class MoviePilotAgent:
|
||||
logger.info(
|
||||
f"Agent推理: session_id={self.session_id}, input={message}, images={len(images) if images else 0}"
|
||||
)
|
||||
self._tool_context = {
|
||||
"incoming_voice": self.reply_with_voice,
|
||||
"user_reply_sent": False,
|
||||
"reply_mode": None,
|
||||
}
|
||||
|
||||
# 获取历史消息
|
||||
messages = memory_manager.get_agent_messages(
|
||||
@@ -417,7 +430,7 @@ class MoviePilotAgent:
|
||||
# 流式输出未能发送全部内容(发送失败等)
|
||||
# 通过常规方式发送剩余内容
|
||||
remaining_text = await self.stream_handler.take()
|
||||
if remaining_text:
|
||||
if remaining_text and not self._tool_context.get("user_reply_sent"):
|
||||
await self.send_agent_message(remaining_text)
|
||||
elif streamed_text:
|
||||
# 流式输出已发送全部内容,但未记录到数据库,补充保存消息记录
|
||||
@@ -447,7 +460,7 @@ class MoviePilotAgent:
|
||||
final_text = text.strip()
|
||||
break
|
||||
|
||||
if final_text:
|
||||
if final_text and not self._tool_context.get("user_reply_sent"):
|
||||
if self.is_background:
|
||||
# 后台任务仅广播最终回复,带标题
|
||||
await self.send_agent_message(
|
||||
@@ -534,6 +547,7 @@ class _MessageTask:
|
||||
channel: Optional[str] = None
|
||||
source: Optional[str] = None
|
||||
username: Optional[str] = None
|
||||
reply_with_voice: bool = False
|
||||
|
||||
|
||||
class AgentManager:
|
||||
@@ -599,6 +613,7 @@ class AgentManager:
|
||||
channel: str = None,
|
||||
source: str = None,
|
||||
username: str = None,
|
||||
reply_with_voice: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
处理用户消息:将消息放入会话队列,按顺序依次处理。
|
||||
@@ -612,6 +627,7 @@ class AgentManager:
|
||||
channel=channel,
|
||||
source=source,
|
||||
username=username,
|
||||
reply_with_voice=reply_with_voice,
|
||||
)
|
||||
|
||||
# 获取或创建会话队列
|
||||
@@ -709,6 +725,7 @@ class AgentManager:
|
||||
agent.source = task.source
|
||||
if task.username:
|
||||
agent.username = task.username
|
||||
agent.reply_with_voice = task.reply_with_voice
|
||||
|
||||
return await agent.process(task.message, images=task.images)
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ Core Capabilities:
|
||||
- Include key details (year, rating, resolution) but do NOT over-explain.
|
||||
- Do not stop for approval on read-only operations. Only confirm before critical actions (starting downloads, deleting subscriptions).
|
||||
- If the current channel supports image sending and an image would materially help, you may use the `send_message` tool with `image_url` to send it.
|
||||
- Voice replies: {voice_reply_spec}
|
||||
- NOT a coding assistant. Do not offer code snippets.
|
||||
- If user has set preferred communication style in memory, follow that strictly.
|
||||
</communication>
|
||||
|
||||
@@ -50,10 +50,13 @@ class PromptManager:
|
||||
logger.error(f"加载提示词失败: {prompt_name}, 错误: {e}")
|
||||
raise
|
||||
|
||||
def get_agent_prompt(self, channel: str = None) -> str:
|
||||
def get_agent_prompt(
|
||||
self, channel: str = None, prefer_voice_reply: bool = False
|
||||
) -> str:
|
||||
"""
|
||||
获取智能体提示词
|
||||
:param channel: 消息渠道(Telegram、微信、Slack等)
|
||||
:param prefer_voice_reply: 是否优先使用语音回复
|
||||
:return: 提示词内容
|
||||
"""
|
||||
# 基础提示词
|
||||
@@ -87,12 +90,16 @@ class PromptManager:
|
||||
|
||||
# MoviePilot系统信息
|
||||
moviepilot_info = self._get_moviepilot_info()
|
||||
voice_reply_spec = self._generate_voice_reply_instructions(
|
||||
prefer_voice_reply=prefer_voice_reply
|
||||
)
|
||||
|
||||
# 始终替换占位符,避免后续 .format() 时因残留花括号报 KeyError
|
||||
base_prompt = base_prompt.format(
|
||||
markdown_spec=markdown_spec,
|
||||
verbose_spec=verbose_spec,
|
||||
moviepilot_info=moviepilot_info,
|
||||
voice_reply_spec=voice_reply_spec,
|
||||
)
|
||||
|
||||
return base_prompt
|
||||
@@ -166,6 +173,20 @@ class PromptManager:
|
||||
instructions.append("- Links: Paste URLs directly as text.")
|
||||
return "\n".join(instructions)
|
||||
|
||||
@staticmethod
|
||||
def _generate_voice_reply_instructions(prefer_voice_reply: bool) -> str:
|
||||
if not prefer_voice_reply:
|
||||
return (
|
||||
"- Voice replies: Use normal text replies by default. "
|
||||
"Only call `send_voice_message` when spoken playback is clearly better than plain text."
|
||||
)
|
||||
return (
|
||||
"- Current message context: The user sent a voice message.\n"
|
||||
"- Reply preference: Prioritize calling `send_voice_message` for the main user-facing reply.\n"
|
||||
"- Fallback: If voice is unavailable on the current channel, `send_voice_message` will fall back to text.\n"
|
||||
"- Do not repeat the same full reply again after calling `send_voice_message`."
|
||||
)
|
||||
|
||||
def clear_cache(self):
|
||||
"""
|
||||
清空缓存
|
||||
|
||||
@@ -31,6 +31,7 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
_username: Optional[str] = PrivateAttr(default=None)
|
||||
_stream_handler: Optional[StreamingHandler] = PrivateAttr(default=None)
|
||||
_require_admin: bool = PrivateAttr(default=False)
|
||||
_agent_context: dict = PrivateAttr(default_factory=dict)
|
||||
|
||||
def __init__(self, session_id: str, user_id: str, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
@@ -142,6 +143,12 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
"""
|
||||
self._stream_handler = stream_handler
|
||||
|
||||
def set_agent_context(self, agent_context: Optional[dict]):
|
||||
"""
|
||||
设置与当前 Agent 共享的上下文。
|
||||
"""
|
||||
self._agent_context = agent_context or {}
|
||||
|
||||
async def _check_permission(self) -> Optional[str]:
|
||||
"""
|
||||
检查用户权限:
|
||||
|
||||
@@ -30,6 +30,7 @@ from app.agent.tools.impl.search_torrents import SearchTorrentsTool
|
||||
from app.agent.tools.impl.get_search_results import GetSearchResultsTool
|
||||
from app.agent.tools.impl.search_web import SearchWebTool
|
||||
from app.agent.tools.impl.send_message import SendMessageTool
|
||||
from app.agent.tools.impl.send_voice_message import SendVoiceMessageTool
|
||||
from app.agent.tools.impl.query_schedulers import QuerySchedulersTool
|
||||
from app.agent.tools.impl.run_scheduler import RunSchedulerTool
|
||||
from app.agent.tools.impl.query_workflows import QueryWorkflowsTool
|
||||
@@ -72,6 +73,7 @@ class MoviePilotToolFactory:
|
||||
source: str = None,
|
||||
username: str = None,
|
||||
stream_handler: Callable = None,
|
||||
agent_context: dict = None,
|
||||
) -> List[MoviePilotTool]:
|
||||
"""
|
||||
创建MoviePilot工具列表
|
||||
@@ -117,6 +119,7 @@ class MoviePilotToolFactory:
|
||||
QueryTransferHistoryTool,
|
||||
TransferFileTool,
|
||||
SendMessageTool,
|
||||
SendVoiceMessageTool,
|
||||
QuerySchedulersTool,
|
||||
RunSchedulerTool,
|
||||
QueryWorkflowsTool,
|
||||
@@ -138,6 +141,7 @@ class MoviePilotToolFactory:
|
||||
tool = ToolClass(session_id=session_id, user_id=user_id)
|
||||
tool.set_message_attr(channel=channel, source=source, username=username)
|
||||
tool.set_stream_handler(stream_handler=stream_handler)
|
||||
tool.set_agent_context(agent_context=agent_context)
|
||||
tools.append(tool)
|
||||
|
||||
# 加载插件提供的工具
|
||||
@@ -161,6 +165,7 @@ class MoviePilotToolFactory:
|
||||
channel=channel, source=source, username=username
|
||||
)
|
||||
tool.set_stream_handler(stream_handler=stream_handler)
|
||||
tool.set_agent_context(agent_context=agent_context)
|
||||
tools.append(tool)
|
||||
plugin_tools_count += 1
|
||||
logger.debug(
|
||||
|
||||
96
app/agent/tools/impl/send_voice_message.py
Normal file
96
app/agent/tools/impl/send_voice_message.py
Normal file
@@ -0,0 +1,96 @@
|
||||
"""发送语音消息工具。"""
|
||||
|
||||
import asyncio
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool, ToolChain
|
||||
from app.core.config import settings
|
||||
from app.helper.voice import VoiceHelper
|
||||
from app.helper.service import ServiceConfigHelper
|
||||
from app.log import logger
|
||||
from app.schemas import Notification, NotificationType
|
||||
from app.schemas.types import MessageChannel
|
||||
|
||||
|
||||
class SendVoiceMessageInput(BaseModel):
|
||||
"""发送语音消息工具输入。"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why a voice reply is the best fit in the current context",
|
||||
)
|
||||
message: str = Field(
|
||||
...,
|
||||
description="The spoken content to send back to the user",
|
||||
)
|
||||
|
||||
|
||||
class SendVoiceMessageTool(MoviePilotTool):
|
||||
name: str = "send_voice_message"
|
||||
description: str = (
|
||||
"Send a voice reply to the current user. Prefer this when the user sent a voice message "
|
||||
"or when spoken playback is more natural. On channels without voice support or when TTS "
|
||||
"is unavailable, it automatically falls back to sending the same content as plain text."
|
||||
)
|
||||
args_schema: Type[BaseModel] = SendVoiceMessageInput
|
||||
require_admin: bool = False
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
message = kwargs.get("message") or ""
|
||||
if len(message) > 40:
|
||||
message = message[:40] + "..."
|
||||
return f"正在发送语音回复: {message}"
|
||||
|
||||
def _supports_real_voice_reply(self) -> bool:
|
||||
channel = self._channel or ""
|
||||
if channel == MessageChannel.Telegram.value:
|
||||
return True
|
||||
if channel != MessageChannel.Wechat.value:
|
||||
return False
|
||||
for config in ServiceConfigHelper.get_notification_configs():
|
||||
if config.name != self._source:
|
||||
continue
|
||||
return (config.config or {}).get("WECHAT_MODE", "app") != "bot"
|
||||
return False
|
||||
|
||||
async def run(self, message: str, **kwargs) -> str:
|
||||
if not message:
|
||||
return "语音回复内容不能为空"
|
||||
|
||||
voice_path = None
|
||||
used_voice = False
|
||||
channel = self._channel or ""
|
||||
if self._supports_real_voice_reply() and VoiceHelper.is_available("tts"):
|
||||
voice_file = await asyncio.to_thread(VoiceHelper.synthesize_speech, message)
|
||||
if voice_file:
|
||||
voice_path = str(voice_file)
|
||||
used_voice = True
|
||||
|
||||
logger.info(
|
||||
"执行工具: %s, channel=%s, use_voice=%s, text_len=%s",
|
||||
self.name,
|
||||
channel,
|
||||
used_voice,
|
||||
len(message),
|
||||
)
|
||||
|
||||
await ToolChain().async_post_message(
|
||||
Notification(
|
||||
channel=self._channel,
|
||||
source=self._source,
|
||||
mtype=NotificationType.Agent,
|
||||
userid=self._user_id,
|
||||
username=self._username,
|
||||
text=message,
|
||||
voice_path=voice_path,
|
||||
voice_caption=message if settings.AI_VOICE_REPLY_WITH_TEXT else None,
|
||||
)
|
||||
)
|
||||
self._agent_context["user_reply_sent"] = True
|
||||
self._agent_context["reply_mode"] = "voice" if used_voice else "text_fallback"
|
||||
|
||||
if used_voice:
|
||||
return "语音回复已发送"
|
||||
return "当前未使用语音通道,已自动回退为文字回复"
|
||||
@@ -4,6 +4,8 @@ import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Optional, Dict, Union, List
|
||||
|
||||
import base64
|
||||
|
||||
from app.agent import agent_manager
|
||||
from app.chain import ChainBase
|
||||
from app.chain.download import DownloadChain
|
||||
@@ -15,6 +17,7 @@ from app.core.context import MediaInfo, Context
|
||||
from app.core.meta import MetaBase
|
||||
from app.db.user_oper import UserOper
|
||||
from app.helper.torrent import TorrentHelper
|
||||
from app.helper.voice import VoiceHelper
|
||||
from app.log import logger
|
||||
from app.schemas import Notification, NotExistMediaInfo, CommingMessage
|
||||
from app.schemas.message import ChannelCapabilityManager
|
||||
@@ -44,26 +47,6 @@ class MessageChain(ChainBase):
|
||||
# 会话超时时间(分钟)
|
||||
_session_timeout_minutes: int = 24 * 60
|
||||
|
||||
@staticmethod
|
||||
def _summarize_images(images: Optional[List[str]], max_items: int = 3) -> List[str]:
|
||||
"""
|
||||
图片引用摘要,避免日志过长或直接输出完整 base64。
|
||||
"""
|
||||
if not images:
|
||||
return []
|
||||
summary = []
|
||||
for image in images[:max_items]:
|
||||
if not image:
|
||||
continue
|
||||
image = str(image)
|
||||
if image.startswith("data:"):
|
||||
summary.append(f"{image[:32]}...({len(image)} chars)")
|
||||
elif len(image) > 120:
|
||||
summary.append(f"{image[:117]}...")
|
||||
else:
|
||||
summary.append(image)
|
||||
return summary
|
||||
|
||||
@staticmethod
|
||||
def __get_noexits_info(
|
||||
_meta: MetaBase, _mediainfo: MediaInfo
|
||||
@@ -146,23 +129,15 @@ class MessageChain(ChainBase):
|
||||
if userid is None or userid == "":
|
||||
logger.debug(f"未识别到用户ID:{body}{form}{args}")
|
||||
return
|
||||
|
||||
# 消息内容
|
||||
text = str(info.text).strip() if info.text else ""
|
||||
images = info.images
|
||||
if not text and not images:
|
||||
audio_refs = info.audio_refs
|
||||
if not text and not images and not audio_refs:
|
||||
logger.debug(f"未识别到消息内容::{body}{form}{args}")
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"消息链路解析完成: source=%s, channel=%s, userid=%s, text_len=%s, image_count=%s, image_refs=%s",
|
||||
source,
|
||||
channel.value if channel else None,
|
||||
userid,
|
||||
len(text),
|
||||
len(images or []),
|
||||
self._summarize_images(images),
|
||||
)
|
||||
|
||||
# 获取原消息ID信息
|
||||
original_message_id = info.message_id
|
||||
original_chat_id = info.chat_id
|
||||
@@ -177,6 +152,7 @@ class MessageChain(ChainBase):
|
||||
original_message_id=original_message_id,
|
||||
original_chat_id=original_chat_id,
|
||||
images=images,
|
||||
audio_refs=audio_refs,
|
||||
)
|
||||
|
||||
def handle_message(
|
||||
@@ -189,25 +165,43 @@ class MessageChain(ChainBase):
|
||||
original_message_id: Optional[Union[str, int]] = None,
|
||||
original_chat_id: Optional[str] = None,
|
||||
images: Optional[List[str]] = None,
|
||||
audio_refs: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
识别消息内容,执行操作
|
||||
"""
|
||||
# 申明全局变量
|
||||
global _current_page, _current_meta, _current_media
|
||||
# 处理消息
|
||||
logger.info(
|
||||
"收到用户消息内容: channel=%s, source=%s, userid=%s, text=%s, image_count=%s, image_refs=%s",
|
||||
channel.value if channel else None,
|
||||
source,
|
||||
userid,
|
||||
text,
|
||||
len(images or []),
|
||||
self._summarize_images(images),
|
||||
)
|
||||
|
||||
# 加载缓存
|
||||
user_cache: Dict[str, dict] = self.load_cache(self._cache_file) or {}
|
||||
|
||||
try:
|
||||
# 识别语音为文本
|
||||
reply_with_voice = bool(audio_refs)
|
||||
if audio_refs:
|
||||
transcript = self._transcribe_audio_refs(audio_refs, channel, source)
|
||||
merged_parts = []
|
||||
seen_parts = set()
|
||||
for item in [text.strip() if text else "", transcript or ""]:
|
||||
normalized = item.strip()
|
||||
if not normalized or normalized in seen_parts:
|
||||
continue
|
||||
seen_parts.add(normalized)
|
||||
merged_parts.append(normalized)
|
||||
text = "\n".join(merged_parts).strip()
|
||||
if not text:
|
||||
self.post_message(
|
||||
Notification(
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
title="语音识别失败,请稍后重试",
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
# 保存消息
|
||||
if not text.startswith("CALLBACK:"):
|
||||
self.messagehelper.put(
|
||||
@@ -251,32 +245,6 @@ class MessageChain(ChainBase):
|
||||
{"cmd": text, "user": userid, "channel": channel, "source": source},
|
||||
)
|
||||
elif text.lower().startswith("/ai"):
|
||||
# 用户指定AI智能体消息响应
|
||||
logger.info(
|
||||
"消息链路分流到AI: reason=explicit_ai, channel=%s, source=%s, userid=%s, image_count=%s",
|
||||
channel.value if channel else None,
|
||||
source,
|
||||
userid,
|
||||
len(images or []),
|
||||
)
|
||||
self._handle_ai_message(
|
||||
text=text,
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
images=images,
|
||||
)
|
||||
elif settings.AI_AGENT_ENABLE and images:
|
||||
# 带图消息优先交给智能体处理,避免图片在传统消息链路中丢失
|
||||
logger.info(
|
||||
"消息链路分流到AI: reason=image_message, channel=%s, source=%s, userid=%s, image_count=%s, image_refs=%s",
|
||||
channel.value if channel else None,
|
||||
source,
|
||||
userid,
|
||||
len(images or []),
|
||||
self._summarize_images(images),
|
||||
)
|
||||
self._handle_ai_message(
|
||||
text=text,
|
||||
channel=channel,
|
||||
@@ -284,16 +252,10 @@ class MessageChain(ChainBase):
|
||||
userid=userid,
|
||||
username=username,
|
||||
images=images,
|
||||
reply_with_voice=reply_with_voice,
|
||||
)
|
||||
elif settings.AI_AGENT_ENABLE and settings.AI_AGENT_GLOBAL:
|
||||
# 普通消息,全局智能体响应
|
||||
logger.info(
|
||||
"消息链路分流到AI: reason=global_agent, channel=%s, source=%s, userid=%s, image_count=%s",
|
||||
channel.value if channel else None,
|
||||
source,
|
||||
userid,
|
||||
len(images or []),
|
||||
)
|
||||
self._handle_ai_message(
|
||||
text=text,
|
||||
channel=channel,
|
||||
@@ -301,6 +263,7 @@ class MessageChain(ChainBase):
|
||||
userid=userid,
|
||||
username=username,
|
||||
images=images,
|
||||
reply_with_voice=reply_with_voice,
|
||||
)
|
||||
else:
|
||||
# 非智能体普通消息响应
|
||||
@@ -1266,6 +1229,7 @@ class MessageChain(ChainBase):
|
||||
userid: Union[str, int],
|
||||
username: str,
|
||||
images: Optional[List[str]] = None,
|
||||
reply_with_voice: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
处理AI智能体消息
|
||||
@@ -1290,16 +1254,6 @@ class MessageChain(ChainBase):
|
||||
else:
|
||||
user_message = text.strip() # 按原消息处理
|
||||
|
||||
logger.info(
|
||||
"AI消息入口: channel=%s, source=%s, userid=%s, text_len=%s, raw_image_count=%s, raw_image_refs=%s",
|
||||
channel.value if channel else None,
|
||||
source,
|
||||
userid,
|
||||
len(user_message),
|
||||
len(images or []),
|
||||
self._summarize_images(images),
|
||||
)
|
||||
|
||||
if not user_message and not images:
|
||||
self.post_message(
|
||||
Notification(
|
||||
@@ -1319,15 +1273,6 @@ class MessageChain(ChainBase):
|
||||
original_images = images
|
||||
if images:
|
||||
images = self._download_images_to_base64(images, channel, source)
|
||||
logger.info(
|
||||
"AI图片预处理完成: channel=%s, source=%s, userid=%s, raw_image_count=%s, converted_image_count=%s, converted_image_refs=%s",
|
||||
channel.value if channel else None,
|
||||
source,
|
||||
userid,
|
||||
len(original_images or []),
|
||||
len(images or []),
|
||||
self._summarize_images(images),
|
||||
)
|
||||
if original_images and not images and not user_message:
|
||||
self.post_message(
|
||||
Notification(
|
||||
@@ -1350,6 +1295,7 @@ class MessageChain(ChainBase):
|
||||
channel=channel.value if channel else None,
|
||||
source=source,
|
||||
username=username,
|
||||
reply_with_voice=reply_with_voice,
|
||||
),
|
||||
global_vars.loop,
|
||||
)
|
||||
@@ -1360,6 +1306,64 @@ class MessageChain(ChainBase):
|
||||
f"AI智能体处理失败: {str(e)}", role="system", title="MoviePilot助手"
|
||||
)
|
||||
|
||||
def _transcribe_audio_refs(
|
||||
self, audio_refs: List[str], channel: MessageChannel, source: str
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
下载并识别语音消息,仅处理当前已接入的渠道。
|
||||
"""
|
||||
if not audio_refs:
|
||||
return None
|
||||
if not VoiceHelper.is_available("stt"):
|
||||
logger.warning("语音能力未配置,跳过语音识别")
|
||||
return None
|
||||
|
||||
transcripts = []
|
||||
for audio_ref in audio_refs:
|
||||
try:
|
||||
if audio_ref.startswith("tg://voice_file_id/"):
|
||||
file_id = audio_ref.replace("tg://voice_file_id/", "", 1)
|
||||
content = self.run_module(
|
||||
"download_telegram_file_bytes", file_id=file_id, source=source
|
||||
)
|
||||
filename = "input.ogg"
|
||||
elif audio_ref.startswith("tg://audio_file_id/"):
|
||||
file_id = audio_ref.replace("tg://audio_file_id/", "", 1)
|
||||
content = self.run_module(
|
||||
"download_telegram_file_bytes", file_id=file_id, source=source
|
||||
)
|
||||
filename = "input.mp3"
|
||||
elif audio_ref.startswith("wxwork://voice_media_id/"):
|
||||
content = self.run_module(
|
||||
"download_wechat_media_bytes", media_ref=audio_ref, source=source
|
||||
)
|
||||
filename = "input.amr"
|
||||
elif audio_ref.startswith("wxbot://voice"):
|
||||
continue
|
||||
else:
|
||||
logger.debug(
|
||||
"暂不支持的语音引用: channel=%s, source=%s, ref=%s",
|
||||
channel.value if channel else None,
|
||||
source,
|
||||
audio_ref,
|
||||
)
|
||||
continue
|
||||
|
||||
transcript = VoiceHelper.transcribe_bytes(content=content, filename=filename)
|
||||
if transcript:
|
||||
transcripts.append(transcript)
|
||||
logger.info(
|
||||
"语音识别成功: channel=%s, source=%s, ref=%s, text_len=%s",
|
||||
channel.value if channel else None,
|
||||
source,
|
||||
audio_ref,
|
||||
len(transcript),
|
||||
)
|
||||
except Exception as err:
|
||||
logger.error(f"语音识别失败: {err}")
|
||||
|
||||
return "\n".join(transcripts).strip() if transcripts else None
|
||||
|
||||
def _download_images_to_base64(
|
||||
self, images: List[str], channel: MessageChannel, source: str
|
||||
) -> List[str]:
|
||||
@@ -1373,16 +1377,10 @@ class MessageChain(ChainBase):
|
||||
try:
|
||||
if img.startswith("data:"):
|
||||
base64_images.append(img)
|
||||
logger.info(
|
||||
"图片无需下载: channel=%s, source=%s, input=%s",
|
||||
channel.value if channel else None,
|
||||
source,
|
||||
self._summarize_images([img])[0],
|
||||
)
|
||||
elif img.startswith("tg://file_id/"):
|
||||
file_id = img.replace("tg://file_id/", "")
|
||||
base64_data = self.run_module(
|
||||
"download_file_to_base64", file_id=file_id, source=source
|
||||
"download_telegram_file_to_base64", file_id=file_id, source=source
|
||||
)
|
||||
if base64_data:
|
||||
base64_images.append(f"data:image/jpeg;base64,{base64_data}")
|
||||
@@ -1402,26 +1400,12 @@ class MessageChain(ChainBase):
|
||||
)
|
||||
if data_url:
|
||||
base64_images.append(data_url)
|
||||
logger.info(
|
||||
"图片下载成功: channel=%s, source=%s, input=%s, output=%s",
|
||||
channel.value if channel else None,
|
||||
source,
|
||||
img,
|
||||
self._summarize_images([data_url])[0],
|
||||
)
|
||||
elif channel == MessageChannel.Slack:
|
||||
data_url = self.run_module(
|
||||
"download_file_to_data_url", file_url=img, source=source
|
||||
"download_slack_file_to_data_url", file_url=img, source=source
|
||||
)
|
||||
if data_url:
|
||||
base64_images.append(data_url)
|
||||
logger.info(
|
||||
"图片下载成功: channel=%s, source=%s, input=%s, output=%s",
|
||||
channel.value if channel else None,
|
||||
source,
|
||||
img,
|
||||
self._summarize_images([data_url])[0],
|
||||
)
|
||||
elif img.startswith("vocechat://file/"):
|
||||
data_url = self.run_module(
|
||||
"download_vocechat_image_to_data_url",
|
||||
@@ -1430,30 +1414,12 @@ class MessageChain(ChainBase):
|
||||
)
|
||||
if data_url:
|
||||
base64_images.append(data_url)
|
||||
logger.info(
|
||||
"图片下载成功: channel=%s, source=%s, input=%s, output=%s",
|
||||
channel.value if channel else None,
|
||||
source,
|
||||
img,
|
||||
self._summarize_images([data_url])[0],
|
||||
)
|
||||
elif img.startswith("http"):
|
||||
resp = RequestUtils(timeout=30).get_res(img)
|
||||
if resp and resp.content:
|
||||
import base64
|
||||
|
||||
base64_data = base64.b64encode(resp.content).decode()
|
||||
mime_type = resp.headers.get("Content-Type", "image/jpeg")
|
||||
base64_images.append(f"data:{mime_type};base64,{base64_data}")
|
||||
logger.info(
|
||||
"图片下载成功: channel=%s, source=%s, input=%s, output=%s",
|
||||
channel.value if channel else None,
|
||||
source,
|
||||
img,
|
||||
self._summarize_images(
|
||||
[f"data:{mime_type};base64,{base64_data}"]
|
||||
)[0],
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"下载图片失败: {img}, error: {e}")
|
||||
return base64_images if base64_images else None
|
||||
|
||||
@@ -538,6 +538,35 @@ class ConfigModel(BaseModel):
|
||||
# AI智能体自动重试整理失败记录开关
|
||||
AI_AGENT_RETRY_TRANSFER: bool = False
|
||||
|
||||
# 语音能力提供商(当前仅支持 openai)
|
||||
AI_VOICE_PROVIDER: str = "openai"
|
||||
# 语音识别提供商,未设置时回退到 AI_VOICE_PROVIDER
|
||||
AI_VOICE_STT_PROVIDER: Optional[str] = None
|
||||
# 语音合成提供商,未设置时回退到 AI_VOICE_PROVIDER
|
||||
AI_VOICE_TTS_PROVIDER: Optional[str] = None
|
||||
# 语音能力 API 密钥,未设置且 LLM_PROVIDER=openai 时回退使用 LLM_API_KEY
|
||||
AI_VOICE_API_KEY: Optional[str] = None
|
||||
# 语音识别 API 密钥,未设置时回退到 AI_VOICE_API_KEY
|
||||
AI_VOICE_STT_API_KEY: Optional[str] = None
|
||||
# 语音合成 API 密钥,未设置时回退到 AI_VOICE_API_KEY
|
||||
AI_VOICE_TTS_API_KEY: Optional[str] = None
|
||||
# 语音能力基础URL,未设置且 LLM_PROVIDER=openai 时回退使用 LLM_BASE_URL
|
||||
AI_VOICE_BASE_URL: Optional[str] = None
|
||||
# 语音识别基础URL,未设置时回退到 AI_VOICE_BASE_URL
|
||||
AI_VOICE_STT_BASE_URL: Optional[str] = None
|
||||
# 语音合成基础URL,未设置时回退到 AI_VOICE_BASE_URL
|
||||
AI_VOICE_TTS_BASE_URL: Optional[str] = None
|
||||
# 语音转文字模型
|
||||
AI_VOICE_STT_MODEL: str = "gpt-4o-mini-transcribe"
|
||||
# 文字转语音模型
|
||||
AI_VOICE_TTS_MODEL: str = "gpt-4o-mini-tts"
|
||||
# TTS 发音人
|
||||
AI_VOICE_TTS_VOICE: str = "alloy"
|
||||
# 语音识别语言
|
||||
AI_VOICE_LANGUAGE: str = "zh"
|
||||
# 回复语音时是否同时附带文字说明
|
||||
AI_VOICE_REPLY_WITH_TEXT: bool = False
|
||||
|
||||
|
||||
class Settings(BaseSettings, ConfigModel, LogConfigModel):
|
||||
"""
|
||||
|
||||
197
app/helper/voice.py
Normal file
197
app/helper/voice.py
Normal file
@@ -0,0 +1,197 @@
|
||||
"""语音能力辅助功能。"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from app.core.config import settings
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class VoiceProvider(ABC):
|
||||
"""语音 provider 抽象层。"""
|
||||
|
||||
MAX_TRANSCRIBE_BYTES = 25 * 1024 * 1024
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""provider 名称。"""
|
||||
|
||||
@abstractmethod
|
||||
def is_available_for_stt(self) -> bool:
|
||||
"""是否可用于语音识别。"""
|
||||
|
||||
@abstractmethod
|
||||
def is_available_for_tts(self) -> bool:
|
||||
"""是否可用于语音合成。"""
|
||||
|
||||
@abstractmethod
|
||||
def transcribe_bytes(self, content: bytes, filename: str = "input.ogg") -> Optional[str]:
|
||||
"""将音频字节转成文字。"""
|
||||
|
||||
@abstractmethod
|
||||
def synthesize_speech(self, text: str) -> Optional[Path]:
|
||||
"""将文字转成语音文件。"""
|
||||
|
||||
|
||||
class OpenAIVoiceProvider(VoiceProvider):
|
||||
"""OpenAI / OpenAI-compatible provider。"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "openai"
|
||||
|
||||
@staticmethod
|
||||
def _resolve_credentials(mode: str) -> tuple[Optional[str], Optional[str]]:
|
||||
mode = mode.lower()
|
||||
provider = (
|
||||
settings.AI_VOICE_STT_PROVIDER
|
||||
if mode == "stt"
|
||||
else settings.AI_VOICE_TTS_PROVIDER
|
||||
) or settings.AI_VOICE_PROVIDER
|
||||
provider = (provider or "").strip().lower()
|
||||
|
||||
api_key = (
|
||||
settings.AI_VOICE_STT_API_KEY
|
||||
if mode == "stt"
|
||||
else settings.AI_VOICE_TTS_API_KEY
|
||||
) or settings.AI_VOICE_API_KEY
|
||||
base_url = (
|
||||
settings.AI_VOICE_STT_BASE_URL
|
||||
if mode == "stt"
|
||||
else settings.AI_VOICE_TTS_BASE_URL
|
||||
) or settings.AI_VOICE_BASE_URL
|
||||
|
||||
if (
|
||||
not api_key
|
||||
and provider == "openai"
|
||||
and (settings.LLM_PROVIDER or "").strip().lower() == "openai"
|
||||
):
|
||||
api_key = settings.LLM_API_KEY
|
||||
base_url = base_url or settings.LLM_BASE_URL
|
||||
|
||||
return api_key, base_url
|
||||
|
||||
def _get_client(self, mode: str):
|
||||
from openai import OpenAI
|
||||
|
||||
api_key, base_url = self._resolve_credentials(mode)
|
||||
if not api_key:
|
||||
raise ValueError(f"{mode.upper()} provider 未配置 API Key")
|
||||
return OpenAI(api_key=api_key, base_url=base_url, max_retries=3)
|
||||
|
||||
def is_available_for_stt(self) -> bool:
|
||||
api_key, _ = self._resolve_credentials("stt")
|
||||
return bool(api_key)
|
||||
|
||||
def is_available_for_tts(self) -> bool:
|
||||
api_key, _ = self._resolve_credentials("tts")
|
||||
return bool(api_key)
|
||||
|
||||
def transcribe_bytes(self, content: bytes, filename: str = "input.ogg") -> Optional[str]:
|
||||
if not content:
|
||||
return None
|
||||
if len(content) > self.MAX_TRANSCRIBE_BYTES:
|
||||
raise ValueError("语音文件超过 25MB,无法识别")
|
||||
|
||||
try:
|
||||
client = self._get_client("stt")
|
||||
audio_file = BytesIO(content)
|
||||
audio_file.name = filename
|
||||
response = client.audio.transcriptions.create(
|
||||
model=settings.AI_VOICE_STT_MODEL,
|
||||
file=audio_file,
|
||||
language=settings.AI_VOICE_LANGUAGE or "zh",
|
||||
response_format="verbose_json",
|
||||
)
|
||||
text = getattr(response, "text", None)
|
||||
return text.strip() if text else None
|
||||
except Exception as err:
|
||||
logger.error(f"语音转文字失败: provider={self.name}, error={err}")
|
||||
return None
|
||||
|
||||
def synthesize_speech(self, text: str) -> Optional[Path]:
|
||||
if not text:
|
||||
return None
|
||||
|
||||
try:
|
||||
client = self._get_client("tts")
|
||||
voice_dir = settings.TEMP_PATH / "voice"
|
||||
voice_dir.mkdir(parents=True, exist_ok=True)
|
||||
output_path = voice_dir / f"{uuid4().hex}.opus"
|
||||
response = client.audio.speech.create(
|
||||
model=settings.AI_VOICE_TTS_MODEL,
|
||||
voice=settings.AI_VOICE_TTS_VOICE,
|
||||
input=text,
|
||||
response_format="opus",
|
||||
)
|
||||
response.write_to_file(output_path)
|
||||
return output_path
|
||||
except Exception as err:
|
||||
logger.error(f"文字转语音失败: provider={self.name}, error={err}")
|
||||
return None
|
||||
|
||||
|
||||
class VoiceHelper:
|
||||
"""统一语音入口,负责按 STT/TTS provider 路由。"""
|
||||
|
||||
_providers: Dict[str, VoiceProvider] = {
|
||||
"openai": OpenAIVoiceProvider(),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def register_provider(cls, provider: VoiceProvider) -> None:
|
||||
cls._providers[provider.name.lower()] = provider
|
||||
|
||||
@staticmethod
|
||||
def _resolve_provider_name(mode: str) -> str:
|
||||
mode = mode.lower()
|
||||
provider = (
|
||||
settings.AI_VOICE_STT_PROVIDER
|
||||
if mode == "stt"
|
||||
else settings.AI_VOICE_TTS_PROVIDER
|
||||
) or settings.AI_VOICE_PROVIDER
|
||||
return (provider or "openai").strip().lower()
|
||||
|
||||
@classmethod
|
||||
def get_provider(cls, mode: str) -> Optional[VoiceProvider]:
|
||||
provider_name = cls._resolve_provider_name(mode)
|
||||
provider = cls._providers.get(provider_name)
|
||||
if provider:
|
||||
return provider
|
||||
logger.warning(f"未注册语音 provider: mode={mode}, provider={provider_name}")
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_registered_providers(cls) -> list[str]:
|
||||
return sorted(cls._providers.keys())
|
||||
|
||||
@classmethod
|
||||
def is_available(cls, mode: Optional[str] = None) -> bool:
|
||||
if mode:
|
||||
provider = cls.get_provider(mode)
|
||||
if not provider:
|
||||
return False
|
||||
return (
|
||||
provider.is_available_for_stt()
|
||||
if mode.lower() == "stt"
|
||||
else provider.is_available_for_tts()
|
||||
)
|
||||
return cls.is_available("stt") or cls.is_available("tts")
|
||||
|
||||
@classmethod
|
||||
def transcribe_bytes(cls, content: bytes, filename: str = "input.ogg") -> Optional[str]:
|
||||
provider = cls.get_provider("stt")
|
||||
if not provider:
|
||||
return None
|
||||
return provider.transcribe_bytes(content=content, filename=filename)
|
||||
|
||||
@classmethod
|
||||
def synthesize_speech(cls, text: str) -> Optional[Path]:
|
||||
provider = cls.get_provider("tts")
|
||||
if not provider:
|
||||
return None
|
||||
return provider.synthesize_speech(text=text)
|
||||
@@ -297,7 +297,7 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]):
|
||||
images.append(url)
|
||||
return images if images else None
|
||||
|
||||
def download_file_to_data_url(self, file_url: str, source: str) -> Optional[str]:
|
||||
def download_slack_file_to_data_url(self, file_url: str, source: str) -> Optional[str]:
|
||||
"""
|
||||
下载Slack文件并转为data URL
|
||||
:param file_url: Slack私有文件URL
|
||||
|
||||
@@ -214,17 +214,19 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
|
||||
text = self._append_reply_markup_links(text, msg.get("reply_markup"))
|
||||
|
||||
images = self._extract_images(msg)
|
||||
audio_refs = self._extract_audio_refs(msg)
|
||||
|
||||
if user_id:
|
||||
if not text and not images:
|
||||
if not text and not images and not audio_refs:
|
||||
logger.debug(
|
||||
f"收到来自 {client_config.name} 的Telegram消息无文本和图片"
|
||||
f"收到来自 {client_config.name} 的Telegram消息无文本、图片和语音"
|
||||
)
|
||||
return None
|
||||
|
||||
logger.info(
|
||||
f"收到来自 {client_config.name} 的Telegram消息:"
|
||||
f"userid={user_id}, username={user_name}, chat_id={chat_id}, text={text}, images={len(images) if images else 0}"
|
||||
f"userid={user_id}, username={user_name}, chat_id={chat_id}, text={text}, "
|
||||
f"images={len(images) if images else 0}, audios={len(audio_refs) if audio_refs else 0}"
|
||||
)
|
||||
|
||||
cleaned_text = (
|
||||
@@ -263,6 +265,7 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
|
||||
text=cleaned_text,
|
||||
chat_id=str(chat_id) if chat_id else None,
|
||||
images=images if images else None,
|
||||
audio_refs=audio_refs if audio_refs else None,
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -288,6 +291,26 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
|
||||
|
||||
return images if images else None
|
||||
|
||||
@staticmethod
|
||||
def _extract_audio_refs(msg: dict) -> Optional[List[str]]:
|
||||
"""
|
||||
从Telegram消息中提取语音/音频 file_id。
|
||||
"""
|
||||
audio_refs = []
|
||||
voice = msg.get("voice")
|
||||
if voice:
|
||||
file_id = voice.get("file_id")
|
||||
if file_id:
|
||||
audio_refs.append(f"tg://voice_file_id/{file_id}")
|
||||
|
||||
audio = msg.get("audio")
|
||||
if audio:
|
||||
file_id = audio.get("file_id")
|
||||
if file_id:
|
||||
audio_refs.append(f"tg://audio_file_id/{file_id}")
|
||||
|
||||
return audio_refs if audio_refs else None
|
||||
|
||||
@staticmethod
|
||||
def _embed_entity_links(text: str, entities: Optional[List[dict]]) -> str:
|
||||
"""
|
||||
@@ -389,17 +412,25 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
|
||||
return
|
||||
client: Telegram = self.get_instance(conf.name)
|
||||
if client:
|
||||
client.send_msg(
|
||||
title=message.title,
|
||||
text=message.text,
|
||||
image=message.image,
|
||||
userid=userid,
|
||||
link=message.link,
|
||||
buttons=message.buttons,
|
||||
original_message_id=message.original_message_id,
|
||||
original_chat_id=message.original_chat_id,
|
||||
disable_web_page_preview=message.disable_web_page_preview,
|
||||
)
|
||||
if message.voice_path:
|
||||
client.send_voice(
|
||||
voice_path=message.voice_path,
|
||||
userid=userid,
|
||||
caption=message.voice_caption,
|
||||
original_chat_id=message.original_chat_id,
|
||||
)
|
||||
else:
|
||||
client.send_msg(
|
||||
title=message.title,
|
||||
text=message.text,
|
||||
image=message.image,
|
||||
userid=userid,
|
||||
link=message.link,
|
||||
buttons=message.buttons,
|
||||
original_message_id=message.original_message_id,
|
||||
original_chat_id=message.original_chat_id,
|
||||
disable_web_page_preview=message.disable_web_page_preview,
|
||||
)
|
||||
|
||||
def post_medias_message(
|
||||
self, message: Notification, medias: List[MediaInfo]
|
||||
@@ -531,14 +562,22 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
|
||||
return None
|
||||
client: Telegram = self.get_instance(conf.name)
|
||||
if client:
|
||||
result = client.send_msg(
|
||||
title=message.title,
|
||||
text=message.text,
|
||||
image=message.image,
|
||||
userid=userid,
|
||||
link=message.link,
|
||||
disable_web_page_preview=message.disable_web_page_preview,
|
||||
)
|
||||
if message.voice_path:
|
||||
result = client.send_voice(
|
||||
voice_path=message.voice_path,
|
||||
userid=userid,
|
||||
caption=message.voice_caption,
|
||||
original_chat_id=message.original_chat_id,
|
||||
)
|
||||
else:
|
||||
result = client.send_msg(
|
||||
title=message.title,
|
||||
text=message.text,
|
||||
image=message.image,
|
||||
userid=userid,
|
||||
link=message.link,
|
||||
disable_web_page_preview=message.disable_web_page_preview,
|
||||
)
|
||||
if result and result.get("success"):
|
||||
return MessageResponse(
|
||||
message_id=result.get("message_id"),
|
||||
@@ -601,7 +640,7 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
|
||||
)
|
||||
client.register_commands(filtered_scoped_commands)
|
||||
|
||||
def download_file_to_base64(self, file_id: str, source: str) -> Optional[str]:
|
||||
def download_telegram_file_to_base64(self, file_id: str, source: str) -> Optional[str]:
|
||||
"""
|
||||
下载Telegram文件并转为base64
|
||||
:param file_id: Telegram文件ID
|
||||
@@ -620,3 +659,15 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
|
||||
|
||||
return base64.b64encode(file_content).decode()
|
||||
return None
|
||||
|
||||
def download_telegram_file_bytes(self, file_id: str, source: str) -> Optional[bytes]:
|
||||
"""
|
||||
下载Telegram文件并返回原始字节。
|
||||
"""
|
||||
config = self.get_config(source)
|
||||
if not config:
|
||||
return None
|
||||
client = self.get_instance(config.name)
|
||||
if not client:
|
||||
return None
|
||||
return client.download_file(file_id)
|
||||
|
||||
@@ -3,6 +3,7 @@ import json
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, List, Dict, Callable, Union
|
||||
from urllib.parse import urljoin, quote
|
||||
|
||||
@@ -461,6 +462,51 @@ class Telegram:
|
||||
self._stop_typing_task(chat_id)
|
||||
return {"success": False}
|
||||
|
||||
def send_voice(
|
||||
self,
|
||||
voice_path: str,
|
||||
userid: Optional[str] = None,
|
||||
caption: Optional[str] = None,
|
||||
original_chat_id: Optional[str] = None,
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
发送Telegram语音消息。
|
||||
"""
|
||||
if not self._bot or not voice_path:
|
||||
return None
|
||||
|
||||
chat_id = self._determine_target_chat_id(userid, original_chat_id)
|
||||
voice_file = Path(voice_path)
|
||||
if not voice_file.exists():
|
||||
logger.error(f"语音文件不存在: {voice_file}")
|
||||
return {"success": False}
|
||||
|
||||
try:
|
||||
with voice_file.open("rb") as fp:
|
||||
sent = self._bot.send_voice(
|
||||
chat_id=chat_id,
|
||||
voice=fp,
|
||||
caption=standardize(caption) if caption else None,
|
||||
parse_mode="MarkdownV2" if caption else None,
|
||||
)
|
||||
self._stop_typing_task(chat_id)
|
||||
if sent and hasattr(sent, "message_id"):
|
||||
return {
|
||||
"success": True,
|
||||
"message_id": sent.message_id,
|
||||
"chat_id": sent.chat.id if hasattr(sent, "chat") else chat_id,
|
||||
}
|
||||
return {"success": bool(sent)}
|
||||
except Exception as err:
|
||||
logger.error(f"发送语音消息失败:{err}")
|
||||
self._stop_typing_task(chat_id)
|
||||
return {"success": False}
|
||||
finally:
|
||||
try:
|
||||
voice_file.unlink(missing_ok=True)
|
||||
except Exception as cleanup_err:
|
||||
logger.debug(f"清理语音临时文件失败: {cleanup_err}")
|
||||
|
||||
def _determine_target_chat_id(
|
||||
self, userid: Optional[str] = None, original_chat_id: Optional[str] = None
|
||||
) -> str:
|
||||
|
||||
@@ -167,6 +167,7 @@ class WechatModule(_ModuleBase, _MessageBase[WeChat]):
|
||||
# 解析消息内容
|
||||
content = None
|
||||
images = None
|
||||
audio_refs = None
|
||||
if msg_type == "event" and event == "click":
|
||||
# 校验用户有权限执行交互命令
|
||||
if client_config.config.get('WECHAT_ADMINS'):
|
||||
@@ -192,14 +193,24 @@ class WechatModule(_ModuleBase, _MessageBase[WeChat]):
|
||||
logger.info(
|
||||
f"收到来自 {client_config.name} 的微信图片消息:userid={user_id}, images={len(images) if images else 0}"
|
||||
)
|
||||
elif msg_type == "voice":
|
||||
media_id = DomUtils.tag_value(root_node, "MediaId")
|
||||
recognition = DomUtils.tag_value(root_node, "Recognition", default="")
|
||||
content = (recognition or "").strip()
|
||||
if media_id:
|
||||
audio_refs = [f"wxwork://voice_media_id/{media_id}"]
|
||||
logger.info(
|
||||
f"收到来自 {client_config.name} 的微信语音消息:userid={user_id}, "
|
||||
f"text={content}, audios={len(audio_refs) if audio_refs else 0}"
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
if content or images:
|
||||
if content or images or audio_refs:
|
||||
# 处理消息内容
|
||||
return CommingMessage(channel=MessageChannel.Wechat, source=client_config.name,
|
||||
userid=user_id, username=user_id, text=content or "",
|
||||
images=images)
|
||||
images=images, audio_refs=audio_refs)
|
||||
except Exception as err:
|
||||
logger.error(f"微信消息处理发生错误:{str(err)}")
|
||||
return None
|
||||
@@ -230,6 +241,7 @@ class WechatModule(_ModuleBase, _MessageBase[WeChat]):
|
||||
|
||||
text = WeChatBot._extract_text_from_body(payload_body)
|
||||
images = WeChatBot._extract_images_from_body(payload_body)
|
||||
audio_refs = ["wxbot://voice"] if payload_body.get("msgtype") == "voice" else None
|
||||
if text:
|
||||
text = re.sub(r"@\S+", "", text).strip()
|
||||
|
||||
@@ -245,7 +257,7 @@ class WechatModule(_ModuleBase, _MessageBase[WeChat]):
|
||||
client.send_msg(title="只有管理员才有权限执行此命令", userid=sender)
|
||||
return None
|
||||
|
||||
if not text and not images:
|
||||
if not text and not images and not audio_refs:
|
||||
return None
|
||||
|
||||
logger.info(
|
||||
@@ -259,6 +271,7 @@ class WechatModule(_ModuleBase, _MessageBase[WeChat]):
|
||||
username=sender,
|
||||
text=text or "",
|
||||
images=images,
|
||||
audio_refs=audio_refs,
|
||||
)
|
||||
|
||||
def post_message(self, message: Notification, **kwargs) -> None:
|
||||
@@ -279,8 +292,17 @@ class WechatModule(_ModuleBase, _MessageBase[WeChat]):
|
||||
return
|
||||
client: WeChat = self.get_instance(conf.name)
|
||||
if client:
|
||||
client.send_msg(title=message.title, text=message.text,
|
||||
image=message.image, userid=userid, link=message.link)
|
||||
if message.voice_path and hasattr(client, "send_voice"):
|
||||
sent = client.send_voice(
|
||||
voice_path=message.voice_path,
|
||||
userid=userid,
|
||||
)
|
||||
if not sent:
|
||||
client.send_msg(title=message.title, text=message.text,
|
||||
image=message.image, userid=userid, link=message.link)
|
||||
else:
|
||||
client.send_msg(title=message.title, text=message.text,
|
||||
image=message.image, userid=userid, link=message.link)
|
||||
|
||||
def download_wechat_image_to_data_url(self, image_ref: str, source: str) -> Optional[str]:
|
||||
"""
|
||||
@@ -301,6 +323,23 @@ class WechatModule(_ModuleBase, _MessageBase[WeChat]):
|
||||
return client.download_image_to_data_url(image_ref)
|
||||
return None
|
||||
|
||||
def download_wechat_media_bytes(self, media_ref: str, source: str) -> Optional[bytes]:
|
||||
"""
|
||||
下载企业微信语音媒体并返回原始字节。
|
||||
"""
|
||||
if not media_ref:
|
||||
return None
|
||||
client_config = self.get_config(source)
|
||||
if not client_config:
|
||||
return None
|
||||
client = self.get_instance(client_config.name)
|
||||
if not client or not hasattr(client, "download_media_bytes"):
|
||||
return None
|
||||
if media_ref.startswith("wxwork://voice_media_id/"):
|
||||
media_id = media_ref.replace("wxwork://voice_media_id/", "", 1)
|
||||
return client.download_media_bytes(media_id)
|
||||
return None
|
||||
|
||||
def post_medias_message(self, message: Notification, medias: List[MediaInfo]) -> None:
|
||||
"""
|
||||
发送媒体信息选择列表
|
||||
|
||||
@@ -2,7 +2,9 @@ import json
|
||||
import re
|
||||
import threading
|
||||
import base64
|
||||
import subprocess
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Dict
|
||||
|
||||
from app.core.context import MediaInfo, Context
|
||||
@@ -46,6 +48,8 @@ class WeChat:
|
||||
_delete_menu_url = "cgi-bin/menu/delete?access_token={access_token}&agentid={agentid}"
|
||||
# 企业微信下载媒体URL
|
||||
_download_media_url = "cgi-bin/media/get?access_token={access_token}&media_id={media_id}"
|
||||
# 企业微信上传临时素材URL
|
||||
_upload_media_url = "cgi-bin/media/upload?access_token={access_token}&type={media_type}"
|
||||
|
||||
def __init__(self, WECHAT_CORPID: Optional[str] = None, WECHAT_APP_SECRET: Optional[str] = None,
|
||||
WECHAT_APP_ID: Optional[str] = None, WECHAT_PROXY: Optional[str] = None, **kwargs):
|
||||
@@ -66,6 +70,7 @@ class WeChat:
|
||||
self._create_menu_url = UrlUtils.adapt_request_url(self._proxy, self._create_menu_url)
|
||||
self._delete_menu_url = UrlUtils.adapt_request_url(self._proxy, self._delete_menu_url)
|
||||
self._download_media_url = UrlUtils.adapt_request_url(self._proxy, self._download_media_url)
|
||||
self._upload_media_url = UrlUtils.adapt_request_url(self._proxy, self._upload_media_url)
|
||||
|
||||
if self._corpid and self._appsecret and self._appid:
|
||||
self.__get_access_token()
|
||||
@@ -323,6 +328,168 @@ class WeChat:
|
||||
mime_type = self._guess_mime_type(res.content, content_type or "image/jpeg")
|
||||
return f"data:{mime_type};base64,{base64.b64encode(res.content).decode()}"
|
||||
|
||||
def download_media_bytes(self, media_id: str) -> Optional[bytes]:
|
||||
"""
|
||||
下载企业微信媒体文件并返回原始字节。
|
||||
"""
|
||||
if not media_id:
|
||||
return None
|
||||
access_token = self.__get_access_token()
|
||||
if not access_token:
|
||||
logger.error("下载企业微信媒体失败:access_token 获取失败")
|
||||
return None
|
||||
req_url = self._download_media_url.format(
|
||||
access_token=access_token,
|
||||
media_id=media_id,
|
||||
)
|
||||
try:
|
||||
res = RequestUtils(timeout=30).get_res(req_url)
|
||||
except Exception as err:
|
||||
logger.error(f"下载企业微信媒体失败:{err}")
|
||||
return None
|
||||
if not res or not res.content:
|
||||
return None
|
||||
content_type = (res.headers.get("Content-Type") or "").split(";")[0].strip()
|
||||
if content_type == "application/json":
|
||||
try:
|
||||
logger.error(f"企业微信媒体下载失败:{res.json()}")
|
||||
except Exception:
|
||||
logger.error(f"企业微信媒体下载失败:{res.text}")
|
||||
return None
|
||||
return res.content
|
||||
|
||||
@staticmethod
|
||||
def _convert_voice_to_amr(voice_path: str) -> Optional[Path]:
|
||||
"""
|
||||
将语音文件转换为企业微信要求的 AMR 格式(<=60s)。
|
||||
"""
|
||||
src_path = Path(voice_path)
|
||||
if not src_path.exists():
|
||||
logger.error(f"语音文件不存在:{src_path}")
|
||||
return None
|
||||
|
||||
dst_path = src_path.with_suffix(".amr")
|
||||
cmd = [
|
||||
"ffmpeg",
|
||||
"-y",
|
||||
"-i",
|
||||
str(src_path),
|
||||
"-ar",
|
||||
"8000",
|
||||
"-ac",
|
||||
"1",
|
||||
"-t",
|
||||
"60",
|
||||
str(dst_path),
|
||||
]
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=False,
|
||||
)
|
||||
except Exception as err:
|
||||
logger.error(f"调用 ffmpeg 转换 AMR 失败:{err}")
|
||||
return None
|
||||
|
||||
if result.returncode != 0 or not dst_path.exists():
|
||||
logger.error(
|
||||
"ffmpeg 转换 AMR 失败: returncode=%s, stderr=%s",
|
||||
result.returncode,
|
||||
(result.stderr or "").strip()[:500],
|
||||
)
|
||||
return None
|
||||
|
||||
if dst_path.stat().st_size > 2 * 1024 * 1024:
|
||||
logger.error("AMR 语音文件超过 2MB,无法发送到企业微信")
|
||||
dst_path.unlink(missing_ok=True)
|
||||
return None
|
||||
return dst_path
|
||||
|
||||
def _upload_temp_media(self, media_path: Path, media_type: str = "voice") -> Optional[str]:
|
||||
"""
|
||||
上传企业微信临时素材,返回 media_id。
|
||||
"""
|
||||
access_token = self.__get_access_token()
|
||||
if not access_token:
|
||||
return None
|
||||
req_url = self._upload_media_url.format(
|
||||
access_token=access_token,
|
||||
media_type=media_type,
|
||||
)
|
||||
try:
|
||||
with media_path.open("rb") as media_file:
|
||||
response = RequestUtils(timeout=60).request(
|
||||
method="post",
|
||||
url=req_url,
|
||||
files={
|
||||
"media": (
|
||||
media_path.name,
|
||||
media_file,
|
||||
"voice/amr" if media_type == "voice" else "application/octet-stream",
|
||||
)
|
||||
},
|
||||
)
|
||||
except Exception as err:
|
||||
logger.error(f"上传企业微信临时素材失败:{err}")
|
||||
return None
|
||||
|
||||
if not response:
|
||||
return None
|
||||
|
||||
try:
|
||||
ret_json = response.json()
|
||||
except Exception as err:
|
||||
logger.error(f"解析企业微信临时素材响应失败:{err}")
|
||||
return None
|
||||
|
||||
if ret_json.get("errcode") != 0:
|
||||
logger.error(f"上传企业微信临时素材失败:{ret_json}")
|
||||
return None
|
||||
return ret_json.get("media_id")
|
||||
|
||||
def send_voice(self, voice_path: str, userid: Optional[str] = None) -> Optional[bool]:
|
||||
"""
|
||||
发送企业微信语音消息。仅自建应用模式支持。
|
||||
"""
|
||||
if not voice_path:
|
||||
return False
|
||||
if not self.__get_access_token():
|
||||
logger.error("获取微信access_token失败,请检查参数配置")
|
||||
return None
|
||||
if not userid:
|
||||
userid = "@all"
|
||||
|
||||
source_path = Path(voice_path)
|
||||
converted_path = self._convert_voice_to_amr(voice_path)
|
||||
if not converted_path:
|
||||
return False
|
||||
|
||||
try:
|
||||
media_id = self._upload_temp_media(converted_path, media_type="voice")
|
||||
if not media_id:
|
||||
return False
|
||||
|
||||
req_json = {
|
||||
"touser": userid,
|
||||
"msgtype": "voice",
|
||||
"agentid": self._appid,
|
||||
"voice": {
|
||||
"media_id": media_id
|
||||
},
|
||||
"safe": 0,
|
||||
"enable_id_trans": 0,
|
||||
"enable_duplicate_check": 0
|
||||
}
|
||||
return self.__post_request(self._send_msg_url, req_json)
|
||||
except Exception as err:
|
||||
logger.error(f"发送企业微信语音消息失败:{err}")
|
||||
return False
|
||||
finally:
|
||||
converted_path.unlink(missing_ok=True)
|
||||
source_path.unlink(missing_ok=True)
|
||||
|
||||
def send_medias_msg(self, medias: List[MediaInfo], userid: Optional[str] = None) -> Optional[bool]:
|
||||
"""
|
||||
发送列表类消息
|
||||
|
||||
@@ -55,6 +55,8 @@ class CommingMessage(BaseModel):
|
||||
callback_query: Optional[Dict] = None
|
||||
# 图片列表(图片URL或file_id)
|
||||
images: Optional[List[str]] = None
|
||||
# 语音/音频引用列表
|
||||
audio_refs: Optional[List[str]] = None
|
||||
|
||||
def to_dict(self):
|
||||
"""
|
||||
@@ -86,6 +88,10 @@ class Notification(BaseModel):
|
||||
text: Optional[str] = None
|
||||
# 图片
|
||||
image: Optional[str] = None
|
||||
# 语音文件路径
|
||||
voice_path: Optional[str] = None
|
||||
# 语音消息附带说明文字
|
||||
voice_caption: Optional[str] = None
|
||||
# 链接
|
||||
link: Optional[str] = None
|
||||
# 用户ID
|
||||
|
||||
@@ -2,11 +2,12 @@ import base64
|
||||
import json
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock, patch
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
from telebot import apihelper
|
||||
|
||||
from app.agent.tools.impl.send_message import SendMessageInput
|
||||
from app.agent import MoviePilotAgent, AgentChain
|
||||
from app.chain.message import MessageChain
|
||||
from app.core.config import settings
|
||||
from app.modules.discord import DiscordModule
|
||||
@@ -23,6 +24,19 @@ from app.schemas.types import MessageChannel
|
||||
|
||||
|
||||
class AgentImageSupportTest(unittest.TestCase):
|
||||
def test_telegram_extract_audio_refs_returns_prefixed_file_ids(self):
|
||||
audio_refs = TelegramModule._extract_audio_refs(
|
||||
{
|
||||
"voice": {"file_id": "voice-1"},
|
||||
"audio": {"file_id": "audio-1"},
|
||||
}
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
audio_refs,
|
||||
["tg://voice_file_id/voice-1", "tg://audio_file_id/audio-1"],
|
||||
)
|
||||
|
||||
def test_telegram_extract_images_returns_prefixed_file_ids(self):
|
||||
images = TelegramModule._extract_images(
|
||||
{
|
||||
@@ -126,6 +140,25 @@ class AgentImageSupportTest(unittest.TestCase):
|
||||
self.assertEqual(handle_kwargs["text"], "")
|
||||
self.assertEqual(handle_kwargs["images"], ["tg://file_id/image-1"])
|
||||
|
||||
def test_process_allows_audio_only_message(self):
|
||||
chain = MessageChain()
|
||||
message = CommingMessage(
|
||||
channel=MessageChannel.Telegram,
|
||||
source="telegram-test",
|
||||
userid="10001",
|
||||
username="tester",
|
||||
audio_refs=["tg://voice_file_id/voice-1"],
|
||||
)
|
||||
|
||||
with patch.object(chain, "message_parser", return_value=message), patch.object(
|
||||
chain, "handle_message"
|
||||
) as handle_message:
|
||||
chain.process(body="{}", form={}, args={"source": "telegram-test"})
|
||||
|
||||
handle_kwargs = handle_message.call_args.kwargs
|
||||
self.assertEqual(handle_kwargs["text"], "")
|
||||
self.assertEqual(handle_kwargs["audio_refs"], ["tg://voice_file_id/voice-1"])
|
||||
|
||||
def test_image_message_routes_to_agent_even_when_global_agent_is_disabled(self):
|
||||
chain = MessageChain()
|
||||
|
||||
@@ -149,6 +182,48 @@ class AgentImageSupportTest(unittest.TestCase):
|
||||
|
||||
handle_ai_message.assert_called_once()
|
||||
|
||||
def test_audio_message_routes_to_agent_with_voice_reply_flag(self):
|
||||
chain = MessageChain()
|
||||
|
||||
with patch.object(chain, "load_cache", return_value={}), patch.object(
|
||||
chain, "_transcribe_audio_refs", return_value="帮我推荐一部电影"
|
||||
), patch.object(chain.messagehelper, "put"), patch.object(
|
||||
chain.messageoper, "add"
|
||||
), patch.object(chain, "_handle_ai_message") as handle_ai_message:
|
||||
chain.handle_message(
|
||||
channel=MessageChannel.Telegram,
|
||||
source="telegram-test",
|
||||
userid="10001",
|
||||
username="tester",
|
||||
text="",
|
||||
audio_refs=["tg://voice_file_id/voice-1"],
|
||||
)
|
||||
|
||||
handle_ai_message.assert_called_once()
|
||||
self.assertEqual(handle_ai_message.call_args.kwargs["text"], "帮我推荐一部电影")
|
||||
self.assertTrue(handle_ai_message.call_args.kwargs["reply_with_voice"])
|
||||
|
||||
def test_agent_send_agent_message_does_not_auto_convert_to_voice(self):
|
||||
agent = MoviePilotAgent(
|
||||
session_id="session-1",
|
||||
user_id="user-1",
|
||||
channel=MessageChannel.Telegram.value,
|
||||
source="telegram-test",
|
||||
username="tester",
|
||||
)
|
||||
agent.reply_with_voice = True
|
||||
|
||||
with patch.object(
|
||||
AgentChain, "async_post_message", new_callable=AsyncMock
|
||||
) as async_post_message:
|
||||
import asyncio
|
||||
|
||||
asyncio.run(agent.send_agent_message("这是语音回复"))
|
||||
|
||||
notification = async_post_message.await_args.args[0]
|
||||
self.assertIsNone(notification.voice_path)
|
||||
self.assertEqual(notification.text, "这是语音回复")
|
||||
|
||||
def test_slack_images_use_authenticated_data_url_download(self):
|
||||
chain = MessageChain()
|
||||
|
||||
|
||||
55
tests/test_voice_helper.py
Normal file
55
tests/test_voice_helper.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import unittest
|
||||
import sys
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
sys.modules.setdefault("psutil", Mock())
|
||||
sys.modules.setdefault("pyquery", Mock())
|
||||
|
||||
from app.core.config import settings
|
||||
from app.helper.voice import VoiceHelper, OpenAIVoiceProvider
|
||||
|
||||
|
||||
class VoiceHelperTest(unittest.TestCase):
|
||||
def test_registered_providers_contains_openai(self):
|
||||
self.assertIn("openai", VoiceHelper.get_registered_providers())
|
||||
|
||||
def test_get_provider_falls_back_to_global_provider(self):
|
||||
with patch.object(settings, "AI_VOICE_PROVIDER", "openai"), patch.object(
|
||||
settings, "AI_VOICE_STT_PROVIDER", None
|
||||
):
|
||||
provider = VoiceHelper.get_provider("stt")
|
||||
|
||||
self.assertIsInstance(provider, OpenAIVoiceProvider)
|
||||
|
||||
def test_is_available_checks_stt_and_tts_separately(self):
|
||||
provider = Mock()
|
||||
provider.is_available_for_stt.return_value = True
|
||||
provider.is_available_for_tts.return_value = False
|
||||
|
||||
with patch.object(VoiceHelper, "get_provider", return_value=provider):
|
||||
self.assertTrue(VoiceHelper.is_available("stt"))
|
||||
self.assertFalse(VoiceHelper.is_available("tts"))
|
||||
|
||||
def test_transcribe_bytes_routes_to_stt_provider(self):
|
||||
provider = Mock()
|
||||
provider.transcribe_bytes.return_value = "你好"
|
||||
|
||||
with patch.object(VoiceHelper, "get_provider", return_value=provider):
|
||||
result = VoiceHelper.transcribe_bytes(b"audio")
|
||||
|
||||
self.assertEqual(result, "你好")
|
||||
provider.transcribe_bytes.assert_called_once()
|
||||
|
||||
def test_synthesize_speech_routes_to_tts_provider(self):
|
||||
provider = Mock()
|
||||
provider.synthesize_speech.return_value = "/tmp/reply.opus"
|
||||
|
||||
with patch.object(VoiceHelper, "get_provider", return_value=provider):
|
||||
result = VoiceHelper.synthesize_speech("你好")
|
||||
|
||||
self.assertEqual(result, "/tmp/reply.opus")
|
||||
provider.synthesize_speech.assert_called_once_with(text="你好")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user