mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-04-13 17:52:28 +08:00
Merge pull request #5273 from PKC278/v2
This commit is contained in:
@@ -33,7 +33,7 @@ def login_access_token(
|
||||
if user_or_message == "MFA_REQUIRED":
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="需要双重验证",
|
||||
detail="需要双重验证,请提供验证码或使用通行密钥",
|
||||
headers={"X-MFA-Required": "true"}
|
||||
)
|
||||
raise HTTPException(status_code=401, detail=user_or_message)
|
||||
@@ -57,7 +57,7 @@ def login_access_token(
|
||||
avatar=user_or_message.avatar,
|
||||
level=level,
|
||||
permissions=user_or_message.permissions or {},
|
||||
widzard=show_wizard
|
||||
wizard=show_wizard
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
通过HTTP API暴露MoviePilot的智能体工具功能
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from typing import List, Any, Dict, Annotated, Optional, Union
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Header
|
||||
@@ -25,64 +24,18 @@ router = APIRouter()
|
||||
MCP_PROTOCOL_VERSIONS = ["2025-11-25", "2025-06-18", "2024-11-05"]
|
||||
MCP_PROTOCOL_VERSION = MCP_PROTOCOL_VERSIONS[0] # 默认使用最新版本
|
||||
|
||||
# 全局会话管理器
|
||||
_sessions: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
# 全局工具管理器实例(单例模式,按用户ID缓存)
|
||||
_tools_managers: Dict[str, MoviePilotToolsManager] = {}
|
||||
|
||||
|
||||
def get_tools_manager(user_id: str = "mcp_user", session_id: str = "mcp_session") -> MoviePilotToolsManager:
|
||||
def get_tools_manager() -> MoviePilotToolsManager:
|
||||
"""
|
||||
获取工具管理器实例(按用户ID缓存)
|
||||
|
||||
获取工具管理器实例
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
session_id: 会话ID
|
||||
|
||||
|
||||
Returns:
|
||||
MoviePilotToolsManager实例
|
||||
"""
|
||||
global _tools_managers
|
||||
# 使用用户ID作为缓存键
|
||||
cache_key = f"{user_id}_{session_id}"
|
||||
if cache_key not in _tools_managers:
|
||||
_tools_managers[cache_key] = MoviePilotToolsManager(
|
||||
user_id=user_id,
|
||||
session_id=session_id
|
||||
)
|
||||
return _tools_managers[cache_key]
|
||||
|
||||
|
||||
def get_session(session_id: Optional[str]) -> Optional[Dict[str, Any]]:
|
||||
"""获取会话"""
|
||||
if not session_id:
|
||||
return None
|
||||
return _sessions.get(session_id)
|
||||
|
||||
|
||||
def create_session(user_id: str) -> Dict[str, Any]:
|
||||
"""创建新会话"""
|
||||
session_id = str(uuid.uuid4())
|
||||
session = {
|
||||
"id": session_id,
|
||||
"user_id": user_id,
|
||||
"initialized": False,
|
||||
"protocol_version": None,
|
||||
"capabilities": {}
|
||||
}
|
||||
_sessions[session_id] = session
|
||||
return session
|
||||
|
||||
|
||||
def delete_session(session_id: str):
|
||||
"""删除会话"""
|
||||
if session_id in _sessions:
|
||||
del _sessions[session_id]
|
||||
# 同时清理工具管理器缓存
|
||||
cache_key = f"{_sessions.get(session_id, {}).get('user_id', 'mcp_user')}_{session_id}"
|
||||
if cache_key in _tools_managers:
|
||||
del _tools_managers[cache_key]
|
||||
return MoviePilotToolsManager()
|
||||
|
||||
|
||||
def create_jsonrpc_response(request_id: Union[str, int, None], result: Any) -> Dict[str, Any]:
|
||||
@@ -112,13 +65,11 @@ def create_jsonrpc_error(request_id: Union[str, int, None], code: int, message:
|
||||
|
||||
# ==================== MCP JSON-RPC 端点 ====================
|
||||
|
||||
@router.post("", summary="MCP JSON-RPC 端点")
|
||||
@router.post("", summary="MCP JSON-RPC 端点", response_model=None)
|
||||
async def mcp_jsonrpc(
|
||||
request: Request,
|
||||
mcp_session_id: Optional[str] = Header(None, alias="MCP-Session-Id"),
|
||||
mcp_protocol_version: Optional[str] = Header(None, alias="MCP-Protocol-Version"),
|
||||
_: Annotated[str, Depends(verify_apikey)] = None
|
||||
) -> JSONResponse:
|
||||
) -> Union[JSONResponse, Response]:
|
||||
"""
|
||||
MCP 标准 JSON-RPC 2.0 端点
|
||||
|
||||
@@ -150,48 +101,27 @@ async def mcp_jsonrpc(
|
||||
try:
|
||||
# 处理初始化请求
|
||||
if method == "initialize":
|
||||
result = await handle_initialize(params, mcp_session_id)
|
||||
response = create_jsonrpc_response(request_id, result["result"])
|
||||
# 如果创建了新会话,在响应头中返回
|
||||
if "session_id" in result:
|
||||
headers = {"MCP-Session-Id": result["session_id"]}
|
||||
return JSONResponse(content=response, headers=headers)
|
||||
return JSONResponse(content=response)
|
||||
result = await handle_initialize(params)
|
||||
return JSONResponse(content=create_jsonrpc_response(request_id, result))
|
||||
|
||||
# 处理已初始化通知
|
||||
elif method == "notifications/initialized":
|
||||
if is_notification:
|
||||
session = get_session(mcp_session_id)
|
||||
if session:
|
||||
session["initialized"] = True
|
||||
# 通知不需要响应
|
||||
return Response(status_code=202)
|
||||
return Response(status_code=204)
|
||||
else:
|
||||
return JSONResponse(
|
||||
content=create_jsonrpc_error(request_id, -32600, "initialized must be a notification")
|
||||
)
|
||||
|
||||
# 验证会话(除了 initialize 和 ping)
|
||||
if method not in ["initialize", "ping"]:
|
||||
session = get_session(mcp_session_id)
|
||||
if not session:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content=create_jsonrpc_error(request_id, -32002, "Session not found")
|
||||
)
|
||||
if not session.get("initialized") and method != "notifications/initialized":
|
||||
return JSONResponse(
|
||||
content=create_jsonrpc_error(request_id, -32003, "Not initialized")
|
||||
status_code=400,
|
||||
content={"error": "initialized must be a notification"}
|
||||
)
|
||||
|
||||
# 处理工具列表请求
|
||||
if method == "tools/list":
|
||||
result = await handle_tools_list(mcp_session_id)
|
||||
result = await handle_tools_list()
|
||||
return JSONResponse(content=create_jsonrpc_response(request_id, result))
|
||||
|
||||
# 处理工具调用请求
|
||||
elif method == "tools/call":
|
||||
result = await handle_tools_call(params, mcp_session_id)
|
||||
result = await handle_tools_call(params)
|
||||
return JSONResponse(content=create_jsonrpc_response(request_id, result))
|
||||
|
||||
# 处理 ping 请求
|
||||
@@ -204,6 +134,12 @@ async def mcp_jsonrpc(
|
||||
content=create_jsonrpc_error(request_id, -32601, f"Method not found: {method}")
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.warning(f"MCP 请求参数错误: {e}")
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_jsonrpc_error(request_id, -32602, "Invalid params", str(e))
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"处理 MCP 请求失败: {e}", exc_info=True)
|
||||
return JSONResponse(
|
||||
@@ -212,24 +148,13 @@ async def mcp_jsonrpc(
|
||||
)
|
||||
|
||||
|
||||
async def handle_initialize(params: Dict[str, Any], session_id: Optional[str]) -> Dict[str, Any]:
|
||||
async def handle_initialize(params: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""处理初始化请求"""
|
||||
protocol_version = params.get("protocolVersion")
|
||||
client_info = params.get("clientInfo", {})
|
||||
client_capabilities = params.get("capabilities", {})
|
||||
|
||||
logger.info(f"MCP 初始化请求: 客户端={client_info.get('name')}, 协议版本={protocol_version}")
|
||||
|
||||
# 如果没有提供会话ID,创建新会话
|
||||
new_session = None
|
||||
if not session_id:
|
||||
new_session = create_session(user_id="mcp_user")
|
||||
session_id = new_session["id"]
|
||||
|
||||
session = get_session(session_id) or new_session
|
||||
if not session:
|
||||
raise ValueError("Failed to create session")
|
||||
|
||||
# 版本协商:选择客户端和服务器都支持的版本
|
||||
negotiated_version = MCP_PROTOCOL_VERSION
|
||||
if protocol_version in MCP_PROTOCOL_VERSIONS:
|
||||
@@ -240,40 +165,26 @@ async def handle_initialize(params: Dict[str, Any], session_id: Optional[str]) -
|
||||
# 客户端版本不支持,使用服务器默认版本
|
||||
logger.warning(f"协议版本不匹配: 客户端={protocol_version}, 使用服务器版本={negotiated_version}")
|
||||
|
||||
session["protocol_version"] = negotiated_version
|
||||
session["capabilities"] = client_capabilities
|
||||
|
||||
result = {
|
||||
"result": {
|
||||
"protocolVersion": negotiated_version,
|
||||
"capabilities": {
|
||||
"tools": {
|
||||
"listChanged": False # 暂不支持工具列表变更通知
|
||||
},
|
||||
"logging": {}
|
||||
return {
|
||||
"protocolVersion": negotiated_version,
|
||||
"capabilities": {
|
||||
"tools": {
|
||||
"listChanged": False # 暂不支持工具列表变更通知
|
||||
},
|
||||
"serverInfo": {
|
||||
"name": "MoviePilot",
|
||||
"version": APP_VERSION,
|
||||
"description": "MoviePilot MCP Server - 电影自动化管理工具",
|
||||
},
|
||||
"instructions": "MoviePilot MCP 服务器,提供媒体管理、订阅、下载等工具。使用 tools/list 查看所有可用工具。"
|
||||
}
|
||||
"logging": {}
|
||||
},
|
||||
"serverInfo": {
|
||||
"name": "MoviePilot",
|
||||
"version": APP_VERSION,
|
||||
"description": "MoviePilot MCP Server - 电影自动化管理工具",
|
||||
},
|
||||
"instructions": "MoviePilot MCP 服务器,提供媒体管理、订阅、下载等工具。"
|
||||
}
|
||||
|
||||
# 如果是新创建的会话,返回会话ID
|
||||
if new_session:
|
||||
result["session_id"] = session_id
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def handle_tools_list(session_id: Optional[str]) -> Dict[str, Any]:
|
||||
async def handle_tools_list() -> Dict[str, Any]:
|
||||
"""处理工具列表请求"""
|
||||
session = get_session(session_id)
|
||||
user_id = session.get("user_id", "mcp_user") if session else "mcp_user"
|
||||
|
||||
manager = get_tools_manager(user_id=user_id, session_id=session_id or "default")
|
||||
manager = get_tools_manager()
|
||||
tools = manager.list_tools()
|
||||
|
||||
# 转换为 MCP 工具格式
|
||||
@@ -291,7 +202,7 @@ async def handle_tools_list(session_id: Optional[str]) -> Dict[str, Any]:
|
||||
}
|
||||
|
||||
|
||||
async def handle_tools_call(params: Dict[str, Any], session_id: Optional[str]) -> Dict[str, Any]:
|
||||
async def handle_tools_call(params: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""处理工具调用请求"""
|
||||
tool_name = params.get("name")
|
||||
arguments = params.get("arguments", {})
|
||||
@@ -299,10 +210,7 @@ async def handle_tools_call(params: Dict[str, Any], session_id: Optional[str]) -
|
||||
if not tool_name:
|
||||
raise ValueError("Missing tool name")
|
||||
|
||||
session = get_session(session_id)
|
||||
user_id = session.get("user_id", "mcp_user") if session else "mcp_user"
|
||||
|
||||
manager = get_tools_manager(user_id=user_id, session_id=session_id or "default")
|
||||
manager = get_tools_manager()
|
||||
|
||||
try:
|
||||
result_text = await manager.call_tool(tool_name, arguments)
|
||||
@@ -328,26 +236,18 @@ async def handle_tools_call(params: Dict[str, Any], session_id: Optional[str]) -
|
||||
}
|
||||
|
||||
|
||||
@router.delete("", summary="终止 MCP 会话")
|
||||
@router.delete("", summary="终止 MCP 会话", response_model=None)
|
||||
async def delete_mcp_session(
|
||||
mcp_session_id: Optional[str] = Header(None, alias="MCP-Session-Id"),
|
||||
_: Annotated[str, Depends(verify_apikey)] = None
|
||||
) -> JSONResponse:
|
||||
) -> Union[JSONResponse, Response]:
|
||||
"""
|
||||
终止 MCP 会话(可选实现)
|
||||
|
||||
客户端可以主动调用此接口终止会话
|
||||
终止 MCP 会话(无状态模式下仅返回成功)
|
||||
"""
|
||||
if not mcp_session_id:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={"detail": "Missing MCP-Session-Id header"}
|
||||
)
|
||||
|
||||
delete_session(mcp_session_id)
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
|
||||
|
||||
# ==================== 兼容的 RESTful API 端点 ====================
|
||||
|
||||
@router.get("/tools", summary="列出所有可用工具", response_model=List[Dict[str, Any]])
|
||||
|
||||
@@ -2,9 +2,8 @@
|
||||
MFA (Multi-Factor Authentication) API 端点
|
||||
包含 OTP 和 PassKey 相关功能
|
||||
"""
|
||||
import base64
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Annotated, List, Union
|
||||
from datetime import timedelta
|
||||
from typing import Any, Annotated, Optional, List, Union
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Body
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
@@ -25,6 +24,22 @@ from app.utils.otp import OtpUtils
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# ==================== 请求模型 ====================
|
||||
|
||||
class OtpVerifyRequest(schemas.BaseModel):
|
||||
"""OTP验证请求"""
|
||||
uri: str
|
||||
otpPassword: str
|
||||
|
||||
class OtpDisableRequest(schemas.BaseModel):
|
||||
"""OTP禁用请求"""
|
||||
password: str
|
||||
|
||||
class PassKeyDeleteRequest(schemas.BaseModel):
|
||||
"""PassKey删除请求"""
|
||||
passkey_id: int
|
||||
password: str
|
||||
|
||||
# ==================== 通用 MFA 接口 ====================
|
||||
|
||||
@router.get('/status/{username}', summary='判断用户是否开启双重验证(MFA)', response_model=schemas.Response)
|
||||
@@ -40,7 +55,7 @@ async def mfa_status(username: str, db: AsyncSession = Depends(get_async_db)) ->
|
||||
has_otp = user.is_otp
|
||||
|
||||
# 检查是否有PassKey
|
||||
has_passkey = bool(PassKey().get_by_user_id(db=None, user_id=user.id))
|
||||
has_passkey = bool(await PassKey.async_get_by_user_id(db=db, user_id=user.id))
|
||||
|
||||
# 只要有任何一种验证方式,就需要双重验证
|
||||
return schemas.Response(success=(has_otp or has_passkey))
|
||||
@@ -59,25 +74,35 @@ def otp_generate(
|
||||
|
||||
@router.post('/otp/verify', summary='绑定并验证 OTP', response_model=schemas.Response)
|
||||
async def otp_verify(
|
||||
data: dict,
|
||||
data: OtpVerifyRequest,
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
current_user: User = Depends(get_current_active_user_async)
|
||||
) -> Any:
|
||||
"""验证用户输入的 OTP 码,验证通过后正式开启 OTP 验证"""
|
||||
uri = data.get("uri")
|
||||
otp_password = data.get("otpPassword")
|
||||
if not OtpUtils.is_legal(uri, otp_password):
|
||||
if not OtpUtils.is_legal(data.uri, data.otpPassword):
|
||||
return schemas.Response(success=False, message="验证码错误")
|
||||
await current_user.async_update_otp_by_name(db, current_user.name, True, OtpUtils.get_secret(uri))
|
||||
await current_user.async_update_otp_by_name(db, current_user.name, True, OtpUtils.get_secret(data.uri))
|
||||
return schemas.Response(success=True)
|
||||
|
||||
|
||||
@router.post('/otp/disable', summary='关闭当前用户的 OTP 验证', response_model=schemas.Response)
|
||||
async def otp_disable(
|
||||
data: OtpDisableRequest,
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
current_user: User = Depends(get_current_active_user_async)
|
||||
) -> Any:
|
||||
"""关闭当前用户的 OTP 验证功能"""
|
||||
# 安全检查:如果存在 PassKey,不允许关闭 OTP
|
||||
has_passkey = bool(await PassKey.async_get_by_user_id(db=db, user_id=current_user.id))
|
||||
if has_passkey:
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message="您已注册通行密钥,为了防止域名配置变更导致无法登录,请先删除所有通行密钥再关闭 OTP 验证"
|
||||
)
|
||||
|
||||
# 验证密码
|
||||
if not security.verify_password(data.password, str(current_user.hashed_password)):
|
||||
return schemas.Response(success=False, message="密码错误")
|
||||
await current_user.async_update_otp_by_name(db, current_user.name, False, "")
|
||||
return schemas.Response(success=True)
|
||||
|
||||
@@ -98,7 +123,7 @@ class PassKeyRegistrationFinish(schemas.BaseModel):
|
||||
|
||||
class PassKeyAuthenticationStart(schemas.BaseModel):
|
||||
"""PassKey认证开始请求"""
|
||||
username: str | None = None
|
||||
username: Optional[str] = None
|
||||
|
||||
|
||||
class PassKeyAuthenticationFinish(schemas.BaseModel):
|
||||
@@ -122,7 +147,7 @@ def passkey_register_start(
|
||||
)
|
||||
|
||||
# 获取用户已有的PassKey
|
||||
existing_passkeys = PassKey().get_by_user_id(db=None, user_id=current_user.id)
|
||||
existing_passkeys = PassKey.get_by_user_id(db=None, user_id=current_user.id)
|
||||
existing_credentials = [
|
||||
{
|
||||
'credential_id': pk.credential_id,
|
||||
@@ -215,7 +240,7 @@ def passkey_authenticate_start(
|
||||
message="用户不存在"
|
||||
)
|
||||
|
||||
existing_passkeys = PassKey().get_by_user_id(db=None, user_id=user.id)
|
||||
existing_passkeys = PassKey.get_by_user_id(db=None, user_id=user.id)
|
||||
if not existing_passkeys:
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
@@ -265,7 +290,7 @@ def passkey_authenticate_finish(
|
||||
credential_id = PassKeyHelper.standardize_credential_id(credential_id_raw)
|
||||
|
||||
# 查找PassKey
|
||||
passkey = PassKey().get_by_credential_id(db=None, credential_id=credential_id)
|
||||
passkey = PassKey.get_by_credential_id(db=None, credential_id=credential_id)
|
||||
if not passkey:
|
||||
raise HTTPException(status_code=401, detail="通行密钥不存在或已失效")
|
||||
|
||||
@@ -286,7 +311,7 @@ def passkey_authenticate_finish(
|
||||
raise HTTPException(status_code=401, detail="通行密钥验证失败")
|
||||
|
||||
# 更新使用时间和签名计数
|
||||
passkey.update_last_used(db=None, credential_id=credential_id, sign_count=new_sign_count)
|
||||
passkey.update_last_used(db=None, sign_count=new_sign_count)
|
||||
|
||||
logger.info(f"用户 {user.name} 通过PassKey认证成功")
|
||||
|
||||
@@ -309,7 +334,7 @@ def passkey_authenticate_finish(
|
||||
avatar=user.avatar,
|
||||
level=level,
|
||||
permissions=user.permissions or {},
|
||||
widzard=show_wizard
|
||||
wizard=show_wizard
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
@@ -324,7 +349,7 @@ def passkey_list(
|
||||
) -> Any:
|
||||
"""获取当前用户的所有 PassKey"""
|
||||
try:
|
||||
passkeys = PassKey().get_by_user_id(db=None, user_id=current_user.id)
|
||||
passkeys = PassKey.get_by_user_id(db=None, user_id=current_user.id)
|
||||
|
||||
passkey_list = [
|
||||
{
|
||||
@@ -350,17 +375,21 @@ def passkey_list(
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/passkey/{passkey_id}", summary="删除 PassKey", response_model=schemas.Response)
|
||||
def passkey_delete(
|
||||
passkey_id: int,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)]
|
||||
@router.post("/passkey/delete", summary="删除 PassKey", response_model=schemas.Response)
|
||||
async def passkey_delete(
|
||||
data: PassKeyDeleteRequest,
|
||||
current_user: User = Depends(get_current_active_user_async)
|
||||
) -> Any:
|
||||
"""删除指定的 PassKey"""
|
||||
try:
|
||||
success = PassKey().delete_by_id(db=None, passkey_id=passkey_id, user_id=current_user.id)
|
||||
# 验证密码
|
||||
if not security.verify_password(data.password, str(current_user.hashed_password)):
|
||||
return schemas.Response(success=False, message="密码错误")
|
||||
|
||||
success = PassKey.delete_by_id(db=None, passkey_id=data.passkey_id, user_id=current_user.id)
|
||||
|
||||
if success:
|
||||
logger.info(f"用户 {current_user.name} 删除了PassKey: {passkey_id}")
|
||||
logger.info(f"用户 {current_user.name} 删除了PassKey: {data.passkey_id}")
|
||||
return schemas.Response(
|
||||
success=True,
|
||||
message="通行密钥已删除"
|
||||
@@ -397,7 +426,7 @@ def passkey_verify_mfa(
|
||||
credential_id = PassKeyHelper.standardize_credential_id(credential_id_raw)
|
||||
|
||||
# 查找PassKey(必须属于当前用户)
|
||||
passkey = PassKey().get_by_credential_id(db=None, credential_id=credential_id)
|
||||
passkey = PassKey.get_by_credential_id(db=None, credential_id=credential_id)
|
||||
if not passkey or passkey.user_id != current_user.id:
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
@@ -419,7 +448,7 @@ def passkey_verify_mfa(
|
||||
)
|
||||
|
||||
# 更新使用时间和签名计数
|
||||
passkey.update_last_used(db=None, credential_id=credential_id, sign_count=new_sign_count)
|
||||
passkey.update_last_used(db=None, sign_count=new_sign_count)
|
||||
|
||||
logger.info(f"用户 {current_user.name} 通过PassKey二次验证成功")
|
||||
|
||||
|
||||
@@ -179,7 +179,7 @@ class UserChain(ChainBase):
|
||||
"""
|
||||
# 检查用户是否有PassKey
|
||||
from app.db.models.passkey import PassKey
|
||||
has_passkey = bool(PassKey().get_by_user_id(db=None, user_id=user.id))
|
||||
has_passkey = bool(PassKey.get_by_user_id(db=None, user_id=user.id))
|
||||
|
||||
# 如果用户既没有启用OTP也没有PassKey,直接通过
|
||||
if not user.is_otp and not has_passkey:
|
||||
@@ -190,11 +190,20 @@ class UserChain(ChainBase):
|
||||
logger.info(f"用户 {user.name} 已启用双重验证(OTP: {user.is_otp}, PassKey: {has_passkey}),需要提供验证码")
|
||||
return "MFA_REQUIRED"
|
||||
|
||||
# 如果提供了验证码,验证OTP
|
||||
# 如果提供了验证码,且用户启用了 OTP,则验证 OTP
|
||||
if user.is_otp:
|
||||
if not OtpUtils.check(str(user.otp_secret), mfa_code):
|
||||
logger.info(f"用户 {user.name} 的 MFA 认证失败")
|
||||
return False
|
||||
# OTP 验证成功
|
||||
return True
|
||||
|
||||
# 用户未启用 OTP,此时提供的 mfa_code 无效;如果启用了 PassKey,则仍需通过 PassKey 验证
|
||||
if has_passkey:
|
||||
logger.info(
|
||||
f"用户 {user.name} 未启用 OTP,但已启用 PassKey,提供的 MFA 验证码将被忽略,仍需通过 PassKey 验证"
|
||||
)
|
||||
return "MFA_REQUIRED"
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@@ -37,14 +37,14 @@ class PassKey(Base):
|
||||
@db_query
|
||||
def get_by_user_id(cls, db: Session, user_id: int):
|
||||
"""获取用户的所有PassKey"""
|
||||
return db.query(cls).filter(cls.user_id == user_id, cls.is_active == True).all()
|
||||
return db.query(cls).filter(cls.user_id == user_id, cls.is_active.is_(True)).all()
|
||||
|
||||
@classmethod
|
||||
@async_db_query
|
||||
async def async_get_by_user_id(cls, db: AsyncSession, user_id: int):
|
||||
"""异步获取用户的所有PassKey"""
|
||||
result = await db.execute(
|
||||
select(cls).filter(cls.user_id == user_id, cls.is_active == True)
|
||||
select(cls).filter(cls.user_id == user_id, cls.is_active.is_(True))
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
@@ -52,14 +52,14 @@ class PassKey(Base):
|
||||
@db_query
|
||||
def get_by_credential_id(cls, db: Session, credential_id: str):
|
||||
"""根据凭证ID获取PassKey"""
|
||||
return db.query(cls).filter(cls.credential_id == credential_id, cls.is_active == True).first()
|
||||
return db.query(cls).filter(cls.credential_id == credential_id, cls.is_active.is_(True)).first()
|
||||
|
||||
@classmethod
|
||||
@async_db_query
|
||||
async def async_get_by_credential_id(cls, db: AsyncSession, credential_id: str):
|
||||
"""异步根据凭证ID获取PassKey"""
|
||||
result = await db.execute(
|
||||
select(cls).filter(cls.credential_id == credential_id, cls.is_active == True)
|
||||
select(cls).filter(cls.credential_id == credential_id, cls.is_active.is_(True))
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
@@ -78,25 +78,27 @@ class PassKey(Base):
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
@classmethod
|
||||
@db_update
|
||||
def delete_by_id(self, db: Session, passkey_id: int, user_id: int):
|
||||
def delete_by_id(cls, db: Session, passkey_id: int, user_id: int):
|
||||
"""删除指定用户的PassKey"""
|
||||
passkey = db.query(PassKey).filter(
|
||||
PassKey.id == passkey_id,
|
||||
PassKey.user_id == user_id
|
||||
passkey = db.query(cls).filter(
|
||||
cls.id == passkey_id,
|
||||
cls.user_id == user_id
|
||||
).first()
|
||||
if passkey:
|
||||
passkey.delete(db, passkey.id)
|
||||
return True
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
@async_db_update
|
||||
async def async_delete_by_id(self, db: AsyncSession, passkey_id: int, user_id: int):
|
||||
async def async_delete_by_id(cls, db: AsyncSession, passkey_id: int, user_id: int):
|
||||
"""异步删除指定用户的PassKey"""
|
||||
result = await db.execute(
|
||||
select(PassKey).filter(
|
||||
PassKey.id == passkey_id,
|
||||
PassKey.user_id == user_id
|
||||
select(cls).filter(
|
||||
cls.id == passkey_id,
|
||||
cls.user_id == user_id
|
||||
)
|
||||
)
|
||||
passkey = result.scalars().first()
|
||||
@@ -106,25 +108,19 @@ class PassKey(Base):
|
||||
return False
|
||||
|
||||
@db_update
|
||||
def update_last_used(self, db: Session, credential_id: str, sign_count: int):
|
||||
def update_last_used(self, db: Session, sign_count: int):
|
||||
"""更新最后使用时间和签名计数"""
|
||||
passkey = self.get_by_credential_id(db, credential_id)
|
||||
if passkey:
|
||||
passkey.update(db, {
|
||||
'last_used_at': datetime.now(),
|
||||
'sign_count': sign_count
|
||||
})
|
||||
return True
|
||||
return False
|
||||
self.update(db, {
|
||||
'last_used_at': datetime.now(),
|
||||
'sign_count': sign_count
|
||||
})
|
||||
return True
|
||||
|
||||
@async_db_update
|
||||
async def async_update_last_used(self, db: AsyncSession, credential_id: str, sign_count: int):
|
||||
async def async_update_last_used(self, db: AsyncSession, sign_count: int):
|
||||
"""异步更新最后使用时间和签名计数"""
|
||||
passkey = await self.async_get_by_credential_id(db, credential_id)
|
||||
if passkey:
|
||||
await passkey.async_update(db, {
|
||||
'last_used_at': datetime.now(),
|
||||
'sign_count': sign_count
|
||||
})
|
||||
return True
|
||||
return False
|
||||
await self.async_update(db, {
|
||||
'last_used_at': datetime.now(),
|
||||
'sign_count': sign_count
|
||||
})
|
||||
return True
|
||||
|
||||
@@ -3,6 +3,7 @@ PassKey WebAuthn 辅助工具类
|
||||
"""
|
||||
import base64
|
||||
import json
|
||||
import binascii
|
||||
from typing import Optional, Tuple, List, Dict, Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
@@ -23,10 +24,6 @@ from webauthn.helpers.structs import (
|
||||
UserVerificationRequirement,
|
||||
AuthenticatorAttachment,
|
||||
ResidentKeyRequirement,
|
||||
PublicKeyCredentialCreationOptions,
|
||||
PublicKeyCredentialRequestOptions,
|
||||
RegistrationCredential,
|
||||
AuthenticationCredential,
|
||||
AuthenticatorSelectionCriteria
|
||||
)
|
||||
from webauthn.helpers.cose import COSEAlgorithmIdentifier
|
||||
@@ -46,6 +43,14 @@ class PassKeyHelper:
|
||||
获取 Relying Party ID
|
||||
"""
|
||||
if settings.APP_DOMAIN:
|
||||
app_domain = settings.APP_DOMAIN.strip()
|
||||
# 确保存在协议前缀,以便 urlparse 正确解析主机和端口
|
||||
if not app_domain.startswith(('http://', 'https://')):
|
||||
app_domain = f'https://{app_domain}'
|
||||
parsed = urlparse(app_domain)
|
||||
host = parsed.hostname
|
||||
if host:
|
||||
return host
|
||||
# 从 APP_DOMAIN 中提取域名
|
||||
host = settings.APP_DOMAIN.replace('https://', '').replace('http://', '')
|
||||
# 移除端口号
|
||||
@@ -81,7 +86,7 @@ class PassKeyHelper:
|
||||
# Base64解码并重新编码以标准化格式
|
||||
decoded = base64.urlsafe_b64decode(credential_id + '==')
|
||||
return base64.urlsafe_b64encode(decoded).decode('utf-8').rstrip('=')
|
||||
except Exception as e:
|
||||
except (binascii.Error, TypeError, ValueError) as e:
|
||||
logger.error(f"标准化凭证ID失败: {e}")
|
||||
return credential_id
|
||||
|
||||
@@ -135,7 +140,7 @@ class PassKeyHelper:
|
||||
user_display_name=display_name or username,
|
||||
exclude_credentials=exclude_credentials if exclude_credentials else None,
|
||||
authenticator_selection=AuthenticatorSelectionCriteria(
|
||||
authenticator_attachment=AuthenticatorAttachment.PLATFORM,
|
||||
authenticator_attachment=None,
|
||||
resident_key=ResidentKeyRequirement.REQUIRED,
|
||||
user_verification=uv_requirement,
|
||||
),
|
||||
|
||||
@@ -21,7 +21,7 @@ class Token(BaseModel):
|
||||
# 详细权限
|
||||
permissions: Optional[dict] = Field(default_factory=dict)
|
||||
# 是否显示配置向导
|
||||
widzard: Optional[bool] = None
|
||||
wizard: Optional[bool] = None
|
||||
|
||||
|
||||
class TokenPayload(BaseModel):
|
||||
|
||||
@@ -16,11 +16,12 @@ MoviePilot 实现了标准的 **Model Context Protocol (MCP)**,允许 AI 智
|
||||
### 端点
|
||||
**POST** `/api/v1/mcp`
|
||||
|
||||
## 3. 会话管理
|
||||
|
||||
* **会话维持**: 在标准 MCP 流程中,通过 HTTP Header `MCP-Session-Id` 识别会话。
|
||||
* **主动终止**:
|
||||
**DELETE** `/api/v1/mcp` (携带 `MCP-Session-Id` Header)
|
||||
### 支持的方法
|
||||
- `initialize`: 初始化会话,协商协议版本和能力。
|
||||
- `notifications/initialized`: 客户端确认初始化完成。
|
||||
- `tools/list`: 获取可用工具列表。
|
||||
- `tools/call`: 调用特定工具。
|
||||
- `ping`: 连接存活检测。
|
||||
|
||||
---
|
||||
|
||||
@@ -65,6 +66,7 @@ MoviePilot 实现了标准的 **Model Context Protocol (MCP)**,允许 AI 智
|
||||
| -32700 | Parse error | JSON 格式错误 |
|
||||
| -32600 | Invalid Request | 无效的 JSON-RPC 请求 |
|
||||
| -32601 | Method not found | 方法不存在 |
|
||||
| -32602 | Invalid params | 参数验证失败 |
|
||||
| -32002 | Session not found | 会话不存在或已过期 |
|
||||
| -32003 | Not initialized | 会话未完成初始化流程 |
|
||||
| -32603 | Internal error | 服务器内部错误 |
|
||||
@@ -203,10 +205,3 @@ MoviePilot 实现了标准的 **Model Context Protocol (MCP)**,允许 AI 智
|
||||
"required": ["title", "year", "media_type"]
|
||||
}
|
||||
```
|
||||
|
||||
## 7. 注意事项
|
||||
|
||||
1. **用户上下文**: API调用会使用当前认证用户的ID作为工具执行的用户上下文
|
||||
2. **会话隔离**: 每个API请求使用独立的会话ID
|
||||
3. **参数验证**: 工具参数会根据JSON Schema进行验证
|
||||
4. **错误日志**: 所有工具调用错误都会记录到MoviePilot日志系统
|
||||
|
||||
Reference in New Issue
Block a user