From 5c0584550046ea20feecc744130dd64c1bfcfc16 Mon Sep 17 00:00:00 2001 From: InfinityPacer <160988576+InfinityPacer@users.noreply.github.com> Date: Mon, 7 Oct 2024 16:35:39 +0800 Subject: [PATCH] refactor(security): replace Depends with Security and define schemes --- app/core/security.py | 126 ++++++++++++++++++++++++++++++------------- 1 file changed, 88 insertions(+), 38 deletions(-) diff --git a/app/core/security.py b/app/core/security.py index 06bf02fa..1f96190e 100644 --- a/app/core/security.py +++ b/app/core/security.py @@ -5,39 +5,64 @@ import json import os import traceback from datetime import datetime, timedelta -from typing import Any, Union, Optional, Annotated +from typing import Any, Union, Annotated, Optional + import jwt from Crypto.Cipher import AES from Crypto.Util.Padding import pad -from fastapi import HTTPException, status, Depends, Header -from fastapi.security import OAuth2PasswordBearer +from cryptography.fernet import Fernet +from fastapi import HTTPException, status, Security +from fastapi.security import OAuth2PasswordBearer, APIKeyHeader, APIKeyQuery from passlib.context import CryptContext from app import schemas from app.core.config import settings -from cryptography.fernet import Fernet - from app.log import logger pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") ALGORITHM = "HS256" -# Token认证 -reusable_oauth2 = OAuth2PasswordBearer( +# OAuth2PasswordBearer 用于 JWT Token 认证 +oauth2_scheme = OAuth2PasswordBearer( tokenUrl=f"{settings.API_V1_STR}/login/access-token" ) +# API TOKEN 通过 QUERY 认证 +api_token_query = APIKeyQuery(name="token", auto_error=False, scheme_name="api_token_query") + +# API KEY 通过 Header 认证 +api_key_header = APIKeyHeader(name="X-API-KEY", auto_error=False, scheme_name="api_key_header") + +# API KEY 通过 QUERY 认证 +api_key_query = APIKeyQuery(name="apikey", auto_error=False, scheme_name="api_key_query") + def create_access_token( - userid: Union[str, Any], username: str, super_user: bool = False, - expires_delta: timedelta = None, level: int = 1 + userid: Union[str, Any], + username: str, + super_user: bool = False, + expires_delta: Optional[timedelta] = None, + level: int = 1 ) -> str: - if expires_delta: + """ + 创建 JWT 访问令牌,包含用户 ID、用户名、是否为超级用户以及权限等级 + :param userid: 用户的唯一标识符,通常是字符串或整数 + :param username: 用户名,用于标识用户的账户名 + :param super_user: 是否为超级用户,默认值为 False + :param expires_delta: 令牌的有效期时长,如果不提供则使用默认过期时间 + :param level: 用户的权限级别,默认为 1 + :return: 编码后的 JWT 令牌字符串 + :raises ValueError: 如果 expires_delta 为负数 + """ + if expires_delta is not None: + if expires_delta.total_seconds() <= 0: + raise ValueError("过期时间必须为正数") expire = datetime.utcnow() + expires_delta else: expire = datetime.utcnow() + timedelta( minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES ) + to_encode = { "exp": expire, "sub": str(userid), @@ -45,11 +70,18 @@ def create_access_token( "super_user": super_user, "level": level } + encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=ALGORITHM) return encoded_jwt -def verify_token(token: str = Depends(reusable_oauth2)) -> schemas.TokenPayload: +def verify_token(token: str = Security(oauth2_scheme)) -> schemas.TokenPayload: + """ + 使用 JWT Token 进行身份认证并解析 Token 的内容 + :param token: JWT 令牌,从请求的 Authorization 头部获取 + :return: 包含用户身份信息的 Token 负载数据 + :raises HTTPException: 如果令牌无效或解码失败,抛出 403 错误 + """ try: payload = jwt.decode( token, settings.SECRET_KEY, algorithms=[ALGORITHM] @@ -62,54 +94,72 @@ def verify_token(token: str = Depends(reusable_oauth2)) -> schemas.TokenPayload: ) -def __get_token(token: str = None) -> str: +def __get_api_token( + token_query: Annotated[str | None, Security(api_token_query)] = None +) -> str: """ - 从请求URL中获取token + 从 URL 查询参数中获取 API Token + :param token_query: 从 URL 中的 `token` 查询参数获取 API Token + :return: 返回获取到的 API Token,若无则返回 None """ - return token + return token_query -def __get_apikey(apikey: str = None, x_api_key: Annotated[str | None, Header()] = None) -> str: +def __get_api_key( + key_query: Annotated[str | None, Security(api_key_query)] = None, + key_header: Annotated[str | None, Security(api_key_header)] = None +) -> str: """ - 从请求URL中获取apikey + 从 URL 查询参数或请求头部获取 API Key,优先使用 URL 参数 + :param key_query: URL 中的 `apikey` 查询参数 + :param key_header: 请求头中的 `X-API-KEY` 参数 + :return: 返回从 URL 或请求头中获取的 API Key,若无则返回 None """ - return apikey or x_api_key + return key_query or key_header -def verify_apitoken(token: str = Depends(__get_token)) -> str: +def __verify_key(key: str, expected_key: str, key_type: str) -> str: """ - 通过依赖项使用token进行身份认证 + 通用的 API Key 或 Token 验证函数 + :param key: 从请求中获取的 API Key 或 Token + :param expected_key: 系统配置中的期望值,用于验证的 API Key 或 Token + :param key_type: 键的类型(例如 "API_KEY" 或 "API_TOKEN"),用于错误消息 + :return: 返回校验通过的 API Key 或 Token + :raises HTTPException: 如果校验不通过,抛出 401 错误 """ - if token != settings.API_TOKEN: + if key != expected_key: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, - detail="token校验不通过" + detail=f"{key_type} 校验不通过" ) - return token + return key -def verify_apikey(apikey: str = Depends(__get_apikey)) -> str: +def verify_apitoken(token: str = Security(__get_api_token)) -> str: """ - 通过依赖项使用apikey进行身份认证 + 使用 API Token 进行身份认证 + :param token: API Token,从 URL 查询参数中获取 + :return: 返回校验通过的 API Token """ - if apikey != settings.API_TOKEN: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="apikey校验不通过" - ) - return apikey + return __verify_key(token, settings.API_TOKEN, "API_TOKEN") -def verify_uri_token(token: str = Depends(__get_token)) -> str: +def verify_apikey(apikey: str = Security(__get_api_key)) -> str: """ - 通过依赖项使用token进行身份认证 + 使用 API Key 进行身份认证 + :param apikey: API Key,从 URL 查询参数或请求头中获取 + :return: 返回校验通过的 API Key """ - if not verify_token(token): - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="token校验不通过" - ) - return token + return __verify_key(apikey, settings.API_TOKEN, "API_KEY") + + +def verify_uri_token(token: str = Security(__get_api_token)) -> str: + """ + 使用 API Token 进行身份认证 + :param token: API Token,从 URL 查询参数中获取 + :return: 返回校验通过的 API Token + """ + return __verify_key(token, settings.API_TOKEN, "API_TOKEN") def verify_password(plain_password: str, hashed_password: str) -> bool: