From 027222a24d72772152daa576e4a90984c4dbfffa Mon Sep 17 00:00:00 2001 From: EstrellaXD Date: Fri, 23 Jan 2026 15:07:18 +0100 Subject: [PATCH] fix: resolve WebAuthn passkey compatibility with py_webauthn 2.7.0 - Fix aaguid type (str not bytes) in registration verification - Fix missing credential_backup_eligible field (use credential_device_type) - Remove invalid credential_id param from verify_authentication_response - Fix origin detection to use browser Origin header for WebAuthn verification - Add async database engine support (aiosqlite) for passkey operations - Convert UserDatabase to async-compatible with sync/async session detection - Update Database class to support both sync and async context managers Co-Authored-By: Claude Opus 4.5 --- backend/src/module/api/auth.py | 4 +- backend/src/module/api/passkey.py | 29 ++++---- backend/src/module/database/combine.py | 35 ++++++++-- backend/src/module/database/engine.py | 9 ++- backend/src/module/database/user.py | 93 +++++++++++++++---------- backend/src/module/security/api.py | 12 ++-- backend/src/module/security/webauthn.py | 8 +-- 7 files changed, 119 insertions(+), 71 deletions(-) diff --git a/backend/src/module/api/auth.py b/backend/src/module/api/auth.py index 8c6667c6..d3399e71 100644 --- a/backend/src/module/api/auth.py +++ b/backend/src/module/api/auth.py @@ -22,7 +22,7 @@ router = APIRouter(prefix="/auth", tags=["auth"]) @router.post("/login", response_model=dict) async def login(response: Response, form_data=Depends(OAuth2PasswordRequestForm)): user = User(username=form_data.username, password=form_data.password) - resp = auth_user(user) + resp = await auth_user(user) if resp.status: token = create_access_token( data={"sub": user.username}, expires_delta=timedelta(days=1) @@ -58,7 +58,7 @@ async def logout(response: Response): @router.post("/update", response_model=dict, dependencies=[Depends(get_current_user)]) async def update_user(user_data: UserUpdate, response: Response): old_user = active_user[0] - if update_user_info(user_data, old_user): + if await update_user_info(user_data, old_user): token = create_access_token( data={"sub": old_user}, expires_delta=timedelta(days=1) ) diff --git a/backend/src/module/api/passkey.py b/backend/src/module/api/passkey.py index 16e739d9..56b66ac4 100644 --- a/backend/src/module/api/passkey.py +++ b/backend/src/module/api/passkey.py @@ -29,22 +29,25 @@ router = APIRouter(prefix="/passkey", tags=["passkey"]) def _get_webauthn_from_request(request: Request): """ 从请求中构造 WebAuthnService - 根据 Host header 动态确定 RP ID 和 origin + 优先使用浏览器的 Origin header(与 clientDataJSON 中的 origin 一致) """ - host = request.headers.get("host", "localhost:7892") - rp_id = host.split(":")[0] # 去掉端口 + from urllib.parse import urlparse - # 判断协议 - forwarded_proto = request.headers.get("x-forwarded-proto") - if forwarded_proto: - scheme = forwarded_proto - else: - scheme = request.url.scheme + origin = request.headers.get("origin") + if not origin: + # Fallback: 从 Referer 或 Host 推断 + referer = request.headers.get("referer", "") + if referer: + parsed = urlparse(referer) + origin = f"{parsed.scheme}://{parsed.netloc}" + else: + host = request.headers.get("host", "localhost:7892") + forwarded_proto = request.headers.get("x-forwarded-proto") + scheme = forwarded_proto if forwarded_proto else request.url.scheme + origin = f"{scheme}://{host}" - if scheme == "https": - origin = f"https://{host}" - else: - origin = f"http://{host}" + parsed_origin = urlparse(origin) + rp_id = parsed_origin.hostname or "localhost" return get_webauthn_service(rp_id, "AutoBangumi", origin) diff --git a/backend/src/module/database/combine.py b/backend/src/module/database/combine.py index 4fc0fe60..7b06196b 100644 --- a/backend/src/module/database/combine.py +++ b/backend/src/module/database/combine.py @@ -4,7 +4,7 @@ from sqlmodel import SQLModel from module.models import Bangumi, Passkey, User from .bangumi import BangumiDatabase -from .engine import async_session_factory, engine as e +from .engine import async_engine, async_session_factory, engine as e from .passkey import PasskeyDatabase from .rss import RSSDatabase from .torrent import TorrentDatabase @@ -13,13 +13,28 @@ from .user import UserDatabase class Database: def __init__(self): - self._session: AsyncSession | None = None + self._session = None self.rss: RSSDatabase | None = None self.torrent: TorrentDatabase | None = None self.bangumi: BangumiDatabase | None = None self.user: UserDatabase | None = None self.passkey: PasskeyDatabase | None = None + # Sync context manager (for legacy code) + def __enter__(self): + from .engine import db_session + + self._session = db_session + self.rss = RSSDatabase(self._session) + self.torrent = TorrentDatabase(self._session) + self.bangumi = BangumiDatabase(self._session) + self.user = UserDatabase(self._session) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + # Async context manager (for passkey and new async code) async def __aenter__(self): self._session = async_session_factory() self.rss = RSSDatabase(self._session) @@ -30,25 +45,31 @@ class Database: return self async def __aexit__(self, exc_type, exc_val, exc_tb): - if self._session: + if self._session and isinstance(self._session, AsyncSession): await self._session.close() async def create_table(self): - async with e.begin() as conn: + async with async_engine.begin() as conn: await conn.run_sync(SQLModel.metadata.create_all) async def drop_table(self): - async with e.begin() as conn: + async with async_engine.begin() as conn: await conn.run_sync(SQLModel.metadata.drop_all) async def commit(self): if self._session: - await self._session.commit() + if isinstance(self._session, AsyncSession): + await self._session.commit() + else: + self._session.commit() async def add(self, obj): if self._session: self._session.add(obj) - await self._session.commit() + if isinstance(self._session, AsyncSession): + await self._session.commit() + else: + self._session.commit() async def migrate(self): # Run migration online diff --git a/backend/src/module/database/engine.py b/backend/src/module/database/engine.py index c94e792c..adf98eaa 100644 --- a/backend/src/module/database/engine.py +++ b/backend/src/module/database/engine.py @@ -1,7 +1,14 @@ +from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine +from sqlalchemy.orm import sessionmaker from sqlmodel import Session, create_engine from module.conf import DATA_PATH +# Sync engine (for legacy code) engine = create_engine(DATA_PATH) - db_session = Session(engine) + +# Async engine (for passkey and new async code) +ASYNC_DATA_PATH = DATA_PATH.replace("sqlite:///", "sqlite+aiosqlite:///") +async_engine = create_async_engine(ASYNC_DATA_PATH) +async_session_factory = sessionmaker(async_engine, class_=AsyncSession, expire_on_commit=False) diff --git a/backend/src/module/database/user.py b/backend/src/module/database/user.py index 336f81f4..ac303e23 100644 --- a/backend/src/module/database/user.py +++ b/backend/src/module/database/user.py @@ -1,7 +1,8 @@ import logging from fastapi import HTTPException -from sqlmodel import Session, select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlmodel import select from module.models import ResponseModel from module.models.user import User, UserLogin, UserUpdate @@ -11,28 +12,36 @@ logger = logging.getLogger(__name__) class UserDatabase: - def __init__(self, session: Session): + def __init__(self, session): self.session = session - def get_user(self, username): + async def get_user(self, username): statement = select(User).where(User.username == username) - result = self.session.exec(statement).first() - if not result: + if isinstance(self.session, AsyncSession): + result = await self.session.execute(statement) + user = result.scalar_one_or_none() + else: + user = self.session.exec(statement).first() + if not user: raise HTTPException(status_code=404, detail="User not found") - return result + return user - def auth_user(self, user: User): + async def auth_user(self, user: User): statement = select(User).where(User.username == user.username) - result = self.session.exec(statement).first() + if isinstance(self.session, AsyncSession): + result = await self.session.execute(statement) + db_user = result.scalar_one_or_none() + else: + db_user = self.session.exec(statement).first() if not user.password: return ResponseModel( status_code=401, status=False, msg_en="Incorrect password format", msg_zh="密码格式不正确" ) - if not result: + if not db_user: return ResponseModel( status_code=401, status=False, msg_en="User not found", msg_zh="用户不存在" ) - if not verify_password(user.password, result.password): + if not verify_password(user.password, db_user.password): return ResponseModel( status_code=401, status=False, @@ -43,36 +52,59 @@ class UserDatabase: status_code=200, status=True, msg_en="Login successfully", msg_zh="登录成功" ) - def update_user(self, username, update_user: UserUpdate): - # Update username and password + async def update_user(self, username, update_user: UserUpdate): statement = select(User).where(User.username == username) - result = self.session.exec(statement).first() - if not result: + if isinstance(self.session, AsyncSession): + result = await self.session.execute(statement) + db_user = result.scalar_one_or_none() + else: + db_user = self.session.exec(statement).first() + if not db_user: raise HTTPException(status_code=404, detail="User not found") if update_user.username: - result.username = update_user.username + db_user.username = update_user.username if update_user.password: - result.password = get_password_hash(update_user.password) - self.session.add(result) - self.session.commit() - return result + db_user.password = get_password_hash(update_user.password) + self.session.add(db_user) + if isinstance(self.session, AsyncSession): + await self.session.commit() + else: + self.session.commit() + return db_user + + async def add_default_user(self): + statement = select(User) + if isinstance(self.session, AsyncSession): + result = await self.session.execute(statement) + users = list(result.scalars().all()) + else: + try: + users = self.session.exec(statement).all() + except Exception: + self.merge_old_user() + users = self.session.exec(statement).all() + if len(users) != 0: + return + user = User(username="admin", password=get_password_hash("adminadmin")) + self.session.add(user) + if isinstance(self.session, AsyncSession): + await self.session.commit() + else: + self.session.commit() def merge_old_user(self): - # get old data + # Legacy migration - sync only statement = """ SELECT * FROM user """ result = self.session.exec(statement).first() if not result: return - # add new data user = User(username=result.username, password=result.password) - # Drop old table statement = """ DROP TABLE user """ self.session.exec(statement) - # Create new table statement = """ CREATE TABLE user ( id INTEGER NOT NULL PRIMARY KEY, @@ -83,18 +115,3 @@ class UserDatabase: self.session.exec(statement) self.session.add(user) self.session.commit() - - def add_default_user(self): - # Check if user exists - statement = select(User) - try: - result = self.session.exec(statement).all() - except Exception: - self.merge_old_user() - result = self.session.exec(statement).all() - if len(result) != 0: - return - # Add default user - user = User(username="admin", password=get_password_hash("adminadmin")) - self.session.add(user) - self.session.commit() diff --git a/backend/src/module/security/api.py b/backend/src/module/security/api.py index e0150e68..4a1ba47d 100644 --- a/backend/src/module/security/api.py +++ b/backend/src/module/security/api.py @@ -34,18 +34,18 @@ async def get_token_data(token: str = Depends(oauth2_scheme)): return payload -def update_user_info(user_data: UserUpdate, current_user): +async def update_user_info(user_data: UserUpdate, current_user): try: - with Database() as db: - db.user.update_user(current_user, user_data) + async with Database() as db: + await db.user.update_user(current_user, user_data) return True except Exception as e: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) -def auth_user(user: User): - with Database() as db: - resp = db.user.auth_user(user) +async def auth_user(user: User): + async with Database() as db: + resp = await db.user.auth_user(user) if resp.status: active_user.append(user.username) return resp diff --git a/backend/src/module/security/webauthn.py b/backend/src/module/security/webauthn.py index f3370ee8..c7228fc1 100644 --- a/backend/src/module/security/webauthn.py +++ b/backend/src/module/security/webauthn.py @@ -18,6 +18,7 @@ from webauthn.helpers.cose import COSEAlgorithmIdentifier from webauthn.helpers.structs import ( AuthenticatorSelectionCriteria, AuthenticatorTransport, + CredentialDeviceType, PublicKeyCredentialDescriptor, PublicKeyCredentialType, ResidentKeyRequirement, @@ -135,8 +136,9 @@ class WebAuthnService: "utf-8" ), sign_count=verification.sign_count, - aaguid=verification.aaguid.hex() if verification.aaguid else None, - backup_eligible=verification.credential_backup_eligible, + aaguid=verification.aaguid if verification.aaguid else None, + backup_eligible=verification.credential_device_type + == CredentialDeviceType.MULTI_DEVICE, backup_state=verification.credential_backed_up, ) @@ -214,7 +216,6 @@ class WebAuthnService: try: # 解码 public key credential_public_key = base64.b64decode(passkey.public_key) - credential_id = self.base64url_decode(passkey.credential_id) verification = verify_authentication_response( credential=credential, @@ -223,7 +224,6 @@ class WebAuthnService: expected_origin=self.origin, credential_public_key=credential_public_key, credential_current_sign_count=passkey.sign_count, - credential_id=credential_id, ) logger.info(f"Successfully verified authentication for {username}")