diff --git a/app/agent/__init__.py b/app/agent/__init__.py
index 0ef1f196..e46b3ee4 100644
--- a/app/agent/__init__.py
+++ b/app/agent/__init__.py
@@ -1,26 +1,25 @@
import asyncio
-from typing import Dict, List, Any, Union
-import json
-import tiktoken
+from time import strftime
+from typing import Dict, List
-from langchain.agents import AgentExecutor
-from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
-from langchain_community.callbacks import get_openai_callback
-from langchain_core.chat_history import InMemoryChatMessageHistory
-from langchain_core.messages import HumanMessage, AIMessage, ToolCall, ToolMessage, SystemMessage, trim_messages
-from langchain_core.runnables import RunnablePassthrough, RunnableLambda
-from langchain_core.runnables.history import RunnableWithMessageHistory
-from langchain.agents.format_scratchpad.openai_tools import format_to_openai_tool_messages
-from langchain.agents.output_parsers.openai_tools import OpenAIToolsAgentOutputParser
+from langchain.agents import create_agent
+from langchain.agents.middleware import (
+ SummarizationMiddleware,
+ ModelRetryMiddleware,
+ ToolRetryMiddleware,
+)
+from langchain_core.messages import (
+ HumanMessage,
+ BaseMessage,
+)
+from langgraph.checkpoint.memory import InMemorySaver
-from app.agent.callback import StreamingCallbackHandler
-from app.agent.memory import conversation_manager
+from app.agent.callback import StreamingHandler, StreamingHandler
+from app.agent.memory import memory_manager
from app.agent.prompt import prompt_manager
from app.agent.tools.factory import MoviePilotToolFactory
from app.chain import ChainBase
-from app.core.config import settings
from app.helper.llm import LLMHelper
-from app.helper.message import MessageHelper
from app.log import logger
from app.schemas import Notification
@@ -31,42 +30,32 @@ class AgentChain(ChainBase):
class MoviePilotAgent:
"""
- MoviePilot AI智能体
+ MoviePilot AI智能体(基于 LangChain v1 + LangGraph)
"""
- def __init__(self, session_id: str, user_id: str = None,
- channel: str = None, source: str = None, username: str = None):
+ def __init__(
+ self,
+ session_id: str,
+ user_id: str = None,
+ channel: str = None,
+ source: str = None,
+ username: str = None,
+ ):
self.session_id = session_id
self.user_id = user_id
- self.channel = channel # 消息渠道
- self.source = source # 消息来源
- self.username = username # 用户名
+ self.channel = channel
+ self.source = source
+ self.username = username
- # 消息助手
- self.message_helper = MessageHelper()
+ # 流式token管理
+ self.stream_handler = StreamingHandler()
- # 回调处理器
- self.callback_handler = StreamingCallbackHandler(
- session_id=session_id
- )
-
- # LLM模型
- self.llm = self._initialize_llm()
-
- # 工具
- self.tools = self._initialize_tools()
-
- # 提示词模板
- self.prompt = self._initialize_prompt()
-
- # Agent执行器
- self.agent_executor = self._create_agent_executor()
-
- def _initialize_llm(self):
+ @staticmethod
+ def _initialize_llm():
"""
- 初始化LLM模型
+ 初始化 LLM(带流式回调)
"""
- return LLMHelper.get_llm(streaming=True, callbacks=[self.callback_handler])
+ return LLMHelper.get_llm(streaming=True)
def _initialize_tools(self) -> List:
"""
@@ -78,384 +67,118 @@ class MoviePilotAgent:
channel=self.channel,
source=self.source,
username=self.username,
- callback_handler=self.callback_handler
+ stream_handler=self.stream_handler,
)
- @staticmethod
- def _initialize_session_store() -> Dict[str, InMemoryChatMessageHistory]:
+ def _create_agent(self):
"""
- 初始化内存存储
- """
- return {}
-
- def get_session_history(self, session_id: str) -> InMemoryChatMessageHistory:
- """
- 获取会话历史
- """
- chat_history = InMemoryChatMessageHistory()
- messages: List[dict] = conversation_manager.get_recent_messages_for_agent(
- session_id=session_id,
- user_id=self.user_id
- )
- if messages:
- for msg in messages:
- if msg.get("role") == "user":
- chat_history.add_message(HumanMessage(content=msg.get("content", "")))
- elif msg.get("role") == "agent":
- chat_history.add_message(AIMessage(content=msg.get("content", "")))
- elif msg.get("role") == "tool_call":
- metadata = msg.get("metadata", {})
- chat_history.add_message(
- AIMessage(
- content=msg.get("content", ""),
- tool_calls=[
- ToolCall(
- id=metadata.get("call_id"),
- name=metadata.get("tool_name"),
- args=metadata.get("parameters"),
- )
- ]
- )
- )
- elif msg.get("role") == "tool_result":
- metadata = msg.get("metadata", {})
- chat_history.add_message(ToolMessage(
- content=msg.get("content", ""),
- tool_call_id=metadata.get("call_id", "unknown")
- ))
- elif msg.get("role") == "system":
- chat_history.add_message(SystemMessage(content=msg.get("content", "")))
-
- return chat_history
-
- @staticmethod
- def _initialize_prompt() -> ChatPromptTemplate:
- """
- 初始化提示词模板
+ 创建 LangGraph Agent(使用 create_agent + SummarizationMiddleware)
"""
try:
- prompt_template = ChatPromptTemplate.from_messages([
- ("system", "{system_prompt}"),
- MessagesPlaceholder(variable_name="chat_history"),
- ("user", "{input}"),
- MessagesPlaceholder(variable_name="agent_scratchpad"),
- ])
- logger.info("LangChain提示词模板初始化成功")
- return prompt_template
+ # 系统提示词
+ system_prompt = prompt_manager.get_agent_prompt(
+ channel=self.channel
+ ).format(
+ current_date=strftime('%Y-%m-%d')
+ )
+
+ # LLM 模型(用于 agent 执行)
+ llm = self._initialize_llm()
+
+ # 工具列表
+ tools = self._initialize_tools()
+
+ # 中间件
+ middlewares = [
+ # 上下文压缩
+ SummarizationMiddleware(
+ model=llm,
+ trigger=("fraction", 0.85)
+ ),
+ # 模型调用失败时自动重试
+ ModelRetryMiddleware(max_retries=3),
+ # 工具调用失败时自动重试
+ ToolRetryMiddleware(max_retries=1)
+ ]
+
+ return create_agent(
+ model=llm,
+ tools=tools,
+ system_prompt=system_prompt,
+ middleware=middlewares,
+ checkpointer=InMemorySaver(),
+ )
except Exception as e:
- logger.error(f"初始化提示词失败: {e}")
+ logger.error(f"创建 Agent 失败: {e}")
raise e
- @staticmethod
- def _token_counter(messages: List[Union[HumanMessage, AIMessage, ToolMessage, SystemMessage]]) -> int:
+ async def process(self, message: str) -> str:
"""
- 通用的Token计数器
+ 处理用户消息,流式推理并返回 Agent 回复
"""
try:
- # 尝试从模型获取编码集,如果失败则回退到 cl100k_base (大多数现代模型使用的编码)
- try:
- encoding = tiktoken.encoding_for_model(settings.LLM_MODEL)
- except KeyError:
- encoding = tiktoken.get_encoding("cl100k_base")
+ logger.info(f"Agent推理: session_id={self.session_id}, input={message}")
- num_tokens = 0
- for message in messages:
- # 基础开销 (每个消息大约 3 个 token)
- num_tokens += 3
-
- # 1. 处理文本内容 (content)
- if isinstance(message.content, str):
- num_tokens += len(encoding.encode(message.content))
- elif isinstance(message.content, list):
- for part in message.content:
- if isinstance(part, dict) and part.get("type") == "text":
- num_tokens += len(encoding.encode(part.get("text", "")))
-
- # 2. 处理工具调用 (仅 AIMessage 包含 tool_calls)
- if getattr(message, "tool_calls", None):
- for tool_call in message.tool_calls:
- # 函数名
- num_tokens += len(encoding.encode(tool_call.get("name", "")))
- # 参数 (转为 JSON 估算)
- args_str = json.dumps(tool_call.get("args", {}), ensure_ascii=False)
- num_tokens += len(encoding.encode(args_str))
- # 额外的结构开销 (ID 等)
- num_tokens += 3
-
- # 3. 处理角色权重
- num_tokens += 1
-
- # 加上回复的起始 Token (大约 3 个 token)
- num_tokens += 3
- return num_tokens
- except Exception as e:
- logger.error(f"Token计数失败: {e}")
- # 发生错误时返回一个保守的估算值
- return len(str(messages)) // 4
-
- def _create_agent_executor(self) -> RunnableWithMessageHistory:
- """
- 创建Agent执行器
- """
- try:
- # 消息裁剪器,防止上下文超出限制
- base_trimmer = trim_messages(
- max_tokens=settings.LLM_MAX_CONTEXT_TOKENS * 1000 * 0.8,
- strategy="last",
- token_counter=self._token_counter,
- include_system=True,
- allow_partial=False,
- start_on="human",
- )
-
- # 包装trimmer,在裁剪后验证工具调用的完整性
- def validated_trimmer(messages):
- # 如果输入是 PromptValue,转换为消息列表
- if hasattr(messages, "to_messages"):
- messages = messages.to_messages()
- trimmed = base_trimmer.invoke(messages)
-
- # 二次校验:确保不出现 broken tool chains
- # 1. AIMessage with tool_calls 必须紧跟着对应的 ToolMessage
- # 2. ToolMessage 必须有对应的 AIMessage 前置
- safe_messages = []
- i = 0
- while i < len(trimmed):
- msg = trimmed[i]
-
- if isinstance(msg, AIMessage) and getattr(msg, "tool_calls", None):
- # 检查工具调用序列是否完整
- tool_calls = msg.tool_calls
- is_valid_sequence = True
- tool_results = []
-
- # 向后查找对应的 ToolMessage
- temp_i = i + 1
- for tool_call in tool_calls:
- if temp_i >= len(trimmed):
- is_valid_sequence = False
- break
-
- next_msg = trimmed[temp_i]
- if isinstance(next_msg, ToolMessage) and next_msg.tool_call_id == tool_call.get("id"):
- tool_results.append(next_msg)
- temp_i += 1
- else:
- is_valid_sequence = False
- break
-
- if is_valid_sequence:
- # 序列完整,保留消息
- safe_messages.append(msg)
- safe_messages.extend(tool_results)
- i = temp_i # 跳过已处理的工具结果
- else:
- # 序列不完整,丢弃该 AIMessage(后续的孤立 ToolMessage 会在下一次循环被当做 orphaned 处理掉)
- logger.warning(f"移除无效的工具调用链: {len(tool_calls)} calls, incomplete results")
- i += 1
- continue
-
- if isinstance(msg, ToolMessage):
- # 如果在这里遇到 ToolMessage,说明它没有被上面的逻辑消费,则是孤立的(或者顺序错乱)
- logger.warning("移除孤立的 ToolMessage")
- i += 1
- continue
-
- # 其他类型的消息直接保留
- safe_messages.append(msg)
- i += 1
-
- if len(safe_messages) < len(messages):
- logger.info(f"LangChain消息上下文已裁剪: {len(messages)} -> {len(safe_messages)}")
- return safe_messages
-
- # 创建Agent执行链
- agent = (
- RunnablePassthrough.assign(
- agent_scratchpad=lambda x: format_to_openai_tool_messages(
- x["intermediate_steps"]
- )
- )
- | self.prompt
- | RunnableLambda(validated_trimmer)
- | self.llm.bind_tools(self.tools)
- | OpenAIToolsAgentOutputParser()
- )
- executor = AgentExecutor(
- agent=agent,
- tools=self.tools,
- verbose=settings.LLM_VERBOSE,
- max_iterations=settings.LLM_MAX_ITERATIONS,
- return_intermediate_steps=True,
- handle_parsing_errors=True,
- early_stopping_method="force"
- )
- return RunnableWithMessageHistory(
- executor,
- self.get_session_history,
- input_messages_key="input",
- history_messages_key="chat_history"
- )
- except Exception as e:
- logger.error(f"创建Agent执行器失败: {e}")
- raise e
-
- async def _summarize_history(self):
- """
- 总结提炼之前的对话和工具执行情况,并把会话总结变成新的系统提示词取代之前的对话
- """
- try:
- # 获取当前历史记录
- chat_history = self.get_session_history(self.session_id)
- messages = chat_history.messages
- if not messages:
- return
-
- logger.info(f"会话 {self.session_id} 历史消息长度已超过 90%,开始总结并重置上下文...")
-
- # 将消息转换为摘要所需的文本格式
- history_text = ""
- for msg in messages:
- if isinstance(msg, HumanMessage):
- history_text += f"用户: {msg.content}\n"
- elif isinstance(msg, AIMessage):
- history_text += f"智能体: {msg.content}\n"
- if getattr(msg, "tool_calls", None):
- for tool_call in msg.tool_calls:
- history_text += f"智能体调用工具: {tool_call.get('name')},参数: {tool_call.get('args')}\n"
- elif isinstance(msg, ToolMessage):
- history_text += f"工具响应: {msg.content}\n"
- elif isinstance(msg, SystemMessage):
- history_text += f"系统: {msg.content}\n"
-
- # 摘要提示词
- summary_prompt = (
- "Please provide a comprehensive and highly informational summary of the preceding conversation and tool executions. "
- "Your goal is to condense the history while retaining all critical details for future reference. "
- "Ensure you include:\n"
- "1. User's core intents, specific requests, and any mentioned preferences.\n"
- "2. Names of movies, TV shows, or other key entities discussed.\n"
- "3. A concise log of tool calls made and their specific results/outcomes.\n"
- "4. The current status of any tasks and any pending actions.\n"
- "5. Any important context that would be necessary for the agent to continue the conversation seamlessly.\n"
- "The summary should be dense with information and serve as the primary context for the next stage of the interaction."
- )
-
- # 调用 LLM 进行总结 (非流式)
- summary_llm = LLMHelper.get_llm(streaming=False)
- response = await summary_llm.ainvoke([
- SystemMessage(content=summary_prompt),
- HumanMessage(content=f"Here is the conversation history to summarize:\n{history_text}")
- ])
- summary_content = str(response.content)
-
- if not summary_content:
- logger.warning("总结生成失败,跳过重置逻辑。")
- return
-
- # 清空原有的会话记录并插入新的系统总结
- await conversation_manager.clear_memory(self.session_id, self.user_id)
- await conversation_manager.add_conversation(
+ # 获取历史消息
+ messages = memory_manager.get_agent_messages(
session_id=self.session_id,
- user_id=self.user_id,
- role="system",
- content=f"\n{summary_content}\n"
- )
- logger.info(f"会话 {self.session_id} 历史摘要替换完成。")
- except Exception as e:
- logger.error(f"执行会话总结出错: {str(e)}")
-
- async def process_message(self, message: str) -> str:
- """
- 处理用户消息
- """
- try:
- # 检查上下文长度是否超过 90%
- history = self.get_session_history(self.session_id)
- if self._token_counter(history.messages) > settings.LLM_MAX_CONTEXT_TOKENS * 1000 * 0.9:
- await self._summarize_history()
-
- # 添加用户消息到记忆
- await conversation_manager.add_conversation(
- self.session_id,
- user_id=self.user_id,
- role="user",
- content=message
+ user_id=self.user_id
)
- # 构建输入上下文
- input_context = {
- "system_prompt": prompt_manager.get_agent_prompt(channel=self.channel),
- "input": message
- }
+ # 增加用户消息
+ messages.append(HumanMessage(content=message))
- # 执行Agent
- logger.info(f"Agent执行推理: session_id={self.session_id}, input={message}")
-
- result = await self._execute_agent(input_context)
-
- # 获取Agent回复
- agent_message = await self.callback_handler.get_message()
-
- # 发送Agent回复给用户(通过原渠道)
- if agent_message:
- # 发送回复
- await self.send_agent_message(agent_message)
-
- # 添加Agent回复到记忆
- await conversation_manager.add_conversation(
- session_id=self.session_id,
- user_id=self.user_id,
- role="agent",
- content=agent_message
- )
- else:
- agent_message = result.get("output") or "很抱歉,智能体出错了,未能生成回复内容。"
- await self.send_agent_message(agent_message)
-
- return agent_message
+ # 执行推理
+ await self._execute_agent(messages)
except Exception as e:
error_message = f"处理消息时发生错误: {str(e)}"
logger.error(error_message)
- # 发送错误消息给用户(通过原渠道)
await self.send_agent_message(error_message)
return error_message
- async def _execute_agent(self, input_context: Dict[str, Any]) -> Dict[str, Any]:
+ async def _execute_agent(self, messages: List[BaseMessage]):
"""
- 执行LangChain Agent
+ 调用 LangGraph Agent,通过 astream_events 流式获取 token,
+ 同时用 UsageMetadataCallbackHandler 统计 token 用量。
"""
try:
- with get_openai_callback() as cb:
- result = await self.agent_executor.ainvoke(
- input_context,
- config={"configurable": {"session_id": self.session_id}},
- callbacks=[self.callback_handler]
- )
- logger.info(f"LLM调用消耗: \n{cb}")
+ # Agent运行配置
+ agent_config = {
+ ""
+ }
+
+ # 创建智能体
+ agent = self._create_agent()
+
+ # 流式运行智能体
+ async for chunk in agent.astream(
+ {"messages": messages},
+ stream_mode="messages",
+ config=agent_config
+ ):
+ token, metadata = chunk
+ if token:
+ self.stream_handler.emit(token.content)
+
+ # 发送最终消息给用户
+ await self.send_agent_message(
+ self.stream_handler.take()
+ )
+
+ # 保存消息
+ memory_manager.save_agent_messages(
+ session_id=self.session_id,
+ user_id=self.user_id,
+ messages=agent.get_state(agent_config).values("messages")
+ )
- if cb.total_tokens > 0:
- result["token_usage"] = {
- "prompt_tokens": cb.prompt_tokens,
- "completion_tokens": cb.completion_tokens,
- "total_tokens": cb.total_tokens
- }
- return result
except asyncio.CancelledError:
logger.info(f"Agent执行被取消: session_id={self.session_id}")
- return {
- "output": "任务已取消",
- "intermediate_steps": [],
- "token_usage": {}
- }
+ return "任务已取消", {}
except Exception as e:
logger.error(f"Agent执行失败: {e}")
- return {
- "output": str(e),
- "intermediate_steps": [],
- "token_usage": {}
- }
+ return str(e), {}
async def send_agent_message(self, message: str, title: str = "MoviePilot助手"):
"""
@@ -468,7 +191,7 @@ class MoviePilotAgent:
userid=self.user_id,
username=self.username,
title=title,
- text=message
+ text=message,
)
)
@@ -492,38 +215,44 @@ class AgentManager:
"""
初始化管理器
"""
- await conversation_manager.initialize()
+ await memory_manager.initialize()
async def close(self):
"""
关闭管理器
"""
- await conversation_manager.close()
- # 清理所有活跃的智能体
+ await memory_manager.close()
for agent in self.active_agents.values():
await agent.cleanup()
self.active_agents.clear()
- async def process_message(self, session_id: str, user_id: str, message: str,
- channel: str = None, source: str = None, username: str = None) -> str:
+ async def process_message(
+ self,
+ session_id: str,
+ user_id: str,
+ message: str,
+ channel: str = None,
+ source: str = None,
+ username: str = None,
+ ) -> str:
"""
处理用户消息
"""
- # 获取或创建Agent实例
if session_id not in self.active_agents:
- logger.info(f"创建新的AI智能体实例,session_id: {session_id}, user_id: {user_id}")
+ logger.info(
+ f"创建新的AI智能体实例,session_id: {session_id}, user_id: {user_id}"
+ )
agent = MoviePilotAgent(
session_id=session_id,
user_id=user_id,
channel=channel,
source=source,
- username=username
+ username=username,
)
self.active_agents[session_id] = agent
else:
agent = self.active_agents[session_id]
- agent.user_id = user_id # 确保user_id是最新的
- # 更新渠道信息
+ agent.user_id = user_id
if channel:
agent.channel = channel
if source:
@@ -531,8 +260,7 @@ class AgentManager:
if username:
agent.username = username
- # 处理消息
- return await agent.process_message(message)
+ return await agent.process(message)
async def clear_session(self, session_id: str, user_id: str):
"""
@@ -542,7 +270,7 @@ class AgentManager:
agent = self.active_agents[session_id]
await agent.cleanup()
del self.active_agents[session_id]
- await conversation_manager.clear_memory(session_id, user_id)
+ await memory_manager.clear_memory(session_id, user_id)
logger.info(f"会话 {session_id} 的记忆已清空")
diff --git a/app/agent/callback/__init__.py b/app/agent/callback/__init__.py
index 5da3957d..963834c4 100644
--- a/app/agent/callback/__init__.py
+++ b/app/agent/callback/__init__.py
@@ -1,39 +1,41 @@
import threading
-from langchain_core.callbacks import AsyncCallbackHandler
-
from app.log import logger
-class StreamingCallbackHandler(AsyncCallbackHandler):
+class StreamingHandler:
"""
- 流式输出回调处理器
+ 流式Token缓冲管理器
+
+ 负责从 LLM 流式 token 中积累文本,供 Agent 在工具调用之间穿插发送中间消息。
"""
- def __init__(self, session_id: str):
+ def __init__(self):
self._lock = threading.Lock()
- self.session_id = session_id
- self.current_message = ""
+ self._buffer = ""
- async def get_message(self):
+ def emit(self, token: str):
"""
- 获取当前消息内容,获取后清空
+ 接收 LLM 流式 token,积累到缓冲区。
"""
with self._lock:
- if not self.current_message:
+ self._buffer += token
+
+ def take(self) -> str:
+ """
+ 获取当前已积累的消息内容,获取后清空缓冲区。
+ """
+ with self._lock:
+ if not self._buffer:
return ""
- msg = self.current_message
- logger.info(f"Agent消息: {msg}")
- self.current_message = ""
- return msg
+ message = self._buffer
+ logger.info(f"Agent消息: {message}")
+ self._buffer = ""
+ return message
- async def on_llm_new_token(self, token: str, **kwargs):
+ def clear(self):
"""
- 处理新的token
+ 清空缓冲区(不返回内容)
"""
- if not token:
- return
with self._lock:
- # 缓存当前消息
- self.current_message += token
-
+ self._buffer = ""
diff --git a/app/agent/memory/__init__.py b/app/agent/memory/__init__.py
index d2957a4a..3fba4f6e 100644
--- a/app/agent/memory/__init__.py
+++ b/app/agent/memory/__init__.py
@@ -1,17 +1,17 @@
"""对话记忆管理器"""
import asyncio
-import json
-from datetime import datetime, timedelta
-from typing import Dict, List, Optional, Any
+from datetime import datetime
+from typing import Dict, List, Optional
+
+from langchain_core.messages import BaseMessage
from app.core.config import settings
-from app.helper.redis import AsyncRedisHelper
from app.log import logger
from app.schemas.agent import ConversationMemory
-class ConversationMemoryManager:
+class MemoryManager:
"""
对话记忆管理器
"""
@@ -19,18 +19,18 @@ class ConversationMemoryManager:
def __init__(self):
# 内存中的会话记忆缓存
self.memory_cache: Dict[str, ConversationMemory] = {}
- # 使用现有的Redis助手
- self.redis_helper = AsyncRedisHelper()
- # 内存缓存清理任务(Redis通过TTL自动过期)
+ # 内存缓存清理任务
self.cleanup_task: Optional[asyncio.Task] = None
- async def initialize(self):
+ def initialize(self):
"""
初始化记忆管理器
"""
try:
# 启动内存缓存清理任务(Redis通过TTL自动过期)
- self.cleanup_task = asyncio.create_task(self._cleanup_expired_memories())
+ self.cleanup_task = asyncio.create_task(
+ self._cleanup_expired_memories()
+ )
logger.info("对话记忆管理器初始化完成")
except Exception as e:
@@ -47,8 +47,6 @@ class ConversationMemoryManager:
except asyncio.CancelledError:
pass
- await self.redis_helper.close()
-
logger.info("对话记忆管理器已关闭")
@staticmethod
@@ -58,258 +56,64 @@ class ConversationMemoryManager:
"""
return f"{user_id}:{session_id}" if user_id else session_id
- @staticmethod
- def _get_redis_key(session_id: str, user_id: str):
- """
- 计算Redis Key
- """
- return f"agent_memory:{user_id}:{session_id}" if user_id else f"agent_memory:{session_id}"
-
- def _get_memory(self, session_id: str, user_id: str):
+ def get_memory(self, session_id: str, user_id: str) -> Optional[ConversationMemory]:
"""
获取内存中的记忆
"""
cache_key = self._get_memory_key(session_id, user_id)
return self.memory_cache.get(cache_key)
-
- async def _get_redis(self, session_id: str, user_id: str) -> Optional[ConversationMemory]:
- """
- 从Redis获取记忆
- """
- if settings.CACHE_BACKEND_TYPE == "redis":
- try:
- redis_key = self._get_redis_key(session_id, user_id)
- memory_data = await self.redis_helper.get(redis_key, region="AI_AGENT")
- if memory_data:
- memory_dict = json.loads(memory_data) if isinstance(memory_data, str) else memory_data
- memory = ConversationMemory(**memory_dict)
- return memory
- except Exception as e:
- logger.warning(f"从Redis加载记忆失败: {e}")
- return None
- async def get_conversation(self, session_id: str, user_id: str) -> ConversationMemory:
- """
- 获取会话记忆
- """
- # 首先检查缓存
- conversion = self._get_memory(session_id, user_id)
- if conversion:
- return conversion
-
- # 尝试从Redis加载
- memory = await self._get_redis(session_id, user_id)
- if memory:
- # 加载到内存缓存
- self._save_memory(memory)
- return memory
-
- # 创建新的记忆
- memory = ConversationMemory(session_id=session_id, user_id=user_id)
- await self._save_conversation(memory)
-
- return memory
-
- async def set_title(self, session_id: str, user_id: str, title: str):
- """
- 设置会话标题
- """
- memory = await self.get_conversation(session_id=session_id, user_id=user_id)
- memory.title = title
- memory.updated_at = datetime.now()
- await self._save_conversation(memory)
-
- async def get_title(self, session_id: str, user_id: str) -> Optional[str]:
- """
- 获取会话标题
- """
- memory = await self.get_conversation(session_id=session_id, user_id=user_id)
- return memory.title
-
- async def list_sessions(self, user_id: str, limit: int = 100) -> List[Dict[str, Any]]:
- """
- 列出历史会话摘要(按更新时间倒序)
-
- - 当启用Redis时:遍历 `agent_memory:*` 键并读取摘要
- - 当未启用Redis时:基于内存缓存返回
- """
- sessions: List[ConversationMemory] = []
- # 从Redis遍历
- if settings.CACHE_BACKEND_TYPE == "redis":
- try:
- # 使用Redis助手的items方法遍历所有键
- async for key, value in self.redis_helper.items(region="AI_AGENT"):
- if key.startswith("agent_memory:"):
- try:
- # 解析键名获取user_id和session_id
- key_parts = key.split(":")
- if len(key_parts) >= 3:
- key_user_id = key_parts[2] if len(key_parts) > 3 else None
- if not user_id or key_user_id == user_id:
- data = value if isinstance(value, dict) else json.loads(value)
- memory = ConversationMemory(**data)
- sessions.append(memory)
- except Exception as err:
- logger.warning(f"解析Redis记忆数据失败: {err}")
- continue
- except Exception as e:
- logger.warning(f"遍历Redis会话失败: {e}")
-
- # 合并内存缓存(确保包含近期的会话)
- for cache_key, memory in self.memory_cache.items():
- # 如果指定了user_id,只返回该用户的会话
- if not user_id or memory.user_id == user_id:
- sessions.append(memory)
-
- # 去重(以 session_id 为键,取最近updated)
- uniq: Dict[str, ConversationMemory] = {}
- for mem in sessions:
- existed = uniq.get(mem.session_id)
- if (not existed) or (mem.updated_at > existed.updated_at):
- uniq[mem.session_id] = mem
-
- # 排序并裁剪
- sorted_list = sorted(uniq.values(), key=lambda m: m.updated_at, reverse=True)[:limit]
- return [
- {
- "session_id": m.session_id,
- "title": m.title or "新会话",
- "message_count": len(m.messages),
- "created_at": m.created_at.isoformat(),
- "updated_at": m.updated_at.isoformat(),
- }
- for m in sorted_list
- ]
-
- async def add_conversation(
- self,
- session_id: str,
- user_id: str,
- role: str,
- content: str,
- metadata: Optional[Dict[str, Any]] = None
- ):
- """
- 添加消息到记忆
- """
- memory = await self.get_conversation(session_id=session_id, user_id=user_id)
-
- message = {
- "role": role,
- "content": content,
- "timestamp": datetime.now().isoformat(),
- "metadata": metadata or {}
- }
-
- memory.messages.append(message)
- memory.updated_at = datetime.now()
-
- # 限制消息数量,避免记忆过大
- max_messages = settings.LLM_MAX_MEMORY_MESSAGES
- if len(memory.messages) > max_messages:
- # 保留最近的消息,但保留第一条系统消息
- system_messages = [msg for msg in memory.messages if msg["role"] == "system"]
- recent_messages = memory.messages[-(max_messages - len(system_messages)):]
- memory.messages = system_messages + recent_messages
-
- await self._save_conversation(memory)
-
- logger.debug(f"消息已添加到记忆: session_id={session_id}, user_id={user_id}, role={role}")
-
- def get_recent_messages_for_agent(
- self,
- session_id: str,
- user_id: str
- ) -> List[Dict[str, Any]]:
+ def get_agent_messages(
+ self, session_id: str, user_id: str
+ ) -> List[BaseMessage]:
"""
为Agent获取最近的消息(仅内存缓存)
如果消息Token数量超过模型最大上下文长度的阀值,会自动进行摘要裁剪
"""
- cache_key = self._get_memory_key(session_id, user_id)
- memory = self.memory_cache.get(cache_key)
+ memory = self.get_memory(session_id, user_id)
if not memory:
return []
# 获取所有消息
- return memory.messages[:-1]
+ return memory.messages
- async def get_recent_messages(
- self,
- session_id: str,
- user_id: str,
- limit: int = 10,
- role_filter: Optional[list] = None
- ) -> List[Dict[str, Any]]:
+ def save_agent_messages(
+ self, session_id: str, user_id: str, messages: List[BaseMessage]
+ ):
"""
- 获取最近的消息
+ 保存Agent消息(仅内存缓存)
+
+ 注意:Redis中的记忆通过TTL机制自动过期,这里只更新内存缓存,Redis会在下次访问时自动过期
"""
- memory = await self.get_conversation(session_id=session_id, user_id=user_id)
+ memory = self.get_memory(session_id, user_id)
+ if not memory:
+ memory = ConversationMemory(session_id=session_id, user_id=user_id)
- messages = memory.messages
- if role_filter:
- messages = [msg for msg in messages if msg["role"] in role_filter]
+ memory.messages = messages
+ memory.updated_at = datetime.now()
- return messages[-limit:] if messages else []
+ # 更新内存缓存
+ self.save_memory(memory)
- async def get_context(self, session_id: str, user_id: str) -> Dict[str, Any]:
+ def save_memory(self, memory: ConversationMemory):
"""
- 获取会话上下文
- """
- memory = await self.get_conversation(session_id=session_id, user_id=user_id)
- return memory.context
+ 保存记忆到内存缓存
- async def clear_memory(self, session_id: str, user_id: str):
- """
- 清空会话记忆
- """
- cache_key = f"{user_id}:{session_id}" if user_id else session_id
- if cache_key in self.memory_cache:
- del self.memory_cache[cache_key]
-
- if settings.CACHE_BACKEND_TYPE == "redis":
- redis_key = self._get_redis_key(session_id, user_id)
- await self.redis_helper.delete(redis_key, region="AI_AGENT")
-
- logger.info(f"会话记忆已清空: session_id={session_id}, user_id={user_id}")
-
- def _save_memory(self, memory: ConversationMemory):
- """
- 保存记忆到内存
+ 注意:Redis中的记忆通过TTL机制自动过期,这里只更新内存缓存,Redis会在下次访问时自动过期
"""
cache_key = self._get_memory_key(memory.session_id, memory.user_id)
self.memory_cache[cache_key] = memory
- async def _save_redis(self, memory: ConversationMemory):
+ def clear_memory(self, session_id: str, user_id: str):
"""
- 保存记忆到Redis
+ 清空会话记忆
"""
- if settings.CACHE_BACKEND_TYPE == "redis":
- try:
- memory_dict = memory.model_dump()
- redis_key = self._get_redis_key(memory.session_id, memory.user_id)
- ttl = int(timedelta(days=settings.LLM_REDIS_MEMORY_RETENTION_DAYS).total_seconds())
- await self.redis_helper.set(
- redis_key,
- memory_dict,
- ttl=ttl,
- region="AI_AGENT"
- )
- except Exception as e:
- logger.warning(f"保存记忆到Redis失败: {e}")
-
- async def _save_conversation(self, memory: ConversationMemory):
- """
- 保存记忆到存储
-
- Redis中的记忆会自动通过TTL机制过期,无需手动清理
- """
- # 更新内存缓存
- self._save_memory(memory)
-
- # 保存到Redis,设置TTL自动过期
- await self._save_redis(memory)
+ cache_key = self._get_memory_key(session_id, user_id)
+ if cache_key in self.memory_cache:
+ del self.memory_cache[cache_key]
+ logger.info(f"会话记忆已清空: session_id={session_id}, user_id={user_id}")
async def _cleanup_expired_memories(self):
"""
@@ -328,7 +132,9 @@ class ConversationMemoryManager:
# 只检查内存缓存中的过期记忆
# Redis中的记忆会通过TTL自动过期,无需手动处理
for cache_key, memory in self.memory_cache.items():
- if (current_time - memory.updated_at).days > settings.LLM_MEMORY_RETENTION_DAYS:
+ if (
+ current_time - memory.updated_at
+ ).days > settings.LLM_MEMORY_RETENTION_DAYS:
expired_sessions.append(cache_key)
# 只清理内存缓存,不删除Redis中的键(Redis会自动过期)
@@ -344,4 +150,5 @@ class ConversationMemoryManager:
except Exception as e:
logger.error(f"清理记忆时发生错误: {e}")
-conversation_manager = ConversationMemoryManager()
+
+memory_manager = MemoryManager()
diff --git a/app/agent/prompt/Agent Prompt.txt b/app/agent/prompt/Agent Prompt.txt
index 0f46c8c8..65682069 100644
--- a/app/agent/prompt/Agent Prompt.txt
+++ b/app/agent/prompt/Agent Prompt.txt
@@ -69,4 +69,6 @@ At the end of your session/turn, provide a concise summary of your actions.
Specific markdown rules:
{markdown_spec}
-
\ No newline at end of file
+
+
+Today's date: {current_date}
diff --git a/app/agent/tools/base.py b/app/agent/tools/base.py
index 5d05dd9a..bb1b12d5 100644
--- a/app/agent/tools/base.py
+++ b/app/agent/tools/base.py
@@ -1,12 +1,11 @@
import json
-import uuid
from abc import ABCMeta, abstractmethod
from typing import Any, Optional
-from langchain.tools import BaseTool
+from langchain_core.tools import BaseTool
from pydantic import PrivateAttr
-from app.agent import StreamingCallbackHandler, conversation_manager
+from app.agent import StreamingHandler
from app.chain import ChainBase
from app.log import logger
from app.schemas import Notification
@@ -18,15 +17,15 @@ class ToolChain(ChainBase):
class MoviePilotTool(BaseTool, metaclass=ABCMeta):
"""
- MoviePilot专用工具基类
+ MoviePilot专用工具基类(LangChain v1 / langchain_core)
"""
_session_id: str = PrivateAttr()
_user_id: str = PrivateAttr()
- _channel: str = PrivateAttr(default=None)
- _source: str = PrivateAttr(default=None)
- _username: str = PrivateAttr(default=None)
- _callback_handler: StreamingCallbackHandler = PrivateAttr(default=None)
+ _channel: Optional[str] = PrivateAttr(default=None)
+ _source: Optional[str] = PrivateAttr(default=None)
+ _username: Optional[str] = PrivateAttr(default=None)
+ _stream_handler: Optional[StreamingHandler] = PrivateAttr(default=None)
def __init__(self, session_id: str, user_id: str, **kwargs):
super().__init__(**kwargs)
@@ -34,93 +33,70 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
self._user_id = user_id
def _run(self, *args: Any, **kwargs: Any) -> Any:
- pass
+ raise NotImplementedError("MoviePilotTool 只支持异步调用,请使用 _arun")
- async def _arun(self, **kwargs) -> str:
+ async def _arun(self, *args: Any, **kwargs: Any) -> str:
"""
- 异步运行工具
+ 异步运行工具,负责:
+ 1. 在工具调用前将流式消息推送给用户
+ 2. 持久化工具调用记录到会话记忆
+ 3. 调用具体工具逻辑(子类实现的 execute 方法)
+ 4. 持久化工具结果到会话记忆
"""
- # 获取工具调用前的agent消息
- agent_message = await self._callback_handler.get_message()
-
- # 生成唯一的工具调用ID
- call_id = f"call_{str(uuid.uuid4())[:16]}"
-
- # 记忆工具调用
- await conversation_manager.add_conversation(
- session_id=self._session_id,
- user_id=self._user_id,
- role="tool_call",
- content=agent_message,
- metadata={
- "call_id": call_id,
- "tool_name": self.name,
- "parameters": kwargs
- }
+ # 获取工具调用前 Agent 已积累的流式文本
+ agent_message = (
+ self._stream_handler.take() if self._stream_handler else ""
)
- # 获取执行工具说明,优先使用工具自定义的提示消息,如果没有则使用 explanation
+ # 获取工具执行提示消息
tool_message = self.get_tool_message(**kwargs)
if not tool_message:
explanation = kwargs.get("explanation")
if explanation:
tool_message = explanation
- # 合并agent消息和工具执行消息,一起发送
+ # 合并 Agent 消息和工具执行消息后一起发送
messages = []
if agent_message:
messages.append(agent_message)
if tool_message:
messages.append(f"⚙️ => {tool_message}")
- # 发送合并后的消息
if messages:
merged_message = "\n\n".join(messages)
await self.send_tool_message(merged_message, title="MoviePilot助手")
- logger.debug(f'Executing tool {self.name} with args: {kwargs}')
+ logger.debug(f"Executing tool {self.name} with args: {kwargs}")
- # 执行工具,捕获异常确保结果总是被存储到记忆中
+ # 执行具体工具逻辑
try:
result = await self.run(**kwargs)
- logger.debug(f'Tool {self.name} executed with result: {result}')
+ logger.debug(f"Tool {self.name} executed with result: {result}")
except Exception as e:
- # 记录异常详情
error_message = f"工具执行异常 ({type(e).__name__}): {str(e)}"
- logger.error(f'Tool {self.name} execution failed: {e}', exc_info=True)
+ logger.error(f"Tool {self.name} execution failed: {e}", exc_info=True)
result = error_message
- # 记忆工具调用结果
+ # 格式化结果
if isinstance(result, str):
- formated_result = result
+ formatted_result = result
elif isinstance(result, (int, float)):
- formated_result = str(result)
+ formatted_result = str(result)
else:
- formated_result = json.dumps(result, ensure_ascii=False, indent=2)
+ formatted_result = json.dumps(result, ensure_ascii=False, indent=2)
- await conversation_manager.add_conversation(
- session_id=self._session_id,
- user_id=self._user_id,
- role="tool_result",
- content=formated_result,
- metadata={
- "call_id": call_id,
- "tool_name": self.name,
- }
- )
-
- return result
+ return formatted_result
def get_tool_message(self, **kwargs) -> Optional[str]:
"""
- 获取工具执行时的友好提示消息
-
+ 获取工具执行时的友好提示消息。
+
子类可以重写此方法,根据实际参数生成个性化的提示消息。
如果返回 None 或空字符串,将回退使用 explanation 参数。
-
+
Args:
**kwargs: 工具的所有参数(包括 explanation)
-
+
Returns:
str: 友好的提示消息,如果返回 None 或空字符串则使用 explanation
"""
@@ -128,6 +104,7 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
@abstractmethod
async def run(self, **kwargs) -> str:
+ """子类实现具体的工具执行逻辑"""
raise NotImplementedError
def set_message_attr(self, channel: str, source: str, username: str):
@@ -138,11 +115,11 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
self._source = source
self._username = username
- def set_callback_handler(self, callback_handler: StreamingCallbackHandler):
+ def set_stream_handler(self, stream_handler: StreamingHandler):
"""
设置回调处理器
"""
- self._callback_handler = callback_handler
+ self._stream_handler = stream_handler
async def send_tool_message(self, message: str, title: str = ""):
"""
@@ -155,6 +132,6 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
userid=self._user_id,
username=self._username,
title=title,
- text=message
+ text=message,
)
)
diff --git a/app/agent/tools/factory.py b/app/agent/tools/factory.py
index 5e84891a..bcbe45f4 100644
--- a/app/agent/tools/factory.py
+++ b/app/agent/tools/factory.py
@@ -54,7 +54,7 @@ class MoviePilotToolFactory:
@staticmethod
def create_tools(session_id: str, user_id: str,
channel: str = None, source: str = None, username: str = None,
- callback_handler: Callable = None) -> List[MoviePilotTool]:
+ stream_handler: Callable = None) -> List[MoviePilotTool]:
"""
创建MoviePilot工具列表
"""
@@ -109,7 +109,7 @@ class MoviePilotToolFactory:
user_id=user_id
)
tool.set_message_attr(channel=channel, source=source, username=username)
- tool.set_callback_handler(callback_handler=callback_handler)
+ tool.set_stream_handler(stream_handler=stream_handler)
tools.append(tool)
# 加载插件提供的工具
@@ -131,7 +131,7 @@ class MoviePilotToolFactory:
user_id=user_id
)
tool.set_message_attr(channel=channel, source=source, username=username)
- tool.set_callback_handler(callback_handler=callback_handler)
+ tool.set_stream_handler(stream_handler=stream_handler)
tools.append(tool)
plugin_tools_count += 1
logger.debug(f"成功加载插件 {plugin_name}({plugin_id}) 的工具: {ToolClass.__name__}")
diff --git a/app/agent/tools/manager.py b/app/agent/tools/manager.py
index 5ea24a19..5ddf7577 100644
--- a/app/agent/tools/manager.py
+++ b/app/agent/tools/manager.py
@@ -25,7 +25,7 @@ class MoviePilotToolsManager:
def __init__(self, user_id: str = "api_user", session_id: str = uuid.uuid4()):
"""
初始化工具管理器
-
+
Args:
user_id: 用户ID
session_id: 会话ID
@@ -47,7 +47,7 @@ class MoviePilotToolsManager:
channel=None,
source="api",
username="API Client",
- callback_handler=None,
+ stream_handler=None,
)
logger.info(f"成功加载 {len(self.tools)} 个工具")
except Exception as e:
@@ -57,40 +57,38 @@ class MoviePilotToolsManager:
def list_tools(self) -> List[ToolDefinition]:
"""
列出所有可用的工具
-
+
Returns:
工具定义列表
"""
tools_list = []
for tool in self.tools:
# 获取工具的输入参数模型
- args_schema = getattr(tool, 'args_schema', None)
+ args_schema = getattr(tool, "args_schema", None)
if args_schema:
# 将Pydantic模型转换为JSON Schema
input_schema = self._convert_to_json_schema(args_schema)
else:
# 如果没有args_schema,使用基本信息
- input_schema = {
- "type": "object",
- "properties": {},
- "required": []
- }
+ input_schema = {"type": "object", "properties": {}, "required": []}
- tools_list.append(ToolDefinition(
- name=tool.name,
- description=tool.description or "",
- input_schema=input_schema
- ))
+ tools_list.append(
+ ToolDefinition(
+ name=tool.name,
+ description=tool.description or "",
+ input_schema=input_schema,
+ )
+ )
return tools_list
def get_tool(self, tool_name: str) -> Optional[Any]:
"""
获取指定工具实例
-
+
Args:
tool_name: 工具名称
-
+
Returns:
工具实例,如果未找到返回None
"""
@@ -159,23 +157,26 @@ class MoviePilotToolsManager:
return []
return [
MoviePilotToolsManager._normalize_scalar_value(item_type, item.strip(), key)
- for item in trimmed.split(",") if item.strip()
+ for item in trimmed.split(",")
+ if item.strip()
]
@staticmethod
- def _normalize_arguments(tool_instance: Any, arguments: Dict[str, Any]) -> Dict[str, Any]:
+ def _normalize_arguments(
+ tool_instance: Any, arguments: Dict[str, Any]
+ ) -> Dict[str, Any]:
"""
根据工具的参数schema规范化参数类型
-
+
Args:
tool_instance: 工具实例
arguments: 原始参数
-
+
Returns:
规范化后的参数
"""
# 获取工具的参数schema
- args_schema = getattr(tool_instance, 'args_schema', None)
+ args_schema = getattr(tool_instance, "args_schema", None)
if not args_schema:
return arguments
@@ -201,31 +202,35 @@ class MoviePilotToolsManager:
# 数组类型:将字符串解析为列表
if field_type == "array" and isinstance(value, str):
item_type = field_info.get("items", {}).get("type", "string")
- normalized[key] = MoviePilotToolsManager._parse_array_string(value, key, item_type)
+ normalized[key] = MoviePilotToolsManager._parse_array_string(
+ value, key, item_type
+ )
continue
# 根据类型进行转换
- normalized[key] = MoviePilotToolsManager._normalize_scalar_value(field_type, value, key)
+ normalized[key] = MoviePilotToolsManager._normalize_scalar_value(
+ field_type, value, key
+ )
return normalized
async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> str:
"""
调用工具
-
+
Args:
tool_name: 工具名称
arguments: 工具参数
-
+
Returns:
工具执行结果(字符串)
"""
tool_instance = self.get_tool(tool_name)
if not tool_instance:
- error_msg = json.dumps({
- "error": f"工具 '{tool_name}' 未找到"
- }, ensure_ascii=False)
+ error_msg = json.dumps(
+ {"error": f"工具 '{tool_name}' 未找到"}, ensure_ascii=False
+ )
return error_msg
try:
@@ -238,7 +243,7 @@ class MoviePilotToolsManager:
# 确保返回字符串
if isinstance(result, str):
formated_result = result
- elif isinstance(result, int, float):
+ elif isinstance(result, (int, float)):
formated_result = str(result)
else:
try:
@@ -250,19 +255,20 @@ class MoviePilotToolsManager:
return formated_result
except Exception as e:
logger.error(f"调用工具 {tool_name} 时发生错误: {e}", exc_info=True)
- error_msg = json.dumps({
- "error": f"调用工具 '{tool_name}' 时发生错误: {str(e)}"
- }, ensure_ascii=False)
+ error_msg = json.dumps(
+ {"error": f"调用工具 '{tool_name}' 时发生错误: {str(e)}"},
+ ensure_ascii=False,
+ )
return error_msg
@staticmethod
def _convert_to_json_schema(args_schema: Any) -> Dict[str, Any]:
"""
将Pydantic模型转换为JSON Schema
-
+
Args:
args_schema: Pydantic模型类
-
+
Returns:
JSON Schema字典
"""
@@ -275,7 +281,9 @@ class MoviePilotToolsManager:
if "properties" in schema:
for field_name, field_info in schema["properties"].items():
- resolved_field_info = MoviePilotToolsManager._resolve_field_schema(field_info)
+ resolved_field_info = MoviePilotToolsManager._resolve_field_schema(
+ field_info
+ )
# 转换字段类型
field_type = resolved_field_info.get("type", "string")
field_description = resolved_field_info.get("description", "")
@@ -286,14 +294,14 @@ class MoviePilotToolsManager:
default_value = resolved_field_info.get("default")
properties[field_name] = {
"type": field_type,
- "description": field_description
+ "description": field_description,
}
if default_value is not None:
properties[field_name]["default"] = default_value
else:
properties[field_name] = {
"type": field_type,
- "description": field_description
+ "description": field_description,
}
required.append(field_name)
@@ -305,11 +313,7 @@ class MoviePilotToolsManager:
if field_type == "array" and "items" in resolved_field_info:
properties[field_name]["items"] = resolved_field_info["items"]
- return {
- "type": "object",
- "properties": properties,
- "required": required
- }
+ return {"type": "object", "properties": properties, "required": required}
moviepilot_tool_manager = MoviePilotToolsManager()
diff --git a/app/helper/llm.py b/app/helper/llm.py
index 4bccbaa5..0ddd9921 100644
--- a/app/helper/llm.py
+++ b/app/helper/llm.py
@@ -1,5 +1,6 @@
"""LLM模型相关辅助功能"""
-from typing import List, Optional
+
+from typing import List
from app.core.config import settings
from app.log import logger
@@ -9,11 +10,10 @@ class LLMHelper:
"""LLM模型相关辅助功能"""
@staticmethod
- def get_llm(streaming: bool = False, callbacks: Optional[list] = None):
+ def get_llm(streaming: bool = False):
"""
获取LLM实例
:param streaming: 是否启用流式输出
- :param callbacks: 回调处理器列表
:return: LLM实例
"""
provider = settings.LLM_PROVIDER.lower()
@@ -24,7 +24,9 @@ class LLMHelper:
if provider == "google":
if settings.PROXY_HOST:
+ # 通过代理使用 Google 的 OpenAI 兼容接口
from langchain_openai import ChatOpenAI
+
return ChatOpenAI(
model=settings.LLM_MODEL,
api_key=api_key,
@@ -32,33 +34,34 @@ class LLMHelper:
base_url="https://generativelanguage.googleapis.com/v1beta/openai",
temperature=settings.LLM_TEMPERATURE,
streaming=streaming,
- callbacks=callbacks,
stream_usage=True,
- openai_proxy=settings.PROXY_HOST
+ openai_proxy=settings.PROXY_HOST,
)
else:
+ # 使用 langchain-google-genai 原生接口(v4 API 变更:google_api_key → api_key,max_retries → retries)
from langchain_google_genai import ChatGoogleGenerativeAI
+
return ChatGoogleGenerativeAI(
model=settings.LLM_MODEL,
- google_api_key=api_key,
- max_retries=3,
+ api_key=api_key,
+ retries=3,
temperature=settings.LLM_TEMPERATURE,
- streaming=streaming,
- callbacks=callbacks
+ streaming=streaming
)
elif provider == "deepseek":
from langchain_deepseek import ChatDeepSeek
+
return ChatDeepSeek(
model=settings.LLM_MODEL,
api_key=api_key,
max_retries=3,
temperature=settings.LLM_TEMPERATURE,
streaming=streaming,
- callbacks=callbacks,
- stream_usage=True
+ stream_usage=True,
)
else:
from langchain_openai import ChatOpenAI
+
return ChatOpenAI(
model=settings.LLM_MODEL,
api_key=api_key,
@@ -66,12 +69,13 @@ class LLMHelper:
base_url=settings.LLM_BASE_URL,
temperature=settings.LLM_TEMPERATURE,
streaming=streaming,
- callbacks=callbacks,
stream_usage=True,
- openai_proxy=settings.PROXY_HOST
+ openai_proxy=settings.PROXY_HOST,
)
- def get_models(self, provider: str, api_key: str, base_url: str = None) -> List[str]:
+ def get_models(
+ self, provider: str, api_key: str, base_url: str = None
+ ) -> List[str]:
"""获取模型列表"""
logger.info(f"获取 {provider} 模型列表...")
if provider == "google":
@@ -81,18 +85,25 @@ class LLMHelper:
@staticmethod
def _get_google_models(api_key: str) -> List[str]:
- """获取Google模型列表"""
+ """获取Google模型列表(使用 google-genai SDK v1)"""
try:
- import google.generativeai as genai
- genai.configure(api_key=api_key)
- models = genai.list_models()
- return [m.name for m in models if 'generateContent' in m.supported_generation_methods]
+ from google import genai
+
+ client = genai.Client(api_key=api_key)
+ models = client.models.list()
+ return [
+ m.name
+ for m in models
+ if m.supported_actions and "generateContent" in m.supported_actions
+ ]
except Exception as e:
logger.error(f"获取Google模型列表失败:{e}")
raise e
@staticmethod
- def _get_openai_compatible_models(provider: str, api_key: str, base_url: str = None) -> List[str]:
+ def _get_openai_compatible_models(
+ provider: str, api_key: str, base_url: str = None
+ ) -> List[str]:
"""获取OpenAI兼容模型列表"""
try:
from openai import OpenAI
diff --git a/app/schemas/agent.py b/app/schemas/agent.py
index 38d34e63..2bc88062 100644
--- a/app/schemas/agent.py
+++ b/app/schemas/agent.py
@@ -1,38 +1,37 @@
"""AI智能体相关数据模型"""
from datetime import datetime
-from typing import Dict, List, Optional, Any
+from typing import List, Optional
+
+from langchain_core.messages import BaseMessage
from pydantic import BaseModel, Field, ConfigDict, field_serializer
class ConversationMemory(BaseModel):
"""对话记忆模型"""
-
+
session_id: str = Field(description="会话ID")
user_id: Optional[str] = Field(default=None, description="用户ID")
- title: Optional[str] = Field(default=None, description="会话标题")
- messages: List[Dict[str, Any]] = Field(default_factory=list, description="消息列表")
- context: Dict[str, Any] = Field(default_factory=dict, description="会话上下文")
- created_at: datetime = Field(default_factory=datetime.now, description="创建时间")
+ messages: List[BaseMessage] = Field(default_factory=list, description="消息列表")
updated_at: datetime = Field(default_factory=datetime.now, description="更新时间")
-
+
model_config = ConfigDict()
-
- @field_serializer('created_at', 'updated_at', when_used='json')
+
+ @field_serializer('updated_at', when_used='json')
def serialize_datetime(self, value: datetime) -> str:
return value.isoformat()
class AgentState(BaseModel):
"""AI智能体状态模型"""
-
+
session_id: str = Field(description="会话ID")
current_task: Optional[str] = Field(default=None, description="当前任务")
is_thinking: bool = Field(default=False, description="是否正在思考")
last_activity: datetime = Field(default_factory=datetime.now, description="最后活动时间")
-
+
model_config = ConfigDict()
-
+
@field_serializer('last_activity', when_used='json')
def serialize_datetime(self, value: datetime) -> str:
return value.isoformat()
@@ -40,7 +39,7 @@ class AgentState(BaseModel):
class UserMessage(BaseModel):
"""用户消息模型"""
-
+
session_id: str = Field(description="会话ID")
content: str = Field(description="消息内容")
user_id: Optional[str] = Field(default=None, description="用户ID")
@@ -50,7 +49,7 @@ class UserMessage(BaseModel):
class ToolResult(BaseModel):
"""工具执行结果模型"""
-
+
session_id: str = Field(description="会话ID")
call_id: str = Field(description="调用ID")
success: bool = Field(description="是否成功")
diff --git a/requirements.in b/requirements.in
index b7427176..432f39c8 100644
--- a/requirements.in
+++ b/requirements.in
@@ -82,14 +82,13 @@ pympler~=1.1
smbprotocol~=1.15.0
setproctitle~=1.3.6
httpx[socks]~=0.28.1
-langchain~=0.3.27
-langchain-core~=0.3.76
-langchain-community~=0.3.29
-langchain-openai~=0.3.33
-langchain-google-genai~=2.0.10
-langchain-deepseek~=0.1.4
-langchain-experimental~=0.3.4
-openai~=1.108.2
-google-generativeai~=0.8.5
+langchain~=1.2.13
+langchain-core~=1.2.20
+langchain-community~=0.4.1
+langchain-openai~=1.1.11
+langchain-google-genai~=4.2.1
+langchain-deepseek~=1.0.1
+langchain-experimental~=0.4.1
+openai~=2.29.0
ddgs~=9.10.0
websocket-client~=1.8.0