mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-05-07 13:52:42 +08:00
Track per-session model and token usage so users can inspect context pressure and cumulative usage with /session_status.
185 lines
5.6 KiB
Python
185 lines
5.6 KiB
Python
from collections.abc import Awaitable, Callable
|
|
from typing import Any
|
|
|
|
from langchain.agents.middleware.types import (
|
|
AgentMiddleware,
|
|
ContextT,
|
|
ModelRequest,
|
|
ModelResponse,
|
|
ResponseT,
|
|
)
|
|
from langchain_core.messages import AIMessage
|
|
|
|
from app.log import logger
|
|
|
|
|
|
class UsageMiddleware(AgentMiddleware):
|
|
"""记录模型调用 usage 信息并回传给外部会话。"""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
on_usage: Callable[[dict[str, Any]], None] | None = None,
|
|
) -> None:
|
|
self.on_usage = on_usage
|
|
|
|
@staticmethod
|
|
def _coerce_int(value: Any) -> int | None:
|
|
if value is None:
|
|
return None
|
|
try:
|
|
return int(value)
|
|
except (TypeError, ValueError):
|
|
return None
|
|
|
|
@classmethod
|
|
def _lookup_int(cls, container: Any, *keys: str) -> int | None:
|
|
if not container:
|
|
return None
|
|
|
|
getter = getattr(container, "get", None)
|
|
if callable(getter):
|
|
for key in keys:
|
|
value = getter(key)
|
|
if value is not None:
|
|
return cls._coerce_int(value)
|
|
|
|
for key in keys:
|
|
value = getattr(container, key, None)
|
|
if value is not None:
|
|
return cls._coerce_int(value)
|
|
|
|
return None
|
|
|
|
@classmethod
|
|
def _extract_model_name(cls, model: Any) -> str | None:
|
|
return (
|
|
getattr(model, "model", None)
|
|
or getattr(model, "model_name", None)
|
|
or getattr(model, "model_id", None)
|
|
)
|
|
|
|
@classmethod
|
|
def _extract_context_window_tokens(cls, model: Any) -> int | None:
|
|
profile = getattr(model, "profile", None)
|
|
if not profile:
|
|
return None
|
|
return cls._lookup_int(profile, "max_input_tokens", "input_token_limit")
|
|
|
|
@classmethod
|
|
def _extract_usage(cls, ai_message: AIMessage) -> dict[str, Any]:
|
|
usage_metadata = getattr(ai_message, "usage_metadata", None)
|
|
|
|
input_tokens = cls._lookup_int(usage_metadata, "input_tokens")
|
|
output_tokens = cls._lookup_int(usage_metadata, "output_tokens")
|
|
total_tokens = cls._lookup_int(usage_metadata, "total_tokens")
|
|
|
|
response_metadata = getattr(ai_message, "response_metadata", None) or {}
|
|
token_usage = (
|
|
response_metadata.get("token_usage")
|
|
or response_metadata.get("usage")
|
|
or response_metadata.get("usage_metadata")
|
|
or {}
|
|
)
|
|
|
|
if input_tokens is None:
|
|
input_tokens = cls._lookup_int(
|
|
token_usage,
|
|
"prompt_tokens",
|
|
"input_tokens",
|
|
)
|
|
if input_tokens is None:
|
|
input_tokens = cls._lookup_int(
|
|
response_metadata,
|
|
"prompt_token_count",
|
|
"input_tokens",
|
|
)
|
|
|
|
if output_tokens is None:
|
|
output_tokens = cls._lookup_int(
|
|
token_usage,
|
|
"completion_tokens",
|
|
"output_tokens",
|
|
)
|
|
if output_tokens is None:
|
|
output_tokens = cls._lookup_int(
|
|
response_metadata,
|
|
"candidates_token_count",
|
|
"output_tokens",
|
|
)
|
|
|
|
if total_tokens is None:
|
|
total_tokens = cls._lookup_int(token_usage, "total_tokens")
|
|
if total_tokens is None:
|
|
total_tokens = cls._lookup_int(response_metadata, "total_token_count")
|
|
|
|
has_usage = any(
|
|
value is not None for value in (input_tokens, output_tokens, total_tokens)
|
|
)
|
|
resolved_input = input_tokens or 0
|
|
resolved_output = output_tokens or 0
|
|
resolved_total = (
|
|
total_tokens
|
|
if total_tokens is not None
|
|
else resolved_input + resolved_output
|
|
)
|
|
|
|
return {
|
|
"has_usage": has_usage,
|
|
"input_tokens": resolved_input,
|
|
"output_tokens": resolved_output,
|
|
"total_tokens": resolved_total,
|
|
}
|
|
|
|
async def awrap_model_call(
|
|
self,
|
|
request: ModelRequest[ContextT],
|
|
handler: Callable[
|
|
[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]
|
|
],
|
|
) -> ModelResponse[ResponseT]:
|
|
response = await handler(request)
|
|
|
|
if not callable(self.on_usage):
|
|
return response
|
|
|
|
try:
|
|
ai_message = next(
|
|
(
|
|
message
|
|
for message in reversed(response.result)
|
|
if isinstance(message, AIMessage)
|
|
),
|
|
None,
|
|
)
|
|
usage = (
|
|
self._extract_usage(ai_message)
|
|
if ai_message
|
|
else {
|
|
"has_usage": False,
|
|
"input_tokens": 0,
|
|
"output_tokens": 0,
|
|
"total_tokens": 0,
|
|
}
|
|
)
|
|
context_window_tokens = self._extract_context_window_tokens(request.model)
|
|
context_usage_ratio = None
|
|
if context_window_tokens and usage["has_usage"]:
|
|
context_usage_ratio = usage["input_tokens"] / context_window_tokens
|
|
|
|
self.on_usage(
|
|
{
|
|
"model": self._extract_model_name(request.model),
|
|
"context_window_tokens": context_window_tokens,
|
|
"context_usage_ratio": context_usage_ratio,
|
|
**usage,
|
|
}
|
|
)
|
|
except Exception as e:
|
|
logger.debug("记录模型 usage 失败: %s", e)
|
|
|
|
return response
|
|
|
|
|
|
__all__ = ["UsageMiddleware"]
|