feat: update hand-made orm to sqlmodel (UserDatabase)

This commit is contained in:
estrella
2023-07-30 21:33:29 +08:00
parent 9ae187a05d
commit 3ca9a9737f
2 changed files with 48 additions and 45 deletions

View File

@@ -2,69 +2,62 @@ import logging
from fastapi import HTTPException
from module.database.connector import DataConnector
from module.models.user import User
from module.models.user import User, UserUpdate, UserLogin
from module.security.jwt import get_password_hash, verify_password
from module.database.engine import engine
from sqlmodel import Session, select, SQLModel
logger = logging.getLogger(__name__)
class AuthDB(DataConnector):
class AuthDB(Session):
def __init__(self):
super().__init__()
self.__table_name = "user"
if not self._table_exists(self.__table_name):
self.__update_table()
def __update_table(self):
db_data = self.__data_to_db(User())
self._update_table(self.__table_name, db_data)
self._insert(self.__table_name, db_data)
self.__update_table()
@staticmethod
def __data_to_db(data: User) -> dict:
db_data = data.dict()
db_data["password"] = get_password_hash(db_data["password"])
return db_data
def __update_table():
SQLModel.metadata.create_all(engine)
@staticmethod
def __db_to_data(db_data: dict) -> User:
return User(**db_data)
# @staticmethod
# def __data_to_db(data: User) -> dict:
# db_data = data.dict()
# db_data["password"] = get_password_hash(db_data["password"])
# return db_data
#
# @staticmethod
# def __db_to_data(db_data: dict) -> User:
# return User(**db_data)
def get_user(self, username):
self._cursor.execute(
f"SELECT * FROM {self.__table_name} WHERE username=?", (username,)
)
result = self._cursor.fetchone()
statement = select(User).where(User.username == username)
result = self.exec(statement).first()
if not result:
return None
db_data = dict(zip([x[0] for x in self._cursor.description], result))
return self.__db_to_data(db_data)
raise HTTPException(status_code=404, detail="User not found")
return result
def auth_user(self, username, password) -> bool:
self._cursor.execute(
f"SELECT username, password FROM {self.__table_name} WHERE username=?",
(username,),
)
result = self._cursor.fetchone()
def auth_user(self, user: UserLogin) -> bool:
statement = select(User).where(User.username == user.username)
result = self.exec(statement).first()
if not result:
raise HTTPException(status_code=401, detail="User not found")
if not verify_password(password, result[1]):
if not verify_password(user.password, result.password):
raise HTTPException(status_code=401, detail="Password error")
return True
def update_user(self, username, update_user: User):
def update_user(self, username, update_user: UserUpdate):
# Update username and password
new_username = update_user.username
new_password = update_user.password
self._cursor.execute(
f"""
UPDATE {self.__table_name}
SET username = '{new_username}', password = '{get_password_hash(new_password)}'
WHERE username = '{username}'
"""
)
self._conn.commit()
statement = select(User).where(User.username == username)
result = self.exec(statement).first()
if not result:
raise HTTPException(status_code=404, detail="User not found")
if update_user.username:
result.username = update_user.username
if update_user.password:
result.password = get_password_hash(update_user.password)
self.add(result)
self.commit()
return result
if __name__ == "__main__":