mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-03-31 17:32:30 +08:00
feat(agent): upgrade langchain to v1.0+
This commit is contained in:
@@ -1,26 +1,25 @@
|
||||
import asyncio
|
||||
from typing import Dict, List, Any, Union
|
||||
import json
|
||||
import tiktoken
|
||||
from time import strftime
|
||||
from typing import Dict, List
|
||||
|
||||
from langchain.agents import AgentExecutor
|
||||
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from langchain_community.callbacks import get_openai_callback
|
||||
from langchain_core.chat_history import InMemoryChatMessageHistory
|
||||
from langchain_core.messages import HumanMessage, AIMessage, ToolCall, ToolMessage, SystemMessage, trim_messages
|
||||
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
|
||||
from langchain_core.runnables.history import RunnableWithMessageHistory
|
||||
from langchain.agents.format_scratchpad.openai_tools import format_to_openai_tool_messages
|
||||
from langchain.agents.output_parsers.openai_tools import OpenAIToolsAgentOutputParser
|
||||
from langchain.agents import create_agent
|
||||
from langchain.agents.middleware import (
|
||||
SummarizationMiddleware,
|
||||
ModelRetryMiddleware,
|
||||
ToolRetryMiddleware,
|
||||
)
|
||||
from langchain_core.messages import (
|
||||
HumanMessage,
|
||||
BaseMessage,
|
||||
)
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
from app.agent.callback import StreamingCallbackHandler
|
||||
from app.agent.memory import conversation_manager
|
||||
from app.agent.callback import StreamingHandler, StreamingHandler
|
||||
from app.agent.memory import memory_manager
|
||||
from app.agent.prompt import prompt_manager
|
||||
from app.agent.tools.factory import MoviePilotToolFactory
|
||||
from app.chain import ChainBase
|
||||
from app.core.config import settings
|
||||
from app.helper.llm import LLMHelper
|
||||
from app.helper.message import MessageHelper
|
||||
from app.log import logger
|
||||
from app.schemas import Notification
|
||||
|
||||
@@ -31,42 +30,32 @@ class AgentChain(ChainBase):
|
||||
|
||||
class MoviePilotAgent:
|
||||
"""
|
||||
MoviePilot AI智能体
|
||||
MoviePilot AI智能体(基于 LangChain v1 + LangGraph)
|
||||
"""
|
||||
|
||||
def __init__(self, session_id: str, user_id: str = None,
|
||||
channel: str = None, source: str = None, username: str = None):
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str,
|
||||
user_id: str = None,
|
||||
channel: str = None,
|
||||
source: str = None,
|
||||
username: str = None,
|
||||
):
|
||||
self.session_id = session_id
|
||||
self.user_id = user_id
|
||||
self.channel = channel # 消息渠道
|
||||
self.source = source # 消息来源
|
||||
self.username = username # 用户名
|
||||
self.channel = channel
|
||||
self.source = source
|
||||
self.username = username
|
||||
|
||||
# 消息助手
|
||||
self.message_helper = MessageHelper()
|
||||
# 流式token管理
|
||||
self.stream_handler = StreamingHandler()
|
||||
|
||||
# 回调处理器
|
||||
self.callback_handler = StreamingCallbackHandler(
|
||||
session_id=session_id
|
||||
)
|
||||
|
||||
# LLM模型
|
||||
self.llm = self._initialize_llm()
|
||||
|
||||
# 工具
|
||||
self.tools = self._initialize_tools()
|
||||
|
||||
# 提示词模板
|
||||
self.prompt = self._initialize_prompt()
|
||||
|
||||
# Agent执行器
|
||||
self.agent_executor = self._create_agent_executor()
|
||||
|
||||
def _initialize_llm(self):
|
||||
@staticmethod
|
||||
def _initialize_llm():
|
||||
"""
|
||||
初始化LLM模型
|
||||
初始化 LLM(带流式回调)
|
||||
"""
|
||||
return LLMHelper.get_llm(streaming=True, callbacks=[self.callback_handler])
|
||||
return LLMHelper.get_llm(streaming=True)
|
||||
|
||||
def _initialize_tools(self) -> List:
|
||||
"""
|
||||
@@ -78,384 +67,118 @@ class MoviePilotAgent:
|
||||
channel=self.channel,
|
||||
source=self.source,
|
||||
username=self.username,
|
||||
callback_handler=self.callback_handler
|
||||
stream_handler=self.stream_handler,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _initialize_session_store() -> Dict[str, InMemoryChatMessageHistory]:
|
||||
def _create_agent(self):
|
||||
"""
|
||||
初始化内存存储
|
||||
"""
|
||||
return {}
|
||||
|
||||
def get_session_history(self, session_id: str) -> InMemoryChatMessageHistory:
|
||||
"""
|
||||
获取会话历史
|
||||
"""
|
||||
chat_history = InMemoryChatMessageHistory()
|
||||
messages: List[dict] = conversation_manager.get_recent_messages_for_agent(
|
||||
session_id=session_id,
|
||||
user_id=self.user_id
|
||||
)
|
||||
if messages:
|
||||
for msg in messages:
|
||||
if msg.get("role") == "user":
|
||||
chat_history.add_message(HumanMessage(content=msg.get("content", "")))
|
||||
elif msg.get("role") == "agent":
|
||||
chat_history.add_message(AIMessage(content=msg.get("content", "")))
|
||||
elif msg.get("role") == "tool_call":
|
||||
metadata = msg.get("metadata", {})
|
||||
chat_history.add_message(
|
||||
AIMessage(
|
||||
content=msg.get("content", ""),
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
id=metadata.get("call_id"),
|
||||
name=metadata.get("tool_name"),
|
||||
args=metadata.get("parameters"),
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
elif msg.get("role") == "tool_result":
|
||||
metadata = msg.get("metadata", {})
|
||||
chat_history.add_message(ToolMessage(
|
||||
content=msg.get("content", ""),
|
||||
tool_call_id=metadata.get("call_id", "unknown")
|
||||
))
|
||||
elif msg.get("role") == "system":
|
||||
chat_history.add_message(SystemMessage(content=msg.get("content", "")))
|
||||
|
||||
return chat_history
|
||||
|
||||
@staticmethod
|
||||
def _initialize_prompt() -> ChatPromptTemplate:
|
||||
"""
|
||||
初始化提示词模板
|
||||
创建 LangGraph Agent(使用 create_agent + SummarizationMiddleware)
|
||||
"""
|
||||
try:
|
||||
prompt_template = ChatPromptTemplate.from_messages([
|
||||
("system", "{system_prompt}"),
|
||||
MessagesPlaceholder(variable_name="chat_history"),
|
||||
("user", "{input}"),
|
||||
MessagesPlaceholder(variable_name="agent_scratchpad"),
|
||||
])
|
||||
logger.info("LangChain提示词模板初始化成功")
|
||||
return prompt_template
|
||||
# 系统提示词
|
||||
system_prompt = prompt_manager.get_agent_prompt(
|
||||
channel=self.channel
|
||||
).format(
|
||||
current_date=strftime('%Y-%m-%d')
|
||||
)
|
||||
|
||||
# LLM 模型(用于 agent 执行)
|
||||
llm = self._initialize_llm()
|
||||
|
||||
# 工具列表
|
||||
tools = self._initialize_tools()
|
||||
|
||||
# 中间件
|
||||
middlewares = [
|
||||
# 上下文压缩
|
||||
SummarizationMiddleware(
|
||||
model=llm,
|
||||
trigger=("fraction", 0.85)
|
||||
),
|
||||
# 模型调用失败时自动重试
|
||||
ModelRetryMiddleware(max_retries=3),
|
||||
# 工具调用失败时自动重试
|
||||
ToolRetryMiddleware(max_retries=1)
|
||||
]
|
||||
|
||||
return create_agent(
|
||||
model=llm,
|
||||
tools=tools,
|
||||
system_prompt=system_prompt,
|
||||
middleware=middlewares,
|
||||
checkpointer=InMemorySaver(),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"初始化提示词失败: {e}")
|
||||
logger.error(f"创建 Agent 失败: {e}")
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
def _token_counter(messages: List[Union[HumanMessage, AIMessage, ToolMessage, SystemMessage]]) -> int:
|
||||
async def process(self, message: str) -> str:
|
||||
"""
|
||||
通用的Token计数器
|
||||
处理用户消息,流式推理并返回 Agent 回复
|
||||
"""
|
||||
try:
|
||||
# 尝试从模型获取编码集,如果失败则回退到 cl100k_base (大多数现代模型使用的编码)
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(settings.LLM_MODEL)
|
||||
except KeyError:
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
logger.info(f"Agent推理: session_id={self.session_id}, input={message}")
|
||||
|
||||
num_tokens = 0
|
||||
for message in messages:
|
||||
# 基础开销 (每个消息大约 3 个 token)
|
||||
num_tokens += 3
|
||||
|
||||
# 1. 处理文本内容 (content)
|
||||
if isinstance(message.content, str):
|
||||
num_tokens += len(encoding.encode(message.content))
|
||||
elif isinstance(message.content, list):
|
||||
for part in message.content:
|
||||
if isinstance(part, dict) and part.get("type") == "text":
|
||||
num_tokens += len(encoding.encode(part.get("text", "")))
|
||||
|
||||
# 2. 处理工具调用 (仅 AIMessage 包含 tool_calls)
|
||||
if getattr(message, "tool_calls", None):
|
||||
for tool_call in message.tool_calls:
|
||||
# 函数名
|
||||
num_tokens += len(encoding.encode(tool_call.get("name", "")))
|
||||
# 参数 (转为 JSON 估算)
|
||||
args_str = json.dumps(tool_call.get("args", {}), ensure_ascii=False)
|
||||
num_tokens += len(encoding.encode(args_str))
|
||||
# 额外的结构开销 (ID 等)
|
||||
num_tokens += 3
|
||||
|
||||
# 3. 处理角色权重
|
||||
num_tokens += 1
|
||||
|
||||
# 加上回复的起始 Token (大约 3 个 token)
|
||||
num_tokens += 3
|
||||
return num_tokens
|
||||
except Exception as e:
|
||||
logger.error(f"Token计数失败: {e}")
|
||||
# 发生错误时返回一个保守的估算值
|
||||
return len(str(messages)) // 4
|
||||
|
||||
def _create_agent_executor(self) -> RunnableWithMessageHistory:
|
||||
"""
|
||||
创建Agent执行器
|
||||
"""
|
||||
try:
|
||||
# 消息裁剪器,防止上下文超出限制
|
||||
base_trimmer = trim_messages(
|
||||
max_tokens=settings.LLM_MAX_CONTEXT_TOKENS * 1000 * 0.8,
|
||||
strategy="last",
|
||||
token_counter=self._token_counter,
|
||||
include_system=True,
|
||||
allow_partial=False,
|
||||
start_on="human",
|
||||
)
|
||||
|
||||
# 包装trimmer,在裁剪后验证工具调用的完整性
|
||||
def validated_trimmer(messages):
|
||||
# 如果输入是 PromptValue,转换为消息列表
|
||||
if hasattr(messages, "to_messages"):
|
||||
messages = messages.to_messages()
|
||||
trimmed = base_trimmer.invoke(messages)
|
||||
|
||||
# 二次校验:确保不出现 broken tool chains
|
||||
# 1. AIMessage with tool_calls 必须紧跟着对应的 ToolMessage
|
||||
# 2. ToolMessage 必须有对应的 AIMessage 前置
|
||||
safe_messages = []
|
||||
i = 0
|
||||
while i < len(trimmed):
|
||||
msg = trimmed[i]
|
||||
|
||||
if isinstance(msg, AIMessage) and getattr(msg, "tool_calls", None):
|
||||
# 检查工具调用序列是否完整
|
||||
tool_calls = msg.tool_calls
|
||||
is_valid_sequence = True
|
||||
tool_results = []
|
||||
|
||||
# 向后查找对应的 ToolMessage
|
||||
temp_i = i + 1
|
||||
for tool_call in tool_calls:
|
||||
if temp_i >= len(trimmed):
|
||||
is_valid_sequence = False
|
||||
break
|
||||
|
||||
next_msg = trimmed[temp_i]
|
||||
if isinstance(next_msg, ToolMessage) and next_msg.tool_call_id == tool_call.get("id"):
|
||||
tool_results.append(next_msg)
|
||||
temp_i += 1
|
||||
else:
|
||||
is_valid_sequence = False
|
||||
break
|
||||
|
||||
if is_valid_sequence:
|
||||
# 序列完整,保留消息
|
||||
safe_messages.append(msg)
|
||||
safe_messages.extend(tool_results)
|
||||
i = temp_i # 跳过已处理的工具结果
|
||||
else:
|
||||
# 序列不完整,丢弃该 AIMessage(后续的孤立 ToolMessage 会在下一次循环被当做 orphaned 处理掉)
|
||||
logger.warning(f"移除无效的工具调用链: {len(tool_calls)} calls, incomplete results")
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if isinstance(msg, ToolMessage):
|
||||
# 如果在这里遇到 ToolMessage,说明它没有被上面的逻辑消费,则是孤立的(或者顺序错乱)
|
||||
logger.warning("移除孤立的 ToolMessage")
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# 其他类型的消息直接保留
|
||||
safe_messages.append(msg)
|
||||
i += 1
|
||||
|
||||
if len(safe_messages) < len(messages):
|
||||
logger.info(f"LangChain消息上下文已裁剪: {len(messages)} -> {len(safe_messages)}")
|
||||
return safe_messages
|
||||
|
||||
# 创建Agent执行链
|
||||
agent = (
|
||||
RunnablePassthrough.assign(
|
||||
agent_scratchpad=lambda x: format_to_openai_tool_messages(
|
||||
x["intermediate_steps"]
|
||||
)
|
||||
)
|
||||
| self.prompt
|
||||
| RunnableLambda(validated_trimmer)
|
||||
| self.llm.bind_tools(self.tools)
|
||||
| OpenAIToolsAgentOutputParser()
|
||||
)
|
||||
executor = AgentExecutor(
|
||||
agent=agent,
|
||||
tools=self.tools,
|
||||
verbose=settings.LLM_VERBOSE,
|
||||
max_iterations=settings.LLM_MAX_ITERATIONS,
|
||||
return_intermediate_steps=True,
|
||||
handle_parsing_errors=True,
|
||||
early_stopping_method="force"
|
||||
)
|
||||
return RunnableWithMessageHistory(
|
||||
executor,
|
||||
self.get_session_history,
|
||||
input_messages_key="input",
|
||||
history_messages_key="chat_history"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"创建Agent执行器失败: {e}")
|
||||
raise e
|
||||
|
||||
async def _summarize_history(self):
|
||||
"""
|
||||
总结提炼之前的对话和工具执行情况,并把会话总结变成新的系统提示词取代之前的对话
|
||||
"""
|
||||
try:
|
||||
# 获取当前历史记录
|
||||
chat_history = self.get_session_history(self.session_id)
|
||||
messages = chat_history.messages
|
||||
if not messages:
|
||||
return
|
||||
|
||||
logger.info(f"会话 {self.session_id} 历史消息长度已超过 90%,开始总结并重置上下文...")
|
||||
|
||||
# 将消息转换为摘要所需的文本格式
|
||||
history_text = ""
|
||||
for msg in messages:
|
||||
if isinstance(msg, HumanMessage):
|
||||
history_text += f"用户: {msg.content}\n"
|
||||
elif isinstance(msg, AIMessage):
|
||||
history_text += f"智能体: {msg.content}\n"
|
||||
if getattr(msg, "tool_calls", None):
|
||||
for tool_call in msg.tool_calls:
|
||||
history_text += f"智能体调用工具: {tool_call.get('name')},参数: {tool_call.get('args')}\n"
|
||||
elif isinstance(msg, ToolMessage):
|
||||
history_text += f"工具响应: {msg.content}\n"
|
||||
elif isinstance(msg, SystemMessage):
|
||||
history_text += f"系统: {msg.content}\n"
|
||||
|
||||
# 摘要提示词
|
||||
summary_prompt = (
|
||||
"Please provide a comprehensive and highly informational summary of the preceding conversation and tool executions. "
|
||||
"Your goal is to condense the history while retaining all critical details for future reference. "
|
||||
"Ensure you include:\n"
|
||||
"1. User's core intents, specific requests, and any mentioned preferences.\n"
|
||||
"2. Names of movies, TV shows, or other key entities discussed.\n"
|
||||
"3. A concise log of tool calls made and their specific results/outcomes.\n"
|
||||
"4. The current status of any tasks and any pending actions.\n"
|
||||
"5. Any important context that would be necessary for the agent to continue the conversation seamlessly.\n"
|
||||
"The summary should be dense with information and serve as the primary context for the next stage of the interaction."
|
||||
)
|
||||
|
||||
# 调用 LLM 进行总结 (非流式)
|
||||
summary_llm = LLMHelper.get_llm(streaming=False)
|
||||
response = await summary_llm.ainvoke([
|
||||
SystemMessage(content=summary_prompt),
|
||||
HumanMessage(content=f"Here is the conversation history to summarize:\n{history_text}")
|
||||
])
|
||||
summary_content = str(response.content)
|
||||
|
||||
if not summary_content:
|
||||
logger.warning("总结生成失败,跳过重置逻辑。")
|
||||
return
|
||||
|
||||
# 清空原有的会话记录并插入新的系统总结
|
||||
await conversation_manager.clear_memory(self.session_id, self.user_id)
|
||||
await conversation_manager.add_conversation(
|
||||
# 获取历史消息
|
||||
messages = memory_manager.get_agent_messages(
|
||||
session_id=self.session_id,
|
||||
user_id=self.user_id,
|
||||
role="system",
|
||||
content=f"<history_summary>\n{summary_content}\n</history_summary>"
|
||||
)
|
||||
logger.info(f"会话 {self.session_id} 历史摘要替换完成。")
|
||||
except Exception as e:
|
||||
logger.error(f"执行会话总结出错: {str(e)}")
|
||||
|
||||
async def process_message(self, message: str) -> str:
|
||||
"""
|
||||
处理用户消息
|
||||
"""
|
||||
try:
|
||||
# 检查上下文长度是否超过 90%
|
||||
history = self.get_session_history(self.session_id)
|
||||
if self._token_counter(history.messages) > settings.LLM_MAX_CONTEXT_TOKENS * 1000 * 0.9:
|
||||
await self._summarize_history()
|
||||
|
||||
# 添加用户消息到记忆
|
||||
await conversation_manager.add_conversation(
|
||||
self.session_id,
|
||||
user_id=self.user_id,
|
||||
role="user",
|
||||
content=message
|
||||
user_id=self.user_id
|
||||
)
|
||||
|
||||
# 构建输入上下文
|
||||
input_context = {
|
||||
"system_prompt": prompt_manager.get_agent_prompt(channel=self.channel),
|
||||
"input": message
|
||||
}
|
||||
# 增加用户消息
|
||||
messages.append(HumanMessage(content=message))
|
||||
|
||||
# 执行Agent
|
||||
logger.info(f"Agent执行推理: session_id={self.session_id}, input={message}")
|
||||
|
||||
result = await self._execute_agent(input_context)
|
||||
|
||||
# 获取Agent回复
|
||||
agent_message = await self.callback_handler.get_message()
|
||||
|
||||
# 发送Agent回复给用户(通过原渠道)
|
||||
if agent_message:
|
||||
# 发送回复
|
||||
await self.send_agent_message(agent_message)
|
||||
|
||||
# 添加Agent回复到记忆
|
||||
await conversation_manager.add_conversation(
|
||||
session_id=self.session_id,
|
||||
user_id=self.user_id,
|
||||
role="agent",
|
||||
content=agent_message
|
||||
)
|
||||
else:
|
||||
agent_message = result.get("output") or "很抱歉,智能体出错了,未能生成回复内容。"
|
||||
await self.send_agent_message(agent_message)
|
||||
|
||||
return agent_message
|
||||
# 执行推理
|
||||
await self._execute_agent(messages)
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"处理消息时发生错误: {str(e)}"
|
||||
logger.error(error_message)
|
||||
# 发送错误消息给用户(通过原渠道)
|
||||
await self.send_agent_message(error_message)
|
||||
return error_message
|
||||
|
||||
async def _execute_agent(self, input_context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
async def _execute_agent(self, messages: List[BaseMessage]):
|
||||
"""
|
||||
执行LangChain Agent
|
||||
调用 LangGraph Agent,通过 astream_events 流式获取 token,
|
||||
同时用 UsageMetadataCallbackHandler 统计 token 用量。
|
||||
"""
|
||||
try:
|
||||
with get_openai_callback() as cb:
|
||||
result = await self.agent_executor.ainvoke(
|
||||
input_context,
|
||||
config={"configurable": {"session_id": self.session_id}},
|
||||
callbacks=[self.callback_handler]
|
||||
)
|
||||
logger.info(f"LLM调用消耗: \n{cb}")
|
||||
# Agent运行配置
|
||||
agent_config = {
|
||||
""
|
||||
}
|
||||
|
||||
# 创建智能体
|
||||
agent = self._create_agent()
|
||||
|
||||
# 流式运行智能体
|
||||
async for chunk in agent.astream(
|
||||
{"messages": messages},
|
||||
stream_mode="messages",
|
||||
config=agent_config
|
||||
):
|
||||
token, metadata = chunk
|
||||
if token:
|
||||
self.stream_handler.emit(token.content)
|
||||
|
||||
# 发送最终消息给用户
|
||||
await self.send_agent_message(
|
||||
self.stream_handler.take()
|
||||
)
|
||||
|
||||
# 保存消息
|
||||
memory_manager.save_agent_messages(
|
||||
session_id=self.session_id,
|
||||
user_id=self.user_id,
|
||||
messages=agent.get_state(agent_config).values("messages")
|
||||
)
|
||||
|
||||
if cb.total_tokens > 0:
|
||||
result["token_usage"] = {
|
||||
"prompt_tokens": cb.prompt_tokens,
|
||||
"completion_tokens": cb.completion_tokens,
|
||||
"total_tokens": cb.total_tokens
|
||||
}
|
||||
return result
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"Agent执行被取消: session_id={self.session_id}")
|
||||
return {
|
||||
"output": "任务已取消",
|
||||
"intermediate_steps": [],
|
||||
"token_usage": {}
|
||||
}
|
||||
return "任务已取消", {}
|
||||
except Exception as e:
|
||||
logger.error(f"Agent执行失败: {e}")
|
||||
return {
|
||||
"output": str(e),
|
||||
"intermediate_steps": [],
|
||||
"token_usage": {}
|
||||
}
|
||||
return str(e), {}
|
||||
|
||||
async def send_agent_message(self, message: str, title: str = "MoviePilot助手"):
|
||||
"""
|
||||
@@ -468,7 +191,7 @@ class MoviePilotAgent:
|
||||
userid=self.user_id,
|
||||
username=self.username,
|
||||
title=title,
|
||||
text=message
|
||||
text=message,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -492,38 +215,44 @@ class AgentManager:
|
||||
"""
|
||||
初始化管理器
|
||||
"""
|
||||
await conversation_manager.initialize()
|
||||
await memory_manager.initialize()
|
||||
|
||||
async def close(self):
|
||||
"""
|
||||
关闭管理器
|
||||
"""
|
||||
await conversation_manager.close()
|
||||
# 清理所有活跃的智能体
|
||||
await memory_manager.close()
|
||||
for agent in self.active_agents.values():
|
||||
await agent.cleanup()
|
||||
self.active_agents.clear()
|
||||
|
||||
async def process_message(self, session_id: str, user_id: str, message: str,
|
||||
channel: str = None, source: str = None, username: str = None) -> str:
|
||||
async def process_message(
|
||||
self,
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
message: str,
|
||||
channel: str = None,
|
||||
source: str = None,
|
||||
username: str = None,
|
||||
) -> str:
|
||||
"""
|
||||
处理用户消息
|
||||
"""
|
||||
# 获取或创建Agent实例
|
||||
if session_id not in self.active_agents:
|
||||
logger.info(f"创建新的AI智能体实例,session_id: {session_id}, user_id: {user_id}")
|
||||
logger.info(
|
||||
f"创建新的AI智能体实例,session_id: {session_id}, user_id: {user_id}"
|
||||
)
|
||||
agent = MoviePilotAgent(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
channel=channel,
|
||||
source=source,
|
||||
username=username
|
||||
username=username,
|
||||
)
|
||||
self.active_agents[session_id] = agent
|
||||
else:
|
||||
agent = self.active_agents[session_id]
|
||||
agent.user_id = user_id # 确保user_id是最新的
|
||||
# 更新渠道信息
|
||||
agent.user_id = user_id
|
||||
if channel:
|
||||
agent.channel = channel
|
||||
if source:
|
||||
@@ -531,8 +260,7 @@ class AgentManager:
|
||||
if username:
|
||||
agent.username = username
|
||||
|
||||
# 处理消息
|
||||
return await agent.process_message(message)
|
||||
return await agent.process(message)
|
||||
|
||||
async def clear_session(self, session_id: str, user_id: str):
|
||||
"""
|
||||
@@ -542,7 +270,7 @@ class AgentManager:
|
||||
agent = self.active_agents[session_id]
|
||||
await agent.cleanup()
|
||||
del self.active_agents[session_id]
|
||||
await conversation_manager.clear_memory(session_id, user_id)
|
||||
await memory_manager.clear_memory(session_id, user_id)
|
||||
logger.info(f"会话 {session_id} 的记忆已清空")
|
||||
|
||||
|
||||
|
||||
@@ -1,39 +1,41 @@
|
||||
import threading
|
||||
|
||||
from langchain_core.callbacks import AsyncCallbackHandler
|
||||
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class StreamingCallbackHandler(AsyncCallbackHandler):
|
||||
class StreamingHandler:
|
||||
"""
|
||||
流式输出回调处理器
|
||||
流式Token缓冲管理器
|
||||
|
||||
负责从 LLM 流式 token 中积累文本,供 Agent 在工具调用之间穿插发送中间消息。
|
||||
"""
|
||||
|
||||
def __init__(self, session_id: str):
|
||||
def __init__(self):
|
||||
self._lock = threading.Lock()
|
||||
self.session_id = session_id
|
||||
self.current_message = ""
|
||||
self._buffer = ""
|
||||
|
||||
async def get_message(self):
|
||||
def emit(self, token: str):
|
||||
"""
|
||||
获取当前消息内容,获取后清空
|
||||
接收 LLM 流式 token,积累到缓冲区。
|
||||
"""
|
||||
with self._lock:
|
||||
if not self.current_message:
|
||||
self._buffer += token
|
||||
|
||||
def take(self) -> str:
|
||||
"""
|
||||
获取当前已积累的消息内容,获取后清空缓冲区。
|
||||
"""
|
||||
with self._lock:
|
||||
if not self._buffer:
|
||||
return ""
|
||||
msg = self.current_message
|
||||
logger.info(f"Agent消息: {msg}")
|
||||
self.current_message = ""
|
||||
return msg
|
||||
message = self._buffer
|
||||
logger.info(f"Agent消息: {message}")
|
||||
self._buffer = ""
|
||||
return message
|
||||
|
||||
async def on_llm_new_token(self, token: str, **kwargs):
|
||||
def clear(self):
|
||||
"""
|
||||
处理新的token
|
||||
清空缓冲区(不返回内容)
|
||||
"""
|
||||
if not token:
|
||||
return
|
||||
with self._lock:
|
||||
# 缓存当前消息
|
||||
self.current_message += token
|
||||
|
||||
self._buffer = ""
|
||||
|
||||
@@ -1,17 +1,17 @@
|
||||
"""对话记忆管理器"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Any
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
from app.core.config import settings
|
||||
from app.helper.redis import AsyncRedisHelper
|
||||
from app.log import logger
|
||||
from app.schemas.agent import ConversationMemory
|
||||
|
||||
|
||||
class ConversationMemoryManager:
|
||||
class MemoryManager:
|
||||
"""
|
||||
对话记忆管理器
|
||||
"""
|
||||
@@ -19,18 +19,18 @@ class ConversationMemoryManager:
|
||||
def __init__(self):
|
||||
# 内存中的会话记忆缓存
|
||||
self.memory_cache: Dict[str, ConversationMemory] = {}
|
||||
# 使用现有的Redis助手
|
||||
self.redis_helper = AsyncRedisHelper()
|
||||
# 内存缓存清理任务(Redis通过TTL自动过期)
|
||||
# 内存缓存清理任务
|
||||
self.cleanup_task: Optional[asyncio.Task] = None
|
||||
|
||||
async def initialize(self):
|
||||
def initialize(self):
|
||||
"""
|
||||
初始化记忆管理器
|
||||
"""
|
||||
try:
|
||||
# 启动内存缓存清理任务(Redis通过TTL自动过期)
|
||||
self.cleanup_task = asyncio.create_task(self._cleanup_expired_memories())
|
||||
self.cleanup_task = asyncio.create_task(
|
||||
self._cleanup_expired_memories()
|
||||
)
|
||||
logger.info("对话记忆管理器初始化完成")
|
||||
|
||||
except Exception as e:
|
||||
@@ -47,8 +47,6 @@ class ConversationMemoryManager:
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
await self.redis_helper.close()
|
||||
|
||||
logger.info("对话记忆管理器已关闭")
|
||||
|
||||
@staticmethod
|
||||
@@ -58,258 +56,64 @@ class ConversationMemoryManager:
|
||||
"""
|
||||
return f"{user_id}:{session_id}" if user_id else session_id
|
||||
|
||||
@staticmethod
|
||||
def _get_redis_key(session_id: str, user_id: str):
|
||||
"""
|
||||
计算Redis Key
|
||||
"""
|
||||
return f"agent_memory:{user_id}:{session_id}" if user_id else f"agent_memory:{session_id}"
|
||||
|
||||
def _get_memory(self, session_id: str, user_id: str):
|
||||
def get_memory(self, session_id: str, user_id: str) -> Optional[ConversationMemory]:
|
||||
"""
|
||||
获取内存中的记忆
|
||||
"""
|
||||
cache_key = self._get_memory_key(session_id, user_id)
|
||||
return self.memory_cache.get(cache_key)
|
||||
|
||||
async def _get_redis(self, session_id: str, user_id: str) -> Optional[ConversationMemory]:
|
||||
"""
|
||||
从Redis获取记忆
|
||||
"""
|
||||
if settings.CACHE_BACKEND_TYPE == "redis":
|
||||
try:
|
||||
redis_key = self._get_redis_key(session_id, user_id)
|
||||
memory_data = await self.redis_helper.get(redis_key, region="AI_AGENT")
|
||||
if memory_data:
|
||||
memory_dict = json.loads(memory_data) if isinstance(memory_data, str) else memory_data
|
||||
memory = ConversationMemory(**memory_dict)
|
||||
return memory
|
||||
except Exception as e:
|
||||
logger.warning(f"从Redis加载记忆失败: {e}")
|
||||
return None
|
||||
|
||||
async def get_conversation(self, session_id: str, user_id: str) -> ConversationMemory:
|
||||
"""
|
||||
获取会话记忆
|
||||
"""
|
||||
# 首先检查缓存
|
||||
conversion = self._get_memory(session_id, user_id)
|
||||
if conversion:
|
||||
return conversion
|
||||
|
||||
# 尝试从Redis加载
|
||||
memory = await self._get_redis(session_id, user_id)
|
||||
if memory:
|
||||
# 加载到内存缓存
|
||||
self._save_memory(memory)
|
||||
return memory
|
||||
|
||||
# 创建新的记忆
|
||||
memory = ConversationMemory(session_id=session_id, user_id=user_id)
|
||||
await self._save_conversation(memory)
|
||||
|
||||
return memory
|
||||
|
||||
async def set_title(self, session_id: str, user_id: str, title: str):
|
||||
"""
|
||||
设置会话标题
|
||||
"""
|
||||
memory = await self.get_conversation(session_id=session_id, user_id=user_id)
|
||||
memory.title = title
|
||||
memory.updated_at = datetime.now()
|
||||
await self._save_conversation(memory)
|
||||
|
||||
async def get_title(self, session_id: str, user_id: str) -> Optional[str]:
|
||||
"""
|
||||
获取会话标题
|
||||
"""
|
||||
memory = await self.get_conversation(session_id=session_id, user_id=user_id)
|
||||
return memory.title
|
||||
|
||||
async def list_sessions(self, user_id: str, limit: int = 100) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
列出历史会话摘要(按更新时间倒序)
|
||||
|
||||
- 当启用Redis时:遍历 `agent_memory:*` 键并读取摘要
|
||||
- 当未启用Redis时:基于内存缓存返回
|
||||
"""
|
||||
sessions: List[ConversationMemory] = []
|
||||
# 从Redis遍历
|
||||
if settings.CACHE_BACKEND_TYPE == "redis":
|
||||
try:
|
||||
# 使用Redis助手的items方法遍历所有键
|
||||
async for key, value in self.redis_helper.items(region="AI_AGENT"):
|
||||
if key.startswith("agent_memory:"):
|
||||
try:
|
||||
# 解析键名获取user_id和session_id
|
||||
key_parts = key.split(":")
|
||||
if len(key_parts) >= 3:
|
||||
key_user_id = key_parts[2] if len(key_parts) > 3 else None
|
||||
if not user_id or key_user_id == user_id:
|
||||
data = value if isinstance(value, dict) else json.loads(value)
|
||||
memory = ConversationMemory(**data)
|
||||
sessions.append(memory)
|
||||
except Exception as err:
|
||||
logger.warning(f"解析Redis记忆数据失败: {err}")
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.warning(f"遍历Redis会话失败: {e}")
|
||||
|
||||
# 合并内存缓存(确保包含近期的会话)
|
||||
for cache_key, memory in self.memory_cache.items():
|
||||
# 如果指定了user_id,只返回该用户的会话
|
||||
if not user_id or memory.user_id == user_id:
|
||||
sessions.append(memory)
|
||||
|
||||
# 去重(以 session_id 为键,取最近updated)
|
||||
uniq: Dict[str, ConversationMemory] = {}
|
||||
for mem in sessions:
|
||||
existed = uniq.get(mem.session_id)
|
||||
if (not existed) or (mem.updated_at > existed.updated_at):
|
||||
uniq[mem.session_id] = mem
|
||||
|
||||
# 排序并裁剪
|
||||
sorted_list = sorted(uniq.values(), key=lambda m: m.updated_at, reverse=True)[:limit]
|
||||
return [
|
||||
{
|
||||
"session_id": m.session_id,
|
||||
"title": m.title or "新会话",
|
||||
"message_count": len(m.messages),
|
||||
"created_at": m.created_at.isoformat(),
|
||||
"updated_at": m.updated_at.isoformat(),
|
||||
}
|
||||
for m in sorted_list
|
||||
]
|
||||
|
||||
async def add_conversation(
|
||||
self,
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
role: str,
|
||||
content: str,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
):
|
||||
"""
|
||||
添加消息到记忆
|
||||
"""
|
||||
memory = await self.get_conversation(session_id=session_id, user_id=user_id)
|
||||
|
||||
message = {
|
||||
"role": role,
|
||||
"content": content,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"metadata": metadata or {}
|
||||
}
|
||||
|
||||
memory.messages.append(message)
|
||||
memory.updated_at = datetime.now()
|
||||
|
||||
# 限制消息数量,避免记忆过大
|
||||
max_messages = settings.LLM_MAX_MEMORY_MESSAGES
|
||||
if len(memory.messages) > max_messages:
|
||||
# 保留最近的消息,但保留第一条系统消息
|
||||
system_messages = [msg for msg in memory.messages if msg["role"] == "system"]
|
||||
recent_messages = memory.messages[-(max_messages - len(system_messages)):]
|
||||
memory.messages = system_messages + recent_messages
|
||||
|
||||
await self._save_conversation(memory)
|
||||
|
||||
logger.debug(f"消息已添加到记忆: session_id={session_id}, user_id={user_id}, role={role}")
|
||||
|
||||
def get_recent_messages_for_agent(
|
||||
self,
|
||||
session_id: str,
|
||||
user_id: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
def get_agent_messages(
|
||||
self, session_id: str, user_id: str
|
||||
) -> List[BaseMessage]:
|
||||
"""
|
||||
为Agent获取最近的消息(仅内存缓存)
|
||||
|
||||
如果消息Token数量超过模型最大上下文长度的阀值,会自动进行摘要裁剪
|
||||
"""
|
||||
cache_key = self._get_memory_key(session_id, user_id)
|
||||
memory = self.memory_cache.get(cache_key)
|
||||
memory = self.get_memory(session_id, user_id)
|
||||
if not memory:
|
||||
return []
|
||||
|
||||
# 获取所有消息
|
||||
return memory.messages[:-1]
|
||||
return memory.messages
|
||||
|
||||
async def get_recent_messages(
|
||||
self,
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
limit: int = 10,
|
||||
role_filter: Optional[list] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
def save_agent_messages(
|
||||
self, session_id: str, user_id: str, messages: List[BaseMessage]
|
||||
):
|
||||
"""
|
||||
获取最近的消息
|
||||
保存Agent消息(仅内存缓存)
|
||||
|
||||
注意:Redis中的记忆通过TTL机制自动过期,这里只更新内存缓存,Redis会在下次访问时自动过期
|
||||
"""
|
||||
memory = await self.get_conversation(session_id=session_id, user_id=user_id)
|
||||
memory = self.get_memory(session_id, user_id)
|
||||
if not memory:
|
||||
memory = ConversationMemory(session_id=session_id, user_id=user_id)
|
||||
|
||||
messages = memory.messages
|
||||
if role_filter:
|
||||
messages = [msg for msg in messages if msg["role"] in role_filter]
|
||||
memory.messages = messages
|
||||
memory.updated_at = datetime.now()
|
||||
|
||||
return messages[-limit:] if messages else []
|
||||
# 更新内存缓存
|
||||
self.save_memory(memory)
|
||||
|
||||
async def get_context(self, session_id: str, user_id: str) -> Dict[str, Any]:
|
||||
def save_memory(self, memory: ConversationMemory):
|
||||
"""
|
||||
获取会话上下文
|
||||
"""
|
||||
memory = await self.get_conversation(session_id=session_id, user_id=user_id)
|
||||
return memory.context
|
||||
保存记忆到内存缓存
|
||||
|
||||
async def clear_memory(self, session_id: str, user_id: str):
|
||||
"""
|
||||
清空会话记忆
|
||||
"""
|
||||
cache_key = f"{user_id}:{session_id}" if user_id else session_id
|
||||
if cache_key in self.memory_cache:
|
||||
del self.memory_cache[cache_key]
|
||||
|
||||
if settings.CACHE_BACKEND_TYPE == "redis":
|
||||
redis_key = self._get_redis_key(session_id, user_id)
|
||||
await self.redis_helper.delete(redis_key, region="AI_AGENT")
|
||||
|
||||
logger.info(f"会话记忆已清空: session_id={session_id}, user_id={user_id}")
|
||||
|
||||
def _save_memory(self, memory: ConversationMemory):
|
||||
"""
|
||||
保存记忆到内存
|
||||
注意:Redis中的记忆通过TTL机制自动过期,这里只更新内存缓存,Redis会在下次访问时自动过期
|
||||
"""
|
||||
cache_key = self._get_memory_key(memory.session_id, memory.user_id)
|
||||
self.memory_cache[cache_key] = memory
|
||||
|
||||
async def _save_redis(self, memory: ConversationMemory):
|
||||
def clear_memory(self, session_id: str, user_id: str):
|
||||
"""
|
||||
保存记忆到Redis
|
||||
清空会话记忆
|
||||
"""
|
||||
if settings.CACHE_BACKEND_TYPE == "redis":
|
||||
try:
|
||||
memory_dict = memory.model_dump()
|
||||
redis_key = self._get_redis_key(memory.session_id, memory.user_id)
|
||||
ttl = int(timedelta(days=settings.LLM_REDIS_MEMORY_RETENTION_DAYS).total_seconds())
|
||||
await self.redis_helper.set(
|
||||
redis_key,
|
||||
memory_dict,
|
||||
ttl=ttl,
|
||||
region="AI_AGENT"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"保存记忆到Redis失败: {e}")
|
||||
|
||||
async def _save_conversation(self, memory: ConversationMemory):
|
||||
"""
|
||||
保存记忆到存储
|
||||
|
||||
Redis中的记忆会自动通过TTL机制过期,无需手动清理
|
||||
"""
|
||||
# 更新内存缓存
|
||||
self._save_memory(memory)
|
||||
|
||||
# 保存到Redis,设置TTL自动过期
|
||||
await self._save_redis(memory)
|
||||
cache_key = self._get_memory_key(session_id, user_id)
|
||||
if cache_key in self.memory_cache:
|
||||
del self.memory_cache[cache_key]
|
||||
|
||||
logger.info(f"会话记忆已清空: session_id={session_id}, user_id={user_id}")
|
||||
|
||||
async def _cleanup_expired_memories(self):
|
||||
"""
|
||||
@@ -328,7 +132,9 @@ class ConversationMemoryManager:
|
||||
# 只检查内存缓存中的过期记忆
|
||||
# Redis中的记忆会通过TTL自动过期,无需手动处理
|
||||
for cache_key, memory in self.memory_cache.items():
|
||||
if (current_time - memory.updated_at).days > settings.LLM_MEMORY_RETENTION_DAYS:
|
||||
if (
|
||||
current_time - memory.updated_at
|
||||
).days > settings.LLM_MEMORY_RETENTION_DAYS:
|
||||
expired_sessions.append(cache_key)
|
||||
|
||||
# 只清理内存缓存,不删除Redis中的键(Redis会自动过期)
|
||||
@@ -344,4 +150,5 @@ class ConversationMemoryManager:
|
||||
except Exception as e:
|
||||
logger.error(f"清理记忆时发生错误: {e}")
|
||||
|
||||
conversation_manager = ConversationMemoryManager()
|
||||
|
||||
memory_manager = MemoryManager()
|
||||
|
||||
@@ -69,4 +69,6 @@ At the end of your session/turn, provide a concise summary of your actions.
|
||||
<markdown_spec>
|
||||
Specific markdown rules:
|
||||
{markdown_spec}
|
||||
</markdown_spec>
|
||||
</markdown_spec>
|
||||
|
||||
Today's date: {current_date}
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
import json
|
||||
import uuid
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain.tools import BaseTool
|
||||
from langchain_core.tools import BaseTool
|
||||
from pydantic import PrivateAttr
|
||||
|
||||
from app.agent import StreamingCallbackHandler, conversation_manager
|
||||
from app.agent import StreamingHandler
|
||||
from app.chain import ChainBase
|
||||
from app.log import logger
|
||||
from app.schemas import Notification
|
||||
@@ -18,15 +17,15 @@ class ToolChain(ChainBase):
|
||||
|
||||
class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
"""
|
||||
MoviePilot专用工具基类
|
||||
MoviePilot专用工具基类(LangChain v1 / langchain_core)
|
||||
"""
|
||||
|
||||
_session_id: str = PrivateAttr()
|
||||
_user_id: str = PrivateAttr()
|
||||
_channel: str = PrivateAttr(default=None)
|
||||
_source: str = PrivateAttr(default=None)
|
||||
_username: str = PrivateAttr(default=None)
|
||||
_callback_handler: StreamingCallbackHandler = PrivateAttr(default=None)
|
||||
_channel: Optional[str] = PrivateAttr(default=None)
|
||||
_source: Optional[str] = PrivateAttr(default=None)
|
||||
_username: Optional[str] = PrivateAttr(default=None)
|
||||
_stream_handler: Optional[StreamingHandler] = PrivateAttr(default=None)
|
||||
|
||||
def __init__(self, session_id: str, user_id: str, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
@@ -34,93 +33,70 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
self._user_id = user_id
|
||||
|
||||
def _run(self, *args: Any, **kwargs: Any) -> Any:
|
||||
pass
|
||||
raise NotImplementedError("MoviePilotTool 只支持异步调用,请使用 _arun")
|
||||
|
||||
async def _arun(self, **kwargs) -> str:
|
||||
async def _arun(self, *args: Any, **kwargs: Any) -> str:
|
||||
"""
|
||||
异步运行工具
|
||||
异步运行工具,负责:
|
||||
1. 在工具调用前将流式消息推送给用户
|
||||
2. 持久化工具调用记录到会话记忆
|
||||
3. 调用具体工具逻辑(子类实现的 execute 方法)
|
||||
4. 持久化工具结果到会话记忆
|
||||
"""
|
||||
# 获取工具调用前的agent消息
|
||||
agent_message = await self._callback_handler.get_message()
|
||||
|
||||
# 生成唯一的工具调用ID
|
||||
call_id = f"call_{str(uuid.uuid4())[:16]}"
|
||||
|
||||
# 记忆工具调用
|
||||
await conversation_manager.add_conversation(
|
||||
session_id=self._session_id,
|
||||
user_id=self._user_id,
|
||||
role="tool_call",
|
||||
content=agent_message,
|
||||
metadata={
|
||||
"call_id": call_id,
|
||||
"tool_name": self.name,
|
||||
"parameters": kwargs
|
||||
}
|
||||
# 获取工具调用前 Agent 已积累的流式文本
|
||||
agent_message = (
|
||||
self._stream_handler.take() if self._stream_handler else ""
|
||||
)
|
||||
|
||||
# 获取执行工具说明,优先使用工具自定义的提示消息,如果没有则使用 explanation
|
||||
# 获取工具执行提示消息
|
||||
tool_message = self.get_tool_message(**kwargs)
|
||||
if not tool_message:
|
||||
explanation = kwargs.get("explanation")
|
||||
if explanation:
|
||||
tool_message = explanation
|
||||
|
||||
# 合并agent消息和工具执行消息,一起发送
|
||||
# 合并 Agent 消息和工具执行消息后一起发送
|
||||
messages = []
|
||||
if agent_message:
|
||||
messages.append(agent_message)
|
||||
if tool_message:
|
||||
messages.append(f"⚙️ => {tool_message}")
|
||||
|
||||
# 发送合并后的消息
|
||||
if messages:
|
||||
merged_message = "\n\n".join(messages)
|
||||
await self.send_tool_message(merged_message, title="MoviePilot助手")
|
||||
|
||||
logger.debug(f'Executing tool {self.name} with args: {kwargs}')
|
||||
logger.debug(f"Executing tool {self.name} with args: {kwargs}")
|
||||
|
||||
# 执行工具,捕获异常确保结果总是被存储到记忆中
|
||||
# 执行具体工具逻辑
|
||||
try:
|
||||
result = await self.run(**kwargs)
|
||||
logger.debug(f'Tool {self.name} executed with result: {result}')
|
||||
logger.debug(f"Tool {self.name} executed with result: {result}")
|
||||
except Exception as e:
|
||||
# 记录异常详情
|
||||
error_message = f"工具执行异常 ({type(e).__name__}): {str(e)}"
|
||||
logger.error(f'Tool {self.name} execution failed: {e}', exc_info=True)
|
||||
logger.error(f"Tool {self.name} execution failed: {e}", exc_info=True)
|
||||
result = error_message
|
||||
|
||||
# 记忆工具调用结果
|
||||
# 格式化结果
|
||||
if isinstance(result, str):
|
||||
formated_result = result
|
||||
formatted_result = result
|
||||
elif isinstance(result, (int, float)):
|
||||
formated_result = str(result)
|
||||
formatted_result = str(result)
|
||||
else:
|
||||
formated_result = json.dumps(result, ensure_ascii=False, indent=2)
|
||||
formatted_result = json.dumps(result, ensure_ascii=False, indent=2)
|
||||
|
||||
await conversation_manager.add_conversation(
|
||||
session_id=self._session_id,
|
||||
user_id=self._user_id,
|
||||
role="tool_result",
|
||||
content=formated_result,
|
||||
metadata={
|
||||
"call_id": call_id,
|
||||
"tool_name": self.name,
|
||||
}
|
||||
)
|
||||
|
||||
return result
|
||||
return formatted_result
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""
|
||||
获取工具执行时的友好提示消息
|
||||
|
||||
获取工具执行时的友好提示消息。
|
||||
|
||||
子类可以重写此方法,根据实际参数生成个性化的提示消息。
|
||||
如果返回 None 或空字符串,将回退使用 explanation 参数。
|
||||
|
||||
|
||||
Args:
|
||||
**kwargs: 工具的所有参数(包括 explanation)
|
||||
|
||||
|
||||
Returns:
|
||||
str: 友好的提示消息,如果返回 None 或空字符串则使用 explanation
|
||||
"""
|
||||
@@ -128,6 +104,7 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
|
||||
@abstractmethod
|
||||
async def run(self, **kwargs) -> str:
|
||||
"""子类实现具体的工具执行逻辑"""
|
||||
raise NotImplementedError
|
||||
|
||||
def set_message_attr(self, channel: str, source: str, username: str):
|
||||
@@ -138,11 +115,11 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
self._source = source
|
||||
self._username = username
|
||||
|
||||
def set_callback_handler(self, callback_handler: StreamingCallbackHandler):
|
||||
def set_stream_handler(self, stream_handler: StreamingHandler):
|
||||
"""
|
||||
设置回调处理器
|
||||
"""
|
||||
self._callback_handler = callback_handler
|
||||
self._stream_handler = stream_handler
|
||||
|
||||
async def send_tool_message(self, message: str, title: str = ""):
|
||||
"""
|
||||
@@ -155,6 +132,6 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
userid=self._user_id,
|
||||
username=self._username,
|
||||
title=title,
|
||||
text=message
|
||||
text=message,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -54,7 +54,7 @@ class MoviePilotToolFactory:
|
||||
@staticmethod
|
||||
def create_tools(session_id: str, user_id: str,
|
||||
channel: str = None, source: str = None, username: str = None,
|
||||
callback_handler: Callable = None) -> List[MoviePilotTool]:
|
||||
stream_handler: Callable = None) -> List[MoviePilotTool]:
|
||||
"""
|
||||
创建MoviePilot工具列表
|
||||
"""
|
||||
@@ -109,7 +109,7 @@ class MoviePilotToolFactory:
|
||||
user_id=user_id
|
||||
)
|
||||
tool.set_message_attr(channel=channel, source=source, username=username)
|
||||
tool.set_callback_handler(callback_handler=callback_handler)
|
||||
tool.set_stream_handler(stream_handler=stream_handler)
|
||||
tools.append(tool)
|
||||
|
||||
# 加载插件提供的工具
|
||||
@@ -131,7 +131,7 @@ class MoviePilotToolFactory:
|
||||
user_id=user_id
|
||||
)
|
||||
tool.set_message_attr(channel=channel, source=source, username=username)
|
||||
tool.set_callback_handler(callback_handler=callback_handler)
|
||||
tool.set_stream_handler(stream_handler=stream_handler)
|
||||
tools.append(tool)
|
||||
plugin_tools_count += 1
|
||||
logger.debug(f"成功加载插件 {plugin_name}({plugin_id}) 的工具: {ToolClass.__name__}")
|
||||
|
||||
@@ -25,7 +25,7 @@ class MoviePilotToolsManager:
|
||||
def __init__(self, user_id: str = "api_user", session_id: str = uuid.uuid4()):
|
||||
"""
|
||||
初始化工具管理器
|
||||
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
session_id: 会话ID
|
||||
@@ -47,7 +47,7 @@ class MoviePilotToolsManager:
|
||||
channel=None,
|
||||
source="api",
|
||||
username="API Client",
|
||||
callback_handler=None,
|
||||
stream_handler=None,
|
||||
)
|
||||
logger.info(f"成功加载 {len(self.tools)} 个工具")
|
||||
except Exception as e:
|
||||
@@ -57,40 +57,38 @@ class MoviePilotToolsManager:
|
||||
def list_tools(self) -> List[ToolDefinition]:
|
||||
"""
|
||||
列出所有可用的工具
|
||||
|
||||
|
||||
Returns:
|
||||
工具定义列表
|
||||
"""
|
||||
tools_list = []
|
||||
for tool in self.tools:
|
||||
# 获取工具的输入参数模型
|
||||
args_schema = getattr(tool, 'args_schema', None)
|
||||
args_schema = getattr(tool, "args_schema", None)
|
||||
if args_schema:
|
||||
# 将Pydantic模型转换为JSON Schema
|
||||
input_schema = self._convert_to_json_schema(args_schema)
|
||||
else:
|
||||
# 如果没有args_schema,使用基本信息
|
||||
input_schema = {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}
|
||||
input_schema = {"type": "object", "properties": {}, "required": []}
|
||||
|
||||
tools_list.append(ToolDefinition(
|
||||
name=tool.name,
|
||||
description=tool.description or "",
|
||||
input_schema=input_schema
|
||||
))
|
||||
tools_list.append(
|
||||
ToolDefinition(
|
||||
name=tool.name,
|
||||
description=tool.description or "",
|
||||
input_schema=input_schema,
|
||||
)
|
||||
)
|
||||
|
||||
return tools_list
|
||||
|
||||
def get_tool(self, tool_name: str) -> Optional[Any]:
|
||||
"""
|
||||
获取指定工具实例
|
||||
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
|
||||
|
||||
Returns:
|
||||
工具实例,如果未找到返回None
|
||||
"""
|
||||
@@ -159,23 +157,26 @@ class MoviePilotToolsManager:
|
||||
return []
|
||||
return [
|
||||
MoviePilotToolsManager._normalize_scalar_value(item_type, item.strip(), key)
|
||||
for item in trimmed.split(",") if item.strip()
|
||||
for item in trimmed.split(",")
|
||||
if item.strip()
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _normalize_arguments(tool_instance: Any, arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def _normalize_arguments(
|
||||
tool_instance: Any, arguments: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
根据工具的参数schema规范化参数类型
|
||||
|
||||
|
||||
Args:
|
||||
tool_instance: 工具实例
|
||||
arguments: 原始参数
|
||||
|
||||
|
||||
Returns:
|
||||
规范化后的参数
|
||||
"""
|
||||
# 获取工具的参数schema
|
||||
args_schema = getattr(tool_instance, 'args_schema', None)
|
||||
args_schema = getattr(tool_instance, "args_schema", None)
|
||||
if not args_schema:
|
||||
return arguments
|
||||
|
||||
@@ -201,31 +202,35 @@ class MoviePilotToolsManager:
|
||||
# 数组类型:将字符串解析为列表
|
||||
if field_type == "array" and isinstance(value, str):
|
||||
item_type = field_info.get("items", {}).get("type", "string")
|
||||
normalized[key] = MoviePilotToolsManager._parse_array_string(value, key, item_type)
|
||||
normalized[key] = MoviePilotToolsManager._parse_array_string(
|
||||
value, key, item_type
|
||||
)
|
||||
continue
|
||||
|
||||
# 根据类型进行转换
|
||||
normalized[key] = MoviePilotToolsManager._normalize_scalar_value(field_type, value, key)
|
||||
normalized[key] = MoviePilotToolsManager._normalize_scalar_value(
|
||||
field_type, value, key
|
||||
)
|
||||
|
||||
return normalized
|
||||
|
||||
async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> str:
|
||||
"""
|
||||
调用工具
|
||||
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
arguments: 工具参数
|
||||
|
||||
|
||||
Returns:
|
||||
工具执行结果(字符串)
|
||||
"""
|
||||
tool_instance = self.get_tool(tool_name)
|
||||
|
||||
if not tool_instance:
|
||||
error_msg = json.dumps({
|
||||
"error": f"工具 '{tool_name}' 未找到"
|
||||
}, ensure_ascii=False)
|
||||
error_msg = json.dumps(
|
||||
{"error": f"工具 '{tool_name}' 未找到"}, ensure_ascii=False
|
||||
)
|
||||
return error_msg
|
||||
|
||||
try:
|
||||
@@ -238,7 +243,7 @@ class MoviePilotToolsManager:
|
||||
# 确保返回字符串
|
||||
if isinstance(result, str):
|
||||
formated_result = result
|
||||
elif isinstance(result, int, float):
|
||||
elif isinstance(result, (int, float)):
|
||||
formated_result = str(result)
|
||||
else:
|
||||
try:
|
||||
@@ -250,19 +255,20 @@ class MoviePilotToolsManager:
|
||||
return formated_result
|
||||
except Exception as e:
|
||||
logger.error(f"调用工具 {tool_name} 时发生错误: {e}", exc_info=True)
|
||||
error_msg = json.dumps({
|
||||
"error": f"调用工具 '{tool_name}' 时发生错误: {str(e)}"
|
||||
}, ensure_ascii=False)
|
||||
error_msg = json.dumps(
|
||||
{"error": f"调用工具 '{tool_name}' 时发生错误: {str(e)}"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
return error_msg
|
||||
|
||||
@staticmethod
|
||||
def _convert_to_json_schema(args_schema: Any) -> Dict[str, Any]:
|
||||
"""
|
||||
将Pydantic模型转换为JSON Schema
|
||||
|
||||
|
||||
Args:
|
||||
args_schema: Pydantic模型类
|
||||
|
||||
|
||||
Returns:
|
||||
JSON Schema字典
|
||||
"""
|
||||
@@ -275,7 +281,9 @@ class MoviePilotToolsManager:
|
||||
|
||||
if "properties" in schema:
|
||||
for field_name, field_info in schema["properties"].items():
|
||||
resolved_field_info = MoviePilotToolsManager._resolve_field_schema(field_info)
|
||||
resolved_field_info = MoviePilotToolsManager._resolve_field_schema(
|
||||
field_info
|
||||
)
|
||||
# 转换字段类型
|
||||
field_type = resolved_field_info.get("type", "string")
|
||||
field_description = resolved_field_info.get("description", "")
|
||||
@@ -286,14 +294,14 @@ class MoviePilotToolsManager:
|
||||
default_value = resolved_field_info.get("default")
|
||||
properties[field_name] = {
|
||||
"type": field_type,
|
||||
"description": field_description
|
||||
"description": field_description,
|
||||
}
|
||||
if default_value is not None:
|
||||
properties[field_name]["default"] = default_value
|
||||
else:
|
||||
properties[field_name] = {
|
||||
"type": field_type,
|
||||
"description": field_description
|
||||
"description": field_description,
|
||||
}
|
||||
required.append(field_name)
|
||||
|
||||
@@ -305,11 +313,7 @@ class MoviePilotToolsManager:
|
||||
if field_type == "array" and "items" in resolved_field_info:
|
||||
properties[field_name]["items"] = resolved_field_info["items"]
|
||||
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": required
|
||||
}
|
||||
return {"type": "object", "properties": properties, "required": required}
|
||||
|
||||
|
||||
moviepilot_tool_manager = MoviePilotToolsManager()
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""LLM模型相关辅助功能"""
|
||||
from typing import List, Optional
|
||||
|
||||
from typing import List
|
||||
|
||||
from app.core.config import settings
|
||||
from app.log import logger
|
||||
@@ -9,11 +10,10 @@ class LLMHelper:
|
||||
"""LLM模型相关辅助功能"""
|
||||
|
||||
@staticmethod
|
||||
def get_llm(streaming: bool = False, callbacks: Optional[list] = None):
|
||||
def get_llm(streaming: bool = False):
|
||||
"""
|
||||
获取LLM实例
|
||||
:param streaming: 是否启用流式输出
|
||||
:param callbacks: 回调处理器列表
|
||||
:return: LLM实例
|
||||
"""
|
||||
provider = settings.LLM_PROVIDER.lower()
|
||||
@@ -24,7 +24,9 @@ class LLMHelper:
|
||||
|
||||
if provider == "google":
|
||||
if settings.PROXY_HOST:
|
||||
# 通过代理使用 Google 的 OpenAI 兼容接口
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
return ChatOpenAI(
|
||||
model=settings.LLM_MODEL,
|
||||
api_key=api_key,
|
||||
@@ -32,33 +34,34 @@ class LLMHelper:
|
||||
base_url="https://generativelanguage.googleapis.com/v1beta/openai",
|
||||
temperature=settings.LLM_TEMPERATURE,
|
||||
streaming=streaming,
|
||||
callbacks=callbacks,
|
||||
stream_usage=True,
|
||||
openai_proxy=settings.PROXY_HOST
|
||||
openai_proxy=settings.PROXY_HOST,
|
||||
)
|
||||
else:
|
||||
# 使用 langchain-google-genai 原生接口(v4 API 变更:google_api_key → api_key,max_retries → retries)
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
|
||||
return ChatGoogleGenerativeAI(
|
||||
model=settings.LLM_MODEL,
|
||||
google_api_key=api_key,
|
||||
max_retries=3,
|
||||
api_key=api_key,
|
||||
retries=3,
|
||||
temperature=settings.LLM_TEMPERATURE,
|
||||
streaming=streaming,
|
||||
callbacks=callbacks
|
||||
streaming=streaming
|
||||
)
|
||||
elif provider == "deepseek":
|
||||
from langchain_deepseek import ChatDeepSeek
|
||||
|
||||
return ChatDeepSeek(
|
||||
model=settings.LLM_MODEL,
|
||||
api_key=api_key,
|
||||
max_retries=3,
|
||||
temperature=settings.LLM_TEMPERATURE,
|
||||
streaming=streaming,
|
||||
callbacks=callbacks,
|
||||
stream_usage=True
|
||||
stream_usage=True,
|
||||
)
|
||||
else:
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
return ChatOpenAI(
|
||||
model=settings.LLM_MODEL,
|
||||
api_key=api_key,
|
||||
@@ -66,12 +69,13 @@ class LLMHelper:
|
||||
base_url=settings.LLM_BASE_URL,
|
||||
temperature=settings.LLM_TEMPERATURE,
|
||||
streaming=streaming,
|
||||
callbacks=callbacks,
|
||||
stream_usage=True,
|
||||
openai_proxy=settings.PROXY_HOST
|
||||
openai_proxy=settings.PROXY_HOST,
|
||||
)
|
||||
|
||||
def get_models(self, provider: str, api_key: str, base_url: str = None) -> List[str]:
|
||||
def get_models(
|
||||
self, provider: str, api_key: str, base_url: str = None
|
||||
) -> List[str]:
|
||||
"""获取模型列表"""
|
||||
logger.info(f"获取 {provider} 模型列表...")
|
||||
if provider == "google":
|
||||
@@ -81,18 +85,25 @@ class LLMHelper:
|
||||
|
||||
@staticmethod
|
||||
def _get_google_models(api_key: str) -> List[str]:
|
||||
"""获取Google模型列表"""
|
||||
"""获取Google模型列表(使用 google-genai SDK v1)"""
|
||||
try:
|
||||
import google.generativeai as genai
|
||||
genai.configure(api_key=api_key)
|
||||
models = genai.list_models()
|
||||
return [m.name for m in models if 'generateContent' in m.supported_generation_methods]
|
||||
from google import genai
|
||||
|
||||
client = genai.Client(api_key=api_key)
|
||||
models = client.models.list()
|
||||
return [
|
||||
m.name
|
||||
for m in models
|
||||
if m.supported_actions and "generateContent" in m.supported_actions
|
||||
]
|
||||
except Exception as e:
|
||||
logger.error(f"获取Google模型列表失败:{e}")
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
def _get_openai_compatible_models(provider: str, api_key: str, base_url: str = None) -> List[str]:
|
||||
def _get_openai_compatible_models(
|
||||
provider: str, api_key: str, base_url: str = None
|
||||
) -> List[str]:
|
||||
"""获取OpenAI兼容模型列表"""
|
||||
try:
|
||||
from openai import OpenAI
|
||||
|
||||
@@ -1,38 +1,37 @@
|
||||
"""AI智能体相关数据模型"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Any
|
||||
from typing import List, Optional
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
from pydantic import BaseModel, Field, ConfigDict, field_serializer
|
||||
|
||||
|
||||
class ConversationMemory(BaseModel):
|
||||
"""对话记忆模型"""
|
||||
|
||||
|
||||
session_id: str = Field(description="会话ID")
|
||||
user_id: Optional[str] = Field(default=None, description="用户ID")
|
||||
title: Optional[str] = Field(default=None, description="会话标题")
|
||||
messages: List[Dict[str, Any]] = Field(default_factory=list, description="消息列表")
|
||||
context: Dict[str, Any] = Field(default_factory=dict, description="会话上下文")
|
||||
created_at: datetime = Field(default_factory=datetime.now, description="创建时间")
|
||||
messages: List[BaseMessage] = Field(default_factory=list, description="消息列表")
|
||||
updated_at: datetime = Field(default_factory=datetime.now, description="更新时间")
|
||||
|
||||
|
||||
model_config = ConfigDict()
|
||||
|
||||
@field_serializer('created_at', 'updated_at', when_used='json')
|
||||
|
||||
@field_serializer('updated_at', when_used='json')
|
||||
def serialize_datetime(self, value: datetime) -> str:
|
||||
return value.isoformat()
|
||||
|
||||
|
||||
class AgentState(BaseModel):
|
||||
"""AI智能体状态模型"""
|
||||
|
||||
|
||||
session_id: str = Field(description="会话ID")
|
||||
current_task: Optional[str] = Field(default=None, description="当前任务")
|
||||
is_thinking: bool = Field(default=False, description="是否正在思考")
|
||||
last_activity: datetime = Field(default_factory=datetime.now, description="最后活动时间")
|
||||
|
||||
|
||||
model_config = ConfigDict()
|
||||
|
||||
|
||||
@field_serializer('last_activity', when_used='json')
|
||||
def serialize_datetime(self, value: datetime) -> str:
|
||||
return value.isoformat()
|
||||
@@ -40,7 +39,7 @@ class AgentState(BaseModel):
|
||||
|
||||
class UserMessage(BaseModel):
|
||||
"""用户消息模型"""
|
||||
|
||||
|
||||
session_id: str = Field(description="会话ID")
|
||||
content: str = Field(description="消息内容")
|
||||
user_id: Optional[str] = Field(default=None, description="用户ID")
|
||||
@@ -50,7 +49,7 @@ class UserMessage(BaseModel):
|
||||
|
||||
class ToolResult(BaseModel):
|
||||
"""工具执行结果模型"""
|
||||
|
||||
|
||||
session_id: str = Field(description="会话ID")
|
||||
call_id: str = Field(description="调用ID")
|
||||
success: bool = Field(description="是否成功")
|
||||
|
||||
@@ -82,14 +82,13 @@ pympler~=1.1
|
||||
smbprotocol~=1.15.0
|
||||
setproctitle~=1.3.6
|
||||
httpx[socks]~=0.28.1
|
||||
langchain~=0.3.27
|
||||
langchain-core~=0.3.76
|
||||
langchain-community~=0.3.29
|
||||
langchain-openai~=0.3.33
|
||||
langchain-google-genai~=2.0.10
|
||||
langchain-deepseek~=0.1.4
|
||||
langchain-experimental~=0.3.4
|
||||
openai~=1.108.2
|
||||
google-generativeai~=0.8.5
|
||||
langchain~=1.2.13
|
||||
langchain-core~=1.2.20
|
||||
langchain-community~=0.4.1
|
||||
langchain-openai~=1.1.11
|
||||
langchain-google-genai~=4.2.1
|
||||
langchain-deepseek~=1.0.1
|
||||
langchain-experimental~=0.4.1
|
||||
openai~=2.29.0
|
||||
ddgs~=9.10.0
|
||||
websocket-client~=1.8.0
|
||||
|
||||
Reference in New Issue
Block a user