mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-05-09 23:02:09 +08:00
feat: add agent session usage status reporting
Track per-session model and token usage so users can inspect context pressure and cumulative usage with /session_status.
This commit is contained in:
@@ -4,7 +4,8 @@ import re
|
||||
import traceback
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Dict, List, Optional
|
||||
from datetime import datetime
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
from langchain.agents import create_agent
|
||||
from langchain.agents.middleware import (
|
||||
@@ -24,6 +25,7 @@ from app.agent.middleware.jobs import JobsMiddleware
|
||||
from app.agent.middleware.memory import MemoryMiddleware
|
||||
from app.agent.middleware.patch_tool_calls import PatchToolCallsMiddleware
|
||||
from app.agent.middleware.skills import SkillsMiddleware
|
||||
from app.agent.middleware.usage import UsageMiddleware
|
||||
from app.agent.prompt import prompt_manager
|
||||
from app.agent.tools.factory import MoviePilotToolFactory
|
||||
from app.chain import ChainBase
|
||||
@@ -41,6 +43,39 @@ class AgentChain(ChainBase):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class _SessionUsageSnapshot:
|
||||
model: Optional[str] = None
|
||||
context_window_tokens: Optional[int] = None
|
||||
last_input_tokens: int = 0
|
||||
last_output_tokens: int = 0
|
||||
last_total_tokens: int = 0
|
||||
last_context_usage_ratio: Optional[float] = None
|
||||
total_input_tokens: int = 0
|
||||
total_output_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
model_call_count: int = 0
|
||||
last_updated_at: Optional[datetime] = None
|
||||
|
||||
def to_dict(self, session_id: str) -> dict[str, Any]:
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"model": self.model,
|
||||
"context_window_tokens": self.context_window_tokens,
|
||||
"last_input_tokens": self.last_input_tokens,
|
||||
"last_output_tokens": self.last_output_tokens,
|
||||
"last_total_tokens": self.last_total_tokens,
|
||||
"last_context_usage_ratio": self.last_context_usage_ratio,
|
||||
"total_input_tokens": self.total_input_tokens,
|
||||
"total_output_tokens": self.total_output_tokens,
|
||||
"total_tokens": self.total_tokens,
|
||||
"model_call_count": self.model_call_count,
|
||||
"last_updated_at": self.last_updated_at.strftime("%Y-%m-%d %H:%M:%S")
|
||||
if self.last_updated_at
|
||||
else None,
|
||||
}
|
||||
|
||||
|
||||
class _ThinkTagStripper:
|
||||
"""
|
||||
流式剥离 <think>...</think> 标签的辅助类。
|
||||
@@ -138,10 +173,92 @@ class MoviePilotAgent:
|
||||
self.force_streaming = False
|
||||
self.suppress_user_reply = False
|
||||
self._streamed_output = ""
|
||||
self._session_usage = _SessionUsageSnapshot()
|
||||
|
||||
# 流式token管理
|
||||
self.stream_handler = StreamingHandler()
|
||||
|
||||
@staticmethod
|
||||
def _coerce_int(value: Any) -> Optional[int]:
|
||||
if value is None:
|
||||
return None
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _get_model_name(cls, llm: Any) -> Optional[str]:
|
||||
return (
|
||||
getattr(llm, "model", None)
|
||||
or getattr(llm, "model_name", None)
|
||||
or getattr(llm, "model_id", None)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _get_context_window_tokens(cls, llm: Any) -> Optional[int]:
|
||||
profile = getattr(llm, "profile", None)
|
||||
if not profile:
|
||||
return None
|
||||
if isinstance(profile, dict):
|
||||
return cls._coerce_int(
|
||||
profile.get("max_input_tokens") or profile.get("input_token_limit")
|
||||
)
|
||||
return cls._coerce_int(
|
||||
getattr(profile, "max_input_tokens", None)
|
||||
or getattr(profile, "input_token_limit", None)
|
||||
)
|
||||
|
||||
def _sync_model_profile(self, llm: Any) -> None:
|
||||
model_name = self._get_model_name(llm)
|
||||
context_window_tokens = self._get_context_window_tokens(llm)
|
||||
if model_name:
|
||||
self._session_usage.model = model_name
|
||||
if context_window_tokens:
|
||||
self._session_usage.context_window_tokens = context_window_tokens
|
||||
|
||||
def _record_usage(self, usage: dict[str, Any]) -> None:
|
||||
if not usage:
|
||||
return
|
||||
|
||||
model_name = usage.get("model")
|
||||
context_window_tokens = self._coerce_int(usage.get("context_window_tokens"))
|
||||
if model_name:
|
||||
self._session_usage.model = model_name
|
||||
if context_window_tokens:
|
||||
self._session_usage.context_window_tokens = context_window_tokens
|
||||
|
||||
self._session_usage.model_call_count += 1
|
||||
self._session_usage.last_updated_at = datetime.now()
|
||||
|
||||
if not usage.get("has_usage"):
|
||||
return
|
||||
|
||||
input_tokens = self._coerce_int(usage.get("input_tokens")) or 0
|
||||
output_tokens = self._coerce_int(usage.get("output_tokens")) or 0
|
||||
total_tokens = self._coerce_int(usage.get("total_tokens"))
|
||||
if total_tokens is None:
|
||||
total_tokens = input_tokens + output_tokens
|
||||
|
||||
self._session_usage.last_input_tokens = input_tokens
|
||||
self._session_usage.last_output_tokens = output_tokens
|
||||
self._session_usage.last_total_tokens = total_tokens
|
||||
self._session_usage.last_context_usage_ratio = usage.get("context_usage_ratio")
|
||||
self._session_usage.total_input_tokens += input_tokens
|
||||
self._session_usage.total_output_tokens += output_tokens
|
||||
self._session_usage.total_tokens += total_tokens
|
||||
|
||||
def get_session_status(self) -> dict[str, Any]:
|
||||
if not self._session_usage.model:
|
||||
self._session_usage.model = settings.LLM_MODEL
|
||||
if not self._session_usage.context_window_tokens:
|
||||
self._session_usage.context_window_tokens = (
|
||||
settings.LLM_MAX_CONTEXT_TOKENS * 1000
|
||||
if settings.LLM_MAX_CONTEXT_TOKENS
|
||||
else None
|
||||
)
|
||||
return self._session_usage.to_dict(self.session_id)
|
||||
|
||||
@property
|
||||
def is_background(self) -> bool:
|
||||
"""
|
||||
@@ -258,6 +375,7 @@ class MoviePilotAgent:
|
||||
|
||||
# LLM 模型(用于 agent 执行)
|
||||
llm = self._initialize_llm(streaming=streaming)
|
||||
self._sync_model_profile(llm)
|
||||
|
||||
# 工具列表
|
||||
tools = self._initialize_tools()
|
||||
@@ -279,6 +397,8 @@ class MoviePilotAgent:
|
||||
ActivityLogMiddleware(
|
||||
activity_dir=str(settings.CONFIG_PATH / "agent" / "activity"),
|
||||
),
|
||||
# 用量统计
|
||||
UsageMiddleware(on_usage=self._record_usage),
|
||||
# 上下文压缩
|
||||
SummarizationMiddleware(model=llm, trigger=("fraction", 0.85)),
|
||||
# 错误工具调用修复
|
||||
@@ -608,6 +728,37 @@ class AgentManager:
|
||||
# 重试整理缓冲区锁
|
||||
self._retry_transfer_lock = asyncio.Lock()
|
||||
|
||||
def get_session_status(self, session_id: str) -> dict[str, Any]:
|
||||
"""获取会话当前模型与 token 使用状态。"""
|
||||
agent = self.active_agents.get(session_id)
|
||||
if agent:
|
||||
status = agent.get_session_status()
|
||||
else:
|
||||
status = {
|
||||
"session_id": session_id,
|
||||
"model": settings.LLM_MODEL,
|
||||
"context_window_tokens": settings.LLM_MAX_CONTEXT_TOKENS * 1000
|
||||
if settings.LLM_MAX_CONTEXT_TOKENS
|
||||
else None,
|
||||
"last_input_tokens": 0,
|
||||
"last_output_tokens": 0,
|
||||
"last_total_tokens": 0,
|
||||
"last_context_usage_ratio": None,
|
||||
"total_input_tokens": 0,
|
||||
"total_output_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
"model_call_count": 0,
|
||||
"last_updated_at": None,
|
||||
}
|
||||
|
||||
queue = self._session_queues.get(session_id)
|
||||
status["pending_messages"] = queue.qsize() if queue else 0
|
||||
status["is_processing"] = (
|
||||
session_id in self._session_workers
|
||||
and not self._session_workers[session_id].done()
|
||||
)
|
||||
return status
|
||||
|
||||
@staticmethod
|
||||
async def initialize():
|
||||
"""
|
||||
|
||||
184
app/agent/middleware/usage.py
Normal file
184
app/agent/middleware/usage.py
Normal file
@@ -0,0 +1,184 @@
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
ContextT,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
ResponseT,
|
||||
)
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class UsageMiddleware(AgentMiddleware):
|
||||
"""记录模型调用 usage 信息并回传给外部会话。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
on_usage: Callable[[dict[str, Any]], None] | None = None,
|
||||
) -> None:
|
||||
self.on_usage = on_usage
|
||||
|
||||
@staticmethod
|
||||
def _coerce_int(value: Any) -> int | None:
|
||||
if value is None:
|
||||
return None
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _lookup_int(cls, container: Any, *keys: str) -> int | None:
|
||||
if not container:
|
||||
return None
|
||||
|
||||
getter = getattr(container, "get", None)
|
||||
if callable(getter):
|
||||
for key in keys:
|
||||
value = getter(key)
|
||||
if value is not None:
|
||||
return cls._coerce_int(value)
|
||||
|
||||
for key in keys:
|
||||
value = getattr(container, key, None)
|
||||
if value is not None:
|
||||
return cls._coerce_int(value)
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _extract_model_name(cls, model: Any) -> str | None:
|
||||
return (
|
||||
getattr(model, "model", None)
|
||||
or getattr(model, "model_name", None)
|
||||
or getattr(model, "model_id", None)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_context_window_tokens(cls, model: Any) -> int | None:
|
||||
profile = getattr(model, "profile", None)
|
||||
if not profile:
|
||||
return None
|
||||
return cls._lookup_int(profile, "max_input_tokens", "input_token_limit")
|
||||
|
||||
@classmethod
|
||||
def _extract_usage(cls, ai_message: AIMessage) -> dict[str, Any]:
|
||||
usage_metadata = getattr(ai_message, "usage_metadata", None)
|
||||
|
||||
input_tokens = cls._lookup_int(usage_metadata, "input_tokens")
|
||||
output_tokens = cls._lookup_int(usage_metadata, "output_tokens")
|
||||
total_tokens = cls._lookup_int(usage_metadata, "total_tokens")
|
||||
|
||||
response_metadata = getattr(ai_message, "response_metadata", None) or {}
|
||||
token_usage = (
|
||||
response_metadata.get("token_usage")
|
||||
or response_metadata.get("usage")
|
||||
or response_metadata.get("usage_metadata")
|
||||
or {}
|
||||
)
|
||||
|
||||
if input_tokens is None:
|
||||
input_tokens = cls._lookup_int(
|
||||
token_usage,
|
||||
"prompt_tokens",
|
||||
"input_tokens",
|
||||
)
|
||||
if input_tokens is None:
|
||||
input_tokens = cls._lookup_int(
|
||||
response_metadata,
|
||||
"prompt_token_count",
|
||||
"input_tokens",
|
||||
)
|
||||
|
||||
if output_tokens is None:
|
||||
output_tokens = cls._lookup_int(
|
||||
token_usage,
|
||||
"completion_tokens",
|
||||
"output_tokens",
|
||||
)
|
||||
if output_tokens is None:
|
||||
output_tokens = cls._lookup_int(
|
||||
response_metadata,
|
||||
"candidates_token_count",
|
||||
"output_tokens",
|
||||
)
|
||||
|
||||
if total_tokens is None:
|
||||
total_tokens = cls._lookup_int(token_usage, "total_tokens")
|
||||
if total_tokens is None:
|
||||
total_tokens = cls._lookup_int(response_metadata, "total_token_count")
|
||||
|
||||
has_usage = any(
|
||||
value is not None for value in (input_tokens, output_tokens, total_tokens)
|
||||
)
|
||||
resolved_input = input_tokens or 0
|
||||
resolved_output = output_tokens or 0
|
||||
resolved_total = (
|
||||
total_tokens
|
||||
if total_tokens is not None
|
||||
else resolved_input + resolved_output
|
||||
)
|
||||
|
||||
return {
|
||||
"has_usage": has_usage,
|
||||
"input_tokens": resolved_input,
|
||||
"output_tokens": resolved_output,
|
||||
"total_tokens": resolved_total,
|
||||
}
|
||||
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[
|
||||
[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]
|
||||
],
|
||||
) -> ModelResponse[ResponseT]:
|
||||
response = await handler(request)
|
||||
|
||||
if not callable(self.on_usage):
|
||||
return response
|
||||
|
||||
try:
|
||||
ai_message = next(
|
||||
(
|
||||
message
|
||||
for message in reversed(response.result)
|
||||
if isinstance(message, AIMessage)
|
||||
),
|
||||
None,
|
||||
)
|
||||
usage = (
|
||||
self._extract_usage(ai_message)
|
||||
if ai_message
|
||||
else {
|
||||
"has_usage": False,
|
||||
"input_tokens": 0,
|
||||
"output_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
}
|
||||
)
|
||||
context_window_tokens = self._extract_context_window_tokens(request.model)
|
||||
context_usage_ratio = None
|
||||
if context_window_tokens and usage["has_usage"]:
|
||||
context_usage_ratio = usage["input_tokens"] / context_window_tokens
|
||||
|
||||
self.on_usage(
|
||||
{
|
||||
"model": self._extract_model_name(request.model),
|
||||
"context_window_tokens": context_window_tokens,
|
||||
"context_usage_ratio": context_usage_ratio,
|
||||
**usage,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("记录模型 usage 失败: %s", e)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
__all__ = ["UsageMiddleware"]
|
||||
@@ -786,6 +786,74 @@ class MessageChain(ChainBase):
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _format_token_count(value: Optional[int]) -> str:
|
||||
return f"{value:,}" if value is not None else "未知"
|
||||
|
||||
@classmethod
|
||||
def _format_session_status_text(cls, status: Dict[str, Any]) -> str:
|
||||
context_window_tokens = status.get("context_window_tokens")
|
||||
last_input_tokens = status.get("last_input_tokens")
|
||||
if context_window_tokens and status.get("model_call_count"):
|
||||
context_ratio = status.get("last_context_usage_ratio")
|
||||
if context_ratio is None and last_input_tokens is not None:
|
||||
context_ratio = last_input_tokens / context_window_tokens
|
||||
context_usage_text = (
|
||||
f"{cls._format_token_count(last_input_tokens)} / "
|
||||
f"{cls._format_token_count(context_window_tokens)} "
|
||||
f"({context_ratio * 100:.2f}%)"
|
||||
if context_ratio is not None
|
||||
else f"{cls._format_token_count(last_input_tokens)} / "
|
||||
f"{cls._format_token_count(context_window_tokens)}"
|
||||
)
|
||||
else:
|
||||
context_usage_text = "暂无模型调用数据"
|
||||
|
||||
lines = [
|
||||
f"会话ID: {status.get('session_id') or '未知'}",
|
||||
f"执行状态: {'运行中' if status.get('is_processing') else '空闲'}",
|
||||
f"当前模型: {status.get('model') or '未知'}",
|
||||
f"上下文窗口: {cls._format_token_count(context_window_tokens)} tokens",
|
||||
f"最近一次上下文占用: {context_usage_text}",
|
||||
f"最近一次 tokens: 输入 {cls._format_token_count(status.get('last_input_tokens'))} / 输出 {cls._format_token_count(status.get('last_output_tokens'))} / 总计 {cls._format_token_count(status.get('last_total_tokens'))}",
|
||||
f"当前会话累计 tokens: 输入 {cls._format_token_count(status.get('total_input_tokens'))} / 输出 {cls._format_token_count(status.get('total_output_tokens'))} / 总计 {cls._format_token_count(status.get('total_tokens'))}",
|
||||
f"模型调用次数: {status.get('model_call_count', 0)}",
|
||||
f"排队消息数: {status.get('pending_messages', 0)}",
|
||||
f"最后更新: {status.get('last_updated_at') or '暂无'}",
|
||||
]
|
||||
return "\n".join(lines)
|
||||
|
||||
def remote_session_status(
|
||||
self,
|
||||
channel: MessageChannel,
|
||||
userid: Union[str, int],
|
||||
source: Optional[str] = None,
|
||||
):
|
||||
"""查询当前用户的智能体会话状态。"""
|
||||
session_info = self._user_sessions.get(userid)
|
||||
if not session_info:
|
||||
self.post_message(
|
||||
Notification(
|
||||
channel=channel,
|
||||
source=source,
|
||||
title="您当前没有活跃的智能体会话",
|
||||
userid=userid,
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
session_id, _ = session_info
|
||||
status = agent_manager.get_session_status(session_id=session_id)
|
||||
self.post_message(
|
||||
Notification(
|
||||
channel=channel,
|
||||
source=source,
|
||||
title="当前智能体会话状态",
|
||||
text=self._format_session_status_text(status),
|
||||
userid=userid,
|
||||
)
|
||||
)
|
||||
|
||||
def _handle_ai_message(
|
||||
self,
|
||||
text: str,
|
||||
@@ -857,7 +925,12 @@ class MessageChain(ChainBase):
|
||||
return
|
||||
elif images:
|
||||
image_attachments = self._build_image_attachments(images)
|
||||
if original_images and not image_attachments and not user_message and not files:
|
||||
if (
|
||||
original_images
|
||||
and not image_attachments
|
||||
and not user_message
|
||||
and not files
|
||||
):
|
||||
self.post_message(
|
||||
Notification(
|
||||
channel=channel,
|
||||
@@ -940,42 +1013,58 @@ class MessageChain(ChainBase):
|
||||
filename = "input.mp3"
|
||||
elif audio_ref.startswith("wxwork://voice_media_id/"):
|
||||
content = self.run_module(
|
||||
"download_wechat_media_bytes", media_ref=audio_ref, source=source
|
||||
"download_wechat_media_bytes",
|
||||
media_ref=audio_ref,
|
||||
source=source,
|
||||
)
|
||||
filename = "input.amr"
|
||||
elif audio_ref.startswith("slack://file/"):
|
||||
content = self.run_module(
|
||||
"download_slack_file_bytes", file_ref=audio_ref, source=source
|
||||
)
|
||||
filename = self._guess_audio_filename(audio_ref, default="input.ogg")
|
||||
filename = self._guess_audio_filename(
|
||||
audio_ref, default="input.ogg"
|
||||
)
|
||||
elif audio_ref.startswith("discord://file/"):
|
||||
content = self.run_module(
|
||||
"download_discord_file_bytes", file_ref=audio_ref, source=source
|
||||
)
|
||||
filename = self._guess_audio_filename(audio_ref, default="input.ogg")
|
||||
filename = self._guess_audio_filename(
|
||||
audio_ref, default="input.ogg"
|
||||
)
|
||||
elif audio_ref.startswith("qq://file/"):
|
||||
content = self.run_module(
|
||||
"download_qq_file_bytes", file_ref=audio_ref, source=source
|
||||
)
|
||||
filename = self._guess_audio_filename(audio_ref, default="input.ogg")
|
||||
filename = self._guess_audio_filename(
|
||||
audio_ref, default="input.ogg"
|
||||
)
|
||||
elif audio_ref.startswith("vocechat://file/"):
|
||||
content = self.run_module(
|
||||
"download_vocechat_file_bytes", file_ref=audio_ref, source=source
|
||||
"download_vocechat_file_bytes",
|
||||
file_ref=audio_ref,
|
||||
source=source,
|
||||
)
|
||||
filename = self._guess_audio_filename(
|
||||
audio_ref, default="input.ogg"
|
||||
)
|
||||
filename = self._guess_audio_filename(audio_ref, default="input.ogg")
|
||||
elif audio_ref.startswith("synology://file/"):
|
||||
content = self.run_module(
|
||||
"download_synologychat_file_bytes",
|
||||
file_ref=audio_ref,
|
||||
source=source,
|
||||
)
|
||||
filename = self._guess_audio_filename(audio_ref, default="input.ogg")
|
||||
filename = self._guess_audio_filename(
|
||||
audio_ref, default="input.ogg"
|
||||
)
|
||||
elif audio_ref.startswith("wxbot://voice"):
|
||||
continue
|
||||
elif audio_ref.startswith("http"):
|
||||
resp = RequestUtils(timeout=30).get_res(audio_ref)
|
||||
content = resp.content if resp and resp.content else None
|
||||
filename = self._guess_audio_filename(audio_ref, default="input.ogg")
|
||||
filename = self._guess_audio_filename(
|
||||
audio_ref, default="input.ogg"
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"暂不支持的语音引用: channel=%s, source=%s, ref=%s",
|
||||
@@ -994,7 +1083,9 @@ class MessageChain(ChainBase):
|
||||
)
|
||||
continue
|
||||
|
||||
transcript = VoiceHelper.transcribe_bytes(content=content, filename=filename)
|
||||
transcript = VoiceHelper.transcribe_bytes(
|
||||
content=content, filename=filename
|
||||
)
|
||||
if transcript:
|
||||
transcripts.append(transcript)
|
||||
logger.info(
|
||||
@@ -1047,7 +1138,9 @@ class MessageChain(ChainBase):
|
||||
elif img.startswith("tg://file_id/"):
|
||||
file_id = img.replace("tg://file_id/", "")
|
||||
base64_data = self.run_module(
|
||||
"download_telegram_file_to_base64", file_id=file_id, source=source
|
||||
"download_telegram_file_to_base64",
|
||||
file_id=file_id,
|
||||
source=source,
|
||||
)
|
||||
if base64_data:
|
||||
base64_images.append(f"data:image/jpeg;base64,{base64_data}")
|
||||
|
||||
@@ -155,6 +155,12 @@ class Command(metaclass=Singleton):
|
||||
"category": "管理",
|
||||
"data": {},
|
||||
},
|
||||
"/session_status": {
|
||||
"func": MessageChain().remote_session_status,
|
||||
"description": "会话状态",
|
||||
"category": "智能体",
|
||||
"data": {},
|
||||
},
|
||||
"/skills": {
|
||||
"func": SkillsChain().remote_manage,
|
||||
"description": "管理技能",
|
||||
|
||||
106
tests/test_agent_session_status.py
Normal file
106
tests/test_agent_session_status.py
Normal file
@@ -0,0 +1,106 @@
|
||||
import asyncio
|
||||
import unittest
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
from langchain.agents.middleware.types import ModelRequest, ModelResponse
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
from app.agent.middleware.usage import UsageMiddleware
|
||||
from app.chain.message import MessageChain
|
||||
from app.schemas.types import MessageChannel
|
||||
|
||||
|
||||
class TestAgentSessionStatus(unittest.TestCase):
|
||||
def test_usage_middleware_records_usage_metadata(self):
|
||||
snapshots = []
|
||||
middleware = UsageMiddleware(on_usage=snapshots.append)
|
||||
request = ModelRequest(
|
||||
model=SimpleNamespace(
|
||||
model="gpt-4o-mini", profile={"max_input_tokens": 128000}
|
||||
),
|
||||
messages=[],
|
||||
state={},
|
||||
runtime=None,
|
||||
)
|
||||
response = ModelResponse(
|
||||
result=[
|
||||
AIMessage(
|
||||
content="ok",
|
||||
usage_metadata={
|
||||
"input_tokens": 1200,
|
||||
"output_tokens": 300,
|
||||
"total_tokens": 1500,
|
||||
},
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
async def handler(_: ModelRequest):
|
||||
return response
|
||||
|
||||
result = asyncio.run(middleware.awrap_model_call(request, handler))
|
||||
|
||||
self.assertIs(result, response)
|
||||
self.assertEqual(len(snapshots), 1)
|
||||
self.assertEqual(snapshots[0]["model"], "gpt-4o-mini")
|
||||
self.assertEqual(snapshots[0]["context_window_tokens"], 128000)
|
||||
self.assertEqual(snapshots[0]["input_tokens"], 1200)
|
||||
self.assertEqual(snapshots[0]["output_tokens"], 300)
|
||||
self.assertEqual(snapshots[0]["total_tokens"], 1500)
|
||||
self.assertAlmostEqual(snapshots[0]["context_usage_ratio"], 1200 / 128000)
|
||||
|
||||
def test_remote_session_status_sends_usage_summary(self):
|
||||
chain = MessageChain()
|
||||
chain._user_sessions["10001"] = ("session-1", datetime.now())
|
||||
status = {
|
||||
"session_id": "session-1",
|
||||
"model": "gpt-4o-mini",
|
||||
"context_window_tokens": 128000,
|
||||
"last_input_tokens": 1200,
|
||||
"last_output_tokens": 300,
|
||||
"last_total_tokens": 1500,
|
||||
"last_context_usage_ratio": 1200 / 128000,
|
||||
"total_input_tokens": 4500,
|
||||
"total_output_tokens": 1500,
|
||||
"total_tokens": 6000,
|
||||
"model_call_count": 4,
|
||||
"last_updated_at": "2026-04-26 12:34:56",
|
||||
"is_processing": True,
|
||||
"pending_messages": 2,
|
||||
}
|
||||
|
||||
with (
|
||||
patch(
|
||||
"app.chain.message.agent_manager.get_session_status",
|
||||
return_value=status,
|
||||
),
|
||||
patch.object(chain, "post_message") as post_message,
|
||||
):
|
||||
chain.remote_session_status(
|
||||
channel=MessageChannel.Telegram,
|
||||
userid="10001",
|
||||
source="telegram-test",
|
||||
)
|
||||
|
||||
notification = post_message.call_args.args[0]
|
||||
self.assertEqual(notification.title, "当前智能体会话状态")
|
||||
self.assertIn("session-1", notification.text)
|
||||
self.assertIn("gpt-4o-mini", notification.text)
|
||||
self.assertIn("1,200 / 128,000 (0.94%)", notification.text)
|
||||
self.assertIn("输入 4,500 / 输出 1,500 / 总计 6,000", notification.text)
|
||||
self.assertIn("运行中", notification.text)
|
||||
|
||||
def test_remote_session_status_handles_missing_session(self):
|
||||
chain = MessageChain()
|
||||
|
||||
with patch.object(chain, "post_message") as post_message:
|
||||
chain.remote_session_status(
|
||||
channel=MessageChannel.Telegram,
|
||||
userid="10001",
|
||||
source="telegram-test",
|
||||
)
|
||||
|
||||
notification = post_message.call_args.args[0]
|
||||
self.assertEqual(notification.title, "您当前没有活跃的智能体会话")
|
||||
Reference in New Issue
Block a user