From 615f50c51a50b7b6da0b3ae38cae06110cfeea6e Mon Sep 17 00:00:00 2001 From: EstrellaXD Date: Fri, 12 May 2023 11:09:27 +0800 Subject: [PATCH] Init auth module --- src/module/api/auth.py | 53 ++++++++++++++++++++++++++++++++ src/module/database/torrent.py | 0 src/module/database/user.py | 42 +++++++++++++++++++++++++ src/module/models/__init__.py | 1 + src/module/models/user.py | 39 ++++++++++++++++++++++++ src/module/security/__init__.py | 0 src/module/security/jwt.py | 54 +++++++++++++++++++++++++++++++++ 7 files changed, 189 insertions(+) create mode 100644 src/module/api/auth.py create mode 100644 src/module/database/torrent.py create mode 100644 src/module/database/user.py create mode 100644 src/module/models/user.py create mode 100644 src/module/security/__init__.py create mode 100644 src/module/security/jwt.py diff --git a/src/module/api/auth.py b/src/module/api/auth.py new file mode 100644 index 00000000..1fe14590 --- /dev/null +++ b/src/module/api/auth.py @@ -0,0 +1,53 @@ +from fastapi import APIRouter, Depends, HTTPException, status +from fastapi.security import OAuth2PasswordRequestForm + +from module.database.user import AuthDB +from module.security.jwt import decode_token, oauth2_scheme + +from .api import router + + +async def get_current_user(token: str = Depends(oauth2_scheme)): + if not token: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token") + payload = decode_token(token) + if not payload: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token") + username = payload.get("sub") + user = user_db.get_user(username) + if not user: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid username") + return user + + +async def get_token_data(token: str = Depends(oauth2_scheme)): + payload = decode_token(token) + if not payload: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token") + return payload + + +@router.get("/api/v1/auth/login", response_model=dict, tags=["login"]) +async def login(form_data: OAuth2PasswordRequestForm = Depends()): + username = form_data.username + password = form_data.password + with AuthDB() as db: + if not db.authenticate(username, password): + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid username or password") + token = db.create_access_token(username) + return {"access_token": token, "token_type": "bearer"} + + +@router.get("/api/v1/auth/logout", response_model=dict, tags=["login"]) +async def logout(token_data: dict = Depends(get_token_data)): + pass + +@router.get("/api/v1/auth/refresh", response_model=dict, tags=["login"]) +async def refresh( + current_user: User = Depends(get_token_data), +): + if not current_user: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token") + token = create_access_token(current_user) + return {"access_token": token, "token_type": "bearer"} + diff --git a/src/module/database/torrent.py b/src/module/database/torrent.py new file mode 100644 index 00000000..e69de29b diff --git a/src/module/database/user.py b/src/module/database/user.py new file mode 100644 index 00000000..d4964e66 --- /dev/null +++ b/src/module/database/user.py @@ -0,0 +1,42 @@ +from fastapi import HTTPException + +from .connector import DataConnector + +from module.security.jwt import get_password_hash, verify_password +from module.models import UserLogin + + +class AuthDB(DataConnector): + def update_table(self): + pass + + def auth_user(self, user: UserLogin) -> bool: + username = user.username + password = user.password + self._cursor.execute(f"SELECT * FROM user WHERE username='{username}'") + result = self._cursor.fetchone() + if not result: + raise HTTPException(status_code=404, detail="User not found") + if not verify_password(password, result[2]): + raise HTTPException(status_code=401, detail="Password error") + return True + + def update_user(self, user: UserUpdate): + # Update username and password + username = user.username + new_username = user.new_username + password = user.password + new_password = user.new_password + self._cursor.execute(f"SELECT * FROM user WHERE username='{username}'") + result = self._cursor.fetchone() + if not result: + raise HTTPException(status_code=404, detail="User not found") + if not verify_password(password, result[2]): + raise HTTPException(status_code=401, detail="Password error") + self._cursor.execute(""" + UPDATE user + SET username=%s, password=%s + WHERE username=%s + """, (new_username, get_password_hash(new_password), username)) + self._conn.commit() + return True diff --git a/src/module/models/__init__.py b/src/module/models/__init__.py index 4cf8bf6a..96b1034e 100644 --- a/src/module/models/__init__.py +++ b/src/module/models/__init__.py @@ -1,3 +1,4 @@ from .bangumi import * from .config import Config from .torrent import EpisodeFile, SubtitleFile +from .user import UserLogin diff --git a/src/module/models/user.py b/src/module/models/user.py new file mode 100644 index 00000000..4e976be3 --- /dev/null +++ b/src/module/models/user.py @@ -0,0 +1,39 @@ +from typing import Optional +from pydantic import BaseModel, Field +from datetime import datetime + + +class UserBase(BaseModel): + username: str = Field(..., min_length=4, max_length=20, regex=r"^[a-zA-Z0-9_]+$") + + +class UserCreate(UserBase): + password: str = Field(..., min_length=8) + + +class UserUpdate(UserBase): + password: str = Field(..., min_length=8) + + +class User(UserBase): + user_id: int + password: str = Field(..., min_length=8) + id: str = Field(..., alias="_id") + + +class UserInDB(UserBase): + password: str = Field(..., min_length=8) + + +class UserLogin(BaseModel): + username: str + password: str = Field(..., min_length=8) + + +class Token(BaseModel): + token: str + token_type: str + + +class TokenData(BaseModel): + username: str | None = None \ No newline at end of file diff --git a/src/module/security/__init__.py b/src/module/security/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/module/security/jwt.py b/src/module/security/jwt.py new file mode 100644 index 00000000..814f91ff --- /dev/null +++ b/src/module/security/jwt.py @@ -0,0 +1,54 @@ +from datetime import datetime, timedelta +from typing import Optional +from passlib.context import CryptContext +from jose import jwt, JWTError + + +# app_pwd_key = settings.KEY +# app_pwd_algorithm = settings.ALGORITHM + +# Hashing 密码 +app_pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") + + +# 创建 JWT Token +def create_access_token(data: dict, expires_delta: timedelta | None = None): + to_encode = data.copy() + if expires_delta: + expire = datetime.utcnow() + expires_delta + else: + expire = datetime.utcnow() + timedelta(minutes=120) + to_encode.update({"exp": expire}) + encoded_jwt = jwt.encode(to_encode, app_pwd_key, algorithm=app_pwd_algorithm) + return encoded_jwt + + +# 解码 Token +def decode_token(token: str): + try: + payload = jwt.decode(token, app_pwd_key, algorithms=[app_pwd_algorithm]) + username = payload.get("sub") + if username is None: + return None + return payload + except JWTError as e: + raise e + + +def verify_token(token: str): + token_data = decode_token(token) + if token_data is None: + return None + expires = token_data.get("exp") + if datetime.utcnow() >= datetime.fromtimestamp(expires): + raise JWTError("Token expired") + return create_access_token(data={"sub": token_data.get("sub")}) + + +# 密码加密&验证 +def verify_password(plain_password, hashed_password): + return app_pwd_context.verify(plain_password, hashed_password) + + +def get_password_hash(password): + return app_pwd_context.hash(password)