fix(backend): restore sync Database interface, isolate async for passkey

The previous refactoring broke backward compatibility by converting
Database from Session-extending sync class to a standalone async class.
This broke RSSEngine, startup code, and auth flows.

- Restore Database(Session) with sync interface for legacy code
- Restore UserDatabase to sync methods
- Restore security/api.py and auth.py to sync calls
- Passkey API now uses async_session_factory directly
- PasskeyAuthStrategy uses async sessions independently
- Remove unused db_session from engine.py

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
EstrellaXD
2026-01-23 15:58:26 +01:00
parent 29e4c16b40
commit cd233881bd
7 changed files with 157 additions and 179 deletions

View File

@@ -1,85 +1,43 @@
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import SQLModel
from sqlmodel import Session, SQLModel
from module.models import Bangumi, Passkey, User
from module.models import Bangumi, User
from module.models.passkey import Passkey
from .bangumi import BangumiDatabase
from .engine import async_engine, async_session_factory, engine as e
from .passkey import PasskeyDatabase
from .engine import engine as e
from .rss import RSSDatabase
from .torrent import TorrentDatabase
from .user import UserDatabase
class Database:
def __init__(self):
self._session = None
self.rss: RSSDatabase | None = None
self.torrent: TorrentDatabase | None = None
self.bangumi: BangumiDatabase | None = None
self.user: UserDatabase | None = None
self.passkey: PasskeyDatabase | None = None
class Database(Session):
def __init__(self, engine=e):
self.engine = engine
super().__init__(engine)
self.rss = RSSDatabase(self)
self.torrent = TorrentDatabase(self)
self.bangumi = BangumiDatabase(self)
self.user = UserDatabase(self)
# Sync context manager (for legacy code)
def __enter__(self):
from .engine import db_session
def create_table(self):
SQLModel.metadata.create_all(self.engine)
self._session = db_session
self.rss = RSSDatabase(self._session)
self.torrent = TorrentDatabase(self._session)
self.bangumi = BangumiDatabase(self._session)
self.user = UserDatabase(self._session)
return self
def drop_table(self):
SQLModel.metadata.drop_all(self.engine)
def __exit__(self, exc_type, exc_val, exc_tb):
pass
# Async context manager (for passkey and new async code)
async def __aenter__(self):
self._session = async_session_factory()
self.rss = RSSDatabase(self._session)
self.torrent = TorrentDatabase(self._session)
self.bangumi = BangumiDatabase(self._session)
self.user = UserDatabase(self._session)
self.passkey = PasskeyDatabase(self._session)
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
if self._session and isinstance(self._session, AsyncSession):
await self._session.close()
async def create_table(self):
async with async_engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all)
async def drop_table(self):
async with async_engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.drop_all)
async def commit(self):
if self._session:
if isinstance(self._session, AsyncSession):
await self._session.commit()
else:
self._session.commit()
async def add(self, obj):
if self._session:
self._session.add(obj)
if isinstance(self._session, AsyncSession):
await self._session.commit()
else:
self._session.commit()
async def migrate(self):
def migrate(self):
# Run migration online
bangumi_data = await self.bangumi.search_all()
bangumi_data = self.bangumi.search_all()
user_data = self.exec("SELECT * FROM user").all()
readd_bangumi = []
for bangumi in bangumi_data:
dict_data = bangumi.dict()
del dict_data["id"]
readd_bangumi.append(Bangumi(**dict_data))
await self.drop_table()
await self.create_table()
await self.commit()
await self.bangumi.add_all(readd_bangumi)
self.drop_table()
self.create_table()
self.commit()
bangumi_data = self.bangumi.search_all()
self.bangumi.add_all(readd_bangumi)
self.add(User(**user_data[0]))
self.commit()

View File

@@ -1,14 +1,13 @@
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from sqlmodel import Session, create_engine
from sqlmodel import create_engine
from module.conf import DATA_PATH
# Sync engine (for legacy code)
# Sync engine (used by Database which extends Session)
engine = create_engine(DATA_PATH)
db_session = Session(engine)
# Async engine (for passkey and new async code)
# Async engine (for passkey operations)
ASYNC_DATA_PATH = DATA_PATH.replace("sqlite:///", "sqlite+aiosqlite:///")
async_engine = create_async_engine(ASYNC_DATA_PATH)
async_session_factory = sessionmaker(async_engine, class_=AsyncSession, expire_on_commit=False)

View File

@@ -1,8 +1,7 @@
import logging
from fastapi import HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import select
from sqlmodel import Session, select
from module.models import ResponseModel
from module.models.user import User, UserLogin, UserUpdate
@@ -12,36 +11,28 @@ logger = logging.getLogger(__name__)
class UserDatabase:
def __init__(self, session):
def __init__(self, session: Session):
self.session = session
async def get_user(self, username):
def get_user(self, username):
statement = select(User).where(User.username == username)
if isinstance(self.session, AsyncSession):
result = await self.session.execute(statement)
user = result.scalar_one_or_none()
else:
user = self.session.exec(statement).first()
if not user:
result = self.session.exec(statement).first()
if not result:
raise HTTPException(status_code=404, detail="User not found")
return user
return result
async def auth_user(self, user: User):
def auth_user(self, user: User):
statement = select(User).where(User.username == user.username)
if isinstance(self.session, AsyncSession):
result = await self.session.execute(statement)
db_user = result.scalar_one_or_none()
else:
db_user = self.session.exec(statement).first()
result = self.session.exec(statement).first()
if not user.password:
return ResponseModel(
status_code=401, status=False, msg_en="Incorrect password format", msg_zh="密码格式不正确"
)
if not db_user:
if not result:
return ResponseModel(
status_code=401, status=False, msg_en="User not found", msg_zh="用户不存在"
)
if not verify_password(user.password, db_user.password):
if not verify_password(user.password, result.password):
return ResponseModel(
status_code=401,
status=False,
@@ -52,59 +43,36 @@ class UserDatabase:
status_code=200, status=True, msg_en="Login successfully", msg_zh="登录成功"
)
async def update_user(self, username, update_user: UserUpdate):
def update_user(self, username, update_user: UserUpdate):
# Update username and password
statement = select(User).where(User.username == username)
if isinstance(self.session, AsyncSession):
result = await self.session.execute(statement)
db_user = result.scalar_one_or_none()
else:
db_user = self.session.exec(statement).first()
if not db_user:
result = self.session.exec(statement).first()
if not result:
raise HTTPException(status_code=404, detail="User not found")
if update_user.username:
db_user.username = update_user.username
result.username = update_user.username
if update_user.password:
db_user.password = get_password_hash(update_user.password)
self.session.add(db_user)
if isinstance(self.session, AsyncSession):
await self.session.commit()
else:
self.session.commit()
return db_user
async def add_default_user(self):
statement = select(User)
if isinstance(self.session, AsyncSession):
result = await self.session.execute(statement)
users = list(result.scalars().all())
else:
try:
users = self.session.exec(statement).all()
except Exception:
self.merge_old_user()
users = self.session.exec(statement).all()
if len(users) != 0:
return
user = User(username="admin", password=get_password_hash("adminadmin"))
self.session.add(user)
if isinstance(self.session, AsyncSession):
await self.session.commit()
else:
self.session.commit()
result.password = get_password_hash(update_user.password)
self.session.add(result)
self.session.commit()
return result
def merge_old_user(self):
# Legacy migration - sync only
# get old data
statement = """
SELECT * FROM user
"""
result = self.session.exec(statement).first()
if not result:
return
# add new data
user = User(username=result.username, password=result.password)
# Drop old table
statement = """
DROP TABLE user
"""
self.session.exec(statement)
# Create new table
statement = """
CREATE TABLE user (
id INTEGER NOT NULL PRIMARY KEY,
@@ -115,3 +83,18 @@ class UserDatabase:
self.session.exec(statement)
self.session.add(user)
self.session.commit()
def add_default_user(self):
# Check if user exists
statement = select(User)
try:
result = self.session.exec(statement).all()
except Exception:
self.merge_old_user()
result = self.session.exec(statement).all()
if len(result) != 0:
return
# Add default user
user = User(username="admin", password=get_password_hash("adminadmin"))
self.session.add(user)
self.session.commit()