diff --git a/app/agent/__init__.py b/app/agent/__init__.py index 6f9ed667..b4afc341 100644 --- a/app/agent/__init__.py +++ b/app/agent/__init__.py @@ -4,7 +4,8 @@ import re import traceback import uuid from dataclasses import dataclass -from typing import Callable, Dict, List, Optional +from datetime import datetime +from typing import Any, Callable, Dict, List, Optional from langchain.agents import create_agent from langchain.agents.middleware import ( @@ -24,6 +25,7 @@ from app.agent.middleware.jobs import JobsMiddleware from app.agent.middleware.memory import MemoryMiddleware from app.agent.middleware.patch_tool_calls import PatchToolCallsMiddleware from app.agent.middleware.skills import SkillsMiddleware +from app.agent.middleware.usage import UsageMiddleware from app.agent.prompt import prompt_manager from app.agent.tools.factory import MoviePilotToolFactory from app.chain import ChainBase @@ -41,6 +43,39 @@ class AgentChain(ChainBase): pass +@dataclass +class _SessionUsageSnapshot: + model: Optional[str] = None + context_window_tokens: Optional[int] = None + last_input_tokens: int = 0 + last_output_tokens: int = 0 + last_total_tokens: int = 0 + last_context_usage_ratio: Optional[float] = None + total_input_tokens: int = 0 + total_output_tokens: int = 0 + total_tokens: int = 0 + model_call_count: int = 0 + last_updated_at: Optional[datetime] = None + + def to_dict(self, session_id: str) -> dict[str, Any]: + return { + "session_id": session_id, + "model": self.model, + "context_window_tokens": self.context_window_tokens, + "last_input_tokens": self.last_input_tokens, + "last_output_tokens": self.last_output_tokens, + "last_total_tokens": self.last_total_tokens, + "last_context_usage_ratio": self.last_context_usage_ratio, + "total_input_tokens": self.total_input_tokens, + "total_output_tokens": self.total_output_tokens, + "total_tokens": self.total_tokens, + "model_call_count": self.model_call_count, + "last_updated_at": self.last_updated_at.strftime("%Y-%m-%d %H:%M:%S") + if self.last_updated_at + else None, + } + + class _ThinkTagStripper: """ 流式剥离 ... 标签的辅助类。 @@ -138,10 +173,92 @@ class MoviePilotAgent: self.force_streaming = False self.suppress_user_reply = False self._streamed_output = "" + self._session_usage = _SessionUsageSnapshot() # 流式token管理 self.stream_handler = StreamingHandler() + @staticmethod + def _coerce_int(value: Any) -> Optional[int]: + if value is None: + return None + try: + return int(value) + except (TypeError, ValueError): + return None + + @classmethod + def _get_model_name(cls, llm: Any) -> Optional[str]: + return ( + getattr(llm, "model", None) + or getattr(llm, "model_name", None) + or getattr(llm, "model_id", None) + ) + + @classmethod + def _get_context_window_tokens(cls, llm: Any) -> Optional[int]: + profile = getattr(llm, "profile", None) + if not profile: + return None + if isinstance(profile, dict): + return cls._coerce_int( + profile.get("max_input_tokens") or profile.get("input_token_limit") + ) + return cls._coerce_int( + getattr(profile, "max_input_tokens", None) + or getattr(profile, "input_token_limit", None) + ) + + def _sync_model_profile(self, llm: Any) -> None: + model_name = self._get_model_name(llm) + context_window_tokens = self._get_context_window_tokens(llm) + if model_name: + self._session_usage.model = model_name + if context_window_tokens: + self._session_usage.context_window_tokens = context_window_tokens + + def _record_usage(self, usage: dict[str, Any]) -> None: + if not usage: + return + + model_name = usage.get("model") + context_window_tokens = self._coerce_int(usage.get("context_window_tokens")) + if model_name: + self._session_usage.model = model_name + if context_window_tokens: + self._session_usage.context_window_tokens = context_window_tokens + + self._session_usage.model_call_count += 1 + self._session_usage.last_updated_at = datetime.now() + + if not usage.get("has_usage"): + return + + input_tokens = self._coerce_int(usage.get("input_tokens")) or 0 + output_tokens = self._coerce_int(usage.get("output_tokens")) or 0 + total_tokens = self._coerce_int(usage.get("total_tokens")) + if total_tokens is None: + total_tokens = input_tokens + output_tokens + + self._session_usage.last_input_tokens = input_tokens + self._session_usage.last_output_tokens = output_tokens + self._session_usage.last_total_tokens = total_tokens + self._session_usage.last_context_usage_ratio = usage.get("context_usage_ratio") + self._session_usage.total_input_tokens += input_tokens + self._session_usage.total_output_tokens += output_tokens + self._session_usage.total_tokens += total_tokens + + def get_session_status(self) -> dict[str, Any]: + if not self._session_usage.model: + self._session_usage.model = settings.LLM_MODEL + if not self._session_usage.context_window_tokens: + self._session_usage.context_window_tokens = ( + settings.LLM_MAX_CONTEXT_TOKENS * 1000 + if settings.LLM_MAX_CONTEXT_TOKENS + else None + ) + return self._session_usage.to_dict(self.session_id) + @property def is_background(self) -> bool: """ @@ -258,6 +375,7 @@ class MoviePilotAgent: # LLM 模型(用于 agent 执行) llm = self._initialize_llm(streaming=streaming) + self._sync_model_profile(llm) # 工具列表 tools = self._initialize_tools() @@ -279,6 +397,8 @@ class MoviePilotAgent: ActivityLogMiddleware( activity_dir=str(settings.CONFIG_PATH / "agent" / "activity"), ), + # 用量统计 + UsageMiddleware(on_usage=self._record_usage), # 上下文压缩 SummarizationMiddleware(model=llm, trigger=("fraction", 0.85)), # 错误工具调用修复 @@ -608,6 +728,37 @@ class AgentManager: # 重试整理缓冲区锁 self._retry_transfer_lock = asyncio.Lock() + def get_session_status(self, session_id: str) -> dict[str, Any]: + """获取会话当前模型与 token 使用状态。""" + agent = self.active_agents.get(session_id) + if agent: + status = agent.get_session_status() + else: + status = { + "session_id": session_id, + "model": settings.LLM_MODEL, + "context_window_tokens": settings.LLM_MAX_CONTEXT_TOKENS * 1000 + if settings.LLM_MAX_CONTEXT_TOKENS + else None, + "last_input_tokens": 0, + "last_output_tokens": 0, + "last_total_tokens": 0, + "last_context_usage_ratio": None, + "total_input_tokens": 0, + "total_output_tokens": 0, + "total_tokens": 0, + "model_call_count": 0, + "last_updated_at": None, + } + + queue = self._session_queues.get(session_id) + status["pending_messages"] = queue.qsize() if queue else 0 + status["is_processing"] = ( + session_id in self._session_workers + and not self._session_workers[session_id].done() + ) + return status + @staticmethod async def initialize(): """ diff --git a/app/agent/middleware/usage.py b/app/agent/middleware/usage.py new file mode 100644 index 00000000..91f1b3ad --- /dev/null +++ b/app/agent/middleware/usage.py @@ -0,0 +1,184 @@ +from collections.abc import Awaitable, Callable +from typing import Any + +from langchain.agents.middleware.types import ( + AgentMiddleware, + ContextT, + ModelRequest, + ModelResponse, + ResponseT, +) +from langchain_core.messages import AIMessage + +from app.log import logger + + +class UsageMiddleware(AgentMiddleware): + """记录模型调用 usage 信息并回传给外部会话。""" + + def __init__( + self, + *, + on_usage: Callable[[dict[str, Any]], None] | None = None, + ) -> None: + self.on_usage = on_usage + + @staticmethod + def _coerce_int(value: Any) -> int | None: + if value is None: + return None + try: + return int(value) + except (TypeError, ValueError): + return None + + @classmethod + def _lookup_int(cls, container: Any, *keys: str) -> int | None: + if not container: + return None + + getter = getattr(container, "get", None) + if callable(getter): + for key in keys: + value = getter(key) + if value is not None: + return cls._coerce_int(value) + + for key in keys: + value = getattr(container, key, None) + if value is not None: + return cls._coerce_int(value) + + return None + + @classmethod + def _extract_model_name(cls, model: Any) -> str | None: + return ( + getattr(model, "model", None) + or getattr(model, "model_name", None) + or getattr(model, "model_id", None) + ) + + @classmethod + def _extract_context_window_tokens(cls, model: Any) -> int | None: + profile = getattr(model, "profile", None) + if not profile: + return None + return cls._lookup_int(profile, "max_input_tokens", "input_token_limit") + + @classmethod + def _extract_usage(cls, ai_message: AIMessage) -> dict[str, Any]: + usage_metadata = getattr(ai_message, "usage_metadata", None) + + input_tokens = cls._lookup_int(usage_metadata, "input_tokens") + output_tokens = cls._lookup_int(usage_metadata, "output_tokens") + total_tokens = cls._lookup_int(usage_metadata, "total_tokens") + + response_metadata = getattr(ai_message, "response_metadata", None) or {} + token_usage = ( + response_metadata.get("token_usage") + or response_metadata.get("usage") + or response_metadata.get("usage_metadata") + or {} + ) + + if input_tokens is None: + input_tokens = cls._lookup_int( + token_usage, + "prompt_tokens", + "input_tokens", + ) + if input_tokens is None: + input_tokens = cls._lookup_int( + response_metadata, + "prompt_token_count", + "input_tokens", + ) + + if output_tokens is None: + output_tokens = cls._lookup_int( + token_usage, + "completion_tokens", + "output_tokens", + ) + if output_tokens is None: + output_tokens = cls._lookup_int( + response_metadata, + "candidates_token_count", + "output_tokens", + ) + + if total_tokens is None: + total_tokens = cls._lookup_int(token_usage, "total_tokens") + if total_tokens is None: + total_tokens = cls._lookup_int(response_metadata, "total_token_count") + + has_usage = any( + value is not None for value in (input_tokens, output_tokens, total_tokens) + ) + resolved_input = input_tokens or 0 + resolved_output = output_tokens or 0 + resolved_total = ( + total_tokens + if total_tokens is not None + else resolved_input + resolved_output + ) + + return { + "has_usage": has_usage, + "input_tokens": resolved_input, + "output_tokens": resolved_output, + "total_tokens": resolved_total, + } + + async def awrap_model_call( + self, + request: ModelRequest[ContextT], + handler: Callable[ + [ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]] + ], + ) -> ModelResponse[ResponseT]: + response = await handler(request) + + if not callable(self.on_usage): + return response + + try: + ai_message = next( + ( + message + for message in reversed(response.result) + if isinstance(message, AIMessage) + ), + None, + ) + usage = ( + self._extract_usage(ai_message) + if ai_message + else { + "has_usage": False, + "input_tokens": 0, + "output_tokens": 0, + "total_tokens": 0, + } + ) + context_window_tokens = self._extract_context_window_tokens(request.model) + context_usage_ratio = None + if context_window_tokens and usage["has_usage"]: + context_usage_ratio = usage["input_tokens"] / context_window_tokens + + self.on_usage( + { + "model": self._extract_model_name(request.model), + "context_window_tokens": context_window_tokens, + "context_usage_ratio": context_usage_ratio, + **usage, + } + ) + except Exception as e: + logger.debug("记录模型 usage 失败: %s", e) + + return response + + +__all__ = ["UsageMiddleware"] diff --git a/app/chain/message.py b/app/chain/message.py index db749994..31ec9342 100644 --- a/app/chain/message.py +++ b/app/chain/message.py @@ -786,6 +786,74 @@ class MessageChain(ChainBase): ) ) + @staticmethod + def _format_token_count(value: Optional[int]) -> str: + return f"{value:,}" if value is not None else "未知" + + @classmethod + def _format_session_status_text(cls, status: Dict[str, Any]) -> str: + context_window_tokens = status.get("context_window_tokens") + last_input_tokens = status.get("last_input_tokens") + if context_window_tokens and status.get("model_call_count"): + context_ratio = status.get("last_context_usage_ratio") + if context_ratio is None and last_input_tokens is not None: + context_ratio = last_input_tokens / context_window_tokens + context_usage_text = ( + f"{cls._format_token_count(last_input_tokens)} / " + f"{cls._format_token_count(context_window_tokens)} " + f"({context_ratio * 100:.2f}%)" + if context_ratio is not None + else f"{cls._format_token_count(last_input_tokens)} / " + f"{cls._format_token_count(context_window_tokens)}" + ) + else: + context_usage_text = "暂无模型调用数据" + + lines = [ + f"会话ID: {status.get('session_id') or '未知'}", + f"执行状态: {'运行中' if status.get('is_processing') else '空闲'}", + f"当前模型: {status.get('model') or '未知'}", + f"上下文窗口: {cls._format_token_count(context_window_tokens)} tokens", + f"最近一次上下文占用: {context_usage_text}", + f"最近一次 tokens: 输入 {cls._format_token_count(status.get('last_input_tokens'))} / 输出 {cls._format_token_count(status.get('last_output_tokens'))} / 总计 {cls._format_token_count(status.get('last_total_tokens'))}", + f"当前会话累计 tokens: 输入 {cls._format_token_count(status.get('total_input_tokens'))} / 输出 {cls._format_token_count(status.get('total_output_tokens'))} / 总计 {cls._format_token_count(status.get('total_tokens'))}", + f"模型调用次数: {status.get('model_call_count', 0)}", + f"排队消息数: {status.get('pending_messages', 0)}", + f"最后更新: {status.get('last_updated_at') or '暂无'}", + ] + return "\n".join(lines) + + def remote_session_status( + self, + channel: MessageChannel, + userid: Union[str, int], + source: Optional[str] = None, + ): + """查询当前用户的智能体会话状态。""" + session_info = self._user_sessions.get(userid) + if not session_info: + self.post_message( + Notification( + channel=channel, + source=source, + title="您当前没有活跃的智能体会话", + userid=userid, + ) + ) + return + + session_id, _ = session_info + status = agent_manager.get_session_status(session_id=session_id) + self.post_message( + Notification( + channel=channel, + source=source, + title="当前智能体会话状态", + text=self._format_session_status_text(status), + userid=userid, + ) + ) + def _handle_ai_message( self, text: str, @@ -857,7 +925,12 @@ class MessageChain(ChainBase): return elif images: image_attachments = self._build_image_attachments(images) - if original_images and not image_attachments and not user_message and not files: + if ( + original_images + and not image_attachments + and not user_message + and not files + ): self.post_message( Notification( channel=channel, @@ -940,42 +1013,58 @@ class MessageChain(ChainBase): filename = "input.mp3" elif audio_ref.startswith("wxwork://voice_media_id/"): content = self.run_module( - "download_wechat_media_bytes", media_ref=audio_ref, source=source + "download_wechat_media_bytes", + media_ref=audio_ref, + source=source, ) filename = "input.amr" elif audio_ref.startswith("slack://file/"): content = self.run_module( "download_slack_file_bytes", file_ref=audio_ref, source=source ) - filename = self._guess_audio_filename(audio_ref, default="input.ogg") + filename = self._guess_audio_filename( + audio_ref, default="input.ogg" + ) elif audio_ref.startswith("discord://file/"): content = self.run_module( "download_discord_file_bytes", file_ref=audio_ref, source=source ) - filename = self._guess_audio_filename(audio_ref, default="input.ogg") + filename = self._guess_audio_filename( + audio_ref, default="input.ogg" + ) elif audio_ref.startswith("qq://file/"): content = self.run_module( "download_qq_file_bytes", file_ref=audio_ref, source=source ) - filename = self._guess_audio_filename(audio_ref, default="input.ogg") + filename = self._guess_audio_filename( + audio_ref, default="input.ogg" + ) elif audio_ref.startswith("vocechat://file/"): content = self.run_module( - "download_vocechat_file_bytes", file_ref=audio_ref, source=source + "download_vocechat_file_bytes", + file_ref=audio_ref, + source=source, + ) + filename = self._guess_audio_filename( + audio_ref, default="input.ogg" ) - filename = self._guess_audio_filename(audio_ref, default="input.ogg") elif audio_ref.startswith("synology://file/"): content = self.run_module( "download_synologychat_file_bytes", file_ref=audio_ref, source=source, ) - filename = self._guess_audio_filename(audio_ref, default="input.ogg") + filename = self._guess_audio_filename( + audio_ref, default="input.ogg" + ) elif audio_ref.startswith("wxbot://voice"): continue elif audio_ref.startswith("http"): resp = RequestUtils(timeout=30).get_res(audio_ref) content = resp.content if resp and resp.content else None - filename = self._guess_audio_filename(audio_ref, default="input.ogg") + filename = self._guess_audio_filename( + audio_ref, default="input.ogg" + ) else: logger.debug( "暂不支持的语音引用: channel=%s, source=%s, ref=%s", @@ -994,7 +1083,9 @@ class MessageChain(ChainBase): ) continue - transcript = VoiceHelper.transcribe_bytes(content=content, filename=filename) + transcript = VoiceHelper.transcribe_bytes( + content=content, filename=filename + ) if transcript: transcripts.append(transcript) logger.info( @@ -1047,7 +1138,9 @@ class MessageChain(ChainBase): elif img.startswith("tg://file_id/"): file_id = img.replace("tg://file_id/", "") base64_data = self.run_module( - "download_telegram_file_to_base64", file_id=file_id, source=source + "download_telegram_file_to_base64", + file_id=file_id, + source=source, ) if base64_data: base64_images.append(f"data:image/jpeg;base64,{base64_data}") diff --git a/app/command.py b/app/command.py index 717e3c65..923f3f2b 100644 --- a/app/command.py +++ b/app/command.py @@ -155,6 +155,12 @@ class Command(metaclass=Singleton): "category": "管理", "data": {}, }, + "/session_status": { + "func": MessageChain().remote_session_status, + "description": "会话状态", + "category": "智能体", + "data": {}, + }, "/skills": { "func": SkillsChain().remote_manage, "description": "管理技能", diff --git a/tests/test_agent_session_status.py b/tests/test_agent_session_status.py new file mode 100644 index 00000000..0fae74bb --- /dev/null +++ b/tests/test_agent_session_status.py @@ -0,0 +1,106 @@ +import asyncio +import unittest +from datetime import datetime +from types import SimpleNamespace +from unittest.mock import patch + +from langchain.agents.middleware.types import ModelRequest, ModelResponse +from langchain_core.messages import AIMessage + +from app.agent.middleware.usage import UsageMiddleware +from app.chain.message import MessageChain +from app.schemas.types import MessageChannel + + +class TestAgentSessionStatus(unittest.TestCase): + def test_usage_middleware_records_usage_metadata(self): + snapshots = [] + middleware = UsageMiddleware(on_usage=snapshots.append) + request = ModelRequest( + model=SimpleNamespace( + model="gpt-4o-mini", profile={"max_input_tokens": 128000} + ), + messages=[], + state={}, + runtime=None, + ) + response = ModelResponse( + result=[ + AIMessage( + content="ok", + usage_metadata={ + "input_tokens": 1200, + "output_tokens": 300, + "total_tokens": 1500, + }, + ) + ] + ) + + async def handler(_: ModelRequest): + return response + + result = asyncio.run(middleware.awrap_model_call(request, handler)) + + self.assertIs(result, response) + self.assertEqual(len(snapshots), 1) + self.assertEqual(snapshots[0]["model"], "gpt-4o-mini") + self.assertEqual(snapshots[0]["context_window_tokens"], 128000) + self.assertEqual(snapshots[0]["input_tokens"], 1200) + self.assertEqual(snapshots[0]["output_tokens"], 300) + self.assertEqual(snapshots[0]["total_tokens"], 1500) + self.assertAlmostEqual(snapshots[0]["context_usage_ratio"], 1200 / 128000) + + def test_remote_session_status_sends_usage_summary(self): + chain = MessageChain() + chain._user_sessions["10001"] = ("session-1", datetime.now()) + status = { + "session_id": "session-1", + "model": "gpt-4o-mini", + "context_window_tokens": 128000, + "last_input_tokens": 1200, + "last_output_tokens": 300, + "last_total_tokens": 1500, + "last_context_usage_ratio": 1200 / 128000, + "total_input_tokens": 4500, + "total_output_tokens": 1500, + "total_tokens": 6000, + "model_call_count": 4, + "last_updated_at": "2026-04-26 12:34:56", + "is_processing": True, + "pending_messages": 2, + } + + with ( + patch( + "app.chain.message.agent_manager.get_session_status", + return_value=status, + ), + patch.object(chain, "post_message") as post_message, + ): + chain.remote_session_status( + channel=MessageChannel.Telegram, + userid="10001", + source="telegram-test", + ) + + notification = post_message.call_args.args[0] + self.assertEqual(notification.title, "当前智能体会话状态") + self.assertIn("session-1", notification.text) + self.assertIn("gpt-4o-mini", notification.text) + self.assertIn("1,200 / 128,000 (0.94%)", notification.text) + self.assertIn("输入 4,500 / 输出 1,500 / 总计 6,000", notification.text) + self.assertIn("运行中", notification.text) + + def test_remote_session_status_handles_missing_session(self): + chain = MessageChain() + + with patch.object(chain, "post_message") as post_message: + chain.remote_session_status( + channel=MessageChannel.Telegram, + userid="10001", + source="telegram-test", + ) + + notification = post_message.call_args.args[0] + self.assertEqual(notification.title, "您当前没有活跃的智能体会话")