Simplify agent message handling and streaming cleanup

This commit is contained in:
jxxghp
2026-06-21 12:34:11 +08:00
parent d483b805d8
commit 495807ef4d
7 changed files with 126 additions and 265 deletions

View File

@@ -317,6 +317,8 @@ class MoviePilotAgent:
"""
判断当前 Agent 是否需要写入会话历史表。
"""
if self._tool_context.get("user_reply_sent"):
return False
return bool(self.channel and self.source)
def _save_display_history_messages(self, messages: List[dict]) -> None:
@@ -1102,9 +1104,9 @@ class MoviePilotAgent:
self._streamed_output = ""
# 获取历史消息
messages = memory_manager.get_agent_messages(
messages = list(memory_manager.get_agent_messages(
session_id=self.session_id, user_id=self.user_id
)
))
# 构建结构化用户消息内容
request_payload = {
@@ -1269,6 +1271,7 @@ class MoviePilotAgent:
self._agent_started_at = datetime.now()
self._llm_runtime_config = None
self._llm_provider_selection = {}
streaming_stopped = False
try:
# Agent运行配置
agent_config = {
@@ -1316,6 +1319,7 @@ class MoviePilotAgent:
all_sent_via_stream,
streamed_text,
) = await self.stream_handler.stop_streaming()
streaming_stopped = True
if not all_sent_via_stream:
# 流式输出未能发送全部内容(发送失败等)
@@ -1418,7 +1422,8 @@ class MoviePilotAgent:
error=execution_error,
)
# 确保停止流式输出
await self.stream_handler.stop_streaming()
if not streaming_stopped:
await self.stream_handler.stop_streaming()
async def send_agent_message(self, message: str, title: str = ""):
"""

View File

@@ -87,7 +87,6 @@ You act as a proactive agent. Your goal is to fully resolve the user's media-rel
{button_choice_spec}
- Voice replies: {voice_reply_spec}
- If the current channel supports image sending and an image would materially help, you may use the `send_message` tool with `image_url` to send it.
{send_message_format_spec}
- If the current channel supports file sending and you need to return a local image or file for the user to download, use `send_local_file`.
</communication_runtime>

View File

@@ -141,9 +141,6 @@ class PromptManager:
if caps:
markdown_spec = self._generate_formatting_instructions(caps)
button_choice_spec = self._generate_button_choice_instructions(msg_channel)
send_message_format_spec = self._generate_send_message_format_instructions(
msg_channel
)
# 啰嗦模式
verbose_spec = ""
@@ -169,7 +166,6 @@ class PromptManager:
moviepilot_info=moviepilot_info,
voice_reply_spec=voice_reply_spec,
button_choice_spec=button_choice_spec,
send_message_format_spec=send_message_format_spec,
)
return base_prompt
@@ -404,27 +400,6 @@ class PromptManager:
"content as a text fallback and still completes the reply."
)
@staticmethod
def _generate_send_message_format_instructions(
channel: MessageChannel = None,
) -> str:
"""
根据渠道生成 send_message 工具的格式参数提示。
"""
if channel != MessageChannel.Telegram:
return ""
return (
"- Telegram message formatting: `send_message` supports an optional "
"`parse_mode` argument. Leave it empty for default MarkdownV2. When a "
"structured Telegram notice would be clearer in HTML, set "
"`parse_mode=\"HTML\"` and write the `message` using only Telegram-supported "
"HTML tags such as `<b>`, `<i>`, `<u>`, `<s>`, `<code>`, `<pre>`, "
"`<blockquote>`, and `<a href=\"...\">`. Keep `title` as plain text; "
"the Telegram module renders it as a bold heading automatically. Escape "
"user-provided or dynamic values before embedding them in HTML. Do "
"not mix Markdown syntax into an HTML-formatted message."
)
@staticmethod
def _generate_button_choice_instructions(
channel: MessageChannel = None,

View File

@@ -1,7 +1,5 @@
"""发送消息工具"""
import html as html_utils
import re
from typing import Optional, Type
from pydantic import BaseModel, Field, model_validator
@@ -12,49 +10,6 @@ from app.log import logger
from app.schemas import Notification
from app.schemas.types import NotificationType
SEND_MESSAGE_PARSE_MODE_MARKDOWN = "MarkdownV2"
SEND_MESSAGE_PARSE_MODE_HTML = "HTML"
SEND_MESSAGE_PARSE_MODE_ALIASES = {
"markdownv2": SEND_MESSAGE_PARSE_MODE_MARKDOWN,
"mdv2": SEND_MESSAGE_PARSE_MODE_MARKDOWN,
"html": SEND_MESSAGE_PARSE_MODE_HTML,
}
SEND_MESSAGE_HTML_ALLOWED_TAGS = {
"a",
"b",
"blockquote",
"code",
"del",
"em",
"i",
"ins",
"pre",
"s",
"span",
"strike",
"strong",
"tg-spoiler",
"u",
}
SEND_MESSAGE_HTML_NORMALIZATION_RULES = (
(re.compile(r"<\s*br\s*/?\s*>", re.IGNORECASE), "\n"),
(re.compile(r"<\s*/\s*p\s*>", re.IGNORECASE), "\n"),
(re.compile(r"<\s*p(?:\s+[^>]*)?>", re.IGNORECASE), ""),
(re.compile(r"<\s*/\s*div\s*>", re.IGNORECASE), "\n"),
(re.compile(r"<\s*div(?:\s+[^>]*)?>", re.IGNORECASE), ""),
(re.compile(r"<\s*/\s*li\s*>", re.IGNORECASE), "\n"),
(re.compile(r"<\s*li(?:\s+[^>]*)?>", re.IGNORECASE), ""),
(re.compile(r"<\s*/?\s*(?:ul|ol)(?:\s+[^>]*)?>", re.IGNORECASE), ""),
(re.compile(r"<\s*h[1-6](?:\s+[^>]*)?>", re.IGNORECASE), "<b>"),
(re.compile(r"<\s*/\s*h[1-6]\s*>", re.IGNORECASE), "</b>\n"),
)
SEND_MESSAGE_HTML_TAG_PATTERN = re.compile(
r"<\s*(/?)\s*([a-zA-Z][\w:-]*)\b([^>]*)>"
)
SEND_MESSAGE_HTML_ATTR_PATTERN_TEMPLATE = (
r"""\b{attr_name}\s*=\s*(?:"([^"]*)"|'([^']*)'|([^\s"'>]+))"""
)
class SendMessageInput(BaseModel):
"""发送消息工具的输入参数模型"""
@@ -75,22 +30,12 @@ class SendMessageInput(BaseModel):
None,
description="Optional image URL to send together with the message on channels that support images (such as Telegram and Slack)",
)
parse_mode: Optional[str] = Field(
None,
description=(
"Optional Telegram message body format. Supported values: HTML or MarkdownV2. "
"Leave empty for default."
),
)
@model_validator(mode="after")
def validate_payload(self) -> "SendMessageInput":
"""校验消息内容和可选格式参数。"""
if not self.message and not self.title and not self.image_url:
raise ValueError("message、title、image_url 至少需要提供一个")
self.parse_mode = SendMessageTool.normalize_parse_mode(self.parse_mode)
if self.parse_mode == SEND_MESSAGE_PARSE_MODE_HTML:
self.message = SendMessageTool.normalize_html_message(self.message)
return self
@@ -109,92 +54,12 @@ class SendMessageTool(MoviePilotTool):
description: str = (
"Send notification message to the user through configured notification channels "
"(Telegram, Slack, WeChat, etc.). Supports optional image_url on channels that can "
"send images. For Telegram, the optional parse_mode parameter controls message body "
"rendering. Supported values are HTML and MarkdownV2; leave it empty for default. "
"This is a terminal response tool: after it sends the user-facing message, do not "
"send another final text reply with the same content."
"send images. This is a terminal response tool: after it sends the user-facing "
"message, do not send another final text reply with the same content."
)
args_schema: Type[BaseModel] = SendMessageInput
require_admin: bool = True
@staticmethod
def normalize_parse_mode(parse_mode: Optional[str]) -> Optional[str]:
"""
规范化 send_message 支持的 Telegram 格式参数。
"""
if not parse_mode:
return None
normalized = SEND_MESSAGE_PARSE_MODE_ALIASES.get(str(parse_mode).strip().lower())
if not normalized:
raise ValueError("parse_mode 仅支持 MarkdownV2 或 HTML")
return normalized
@staticmethod
def _extract_html_attr(attrs: str, attr_name: str) -> Optional[str]:
"""
从 HTML 标签属性中提取指定属性值。
"""
pattern = SEND_MESSAGE_HTML_ATTR_PATTERN_TEMPLATE.format(
attr_name=re.escape(attr_name)
)
match = re.search(pattern, attrs or "", re.IGNORECASE)
if not match:
return None
return next((value for value in match.groups() if value is not None), None)
@staticmethod
def _normalize_html_tag(match: re.Match) -> str:
"""
规范化 Telegram 支持的 HTML 标签,并剥离不支持的属性。
"""
closing, tag_name, attrs = match.groups()
tag_name = tag_name.lower()
if tag_name not in SEND_MESSAGE_HTML_ALLOWED_TAGS:
raise ValueError(f"HTML 标签 <{tag_name}> 不受 Telegram 支持")
if closing:
return f"</{tag_name}>"
if tag_name == "a":
href = SendMessageTool._extract_html_attr(attrs, "href")
if not href:
raise ValueError("HTML 标签 <a> 必须包含 href 属性")
return f'<a href="{html_utils.escape(href, quote=True)}">'
if tag_name == "span":
class_name = SendMessageTool._extract_html_attr(attrs, "class")
if class_name != "tg-spoiler":
raise ValueError('HTML 标签 <span> 仅支持 class="tg-spoiler"')
return '<span class="tg-spoiler">'
if tag_name == "blockquote":
if re.search(r"(^|\s)expandable(\s|/|$)", attrs or "", re.IGNORECASE):
return "<blockquote expandable>"
return "<blockquote>"
if tag_name == "code":
class_name = SendMessageTool._extract_html_attr(attrs, "class")
if class_name and class_name.startswith("language-"):
escaped_class = html_utils.escape(class_name, quote=True)
return f'<code class="{escaped_class}">'
return "<code>"
return f"<{tag_name}>"
@staticmethod
def normalize_html_message(message: Optional[str]) -> Optional[str]:
"""
规范化 Agent 生成的 Telegram HTML 正文。
"""
if not message:
return message
normalized = message
for pattern, replacement in SEND_MESSAGE_HTML_NORMALIZATION_RULES:
normalized = pattern.sub(replacement, normalized)
return SEND_MESSAGE_HTML_TAG_PATTERN.sub(
SendMessageTool._normalize_html_tag, normalized
)
def get_tool_message(self, **kwargs) -> Optional[str]:
"""根据消息参数生成友好的提示消息"""
message = kwargs.get("message", "") or ""
@@ -218,22 +83,15 @@ class SendMessageTool(MoviePilotTool):
message: Optional[str] = None,
title: Optional[str] = None,
image_url: Optional[str] = None,
parse_mode: Optional[str] = None,
**kwargs,
) -> str:
"""发送消息到当前会话渠道。"""
title = title or ("图片" if image_url and not message else "")
text = message or ""
try:
parse_mode = self.normalize_parse_mode(parse_mode)
if parse_mode == SEND_MESSAGE_PARSE_MODE_HTML:
text = self.normalize_html_message(text) or ""
except ValueError as e:
return str(e)
logger.info(
f"执行工具: {self.name}, 参数: title={title}, message={text}, "
f"image_url={image_url}, parse_mode={parse_mode}"
f"image_url={image_url}"
)
try:
await self.send_notification_message(
@@ -246,7 +104,6 @@ class SendMessageTool(MoviePilotTool):
title=title,
text=text,
image=image_url,
parse_mode=parse_mode,
)
)
self._agent_context["user_reply_sent"] = True

View File

@@ -56,6 +56,13 @@ class _FakeStreamingFailingAgent(_FakeFailingAgent):
yield None
class _FakeStreamingAgent(_FakeAgent):
async def astream(self, _messages, **_kwargs):
return
# 保持 async generator 形态,当前用例不需要实际 token。
yield None
class StreamChunkTimeoutError(RuntimeError):
"""模拟 langchain_openai 的流式分块超时异常。"""
@@ -191,6 +198,81 @@ class AgentBackgroundOutputTest(unittest.IsolatedAsyncioTestCase):
self.assertNotIn("Tune or disable", sent_message)
self.assertEqual(expected, agent._streamed_output)
async def test_streaming_success_stops_streaming_once(self):
"""流式正常完成时不应在 finally 中重复停止流式输出。"""
agent = MoviePilotAgent(session_id="stream-ok", user_id="user-1")
agent.channel = "Telegram"
agent.source = "telegram-test"
agent._tool_context = {"user_reply_sent": False}
agent._streamed_output = ""
agent.stream_handler = SimpleNamespace(
set_dispatch_policy=lambda allow_dispatch_without_context=False: None,
start_streaming=AsyncMock(),
flush_pending_tool_summary=lambda: "",
stop_streaming=AsyncMock(return_value=(True, "已发送")),
)
agent._should_stream = lambda: True
agent._create_agent = AsyncMock(
return_value=_FakeStreamingAgent([AIMessage(content="已发送")])
)
agent.send_agent_message = AsyncMock()
await agent._execute_agent([HumanMessage(content="测试")])
agent.stream_handler.stop_streaming.assert_awaited_once()
async def test_tool_sent_reply_does_not_persist_raw_agent_messages(self):
"""工具已发送用户回复时不应把工具调用状态写入下一轮记忆。"""
agent = MoviePilotAgent(session_id="tool-reply", user_id="user-1")
agent.channel = "Telegram"
agent.source = "telegram-test"
agent._tool_context = {"user_reply_sent": True}
agent._streamed_output = ""
agent.stream_handler = SimpleNamespace(
stop_streaming=AsyncMock(return_value=(False, ""))
)
agent._should_stream = lambda: False
agent._create_agent = AsyncMock(
return_value=_FakeAgent([AIMessage(content="消息已发送")])
)
agent.send_agent_message = AsyncMock()
with patch.object(memory_manager, "save_agent_messages") as save_messages:
await agent._execute_agent([HumanMessage(content="测试")])
save_messages.assert_not_called()
async def test_process_does_not_mutate_cached_agent_messages(self):
"""处理新消息时不应直接修改记忆缓存中的历史消息列表。"""
agent = MoviePilotAgent(
session_id="cached-memory",
user_id="user-1",
channel="Telegram",
source="telegram-test",
)
cached_messages = [HumanMessage(content="上一轮")]
captured = {}
async def _execute_agent(messages):
captured["messages"] = messages
return "消息已发送", {}
agent._execute_agent = AsyncMock(side_effect=_execute_agent)
with (
patch.object(
memory_manager, "get_agent_messages", return_value=cached_messages
),
patch.object(agent, "prepare_chat_title", new=AsyncMock()),
patch.object(agent, "_save_display_history_messages"),
):
result = await agent.process("继续")
self.assertEqual("消息已发送", result)
self.assertEqual(1, len(cached_messages))
self.assertIsNot(cached_messages, captured["messages"])
self.assertEqual(2, len(captured["messages"]))
async def test_background_non_streaming_sends_when_reply_mode_dispatch(self):
agent = MoviePilotAgent(session_id="bg-test", user_id="system")
agent.channel = None

View File

@@ -569,37 +569,6 @@ class AgentImageSupportTest(unittest.TestCase):
self.assertEqual(payload.image_url, "https://example.com/poster.png")
def test_send_message_input_normalizes_html_parse_mode(self):
payload = SendMessageInput(
explanation="send html notice",
message="<b>处理完成</b>",
parse_mode="html",
)
self.assertEqual(payload.parse_mode, "HTML")
def test_send_message_input_normalizes_common_html_tags(self):
payload = SendMessageInput(
explanation="send html notice",
message="<h1>标题</h1><p>第一行<br>第二行</p><ul><li>A</li></ul>",
parse_mode="HTML",
)
self.assertEqual(
payload.message,
"<b>标题</b>\n第一行\n第二行\n• A\n",
)
def test_send_message_input_rejects_unsupported_html_tags(self):
with self.assertRaises(ValueError) as error:
SendMessageInput(
explanation="send html notice",
message="<table><tr><td>A</td></tr></table>",
parse_mode="HTML",
)
self.assertIn("HTML 标签 <table> 不受 Telegram 支持", str(error.exception))
def test_send_message_tool_uses_regular_notification_type(self):
"""发送消息工具应按普通通知消息登记。"""
@@ -619,7 +588,6 @@ class AgentImageSupportTest(unittest.TestCase):
message="处理完成",
title="智能体通知",
image_url="https://example.com/poster.png",
parse_mode="HTML",
)
return result, async_post_message
@@ -633,7 +601,35 @@ class AgentImageSupportTest(unittest.TestCase):
self.assertEqual(notification.title, "智能体通知")
self.assertEqual(notification.text, "处理完成")
self.assertEqual(notification.image, "https://example.com/poster.png")
self.assertEqual(notification.parse_mode, "HTML")
self.assertIsNone(notification.parse_mode)
def test_send_message_tool_ignores_parse_mode_argument(self):
"""发送消息工具不再支持由 Agent 指定 Telegram parse_mode。"""
async def _run():
tool = SendMessageTool(session_id="session-1", user_id="10001")
tool.set_message_attr(
channel=MessageChannel.Telegram.value,
source="telegram-test",
username="tester",
)
with patch(
"app.agent.tools.base.ToolChain.async_post_message",
new_callable=AsyncMock,
) as async_post_message:
result = await tool.run(
message="<b>处理完成</b>",
parse_mode="HTML",
)
return result, async_post_message
result, async_post_message = asyncio.run(_run())
notification = async_post_message.await_args.args[0]
self.assertEqual(result, "消息已发送")
self.assertEqual(notification.text, "<b>处理完成</b>")
self.assertIsNone(notification.parse_mode)
def test_send_message_tool_marks_reply_sent_after_dispatch(self):
"""发送消息工具成功发送后应终止本轮回复。"""
@@ -652,7 +648,7 @@ class AgentImageSupportTest(unittest.TestCase):
"app.agent.tools.base.ToolChain.async_post_message",
new_callable=AsyncMock,
):
result = await tool.run(message="<b>处理完成</b>", parse_mode="HTML")
result = await tool.run(message="处理完成")
return result, agent_context
result, agent_context = asyncio.run(_run())
@@ -661,58 +657,6 @@ class AgentImageSupportTest(unittest.TestCase):
self.assertTrue(agent_context["user_reply_sent"])
self.assertEqual(agent_context["reply_mode"], "send_message")
def test_send_message_tool_rejects_unsupported_html_before_dispatch(self):
"""发送消息工具应在进入消息链路前拒绝不支持的 HTML。"""
async def _run():
tool = SendMessageTool(session_id="session-1", user_id="10001")
tool.set_message_attr(
channel=MessageChannel.Telegram.value,
source="telegram-test",
username="tester",
)
with patch(
"app.agent.tools.base.ToolChain.async_post_message",
new_callable=AsyncMock,
) as async_post_message:
result = await tool.run(
message="<table><tr><td>A</td></tr></table>",
parse_mode="HTML",
)
return result, async_post_message
result, async_post_message = asyncio.run(_run())
self.assertIn("HTML 标签 <table> 不受 Telegram 支持", result)
async_post_message.assert_not_awaited()
def test_send_message_tool_rejects_invalid_parse_mode(self):
"""发送消息工具应拒绝不支持的格式类型。"""
async def _run():
tool = SendMessageTool(session_id="session-1", user_id="10001")
tool.set_message_attr(
channel=MessageChannel.Telegram.value,
source="telegram-test",
username="tester",
)
with patch(
"app.agent.tools.base.ToolChain.async_post_message",
new_callable=AsyncMock,
) as async_post_message:
result = await tool.run(
message="处理完成",
parse_mode="Markdown",
)
return result, async_post_message
result, async_post_message = asyncio.run(_run())
self.assertIn("parse_mode 仅支持 MarkdownV2 或 HTML", result)
async_post_message.assert_not_awaited()
def test_send_local_file_input_accepts_file_payload(self):
payload = SendLocalFileInput(
explanation="send generated report",

View File

@@ -39,7 +39,7 @@ class TestAgentInteraction(unittest.TestCase):
self.assertIn("do not write a final text reply after it", telegram_prompt)
self.assertNotIn("ask_user_choice", wechat_prompt)
def test_prompt_injects_send_message_html_hint_only_for_telegram(self):
def test_prompt_does_not_inject_send_message_html_hint(self):
telegram_prompt = prompt_manager.get_agent_prompt(
channel=MessageChannel.Telegram.value
)
@@ -47,9 +47,8 @@ class TestAgentInteraction(unittest.TestCase):
channel=MessageChannel.Wechat.value
)
self.assertIn("parse_mode=\"HTML\"", telegram_prompt)
self.assertIn("Telegram-supported HTML tags", telegram_prompt)
self.assertIn("Do not mix Markdown syntax", telegram_prompt)
self.assertNotIn("parse_mode=\"HTML\"", telegram_prompt)
self.assertNotIn("Telegram-supported HTML tags", telegram_prompt)
self.assertNotIn("parse_mode=\"HTML\"", wechat_prompt)
def test_factory_injects_choice_tool_only_for_button_channels(self):