mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-07-04 19:06:39 +08:00
Simplify agent message handling and streaming cleanup
This commit is contained in:
@@ -317,6 +317,8 @@ class MoviePilotAgent:
|
||||
"""
|
||||
判断当前 Agent 是否需要写入会话历史表。
|
||||
"""
|
||||
if self._tool_context.get("user_reply_sent"):
|
||||
return False
|
||||
return bool(self.channel and self.source)
|
||||
|
||||
def _save_display_history_messages(self, messages: List[dict]) -> None:
|
||||
@@ -1102,9 +1104,9 @@ class MoviePilotAgent:
|
||||
self._streamed_output = ""
|
||||
|
||||
# 获取历史消息
|
||||
messages = memory_manager.get_agent_messages(
|
||||
messages = list(memory_manager.get_agent_messages(
|
||||
session_id=self.session_id, user_id=self.user_id
|
||||
)
|
||||
))
|
||||
|
||||
# 构建结构化用户消息内容
|
||||
request_payload = {
|
||||
@@ -1269,6 +1271,7 @@ class MoviePilotAgent:
|
||||
self._agent_started_at = datetime.now()
|
||||
self._llm_runtime_config = None
|
||||
self._llm_provider_selection = {}
|
||||
streaming_stopped = False
|
||||
try:
|
||||
# Agent运行配置
|
||||
agent_config = {
|
||||
@@ -1316,6 +1319,7 @@ class MoviePilotAgent:
|
||||
all_sent_via_stream,
|
||||
streamed_text,
|
||||
) = await self.stream_handler.stop_streaming()
|
||||
streaming_stopped = True
|
||||
|
||||
if not all_sent_via_stream:
|
||||
# 流式输出未能发送全部内容(发送失败等)
|
||||
@@ -1418,7 +1422,8 @@ class MoviePilotAgent:
|
||||
error=execution_error,
|
||||
)
|
||||
# 确保停止流式输出
|
||||
await self.stream_handler.stop_streaming()
|
||||
if not streaming_stopped:
|
||||
await self.stream_handler.stop_streaming()
|
||||
|
||||
async def send_agent_message(self, message: str, title: str = ""):
|
||||
"""
|
||||
|
||||
@@ -87,7 +87,6 @@ You act as a proactive agent. Your goal is to fully resolve the user's media-rel
|
||||
{button_choice_spec}
|
||||
- Voice replies: {voice_reply_spec}
|
||||
- If the current channel supports image sending and an image would materially help, you may use the `send_message` tool with `image_url` to send it.
|
||||
{send_message_format_spec}
|
||||
- If the current channel supports file sending and you need to return a local image or file for the user to download, use `send_local_file`.
|
||||
</communication_runtime>
|
||||
|
||||
|
||||
@@ -141,9 +141,6 @@ class PromptManager:
|
||||
if caps:
|
||||
markdown_spec = self._generate_formatting_instructions(caps)
|
||||
button_choice_spec = self._generate_button_choice_instructions(msg_channel)
|
||||
send_message_format_spec = self._generate_send_message_format_instructions(
|
||||
msg_channel
|
||||
)
|
||||
|
||||
# 啰嗦模式
|
||||
verbose_spec = ""
|
||||
@@ -169,7 +166,6 @@ class PromptManager:
|
||||
moviepilot_info=moviepilot_info,
|
||||
voice_reply_spec=voice_reply_spec,
|
||||
button_choice_spec=button_choice_spec,
|
||||
send_message_format_spec=send_message_format_spec,
|
||||
)
|
||||
|
||||
return base_prompt
|
||||
@@ -404,27 +400,6 @@ class PromptManager:
|
||||
"content as a text fallback and still completes the reply."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _generate_send_message_format_instructions(
|
||||
channel: MessageChannel = None,
|
||||
) -> str:
|
||||
"""
|
||||
根据渠道生成 send_message 工具的格式参数提示。
|
||||
"""
|
||||
if channel != MessageChannel.Telegram:
|
||||
return ""
|
||||
return (
|
||||
"- Telegram message formatting: `send_message` supports an optional "
|
||||
"`parse_mode` argument. Leave it empty for default MarkdownV2. When a "
|
||||
"structured Telegram notice would be clearer in HTML, set "
|
||||
"`parse_mode=\"HTML\"` and write the `message` using only Telegram-supported "
|
||||
"HTML tags such as `<b>`, `<i>`, `<u>`, `<s>`, `<code>`, `<pre>`, "
|
||||
"`<blockquote>`, and `<a href=\"...\">`. Keep `title` as plain text; "
|
||||
"the Telegram module renders it as a bold heading automatically. Escape "
|
||||
"user-provided or dynamic values before embedding them in HTML. Do "
|
||||
"not mix Markdown syntax into an HTML-formatted message."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _generate_button_choice_instructions(
|
||||
channel: MessageChannel = None,
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
"""发送消息工具"""
|
||||
|
||||
import html as html_utils
|
||||
import re
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
@@ -12,49 +10,6 @@ from app.log import logger
|
||||
from app.schemas import Notification
|
||||
from app.schemas.types import NotificationType
|
||||
|
||||
SEND_MESSAGE_PARSE_MODE_MARKDOWN = "MarkdownV2"
|
||||
SEND_MESSAGE_PARSE_MODE_HTML = "HTML"
|
||||
SEND_MESSAGE_PARSE_MODE_ALIASES = {
|
||||
"markdownv2": SEND_MESSAGE_PARSE_MODE_MARKDOWN,
|
||||
"mdv2": SEND_MESSAGE_PARSE_MODE_MARKDOWN,
|
||||
"html": SEND_MESSAGE_PARSE_MODE_HTML,
|
||||
}
|
||||
SEND_MESSAGE_HTML_ALLOWED_TAGS = {
|
||||
"a",
|
||||
"b",
|
||||
"blockquote",
|
||||
"code",
|
||||
"del",
|
||||
"em",
|
||||
"i",
|
||||
"ins",
|
||||
"pre",
|
||||
"s",
|
||||
"span",
|
||||
"strike",
|
||||
"strong",
|
||||
"tg-spoiler",
|
||||
"u",
|
||||
}
|
||||
SEND_MESSAGE_HTML_NORMALIZATION_RULES = (
|
||||
(re.compile(r"<\s*br\s*/?\s*>", re.IGNORECASE), "\n"),
|
||||
(re.compile(r"<\s*/\s*p\s*>", re.IGNORECASE), "\n"),
|
||||
(re.compile(r"<\s*p(?:\s+[^>]*)?>", re.IGNORECASE), ""),
|
||||
(re.compile(r"<\s*/\s*div\s*>", re.IGNORECASE), "\n"),
|
||||
(re.compile(r"<\s*div(?:\s+[^>]*)?>", re.IGNORECASE), ""),
|
||||
(re.compile(r"<\s*/\s*li\s*>", re.IGNORECASE), "\n"),
|
||||
(re.compile(r"<\s*li(?:\s+[^>]*)?>", re.IGNORECASE), "• "),
|
||||
(re.compile(r"<\s*/?\s*(?:ul|ol)(?:\s+[^>]*)?>", re.IGNORECASE), ""),
|
||||
(re.compile(r"<\s*h[1-6](?:\s+[^>]*)?>", re.IGNORECASE), "<b>"),
|
||||
(re.compile(r"<\s*/\s*h[1-6]\s*>", re.IGNORECASE), "</b>\n"),
|
||||
)
|
||||
SEND_MESSAGE_HTML_TAG_PATTERN = re.compile(
|
||||
r"<\s*(/?)\s*([a-zA-Z][\w:-]*)\b([^>]*)>"
|
||||
)
|
||||
SEND_MESSAGE_HTML_ATTR_PATTERN_TEMPLATE = (
|
||||
r"""\b{attr_name}\s*=\s*(?:"([^"]*)"|'([^']*)'|([^\s"'>]+))"""
|
||||
)
|
||||
|
||||
|
||||
class SendMessageInput(BaseModel):
|
||||
"""发送消息工具的输入参数模型"""
|
||||
@@ -75,22 +30,12 @@ class SendMessageInput(BaseModel):
|
||||
None,
|
||||
description="Optional image URL to send together with the message on channels that support images (such as Telegram and Slack)",
|
||||
)
|
||||
parse_mode: Optional[str] = Field(
|
||||
None,
|
||||
description=(
|
||||
"Optional Telegram message body format. Supported values: HTML or MarkdownV2. "
|
||||
"Leave empty for default."
|
||||
),
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_payload(self) -> "SendMessageInput":
|
||||
"""校验消息内容和可选格式参数。"""
|
||||
if not self.message and not self.title and not self.image_url:
|
||||
raise ValueError("message、title、image_url 至少需要提供一个")
|
||||
self.parse_mode = SendMessageTool.normalize_parse_mode(self.parse_mode)
|
||||
if self.parse_mode == SEND_MESSAGE_PARSE_MODE_HTML:
|
||||
self.message = SendMessageTool.normalize_html_message(self.message)
|
||||
return self
|
||||
|
||||
|
||||
@@ -109,92 +54,12 @@ class SendMessageTool(MoviePilotTool):
|
||||
description: str = (
|
||||
"Send notification message to the user through configured notification channels "
|
||||
"(Telegram, Slack, WeChat, etc.). Supports optional image_url on channels that can "
|
||||
"send images. For Telegram, the optional parse_mode parameter controls message body "
|
||||
"rendering. Supported values are HTML and MarkdownV2; leave it empty for default. "
|
||||
"This is a terminal response tool: after it sends the user-facing message, do not "
|
||||
"send another final text reply with the same content."
|
||||
"send images. This is a terminal response tool: after it sends the user-facing "
|
||||
"message, do not send another final text reply with the same content."
|
||||
)
|
||||
args_schema: Type[BaseModel] = SendMessageInput
|
||||
require_admin: bool = True
|
||||
|
||||
@staticmethod
|
||||
def normalize_parse_mode(parse_mode: Optional[str]) -> Optional[str]:
|
||||
"""
|
||||
规范化 send_message 支持的 Telegram 格式参数。
|
||||
"""
|
||||
if not parse_mode:
|
||||
return None
|
||||
normalized = SEND_MESSAGE_PARSE_MODE_ALIASES.get(str(parse_mode).strip().lower())
|
||||
if not normalized:
|
||||
raise ValueError("parse_mode 仅支持 MarkdownV2 或 HTML")
|
||||
return normalized
|
||||
|
||||
@staticmethod
|
||||
def _extract_html_attr(attrs: str, attr_name: str) -> Optional[str]:
|
||||
"""
|
||||
从 HTML 标签属性中提取指定属性值。
|
||||
"""
|
||||
pattern = SEND_MESSAGE_HTML_ATTR_PATTERN_TEMPLATE.format(
|
||||
attr_name=re.escape(attr_name)
|
||||
)
|
||||
match = re.search(pattern, attrs or "", re.IGNORECASE)
|
||||
if not match:
|
||||
return None
|
||||
return next((value for value in match.groups() if value is not None), None)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_html_tag(match: re.Match) -> str:
|
||||
"""
|
||||
规范化 Telegram 支持的 HTML 标签,并剥离不支持的属性。
|
||||
"""
|
||||
closing, tag_name, attrs = match.groups()
|
||||
tag_name = tag_name.lower()
|
||||
if tag_name not in SEND_MESSAGE_HTML_ALLOWED_TAGS:
|
||||
raise ValueError(f"HTML 标签 <{tag_name}> 不受 Telegram 支持")
|
||||
|
||||
if closing:
|
||||
return f"</{tag_name}>"
|
||||
|
||||
if tag_name == "a":
|
||||
href = SendMessageTool._extract_html_attr(attrs, "href")
|
||||
if not href:
|
||||
raise ValueError("HTML 标签 <a> 必须包含 href 属性")
|
||||
return f'<a href="{html_utils.escape(href, quote=True)}">'
|
||||
|
||||
if tag_name == "span":
|
||||
class_name = SendMessageTool._extract_html_attr(attrs, "class")
|
||||
if class_name != "tg-spoiler":
|
||||
raise ValueError('HTML 标签 <span> 仅支持 class="tg-spoiler"')
|
||||
return '<span class="tg-spoiler">'
|
||||
|
||||
if tag_name == "blockquote":
|
||||
if re.search(r"(^|\s)expandable(\s|/|$)", attrs or "", re.IGNORECASE):
|
||||
return "<blockquote expandable>"
|
||||
return "<blockquote>"
|
||||
|
||||
if tag_name == "code":
|
||||
class_name = SendMessageTool._extract_html_attr(attrs, "class")
|
||||
if class_name and class_name.startswith("language-"):
|
||||
escaped_class = html_utils.escape(class_name, quote=True)
|
||||
return f'<code class="{escaped_class}">'
|
||||
return "<code>"
|
||||
|
||||
return f"<{tag_name}>"
|
||||
|
||||
@staticmethod
|
||||
def normalize_html_message(message: Optional[str]) -> Optional[str]:
|
||||
"""
|
||||
规范化 Agent 生成的 Telegram HTML 正文。
|
||||
"""
|
||||
if not message:
|
||||
return message
|
||||
normalized = message
|
||||
for pattern, replacement in SEND_MESSAGE_HTML_NORMALIZATION_RULES:
|
||||
normalized = pattern.sub(replacement, normalized)
|
||||
return SEND_MESSAGE_HTML_TAG_PATTERN.sub(
|
||||
SendMessageTool._normalize_html_tag, normalized
|
||||
)
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据消息参数生成友好的提示消息"""
|
||||
message = kwargs.get("message", "") or ""
|
||||
@@ -218,22 +83,15 @@ class SendMessageTool(MoviePilotTool):
|
||||
message: Optional[str] = None,
|
||||
title: Optional[str] = None,
|
||||
image_url: Optional[str] = None,
|
||||
parse_mode: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""发送消息到当前会话渠道。"""
|
||||
title = title or ("图片" if image_url and not message else "")
|
||||
text = message or ""
|
||||
try:
|
||||
parse_mode = self.normalize_parse_mode(parse_mode)
|
||||
if parse_mode == SEND_MESSAGE_PARSE_MODE_HTML:
|
||||
text = self.normalize_html_message(text) or ""
|
||||
except ValueError as e:
|
||||
return str(e)
|
||||
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: title={title}, message={text}, "
|
||||
f"image_url={image_url}, parse_mode={parse_mode}"
|
||||
f"image_url={image_url}"
|
||||
)
|
||||
try:
|
||||
await self.send_notification_message(
|
||||
@@ -246,7 +104,6 @@ class SendMessageTool(MoviePilotTool):
|
||||
title=title,
|
||||
text=text,
|
||||
image=image_url,
|
||||
parse_mode=parse_mode,
|
||||
)
|
||||
)
|
||||
self._agent_context["user_reply_sent"] = True
|
||||
|
||||
@@ -56,6 +56,13 @@ class _FakeStreamingFailingAgent(_FakeFailingAgent):
|
||||
yield None
|
||||
|
||||
|
||||
class _FakeStreamingAgent(_FakeAgent):
|
||||
async def astream(self, _messages, **_kwargs):
|
||||
return
|
||||
# 保持 async generator 形态,当前用例不需要实际 token。
|
||||
yield None
|
||||
|
||||
|
||||
class StreamChunkTimeoutError(RuntimeError):
|
||||
"""模拟 langchain_openai 的流式分块超时异常。"""
|
||||
|
||||
@@ -191,6 +198,81 @@ class AgentBackgroundOutputTest(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertNotIn("Tune or disable", sent_message)
|
||||
self.assertEqual(expected, agent._streamed_output)
|
||||
|
||||
async def test_streaming_success_stops_streaming_once(self):
|
||||
"""流式正常完成时不应在 finally 中重复停止流式输出。"""
|
||||
agent = MoviePilotAgent(session_id="stream-ok", user_id="user-1")
|
||||
agent.channel = "Telegram"
|
||||
agent.source = "telegram-test"
|
||||
agent._tool_context = {"user_reply_sent": False}
|
||||
agent._streamed_output = ""
|
||||
agent.stream_handler = SimpleNamespace(
|
||||
set_dispatch_policy=lambda allow_dispatch_without_context=False: None,
|
||||
start_streaming=AsyncMock(),
|
||||
flush_pending_tool_summary=lambda: "",
|
||||
stop_streaming=AsyncMock(return_value=(True, "已发送")),
|
||||
)
|
||||
agent._should_stream = lambda: True
|
||||
agent._create_agent = AsyncMock(
|
||||
return_value=_FakeStreamingAgent([AIMessage(content="已发送")])
|
||||
)
|
||||
agent.send_agent_message = AsyncMock()
|
||||
|
||||
await agent._execute_agent([HumanMessage(content="测试")])
|
||||
|
||||
agent.stream_handler.stop_streaming.assert_awaited_once()
|
||||
|
||||
async def test_tool_sent_reply_does_not_persist_raw_agent_messages(self):
|
||||
"""工具已发送用户回复时不应把工具调用状态写入下一轮记忆。"""
|
||||
agent = MoviePilotAgent(session_id="tool-reply", user_id="user-1")
|
||||
agent.channel = "Telegram"
|
||||
agent.source = "telegram-test"
|
||||
agent._tool_context = {"user_reply_sent": True}
|
||||
agent._streamed_output = ""
|
||||
agent.stream_handler = SimpleNamespace(
|
||||
stop_streaming=AsyncMock(return_value=(False, ""))
|
||||
)
|
||||
agent._should_stream = lambda: False
|
||||
agent._create_agent = AsyncMock(
|
||||
return_value=_FakeAgent([AIMessage(content="消息已发送")])
|
||||
)
|
||||
agent.send_agent_message = AsyncMock()
|
||||
|
||||
with patch.object(memory_manager, "save_agent_messages") as save_messages:
|
||||
await agent._execute_agent([HumanMessage(content="测试")])
|
||||
|
||||
save_messages.assert_not_called()
|
||||
|
||||
async def test_process_does_not_mutate_cached_agent_messages(self):
|
||||
"""处理新消息时不应直接修改记忆缓存中的历史消息列表。"""
|
||||
agent = MoviePilotAgent(
|
||||
session_id="cached-memory",
|
||||
user_id="user-1",
|
||||
channel="Telegram",
|
||||
source="telegram-test",
|
||||
)
|
||||
cached_messages = [HumanMessage(content="上一轮")]
|
||||
captured = {}
|
||||
|
||||
async def _execute_agent(messages):
|
||||
captured["messages"] = messages
|
||||
return "消息已发送", {}
|
||||
|
||||
agent._execute_agent = AsyncMock(side_effect=_execute_agent)
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
memory_manager, "get_agent_messages", return_value=cached_messages
|
||||
),
|
||||
patch.object(agent, "prepare_chat_title", new=AsyncMock()),
|
||||
patch.object(agent, "_save_display_history_messages"),
|
||||
):
|
||||
result = await agent.process("继续")
|
||||
|
||||
self.assertEqual("消息已发送", result)
|
||||
self.assertEqual(1, len(cached_messages))
|
||||
self.assertIsNot(cached_messages, captured["messages"])
|
||||
self.assertEqual(2, len(captured["messages"]))
|
||||
|
||||
async def test_background_non_streaming_sends_when_reply_mode_dispatch(self):
|
||||
agent = MoviePilotAgent(session_id="bg-test", user_id="system")
|
||||
agent.channel = None
|
||||
|
||||
@@ -569,37 +569,6 @@ class AgentImageSupportTest(unittest.TestCase):
|
||||
|
||||
self.assertEqual(payload.image_url, "https://example.com/poster.png")
|
||||
|
||||
def test_send_message_input_normalizes_html_parse_mode(self):
|
||||
payload = SendMessageInput(
|
||||
explanation="send html notice",
|
||||
message="<b>处理完成</b>",
|
||||
parse_mode="html",
|
||||
)
|
||||
|
||||
self.assertEqual(payload.parse_mode, "HTML")
|
||||
|
||||
def test_send_message_input_normalizes_common_html_tags(self):
|
||||
payload = SendMessageInput(
|
||||
explanation="send html notice",
|
||||
message="<h1>标题</h1><p>第一行<br>第二行</p><ul><li>A</li></ul>",
|
||||
parse_mode="HTML",
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
payload.message,
|
||||
"<b>标题</b>\n第一行\n第二行\n• A\n",
|
||||
)
|
||||
|
||||
def test_send_message_input_rejects_unsupported_html_tags(self):
|
||||
with self.assertRaises(ValueError) as error:
|
||||
SendMessageInput(
|
||||
explanation="send html notice",
|
||||
message="<table><tr><td>A</td></tr></table>",
|
||||
parse_mode="HTML",
|
||||
)
|
||||
|
||||
self.assertIn("HTML 标签 <table> 不受 Telegram 支持", str(error.exception))
|
||||
|
||||
def test_send_message_tool_uses_regular_notification_type(self):
|
||||
"""发送消息工具应按普通通知消息登记。"""
|
||||
|
||||
@@ -619,7 +588,6 @@ class AgentImageSupportTest(unittest.TestCase):
|
||||
message="处理完成",
|
||||
title="智能体通知",
|
||||
image_url="https://example.com/poster.png",
|
||||
parse_mode="HTML",
|
||||
)
|
||||
return result, async_post_message
|
||||
|
||||
@@ -633,7 +601,35 @@ class AgentImageSupportTest(unittest.TestCase):
|
||||
self.assertEqual(notification.title, "智能体通知")
|
||||
self.assertEqual(notification.text, "处理完成")
|
||||
self.assertEqual(notification.image, "https://example.com/poster.png")
|
||||
self.assertEqual(notification.parse_mode, "HTML")
|
||||
self.assertIsNone(notification.parse_mode)
|
||||
|
||||
def test_send_message_tool_ignores_parse_mode_argument(self):
|
||||
"""发送消息工具不再支持由 Agent 指定 Telegram parse_mode。"""
|
||||
|
||||
async def _run():
|
||||
tool = SendMessageTool(session_id="session-1", user_id="10001")
|
||||
tool.set_message_attr(
|
||||
channel=MessageChannel.Telegram.value,
|
||||
source="telegram-test",
|
||||
username="tester",
|
||||
)
|
||||
|
||||
with patch(
|
||||
"app.agent.tools.base.ToolChain.async_post_message",
|
||||
new_callable=AsyncMock,
|
||||
) as async_post_message:
|
||||
result = await tool.run(
|
||||
message="<b>处理完成</b>",
|
||||
parse_mode="HTML",
|
||||
)
|
||||
return result, async_post_message
|
||||
|
||||
result, async_post_message = asyncio.run(_run())
|
||||
notification = async_post_message.await_args.args[0]
|
||||
|
||||
self.assertEqual(result, "消息已发送")
|
||||
self.assertEqual(notification.text, "<b>处理完成</b>")
|
||||
self.assertIsNone(notification.parse_mode)
|
||||
|
||||
def test_send_message_tool_marks_reply_sent_after_dispatch(self):
|
||||
"""发送消息工具成功发送后应终止本轮回复。"""
|
||||
@@ -652,7 +648,7 @@ class AgentImageSupportTest(unittest.TestCase):
|
||||
"app.agent.tools.base.ToolChain.async_post_message",
|
||||
new_callable=AsyncMock,
|
||||
):
|
||||
result = await tool.run(message="<b>处理完成</b>", parse_mode="HTML")
|
||||
result = await tool.run(message="处理完成")
|
||||
return result, agent_context
|
||||
|
||||
result, agent_context = asyncio.run(_run())
|
||||
@@ -661,58 +657,6 @@ class AgentImageSupportTest(unittest.TestCase):
|
||||
self.assertTrue(agent_context["user_reply_sent"])
|
||||
self.assertEqual(agent_context["reply_mode"], "send_message")
|
||||
|
||||
def test_send_message_tool_rejects_unsupported_html_before_dispatch(self):
|
||||
"""发送消息工具应在进入消息链路前拒绝不支持的 HTML。"""
|
||||
|
||||
async def _run():
|
||||
tool = SendMessageTool(session_id="session-1", user_id="10001")
|
||||
tool.set_message_attr(
|
||||
channel=MessageChannel.Telegram.value,
|
||||
source="telegram-test",
|
||||
username="tester",
|
||||
)
|
||||
|
||||
with patch(
|
||||
"app.agent.tools.base.ToolChain.async_post_message",
|
||||
new_callable=AsyncMock,
|
||||
) as async_post_message:
|
||||
result = await tool.run(
|
||||
message="<table><tr><td>A</td></tr></table>",
|
||||
parse_mode="HTML",
|
||||
)
|
||||
return result, async_post_message
|
||||
|
||||
result, async_post_message = asyncio.run(_run())
|
||||
|
||||
self.assertIn("HTML 标签 <table> 不受 Telegram 支持", result)
|
||||
async_post_message.assert_not_awaited()
|
||||
|
||||
def test_send_message_tool_rejects_invalid_parse_mode(self):
|
||||
"""发送消息工具应拒绝不支持的格式类型。"""
|
||||
|
||||
async def _run():
|
||||
tool = SendMessageTool(session_id="session-1", user_id="10001")
|
||||
tool.set_message_attr(
|
||||
channel=MessageChannel.Telegram.value,
|
||||
source="telegram-test",
|
||||
username="tester",
|
||||
)
|
||||
|
||||
with patch(
|
||||
"app.agent.tools.base.ToolChain.async_post_message",
|
||||
new_callable=AsyncMock,
|
||||
) as async_post_message:
|
||||
result = await tool.run(
|
||||
message="处理完成",
|
||||
parse_mode="Markdown",
|
||||
)
|
||||
return result, async_post_message
|
||||
|
||||
result, async_post_message = asyncio.run(_run())
|
||||
|
||||
self.assertIn("parse_mode 仅支持 MarkdownV2 或 HTML", result)
|
||||
async_post_message.assert_not_awaited()
|
||||
|
||||
def test_send_local_file_input_accepts_file_payload(self):
|
||||
payload = SendLocalFileInput(
|
||||
explanation="send generated report",
|
||||
|
||||
@@ -39,7 +39,7 @@ class TestAgentInteraction(unittest.TestCase):
|
||||
self.assertIn("do not write a final text reply after it", telegram_prompt)
|
||||
self.assertNotIn("ask_user_choice", wechat_prompt)
|
||||
|
||||
def test_prompt_injects_send_message_html_hint_only_for_telegram(self):
|
||||
def test_prompt_does_not_inject_send_message_html_hint(self):
|
||||
telegram_prompt = prompt_manager.get_agent_prompt(
|
||||
channel=MessageChannel.Telegram.value
|
||||
)
|
||||
@@ -47,9 +47,8 @@ class TestAgentInteraction(unittest.TestCase):
|
||||
channel=MessageChannel.Wechat.value
|
||||
)
|
||||
|
||||
self.assertIn("parse_mode=\"HTML\"", telegram_prompt)
|
||||
self.assertIn("Telegram-supported HTML tags", telegram_prompt)
|
||||
self.assertIn("Do not mix Markdown syntax", telegram_prompt)
|
||||
self.assertNotIn("parse_mode=\"HTML\"", telegram_prompt)
|
||||
self.assertNotIn("Telegram-supported HTML tags", telegram_prompt)
|
||||
self.assertNotIn("parse_mode=\"HTML\"", wechat_prompt)
|
||||
|
||||
def test_factory_injects_choice_tool_only_for_button_channels(self):
|
||||
|
||||
Reference in New Issue
Block a user