From 7b99f09810b6496b153e5c9a24ea503ebf16a2d3 Mon Sep 17 00:00:00 2001 From: PKC278 <52959804+PKC278@users.noreply.github.com> Date: Tue, 23 Dec 2025 14:57:43 +0800 Subject: [PATCH 1/2] =?UTF-8?q?fix(mfa):=20=E4=BF=AE=E5=A4=8D=E5=8F=8C?= =?UTF-8?q?=E9=87=8D=E9=AA=8C=E8=AF=81=E6=BC=8F=E6=B4=9E?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/endpoints/login.py | 4 +- app/api/endpoints/mfa.py | 77 ++++++++++++++++++++++++++------------ app/chain/user.py | 13 ++++++- app/db/models/passkey.py | 56 +++++++++++++-------------- app/helper/passkey.py | 17 ++++++--- app/schemas/token.py | 2 +- 6 files changed, 104 insertions(+), 65 deletions(-) diff --git a/app/api/endpoints/login.py b/app/api/endpoints/login.py index 769ad445..928ba16f 100644 --- a/app/api/endpoints/login.py +++ b/app/api/endpoints/login.py @@ -33,7 +33,7 @@ def login_access_token( if user_or_message == "MFA_REQUIRED": raise HTTPException( status_code=401, - detail="需要双重验证", + detail="需要双重验证,请提供验证码或使用通行密钥", headers={"X-MFA-Required": "true"} ) raise HTTPException(status_code=401, detail=user_or_message) @@ -57,7 +57,7 @@ def login_access_token( avatar=user_or_message.avatar, level=level, permissions=user_or_message.permissions or {}, - widzard=show_wizard + wizard=show_wizard ) diff --git a/app/api/endpoints/mfa.py b/app/api/endpoints/mfa.py index f4498c25..c0f63d02 100644 --- a/app/api/endpoints/mfa.py +++ b/app/api/endpoints/mfa.py @@ -2,9 +2,8 @@ MFA (Multi-Factor Authentication) API 端点 包含 OTP 和 PassKey 相关功能 """ -import base64 -from datetime import datetime, timedelta -from typing import Any, Annotated, List, Union +from datetime import timedelta +from typing import Any, Annotated, Optional, List, Union from fastapi import APIRouter, Depends, HTTPException, Body from sqlalchemy.ext.asyncio import AsyncSession @@ -25,6 +24,22 @@ from app.utils.otp import OtpUtils router = APIRouter() +# ==================== 请求模型 ==================== + +class OtpVerifyRequest(schemas.BaseModel): + """OTP验证请求""" + uri: str + otpPassword: str + +class OtpDisableRequest(schemas.BaseModel): + """OTP禁用请求""" + password: str + +class PassKeyDeleteRequest(schemas.BaseModel): + """PassKey删除请求""" + passkey_id: int + password: str + # ==================== 通用 MFA 接口 ==================== @router.get('/status/{username}', summary='判断用户是否开启双重验证(MFA)', response_model=schemas.Response) @@ -40,7 +55,7 @@ async def mfa_status(username: str, db: AsyncSession = Depends(get_async_db)) -> has_otp = user.is_otp # 检查是否有PassKey - has_passkey = bool(PassKey().get_by_user_id(db=None, user_id=user.id)) + has_passkey = bool(await PassKey.async_get_by_user_id(db=db, user_id=user.id)) # 只要有任何一种验证方式,就需要双重验证 return schemas.Response(success=(has_otp or has_passkey)) @@ -59,25 +74,35 @@ def otp_generate( @router.post('/otp/verify', summary='绑定并验证 OTP', response_model=schemas.Response) async def otp_verify( - data: dict, + data: OtpVerifyRequest, db: AsyncSession = Depends(get_async_db), current_user: User = Depends(get_current_active_user_async) ) -> Any: """验证用户输入的 OTP 码,验证通过后正式开启 OTP 验证""" - uri = data.get("uri") - otp_password = data.get("otpPassword") - if not OtpUtils.is_legal(uri, otp_password): + if not OtpUtils.is_legal(data.uri, data.otpPassword): return schemas.Response(success=False, message="验证码错误") - await current_user.async_update_otp_by_name(db, current_user.name, True, OtpUtils.get_secret(uri)) + await current_user.async_update_otp_by_name(db, current_user.name, True, OtpUtils.get_secret(data.uri)) return schemas.Response(success=True) @router.post('/otp/disable', summary='关闭当前用户的 OTP 验证', response_model=schemas.Response) async def otp_disable( + data: OtpDisableRequest, db: AsyncSession = Depends(get_async_db), current_user: User = Depends(get_current_active_user_async) ) -> Any: """关闭当前用户的 OTP 验证功能""" + # 安全检查:如果存在 PassKey,不允许关闭 OTP + has_passkey = bool(await PassKey.async_get_by_user_id(db=db, user_id=current_user.id)) + if has_passkey: + return schemas.Response( + success=False, + message="您已注册通行密钥,为了防止域名配置变更导致无法登录,请先删除所有通行密钥再关闭 OTP 验证" + ) + + # 验证密码 + if not security.verify_password(data.password, str(current_user.hashed_password)): + return schemas.Response(success=False, message="密码错误") await current_user.async_update_otp_by_name(db, current_user.name, False, "") return schemas.Response(success=True) @@ -98,7 +123,7 @@ class PassKeyRegistrationFinish(schemas.BaseModel): class PassKeyAuthenticationStart(schemas.BaseModel): """PassKey认证开始请求""" - username: str | None = None + username: Optional[str] = None class PassKeyAuthenticationFinish(schemas.BaseModel): @@ -122,7 +147,7 @@ def passkey_register_start( ) # 获取用户已有的PassKey - existing_passkeys = PassKey().get_by_user_id(db=None, user_id=current_user.id) + existing_passkeys = PassKey.get_by_user_id(db=None, user_id=current_user.id) existing_credentials = [ { 'credential_id': pk.credential_id, @@ -215,7 +240,7 @@ def passkey_authenticate_start( message="用户不存在" ) - existing_passkeys = PassKey().get_by_user_id(db=None, user_id=user.id) + existing_passkeys = PassKey.get_by_user_id(db=None, user_id=user.id) if not existing_passkeys: return schemas.Response( success=False, @@ -265,7 +290,7 @@ def passkey_authenticate_finish( credential_id = PassKeyHelper.standardize_credential_id(credential_id_raw) # 查找PassKey - passkey = PassKey().get_by_credential_id(db=None, credential_id=credential_id) + passkey = PassKey.get_by_credential_id(db=None, credential_id=credential_id) if not passkey: raise HTTPException(status_code=401, detail="通行密钥不存在或已失效") @@ -286,7 +311,7 @@ def passkey_authenticate_finish( raise HTTPException(status_code=401, detail="通行密钥验证失败") # 更新使用时间和签名计数 - passkey.update_last_used(db=None, credential_id=credential_id, sign_count=new_sign_count) + passkey.update_last_used(db=None, sign_count=new_sign_count) logger.info(f"用户 {user.name} 通过PassKey认证成功") @@ -309,7 +334,7 @@ def passkey_authenticate_finish( avatar=user.avatar, level=level, permissions=user.permissions or {}, - widzard=show_wizard + wizard=show_wizard ) except HTTPException: raise @@ -324,7 +349,7 @@ def passkey_list( ) -> Any: """获取当前用户的所有 PassKey""" try: - passkeys = PassKey().get_by_user_id(db=None, user_id=current_user.id) + passkeys = PassKey.get_by_user_id(db=None, user_id=current_user.id) passkey_list = [ { @@ -350,17 +375,21 @@ def passkey_list( ) -@router.delete("/passkey/{passkey_id}", summary="删除 PassKey", response_model=schemas.Response) -def passkey_delete( - passkey_id: int, - current_user: Annotated[User, Depends(get_current_active_user)] +@router.post("/passkey/delete", summary="删除 PassKey", response_model=schemas.Response) +async def passkey_delete( + data: PassKeyDeleteRequest, + current_user: User = Depends(get_current_active_user_async) ) -> Any: """删除指定的 PassKey""" try: - success = PassKey().delete_by_id(db=None, passkey_id=passkey_id, user_id=current_user.id) + # 验证密码 + if not security.verify_password(data.password, str(current_user.hashed_password)): + return schemas.Response(success=False, message="密码错误") + + success = PassKey.delete_by_id(db=None, passkey_id=data.passkey_id, user_id=current_user.id) if success: - logger.info(f"用户 {current_user.name} 删除了PassKey: {passkey_id}") + logger.info(f"用户 {current_user.name} 删除了PassKey: {data.passkey_id}") return schemas.Response( success=True, message="通行密钥已删除" @@ -397,7 +426,7 @@ def passkey_verify_mfa( credential_id = PassKeyHelper.standardize_credential_id(credential_id_raw) # 查找PassKey(必须属于当前用户) - passkey = PassKey().get_by_credential_id(db=None, credential_id=credential_id) + passkey = PassKey.get_by_credential_id(db=None, credential_id=credential_id) if not passkey or passkey.user_id != current_user.id: return schemas.Response( success=False, @@ -419,7 +448,7 @@ def passkey_verify_mfa( ) # 更新使用时间和签名计数 - passkey.update_last_used(db=None, credential_id=credential_id, sign_count=new_sign_count) + passkey.update_last_used(db=None, sign_count=new_sign_count) logger.info(f"用户 {current_user.name} 通过PassKey二次验证成功") diff --git a/app/chain/user.py b/app/chain/user.py index b920f490..2223487c 100644 --- a/app/chain/user.py +++ b/app/chain/user.py @@ -179,7 +179,7 @@ class UserChain(ChainBase): """ # 检查用户是否有PassKey from app.db.models.passkey import PassKey - has_passkey = bool(PassKey().get_by_user_id(db=None, user_id=user.id)) + has_passkey = bool(PassKey.get_by_user_id(db=None, user_id=user.id)) # 如果用户既没有启用OTP也没有PassKey,直接通过 if not user.is_otp and not has_passkey: @@ -190,11 +190,20 @@ class UserChain(ChainBase): logger.info(f"用户 {user.name} 已启用双重验证(OTP: {user.is_otp}, PassKey: {has_passkey}),需要提供验证码") return "MFA_REQUIRED" - # 如果提供了验证码,验证OTP + # 如果提供了验证码,且用户启用了 OTP,则验证 OTP if user.is_otp: if not OtpUtils.check(str(user.otp_secret), mfa_code): logger.info(f"用户 {user.name} 的 MFA 认证失败") return False + # OTP 验证成功 + return True + + # 用户未启用 OTP,此时提供的 mfa_code 无效;如果启用了 PassKey,则仍需通过 PassKey 验证 + if has_passkey: + logger.info( + f"用户 {user.name} 未启用 OTP,但已启用 PassKey,提供的 MFA 验证码将被忽略,仍需通过 PassKey 验证" + ) + return "MFA_REQUIRED" return True diff --git a/app/db/models/passkey.py b/app/db/models/passkey.py index 129d29e9..971a94c3 100644 --- a/app/db/models/passkey.py +++ b/app/db/models/passkey.py @@ -37,14 +37,14 @@ class PassKey(Base): @db_query def get_by_user_id(cls, db: Session, user_id: int): """获取用户的所有PassKey""" - return db.query(cls).filter(cls.user_id == user_id, cls.is_active == True).all() + return db.query(cls).filter(cls.user_id == user_id, cls.is_active.is_(True)).all() @classmethod @async_db_query async def async_get_by_user_id(cls, db: AsyncSession, user_id: int): """异步获取用户的所有PassKey""" result = await db.execute( - select(cls).filter(cls.user_id == user_id, cls.is_active == True) + select(cls).filter(cls.user_id == user_id, cls.is_active.is_(True)) ) return result.scalars().all() @@ -52,14 +52,14 @@ class PassKey(Base): @db_query def get_by_credential_id(cls, db: Session, credential_id: str): """根据凭证ID获取PassKey""" - return db.query(cls).filter(cls.credential_id == credential_id, cls.is_active == True).first() + return db.query(cls).filter(cls.credential_id == credential_id, cls.is_active.is_(True)).first() @classmethod @async_db_query async def async_get_by_credential_id(cls, db: AsyncSession, credential_id: str): """异步根据凭证ID获取PassKey""" result = await db.execute( - select(cls).filter(cls.credential_id == credential_id, cls.is_active == True) + select(cls).filter(cls.credential_id == credential_id, cls.is_active.is_(True)) ) return result.scalars().first() @@ -78,25 +78,27 @@ class PassKey(Base): ) return result.scalars().first() + @classmethod @db_update - def delete_by_id(self, db: Session, passkey_id: int, user_id: int): + def delete_by_id(cls, db: Session, passkey_id: int, user_id: int): """删除指定用户的PassKey""" - passkey = db.query(PassKey).filter( - PassKey.id == passkey_id, - PassKey.user_id == user_id + passkey = db.query(cls).filter( + cls.id == passkey_id, + cls.user_id == user_id ).first() if passkey: passkey.delete(db, passkey.id) return True return False + @classmethod @async_db_update - async def async_delete_by_id(self, db: AsyncSession, passkey_id: int, user_id: int): + async def async_delete_by_id(cls, db: AsyncSession, passkey_id: int, user_id: int): """异步删除指定用户的PassKey""" result = await db.execute( - select(PassKey).filter( - PassKey.id == passkey_id, - PassKey.user_id == user_id + select(cls).filter( + cls.id == passkey_id, + cls.user_id == user_id ) ) passkey = result.scalars().first() @@ -106,25 +108,19 @@ class PassKey(Base): return False @db_update - def update_last_used(self, db: Session, credential_id: str, sign_count: int): + def update_last_used(self, db: Session, sign_count: int): """更新最后使用时间和签名计数""" - passkey = self.get_by_credential_id(db, credential_id) - if passkey: - passkey.update(db, { - 'last_used_at': datetime.now(), - 'sign_count': sign_count - }) - return True - return False + self.update(db, { + 'last_used_at': datetime.now(), + 'sign_count': sign_count + }) + return True @async_db_update - async def async_update_last_used(self, db: AsyncSession, credential_id: str, sign_count: int): + async def async_update_last_used(self, db: AsyncSession, sign_count: int): """异步更新最后使用时间和签名计数""" - passkey = await self.async_get_by_credential_id(db, credential_id) - if passkey: - await passkey.async_update(db, { - 'last_used_at': datetime.now(), - 'sign_count': sign_count - }) - return True - return False + await self.async_update(db, { + 'last_used_at': datetime.now(), + 'sign_count': sign_count + }) + return True diff --git a/app/helper/passkey.py b/app/helper/passkey.py index 6fccc413..ac410710 100644 --- a/app/helper/passkey.py +++ b/app/helper/passkey.py @@ -3,6 +3,7 @@ PassKey WebAuthn 辅助工具类 """ import base64 import json +import binascii from typing import Optional, Tuple, List, Dict, Any from urllib.parse import urlparse @@ -23,10 +24,6 @@ from webauthn.helpers.structs import ( UserVerificationRequirement, AuthenticatorAttachment, ResidentKeyRequirement, - PublicKeyCredentialCreationOptions, - PublicKeyCredentialRequestOptions, - RegistrationCredential, - AuthenticationCredential, AuthenticatorSelectionCriteria ) from webauthn.helpers.cose import COSEAlgorithmIdentifier @@ -46,6 +43,14 @@ class PassKeyHelper: 获取 Relying Party ID """ if settings.APP_DOMAIN: + app_domain = settings.APP_DOMAIN.strip() + # 确保存在协议前缀,以便 urlparse 正确解析主机和端口 + if not app_domain.startswith(('http://', 'https://')): + app_domain = f'https://{app_domain}' + parsed = urlparse(app_domain) + host = parsed.hostname + if host: + return host # 从 APP_DOMAIN 中提取域名 host = settings.APP_DOMAIN.replace('https://', '').replace('http://', '') # 移除端口号 @@ -81,7 +86,7 @@ class PassKeyHelper: # Base64解码并重新编码以标准化格式 decoded = base64.urlsafe_b64decode(credential_id + '==') return base64.urlsafe_b64encode(decoded).decode('utf-8').rstrip('=') - except Exception as e: + except (binascii.Error, TypeError, ValueError) as e: logger.error(f"标准化凭证ID失败: {e}") return credential_id @@ -135,7 +140,7 @@ class PassKeyHelper: user_display_name=display_name or username, exclude_credentials=exclude_credentials if exclude_credentials else None, authenticator_selection=AuthenticatorSelectionCriteria( - authenticator_attachment=AuthenticatorAttachment.PLATFORM, + authenticator_attachment=None, resident_key=ResidentKeyRequirement.REQUIRED, user_verification=uv_requirement, ), diff --git a/app/schemas/token.py b/app/schemas/token.py index 5493f880..8e868aee 100644 --- a/app/schemas/token.py +++ b/app/schemas/token.py @@ -21,7 +21,7 @@ class Token(BaseModel): # 详细权限 permissions: Optional[dict] = Field(default_factory=dict) # 是否显示配置向导 - widzard: Optional[bool] = None + wizard: Optional[bool] = None class TokenPayload(BaseModel): From 2de83c44abd578964d3d79c34ecbeeb6ba1a71ad Mon Sep 17 00:00:00 2001 From: PKC278 <52959804+PKC278@users.noreply.github.com> Date: Tue, 23 Dec 2025 17:06:17 +0800 Subject: [PATCH 2/2] =?UTF-8?q?refactor(mcp):=20=E7=B2=BE=E7=AE=80?= =?UTF-8?q?=E4=BC=9A=E8=AF=9D=E7=AE=A1=E7=90=86=E9=80=BB=E8=BE=91=E5=B9=B6?= =?UTF-8?q?=E6=9B=B4=E6=96=B0API=E6=96=87=E6=A1=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/endpoints/mcp.py | 186 +++++++++------------------------------ docs/mcp-api.md | 19 ++-- 2 files changed, 50 insertions(+), 155 deletions(-) diff --git a/app/api/endpoints/mcp.py b/app/api/endpoints/mcp.py index be238fc2..889b9a3f 100644 --- a/app/api/endpoints/mcp.py +++ b/app/api/endpoints/mcp.py @@ -2,7 +2,6 @@ 通过HTTP API暴露MoviePilot的智能体工具功能 """ -import uuid from typing import List, Any, Dict, Annotated, Optional, Union from fastapi import APIRouter, Depends, HTTPException, Request, Header @@ -25,64 +24,18 @@ router = APIRouter() MCP_PROTOCOL_VERSIONS = ["2025-11-25", "2025-06-18", "2024-11-05"] MCP_PROTOCOL_VERSION = MCP_PROTOCOL_VERSIONS[0] # 默认使用最新版本 -# 全局会话管理器 -_sessions: Dict[str, Dict[str, Any]] = {} -# 全局工具管理器实例(单例模式,按用户ID缓存) -_tools_managers: Dict[str, MoviePilotToolsManager] = {} - - -def get_tools_manager(user_id: str = "mcp_user", session_id: str = "mcp_session") -> MoviePilotToolsManager: +def get_tools_manager() -> MoviePilotToolsManager: """ - 获取工具管理器实例(按用户ID缓存) - + 获取工具管理器实例 + Args: user_id: 用户ID - session_id: 会话ID - + Returns: MoviePilotToolsManager实例 """ - global _tools_managers - # 使用用户ID作为缓存键 - cache_key = f"{user_id}_{session_id}" - if cache_key not in _tools_managers: - _tools_managers[cache_key] = MoviePilotToolsManager( - user_id=user_id, - session_id=session_id - ) - return _tools_managers[cache_key] - - -def get_session(session_id: Optional[str]) -> Optional[Dict[str, Any]]: - """获取会话""" - if not session_id: - return None - return _sessions.get(session_id) - - -def create_session(user_id: str) -> Dict[str, Any]: - """创建新会话""" - session_id = str(uuid.uuid4()) - session = { - "id": session_id, - "user_id": user_id, - "initialized": False, - "protocol_version": None, - "capabilities": {} - } - _sessions[session_id] = session - return session - - -def delete_session(session_id: str): - """删除会话""" - if session_id in _sessions: - del _sessions[session_id] - # 同时清理工具管理器缓存 - cache_key = f"{_sessions.get(session_id, {}).get('user_id', 'mcp_user')}_{session_id}" - if cache_key in _tools_managers: - del _tools_managers[cache_key] + return MoviePilotToolsManager() def create_jsonrpc_response(request_id: Union[str, int, None], result: Any) -> Dict[str, Any]: @@ -112,13 +65,11 @@ def create_jsonrpc_error(request_id: Union[str, int, None], code: int, message: # ==================== MCP JSON-RPC 端点 ==================== -@router.post("", summary="MCP JSON-RPC 端点") +@router.post("", summary="MCP JSON-RPC 端点", response_model=None) async def mcp_jsonrpc( request: Request, - mcp_session_id: Optional[str] = Header(None, alias="MCP-Session-Id"), - mcp_protocol_version: Optional[str] = Header(None, alias="MCP-Protocol-Version"), _: Annotated[str, Depends(verify_apikey)] = None -) -> JSONResponse: +) -> Union[JSONResponse, Response]: """ MCP 标准 JSON-RPC 2.0 端点 @@ -150,48 +101,27 @@ async def mcp_jsonrpc( try: # 处理初始化请求 if method == "initialize": - result = await handle_initialize(params, mcp_session_id) - response = create_jsonrpc_response(request_id, result["result"]) - # 如果创建了新会话,在响应头中返回 - if "session_id" in result: - headers = {"MCP-Session-Id": result["session_id"]} - return JSONResponse(content=response, headers=headers) - return JSONResponse(content=response) + result = await handle_initialize(params) + return JSONResponse(content=create_jsonrpc_response(request_id, result)) # 处理已初始化通知 elif method == "notifications/initialized": if is_notification: - session = get_session(mcp_session_id) - if session: - session["initialized"] = True - # 通知不需要响应 - return Response(status_code=202) + return Response(status_code=204) else: return JSONResponse( - content=create_jsonrpc_error(request_id, -32600, "initialized must be a notification") - ) - - # 验证会话(除了 initialize 和 ping) - if method not in ["initialize", "ping"]: - session = get_session(mcp_session_id) - if not session: - return JSONResponse( - status_code=404, - content=create_jsonrpc_error(request_id, -32002, "Session not found") - ) - if not session.get("initialized") and method != "notifications/initialized": - return JSONResponse( - content=create_jsonrpc_error(request_id, -32003, "Not initialized") + status_code=400, + content={"error": "initialized must be a notification"} ) # 处理工具列表请求 if method == "tools/list": - result = await handle_tools_list(mcp_session_id) + result = await handle_tools_list() return JSONResponse(content=create_jsonrpc_response(request_id, result)) # 处理工具调用请求 elif method == "tools/call": - result = await handle_tools_call(params, mcp_session_id) + result = await handle_tools_call(params) return JSONResponse(content=create_jsonrpc_response(request_id, result)) # 处理 ping 请求 @@ -204,6 +134,12 @@ async def mcp_jsonrpc( content=create_jsonrpc_error(request_id, -32601, f"Method not found: {method}") ) + except ValueError as e: + logger.warning(f"MCP 请求参数错误: {e}") + return JSONResponse( + status_code=400, + content=create_jsonrpc_error(request_id, -32602, "Invalid params", str(e)) + ) except Exception as e: logger.error(f"处理 MCP 请求失败: {e}", exc_info=True) return JSONResponse( @@ -212,24 +148,13 @@ async def mcp_jsonrpc( ) -async def handle_initialize(params: Dict[str, Any], session_id: Optional[str]) -> Dict[str, Any]: +async def handle_initialize(params: Dict[str, Any]) -> Dict[str, Any]: """处理初始化请求""" protocol_version = params.get("protocolVersion") client_info = params.get("clientInfo", {}) - client_capabilities = params.get("capabilities", {}) logger.info(f"MCP 初始化请求: 客户端={client_info.get('name')}, 协议版本={protocol_version}") - # 如果没有提供会话ID,创建新会话 - new_session = None - if not session_id: - new_session = create_session(user_id="mcp_user") - session_id = new_session["id"] - - session = get_session(session_id) or new_session - if not session: - raise ValueError("Failed to create session") - # 版本协商:选择客户端和服务器都支持的版本 negotiated_version = MCP_PROTOCOL_VERSION if protocol_version in MCP_PROTOCOL_VERSIONS: @@ -240,40 +165,26 @@ async def handle_initialize(params: Dict[str, Any], session_id: Optional[str]) - # 客户端版本不支持,使用服务器默认版本 logger.warning(f"协议版本不匹配: 客户端={protocol_version}, 使用服务器版本={negotiated_version}") - session["protocol_version"] = negotiated_version - session["capabilities"] = client_capabilities - - result = { - "result": { - "protocolVersion": negotiated_version, - "capabilities": { - "tools": { - "listChanged": False # 暂不支持工具列表变更通知 - }, - "logging": {} + return { + "protocolVersion": negotiated_version, + "capabilities": { + "tools": { + "listChanged": False # 暂不支持工具列表变更通知 }, - "serverInfo": { - "name": "MoviePilot", - "version": APP_VERSION, - "description": "MoviePilot MCP Server - 电影自动化管理工具", - }, - "instructions": "MoviePilot MCP 服务器,提供媒体管理、订阅、下载等工具。使用 tools/list 查看所有可用工具。" - } + "logging": {} + }, + "serverInfo": { + "name": "MoviePilot", + "version": APP_VERSION, + "description": "MoviePilot MCP Server - 电影自动化管理工具", + }, + "instructions": "MoviePilot MCP 服务器,提供媒体管理、订阅、下载等工具。" } - # 如果是新创建的会话,返回会话ID - if new_session: - result["session_id"] = session_id - return result - - -async def handle_tools_list(session_id: Optional[str]) -> Dict[str, Any]: +async def handle_tools_list() -> Dict[str, Any]: """处理工具列表请求""" - session = get_session(session_id) - user_id = session.get("user_id", "mcp_user") if session else "mcp_user" - - manager = get_tools_manager(user_id=user_id, session_id=session_id or "default") + manager = get_tools_manager() tools = manager.list_tools() # 转换为 MCP 工具格式 @@ -291,7 +202,7 @@ async def handle_tools_list(session_id: Optional[str]) -> Dict[str, Any]: } -async def handle_tools_call(params: Dict[str, Any], session_id: Optional[str]) -> Dict[str, Any]: +async def handle_tools_call(params: Dict[str, Any]) -> Dict[str, Any]: """处理工具调用请求""" tool_name = params.get("name") arguments = params.get("arguments", {}) @@ -299,10 +210,7 @@ async def handle_tools_call(params: Dict[str, Any], session_id: Optional[str]) - if not tool_name: raise ValueError("Missing tool name") - session = get_session(session_id) - user_id = session.get("user_id", "mcp_user") if session else "mcp_user" - - manager = get_tools_manager(user_id=user_id, session_id=session_id or "default") + manager = get_tools_manager() try: result_text = await manager.call_tool(tool_name, arguments) @@ -328,26 +236,18 @@ async def handle_tools_call(params: Dict[str, Any], session_id: Optional[str]) - } -@router.delete("", summary="终止 MCP 会话") +@router.delete("", summary="终止 MCP 会话", response_model=None) async def delete_mcp_session( - mcp_session_id: Optional[str] = Header(None, alias="MCP-Session-Id"), _: Annotated[str, Depends(verify_apikey)] = None -) -> JSONResponse: +) -> Union[JSONResponse, Response]: """ - 终止 MCP 会话(可选实现) - - 客户端可以主动调用此接口终止会话 + 终止 MCP 会话(无状态模式下仅返回成功) """ - if not mcp_session_id: - return JSONResponse( - status_code=400, - content={"detail": "Missing MCP-Session-Id header"} - ) - - delete_session(mcp_session_id) return Response(status_code=204) + + # ==================== 兼容的 RESTful API 端点 ==================== @router.get("/tools", summary="列出所有可用工具", response_model=List[Dict[str, Any]]) diff --git a/docs/mcp-api.md b/docs/mcp-api.md index 01d0cea6..b5f27ba6 100644 --- a/docs/mcp-api.md +++ b/docs/mcp-api.md @@ -16,11 +16,12 @@ MoviePilot 实现了标准的 **Model Context Protocol (MCP)**,允许 AI 智 ### 端点 **POST** `/api/v1/mcp` -## 3. 会话管理 - -* **会话维持**: 在标准 MCP 流程中,通过 HTTP Header `MCP-Session-Id` 识别会话。 -* **主动终止**: - **DELETE** `/api/v1/mcp` (携带 `MCP-Session-Id` Header) +### 支持的方法 +- `initialize`: 初始化会话,协商协议版本和能力。 +- `notifications/initialized`: 客户端确认初始化完成。 +- `tools/list`: 获取可用工具列表。 +- `tools/call`: 调用特定工具。 +- `ping`: 连接存活检测。 --- @@ -65,6 +66,7 @@ MoviePilot 实现了标准的 **Model Context Protocol (MCP)**,允许 AI 智 | -32700 | Parse error | JSON 格式错误 | | -32600 | Invalid Request | 无效的 JSON-RPC 请求 | | -32601 | Method not found | 方法不存在 | +| -32602 | Invalid params | 参数验证失败 | | -32002 | Session not found | 会话不存在或已过期 | | -32003 | Not initialized | 会话未完成初始化流程 | | -32603 | Internal error | 服务器内部错误 | @@ -203,10 +205,3 @@ MoviePilot 实现了标准的 **Model Context Protocol (MCP)**,允许 AI 智 "required": ["title", "year", "media_type"] } ``` - -## 7. 注意事项 - -1. **用户上下文**: API调用会使用当前认证用户的ID作为工具执行的用户上下文 -2. **会话隔离**: 每个API请求使用独立的会话ID -3. **参数验证**: 工具参数会根据JSON Schema进行验证 -4. **错误日志**: 所有工具调用错误都会记录到MoviePilot日志系统