feat: add WebAuthn passkey authentication support

- Add passkey login as alternative authentication method
- Support multiple passkeys per user with custom names
- Backend: WebAuthn service, auth strategy pattern, API endpoints
- Frontend: passkey management UI in settings, login option
- Fix: convert downloader check from sync requests to async httpx
  to prevent blocking the event loop when downloader unavailable

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
EstrellaXD
2026-01-22 11:50:55 +01:00
parent b4d90e2a11
commit bfba010471
23 changed files with 1607 additions and 87 deletions

View File

@@ -4,6 +4,7 @@ from .auth import router as auth_router
from .bangumi import router as bangumi_router
from .config import router as config_router
from .log import router as log_router
from .passkey import router as passkey_router
from .program import router as program_router
from .rss import router as rss_router
from .search import router as search_router
@@ -13,6 +14,7 @@ __all__ = "v1"
# API 1.0
v1 = APIRouter(prefix="/v1")
v1.include_router(auth_router)
v1.include_router(passkey_router)
v1.include_router(log_router)
v1.include_router(program_router)
v1.include_router(bangumi_router)

View File

@@ -0,0 +1,229 @@
"""
Passkey 管理 API
用于注册、列表、删除 Passkey 凭证
"""
import logging
from datetime import timedelta
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import JSONResponse, Response
from module.database import Database
from module.models import APIResponse
from module.models.passkey import (
PasskeyAuthFinish,
PasskeyAuthStart,
PasskeyCreate,
PasskeyDelete,
PasskeyList,
)
from module.security.api import active_user, get_current_user
from module.security.auth_strategy import PasskeyAuthStrategy
from module.security.jwt import create_access_token
from module.security.webauthn import get_webauthn_service
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/passkey", tags=["passkey"])
def _get_webauthn_from_request(request: Request):
"""
从请求中构造 WebAuthnService
根据 Host header 动态确定 RP ID 和 origin
"""
host = request.headers.get("host", "localhost:7892")
rp_id = host.split(":")[0] # 去掉端口
# 判断协议
forwarded_proto = request.headers.get("x-forwarded-proto")
if forwarded_proto:
scheme = forwarded_proto
else:
scheme = request.url.scheme
if scheme == "https":
origin = f"https://{host}"
else:
origin = f"http://{host}"
return get_webauthn_service(rp_id, "AutoBangumi", origin)
# ============ 注册流程 ============
@router.post("/register/options", response_model=dict)
async def get_registration_options(
request: Request,
username: str = Depends(get_current_user),
):
"""
生成 Passkey 注册选项
前端调用 navigator.credentials.create() 时使用
"""
webauthn = _get_webauthn_from_request(request)
async with Database() as db:
try:
user = await db.user.get_user(username)
existing_passkeys = await db.passkey.get_passkeys_by_user_id(user.id)
options = webauthn.generate_registration_options(
username=username,
user_id=user.id,
existing_passkeys=existing_passkeys,
)
return options
except Exception as e:
logger.error(f"Failed to generate registration options: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/register/verify", response_model=APIResponse)
async def verify_registration(
passkey_data: PasskeyCreate,
request: Request,
username: str = Depends(get_current_user),
):
"""
验证 Passkey 注册响应并保存
"""
webauthn = _get_webauthn_from_request(request)
async with Database() as db:
try:
user = await db.user.get_user(username)
# 验证 WebAuthn 响应
passkey = webauthn.verify_registration(
username=username,
credential=passkey_data.attestation_response,
device_name=passkey_data.name,
)
# 设置 user_id 并保存
passkey.user_id = user.id
await db.passkey.create_passkey(passkey)
return JSONResponse(
status_code=200,
content={
"msg_en": f"Passkey '{passkey_data.name}' registered successfully",
"msg_zh": f"Passkey '{passkey_data.name}' 注册成功",
},
)
except ValueError as e:
logger.warning(f"Registration verification failed for {username}: {e}")
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(f"Failed to register passkey: {e}")
raise HTTPException(status_code=500, detail=str(e))
# ============ 认证流程 ============
@router.post("/auth/options", response_model=dict)
async def get_passkey_login_options(
auth_data: PasskeyAuthStart,
request: Request,
):
"""
生成 Passkey 登录选项challenge
前端先调用此接口,再调用 navigator.credentials.get()
"""
webauthn = _get_webauthn_from_request(request)
async with Database() as db:
try:
user = await db.user.get_user(auth_data.username)
passkeys = await db.passkey.get_passkeys_by_user_id(user.id)
if not passkeys:
raise HTTPException(
status_code=400, detail="No passkeys registered for this user"
)
options = webauthn.generate_authentication_options(
auth_data.username, passkeys
)
return options
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to generate login options: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/auth/verify", response_model=dict)
async def login_with_passkey(
auth_data: PasskeyAuthFinish,
response: Response,
request: Request,
):
"""
使用 Passkey 登录(替代密码登录)
"""
webauthn = _get_webauthn_from_request(request)
strategy = PasskeyAuthStrategy(webauthn)
resp = await strategy.authenticate(auth_data.username, auth_data.credential)
if resp.status:
token = create_access_token(
data={"sub": auth_data.username}, expires_delta=timedelta(days=1)
)
response.set_cookie(key="token", value=token, httponly=True, max_age=86400)
if auth_data.username not in active_user:
active_user.append(auth_data.username)
return {"access_token": token, "token_type": "bearer"}
raise HTTPException(status_code=resp.status_code, detail=resp.msg_en)
# ============ Passkey 管理 ============
@router.get("/list", response_model=list[PasskeyList])
async def list_passkeys(username: str = Depends(get_current_user)):
"""获取用户的所有 Passkey"""
async with Database() as db:
try:
user = await db.user.get_user(username)
passkeys = await db.passkey.get_passkeys_by_user_id(user.id)
return [db.passkey.to_list_model(pk) for pk in passkeys]
except Exception as e:
logger.error(f"Failed to list passkeys: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/delete", response_model=APIResponse)
async def delete_passkey(
delete_data: PasskeyDelete,
username: str = Depends(get_current_user),
):
"""删除 Passkey"""
async with Database() as db:
try:
user = await db.user.get_user(username)
await db.passkey.delete_passkey(delete_data.passkey_id, user.id)
return JSONResponse(
status_code=200,
content={
"msg_en": "Passkey deleted successfully",
"msg_zh": "Passkey 删除成功",
},
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to delete passkey: {e}")
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -24,7 +24,7 @@ async def startup():
@router.on_event("shutdown")
async def shutdown():
program.stop()
await program.stop()
@router.get(
@@ -69,7 +69,8 @@ async def start():
"/stop", response_model=APIResponse, dependencies=[Depends(get_current_user)]
)
async def stop():
return u_response(program.stop())
resp = await program.stop()
return u_response(resp)
@router.get("/status", response_model=dict, dependencies=[Depends(get_current_user)])
@@ -92,7 +93,7 @@ async def program_status():
"/shutdown", response_model=APIResponse, dependencies=[Depends(get_current_user)]
)
async def shutdown_program():
program.stop()
await program.stop()
logger.info("Shutting down program...")
os.kill(os.getpid(), signal.SIGINT)
return JSONResponse(
@@ -112,4 +113,4 @@ async def shutdown_program():
dependencies=[Depends(get_current_user)],
)
async def check_downloader_status():
return program.check_downloader()
return await program.check_downloader()

View File

@@ -1,10 +1,9 @@
import logging
from pathlib import Path
import requests
import httpx
from module.conf import VERSION, settings
from module.downloader import DownloadClient
from module.models import Config
from module.update import version_check
@@ -49,27 +48,28 @@ class Checker:
return True
@staticmethod
def check_downloader() -> bool:
async def check_downloader() -> bool:
from module.downloader import DownloadClient
try:
url = (
f"http://{settings.downloader.host}"
if "://" not in settings.downloader.host
else f"{settings.downloader.host}"
)
response = requests.get(url, timeout=2)
# if settings.downloader.type in response.text.lower():
async with httpx.AsyncClient(timeout=2.0) as client:
response = await client.get(url)
if "qbittorrent" in response.text.lower() or "vuetorrent" in response.text.lower():
with DownloadClient() as client:
if client.authed:
async with DownloadClient() as dl_client:
if dl_client.authed:
return True
else:
return False
else:
return False
except requests.exceptions.ReadTimeout:
except httpx.TimeoutException:
logger.error("[Checker] Downloader connect timeout.")
return False
except requests.exceptions.ConnectionError:
except httpx.ConnectError:
logger.error("[Checker] Downloader connect failed.")
return False
except Exception as e:

View File

@@ -16,14 +16,14 @@ from .sub_thread import RenameThread, RSSThread
logger = logging.getLogger(__name__)
figlet = r"""
_ ____ _
/\ | | | _ \ (_)
/ \ _ _| |_ ___ | |_) | __ _ _ __ __ _ _ _ _ __ ___ _
/ /\ \| | | | __/ _ \| _ < / _` | '_ \ / _` | | | | '_ ` _ \| |
/ ____ \ |_| | || (_) | |_) | (_| | | | | (_| | |_| | | | | | | |
/_/ \_\__,_|\__\___/|____/ \__,_|_| |_|\__, |\__,_|_| |_| |_|_|
__/ |
|___/
_ ____ _
/\ | | | _ \ (_)
/ \ _ _| |_ ___ | |_) | __ _ _ __ __ _ _ _ _ __ ___ _
/ /\ \| | | | __/ _ \| _ < / _` | '_ \ / _` | | | | '_ ` _ \| |
/ ____ \ |_| | || (_) | |_) | (_| | | | | (_| | |_| | | | | | | |
/_/ \_\__,_|\__\___/|____/ \__,_|_| |_|\__, |\__,_|_| |_| |_|_|
__/ |
|___/
"""
@@ -61,7 +61,7 @@ class Program(RenameThread, RSSThread):
async def start(self):
self.stop_event.clear()
settings.load()
while not self.downloader_status:
while not await self.check_downloader_status():
logger.warning("Downloader is not running.")
logger.info("Waiting for downloader to start.")
await asyncio.sleep(30)
@@ -77,11 +77,11 @@ class Program(RenameThread, RSSThread):
msg_zh="程序启动成功。",
)
def stop(self):
async def stop(self):
if self.is_running:
self.stop_event.set()
self.rename_stop()
self.rss_stop()
await self.rename_stop()
await self.rss_stop()
return ResponseModel(
status=True,
status_code=200,
@@ -97,7 +97,7 @@ class Program(RenameThread, RSSThread):
)
async def restart(self):
self.stop()
await self.stop()
await self.start()
return ResponseModel(
status=True,

View File

@@ -1,5 +1,4 @@
import asyncio
import threading
from module.checker import Checker
from module.conf import LEGACY_DATA_PATH
@@ -8,8 +7,8 @@ from module.conf import LEGACY_DATA_PATH
class ProgramStatus(Checker):
def __init__(self):
super().__init__()
self.stop_event = threading.Event()
self.lock = threading.Lock()
self.stop_event = asyncio.Event()
self.lock = asyncio.Lock()
self._downloader_status = False
self._torrents_status = False
self.event = asyncio.Event()
@@ -27,8 +26,11 @@ class ProgramStatus(Checker):
@property
def downloader_status(self):
return self._downloader_status
async def check_downloader_status(self) -> bool:
if not self._downloader_status:
self._downloader_status = self.check_downloader()
self._downloader_status = await self.check_downloader()
return self._downloader_status
@property

View File

@@ -1,42 +1,64 @@
from sqlmodel import Session, SQLModel
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import SQLModel
from module.models import Bangumi, User
from module.models import Bangumi, Passkey, User
from .bangumi import BangumiDatabase
from .engine import engine as e
from .engine import async_session_factory, engine as e
from .passkey import PasskeyDatabase
from .rss import RSSDatabase
from .torrent import TorrentDatabase
from .user import UserDatabase
class Database(Session):
def __init__(self, engine=e):
self.engine = engine
super().__init__(engine)
self.rss = RSSDatabase(self)
self.torrent = TorrentDatabase(self)
self.bangumi = BangumiDatabase(self)
self.user = UserDatabase(self)
class Database:
def __init__(self):
self._session: AsyncSession | None = None
self.rss: RSSDatabase | None = None
self.torrent: TorrentDatabase | None = None
self.bangumi: BangumiDatabase | None = None
self.user: UserDatabase | None = None
self.passkey: PasskeyDatabase | None = None
def create_table(self):
SQLModel.metadata.create_all(self.engine)
async def __aenter__(self):
self._session = async_session_factory()
self.rss = RSSDatabase(self._session)
self.torrent = TorrentDatabase(self._session)
self.bangumi = BangumiDatabase(self._session)
self.user = UserDatabase(self._session)
self.passkey = PasskeyDatabase(self._session)
return self
def drop_table(self):
SQLModel.metadata.drop_all(self.engine)
async def __aexit__(self, exc_type, exc_val, exc_tb):
if self._session:
await self._session.close()
def migrate(self):
async def create_table(self):
async with e.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all)
async def drop_table(self):
async with e.begin() as conn:
await conn.run_sync(SQLModel.metadata.drop_all)
async def commit(self):
if self._session:
await self._session.commit()
async def add(self, obj):
if self._session:
self._session.add(obj)
await self._session.commit()
async def migrate(self):
# Run migration online
bangumi_data = self.bangumi.search_all()
user_data = self.exec("SELECT * FROM user").all()
bangumi_data = await self.bangumi.search_all()
readd_bangumi = []
for bangumi in bangumi_data:
dict_data = bangumi.dict()
del dict_data["id"]
readd_bangumi.append(Bangumi(**dict_data))
self.drop_table()
self.create_table()
self.commit()
bangumi_data = self.bangumi.search_all()
self.bangumi.add_all(readd_bangumi)
self.add(User(**user_data[0]))
self.commit()
await self.drop_table()
await self.create_table()
await self.commit()
await self.bangumi.add_all(readd_bangumi)

View File

@@ -0,0 +1,78 @@
"""
Passkey 数据库操作层
"""
import logging
from datetime import datetime
from typing import List, Optional
from fastapi import HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import select
from module.models.passkey import Passkey, PasskeyList
logger = logging.getLogger(__name__)
class PasskeyDatabase:
def __init__(self, session: AsyncSession):
self.session = session
async def create_passkey(self, passkey: Passkey) -> Passkey:
"""创建新的 Passkey 凭证"""
self.session.add(passkey)
await self.session.commit()
await self.session.refresh(passkey)
logger.info(f"Created passkey '{passkey.name}' for user_id={passkey.user_id}")
return passkey
async def get_passkey_by_credential_id(
self, credential_id: str
) -> Optional[Passkey]:
"""通过 credential_id 查找 Passkey用于认证"""
statement = select(Passkey).where(Passkey.credential_id == credential_id)
result = await self.session.execute(statement)
return result.scalar_one_or_none()
async def get_passkeys_by_user_id(self, user_id: int) -> List[Passkey]:
"""获取用户的所有 Passkey"""
statement = select(Passkey).where(Passkey.user_id == user_id)
result = await self.session.execute(statement)
return list(result.scalars().all())
async def get_passkey_by_id(self, passkey_id: int, user_id: int) -> Passkey:
"""获取特定 Passkey带权限检查"""
statement = select(Passkey).where(
Passkey.id == passkey_id, Passkey.user_id == user_id
)
result = await self.session.execute(statement)
passkey = result.scalar_one_or_none()
if not passkey:
raise HTTPException(status_code=404, detail="Passkey not found")
return passkey
async def update_passkey_usage(self, passkey: Passkey, new_sign_count: int):
"""更新 Passkey 使用记录(签名计数器 + 最后使用时间)"""
passkey.sign_count = new_sign_count
passkey.last_used_at = datetime.utcnow()
self.session.add(passkey)
await self.session.commit()
async def delete_passkey(self, passkey_id: int, user_id: int) -> bool:
"""删除 Passkey"""
passkey = await self.get_passkey_by_id(passkey_id, user_id)
await self.session.delete(passkey)
await self.session.commit()
logger.info(f"Deleted passkey id={passkey_id} for user_id={user_id}")
return True
def to_list_model(self, passkey: Passkey) -> PasskeyList:
"""转换为安全的列表展示模型"""
return PasskeyList(
id=passkey.id,
name=passkey.name,
created_at=passkey.created_at,
last_used_at=passkey.last_used_at,
backup_eligible=passkey.backup_eligible,
aaguid=passkey.aaguid,
)

View File

@@ -1,5 +1,6 @@
from .bangumi import Bangumi, BangumiUpdate, Episode, Notification
from .config import Config
from .passkey import Passkey, PasskeyCreate, PasskeyDelete, PasskeyList
from .response import APIResponse, ResponseModel
from .rss import RSSItem, RSSUpdate
from .torrent import EpisodeFile, SubtitleFile, Torrent, TorrentUpdate

View File

@@ -0,0 +1,75 @@
"""
WebAuthn Passkey 数据模型
"""
from datetime import datetime
from typing import Optional
from pydantic import BaseModel
from sqlmodel import Field, SQLModel
class Passkey(SQLModel, table=True):
"""存储 WebAuthn 凭证的数据库模型"""
__tablename__ = "passkey"
id: int = Field(default=None, primary_key=True)
user_id: int = Field(foreign_key="user.id", index=True)
# 用户友好的名称 (e.g., "iPhone 15", "MacBook Pro")
name: str = Field(min_length=1, max_length=64)
# WebAuthn 核心字段
credential_id: str = Field(unique=True, index=True) # Base64URL encoded
public_key: str # CBOR encoded public key, Base64 stored
sign_count: int = Field(default=0) # 防止克隆攻击
# 可选的设备信息
aaguid: Optional[str] = None # Authenticator AAGUID
transports: Optional[str] = None # JSON array: ["usb", "nfc", "ble", "internal"]
# 审计字段
created_at: datetime = Field(default_factory=datetime.utcnow)
last_used_at: Optional[datetime] = None
# 备份状态 (是否为多设备凭证,如 iCloud Keychain)
backup_eligible: bool = Field(default=False)
backup_state: bool = Field(default=False)
class PasskeyCreate(BaseModel):
"""创建 Passkey 的请求模型"""
name: str = Field(min_length=1, max_length=64)
# 注册完成后的 WebAuthn 响应
attestation_response: dict
class PasskeyList(BaseModel):
"""返回给前端的 Passkey 列表(不含敏感数据)"""
id: int
name: str
created_at: datetime
last_used_at: Optional[datetime]
backup_eligible: bool
aaguid: Optional[str]
class PasskeyDelete(BaseModel):
"""删除 Passkey 请求"""
passkey_id: int
class PasskeyAuthStart(BaseModel):
"""Passkey 认证开始请求"""
username: str
class PasskeyAuthFinish(BaseModel):
"""Passkey 认证完成请求"""
username: str
credential: dict

View File

@@ -0,0 +1,115 @@
"""
认证策略抽象层
将密码认证和 Passkey 认证统一为策略模式
"""
import base64
from abc import ABC, abstractmethod
from module.database import Database
from module.models import ResponseModel
from module.models.user import User
class AuthStrategy(ABC):
"""认证策略基类"""
@abstractmethod
async def authenticate(self, username: str, credential: dict) -> ResponseModel:
"""
执行认证
Args:
username: 用户名
credential: 认证凭证(密码或 WebAuthn 响应)
Returns:
ResponseModel with status and user info
"""
pass
class PasswordAuthStrategy(AuthStrategy):
"""密码认证策略(保持向后兼容)"""
async def authenticate(self, username: str, credential: dict) -> ResponseModel:
"""使用密码认证"""
password = credential.get("password")
if not password:
return ResponseModel(
status_code=401,
status=False,
msg_en="Password is required",
msg_zh="密码不能为空",
)
user = User(username=username, password=password)
async with Database() as db:
return await db.user.auth_user(user)
class PasskeyAuthStrategy(AuthStrategy):
"""Passkey 认证策略"""
def __init__(self, webauthn_service):
self.webauthn_service = webauthn_service
async def authenticate(self, username: str, credential: dict) -> ResponseModel:
"""使用 WebAuthn Passkey 认证"""
async with Database() as db:
# 1. 查找用户
try:
user = await db.user.get_user(username)
except Exception:
return ResponseModel(
status_code=401,
status=False,
msg_en="User not found",
msg_zh="用户不存在",
)
# 2. 提取 credential_id 并查找对应的 passkey
try:
raw_id = credential.get("rawId")
if not raw_id:
raise ValueError("Missing credential ID")
# 将 rawId 从 base64url 转换为标准格式
credential_id_str = self.webauthn_service.base64url_encode(
self.webauthn_service.base64url_decode(raw_id)
)
passkey = await db.passkey.get_passkey_by_credential_id(credential_id_str)
if not passkey or passkey.user_id != user.id:
raise ValueError("Passkey not found or not owned by user")
except Exception as e:
return ResponseModel(
status_code=401,
status=False,
msg_en="Invalid passkey credential",
msg_zh="Passkey 凭证无效",
)
# 3. 验证 WebAuthn 签名
try:
new_sign_count = self.webauthn_service.verify_authentication(
username, credential, passkey
)
# 4. 更新使用记录
await db.passkey.update_passkey_usage(passkey, new_sign_count)
return ResponseModel(
status_code=200,
status=True,
msg_en="Login successfully with passkey",
msg_zh="通过 Passkey 登录成功",
)
except ValueError as e:
return ResponseModel(
status_code=401,
status=False,
msg_en=f"Passkey verification failed: {str(e)}",
msg_zh=f"Passkey 验证失败: {str(e)}",
)

View File

@@ -0,0 +1,277 @@
"""
WebAuthn 认证服务层
封装 py_webauthn 库的复杂性,提供清晰的注册和认证接口
"""
import base64
import json
import logging
from typing import List, Optional
from webauthn import (
generate_authentication_options,
generate_registration_options,
options_to_json,
verify_authentication_response,
verify_registration_response,
)
from webauthn.helpers.cose import COSEAlgorithmIdentifier
from webauthn.helpers.structs import (
AuthenticatorSelectionCriteria,
AuthenticatorTransport,
PublicKeyCredentialDescriptor,
PublicKeyCredentialType,
ResidentKeyRequirement,
UserVerificationRequirement,
)
from module.models.passkey import Passkey
logger = logging.getLogger(__name__)
class WebAuthnService:
"""WebAuthn 核心业务逻辑"""
def __init__(self, rp_id: str, rp_name: str, origin: str):
"""
Args:
rp_id: 依赖方 ID (e.g., "localhost" or "autobangumi.example.com")
rp_name: 依赖方名称 (e.g., "AutoBangumi")
origin: 前端 origin (e.g., "http://localhost:5173")
"""
self.rp_id = rp_id
self.rp_name = rp_name
self.origin = origin
# 存储临时的 challenge生产环境应使用 Redis
self._challenges: dict[str, bytes] = {}
# ============ 注册流程 ============
def generate_registration_options(
self, username: str, user_id: int, existing_passkeys: List[Passkey]
) -> dict:
"""
生成 WebAuthn 注册选项
Args:
username: 用户名
user_id: 用户 ID转为 bytes
existing_passkeys: 用户已有的 Passkey用于排除
Returns:
JSON-serializable registration options
"""
# 将已有凭证转为排除列表
exclude_credentials = [
PublicKeyCredentialDescriptor(
id=self.base64url_decode(pk.credential_id),
type=PublicKeyCredentialType.PUBLIC_KEY,
transports=self._parse_transports(pk.transports),
)
for pk in existing_passkeys
]
options = generate_registration_options(
rp_id=self.rp_id,
rp_name=self.rp_name,
user_id=str(user_id).encode("utf-8"),
user_name=username,
user_display_name=username,
exclude_credentials=exclude_credentials if exclude_credentials else None,
authenticator_selection=AuthenticatorSelectionCriteria(
resident_key=ResidentKeyRequirement.PREFERRED,
user_verification=UserVerificationRequirement.PREFERRED,
),
supported_pub_key_algs=[
COSEAlgorithmIdentifier.ECDSA_SHA_256, # -7: ES256
COSEAlgorithmIdentifier.RSASSA_PKCS1_v1_5_SHA_256, # -257: RS256
],
)
# 存储 challenge 用于后续验证
challenge_key = f"reg_{username}"
self._challenges[challenge_key] = options.challenge
logger.debug(f"Generated registration challenge for {username}")
return json.loads(options_to_json(options))
def verify_registration(
self, username: str, credential: dict, device_name: str
) -> Passkey:
"""
验证注册响应并创建 Passkey 对象
Args:
username: 用户名
credential: 来自前端的 credential 响应
device_name: 用户输入的设备名称
Returns:
Passkey 对象(未保存到数据库)
Raises:
ValueError: 验证失败
"""
challenge_key = f"reg_{username}"
expected_challenge = self._challenges.get(challenge_key)
if not expected_challenge:
raise ValueError("Challenge not found or expired")
try:
verification = verify_registration_response(
credential=credential,
expected_challenge=expected_challenge,
expected_rp_id=self.rp_id,
expected_origin=self.origin,
)
# 构造 Passkey 对象
passkey = Passkey(
user_id=0, # 调用方设置
name=device_name,
credential_id=self.base64url_encode(verification.credential_id),
public_key=base64.b64encode(verification.credential_public_key).decode(
"utf-8"
),
sign_count=verification.sign_count,
aaguid=verification.aaguid.hex() if verification.aaguid else None,
backup_eligible=verification.credential_backup_eligible,
backup_state=verification.credential_backed_up,
)
logger.info(
f"Successfully verified registration for {username}, device: {device_name}"
)
return passkey
except Exception as e:
logger.error(f"Registration verification failed: {e}")
raise ValueError(f"Invalid registration response: {str(e)}")
finally:
# 清理使用过的 challenge无论成功或失败都清理防止重放攻击
self._challenges.pop(challenge_key, None)
# ============ 认证流程 ============
def generate_authentication_options(
self, username: str, passkeys: List[Passkey]
) -> dict:
"""
生成 WebAuthn 认证选项
Args:
username: 用户名
passkeys: 用户的 Passkey 列表(限定可用凭证)
Returns:
JSON-serializable authentication options
"""
allow_credentials = [
PublicKeyCredentialDescriptor(
id=self.base64url_decode(pk.credential_id),
type=PublicKeyCredentialType.PUBLIC_KEY,
transports=self._parse_transports(pk.transports),
)
for pk in passkeys
]
options = generate_authentication_options(
rp_id=self.rp_id,
allow_credentials=allow_credentials if allow_credentials else None,
user_verification=UserVerificationRequirement.PREFERRED,
)
# 存储 challenge
challenge_key = f"auth_{username}"
self._challenges[challenge_key] = options.challenge
logger.debug(f"Generated authentication challenge for {username}")
return json.loads(options_to_json(options))
def verify_authentication(
self, username: str, credential: dict, passkey: Passkey
) -> int:
"""
验证认证响应
Args:
username: 用户名
credential: 来自前端的 credential 响应
passkey: 对应的 Passkey 对象
Returns:
新的 sign_count用于更新数据库
Raises:
ValueError: 验证失败
"""
challenge_key = f"auth_{username}"
expected_challenge = self._challenges.get(challenge_key)
if not expected_challenge:
raise ValueError("Challenge not found or expired")
try:
# 解码 public key
credential_public_key = base64.b64decode(passkey.public_key)
credential_id = self.base64url_decode(passkey.credential_id)
verification = verify_authentication_response(
credential=credential,
expected_challenge=expected_challenge,
expected_rp_id=self.rp_id,
expected_origin=self.origin,
credential_public_key=credential_public_key,
credential_current_sign_count=passkey.sign_count,
credential_id=credential_id,
)
logger.info(f"Successfully verified authentication for {username}")
return verification.new_sign_count
except Exception as e:
logger.error(f"Authentication verification failed: {e}")
raise ValueError(f"Invalid authentication response: {str(e)}")
finally:
# 清理 challenge无论成功或失败都清理防止重放攻击
self._challenges.pop(challenge_key, None)
# ============ 辅助方法 ============
def _parse_transports(
self, transports_json: Optional[str]
) -> List[AuthenticatorTransport]:
"""解析存储的 transports JSON"""
if not transports_json:
return []
try:
transport_strings = json.loads(transports_json)
return [AuthenticatorTransport(t) for t in transport_strings]
except Exception:
return []
def base64url_encode(self, data: bytes) -> str:
"""Base64URL 编码(无 padding"""
return base64.urlsafe_b64encode(data).decode("utf-8").rstrip("=")
def base64url_decode(self, data: str) -> bytes:
"""Base64URL 解码(补齐 padding"""
padding = 4 - len(data) % 4
if padding != 4:
data += "=" * padding
return base64.urlsafe_b64decode(data)
# 全局 WebAuthn 服务实例存储
_webauthn_services: dict[str, WebAuthnService] = {}
def get_webauthn_service(rp_id: str, rp_name: str, origin: str) -> WebAuthnService:
"""
获取或创建 WebAuthnService 实例
使用缓存以保持 challenge 状态
"""
key = f"{rp_id}:{origin}"
if key not in _webauthn_services:
_webauthn_services[key] = WebAuthnService(rp_id, rp_name, origin)
return _webauthn_services[key]