feat(manager): 添加工具调用参数格式自动转换功能

This commit is contained in:
PKC278
2025-12-22 21:04:13 +08:00
parent ec375a19ae
commit c3a5106adc

View File

@@ -100,6 +100,72 @@ class MoviePilotToolsManager:
return tool
return None
def _normalize_arguments(self, tool_instance: Any, arguments: Dict[str, Any]) -> Dict[str, Any]:
"""
根据工具的参数schema规范化参数类型
Args:
tool_instance: 工具实例
arguments: 原始参数
Returns:
规范化后的参数
"""
# 获取工具的参数schema
args_schema = getattr(tool_instance, 'args_schema', None)
if not args_schema:
return arguments
# 获取schema中的字段定义
try:
schema = args_schema.model_json_schema()
properties = schema.get("properties", {})
except Exception as e:
logger.warning(f"获取工具schema失败: {e}")
return arguments
# 规范化参数
normalized = {}
for key, value in arguments.items():
if key not in properties:
# 参数不在schema中保持原样
normalized[key] = value
continue
field_info = properties[key]
field_type = field_info.get("type")
# 处理 anyOf 类型(例如 Optional[int] 会生成 anyOf
any_of = field_info.get("anyOf")
if any_of and not field_type:
# 从 anyOf 中提取实际类型
for type_option in any_of:
if "type" in type_option and type_option["type"] != "null":
field_type = type_option["type"]
break
# 根据类型进行转换
if field_type == "integer" and isinstance(value, str):
try:
normalized[key] = int(value)
except (ValueError, TypeError):
logger.warning(f"无法将参数 {key}='{value}' 转换为整数,保持原值")
normalized[key] = value
elif field_type == "number" and isinstance(value, str):
try:
normalized[key] = float(value)
except (ValueError, TypeError):
logger.warning(f"无法将参数 {key}='{value}' 转换为浮点数,保持原值")
normalized[key] = value
elif field_type == "boolean" and isinstance(value, str):
# 转换字符串为布尔值
normalized[key] = value.lower() in ("true", "1", "yes", "on")
else:
# 其他类型保持原样
normalized[key] = value
return normalized
async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> str:
"""
调用工具
@@ -120,8 +186,11 @@ class MoviePilotToolsManager:
return error_msg
try:
# 规范化参数类型
normalized_arguments = self._normalize_arguments(tool_instance, arguments)
# 调用工具的run方法
result = await tool_instance.run(**arguments)
result = await tool_instance.run(**normalized_arguments)
# 确保返回字符串
if isinstance(result, str):