fix: resolve WebAuthn passkey compatibility with py_webauthn 2.7.0

- Fix aaguid type (str not bytes) in registration verification
- Fix missing credential_backup_eligible field (use credential_device_type)
- Remove invalid credential_id param from verify_authentication_response
- Fix origin detection to use browser Origin header for WebAuthn verification
- Add async database engine support (aiosqlite) for passkey operations
- Convert UserDatabase to async-compatible with sync/async session detection
- Update Database class to support both sync and async context managers

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
EstrellaXD
2026-01-23 15:07:18 +01:00
parent d2cfd9b150
commit 027222a24d
7 changed files with 119 additions and 71 deletions

View File

@@ -22,7 +22,7 @@ router = APIRouter(prefix="/auth", tags=["auth"])
@router.post("/login", response_model=dict)
async def login(response: Response, form_data=Depends(OAuth2PasswordRequestForm)):
user = User(username=form_data.username, password=form_data.password)
resp = auth_user(user)
resp = await auth_user(user)
if resp.status:
token = create_access_token(
data={"sub": user.username}, expires_delta=timedelta(days=1)
@@ -58,7 +58,7 @@ async def logout(response: Response):
@router.post("/update", response_model=dict, dependencies=[Depends(get_current_user)])
async def update_user(user_data: UserUpdate, response: Response):
old_user = active_user[0]
if update_user_info(user_data, old_user):
if await update_user_info(user_data, old_user):
token = create_access_token(
data={"sub": old_user}, expires_delta=timedelta(days=1)
)

View File

@@ -29,22 +29,25 @@ router = APIRouter(prefix="/passkey", tags=["passkey"])
def _get_webauthn_from_request(request: Request):
"""
从请求中构造 WebAuthnService
根据 Host header 动态确定 RP ID 和 origin
优先使用浏览器的 Origin header与 clientDataJSON 中的 origin 一致)
"""
host = request.headers.get("host", "localhost:7892")
rp_id = host.split(":")[0] # 去掉端口
from urllib.parse import urlparse
# 判断协议
forwarded_proto = request.headers.get("x-forwarded-proto")
if forwarded_proto:
scheme = forwarded_proto
else:
scheme = request.url.scheme
origin = request.headers.get("origin")
if not origin:
# Fallback: 从 Referer 或 Host 推断
referer = request.headers.get("referer", "")
if referer:
parsed = urlparse(referer)
origin = f"{parsed.scheme}://{parsed.netloc}"
else:
host = request.headers.get("host", "localhost:7892")
forwarded_proto = request.headers.get("x-forwarded-proto")
scheme = forwarded_proto if forwarded_proto else request.url.scheme
origin = f"{scheme}://{host}"
if scheme == "https":
origin = f"https://{host}"
else:
origin = f"http://{host}"
parsed_origin = urlparse(origin)
rp_id = parsed_origin.hostname or "localhost"
return get_webauthn_service(rp_id, "AutoBangumi", origin)

View File

@@ -4,7 +4,7 @@ from sqlmodel import SQLModel
from module.models import Bangumi, Passkey, User
from .bangumi import BangumiDatabase
from .engine import async_session_factory, engine as e
from .engine import async_engine, async_session_factory, engine as e
from .passkey import PasskeyDatabase
from .rss import RSSDatabase
from .torrent import TorrentDatabase
@@ -13,13 +13,28 @@ from .user import UserDatabase
class Database:
def __init__(self):
self._session: AsyncSession | None = None
self._session = None
self.rss: RSSDatabase | None = None
self.torrent: TorrentDatabase | None = None
self.bangumi: BangumiDatabase | None = None
self.user: UserDatabase | None = None
self.passkey: PasskeyDatabase | None = None
# Sync context manager (for legacy code)
def __enter__(self):
from .engine import db_session
self._session = db_session
self.rss = RSSDatabase(self._session)
self.torrent = TorrentDatabase(self._session)
self.bangumi = BangumiDatabase(self._session)
self.user = UserDatabase(self._session)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
pass
# Async context manager (for passkey and new async code)
async def __aenter__(self):
self._session = async_session_factory()
self.rss = RSSDatabase(self._session)
@@ -30,25 +45,31 @@ class Database:
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
if self._session:
if self._session and isinstance(self._session, AsyncSession):
await self._session.close()
async def create_table(self):
async with e.begin() as conn:
async with async_engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all)
async def drop_table(self):
async with e.begin() as conn:
async with async_engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.drop_all)
async def commit(self):
if self._session:
await self._session.commit()
if isinstance(self._session, AsyncSession):
await self._session.commit()
else:
self._session.commit()
async def add(self, obj):
if self._session:
self._session.add(obj)
await self._session.commit()
if isinstance(self._session, AsyncSession):
await self._session.commit()
else:
self._session.commit()
async def migrate(self):
# Run migration online

View File

@@ -1,7 +1,14 @@
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from sqlmodel import Session, create_engine
from module.conf import DATA_PATH
# Sync engine (for legacy code)
engine = create_engine(DATA_PATH)
db_session = Session(engine)
# Async engine (for passkey and new async code)
ASYNC_DATA_PATH = DATA_PATH.replace("sqlite:///", "sqlite+aiosqlite:///")
async_engine = create_async_engine(ASYNC_DATA_PATH)
async_session_factory = sessionmaker(async_engine, class_=AsyncSession, expire_on_commit=False)

View File

@@ -1,7 +1,8 @@
import logging
from fastapi import HTTPException
from sqlmodel import Session, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import select
from module.models import ResponseModel
from module.models.user import User, UserLogin, UserUpdate
@@ -11,28 +12,36 @@ logger = logging.getLogger(__name__)
class UserDatabase:
def __init__(self, session: Session):
def __init__(self, session):
self.session = session
def get_user(self, username):
async def get_user(self, username):
statement = select(User).where(User.username == username)
result = self.session.exec(statement).first()
if not result:
if isinstance(self.session, AsyncSession):
result = await self.session.execute(statement)
user = result.scalar_one_or_none()
else:
user = self.session.exec(statement).first()
if not user:
raise HTTPException(status_code=404, detail="User not found")
return result
return user
def auth_user(self, user: User):
async def auth_user(self, user: User):
statement = select(User).where(User.username == user.username)
result = self.session.exec(statement).first()
if isinstance(self.session, AsyncSession):
result = await self.session.execute(statement)
db_user = result.scalar_one_or_none()
else:
db_user = self.session.exec(statement).first()
if not user.password:
return ResponseModel(
status_code=401, status=False, msg_en="Incorrect password format", msg_zh="密码格式不正确"
)
if not result:
if not db_user:
return ResponseModel(
status_code=401, status=False, msg_en="User not found", msg_zh="用户不存在"
)
if not verify_password(user.password, result.password):
if not verify_password(user.password, db_user.password):
return ResponseModel(
status_code=401,
status=False,
@@ -43,36 +52,59 @@ class UserDatabase:
status_code=200, status=True, msg_en="Login successfully", msg_zh="登录成功"
)
def update_user(self, username, update_user: UserUpdate):
# Update username and password
async def update_user(self, username, update_user: UserUpdate):
statement = select(User).where(User.username == username)
result = self.session.exec(statement).first()
if not result:
if isinstance(self.session, AsyncSession):
result = await self.session.execute(statement)
db_user = result.scalar_one_or_none()
else:
db_user = self.session.exec(statement).first()
if not db_user:
raise HTTPException(status_code=404, detail="User not found")
if update_user.username:
result.username = update_user.username
db_user.username = update_user.username
if update_user.password:
result.password = get_password_hash(update_user.password)
self.session.add(result)
self.session.commit()
return result
db_user.password = get_password_hash(update_user.password)
self.session.add(db_user)
if isinstance(self.session, AsyncSession):
await self.session.commit()
else:
self.session.commit()
return db_user
async def add_default_user(self):
statement = select(User)
if isinstance(self.session, AsyncSession):
result = await self.session.execute(statement)
users = list(result.scalars().all())
else:
try:
users = self.session.exec(statement).all()
except Exception:
self.merge_old_user()
users = self.session.exec(statement).all()
if len(users) != 0:
return
user = User(username="admin", password=get_password_hash("adminadmin"))
self.session.add(user)
if isinstance(self.session, AsyncSession):
await self.session.commit()
else:
self.session.commit()
def merge_old_user(self):
# get old data
# Legacy migration - sync only
statement = """
SELECT * FROM user
"""
result = self.session.exec(statement).first()
if not result:
return
# add new data
user = User(username=result.username, password=result.password)
# Drop old table
statement = """
DROP TABLE user
"""
self.session.exec(statement)
# Create new table
statement = """
CREATE TABLE user (
id INTEGER NOT NULL PRIMARY KEY,
@@ -83,18 +115,3 @@ class UserDatabase:
self.session.exec(statement)
self.session.add(user)
self.session.commit()
def add_default_user(self):
# Check if user exists
statement = select(User)
try:
result = self.session.exec(statement).all()
except Exception:
self.merge_old_user()
result = self.session.exec(statement).all()
if len(result) != 0:
return
# Add default user
user = User(username="admin", password=get_password_hash("adminadmin"))
self.session.add(user)
self.session.commit()

View File

@@ -34,18 +34,18 @@ async def get_token_data(token: str = Depends(oauth2_scheme)):
return payload
def update_user_info(user_data: UserUpdate, current_user):
async def update_user_info(user_data: UserUpdate, current_user):
try:
with Database() as db:
db.user.update_user(current_user, user_data)
async with Database() as db:
await db.user.update_user(current_user, user_data)
return True
except Exception as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
def auth_user(user: User):
with Database() as db:
resp = db.user.auth_user(user)
async def auth_user(user: User):
async with Database() as db:
resp = await db.user.auth_user(user)
if resp.status:
active_user.append(user.username)
return resp

View File

@@ -18,6 +18,7 @@ from webauthn.helpers.cose import COSEAlgorithmIdentifier
from webauthn.helpers.structs import (
AuthenticatorSelectionCriteria,
AuthenticatorTransport,
CredentialDeviceType,
PublicKeyCredentialDescriptor,
PublicKeyCredentialType,
ResidentKeyRequirement,
@@ -135,8 +136,9 @@ class WebAuthnService:
"utf-8"
),
sign_count=verification.sign_count,
aaguid=verification.aaguid.hex() if verification.aaguid else None,
backup_eligible=verification.credential_backup_eligible,
aaguid=verification.aaguid if verification.aaguid else None,
backup_eligible=verification.credential_device_type
== CredentialDeviceType.MULTI_DEVICE,
backup_state=verification.credential_backed_up,
)
@@ -214,7 +216,6 @@ class WebAuthnService:
try:
# 解码 public key
credential_public_key = base64.b64decode(passkey.public_key)
credential_id = self.base64url_decode(passkey.credential_id)
verification = verify_authentication_response(
credential=credential,
@@ -223,7 +224,6 @@ class WebAuthnService:
expected_origin=self.origin,
credential_public_key=credential_public_key,
credential_current_sign_count=passkey.sign_count,
credential_id=credential_id,
)
logger.info(f"Successfully verified authentication for {username}")