diff --git a/app/api/endpoints/system.py b/app/api/endpoints/system.py index 7704c788..3b3e923d 100644 --- a/app/api/endpoints/system.py +++ b/app/api/endpoints/system.py @@ -12,7 +12,7 @@ from app.chain.search import SearchChain from app.chain.system import SystemChain from app.core.config import settings, global_vars from app.core.module import ModuleManager -from app.core.security import verify_token +from app.core.security import verify_token, verify_uri_token from app.db.models import User from app.db.systemconfig_oper import SystemConfigOper from app.db.user_oper import get_current_active_superuser @@ -137,16 +137,10 @@ def set_env_setting(env: dict, @router.get("/progress/{process_type}", summary="实时进度") -def get_progress(process_type: str, token: str): +def get_progress(process_type: str, _: schemas.TokenPayload = Depends(verify_token)): """ 实时获取处理进度,返回格式为SSE """ - if not token or not verify_token(token): - raise HTTPException( - status_code=403, - detail="认证失败!", - ) - progress = ProgressHelper() def event_generator(): @@ -192,16 +186,10 @@ def set_setting(key: str, value: Union[list, dict, bool, int, str] = None, @router.get("/message", summary="实时消息") -def get_message(token: str, role: str = "system"): +def get_message(role: str = "system", _: schemas.TokenPayload = Depends(verify_uri_token)): """ 实时获取系统消息,返回格式为SSE """ - if not token or not verify_token(token): - raise HTTPException( - status_code=403, - detail="认证失败!", - ) - message = MessageHelper() def event_generator(): @@ -216,18 +204,12 @@ def get_message(token: str, role: str = "system"): @router.get("/logging", summary="实时日志") -def get_logging(token: str, length: int = 50, logfile: str = "moviepilot.log"): +def get_logging(length: int = 50, logfile: str = "moviepilot.log", _: schemas.TokenPayload = Depends(verify_uri_token)): """ 实时获取系统日志 length = -1 时, 返回text/plain 否则 返回格式SSE """ - if not token or not verify_token(token): - raise HTTPException( - status_code=403, - detail="认证失败!", - ) - log_path = settings.LOG_PATH / logfile def log_generator(): diff --git a/app/core/security.py b/app/core/security.py index 1f96190e..2f8628c4 100644 --- a/app/core/security.py +++ b/app/core/security.py @@ -27,6 +27,9 @@ oauth2_scheme = OAuth2PasswordBearer( tokenUrl=f"{settings.API_V1_STR}/login/access-token" ) +# JWT TOKEN 通过 QUERY 认证 +jwt_token_query = APIKeyQuery(name="token", auto_error=False, scheme_name="jwt_token_query") + # API TOKEN 通过 QUERY 认证 api_token_query = APIKeyQuery(name="token", auto_error=False, scheme_name="api_token_query") @@ -75,13 +78,13 @@ def create_access_token( return encoded_jwt -def verify_token(token: str = Security(oauth2_scheme)) -> schemas.TokenPayload: - """ - 使用 JWT Token 进行身份认证并解析 Token 的内容 - :param token: JWT 令牌,从请求的 Authorization 头部获取 - :return: 包含用户身份信息的 Token 负载数据 - :raises HTTPException: 如果令牌无效或解码失败,抛出 403 错误 +def __verify_token(token: str) -> schemas.TokenPayload: """ + 使用 JWT Token 进行身份认证并解析 Token 的内容 + :param token: JWT 令牌,从请求的 Authorization 头部获取 + :return: 包含用户身份信息的 Token 负载数据 + :raises HTTPException: 如果令牌无效或解码失败,抛出 403 错误 + """ try: payload = jwt.decode( token, settings.SECRET_KEY, algorithms=[ALGORITHM] @@ -94,6 +97,26 @@ def verify_token(token: str = Security(oauth2_scheme)) -> schemas.TokenPayload: ) +def verify_token(token: str = Security(oauth2_scheme)) -> schemas.TokenPayload: + """ + 使用 JWT Token 进行身份认证并解析 Token 的内容 + :param token: JWT 令牌,从请求的 Authorization 头部获取 + :return: 包含用户身份信息的 Token 负载数据 + :raises HTTPException: 如果令牌无效或解码失败,抛出 403 错误 + """ + return __verify_token(token) + + +def verify_uri_token(token: str = Security(jwt_token_query)) -> schemas.TokenPayload: + """ + 使用 JWT Token 进行身份认证并解析 Token 的内容 + :param token: JWT 令牌,从请求的 Authorization 头部获取 + :return: 包含用户身份信息的 Token 负载数据 + :raises HTTPException: 如果令牌无效或解码失败,抛出 403 错误 + """ + return __verify_token(token) + + def __get_api_token( token_query: Annotated[str | None, Security(api_token_query)] = None ) -> str: @@ -153,15 +176,6 @@ def verify_apikey(apikey: str = Security(__get_api_key)) -> str: return __verify_key(apikey, settings.API_TOKEN, "API_KEY") -def verify_uri_token(token: str = Security(__get_api_token)) -> str: - """ - 使用 API Token 进行身份认证 - :param token: API Token,从 URL 查询参数中获取 - :return: 返回校验通过的 API Token - """ - return __verify_key(token, settings.API_TOKEN, "API_TOKEN") - - def verify_password(plain_password: str, hashed_password: str) -> bool: return pwd_context.verify(plain_password, hashed_password)