mirror of
https://github.com/EstrellaXD/Auto_Bangumi.git
synced 2026-02-11 06:16:36 +08:00
101 lines
3.3 KiB
Python
101 lines
3.3 KiB
Python
import logging
|
|
|
|
from fastapi import HTTPException
|
|
from sqlmodel import Session, select
|
|
|
|
from module.models import ResponseModel
|
|
from module.models.user import User, UserLogin, UserUpdate
|
|
from module.security.jwt import get_password_hash, verify_password
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class UserDatabase:
|
|
def __init__(self, session: Session):
|
|
self.session = session
|
|
|
|
def get_user(self, username):
|
|
statement = select(User).where(User.username == username)
|
|
result = self.session.exec(statement).first()
|
|
if not result:
|
|
raise HTTPException(status_code=404, detail="User not found")
|
|
return result
|
|
|
|
def auth_user(self, user: User):
|
|
statement = select(User).where(User.username == user.username)
|
|
result = self.session.exec(statement).first()
|
|
if not user.password:
|
|
return ResponseModel(
|
|
status_code=401, status=False, msg_en="Incorrect password format", msg_zh="密码格式不正确"
|
|
)
|
|
if not result:
|
|
return ResponseModel(
|
|
status_code=401, status=False, msg_en="User not found", msg_zh="用户不存在"
|
|
)
|
|
if not verify_password(user.password, result.password):
|
|
return ResponseModel(
|
|
status_code=401,
|
|
status=False,
|
|
msg_en="Incorrect password",
|
|
msg_zh="密码错误",
|
|
)
|
|
return ResponseModel(
|
|
status_code=200, status=True, msg_en="Login successfully", msg_zh="登录成功"
|
|
)
|
|
|
|
def update_user(self, username, update_user: UserUpdate):
|
|
# Update username and password
|
|
statement = select(User).where(User.username == username)
|
|
result = self.session.exec(statement).first()
|
|
if not result:
|
|
raise HTTPException(status_code=404, detail="User not found")
|
|
if update_user.username:
|
|
result.username = update_user.username
|
|
if update_user.password:
|
|
result.password = get_password_hash(update_user.password)
|
|
self.session.add(result)
|
|
self.session.commit()
|
|
return result
|
|
|
|
def merge_old_user(self):
|
|
# get old data
|
|
statement = """
|
|
SELECT * FROM user
|
|
"""
|
|
result = self.session.exec(statement).first()
|
|
if not result:
|
|
return
|
|
# add new data
|
|
user = User(username=result.username, password=result.password)
|
|
# Drop old table
|
|
statement = """
|
|
DROP TABLE user
|
|
"""
|
|
self.session.exec(statement)
|
|
# Create new table
|
|
statement = """
|
|
CREATE TABLE user (
|
|
id INTEGER NOT NULL PRIMARY KEY,
|
|
username VARCHAR NOT NULL,
|
|
password VARCHAR NOT NULL
|
|
)
|
|
"""
|
|
self.session.exec(statement)
|
|
self.session.add(user)
|
|
self.session.commit()
|
|
|
|
def add_default_user(self):
|
|
# Check if user exists
|
|
statement = select(User)
|
|
try:
|
|
result = self.session.exec(statement).all()
|
|
except Exception:
|
|
self.merge_old_user()
|
|
result = self.session.exec(statement).all()
|
|
if len(result) != 0:
|
|
return
|
|
# Add default user
|
|
user = User(username="admin", password=get_password_hash("adminadmin"))
|
|
self.session.add(user)
|
|
self.session.commit()
|