Files
MoviePilot/app/agent/__init__.py

488 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import asyncio
from typing import Dict, List, Any, Union
import json
import tiktoken
from langchain.agents import AgentExecutor
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_community.callbacks import get_openai_callback
from langchain_core.chat_history import InMemoryChatMessageHistory
from langchain_core.messages import HumanMessage, AIMessage, ToolCall, ToolMessage, SystemMessage, trim_messages
from langchain_core.runnables import RunnablePassthrough
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain.agents.format_scratchpad.openai_tools import format_to_openai_tool_messages
from langchain.agents.output_parsers.openai_tools import OpenAIToolsAgentOutputParser
from app.agent.callback import StreamingCallbackHandler
from app.agent.memory import conversation_manager
from app.agent.prompt import prompt_manager
from app.agent.tools.factory import MoviePilotToolFactory
from app.chain import ChainBase
from app.core.config import settings
from app.helper.llm import LLMHelper
from app.helper.message import MessageHelper
from app.log import logger
from app.schemas import Notification
class AgentChain(ChainBase):
pass
class MoviePilotAgent:
"""
MoviePilot AI智能体
"""
def __init__(self, session_id: str, user_id: str = None,
channel: str = None, source: str = None, username: str = None):
self.session_id = session_id
self.user_id = user_id
self.channel = channel # 消息渠道
self.source = source # 消息来源
self.username = username # 用户名
# 消息助手
self.message_helper = MessageHelper()
# 回调处理器
self.callback_handler = StreamingCallbackHandler(
session_id=session_id
)
# LLM模型
self.llm = self._initialize_llm()
# 工具
self.tools = self._initialize_tools()
# 提示词模板
self.prompt = self._initialize_prompt()
# Agent执行器
self.agent_executor = self._create_agent_executor()
def _initialize_llm(self):
"""
初始化LLM模型
"""
return LLMHelper.get_llm(streaming=True, callbacks=[self.callback_handler])
def _initialize_tools(self) -> List:
"""
初始化工具列表
"""
return MoviePilotToolFactory.create_tools(
session_id=self.session_id,
user_id=self.user_id,
channel=self.channel,
source=self.source,
username=self.username,
callback_handler=self.callback_handler
)
@staticmethod
def _initialize_session_store() -> Dict[str, InMemoryChatMessageHistory]:
"""
初始化内存存储
"""
return {}
def get_session_history(self, session_id: str) -> InMemoryChatMessageHistory:
"""
获取会话历史
"""
chat_history = InMemoryChatMessageHistory()
messages: List[dict] = conversation_manager.get_recent_messages_for_agent(
session_id=session_id,
user_id=self.user_id
)
if messages:
loaded_messages = []
for msg in messages:
if msg.get("role") == "user":
loaded_messages.append(HumanMessage(content=msg.get("content", "")))
elif msg.get("role") == "agent":
loaded_messages.append(AIMessage(content=msg.get("content", "")))
elif msg.get("role") == "tool_call":
metadata = msg.get("metadata", {})
loaded_messages.append(
AIMessage(
content=msg.get("content", ""),
tool_calls=[
ToolCall(
id=metadata.get("call_id"),
name=metadata.get("tool_name"),
args=metadata.get("parameters"),
)
]
)
)
elif msg.get("role") == "tool_result":
metadata = msg.get("metadata", {})
loaded_messages.append(ToolMessage(
content=msg.get("content", ""),
tool_call_id=metadata.get("call_id", "unknown")
))
elif msg.get("role") == "system":
loaded_messages.append(SystemMessage(content=msg.get("content", "")))
# Validate and fix tool call integrity
validated_messages = self._ensure_tool_call_integrity(loaded_messages)
for msg in validated_messages:
chat_history.add_message(msg)
return chat_history
@staticmethod
def _initialize_prompt() -> ChatPromptTemplate:
"""
初始化提示词模板
"""
try:
prompt_template = ChatPromptTemplate.from_messages([
("system", "{system_prompt}"),
MessagesPlaceholder(variable_name="chat_history"),
("user", "{input}"),
MessagesPlaceholder(variable_name="agent_scratchpad"),
])
logger.info("LangChain提示词模板初始化成功")
return prompt_template
except Exception as e:
logger.error(f"初始化提示词失败: {e}")
raise e
@staticmethod
def _token_counter(messages: List[Union[HumanMessage, AIMessage, ToolMessage, SystemMessage]]) -> int:
"""
通用的Token计数器
"""
try:
# 尝试从模型获取编码集,如果失败则回退到 cl100k_base (大多数现代模型使用的编码)
try:
encoding = tiktoken.encoding_for_model(settings.LLM_MODEL)
except KeyError:
encoding = tiktoken.get_encoding("cl100k_base")
num_tokens = 0
for message in messages:
# 基础开销 (每个消息大约 3 个 token)
num_tokens += 3
# 1. 处理文本内容 (content)
if isinstance(message.content, str):
num_tokens += len(encoding.encode(message.content))
elif isinstance(message.content, list):
for part in message.content:
if isinstance(part, dict) and part.get("type") == "text":
num_tokens += len(encoding.encode(part.get("text", "")))
# 2. 处理工具调用 (仅 AIMessage 包含 tool_calls)
if getattr(message, "tool_calls", None):
for tool_call in message.tool_calls:
# 函数名
num_tokens += len(encoding.encode(tool_call.get("name", "")))
# 参数 (转为 JSON 估算)
args_str = json.dumps(tool_call.get("args", {}), ensure_ascii=False)
num_tokens += len(encoding.encode(args_str))
# 额外的结构开销 (ID 等)
num_tokens += 3
# 3. 处理角色权重
num_tokens += 1
# 加上回复的起始 Token (大约 3 个 token)
num_tokens += 3
return num_tokens
except Exception as e:
logger.error(f"Token计数失败: {e}")
# 发生错误时返回一个保守的估算值
return len(str(messages)) // 4
def _ensure_tool_call_integrity(self, messages: List[Union[HumanMessage, AIMessage, ToolMessage, SystemMessage]]) \
-> List[Union[HumanMessage, AIMessage, ToolMessage, SystemMessage]]:
"""
Ensure tool call integrity:
1. AIMessage with tool_calls must be followed by corresponding ToolMessages
2. Remove incomplete AIMessages (with tool_calls but missing ToolMessage responses) and their partial ToolMessages
3. Filter out orphaned ToolMessages that don't correspond to any tool_call
"""
if not messages:
return messages
validated_messages = []
i = 0
while i < len(messages):
msg = messages[i]
# Check if this is an AIMessage with tool_calls
if isinstance(msg, AIMessage) and getattr(msg, 'tool_calls', None):
# Extract tool_call IDs (ToolCall is a TypedDict, so it's a dict at runtime)
tool_call_ids = {tc.get('id') or tc.get('name') for tc in msg.tool_calls}
# Find subsequent ToolMessages
j = i + 1
found_tool_messages = []
while j < len(messages) and isinstance(messages[j], ToolMessage):
found_tool_messages.append(messages[j])
j += 1
# Check if all tool_calls have corresponding ToolMessages
found_tool_call_ids = {tm.tool_call_id for tm in found_tool_messages}
if not tool_call_ids.issubset(found_tool_call_ids):
# Missing some tool_call responses, skip this AIMessage and all its ToolMessages
logger.warning("Removing incomplete tool_call AIMessage and its partial ToolMessages: missing tool_call responses")
i = j # Skip the AIMessage and all related ToolMessages
continue
else:
# Add the AIMessage and only the ToolMessages that correspond to its tool_calls
validated_messages.append(msg)
for tm in found_tool_messages:
if tm.tool_call_id in tool_call_ids:
validated_messages.append(tm)
i = j
continue
# Skip orphaned ToolMessages (not preceded by an AIMessage with tool_calls)
elif isinstance(msg, ToolMessage):
logger.warning(f"Skipping orphaned ToolMessage with tool_call_id: {msg.tool_call_id}")
i += 1
continue
else:
validated_messages.append(msg)
i += 1
return validated_messages
def _create_agent_executor(self) -> RunnableWithMessageHistory:
"""
创建Agent执行器
"""
try:
# 消息裁剪器,防止上下文超出限制
base_trimmer = trim_messages(
max_tokens=settings.LLM_MAX_CONTEXT_TOKENS * 1000 * 0.8,
strategy="last",
token_counter=self._token_counter,
include_system=True,
allow_partial=False,
start_on="human",
)
# 包装trimmer在裁剪后验证工具调用的完整性
def validated_trimmer(messages):
trimmed = base_trimmer(messages)
return self._ensure_tool_call_integrity(trimmed)
# 创建Agent执行链
agent = (
RunnablePassthrough.assign(
agent_scratchpad=lambda x: format_to_openai_tool_messages(
x["intermediate_steps"]
)
)
| self.prompt
| validated_trimmer
| self.llm.bind_tools(self.tools)
| OpenAIToolsAgentOutputParser()
)
executor = AgentExecutor(
agent=agent,
tools=self.tools,
verbose=settings.LLM_VERBOSE,
max_iterations=settings.LLM_MAX_ITERATIONS,
return_intermediate_steps=True,
handle_parsing_errors=True,
early_stopping_method="force"
)
return RunnableWithMessageHistory(
executor,
self.get_session_history,
input_messages_key="input",
history_messages_key="chat_history"
)
except Exception as e:
logger.error(f"创建Agent执行器失败: {e}")
raise e
async def process_message(self, message: str) -> str:
"""
处理用户消息
"""
try:
# 添加用户消息到记忆
await conversation_manager.add_conversation(
self.session_id,
user_id=self.user_id,
role="user",
content=message
)
# 构建输入上下文
input_context = {
"system_prompt": prompt_manager.get_agent_prompt(channel=self.channel),
"input": message
}
# 执行Agent
logger.info(f"Agent执行推理: session_id={self.session_id}, input={message}")
result = await self._execute_agent(input_context)
# 获取Agent回复
agent_message = await self.callback_handler.get_message()
# 发送Agent回复给用户通过原渠道
if agent_message:
# 发送回复
await self.send_agent_message(agent_message)
# 添加Agent回复到记忆
await conversation_manager.add_conversation(
session_id=self.session_id,
user_id=self.user_id,
role="agent",
content=agent_message
)
else:
agent_message = result.get("output") or "很抱歉,智能体出错了,未能生成回复内容。"
await self.send_agent_message(agent_message)
return agent_message
except Exception as e:
error_message = f"处理消息时发生错误: {str(e)}"
logger.error(error_message)
# 发送错误消息给用户(通过原渠道)
await self.send_agent_message(error_message)
return error_message
async def _execute_agent(self, input_context: Dict[str, Any]) -> Dict[str, Any]:
"""
执行LangChain Agent
"""
try:
with get_openai_callback() as cb:
result = await self.agent_executor.ainvoke(
input_context,
config={"configurable": {"session_id": self.session_id}},
callbacks=[self.callback_handler]
)
logger.info(f"LLM调用消耗: \n{cb}")
if cb.total_tokens > 0:
result["token_usage"] = {
"prompt_tokens": cb.prompt_tokens,
"completion_tokens": cb.completion_tokens,
"total_tokens": cb.total_tokens
}
return result
except asyncio.CancelledError:
logger.info(f"Agent执行被取消: session_id={self.session_id}")
return {
"output": "任务已取消",
"intermediate_steps": [],
"token_usage": {}
}
except Exception as e:
logger.error(f"Agent执行失败: {e}")
return {
"output": str(e),
"intermediate_steps": [],
"token_usage": {}
}
async def send_agent_message(self, message: str, title: str = "MoviePilot助手"):
"""
通过原渠道发送消息给用户
"""
await AgentChain().async_post_message(
Notification(
channel=self.channel,
source=self.source,
userid=self.user_id,
username=self.username,
title=title,
text=message
)
)
async def cleanup(self):
"""
清理智能体资源
"""
logger.info(f"MoviePilot智能体已清理: session_id={self.session_id}")
class AgentManager:
"""
AI智能体管理器
"""
def __init__(self):
self.active_agents: Dict[str, MoviePilotAgent] = {}
@staticmethod
async def initialize():
"""
初始化管理器
"""
await conversation_manager.initialize()
async def close(self):
"""
关闭管理器
"""
await conversation_manager.close()
# 清理所有活跃的智能体
for agent in self.active_agents.values():
await agent.cleanup()
self.active_agents.clear()
async def process_message(self, session_id: str, user_id: str, message: str,
channel: str = None, source: str = None, username: str = None) -> str:
"""
处理用户消息
"""
# 获取或创建Agent实例
if session_id not in self.active_agents:
logger.info(f"创建新的AI智能体实例session_id: {session_id}, user_id: {user_id}")
agent = MoviePilotAgent(
session_id=session_id,
user_id=user_id,
channel=channel,
source=source,
username=username
)
self.active_agents[session_id] = agent
else:
agent = self.active_agents[session_id]
agent.user_id = user_id # 确保user_id是最新的
# 更新渠道信息
if channel:
agent.channel = channel
if source:
agent.source = source
if username:
agent.username = username
# 处理消息
return await agent.process_message(message)
async def clear_session(self, session_id: str, user_id: str):
"""
清空会话
"""
if session_id in self.active_agents:
agent = self.active_agents[session_id]
await agent.cleanup()
del self.active_agents[session_id]
await conversation_manager.clear_memory(session_id, user_id)
logger.info(f"会话 {session_id} 的记忆已清空")
# 全局智能体管理器实例
agent_manager = AgentManager()