mirror of
https://github.com/EstrellaXD/Auto_Bangumi.git
synced 2026-04-13 17:19:56 +08:00
fix: resolve WebAuthn passkey compatibility with py_webauthn 2.7.0
- Fix aaguid type (str not bytes) in registration verification - Fix missing credential_backup_eligible field (use credential_device_type) - Remove invalid credential_id param from verify_authentication_response - Fix origin detection to use browser Origin header for WebAuthn verification - Add async database engine support (aiosqlite) for passkey operations - Convert UserDatabase to async-compatible with sync/async session detection - Update Database class to support both sync and async context managers Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -22,7 +22,7 @@ router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
@router.post("/login", response_model=dict)
|
||||
async def login(response: Response, form_data=Depends(OAuth2PasswordRequestForm)):
|
||||
user = User(username=form_data.username, password=form_data.password)
|
||||
resp = auth_user(user)
|
||||
resp = await auth_user(user)
|
||||
if resp.status:
|
||||
token = create_access_token(
|
||||
data={"sub": user.username}, expires_delta=timedelta(days=1)
|
||||
@@ -58,7 +58,7 @@ async def logout(response: Response):
|
||||
@router.post("/update", response_model=dict, dependencies=[Depends(get_current_user)])
|
||||
async def update_user(user_data: UserUpdate, response: Response):
|
||||
old_user = active_user[0]
|
||||
if update_user_info(user_data, old_user):
|
||||
if await update_user_info(user_data, old_user):
|
||||
token = create_access_token(
|
||||
data={"sub": old_user}, expires_delta=timedelta(days=1)
|
||||
)
|
||||
|
||||
@@ -29,22 +29,25 @@ router = APIRouter(prefix="/passkey", tags=["passkey"])
|
||||
def _get_webauthn_from_request(request: Request):
|
||||
"""
|
||||
从请求中构造 WebAuthnService
|
||||
根据 Host header 动态确定 RP ID 和 origin
|
||||
优先使用浏览器的 Origin header(与 clientDataJSON 中的 origin 一致)
|
||||
"""
|
||||
host = request.headers.get("host", "localhost:7892")
|
||||
rp_id = host.split(":")[0] # 去掉端口
|
||||
from urllib.parse import urlparse
|
||||
|
||||
# 判断协议
|
||||
forwarded_proto = request.headers.get("x-forwarded-proto")
|
||||
if forwarded_proto:
|
||||
scheme = forwarded_proto
|
||||
else:
|
||||
scheme = request.url.scheme
|
||||
origin = request.headers.get("origin")
|
||||
if not origin:
|
||||
# Fallback: 从 Referer 或 Host 推断
|
||||
referer = request.headers.get("referer", "")
|
||||
if referer:
|
||||
parsed = urlparse(referer)
|
||||
origin = f"{parsed.scheme}://{parsed.netloc}"
|
||||
else:
|
||||
host = request.headers.get("host", "localhost:7892")
|
||||
forwarded_proto = request.headers.get("x-forwarded-proto")
|
||||
scheme = forwarded_proto if forwarded_proto else request.url.scheme
|
||||
origin = f"{scheme}://{host}"
|
||||
|
||||
if scheme == "https":
|
||||
origin = f"https://{host}"
|
||||
else:
|
||||
origin = f"http://{host}"
|
||||
parsed_origin = urlparse(origin)
|
||||
rp_id = parsed_origin.hostname or "localhost"
|
||||
|
||||
return get_webauthn_service(rp_id, "AutoBangumi", origin)
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ from sqlmodel import SQLModel
|
||||
from module.models import Bangumi, Passkey, User
|
||||
|
||||
from .bangumi import BangumiDatabase
|
||||
from .engine import async_session_factory, engine as e
|
||||
from .engine import async_engine, async_session_factory, engine as e
|
||||
from .passkey import PasskeyDatabase
|
||||
from .rss import RSSDatabase
|
||||
from .torrent import TorrentDatabase
|
||||
@@ -13,13 +13,28 @@ from .user import UserDatabase
|
||||
|
||||
class Database:
|
||||
def __init__(self):
|
||||
self._session: AsyncSession | None = None
|
||||
self._session = None
|
||||
self.rss: RSSDatabase | None = None
|
||||
self.torrent: TorrentDatabase | None = None
|
||||
self.bangumi: BangumiDatabase | None = None
|
||||
self.user: UserDatabase | None = None
|
||||
self.passkey: PasskeyDatabase | None = None
|
||||
|
||||
# Sync context manager (for legacy code)
|
||||
def __enter__(self):
|
||||
from .engine import db_session
|
||||
|
||||
self._session = db_session
|
||||
self.rss = RSSDatabase(self._session)
|
||||
self.torrent = TorrentDatabase(self._session)
|
||||
self.bangumi = BangumiDatabase(self._session)
|
||||
self.user = UserDatabase(self._session)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
pass
|
||||
|
||||
# Async context manager (for passkey and new async code)
|
||||
async def __aenter__(self):
|
||||
self._session = async_session_factory()
|
||||
self.rss = RSSDatabase(self._session)
|
||||
@@ -30,25 +45,31 @@ class Database:
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
if self._session:
|
||||
if self._session and isinstance(self._session, AsyncSession):
|
||||
await self._session.close()
|
||||
|
||||
async def create_table(self):
|
||||
async with e.begin() as conn:
|
||||
async with async_engine.begin() as conn:
|
||||
await conn.run_sync(SQLModel.metadata.create_all)
|
||||
|
||||
async def drop_table(self):
|
||||
async with e.begin() as conn:
|
||||
async with async_engine.begin() as conn:
|
||||
await conn.run_sync(SQLModel.metadata.drop_all)
|
||||
|
||||
async def commit(self):
|
||||
if self._session:
|
||||
await self._session.commit()
|
||||
if isinstance(self._session, AsyncSession):
|
||||
await self._session.commit()
|
||||
else:
|
||||
self._session.commit()
|
||||
|
||||
async def add(self, obj):
|
||||
if self._session:
|
||||
self._session.add(obj)
|
||||
await self._session.commit()
|
||||
if isinstance(self._session, AsyncSession):
|
||||
await self._session.commit()
|
||||
else:
|
||||
self._session.commit()
|
||||
|
||||
async def migrate(self):
|
||||
# Run migration online
|
||||
|
||||
@@ -1,7 +1,14 @@
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlmodel import Session, create_engine
|
||||
|
||||
from module.conf import DATA_PATH
|
||||
|
||||
# Sync engine (for legacy code)
|
||||
engine = create_engine(DATA_PATH)
|
||||
|
||||
db_session = Session(engine)
|
||||
|
||||
# Async engine (for passkey and new async code)
|
||||
ASYNC_DATA_PATH = DATA_PATH.replace("sqlite:///", "sqlite+aiosqlite:///")
|
||||
async_engine = create_async_engine(ASYNC_DATA_PATH)
|
||||
async_session_factory = sessionmaker(async_engine, class_=AsyncSession, expire_on_commit=False)
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import logging
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlmodel import Session, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlmodel import select
|
||||
|
||||
from module.models import ResponseModel
|
||||
from module.models.user import User, UserLogin, UserUpdate
|
||||
@@ -11,28 +12,36 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UserDatabase:
|
||||
def __init__(self, session: Session):
|
||||
def __init__(self, session):
|
||||
self.session = session
|
||||
|
||||
def get_user(self, username):
|
||||
async def get_user(self, username):
|
||||
statement = select(User).where(User.username == username)
|
||||
result = self.session.exec(statement).first()
|
||||
if not result:
|
||||
if isinstance(self.session, AsyncSession):
|
||||
result = await self.session.execute(statement)
|
||||
user = result.scalar_one_or_none()
|
||||
else:
|
||||
user = self.session.exec(statement).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
return result
|
||||
return user
|
||||
|
||||
def auth_user(self, user: User):
|
||||
async def auth_user(self, user: User):
|
||||
statement = select(User).where(User.username == user.username)
|
||||
result = self.session.exec(statement).first()
|
||||
if isinstance(self.session, AsyncSession):
|
||||
result = await self.session.execute(statement)
|
||||
db_user = result.scalar_one_or_none()
|
||||
else:
|
||||
db_user = self.session.exec(statement).first()
|
||||
if not user.password:
|
||||
return ResponseModel(
|
||||
status_code=401, status=False, msg_en="Incorrect password format", msg_zh="密码格式不正确"
|
||||
)
|
||||
if not result:
|
||||
if not db_user:
|
||||
return ResponseModel(
|
||||
status_code=401, status=False, msg_en="User not found", msg_zh="用户不存在"
|
||||
)
|
||||
if not verify_password(user.password, result.password):
|
||||
if not verify_password(user.password, db_user.password):
|
||||
return ResponseModel(
|
||||
status_code=401,
|
||||
status=False,
|
||||
@@ -43,36 +52,59 @@ class UserDatabase:
|
||||
status_code=200, status=True, msg_en="Login successfully", msg_zh="登录成功"
|
||||
)
|
||||
|
||||
def update_user(self, username, update_user: UserUpdate):
|
||||
# Update username and password
|
||||
async def update_user(self, username, update_user: UserUpdate):
|
||||
statement = select(User).where(User.username == username)
|
||||
result = self.session.exec(statement).first()
|
||||
if not result:
|
||||
if isinstance(self.session, AsyncSession):
|
||||
result = await self.session.execute(statement)
|
||||
db_user = result.scalar_one_or_none()
|
||||
else:
|
||||
db_user = self.session.exec(statement).first()
|
||||
if not db_user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
if update_user.username:
|
||||
result.username = update_user.username
|
||||
db_user.username = update_user.username
|
||||
if update_user.password:
|
||||
result.password = get_password_hash(update_user.password)
|
||||
self.session.add(result)
|
||||
self.session.commit()
|
||||
return result
|
||||
db_user.password = get_password_hash(update_user.password)
|
||||
self.session.add(db_user)
|
||||
if isinstance(self.session, AsyncSession):
|
||||
await self.session.commit()
|
||||
else:
|
||||
self.session.commit()
|
||||
return db_user
|
||||
|
||||
async def add_default_user(self):
|
||||
statement = select(User)
|
||||
if isinstance(self.session, AsyncSession):
|
||||
result = await self.session.execute(statement)
|
||||
users = list(result.scalars().all())
|
||||
else:
|
||||
try:
|
||||
users = self.session.exec(statement).all()
|
||||
except Exception:
|
||||
self.merge_old_user()
|
||||
users = self.session.exec(statement).all()
|
||||
if len(users) != 0:
|
||||
return
|
||||
user = User(username="admin", password=get_password_hash("adminadmin"))
|
||||
self.session.add(user)
|
||||
if isinstance(self.session, AsyncSession):
|
||||
await self.session.commit()
|
||||
else:
|
||||
self.session.commit()
|
||||
|
||||
def merge_old_user(self):
|
||||
# get old data
|
||||
# Legacy migration - sync only
|
||||
statement = """
|
||||
SELECT * FROM user
|
||||
"""
|
||||
result = self.session.exec(statement).first()
|
||||
if not result:
|
||||
return
|
||||
# add new data
|
||||
user = User(username=result.username, password=result.password)
|
||||
# Drop old table
|
||||
statement = """
|
||||
DROP TABLE user
|
||||
"""
|
||||
self.session.exec(statement)
|
||||
# Create new table
|
||||
statement = """
|
||||
CREATE TABLE user (
|
||||
id INTEGER NOT NULL PRIMARY KEY,
|
||||
@@ -83,18 +115,3 @@ class UserDatabase:
|
||||
self.session.exec(statement)
|
||||
self.session.add(user)
|
||||
self.session.commit()
|
||||
|
||||
def add_default_user(self):
|
||||
# Check if user exists
|
||||
statement = select(User)
|
||||
try:
|
||||
result = self.session.exec(statement).all()
|
||||
except Exception:
|
||||
self.merge_old_user()
|
||||
result = self.session.exec(statement).all()
|
||||
if len(result) != 0:
|
||||
return
|
||||
# Add default user
|
||||
user = User(username="admin", password=get_password_hash("adminadmin"))
|
||||
self.session.add(user)
|
||||
self.session.commit()
|
||||
|
||||
@@ -34,18 +34,18 @@ async def get_token_data(token: str = Depends(oauth2_scheme)):
|
||||
return payload
|
||||
|
||||
|
||||
def update_user_info(user_data: UserUpdate, current_user):
|
||||
async def update_user_info(user_data: UserUpdate, current_user):
|
||||
try:
|
||||
with Database() as db:
|
||||
db.user.update_user(current_user, user_data)
|
||||
async with Database() as db:
|
||||
await db.user.update_user(current_user, user_data)
|
||||
return True
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
|
||||
|
||||
def auth_user(user: User):
|
||||
with Database() as db:
|
||||
resp = db.user.auth_user(user)
|
||||
async def auth_user(user: User):
|
||||
async with Database() as db:
|
||||
resp = await db.user.auth_user(user)
|
||||
if resp.status:
|
||||
active_user.append(user.username)
|
||||
return resp
|
||||
|
||||
@@ -18,6 +18,7 @@ from webauthn.helpers.cose import COSEAlgorithmIdentifier
|
||||
from webauthn.helpers.structs import (
|
||||
AuthenticatorSelectionCriteria,
|
||||
AuthenticatorTransport,
|
||||
CredentialDeviceType,
|
||||
PublicKeyCredentialDescriptor,
|
||||
PublicKeyCredentialType,
|
||||
ResidentKeyRequirement,
|
||||
@@ -135,8 +136,9 @@ class WebAuthnService:
|
||||
"utf-8"
|
||||
),
|
||||
sign_count=verification.sign_count,
|
||||
aaguid=verification.aaguid.hex() if verification.aaguid else None,
|
||||
backup_eligible=verification.credential_backup_eligible,
|
||||
aaguid=verification.aaguid if verification.aaguid else None,
|
||||
backup_eligible=verification.credential_device_type
|
||||
== CredentialDeviceType.MULTI_DEVICE,
|
||||
backup_state=verification.credential_backed_up,
|
||||
)
|
||||
|
||||
@@ -214,7 +216,6 @@ class WebAuthnService:
|
||||
try:
|
||||
# 解码 public key
|
||||
credential_public_key = base64.b64decode(passkey.public_key)
|
||||
credential_id = self.base64url_decode(passkey.credential_id)
|
||||
|
||||
verification = verify_authentication_response(
|
||||
credential=credential,
|
||||
@@ -223,7 +224,6 @@ class WebAuthnService:
|
||||
expected_origin=self.origin,
|
||||
credential_public_key=credential_public_key,
|
||||
credential_current_sign_count=passkey.sign_count,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
|
||||
logger.info(f"Successfully verified authentication for {username}")
|
||||
|
||||
Reference in New Issue
Block a user