mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-03-20 03:57:30 +08:00
fix(agent):修复智能体工具调用,优化媒体库查询工具
This commit is contained in:
@@ -98,15 +98,14 @@ class MoviePilotAgent:
|
||||
user_id=self.user_id
|
||||
)
|
||||
if messages:
|
||||
loaded_messages = []
|
||||
for msg in messages:
|
||||
if msg.get("role") == "user":
|
||||
loaded_messages.append(HumanMessage(content=msg.get("content", "")))
|
||||
chat_history.add_message(HumanMessage(content=msg.get("content", "")))
|
||||
elif msg.get("role") == "agent":
|
||||
loaded_messages.append(AIMessage(content=msg.get("content", "")))
|
||||
chat_history.add_message(AIMessage(content=msg.get("content", "")))
|
||||
elif msg.get("role") == "tool_call":
|
||||
metadata = msg.get("metadata", {})
|
||||
loaded_messages.append(
|
||||
chat_history.add_message(
|
||||
AIMessage(
|
||||
content=msg.get("content", ""),
|
||||
tool_calls=[
|
||||
@@ -120,17 +119,12 @@ class MoviePilotAgent:
|
||||
)
|
||||
elif msg.get("role") == "tool_result":
|
||||
metadata = msg.get("metadata", {})
|
||||
loaded_messages.append(ToolMessage(
|
||||
chat_history.add_message(ToolMessage(
|
||||
content=msg.get("content", ""),
|
||||
tool_call_id=metadata.get("call_id", "unknown")
|
||||
))
|
||||
elif msg.get("role") == "system":
|
||||
loaded_messages.append(SystemMessage(content=msg.get("content", "")))
|
||||
|
||||
# 验证并修复工具调用的完整性
|
||||
validated_messages = self._ensure_tool_call_integrity(loaded_messages)
|
||||
for msg in validated_messages:
|
||||
chat_history.add_message(msg)
|
||||
chat_history.add_message(SystemMessage(content=msg.get("content", "")))
|
||||
|
||||
return chat_history
|
||||
|
||||
@@ -199,71 +193,6 @@ class MoviePilotAgent:
|
||||
# 发生错误时返回一个保守的估算值
|
||||
return len(str(messages)) // 4
|
||||
|
||||
def _ensure_tool_call_integrity(self, messages: List[Union[HumanMessage, AIMessage, ToolMessage, SystemMessage]]) \
|
||||
-> List[Union[HumanMessage, AIMessage, ToolMessage, SystemMessage]]:
|
||||
"""
|
||||
确保工具调用的完整性:
|
||||
1. 带有 tool_calls 的 AIMessage 必须后跟相应的 ToolMessages
|
||||
2. 移除不完整的 AIMessages(有 tool_calls 但缺少 ToolMessage 响应)及其部分 ToolMessages
|
||||
3. 过滤掉不对应任何 tool_call 的孤立 ToolMessages
|
||||
"""
|
||||
if not messages:
|
||||
return messages
|
||||
|
||||
validated_messages = []
|
||||
i = 0
|
||||
|
||||
while i < len(messages):
|
||||
msg = messages[i]
|
||||
|
||||
# 检查这是否是一个带有 tool_calls 的 AIMessage
|
||||
if isinstance(msg, AIMessage) and getattr(msg, 'tool_calls', None):
|
||||
# 提取 tool_call IDs(ToolCall 是 TypedDict,在运行时是 dict)
|
||||
tool_call_ids = {
|
||||
tc.get('id') if tc.get('id') is not None else tc.get('name')
|
||||
for tc in msg.tool_calls
|
||||
}
|
||||
|
||||
# 查找后续的 ToolMessages
|
||||
j = i + 1
|
||||
found_tool_messages = []
|
||||
while j < len(messages) and isinstance(messages[j], ToolMessage):
|
||||
found_tool_messages.append(messages[j])
|
||||
j += 1
|
||||
|
||||
# 检查所有 tool_calls 是否都有对应的 ToolMessages
|
||||
found_tool_call_ids = {tm.tool_call_id for tm in found_tool_messages}
|
||||
|
||||
# 警告:如果存在不对应任何 tool_call 的额外 ToolMessages
|
||||
extra_tool_messages = found_tool_call_ids - tool_call_ids
|
||||
if extra_tool_messages:
|
||||
logger.warning(f"Found extra ToolMessages that don't correspond to any tool_call: {extra_tool_messages}")
|
||||
|
||||
if not tool_call_ids.issubset(found_tool_call_ids):
|
||||
# 缺少部分 tool_call 响应,跳过此 AIMessage 及其所有 ToolMessages
|
||||
logger.warning("Removing incomplete tool_call AIMessage and its partial ToolMessages: missing tool_call responses")
|
||||
i = j # 跳过此 AIMessage 及其所有相关的 ToolMessages
|
||||
continue
|
||||
else:
|
||||
# 添加此 AIMessage 及其对应的 ToolMessages
|
||||
validated_messages.append(msg)
|
||||
for tm in found_tool_messages:
|
||||
if tm.tool_call_id in tool_call_ids:
|
||||
validated_messages.append(tm)
|
||||
i = j
|
||||
continue
|
||||
# 跳过孤立的 ToolMessages(前面没有带 tool_calls 的 AIMessage)
|
||||
elif isinstance(msg, ToolMessage):
|
||||
logger.warning(f"Skipping orphaned ToolMessage with tool_call_id: {msg.tool_call_id}")
|
||||
i += 1
|
||||
continue
|
||||
else:
|
||||
validated_messages.append(msg)
|
||||
|
||||
i += 1
|
||||
|
||||
return validated_messages
|
||||
|
||||
def _create_agent_executor(self) -> RunnableWithMessageHistory:
|
||||
"""
|
||||
创建Agent执行器
|
||||
@@ -281,8 +210,13 @@ class MoviePilotAgent:
|
||||
|
||||
# 包装trimmer,在裁剪后验证工具调用的完整性
|
||||
def validated_trimmer(messages):
|
||||
trimmed = base_trimmer(messages)
|
||||
return self._ensure_tool_call_integrity(trimmed)
|
||||
# 如果输入是 PromptValue,转换为消息列表
|
||||
if hasattr(messages, "to_messages"):
|
||||
messages = messages.to_messages()
|
||||
trimmed = base_trimmer.invoke(messages)
|
||||
if len(trimmed) < len(messages):
|
||||
logger.info(f"LangChain消息上下文已裁剪: {len(messages)} -> {len(trimmed)}")
|
||||
return trimmed
|
||||
|
||||
# 创建Agent执行链
|
||||
agent = (
|
||||
|
||||
@@ -232,7 +232,7 @@ class ConversationMemoryManager:
|
||||
return []
|
||||
|
||||
# 获取所有消息
|
||||
return memory.messages
|
||||
return memory.messages[:-1]
|
||||
|
||||
async def get_recent_messages(
|
||||
self,
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
import uuid
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import Any, Optional
|
||||
|
||||
@@ -42,6 +43,9 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
# 获取工具调用前的agent消息
|
||||
agent_message = await self._callback_handler.get_message()
|
||||
|
||||
# 生成唯一的工具调用ID
|
||||
call_id = f"call_{str(uuid.uuid4())[:16]}"
|
||||
|
||||
# 记忆工具调用
|
||||
await conversation_manager.add_conversation(
|
||||
session_id=self._session_id,
|
||||
@@ -49,8 +53,8 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
role="tool_call",
|
||||
content=agent_message,
|
||||
metadata={
|
||||
"call_id": self.__class__.__name__,
|
||||
"tool_name": self.__class__.__name__,
|
||||
"call_id": call_id,
|
||||
"tool_name": self.name,
|
||||
"parameters": kwargs
|
||||
}
|
||||
)
|
||||
@@ -61,21 +65,21 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
explanation = kwargs.get("explanation")
|
||||
if explanation:
|
||||
tool_message = explanation
|
||||
|
||||
|
||||
# 合并agent消息和工具执行消息,一起发送
|
||||
messages = []
|
||||
if agent_message:
|
||||
messages.append(agent_message)
|
||||
if tool_message:
|
||||
messages.append(f"⚙️ => {tool_message}")
|
||||
|
||||
|
||||
# 发送合并后的消息
|
||||
if messages:
|
||||
merged_message = "\n\n".join(messages)
|
||||
await self.send_tool_message(merged_message, title="MoviePilot助手")
|
||||
|
||||
logger.debug(f'Executing tool {self.name} with args: {kwargs}')
|
||||
|
||||
|
||||
# 执行工具,捕获异常确保结果总是被存储到记忆中
|
||||
try:
|
||||
result = await self.run(**kwargs)
|
||||
@@ -93,13 +97,15 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
formated_result = str(result)
|
||||
else:
|
||||
formated_result = json.dumps(result, ensure_ascii=False, indent=2)
|
||||
|
||||
await conversation_manager.add_conversation(
|
||||
session_id=self._session_id,
|
||||
user_id=self._user_id,
|
||||
role="tool_result",
|
||||
content=formated_result,
|
||||
metadata={
|
||||
"call_id": self.__class__.__name__
|
||||
"call_id": call_id,
|
||||
"tool_name": self.name,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ from pydantic import BaseModel, Field
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.mediaserver import MediaServerChain
|
||||
from app.core.context import MediaInfo
|
||||
from app.core.meta import MetaBase
|
||||
from app.log import logger
|
||||
from app.schemas.types import MediaType
|
||||
|
||||
@@ -51,47 +52,88 @@ class QueryLibraryExistsTool(MoviePilotTool):
|
||||
try:
|
||||
if not title:
|
||||
return "请提供媒体标题进行查询"
|
||||
|
||||
# 创建 MediaInfo 对象
|
||||
mediainfo = MediaInfo()
|
||||
mediainfo.title = title
|
||||
mediainfo.year = year
|
||||
|
||||
# 转换媒体类型
|
||||
if media_type == "电影":
|
||||
mediainfo.type = MediaType.MOVIE
|
||||
elif media_type == "电视剧":
|
||||
mediainfo.type = MediaType.TV
|
||||
# media_type == "all" 时不设置类型,让媒体服务器自动判断
|
||||
|
||||
# 调用媒体服务器接口实时查询
|
||||
|
||||
media_chain = MediaServerChain()
|
||||
|
||||
# 1. 识别媒体信息(获取 TMDB ID 和各季的总集数等元数据)
|
||||
meta = MetaBase(title=title)
|
||||
if year:
|
||||
meta.year = str(year)
|
||||
if media_type == "电影":
|
||||
meta.type = MediaType.MOVIE
|
||||
elif media_type == "电视剧":
|
||||
meta.type = MediaType.TV
|
||||
|
||||
# 使用识别方法补充信息
|
||||
recognize_info = media_chain.recognize_media(meta=meta)
|
||||
if recognize_info:
|
||||
mediainfo = recognize_info
|
||||
else:
|
||||
# 识别失败,创建基本信息的 MediaInfo
|
||||
mediainfo = MediaInfo()
|
||||
mediainfo.title = title
|
||||
mediainfo.year = year
|
||||
if media_type == "电影":
|
||||
mediainfo.type = MediaType.MOVIE
|
||||
elif media_type == "电视剧":
|
||||
mediainfo.type = MediaType.TV
|
||||
|
||||
# 2. 调用媒体服务器接口实时查询存在信息
|
||||
existsinfo = media_chain.media_exists(mediainfo=mediainfo)
|
||||
|
||||
|
||||
if not existsinfo:
|
||||
return "媒体库中未找到相关媒体"
|
||||
|
||||
# 如果找到了,获取详细信息
|
||||
|
||||
# 3. 如果找到了,获取详细信息并组装结果
|
||||
result_items = []
|
||||
if existsinfo.itemid and existsinfo.server:
|
||||
iteminfo = media_chain.iteminfo(server=existsinfo.server, item_id=existsinfo.itemid)
|
||||
if iteminfo:
|
||||
# 使用 model_dump() 转换为字典格式
|
||||
item_dict = iteminfo.model_dump(exclude_none=True)
|
||||
|
||||
# 对于电视剧,补充已存在的季集详情及进度统计
|
||||
if existsinfo.type == MediaType.TV:
|
||||
# 注入已存在集信息 (Dict[int, list])
|
||||
item_dict["seasoninfo"] = existsinfo.seasons
|
||||
|
||||
# 统计库中已存在的季集总数
|
||||
if existsinfo.seasons:
|
||||
item_dict["existing_episodes_count"] = sum(len(e) for e in existsinfo.seasons.values())
|
||||
item_dict["seasons_existing_count"] = {str(s): len(e) for s, e in existsinfo.seasons.items()}
|
||||
|
||||
# 如果识别到了元数据,补充总计对比和进度概览
|
||||
if mediainfo.seasons:
|
||||
item_dict["seasons_total_count"] = {str(s): len(e) for s, e in mediainfo.seasons.items()}
|
||||
# 进度概览,例如 "Season 1": "3/12"
|
||||
item_dict["seasons_progress"] = {
|
||||
f"第{s}季": f"{len(existsinfo.seasons.get(s, []))}/{len(mediainfo.seasons.get(s, []))} 集"
|
||||
for s in mediainfo.seasons.keys() if (s in existsinfo.seasons or s > 0)
|
||||
}
|
||||
|
||||
result_items.append(item_dict)
|
||||
|
||||
|
||||
if result_items:
|
||||
return json.dumps(result_items, ensure_ascii=False)
|
||||
|
||||
# 如果找到了但没有详细信息,返回基本信息
|
||||
|
||||
# 如果找到了但没有获取到 iteminfo,返回基本信息
|
||||
result_dict = {
|
||||
"title": mediainfo.title,
|
||||
"year": mediainfo.year,
|
||||
"type": existsinfo.type.value if existsinfo.type else None,
|
||||
"server": existsinfo.server,
|
||||
"server_type": existsinfo.server_type,
|
||||
"itemid": existsinfo.itemid,
|
||||
"seasons": existsinfo.seasons if existsinfo.seasons else {}
|
||||
}
|
||||
if existsinfo.type == MediaType.TV and existsinfo.seasons:
|
||||
result_dict["existing_episodes_count"] = sum(len(e) for e in existsinfo.seasons.values())
|
||||
result_dict["seasons_existing_count"] = {str(s): len(e) for s, e in existsinfo.seasons.items()}
|
||||
if mediainfo.seasons:
|
||||
result_dict["seasons_total_count"] = {str(s): len(e) for s, e in mediainfo.seasons.items()}
|
||||
|
||||
return json.dumps([result_dict], ensure_ascii=False)
|
||||
except Exception as e:
|
||||
logger.error(f"查询媒体库失败: {e}", exc_info=True)
|
||||
return f"查询媒体库时发生错误: {str(e)}"
|
||||
|
||||
|
||||
@@ -842,8 +842,7 @@ class MessageChain(ChainBase):
|
||||
|
||||
return buttons
|
||||
|
||||
@staticmethod
|
||||
def _get_or_create_session_id(userid: Union[str, int]) -> str:
|
||||
def _get_or_create_session_id(self, userid: Union[str, int]) -> str:
|
||||
"""
|
||||
获取或创建会话ID
|
||||
如果用户上次会话在15分钟内,则复用相同的会话ID;否则创建新的会话ID
|
||||
@@ -851,8 +850,8 @@ class MessageChain(ChainBase):
|
||||
current_time = datetime.now()
|
||||
|
||||
# 检查用户是否有已存在的会话
|
||||
if userid in MessageChain._user_sessions:
|
||||
session_id, last_time = MessageChain._user_sessions[userid]
|
||||
if userid in self._user_sessions:
|
||||
session_id, last_time = self._user_sessions[userid]
|
||||
|
||||
# 计算时间差
|
||||
time_diff = current_time - last_time
|
||||
@@ -860,25 +859,24 @@ class MessageChain(ChainBase):
|
||||
# 如果时间差小于等于15分钟,复用会话ID
|
||||
if time_diff <= timedelta(minutes=MessageChain._session_timeout_minutes):
|
||||
# 更新最后使用时间
|
||||
MessageChain._user_sessions[userid] = (session_id, current_time)
|
||||
self._user_sessions[userid] = (session_id, current_time)
|
||||
logger.info(
|
||||
f"复用会话ID: {session_id}, 用户: {userid}, 距离上次会话: {time_diff.total_seconds() / 60:.1f}分钟")
|
||||
return session_id
|
||||
|
||||
# 创建新的会话ID
|
||||
new_session_id = f"user_{userid}_{int(time.time())}"
|
||||
MessageChain._user_sessions[userid] = (new_session_id, current_time)
|
||||
self._user_sessions[userid] = (new_session_id, current_time)
|
||||
logger.info(f"创建新会话ID: {new_session_id}, 用户: {userid}")
|
||||
return new_session_id
|
||||
|
||||
@staticmethod
|
||||
def clear_user_session(userid: Union[str, int]) -> bool:
|
||||
def clear_user_session(self, userid: Union[str, int]) -> bool:
|
||||
"""
|
||||
清除指定用户的会话信息
|
||||
返回是否成功清除
|
||||
"""
|
||||
if userid in MessageChain._user_sessions:
|
||||
session_id, _ = MessageChain._user_sessions.pop(userid)
|
||||
if userid in self._user_sessions:
|
||||
session_id, _ = self._user_sessions.pop(userid)
|
||||
logger.info(f"已清除用户 {userid} 的会话: {session_id}")
|
||||
return True
|
||||
return False
|
||||
@@ -889,8 +887,8 @@ class MessageChain(ChainBase):
|
||||
"""
|
||||
# 获取并清除会话信息
|
||||
session_id = None
|
||||
if userid in MessageChain._user_sessions:
|
||||
session_id, _ = MessageChain._user_sessions.pop(userid)
|
||||
if userid in self._user_sessions:
|
||||
session_id, _ = self._user_sessions.pop(userid)
|
||||
logger.info(f"已清除用户 {userid} 的会话: {session_id}")
|
||||
|
||||
# 如果有会话ID,同时清除智能体的会话记忆
|
||||
|
||||
Reference in New Issue
Block a user