mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-04-01 01:41:59 +08:00
354 lines
14 KiB
Python
354 lines
14 KiB
Python
import asyncio
|
||
import threading
|
||
from typing import Optional, Tuple
|
||
|
||
from app.chain import ChainBase
|
||
from app.log import logger
|
||
from app.schemas import Notification
|
||
from app.schemas.message import (
|
||
MessageResponse,
|
||
ChannelCapabilityManager,
|
||
ChannelCapability,
|
||
)
|
||
from app.schemas.types import MessageChannel
|
||
|
||
|
||
class _StreamChain(ChainBase):
|
||
pass
|
||
|
||
|
||
class StreamingHandler:
|
||
"""
|
||
流式Token缓冲管理器
|
||
|
||
负责从 LLM 流式 token 中积累文本,并在支持消息编辑的渠道上实时推送给用户。
|
||
|
||
工作流程:
|
||
1. Agent开始处理时调用 start_streaming(),检查渠道能力并启动定时刷新
|
||
2. LLM 产生 token 时调用 emit() 积累到缓冲区
|
||
3. 定时器周期性调用 _flush():
|
||
- 第一次有内容时发送新消息(通过 send_direct_message 获取 message_id)
|
||
- 后续有新内容时编辑同一条消息(通过 edit_message)
|
||
- 当消息长度接近渠道限制时,冻结当前消息并发送新消息继续输出
|
||
4. 工具调用时:
|
||
- 流式渠道:工具消息直接 emit() 追加到 buffer,与 Agent 文字合并为同一条流式消息
|
||
- 非流式渠道:调用 take() 取出已积累的文字,与工具消息合并独立发送
|
||
5. Agent最终完成时调用 stop_streaming():执行最后一次刷新,
|
||
返回是否已通过流式发送完所有内容(调用方据此决定是否还需额外发送)
|
||
"""
|
||
|
||
# 流式输出的刷新间隔(秒)
|
||
FLUSH_INTERVAL = 1.0
|
||
|
||
def __init__(self):
|
||
self._lock = threading.Lock()
|
||
self._buffer = ""
|
||
# 流式输出相关状态
|
||
self._streaming_enabled = False
|
||
self._flush_task: Optional[asyncio.Task] = None
|
||
# 当前消息的发送信息(用于编辑消息)
|
||
self._message_response: Optional[MessageResponse] = None
|
||
# 已发送给用户的文本(用于追踪增量)
|
||
self._sent_text = ""
|
||
# 当前消息的起始偏移量(buffer 中属于当前消息的起始位置)
|
||
self._msg_start_offset = 0
|
||
# 当前渠道的单条消息最大长度(0 表示不限制)
|
||
self._max_message_length = 0
|
||
# 消息发送所需的上下文信息
|
||
self._channel: Optional[str] = None
|
||
self._source: Optional[str] = None
|
||
self._user_id: Optional[str] = None
|
||
self._username: Optional[str] = None
|
||
self._title: str = ""
|
||
|
||
def emit(self, token: str):
|
||
"""
|
||
接收 LLM 流式 token,积累到缓冲区。
|
||
"""
|
||
with self._lock:
|
||
# 如果存量消息结束是两个换行,则去掉新消息前面的换行,避免过多空行
|
||
if self._buffer.endswith("\n\n") and token.startswith("\n"):
|
||
token = token.lstrip("\n")
|
||
self._buffer += token
|
||
|
||
async def take(self) -> str:
|
||
"""
|
||
获取当前已积累的消息内容,获取后清空缓冲区。
|
||
|
||
用于非流式渠道:工具调用前取出 Agent 已产出的文字,
|
||
与工具提示合并后独立发送。
|
||
|
||
注意:流式渠道不调用此方法,工具消息直接 emit 到 buffer 中。
|
||
"""
|
||
with self._lock:
|
||
if not self._buffer:
|
||
return ""
|
||
message = self._buffer
|
||
logger.info(f"Agent消息: {message}")
|
||
self._buffer = ""
|
||
return message
|
||
|
||
def clear(self):
|
||
"""
|
||
清空缓冲区(不返回内容)
|
||
"""
|
||
with self._lock:
|
||
self._buffer = ""
|
||
self._sent_text = ""
|
||
self._message_response = None
|
||
self._msg_start_offset = 0
|
||
|
||
def reset(self):
|
||
"""
|
||
重置缓冲区,清空已发送的文本从头更新,但保持消息编辑能力。
|
||
|
||
与 clear 的区别:
|
||
- clear:完全重置所有状态,后续会开新消息
|
||
- reset:只清空buffer,保留消息编辑状态,后续继续编辑同一条消息
|
||
"""
|
||
with self._lock:
|
||
self._buffer = ""
|
||
self._sent_text = ""
|
||
self._msg_start_offset = 0
|
||
|
||
async def start_streaming(
|
||
self,
|
||
channel: Optional[str] = None,
|
||
source: Optional[str] = None,
|
||
user_id: Optional[str] = None,
|
||
username: Optional[str] = None,
|
||
title: str = "",
|
||
):
|
||
"""
|
||
启动流式输出。检查渠道是否支持消息编辑,如果支持则启动定时刷新任务。
|
||
:param channel: 消息渠道
|
||
:param source: 消息来源
|
||
:param user_id: 用户ID
|
||
:param username: 用户名
|
||
:param title: 消息标题
|
||
"""
|
||
self._channel = channel
|
||
self._source = source
|
||
self._user_id = user_id
|
||
self._username = username
|
||
self._title = title
|
||
|
||
# 检查渠道是否支持消息编辑
|
||
if not self._can_stream():
|
||
logger.debug(f"渠道 {channel} 不支持消息编辑,不启用流式输出")
|
||
return
|
||
|
||
self._streaming_enabled = True
|
||
self._sent_text = ""
|
||
self._message_response = None
|
||
self._msg_start_offset = 0
|
||
|
||
# 从渠道能力中获取单条消息最大长度
|
||
try:
|
||
channel_enum = MessageChannel(self._channel)
|
||
self._max_message_length = ChannelCapabilityManager.get_max_message_length(
|
||
channel_enum
|
||
)
|
||
except (ValueError, KeyError):
|
||
self._max_message_length = 0
|
||
|
||
# 启动异步定时刷新任务
|
||
self._flush_task = asyncio.create_task(self._flush_loop())
|
||
logger.debug("流式输出已启动")
|
||
|
||
async def stop_streaming(self) -> Tuple[bool, str]:
|
||
"""
|
||
停止流式输出。执行最后一次刷新确保所有内容都已发送。
|
||
:return: (all_sent, final_text)
|
||
all_sent: 是否已经通过流式编辑将最终完整内容发送给了用户
|
||
(True 表示调用方无需再额外发送消息)
|
||
final_text: 流式发送的完整文本内容(用于调用方保存消息记录)
|
||
"""
|
||
if not self._streaming_enabled:
|
||
return False, ""
|
||
|
||
self._streaming_enabled = False
|
||
|
||
# 取消定时任务
|
||
await self._cancel_flush_task()
|
||
|
||
# 执行最后一次刷新
|
||
await self._flush()
|
||
|
||
# 检查是否所有缓冲内容都已发送
|
||
with self._lock:
|
||
# 当前消息的文本 = buffer 中从 _msg_start_offset 开始的部分
|
||
current_msg_text = self._buffer[self._msg_start_offset :]
|
||
all_sent = (
|
||
self._message_response is not None
|
||
and self._sent_text
|
||
and current_msg_text == self._sent_text
|
||
)
|
||
# 保留最终文本用于返回(返回完整 buffer 内容,包含所有分段消息)
|
||
final_text = self._buffer if all_sent else ""
|
||
# 重置状态
|
||
self._sent_text = ""
|
||
self._message_response = None
|
||
self._msg_start_offset = 0
|
||
if all_sent:
|
||
# 所有内容已通过流式发送,清空缓冲区
|
||
self._buffer = ""
|
||
return all_sent, final_text
|
||
|
||
def _can_stream(self) -> bool:
|
||
"""
|
||
检查当前渠道是否支持流式输出(消息编辑)
|
||
"""
|
||
if not self._channel:
|
||
return False
|
||
try:
|
||
channel_enum = MessageChannel(self._channel)
|
||
return ChannelCapabilityManager.supports_capability(
|
||
channel_enum, ChannelCapability.MESSAGE_EDITING
|
||
)
|
||
except (ValueError, KeyError):
|
||
return False
|
||
|
||
async def _flush_loop(self):
|
||
"""
|
||
定时刷新循环,定期将缓冲区内容发送/编辑到用户
|
||
"""
|
||
try:
|
||
while self._streaming_enabled:
|
||
await asyncio.sleep(self.FLUSH_INTERVAL)
|
||
if self._streaming_enabled:
|
||
await self._flush()
|
||
except asyncio.CancelledError:
|
||
pass
|
||
except Exception as e:
|
||
logger.error(f"流式刷新异常: {e}")
|
||
|
||
async def _cancel_flush_task(self):
|
||
"""
|
||
取消当前的定时刷新任务
|
||
"""
|
||
if self._flush_task and not self._flush_task.done():
|
||
self._flush_task.cancel()
|
||
try:
|
||
await self._flush_task
|
||
except asyncio.CancelledError:
|
||
pass
|
||
self._flush_task = None
|
||
|
||
async def _flush(self):
|
||
"""
|
||
将当前缓冲区内容刷新到用户消息
|
||
- 如果还没有发送过消息,先发送一条新消息并记录message_id
|
||
- 如果已经发送过消息,编辑该消息为最新的完整内容
|
||
- 如果当前消息内容超过长度限制,冻结当前消息并发送新消息继续输出
|
||
"""
|
||
with self._lock:
|
||
# 当前消息的文本 = buffer 中从 _msg_start_offset 开始的部分
|
||
current_text = self._buffer[self._msg_start_offset :]
|
||
if not current_text or current_text == self._sent_text:
|
||
# 没有新内容需要刷新
|
||
return
|
||
|
||
chain = _StreamChain()
|
||
|
||
try:
|
||
if self._message_response is None:
|
||
# 第一次发送:发送新消息并获取 message_id
|
||
response = chain.send_direct_message(
|
||
Notification(
|
||
channel=self._channel,
|
||
source=self._source,
|
||
userid=self._user_id,
|
||
username=self._username,
|
||
title=self._title,
|
||
text=current_text,
|
||
)
|
||
)
|
||
if response and response.success and response.message_id:
|
||
self._message_response = response
|
||
with self._lock:
|
||
self._sent_text = current_text
|
||
logger.debug(
|
||
f"流式输出初始消息已发送: message_id={response.message_id}"
|
||
)
|
||
else:
|
||
logger.debug(
|
||
"流式输出初始消息发送失败或未返回message_id,降级为非流式输出"
|
||
)
|
||
self._streaming_enabled = False
|
||
else:
|
||
# 检查当前消息内容是否超过长度限制
|
||
if (
|
||
self._max_message_length
|
||
and len(current_text) > self._max_message_length
|
||
):
|
||
# 消息过长,冻结当前消息(保持最后一次成功编辑的内容)
|
||
# 将 offset 移动到已发送文本之后,开启新消息
|
||
logger.debug(
|
||
f"流式消息长度 {len(current_text)} 超过限制 {self._max_message_length},启用新消息"
|
||
)
|
||
with self._lock:
|
||
self._msg_start_offset += len(self._sent_text)
|
||
current_text = self._buffer[self._msg_start_offset :]
|
||
self._message_response = None
|
||
self._sent_text = ""
|
||
|
||
# 如果偏移后还有新内容,立即发送为新消息
|
||
if current_text:
|
||
response = chain.send_direct_message(
|
||
Notification(
|
||
channel=self._channel,
|
||
source=self._source,
|
||
userid=self._user_id,
|
||
username=self._username,
|
||
title=self._title,
|
||
text=current_text,
|
||
)
|
||
)
|
||
if response and response.success and response.message_id:
|
||
self._message_response = response
|
||
with self._lock:
|
||
self._sent_text = current_text
|
||
logger.debug(
|
||
f"流式输出新消息已发送: message_id={response.message_id}"
|
||
)
|
||
else:
|
||
logger.debug("流式输出新消息发送失败,降级为非流式输出")
|
||
self._streaming_enabled = False
|
||
else:
|
||
# 后续更新:编辑已有消息
|
||
try:
|
||
channel_enum = MessageChannel(self._channel)
|
||
except (ValueError, KeyError):
|
||
return
|
||
|
||
success = chain.edit_message(
|
||
channel=channel_enum,
|
||
source=self._message_response.source,
|
||
message_id=self._message_response.message_id,
|
||
chat_id=self._message_response.chat_id,
|
||
text=current_text,
|
||
title=self._title,
|
||
)
|
||
if success:
|
||
with self._lock:
|
||
self._sent_text = current_text
|
||
else:
|
||
logger.debug("流式输出消息编辑失败")
|
||
except Exception as e:
|
||
logger.error(f"流式输出刷新失败: {e}")
|
||
|
||
@property
|
||
def is_streaming(self) -> bool:
|
||
"""
|
||
是否正在流式输出
|
||
"""
|
||
return self._streaming_enabled
|
||
|
||
@property
|
||
def has_sent_message(self) -> bool:
|
||
"""
|
||
是否已经通过流式输出发送过消息(当前轮次)
|
||
"""
|
||
return self._message_response is not None
|