mirror of
https://github.com/EstrellaXD/Auto_Bangumi.git
synced 2026-04-13 10:19:47 +08:00
fix: Userdatabase bugs
This commit is contained in:
@@ -4,7 +4,7 @@ from .config import VERSION, settings
|
||||
from .log import LOG_PATH, setup_logger
|
||||
|
||||
TMDB_API = "32b19d6a05b512190a056fa4e747cbbc"
|
||||
DATA_PATH = Path("data/data.db")
|
||||
DATA_PATH = "sqlite:///data/data.db"
|
||||
LEGACY_DATA_PATH = Path("data/data.json")
|
||||
|
||||
PLATFORM = "Windows" if "\\" in settings.downloader.path else "Unix"
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import logging
|
||||
|
||||
from module.database.orm import Connector
|
||||
from module.models import Bangumi, BangumiUpdate
|
||||
from module.conf import DATA_PATH
|
||||
from sqlmodel import Session, select, delete, SQLModel
|
||||
from module.database.engine import engine
|
||||
from typing import Optional
|
||||
@@ -14,51 +12,18 @@ logger = logging.getLogger(__name__)
|
||||
class BangumiDatabase(Session):
|
||||
def __init__(self):
|
||||
super().__init__(engine)
|
||||
# table_name="bangumi",
|
||||
# data=self.__data_to_db(BangumiData()),
|
||||
# database=database,
|
||||
# )
|
||||
|
||||
@staticmethod
|
||||
def update_table():
|
||||
SQLModel.metadata.create_all(engine)
|
||||
|
||||
# @staticmethod
|
||||
# def __data_to_db(data: BangumiData) -> Bangumi:
|
||||
# db_data = data.dict()
|
||||
# for key, value in db_data.items():
|
||||
# if isinstance(value, bool):
|
||||
# db_data[key] = int(value)
|
||||
# elif isinstance(value, list):
|
||||
# db_data[key] = ",".join(value)
|
||||
# return db_data
|
||||
#
|
||||
# @staticmethod
|
||||
# def __db_to_data(db_data: Bangumi) -> BangumiData:
|
||||
# for key, item in db_data.items():
|
||||
# if isinstance(item, int):
|
||||
# if key not in ["id", "offset", "season", "year"]:
|
||||
# db_data[key] = bool(item)
|
||||
# elif key in ["filter", "rss_link"]:
|
||||
# db_data[key] = item.split(",")
|
||||
# return BangumiData(**db_data)
|
||||
|
||||
def insert_one(self, data: Bangumi):
|
||||
self.add(data)
|
||||
self.commit()
|
||||
# db_data = self.__data_to_db(data)
|
||||
# self.insert.one(db_data)
|
||||
logger.debug(f"[Database] Insert {data.official_title} into database.")
|
||||
|
||||
def insert_list(self, data: list[Bangumi]):
|
||||
self.add_all(data)
|
||||
# data_list = [self.__data_to_db(x) for x in data]
|
||||
# self.insert.many(data_list)
|
||||
# _id = self.gen_id()
|
||||
# for i, item in enumerate(data):
|
||||
# item.id = _id + i
|
||||
# data_list = [self.__data_to_db(x) for x in data]
|
||||
# self._insert_list(data_list=data_list, table_name=self.__table_name)
|
||||
logger.debug(f"[Database] Insert {len(data)} bangumi into database.")
|
||||
|
||||
def update_one(self, data: BangumiUpdate) -> bool:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from sqlmodel import create_engine, Session
|
||||
from module.conf import DATA_PATH
|
||||
|
||||
|
||||
engine = create_engine("sqlite:///data/data.db")
|
||||
engine = create_engine(DATA_PATH)
|
||||
|
||||
db_session = Session(engine)
|
||||
@@ -6,18 +6,22 @@ from module.models.user import User, UserUpdate, UserLogin
|
||||
from module.security.jwt import get_password_hash, verify_password
|
||||
from module.database.engine import engine
|
||||
from sqlmodel import Session, select, SQLModel
|
||||
from sqlalchemy.exc import UnboundExecutionError, OperationalError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AuthDB(Session):
|
||||
class UserDatabase(Session):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.__update_table()
|
||||
super().__init__(engine)
|
||||
statement = select(User)
|
||||
try:
|
||||
self.exec(statement)
|
||||
except OperationalError:
|
||||
SQLModel.metadata.create_all(engine)
|
||||
self.add(User())
|
||||
self.commit()
|
||||
|
||||
@staticmethod
|
||||
def __update_table():
|
||||
SQLModel.metadata.create_all(engine)
|
||||
|
||||
# @staticmethod
|
||||
# def __data_to_db(data: User) -> dict:
|
||||
@@ -61,6 +65,6 @@ class AuthDB(Session):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
with AuthDB() as db:
|
||||
with UserDatabase() as db:
|
||||
# db.update_user(UserLogin(username="admin", password="adminadmin"), User(username="admin", password="cica1234"))
|
||||
db.update_user("admin", User(username="estrella", password="cica1234"))
|
||||
db.update_user("admin", UserUpdate(username="estrella", password="cica1234"))
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
from sqlmodel import SQLModel, Field
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
|
||||
from module.database.user import AuthDB
|
||||
from module.database.user import UserDatabase
|
||||
from module.models.user import User
|
||||
|
||||
from .jwt import verify_token
|
||||
@@ -20,7 +20,7 @@ async def get_current_user(token: str = Depends(oauth2_scheme)):
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token"
|
||||
)
|
||||
username = payload.get("sub")
|
||||
with AuthDB() as user_db:
|
||||
with UserDatabase as user_db:
|
||||
user = user_db.get_user(username)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
@@ -40,7 +40,7 @@ async def get_token_data(token: str = Depends(oauth2_scheme)):
|
||||
|
||||
def update_user_info(user_data: User, current_user):
|
||||
try:
|
||||
with AuthDB() as db:
|
||||
with UserDatabase as db:
|
||||
db.update_user(current_user.username, user_data)
|
||||
return True
|
||||
except Exception as e:
|
||||
@@ -48,5 +48,5 @@ def update_user_info(user_data: User, current_user):
|
||||
|
||||
|
||||
def auth_user(username, password):
|
||||
with AuthDB() as db:
|
||||
with UserDatabase() as db:
|
||||
db.auth_user(username, password)
|
||||
|
||||
Reference in New Issue
Block a user