fix(backend): restore sync Database interface, isolate async for passkey

The previous refactoring broke backward compatibility by converting
Database from Session-extending sync class to a standalone async class.
This broke RSSEngine, startup code, and auth flows.

- Restore Database(Session) with sync interface for legacy code
- Restore UserDatabase to sync methods
- Restore security/api.py and auth.py to sync calls
- Passkey API now uses async_session_factory directly
- PasskeyAuthStrategy uses async sessions independently
- Remove unused db_session from engine.py

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
EstrellaXD
2026-01-23 15:58:26 +01:00
parent 29e4c16b40
commit cd233881bd
7 changed files with 157 additions and 179 deletions

View File

@@ -22,7 +22,7 @@ router = APIRouter(prefix="/auth", tags=["auth"])
@router.post("/login", response_model=dict)
async def login(response: Response, form_data=Depends(OAuth2PasswordRequestForm)):
user = User(username=form_data.username, password=form_data.password)
resp = await auth_user(user)
resp = auth_user(user)
if resp.status:
token = create_access_token(
data={"sub": user.username}, expires_delta=timedelta(days=1)
@@ -58,7 +58,7 @@ async def logout(response: Response):
@router.post("/update", response_model=dict, dependencies=[Depends(get_current_user)])
async def update_user(user_data: UserUpdate, response: Response):
old_user = active_user[0]
if await update_user_info(user_data, old_user):
if update_user_info(user_data, old_user):
token = create_access_token(
data={"sub": old_user}, expires_delta=timedelta(days=1)
)

View File

@@ -7,8 +7,10 @@ from datetime import timedelta
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import JSONResponse, Response
from sqlmodel import select
from module.database import Database
from module.database.engine import async_session_factory
from module.database.passkey import PasskeyDatabase
from module.models import APIResponse
from module.models.passkey import (
PasskeyAuthFinish,
@@ -17,6 +19,7 @@ from module.models.passkey import (
PasskeyDelete,
PasskeyList,
)
from module.models.user import User
from module.security.api import active_user, get_current_user
from module.security.auth_strategy import PasskeyAuthStrategy
from module.security.jwt import create_access_token
@@ -66,10 +69,19 @@ async def get_registration_options(
"""
webauthn = _get_webauthn_from_request(request)
async with Database() as db:
async with async_session_factory() as session:
try:
user = await db.user.get_user(username)
existing_passkeys = await db.passkey.get_passkeys_by_user_id(user.id)
# Get user
result = await session.execute(
select(User).where(User.username == username)
)
user = result.scalar_one_or_none()
if not user:
raise HTTPException(status_code=404, detail="User not found")
# Get existing passkeys
passkey_db = PasskeyDatabase(session)
existing_passkeys = await passkey_db.get_passkeys_by_user_id(user.id)
options = webauthn.generate_registration_options(
username=username,
@@ -79,6 +91,8 @@ async def get_registration_options(
return options
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to generate registration options: {e}")
raise HTTPException(status_code=500, detail=str(e))
@@ -95,9 +109,15 @@ async def verify_registration(
"""
webauthn = _get_webauthn_from_request(request)
async with Database() as db:
async with async_session_factory() as session:
try:
user = await db.user.get_user(username)
# Get user
result = await session.execute(
select(User).where(User.username == username)
)
user = result.scalar_one_or_none()
if not user:
raise HTTPException(status_code=404, detail="User not found")
# 验证 WebAuthn 响应
passkey = webauthn.verify_registration(
@@ -108,7 +128,8 @@ async def verify_registration(
# 设置 user_id 并保存
passkey.user_id = user.id
await db.passkey.create_passkey(passkey)
passkey_db = PasskeyDatabase(session)
await passkey_db.create_passkey(passkey)
return JSONResponse(
status_code=200,
@@ -121,6 +142,8 @@ async def verify_registration(
except ValueError as e:
logger.warning(f"Registration verification failed for {username}: {e}")
raise HTTPException(status_code=400, detail=str(e))
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to register passkey: {e}")
raise HTTPException(status_code=500, detail=str(e))
@@ -140,10 +163,18 @@ async def get_passkey_login_options(
"""
webauthn = _get_webauthn_from_request(request)
async with Database() as db:
async with async_session_factory() as session:
try:
user = await db.user.get_user(auth_data.username)
passkeys = await db.passkey.get_passkeys_by_user_id(user.id)
# Get user
result = await session.execute(
select(User).where(User.username == auth_data.username)
)
user = result.scalar_one_or_none()
if not user:
raise HTTPException(status_code=404, detail="User not found")
passkey_db = PasskeyDatabase(session)
passkeys = await passkey_db.get_passkeys_by_user_id(user.id)
if not passkeys:
raise HTTPException(
@@ -194,13 +225,23 @@ async def login_with_passkey(
@router.get("/list", response_model=list[PasskeyList])
async def list_passkeys(username: str = Depends(get_current_user)):
"""获取用户的所有 Passkey"""
async with Database() as db:
async with async_session_factory() as session:
try:
user = await db.user.get_user(username)
passkeys = await db.passkey.get_passkeys_by_user_id(user.id)
# Get user
result = await session.execute(
select(User).where(User.username == username)
)
user = result.scalar_one_or_none()
if not user:
raise HTTPException(status_code=404, detail="User not found")
return [db.passkey.to_list_model(pk) for pk in passkeys]
passkey_db = PasskeyDatabase(session)
passkeys = await passkey_db.get_passkeys_by_user_id(user.id)
return [passkey_db.to_list_model(pk) for pk in passkeys]
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to list passkeys: {e}")
raise HTTPException(status_code=500, detail=str(e))
@@ -212,10 +253,18 @@ async def delete_passkey(
username: str = Depends(get_current_user),
):
"""删除 Passkey"""
async with Database() as db:
async with async_session_factory() as session:
try:
user = await db.user.get_user(username)
await db.passkey.delete_passkey(delete_data.passkey_id, user.id)
# Get user
result = await session.execute(
select(User).where(User.username == username)
)
user = result.scalar_one_or_none()
if not user:
raise HTTPException(status_code=404, detail="User not found")
passkey_db = PasskeyDatabase(session)
await passkey_db.delete_passkey(delete_data.passkey_id, user.id)
return JSONResponse(
status_code=200,

View File

@@ -1,85 +1,43 @@
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import SQLModel
from sqlmodel import Session, SQLModel
from module.models import Bangumi, Passkey, User
from module.models import Bangumi, User
from module.models.passkey import Passkey
from .bangumi import BangumiDatabase
from .engine import async_engine, async_session_factory, engine as e
from .passkey import PasskeyDatabase
from .engine import engine as e
from .rss import RSSDatabase
from .torrent import TorrentDatabase
from .user import UserDatabase
class Database:
def __init__(self):
self._session = None
self.rss: RSSDatabase | None = None
self.torrent: TorrentDatabase | None = None
self.bangumi: BangumiDatabase | None = None
self.user: UserDatabase | None = None
self.passkey: PasskeyDatabase | None = None
class Database(Session):
def __init__(self, engine=e):
self.engine = engine
super().__init__(engine)
self.rss = RSSDatabase(self)
self.torrent = TorrentDatabase(self)
self.bangumi = BangumiDatabase(self)
self.user = UserDatabase(self)
# Sync context manager (for legacy code)
def __enter__(self):
from .engine import db_session
def create_table(self):
SQLModel.metadata.create_all(self.engine)
self._session = db_session
self.rss = RSSDatabase(self._session)
self.torrent = TorrentDatabase(self._session)
self.bangumi = BangumiDatabase(self._session)
self.user = UserDatabase(self._session)
return self
def drop_table(self):
SQLModel.metadata.drop_all(self.engine)
def __exit__(self, exc_type, exc_val, exc_tb):
pass
# Async context manager (for passkey and new async code)
async def __aenter__(self):
self._session = async_session_factory()
self.rss = RSSDatabase(self._session)
self.torrent = TorrentDatabase(self._session)
self.bangumi = BangumiDatabase(self._session)
self.user = UserDatabase(self._session)
self.passkey = PasskeyDatabase(self._session)
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
if self._session and isinstance(self._session, AsyncSession):
await self._session.close()
async def create_table(self):
async with async_engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all)
async def drop_table(self):
async with async_engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.drop_all)
async def commit(self):
if self._session:
if isinstance(self._session, AsyncSession):
await self._session.commit()
else:
self._session.commit()
async def add(self, obj):
if self._session:
self._session.add(obj)
if isinstance(self._session, AsyncSession):
await self._session.commit()
else:
self._session.commit()
async def migrate(self):
def migrate(self):
# Run migration online
bangumi_data = await self.bangumi.search_all()
bangumi_data = self.bangumi.search_all()
user_data = self.exec("SELECT * FROM user").all()
readd_bangumi = []
for bangumi in bangumi_data:
dict_data = bangumi.dict()
del dict_data["id"]
readd_bangumi.append(Bangumi(**dict_data))
await self.drop_table()
await self.create_table()
await self.commit()
await self.bangumi.add_all(readd_bangumi)
self.drop_table()
self.create_table()
self.commit()
bangumi_data = self.bangumi.search_all()
self.bangumi.add_all(readd_bangumi)
self.add(User(**user_data[0]))
self.commit()

View File

@@ -1,14 +1,13 @@
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from sqlmodel import Session, create_engine
from sqlmodel import create_engine
from module.conf import DATA_PATH
# Sync engine (for legacy code)
# Sync engine (used by Database which extends Session)
engine = create_engine(DATA_PATH)
db_session = Session(engine)
# Async engine (for passkey and new async code)
# Async engine (for passkey operations)
ASYNC_DATA_PATH = DATA_PATH.replace("sqlite:///", "sqlite+aiosqlite:///")
async_engine = create_async_engine(ASYNC_DATA_PATH)
async_session_factory = sessionmaker(async_engine, class_=AsyncSession, expire_on_commit=False)

View File

@@ -1,8 +1,7 @@
import logging
from fastapi import HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import select
from sqlmodel import Session, select
from module.models import ResponseModel
from module.models.user import User, UserLogin, UserUpdate
@@ -12,36 +11,28 @@ logger = logging.getLogger(__name__)
class UserDatabase:
def __init__(self, session):
def __init__(self, session: Session):
self.session = session
async def get_user(self, username):
def get_user(self, username):
statement = select(User).where(User.username == username)
if isinstance(self.session, AsyncSession):
result = await self.session.execute(statement)
user = result.scalar_one_or_none()
else:
user = self.session.exec(statement).first()
if not user:
result = self.session.exec(statement).first()
if not result:
raise HTTPException(status_code=404, detail="User not found")
return user
return result
async def auth_user(self, user: User):
def auth_user(self, user: User):
statement = select(User).where(User.username == user.username)
if isinstance(self.session, AsyncSession):
result = await self.session.execute(statement)
db_user = result.scalar_one_or_none()
else:
db_user = self.session.exec(statement).first()
result = self.session.exec(statement).first()
if not user.password:
return ResponseModel(
status_code=401, status=False, msg_en="Incorrect password format", msg_zh="密码格式不正确"
)
if not db_user:
if not result:
return ResponseModel(
status_code=401, status=False, msg_en="User not found", msg_zh="用户不存在"
)
if not verify_password(user.password, db_user.password):
if not verify_password(user.password, result.password):
return ResponseModel(
status_code=401,
status=False,
@@ -52,59 +43,36 @@ class UserDatabase:
status_code=200, status=True, msg_en="Login successfully", msg_zh="登录成功"
)
async def update_user(self, username, update_user: UserUpdate):
def update_user(self, username, update_user: UserUpdate):
# Update username and password
statement = select(User).where(User.username == username)
if isinstance(self.session, AsyncSession):
result = await self.session.execute(statement)
db_user = result.scalar_one_or_none()
else:
db_user = self.session.exec(statement).first()
if not db_user:
result = self.session.exec(statement).first()
if not result:
raise HTTPException(status_code=404, detail="User not found")
if update_user.username:
db_user.username = update_user.username
result.username = update_user.username
if update_user.password:
db_user.password = get_password_hash(update_user.password)
self.session.add(db_user)
if isinstance(self.session, AsyncSession):
await self.session.commit()
else:
self.session.commit()
return db_user
async def add_default_user(self):
statement = select(User)
if isinstance(self.session, AsyncSession):
result = await self.session.execute(statement)
users = list(result.scalars().all())
else:
try:
users = self.session.exec(statement).all()
except Exception:
self.merge_old_user()
users = self.session.exec(statement).all()
if len(users) != 0:
return
user = User(username="admin", password=get_password_hash("adminadmin"))
self.session.add(user)
if isinstance(self.session, AsyncSession):
await self.session.commit()
else:
self.session.commit()
result.password = get_password_hash(update_user.password)
self.session.add(result)
self.session.commit()
return result
def merge_old_user(self):
# Legacy migration - sync only
# get old data
statement = """
SELECT * FROM user
"""
result = self.session.exec(statement).first()
if not result:
return
# add new data
user = User(username=result.username, password=result.password)
# Drop old table
statement = """
DROP TABLE user
"""
self.session.exec(statement)
# Create new table
statement = """
CREATE TABLE user (
id INTEGER NOT NULL PRIMARY KEY,
@@ -115,3 +83,18 @@ class UserDatabase:
self.session.exec(statement)
self.session.add(user)
self.session.commit()
def add_default_user(self):
# Check if user exists
statement = select(User)
try:
result = self.session.exec(statement).all()
except Exception:
self.merge_old_user()
result = self.session.exec(statement).all()
if len(result) != 0:
return
# Add default user
user = User(username="admin", password=get_password_hash("adminadmin"))
self.session.add(user)
self.session.commit()

View File

@@ -34,18 +34,18 @@ async def get_token_data(token: str = Depends(oauth2_scheme)):
return payload
async def update_user_info(user_data: UserUpdate, current_user):
def update_user_info(user_data: UserUpdate, current_user):
try:
async with Database() as db:
await db.user.update_user(current_user, user_data)
with Database() as db:
db.user.update_user(current_user, user_data)
return True
except Exception as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
async def auth_user(user: User):
async with Database() as db:
resp = await db.user.auth_user(user)
def auth_user(user: User):
with Database() as db:
resp = db.user.auth_user(user)
if resp.status:
active_user.append(user.username)
return resp

View File

@@ -2,10 +2,12 @@
认证策略抽象层
将密码认证和 Passkey 认证统一为策略模式
"""
import base64
from abc import ABC, abstractmethod
from module.database import Database
from sqlmodel import select
from module.database.engine import async_session_factory
from module.database.passkey import PasskeyDatabase
from module.models import ResponseModel
from module.models.user import User
@@ -28,25 +30,6 @@ class AuthStrategy(ABC):
pass
class PasswordAuthStrategy(AuthStrategy):
"""密码认证策略(保持向后兼容)"""
async def authenticate(self, username: str, credential: dict) -> ResponseModel:
"""使用密码认证"""
password = credential.get("password")
if not password:
return ResponseModel(
status_code=401,
status=False,
msg_en="Password is required",
msg_zh="密码不能为空",
)
user = User(username=username, password=password)
async with Database() as db:
return await db.user.auth_user(user)
class PasskeyAuthStrategy(AuthStrategy):
"""Passkey 认证策略"""
@@ -55,11 +38,16 @@ class PasskeyAuthStrategy(AuthStrategy):
async def authenticate(self, username: str, credential: dict) -> ResponseModel:
"""使用 WebAuthn Passkey 认证"""
async with Database() as db:
async with async_session_factory() as session:
# 1. 查找用户
try:
user = await db.user.get_user(username)
except Exception:
result = await session.execute(
select(User).where(User.username == username)
)
user = result.scalar_one_or_none()
if not user:
raise ValueError("User not found")
except ValueError:
return ResponseModel(
status_code=401,
status=False,
@@ -78,11 +66,12 @@ class PasskeyAuthStrategy(AuthStrategy):
self.webauthn_service.base64url_decode(raw_id)
)
passkey = await db.passkey.get_passkey_by_credential_id(credential_id_str)
passkey_db = PasskeyDatabase(session)
passkey = await passkey_db.get_passkey_by_credential_id(credential_id_str)
if not passkey or passkey.user_id != user.id:
raise ValueError("Passkey not found or not owned by user")
except Exception as e:
except Exception:
return ResponseModel(
status_code=401,
status=False,
@@ -97,7 +86,7 @@ class PasskeyAuthStrategy(AuthStrategy):
)
# 4. 更新使用记录
await db.passkey.update_passkey_usage(passkey, new_sign_count)
await passkey_db.update_passkey_usage(passkey, new_sign_count)
return ResponseModel(
status_code=200,