From 7b99f09810b6496b153e5c9a24ea503ebf16a2d3 Mon Sep 17 00:00:00 2001 From: PKC278 <52959804+PKC278@users.noreply.github.com> Date: Tue, 23 Dec 2025 14:57:43 +0800 Subject: [PATCH] =?UTF-8?q?fix(mfa):=20=E4=BF=AE=E5=A4=8D=E5=8F=8C?= =?UTF-8?q?=E9=87=8D=E9=AA=8C=E8=AF=81=E6=BC=8F=E6=B4=9E?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/endpoints/login.py | 4 +- app/api/endpoints/mfa.py | 77 ++++++++++++++++++++++++++------------ app/chain/user.py | 13 ++++++- app/db/models/passkey.py | 56 +++++++++++++-------------- app/helper/passkey.py | 17 ++++++--- app/schemas/token.py | 2 +- 6 files changed, 104 insertions(+), 65 deletions(-) diff --git a/app/api/endpoints/login.py b/app/api/endpoints/login.py index 769ad445..928ba16f 100644 --- a/app/api/endpoints/login.py +++ b/app/api/endpoints/login.py @@ -33,7 +33,7 @@ def login_access_token( if user_or_message == "MFA_REQUIRED": raise HTTPException( status_code=401, - detail="需要双重验证", + detail="需要双重验证,请提供验证码或使用通行密钥", headers={"X-MFA-Required": "true"} ) raise HTTPException(status_code=401, detail=user_or_message) @@ -57,7 +57,7 @@ def login_access_token( avatar=user_or_message.avatar, level=level, permissions=user_or_message.permissions or {}, - widzard=show_wizard + wizard=show_wizard ) diff --git a/app/api/endpoints/mfa.py b/app/api/endpoints/mfa.py index f4498c25..c0f63d02 100644 --- a/app/api/endpoints/mfa.py +++ b/app/api/endpoints/mfa.py @@ -2,9 +2,8 @@ MFA (Multi-Factor Authentication) API 端点 包含 OTP 和 PassKey 相关功能 """ -import base64 -from datetime import datetime, timedelta -from typing import Any, Annotated, List, Union +from datetime import timedelta +from typing import Any, Annotated, Optional, List, Union from fastapi import APIRouter, Depends, HTTPException, Body from sqlalchemy.ext.asyncio import AsyncSession @@ -25,6 +24,22 @@ from app.utils.otp import OtpUtils router = APIRouter() +# ==================== 请求模型 ==================== + +class OtpVerifyRequest(schemas.BaseModel): + """OTP验证请求""" + uri: str + otpPassword: str + +class OtpDisableRequest(schemas.BaseModel): + """OTP禁用请求""" + password: str + +class PassKeyDeleteRequest(schemas.BaseModel): + """PassKey删除请求""" + passkey_id: int + password: str + # ==================== 通用 MFA 接口 ==================== @router.get('/status/{username}', summary='判断用户是否开启双重验证(MFA)', response_model=schemas.Response) @@ -40,7 +55,7 @@ async def mfa_status(username: str, db: AsyncSession = Depends(get_async_db)) -> has_otp = user.is_otp # 检查是否有PassKey - has_passkey = bool(PassKey().get_by_user_id(db=None, user_id=user.id)) + has_passkey = bool(await PassKey.async_get_by_user_id(db=db, user_id=user.id)) # 只要有任何一种验证方式,就需要双重验证 return schemas.Response(success=(has_otp or has_passkey)) @@ -59,25 +74,35 @@ def otp_generate( @router.post('/otp/verify', summary='绑定并验证 OTP', response_model=schemas.Response) async def otp_verify( - data: dict, + data: OtpVerifyRequest, db: AsyncSession = Depends(get_async_db), current_user: User = Depends(get_current_active_user_async) ) -> Any: """验证用户输入的 OTP 码,验证通过后正式开启 OTP 验证""" - uri = data.get("uri") - otp_password = data.get("otpPassword") - if not OtpUtils.is_legal(uri, otp_password): + if not OtpUtils.is_legal(data.uri, data.otpPassword): return schemas.Response(success=False, message="验证码错误") - await current_user.async_update_otp_by_name(db, current_user.name, True, OtpUtils.get_secret(uri)) + await current_user.async_update_otp_by_name(db, current_user.name, True, OtpUtils.get_secret(data.uri)) return schemas.Response(success=True) @router.post('/otp/disable', summary='关闭当前用户的 OTP 验证', response_model=schemas.Response) async def otp_disable( + data: OtpDisableRequest, db: AsyncSession = Depends(get_async_db), current_user: User = Depends(get_current_active_user_async) ) -> Any: """关闭当前用户的 OTP 验证功能""" + # 安全检查:如果存在 PassKey,不允许关闭 OTP + has_passkey = bool(await PassKey.async_get_by_user_id(db=db, user_id=current_user.id)) + if has_passkey: + return schemas.Response( + success=False, + message="您已注册通行密钥,为了防止域名配置变更导致无法登录,请先删除所有通行密钥再关闭 OTP 验证" + ) + + # 验证密码 + if not security.verify_password(data.password, str(current_user.hashed_password)): + return schemas.Response(success=False, message="密码错误") await current_user.async_update_otp_by_name(db, current_user.name, False, "") return schemas.Response(success=True) @@ -98,7 +123,7 @@ class PassKeyRegistrationFinish(schemas.BaseModel): class PassKeyAuthenticationStart(schemas.BaseModel): """PassKey认证开始请求""" - username: str | None = None + username: Optional[str] = None class PassKeyAuthenticationFinish(schemas.BaseModel): @@ -122,7 +147,7 @@ def passkey_register_start( ) # 获取用户已有的PassKey - existing_passkeys = PassKey().get_by_user_id(db=None, user_id=current_user.id) + existing_passkeys = PassKey.get_by_user_id(db=None, user_id=current_user.id) existing_credentials = [ { 'credential_id': pk.credential_id, @@ -215,7 +240,7 @@ def passkey_authenticate_start( message="用户不存在" ) - existing_passkeys = PassKey().get_by_user_id(db=None, user_id=user.id) + existing_passkeys = PassKey.get_by_user_id(db=None, user_id=user.id) if not existing_passkeys: return schemas.Response( success=False, @@ -265,7 +290,7 @@ def passkey_authenticate_finish( credential_id = PassKeyHelper.standardize_credential_id(credential_id_raw) # 查找PassKey - passkey = PassKey().get_by_credential_id(db=None, credential_id=credential_id) + passkey = PassKey.get_by_credential_id(db=None, credential_id=credential_id) if not passkey: raise HTTPException(status_code=401, detail="通行密钥不存在或已失效") @@ -286,7 +311,7 @@ def passkey_authenticate_finish( raise HTTPException(status_code=401, detail="通行密钥验证失败") # 更新使用时间和签名计数 - passkey.update_last_used(db=None, credential_id=credential_id, sign_count=new_sign_count) + passkey.update_last_used(db=None, sign_count=new_sign_count) logger.info(f"用户 {user.name} 通过PassKey认证成功") @@ -309,7 +334,7 @@ def passkey_authenticate_finish( avatar=user.avatar, level=level, permissions=user.permissions or {}, - widzard=show_wizard + wizard=show_wizard ) except HTTPException: raise @@ -324,7 +349,7 @@ def passkey_list( ) -> Any: """获取当前用户的所有 PassKey""" try: - passkeys = PassKey().get_by_user_id(db=None, user_id=current_user.id) + passkeys = PassKey.get_by_user_id(db=None, user_id=current_user.id) passkey_list = [ { @@ -350,17 +375,21 @@ def passkey_list( ) -@router.delete("/passkey/{passkey_id}", summary="删除 PassKey", response_model=schemas.Response) -def passkey_delete( - passkey_id: int, - current_user: Annotated[User, Depends(get_current_active_user)] +@router.post("/passkey/delete", summary="删除 PassKey", response_model=schemas.Response) +async def passkey_delete( + data: PassKeyDeleteRequest, + current_user: User = Depends(get_current_active_user_async) ) -> Any: """删除指定的 PassKey""" try: - success = PassKey().delete_by_id(db=None, passkey_id=passkey_id, user_id=current_user.id) + # 验证密码 + if not security.verify_password(data.password, str(current_user.hashed_password)): + return schemas.Response(success=False, message="密码错误") + + success = PassKey.delete_by_id(db=None, passkey_id=data.passkey_id, user_id=current_user.id) if success: - logger.info(f"用户 {current_user.name} 删除了PassKey: {passkey_id}") + logger.info(f"用户 {current_user.name} 删除了PassKey: {data.passkey_id}") return schemas.Response( success=True, message="通行密钥已删除" @@ -397,7 +426,7 @@ def passkey_verify_mfa( credential_id = PassKeyHelper.standardize_credential_id(credential_id_raw) # 查找PassKey(必须属于当前用户) - passkey = PassKey().get_by_credential_id(db=None, credential_id=credential_id) + passkey = PassKey.get_by_credential_id(db=None, credential_id=credential_id) if not passkey or passkey.user_id != current_user.id: return schemas.Response( success=False, @@ -419,7 +448,7 @@ def passkey_verify_mfa( ) # 更新使用时间和签名计数 - passkey.update_last_used(db=None, credential_id=credential_id, sign_count=new_sign_count) + passkey.update_last_used(db=None, sign_count=new_sign_count) logger.info(f"用户 {current_user.name} 通过PassKey二次验证成功") diff --git a/app/chain/user.py b/app/chain/user.py index b920f490..2223487c 100644 --- a/app/chain/user.py +++ b/app/chain/user.py @@ -179,7 +179,7 @@ class UserChain(ChainBase): """ # 检查用户是否有PassKey from app.db.models.passkey import PassKey - has_passkey = bool(PassKey().get_by_user_id(db=None, user_id=user.id)) + has_passkey = bool(PassKey.get_by_user_id(db=None, user_id=user.id)) # 如果用户既没有启用OTP也没有PassKey,直接通过 if not user.is_otp and not has_passkey: @@ -190,11 +190,20 @@ class UserChain(ChainBase): logger.info(f"用户 {user.name} 已启用双重验证(OTP: {user.is_otp}, PassKey: {has_passkey}),需要提供验证码") return "MFA_REQUIRED" - # 如果提供了验证码,验证OTP + # 如果提供了验证码,且用户启用了 OTP,则验证 OTP if user.is_otp: if not OtpUtils.check(str(user.otp_secret), mfa_code): logger.info(f"用户 {user.name} 的 MFA 认证失败") return False + # OTP 验证成功 + return True + + # 用户未启用 OTP,此时提供的 mfa_code 无效;如果启用了 PassKey,则仍需通过 PassKey 验证 + if has_passkey: + logger.info( + f"用户 {user.name} 未启用 OTP,但已启用 PassKey,提供的 MFA 验证码将被忽略,仍需通过 PassKey 验证" + ) + return "MFA_REQUIRED" return True diff --git a/app/db/models/passkey.py b/app/db/models/passkey.py index 129d29e9..971a94c3 100644 --- a/app/db/models/passkey.py +++ b/app/db/models/passkey.py @@ -37,14 +37,14 @@ class PassKey(Base): @db_query def get_by_user_id(cls, db: Session, user_id: int): """获取用户的所有PassKey""" - return db.query(cls).filter(cls.user_id == user_id, cls.is_active == True).all() + return db.query(cls).filter(cls.user_id == user_id, cls.is_active.is_(True)).all() @classmethod @async_db_query async def async_get_by_user_id(cls, db: AsyncSession, user_id: int): """异步获取用户的所有PassKey""" result = await db.execute( - select(cls).filter(cls.user_id == user_id, cls.is_active == True) + select(cls).filter(cls.user_id == user_id, cls.is_active.is_(True)) ) return result.scalars().all() @@ -52,14 +52,14 @@ class PassKey(Base): @db_query def get_by_credential_id(cls, db: Session, credential_id: str): """根据凭证ID获取PassKey""" - return db.query(cls).filter(cls.credential_id == credential_id, cls.is_active == True).first() + return db.query(cls).filter(cls.credential_id == credential_id, cls.is_active.is_(True)).first() @classmethod @async_db_query async def async_get_by_credential_id(cls, db: AsyncSession, credential_id: str): """异步根据凭证ID获取PassKey""" result = await db.execute( - select(cls).filter(cls.credential_id == credential_id, cls.is_active == True) + select(cls).filter(cls.credential_id == credential_id, cls.is_active.is_(True)) ) return result.scalars().first() @@ -78,25 +78,27 @@ class PassKey(Base): ) return result.scalars().first() + @classmethod @db_update - def delete_by_id(self, db: Session, passkey_id: int, user_id: int): + def delete_by_id(cls, db: Session, passkey_id: int, user_id: int): """删除指定用户的PassKey""" - passkey = db.query(PassKey).filter( - PassKey.id == passkey_id, - PassKey.user_id == user_id + passkey = db.query(cls).filter( + cls.id == passkey_id, + cls.user_id == user_id ).first() if passkey: passkey.delete(db, passkey.id) return True return False + @classmethod @async_db_update - async def async_delete_by_id(self, db: AsyncSession, passkey_id: int, user_id: int): + async def async_delete_by_id(cls, db: AsyncSession, passkey_id: int, user_id: int): """异步删除指定用户的PassKey""" result = await db.execute( - select(PassKey).filter( - PassKey.id == passkey_id, - PassKey.user_id == user_id + select(cls).filter( + cls.id == passkey_id, + cls.user_id == user_id ) ) passkey = result.scalars().first() @@ -106,25 +108,19 @@ class PassKey(Base): return False @db_update - def update_last_used(self, db: Session, credential_id: str, sign_count: int): + def update_last_used(self, db: Session, sign_count: int): """更新最后使用时间和签名计数""" - passkey = self.get_by_credential_id(db, credential_id) - if passkey: - passkey.update(db, { - 'last_used_at': datetime.now(), - 'sign_count': sign_count - }) - return True - return False + self.update(db, { + 'last_used_at': datetime.now(), + 'sign_count': sign_count + }) + return True @async_db_update - async def async_update_last_used(self, db: AsyncSession, credential_id: str, sign_count: int): + async def async_update_last_used(self, db: AsyncSession, sign_count: int): """异步更新最后使用时间和签名计数""" - passkey = await self.async_get_by_credential_id(db, credential_id) - if passkey: - await passkey.async_update(db, { - 'last_used_at': datetime.now(), - 'sign_count': sign_count - }) - return True - return False + await self.async_update(db, { + 'last_used_at': datetime.now(), + 'sign_count': sign_count + }) + return True diff --git a/app/helper/passkey.py b/app/helper/passkey.py index 6fccc413..ac410710 100644 --- a/app/helper/passkey.py +++ b/app/helper/passkey.py @@ -3,6 +3,7 @@ PassKey WebAuthn 辅助工具类 """ import base64 import json +import binascii from typing import Optional, Tuple, List, Dict, Any from urllib.parse import urlparse @@ -23,10 +24,6 @@ from webauthn.helpers.structs import ( UserVerificationRequirement, AuthenticatorAttachment, ResidentKeyRequirement, - PublicKeyCredentialCreationOptions, - PublicKeyCredentialRequestOptions, - RegistrationCredential, - AuthenticationCredential, AuthenticatorSelectionCriteria ) from webauthn.helpers.cose import COSEAlgorithmIdentifier @@ -46,6 +43,14 @@ class PassKeyHelper: 获取 Relying Party ID """ if settings.APP_DOMAIN: + app_domain = settings.APP_DOMAIN.strip() + # 确保存在协议前缀,以便 urlparse 正确解析主机和端口 + if not app_domain.startswith(('http://', 'https://')): + app_domain = f'https://{app_domain}' + parsed = urlparse(app_domain) + host = parsed.hostname + if host: + return host # 从 APP_DOMAIN 中提取域名 host = settings.APP_DOMAIN.replace('https://', '').replace('http://', '') # 移除端口号 @@ -81,7 +86,7 @@ class PassKeyHelper: # Base64解码并重新编码以标准化格式 decoded = base64.urlsafe_b64decode(credential_id + '==') return base64.urlsafe_b64encode(decoded).decode('utf-8').rstrip('=') - except Exception as e: + except (binascii.Error, TypeError, ValueError) as e: logger.error(f"标准化凭证ID失败: {e}") return credential_id @@ -135,7 +140,7 @@ class PassKeyHelper: user_display_name=display_name or username, exclude_credentials=exclude_credentials if exclude_credentials else None, authenticator_selection=AuthenticatorSelectionCriteria( - authenticator_attachment=AuthenticatorAttachment.PLATFORM, + authenticator_attachment=None, resident_key=ResidentKeyRequirement.REQUIRED, user_verification=uv_requirement, ), diff --git a/app/schemas/token.py b/app/schemas/token.py index 5493f880..8e868aee 100644 --- a/app/schemas/token.py +++ b/app/schemas/token.py @@ -21,7 +21,7 @@ class Token(BaseModel): # 详细权限 permissions: Optional[dict] = Field(default_factory=dict) # 是否显示配置向导 - widzard: Optional[bool] = None + wizard: Optional[bool] = None class TokenPayload(BaseModel):