mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-03-20 03:57:30 +08:00
fix agent
This commit is contained in:
@@ -1,17 +1,16 @@
|
||||
"""MoviePilot AI智能体实现"""
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
from typing import Dict, List, Any
|
||||
|
||||
from langchain.agents import AgentExecutor, create_openai_tools_agent
|
||||
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from langchain_community.callbacks import get_openai_callback
|
||||
from langchain_core.callbacks import AsyncCallbackHandler
|
||||
from langchain_core.chat_history import InMemoryChatMessageHistory
|
||||
from langchain_core.messages import HumanMessage, AIMessage, ToolCall
|
||||
from langchain_core.runnables.history import RunnableWithMessageHistory
|
||||
|
||||
from app.agent.callback import StreamingCallbackHandler
|
||||
from app.agent.memory import ConversationMemoryManager
|
||||
from app.agent.prompt import PromptManager
|
||||
from app.agent.tools import MoviePilotToolFactory
|
||||
@@ -26,34 +25,6 @@ class AgentChain(ChainBase):
|
||||
pass
|
||||
|
||||
|
||||
class StreamingCallbackHandler(AsyncCallbackHandler):
|
||||
"""流式输出回调处理器"""
|
||||
|
||||
def __init__(self, session_id: str):
|
||||
self._lock = threading.Lock()
|
||||
self.session_id = session_id
|
||||
self.current_message = ""
|
||||
self.message_helper = MessageHelper()
|
||||
|
||||
async def get_message(self):
|
||||
"""获取当前消息内容,获取后清空"""
|
||||
with self._lock:
|
||||
if not self.current_message:
|
||||
return ""
|
||||
msg = self.current_message
|
||||
logger.info(f"Agent消息: {msg}")
|
||||
self.current_message = ""
|
||||
return msg
|
||||
|
||||
async def on_llm_new_token(self, token: str, **kwargs):
|
||||
"""处理新的token"""
|
||||
if not token:
|
||||
return
|
||||
with self._lock:
|
||||
# 缓存当前消息
|
||||
self.current_message += token
|
||||
|
||||
|
||||
class MoviePilotAgent:
|
||||
"""MoviePilot AI智能体"""
|
||||
|
||||
@@ -142,7 +113,8 @@ class MoviePilotAgent:
|
||||
user_id=self.user_id,
|
||||
channel=self.channel,
|
||||
source=self.source,
|
||||
username=self.username
|
||||
username=self.username,
|
||||
callback_handler=self.callback_handler
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -249,7 +221,7 @@ class MoviePilotAgent:
|
||||
agent_message = await self.callback_handler.get_message()
|
||||
|
||||
# 发送Agent回复给用户(通过原渠道)
|
||||
self.send_agent_message(agent_message)
|
||||
await self.send_agent_message(agent_message)
|
||||
|
||||
# 添加Agent回复到记忆
|
||||
await self.memory_manager.add_memory(
|
||||
@@ -265,7 +237,7 @@ class MoviePilotAgent:
|
||||
error_message = f"处理消息时发生错误: {str(e)}"
|
||||
logger.error(error_message)
|
||||
# 发送错误消息给用户(通过原渠道)
|
||||
self.send_agent_message(error_message)
|
||||
await self.send_agent_message(error_message)
|
||||
return error_message
|
||||
|
||||
async def _execute_agent(self, input_context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
@@ -301,9 +273,9 @@ class MoviePilotAgent:
|
||||
"token_usage": {}
|
||||
}
|
||||
|
||||
def send_agent_message(self, message: str, title: str = "MoviePilot助手"):
|
||||
async def send_agent_message(self, message: str, title: str = "MoviePilot助手"):
|
||||
"""通过原渠道发送消息给用户"""
|
||||
AgentChain().post_message(
|
||||
await AgentChain().async_post_message(
|
||||
Notification(
|
||||
channel=self.channel,
|
||||
source=self.source,
|
||||
|
||||
Reference in New Issue
Block a user