mirror of
https://github.com/EstrellaXD/Auto_Bangumi.git
synced 2026-04-23 18:11:37 +08:00
fix: old data support problem.
This commit is contained in:
@@ -4,34 +4,25 @@ from fastapi import HTTPException
|
||||
|
||||
from module.models.user import User, UserUpdate, UserLogin
|
||||
from module.security.jwt import get_password_hash, verify_password
|
||||
from module.database.engine import engine
|
||||
from sqlmodel import Session, select, SQLModel
|
||||
from sqlalchemy.exc import UnboundExecutionError, OperationalError
|
||||
from sqlmodel import Session, select
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UserDatabase(Session):
|
||||
def __init__(self):
|
||||
super().__init__(engine)
|
||||
statement = select(User)
|
||||
try:
|
||||
self.exec(statement)
|
||||
except OperationalError:
|
||||
SQLModel.metadata.create_all(engine)
|
||||
self.add(User())
|
||||
self.commit()
|
||||
class UserDatabase:
|
||||
def __init__(self, session: Session):
|
||||
self.session = session
|
||||
|
||||
def get_user(self, username):
|
||||
statement = select(User).where(User.username == username)
|
||||
result = self.exec(statement).first()
|
||||
result = self.session.exec(statement).first()
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
return result
|
||||
|
||||
def auth_user(self, user: UserLogin) -> bool:
|
||||
statement = select(User).where(User.username == user.username)
|
||||
result = self.exec(statement).first()
|
||||
result = self.session.exec(statement).first()
|
||||
if not result:
|
||||
raise HTTPException(status_code=401, detail="User not found")
|
||||
if not verify_password(user.password, result.password):
|
||||
@@ -41,19 +32,55 @@ class UserDatabase(Session):
|
||||
def update_user(self, username, update_user: UserUpdate):
|
||||
# Update username and password
|
||||
statement = select(User).where(User.username == username)
|
||||
result = self.exec(statement).first()
|
||||
result = self.session.exec(statement).first()
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
if update_user.username:
|
||||
result.username = update_user.username
|
||||
if update_user.password:
|
||||
result.password = get_password_hash(update_user.password)
|
||||
self.add(result)
|
||||
self.commit()
|
||||
self.session.add(result)
|
||||
self.session.commit()
|
||||
return result
|
||||
|
||||
def merge_old_user(self):
|
||||
# get old data
|
||||
statement = """
|
||||
SELECT * FROM user
|
||||
"""
|
||||
result = self.session.exec(statement).first()
|
||||
if not result:
|
||||
return
|
||||
# add new data
|
||||
user = User(username=result.username, password=result.password)
|
||||
# Drop old table
|
||||
statement = """
|
||||
DROP TABLE user
|
||||
"""
|
||||
self.session.exec(statement)
|
||||
# Create new table
|
||||
statement = """
|
||||
CREATE TABLE user (
|
||||
id INTEGER NOT NULL PRIMARY KEY,
|
||||
username VARCHAR NOT NULL,
|
||||
password VARCHAR NOT NULL
|
||||
)
|
||||
"""
|
||||
self.session.exec(statement)
|
||||
self.session.add(user)
|
||||
self.session.commit()
|
||||
|
||||
if __name__ == "__main__":
|
||||
with UserDatabase() as db:
|
||||
# db.update_user(UserLogin(username="admin", password="adminadmin"), User(username="admin", password="cica1234"))
|
||||
db.update_user("admin", UserUpdate(username="estrella", password="cica1234"))
|
||||
def add_default_user(self):
|
||||
# Check if user exists
|
||||
statement = select(User)
|
||||
try:
|
||||
result = self.session.exec(statement).all()
|
||||
except Exception as e:
|
||||
self.merge_old_user()
|
||||
result = self.session.exec(statement).all()
|
||||
if len(result) != 0:
|
||||
return
|
||||
# Add default user
|
||||
user = User(username="admin", password=get_password_hash("adminadmin"))
|
||||
self.session.add(user)
|
||||
self.session.commit()
|
||||
|
||||
Reference in New Issue
Block a user