feat: add agent session usage status reporting

Track per-session model and token usage so users can inspect context pressure and cumulative usage with /session_status.
This commit is contained in:
jxxghp
2026-04-26 08:19:05 +08:00
parent 79bfeaf2af
commit 0277288a41
5 changed files with 552 additions and 12 deletions

View File

@@ -4,7 +4,8 @@ import re
import traceback
import uuid
from dataclasses import dataclass
from typing import Callable, Dict, List, Optional
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional
from langchain.agents import create_agent
from langchain.agents.middleware import (
@@ -24,6 +25,7 @@ from app.agent.middleware.jobs import JobsMiddleware
from app.agent.middleware.memory import MemoryMiddleware
from app.agent.middleware.patch_tool_calls import PatchToolCallsMiddleware
from app.agent.middleware.skills import SkillsMiddleware
from app.agent.middleware.usage import UsageMiddleware
from app.agent.prompt import prompt_manager
from app.agent.tools.factory import MoviePilotToolFactory
from app.chain import ChainBase
@@ -41,6 +43,39 @@ class AgentChain(ChainBase):
pass
@dataclass
class _SessionUsageSnapshot:
model: Optional[str] = None
context_window_tokens: Optional[int] = None
last_input_tokens: int = 0
last_output_tokens: int = 0
last_total_tokens: int = 0
last_context_usage_ratio: Optional[float] = None
total_input_tokens: int = 0
total_output_tokens: int = 0
total_tokens: int = 0
model_call_count: int = 0
last_updated_at: Optional[datetime] = None
def to_dict(self, session_id: str) -> dict[str, Any]:
return {
"session_id": session_id,
"model": self.model,
"context_window_tokens": self.context_window_tokens,
"last_input_tokens": self.last_input_tokens,
"last_output_tokens": self.last_output_tokens,
"last_total_tokens": self.last_total_tokens,
"last_context_usage_ratio": self.last_context_usage_ratio,
"total_input_tokens": self.total_input_tokens,
"total_output_tokens": self.total_output_tokens,
"total_tokens": self.total_tokens,
"model_call_count": self.model_call_count,
"last_updated_at": self.last_updated_at.strftime("%Y-%m-%d %H:%M:%S")
if self.last_updated_at
else None,
}
class _ThinkTagStripper:
"""
流式剥离 <think>...</think> 标签的辅助类。
@@ -138,10 +173,92 @@ class MoviePilotAgent:
self.force_streaming = False
self.suppress_user_reply = False
self._streamed_output = ""
self._session_usage = _SessionUsageSnapshot()
# 流式token管理
self.stream_handler = StreamingHandler()
@staticmethod
def _coerce_int(value: Any) -> Optional[int]:
if value is None:
return None
try:
return int(value)
except (TypeError, ValueError):
return None
@classmethod
def _get_model_name(cls, llm: Any) -> Optional[str]:
return (
getattr(llm, "model", None)
or getattr(llm, "model_name", None)
or getattr(llm, "model_id", None)
)
@classmethod
def _get_context_window_tokens(cls, llm: Any) -> Optional[int]:
profile = getattr(llm, "profile", None)
if not profile:
return None
if isinstance(profile, dict):
return cls._coerce_int(
profile.get("max_input_tokens") or profile.get("input_token_limit")
)
return cls._coerce_int(
getattr(profile, "max_input_tokens", None)
or getattr(profile, "input_token_limit", None)
)
def _sync_model_profile(self, llm: Any) -> None:
model_name = self._get_model_name(llm)
context_window_tokens = self._get_context_window_tokens(llm)
if model_name:
self._session_usage.model = model_name
if context_window_tokens:
self._session_usage.context_window_tokens = context_window_tokens
def _record_usage(self, usage: dict[str, Any]) -> None:
if not usage:
return
model_name = usage.get("model")
context_window_tokens = self._coerce_int(usage.get("context_window_tokens"))
if model_name:
self._session_usage.model = model_name
if context_window_tokens:
self._session_usage.context_window_tokens = context_window_tokens
self._session_usage.model_call_count += 1
self._session_usage.last_updated_at = datetime.now()
if not usage.get("has_usage"):
return
input_tokens = self._coerce_int(usage.get("input_tokens")) or 0
output_tokens = self._coerce_int(usage.get("output_tokens")) or 0
total_tokens = self._coerce_int(usage.get("total_tokens"))
if total_tokens is None:
total_tokens = input_tokens + output_tokens
self._session_usage.last_input_tokens = input_tokens
self._session_usage.last_output_tokens = output_tokens
self._session_usage.last_total_tokens = total_tokens
self._session_usage.last_context_usage_ratio = usage.get("context_usage_ratio")
self._session_usage.total_input_tokens += input_tokens
self._session_usage.total_output_tokens += output_tokens
self._session_usage.total_tokens += total_tokens
def get_session_status(self) -> dict[str, Any]:
if not self._session_usage.model:
self._session_usage.model = settings.LLM_MODEL
if not self._session_usage.context_window_tokens:
self._session_usage.context_window_tokens = (
settings.LLM_MAX_CONTEXT_TOKENS * 1000
if settings.LLM_MAX_CONTEXT_TOKENS
else None
)
return self._session_usage.to_dict(self.session_id)
@property
def is_background(self) -> bool:
"""
@@ -258,6 +375,7 @@ class MoviePilotAgent:
# LLM 模型(用于 agent 执行)
llm = self._initialize_llm(streaming=streaming)
self._sync_model_profile(llm)
# 工具列表
tools = self._initialize_tools()
@@ -279,6 +397,8 @@ class MoviePilotAgent:
ActivityLogMiddleware(
activity_dir=str(settings.CONFIG_PATH / "agent" / "activity"),
),
# 用量统计
UsageMiddleware(on_usage=self._record_usage),
# 上下文压缩
SummarizationMiddleware(model=llm, trigger=("fraction", 0.85)),
# 错误工具调用修复
@@ -608,6 +728,37 @@ class AgentManager:
# 重试整理缓冲区锁
self._retry_transfer_lock = asyncio.Lock()
def get_session_status(self, session_id: str) -> dict[str, Any]:
"""获取会话当前模型与 token 使用状态。"""
agent = self.active_agents.get(session_id)
if agent:
status = agent.get_session_status()
else:
status = {
"session_id": session_id,
"model": settings.LLM_MODEL,
"context_window_tokens": settings.LLM_MAX_CONTEXT_TOKENS * 1000
if settings.LLM_MAX_CONTEXT_TOKENS
else None,
"last_input_tokens": 0,
"last_output_tokens": 0,
"last_total_tokens": 0,
"last_context_usage_ratio": None,
"total_input_tokens": 0,
"total_output_tokens": 0,
"total_tokens": 0,
"model_call_count": 0,
"last_updated_at": None,
}
queue = self._session_queues.get(session_id)
status["pending_messages"] = queue.qsize() if queue else 0
status["is_processing"] = (
session_id in self._session_workers
and not self._session_workers[session_id].done()
)
return status
@staticmethod
async def initialize():
"""

View File

@@ -0,0 +1,184 @@
from collections.abc import Awaitable, Callable
from typing import Any
from langchain.agents.middleware.types import (
AgentMiddleware,
ContextT,
ModelRequest,
ModelResponse,
ResponseT,
)
from langchain_core.messages import AIMessage
from app.log import logger
class UsageMiddleware(AgentMiddleware):
"""记录模型调用 usage 信息并回传给外部会话。"""
def __init__(
self,
*,
on_usage: Callable[[dict[str, Any]], None] | None = None,
) -> None:
self.on_usage = on_usage
@staticmethod
def _coerce_int(value: Any) -> int | None:
if value is None:
return None
try:
return int(value)
except (TypeError, ValueError):
return None
@classmethod
def _lookup_int(cls, container: Any, *keys: str) -> int | None:
if not container:
return None
getter = getattr(container, "get", None)
if callable(getter):
for key in keys:
value = getter(key)
if value is not None:
return cls._coerce_int(value)
for key in keys:
value = getattr(container, key, None)
if value is not None:
return cls._coerce_int(value)
return None
@classmethod
def _extract_model_name(cls, model: Any) -> str | None:
return (
getattr(model, "model", None)
or getattr(model, "model_name", None)
or getattr(model, "model_id", None)
)
@classmethod
def _extract_context_window_tokens(cls, model: Any) -> int | None:
profile = getattr(model, "profile", None)
if not profile:
return None
return cls._lookup_int(profile, "max_input_tokens", "input_token_limit")
@classmethod
def _extract_usage(cls, ai_message: AIMessage) -> dict[str, Any]:
usage_metadata = getattr(ai_message, "usage_metadata", None)
input_tokens = cls._lookup_int(usage_metadata, "input_tokens")
output_tokens = cls._lookup_int(usage_metadata, "output_tokens")
total_tokens = cls._lookup_int(usage_metadata, "total_tokens")
response_metadata = getattr(ai_message, "response_metadata", None) or {}
token_usage = (
response_metadata.get("token_usage")
or response_metadata.get("usage")
or response_metadata.get("usage_metadata")
or {}
)
if input_tokens is None:
input_tokens = cls._lookup_int(
token_usage,
"prompt_tokens",
"input_tokens",
)
if input_tokens is None:
input_tokens = cls._lookup_int(
response_metadata,
"prompt_token_count",
"input_tokens",
)
if output_tokens is None:
output_tokens = cls._lookup_int(
token_usage,
"completion_tokens",
"output_tokens",
)
if output_tokens is None:
output_tokens = cls._lookup_int(
response_metadata,
"candidates_token_count",
"output_tokens",
)
if total_tokens is None:
total_tokens = cls._lookup_int(token_usage, "total_tokens")
if total_tokens is None:
total_tokens = cls._lookup_int(response_metadata, "total_token_count")
has_usage = any(
value is not None for value in (input_tokens, output_tokens, total_tokens)
)
resolved_input = input_tokens or 0
resolved_output = output_tokens or 0
resolved_total = (
total_tokens
if total_tokens is not None
else resolved_input + resolved_output
)
return {
"has_usage": has_usage,
"input_tokens": resolved_input,
"output_tokens": resolved_output,
"total_tokens": resolved_total,
}
async def awrap_model_call(
self,
request: ModelRequest[ContextT],
handler: Callable[
[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]
],
) -> ModelResponse[ResponseT]:
response = await handler(request)
if not callable(self.on_usage):
return response
try:
ai_message = next(
(
message
for message in reversed(response.result)
if isinstance(message, AIMessage)
),
None,
)
usage = (
self._extract_usage(ai_message)
if ai_message
else {
"has_usage": False,
"input_tokens": 0,
"output_tokens": 0,
"total_tokens": 0,
}
)
context_window_tokens = self._extract_context_window_tokens(request.model)
context_usage_ratio = None
if context_window_tokens and usage["has_usage"]:
context_usage_ratio = usage["input_tokens"] / context_window_tokens
self.on_usage(
{
"model": self._extract_model_name(request.model),
"context_window_tokens": context_window_tokens,
"context_usage_ratio": context_usage_ratio,
**usage,
}
)
except Exception as e:
logger.debug("记录模型 usage 失败: %s", e)
return response
__all__ = ["UsageMiddleware"]

View File

@@ -786,6 +786,74 @@ class MessageChain(ChainBase):
)
)
@staticmethod
def _format_token_count(value: Optional[int]) -> str:
return f"{value:,}" if value is not None else "未知"
@classmethod
def _format_session_status_text(cls, status: Dict[str, Any]) -> str:
context_window_tokens = status.get("context_window_tokens")
last_input_tokens = status.get("last_input_tokens")
if context_window_tokens and status.get("model_call_count"):
context_ratio = status.get("last_context_usage_ratio")
if context_ratio is None and last_input_tokens is not None:
context_ratio = last_input_tokens / context_window_tokens
context_usage_text = (
f"{cls._format_token_count(last_input_tokens)} / "
f"{cls._format_token_count(context_window_tokens)} "
f"({context_ratio * 100:.2f}%)"
if context_ratio is not None
else f"{cls._format_token_count(last_input_tokens)} / "
f"{cls._format_token_count(context_window_tokens)}"
)
else:
context_usage_text = "暂无模型调用数据"
lines = [
f"会话ID: {status.get('session_id') or '未知'}",
f"执行状态: {'运行中' if status.get('is_processing') else '空闲'}",
f"当前模型: {status.get('model') or '未知'}",
f"上下文窗口: {cls._format_token_count(context_window_tokens)} tokens",
f"最近一次上下文占用: {context_usage_text}",
f"最近一次 tokens: 输入 {cls._format_token_count(status.get('last_input_tokens'))} / 输出 {cls._format_token_count(status.get('last_output_tokens'))} / 总计 {cls._format_token_count(status.get('last_total_tokens'))}",
f"当前会话累计 tokens: 输入 {cls._format_token_count(status.get('total_input_tokens'))} / 输出 {cls._format_token_count(status.get('total_output_tokens'))} / 总计 {cls._format_token_count(status.get('total_tokens'))}",
f"模型调用次数: {status.get('model_call_count', 0)}",
f"排队消息数: {status.get('pending_messages', 0)}",
f"最后更新: {status.get('last_updated_at') or '暂无'}",
]
return "\n".join(lines)
def remote_session_status(
self,
channel: MessageChannel,
userid: Union[str, int],
source: Optional[str] = None,
):
"""查询当前用户的智能体会话状态。"""
session_info = self._user_sessions.get(userid)
if not session_info:
self.post_message(
Notification(
channel=channel,
source=source,
title="您当前没有活跃的智能体会话",
userid=userid,
)
)
return
session_id, _ = session_info
status = agent_manager.get_session_status(session_id=session_id)
self.post_message(
Notification(
channel=channel,
source=source,
title="当前智能体会话状态",
text=self._format_session_status_text(status),
userid=userid,
)
)
def _handle_ai_message(
self,
text: str,
@@ -857,7 +925,12 @@ class MessageChain(ChainBase):
return
elif images:
image_attachments = self._build_image_attachments(images)
if original_images and not image_attachments and not user_message and not files:
if (
original_images
and not image_attachments
and not user_message
and not files
):
self.post_message(
Notification(
channel=channel,
@@ -940,42 +1013,58 @@ class MessageChain(ChainBase):
filename = "input.mp3"
elif audio_ref.startswith("wxwork://voice_media_id/"):
content = self.run_module(
"download_wechat_media_bytes", media_ref=audio_ref, source=source
"download_wechat_media_bytes",
media_ref=audio_ref,
source=source,
)
filename = "input.amr"
elif audio_ref.startswith("slack://file/"):
content = self.run_module(
"download_slack_file_bytes", file_ref=audio_ref, source=source
)
filename = self._guess_audio_filename(audio_ref, default="input.ogg")
filename = self._guess_audio_filename(
audio_ref, default="input.ogg"
)
elif audio_ref.startswith("discord://file/"):
content = self.run_module(
"download_discord_file_bytes", file_ref=audio_ref, source=source
)
filename = self._guess_audio_filename(audio_ref, default="input.ogg")
filename = self._guess_audio_filename(
audio_ref, default="input.ogg"
)
elif audio_ref.startswith("qq://file/"):
content = self.run_module(
"download_qq_file_bytes", file_ref=audio_ref, source=source
)
filename = self._guess_audio_filename(audio_ref, default="input.ogg")
filename = self._guess_audio_filename(
audio_ref, default="input.ogg"
)
elif audio_ref.startswith("vocechat://file/"):
content = self.run_module(
"download_vocechat_file_bytes", file_ref=audio_ref, source=source
"download_vocechat_file_bytes",
file_ref=audio_ref,
source=source,
)
filename = self._guess_audio_filename(
audio_ref, default="input.ogg"
)
filename = self._guess_audio_filename(audio_ref, default="input.ogg")
elif audio_ref.startswith("synology://file/"):
content = self.run_module(
"download_synologychat_file_bytes",
file_ref=audio_ref,
source=source,
)
filename = self._guess_audio_filename(audio_ref, default="input.ogg")
filename = self._guess_audio_filename(
audio_ref, default="input.ogg"
)
elif audio_ref.startswith("wxbot://voice"):
continue
elif audio_ref.startswith("http"):
resp = RequestUtils(timeout=30).get_res(audio_ref)
content = resp.content if resp and resp.content else None
filename = self._guess_audio_filename(audio_ref, default="input.ogg")
filename = self._guess_audio_filename(
audio_ref, default="input.ogg"
)
else:
logger.debug(
"暂不支持的语音引用: channel=%s, source=%s, ref=%s",
@@ -994,7 +1083,9 @@ class MessageChain(ChainBase):
)
continue
transcript = VoiceHelper.transcribe_bytes(content=content, filename=filename)
transcript = VoiceHelper.transcribe_bytes(
content=content, filename=filename
)
if transcript:
transcripts.append(transcript)
logger.info(
@@ -1047,7 +1138,9 @@ class MessageChain(ChainBase):
elif img.startswith("tg://file_id/"):
file_id = img.replace("tg://file_id/", "")
base64_data = self.run_module(
"download_telegram_file_to_base64", file_id=file_id, source=source
"download_telegram_file_to_base64",
file_id=file_id,
source=source,
)
if base64_data:
base64_images.append(f"data:image/jpeg;base64,{base64_data}")

View File

@@ -155,6 +155,12 @@ class Command(metaclass=Singleton):
"category": "管理",
"data": {},
},
"/session_status": {
"func": MessageChain().remote_session_status,
"description": "会话状态",
"category": "智能体",
"data": {},
},
"/skills": {
"func": SkillsChain().remote_manage,
"description": "管理技能",

View File

@@ -0,0 +1,106 @@
import asyncio
import unittest
from datetime import datetime
from types import SimpleNamespace
from unittest.mock import patch
from langchain.agents.middleware.types import ModelRequest, ModelResponse
from langchain_core.messages import AIMessage
from app.agent.middleware.usage import UsageMiddleware
from app.chain.message import MessageChain
from app.schemas.types import MessageChannel
class TestAgentSessionStatus(unittest.TestCase):
def test_usage_middleware_records_usage_metadata(self):
snapshots = []
middleware = UsageMiddleware(on_usage=snapshots.append)
request = ModelRequest(
model=SimpleNamespace(
model="gpt-4o-mini", profile={"max_input_tokens": 128000}
),
messages=[],
state={},
runtime=None,
)
response = ModelResponse(
result=[
AIMessage(
content="ok",
usage_metadata={
"input_tokens": 1200,
"output_tokens": 300,
"total_tokens": 1500,
},
)
]
)
async def handler(_: ModelRequest):
return response
result = asyncio.run(middleware.awrap_model_call(request, handler))
self.assertIs(result, response)
self.assertEqual(len(snapshots), 1)
self.assertEqual(snapshots[0]["model"], "gpt-4o-mini")
self.assertEqual(snapshots[0]["context_window_tokens"], 128000)
self.assertEqual(snapshots[0]["input_tokens"], 1200)
self.assertEqual(snapshots[0]["output_tokens"], 300)
self.assertEqual(snapshots[0]["total_tokens"], 1500)
self.assertAlmostEqual(snapshots[0]["context_usage_ratio"], 1200 / 128000)
def test_remote_session_status_sends_usage_summary(self):
chain = MessageChain()
chain._user_sessions["10001"] = ("session-1", datetime.now())
status = {
"session_id": "session-1",
"model": "gpt-4o-mini",
"context_window_tokens": 128000,
"last_input_tokens": 1200,
"last_output_tokens": 300,
"last_total_tokens": 1500,
"last_context_usage_ratio": 1200 / 128000,
"total_input_tokens": 4500,
"total_output_tokens": 1500,
"total_tokens": 6000,
"model_call_count": 4,
"last_updated_at": "2026-04-26 12:34:56",
"is_processing": True,
"pending_messages": 2,
}
with (
patch(
"app.chain.message.agent_manager.get_session_status",
return_value=status,
),
patch.object(chain, "post_message") as post_message,
):
chain.remote_session_status(
channel=MessageChannel.Telegram,
userid="10001",
source="telegram-test",
)
notification = post_message.call_args.args[0]
self.assertEqual(notification.title, "当前智能体会话状态")
self.assertIn("session-1", notification.text)
self.assertIn("gpt-4o-mini", notification.text)
self.assertIn("1,200 / 128,000 (0.94%)", notification.text)
self.assertIn("输入 4,500 / 输出 1,500 / 总计 6,000", notification.text)
self.assertIn("运行中", notification.text)
def test_remote_session_status_handles_missing_session(self):
chain = MessageChain()
with patch.object(chain, "post_message") as post_message:
chain.remote_session_status(
channel=MessageChannel.Telegram,
userid="10001",
source="telegram-test",
)
notification = post_message.call_args.args[0]
self.assertEqual(notification.title, "您当前没有活跃的智能体会话")