feat:Agent上下文裁剪

This commit is contained in:
jxxghp
2026-01-23 22:47:18 +08:00
parent 1b3ae6ab25
commit 7c4d736b54
2 changed files with 78 additions and 9 deletions

View File

@@ -1,12 +1,17 @@
import asyncio
from typing import Dict, List, Any
from typing import Dict, List, Any, Union
import json
import tiktoken
from langchain.agents import AgentExecutor, create_openai_tools_agent
from langchain.agents import AgentExecutor
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_community.callbacks import get_openai_callback
from langchain_core.chat_history import InMemoryChatMessageHistory
from langchain_core.messages import HumanMessage, AIMessage, ToolCall, ToolMessage, SystemMessage
from langchain_core.messages import HumanMessage, AIMessage, ToolCall, ToolMessage, SystemMessage, trim_messages
from langchain_core.runnables import RunnablePassthrough
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain.agents.format_scratchpad.openai_tools import format_to_openai_tool_messages
from langchain.agents.output_parsers.openai_tools import OpenAIToolsAgentOutputParser
from app.agent.callback import StreamingCallbackHandler
from app.agent.memory import conversation_manager
@@ -140,15 +145,77 @@ class MoviePilotAgent:
logger.error(f"初始化提示词失败: {e}")
raise e
def _token_counter(self, messages: List[Union[HumanMessage, AIMessage, ToolMessage, SystemMessage]]) -> int:
"""
通用的Token计数器
"""
try:
# 尝试从模型获取编码集,如果失败则回退到 cl100k_base (大多数现代模型使用的编码)
try:
encoding = tiktoken.encoding_for_model(settings.LLM_MODEL)
except KeyError:
encoding = tiktoken.get_encoding("cl100k_base")
num_tokens = 0
for message in messages:
# 基础开销 (每个消息大约 3 个 token)
num_tokens += 3
# 1. 处理文本内容 (content)
if isinstance(message.content, str):
num_tokens += len(encoding.encode(message.content))
elif isinstance(message.content, list):
for part in message.content:
if isinstance(part, dict) and part.get("type") == "text":
num_tokens += len(encoding.encode(part.get("text", "")))
# 2. 处理工具调用 (仅 AIMessage 包含 tool_calls)
if getattr(message, "tool_calls", None):
for tool_call in message.tool_calls:
# 函数名
num_tokens += len(encoding.encode(tool_call.get("name", "")))
# 参数 (转为 JSON 估算)
args_str = json.dumps(tool_call.get("args", {}), ensure_ascii=False)
num_tokens += len(encoding.encode(args_str))
# 额外的结构开销 (ID 等)
num_tokens += 3
# 3. 处理角色权重
num_tokens += 1
# 加上回复的起始 Token (大约 3 个 token)
num_tokens += 3
return num_tokens
except Exception as e:
logger.error(f"Token计数失败: {e}")
# 发生错误时返回一个保守的估算值
return len(str(messages)) // 4
def _create_agent_executor(self) -> RunnableWithMessageHistory:
"""
创建Agent执行器
"""
try:
agent = create_openai_tools_agent(
llm=self.llm,
tools=self.tools,
prompt=self.prompt
# 消息裁剪器,防止上下文超出限制
trimmer = trim_messages(
max_tokens=settings.LLM_MAX_CONTEXT_TOKENS * 1000 * 0.8,
strategy="last",
token_counter=self._token_counter,
include_system=True,
allow_partial=False,
start_on="human",
)
# 创建Agent执行链
agent = (
RunnablePassthrough.assign(
agent_scratchpad=lambda x: format_to_openai_tool_messages(
x["intermediate_steps"]
)
)
| self.prompt
| trimmer
| self.llm.bind_tools(self.tools)
| OpenAIToolsAgentOutputParser()
)
executor = AgentExecutor(
agent=agent,
@@ -190,7 +257,7 @@ class MoviePilotAgent:
# 执行Agent
logger.info(f"Agent执行推理: session_id={self.session_id}, input={message}")
await self._execute_agent(input_context)
result = await self._execute_agent(input_context)
# 获取Agent回复
agent_message = await self.callback_handler.get_message()
@@ -208,7 +275,7 @@ class MoviePilotAgent:
content=agent_message
)
else:
agent_message = "很抱歉,智能体出错了,未能生成回复内容。"
agent_message = result.get("output") or "很抱歉,智能体出错了,未能生成回复内容。"
await self.send_agent_message(agent_message)
return agent_message

View File

@@ -427,6 +427,8 @@ class ConfigModel(BaseModel):
LLM_API_KEY: Optional[str] = None
# LLM基础URL用于自定义API端点
LLM_BASE_URL: Optional[str] = "https://api.deepseek.com"
# LLM最大上下文Token数量K
LLM_MAX_CONTEXT_TOKENS: int = 64
# LLM温度参数
LLM_TEMPERATURE: float = 0.1
# LLM最大迭代次数