diff --git a/requirements.txt b/requirements.txt index 1c2ff2be..2781feb3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ anyio -beautifulsoup4 +bs4 certifi charset-normalizer click @@ -19,3 +19,7 @@ uvicorn attrdict jinja2 python-dotenv +python-jose +passlib +bcrypt +python-multipart diff --git a/src/module/api/auth.py b/src/module/api/auth.py index f2ab6a68..88229779 100644 --- a/src/module/api/auth.py +++ b/src/module/api/auth.py @@ -1,10 +1,13 @@ -from fastapi import APIRouter, Depends, HTTPException, status -from fastapi.security import OAuth2PasswordRequestForm +from fastapi import Depends, HTTPException, status +from fastapi.security import OAuth2PasswordRequestForm, OAuth2PasswordBearer from module.database.user import AuthDB -from module.security.jwt import decode_token, oauth2_scheme +from module.security.jwt import create_access_token, decode_token +from module.models.user import User -from .log import router +from .program import router + +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login") async def get_current_user(token: str = Depends(oauth2_scheme)): @@ -14,7 +17,8 @@ async def get_current_user(token: str = Depends(oauth2_scheme)): if not payload: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token") username = payload.get("sub") - user = user_db.get_user(username) + with AuthDB() as user_db: + user = user_db.get_user(username) if not user: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid username") return user @@ -27,27 +31,43 @@ async def get_token_data(token: str = Depends(oauth2_scheme)): return payload -@router.get("/api/v1/auth/login", response_model=dict, tags=["login"]) +@router.post("/api/v1/auth/login", response_model=dict, tags=["login"]) async def login(form_data: OAuth2PasswordRequestForm = Depends()): username = form_data.username password = form_data.password with AuthDB() as db: - if not db.authenticate(username, password): + if not db.auth_user(username, password): raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid username or password") - token = db.create_access_token(username) - return {"access_token": token, "token_type": "bearer"} + token = create_access_token({"sub": username}) + return {"access_token": token, "token_type": "bearer", "expire": 86400} -@router.get("/api/v1/auth/logout", response_model=dict, tags=["login"]) -async def logout(token_data: dict = Depends(get_token_data)): - pass - -@router.get("/api/v1/auth/refresh", response_model=dict, tags=["login"]) +@router.get("/api/v1/auth/refresh_token", response_model=dict, tags=["login"]) async def refresh( - current_user: User = Depends(get_token_data), + current_user: User = Depends(get_current_user) ): if not current_user: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token") - token = create_access_token(current_user) - return {"access_token": token, "token_type": "bearer"} + token = create_access_token({"sub": current_user.username}) + return {"access_token": token, "token_type": "bearer", "expire": 86400} + +@router.get("/api/v1/auth/logout", response_model=dict, tags=["login"]) +async def logout( + current_user: User = Depends(get_current_user) +): + if not current_user: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token") + return {"message": "logout success"} + + +@router.post("/api/v1/auth/update", response_model=dict, tags=["users"]) +async def update_user(data: User, current_user: User = Depends(get_current_user)): + if not current_user: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token") + try: + with AuthDB() as db: + db.update_user(current_user.username, data) + return {"message": "update success"} + except Exception as e: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) diff --git a/src/module/api/log.py b/src/module/api/log.py index 9415ac82..bd406365 100644 --- a/src/module/api/log.py +++ b/src/module/api/log.py @@ -1,7 +1,7 @@ import os from fastapi import Response -from .program import router +from .auth import router from module.conf import LOG_PATH diff --git a/src/module/core/status.py b/src/module/core/status.py index f0a83c86..1e2f3fc2 100644 --- a/src/module/core/status.py +++ b/src/module/core/status.py @@ -3,7 +3,6 @@ import threading import asyncio from module.checker import Checker -from module.conf import DATA_PATH class ProgramStatus(Checker): diff --git a/src/module/database/bangumi.py b/src/module/database/bangumi.py index 3b669fc8..9588ad4d 100644 --- a/src/module/database/bangumi.py +++ b/src/module/database/bangumi.py @@ -10,7 +10,10 @@ class BangumiDatabase(DataConnector): def __init__(self): super().__init__() self.__table_name = "bangumi" - self.update_table() + self.__updated = False + if not self.__updated: + self.update_table() + self.__updated = True def update_table(self): db_data = self.__data_to_db(BangumiData()) diff --git a/src/module/database/connector.py b/src/module/database/connector.py index fad7e347..03060f53 100644 --- a/src/module/database/connector.py +++ b/src/module/database/connector.py @@ -73,6 +73,10 @@ class DataConnector: self._cursor.execute(f"DELETE FROM {table_name}") self._conn.commit() + def _table_exists(self, table_name: str) -> bool: + self._cursor.execute(f"SELECT name FROM sqlite_master WHERE type='table' AND name='{table_name}'") + return len(self._cursor.fetchall()) == 1 + @staticmethod def __python_to_sqlite_type(value) -> str: if isinstance(value, int): diff --git a/src/module/database/user.py b/src/module/database/user.py index 31bb9eaf..ddca0cc8 100644 --- a/src/module/database/user.py +++ b/src/module/database/user.py @@ -1,54 +1,67 @@ +import logging + from fastapi import HTTPException -from .connector import DataConnector +from module.database.connector import DataConnector from module.security.jwt import get_password_hash, verify_password -from module.models import UserLogin +from module.models.user import UserLogin, User + +logger = logging.getLogger(__name__) class AuthDB(DataConnector): - def update_table(self): - table_name = "user" - db_data = self.__data_to_db(UserLogin()) - self._update_table(table_name, db_data) + def __init__(self): + super().__init__() + self.__table_name = "user" + if not self._table_exists(self.__table_name): + self.__update_table() + + def __update_table(self): + db_data = self.__data_to_db(User()) + self._update_table(self.__table_name, db_data) + self._insert(self.__table_name, db_data) @staticmethod - def __data_to_db(data: UserLogin) -> dict: + def __data_to_db(data: User) -> dict: db_data = data.dict() db_data["password"] = get_password_hash(db_data["password"]) return db_data @staticmethod - def __db_to_data(db_data: dict) -> UserLogin: - return UserLogin(**db_data) + def __db_to_data(db_data: dict) -> User: + return User(**db_data) - def auth_user(self, user: UserLogin) -> bool: - username = user.username - password = user.password - self._cursor.execute(f"SELECT * FROM user WHERE username='{username}'") + def get_user(self, username): + self._cursor.execute(f"SELECT * FROM {self.__table_name} WHERE username='{username}'") + result = self._cursor.fetchone() + if not result: + return None + db_data = dict(zip([x[0] for x in self._cursor.description], result)) + return self.__db_to_data(db_data) + + def auth_user(self, username, password) -> bool: + self._cursor.execute(f"SELECT username, password FROM {self.__table_name} WHERE username='{username}'") result = self._cursor.fetchone() if not result: raise HTTPException(status_code=404, detail="User not found") - if not verify_password(password, result[2]): + if not verify_password(password, result[1]): raise HTTPException(status_code=401, detail="Password error") return True - def update_user(self, user: UserUpdate): + def update_user(self, username, update_user: User): # Update username and password - username = user.username - new_username = user.new_username - password = user.password - new_password = user.new_password - self._cursor.execute(f"SELECT * FROM user WHERE username='{username}'") - result = self._cursor.fetchone() - if not result: - raise HTTPException(status_code=404, detail="User not found") - if not verify_password(password, result[2]): - raise HTTPException(status_code=401, detail="Password error") - self._cursor.execute(""" - UPDATE user - SET username=%s, password=%s - WHERE username=%s - """, (new_username, get_password_hash(new_password), username)) + new_username = update_user.username + new_password = update_user.password + self._cursor.execute(f""" + UPDATE {self.__table_name} + SET username = '{new_username}', password = '{get_password_hash(new_password)}' + WHERE username = '{username}' + """) self._conn.commit() - return True + + +if __name__ == "__main__": + with AuthDB() as db: + # db.update_user(UserLogin(username="admin", password="adminadmin"), User(username="admin", password="cica1234")) + db.update_user("admin", User(username="estrella", password="cica1234")) \ No newline at end of file diff --git a/src/module/models/user.py b/src/module/models/user.py index 253e7480..12031bfd 100644 --- a/src/module/models/user.py +++ b/src/module/models/user.py @@ -1,27 +1,9 @@ -from typing import Optional from pydantic import BaseModel, Field -from datetime import datetime -class UserBase(BaseModel): - username: str = Field(..., min_length=4, max_length=20, regex=r"^[a-zA-Z0-9_]+$") - - -class UserCreate(UserBase): - password: str = Field(..., min_length=8) - - -class UserUpdate(UserBase): - password: str = Field(..., min_length=8) - - -class User(UserBase): - id: int = Field(..., alias="_id") - password: str = Field(..., min_length=8) - - -class UserInDB(UserBase): - password: str = Field(..., min_length=8) +class User(BaseModel): + username: str = Field("admin", min_length=4, max_length=20, regex=r"^[a-zA-Z0-9_]+$") + password: str = Field("adminadmin", min_length=8) class UserLogin(BaseModel): diff --git a/src/module/security/jwt.py b/src/module/security/jwt.py index 814f91ff..aaeb04b0 100644 --- a/src/module/security/jwt.py +++ b/src/module/security/jwt.py @@ -1,11 +1,10 @@ from datetime import datetime, timedelta -from typing import Optional from passlib.context import CryptContext from jose import jwt, JWTError -# app_pwd_key = settings.KEY -# app_pwd_algorithm = settings.ALGORITHM +app_pwd_key = "auto_bangumi" +app_pwd_algorithm = "HS256" # Hashing 密码 app_pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") @@ -17,7 +16,7 @@ def create_access_token(data: dict, expires_delta: timedelta | None = None): if expires_delta: expire = datetime.utcnow() + expires_delta else: - expire = datetime.utcnow() + timedelta(minutes=120) + expire = datetime.utcnow() + timedelta(minutes=1440) to_encode.update({"exp": expire}) encoded_jwt = jwt.encode(to_encode, app_pwd_key, algorithm=app_pwd_algorithm) return encoded_jwt