From ea4e0dd7649c149f229827290d737190547e6b8c Mon Sep 17 00:00:00 2001 From: jxxghp Date: Sun, 22 Mar 2026 21:07:45 +0800 Subject: [PATCH] feat(agent): upgrade langchain to v1.0+ --- app/agent/__init__.py | 538 ++++++++---------------------- app/agent/callback/__init__.py | 44 +-- app/agent/memory/__init__.py | 279 +++------------- app/agent/prompt/Agent Prompt.txt | 4 +- app/agent/tools/base.py | 95 ++---- app/agent/tools/factory.py | 6 +- app/agent/tools/manager.py | 88 ++--- app/helper/llm.py | 51 +-- app/schemas/agent.py | 27 +- requirements.in | 17 +- 10 files changed, 339 insertions(+), 810 deletions(-) diff --git a/app/agent/__init__.py b/app/agent/__init__.py index 0ef1f196..e46b3ee4 100644 --- a/app/agent/__init__.py +++ b/app/agent/__init__.py @@ -1,26 +1,25 @@ import asyncio -from typing import Dict, List, Any, Union -import json -import tiktoken +from time import strftime +from typing import Dict, List -from langchain.agents import AgentExecutor -from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder -from langchain_community.callbacks import get_openai_callback -from langchain_core.chat_history import InMemoryChatMessageHistory -from langchain_core.messages import HumanMessage, AIMessage, ToolCall, ToolMessage, SystemMessage, trim_messages -from langchain_core.runnables import RunnablePassthrough, RunnableLambda -from langchain_core.runnables.history import RunnableWithMessageHistory -from langchain.agents.format_scratchpad.openai_tools import format_to_openai_tool_messages -from langchain.agents.output_parsers.openai_tools import OpenAIToolsAgentOutputParser +from langchain.agents import create_agent +from langchain.agents.middleware import ( + SummarizationMiddleware, + ModelRetryMiddleware, + ToolRetryMiddleware, +) +from langchain_core.messages import ( + HumanMessage, + BaseMessage, +) +from langgraph.checkpoint.memory import InMemorySaver -from app.agent.callback import StreamingCallbackHandler -from app.agent.memory import conversation_manager +from app.agent.callback import StreamingHandler, StreamingHandler +from app.agent.memory import memory_manager from app.agent.prompt import prompt_manager from app.agent.tools.factory import MoviePilotToolFactory from app.chain import ChainBase -from app.core.config import settings from app.helper.llm import LLMHelper -from app.helper.message import MessageHelper from app.log import logger from app.schemas import Notification @@ -31,42 +30,32 @@ class AgentChain(ChainBase): class MoviePilotAgent: """ - MoviePilot AI智能体 + MoviePilot AI智能体(基于 LangChain v1 + LangGraph) """ - def __init__(self, session_id: str, user_id: str = None, - channel: str = None, source: str = None, username: str = None): + def __init__( + self, + session_id: str, + user_id: str = None, + channel: str = None, + source: str = None, + username: str = None, + ): self.session_id = session_id self.user_id = user_id - self.channel = channel # 消息渠道 - self.source = source # 消息来源 - self.username = username # 用户名 + self.channel = channel + self.source = source + self.username = username - # 消息助手 - self.message_helper = MessageHelper() + # 流式token管理 + self.stream_handler = StreamingHandler() - # 回调处理器 - self.callback_handler = StreamingCallbackHandler( - session_id=session_id - ) - - # LLM模型 - self.llm = self._initialize_llm() - - # 工具 - self.tools = self._initialize_tools() - - # 提示词模板 - self.prompt = self._initialize_prompt() - - # Agent执行器 - self.agent_executor = self._create_agent_executor() - - def _initialize_llm(self): + @staticmethod + def _initialize_llm(): """ - 初始化LLM模型 + 初始化 LLM(带流式回调) """ - return LLMHelper.get_llm(streaming=True, callbacks=[self.callback_handler]) + return LLMHelper.get_llm(streaming=True) def _initialize_tools(self) -> List: """ @@ -78,384 +67,118 @@ class MoviePilotAgent: channel=self.channel, source=self.source, username=self.username, - callback_handler=self.callback_handler + stream_handler=self.stream_handler, ) - @staticmethod - def _initialize_session_store() -> Dict[str, InMemoryChatMessageHistory]: + def _create_agent(self): """ - 初始化内存存储 - """ - return {} - - def get_session_history(self, session_id: str) -> InMemoryChatMessageHistory: - """ - 获取会话历史 - """ - chat_history = InMemoryChatMessageHistory() - messages: List[dict] = conversation_manager.get_recent_messages_for_agent( - session_id=session_id, - user_id=self.user_id - ) - if messages: - for msg in messages: - if msg.get("role") == "user": - chat_history.add_message(HumanMessage(content=msg.get("content", ""))) - elif msg.get("role") == "agent": - chat_history.add_message(AIMessage(content=msg.get("content", ""))) - elif msg.get("role") == "tool_call": - metadata = msg.get("metadata", {}) - chat_history.add_message( - AIMessage( - content=msg.get("content", ""), - tool_calls=[ - ToolCall( - id=metadata.get("call_id"), - name=metadata.get("tool_name"), - args=metadata.get("parameters"), - ) - ] - ) - ) - elif msg.get("role") == "tool_result": - metadata = msg.get("metadata", {}) - chat_history.add_message(ToolMessage( - content=msg.get("content", ""), - tool_call_id=metadata.get("call_id", "unknown") - )) - elif msg.get("role") == "system": - chat_history.add_message(SystemMessage(content=msg.get("content", ""))) - - return chat_history - - @staticmethod - def _initialize_prompt() -> ChatPromptTemplate: - """ - 初始化提示词模板 + 创建 LangGraph Agent(使用 create_agent + SummarizationMiddleware) """ try: - prompt_template = ChatPromptTemplate.from_messages([ - ("system", "{system_prompt}"), - MessagesPlaceholder(variable_name="chat_history"), - ("user", "{input}"), - MessagesPlaceholder(variable_name="agent_scratchpad"), - ]) - logger.info("LangChain提示词模板初始化成功") - return prompt_template + # 系统提示词 + system_prompt = prompt_manager.get_agent_prompt( + channel=self.channel + ).format( + current_date=strftime('%Y-%m-%d') + ) + + # LLM 模型(用于 agent 执行) + llm = self._initialize_llm() + + # 工具列表 + tools = self._initialize_tools() + + # 中间件 + middlewares = [ + # 上下文压缩 + SummarizationMiddleware( + model=llm, + trigger=("fraction", 0.85) + ), + # 模型调用失败时自动重试 + ModelRetryMiddleware(max_retries=3), + # 工具调用失败时自动重试 + ToolRetryMiddleware(max_retries=1) + ] + + return create_agent( + model=llm, + tools=tools, + system_prompt=system_prompt, + middleware=middlewares, + checkpointer=InMemorySaver(), + ) except Exception as e: - logger.error(f"初始化提示词失败: {e}") + logger.error(f"创建 Agent 失败: {e}") raise e - @staticmethod - def _token_counter(messages: List[Union[HumanMessage, AIMessage, ToolMessage, SystemMessage]]) -> int: + async def process(self, message: str) -> str: """ - 通用的Token计数器 + 处理用户消息,流式推理并返回 Agent 回复 """ try: - # 尝试从模型获取编码集,如果失败则回退到 cl100k_base (大多数现代模型使用的编码) - try: - encoding = tiktoken.encoding_for_model(settings.LLM_MODEL) - except KeyError: - encoding = tiktoken.get_encoding("cl100k_base") + logger.info(f"Agent推理: session_id={self.session_id}, input={message}") - num_tokens = 0 - for message in messages: - # 基础开销 (每个消息大约 3 个 token) - num_tokens += 3 - - # 1. 处理文本内容 (content) - if isinstance(message.content, str): - num_tokens += len(encoding.encode(message.content)) - elif isinstance(message.content, list): - for part in message.content: - if isinstance(part, dict) and part.get("type") == "text": - num_tokens += len(encoding.encode(part.get("text", ""))) - - # 2. 处理工具调用 (仅 AIMessage 包含 tool_calls) - if getattr(message, "tool_calls", None): - for tool_call in message.tool_calls: - # 函数名 - num_tokens += len(encoding.encode(tool_call.get("name", ""))) - # 参数 (转为 JSON 估算) - args_str = json.dumps(tool_call.get("args", {}), ensure_ascii=False) - num_tokens += len(encoding.encode(args_str)) - # 额外的结构开销 (ID 等) - num_tokens += 3 - - # 3. 处理角色权重 - num_tokens += 1 - - # 加上回复的起始 Token (大约 3 个 token) - num_tokens += 3 - return num_tokens - except Exception as e: - logger.error(f"Token计数失败: {e}") - # 发生错误时返回一个保守的估算值 - return len(str(messages)) // 4 - - def _create_agent_executor(self) -> RunnableWithMessageHistory: - """ - 创建Agent执行器 - """ - try: - # 消息裁剪器,防止上下文超出限制 - base_trimmer = trim_messages( - max_tokens=settings.LLM_MAX_CONTEXT_TOKENS * 1000 * 0.8, - strategy="last", - token_counter=self._token_counter, - include_system=True, - allow_partial=False, - start_on="human", - ) - - # 包装trimmer,在裁剪后验证工具调用的完整性 - def validated_trimmer(messages): - # 如果输入是 PromptValue,转换为消息列表 - if hasattr(messages, "to_messages"): - messages = messages.to_messages() - trimmed = base_trimmer.invoke(messages) - - # 二次校验:确保不出现 broken tool chains - # 1. AIMessage with tool_calls 必须紧跟着对应的 ToolMessage - # 2. ToolMessage 必须有对应的 AIMessage 前置 - safe_messages = [] - i = 0 - while i < len(trimmed): - msg = trimmed[i] - - if isinstance(msg, AIMessage) and getattr(msg, "tool_calls", None): - # 检查工具调用序列是否完整 - tool_calls = msg.tool_calls - is_valid_sequence = True - tool_results = [] - - # 向后查找对应的 ToolMessage - temp_i = i + 1 - for tool_call in tool_calls: - if temp_i >= len(trimmed): - is_valid_sequence = False - break - - next_msg = trimmed[temp_i] - if isinstance(next_msg, ToolMessage) and next_msg.tool_call_id == tool_call.get("id"): - tool_results.append(next_msg) - temp_i += 1 - else: - is_valid_sequence = False - break - - if is_valid_sequence: - # 序列完整,保留消息 - safe_messages.append(msg) - safe_messages.extend(tool_results) - i = temp_i # 跳过已处理的工具结果 - else: - # 序列不完整,丢弃该 AIMessage(后续的孤立 ToolMessage 会在下一次循环被当做 orphaned 处理掉) - logger.warning(f"移除无效的工具调用链: {len(tool_calls)} calls, incomplete results") - i += 1 - continue - - if isinstance(msg, ToolMessage): - # 如果在这里遇到 ToolMessage,说明它没有被上面的逻辑消费,则是孤立的(或者顺序错乱) - logger.warning("移除孤立的 ToolMessage") - i += 1 - continue - - # 其他类型的消息直接保留 - safe_messages.append(msg) - i += 1 - - if len(safe_messages) < len(messages): - logger.info(f"LangChain消息上下文已裁剪: {len(messages)} -> {len(safe_messages)}") - return safe_messages - - # 创建Agent执行链 - agent = ( - RunnablePassthrough.assign( - agent_scratchpad=lambda x: format_to_openai_tool_messages( - x["intermediate_steps"] - ) - ) - | self.prompt - | RunnableLambda(validated_trimmer) - | self.llm.bind_tools(self.tools) - | OpenAIToolsAgentOutputParser() - ) - executor = AgentExecutor( - agent=agent, - tools=self.tools, - verbose=settings.LLM_VERBOSE, - max_iterations=settings.LLM_MAX_ITERATIONS, - return_intermediate_steps=True, - handle_parsing_errors=True, - early_stopping_method="force" - ) - return RunnableWithMessageHistory( - executor, - self.get_session_history, - input_messages_key="input", - history_messages_key="chat_history" - ) - except Exception as e: - logger.error(f"创建Agent执行器失败: {e}") - raise e - - async def _summarize_history(self): - """ - 总结提炼之前的对话和工具执行情况,并把会话总结变成新的系统提示词取代之前的对话 - """ - try: - # 获取当前历史记录 - chat_history = self.get_session_history(self.session_id) - messages = chat_history.messages - if not messages: - return - - logger.info(f"会话 {self.session_id} 历史消息长度已超过 90%,开始总结并重置上下文...") - - # 将消息转换为摘要所需的文本格式 - history_text = "" - for msg in messages: - if isinstance(msg, HumanMessage): - history_text += f"用户: {msg.content}\n" - elif isinstance(msg, AIMessage): - history_text += f"智能体: {msg.content}\n" - if getattr(msg, "tool_calls", None): - for tool_call in msg.tool_calls: - history_text += f"智能体调用工具: {tool_call.get('name')},参数: {tool_call.get('args')}\n" - elif isinstance(msg, ToolMessage): - history_text += f"工具响应: {msg.content}\n" - elif isinstance(msg, SystemMessage): - history_text += f"系统: {msg.content}\n" - - # 摘要提示词 - summary_prompt = ( - "Please provide a comprehensive and highly informational summary of the preceding conversation and tool executions. " - "Your goal is to condense the history while retaining all critical details for future reference. " - "Ensure you include:\n" - "1. User's core intents, specific requests, and any mentioned preferences.\n" - "2. Names of movies, TV shows, or other key entities discussed.\n" - "3. A concise log of tool calls made and their specific results/outcomes.\n" - "4. The current status of any tasks and any pending actions.\n" - "5. Any important context that would be necessary for the agent to continue the conversation seamlessly.\n" - "The summary should be dense with information and serve as the primary context for the next stage of the interaction." - ) - - # 调用 LLM 进行总结 (非流式) - summary_llm = LLMHelper.get_llm(streaming=False) - response = await summary_llm.ainvoke([ - SystemMessage(content=summary_prompt), - HumanMessage(content=f"Here is the conversation history to summarize:\n{history_text}") - ]) - summary_content = str(response.content) - - if not summary_content: - logger.warning("总结生成失败,跳过重置逻辑。") - return - - # 清空原有的会话记录并插入新的系统总结 - await conversation_manager.clear_memory(self.session_id, self.user_id) - await conversation_manager.add_conversation( + # 获取历史消息 + messages = memory_manager.get_agent_messages( session_id=self.session_id, - user_id=self.user_id, - role="system", - content=f"\n{summary_content}\n" - ) - logger.info(f"会话 {self.session_id} 历史摘要替换完成。") - except Exception as e: - logger.error(f"执行会话总结出错: {str(e)}") - - async def process_message(self, message: str) -> str: - """ - 处理用户消息 - """ - try: - # 检查上下文长度是否超过 90% - history = self.get_session_history(self.session_id) - if self._token_counter(history.messages) > settings.LLM_MAX_CONTEXT_TOKENS * 1000 * 0.9: - await self._summarize_history() - - # 添加用户消息到记忆 - await conversation_manager.add_conversation( - self.session_id, - user_id=self.user_id, - role="user", - content=message + user_id=self.user_id ) - # 构建输入上下文 - input_context = { - "system_prompt": prompt_manager.get_agent_prompt(channel=self.channel), - "input": message - } + # 增加用户消息 + messages.append(HumanMessage(content=message)) - # 执行Agent - logger.info(f"Agent执行推理: session_id={self.session_id}, input={message}") - - result = await self._execute_agent(input_context) - - # 获取Agent回复 - agent_message = await self.callback_handler.get_message() - - # 发送Agent回复给用户(通过原渠道) - if agent_message: - # 发送回复 - await self.send_agent_message(agent_message) - - # 添加Agent回复到记忆 - await conversation_manager.add_conversation( - session_id=self.session_id, - user_id=self.user_id, - role="agent", - content=agent_message - ) - else: - agent_message = result.get("output") or "很抱歉,智能体出错了,未能生成回复内容。" - await self.send_agent_message(agent_message) - - return agent_message + # 执行推理 + await self._execute_agent(messages) except Exception as e: error_message = f"处理消息时发生错误: {str(e)}" logger.error(error_message) - # 发送错误消息给用户(通过原渠道) await self.send_agent_message(error_message) return error_message - async def _execute_agent(self, input_context: Dict[str, Any]) -> Dict[str, Any]: + async def _execute_agent(self, messages: List[BaseMessage]): """ - 执行LangChain Agent + 调用 LangGraph Agent,通过 astream_events 流式获取 token, + 同时用 UsageMetadataCallbackHandler 统计 token 用量。 """ try: - with get_openai_callback() as cb: - result = await self.agent_executor.ainvoke( - input_context, - config={"configurable": {"session_id": self.session_id}}, - callbacks=[self.callback_handler] - ) - logger.info(f"LLM调用消耗: \n{cb}") + # Agent运行配置 + agent_config = { + "" + } + + # 创建智能体 + agent = self._create_agent() + + # 流式运行智能体 + async for chunk in agent.astream( + {"messages": messages}, + stream_mode="messages", + config=agent_config + ): + token, metadata = chunk + if token: + self.stream_handler.emit(token.content) + + # 发送最终消息给用户 + await self.send_agent_message( + self.stream_handler.take() + ) + + # 保存消息 + memory_manager.save_agent_messages( + session_id=self.session_id, + user_id=self.user_id, + messages=agent.get_state(agent_config).values("messages") + ) - if cb.total_tokens > 0: - result["token_usage"] = { - "prompt_tokens": cb.prompt_tokens, - "completion_tokens": cb.completion_tokens, - "total_tokens": cb.total_tokens - } - return result except asyncio.CancelledError: logger.info(f"Agent执行被取消: session_id={self.session_id}") - return { - "output": "任务已取消", - "intermediate_steps": [], - "token_usage": {} - } + return "任务已取消", {} except Exception as e: logger.error(f"Agent执行失败: {e}") - return { - "output": str(e), - "intermediate_steps": [], - "token_usage": {} - } + return str(e), {} async def send_agent_message(self, message: str, title: str = "MoviePilot助手"): """ @@ -468,7 +191,7 @@ class MoviePilotAgent: userid=self.user_id, username=self.username, title=title, - text=message + text=message, ) ) @@ -492,38 +215,44 @@ class AgentManager: """ 初始化管理器 """ - await conversation_manager.initialize() + await memory_manager.initialize() async def close(self): """ 关闭管理器 """ - await conversation_manager.close() - # 清理所有活跃的智能体 + await memory_manager.close() for agent in self.active_agents.values(): await agent.cleanup() self.active_agents.clear() - async def process_message(self, session_id: str, user_id: str, message: str, - channel: str = None, source: str = None, username: str = None) -> str: + async def process_message( + self, + session_id: str, + user_id: str, + message: str, + channel: str = None, + source: str = None, + username: str = None, + ) -> str: """ 处理用户消息 """ - # 获取或创建Agent实例 if session_id not in self.active_agents: - logger.info(f"创建新的AI智能体实例,session_id: {session_id}, user_id: {user_id}") + logger.info( + f"创建新的AI智能体实例,session_id: {session_id}, user_id: {user_id}" + ) agent = MoviePilotAgent( session_id=session_id, user_id=user_id, channel=channel, source=source, - username=username + username=username, ) self.active_agents[session_id] = agent else: agent = self.active_agents[session_id] - agent.user_id = user_id # 确保user_id是最新的 - # 更新渠道信息 + agent.user_id = user_id if channel: agent.channel = channel if source: @@ -531,8 +260,7 @@ class AgentManager: if username: agent.username = username - # 处理消息 - return await agent.process_message(message) + return await agent.process(message) async def clear_session(self, session_id: str, user_id: str): """ @@ -542,7 +270,7 @@ class AgentManager: agent = self.active_agents[session_id] await agent.cleanup() del self.active_agents[session_id] - await conversation_manager.clear_memory(session_id, user_id) + await memory_manager.clear_memory(session_id, user_id) logger.info(f"会话 {session_id} 的记忆已清空") diff --git a/app/agent/callback/__init__.py b/app/agent/callback/__init__.py index 5da3957d..963834c4 100644 --- a/app/agent/callback/__init__.py +++ b/app/agent/callback/__init__.py @@ -1,39 +1,41 @@ import threading -from langchain_core.callbacks import AsyncCallbackHandler - from app.log import logger -class StreamingCallbackHandler(AsyncCallbackHandler): +class StreamingHandler: """ - 流式输出回调处理器 + 流式Token缓冲管理器 + + 负责从 LLM 流式 token 中积累文本,供 Agent 在工具调用之间穿插发送中间消息。 """ - def __init__(self, session_id: str): + def __init__(self): self._lock = threading.Lock() - self.session_id = session_id - self.current_message = "" + self._buffer = "" - async def get_message(self): + def emit(self, token: str): """ - 获取当前消息内容,获取后清空 + 接收 LLM 流式 token,积累到缓冲区。 """ with self._lock: - if not self.current_message: + self._buffer += token + + def take(self) -> str: + """ + 获取当前已积累的消息内容,获取后清空缓冲区。 + """ + with self._lock: + if not self._buffer: return "" - msg = self.current_message - logger.info(f"Agent消息: {msg}") - self.current_message = "" - return msg + message = self._buffer + logger.info(f"Agent消息: {message}") + self._buffer = "" + return message - async def on_llm_new_token(self, token: str, **kwargs): + def clear(self): """ - 处理新的token + 清空缓冲区(不返回内容) """ - if not token: - return with self._lock: - # 缓存当前消息 - self.current_message += token - + self._buffer = "" diff --git a/app/agent/memory/__init__.py b/app/agent/memory/__init__.py index d2957a4a..3fba4f6e 100644 --- a/app/agent/memory/__init__.py +++ b/app/agent/memory/__init__.py @@ -1,17 +1,17 @@ """对话记忆管理器""" import asyncio -import json -from datetime import datetime, timedelta -from typing import Dict, List, Optional, Any +from datetime import datetime +from typing import Dict, List, Optional + +from langchain_core.messages import BaseMessage from app.core.config import settings -from app.helper.redis import AsyncRedisHelper from app.log import logger from app.schemas.agent import ConversationMemory -class ConversationMemoryManager: +class MemoryManager: """ 对话记忆管理器 """ @@ -19,18 +19,18 @@ class ConversationMemoryManager: def __init__(self): # 内存中的会话记忆缓存 self.memory_cache: Dict[str, ConversationMemory] = {} - # 使用现有的Redis助手 - self.redis_helper = AsyncRedisHelper() - # 内存缓存清理任务(Redis通过TTL自动过期) + # 内存缓存清理任务 self.cleanup_task: Optional[asyncio.Task] = None - async def initialize(self): + def initialize(self): """ 初始化记忆管理器 """ try: # 启动内存缓存清理任务(Redis通过TTL自动过期) - self.cleanup_task = asyncio.create_task(self._cleanup_expired_memories()) + self.cleanup_task = asyncio.create_task( + self._cleanup_expired_memories() + ) logger.info("对话记忆管理器初始化完成") except Exception as e: @@ -47,8 +47,6 @@ class ConversationMemoryManager: except asyncio.CancelledError: pass - await self.redis_helper.close() - logger.info("对话记忆管理器已关闭") @staticmethod @@ -58,258 +56,64 @@ class ConversationMemoryManager: """ return f"{user_id}:{session_id}" if user_id else session_id - @staticmethod - def _get_redis_key(session_id: str, user_id: str): - """ - 计算Redis Key - """ - return f"agent_memory:{user_id}:{session_id}" if user_id else f"agent_memory:{session_id}" - - def _get_memory(self, session_id: str, user_id: str): + def get_memory(self, session_id: str, user_id: str) -> Optional[ConversationMemory]: """ 获取内存中的记忆 """ cache_key = self._get_memory_key(session_id, user_id) return self.memory_cache.get(cache_key) - - async def _get_redis(self, session_id: str, user_id: str) -> Optional[ConversationMemory]: - """ - 从Redis获取记忆 - """ - if settings.CACHE_BACKEND_TYPE == "redis": - try: - redis_key = self._get_redis_key(session_id, user_id) - memory_data = await self.redis_helper.get(redis_key, region="AI_AGENT") - if memory_data: - memory_dict = json.loads(memory_data) if isinstance(memory_data, str) else memory_data - memory = ConversationMemory(**memory_dict) - return memory - except Exception as e: - logger.warning(f"从Redis加载记忆失败: {e}") - return None - async def get_conversation(self, session_id: str, user_id: str) -> ConversationMemory: - """ - 获取会话记忆 - """ - # 首先检查缓存 - conversion = self._get_memory(session_id, user_id) - if conversion: - return conversion - - # 尝试从Redis加载 - memory = await self._get_redis(session_id, user_id) - if memory: - # 加载到内存缓存 - self._save_memory(memory) - return memory - - # 创建新的记忆 - memory = ConversationMemory(session_id=session_id, user_id=user_id) - await self._save_conversation(memory) - - return memory - - async def set_title(self, session_id: str, user_id: str, title: str): - """ - 设置会话标题 - """ - memory = await self.get_conversation(session_id=session_id, user_id=user_id) - memory.title = title - memory.updated_at = datetime.now() - await self._save_conversation(memory) - - async def get_title(self, session_id: str, user_id: str) -> Optional[str]: - """ - 获取会话标题 - """ - memory = await self.get_conversation(session_id=session_id, user_id=user_id) - return memory.title - - async def list_sessions(self, user_id: str, limit: int = 100) -> List[Dict[str, Any]]: - """ - 列出历史会话摘要(按更新时间倒序) - - - 当启用Redis时:遍历 `agent_memory:*` 键并读取摘要 - - 当未启用Redis时:基于内存缓存返回 - """ - sessions: List[ConversationMemory] = [] - # 从Redis遍历 - if settings.CACHE_BACKEND_TYPE == "redis": - try: - # 使用Redis助手的items方法遍历所有键 - async for key, value in self.redis_helper.items(region="AI_AGENT"): - if key.startswith("agent_memory:"): - try: - # 解析键名获取user_id和session_id - key_parts = key.split(":") - if len(key_parts) >= 3: - key_user_id = key_parts[2] if len(key_parts) > 3 else None - if not user_id or key_user_id == user_id: - data = value if isinstance(value, dict) else json.loads(value) - memory = ConversationMemory(**data) - sessions.append(memory) - except Exception as err: - logger.warning(f"解析Redis记忆数据失败: {err}") - continue - except Exception as e: - logger.warning(f"遍历Redis会话失败: {e}") - - # 合并内存缓存(确保包含近期的会话) - for cache_key, memory in self.memory_cache.items(): - # 如果指定了user_id,只返回该用户的会话 - if not user_id or memory.user_id == user_id: - sessions.append(memory) - - # 去重(以 session_id 为键,取最近updated) - uniq: Dict[str, ConversationMemory] = {} - for mem in sessions: - existed = uniq.get(mem.session_id) - if (not existed) or (mem.updated_at > existed.updated_at): - uniq[mem.session_id] = mem - - # 排序并裁剪 - sorted_list = sorted(uniq.values(), key=lambda m: m.updated_at, reverse=True)[:limit] - return [ - { - "session_id": m.session_id, - "title": m.title or "新会话", - "message_count": len(m.messages), - "created_at": m.created_at.isoformat(), - "updated_at": m.updated_at.isoformat(), - } - for m in sorted_list - ] - - async def add_conversation( - self, - session_id: str, - user_id: str, - role: str, - content: str, - metadata: Optional[Dict[str, Any]] = None - ): - """ - 添加消息到记忆 - """ - memory = await self.get_conversation(session_id=session_id, user_id=user_id) - - message = { - "role": role, - "content": content, - "timestamp": datetime.now().isoformat(), - "metadata": metadata or {} - } - - memory.messages.append(message) - memory.updated_at = datetime.now() - - # 限制消息数量,避免记忆过大 - max_messages = settings.LLM_MAX_MEMORY_MESSAGES - if len(memory.messages) > max_messages: - # 保留最近的消息,但保留第一条系统消息 - system_messages = [msg for msg in memory.messages if msg["role"] == "system"] - recent_messages = memory.messages[-(max_messages - len(system_messages)):] - memory.messages = system_messages + recent_messages - - await self._save_conversation(memory) - - logger.debug(f"消息已添加到记忆: session_id={session_id}, user_id={user_id}, role={role}") - - def get_recent_messages_for_agent( - self, - session_id: str, - user_id: str - ) -> List[Dict[str, Any]]: + def get_agent_messages( + self, session_id: str, user_id: str + ) -> List[BaseMessage]: """ 为Agent获取最近的消息(仅内存缓存) 如果消息Token数量超过模型最大上下文长度的阀值,会自动进行摘要裁剪 """ - cache_key = self._get_memory_key(session_id, user_id) - memory = self.memory_cache.get(cache_key) + memory = self.get_memory(session_id, user_id) if not memory: return [] # 获取所有消息 - return memory.messages[:-1] + return memory.messages - async def get_recent_messages( - self, - session_id: str, - user_id: str, - limit: int = 10, - role_filter: Optional[list] = None - ) -> List[Dict[str, Any]]: + def save_agent_messages( + self, session_id: str, user_id: str, messages: List[BaseMessage] + ): """ - 获取最近的消息 + 保存Agent消息(仅内存缓存) + + 注意:Redis中的记忆通过TTL机制自动过期,这里只更新内存缓存,Redis会在下次访问时自动过期 """ - memory = await self.get_conversation(session_id=session_id, user_id=user_id) + memory = self.get_memory(session_id, user_id) + if not memory: + memory = ConversationMemory(session_id=session_id, user_id=user_id) - messages = memory.messages - if role_filter: - messages = [msg for msg in messages if msg["role"] in role_filter] + memory.messages = messages + memory.updated_at = datetime.now() - return messages[-limit:] if messages else [] + # 更新内存缓存 + self.save_memory(memory) - async def get_context(self, session_id: str, user_id: str) -> Dict[str, Any]: + def save_memory(self, memory: ConversationMemory): """ - 获取会话上下文 - """ - memory = await self.get_conversation(session_id=session_id, user_id=user_id) - return memory.context + 保存记忆到内存缓存 - async def clear_memory(self, session_id: str, user_id: str): - """ - 清空会话记忆 - """ - cache_key = f"{user_id}:{session_id}" if user_id else session_id - if cache_key in self.memory_cache: - del self.memory_cache[cache_key] - - if settings.CACHE_BACKEND_TYPE == "redis": - redis_key = self._get_redis_key(session_id, user_id) - await self.redis_helper.delete(redis_key, region="AI_AGENT") - - logger.info(f"会话记忆已清空: session_id={session_id}, user_id={user_id}") - - def _save_memory(self, memory: ConversationMemory): - """ - 保存记忆到内存 + 注意:Redis中的记忆通过TTL机制自动过期,这里只更新内存缓存,Redis会在下次访问时自动过期 """ cache_key = self._get_memory_key(memory.session_id, memory.user_id) self.memory_cache[cache_key] = memory - async def _save_redis(self, memory: ConversationMemory): + def clear_memory(self, session_id: str, user_id: str): """ - 保存记忆到Redis + 清空会话记忆 """ - if settings.CACHE_BACKEND_TYPE == "redis": - try: - memory_dict = memory.model_dump() - redis_key = self._get_redis_key(memory.session_id, memory.user_id) - ttl = int(timedelta(days=settings.LLM_REDIS_MEMORY_RETENTION_DAYS).total_seconds()) - await self.redis_helper.set( - redis_key, - memory_dict, - ttl=ttl, - region="AI_AGENT" - ) - except Exception as e: - logger.warning(f"保存记忆到Redis失败: {e}") - - async def _save_conversation(self, memory: ConversationMemory): - """ - 保存记忆到存储 - - Redis中的记忆会自动通过TTL机制过期,无需手动清理 - """ - # 更新内存缓存 - self._save_memory(memory) - - # 保存到Redis,设置TTL自动过期 - await self._save_redis(memory) + cache_key = self._get_memory_key(session_id, user_id) + if cache_key in self.memory_cache: + del self.memory_cache[cache_key] + logger.info(f"会话记忆已清空: session_id={session_id}, user_id={user_id}") async def _cleanup_expired_memories(self): """ @@ -328,7 +132,9 @@ class ConversationMemoryManager: # 只检查内存缓存中的过期记忆 # Redis中的记忆会通过TTL自动过期,无需手动处理 for cache_key, memory in self.memory_cache.items(): - if (current_time - memory.updated_at).days > settings.LLM_MEMORY_RETENTION_DAYS: + if ( + current_time - memory.updated_at + ).days > settings.LLM_MEMORY_RETENTION_DAYS: expired_sessions.append(cache_key) # 只清理内存缓存,不删除Redis中的键(Redis会自动过期) @@ -344,4 +150,5 @@ class ConversationMemoryManager: except Exception as e: logger.error(f"清理记忆时发生错误: {e}") -conversation_manager = ConversationMemoryManager() + +memory_manager = MemoryManager() diff --git a/app/agent/prompt/Agent Prompt.txt b/app/agent/prompt/Agent Prompt.txt index 0f46c8c8..65682069 100644 --- a/app/agent/prompt/Agent Prompt.txt +++ b/app/agent/prompt/Agent Prompt.txt @@ -69,4 +69,6 @@ At the end of your session/turn, provide a concise summary of your actions. Specific markdown rules: {markdown_spec} - \ No newline at end of file + + +Today's date: {current_date} diff --git a/app/agent/tools/base.py b/app/agent/tools/base.py index 5d05dd9a..bb1b12d5 100644 --- a/app/agent/tools/base.py +++ b/app/agent/tools/base.py @@ -1,12 +1,11 @@ import json -import uuid from abc import ABCMeta, abstractmethod from typing import Any, Optional -from langchain.tools import BaseTool +from langchain_core.tools import BaseTool from pydantic import PrivateAttr -from app.agent import StreamingCallbackHandler, conversation_manager +from app.agent import StreamingHandler from app.chain import ChainBase from app.log import logger from app.schemas import Notification @@ -18,15 +17,15 @@ class ToolChain(ChainBase): class MoviePilotTool(BaseTool, metaclass=ABCMeta): """ - MoviePilot专用工具基类 + MoviePilot专用工具基类(LangChain v1 / langchain_core) """ _session_id: str = PrivateAttr() _user_id: str = PrivateAttr() - _channel: str = PrivateAttr(default=None) - _source: str = PrivateAttr(default=None) - _username: str = PrivateAttr(default=None) - _callback_handler: StreamingCallbackHandler = PrivateAttr(default=None) + _channel: Optional[str] = PrivateAttr(default=None) + _source: Optional[str] = PrivateAttr(default=None) + _username: Optional[str] = PrivateAttr(default=None) + _stream_handler: Optional[StreamingHandler] = PrivateAttr(default=None) def __init__(self, session_id: str, user_id: str, **kwargs): super().__init__(**kwargs) @@ -34,93 +33,70 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta): self._user_id = user_id def _run(self, *args: Any, **kwargs: Any) -> Any: - pass + raise NotImplementedError("MoviePilotTool 只支持异步调用,请使用 _arun") - async def _arun(self, **kwargs) -> str: + async def _arun(self, *args: Any, **kwargs: Any) -> str: """ - 异步运行工具 + 异步运行工具,负责: + 1. 在工具调用前将流式消息推送给用户 + 2. 持久化工具调用记录到会话记忆 + 3. 调用具体工具逻辑(子类实现的 execute 方法) + 4. 持久化工具结果到会话记忆 """ - # 获取工具调用前的agent消息 - agent_message = await self._callback_handler.get_message() - - # 生成唯一的工具调用ID - call_id = f"call_{str(uuid.uuid4())[:16]}" - - # 记忆工具调用 - await conversation_manager.add_conversation( - session_id=self._session_id, - user_id=self._user_id, - role="tool_call", - content=agent_message, - metadata={ - "call_id": call_id, - "tool_name": self.name, - "parameters": kwargs - } + # 获取工具调用前 Agent 已积累的流式文本 + agent_message = ( + self._stream_handler.take() if self._stream_handler else "" ) - # 获取执行工具说明,优先使用工具自定义的提示消息,如果没有则使用 explanation + # 获取工具执行提示消息 tool_message = self.get_tool_message(**kwargs) if not tool_message: explanation = kwargs.get("explanation") if explanation: tool_message = explanation - # 合并agent消息和工具执行消息,一起发送 + # 合并 Agent 消息和工具执行消息后一起发送 messages = [] if agent_message: messages.append(agent_message) if tool_message: messages.append(f"⚙️ => {tool_message}") - # 发送合并后的消息 if messages: merged_message = "\n\n".join(messages) await self.send_tool_message(merged_message, title="MoviePilot助手") - logger.debug(f'Executing tool {self.name} with args: {kwargs}') + logger.debug(f"Executing tool {self.name} with args: {kwargs}") - # 执行工具,捕获异常确保结果总是被存储到记忆中 + # 执行具体工具逻辑 try: result = await self.run(**kwargs) - logger.debug(f'Tool {self.name} executed with result: {result}') + logger.debug(f"Tool {self.name} executed with result: {result}") except Exception as e: - # 记录异常详情 error_message = f"工具执行异常 ({type(e).__name__}): {str(e)}" - logger.error(f'Tool {self.name} execution failed: {e}', exc_info=True) + logger.error(f"Tool {self.name} execution failed: {e}", exc_info=True) result = error_message - # 记忆工具调用结果 + # 格式化结果 if isinstance(result, str): - formated_result = result + formatted_result = result elif isinstance(result, (int, float)): - formated_result = str(result) + formatted_result = str(result) else: - formated_result = json.dumps(result, ensure_ascii=False, indent=2) + formatted_result = json.dumps(result, ensure_ascii=False, indent=2) - await conversation_manager.add_conversation( - session_id=self._session_id, - user_id=self._user_id, - role="tool_result", - content=formated_result, - metadata={ - "call_id": call_id, - "tool_name": self.name, - } - ) - - return result + return formatted_result def get_tool_message(self, **kwargs) -> Optional[str]: """ - 获取工具执行时的友好提示消息 - + 获取工具执行时的友好提示消息。 + 子类可以重写此方法,根据实际参数生成个性化的提示消息。 如果返回 None 或空字符串,将回退使用 explanation 参数。 - + Args: **kwargs: 工具的所有参数(包括 explanation) - + Returns: str: 友好的提示消息,如果返回 None 或空字符串则使用 explanation """ @@ -128,6 +104,7 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta): @abstractmethod async def run(self, **kwargs) -> str: + """子类实现具体的工具执行逻辑""" raise NotImplementedError def set_message_attr(self, channel: str, source: str, username: str): @@ -138,11 +115,11 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta): self._source = source self._username = username - def set_callback_handler(self, callback_handler: StreamingCallbackHandler): + def set_stream_handler(self, stream_handler: StreamingHandler): """ 设置回调处理器 """ - self._callback_handler = callback_handler + self._stream_handler = stream_handler async def send_tool_message(self, message: str, title: str = ""): """ @@ -155,6 +132,6 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta): userid=self._user_id, username=self._username, title=title, - text=message + text=message, ) ) diff --git a/app/agent/tools/factory.py b/app/agent/tools/factory.py index 5e84891a..bcbe45f4 100644 --- a/app/agent/tools/factory.py +++ b/app/agent/tools/factory.py @@ -54,7 +54,7 @@ class MoviePilotToolFactory: @staticmethod def create_tools(session_id: str, user_id: str, channel: str = None, source: str = None, username: str = None, - callback_handler: Callable = None) -> List[MoviePilotTool]: + stream_handler: Callable = None) -> List[MoviePilotTool]: """ 创建MoviePilot工具列表 """ @@ -109,7 +109,7 @@ class MoviePilotToolFactory: user_id=user_id ) tool.set_message_attr(channel=channel, source=source, username=username) - tool.set_callback_handler(callback_handler=callback_handler) + tool.set_stream_handler(stream_handler=stream_handler) tools.append(tool) # 加载插件提供的工具 @@ -131,7 +131,7 @@ class MoviePilotToolFactory: user_id=user_id ) tool.set_message_attr(channel=channel, source=source, username=username) - tool.set_callback_handler(callback_handler=callback_handler) + tool.set_stream_handler(stream_handler=stream_handler) tools.append(tool) plugin_tools_count += 1 logger.debug(f"成功加载插件 {plugin_name}({plugin_id}) 的工具: {ToolClass.__name__}") diff --git a/app/agent/tools/manager.py b/app/agent/tools/manager.py index 5ea24a19..5ddf7577 100644 --- a/app/agent/tools/manager.py +++ b/app/agent/tools/manager.py @@ -25,7 +25,7 @@ class MoviePilotToolsManager: def __init__(self, user_id: str = "api_user", session_id: str = uuid.uuid4()): """ 初始化工具管理器 - + Args: user_id: 用户ID session_id: 会话ID @@ -47,7 +47,7 @@ class MoviePilotToolsManager: channel=None, source="api", username="API Client", - callback_handler=None, + stream_handler=None, ) logger.info(f"成功加载 {len(self.tools)} 个工具") except Exception as e: @@ -57,40 +57,38 @@ class MoviePilotToolsManager: def list_tools(self) -> List[ToolDefinition]: """ 列出所有可用的工具 - + Returns: 工具定义列表 """ tools_list = [] for tool in self.tools: # 获取工具的输入参数模型 - args_schema = getattr(tool, 'args_schema', None) + args_schema = getattr(tool, "args_schema", None) if args_schema: # 将Pydantic模型转换为JSON Schema input_schema = self._convert_to_json_schema(args_schema) else: # 如果没有args_schema,使用基本信息 - input_schema = { - "type": "object", - "properties": {}, - "required": [] - } + input_schema = {"type": "object", "properties": {}, "required": []} - tools_list.append(ToolDefinition( - name=tool.name, - description=tool.description or "", - input_schema=input_schema - )) + tools_list.append( + ToolDefinition( + name=tool.name, + description=tool.description or "", + input_schema=input_schema, + ) + ) return tools_list def get_tool(self, tool_name: str) -> Optional[Any]: """ 获取指定工具实例 - + Args: tool_name: 工具名称 - + Returns: 工具实例,如果未找到返回None """ @@ -159,23 +157,26 @@ class MoviePilotToolsManager: return [] return [ MoviePilotToolsManager._normalize_scalar_value(item_type, item.strip(), key) - for item in trimmed.split(",") if item.strip() + for item in trimmed.split(",") + if item.strip() ] @staticmethod - def _normalize_arguments(tool_instance: Any, arguments: Dict[str, Any]) -> Dict[str, Any]: + def _normalize_arguments( + tool_instance: Any, arguments: Dict[str, Any] + ) -> Dict[str, Any]: """ 根据工具的参数schema规范化参数类型 - + Args: tool_instance: 工具实例 arguments: 原始参数 - + Returns: 规范化后的参数 """ # 获取工具的参数schema - args_schema = getattr(tool_instance, 'args_schema', None) + args_schema = getattr(tool_instance, "args_schema", None) if not args_schema: return arguments @@ -201,31 +202,35 @@ class MoviePilotToolsManager: # 数组类型:将字符串解析为列表 if field_type == "array" and isinstance(value, str): item_type = field_info.get("items", {}).get("type", "string") - normalized[key] = MoviePilotToolsManager._parse_array_string(value, key, item_type) + normalized[key] = MoviePilotToolsManager._parse_array_string( + value, key, item_type + ) continue # 根据类型进行转换 - normalized[key] = MoviePilotToolsManager._normalize_scalar_value(field_type, value, key) + normalized[key] = MoviePilotToolsManager._normalize_scalar_value( + field_type, value, key + ) return normalized async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> str: """ 调用工具 - + Args: tool_name: 工具名称 arguments: 工具参数 - + Returns: 工具执行结果(字符串) """ tool_instance = self.get_tool(tool_name) if not tool_instance: - error_msg = json.dumps({ - "error": f"工具 '{tool_name}' 未找到" - }, ensure_ascii=False) + error_msg = json.dumps( + {"error": f"工具 '{tool_name}' 未找到"}, ensure_ascii=False + ) return error_msg try: @@ -238,7 +243,7 @@ class MoviePilotToolsManager: # 确保返回字符串 if isinstance(result, str): formated_result = result - elif isinstance(result, int, float): + elif isinstance(result, (int, float)): formated_result = str(result) else: try: @@ -250,19 +255,20 @@ class MoviePilotToolsManager: return formated_result except Exception as e: logger.error(f"调用工具 {tool_name} 时发生错误: {e}", exc_info=True) - error_msg = json.dumps({ - "error": f"调用工具 '{tool_name}' 时发生错误: {str(e)}" - }, ensure_ascii=False) + error_msg = json.dumps( + {"error": f"调用工具 '{tool_name}' 时发生错误: {str(e)}"}, + ensure_ascii=False, + ) return error_msg @staticmethod def _convert_to_json_schema(args_schema: Any) -> Dict[str, Any]: """ 将Pydantic模型转换为JSON Schema - + Args: args_schema: Pydantic模型类 - + Returns: JSON Schema字典 """ @@ -275,7 +281,9 @@ class MoviePilotToolsManager: if "properties" in schema: for field_name, field_info in schema["properties"].items(): - resolved_field_info = MoviePilotToolsManager._resolve_field_schema(field_info) + resolved_field_info = MoviePilotToolsManager._resolve_field_schema( + field_info + ) # 转换字段类型 field_type = resolved_field_info.get("type", "string") field_description = resolved_field_info.get("description", "") @@ -286,14 +294,14 @@ class MoviePilotToolsManager: default_value = resolved_field_info.get("default") properties[field_name] = { "type": field_type, - "description": field_description + "description": field_description, } if default_value is not None: properties[field_name]["default"] = default_value else: properties[field_name] = { "type": field_type, - "description": field_description + "description": field_description, } required.append(field_name) @@ -305,11 +313,7 @@ class MoviePilotToolsManager: if field_type == "array" and "items" in resolved_field_info: properties[field_name]["items"] = resolved_field_info["items"] - return { - "type": "object", - "properties": properties, - "required": required - } + return {"type": "object", "properties": properties, "required": required} moviepilot_tool_manager = MoviePilotToolsManager() diff --git a/app/helper/llm.py b/app/helper/llm.py index 4bccbaa5..0ddd9921 100644 --- a/app/helper/llm.py +++ b/app/helper/llm.py @@ -1,5 +1,6 @@ """LLM模型相关辅助功能""" -from typing import List, Optional + +from typing import List from app.core.config import settings from app.log import logger @@ -9,11 +10,10 @@ class LLMHelper: """LLM模型相关辅助功能""" @staticmethod - def get_llm(streaming: bool = False, callbacks: Optional[list] = None): + def get_llm(streaming: bool = False): """ 获取LLM实例 :param streaming: 是否启用流式输出 - :param callbacks: 回调处理器列表 :return: LLM实例 """ provider = settings.LLM_PROVIDER.lower() @@ -24,7 +24,9 @@ class LLMHelper: if provider == "google": if settings.PROXY_HOST: + # 通过代理使用 Google 的 OpenAI 兼容接口 from langchain_openai import ChatOpenAI + return ChatOpenAI( model=settings.LLM_MODEL, api_key=api_key, @@ -32,33 +34,34 @@ class LLMHelper: base_url="https://generativelanguage.googleapis.com/v1beta/openai", temperature=settings.LLM_TEMPERATURE, streaming=streaming, - callbacks=callbacks, stream_usage=True, - openai_proxy=settings.PROXY_HOST + openai_proxy=settings.PROXY_HOST, ) else: + # 使用 langchain-google-genai 原生接口(v4 API 变更:google_api_key → api_key,max_retries → retries) from langchain_google_genai import ChatGoogleGenerativeAI + return ChatGoogleGenerativeAI( model=settings.LLM_MODEL, - google_api_key=api_key, - max_retries=3, + api_key=api_key, + retries=3, temperature=settings.LLM_TEMPERATURE, - streaming=streaming, - callbacks=callbacks + streaming=streaming ) elif provider == "deepseek": from langchain_deepseek import ChatDeepSeek + return ChatDeepSeek( model=settings.LLM_MODEL, api_key=api_key, max_retries=3, temperature=settings.LLM_TEMPERATURE, streaming=streaming, - callbacks=callbacks, - stream_usage=True + stream_usage=True, ) else: from langchain_openai import ChatOpenAI + return ChatOpenAI( model=settings.LLM_MODEL, api_key=api_key, @@ -66,12 +69,13 @@ class LLMHelper: base_url=settings.LLM_BASE_URL, temperature=settings.LLM_TEMPERATURE, streaming=streaming, - callbacks=callbacks, stream_usage=True, - openai_proxy=settings.PROXY_HOST + openai_proxy=settings.PROXY_HOST, ) - def get_models(self, provider: str, api_key: str, base_url: str = None) -> List[str]: + def get_models( + self, provider: str, api_key: str, base_url: str = None + ) -> List[str]: """获取模型列表""" logger.info(f"获取 {provider} 模型列表...") if provider == "google": @@ -81,18 +85,25 @@ class LLMHelper: @staticmethod def _get_google_models(api_key: str) -> List[str]: - """获取Google模型列表""" + """获取Google模型列表(使用 google-genai SDK v1)""" try: - import google.generativeai as genai - genai.configure(api_key=api_key) - models = genai.list_models() - return [m.name for m in models if 'generateContent' in m.supported_generation_methods] + from google import genai + + client = genai.Client(api_key=api_key) + models = client.models.list() + return [ + m.name + for m in models + if m.supported_actions and "generateContent" in m.supported_actions + ] except Exception as e: logger.error(f"获取Google模型列表失败:{e}") raise e @staticmethod - def _get_openai_compatible_models(provider: str, api_key: str, base_url: str = None) -> List[str]: + def _get_openai_compatible_models( + provider: str, api_key: str, base_url: str = None + ) -> List[str]: """获取OpenAI兼容模型列表""" try: from openai import OpenAI diff --git a/app/schemas/agent.py b/app/schemas/agent.py index 38d34e63..2bc88062 100644 --- a/app/schemas/agent.py +++ b/app/schemas/agent.py @@ -1,38 +1,37 @@ """AI智能体相关数据模型""" from datetime import datetime -from typing import Dict, List, Optional, Any +from typing import List, Optional + +from langchain_core.messages import BaseMessage from pydantic import BaseModel, Field, ConfigDict, field_serializer class ConversationMemory(BaseModel): """对话记忆模型""" - + session_id: str = Field(description="会话ID") user_id: Optional[str] = Field(default=None, description="用户ID") - title: Optional[str] = Field(default=None, description="会话标题") - messages: List[Dict[str, Any]] = Field(default_factory=list, description="消息列表") - context: Dict[str, Any] = Field(default_factory=dict, description="会话上下文") - created_at: datetime = Field(default_factory=datetime.now, description="创建时间") + messages: List[BaseMessage] = Field(default_factory=list, description="消息列表") updated_at: datetime = Field(default_factory=datetime.now, description="更新时间") - + model_config = ConfigDict() - - @field_serializer('created_at', 'updated_at', when_used='json') + + @field_serializer('updated_at', when_used='json') def serialize_datetime(self, value: datetime) -> str: return value.isoformat() class AgentState(BaseModel): """AI智能体状态模型""" - + session_id: str = Field(description="会话ID") current_task: Optional[str] = Field(default=None, description="当前任务") is_thinking: bool = Field(default=False, description="是否正在思考") last_activity: datetime = Field(default_factory=datetime.now, description="最后活动时间") - + model_config = ConfigDict() - + @field_serializer('last_activity', when_used='json') def serialize_datetime(self, value: datetime) -> str: return value.isoformat() @@ -40,7 +39,7 @@ class AgentState(BaseModel): class UserMessage(BaseModel): """用户消息模型""" - + session_id: str = Field(description="会话ID") content: str = Field(description="消息内容") user_id: Optional[str] = Field(default=None, description="用户ID") @@ -50,7 +49,7 @@ class UserMessage(BaseModel): class ToolResult(BaseModel): """工具执行结果模型""" - + session_id: str = Field(description="会话ID") call_id: str = Field(description="调用ID") success: bool = Field(description="是否成功") diff --git a/requirements.in b/requirements.in index b7427176..432f39c8 100644 --- a/requirements.in +++ b/requirements.in @@ -82,14 +82,13 @@ pympler~=1.1 smbprotocol~=1.15.0 setproctitle~=1.3.6 httpx[socks]~=0.28.1 -langchain~=0.3.27 -langchain-core~=0.3.76 -langchain-community~=0.3.29 -langchain-openai~=0.3.33 -langchain-google-genai~=2.0.10 -langchain-deepseek~=0.1.4 -langchain-experimental~=0.3.4 -openai~=1.108.2 -google-generativeai~=0.8.5 +langchain~=1.2.13 +langchain-core~=1.2.20 +langchain-community~=0.4.1 +langchain-openai~=1.1.11 +langchain-google-genai~=4.2.1 +langchain-deepseek~=1.0.1 +langchain-experimental~=0.4.1 +openai~=2.29.0 ddgs~=9.10.0 websocket-client~=1.8.0