From 3ca9a9737fcf4266fe1f1a71cf22f0a271e2505c Mon Sep 17 00:00:00 2001 From: estrella Date: Sun, 30 Jul 2023 21:33:29 +0800 Subject: [PATCH] feat: update hand-made orm to sqlmodel (UserDatabase) --- backend/src/module/database/user.py | 79 +++++++++++++---------------- backend/src/module/models/user.py | 14 ++++- 2 files changed, 48 insertions(+), 45 deletions(-) diff --git a/backend/src/module/database/user.py b/backend/src/module/database/user.py index 69dd9dea..382ada60 100644 --- a/backend/src/module/database/user.py +++ b/backend/src/module/database/user.py @@ -2,69 +2,62 @@ import logging from fastapi import HTTPException -from module.database.connector import DataConnector -from module.models.user import User +from module.models.user import User, UserUpdate, UserLogin from module.security.jwt import get_password_hash, verify_password +from module.database.engine import engine +from sqlmodel import Session, select, SQLModel logger = logging.getLogger(__name__) -class AuthDB(DataConnector): +class AuthDB(Session): def __init__(self): super().__init__() - self.__table_name = "user" - if not self._table_exists(self.__table_name): - self.__update_table() - - def __update_table(self): - db_data = self.__data_to_db(User()) - self._update_table(self.__table_name, db_data) - self._insert(self.__table_name, db_data) + self.__update_table() @staticmethod - def __data_to_db(data: User) -> dict: - db_data = data.dict() - db_data["password"] = get_password_hash(db_data["password"]) - return db_data + def __update_table(): + SQLModel.metadata.create_all(engine) - @staticmethod - def __db_to_data(db_data: dict) -> User: - return User(**db_data) + # @staticmethod + # def __data_to_db(data: User) -> dict: + # db_data = data.dict() + # db_data["password"] = get_password_hash(db_data["password"]) + # return db_data + # + # @staticmethod + # def __db_to_data(db_data: dict) -> User: + # return User(**db_data) def get_user(self, username): - self._cursor.execute( - f"SELECT * FROM {self.__table_name} WHERE username=?", (username,) - ) - result = self._cursor.fetchone() + statement = select(User).where(User.username == username) + result = self.exec(statement).first() if not result: - return None - db_data = dict(zip([x[0] for x in self._cursor.description], result)) - return self.__db_to_data(db_data) + raise HTTPException(status_code=404, detail="User not found") + return result - def auth_user(self, username, password) -> bool: - self._cursor.execute( - f"SELECT username, password FROM {self.__table_name} WHERE username=?", - (username,), - ) - result = self._cursor.fetchone() + def auth_user(self, user: UserLogin) -> bool: + statement = select(User).where(User.username == user.username) + result = self.exec(statement).first() if not result: raise HTTPException(status_code=401, detail="User not found") - if not verify_password(password, result[1]): + if not verify_password(user.password, result.password): raise HTTPException(status_code=401, detail="Password error") return True - def update_user(self, username, update_user: User): + def update_user(self, username, update_user: UserUpdate): # Update username and password - new_username = update_user.username - new_password = update_user.password - self._cursor.execute( - f""" - UPDATE {self.__table_name} - SET username = '{new_username}', password = '{get_password_hash(new_password)}' - WHERE username = '{username}' - """ - ) - self._conn.commit() + statement = select(User).where(User.username == username) + result = self.exec(statement).first() + if not result: + raise HTTPException(status_code=404, detail="User not found") + if update_user.username: + result.username = update_user.username + if update_user.password: + result.password = get_password_hash(update_user.password) + self.add(result) + self.commit() + return result if __name__ == "__main__": diff --git a/backend/src/module/models/user.py b/backend/src/module/models/user.py index 36512642..662622cd 100644 --- a/backend/src/module/models/user.py +++ b/backend/src/module/models/user.py @@ -1,14 +1,24 @@ from pydantic import BaseModel, Field +from typing import Optional +from sqlmodel import SQLModel, Field -class User(BaseModel): +class User(SQLModel, table=True): + id: int = Field(default=None, primary_key=True) username: str = Field( "admin", min_length=4, max_length=20, regex=r"^[a-zA-Z0-9_]+$" ) password: str = Field("adminadmin", min_length=8) -class UserLogin(BaseModel): +class UserUpdate(SQLModel): + username: Optional[str] = Field( + None, min_length=4, max_length=20, regex=r"^[a-zA-Z0-9_]+$" + ) + password: Optional[str] = Field(None, min_length=8) + + +class UserLogin(SQLModel): username: str password: str = Field(..., min_length=8)