mirror of
https://github.com/EstrellaXD/Auto_Bangumi.git
synced 2026-04-13 17:39:52 +08:00
fix(backend): restore sync Database interface, isolate async for passkey
The previous refactoring broke backward compatibility by converting Database from Session-extending sync class to a standalone async class. This broke RSSEngine, startup code, and auth flows. - Restore Database(Session) with sync interface for legacy code - Restore UserDatabase to sync methods - Restore security/api.py and auth.py to sync calls - Passkey API now uses async_session_factory directly - PasskeyAuthStrategy uses async sessions independently - Remove unused db_session from engine.py Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -22,7 +22,7 @@ router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
@router.post("/login", response_model=dict)
|
||||
async def login(response: Response, form_data=Depends(OAuth2PasswordRequestForm)):
|
||||
user = User(username=form_data.username, password=form_data.password)
|
||||
resp = await auth_user(user)
|
||||
resp = auth_user(user)
|
||||
if resp.status:
|
||||
token = create_access_token(
|
||||
data={"sub": user.username}, expires_delta=timedelta(days=1)
|
||||
@@ -58,7 +58,7 @@ async def logout(response: Response):
|
||||
@router.post("/update", response_model=dict, dependencies=[Depends(get_current_user)])
|
||||
async def update_user(user_data: UserUpdate, response: Response):
|
||||
old_user = active_user[0]
|
||||
if await update_user_info(user_data, old_user):
|
||||
if update_user_info(user_data, old_user):
|
||||
token = create_access_token(
|
||||
data={"sub": old_user}, expires_delta=timedelta(days=1)
|
||||
)
|
||||
|
||||
@@ -7,8 +7,10 @@ from datetime import timedelta
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse, Response
|
||||
from sqlmodel import select
|
||||
|
||||
from module.database import Database
|
||||
from module.database.engine import async_session_factory
|
||||
from module.database.passkey import PasskeyDatabase
|
||||
from module.models import APIResponse
|
||||
from module.models.passkey import (
|
||||
PasskeyAuthFinish,
|
||||
@@ -17,6 +19,7 @@ from module.models.passkey import (
|
||||
PasskeyDelete,
|
||||
PasskeyList,
|
||||
)
|
||||
from module.models.user import User
|
||||
from module.security.api import active_user, get_current_user
|
||||
from module.security.auth_strategy import PasskeyAuthStrategy
|
||||
from module.security.jwt import create_access_token
|
||||
@@ -66,10 +69,19 @@ async def get_registration_options(
|
||||
"""
|
||||
webauthn = _get_webauthn_from_request(request)
|
||||
|
||||
async with Database() as db:
|
||||
async with async_session_factory() as session:
|
||||
try:
|
||||
user = await db.user.get_user(username)
|
||||
existing_passkeys = await db.passkey.get_passkeys_by_user_id(user.id)
|
||||
# Get user
|
||||
result = await session.execute(
|
||||
select(User).where(User.username == username)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
# Get existing passkeys
|
||||
passkey_db = PasskeyDatabase(session)
|
||||
existing_passkeys = await passkey_db.get_passkeys_by_user_id(user.id)
|
||||
|
||||
options = webauthn.generate_registration_options(
|
||||
username=username,
|
||||
@@ -79,6 +91,8 @@ async def get_registration_options(
|
||||
|
||||
return options
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate registration options: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@@ -95,9 +109,15 @@ async def verify_registration(
|
||||
"""
|
||||
webauthn = _get_webauthn_from_request(request)
|
||||
|
||||
async with Database() as db:
|
||||
async with async_session_factory() as session:
|
||||
try:
|
||||
user = await db.user.get_user(username)
|
||||
# Get user
|
||||
result = await session.execute(
|
||||
select(User).where(User.username == username)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
# 验证 WebAuthn 响应
|
||||
passkey = webauthn.verify_registration(
|
||||
@@ -108,7 +128,8 @@ async def verify_registration(
|
||||
|
||||
# 设置 user_id 并保存
|
||||
passkey.user_id = user.id
|
||||
await db.passkey.create_passkey(passkey)
|
||||
passkey_db = PasskeyDatabase(session)
|
||||
await passkey_db.create_passkey(passkey)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
@@ -121,6 +142,8 @@ async def verify_registration(
|
||||
except ValueError as e:
|
||||
logger.warning(f"Registration verification failed for {username}: {e}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to register passkey: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@@ -140,10 +163,18 @@ async def get_passkey_login_options(
|
||||
"""
|
||||
webauthn = _get_webauthn_from_request(request)
|
||||
|
||||
async with Database() as db:
|
||||
async with async_session_factory() as session:
|
||||
try:
|
||||
user = await db.user.get_user(auth_data.username)
|
||||
passkeys = await db.passkey.get_passkeys_by_user_id(user.id)
|
||||
# Get user
|
||||
result = await session.execute(
|
||||
select(User).where(User.username == auth_data.username)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
passkey_db = PasskeyDatabase(session)
|
||||
passkeys = await passkey_db.get_passkeys_by_user_id(user.id)
|
||||
|
||||
if not passkeys:
|
||||
raise HTTPException(
|
||||
@@ -194,13 +225,23 @@ async def login_with_passkey(
|
||||
@router.get("/list", response_model=list[PasskeyList])
|
||||
async def list_passkeys(username: str = Depends(get_current_user)):
|
||||
"""获取用户的所有 Passkey"""
|
||||
async with Database() as db:
|
||||
async with async_session_factory() as session:
|
||||
try:
|
||||
user = await db.user.get_user(username)
|
||||
passkeys = await db.passkey.get_passkeys_by_user_id(user.id)
|
||||
# Get user
|
||||
result = await session.execute(
|
||||
select(User).where(User.username == username)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
return [db.passkey.to_list_model(pk) for pk in passkeys]
|
||||
passkey_db = PasskeyDatabase(session)
|
||||
passkeys = await passkey_db.get_passkeys_by_user_id(user.id)
|
||||
|
||||
return [passkey_db.to_list_model(pk) for pk in passkeys]
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list passkeys: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@@ -212,10 +253,18 @@ async def delete_passkey(
|
||||
username: str = Depends(get_current_user),
|
||||
):
|
||||
"""删除 Passkey"""
|
||||
async with Database() as db:
|
||||
async with async_session_factory() as session:
|
||||
try:
|
||||
user = await db.user.get_user(username)
|
||||
await db.passkey.delete_passkey(delete_data.passkey_id, user.id)
|
||||
# Get user
|
||||
result = await session.execute(
|
||||
select(User).where(User.username == username)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
passkey_db = PasskeyDatabase(session)
|
||||
await passkey_db.delete_passkey(delete_data.passkey_id, user.id)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
|
||||
@@ -1,85 +1,43 @@
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlmodel import SQLModel
|
||||
from sqlmodel import Session, SQLModel
|
||||
|
||||
from module.models import Bangumi, Passkey, User
|
||||
from module.models import Bangumi, User
|
||||
from module.models.passkey import Passkey
|
||||
|
||||
from .bangumi import BangumiDatabase
|
||||
from .engine import async_engine, async_session_factory, engine as e
|
||||
from .passkey import PasskeyDatabase
|
||||
from .engine import engine as e
|
||||
from .rss import RSSDatabase
|
||||
from .torrent import TorrentDatabase
|
||||
from .user import UserDatabase
|
||||
|
||||
|
||||
class Database:
|
||||
def __init__(self):
|
||||
self._session = None
|
||||
self.rss: RSSDatabase | None = None
|
||||
self.torrent: TorrentDatabase | None = None
|
||||
self.bangumi: BangumiDatabase | None = None
|
||||
self.user: UserDatabase | None = None
|
||||
self.passkey: PasskeyDatabase | None = None
|
||||
class Database(Session):
|
||||
def __init__(self, engine=e):
|
||||
self.engine = engine
|
||||
super().__init__(engine)
|
||||
self.rss = RSSDatabase(self)
|
||||
self.torrent = TorrentDatabase(self)
|
||||
self.bangumi = BangumiDatabase(self)
|
||||
self.user = UserDatabase(self)
|
||||
|
||||
# Sync context manager (for legacy code)
|
||||
def __enter__(self):
|
||||
from .engine import db_session
|
||||
def create_table(self):
|
||||
SQLModel.metadata.create_all(self.engine)
|
||||
|
||||
self._session = db_session
|
||||
self.rss = RSSDatabase(self._session)
|
||||
self.torrent = TorrentDatabase(self._session)
|
||||
self.bangumi = BangumiDatabase(self._session)
|
||||
self.user = UserDatabase(self._session)
|
||||
return self
|
||||
def drop_table(self):
|
||||
SQLModel.metadata.drop_all(self.engine)
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
pass
|
||||
|
||||
# Async context manager (for passkey and new async code)
|
||||
async def __aenter__(self):
|
||||
self._session = async_session_factory()
|
||||
self.rss = RSSDatabase(self._session)
|
||||
self.torrent = TorrentDatabase(self._session)
|
||||
self.bangumi = BangumiDatabase(self._session)
|
||||
self.user = UserDatabase(self._session)
|
||||
self.passkey = PasskeyDatabase(self._session)
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
if self._session and isinstance(self._session, AsyncSession):
|
||||
await self._session.close()
|
||||
|
||||
async def create_table(self):
|
||||
async with async_engine.begin() as conn:
|
||||
await conn.run_sync(SQLModel.metadata.create_all)
|
||||
|
||||
async def drop_table(self):
|
||||
async with async_engine.begin() as conn:
|
||||
await conn.run_sync(SQLModel.metadata.drop_all)
|
||||
|
||||
async def commit(self):
|
||||
if self._session:
|
||||
if isinstance(self._session, AsyncSession):
|
||||
await self._session.commit()
|
||||
else:
|
||||
self._session.commit()
|
||||
|
||||
async def add(self, obj):
|
||||
if self._session:
|
||||
self._session.add(obj)
|
||||
if isinstance(self._session, AsyncSession):
|
||||
await self._session.commit()
|
||||
else:
|
||||
self._session.commit()
|
||||
|
||||
async def migrate(self):
|
||||
def migrate(self):
|
||||
# Run migration online
|
||||
bangumi_data = await self.bangumi.search_all()
|
||||
bangumi_data = self.bangumi.search_all()
|
||||
user_data = self.exec("SELECT * FROM user").all()
|
||||
readd_bangumi = []
|
||||
for bangumi in bangumi_data:
|
||||
dict_data = bangumi.dict()
|
||||
del dict_data["id"]
|
||||
readd_bangumi.append(Bangumi(**dict_data))
|
||||
await self.drop_table()
|
||||
await self.create_table()
|
||||
await self.commit()
|
||||
await self.bangumi.add_all(readd_bangumi)
|
||||
self.drop_table()
|
||||
self.create_table()
|
||||
self.commit()
|
||||
bangumi_data = self.bangumi.search_all()
|
||||
self.bangumi.add_all(readd_bangumi)
|
||||
self.add(User(**user_data[0]))
|
||||
self.commit()
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlmodel import Session, create_engine
|
||||
from sqlmodel import create_engine
|
||||
|
||||
from module.conf import DATA_PATH
|
||||
|
||||
# Sync engine (for legacy code)
|
||||
# Sync engine (used by Database which extends Session)
|
||||
engine = create_engine(DATA_PATH)
|
||||
db_session = Session(engine)
|
||||
|
||||
# Async engine (for passkey and new async code)
|
||||
# Async engine (for passkey operations)
|
||||
ASYNC_DATA_PATH = DATA_PATH.replace("sqlite:///", "sqlite+aiosqlite:///")
|
||||
async_engine = create_async_engine(ASYNC_DATA_PATH)
|
||||
async_session_factory = sessionmaker(async_engine, class_=AsyncSession, expire_on_commit=False)
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
import logging
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlmodel import select
|
||||
from sqlmodel import Session, select
|
||||
|
||||
from module.models import ResponseModel
|
||||
from module.models.user import User, UserLogin, UserUpdate
|
||||
@@ -12,36 +11,28 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UserDatabase:
|
||||
def __init__(self, session):
|
||||
def __init__(self, session: Session):
|
||||
self.session = session
|
||||
|
||||
async def get_user(self, username):
|
||||
def get_user(self, username):
|
||||
statement = select(User).where(User.username == username)
|
||||
if isinstance(self.session, AsyncSession):
|
||||
result = await self.session.execute(statement)
|
||||
user = result.scalar_one_or_none()
|
||||
else:
|
||||
user = self.session.exec(statement).first()
|
||||
if not user:
|
||||
result = self.session.exec(statement).first()
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
return user
|
||||
return result
|
||||
|
||||
async def auth_user(self, user: User):
|
||||
def auth_user(self, user: User):
|
||||
statement = select(User).where(User.username == user.username)
|
||||
if isinstance(self.session, AsyncSession):
|
||||
result = await self.session.execute(statement)
|
||||
db_user = result.scalar_one_or_none()
|
||||
else:
|
||||
db_user = self.session.exec(statement).first()
|
||||
result = self.session.exec(statement).first()
|
||||
if not user.password:
|
||||
return ResponseModel(
|
||||
status_code=401, status=False, msg_en="Incorrect password format", msg_zh="密码格式不正确"
|
||||
)
|
||||
if not db_user:
|
||||
if not result:
|
||||
return ResponseModel(
|
||||
status_code=401, status=False, msg_en="User not found", msg_zh="用户不存在"
|
||||
)
|
||||
if not verify_password(user.password, db_user.password):
|
||||
if not verify_password(user.password, result.password):
|
||||
return ResponseModel(
|
||||
status_code=401,
|
||||
status=False,
|
||||
@@ -52,59 +43,36 @@ class UserDatabase:
|
||||
status_code=200, status=True, msg_en="Login successfully", msg_zh="登录成功"
|
||||
)
|
||||
|
||||
async def update_user(self, username, update_user: UserUpdate):
|
||||
def update_user(self, username, update_user: UserUpdate):
|
||||
# Update username and password
|
||||
statement = select(User).where(User.username == username)
|
||||
if isinstance(self.session, AsyncSession):
|
||||
result = await self.session.execute(statement)
|
||||
db_user = result.scalar_one_or_none()
|
||||
else:
|
||||
db_user = self.session.exec(statement).first()
|
||||
if not db_user:
|
||||
result = self.session.exec(statement).first()
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
if update_user.username:
|
||||
db_user.username = update_user.username
|
||||
result.username = update_user.username
|
||||
if update_user.password:
|
||||
db_user.password = get_password_hash(update_user.password)
|
||||
self.session.add(db_user)
|
||||
if isinstance(self.session, AsyncSession):
|
||||
await self.session.commit()
|
||||
else:
|
||||
self.session.commit()
|
||||
return db_user
|
||||
|
||||
async def add_default_user(self):
|
||||
statement = select(User)
|
||||
if isinstance(self.session, AsyncSession):
|
||||
result = await self.session.execute(statement)
|
||||
users = list(result.scalars().all())
|
||||
else:
|
||||
try:
|
||||
users = self.session.exec(statement).all()
|
||||
except Exception:
|
||||
self.merge_old_user()
|
||||
users = self.session.exec(statement).all()
|
||||
if len(users) != 0:
|
||||
return
|
||||
user = User(username="admin", password=get_password_hash("adminadmin"))
|
||||
self.session.add(user)
|
||||
if isinstance(self.session, AsyncSession):
|
||||
await self.session.commit()
|
||||
else:
|
||||
self.session.commit()
|
||||
result.password = get_password_hash(update_user.password)
|
||||
self.session.add(result)
|
||||
self.session.commit()
|
||||
return result
|
||||
|
||||
def merge_old_user(self):
|
||||
# Legacy migration - sync only
|
||||
# get old data
|
||||
statement = """
|
||||
SELECT * FROM user
|
||||
"""
|
||||
result = self.session.exec(statement).first()
|
||||
if not result:
|
||||
return
|
||||
# add new data
|
||||
user = User(username=result.username, password=result.password)
|
||||
# Drop old table
|
||||
statement = """
|
||||
DROP TABLE user
|
||||
"""
|
||||
self.session.exec(statement)
|
||||
# Create new table
|
||||
statement = """
|
||||
CREATE TABLE user (
|
||||
id INTEGER NOT NULL PRIMARY KEY,
|
||||
@@ -115,3 +83,18 @@ class UserDatabase:
|
||||
self.session.exec(statement)
|
||||
self.session.add(user)
|
||||
self.session.commit()
|
||||
|
||||
def add_default_user(self):
|
||||
# Check if user exists
|
||||
statement = select(User)
|
||||
try:
|
||||
result = self.session.exec(statement).all()
|
||||
except Exception:
|
||||
self.merge_old_user()
|
||||
result = self.session.exec(statement).all()
|
||||
if len(result) != 0:
|
||||
return
|
||||
# Add default user
|
||||
user = User(username="admin", password=get_password_hash("adminadmin"))
|
||||
self.session.add(user)
|
||||
self.session.commit()
|
||||
|
||||
@@ -34,18 +34,18 @@ async def get_token_data(token: str = Depends(oauth2_scheme)):
|
||||
return payload
|
||||
|
||||
|
||||
async def update_user_info(user_data: UserUpdate, current_user):
|
||||
def update_user_info(user_data: UserUpdate, current_user):
|
||||
try:
|
||||
async with Database() as db:
|
||||
await db.user.update_user(current_user, user_data)
|
||||
with Database() as db:
|
||||
db.user.update_user(current_user, user_data)
|
||||
return True
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
|
||||
|
||||
async def auth_user(user: User):
|
||||
async with Database() as db:
|
||||
resp = await db.user.auth_user(user)
|
||||
def auth_user(user: User):
|
||||
with Database() as db:
|
||||
resp = db.user.auth_user(user)
|
||||
if resp.status:
|
||||
active_user.append(user.username)
|
||||
return resp
|
||||
|
||||
@@ -2,10 +2,12 @@
|
||||
认证策略抽象层
|
||||
将密码认证和 Passkey 认证统一为策略模式
|
||||
"""
|
||||
import base64
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from module.database import Database
|
||||
from sqlmodel import select
|
||||
|
||||
from module.database.engine import async_session_factory
|
||||
from module.database.passkey import PasskeyDatabase
|
||||
from module.models import ResponseModel
|
||||
from module.models.user import User
|
||||
|
||||
@@ -28,25 +30,6 @@ class AuthStrategy(ABC):
|
||||
pass
|
||||
|
||||
|
||||
class PasswordAuthStrategy(AuthStrategy):
|
||||
"""密码认证策略(保持向后兼容)"""
|
||||
|
||||
async def authenticate(self, username: str, credential: dict) -> ResponseModel:
|
||||
"""使用密码认证"""
|
||||
password = credential.get("password")
|
||||
if not password:
|
||||
return ResponseModel(
|
||||
status_code=401,
|
||||
status=False,
|
||||
msg_en="Password is required",
|
||||
msg_zh="密码不能为空",
|
||||
)
|
||||
|
||||
user = User(username=username, password=password)
|
||||
async with Database() as db:
|
||||
return await db.user.auth_user(user)
|
||||
|
||||
|
||||
class PasskeyAuthStrategy(AuthStrategy):
|
||||
"""Passkey 认证策略"""
|
||||
|
||||
@@ -55,11 +38,16 @@ class PasskeyAuthStrategy(AuthStrategy):
|
||||
|
||||
async def authenticate(self, username: str, credential: dict) -> ResponseModel:
|
||||
"""使用 WebAuthn Passkey 认证"""
|
||||
async with Database() as db:
|
||||
async with async_session_factory() as session:
|
||||
# 1. 查找用户
|
||||
try:
|
||||
user = await db.user.get_user(username)
|
||||
except Exception:
|
||||
result = await session.execute(
|
||||
select(User).where(User.username == username)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
if not user:
|
||||
raise ValueError("User not found")
|
||||
except ValueError:
|
||||
return ResponseModel(
|
||||
status_code=401,
|
||||
status=False,
|
||||
@@ -78,11 +66,12 @@ class PasskeyAuthStrategy(AuthStrategy):
|
||||
self.webauthn_service.base64url_decode(raw_id)
|
||||
)
|
||||
|
||||
passkey = await db.passkey.get_passkey_by_credential_id(credential_id_str)
|
||||
passkey_db = PasskeyDatabase(session)
|
||||
passkey = await passkey_db.get_passkey_by_credential_id(credential_id_str)
|
||||
if not passkey or passkey.user_id != user.id:
|
||||
raise ValueError("Passkey not found or not owned by user")
|
||||
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
return ResponseModel(
|
||||
status_code=401,
|
||||
status=False,
|
||||
@@ -97,7 +86,7 @@ class PasskeyAuthStrategy(AuthStrategy):
|
||||
)
|
||||
|
||||
# 4. 更新使用记录
|
||||
await db.passkey.update_passkey_usage(passkey, new_sign_count)
|
||||
await passkey_db.update_passkey_usage(passkey, new_sign_count)
|
||||
|
||||
return ResponseModel(
|
||||
status_code=200,
|
||||
|
||||
Reference in New Issue
Block a user