mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-03-31 17:32:30 +08:00
274 lines
9.3 KiB
Python
274 lines
9.3 KiB
Python
import asyncio
|
||
import threading
|
||
from typing import Optional
|
||
|
||
from app.chain import ChainBase
|
||
from app.log import logger
|
||
from app.schemas import Notification
|
||
from app.schemas.message import (
|
||
MessageResponse,
|
||
ChannelCapabilityManager,
|
||
ChannelCapability,
|
||
)
|
||
from app.schemas.types import MessageChannel
|
||
|
||
|
||
class _StreamChain(ChainBase):
|
||
pass
|
||
|
||
|
||
class StreamingHandler:
|
||
"""
|
||
流式Token缓冲管理器
|
||
|
||
负责从 LLM 流式 token 中积累文本,并在支持消息编辑的渠道上实时推送给用户。
|
||
|
||
工作流程:
|
||
1. Agent开始处理时调用 start_streaming(),检查渠道能力并启动定时刷新
|
||
2. LLM 产生 token 时调用 emit() 积累到缓冲区
|
||
3. 定时器周期性调用 _flush():
|
||
- 第一次有内容时发送新消息(通过 send_direct_message 获取 message_id)
|
||
- 后续有新内容时编辑同一条消息(通过 edit_message)
|
||
4. 工具调用时:
|
||
- 流式渠道:工具消息直接 emit() 追加到 buffer,与 Agent 文字合并为同一条流式消息
|
||
- 非流式渠道:调用 take() 取出已积累的文字,与工具消息合并独立发送
|
||
5. Agent最终完成时调用 stop_streaming():执行最后一次刷新,
|
||
返回是否已通过流式发送完所有内容(调用方据此决定是否还需额外发送)
|
||
"""
|
||
|
||
# 流式输出的刷新间隔(秒)
|
||
FLUSH_INTERVAL = 1.0
|
||
|
||
def __init__(self):
|
||
self._lock = threading.Lock()
|
||
self._buffer = ""
|
||
# 流式输出相关状态
|
||
self._streaming_enabled = False
|
||
self._flush_task: Optional[asyncio.Task] = None
|
||
# 当前消息的发送信息(用于编辑消息)
|
||
self._message_response: Optional[MessageResponse] = None
|
||
# 已发送给用户的文本(用于追踪增量)
|
||
self._sent_text = ""
|
||
# 消息发送所需的上下文信息
|
||
self._channel: Optional[str] = None
|
||
self._source: Optional[str] = None
|
||
self._user_id: Optional[str] = None
|
||
self._username: Optional[str] = None
|
||
self._title: str = ""
|
||
|
||
def emit(self, token: str):
|
||
"""
|
||
接收 LLM 流式 token,积累到缓冲区。
|
||
"""
|
||
with self._lock:
|
||
self._buffer += token
|
||
|
||
async def take(self) -> str:
|
||
"""
|
||
获取当前已积累的消息内容,获取后清空缓冲区。
|
||
|
||
用于非流式渠道:工具调用前取出 Agent 已产出的文字,
|
||
与工具提示合并后独立发送。
|
||
|
||
注意:流式渠道不调用此方法,工具消息直接 emit 到 buffer 中。
|
||
"""
|
||
with self._lock:
|
||
if not self._buffer:
|
||
return ""
|
||
message = self._buffer
|
||
logger.info(f"Agent消息: {message}")
|
||
self._buffer = ""
|
||
return message
|
||
|
||
def clear(self):
|
||
"""
|
||
清空缓冲区(不返回内容)
|
||
"""
|
||
with self._lock:
|
||
self._buffer = ""
|
||
self._sent_text = ""
|
||
self._message_response = None
|
||
|
||
async def start_streaming(
|
||
self,
|
||
channel: Optional[str] = None,
|
||
source: Optional[str] = None,
|
||
user_id: Optional[str] = None,
|
||
username: Optional[str] = None,
|
||
title: str = "",
|
||
):
|
||
"""
|
||
启动流式输出。检查渠道是否支持消息编辑,如果支持则启动定时刷新任务。
|
||
:param channel: 消息渠道
|
||
:param source: 消息来源
|
||
:param user_id: 用户ID
|
||
:param username: 用户名
|
||
:param title: 消息标题
|
||
"""
|
||
self._channel = channel
|
||
self._source = source
|
||
self._user_id = user_id
|
||
self._username = username
|
||
self._title = title
|
||
|
||
# 检查渠道是否支持消息编辑
|
||
if not self._can_stream():
|
||
logger.debug(f"渠道 {channel} 不支持消息编辑,不启用流式输出")
|
||
return
|
||
|
||
self._streaming_enabled = True
|
||
self._sent_text = ""
|
||
self._message_response = None
|
||
|
||
# 启动异步定时刷新任务
|
||
self._flush_task = asyncio.create_task(self._flush_loop())
|
||
logger.debug("流式输出已启动")
|
||
|
||
async def stop_streaming(self) -> bool:
|
||
"""
|
||
停止流式输出。执行最后一次刷新确保所有内容都已发送。
|
||
:return: 是否已经通过流式编辑将最终完整内容发送给了用户
|
||
(True 表示调用方无需再额外发送消息)
|
||
"""
|
||
if not self._streaming_enabled:
|
||
return False
|
||
|
||
self._streaming_enabled = False
|
||
|
||
# 取消定时任务
|
||
await self._cancel_flush_task()
|
||
|
||
# 执行最后一次刷新
|
||
await self._flush()
|
||
|
||
# 检查是否所有缓冲内容都已发送
|
||
with self._lock:
|
||
all_sent = (
|
||
self._message_response is not None
|
||
and self._sent_text
|
||
and self._buffer == self._sent_text
|
||
)
|
||
# 重置状态
|
||
self._sent_text = ""
|
||
self._message_response = None
|
||
if all_sent:
|
||
# 所有内容已通过流式发送,清空缓冲区
|
||
self._buffer = ""
|
||
return all_sent
|
||
|
||
def _can_stream(self) -> bool:
|
||
"""
|
||
检查当前渠道是否支持流式输出(消息编辑)
|
||
"""
|
||
if not self._channel:
|
||
return False
|
||
try:
|
||
channel_enum = MessageChannel(self._channel)
|
||
return ChannelCapabilityManager.supports_capability(
|
||
channel_enum, ChannelCapability.MESSAGE_EDITING
|
||
)
|
||
except (ValueError, KeyError):
|
||
return False
|
||
|
||
async def _flush_loop(self):
|
||
"""
|
||
定时刷新循环,定期将缓冲区内容发送/编辑到用户
|
||
"""
|
||
try:
|
||
while self._streaming_enabled:
|
||
await asyncio.sleep(self.FLUSH_INTERVAL)
|
||
if self._streaming_enabled:
|
||
await self._flush()
|
||
except asyncio.CancelledError:
|
||
pass
|
||
except Exception as e:
|
||
logger.error(f"流式刷新异常: {e}")
|
||
|
||
async def _cancel_flush_task(self):
|
||
"""
|
||
取消当前的定时刷新任务
|
||
"""
|
||
if self._flush_task and not self._flush_task.done():
|
||
self._flush_task.cancel()
|
||
try:
|
||
await self._flush_task
|
||
except asyncio.CancelledError:
|
||
pass
|
||
self._flush_task = None
|
||
|
||
async def _flush(self):
|
||
"""
|
||
将当前缓冲区内容刷新到用户消息
|
||
- 如果还没有发送过消息,先发送一条新消息并记录message_id
|
||
- 如果已经发送过消息,编辑该消息为最新的完整内容
|
||
"""
|
||
with self._lock:
|
||
current_text = self._buffer
|
||
if not current_text or current_text == self._sent_text:
|
||
# 没有新内容需要刷新
|
||
return
|
||
|
||
chain = _StreamChain()
|
||
|
||
try:
|
||
if self._message_response is None:
|
||
# 第一次发送:发送新消息并获取 message_id
|
||
response = chain.send_direct_message(
|
||
Notification(
|
||
channel=self._channel,
|
||
source=self._source,
|
||
userid=self._user_id,
|
||
username=self._username,
|
||
title=self._title,
|
||
text=current_text,
|
||
)
|
||
)
|
||
if response and response.success and response.message_id:
|
||
self._message_response = response
|
||
with self._lock:
|
||
self._sent_text = current_text
|
||
logger.debug(
|
||
f"流式输出初始消息已发送: message_id={response.message_id}"
|
||
)
|
||
else:
|
||
logger.debug(
|
||
"流式输出初始消息发送失败或未返回message_id,降级为非流式输出"
|
||
)
|
||
self._streaming_enabled = False
|
||
else:
|
||
# 后续更新:编辑已有消息
|
||
try:
|
||
channel_enum = MessageChannel(self._channel)
|
||
except (ValueError, KeyError):
|
||
return
|
||
|
||
success = chain.edit_message(
|
||
channel=channel_enum,
|
||
source=self._message_response.source,
|
||
message_id=self._message_response.message_id,
|
||
chat_id=self._message_response.chat_id,
|
||
text=current_text,
|
||
title=self._title,
|
||
)
|
||
if success:
|
||
with self._lock:
|
||
self._sent_text = current_text
|
||
else:
|
||
logger.debug("流式输出消息编辑失败")
|
||
except Exception as e:
|
||
logger.error(f"流式输出刷新失败: {e}")
|
||
|
||
@property
|
||
def is_streaming(self) -> bool:
|
||
"""
|
||
是否正在流式输出
|
||
"""
|
||
return self._streaming_enabled
|
||
|
||
@property
|
||
def has_sent_message(self) -> bool:
|
||
"""
|
||
是否已经通过流式输出发送过消息(当前轮次)
|
||
"""
|
||
return self._message_response is not None
|