fix(mfa): 修复双重验证漏洞

This commit is contained in:
PKC278
2025-12-23 14:57:43 +08:00
parent 6b4ba8bfad
commit 7b99f09810
6 changed files with 104 additions and 65 deletions

View File

@@ -33,7 +33,7 @@ def login_access_token(
if user_or_message == "MFA_REQUIRED": if user_or_message == "MFA_REQUIRED":
raise HTTPException( raise HTTPException(
status_code=401, status_code=401,
detail="需要双重验证", detail="需要双重验证,请提供验证码或使用通行密钥",
headers={"X-MFA-Required": "true"} headers={"X-MFA-Required": "true"}
) )
raise HTTPException(status_code=401, detail=user_or_message) raise HTTPException(status_code=401, detail=user_or_message)
@@ -57,7 +57,7 @@ def login_access_token(
avatar=user_or_message.avatar, avatar=user_or_message.avatar,
level=level, level=level,
permissions=user_or_message.permissions or {}, permissions=user_or_message.permissions or {},
widzard=show_wizard wizard=show_wizard
) )

View File

@@ -2,9 +2,8 @@
MFA (Multi-Factor Authentication) API 端点 MFA (Multi-Factor Authentication) API 端点
包含 OTP 和 PassKey 相关功能 包含 OTP 和 PassKey 相关功能
""" """
import base64 from datetime import timedelta
from datetime import datetime, timedelta from typing import Any, Annotated, Optional, List, Union
from typing import Any, Annotated, List, Union
from fastapi import APIRouter, Depends, HTTPException, Body from fastapi import APIRouter, Depends, HTTPException, Body
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
@@ -25,6 +24,22 @@ from app.utils.otp import OtpUtils
router = APIRouter() router = APIRouter()
# ==================== 请求模型 ====================
class OtpVerifyRequest(schemas.BaseModel):
"""OTP验证请求"""
uri: str
otpPassword: str
class OtpDisableRequest(schemas.BaseModel):
"""OTP禁用请求"""
password: str
class PassKeyDeleteRequest(schemas.BaseModel):
"""PassKey删除请求"""
passkey_id: int
password: str
# ==================== 通用 MFA 接口 ==================== # ==================== 通用 MFA 接口 ====================
@router.get('/status/{username}', summary='判断用户是否开启双重验证(MFA)', response_model=schemas.Response) @router.get('/status/{username}', summary='判断用户是否开启双重验证(MFA)', response_model=schemas.Response)
@@ -40,7 +55,7 @@ async def mfa_status(username: str, db: AsyncSession = Depends(get_async_db)) ->
has_otp = user.is_otp has_otp = user.is_otp
# 检查是否有PassKey # 检查是否有PassKey
has_passkey = bool(PassKey().get_by_user_id(db=None, user_id=user.id)) has_passkey = bool(await PassKey.async_get_by_user_id(db=db, user_id=user.id))
# 只要有任何一种验证方式,就需要双重验证 # 只要有任何一种验证方式,就需要双重验证
return schemas.Response(success=(has_otp or has_passkey)) return schemas.Response(success=(has_otp or has_passkey))
@@ -59,25 +74,35 @@ def otp_generate(
@router.post('/otp/verify', summary='绑定并验证 OTP', response_model=schemas.Response) @router.post('/otp/verify', summary='绑定并验证 OTP', response_model=schemas.Response)
async def otp_verify( async def otp_verify(
data: dict, data: OtpVerifyRequest,
db: AsyncSession = Depends(get_async_db), db: AsyncSession = Depends(get_async_db),
current_user: User = Depends(get_current_active_user_async) current_user: User = Depends(get_current_active_user_async)
) -> Any: ) -> Any:
"""验证用户输入的 OTP 码,验证通过后正式开启 OTP 验证""" """验证用户输入的 OTP 码,验证通过后正式开启 OTP 验证"""
uri = data.get("uri") if not OtpUtils.is_legal(data.uri, data.otpPassword):
otp_password = data.get("otpPassword")
if not OtpUtils.is_legal(uri, otp_password):
return schemas.Response(success=False, message="验证码错误") return schemas.Response(success=False, message="验证码错误")
await current_user.async_update_otp_by_name(db, current_user.name, True, OtpUtils.get_secret(uri)) await current_user.async_update_otp_by_name(db, current_user.name, True, OtpUtils.get_secret(data.uri))
return schemas.Response(success=True) return schemas.Response(success=True)
@router.post('/otp/disable', summary='关闭当前用户的 OTP 验证', response_model=schemas.Response) @router.post('/otp/disable', summary='关闭当前用户的 OTP 验证', response_model=schemas.Response)
async def otp_disable( async def otp_disable(
data: OtpDisableRequest,
db: AsyncSession = Depends(get_async_db), db: AsyncSession = Depends(get_async_db),
current_user: User = Depends(get_current_active_user_async) current_user: User = Depends(get_current_active_user_async)
) -> Any: ) -> Any:
"""关闭当前用户的 OTP 验证功能""" """关闭当前用户的 OTP 验证功能"""
# 安全检查:如果存在 PassKey不允许关闭 OTP
has_passkey = bool(await PassKey.async_get_by_user_id(db=db, user_id=current_user.id))
if has_passkey:
return schemas.Response(
success=False,
message="您已注册通行密钥,为了防止域名配置变更导致无法登录,请先删除所有通行密钥再关闭 OTP 验证"
)
# 验证密码
if not security.verify_password(data.password, str(current_user.hashed_password)):
return schemas.Response(success=False, message="密码错误")
await current_user.async_update_otp_by_name(db, current_user.name, False, "") await current_user.async_update_otp_by_name(db, current_user.name, False, "")
return schemas.Response(success=True) return schemas.Response(success=True)
@@ -98,7 +123,7 @@ class PassKeyRegistrationFinish(schemas.BaseModel):
class PassKeyAuthenticationStart(schemas.BaseModel): class PassKeyAuthenticationStart(schemas.BaseModel):
"""PassKey认证开始请求""" """PassKey认证开始请求"""
username: str | None = None username: Optional[str] = None
class PassKeyAuthenticationFinish(schemas.BaseModel): class PassKeyAuthenticationFinish(schemas.BaseModel):
@@ -122,7 +147,7 @@ def passkey_register_start(
) )
# 获取用户已有的PassKey # 获取用户已有的PassKey
existing_passkeys = PassKey().get_by_user_id(db=None, user_id=current_user.id) existing_passkeys = PassKey.get_by_user_id(db=None, user_id=current_user.id)
existing_credentials = [ existing_credentials = [
{ {
'credential_id': pk.credential_id, 'credential_id': pk.credential_id,
@@ -215,7 +240,7 @@ def passkey_authenticate_start(
message="用户不存在" message="用户不存在"
) )
existing_passkeys = PassKey().get_by_user_id(db=None, user_id=user.id) existing_passkeys = PassKey.get_by_user_id(db=None, user_id=user.id)
if not existing_passkeys: if not existing_passkeys:
return schemas.Response( return schemas.Response(
success=False, success=False,
@@ -265,7 +290,7 @@ def passkey_authenticate_finish(
credential_id = PassKeyHelper.standardize_credential_id(credential_id_raw) credential_id = PassKeyHelper.standardize_credential_id(credential_id_raw)
# 查找PassKey # 查找PassKey
passkey = PassKey().get_by_credential_id(db=None, credential_id=credential_id) passkey = PassKey.get_by_credential_id(db=None, credential_id=credential_id)
if not passkey: if not passkey:
raise HTTPException(status_code=401, detail="通行密钥不存在或已失效") raise HTTPException(status_code=401, detail="通行密钥不存在或已失效")
@@ -286,7 +311,7 @@ def passkey_authenticate_finish(
raise HTTPException(status_code=401, detail="通行密钥验证失败") raise HTTPException(status_code=401, detail="通行密钥验证失败")
# 更新使用时间和签名计数 # 更新使用时间和签名计数
passkey.update_last_used(db=None, credential_id=credential_id, sign_count=new_sign_count) passkey.update_last_used(db=None, sign_count=new_sign_count)
logger.info(f"用户 {user.name} 通过PassKey认证成功") logger.info(f"用户 {user.name} 通过PassKey认证成功")
@@ -309,7 +334,7 @@ def passkey_authenticate_finish(
avatar=user.avatar, avatar=user.avatar,
level=level, level=level,
permissions=user.permissions or {}, permissions=user.permissions or {},
widzard=show_wizard wizard=show_wizard
) )
except HTTPException: except HTTPException:
raise raise
@@ -324,7 +349,7 @@ def passkey_list(
) -> Any: ) -> Any:
"""获取当前用户的所有 PassKey""" """获取当前用户的所有 PassKey"""
try: try:
passkeys = PassKey().get_by_user_id(db=None, user_id=current_user.id) passkeys = PassKey.get_by_user_id(db=None, user_id=current_user.id)
passkey_list = [ passkey_list = [
{ {
@@ -350,17 +375,21 @@ def passkey_list(
) )
@router.delete("/passkey/{passkey_id}", summary="删除 PassKey", response_model=schemas.Response) @router.post("/passkey/delete", summary="删除 PassKey", response_model=schemas.Response)
def passkey_delete( async def passkey_delete(
passkey_id: int, data: PassKeyDeleteRequest,
current_user: Annotated[User, Depends(get_current_active_user)] current_user: User = Depends(get_current_active_user_async)
) -> Any: ) -> Any:
"""删除指定的 PassKey""" """删除指定的 PassKey"""
try: try:
success = PassKey().delete_by_id(db=None, passkey_id=passkey_id, user_id=current_user.id) # 验证密码
if not security.verify_password(data.password, str(current_user.hashed_password)):
return schemas.Response(success=False, message="密码错误")
success = PassKey.delete_by_id(db=None, passkey_id=data.passkey_id, user_id=current_user.id)
if success: if success:
logger.info(f"用户 {current_user.name} 删除了PassKey: {passkey_id}") logger.info(f"用户 {current_user.name} 删除了PassKey: {data.passkey_id}")
return schemas.Response( return schemas.Response(
success=True, success=True,
message="通行密钥已删除" message="通行密钥已删除"
@@ -397,7 +426,7 @@ def passkey_verify_mfa(
credential_id = PassKeyHelper.standardize_credential_id(credential_id_raw) credential_id = PassKeyHelper.standardize_credential_id(credential_id_raw)
# 查找PassKey必须属于当前用户 # 查找PassKey必须属于当前用户
passkey = PassKey().get_by_credential_id(db=None, credential_id=credential_id) passkey = PassKey.get_by_credential_id(db=None, credential_id=credential_id)
if not passkey or passkey.user_id != current_user.id: if not passkey or passkey.user_id != current_user.id:
return schemas.Response( return schemas.Response(
success=False, success=False,
@@ -419,7 +448,7 @@ def passkey_verify_mfa(
) )
# 更新使用时间和签名计数 # 更新使用时间和签名计数
passkey.update_last_used(db=None, credential_id=credential_id, sign_count=new_sign_count) passkey.update_last_used(db=None, sign_count=new_sign_count)
logger.info(f"用户 {current_user.name} 通过PassKey二次验证成功") logger.info(f"用户 {current_user.name} 通过PassKey二次验证成功")

View File

@@ -179,7 +179,7 @@ class UserChain(ChainBase):
""" """
# 检查用户是否有PassKey # 检查用户是否有PassKey
from app.db.models.passkey import PassKey from app.db.models.passkey import PassKey
has_passkey = bool(PassKey().get_by_user_id(db=None, user_id=user.id)) has_passkey = bool(PassKey.get_by_user_id(db=None, user_id=user.id))
# 如果用户既没有启用OTP也没有PassKey直接通过 # 如果用户既没有启用OTP也没有PassKey直接通过
if not user.is_otp and not has_passkey: if not user.is_otp and not has_passkey:
@@ -190,11 +190,20 @@ class UserChain(ChainBase):
logger.info(f"用户 {user.name} 已启用双重验证OTP: {user.is_otp}, PassKey: {has_passkey}),需要提供验证码") logger.info(f"用户 {user.name} 已启用双重验证OTP: {user.is_otp}, PassKey: {has_passkey}),需要提供验证码")
return "MFA_REQUIRED" return "MFA_REQUIRED"
# 如果提供了验证码验证OTP # 如果提供了验证码,且用户启用了 OTP验证 OTP
if user.is_otp: if user.is_otp:
if not OtpUtils.check(str(user.otp_secret), mfa_code): if not OtpUtils.check(str(user.otp_secret), mfa_code):
logger.info(f"用户 {user.name} 的 MFA 认证失败") logger.info(f"用户 {user.name} 的 MFA 认证失败")
return False return False
# OTP 验证成功
return True
# 用户未启用 OTP此时提供的 mfa_code 无效;如果启用了 PassKey则仍需通过 PassKey 验证
if has_passkey:
logger.info(
f"用户 {user.name} 未启用 OTP但已启用 PassKey提供的 MFA 验证码将被忽略,仍需通过 PassKey 验证"
)
return "MFA_REQUIRED"
return True return True

View File

@@ -37,14 +37,14 @@ class PassKey(Base):
@db_query @db_query
def get_by_user_id(cls, db: Session, user_id: int): def get_by_user_id(cls, db: Session, user_id: int):
"""获取用户的所有PassKey""" """获取用户的所有PassKey"""
return db.query(cls).filter(cls.user_id == user_id, cls.is_active == True).all() return db.query(cls).filter(cls.user_id == user_id, cls.is_active.is_(True)).all()
@classmethod @classmethod
@async_db_query @async_db_query
async def async_get_by_user_id(cls, db: AsyncSession, user_id: int): async def async_get_by_user_id(cls, db: AsyncSession, user_id: int):
"""异步获取用户的所有PassKey""" """异步获取用户的所有PassKey"""
result = await db.execute( result = await db.execute(
select(cls).filter(cls.user_id == user_id, cls.is_active == True) select(cls).filter(cls.user_id == user_id, cls.is_active.is_(True))
) )
return result.scalars().all() return result.scalars().all()
@@ -52,14 +52,14 @@ class PassKey(Base):
@db_query @db_query
def get_by_credential_id(cls, db: Session, credential_id: str): def get_by_credential_id(cls, db: Session, credential_id: str):
"""根据凭证ID获取PassKey""" """根据凭证ID获取PassKey"""
return db.query(cls).filter(cls.credential_id == credential_id, cls.is_active == True).first() return db.query(cls).filter(cls.credential_id == credential_id, cls.is_active.is_(True)).first()
@classmethod @classmethod
@async_db_query @async_db_query
async def async_get_by_credential_id(cls, db: AsyncSession, credential_id: str): async def async_get_by_credential_id(cls, db: AsyncSession, credential_id: str):
"""异步根据凭证ID获取PassKey""" """异步根据凭证ID获取PassKey"""
result = await db.execute( result = await db.execute(
select(cls).filter(cls.credential_id == credential_id, cls.is_active == True) select(cls).filter(cls.credential_id == credential_id, cls.is_active.is_(True))
) )
return result.scalars().first() return result.scalars().first()
@@ -78,25 +78,27 @@ class PassKey(Base):
) )
return result.scalars().first() return result.scalars().first()
@classmethod
@db_update @db_update
def delete_by_id(self, db: Session, passkey_id: int, user_id: int): def delete_by_id(cls, db: Session, passkey_id: int, user_id: int):
"""删除指定用户的PassKey""" """删除指定用户的PassKey"""
passkey = db.query(PassKey).filter( passkey = db.query(cls).filter(
PassKey.id == passkey_id, cls.id == passkey_id,
PassKey.user_id == user_id cls.user_id == user_id
).first() ).first()
if passkey: if passkey:
passkey.delete(db, passkey.id) passkey.delete(db, passkey.id)
return True return True
return False return False
@classmethod
@async_db_update @async_db_update
async def async_delete_by_id(self, db: AsyncSession, passkey_id: int, user_id: int): async def async_delete_by_id(cls, db: AsyncSession, passkey_id: int, user_id: int):
"""异步删除指定用户的PassKey""" """异步删除指定用户的PassKey"""
result = await db.execute( result = await db.execute(
select(PassKey).filter( select(cls).filter(
PassKey.id == passkey_id, cls.id == passkey_id,
PassKey.user_id == user_id cls.user_id == user_id
) )
) )
passkey = result.scalars().first() passkey = result.scalars().first()
@@ -106,25 +108,19 @@ class PassKey(Base):
return False return False
@db_update @db_update
def update_last_used(self, db: Session, credential_id: str, sign_count: int): def update_last_used(self, db: Session, sign_count: int):
"""更新最后使用时间和签名计数""" """更新最后使用时间和签名计数"""
passkey = self.get_by_credential_id(db, credential_id) self.update(db, {
if passkey: 'last_used_at': datetime.now(),
passkey.update(db, { 'sign_count': sign_count
'last_used_at': datetime.now(), })
'sign_count': sign_count return True
})
return True
return False
@async_db_update @async_db_update
async def async_update_last_used(self, db: AsyncSession, credential_id: str, sign_count: int): async def async_update_last_used(self, db: AsyncSession, sign_count: int):
"""异步更新最后使用时间和签名计数""" """异步更新最后使用时间和签名计数"""
passkey = await self.async_get_by_credential_id(db, credential_id) await self.async_update(db, {
if passkey: 'last_used_at': datetime.now(),
await passkey.async_update(db, { 'sign_count': sign_count
'last_used_at': datetime.now(), })
'sign_count': sign_count return True
})
return True
return False

View File

@@ -3,6 +3,7 @@ PassKey WebAuthn 辅助工具类
""" """
import base64 import base64
import json import json
import binascii
from typing import Optional, Tuple, List, Dict, Any from typing import Optional, Tuple, List, Dict, Any
from urllib.parse import urlparse from urllib.parse import urlparse
@@ -23,10 +24,6 @@ from webauthn.helpers.structs import (
UserVerificationRequirement, UserVerificationRequirement,
AuthenticatorAttachment, AuthenticatorAttachment,
ResidentKeyRequirement, ResidentKeyRequirement,
PublicKeyCredentialCreationOptions,
PublicKeyCredentialRequestOptions,
RegistrationCredential,
AuthenticationCredential,
AuthenticatorSelectionCriteria AuthenticatorSelectionCriteria
) )
from webauthn.helpers.cose import COSEAlgorithmIdentifier from webauthn.helpers.cose import COSEAlgorithmIdentifier
@@ -46,6 +43,14 @@ class PassKeyHelper:
获取 Relying Party ID 获取 Relying Party ID
""" """
if settings.APP_DOMAIN: if settings.APP_DOMAIN:
app_domain = settings.APP_DOMAIN.strip()
# 确保存在协议前缀,以便 urlparse 正确解析主机和端口
if not app_domain.startswith(('http://', 'https://')):
app_domain = f'https://{app_domain}'
parsed = urlparse(app_domain)
host = parsed.hostname
if host:
return host
# 从 APP_DOMAIN 中提取域名 # 从 APP_DOMAIN 中提取域名
host = settings.APP_DOMAIN.replace('https://', '').replace('http://', '') host = settings.APP_DOMAIN.replace('https://', '').replace('http://', '')
# 移除端口号 # 移除端口号
@@ -81,7 +86,7 @@ class PassKeyHelper:
# Base64解码并重新编码以标准化格式 # Base64解码并重新编码以标准化格式
decoded = base64.urlsafe_b64decode(credential_id + '==') decoded = base64.urlsafe_b64decode(credential_id + '==')
return base64.urlsafe_b64encode(decoded).decode('utf-8').rstrip('=') return base64.urlsafe_b64encode(decoded).decode('utf-8').rstrip('=')
except Exception as e: except (binascii.Error, TypeError, ValueError) as e:
logger.error(f"标准化凭证ID失败: {e}") logger.error(f"标准化凭证ID失败: {e}")
return credential_id return credential_id
@@ -135,7 +140,7 @@ class PassKeyHelper:
user_display_name=display_name or username, user_display_name=display_name or username,
exclude_credentials=exclude_credentials if exclude_credentials else None, exclude_credentials=exclude_credentials if exclude_credentials else None,
authenticator_selection=AuthenticatorSelectionCriteria( authenticator_selection=AuthenticatorSelectionCriteria(
authenticator_attachment=AuthenticatorAttachment.PLATFORM, authenticator_attachment=None,
resident_key=ResidentKeyRequirement.REQUIRED, resident_key=ResidentKeyRequirement.REQUIRED,
user_verification=uv_requirement, user_verification=uv_requirement,
), ),

View File

@@ -21,7 +21,7 @@ class Token(BaseModel):
# 详细权限 # 详细权限
permissions: Optional[dict] = Field(default_factory=dict) permissions: Optional[dict] = Field(default_factory=dict)
# 是否显示配置向导 # 是否显示配置向导
widzard: Optional[bool] = None wizard: Optional[bool] = None
class TokenPayload(BaseModel): class TokenPayload(BaseModel):