feat(security): replace validation with Depends for system endpoints

This commit is contained in:
InfinityPacer
2024-10-08 18:12:40 +08:00
parent 4af57d9857
commit 4dd146d1c8
2 changed files with 33 additions and 37 deletions

View File

@@ -12,7 +12,7 @@ from app.chain.search import SearchChain
from app.chain.system import SystemChain
from app.core.config import settings, global_vars
from app.core.module import ModuleManager
from app.core.security import verify_token
from app.core.security import verify_token, verify_uri_token
from app.db.models import User
from app.db.systemconfig_oper import SystemConfigOper
from app.db.user_oper import get_current_active_superuser
@@ -137,16 +137,10 @@ def set_env_setting(env: dict,
@router.get("/progress/{process_type}", summary="实时进度")
def get_progress(process_type: str, token: str):
def get_progress(process_type: str, _: schemas.TokenPayload = Depends(verify_token)):
"""
实时获取处理进度返回格式为SSE
"""
if not token or not verify_token(token):
raise HTTPException(
status_code=403,
detail="认证失败!",
)
progress = ProgressHelper()
def event_generator():
@@ -192,16 +186,10 @@ def set_setting(key: str, value: Union[list, dict, bool, int, str] = None,
@router.get("/message", summary="实时消息")
def get_message(token: str, role: str = "system"):
def get_message(role: str = "system", _: schemas.TokenPayload = Depends(verify_uri_token)):
"""
实时获取系统消息返回格式为SSE
"""
if not token or not verify_token(token):
raise HTTPException(
status_code=403,
detail="认证失败!",
)
message = MessageHelper()
def event_generator():
@@ -216,18 +204,12 @@ def get_message(token: str, role: str = "system"):
@router.get("/logging", summary="实时日志")
def get_logging(token: str, length: int = 50, logfile: str = "moviepilot.log"):
def get_logging(length: int = 50, logfile: str = "moviepilot.log", _: schemas.TokenPayload = Depends(verify_uri_token)):
"""
实时获取系统日志
length = -1 时, 返回text/plain
否则 返回格式SSE
"""
if not token or not verify_token(token):
raise HTTPException(
status_code=403,
detail="认证失败!",
)
log_path = settings.LOG_PATH / logfile
def log_generator():

View File

@@ -27,6 +27,9 @@ oauth2_scheme = OAuth2PasswordBearer(
tokenUrl=f"{settings.API_V1_STR}/login/access-token"
)
# JWT TOKEN 通过 QUERY 认证
jwt_token_query = APIKeyQuery(name="token", auto_error=False, scheme_name="jwt_token_query")
# API TOKEN 通过 QUERY 认证
api_token_query = APIKeyQuery(name="token", auto_error=False, scheme_name="api_token_query")
@@ -75,13 +78,13 @@ def create_access_token(
return encoded_jwt
def verify_token(token: str = Security(oauth2_scheme)) -> schemas.TokenPayload:
"""
使用 JWT Token 进行身份认证并解析 Token 的内容
:param token: JWT 令牌,从请求的 Authorization 头部获取
:return: 包含用户身份信息的 Token 负载数据
:raises HTTPException: 如果令牌无效或解码失败,抛出 403 错误
def __verify_token(token: str) -> schemas.TokenPayload:
"""
使用 JWT Token 进行身份认证并解析 Token 的内容
:param token: JWT 令牌,从请求的 Authorization 头部获取
:return: 包含用户身份信息的 Token 负载数据
:raises HTTPException: 如果令牌无效或解码失败,抛出 403 错误
"""
try:
payload = jwt.decode(
token, settings.SECRET_KEY, algorithms=[ALGORITHM]
@@ -94,6 +97,26 @@ def verify_token(token: str = Security(oauth2_scheme)) -> schemas.TokenPayload:
)
def verify_token(token: str = Security(oauth2_scheme)) -> schemas.TokenPayload:
"""
使用 JWT Token 进行身份认证并解析 Token 的内容
:param token: JWT 令牌,从请求的 Authorization 头部获取
:return: 包含用户身份信息的 Token 负载数据
:raises HTTPException: 如果令牌无效或解码失败,抛出 403 错误
"""
return __verify_token(token)
def verify_uri_token(token: str = Security(jwt_token_query)) -> schemas.TokenPayload:
"""
使用 JWT Token 进行身份认证并解析 Token 的内容
:param token: JWT 令牌,从请求的 Authorization 头部获取
:return: 包含用户身份信息的 Token 负载数据
:raises HTTPException: 如果令牌无效或解码失败,抛出 403 错误
"""
return __verify_token(token)
def __get_api_token(
token_query: Annotated[str | None, Security(api_token_query)] = None
) -> str:
@@ -153,15 +176,6 @@ def verify_apikey(apikey: str = Security(__get_api_key)) -> str:
return __verify_key(apikey, settings.API_TOKEN, "API_KEY")
def verify_uri_token(token: str = Security(__get_api_token)) -> str:
"""
使用 API Token 进行身份认证
:param token: API Token从 URL 查询参数中获取
:return: 返回校验通过的 API Token
"""
return __verify_key(token, settings.API_TOKEN, "API_TOKEN")
def verify_password(plain_password: str, hashed_password: str) -> bool:
return pwd_context.verify(plain_password, hashed_password)