From cd233881bdd8dbe9055fdeddaade8f100c21ac7d Mon Sep 17 00:00:00 2001 From: EstrellaXD Date: Fri, 23 Jan 2026 15:58:26 +0100 Subject: [PATCH] fix(backend): restore sync Database interface, isolate async for passkey The previous refactoring broke backward compatibility by converting Database from Session-extending sync class to a standalone async class. This broke RSSEngine, startup code, and auth flows. - Restore Database(Session) with sync interface for legacy code - Restore UserDatabase to sync methods - Restore security/api.py and auth.py to sync calls - Passkey API now uses async_session_factory directly - PasskeyAuthStrategy uses async sessions independently - Remove unused db_session from engine.py Co-Authored-By: Claude Opus 4.5 --- backend/src/module/api/auth.py | 4 +- backend/src/module/api/passkey.py | 83 +++++++++++++---- backend/src/module/database/combine.py | 94 ++++++-------------- backend/src/module/database/engine.py | 7 +- backend/src/module/database/user.py | 93 ++++++++----------- backend/src/module/security/api.py | 12 +-- backend/src/module/security/auth_strategy.py | 43 ++++----- 7 files changed, 157 insertions(+), 179 deletions(-) diff --git a/backend/src/module/api/auth.py b/backend/src/module/api/auth.py index d3399e71..8c6667c6 100644 --- a/backend/src/module/api/auth.py +++ b/backend/src/module/api/auth.py @@ -22,7 +22,7 @@ router = APIRouter(prefix="/auth", tags=["auth"]) @router.post("/login", response_model=dict) async def login(response: Response, form_data=Depends(OAuth2PasswordRequestForm)): user = User(username=form_data.username, password=form_data.password) - resp = await auth_user(user) + resp = auth_user(user) if resp.status: token = create_access_token( data={"sub": user.username}, expires_delta=timedelta(days=1) @@ -58,7 +58,7 @@ async def logout(response: Response): @router.post("/update", response_model=dict, dependencies=[Depends(get_current_user)]) async def update_user(user_data: UserUpdate, response: Response): old_user = active_user[0] - if await update_user_info(user_data, old_user): + if update_user_info(user_data, old_user): token = create_access_token( data={"sub": old_user}, expires_delta=timedelta(days=1) ) diff --git a/backend/src/module/api/passkey.py b/backend/src/module/api/passkey.py index 56b66ac4..0275a9a2 100644 --- a/backend/src/module/api/passkey.py +++ b/backend/src/module/api/passkey.py @@ -7,8 +7,10 @@ from datetime import timedelta from fastapi import APIRouter, Depends, HTTPException, Request from fastapi.responses import JSONResponse, Response +from sqlmodel import select -from module.database import Database +from module.database.engine import async_session_factory +from module.database.passkey import PasskeyDatabase from module.models import APIResponse from module.models.passkey import ( PasskeyAuthFinish, @@ -17,6 +19,7 @@ from module.models.passkey import ( PasskeyDelete, PasskeyList, ) +from module.models.user import User from module.security.api import active_user, get_current_user from module.security.auth_strategy import PasskeyAuthStrategy from module.security.jwt import create_access_token @@ -66,10 +69,19 @@ async def get_registration_options( """ webauthn = _get_webauthn_from_request(request) - async with Database() as db: + async with async_session_factory() as session: try: - user = await db.user.get_user(username) - existing_passkeys = await db.passkey.get_passkeys_by_user_id(user.id) + # Get user + result = await session.execute( + select(User).where(User.username == username) + ) + user = result.scalar_one_or_none() + if not user: + raise HTTPException(status_code=404, detail="User not found") + + # Get existing passkeys + passkey_db = PasskeyDatabase(session) + existing_passkeys = await passkey_db.get_passkeys_by_user_id(user.id) options = webauthn.generate_registration_options( username=username, @@ -79,6 +91,8 @@ async def get_registration_options( return options + except HTTPException: + raise except Exception as e: logger.error(f"Failed to generate registration options: {e}") raise HTTPException(status_code=500, detail=str(e)) @@ -95,9 +109,15 @@ async def verify_registration( """ webauthn = _get_webauthn_from_request(request) - async with Database() as db: + async with async_session_factory() as session: try: - user = await db.user.get_user(username) + # Get user + result = await session.execute( + select(User).where(User.username == username) + ) + user = result.scalar_one_or_none() + if not user: + raise HTTPException(status_code=404, detail="User not found") # 验证 WebAuthn 响应 passkey = webauthn.verify_registration( @@ -108,7 +128,8 @@ async def verify_registration( # 设置 user_id 并保存 passkey.user_id = user.id - await db.passkey.create_passkey(passkey) + passkey_db = PasskeyDatabase(session) + await passkey_db.create_passkey(passkey) return JSONResponse( status_code=200, @@ -121,6 +142,8 @@ async def verify_registration( except ValueError as e: logger.warning(f"Registration verification failed for {username}: {e}") raise HTTPException(status_code=400, detail=str(e)) + except HTTPException: + raise except Exception as e: logger.error(f"Failed to register passkey: {e}") raise HTTPException(status_code=500, detail=str(e)) @@ -140,10 +163,18 @@ async def get_passkey_login_options( """ webauthn = _get_webauthn_from_request(request) - async with Database() as db: + async with async_session_factory() as session: try: - user = await db.user.get_user(auth_data.username) - passkeys = await db.passkey.get_passkeys_by_user_id(user.id) + # Get user + result = await session.execute( + select(User).where(User.username == auth_data.username) + ) + user = result.scalar_one_or_none() + if not user: + raise HTTPException(status_code=404, detail="User not found") + + passkey_db = PasskeyDatabase(session) + passkeys = await passkey_db.get_passkeys_by_user_id(user.id) if not passkeys: raise HTTPException( @@ -194,13 +225,23 @@ async def login_with_passkey( @router.get("/list", response_model=list[PasskeyList]) async def list_passkeys(username: str = Depends(get_current_user)): """获取用户的所有 Passkey""" - async with Database() as db: + async with async_session_factory() as session: try: - user = await db.user.get_user(username) - passkeys = await db.passkey.get_passkeys_by_user_id(user.id) + # Get user + result = await session.execute( + select(User).where(User.username == username) + ) + user = result.scalar_one_or_none() + if not user: + raise HTTPException(status_code=404, detail="User not found") - return [db.passkey.to_list_model(pk) for pk in passkeys] + passkey_db = PasskeyDatabase(session) + passkeys = await passkey_db.get_passkeys_by_user_id(user.id) + return [passkey_db.to_list_model(pk) for pk in passkeys] + + except HTTPException: + raise except Exception as e: logger.error(f"Failed to list passkeys: {e}") raise HTTPException(status_code=500, detail=str(e)) @@ -212,10 +253,18 @@ async def delete_passkey( username: str = Depends(get_current_user), ): """删除 Passkey""" - async with Database() as db: + async with async_session_factory() as session: try: - user = await db.user.get_user(username) - await db.passkey.delete_passkey(delete_data.passkey_id, user.id) + # Get user + result = await session.execute( + select(User).where(User.username == username) + ) + user = result.scalar_one_or_none() + if not user: + raise HTTPException(status_code=404, detail="User not found") + + passkey_db = PasskeyDatabase(session) + await passkey_db.delete_passkey(delete_data.passkey_id, user.id) return JSONResponse( status_code=200, diff --git a/backend/src/module/database/combine.py b/backend/src/module/database/combine.py index 7b06196b..57979322 100644 --- a/backend/src/module/database/combine.py +++ b/backend/src/module/database/combine.py @@ -1,85 +1,43 @@ -from sqlalchemy.ext.asyncio import AsyncSession -from sqlmodel import SQLModel +from sqlmodel import Session, SQLModel -from module.models import Bangumi, Passkey, User +from module.models import Bangumi, User +from module.models.passkey import Passkey from .bangumi import BangumiDatabase -from .engine import async_engine, async_session_factory, engine as e -from .passkey import PasskeyDatabase +from .engine import engine as e from .rss import RSSDatabase from .torrent import TorrentDatabase from .user import UserDatabase -class Database: - def __init__(self): - self._session = None - self.rss: RSSDatabase | None = None - self.torrent: TorrentDatabase | None = None - self.bangumi: BangumiDatabase | None = None - self.user: UserDatabase | None = None - self.passkey: PasskeyDatabase | None = None +class Database(Session): + def __init__(self, engine=e): + self.engine = engine + super().__init__(engine) + self.rss = RSSDatabase(self) + self.torrent = TorrentDatabase(self) + self.bangumi = BangumiDatabase(self) + self.user = UserDatabase(self) - # Sync context manager (for legacy code) - def __enter__(self): - from .engine import db_session + def create_table(self): + SQLModel.metadata.create_all(self.engine) - self._session = db_session - self.rss = RSSDatabase(self._session) - self.torrent = TorrentDatabase(self._session) - self.bangumi = BangumiDatabase(self._session) - self.user = UserDatabase(self._session) - return self + def drop_table(self): + SQLModel.metadata.drop_all(self.engine) - def __exit__(self, exc_type, exc_val, exc_tb): - pass - - # Async context manager (for passkey and new async code) - async def __aenter__(self): - self._session = async_session_factory() - self.rss = RSSDatabase(self._session) - self.torrent = TorrentDatabase(self._session) - self.bangumi = BangumiDatabase(self._session) - self.user = UserDatabase(self._session) - self.passkey = PasskeyDatabase(self._session) - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb): - if self._session and isinstance(self._session, AsyncSession): - await self._session.close() - - async def create_table(self): - async with async_engine.begin() as conn: - await conn.run_sync(SQLModel.metadata.create_all) - - async def drop_table(self): - async with async_engine.begin() as conn: - await conn.run_sync(SQLModel.metadata.drop_all) - - async def commit(self): - if self._session: - if isinstance(self._session, AsyncSession): - await self._session.commit() - else: - self._session.commit() - - async def add(self, obj): - if self._session: - self._session.add(obj) - if isinstance(self._session, AsyncSession): - await self._session.commit() - else: - self._session.commit() - - async def migrate(self): + def migrate(self): # Run migration online - bangumi_data = await self.bangumi.search_all() + bangumi_data = self.bangumi.search_all() + user_data = self.exec("SELECT * FROM user").all() readd_bangumi = [] for bangumi in bangumi_data: dict_data = bangumi.dict() del dict_data["id"] readd_bangumi.append(Bangumi(**dict_data)) - await self.drop_table() - await self.create_table() - await self.commit() - await self.bangumi.add_all(readd_bangumi) + self.drop_table() + self.create_table() + self.commit() + bangumi_data = self.bangumi.search_all() + self.bangumi.add_all(readd_bangumi) + self.add(User(**user_data[0])) + self.commit() diff --git a/backend/src/module/database/engine.py b/backend/src/module/database/engine.py index adf98eaa..5bdb6f64 100644 --- a/backend/src/module/database/engine.py +++ b/backend/src/module/database/engine.py @@ -1,14 +1,13 @@ from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.orm import sessionmaker -from sqlmodel import Session, create_engine +from sqlmodel import create_engine from module.conf import DATA_PATH -# Sync engine (for legacy code) +# Sync engine (used by Database which extends Session) engine = create_engine(DATA_PATH) -db_session = Session(engine) -# Async engine (for passkey and new async code) +# Async engine (for passkey operations) ASYNC_DATA_PATH = DATA_PATH.replace("sqlite:///", "sqlite+aiosqlite:///") async_engine = create_async_engine(ASYNC_DATA_PATH) async_session_factory = sessionmaker(async_engine, class_=AsyncSession, expire_on_commit=False) diff --git a/backend/src/module/database/user.py b/backend/src/module/database/user.py index ac303e23..336f81f4 100644 --- a/backend/src/module/database/user.py +++ b/backend/src/module/database/user.py @@ -1,8 +1,7 @@ import logging from fastapi import HTTPException -from sqlalchemy.ext.asyncio import AsyncSession -from sqlmodel import select +from sqlmodel import Session, select from module.models import ResponseModel from module.models.user import User, UserLogin, UserUpdate @@ -12,36 +11,28 @@ logger = logging.getLogger(__name__) class UserDatabase: - def __init__(self, session): + def __init__(self, session: Session): self.session = session - async def get_user(self, username): + def get_user(self, username): statement = select(User).where(User.username == username) - if isinstance(self.session, AsyncSession): - result = await self.session.execute(statement) - user = result.scalar_one_or_none() - else: - user = self.session.exec(statement).first() - if not user: + result = self.session.exec(statement).first() + if not result: raise HTTPException(status_code=404, detail="User not found") - return user + return result - async def auth_user(self, user: User): + def auth_user(self, user: User): statement = select(User).where(User.username == user.username) - if isinstance(self.session, AsyncSession): - result = await self.session.execute(statement) - db_user = result.scalar_one_or_none() - else: - db_user = self.session.exec(statement).first() + result = self.session.exec(statement).first() if not user.password: return ResponseModel( status_code=401, status=False, msg_en="Incorrect password format", msg_zh="密码格式不正确" ) - if not db_user: + if not result: return ResponseModel( status_code=401, status=False, msg_en="User not found", msg_zh="用户不存在" ) - if not verify_password(user.password, db_user.password): + if not verify_password(user.password, result.password): return ResponseModel( status_code=401, status=False, @@ -52,59 +43,36 @@ class UserDatabase: status_code=200, status=True, msg_en="Login successfully", msg_zh="登录成功" ) - async def update_user(self, username, update_user: UserUpdate): + def update_user(self, username, update_user: UserUpdate): + # Update username and password statement = select(User).where(User.username == username) - if isinstance(self.session, AsyncSession): - result = await self.session.execute(statement) - db_user = result.scalar_one_or_none() - else: - db_user = self.session.exec(statement).first() - if not db_user: + result = self.session.exec(statement).first() + if not result: raise HTTPException(status_code=404, detail="User not found") if update_user.username: - db_user.username = update_user.username + result.username = update_user.username if update_user.password: - db_user.password = get_password_hash(update_user.password) - self.session.add(db_user) - if isinstance(self.session, AsyncSession): - await self.session.commit() - else: - self.session.commit() - return db_user - - async def add_default_user(self): - statement = select(User) - if isinstance(self.session, AsyncSession): - result = await self.session.execute(statement) - users = list(result.scalars().all()) - else: - try: - users = self.session.exec(statement).all() - except Exception: - self.merge_old_user() - users = self.session.exec(statement).all() - if len(users) != 0: - return - user = User(username="admin", password=get_password_hash("adminadmin")) - self.session.add(user) - if isinstance(self.session, AsyncSession): - await self.session.commit() - else: - self.session.commit() + result.password = get_password_hash(update_user.password) + self.session.add(result) + self.session.commit() + return result def merge_old_user(self): - # Legacy migration - sync only + # get old data statement = """ SELECT * FROM user """ result = self.session.exec(statement).first() if not result: return + # add new data user = User(username=result.username, password=result.password) + # Drop old table statement = """ DROP TABLE user """ self.session.exec(statement) + # Create new table statement = """ CREATE TABLE user ( id INTEGER NOT NULL PRIMARY KEY, @@ -115,3 +83,18 @@ class UserDatabase: self.session.exec(statement) self.session.add(user) self.session.commit() + + def add_default_user(self): + # Check if user exists + statement = select(User) + try: + result = self.session.exec(statement).all() + except Exception: + self.merge_old_user() + result = self.session.exec(statement).all() + if len(result) != 0: + return + # Add default user + user = User(username="admin", password=get_password_hash("adminadmin")) + self.session.add(user) + self.session.commit() diff --git a/backend/src/module/security/api.py b/backend/src/module/security/api.py index 4a1ba47d..e0150e68 100644 --- a/backend/src/module/security/api.py +++ b/backend/src/module/security/api.py @@ -34,18 +34,18 @@ async def get_token_data(token: str = Depends(oauth2_scheme)): return payload -async def update_user_info(user_data: UserUpdate, current_user): +def update_user_info(user_data: UserUpdate, current_user): try: - async with Database() as db: - await db.user.update_user(current_user, user_data) + with Database() as db: + db.user.update_user(current_user, user_data) return True except Exception as e: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) -async def auth_user(user: User): - async with Database() as db: - resp = await db.user.auth_user(user) +def auth_user(user: User): + with Database() as db: + resp = db.user.auth_user(user) if resp.status: active_user.append(user.username) return resp diff --git a/backend/src/module/security/auth_strategy.py b/backend/src/module/security/auth_strategy.py index e1f25d78..4b17a24c 100644 --- a/backend/src/module/security/auth_strategy.py +++ b/backend/src/module/security/auth_strategy.py @@ -2,10 +2,12 @@ 认证策略抽象层 将密码认证和 Passkey 认证统一为策略模式 """ -import base64 from abc import ABC, abstractmethod -from module.database import Database +from sqlmodel import select + +from module.database.engine import async_session_factory +from module.database.passkey import PasskeyDatabase from module.models import ResponseModel from module.models.user import User @@ -28,25 +30,6 @@ class AuthStrategy(ABC): pass -class PasswordAuthStrategy(AuthStrategy): - """密码认证策略(保持向后兼容)""" - - async def authenticate(self, username: str, credential: dict) -> ResponseModel: - """使用密码认证""" - password = credential.get("password") - if not password: - return ResponseModel( - status_code=401, - status=False, - msg_en="Password is required", - msg_zh="密码不能为空", - ) - - user = User(username=username, password=password) - async with Database() as db: - return await db.user.auth_user(user) - - class PasskeyAuthStrategy(AuthStrategy): """Passkey 认证策略""" @@ -55,11 +38,16 @@ class PasskeyAuthStrategy(AuthStrategy): async def authenticate(self, username: str, credential: dict) -> ResponseModel: """使用 WebAuthn Passkey 认证""" - async with Database() as db: + async with async_session_factory() as session: # 1. 查找用户 try: - user = await db.user.get_user(username) - except Exception: + result = await session.execute( + select(User).where(User.username == username) + ) + user = result.scalar_one_or_none() + if not user: + raise ValueError("User not found") + except ValueError: return ResponseModel( status_code=401, status=False, @@ -78,11 +66,12 @@ class PasskeyAuthStrategy(AuthStrategy): self.webauthn_service.base64url_decode(raw_id) ) - passkey = await db.passkey.get_passkey_by_credential_id(credential_id_str) + passkey_db = PasskeyDatabase(session) + passkey = await passkey_db.get_passkey_by_credential_id(credential_id_str) if not passkey or passkey.user_id != user.id: raise ValueError("Passkey not found or not owned by user") - except Exception as e: + except Exception: return ResponseModel( status_code=401, status=False, @@ -97,7 +86,7 @@ class PasskeyAuthStrategy(AuthStrategy): ) # 4. 更新使用记录 - await db.passkey.update_passkey_usage(passkey, new_sign_count) + await passkey_db.update_passkey_usage(passkey, new_sign_count) return ResponseModel( status_code=200,