feat: 为工具管理器添加参数类型规范化处理,并基于渠道能力动态生成提示词中的格式要求

This commit is contained in:
jxxghp
2026-01-15 20:55:35 +08:00
parent 1bf9862e47
commit 1641d432dd
10 changed files with 324 additions and 282 deletions

View File

@@ -1,4 +1,3 @@
"""MoviePilot工具基类"""
import json
from abc import ABCMeta, abstractmethod
from typing import Any, Optional
@@ -6,7 +5,7 @@ from typing import Any, Optional
from langchain.tools import BaseTool
from pydantic import PrivateAttr
from app.agent import StreamingCallbackHandler, ConversationMemoryManager
from app.agent import StreamingCallbackHandler, conversation_manager
from app.chain import ChainBase
from app.log import logger
from app.schemas import Notification
@@ -17,7 +16,9 @@ class ToolChain(ChainBase):
class MoviePilotTool(BaseTool, metaclass=ABCMeta):
"""MoviePilot专用工具基类"""
"""
MoviePilot专用工具基类
"""
_session_id: str = PrivateAttr()
_user_id: str = PrivateAttr()
@@ -25,7 +26,6 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
_source: str = PrivateAttr(default=None)
_username: str = PrivateAttr(default=None)
_callback_handler: StreamingCallbackHandler = PrivateAttr(default=None)
_memory_manager: ConversationMemoryManager = PrivateAttr(default=None)
def __init__(self, session_id: str, user_id: str, **kwargs):
super().__init__(**kwargs)
@@ -36,7 +36,9 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
pass
async def _arun(self, **kwargs) -> str:
"""异步运行工具"""
"""
异步运行工具
"""
# 发送和记忆工具调用前的信息
agent_message = await self._callback_handler.get_message()
if agent_message:
@@ -44,7 +46,7 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
await self.send_tool_message(agent_message, title="MoviePilot助手")
# 记忆工具调用
await self._memory_manager.add_memory(
await conversation_manager.add_conversation(
session_id=self._session_id,
user_id=self._user_id,
role="tool_call",
@@ -77,7 +79,7 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
formated_result = str(result)
else:
formated_result = json.dumps(result, ensure_ascii=False, indent=2)
await self._memory_manager.add_memory(
await conversation_manager.add_conversation(
session_id=self._session_id,
user_id=self._user_id,
role="tool_result",
@@ -106,21 +108,23 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
raise NotImplementedError
def set_message_attr(self, channel: str, source: str, username: str):
"""设置消息属性"""
"""
设置消息属性
"""
self._channel = channel
self._source = source
self._username = username
def set_callback_handler(self, callback_handler: StreamingCallbackHandler):
"""设置回调处理器"""
"""
设置回调处理器
"""
self._callback_handler = callback_handler
def set_memory_manager(self, memory_manager: ConversationMemoryManager):
"""设置记忆客理器"""
self._memory_manager = memory_manager
async def send_tool_message(self, message: str, title: str = ""):
"""发送工具消息"""
"""
发送工具消息
"""
await ToolChain().async_post_message(
Notification(
channel=self._channel,

View File

@@ -1,5 +1,3 @@
"""MoviePilot工具工厂"""
from typing import List, Callable
from app.agent.tools.impl.add_download import AddDownloadTool
@@ -47,13 +45,17 @@ from .base import MoviePilotTool
class MoviePilotToolFactory:
"""MoviePilot工具工厂"""
"""
MoviePilot工具工厂
"""
@staticmethod
def create_tools(session_id: str, user_id: str,
channel: str = None, source: str = None, username: str = None,
callback_handler: Callable = None, memory_mananger: Callable = None) -> List[MoviePilotTool]:
"""创建MoviePilot工具列表"""
callback_handler: Callable = None) -> List[MoviePilotTool]:
"""
创建MoviePilot工具列表
"""
tools = []
tool_definitions = [
SearchMediaTool,
@@ -104,7 +106,6 @@ class MoviePilotToolFactory:
)
tool.set_message_attr(channel=channel, source=source, username=username)
tool.set_callback_handler(callback_handler=callback_handler)
tool.set_memory_manager(memory_manager=memory_mananger)
tools.append(tool)
# 加载插件提供的工具
@@ -127,7 +128,6 @@ class MoviePilotToolFactory:
)
tool.set_message_attr(channel=channel, source=source, username=username)
tool.set_callback_handler(callback_handler=callback_handler)
tool.set_memory_manager(memory_manager=memory_mananger)
tools.append(tool)
plugin_tools_count += 1
logger.debug(f"成功加载插件 {plugin_name}({plugin_id}) 的工具: {ToolClass.__name__}")

View File

@@ -1,18 +1,15 @@
"""MoviePilot工具管理器
用于HTTP API调用工具
"""
import json
import uuid
from typing import Any, Dict, List, Optional
from app.agent import ConversationMemoryManager
from app.agent.tools.factory import MoviePilotToolFactory
from app.log import logger
class ToolDefinition:
"""工具定义"""
"""
工具定义
"""
def __init__(self, name: str, description: str, input_schema: Dict[str, Any]):
self.name = name
@@ -21,7 +18,9 @@ class ToolDefinition:
class MoviePilotToolsManager:
"""MoviePilot工具管理器用于HTTP API"""
"""
MoviePilot工具管理器用于HTTP API
"""
def __init__(self, user_id: str = "api_user", session_id: str = uuid.uuid4()):
"""
@@ -34,11 +33,12 @@ class MoviePilotToolsManager:
self.user_id = user_id
self.session_id = session_id
self.tools: List[Any] = []
self.memory_manager = ConversationMemoryManager()
self._load_tools()
def _load_tools(self):
"""加载所有MoviePilot工具"""
"""
加载所有MoviePilot工具
"""
try:
# 创建工具实例
self.tools = MoviePilotToolFactory.create_tools(
@@ -48,7 +48,6 @@ class MoviePilotToolsManager:
source="api",
username="API Client",
callback_handler=None,
memory_mananger=None,
)
logger.info(f"成功加载 {len(self.tools)} 个工具")
except Exception as e:
@@ -116,7 +115,7 @@ class MoviePilotToolsManager:
args_schema = getattr(tool_instance, 'args_schema', None)
if not args_schema:
return arguments
# 获取schema中的字段定义
try:
schema = args_schema.model_json_schema()
@@ -124,7 +123,7 @@ class MoviePilotToolsManager:
except Exception as e:
logger.warning(f"获取工具schema失败: {e}")
return arguments
# 规范化参数
normalized = {}
for key, value in arguments.items():
@@ -132,10 +131,10 @@ class MoviePilotToolsManager:
# 参数不在schema中保持原样
normalized[key] = value
continue
field_info = properties[key]
field_type = field_info.get("type")
# 处理 anyOf 类型(例如 Optional[int] 会生成 anyOf
any_of = field_info.get("anyOf")
if any_of and not field_type:
@@ -144,7 +143,7 @@ class MoviePilotToolsManager:
if "type" in type_option and type_option["type"] != "null":
field_type = type_option["type"]
break
# 根据类型进行转换
if field_type == "integer" and isinstance(value, str):
try:
@@ -167,7 +166,7 @@ class MoviePilotToolsManager:
normalized[key] = True
else:
normalized[key] = value
return normalized
async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> str:
@@ -192,7 +191,7 @@ class MoviePilotToolsManager:
try:
# 规范化参数类型
normalized_arguments = self._normalize_arguments(tool_instance, arguments)
# 调用工具的run方法
result = await tool_instance.run(**normalized_arguments)
@@ -270,3 +269,6 @@ class MoviePilotToolsManager:
"properties": properties,
"required": required
}
moviepilot_tool_manager = MoviePilotToolsManager()