From 9e64b4cd7fcb3fdcf78ec208ddf4a673fb9b1741 Mon Sep 17 00:00:00 2001 From: PKC278 <52959804+PKC278@users.noreply.github.com> Date: Sun, 11 Jan 2026 19:20:53 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E4=BC=98=E5=8C=96=E7=99=BB?= =?UTF-8?q?=E5=BD=95=E5=AE=89=E5=85=A8=E6=80=A7=E5=B9=B6=E9=87=8D=E6=9E=84?= =?UTF-8?q?=20PassKey=20=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 统一登录失败返回信息,防止信息泄露 - 提取 PassKeyHelper 公共函数,简化 Base64 和凭证处理 - 重构 mfa.py 端点代码,提升可读性和维护性 - 移除冗余的 origin 验证逻辑 --- app/api/endpoints/login.py | 4 +- app/api/endpoints/mfa.py | 177 ++++++++++++++++++++++--------------- app/helper/passkey.py | 177 +++++++++++++++++++------------------ 3 files changed, 201 insertions(+), 157 deletions(-) diff --git a/app/api/endpoints/login.py b/app/api/endpoints/login.py index 928ba16f..865218d7 100644 --- a/app/api/endpoints/login.py +++ b/app/api/endpoints/login.py @@ -32,11 +32,11 @@ def login_access_token( # 如果是需要MFA验证,返回特殊标识 if user_or_message == "MFA_REQUIRED": raise HTTPException( - status_code=401, + status_code=401, detail="需要双重验证,请提供验证码或使用通行密钥", headers={"X-MFA-Required": "true"} ) - raise HTTPException(status_code=401, detail=user_or_message) + raise HTTPException(status_code=401, detail="用户名或密码错误") # 用户等级 level = SitesHelper().auth_level diff --git a/app/api/endpoints/mfa.py b/app/api/endpoints/mfa.py index 98b9eae1..d480bb3f 100644 --- a/app/api/endpoints/mfa.py +++ b/app/api/endpoints/mfa.py @@ -24,6 +24,75 @@ from app.utils.otp import OtpUtils router = APIRouter() +# ==================== 辅助函数 ==================== + +def _build_credential_list(passkeys: list[PassKey]) -> list[dict[str, Any]]: + """ + 构建凭证列表 + + :param passkeys: PassKey 列表 + :return: 凭证字典列表 + """ + return [ + { + 'credential_id': pk.credential_id, + 'transports': pk.transports + } + for pk in passkeys + ] if passkeys else [] + + +def _extract_and_standardize_credential_id(credential: dict) -> str: + """ + 从凭证中提取并标准化 credential_id + + :param credential: 凭证字典 + :return: 标准化后的 credential_id + :raises ValueError: 如果凭证无效 + """ + credential_id_raw = credential.get('id') or credential.get('rawId') + if not credential_id_raw: + raise ValueError("无效的凭证") + return PassKeyHelper.standardize_credential_id(credential_id_raw) + + +def _verify_passkey_and_update( + credential: dict, + challenge: str, + passkey: PassKey +) -> tuple[bool, int]: + """ + 验证 PassKey 并更新使用时间和签名计数 + + :param credential: 凭证字典 + :param challenge: 挑战值 + :param passkey: PassKey 对象 + :return: (验证是否成功, 新的签名计数) + """ + success, new_sign_count = PassKeyHelper.verify_authentication_response( + credential=credential, + expected_challenge=challenge, + credential_public_key=passkey.public_key, + credential_current_sign_count=passkey.sign_count + ) + + if success: + passkey.update_last_used(db=None, sign_count=new_sign_count) + + return success, new_sign_count + + +async def _check_user_has_passkey(db: AsyncSession, user_id: int) -> bool: + """ + 检查用户是否有 PassKey + + :param db: 数据库会话 + :param user_id: 用户 ID + :return: 是否有 PassKey + """ + return bool(await PassKey.async_get_by_user_id(db=db, user_id=user_id)) + + # ==================== 请求模型 ==================== class OtpVerifyRequest(schemas.BaseModel): @@ -55,7 +124,7 @@ async def mfa_status(username: str, db: AsyncSession = Depends(get_async_db)) -> has_otp = user.is_otp # 检查是否有PassKey - has_passkey = bool(await PassKey.async_get_by_user_id(db=db, user_id=user.id)) + has_passkey = await _check_user_has_passkey(db, user.id) # 只要有任何一种验证方式,就需要双重验证 return schemas.Response(success=(has_otp or has_passkey)) @@ -93,7 +162,7 @@ async def otp_disable( ) -> Any: """关闭当前用户的 OTP 验证功能""" # 安全检查:如果存在 PassKey,不允许关闭 OTP - has_passkey = bool(await PassKey.async_get_by_user_id(db=db, user_id=current_user.id)) + has_passkey = await _check_user_has_passkey(db, current_user.id) if has_passkey: return schemas.Response( success=False, @@ -147,13 +216,7 @@ def passkey_register_start( # 获取用户已有的PassKey existing_passkeys = PassKey.get_by_user_id(db=None, user_id=current_user.id) - existing_credentials = [ - { - 'credential_id': pk.credential_id, - 'transports': pk.transports - } - for pk in existing_passkeys - ] if existing_passkeys else None + existing_credentials = _build_credential_list(existing_passkeys) if existing_passkeys else None # 生成注册选项 options_json, challenge = PassKeyHelper.generate_registration_options( @@ -233,26 +296,15 @@ def passkey_authenticate_start( # 如果指定了用户名,只允许该用户的PassKey if passkey_req.username: user = User.get_by_name(db=None, name=passkey_req.username) - if not user: + existing_passkeys = PassKey.get_by_user_id(db=None, user_id=user.id) if user else None + + if not user or not existing_passkeys: return schemas.Response( success=False, - message="用户不存在" + message="认证失败" ) - - existing_passkeys = PassKey.get_by_user_id(db=None, user_id=user.id) - if not existing_passkeys: - return schemas.Response( - success=False, - message="该用户未注册通行密钥" - ) - - existing_credentials = [ - { - 'credential_id': pk.credential_id, - 'transports': pk.transports - } - for pk in existing_passkeys - ] + + existing_credentials = _build_credential_list(existing_passkeys) # 生成认证选项 options_json, challenge = PassKeyHelper.generate_authentication_options( @@ -270,7 +322,7 @@ def passkey_authenticate_start( logger.error(f"生成PassKey认证选项失败: {e}") return schemas.Response( success=False, - message=f"生成认证选项失败: {str(e)}" + message="认证失败" ) @@ -280,37 +332,28 @@ def passkey_authenticate_finish( ) -> Any: """完成 PassKey 认证 - 验证凭证并返回 token""" try: - # 从credential中提取credential_id - credential_id_raw = passkey_req.credential.get('id') or passkey_req.credential.get('rawId') - if not credential_id_raw: - raise HTTPException(status_code=400, detail="无效的凭证") + # 提取并标准化凭证ID + try: + credential_id = _extract_and_standardize_credential_id(passkey_req.credential) + except ValueError as e: + logger.warning(f"PassKey认证失败,提供的凭证无效: {e}") + raise HTTPException(status_code=401, detail="认证失败") - # 标准化凭证ID - credential_id = PassKeyHelper.standardize_credential_id(credential_id_raw) - - # 查找PassKey + # 查找PassKey并获取用户 passkey = PassKey.get_by_credential_id(db=None, credential_id=credential_id) - if not passkey: - raise HTTPException(status_code=401, detail="通行密钥不存在或已失效") + user = User.get_by_id(db=None, user_id=passkey.user_id) if passkey else None + if not passkey or not user or not user.is_active: + raise HTTPException(status_code=401, detail="认证失败") - # 获取用户 - user = User.get_by_id(db=None, user_id=passkey.user_id) - if not user or not user.is_active: - raise HTTPException(status_code=401, detail="用户不存在或已禁用") - - # 验证认证响应 - success, new_sign_count = PassKeyHelper.verify_authentication_response( + # 验证认证响应并更新 + success, _ = _verify_passkey_and_update( credential=passkey_req.credential, - expected_challenge=passkey_req.challenge, - credential_public_key=passkey.public_key, - credential_current_sign_count=passkey.sign_count + challenge=passkey_req.challenge, + passkey=passkey ) if not success: - raise HTTPException(status_code=401, detail="通行密钥验证失败") - - # 更新使用时间和签名计数 - passkey.update_last_used(db=None, sign_count=new_sign_count) + raise HTTPException(status_code=401, detail="认证失败") logger.info(f"用户 {user.name} 通过PassKey认证成功") @@ -339,7 +382,7 @@ def passkey_authenticate_finish( raise except Exception as e: logger.error(f"PassKey认证失败: {e}") - raise HTTPException(status_code=401, detail=f"认证失败: {str(e)}") + raise HTTPException(status_code=401, detail="认证失败") @router.get("/passkey/list", summary="获取当前用户的 PassKey 列表", response_model=schemas.Response) @@ -413,16 +456,12 @@ def passkey_verify_mfa( ) -> Any: """使用 PassKey 进行二次验证(MFA)""" try: - # 从credential中提取credential_id - credential_id_raw = passkey_req.credential.get('id') or passkey_req.credential.get('rawId') - if not credential_id_raw: - return schemas.Response( - success=False, - message="无效的凭证" - ) - - # 标准化凭证ID - credential_id = PassKeyHelper.standardize_credential_id(credential_id_raw) + # 提取并标准化凭证ID + try: + credential_id = _extract_and_standardize_credential_id(passkey_req.credential) + except ValueError as e: + logger.warning(f"PassKey二次验证失败,提供的凭证无效: {e}") + return schemas.Response(success=False, message="验证失败") # 查找PassKey(必须属于当前用户) passkey = PassKey.get_by_credential_id(db=None, credential_id=credential_id) @@ -432,12 +471,11 @@ def passkey_verify_mfa( message="通行密钥不存在或不属于当前用户" ) - # 验证认证响应 - success, new_sign_count = PassKeyHelper.verify_authentication_response( + # 验证认证响应并更新 + success, _ = _verify_passkey_and_update( credential=passkey_req.credential, - expected_challenge=passkey_req.challenge, - credential_public_key=passkey.public_key, - credential_current_sign_count=passkey.sign_count + challenge=passkey_req.challenge, + passkey=passkey ) if not success: @@ -446,9 +484,6 @@ def passkey_verify_mfa( message="通行密钥验证失败" ) - # 更新使用时间和签名计数 - passkey.update_last_used(db=None, sign_count=new_sign_count) - logger.info(f"用户 {current_user.name} 通过PassKey二次验证成功") return schemas.Response( @@ -459,5 +494,5 @@ def passkey_verify_mfa( logger.error(f"PassKey二次验证失败: {e}") return schemas.Response( success=False, - message=f"验证失败: {str(e)}" + message="验证失败" ) diff --git a/app/helper/passkey.py b/app/helper/passkey.py index ac410710..cbac5d27 100644 --- a/app/helper/passkey.py +++ b/app/helper/passkey.py @@ -90,6 +90,79 @@ class PassKeyHelper: logger.error(f"标准化凭证ID失败: {e}") return credential_id + @staticmethod + def _base64_encode_urlsafe(data: bytes) -> str: + """ + Base64 URL Safe 编码(不带填充) + + :param data: 要编码的字节数据 + :return: Base64 URL Safe 编码的字符串 + """ + return base64.urlsafe_b64encode(data).decode('utf-8').rstrip('=') + + @staticmethod + def _base64_decode_urlsafe(data: str) -> bytes: + """ + Base64 URL Safe 解码(自动添加填充) + + :param data: Base64 URL Safe 编码的字符串 + :return: 解码后的字节数据 + """ + return base64.urlsafe_b64decode(data + '==') + + @staticmethod + def _parse_credential_list(credentials: List[Dict[str, Any]]) -> List[PublicKeyCredentialDescriptor]: + """ + 解析凭证列表为 PublicKeyCredentialDescriptor 列表 + + :param credentials: 凭证字典列表 + :return: PublicKeyCredentialDescriptor 列表 + """ + result = [] + for cred in credentials: + try: + result.append( + PublicKeyCredentialDescriptor( + id=PassKeyHelper._base64_decode_urlsafe(cred['credential_id']), + transports=[ + AuthenticatorTransport(t) for t in cred.get('transports', '').split(',') if t + ] if cred.get('transports') else None + ) + ) + except Exception as e: + logger.warning(f"解析凭证失败: {e}") + continue + return result + + @staticmethod + def _get_user_verification_requirement(user_verification: Optional[str] = None) -> UserVerificationRequirement: + """ + 获取用户验证要求 + + :param user_verification: 指定的用户验证要求,如果不指定则从配置中读取 + :return: UserVerificationRequirement + """ + if user_verification: + return UserVerificationRequirement(user_verification) + return UserVerificationRequirement.REQUIRED if settings.PASSKEY_REQUIRE_UV \ + else UserVerificationRequirement.PREFERRED + + @staticmethod + def _get_verification_params( + expected_origin: Optional[str] = None, + expected_rp_id: Optional[str] = None + ) -> Tuple[str, str]: + """ + 获取验证参数(origin 和 rp_id) + + :param expected_origin: 期望的源地址 + :param expected_rp_id: 期望的RP ID + :return: (origin, rp_id) + """ + origin = expected_origin or PassKeyHelper.get_origin() + rp_id = expected_rp_id or PassKeyHelper.get_rp_id() + return origin, rp_id + @staticmethod def generate_registration_options( user_id: int, @@ -109,27 +182,13 @@ class PassKeyHelper: try: # 用户信息 user_id_bytes = str(user_id).encode('utf-8') - + # 排除已有的凭证 - exclude_credentials = [] - if existing_credentials: - for cred in existing_credentials: - try: - exclude_credentials.append( - PublicKeyCredentialDescriptor( - id=base64.urlsafe_b64decode(cred['credential_id'] + '=='), - transports=[ - AuthenticatorTransport(t) for t in cred.get('transports', '').split(',') if t - ] if cred.get('transports') else None - ) - ) - except Exception as e: - logger.warning(f"解析凭证失败: {e}") - continue + exclude_credentials = PassKeyHelper._parse_credential_list(existing_credentials) \ + if existing_credentials else None # 用户验证要求 - uv_requirement = UserVerificationRequirement.REQUIRED if settings.PASSKEY_REQUIRE_UV \ - else UserVerificationRequirement.PREFERRED + uv_requirement = PassKeyHelper._get_user_verification_requirement() # 生成注册选项 options = generate_registration_options( @@ -138,7 +197,7 @@ class PassKeyHelper: user_id=user_id_bytes, user_name=username, user_display_name=display_name or username, - exclude_credentials=exclude_credentials if exclude_credentials else None, + exclude_credentials=exclude_credentials, authenticator_selection=AuthenticatorSelectionCriteria( authenticator_attachment=None, resident_key=ResidentKeyRequirement.REQUIRED, @@ -152,9 +211,9 @@ class PassKeyHelper: # 转换为JSON options_json = options_to_json(options) - + # 提取challenge(用于后续验证) - challenge = base64.urlsafe_b64encode(options.challenge).decode('utf-8').rstrip('=') + challenge = PassKeyHelper._base64_encode_urlsafe(options.challenge) return options_json, challenge @@ -162,29 +221,6 @@ class PassKeyHelper: logger.error(f"生成注册选项失败: {e}") raise - @staticmethod - def _get_verified_origin(credential: Dict[str, Any], rp_id: str, default_origin: str) -> str: - """ - 在 localhost 环境下获取并验证实际 Origin,否则返回默认值 - """ - if not settings.APP_DOMAIN and rp_id == 'localhost': - try: - # 解析 clientDataJSON 获取实际的 origin - client_data_json = json.loads( - base64.urlsafe_b64decode( - credential['response']['clientDataJSON'].replace('-', '+').replace('_', '/') + '==' - ).decode('utf-8') - ) - actual_origin = client_data_json.get('origin', '') - hostname = urlparse(actual_origin).hostname - - if hostname in ['localhost', '127.0.0.1']: - logger.info(f"本地环境,使用动态 origin: {actual_origin}") - return actual_origin - except Exception as e: - logger.warning(f"无法提取动态 origin: {e}") - return default_origin - @staticmethod def verify_registration_response( credential: Dict[str, Any], @@ -203,18 +239,13 @@ class PassKeyHelper: """ try: # 准备验证参数 - origin = expected_origin or PassKeyHelper.get_origin() - rp_id = expected_rp_id or PassKeyHelper.get_rp_id() - + origin, rp_id = PassKeyHelper._get_verification_params(expected_origin, expected_rp_id) # 解码challenge - challenge_bytes = base64.urlsafe_b64decode(expected_challenge + '==') + challenge_bytes = PassKeyHelper._base64_decode_urlsafe(expected_challenge) # 构建RegistrationCredential对象 registration_credential = parse_registration_credential_json(json.dumps(credential)) - # 获取并验证 Origin - origin = PassKeyHelper._get_verified_origin(credential, rp_id, origin) - # 验证注册响应 verification = verify_registration_response( credential=registration_credential, @@ -225,8 +256,8 @@ class PassKeyHelper: ) # 提取信息 - credential_id = base64.urlsafe_b64encode(verification.credential_id).decode('utf-8').rstrip('=') - public_key = base64.urlsafe_b64encode(verification.credential_public_key).decode('utf-8').rstrip('=') + credential_id = PassKeyHelper._base64_encode_urlsafe(verification.credential_id) + public_key = PassKeyHelper._base64_encode_urlsafe(verification.credential_public_key) sign_count = verification.sign_count # aaguid 可能已经是字符串格式,也可能是bytes if verification.aaguid: @@ -257,41 +288,24 @@ class PassKeyHelper: """ try: # 允许的凭证 - allow_credentials = [] - if existing_credentials: - for cred in existing_credentials: - try: - allow_credentials.append( - PublicKeyCredentialDescriptor( - id=base64.urlsafe_b64decode(cred['credential_id'] + '=='), - transports=[ - AuthenticatorTransport(t) for t in cred.get('transports', '').split(',') if t - ] if cred.get('transports') else None - ) - ) - except Exception as e: - logger.warning(f"解析凭证失败: {e}") - continue + allow_credentials = PassKeyHelper._parse_credential_list(existing_credentials) \ + if existing_credentials else None # 用户验证要求 - if not user_verification: - uv_requirement = UserVerificationRequirement.REQUIRED if settings.PASSKEY_REQUIRE_UV \ - else UserVerificationRequirement.PREFERRED - else: - uv_requirement = UserVerificationRequirement(user_verification) + uv_requirement = PassKeyHelper._get_user_verification_requirement(user_verification) # 生成认证选项 options = generate_authentication_options( rp_id=PassKeyHelper.get_rp_id(), - allow_credentials=allow_credentials if allow_credentials else None, + allow_credentials=allow_credentials, user_verification=uv_requirement ) # 转换为JSON options_json = options_to_json(options) - + # 提取challenge - challenge = base64.urlsafe_b64encode(options.challenge).decode('utf-8').rstrip('=') + challenge = PassKeyHelper._base64_encode_urlsafe(options.challenge) return options_json, challenge @@ -321,19 +335,14 @@ class PassKeyHelper: """ try: # 准备验证参数 - origin = expected_origin or PassKeyHelper.get_origin() - rp_id = expected_rp_id or PassKeyHelper.get_rp_id() - + origin, rp_id = PassKeyHelper._get_verification_params(expected_origin, expected_rp_id) # 解码 - challenge_bytes = base64.urlsafe_b64decode(expected_challenge + '==') - public_key_bytes = base64.urlsafe_b64decode(credential_public_key + '==') + challenge_bytes = PassKeyHelper._base64_decode_urlsafe(expected_challenge) + public_key_bytes = PassKeyHelper._base64_decode_urlsafe(credential_public_key) # 构建AuthenticationCredential对象 authentication_credential = parse_authentication_credential_json(json.dumps(credential)) - # 获取并验证 Origin - origin = PassKeyHelper._get_verified_origin(credential, rp_id, origin) - # 验证认证响应 verification = verify_authentication_response( credential=authentication_credential,