fix: Userdatabase bugs

This commit is contained in:
estrella
2023-07-30 21:49:34 +08:00
parent 3ca9a9737f
commit d24acc60d5
6 changed files with 20 additions and 50 deletions

View File

@@ -4,7 +4,7 @@ from .config import VERSION, settings
from .log import LOG_PATH, setup_logger
TMDB_API = "32b19d6a05b512190a056fa4e747cbbc"
DATA_PATH = Path("data/data.db")
DATA_PATH = "sqlite:///data/data.db"
LEGACY_DATA_PATH = Path("data/data.json")
PLATFORM = "Windows" if "\\" in settings.downloader.path else "Unix"

View File

@@ -1,8 +1,6 @@
import logging
from module.database.orm import Connector
from module.models import Bangumi, BangumiUpdate
from module.conf import DATA_PATH
from sqlmodel import Session, select, delete, SQLModel
from module.database.engine import engine
from typing import Optional
@@ -14,51 +12,18 @@ logger = logging.getLogger(__name__)
class BangumiDatabase(Session):
def __init__(self):
super().__init__(engine)
# table_name="bangumi",
# data=self.__data_to_db(BangumiData()),
# database=database,
# )
@staticmethod
def update_table():
SQLModel.metadata.create_all(engine)
# @staticmethod
# def __data_to_db(data: BangumiData) -> Bangumi:
# db_data = data.dict()
# for key, value in db_data.items():
# if isinstance(value, bool):
# db_data[key] = int(value)
# elif isinstance(value, list):
# db_data[key] = ",".join(value)
# return db_data
#
# @staticmethod
# def __db_to_data(db_data: Bangumi) -> BangumiData:
# for key, item in db_data.items():
# if isinstance(item, int):
# if key not in ["id", "offset", "season", "year"]:
# db_data[key] = bool(item)
# elif key in ["filter", "rss_link"]:
# db_data[key] = item.split(",")
# return BangumiData(**db_data)
def insert_one(self, data: Bangumi):
self.add(data)
self.commit()
# db_data = self.__data_to_db(data)
# self.insert.one(db_data)
logger.debug(f"[Database] Insert {data.official_title} into database.")
def insert_list(self, data: list[Bangumi]):
self.add_all(data)
# data_list = [self.__data_to_db(x) for x in data]
# self.insert.many(data_list)
# _id = self.gen_id()
# for i, item in enumerate(data):
# item.id = _id + i
# data_list = [self.__data_to_db(x) for x in data]
# self._insert_list(data_list=data_list, table_name=self.__table_name)
logger.debug(f"[Database] Insert {len(data)} bangumi into database.")
def update_one(self, data: BangumiUpdate) -> bool:

View File

@@ -1,6 +1,7 @@
from sqlmodel import create_engine, Session
from module.conf import DATA_PATH
engine = create_engine("sqlite:///data/data.db")
engine = create_engine(DATA_PATH)
db_session = Session(engine)

View File

@@ -6,18 +6,22 @@ from module.models.user import User, UserUpdate, UserLogin
from module.security.jwt import get_password_hash, verify_password
from module.database.engine import engine
from sqlmodel import Session, select, SQLModel
from sqlalchemy.exc import UnboundExecutionError, OperationalError
logger = logging.getLogger(__name__)
class AuthDB(Session):
class UserDatabase(Session):
def __init__(self):
super().__init__()
self.__update_table()
super().__init__(engine)
statement = select(User)
try:
self.exec(statement)
except OperationalError:
SQLModel.metadata.create_all(engine)
self.add(User())
self.commit()
@staticmethod
def __update_table():
SQLModel.metadata.create_all(engine)
# @staticmethod
# def __data_to_db(data: User) -> dict:
@@ -61,6 +65,6 @@ class AuthDB(Session):
if __name__ == "__main__":
with AuthDB() as db:
with UserDatabase() as db:
# db.update_user(UserLogin(username="admin", password="adminadmin"), User(username="admin", password="cica1234"))
db.update_user("admin", User(username="estrella", password="cica1234"))
db.update_user("admin", UserUpdate(username="estrella", password="cica1234"))

View File

@@ -1,4 +1,4 @@
from pydantic import BaseModel, Field
from pydantic import BaseModel
from typing import Optional
from sqlmodel import SQLModel, Field

View File

@@ -1,7 +1,7 @@
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from module.database.user import AuthDB
from module.database.user import UserDatabase
from module.models.user import User
from .jwt import verify_token
@@ -20,7 +20,7 @@ async def get_current_user(token: str = Depends(oauth2_scheme)):
status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token"
)
username = payload.get("sub")
with AuthDB() as user_db:
with UserDatabase as user_db:
user = user_db.get_user(username)
if not user:
raise HTTPException(
@@ -40,7 +40,7 @@ async def get_token_data(token: str = Depends(oauth2_scheme)):
def update_user_info(user_data: User, current_user):
try:
with AuthDB() as db:
with UserDatabase as db:
db.update_user(current_user.username, user_data)
return True
except Exception as e:
@@ -48,5 +48,5 @@ def update_user_info(user_data: User, current_user):
def auth_user(username, password):
with AuthDB() as db:
with UserDatabase() as db:
db.auth_user(username, password)