diff --git a/app/agent/__init__.py b/app/agent/__init__.py index 880f6aa1..d1b0b5cb 100644 --- a/app/agent/__init__.py +++ b/app/agent/__init__.py @@ -1,12 +1,17 @@ import asyncio -from typing import Dict, List, Any +from typing import Dict, List, Any, Union +import json +import tiktoken -from langchain.agents import AgentExecutor, create_openai_tools_agent +from langchain.agents import AgentExecutor from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_community.callbacks import get_openai_callback from langchain_core.chat_history import InMemoryChatMessageHistory -from langchain_core.messages import HumanMessage, AIMessage, ToolCall, ToolMessage, SystemMessage +from langchain_core.messages import HumanMessage, AIMessage, ToolCall, ToolMessage, SystemMessage, trim_messages +from langchain_core.runnables import RunnablePassthrough from langchain_core.runnables.history import RunnableWithMessageHistory +from langchain.agents.format_scratchpad.openai_tools import format_to_openai_tool_messages +from langchain.agents.output_parsers.openai_tools import OpenAIToolsAgentOutputParser from app.agent.callback import StreamingCallbackHandler from app.agent.memory import conversation_manager @@ -140,15 +145,77 @@ class MoviePilotAgent: logger.error(f"初始化提示词失败: {e}") raise e + def _token_counter(self, messages: List[Union[HumanMessage, AIMessage, ToolMessage, SystemMessage]]) -> int: + """ + 通用的Token计数器 + """ + try: + # 尝试从模型获取编码集,如果失败则回退到 cl100k_base (大多数现代模型使用的编码) + try: + encoding = tiktoken.encoding_for_model(settings.LLM_MODEL) + except KeyError: + encoding = tiktoken.get_encoding("cl100k_base") + + num_tokens = 0 + for message in messages: + # 基础开销 (每个消息大约 3 个 token) + num_tokens += 3 + + # 1. 处理文本内容 (content) + if isinstance(message.content, str): + num_tokens += len(encoding.encode(message.content)) + elif isinstance(message.content, list): + for part in message.content: + if isinstance(part, dict) and part.get("type") == "text": + num_tokens += len(encoding.encode(part.get("text", ""))) + + # 2. 处理工具调用 (仅 AIMessage 包含 tool_calls) + if getattr(message, "tool_calls", None): + for tool_call in message.tool_calls: + # 函数名 + num_tokens += len(encoding.encode(tool_call.get("name", ""))) + # 参数 (转为 JSON 估算) + args_str = json.dumps(tool_call.get("args", {}), ensure_ascii=False) + num_tokens += len(encoding.encode(args_str)) + # 额外的结构开销 (ID 等) + num_tokens += 3 + + # 3. 处理角色权重 + num_tokens += 1 + + # 加上回复的起始 Token (大约 3 个 token) + num_tokens += 3 + return num_tokens + except Exception as e: + logger.error(f"Token计数失败: {e}") + # 发生错误时返回一个保守的估算值 + return len(str(messages)) // 4 + def _create_agent_executor(self) -> RunnableWithMessageHistory: """ 创建Agent执行器 """ try: - agent = create_openai_tools_agent( - llm=self.llm, - tools=self.tools, - prompt=self.prompt + # 消息裁剪器,防止上下文超出限制 + trimmer = trim_messages( + max_tokens=settings.LLM_MAX_CONTEXT_TOKENS * 1000 * 0.8, + strategy="last", + token_counter=self._token_counter, + include_system=True, + allow_partial=False, + start_on="human", + ) + # 创建Agent执行链 + agent = ( + RunnablePassthrough.assign( + agent_scratchpad=lambda x: format_to_openai_tool_messages( + x["intermediate_steps"] + ) + ) + | self.prompt + | trimmer + | self.llm.bind_tools(self.tools) + | OpenAIToolsAgentOutputParser() ) executor = AgentExecutor( agent=agent, @@ -190,7 +257,7 @@ class MoviePilotAgent: # 执行Agent logger.info(f"Agent执行推理: session_id={self.session_id}, input={message}") - await self._execute_agent(input_context) + result = await self._execute_agent(input_context) # 获取Agent回复 agent_message = await self.callback_handler.get_message() @@ -208,7 +275,7 @@ class MoviePilotAgent: content=agent_message ) else: - agent_message = "很抱歉,智能体出错了,未能生成回复内容。" + agent_message = result.get("output") or "很抱歉,智能体出错了,未能生成回复内容。" await self.send_agent_message(agent_message) return agent_message diff --git a/app/core/config.py b/app/core/config.py index 24b8de3c..c180bcbe 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -427,6 +427,8 @@ class ConfigModel(BaseModel): LLM_API_KEY: Optional[str] = None # LLM基础URL(用于自定义API端点) LLM_BASE_URL: Optional[str] = "https://api.deepseek.com" + # LLM最大上下文Token数量(K) + LLM_MAX_CONTEXT_TOKENS: int = 64 # LLM温度参数 LLM_TEMPERATURE: float = 0.1 # LLM最大迭代次数