From 0c6cfc5020f18ebcedd2829ad6a69e5a0c5bcab5 Mon Sep 17 00:00:00 2001 From: PKC278 <52959804+PKC278@users.noreply.github.com> Date: Tue, 23 Dec 2025 13:53:54 +0800 Subject: [PATCH] =?UTF-8?q?feat(passkey):=20=E6=B7=BB=E5=8A=A0PassKey?= =?UTF-8?q?=E6=94=AF=E6=8C=81=E5=B9=B6=E4=BC=98=E5=8C=96=E5=8F=8C=E9=87=8D?= =?UTF-8?q?=E9=AA=8C=E8=AF=81=E7=99=BB=E5=BD=95=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/apiv1.py | 3 +- app/api/endpoints/login.py | 7 + app/api/endpoints/mfa.py | 435 +++++++++++++++++++++++++++++++++++++ app/api/endpoints/user.py | 39 ---- app/chain/user.py | 43 +++- app/core/config.py | 2 + app/db/models/__init__.py | 1 + app/db/models/passkey.py | 130 +++++++++++ app/helper/passkey.py | 347 +++++++++++++++++++++++++++++ requirements.in | 1 + 10 files changed, 957 insertions(+), 51 deletions(-) create mode 100644 app/api/endpoints/mfa.py create mode 100644 app/db/models/passkey.py create mode 100644 app/helper/passkey.py diff --git a/app/api/apiv1.py b/app/api/apiv1.py index 8453d35a..b519acb9 100644 --- a/app/api/apiv1.py +++ b/app/api/apiv1.py @@ -2,11 +2,12 @@ from fastapi import APIRouter from app.api.endpoints import login, user, webhook, message, site, subscribe, \ media, douban, search, plugin, tmdb, history, system, download, dashboard, \ - transfer, mediaserver, bangumi, storage, discover, recommend, workflow, torrent, mcp + transfer, mediaserver, bangumi, storage, discover, recommend, workflow, torrent, mcp, mfa api_router = APIRouter() api_router.include_router(login.router, prefix="/login", tags=["login"]) api_router.include_router(user.router, prefix="/user", tags=["user"]) +api_router.include_router(mfa.router, prefix="/mfa", tags=["mfa"]) api_router.include_router(site.router, prefix="/site", tags=["site"]) api_router.include_router(message.router, prefix="/message", tags=["message"]) api_router.include_router(webhook.router, prefix="/webhook", tags=["webhook"]) diff --git a/app/api/endpoints/login.py b/app/api/endpoints/login.py index f4791d0a..769ad445 100644 --- a/app/api/endpoints/login.py +++ b/app/api/endpoints/login.py @@ -29,6 +29,13 @@ def login_access_token( mfa_code=otp_password) if not success: + # 如果是需要MFA验证,返回特殊标识 + if user_or_message == "MFA_REQUIRED": + raise HTTPException( + status_code=401, + detail="需要双重验证", + headers={"X-MFA-Required": "true"} + ) raise HTTPException(status_code=401, detail=user_or_message) # 用户等级 diff --git a/app/api/endpoints/mfa.py b/app/api/endpoints/mfa.py new file mode 100644 index 00000000..f4498c25 --- /dev/null +++ b/app/api/endpoints/mfa.py @@ -0,0 +1,435 @@ +""" +MFA (Multi-Factor Authentication) API 端点 +包含 OTP 和 PassKey 相关功能 +""" +import base64 +from datetime import datetime, timedelta +from typing import Any, Annotated, List, Union + +from fastapi import APIRouter, Depends, HTTPException, Body +from sqlalchemy.ext.asyncio import AsyncSession + +from app import schemas +from app.core import security +from app.core.config import settings +from app.db import get_async_db +from app.db.models.passkey import PassKey +from app.db.models.user import User +from app.db.user_oper import get_current_active_user, get_current_active_user_async +from app.db.systemconfig_oper import SystemConfigOper +from app.helper.passkey import PassKeyHelper +from app.helper.sites import SitesHelper +from app.log import logger +from app.schemas.types import SystemConfigKey +from app.utils.otp import OtpUtils + +router = APIRouter() + +# ==================== 通用 MFA 接口 ==================== + +@router.get('/status/{username}', summary='判断用户是否开启双重验证(MFA)', response_model=schemas.Response) +async def mfa_status(username: str, db: AsyncSession = Depends(get_async_db)) -> Any: + """ + 检查指定用户是否启用了任何双重验证方式(OTP 或 PassKey) + """ + user: User = await User.async_get_by_name(db, username) + if not user: + return schemas.Response(success=False) + + # 检查是否启用了OTP + has_otp = user.is_otp + + # 检查是否有PassKey + has_passkey = bool(PassKey().get_by_user_id(db=None, user_id=user.id)) + + # 只要有任何一种验证方式,就需要双重验证 + return schemas.Response(success=(has_otp or has_passkey)) + + +# ==================== OTP 相关接口 ==================== + +@router.post('/otp/generate', summary='生成 OTP 验证 URI', response_model=schemas.Response) +def otp_generate( + current_user: Annotated[User, Depends(get_current_active_user)] +) -> Any: + """生成 OTP 密钥及对应的 URI""" + secret, uri = OtpUtils.generate_secret_key(current_user.name) + return schemas.Response(success=secret != "", data={'secret': secret, 'uri': uri}) + + +@router.post('/otp/verify', summary='绑定并验证 OTP', response_model=schemas.Response) +async def otp_verify( + data: dict, + db: AsyncSession = Depends(get_async_db), + current_user: User = Depends(get_current_active_user_async) +) -> Any: + """验证用户输入的 OTP 码,验证通过后正式开启 OTP 验证""" + uri = data.get("uri") + otp_password = data.get("otpPassword") + if not OtpUtils.is_legal(uri, otp_password): + return schemas.Response(success=False, message="验证码错误") + await current_user.async_update_otp_by_name(db, current_user.name, True, OtpUtils.get_secret(uri)) + return schemas.Response(success=True) + + +@router.post('/otp/disable', summary='关闭当前用户的 OTP 验证', response_model=schemas.Response) +async def otp_disable( + db: AsyncSession = Depends(get_async_db), + current_user: User = Depends(get_current_active_user_async) +) -> Any: + """关闭当前用户的 OTP 验证功能""" + await current_user.async_update_otp_by_name(db, current_user.name, False, "") + return schemas.Response(success=True) + + +# ==================== PassKey 相关接口 ==================== + +class PassKeyRegistrationStart(schemas.BaseModel): + """PassKey注册开始请求""" + name: str = "通行密钥" + + +class PassKeyRegistrationFinish(schemas.BaseModel): + """PassKey注册完成请求""" + credential: dict + challenge: str + name: str = "通行密钥" + + +class PassKeyAuthenticationStart(schemas.BaseModel): + """PassKey认证开始请求""" + username: str | None = None + + +class PassKeyAuthenticationFinish(schemas.BaseModel): + """PassKey认证完成请求""" + credential: dict + challenge: str + + +@router.post("/passkey/register/start", summary="开始注册 PassKey", response_model=schemas.Response) +def passkey_register_start( + passkey_req: PassKeyRegistrationStart, + current_user: Annotated[User, Depends(get_current_active_user)] +) -> Any: + """开始注册 PassKey - 生成注册选项""" + try: + # 安全检查:必须先启用 OTP + if not current_user.is_otp: + return schemas.Response( + success=False, + message="为了确保在域名配置错误时仍能找回访问权限,请先启用 OTP 验证码再注册通行密钥" + ) + + # 获取用户已有的PassKey + existing_passkeys = PassKey().get_by_user_id(db=None, user_id=current_user.id) + existing_credentials = [ + { + 'credential_id': pk.credential_id, + 'transports': pk.transports + } + for pk in existing_passkeys + ] if existing_passkeys else None + + # 生成注册选项 + options_json, challenge = PassKeyHelper.generate_registration_options( + user_id=current_user.id, + username=current_user.name, + display_name=current_user.settings.get('nickname') if current_user.settings else None, + existing_credentials=existing_credentials + ) + + return schemas.Response( + success=True, + data={ + 'options': options_json, + 'challenge': challenge + } + ) + except Exception as e: + logger.error(f"生成PassKey注册选项失败: {e}") + return schemas.Response( + success=False, + message=f"生成注册选项失败: {str(e)}" + ) + + +@router.post("/passkey/register/finish", summary="完成注册 PassKey", response_model=schemas.Response) +def passkey_register_finish( + passkey_req: PassKeyRegistrationFinish, + current_user: Annotated[User, Depends(get_current_active_user)] +) -> Any: + """完成注册 PassKey - 验证并保存凭证""" + try: + # 验证注册响应 + credential_id, public_key, sign_count, aaguid = PassKeyHelper.verify_registration_response( + credential=passkey_req.credential, + expected_challenge=passkey_req.challenge + ) + + # 提取transports + transports = None + if 'response' in passkey_req.credential and 'transports' in passkey_req.credential['response']: + transports = ','.join(passkey_req.credential['response']['transports']) + + # 保存到数据库 + passkey = PassKey( + user_id=current_user.id, + credential_id=credential_id, + public_key=public_key, + sign_count=sign_count, + name=passkey_req.name or "通行密钥", + aaguid=aaguid, + transports=transports + ) + passkey.create() + + logger.info(f"用户 {current_user.name} 成功注册PassKey: {passkey_req.name}") + + return schemas.Response( + success=True, + message="通行密钥注册成功" + ) + except Exception as e: + logger.error(f"注册PassKey失败: {e}") + return schemas.Response( + success=False, + message=f"注册失败: {str(e)}" + ) + + +@router.post("/passkey/authenticate/start", summary="开始 PassKey 认证", response_model=schemas.Response) +def passkey_authenticate_start( + passkey_req: PassKeyAuthenticationStart = Body(...) +) -> Any: + """开始 PassKey 认证 - 生成认证选项""" + try: + existing_credentials = None + + # 如果指定了用户名,只允许该用户的PassKey + if passkey_req.username: + user = User.get_by_name(db=None, name=passkey_req.username) + if not user: + return schemas.Response( + success=False, + message="用户不存在" + ) + + existing_passkeys = PassKey().get_by_user_id(db=None, user_id=user.id) + if not existing_passkeys: + return schemas.Response( + success=False, + message="该用户未注册通行密钥" + ) + + existing_credentials = [ + { + 'credential_id': pk.credential_id, + 'transports': pk.transports + } + for pk in existing_passkeys + ] + + # 生成认证选项 + options_json, challenge = PassKeyHelper.generate_authentication_options( + existing_credentials=existing_credentials + ) + + return schemas.Response( + success=True, + data={ + 'options': options_json, + 'challenge': challenge + } + ) + except Exception as e: + logger.error(f"生成PassKey认证选项失败: {e}") + return schemas.Response( + success=False, + message=f"生成认证选项失败: {str(e)}" + ) + + +@router.post("/passkey/authenticate/finish", summary="完成 PassKey 认证", response_model=schemas.Token) +def passkey_authenticate_finish( + passkey_req: PassKeyAuthenticationFinish +) -> Any: + """完成 PassKey 认证 - 验证凭证并返回 token""" + try: + # 从credential中提取credential_id + credential_id_raw = passkey_req.credential.get('id') or passkey_req.credential.get('rawId') + if not credential_id_raw: + raise HTTPException(status_code=400, detail="无效的凭证") + + # 标准化凭证ID + credential_id = PassKeyHelper.standardize_credential_id(credential_id_raw) + + # 查找PassKey + passkey = PassKey().get_by_credential_id(db=None, credential_id=credential_id) + if not passkey: + raise HTTPException(status_code=401, detail="通行密钥不存在或已失效") + + # 获取用户 + user = User.get_by_id(db=None, user_id=passkey.user_id) + if not user or not user.is_active: + raise HTTPException(status_code=401, detail="用户不存在或已禁用") + + # 验证认证响应 + success, new_sign_count = PassKeyHelper.verify_authentication_response( + credential=passkey_req.credential, + expected_challenge=passkey_req.challenge, + credential_public_key=passkey.public_key, + credential_current_sign_count=passkey.sign_count + ) + + if not success: + raise HTTPException(status_code=401, detail="通行密钥验证失败") + + # 更新使用时间和签名计数 + passkey.update_last_used(db=None, credential_id=credential_id, sign_count=new_sign_count) + + logger.info(f"用户 {user.name} 通过PassKey认证成功") + + # 生成token + level = SitesHelper().auth_level + show_wizard = not SystemConfigOper().get(SystemConfigKey.SetupWizardState) and not settings.ADVANCED_MODE + + return schemas.Token( + access_token=security.create_access_token( + userid=user.id, + username=user.name, + super_user=user.is_superuser, + expires_delta=timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES), + level=level + ), + token_type="bearer", + super_user=user.is_superuser, + user_id=user.id, + user_name=user.name, + avatar=user.avatar, + level=level, + permissions=user.permissions or {}, + widzard=show_wizard + ) + except HTTPException: + raise + except Exception as e: + logger.error(f"PassKey认证失败: {e}") + raise HTTPException(status_code=401, detail=f"认证失败: {str(e)}") + + +@router.get("/passkey/list", summary="获取当前用户的 PassKey 列表", response_model=schemas.Response) +def passkey_list( + current_user: Annotated[User, Depends(get_current_active_user)] +) -> Any: + """获取当前用户的所有 PassKey""" + try: + passkeys = PassKey().get_by_user_id(db=None, user_id=current_user.id) + + passkey_list = [ + { + 'id': pk.id, + 'name': pk.name, + 'created_at': pk.created_at.isoformat() if pk.created_at else None, + 'last_used_at': pk.last_used_at.isoformat() if pk.last_used_at else None, + 'aaguid': pk.aaguid, + 'transports': pk.transports + } + for pk in passkeys + ] if passkeys else [] + + return schemas.Response( + success=True, + data=passkey_list + ) + except Exception as e: + logger.error(f"获取PassKey列表失败: {e}") + return schemas.Response( + success=False, + message=f"获取列表失败: {str(e)}" + ) + + +@router.delete("/passkey/{passkey_id}", summary="删除 PassKey", response_model=schemas.Response) +def passkey_delete( + passkey_id: int, + current_user: Annotated[User, Depends(get_current_active_user)] +) -> Any: + """删除指定的 PassKey""" + try: + success = PassKey().delete_by_id(db=None, passkey_id=passkey_id, user_id=current_user.id) + + if success: + logger.info(f"用户 {current_user.name} 删除了PassKey: {passkey_id}") + return schemas.Response( + success=True, + message="通行密钥已删除" + ) + else: + return schemas.Response( + success=False, + message="通行密钥不存在或无权删除" + ) + except Exception as e: + logger.error(f"删除PassKey失败: {e}") + return schemas.Response( + success=False, + message=f"删除失败: {str(e)}" + ) + + +@router.post("/passkey/verify", summary="PassKey 二次验证", response_model=schemas.Response) +def passkey_verify_mfa( + passkey_req: PassKeyAuthenticationFinish, + current_user: Annotated[User, Depends(get_current_active_user)] +) -> Any: + """使用 PassKey 进行二次验证(MFA)""" + try: + # 从credential中提取credential_id + credential_id_raw = passkey_req.credential.get('id') or passkey_req.credential.get('rawId') + if not credential_id_raw: + return schemas.Response( + success=False, + message="无效的凭证" + ) + + # 标准化凭证ID + credential_id = PassKeyHelper.standardize_credential_id(credential_id_raw) + + # 查找PassKey(必须属于当前用户) + passkey = PassKey().get_by_credential_id(db=None, credential_id=credential_id) + if not passkey or passkey.user_id != current_user.id: + return schemas.Response( + success=False, + message="通行密钥不存在或不属于当前用户" + ) + + # 验证认证响应 + success, new_sign_count = PassKeyHelper.verify_authentication_response( + credential=passkey_req.credential, + expected_challenge=passkey_req.challenge, + credential_public_key=passkey.public_key, + credential_current_sign_count=passkey.sign_count + ) + + if not success: + return schemas.Response( + success=False, + message="通行密钥验证失败" + ) + + # 更新使用时间和签名计数 + passkey.update_last_used(db=None, credential_id=credential_id, sign_count=new_sign_count) + + logger.info(f"用户 {current_user.name} 通过PassKey二次验证成功") + + return schemas.Response( + success=True, + message="二次验证成功" + ) + except Exception as e: + logger.error(f"PassKey二次验证失败: {e}") + return schemas.Response( + success=False, + message=f"验证失败: {str(e)}" + ) diff --git a/app/api/endpoints/user.py b/app/api/endpoints/user.py index fe6aa1f9..56d756fc 100644 --- a/app/api/endpoints/user.py +++ b/app/api/endpoints/user.py @@ -111,45 +111,6 @@ async def upload_avatar(user_id: int, db: AsyncSession = Depends(get_async_db), return schemas.Response(success=True, message=file.filename) -@router.post('/otp/generate', summary='生成otp验证uri', response_model=schemas.Response) -def otp_generate( - current_user: User = Depends(get_current_active_user) -) -> Any: - secret, uri = OtpUtils.generate_secret_key(current_user.name) - return schemas.Response(success=secret != "", data={'secret': secret, 'uri': uri}) - - -@router.post('/otp/judge', summary='判断otp验证是否通过', response_model=schemas.Response) -async def otp_judge( - data: dict, - db: AsyncSession = Depends(get_async_db), - current_user: User = Depends(get_current_active_user_async) -) -> Any: - uri = data.get("uri") - otp_password = data.get("otpPassword") - if not OtpUtils.is_legal(uri, otp_password): - return schemas.Response(success=False, message="验证码错误") - await current_user.async_update_otp_by_name(db, current_user.name, True, OtpUtils.get_secret(uri)) - return schemas.Response(success=True) - - -@router.post('/otp/disable', summary='关闭当前用户的otp验证', response_model=schemas.Response) -async def otp_disable( - db: AsyncSession = Depends(get_async_db), - current_user: User = Depends(get_current_active_user_async) -) -> Any: - await current_user.async_update_otp_by_name(db, current_user.name, False, "") - return schemas.Response(success=True) - - -@router.get('/otp/{userid}', summary='判断当前用户是否开启otp验证', response_model=schemas.Response) -async def otp_enable(userid: str, db: AsyncSession = Depends(get_async_db)) -> Any: - user: User = await User.async_get_by_name(db, userid) - if not user: - return schemas.Response(success=False) - return schemas.Response(success=user.is_otp) - - @router.get("/config/{key}", summary="查询用户配置", response_model=schemas.Response) def get_config(key: str, current_user: User = Depends(get_current_active_user)): diff --git a/app/chain/user.py b/app/chain/user.py index 331b1836..b920f490 100644 --- a/app/chain/user.py +++ b/app/chain/user.py @@ -52,7 +52,10 @@ class UserChain(ChainBase): success, user_or_message = self.password_authenticate(credentials=credentials) if success: # 如果用户启用了二次验证码,则进一步验证 - if not self._verify_mfa(user_or_message, credentials.mfa_code): + mfa_result = self._verify_mfa(user_or_message, credentials.mfa_code) + if mfa_result == "MFA_REQUIRED": + return False, "MFA_REQUIRED" + elif not mfa_result: return False, PASSWORD_INVALID_CREDENTIALS_MESSAGE logger.info(f"用户 {username} 通过密码认证成功") return True, user_or_message @@ -63,7 +66,10 @@ class UserChain(ChainBase): aux_success, aux_user_or_message = self.auxiliary_authenticate(credentials=credentials) if aux_success: # 辅助认证成功后再验证二次验证码 - if not self._verify_mfa(aux_user_or_message, credentials.mfa_code): + mfa_result = self._verify_mfa(aux_user_or_message, credentials.mfa_code) + if mfa_result == "MFA_REQUIRED": + return False, "MFA_REQUIRED" + elif not mfa_result: return False, PASSWORD_INVALID_CREDENTIALS_MESSAGE return True, aux_user_or_message else: @@ -159,22 +165,37 @@ class UserChain(ChainBase): return False, PASSWORD_INVALID_CREDENTIALS_MESSAGE @staticmethod - def _verify_mfa(user: User, mfa_code: Optional[str]) -> bool: + def _verify_mfa(user: User, mfa_code: Optional[str]) -> Union[bool, str]: """ 验证 MFA(二次验证码) + 检查用户是否启用了 OTP 或 PassKey,如果启用了任何一种,都需要提供验证 :param user: 用户对象 - :param mfa_code: 二次验证码 - :return: 如果验证成功返回 True,否则返回 False + :param mfa_code: 二次验证码(如果提供了则验证OTP) + :return: + - 如果验证成功返回 True + - 如果需要MFA但未提供,返回 "MFA_REQUIRED" + - 如果MFA验证失败,返回 False """ - if not user.is_otp: + # 检查用户是否有PassKey + from app.db.models.passkey import PassKey + has_passkey = bool(PassKey().get_by_user_id(db=None, user_id=user.id)) + + # 如果用户既没有启用OTP也没有PassKey,直接通过 + if not user.is_otp and not has_passkey: return True + + # 如果用户启用了OTP或PassKey,但没有提供验证码,需要进行二次验证 if not mfa_code: - logger.info(f"用户 {user.name} 缺少 MFA 认证码") - return False - if not OtpUtils.check(str(user.otp_secret), mfa_code): - logger.info(f"用户 {user.name} 的 MFA 认证失败") - return False + logger.info(f"用户 {user.name} 已启用双重验证(OTP: {user.is_otp}, PassKey: {has_passkey}),需要提供验证码") + return "MFA_REQUIRED" + + # 如果提供了验证码,验证OTP + if user.is_otp: + if not OtpUtils.check(str(user.otp_secret), mfa_code): + logger.info(f"用户 {user.name} 的 MFA 认证失败") + return False + return True def _process_auth_success(self, username: str, credentials: AuthCredentials) -> bool: diff --git a/app/core/config.py b/app/core/config.py index 394d259e..090cb0ce 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -393,6 +393,8 @@ class ConfigModel(BaseModel): ]) # 允许的图片文件后缀格式 SECURITY_IMAGE_SUFFIXES: list = Field(default=[".jpg", ".jpeg", ".png", ".webp", ".gif", ".svg", ".avif"]) + # PassKey 是否强制用户验证(生物识别等) + PASSKEY_REQUIRE_UV: bool = True # ==================== 工作流配置 ==================== # 工作流数据共享 diff --git a/app/db/models/__init__.py b/app/db/models/__init__.py index 09ecf69f..a7bb7cd8 100644 --- a/app/db/models/__init__.py +++ b/app/db/models/__init__.py @@ -1,5 +1,6 @@ from .downloadhistory import DownloadHistory, DownloadFiles from .mediaserver import MediaServerItem +from .passkey import PassKey from .plugindata import PluginData from .site import Site from .siteicon import SiteIcon diff --git a/app/db/models/passkey.py b/app/db/models/passkey.py new file mode 100644 index 00000000..129d29e9 --- /dev/null +++ b/app/db/models/passkey.py @@ -0,0 +1,130 @@ +from sqlalchemy import Column, Integer, String, Boolean, DateTime, Text, select, ForeignKey +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import Session +from datetime import datetime + +from app.db import Base, db_query, db_update, async_db_query, async_db_update, get_id_column + + +class PassKey(Base): + """ + 用户PassKey凭证表 + """ + # ID + id = get_id_column() + # 用户ID + user_id = Column(Integer, ForeignKey('user.id'), nullable=False, index=True) + # 凭证ID (credential_id) + credential_id = Column(String, nullable=False, unique=True, index=True) + # 凭证公钥 + public_key = Column(Text, nullable=False) + # 签名计数器 + sign_count = Column(Integer, default=0) + # 凭证名称(用户自定义) + name = Column(String, default="通行密钥") + # AAGUID (Authenticator Attestation GUID) + aaguid = Column(String, nullable=True) + # 创建时间 + created_at = Column(DateTime, default=datetime.now) + # 最后使用时间 + last_used_at = Column(DateTime, nullable=True) + # 是否启用 + is_active = Column(Boolean, default=True) + # 传输方式 (usb, nfc, ble, internal) + transports = Column(String, nullable=True) + + @classmethod + @db_query + def get_by_user_id(cls, db: Session, user_id: int): + """获取用户的所有PassKey""" + return db.query(cls).filter(cls.user_id == user_id, cls.is_active == True).all() + + @classmethod + @async_db_query + async def async_get_by_user_id(cls, db: AsyncSession, user_id: int): + """异步获取用户的所有PassKey""" + result = await db.execute( + select(cls).filter(cls.user_id == user_id, cls.is_active == True) + ) + return result.scalars().all() + + @classmethod + @db_query + def get_by_credential_id(cls, db: Session, credential_id: str): + """根据凭证ID获取PassKey""" + return db.query(cls).filter(cls.credential_id == credential_id, cls.is_active == True).first() + + @classmethod + @async_db_query + async def async_get_by_credential_id(cls, db: AsyncSession, credential_id: str): + """异步根据凭证ID获取PassKey""" + result = await db.execute( + select(cls).filter(cls.credential_id == credential_id, cls.is_active == True) + ) + return result.scalars().first() + + @classmethod + @db_query + def get_by_id(cls, db: Session, passkey_id: int): + """根据ID获取PassKey""" + return db.query(cls).filter(cls.id == passkey_id).first() + + @classmethod + @async_db_query + async def async_get_by_id(cls, db: AsyncSession, passkey_id: int): + """异步根据ID获取PassKey""" + result = await db.execute( + select(cls).filter(cls.id == passkey_id) + ) + return result.scalars().first() + + @db_update + def delete_by_id(self, db: Session, passkey_id: int, user_id: int): + """删除指定用户的PassKey""" + passkey = db.query(PassKey).filter( + PassKey.id == passkey_id, + PassKey.user_id == user_id + ).first() + if passkey: + passkey.delete(db, passkey.id) + return True + return False + + @async_db_update + async def async_delete_by_id(self, db: AsyncSession, passkey_id: int, user_id: int): + """异步删除指定用户的PassKey""" + result = await db.execute( + select(PassKey).filter( + PassKey.id == passkey_id, + PassKey.user_id == user_id + ) + ) + passkey = result.scalars().first() + if passkey: + await passkey.async_delete(db, passkey.id) + return True + return False + + @db_update + def update_last_used(self, db: Session, credential_id: str, sign_count: int): + """更新最后使用时间和签名计数""" + passkey = self.get_by_credential_id(db, credential_id) + if passkey: + passkey.update(db, { + 'last_used_at': datetime.now(), + 'sign_count': sign_count + }) + return True + return False + + @async_db_update + async def async_update_last_used(self, db: AsyncSession, credential_id: str, sign_count: int): + """异步更新最后使用时间和签名计数""" + passkey = await self.async_get_by_credential_id(db, credential_id) + if passkey: + await passkey.async_update(db, { + 'last_used_at': datetime.now(), + 'sign_count': sign_count + }) + return True + return False diff --git a/app/helper/passkey.py b/app/helper/passkey.py new file mode 100644 index 00000000..6fccc413 --- /dev/null +++ b/app/helper/passkey.py @@ -0,0 +1,347 @@ +""" +PassKey WebAuthn 辅助工具类 +""" +import base64 +import json +from typing import Optional, Tuple, List, Dict, Any +from urllib.parse import urlparse + +from webauthn import ( + generate_registration_options, + verify_registration_response, + generate_authentication_options, + verify_authentication_response, + options_to_json +) +from webauthn.helpers import ( + parse_registration_credential_json, + parse_authentication_credential_json +) +from webauthn.helpers.structs import ( + PublicKeyCredentialDescriptor, + AuthenticatorTransport, + UserVerificationRequirement, + AuthenticatorAttachment, + ResidentKeyRequirement, + PublicKeyCredentialCreationOptions, + PublicKeyCredentialRequestOptions, + RegistrationCredential, + AuthenticationCredential, + AuthenticatorSelectionCriteria +) +from webauthn.helpers.cose import COSEAlgorithmIdentifier + +from app.core.config import settings +from app.log import logger + + +class PassKeyHelper: + """ + PassKey WebAuthn 辅助类 + """ + + @staticmethod + def get_rp_id() -> str: + """ + 获取 Relying Party ID + """ + if settings.APP_DOMAIN: + # 从 APP_DOMAIN 中提取域名 + host = settings.APP_DOMAIN.replace('https://', '').replace('http://', '') + # 移除端口号 + if ':' in host: + host = host.split(':')[0] + return host + # 只有在未配置 APP_DOMAIN 时,才默认为 localhost + return 'localhost' + + @staticmethod + def get_rp_name() -> str: + """ + 获取 Relying Party 名称 + """ + return "MoviePilot" + + @staticmethod + def get_origin() -> str: + """ + 获取源地址 + """ + if settings.APP_DOMAIN: + return settings.APP_DOMAIN.rstrip('/') + # 如果未配置APP_DOMAIN,使用默认的localhost地址 + return f'http://localhost:{settings.NGINX_PORT}' + + @staticmethod + def standardize_credential_id(credential_id: str) -> str: + """ + 标准化凭证ID(Base64 URL Safe) + """ + try: + # Base64解码并重新编码以标准化格式 + decoded = base64.urlsafe_b64decode(credential_id + '==') + return base64.urlsafe_b64encode(decoded).decode('utf-8').rstrip('=') + except Exception as e: + logger.error(f"标准化凭证ID失败: {e}") + return credential_id + + @staticmethod + def generate_registration_options( + user_id: int, + username: str, + display_name: Optional[str] = None, + existing_credentials: Optional[List[Dict[str, Any]]] = None + ) -> Tuple[str, str]: + """ + 生成注册选项 + + :param user_id: 用户ID + :param username: 用户名 + :param display_name: 显示名称 + :param existing_credentials: 已存在的凭证列表 + :return: (options_json, challenge) + """ + try: + # 用户信息 + user_id_bytes = str(user_id).encode('utf-8') + + # 排除已有的凭证 + exclude_credentials = [] + if existing_credentials: + for cred in existing_credentials: + try: + exclude_credentials.append( + PublicKeyCredentialDescriptor( + id=base64.urlsafe_b64decode(cred['credential_id'] + '=='), + transports=[ + AuthenticatorTransport(t) for t in cred.get('transports', '').split(',') if t + ] if cred.get('transports') else None + ) + ) + except Exception as e: + logger.warning(f"解析凭证失败: {e}") + continue + + # 用户验证要求 + uv_requirement = UserVerificationRequirement.REQUIRED if settings.PASSKEY_REQUIRE_UV \ + else UserVerificationRequirement.PREFERRED + + # 生成注册选项 + options = generate_registration_options( + rp_id=PassKeyHelper.get_rp_id(), + rp_name=PassKeyHelper.get_rp_name(), + user_id=user_id_bytes, + user_name=username, + user_display_name=display_name or username, + exclude_credentials=exclude_credentials if exclude_credentials else None, + authenticator_selection=AuthenticatorSelectionCriteria( + authenticator_attachment=AuthenticatorAttachment.PLATFORM, + resident_key=ResidentKeyRequirement.REQUIRED, + user_verification=uv_requirement, + ), + supported_pub_key_algs=[ + COSEAlgorithmIdentifier.ECDSA_SHA_256, + COSEAlgorithmIdentifier.RSASSA_PKCS1_v1_5_SHA_256, + ] + ) + + # 转换为JSON + options_json = options_to_json(options) + + # 提取challenge(用于后续验证) + challenge = base64.urlsafe_b64encode(options.challenge).decode('utf-8').rstrip('=') + + return options_json, challenge + + except Exception as e: + logger.error(f"生成注册选项失败: {e}") + raise + + @staticmethod + def _get_verified_origin(credential: Dict[str, Any], rp_id: str, default_origin: str) -> str: + """ + 在 localhost 环境下获取并验证实际 Origin,否则返回默认值 + """ + if not settings.APP_DOMAIN and rp_id == 'localhost': + try: + # 解析 clientDataJSON 获取实际的 origin + client_data_json = json.loads( + base64.urlsafe_b64decode( + credential['response']['clientDataJSON'].replace('-', '+').replace('_', '/') + '==' + ).decode('utf-8') + ) + actual_origin = client_data_json.get('origin', '') + hostname = urlparse(actual_origin).hostname + + if hostname in ['localhost', '127.0.0.1']: + logger.info(f"本地环境,使用动态 origin: {actual_origin}") + return actual_origin + except Exception as e: + logger.warning(f"无法提取动态 origin: {e}") + return default_origin + + @staticmethod + def verify_registration_response( + credential: Dict[str, Any], + expected_challenge: str, + expected_origin: Optional[str] = None, + expected_rp_id: Optional[str] = None + ) -> Tuple[str, str, int, Optional[str]]: + """ + 验证注册响应 + + :param credential: 客户端返回的凭证 + :param expected_challenge: 期望的challenge + :param expected_origin: 期望的源地址 + :param expected_rp_id: 期望的RP ID + :return: (credential_id, public_key, sign_count, aaguid) + """ + try: + # 准备验证参数 + origin = expected_origin or PassKeyHelper.get_origin() + rp_id = expected_rp_id or PassKeyHelper.get_rp_id() + + # 解码challenge + challenge_bytes = base64.urlsafe_b64decode(expected_challenge + '==') + + # 构建RegistrationCredential对象 + registration_credential = parse_registration_credential_json(json.dumps(credential)) + + # 获取并验证 Origin + origin = PassKeyHelper._get_verified_origin(credential, rp_id, origin) + + # 验证注册响应 + verification = verify_registration_response( + credential=registration_credential, + expected_challenge=challenge_bytes, + expected_rp_id=rp_id, + expected_origin=origin, + require_user_verification=settings.PASSKEY_REQUIRE_UV + ) + + # 提取信息 + credential_id = base64.urlsafe_b64encode(verification.credential_id).decode('utf-8').rstrip('=') + public_key = base64.urlsafe_b64encode(verification.credential_public_key).decode('utf-8').rstrip('=') + sign_count = verification.sign_count + # aaguid 可能已经是字符串格式,也可能是bytes + if verification.aaguid: + if isinstance(verification.aaguid, bytes): + aaguid = verification.aaguid.hex() + else: + aaguid = str(verification.aaguid) + else: + aaguid = None + + return credential_id, public_key, sign_count, aaguid + + except Exception as e: + logger.error(f"验证注册响应失败: {e}") + raise + + @staticmethod + def generate_authentication_options( + existing_credentials: Optional[List[Dict[str, Any]]] = None, + user_verification: Optional[str] = None + ) -> Tuple[str, str]: + """ + 生成认证选项 + + :param existing_credentials: 已存在的凭证列表(用于限制可用凭证) + :param user_verification: 用户验证要求,如果不指定则从配置中读取 + :return: (options_json, challenge) + """ + try: + # 允许的凭证 + allow_credentials = [] + if existing_credentials: + for cred in existing_credentials: + try: + allow_credentials.append( + PublicKeyCredentialDescriptor( + id=base64.urlsafe_b64decode(cred['credential_id'] + '=='), + transports=[ + AuthenticatorTransport(t) for t in cred.get('transports', '').split(',') if t + ] if cred.get('transports') else None + ) + ) + except Exception as e: + logger.warning(f"解析凭证失败: {e}") + continue + + # 用户验证要求 + if not user_verification: + uv_requirement = UserVerificationRequirement.REQUIRED if settings.PASSKEY_REQUIRE_UV \ + else UserVerificationRequirement.PREFERRED + else: + uv_requirement = UserVerificationRequirement(user_verification) + + # 生成认证选项 + options = generate_authentication_options( + rp_id=PassKeyHelper.get_rp_id(), + allow_credentials=allow_credentials if allow_credentials else None, + user_verification=uv_requirement + ) + + # 转换为JSON + options_json = options_to_json(options) + + # 提取challenge + challenge = base64.urlsafe_b64encode(options.challenge).decode('utf-8').rstrip('=') + + return options_json, challenge + + except Exception as e: + logger.error(f"生成认证选项失败: {e}") + raise + + @staticmethod + def verify_authentication_response( + credential: Dict[str, Any], + expected_challenge: str, + credential_public_key: str, + credential_current_sign_count: int, + expected_origin: Optional[str] = None, + expected_rp_id: Optional[str] = None + ) -> Tuple[bool, int]: + """ + 验证认证响应 + + :param credential: 客户端返回的凭证 + :param expected_challenge: 期望的challenge + :param credential_public_key: 凭证公钥 + :param credential_current_sign_count: 当前签名计数 + :param expected_origin: 期望的源地址 + :param expected_rp_id: 期望的RP ID + :return: (验证成功, 新的签名计数) + """ + try: + # 准备验证参数 + origin = expected_origin or PassKeyHelper.get_origin() + rp_id = expected_rp_id or PassKeyHelper.get_rp_id() + + # 解码 + challenge_bytes = base64.urlsafe_b64decode(expected_challenge + '==') + public_key_bytes = base64.urlsafe_b64decode(credential_public_key + '==') + + # 构建AuthenticationCredential对象 + authentication_credential = parse_authentication_credential_json(json.dumps(credential)) + + # 获取并验证 Origin + origin = PassKeyHelper._get_verified_origin(credential, rp_id, origin) + + # 验证认证响应 + verification = verify_authentication_response( + credential=authentication_credential, + expected_challenge=challenge_bytes, + expected_rp_id=rp_id, + expected_origin=origin, + credential_public_key=public_key_bytes, + credential_current_sign_count=credential_current_sign_count, + require_user_verification=settings.PASSKEY_REQUIRE_UV + ) + + return True, verification.new_sign_count + + except Exception as e: + logger.error(f"验证认证响应失败: {e}") + return False, credential_current_sign_count diff --git a/requirements.in b/requirements.in index 7ca6ad66..6eb374f3 100644 --- a/requirements.in +++ b/requirements.in @@ -62,6 +62,7 @@ cachetools~=6.1.0 fast-bencode~=1.1.7 pystray~=0.19.5 pyotp~=2.9.0 +webauthn~=2.7.0 Pinyin2Hanzi~=0.1.1 pywebpush~=2.0.3 aiopathlib~=0.6.0