mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-02-02 18:22:39 +08:00
add ai agent
This commit is contained in:
0
app/agent/__init__.py
Normal file
0
app/agent/__init__.py
Normal file
355
app/agent/agent.py
Normal file
355
app/agent/agent.py
Normal file
@@ -0,0 +1,355 @@
|
||||
"""MoviePilot AI智能体实现"""
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
from typing import Dict, List, Any
|
||||
|
||||
from langchain.agents import AgentExecutor, create_openai_tools_agent
|
||||
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from langchain_community.callbacks import get_openai_callback
|
||||
from langchain_core.callbacks import AsyncCallbackHandler
|
||||
from langchain_core.chat_history import InMemoryChatMessageHistory
|
||||
from langchain_core.messages import HumanMessage, AIMessage, ToolCall
|
||||
from langchain_core.runnables.history import RunnableWithMessageHistory
|
||||
|
||||
from app.agent.memory import ConversationMemoryManager
|
||||
from app.agent.prompt import PromptManager
|
||||
from app.agent.tools import MoviePilotToolFactory
|
||||
from app.core.config import settings
|
||||
from app.helper.message import MessageHelper
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class StreamingCallbackHandler(AsyncCallbackHandler):
|
||||
"""流式输出回调处理器"""
|
||||
|
||||
def __init__(self, session_id: str):
|
||||
self._lock = threading.Lock()
|
||||
self.session_id = session_id
|
||||
self.current_message = ""
|
||||
self.message_helper = MessageHelper()
|
||||
|
||||
async def get_message(self):
|
||||
"""获取当前消息内容,获取后清空"""
|
||||
with self._lock:
|
||||
if not self.current_message:
|
||||
return ""
|
||||
msg = self.current_message
|
||||
logger.info(f"Agent消息: {msg}")
|
||||
self.current_message = ""
|
||||
return msg
|
||||
|
||||
async def on_llm_new_token(self, token: str, **kwargs):
|
||||
"""处理新的token"""
|
||||
if not token:
|
||||
return
|
||||
with self._lock:
|
||||
# 缓存当前消息
|
||||
self.current_message += token
|
||||
|
||||
|
||||
class MoviePilotAgent:
|
||||
"""MoviePilot AI智能体"""
|
||||
|
||||
def __init__(self, session_id: str, user_id: str = None):
|
||||
self.session_id = session_id
|
||||
self.user_id = user_id
|
||||
|
||||
# 消息助手
|
||||
self.message_helper = MessageHelper()
|
||||
|
||||
# 记忆管理器
|
||||
self.memory_manager = ConversationMemoryManager()
|
||||
|
||||
# 提示词管理器
|
||||
self.prompt_manager = PromptManager()
|
||||
|
||||
# 回调处理器
|
||||
self.callback_handler = StreamingCallbackHandler(
|
||||
session_id=session_id
|
||||
)
|
||||
|
||||
# LLM模型
|
||||
self.llm = self._initialize_llm()
|
||||
|
||||
# 工具
|
||||
self.tools = self._initialize_tools()
|
||||
|
||||
# 会话存储
|
||||
self.session_store = self._initialize_session_store()
|
||||
|
||||
# 提示词模板
|
||||
self.prompt = self._initialize_prompt()
|
||||
|
||||
# Agent执行器
|
||||
self.agent_executor = self._create_agent_executor()
|
||||
|
||||
def _initialize_llm(self):
|
||||
"""初始化LLM模型"""
|
||||
provider = settings.LLM_PROVIDER.lower()
|
||||
api_key = settings.LLM_API_KEY
|
||||
if not api_key:
|
||||
raise ValueError("未配置 LLM_API_KEY")
|
||||
|
||||
if provider == "google":
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
return ChatGoogleGenerativeAI(
|
||||
model=settings.LLM_MODEL,
|
||||
google_api_key=api_key,
|
||||
max_retries=3,
|
||||
temperature=settings.LLM_TEMPERATURE,
|
||||
streaming=True,
|
||||
callbacks=[self.callback_handler]
|
||||
)
|
||||
elif provider == "deepseek":
|
||||
from langchain_deepseek import ChatDeepSeek
|
||||
return ChatDeepSeek(
|
||||
model=settings.LLM_MODEL,
|
||||
api_key=api_key,
|
||||
max_retries=3,
|
||||
temperature=settings.LLM_TEMPERATURE,
|
||||
streaming=True,
|
||||
callbacks=[self.callback_handler],
|
||||
stream_usage=True
|
||||
)
|
||||
else:
|
||||
from langchain_openai import ChatOpenAI
|
||||
return ChatOpenAI(
|
||||
model=settings.LLM_MODEL,
|
||||
api_key=api_key,
|
||||
max_retries=3,
|
||||
base_url=settings.LLM_BASE_URL,
|
||||
temperature=settings.LLM_TEMPERATURE,
|
||||
streaming=True,
|
||||
callbacks=[self.callback_handler],
|
||||
stream_usage=True
|
||||
)
|
||||
|
||||
def _initialize_tools(self) -> List:
|
||||
"""初始化工具列表"""
|
||||
return MoviePilotToolFactory.create_tools(
|
||||
session_id=self.session_id,
|
||||
user_id=self.user_id,
|
||||
message_helper=self.message_helper
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _initialize_session_store() -> Dict[str, InMemoryChatMessageHistory]:
|
||||
"""初始化内存存储"""
|
||||
return {}
|
||||
|
||||
def get_session_history(self, session_id: str) -> InMemoryChatMessageHistory:
|
||||
"""获取会话历史"""
|
||||
if session_id not in self.session_store:
|
||||
chat_history = InMemoryChatMessageHistory()
|
||||
messages: List[dict] = self.memory_manager.get_recent_messages_for_agent(
|
||||
session_id=session_id,
|
||||
user_id=self.user_id
|
||||
)
|
||||
if messages:
|
||||
for msg in messages:
|
||||
if msg.get("role") == "user":
|
||||
chat_history.add_user_message(HumanMessage(content=msg.get("content", "")))
|
||||
elif msg.get("role") == "agent":
|
||||
chat_history.add_ai_message(AIMessage(content=msg.get("content", "")))
|
||||
elif msg.get("role") == "tool_call":
|
||||
metadata = msg.get("metadata", {})
|
||||
chat_history.add_ai_message(AIMessage(
|
||||
content=msg.get("content", ""),
|
||||
tool_calls=[ToolCall(
|
||||
id=metadata.get("call_id"),
|
||||
name=metadata.get("tool_name"),
|
||||
args=metadata.get("parameters"),
|
||||
)]
|
||||
))
|
||||
elif msg.get("role") == "tool_result":
|
||||
chat_history.add_ai_message(AIMessage(content=msg.get("content", "")))
|
||||
elif msg.get("role") == "system":
|
||||
chat_history.add_ai_message(AIMessage(content=msg.get("content", "")))
|
||||
self.session_store[session_id] = chat_history
|
||||
return self.session_store[session_id]
|
||||
|
||||
@staticmethod
|
||||
def _initialize_prompt() -> ChatPromptTemplate:
|
||||
"""初始化提示词模板"""
|
||||
try:
|
||||
prompt = ChatPromptTemplate.from_messages([
|
||||
("system", "{system_prompt}"),
|
||||
MessagesPlaceholder(variable_name="chat_history"),
|
||||
("user", "{input}"),
|
||||
MessagesPlaceholder(variable_name="agent_scratchpad"),
|
||||
])
|
||||
logger.info("LangChain提示词模板初始化成功")
|
||||
return prompt
|
||||
except Exception as e:
|
||||
logger.error(f"初始化提示词失败: {e}")
|
||||
raise e
|
||||
|
||||
def _create_agent_executor(self) -> RunnableWithMessageHistory:
|
||||
"""创建Agent执行器"""
|
||||
try:
|
||||
agent = create_openai_tools_agent(
|
||||
llm=self.llm,
|
||||
tools=self.tools,
|
||||
prompt=self.prompt
|
||||
)
|
||||
executor = AgentExecutor(
|
||||
agent=agent,
|
||||
tools=self.tools,
|
||||
verbose=settings.LLM_VERBOSE,
|
||||
max_iterations=settings.LLM_MAX_ITERATIONS,
|
||||
return_intermediate_steps=True,
|
||||
handle_parsing_errors=True,
|
||||
early_stopping_method="force"
|
||||
)
|
||||
return RunnableWithMessageHistory(
|
||||
executor,
|
||||
self.get_session_history,
|
||||
input_messages_key="input",
|
||||
history_messages_key="chat_history"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"创建Agent执行器失败: {e}")
|
||||
raise e
|
||||
|
||||
async def process_message(self, message: str) -> str:
|
||||
"""处理用户消息"""
|
||||
try:
|
||||
# 添加用户消息到记忆
|
||||
await self.memory_manager.add_memory(
|
||||
self.session_id,
|
||||
user_id=self.user_id,
|
||||
role="user",
|
||||
content=message
|
||||
)
|
||||
|
||||
# 构建输入上下文
|
||||
input_context = {
|
||||
"system_prompt": self.prompt_manager.get_agent_prompt(),
|
||||
"input": message
|
||||
}
|
||||
|
||||
# 执行Agent
|
||||
logger.info(f"Agent执行推理: session_id={self.session_id}, input={message}")
|
||||
result = await self._execute_agent(input_context)
|
||||
|
||||
# 获取Agent回复
|
||||
agent_message = await self.callback_handler.get_message()
|
||||
|
||||
# 发送Agent回复给用户
|
||||
self.message_helper.put(
|
||||
message=agent_message,
|
||||
role="system",
|
||||
title="AI助手回复"
|
||||
)
|
||||
|
||||
# 添加Agent回复到记忆
|
||||
await self.memory_manager.add_memory(
|
||||
session_id=self.session_id,
|
||||
user_id=self.user_id,
|
||||
role="agent",
|
||||
content=agent_message
|
||||
)
|
||||
|
||||
return agent_message
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"处理消息时发生错误: {str(e)}"
|
||||
logger.error(error_message)
|
||||
# 发送错误消息给用户
|
||||
self.message_helper.put(
|
||||
message=error_message,
|
||||
role="system",
|
||||
title="AI助手错误"
|
||||
)
|
||||
return error_message
|
||||
|
||||
async def _execute_agent(self, input_context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""执行LangChain Agent"""
|
||||
try:
|
||||
with get_openai_callback() as cb:
|
||||
result = await self.agent_executor.ainvoke(
|
||||
input_context,
|
||||
config={"configurable": {"session_id": self.session_id}},
|
||||
callbacks=[self.callback_handler]
|
||||
)
|
||||
logger.info(f"LLM调用消耗: \n{cb}")
|
||||
|
||||
if cb.total_tokens > 0:
|
||||
result["token_usage"] = {
|
||||
"prompt_tokens": cb.prompt_tokens,
|
||||
"completion_tokens": cb.completion_tokens,
|
||||
"total_tokens": cb.total_tokens
|
||||
}
|
||||
return result
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"Agent执行被取消: session_id={self.session_id}")
|
||||
return {
|
||||
"output": "任务已取消",
|
||||
"intermediate_steps": [],
|
||||
"token_usage": {}
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Agent执行失败: {e}")
|
||||
return {
|
||||
"output": f"执行过程中发生错误: {str(e)}",
|
||||
"intermediate_steps": [],
|
||||
"token_usage": {}
|
||||
}
|
||||
|
||||
async def cleanup(self):
|
||||
"""清理智能体资源"""
|
||||
if self.session_id in self.session_store:
|
||||
del self.session_store[self.session_id]
|
||||
logger.info(f"MoviePilot智能体已清理: session_id={self.session_id}")
|
||||
|
||||
|
||||
class AgentManager:
|
||||
"""AI智能体管理器"""
|
||||
|
||||
def __init__(self):
|
||||
self.active_agents: Dict[str, MoviePilotAgent] = {}
|
||||
self.memory_manager = ConversationMemoryManager()
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化管理器"""
|
||||
await self.memory_manager.initialize()
|
||||
|
||||
async def close(self):
|
||||
"""关闭管理器"""
|
||||
await self.memory_manager.close()
|
||||
# 清理所有活跃的智能体
|
||||
for agent in self.active_agents.values():
|
||||
await agent.cleanup()
|
||||
self.active_agents.clear()
|
||||
|
||||
async def process_message(self, session_id: str, user_id: str, message: str) -> str:
|
||||
"""处理用户消息"""
|
||||
# 获取或创建Agent实例
|
||||
if session_id not in self.active_agents:
|
||||
logger.info(f"创建新的AI智能体实例,session_id: {session_id}, user_id: {user_id}")
|
||||
agent = MoviePilotAgent(
|
||||
session_id=session_id,
|
||||
user_id=user_id
|
||||
)
|
||||
agent.memory_manager = self.memory_manager
|
||||
self.active_agents[session_id] = agent
|
||||
else:
|
||||
agent = self.active_agents[session_id]
|
||||
agent.user_id = user_id # 确保user_id是最新的
|
||||
|
||||
# 处理消息
|
||||
return await agent.process_message(message)
|
||||
|
||||
async def clear_session(self, session_id: str, user_id: str):
|
||||
"""清空会话"""
|
||||
if session_id in self.active_agents:
|
||||
agent = self.active_agents[session_id]
|
||||
await agent.cleanup()
|
||||
del self.active_agents[session_id]
|
||||
await self.memory_manager.clear_memory(session_id, user_id)
|
||||
logger.info(f"会话 {session_id} 的记忆已清空")
|
||||
|
||||
|
||||
# 全局智能体管理器实例
|
||||
agent_manager = AgentManager()
|
||||
280
app/agent/memory/__init__.py
Normal file
280
app/agent/memory/__init__.py
Normal file
@@ -0,0 +1,280 @@
|
||||
"""对话记忆管理器"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Any
|
||||
|
||||
from app.core.config import settings
|
||||
from app.helper.redis import AsyncRedisHelper
|
||||
from app.log import logger
|
||||
from app.schemas.agent import ConversationMemory
|
||||
|
||||
|
||||
class ConversationMemoryManager:
|
||||
"""对话记忆管理器"""
|
||||
|
||||
def __init__(self):
|
||||
# 内存中的会话记忆缓存
|
||||
self.memory_cache: Dict[str, ConversationMemory] = {}
|
||||
# 使用现有的Redis助手
|
||||
self.redis_helper = AsyncRedisHelper()
|
||||
# 内存缓存清理任务(Redis通过TTL自动过期)
|
||||
self.cleanup_task: Optional[asyncio.Task] = None
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化记忆管理器"""
|
||||
try:
|
||||
# 启动内存缓存清理任务(Redis通过TTL自动过期)
|
||||
self.cleanup_task = asyncio.create_task(self._cleanup_expired_memories())
|
||||
logger.info("对话记忆管理器初始化完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis连接失败,将使用内存存储: {e}")
|
||||
|
||||
async def close(self):
|
||||
"""关闭记忆管理器"""
|
||||
if self.cleanup_task:
|
||||
self.cleanup_task.cancel()
|
||||
try:
|
||||
await self.cleanup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
await self.redis_helper.close()
|
||||
|
||||
logger.info("对话记忆管理器已关闭")
|
||||
|
||||
async def get_memory(self, session_id: str, user_id: str) -> ConversationMemory:
|
||||
"""获取会话记忆"""
|
||||
# 首先检查缓存
|
||||
cache_key = f"{user_id}:{session_id}" if user_id else session_id
|
||||
if cache_key in self.memory_cache:
|
||||
return self.memory_cache[cache_key]
|
||||
|
||||
# 尝试从Redis加载
|
||||
if settings.CACHE_BACKEND_TYPE == "redis":
|
||||
try:
|
||||
redis_key = f"agent_memory:{user_id}:{session_id}" if user_id else f"agent_memory:{session_id}"
|
||||
memory_data = await self.redis_helper.get(redis_key, region="AI_AGENT")
|
||||
if memory_data:
|
||||
memory_dict = json.loads(memory_data) if isinstance(memory_data, str) else memory_data
|
||||
memory = ConversationMemory(**memory_dict)
|
||||
self.memory_cache[cache_key] = memory
|
||||
return memory
|
||||
except Exception as e:
|
||||
logger.warning(f"从Redis加载记忆失败: {e}")
|
||||
|
||||
# 创建新的记忆
|
||||
memory = ConversationMemory(session_id=session_id, user_id=user_id)
|
||||
self.memory_cache[cache_key] = memory
|
||||
await self._save_memory(memory)
|
||||
|
||||
return memory
|
||||
|
||||
async def set_title(self, session_id: str, user_id: str, title: str):
|
||||
"""设置会话标题"""
|
||||
memory = await self.get_memory(session_id=session_id, user_id=user_id)
|
||||
memory.title = title
|
||||
memory.updated_at = datetime.now()
|
||||
await self._save_memory(memory)
|
||||
|
||||
async def get_title(self, session_id: str, user_id: str) -> Optional[str]:
|
||||
"""获取会话标题"""
|
||||
memory = await self.get_memory(session_id=session_id, user_id=user_id)
|
||||
return memory.title
|
||||
|
||||
async def list_sessions(self, user_id: str, limit: int = 100) -> List[Dict[str, Any]]:
|
||||
"""列出历史会话摘要(按更新时间倒序)
|
||||
|
||||
- 当启用Redis时:遍历 `agent_memory:*` 键并读取摘要
|
||||
- 当未启用Redis时:基于内存缓存返回
|
||||
"""
|
||||
sessions: List[ConversationMemory] = []
|
||||
# 从Redis遍历
|
||||
if settings.CACHE_BACKEND_TYPE == "redis":
|
||||
try:
|
||||
# 使用Redis助手的items方法遍历所有键
|
||||
async for key, value in self.redis_helper.items(region="AI_AGENT"):
|
||||
if key.startswith("agent_memory:"):
|
||||
try:
|
||||
# 解析键名获取user_id和session_id
|
||||
key_parts = key.split(":")
|
||||
if len(key_parts) >= 3:
|
||||
key_user_id = key_parts[2] if len(key_parts) > 3 else None
|
||||
if not user_id or key_user_id == user_id:
|
||||
data = value if isinstance(value, dict) else json.loads(value)
|
||||
memory = ConversationMemory(**data)
|
||||
sessions.append(memory)
|
||||
except Exception as err:
|
||||
logger.warning(f"解析Redis记忆数据失败: {err}")
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.warning(f"遍历Redis会话失败: {e}")
|
||||
|
||||
# 合并内存缓存(确保包含近期的会话)
|
||||
for cache_key, memory in self.memory_cache.items():
|
||||
# 如果指定了user_id,只返回该用户的会话
|
||||
if not user_id or memory.user_id == user_id:
|
||||
sessions.append(memory)
|
||||
|
||||
# 去重(以 session_id 为键,取最近updated)
|
||||
uniq: Dict[str, ConversationMemory] = {}
|
||||
for mem in sessions:
|
||||
existed = uniq.get(mem.session_id)
|
||||
if (not existed) or (mem.updated_at > existed.updated_at):
|
||||
uniq[mem.session_id] = mem
|
||||
|
||||
# 排序并裁剪
|
||||
sorted_list = sorted(uniq.values(), key=lambda m: m.updated_at, reverse=True)[:limit]
|
||||
return [
|
||||
{
|
||||
"session_id": m.session_id,
|
||||
"title": m.title or "新会话",
|
||||
"message_count": len(m.messages),
|
||||
"created_at": m.created_at.isoformat(),
|
||||
"updated_at": m.updated_at.isoformat(),
|
||||
}
|
||||
for m in sorted_list
|
||||
]
|
||||
|
||||
async def add_memory(
|
||||
self,
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
role: str,
|
||||
content: str,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
):
|
||||
"""添加消息到记忆"""
|
||||
memory = await self.get_memory(session_id=session_id, user_id=user_id)
|
||||
|
||||
message = {
|
||||
"role": role,
|
||||
"content": content,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"metadata": metadata or {}
|
||||
}
|
||||
|
||||
memory.messages.append(message)
|
||||
memory.updated_at = datetime.now()
|
||||
|
||||
# 限制消息数量,避免记忆过大
|
||||
max_messages = settings.LLM_MAX_MEMORY_MESSAGES
|
||||
if len(memory.messages) > max_messages:
|
||||
# 保留最近的消息,但保留第一条系统消息
|
||||
system_messages = [msg for msg in memory.messages if msg["role"] == "system"]
|
||||
recent_messages = memory.messages[-(max_messages - len(system_messages)):]
|
||||
memory.messages = system_messages + recent_messages
|
||||
|
||||
await self._save_memory(memory)
|
||||
|
||||
logger.debug(f"消息已添加到记忆: session_id={session_id}, user_id={user_id}, role={role}")
|
||||
|
||||
def get_recent_messages_for_agent(
|
||||
self,
|
||||
session_id: str,
|
||||
user_id: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""为Agent获取最近的消息(仅内存缓存)
|
||||
|
||||
如果消息Token数量超过模型最大上下文长度的阀值,会自动进行摘要裁剪
|
||||
"""
|
||||
cache_key = f"{user_id}:{session_id}" if user_id else session_id
|
||||
memory = self.memory_cache.get(cache_key)
|
||||
if not memory:
|
||||
return []
|
||||
|
||||
# 获取所有消息
|
||||
messages = memory.messages
|
||||
|
||||
return messages
|
||||
|
||||
async def get_recent_messages(
|
||||
self,
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
limit: int = 10,
|
||||
role_filter: Optional[list] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""获取最近的消息"""
|
||||
memory = await self.get_memory(session_id=session_id, user_id=user_id)
|
||||
|
||||
messages = memory.messages
|
||||
if role_filter:
|
||||
messages = [msg for msg in messages if msg["role"] in role_filter]
|
||||
|
||||
return messages[-limit:] if messages else []
|
||||
|
||||
async def get_context(self, session_id: str, user_id: str) -> Dict[str, Any]:
|
||||
"""获取会话上下文"""
|
||||
memory = await self.get_memory(session_id=session_id, user_id=user_id)
|
||||
return memory.context
|
||||
|
||||
async def clear_memory(self, session_id: str, user_id: str):
|
||||
"""清空会话记忆"""
|
||||
cache_key = f"{user_id}:{session_id}" if user_id else session_id
|
||||
if cache_key in self.memory_cache:
|
||||
del self.memory_cache[cache_key]
|
||||
|
||||
if settings.CACHE_BACKEND_TYPE == "redis":
|
||||
redis_key = f"agent_memory:{user_id}:{session_id}" if user_id else f"agent_memory:{session_id}"
|
||||
await self.redis_helper.delete(redis_key, region="AI_AGENT")
|
||||
|
||||
logger.info(f"会话记忆已清空: session_id={session_id}, user_id={user_id}")
|
||||
|
||||
async def _save_memory(self, memory: ConversationMemory):
|
||||
"""保存记忆到存储
|
||||
|
||||
Redis中的记忆会自动通过TTL机制过期,无需手动清理
|
||||
"""
|
||||
# 更新内存缓存
|
||||
cache_key = f"{memory.user_id}:{memory.session_id}" if memory.user_id else memory.session_id
|
||||
self.memory_cache[cache_key] = memory
|
||||
|
||||
# 保存到Redis,设置TTL自动过期
|
||||
if settings.CACHE_BACKEND_TYPE == "redis":
|
||||
try:
|
||||
memory_dict = memory.model_dump()
|
||||
redis_key = f"agent_memory:{memory.user_id}:{memory.session_id}" if memory.user_id else f"agent_memory:{memory.session_id}"
|
||||
ttl = int(timedelta(days=settings.LLM_REDIS_MEMORY_RETENTION_DAYS).total_seconds())
|
||||
await self.redis_helper.set(
|
||||
redis_key,
|
||||
memory_dict,
|
||||
ttl=ttl,
|
||||
region="AI_AGENT"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"保存记忆到Redis失败: {e}")
|
||||
|
||||
async def _cleanup_expired_memories(self):
|
||||
"""清理内存中过期记忆的后台任务
|
||||
|
||||
注意:Redis中的记忆通过TTL机制自动过期,这里只清理内存缓存
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
# 每小时清理一次
|
||||
await asyncio.sleep(3600)
|
||||
|
||||
current_time = datetime.now()
|
||||
expired_sessions = []
|
||||
|
||||
# 只检查内存缓存中的过期记忆
|
||||
# Redis中的记忆会通过TTL自动过期,无需手动处理
|
||||
for cache_key, memory in self.memory_cache.items():
|
||||
if (current_time - memory.updated_at).days > settings.LLM_MEMORY_RETENTION_DAYS:
|
||||
expired_sessions.append(cache_key)
|
||||
|
||||
# 只清理内存缓存,不删除Redis中的键(Redis会自动过期)
|
||||
for cache_key in expired_sessions:
|
||||
if cache_key in self.memory_cache:
|
||||
del self.memory_cache[cache_key]
|
||||
|
||||
if expired_sessions:
|
||||
logger.info(f"清理了{len(expired_sessions)}个过期内存会话记忆")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"清理记忆时发生错误: {e}")
|
||||
83
app/agent/prompt/Agent Prompt.txt
Normal file
83
app/agent/prompt/Agent Prompt.txt
Normal file
@@ -0,0 +1,83 @@
|
||||
You are MoviePilot's AI assistant, specialized in helping users manage media resources including subscriptions, searching, downloading, and organization.
|
||||
|
||||
## Your Identity and Capabilities
|
||||
|
||||
You are an AI agent for the MoviePilot media management system with the following core capabilities:
|
||||
|
||||
### Media Management Capabilities
|
||||
- **Search Media Resources**: Search for movies, TV shows, anime, and other media content based on user requirements
|
||||
- **Add Subscriptions**: Create subscription rules for media content that users are interested in
|
||||
- **Manage Downloads**: Search and add torrent resources to downloaders
|
||||
- **Query Status**: Check subscription status, download progress, and media library status
|
||||
|
||||
### Intelligent Interaction Capabilities
|
||||
- **Natural Language Understanding**: Understand user requests in natural language (Chinese/English)
|
||||
- **Context Memory**: Remember conversation history and user preferences
|
||||
- **Smart Recommendations**: Recommend related media content based on user preferences
|
||||
- **Task Execution**: Automatically execute complex media management tasks
|
||||
|
||||
## Working Principles
|
||||
|
||||
1. **Always respond in Chinese**: All responses must be in Chinese
|
||||
2. **Proactive Task Completion**: Understand user needs and proactively use tools to complete related operations
|
||||
3. **Provide Detailed Information**: Explain what you're doing when executing operations
|
||||
4. **Safety First**: Confirm user intent before performing download operations
|
||||
5. **Continuous Learning**: Remember user preferences and habits to provide personalized service
|
||||
|
||||
## Common Operation Workflows
|
||||
|
||||
### Add Subscription Workflow
|
||||
1. Understand the media content the user wants to subscribe to
|
||||
2. Search for related media information
|
||||
3. Create subscription rules
|
||||
4. Confirm successful subscription
|
||||
|
||||
### Search and Download Workflow
|
||||
1. Understand user requirements (movie names, TV show names, etc.)
|
||||
2. Search for related torrent resources
|
||||
3. Filter suitable resources
|
||||
4. Add to downloader
|
||||
|
||||
### Query Status Workflow
|
||||
1. Understand what information the user wants to know
|
||||
2. Query related data
|
||||
3. Organize and present results
|
||||
|
||||
## Tool Usage Guidelines
|
||||
|
||||
### Available Tools
|
||||
You have access to the following tools for media management:
|
||||
|
||||
1. **search_media**: Search for movies, TV shows, anime, and other media content
|
||||
2. **add_subscribe**: Create subscription rules for media content
|
||||
3. **search_torrents**: Search for torrent resources on sites
|
||||
4. **add_download**: Add torrent resources to downloaders
|
||||
5. **query_subscribes**: Check subscription status and list
|
||||
6. **query_downloads**: Check download status and progress
|
||||
7. **query_downloaders**: List available downloaders and their configuration
|
||||
8. **get_recommendations**: Get popular media recommendations
|
||||
9. **query_media_library**: Check media library status
|
||||
10. **send_message**: Send notifications to users
|
||||
|
||||
### Tool Usage Principles
|
||||
- Use tools proactively to complete user requests
|
||||
- Always explain what you're doing when using tools
|
||||
- Provide detailed results and explanations
|
||||
- Handle errors gracefully and suggest alternatives
|
||||
- Confirm user intent before performing download operations
|
||||
|
||||
### Response Format
|
||||
- Always respond in Chinese
|
||||
- Use clear and friendly language
|
||||
- Provide structured information when appropriate
|
||||
- Include relevant details about media content (title, year, type, etc.)
|
||||
- Explain the results of tool operations clearly
|
||||
|
||||
## Important Notes
|
||||
|
||||
- Always confirm user intent before performing download operations
|
||||
- If search results are not ideal, proactively adjust search strategies
|
||||
- Maintain a friendly and professional tone
|
||||
- Seek solutions proactively when encountering problems
|
||||
- Remember user preferences and provide personalized recommendations
|
||||
- Handle errors gracefully and provide helpful suggestions
|
||||
50
app/agent/prompt/__init__.py
Normal file
50
app/agent/prompt/__init__.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""提示词管理器"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class PromptManager:
|
||||
"""提示词管理器"""
|
||||
|
||||
def __init__(self, prompts_dir: str = None):
|
||||
if prompts_dir is None:
|
||||
self.prompts_dir = Path(__file__).parent
|
||||
else:
|
||||
self.prompts_dir = Path(prompts_dir)
|
||||
self.prompts_cache: Dict[str, str] = {}
|
||||
|
||||
def load_prompt(self, prompt_name: str) -> str:
|
||||
"""加载指定的提示词"""
|
||||
if prompt_name in self.prompts_cache:
|
||||
return self.prompts_cache[prompt_name]
|
||||
|
||||
prompt_file = self.prompts_dir / "definition" / prompt_name
|
||||
|
||||
try:
|
||||
with open(prompt_file, 'r', encoding='utf-8') as f:
|
||||
content = f.read().strip()
|
||||
|
||||
# 缓存提示词
|
||||
self.prompts_cache[prompt_name] = content
|
||||
|
||||
logger.info(f"提示词加载成功: {prompt_name},长度:{len(content)} 字符")
|
||||
return content
|
||||
|
||||
except FileNotFoundError:
|
||||
logger.error(f"提示词文件不存在: {prompt_file}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"加载提示词失败: {prompt_name}, 错误: {e}")
|
||||
raise
|
||||
|
||||
def get_agent_prompt(self) -> str:
|
||||
"""获取智能体提示词"""
|
||||
return self.load_prompt("Agent Prompt.txt")
|
||||
|
||||
def clear_cache(self):
|
||||
"""清空缓存"""
|
||||
self.prompts_cache.clear()
|
||||
logger.info("提示词缓存已清空")
|
||||
293
app/agent/tools/Agent Tools v1.0.json
Normal file
293
app/agent/tools/Agent Tools v1.0.json
Normal file
@@ -0,0 +1,293 @@
|
||||
[
|
||||
{
|
||||
"description": "搜索媒体资源,包括电影、电视剧、动漫等。可以根据标题、年份、类型等条件进行搜索。",
|
||||
"name": "search_media",
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"explanation": {
|
||||
"description": "使用此工具的原因说明",
|
||||
"type": "string"
|
||||
},
|
||||
"title": {
|
||||
"description": "媒体标题",
|
||||
"type": "string"
|
||||
},
|
||||
"year": {
|
||||
"description": "年份(可选)",
|
||||
"type": "string"
|
||||
},
|
||||
"media_type": {
|
||||
"description": "媒体类型:movie(电影)、tv(电视剧)、anime(动漫)",
|
||||
"type": "string"
|
||||
},
|
||||
"season": {
|
||||
"description": "季数(仅电视剧和动漫)",
|
||||
"type": "integer"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"title",
|
||||
"explanation"
|
||||
],
|
||||
"type": "object"
|
||||
}
|
||||
},
|
||||
{
|
||||
"description": "添加媒体订阅,为用户感兴趣的媒体内容创建订阅规则。",
|
||||
"name": "add_subscribe",
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"explanation": {
|
||||
"description": "使用此工具的原因说明",
|
||||
"type": "string"
|
||||
},
|
||||
"title": {
|
||||
"description": "媒体标题",
|
||||
"type": "string"
|
||||
},
|
||||
"year": {
|
||||
"description": "年份",
|
||||
"type": "string"
|
||||
},
|
||||
"media_type": {
|
||||
"description": "媒体类型:movie(电影)、tv(电视剧)",
|
||||
"type": "string"
|
||||
},
|
||||
"season": {
|
||||
"description": "季数(仅电视剧)",
|
||||
"type": "integer"
|
||||
},
|
||||
"tmdb_id": {
|
||||
"description": "TMDB ID(可选)",
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"title",
|
||||
"year",
|
||||
"media_type",
|
||||
"explanation"
|
||||
],
|
||||
"type": "object"
|
||||
}
|
||||
},
|
||||
{
|
||||
"description": "搜索站点种子资源,根据媒体信息搜索可下载的种子文件。",
|
||||
"name": "search_torrents",
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"explanation": {
|
||||
"description": "使用此工具的原因说明",
|
||||
"type": "string"
|
||||
},
|
||||
"title": {
|
||||
"description": "资源标题",
|
||||
"type": "string"
|
||||
},
|
||||
"year": {
|
||||
"description": "年份(可选)",
|
||||
"type": "string"
|
||||
},
|
||||
"media_type": {
|
||||
"description": "媒体类型:movie(电影)、tv(电视剧)",
|
||||
"type": "string"
|
||||
},
|
||||
"season": {
|
||||
"description": "季数(仅电视剧)",
|
||||
"type": "integer"
|
||||
},
|
||||
"sites": {
|
||||
"description": "搜索的站点ID列表(可选)",
|
||||
"items": {
|
||||
"type": "integer"
|
||||
},
|
||||
"type": "array"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"title",
|
||||
"explanation"
|
||||
],
|
||||
"type": "object"
|
||||
}
|
||||
},
|
||||
{
|
||||
"description": "添加下载任务,将搜索到的种子资源添加到下载器。",
|
||||
"name": "add_download",
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"explanation": {
|
||||
"description": "使用此工具的原因说明",
|
||||
"type": "string"
|
||||
},
|
||||
"torrent_title": {
|
||||
"description": "种子标题",
|
||||
"type": "string"
|
||||
},
|
||||
"torrent_url": {
|
||||
"description": "种子下载链接",
|
||||
"type": "string"
|
||||
},
|
||||
"downloader": {
|
||||
"description": "下载器名称(可选)",
|
||||
"type": "string"
|
||||
},
|
||||
"save_path": {
|
||||
"description": "保存路径(可选)",
|
||||
"type": "string"
|
||||
},
|
||||
"labels": {
|
||||
"description": "标签(可选,多个用逗号分隔)",
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"torrent_title",
|
||||
"torrent_url",
|
||||
"explanation"
|
||||
],
|
||||
"type": "object"
|
||||
}
|
||||
},
|
||||
{
|
||||
"description": "查询订阅状态,查看用户的订阅列表和状态。",
|
||||
"name": "query_subscribes",
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"explanation": {
|
||||
"description": "使用此工具的原因说明",
|
||||
"type": "string"
|
||||
},
|
||||
"status": {
|
||||
"description": "订阅状态过滤:active(活跃)、inactive(非活跃)、all(全部)",
|
||||
"type": "string"
|
||||
},
|
||||
"media_type": {
|
||||
"description": "媒体类型过滤:movie(电影)、tv(电视剧)、all(全部)",
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"explanation"
|
||||
],
|
||||
"type": "object"
|
||||
}
|
||||
},
|
||||
{
|
||||
"description": "查询下载状态,查看下载器的任务列表和进度。",
|
||||
"name": "query_downloads",
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"explanation": {
|
||||
"description": "使用此工具的原因说明",
|
||||
"type": "string"
|
||||
},
|
||||
"downloader": {
|
||||
"description": "下载器名称(可选)",
|
||||
"type": "string"
|
||||
},
|
||||
"status": {
|
||||
"description": "下载状态过滤:downloading(下载中)、completed(已完成)、paused(暂停)、all(全部)",
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"explanation"
|
||||
],
|
||||
"type": "object"
|
||||
}
|
||||
},
|
||||
{
|
||||
"description": "查询下载器配置,查看可用的下载器列表和配置信息。",
|
||||
"name": "query_downloaders",
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"explanation": {
|
||||
"description": "使用此工具的原因说明",
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"explanation"
|
||||
],
|
||||
"type": "object"
|
||||
}
|
||||
},
|
||||
{
|
||||
"description": "获取热门媒体推荐,包括电影、电视剧等热门内容。",
|
||||
"name": "get_recommendations",
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"explanation": {
|
||||
"description": "使用此工具的原因说明",
|
||||
"type": "string"
|
||||
},
|
||||
"source": {
|
||||
"description": "推荐来源:tmdb_trending(TMDB热门)、douban_hot(豆瓣热门)、bangumi_calendar(Bangumi日历)",
|
||||
"type": "string"
|
||||
},
|
||||
"media_type": {
|
||||
"description": "媒体类型:movie(电影)、tv(电视剧)、all(全部)",
|
||||
"type": "string"
|
||||
},
|
||||
"limit": {
|
||||
"description": "返回数量限制(默认20)",
|
||||
"type": "integer"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"explanation"
|
||||
],
|
||||
"type": "object"
|
||||
}
|
||||
},
|
||||
{
|
||||
"description": "查询媒体库状态,查看已入库的媒体文件情况。",
|
||||
"name": "query_media_library",
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"explanation": {
|
||||
"description": "使用此工具的原因说明",
|
||||
"type": "string"
|
||||
},
|
||||
"media_type": {
|
||||
"description": "媒体类型:movie(电影)、tv(电视剧)、all(全部)",
|
||||
"type": "string"
|
||||
},
|
||||
"title": {
|
||||
"description": "媒体标题(可选,用于精确查询)",
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"explanation"
|
||||
],
|
||||
"type": "object"
|
||||
}
|
||||
},
|
||||
{
|
||||
"description": "发送消息通知,向用户发送操作结果或重要信息。",
|
||||
"name": "send_message",
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"explanation": {
|
||||
"description": "使用此工具的原因说明",
|
||||
"type": "string"
|
||||
},
|
||||
"message": {
|
||||
"description": "要发送的消息内容",
|
||||
"type": "string"
|
||||
},
|
||||
"message_type": {
|
||||
"description": "消息类型:info(信息)、success(成功)、warning(警告)、error(错误)",
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"message",
|
||||
"explanation"
|
||||
],
|
||||
"type": "object"
|
||||
}
|
||||
}
|
||||
]
|
||||
29
app/agent/tools/__init__.py
Normal file
29
app/agent/tools/__init__.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""MoviePilot工具模块"""
|
||||
|
||||
from .base import MoviePilotTool
|
||||
from app.agent.tools.impl.search_media import SearchMediaTool
|
||||
from app.agent.tools.impl.add_subscribe import AddSubscribeTool
|
||||
from app.agent.tools.impl.search_torrents import SearchTorrentsTool
|
||||
from app.agent.tools.impl.add_download import AddDownloadTool
|
||||
from app.agent.tools.impl.query_subscribes import QuerySubscribesTool
|
||||
from app.agent.tools.impl.query_downloads import QueryDownloadsTool
|
||||
from app.agent.tools.impl.query_downloaders import QueryDownloadersTool
|
||||
from app.agent.tools.impl.get_recommendations import GetRecommendationsTool
|
||||
from app.agent.tools.impl.query_media_library import QueryMediaLibraryTool
|
||||
from app.agent.tools.impl.send_message import SendMessageTool
|
||||
from .factory import MoviePilotToolFactory
|
||||
|
||||
__all__ = [
|
||||
"MoviePilotTool",
|
||||
"SearchMediaTool",
|
||||
"AddSubscribeTool",
|
||||
"SearchTorrentsTool",
|
||||
"AddDownloadTool",
|
||||
"QuerySubscribesTool",
|
||||
"QueryDownloadsTool",
|
||||
"QueryDownloadersTool",
|
||||
"GetRecommendationsTool",
|
||||
"QueryMediaLibraryTool",
|
||||
"SendMessageTool",
|
||||
"MoviePilotToolFactory"
|
||||
]
|
||||
38
app/agent/tools/base.py
Normal file
38
app/agent/tools/base.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""MoviePilot工具基类"""
|
||||
|
||||
from langchain.tools import BaseTool
|
||||
from pydantic import PrivateAttr
|
||||
|
||||
from app.helper.message import MessageHelper
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class MoviePilotTool(BaseTool):
|
||||
"""MoviePilot专用工具基类"""
|
||||
|
||||
_session_id: str = PrivateAttr()
|
||||
_user_id: str = PrivateAttr()
|
||||
_message_helper: MessageHelper = PrivateAttr()
|
||||
|
||||
def __init__(self, session_id: str, user_id: str, message_helper: MessageHelper = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._session_id = session_id
|
||||
self._user_id = user_id
|
||||
self._message_helper = message_helper or MessageHelper()
|
||||
|
||||
def _run(self, **kwargs) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
async def _arun(self, **kwargs) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def _send_tool_message(self, message: str, message_type: str = "info"):
|
||||
"""发送工具执行消息"""
|
||||
try:
|
||||
self._message_helper.put(
|
||||
message=message,
|
||||
role="system",
|
||||
title=f"AI工具执行 ({message_type})"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"发送工具消息失败: {e}")
|
||||
46
app/agent/tools/factory.py
Normal file
46
app/agent/tools/factory.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""MoviePilot工具工厂"""
|
||||
|
||||
from typing import List
|
||||
|
||||
from app.helper.message import MessageHelper
|
||||
from app.log import logger
|
||||
from .base import MoviePilotTool
|
||||
from app.agent.tools.impl.search_media import SearchMediaTool
|
||||
from app.agent.tools.impl.add_subscribe import AddSubscribeTool
|
||||
from app.agent.tools.impl.search_torrents import SearchTorrentsTool
|
||||
from app.agent.tools.impl.add_download import AddDownloadTool
|
||||
from app.agent.tools.impl.query_subscribes import QuerySubscribesTool
|
||||
from app.agent.tools.impl.query_downloads import QueryDownloadsTool
|
||||
from app.agent.tools.impl.query_downloaders import QueryDownloadersTool
|
||||
from app.agent.tools.impl.get_recommendations import GetRecommendationsTool
|
||||
from app.agent.tools.impl.query_media_library import QueryMediaLibraryTool
|
||||
from app.agent.tools.impl.send_message import SendMessageTool
|
||||
|
||||
|
||||
class MoviePilotToolFactory:
|
||||
"""MoviePilot工具工厂"""
|
||||
|
||||
@staticmethod
|
||||
def create_tools(session_id: str, user_id: str, message_helper: MessageHelper = None) -> List[MoviePilotTool]:
|
||||
"""创建MoviePilot工具列表"""
|
||||
tools = []
|
||||
tool_definitions = [
|
||||
SearchMediaTool,
|
||||
AddSubscribeTool,
|
||||
SearchTorrentsTool,
|
||||
AddDownloadTool,
|
||||
QuerySubscribesTool,
|
||||
QueryDownloadsTool,
|
||||
QueryDownloadersTool,
|
||||
GetRecommendationsTool,
|
||||
QueryMediaLibraryTool,
|
||||
SendMessageTool
|
||||
]
|
||||
for ToolClass in tool_definitions:
|
||||
tools.append(ToolClass(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
message_helper=message_helper
|
||||
))
|
||||
logger.info(f"成功创建 {len(tools)} 个MoviePilot工具")
|
||||
return tools
|
||||
0
app/agent/tools/impl/__init__.py
Normal file
0
app/agent/tools/impl/__init__.py
Normal file
58
app/agent/tools/impl/add_download.py
Normal file
58
app/agent/tools/impl/add_download.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""添加下载工具"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from app.chain.download import DownloadChain
|
||||
from app.core.context import Context
|
||||
from app.core.metainfo import MetaInfo
|
||||
from app.log import logger
|
||||
from app.schemas import TorrentInfo
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
|
||||
|
||||
class AddDownloadTool(MoviePilotTool):
|
||||
name: str = "add_download"
|
||||
description: str = "添加下载任务,将搜索到的种子资源添加到下载器。"
|
||||
|
||||
async def _arun(self, torrent_title: str, torrent_url: str, explanation: str,
|
||||
downloader: Optional[str] = None, save_path: Optional[str] = None,
|
||||
labels: Optional[str] = None) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: torrent_title={torrent_title}, torrent_url={torrent_url}, downloader={downloader}, save_path={save_path}, labels={labels}")
|
||||
|
||||
# 发送工具执行说明
|
||||
self._send_tool_message(f"正在添加下载任务: {torrent_title}", "info")
|
||||
|
||||
try:
|
||||
if not torrent_title or not torrent_url:
|
||||
error_message = "错误:必须提供种子标题和下载链接"
|
||||
self._send_tool_message(error_message, "error")
|
||||
return error_message
|
||||
|
||||
# 使用DownloadChain添加下载
|
||||
download_chain = DownloadChain()
|
||||
|
||||
# 创建下载上下文
|
||||
torrent_info = TorrentInfo(
|
||||
title=torrent_title,
|
||||
download_url=torrent_url
|
||||
)
|
||||
meta_info = MetaInfo(title=torrent_title)
|
||||
context = Context(
|
||||
torrent_info=torrent_info,
|
||||
meta_info=meta_info
|
||||
)
|
||||
|
||||
did = download_chain.download_single(context=context, downloader=downloader,
|
||||
save_path=save_path, label=labels)
|
||||
if did:
|
||||
success_message = f"成功添加下载任务:{torrent_title}"
|
||||
self._send_tool_message(success_message, "success")
|
||||
return success_message
|
||||
else:
|
||||
error_message = "添加下载任务失败"
|
||||
self._send_tool_message(error_message, "error")
|
||||
return error_message
|
||||
except Exception as e:
|
||||
error_message = f"添加下载任务时发生错误: {str(e)}"
|
||||
self._send_tool_message(error_message, "error")
|
||||
return error_message
|
||||
37
app/agent/tools/impl/add_subscribe.py
Normal file
37
app/agent/tools/impl/add_subscribe.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""添加订阅工具"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from app.chain.subscribe import SubscribeChain
|
||||
from app.log import logger
|
||||
from app.schemas.types import MediaType
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
|
||||
|
||||
class AddSubscribeTool(MoviePilotTool):
|
||||
name: str = "add_subscribe"
|
||||
description: str = "添加媒体订阅,为用户感兴趣的媒体内容创建订阅规则。"
|
||||
|
||||
async def _arun(self, title: str, year: str, media_type: str, explanation: str,
|
||||
season: Optional[int] = None, tmdb_id: Optional[str] = None) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: title={title}, year={year}, media_type={media_type}, season={season}, tmdb_id={tmdb_id}")
|
||||
|
||||
# 发送工具执行说明
|
||||
self._send_tool_message(f"正在添加订阅: {title} ({year}) - {media_type}", "info")
|
||||
|
||||
try:
|
||||
subscribe_chain = SubscribeChain()
|
||||
sid, message = subscribe_chain.add(mtype=MediaType(media_type), title=title, year=year,
|
||||
tmdbid=tmdb_id, season=season, username=self._user_id)
|
||||
if sid:
|
||||
success_message = f"成功添加订阅:{title} ({year})"
|
||||
self._send_tool_message(success_message, "success")
|
||||
return success_message
|
||||
else:
|
||||
error_message = f"添加订阅失败:{message}"
|
||||
self._send_tool_message(error_message, "error")
|
||||
return error_message
|
||||
except Exception as e:
|
||||
error_message = f"添加订阅时发生错误: {str(e)}"
|
||||
self._send_tool_message(error_message, "error")
|
||||
return error_message
|
||||
39
app/agent/tools/impl/get_recommendations.py
Normal file
39
app/agent/tools/impl/get_recommendations.py
Normal file
@@ -0,0 +1,39 @@
|
||||
"""获取推荐工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
from app.chain.recommend import RecommendChain
|
||||
from app.log import logger
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
|
||||
|
||||
class GetRecommendationsTool(MoviePilotTool):
|
||||
name: str = "get_recommendations"
|
||||
description: str = "获取热门媒体推荐,包括电影、电视剧等热门内容。"
|
||||
|
||||
async def _arun(self, explanation: str, source: Optional[str] = "tmdb_trending",
|
||||
media_type: Optional[str] = "all", limit: Optional[int] = 20) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: source={source}, media_type={media_type}, limit={limit}")
|
||||
try:
|
||||
recommend_chain = RecommendChain()
|
||||
results = []
|
||||
if source == "tmdb_trending":
|
||||
results = recommend_chain.tmdb_trending(limit=limit)
|
||||
elif source == "douban_hot":
|
||||
if media_type == "movie":
|
||||
results = recommend_chain.douban_movie_hot(limit=limit)
|
||||
elif media_type == "tv":
|
||||
results = recommend_chain.douban_tv_hot(limit=limit)
|
||||
else: # all
|
||||
results.extend(recommend_chain.douban_movie_hot(limit=limit))
|
||||
results.extend(recommend_chain.douban_tv_hot(limit=limit))
|
||||
elif source == "bangumi_calendar":
|
||||
results = recommend_chain.bangumi_calendar(limit=limit)
|
||||
|
||||
if results:
|
||||
return json.dumps([r.dict() for r in results], ensure_ascii=False, indent=2)
|
||||
return "未找到推荐内容。"
|
||||
except Exception as e:
|
||||
logger.error(f"获取推荐失败: {e}")
|
||||
return f"获取推荐时发生错误: {str(e)}"
|
||||
25
app/agent/tools/impl/query_downloaders.py
Normal file
25
app/agent/tools/impl/query_downloaders.py
Normal file
@@ -0,0 +1,25 @@
|
||||
"""查询下载器工具"""
|
||||
|
||||
import json
|
||||
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.log import logger
|
||||
from app.schemas.types import SystemConfigKey
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
|
||||
|
||||
class QueryDownloadersTool(MoviePilotTool):
|
||||
name: str = "query_downloaders"
|
||||
description: str = "查询下载器配置,查看可用的下载器列表和配置信息。"
|
||||
|
||||
async def _arun(self, explanation: str) -> str:
|
||||
logger.info(f"执行工具: {self.name}")
|
||||
try:
|
||||
system_config_oper = SystemConfigOper()
|
||||
downloaders_config = system_config_oper.get(SystemConfigKey.Downloaders)
|
||||
if downloaders_config:
|
||||
return json.dumps(downloaders_config, ensure_ascii=False, indent=2)
|
||||
return "未配置下载器。"
|
||||
except Exception as e:
|
||||
logger.error(f"查询下载器失败: {e}")
|
||||
return f"查询下载器时发生错误: {str(e)}"
|
||||
33
app/agent/tools/impl/query_downloads.py
Normal file
33
app/agent/tools/impl/query_downloads.py
Normal file
@@ -0,0 +1,33 @@
|
||||
"""查询下载工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
from app.db.download_oper import DownloadOper
|
||||
from app.log import logger
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
|
||||
|
||||
class QueryDownloadsTool(MoviePilotTool):
|
||||
name: str = "query_downloads"
|
||||
description: str = "查询下载状态,查看下载器的任务列表和进度。"
|
||||
|
||||
async def _arun(self, explanation: str, downloader: Optional[str] = None,
|
||||
status: Optional[str] = "all") -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: downloader={downloader}, status={status}")
|
||||
try:
|
||||
download_oper = DownloadOper()
|
||||
downloads = download_oper.list()
|
||||
filtered_downloads = []
|
||||
for dl in downloads:
|
||||
if downloader and dl.downloader != downloader:
|
||||
continue
|
||||
if status != "all" and dl.status != status:
|
||||
continue
|
||||
filtered_downloads.append(dl)
|
||||
if filtered_downloads:
|
||||
return json.dumps([d.dict() for d in filtered_downloads], ensure_ascii=False, indent=2)
|
||||
return "未找到相关下载任务。"
|
||||
except Exception as e:
|
||||
logger.error(f"查询下载失败: {e}")
|
||||
return f"查询下载时发生错误: {str(e)}"
|
||||
33
app/agent/tools/impl/query_media_library.py
Normal file
33
app/agent/tools/impl/query_media_library.py
Normal file
@@ -0,0 +1,33 @@
|
||||
"""查询媒体库工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
from app.db.media_oper import MediaOper
|
||||
from app.log import logger
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
|
||||
|
||||
class QueryMediaLibraryTool(MoviePilotTool):
|
||||
name: str = "query_media_library"
|
||||
description: str = "查询媒体库状态,查看已入库的媒体文件情况。"
|
||||
|
||||
async def _arun(self, explanation: str, media_type: Optional[str] = "all",
|
||||
title: Optional[str] = None) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: media_type={media_type}, title={title}")
|
||||
try:
|
||||
media_oper = MediaOper()
|
||||
medias = media_oper.list()
|
||||
filtered_medias = []
|
||||
for media in medias:
|
||||
if media_type != "all" and media.type != media_type:
|
||||
continue
|
||||
if title and title.lower() not in media.title.lower():
|
||||
continue
|
||||
filtered_medias.append(media)
|
||||
if filtered_medias:
|
||||
return json.dumps([m.dict() for m in filtered_medias], ensure_ascii=False, indent=2)
|
||||
return "媒体库中未找到相关媒体。"
|
||||
except Exception as e:
|
||||
logger.error(f"查询媒体库失败: {e}")
|
||||
return f"查询媒体库时发生错误: {str(e)}"
|
||||
33
app/agent/tools/impl/query_subscribes.py
Normal file
33
app/agent/tools/impl/query_subscribes.py
Normal file
@@ -0,0 +1,33 @@
|
||||
"""查询订阅工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
from app.db.subscribe_oper import SubscribeOper
|
||||
from app.log import logger
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
|
||||
|
||||
class QuerySubscribesTool(MoviePilotTool):
|
||||
name: str = "query_subscribes"
|
||||
description: str = "查询订阅状态,查看用户的订阅列表和状态。"
|
||||
|
||||
async def _arun(self, explanation: str, status: Optional[str] = "all",
|
||||
media_type: Optional[str] = "all") -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: status={status}, media_type={media_type}")
|
||||
try:
|
||||
subscribe_oper = SubscribeOper()
|
||||
subscribes = subscribe_oper.list()
|
||||
filtered_subscribes = []
|
||||
for sub in subscribes:
|
||||
if status != "all" and sub.status != status:
|
||||
continue
|
||||
if media_type != "all" and sub.type != media_type:
|
||||
continue
|
||||
filtered_subscribes.append(sub)
|
||||
if filtered_subscribes:
|
||||
return json.dumps([s.dict() for s in filtered_subscribes], ensure_ascii=False, indent=2)
|
||||
return "未找到相关订阅。"
|
||||
except Exception as e:
|
||||
logger.error(f"查询订阅失败: {e}")
|
||||
return f"查询订阅时发生错误: {str(e)}"
|
||||
42
app/agent/tools/impl/search_media.py
Normal file
42
app/agent/tools/impl/search_media.py
Normal file
@@ -0,0 +1,42 @@
|
||||
"""搜索媒体工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
from app.chain.media import MediaChain
|
||||
from app.log import logger
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
|
||||
|
||||
class SearchMediaTool(MoviePilotTool):
|
||||
name: str = "search_media"
|
||||
description: str = "搜索媒体资源,包括电影、电视剧、动漫等。可以根据标题、年份、类型等条件进行搜索。"
|
||||
|
||||
async def _arun(self, title: str, explanation: str, year: Optional[str] = None,
|
||||
media_type: Optional[str] = None, season: Optional[int] = None) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: title={title}, year={year}, media_type={media_type}, season={season}")
|
||||
|
||||
# 发送工具执行说明
|
||||
self._send_tool_message(f"正在搜索媒体资源: {title}" + (f" ({year})" if year else ""), "info")
|
||||
|
||||
try:
|
||||
media_chain = MediaChain()
|
||||
results = media_chain.search_media(title=title, year=year, mtype=media_type, season=season)
|
||||
if results:
|
||||
result_message = f"找到 {len(results)} 个相关媒体资源"
|
||||
self._send_tool_message(result_message, "success")
|
||||
|
||||
# 发送详细结果
|
||||
for i, result in enumerate(results[:5]): # 只显示前5个结果
|
||||
media_info = f"{i+1}. {result.title} ({result.year}) - {result.type}"
|
||||
self._send_tool_message(media_info, "info")
|
||||
|
||||
return json.dumps([r.dict() for r in results], ensure_ascii=False, indent=2)
|
||||
else:
|
||||
error_message = f"未找到相关媒体资源: {title}"
|
||||
self._send_tool_message(error_message, "warning")
|
||||
return error_message
|
||||
except Exception as e:
|
||||
error_message = f"搜索媒体失败: {str(e)}"
|
||||
self._send_tool_message(error_message, "error")
|
||||
return error_message
|
||||
54
app/agent/tools/impl/search_torrents.py
Normal file
54
app/agent/tools/impl/search_torrents.py
Normal file
@@ -0,0 +1,54 @@
|
||||
"""搜索种子工具"""
|
||||
|
||||
import json
|
||||
from typing import List, Optional
|
||||
|
||||
from app.chain.search import SearchChain
|
||||
from app.log import logger
|
||||
from app.schemas.types import MediaType
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
|
||||
|
||||
class SearchTorrentsTool(MoviePilotTool):
|
||||
name: str = "search_torrents"
|
||||
description: str = "搜索站点种子资源,根据媒体信息搜索可下载的种子文件。"
|
||||
|
||||
async def _arun(self, title: str, explanation: str, year: Optional[str] = None,
|
||||
media_type: Optional[str] = None, season: Optional[int] = None,
|
||||
sites: Optional[List[int]] = None) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: title={title}, year={year}, media_type={media_type}, season={season}, sites={sites}")
|
||||
|
||||
# 发送工具执行说明
|
||||
self._send_tool_message(f"正在搜索种子资源: {title}" + (f" ({year})" if year else ""), "info")
|
||||
|
||||
try:
|
||||
search_chain = SearchChain()
|
||||
torrents = search_chain.search_by_title(title=title, sites=sites)
|
||||
filtered_torrents = []
|
||||
for torrent in torrents:
|
||||
if year and torrent.meta_info.year != year:
|
||||
continue
|
||||
if media_type and torrent.media_info and torrent.media_info.type != MediaType(media_type):
|
||||
continue
|
||||
if season and torrent.meta_info.begin_season != season:
|
||||
continue
|
||||
filtered_torrents.append(torrent)
|
||||
|
||||
if filtered_torrents:
|
||||
result_message = f"找到 {len(filtered_torrents)} 个相关种子资源"
|
||||
self._send_tool_message(result_message, "success")
|
||||
|
||||
# 发送详细结果
|
||||
for i, torrent in enumerate(filtered_torrents[:5]): # 只显示前5个结果
|
||||
torrent_info = f"{i+1}. {torrent.title} - {torrent.site_name}"
|
||||
self._send_tool_message(torrent_info, "info")
|
||||
|
||||
return json.dumps([t.dict() for t in filtered_torrents], ensure_ascii=False, indent=2)
|
||||
else:
|
||||
error_message = f"未找到相关种子资源: {title}"
|
||||
self._send_tool_message(error_message, "warning")
|
||||
return error_message
|
||||
except Exception as e:
|
||||
error_message = f"搜索种子时发生错误: {str(e)}"
|
||||
self._send_tool_message(error_message, "error")
|
||||
return error_message
|
||||
22
app/agent/tools/impl/send_message.py
Normal file
22
app/agent/tools/impl/send_message.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""发送消息工具"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from app.helper.message import MessageHelper
|
||||
from app.log import logger
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
|
||||
|
||||
class SendMessageTool(MoviePilotTool):
|
||||
name: str = "send_message"
|
||||
description: str = "发送消息通知,向用户发送操作结果或重要信息。"
|
||||
|
||||
async def _arun(self, message: str, explanation: str, message_type: Optional[str] = "info") -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: message={message}, message_type={message_type}")
|
||||
try:
|
||||
message_helper = MessageHelper()
|
||||
message_helper.put(message=message, role="system", title=f"AI助手通知 ({message_type})")
|
||||
return "消息已发送。"
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息失败: {e}")
|
||||
return f"发送消息时发生错误: {str(e)}"
|
||||
@@ -163,6 +163,10 @@ class MessageChain(ChainBase):
|
||||
original_message_id=original_message_id, original_chat_id=original_chat_id)
|
||||
else:
|
||||
logger.warning(f"渠道 {channel.value} 不支持回调,但收到了回调消息:{text}")
|
||||
elif text.startswith('/ai') or text.startswith('/AI'):
|
||||
# AI智能体处理
|
||||
self._handle_ai_message(text=text, channel=channel, source=source,
|
||||
userid=userid, username=username)
|
||||
elif text.startswith('/'):
|
||||
# 执行命令
|
||||
self.eventmanager.send_event(
|
||||
@@ -815,3 +819,63 @@ class MessageChain(ChainBase):
|
||||
buttons.append(page_buttons)
|
||||
|
||||
return buttons
|
||||
|
||||
def _handle_ai_message(self, text: str, channel: MessageChannel, source: str,
|
||||
userid: Union[str, int], username: str) -> None:
|
||||
"""
|
||||
处理AI智能体消息
|
||||
"""
|
||||
try:
|
||||
# 检查AI智能体是否启用
|
||||
if not settings.AI_AGENT_ENABLE:
|
||||
self.messagehelper.put("AI智能体功能未启用,请在系统设置中启用", role="system", title="AI助手")
|
||||
return
|
||||
|
||||
# 检查LLM配置
|
||||
if not settings.LLM_API_KEY:
|
||||
self.messagehelper.put("LLM API密钥未配置,请检查系统设置", role="system", title="AI助手")
|
||||
return
|
||||
|
||||
# 提取用户消息
|
||||
user_message = text[3:].strip() # 移除 "/ai" 前缀
|
||||
if not user_message:
|
||||
self.messagehelper.put("请输入您的问题或需求", role="system", title="AI助手")
|
||||
return
|
||||
|
||||
# 发送处理中消息
|
||||
self.messagehelper.put("正在处理您的请求,请稍候...", role="system", title="AI助手")
|
||||
|
||||
# 异步处理AI智能体请求
|
||||
import asyncio
|
||||
from app.agent.agent import agent_manager
|
||||
|
||||
# 生成会话ID
|
||||
session_id = f"user_{userid}_{hash(user_message) % 10000}"
|
||||
|
||||
# 在事件循环中处理
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
response = loop.run_until_complete(
|
||||
agent_manager.process_message(
|
||||
session_id=session_id,
|
||||
user_id=str(userid),
|
||||
message=user_message
|
||||
)
|
||||
)
|
||||
except RuntimeError:
|
||||
# 如果没有事件循环,创建新的
|
||||
response = asyncio.run(
|
||||
agent_manager.process_message(
|
||||
session_id=session_id,
|
||||
user_id=str(userid),
|
||||
message=user_message
|
||||
)
|
||||
)
|
||||
|
||||
# 发送AI智能体回复
|
||||
self.messagehelper.put(response, role="system", title="AI助手")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理AI智能体消息失败: {e}")
|
||||
self.messagehelper.put(f"AI智能体处理失败: {str(e)}", role="system", title="AI助手")
|
||||
|
||||
|
||||
@@ -406,6 +406,32 @@ class ConfigModel(BaseModel):
|
||||
# Docker Client API地址
|
||||
DOCKER_CLIENT_API: Optional[str] = "tcp://127.0.0.1:38379"
|
||||
|
||||
# ==================== AI智能体配置 ====================
|
||||
# AI智能体开关
|
||||
AI_AGENT_ENABLE: bool = False
|
||||
# LLM提供商 (openai/google/deepseek)
|
||||
LLM_PROVIDER: str = "openai"
|
||||
# LLM模型名称
|
||||
LLM_MODEL: str = "gpt-3.5-turbo"
|
||||
# LLM API密钥
|
||||
LLM_API_KEY: Optional[str] = None
|
||||
# LLM基础URL(用于自定义API端点)
|
||||
LLM_BASE_URL: Optional[str] = None
|
||||
# LLM温度参数
|
||||
LLM_TEMPERATURE: float = 0.7
|
||||
# LLM最大迭代次数
|
||||
LLM_MAX_ITERATIONS: int = 5
|
||||
# LLM工具调用超时时间(秒)
|
||||
LLM_TOOL_TIMEOUT: int = 300
|
||||
# 是否启用详细日志
|
||||
LLM_VERBOSE: bool = False
|
||||
# 最大记忆消息数量
|
||||
LLM_MAX_MEMORY_MESSAGES: int = 50
|
||||
# 记忆保留天数
|
||||
LLM_MEMORY_RETENTION_DAYS: int = 30
|
||||
# Redis记忆保留天数(如果使用Redis)
|
||||
LLM_REDIS_MEMORY_RETENTION_DAYS: int = 7
|
||||
|
||||
|
||||
class Settings(BaseSettings, ConfigModel, LogConfigModel):
|
||||
"""
|
||||
|
||||
@@ -95,4 +95,4 @@ if __name__ == '__main__':
|
||||
# 更新数据库
|
||||
update_db()
|
||||
# 启动API服务
|
||||
Server.run()
|
||||
Server.run()
|
||||
@@ -190,7 +190,7 @@ class Api:
|
||||
"""
|
||||
用户列表(仅管理员有权访问)
|
||||
"""
|
||||
if (res := self.request("/manager/user/list")) and res.success:
|
||||
if (res := self.request("/memory/user/list")) and res.success:
|
||||
if not res.data:
|
||||
return []
|
||||
return [
|
||||
|
||||
56
app/schemas/agent.py
Normal file
56
app/schemas/agent.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""AI智能体相关数据模型"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Any
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ConversationMemory(BaseModel):
|
||||
"""对话记忆模型"""
|
||||
|
||||
session_id: str = Field(description="会话ID")
|
||||
user_id: Optional[str] = Field(default=None, description="用户ID")
|
||||
title: Optional[str] = Field(default=None, description="会话标题")
|
||||
messages: List[Dict[str, Any]] = Field(default_factory=list, description="消息列表")
|
||||
context: Dict[str, Any] = Field(default_factory=dict, description="会话上下文")
|
||||
created_at: datetime = Field(default_factory=datetime.now, description="创建时间")
|
||||
updated_at: datetime = Field(default_factory=datetime.now, description="更新时间")
|
||||
|
||||
class Config:
|
||||
json_encoders = {
|
||||
datetime: lambda v: v.isoformat()
|
||||
}
|
||||
|
||||
|
||||
class AgentState(BaseModel):
|
||||
"""AI智能体状态模型"""
|
||||
|
||||
session_id: str = Field(description="会话ID")
|
||||
current_task: Optional[str] = Field(default=None, description="当前任务")
|
||||
is_thinking: bool = Field(default=False, description="是否正在思考")
|
||||
last_activity: datetime = Field(default_factory=datetime.now, description="最后活动时间")
|
||||
|
||||
class Config:
|
||||
json_encoders = {
|
||||
datetime: lambda v: v.isoformat()
|
||||
}
|
||||
|
||||
|
||||
class UserMessage(BaseModel):
|
||||
"""用户消息模型"""
|
||||
|
||||
session_id: str = Field(description="会话ID")
|
||||
content: str = Field(description="消息内容")
|
||||
user_id: Optional[str] = Field(default=None, description="用户ID")
|
||||
channel: Optional[str] = Field(default=None, description="消息渠道")
|
||||
source: Optional[str] = Field(default=None, description="消息来源")
|
||||
|
||||
|
||||
class ToolResult(BaseModel):
|
||||
"""工具执行结果模型"""
|
||||
|
||||
session_id: str = Field(description="会话ID")
|
||||
call_id: str = Field(description="调用ID")
|
||||
success: bool = Field(description="是否成功")
|
||||
result: Optional[str] = Field(default=None, description="执行结果")
|
||||
error: Optional[str] = Field(default=None, description="错误信息")
|
||||
@@ -194,6 +194,8 @@ class SystemConfigKey(Enum):
|
||||
FollowSubscribers = "FollowSubscribers"
|
||||
# 通知发送时间
|
||||
NotificationSendTime = "NotificationSendTime"
|
||||
# AI智能体配置
|
||||
AIAgentConfig = "AIAgentConfig"
|
||||
# 通知消息格式模板
|
||||
NotificationTemplates = "NotificationTemplates"
|
||||
# 刮削开关设置
|
||||
|
||||
105
app/startup/agent_initializer.py
Normal file
105
app/startup/agent_initializer.py
Normal file
@@ -0,0 +1,105 @@
|
||||
"""
|
||||
AI智能体初始化器
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
from typing import Optional
|
||||
|
||||
from app.core.config import settings
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class AgentInitializer:
|
||||
"""AI智能体初始化器"""
|
||||
|
||||
def __init__(self):
|
||||
self.agent_manager = None
|
||||
self._initialized = False
|
||||
|
||||
async def initialize(self) -> bool:
|
||||
"""
|
||||
初始化AI智能体管理器
|
||||
"""
|
||||
try:
|
||||
if not settings.AI_AGENT_ENABLE:
|
||||
logger.info("AI智能体功能未启用")
|
||||
return True
|
||||
|
||||
from app.agent.agent import agent_manager
|
||||
self.agent_manager = agent_manager
|
||||
|
||||
await agent_manager.initialize()
|
||||
self._initialized = True
|
||||
logger.info("AI智能体管理器初始化成功")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"AI智能体管理器初始化失败: {e}")
|
||||
return False
|
||||
|
||||
async def cleanup(self) -> None:
|
||||
"""
|
||||
清理AI智能体管理器
|
||||
"""
|
||||
try:
|
||||
if not self._initialized or not self.agent_manager:
|
||||
return
|
||||
|
||||
await self.agent_manager.close()
|
||||
self._initialized = False
|
||||
logger.info("AI智能体管理器已关闭")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"关闭AI智能体管理器时发生错误: {e}")
|
||||
|
||||
|
||||
# 全局AI智能体初始化器实例
|
||||
agent_initializer = AgentInitializer()
|
||||
|
||||
|
||||
def init_agent():
|
||||
"""
|
||||
初始化AI智能体(同步版本,用于在后台线程中运行)
|
||||
"""
|
||||
try:
|
||||
if not settings.AI_AGENT_ENABLE:
|
||||
logger.info("AI智能体功能未启用")
|
||||
return True
|
||||
|
||||
# 在新的事件循环中初始化AI智能体管理器
|
||||
def run_init():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
success = loop.run_until_complete(agent_initializer.initialize())
|
||||
if success:
|
||||
logger.info("AI智能体管理器初始化成功")
|
||||
else:
|
||||
logger.error("AI智能体管理器初始化失败")
|
||||
return success
|
||||
except Exception as e:
|
||||
logger.error(f"AI智能体管理器初始化失败: {e}")
|
||||
return False
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
# 在后台线程中初始化
|
||||
init_thread = threading.Thread(target=run_init, daemon=True)
|
||||
init_thread.start()
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"初始化AI智能体时发生错误: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def stop_agent():
|
||||
"""
|
||||
停止AI智能体(异步版本,用于在应用关闭时调用)
|
||||
"""
|
||||
try:
|
||||
await agent_initializer.cleanup()
|
||||
except Exception as e:
|
||||
logger.error(f"停止AI智能体时发生错误: {e}")
|
||||
@@ -27,6 +27,7 @@ from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.command import CommandChain
|
||||
from app.schemas import Notification, NotificationType
|
||||
from app.schemas.types import SystemConfigKey
|
||||
from app.startup.agent_initializer import init_agent, stop_agent
|
||||
|
||||
|
||||
def start_frontend():
|
||||
@@ -110,6 +111,8 @@ async def stop_modules():
|
||||
"""
|
||||
服务关闭
|
||||
"""
|
||||
# 停止AI智能体
|
||||
await stop_agent()
|
||||
# 停止模块
|
||||
ModuleManager().stop()
|
||||
# 停止事件消费
|
||||
@@ -151,6 +154,8 @@ def init_modules():
|
||||
EventManager().start()
|
||||
# 初始化订阅分享
|
||||
SubscribeHelper()
|
||||
# 初始化AI智能体
|
||||
init_agent()
|
||||
# 启动前端服务
|
||||
start_frontend()
|
||||
# 检查认证状态
|
||||
|
||||
@@ -77,4 +77,12 @@ setuptools~=78.1.0
|
||||
pympler~=1.1
|
||||
smbprotocol~=1.15.0
|
||||
setproctitle~=1.3.6
|
||||
httpx[socks]~=0.28.1
|
||||
httpx[socks]~=0.28.1
|
||||
langchain==0.3.27
|
||||
langchain-core==0.3.76
|
||||
langchain-community==0.3.29
|
||||
langchain-openai==0.3.33
|
||||
langchain-google-genai==2.0.10
|
||||
langchain-deepseek==0.1.4
|
||||
langchain-experimental==0.3.4
|
||||
openai==1.108.2
|
||||
Reference in New Issue
Block a user