From 30b332ac7e3999b6ee884baa78c52cc75b68e526 Mon Sep 17 00:00:00 2001 From: jxxghp Date: Sun, 22 Mar 2026 23:35:34 +0800 Subject: [PATCH] feat: Introduce `MemoryMiddleware` and `PatchToolCallsMiddleware` to the agent, and add `EditFileTool` and `WriteFileTool` for file manipulation. --- app/agent/__init__.py | 17 +- app/agent/middleware/__init__.py | 0 app/agent/middleware/memory.py | 230 +++++++++++++++++++++++ app/agent/middleware/patch_tool_calls.py | 43 +++++ app/agent/tools/factory.py | 6 +- app/agent/tools/impl/edit_file.py | 75 ++++++++ app/agent/tools/impl/write_file.py | 52 +++++ 7 files changed, 415 insertions(+), 8 deletions(-) create mode 100644 app/agent/middleware/__init__.py create mode 100644 app/agent/middleware/memory.py create mode 100644 app/agent/middleware/patch_tool_calls.py create mode 100644 app/agent/tools/impl/edit_file.py create mode 100644 app/agent/tools/impl/write_file.py diff --git a/app/agent/__init__.py b/app/agent/__init__.py index c5ffdc25..55f2fea9 100644 --- a/app/agent/__init__.py +++ b/app/agent/__init__.py @@ -6,8 +6,6 @@ from typing import Dict, List from langchain.agents import create_agent from langchain.agents.middleware import ( SummarizationMiddleware, - ModelRetryMiddleware, - ToolRetryMiddleware, ) from langchain_core.messages import ( HumanMessage, @@ -17,9 +15,12 @@ from langgraph.checkpoint.memory import InMemorySaver from app.agent.callback import StreamingHandler, StreamingHandler from app.agent.memory import memory_manager +from app.agent.middleware.memory import MemoryMiddleware +from app.agent.middleware.patch_tool_calls import PatchToolCallsMiddleware from app.agent.prompt import prompt_manager from app.agent.tools.factory import MoviePilotToolFactory from app.chain import ChainBase +from app.core.config import settings from app.helper.llm import LLMHelper from app.log import logger from app.schemas import Notification @@ -91,15 +92,17 @@ class MoviePilotAgent: # 中间件 middlewares = [ + # 记忆管理 + MemoryMiddleware( + sources=[str(settings.CONFIG_PATH / "agent" / "MEMORY.md")] + ), # 上下文压缩 SummarizationMiddleware( model=llm, trigger=("fraction", 0.85) ), - # 模型调用失败时自动重试 - ModelRetryMiddleware(max_retries=3), - # 工具调用失败时自动重试 - ToolRetryMiddleware(max_retries=1) + # 错误工具调用修复 + PatchToolCallsMiddleware() ] return create_agent( @@ -174,7 +177,7 @@ class MoviePilotAgent: memory_manager.save_agent_messages( session_id=self.session_id, user_id=self.user_id, - messages=agent.get_state(agent_config).values().get("messages", []) + messages=agent.get_state(agent_config).values.get("messages", []) ) except asyncio.CancelledError: diff --git a/app/agent/middleware/__init__.py b/app/agent/middleware/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/app/agent/middleware/memory.py b/app/agent/middleware/memory.py new file mode 100644 index 00000000..5616314c --- /dev/null +++ b/app/agent/middleware/memory.py @@ -0,0 +1,230 @@ +from collections.abc import Awaitable, Callable +from typing import Annotated, NotRequired, TypedDict, Dict + +from aiopath import AsyncPath +from langchain.agents.middleware.types import ( + AgentMiddleware, + AgentState, + ContextT, + ModelRequest, + ModelResponse, + PrivateStateAttr, # noqa + ResponseT, +) +from langchain_core.messages import SystemMessage, ContentBlock +from langchain_core.runnables import RunnableConfig +from langgraph.runtime import Runtime + +from app.log import logger + + +class MemoryState(AgentState): + """`MemoryMiddleware` 的状态模型。 + + 属性: + memory_contents: 将源路径映射到其加载内容的字典。 + 标记为私有,因此不包含在最终的代理状态中。 + """ + + memory_contents: NotRequired[Annotated[dict[str, str], PrivateStateAttr]] + + +class MemoryStateUpdate(TypedDict): + """`MemoryMiddleware` 的状态更新。""" + + memory_contents: dict[str, str] + + +MEMORY_SYSTEM_PROMPT = """ +{agent_memory} + + + + The above was loaded in from files in your filesystem. As you learn from your interactions with the user, you can save new knowledge by calling the `edit_file` or `write_file` tool. + + **Learning from feedback:** + - One of your MAIN PRIORITIES is to learn from your interactions with the user. These learnings can be implicit or explicit. This means that in the future, you will remember this important information. + - When you need to remember something, updating memory must be your FIRST, IMMEDIATE action - before responding to the user, before calling other tools, before doing anything else. Just update memory immediately. + - When user says something is better/worse, capture WHY and encode it as a pattern. + - Each correction is a chance to improve permanently - don't just fix the immediate issue, update your instructions. + - A great opportunity to update your memories is when the user interrupts a tool call and provides feedback. You should update your memories immediately before revising the tool call. + - Look for the underlying principle behind corrections, not just the specific mistake. + - The user might not explicitly ask you to remember something, but if they provide information that is useful for future use, you should update your memories immediately. + + **Asking for information:** + - If you lack context to perform an action (e.g. send a Slack DM, requires a user ID/email) you should explicitly ask the user for this information. + - It is preferred for you to ask for information, don't assume anything that you do not know! + - When the user provides information that is useful for future use, you should update your memories immediately. + + **When to update memories:** + - When the user explicitly asks you to remember something (e.g., "remember my email", "save this preference") + - When the user describes your role or how you should behave (e.g., "you are a web researcher", "always do X") + - When the user gives feedback on your work - capture what was wrong and how to improve + - When the user provides information required for tool use (e.g., slack channel ID, email addresses) + - When the user provides context useful for future tasks, such as how to use tools, or which actions to take in a particular situation + - When you discover new patterns or preferences (coding styles, conventions, workflows) + + **When to NOT update memories:** + - When the information is temporary or transient (e.g., "I'm running late", "I'm on my phone right now") + - When the information is a one-time task request (e.g., "Find me a recipe", "What's 25 * 4?") + - When the information is a simple question that doesn't reveal lasting preferences (e.g., "What day is it?", "Can you explain X?") + - When the information is an acknowledgment or small talk (e.g., "Sounds good!", "Hello", "Thanks for that") + - When the information is stale or irrelevant in future conversations + - Never store API keys, access tokens, passwords, or any other credentials in any file, memory, or system prompt. + - If the user asks where to put API keys or provides an API key, do NOT echo or save it. + + **Examples:** + Example 1 (remembering user information): + User: Can you connect to my google account? + Agent: Sure, I'll connect to your google account, what's your google account email? + User: john@example.com + Agent: Let me save this to my memory. + Tool Call: edit_file(...) -> remembers that the user's google account email is john@example.com + + Example 2 (remembering implicit user preferences): + User: Can you write me an example for creating a deep agent in LangChain? + Agent: Sure, I'll write you an example for creating a deep agent in LangChain + User: Can you do this in JavaScript + Agent: Let me save this to my memory. + Tool Call: edit_file(...) -> remembers that the user prefers to get LangChain code examples in JavaScript + Agent: Sure, here is the JavaScript example + + Example 3 (do not remember transient information): + User: I'm going to play basketball tonight so I will be offline for a few hours. + Agent: Okay I'll add a block to your calendar. + Tool Call: create_calendar_event(...) -> just calls a tool, does not commit anything to memory, as it is transient information + +""" + + +def append_to_system_message( + system_message: SystemMessage | None, + text: str, +) -> SystemMessage: + """将文本追加到系统消息。 + + 参数: + system_message: 现有的系统消息或 None。 + text: 要添加到系统消息的文本。 + + 返回: + 追加了文本的新 SystemMessage。 + """ + new_content: list[ContentBlock] = list(system_message.content_blocks) if system_message else [] # noqa + if new_content: + text = f"\n\n{text}" + new_content.append({"type": "text", "text": text}) + return SystemMessage(content_blocks=new_content) + + +class MemoryMiddleware(AgentMiddleware[MemoryState, ContextT, ResponseT]): # noqa + """从 `AGENTS.md` 文件加载代理记忆的中间件。 + + 从配置的源加载记忆内容并注入到系统提示词中。 + + 支持对多个源进行合并。 + + 参数: + sources: 包含指定路径和名称的 `MemorySource` 配置列表。 + """ + + state_schema = MemoryState + + def __init__( + self, + *, + sources: list[str], + ) -> None: + """初始化记忆中间件。 + + 参数: + sources: 要加载的记忆文件路径列表(例如,`["~/.deepagents/AGENTS.md", + "./.deepagents/AGENTS.md"]`)。 + + 显示名称自动从路径中派生。 + + 按顺序加载源。 + """ + self.sources = sources + + def _format_agent_memory(self, contents: dict[str, str]) -> str: + """格式化记忆,将位置和内容成对组合。 + + 参数: + contents: 将源路径映射到内容的字典。 + + 返回: + 在 标签中包装了位置+内容对的格式化字符串。 + """ + if not contents: + return MEMORY_SYSTEM_PROMPT.format( + agent_memory=f"(No memory loaded), but you can add some by calling the `write_file` tool to the file: {self.sources[0]}.") + + sections = [f"{path}\n{contents[path]}" for path in self.sources if contents.get(path)] + + if not sections: + return MEMORY_SYSTEM_PROMPT.format(agent_memory="(No memory loaded)") + + memory_body = "\n\n".join(sections) + return MEMORY_SYSTEM_PROMPT.format(agent_memory=memory_body) + + async def abefore_agent(self, state: MemoryState, runtime: Runtime, # noqa + config: RunnableConfig) -> MemoryStateUpdate | None: + """在代理执行前加载记忆内容。 + + 从所有配置的源加载记忆并存储在状态中。 + 如果状态中尚未存在则进行加载。 + + 参数: + state: 当前代理状态。 + runtime: 运行时上下文。 + config: Runnable 配置。 + + 返回: + 填充了 memory_contents 的状态更新。 + """ + # 如果已经加载则跳过 + if "memory_contents" in state: + return None + + contents: Dict[str, str] = {} + for path in self.sources: + file_path = AsyncPath(path) + if await file_path.exists(): + contents[path] = await file_path.read_text() + logger.debug("Loaded memory from: %s", path) + + return MemoryStateUpdate(memory_contents=contents) + + def modify_request(self, request: ModelRequest[ContextT]) -> ModelRequest[ContextT]: + """将记忆内容注入系统消息。 + + 参数: + request: 要修改的模型请求。 + + 返回: + 将记忆注入系统消息后的修改后请求。 + """ + contents = request.state.get("memory_contents", {}) # noqa + agent_memory = self._format_agent_memory(contents) + + new_system_message = append_to_system_message(request.system_message, agent_memory) + + return request.override(system_message=new_system_message) + + async def awrap_model_call( + self, + request: ModelRequest[ContextT], + handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]], + ) -> ModelResponse[ResponseT]: + """异步包装模型调用,将记忆注入系统提示词。 + + 参数: + request: 正在处理的模型请求。 + handler: 使用修改后的请求进行调用的异步处理函数。 + + 返回: + 来自处理函数的模型响应。 + """ + modified_request = self.modify_request(request) + return await handler(modified_request) diff --git a/app/agent/middleware/patch_tool_calls.py b/app/agent/middleware/patch_tool_calls.py new file mode 100644 index 00000000..0bf6a48f --- /dev/null +++ b/app/agent/middleware/patch_tool_calls.py @@ -0,0 +1,43 @@ +from typing import Any + +from langchain.agents.middleware import AgentMiddleware, AgentState +from langchain_core.messages import AIMessage, ToolMessage +from langgraph.runtime import Runtime +from langgraph.types import Overwrite + + +class PatchToolCallsMiddleware(AgentMiddleware): + """修复消息历史中悬空工具调用的中间件。""" + + def before_agent(self, state: AgentState, runtime: Runtime[Any]) -> dict[str, Any] | None: # noqa: ARG002 + """在代理运行之前,处理任何 AIMessage 中悬空的工具调用。""" + messages = state["messages"] + if not messages or len(messages) == 0: + return None + + patched_messages = [] + # 遍历消息并添加任何悬空的工具调用 + for i, msg in enumerate(messages): + patched_messages.append(msg) + if isinstance(msg, AIMessage) and msg.tool_calls: + for tool_call in msg.tool_calls: + corresponding_tool_msg = next( + (msg for msg in messages[i:] if msg.type == "tool" and msg.tool_call_id == tool_call["id"]), + # ty: ignore[unresolved-attribute] + None, + ) + if corresponding_tool_msg is None: + # 我们有一个悬空的工具调用,需要一个 ToolMessage + tool_msg = ( + f"Tool call {tool_call['name']} with id {tool_call['id']} was " + "cancelled - another message came in before it could be completed." + ) + patched_messages.append( + ToolMessage( + content=tool_msg, + name=tool_call["name"], + tool_call_id=tool_call["id"], + ) + ) + + return {"messages": Overwrite(patched_messages)} diff --git a/app/agent/tools/factory.py b/app/agent/tools/factory.py index bcbe45f4..39f05a21 100644 --- a/app/agent/tools/factory.py +++ b/app/agent/tools/factory.py @@ -41,6 +41,8 @@ from app.agent.tools.impl.list_directory import ListDirectoryTool from app.agent.tools.impl.query_transfer_history import QueryTransferHistoryTool from app.agent.tools.impl.transfer_file import TransferFileTool from app.agent.tools.impl.execute_command import ExecuteCommandTool +from app.agent.tools.impl.edit_file import EditFileTool +from app.agent.tools.impl.write_file import WriteFileTool from app.core.plugin import PluginManager from app.log import logger from .base import MoviePilotTool @@ -100,7 +102,9 @@ class MoviePilotToolFactory: RunSchedulerTool, QueryWorkflowsTool, RunWorkflowTool, - ExecuteCommandTool + ExecuteCommandTool, + EditFileTool, + WriteFileTool ] # 创建内置工具 for ToolClass in tool_definitions: diff --git a/app/agent/tools/impl/edit_file.py b/app/agent/tools/impl/edit_file.py new file mode 100644 index 00000000..ddf97770 --- /dev/null +++ b/app/agent/tools/impl/edit_file.py @@ -0,0 +1,75 @@ +"""文件编辑工具""" + +from pathlib import Path +from typing import Optional, Type + +from anyio import Path as AsyncPath +from pydantic import BaseModel, Field + +from app.agent.tools.base import MoviePilotTool +from app.log import logger + + +class EditFileInput(BaseModel): + """Input parameters for edit file tool""" + file_path: str = Field(..., description="The absolute path of the file to edit") + old_text: str = Field(..., description="The exact old text to be replaced") + new_text: str = Field(..., description="The new text to replace with") + + +class EditFileTool(MoviePilotTool): + name: str = "edit_file" + description: str = "Edit a file by replacing specific old text with new text. Useful for modifying configuration files, code, or scripts." + args_schema: Type[BaseModel] = EditFileInput + + def get_tool_message(self, **kwargs) -> Optional[str]: + """根据参数生成友好的提示消息""" + file_path = kwargs.get("file_path", "") + file_name = Path(file_path).name if file_path else "未知文件" + return f"正在编辑文件: {file_name}" + + async def run(self, file_path: str, old_text: str, new_text: str, **kwargs) -> str: + logger.info(f"执行工具: {self.name}, 参数: file_path={file_path}") + + try: + path = AsyncPath(file_path) + # 校验逻辑:如果要替换特定文本,文件必须存在且包含该文本 + if not await path.exists(): + # 如果 old_text 为空,可能用户想直接创建文件,但通常 edit_file 需要匹配旧内容 + if old_text: + return f"错误:文件 {file_path} 不存在,无法进行内容替换。" + + if await path.exists() and not await path.is_file(): + return f"错误:{file_path} 不是一个文件" + + if await path.exists(): + content = await path.read_text(encoding="utf-8") + if old_text not in content: + logger.warning(f"编辑文件 {file_path} 失败:未找到指定的旧文本块") + return f"错误:在文件 {file_path} 中未找到指定的旧文本。请确保包含所有的空格、缩进 and 换行符。" + occurrences = content.count(old_text) + new_content = content.replace(old_text, new_text) + else: + # 文件不存在且 old_text 为空的情形(初始化新文件) + new_content = new_text + occurrences = 1 + + # 自动创建父目录 + await path.parent.mkdir(parents=True, exist_ok=True) + + # 写入文件 + await path.write_text(new_content, encoding="utf-8") + + logger.info(f"成功编辑文件 {file_path},替换了 {occurrences} 处内容") + return f"成功编辑文件 {file_path} (替换了 {occurrences} 处匹配内容)" + + + except PermissionError: + return f"错误:没有访问/修改 {file_path} 的权限" + except UnicodeDecodeError: + return f"错误:{file_path} 不是文本文件,无法编辑" + except Exception as e: + logger.error(f"编辑文件 {file_path} 时发生错误: {str(e)}", exc_info=True) + return f"操作失败: {str(e)}" + + diff --git a/app/agent/tools/impl/write_file.py b/app/agent/tools/impl/write_file.py new file mode 100644 index 00000000..41be6e3b --- /dev/null +++ b/app/agent/tools/impl/write_file.py @@ -0,0 +1,52 @@ +"""文件写入工具""" + +from pathlib import Path +from typing import Optional, Type + +from anyio import Path as AsyncPath +from pydantic import BaseModel, Field + +from app.agent.tools.base import MoviePilotTool +from app.log import logger + + +class WriteFileInput(BaseModel): + """Input parameters for write file tool""" + file_path: str = Field(..., description="The absolute path of the file to write") + content: str = Field(..., description="The content to write into the file") + + +class WriteFileTool(MoviePilotTool): + name: str = "write_file" + description: str = "Write full content to a file. If the file already exists, it will be overwritten. Automatically creates parent directories if they don't exist." + args_schema: Type[BaseModel] = WriteFileInput + + def get_tool_message(self, **kwargs) -> Optional[str]: + """根据参数生成友好的提示消息""" + file_path = kwargs.get("file_path", "") + file_name = Path(file_path).name if file_path else "未知文件" + return f"正在写入文件: {file_name}" + + async def run(self, file_path: str, content: str, **kwargs) -> str: + logger.info(f"执行工具: {self.name}, 参数: file_path={file_path}") + + try: + path = AsyncPath(file_path) + + if await path.exists() and not await path.is_file(): + return f"错误:{file_path} 路径已存在但不是一个文件" + + # 自动创建父目录 + await path.parent.mkdir(parents=True, exist_ok=True) + + # 写入文件 + await path.write_text(content, encoding="utf-8") + + logger.info(f"成功写入文件 {file_path}") + return f"成功写入文件 {file_path}" + + except PermissionError: + return f"错误:没有权限写入 {file_path}" + except Exception as e: + logger.error(f"写入文件 {file_path} 时发生错误: {str(e)}", exc_info=True) + return f"操作失败: {str(e)}"