From d24acc60d5d5dd805e90da8933bff20938af1bd4 Mon Sep 17 00:00:00 2001 From: estrella Date: Sun, 30 Jul 2023 21:49:34 +0800 Subject: [PATCH] fix: Userdatabase bugs --- backend/src/module/conf/__init__.py | 2 +- backend/src/module/database/bangumi.py | 35 -------------------------- backend/src/module/database/engine.py | 3 ++- backend/src/module/database/user.py | 20 +++++++++------ backend/src/module/models/user.py | 2 +- backend/src/module/security/api.py | 8 +++--- 6 files changed, 20 insertions(+), 50 deletions(-) diff --git a/backend/src/module/conf/__init__.py b/backend/src/module/conf/__init__.py index 00dbc6fe..b6bc7662 100644 --- a/backend/src/module/conf/__init__.py +++ b/backend/src/module/conf/__init__.py @@ -4,7 +4,7 @@ from .config import VERSION, settings from .log import LOG_PATH, setup_logger TMDB_API = "32b19d6a05b512190a056fa4e747cbbc" -DATA_PATH = Path("data/data.db") +DATA_PATH = "sqlite:///data/data.db" LEGACY_DATA_PATH = Path("data/data.json") PLATFORM = "Windows" if "\\" in settings.downloader.path else "Unix" diff --git a/backend/src/module/database/bangumi.py b/backend/src/module/database/bangumi.py index 9209e717..986ccee5 100644 --- a/backend/src/module/database/bangumi.py +++ b/backend/src/module/database/bangumi.py @@ -1,8 +1,6 @@ import logging -from module.database.orm import Connector from module.models import Bangumi, BangumiUpdate -from module.conf import DATA_PATH from sqlmodel import Session, select, delete, SQLModel from module.database.engine import engine from typing import Optional @@ -14,51 +12,18 @@ logger = logging.getLogger(__name__) class BangumiDatabase(Session): def __init__(self): super().__init__(engine) - # table_name="bangumi", - # data=self.__data_to_db(BangumiData()), - # database=database, - # ) @staticmethod def update_table(): SQLModel.metadata.create_all(engine) - # @staticmethod - # def __data_to_db(data: BangumiData) -> Bangumi: - # db_data = data.dict() - # for key, value in db_data.items(): - # if isinstance(value, bool): - # db_data[key] = int(value) - # elif isinstance(value, list): - # db_data[key] = ",".join(value) - # return db_data - # - # @staticmethod - # def __db_to_data(db_data: Bangumi) -> BangumiData: - # for key, item in db_data.items(): - # if isinstance(item, int): - # if key not in ["id", "offset", "season", "year"]: - # db_data[key] = bool(item) - # elif key in ["filter", "rss_link"]: - # db_data[key] = item.split(",") - # return BangumiData(**db_data) - def insert_one(self, data: Bangumi): self.add(data) self.commit() - # db_data = self.__data_to_db(data) - # self.insert.one(db_data) logger.debug(f"[Database] Insert {data.official_title} into database.") def insert_list(self, data: list[Bangumi]): self.add_all(data) - # data_list = [self.__data_to_db(x) for x in data] - # self.insert.many(data_list) - # _id = self.gen_id() - # for i, item in enumerate(data): - # item.id = _id + i - # data_list = [self.__data_to_db(x) for x in data] - # self._insert_list(data_list=data_list, table_name=self.__table_name) logger.debug(f"[Database] Insert {len(data)} bangumi into database.") def update_one(self, data: BangumiUpdate) -> bool: diff --git a/backend/src/module/database/engine.py b/backend/src/module/database/engine.py index e4bd2f5d..369c53b3 100644 --- a/backend/src/module/database/engine.py +++ b/backend/src/module/database/engine.py @@ -1,6 +1,7 @@ from sqlmodel import create_engine, Session +from module.conf import DATA_PATH -engine = create_engine("sqlite:///data/data.db") +engine = create_engine(DATA_PATH) db_session = Session(engine) \ No newline at end of file diff --git a/backend/src/module/database/user.py b/backend/src/module/database/user.py index 382ada60..c763346f 100644 --- a/backend/src/module/database/user.py +++ b/backend/src/module/database/user.py @@ -6,18 +6,22 @@ from module.models.user import User, UserUpdate, UserLogin from module.security.jwt import get_password_hash, verify_password from module.database.engine import engine from sqlmodel import Session, select, SQLModel +from sqlalchemy.exc import UnboundExecutionError, OperationalError logger = logging.getLogger(__name__) -class AuthDB(Session): +class UserDatabase(Session): def __init__(self): - super().__init__() - self.__update_table() + super().__init__(engine) + statement = select(User) + try: + self.exec(statement) + except OperationalError: + SQLModel.metadata.create_all(engine) + self.add(User()) + self.commit() - @staticmethod - def __update_table(): - SQLModel.metadata.create_all(engine) # @staticmethod # def __data_to_db(data: User) -> dict: @@ -61,6 +65,6 @@ class AuthDB(Session): if __name__ == "__main__": - with AuthDB() as db: + with UserDatabase() as db: # db.update_user(UserLogin(username="admin", password="adminadmin"), User(username="admin", password="cica1234")) - db.update_user("admin", User(username="estrella", password="cica1234")) + db.update_user("admin", UserUpdate(username="estrella", password="cica1234")) diff --git a/backend/src/module/models/user.py b/backend/src/module/models/user.py index 662622cd..3e5cef29 100644 --- a/backend/src/module/models/user.py +++ b/backend/src/module/models/user.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel, Field +from pydantic import BaseModel from typing import Optional from sqlmodel import SQLModel, Field diff --git a/backend/src/module/security/api.py b/backend/src/module/security/api.py index 368abb0e..b41dcafa 100644 --- a/backend/src/module/security/api.py +++ b/backend/src/module/security/api.py @@ -1,7 +1,7 @@ from fastapi import Depends, HTTPException, status from fastapi.security import OAuth2PasswordBearer -from module.database.user import AuthDB +from module.database.user import UserDatabase from module.models.user import User from .jwt import verify_token @@ -20,7 +20,7 @@ async def get_current_user(token: str = Depends(oauth2_scheme)): status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token" ) username = payload.get("sub") - with AuthDB() as user_db: + with UserDatabase as user_db: user = user_db.get_user(username) if not user: raise HTTPException( @@ -40,7 +40,7 @@ async def get_token_data(token: str = Depends(oauth2_scheme)): def update_user_info(user_data: User, current_user): try: - with AuthDB() as db: + with UserDatabase as db: db.update_user(current_user.username, user_data) return True except Exception as e: @@ -48,5 +48,5 @@ def update_user_info(user_data: User, current_user): def auth_user(username, password): - with AuthDB() as db: + with UserDatabase() as db: db.auth_user(username, password)