mirror of
https://github.com/EstrellaXD/Auto_Bangumi.git
synced 2026-04-14 10:30:35 +08:00
fix(backend): restore sync Database interface, isolate async for passkey
The previous refactoring broke backward compatibility by converting Database from Session-extending sync class to a standalone async class. This broke RSSEngine, startup code, and auth flows. - Restore Database(Session) with sync interface for legacy code - Restore UserDatabase to sync methods - Restore security/api.py and auth.py to sync calls - Passkey API now uses async_session_factory directly - PasskeyAuthStrategy uses async sessions independently - Remove unused db_session from engine.py Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -1,85 +1,43 @@
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlmodel import SQLModel
|
||||
from sqlmodel import Session, SQLModel
|
||||
|
||||
from module.models import Bangumi, Passkey, User
|
||||
from module.models import Bangumi, User
|
||||
from module.models.passkey import Passkey
|
||||
|
||||
from .bangumi import BangumiDatabase
|
||||
from .engine import async_engine, async_session_factory, engine as e
|
||||
from .passkey import PasskeyDatabase
|
||||
from .engine import engine as e
|
||||
from .rss import RSSDatabase
|
||||
from .torrent import TorrentDatabase
|
||||
from .user import UserDatabase
|
||||
|
||||
|
||||
class Database:
|
||||
def __init__(self):
|
||||
self._session = None
|
||||
self.rss: RSSDatabase | None = None
|
||||
self.torrent: TorrentDatabase | None = None
|
||||
self.bangumi: BangumiDatabase | None = None
|
||||
self.user: UserDatabase | None = None
|
||||
self.passkey: PasskeyDatabase | None = None
|
||||
class Database(Session):
|
||||
def __init__(self, engine=e):
|
||||
self.engine = engine
|
||||
super().__init__(engine)
|
||||
self.rss = RSSDatabase(self)
|
||||
self.torrent = TorrentDatabase(self)
|
||||
self.bangumi = BangumiDatabase(self)
|
||||
self.user = UserDatabase(self)
|
||||
|
||||
# Sync context manager (for legacy code)
|
||||
def __enter__(self):
|
||||
from .engine import db_session
|
||||
def create_table(self):
|
||||
SQLModel.metadata.create_all(self.engine)
|
||||
|
||||
self._session = db_session
|
||||
self.rss = RSSDatabase(self._session)
|
||||
self.torrent = TorrentDatabase(self._session)
|
||||
self.bangumi = BangumiDatabase(self._session)
|
||||
self.user = UserDatabase(self._session)
|
||||
return self
|
||||
def drop_table(self):
|
||||
SQLModel.metadata.drop_all(self.engine)
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
pass
|
||||
|
||||
# Async context manager (for passkey and new async code)
|
||||
async def __aenter__(self):
|
||||
self._session = async_session_factory()
|
||||
self.rss = RSSDatabase(self._session)
|
||||
self.torrent = TorrentDatabase(self._session)
|
||||
self.bangumi = BangumiDatabase(self._session)
|
||||
self.user = UserDatabase(self._session)
|
||||
self.passkey = PasskeyDatabase(self._session)
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
if self._session and isinstance(self._session, AsyncSession):
|
||||
await self._session.close()
|
||||
|
||||
async def create_table(self):
|
||||
async with async_engine.begin() as conn:
|
||||
await conn.run_sync(SQLModel.metadata.create_all)
|
||||
|
||||
async def drop_table(self):
|
||||
async with async_engine.begin() as conn:
|
||||
await conn.run_sync(SQLModel.metadata.drop_all)
|
||||
|
||||
async def commit(self):
|
||||
if self._session:
|
||||
if isinstance(self._session, AsyncSession):
|
||||
await self._session.commit()
|
||||
else:
|
||||
self._session.commit()
|
||||
|
||||
async def add(self, obj):
|
||||
if self._session:
|
||||
self._session.add(obj)
|
||||
if isinstance(self._session, AsyncSession):
|
||||
await self._session.commit()
|
||||
else:
|
||||
self._session.commit()
|
||||
|
||||
async def migrate(self):
|
||||
def migrate(self):
|
||||
# Run migration online
|
||||
bangumi_data = await self.bangumi.search_all()
|
||||
bangumi_data = self.bangumi.search_all()
|
||||
user_data = self.exec("SELECT * FROM user").all()
|
||||
readd_bangumi = []
|
||||
for bangumi in bangumi_data:
|
||||
dict_data = bangumi.dict()
|
||||
del dict_data["id"]
|
||||
readd_bangumi.append(Bangumi(**dict_data))
|
||||
await self.drop_table()
|
||||
await self.create_table()
|
||||
await self.commit()
|
||||
await self.bangumi.add_all(readd_bangumi)
|
||||
self.drop_table()
|
||||
self.create_table()
|
||||
self.commit()
|
||||
bangumi_data = self.bangumi.search_all()
|
||||
self.bangumi.add_all(readd_bangumi)
|
||||
self.add(User(**user_data[0]))
|
||||
self.commit()
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlmodel import Session, create_engine
|
||||
from sqlmodel import create_engine
|
||||
|
||||
from module.conf import DATA_PATH
|
||||
|
||||
# Sync engine (for legacy code)
|
||||
# Sync engine (used by Database which extends Session)
|
||||
engine = create_engine(DATA_PATH)
|
||||
db_session = Session(engine)
|
||||
|
||||
# Async engine (for passkey and new async code)
|
||||
# Async engine (for passkey operations)
|
||||
ASYNC_DATA_PATH = DATA_PATH.replace("sqlite:///", "sqlite+aiosqlite:///")
|
||||
async_engine = create_async_engine(ASYNC_DATA_PATH)
|
||||
async_session_factory = sessionmaker(async_engine, class_=AsyncSession, expire_on_commit=False)
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
import logging
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlmodel import select
|
||||
from sqlmodel import Session, select
|
||||
|
||||
from module.models import ResponseModel
|
||||
from module.models.user import User, UserLogin, UserUpdate
|
||||
@@ -12,36 +11,28 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UserDatabase:
|
||||
def __init__(self, session):
|
||||
def __init__(self, session: Session):
|
||||
self.session = session
|
||||
|
||||
async def get_user(self, username):
|
||||
def get_user(self, username):
|
||||
statement = select(User).where(User.username == username)
|
||||
if isinstance(self.session, AsyncSession):
|
||||
result = await self.session.execute(statement)
|
||||
user = result.scalar_one_or_none()
|
||||
else:
|
||||
user = self.session.exec(statement).first()
|
||||
if not user:
|
||||
result = self.session.exec(statement).first()
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
return user
|
||||
return result
|
||||
|
||||
async def auth_user(self, user: User):
|
||||
def auth_user(self, user: User):
|
||||
statement = select(User).where(User.username == user.username)
|
||||
if isinstance(self.session, AsyncSession):
|
||||
result = await self.session.execute(statement)
|
||||
db_user = result.scalar_one_or_none()
|
||||
else:
|
||||
db_user = self.session.exec(statement).first()
|
||||
result = self.session.exec(statement).first()
|
||||
if not user.password:
|
||||
return ResponseModel(
|
||||
status_code=401, status=False, msg_en="Incorrect password format", msg_zh="密码格式不正确"
|
||||
)
|
||||
if not db_user:
|
||||
if not result:
|
||||
return ResponseModel(
|
||||
status_code=401, status=False, msg_en="User not found", msg_zh="用户不存在"
|
||||
)
|
||||
if not verify_password(user.password, db_user.password):
|
||||
if not verify_password(user.password, result.password):
|
||||
return ResponseModel(
|
||||
status_code=401,
|
||||
status=False,
|
||||
@@ -52,59 +43,36 @@ class UserDatabase:
|
||||
status_code=200, status=True, msg_en="Login successfully", msg_zh="登录成功"
|
||||
)
|
||||
|
||||
async def update_user(self, username, update_user: UserUpdate):
|
||||
def update_user(self, username, update_user: UserUpdate):
|
||||
# Update username and password
|
||||
statement = select(User).where(User.username == username)
|
||||
if isinstance(self.session, AsyncSession):
|
||||
result = await self.session.execute(statement)
|
||||
db_user = result.scalar_one_or_none()
|
||||
else:
|
||||
db_user = self.session.exec(statement).first()
|
||||
if not db_user:
|
||||
result = self.session.exec(statement).first()
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
if update_user.username:
|
||||
db_user.username = update_user.username
|
||||
result.username = update_user.username
|
||||
if update_user.password:
|
||||
db_user.password = get_password_hash(update_user.password)
|
||||
self.session.add(db_user)
|
||||
if isinstance(self.session, AsyncSession):
|
||||
await self.session.commit()
|
||||
else:
|
||||
self.session.commit()
|
||||
return db_user
|
||||
|
||||
async def add_default_user(self):
|
||||
statement = select(User)
|
||||
if isinstance(self.session, AsyncSession):
|
||||
result = await self.session.execute(statement)
|
||||
users = list(result.scalars().all())
|
||||
else:
|
||||
try:
|
||||
users = self.session.exec(statement).all()
|
||||
except Exception:
|
||||
self.merge_old_user()
|
||||
users = self.session.exec(statement).all()
|
||||
if len(users) != 0:
|
||||
return
|
||||
user = User(username="admin", password=get_password_hash("adminadmin"))
|
||||
self.session.add(user)
|
||||
if isinstance(self.session, AsyncSession):
|
||||
await self.session.commit()
|
||||
else:
|
||||
self.session.commit()
|
||||
result.password = get_password_hash(update_user.password)
|
||||
self.session.add(result)
|
||||
self.session.commit()
|
||||
return result
|
||||
|
||||
def merge_old_user(self):
|
||||
# Legacy migration - sync only
|
||||
# get old data
|
||||
statement = """
|
||||
SELECT * FROM user
|
||||
"""
|
||||
result = self.session.exec(statement).first()
|
||||
if not result:
|
||||
return
|
||||
# add new data
|
||||
user = User(username=result.username, password=result.password)
|
||||
# Drop old table
|
||||
statement = """
|
||||
DROP TABLE user
|
||||
"""
|
||||
self.session.exec(statement)
|
||||
# Create new table
|
||||
statement = """
|
||||
CREATE TABLE user (
|
||||
id INTEGER NOT NULL PRIMARY KEY,
|
||||
@@ -115,3 +83,18 @@ class UserDatabase:
|
||||
self.session.exec(statement)
|
||||
self.session.add(user)
|
||||
self.session.commit()
|
||||
|
||||
def add_default_user(self):
|
||||
# Check if user exists
|
||||
statement = select(User)
|
||||
try:
|
||||
result = self.session.exec(statement).all()
|
||||
except Exception:
|
||||
self.merge_old_user()
|
||||
result = self.session.exec(statement).all()
|
||||
if len(result) != 0:
|
||||
return
|
||||
# Add default user
|
||||
user = User(username="admin", password=get_password_hash("adminadmin"))
|
||||
self.session.add(user)
|
||||
self.session.commit()
|
||||
|
||||
Reference in New Issue
Block a user