mirror of
https://github.com/EstrellaXD/Auto_Bangumi.git
synced 2026-03-20 03:46:40 +08:00
fix(security): harden auth, JWT, WebAuthn, and API endpoints
- Persist JWT secret to config/.jwt_secret (survives restarts) - Change active_user from list to dict with timestamps - Extract username from cookie token instead of list index - Add SSRF protection (_validate_url) for setup test endpoints - Mask sensitive config fields (password, api_key, token, secret) - Add auth guards to notification test endpoints - Fix path traversal in /posters endpoint using resolved path check - Add CORS middleware with empty allow_origins - WebAuthn: add challenge TTL (300s), max capacity (100), cleanup - Remove hardcoded default password from User model - Use timezone-aware datetime in passkey models - Adapt unit tests for active_user dict and cookie-based auth Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -1,9 +1,11 @@
|
||||
import logging
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import FileResponse, HTMLResponse, RedirectResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.templating import Jinja2Templates
|
||||
@@ -44,6 +46,14 @@ async def lifespan(app: FastAPI):
|
||||
def create_app() -> FastAPI:
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=[],
|
||||
allow_credentials=True,
|
||||
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# mount routers
|
||||
app.include_router(v1, prefix="/api")
|
||||
|
||||
@@ -56,12 +66,15 @@ def create_app() -> FastAPI:
|
||||
app = create_app()
|
||||
|
||||
|
||||
_POSTERS_BASE = Path("data/posters").resolve()
|
||||
|
||||
|
||||
@app.get("/posters/{path:path}", tags=["posters"])
|
||||
def posters(path: str):
|
||||
# prevent path traversal
|
||||
if ".." in path:
|
||||
resolved = (_POSTERS_BASE / path).resolve()
|
||||
if not str(resolved).startswith(str(_POSTERS_BASE)):
|
||||
return HTMLResponse(status_code=403)
|
||||
return FileResponse(f"data/posters/{path}")
|
||||
return FileResponse(str(resolved))
|
||||
|
||||
|
||||
if VERSION != "DEV_VERSION":
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from datetime import timedelta
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi import APIRouter, Cookie, Depends, HTTPException, status
|
||||
from fastapi.responses import JSONResponse, Response
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
|
||||
@@ -12,7 +12,7 @@ from module.security.api import (
|
||||
get_current_user,
|
||||
update_user_info,
|
||||
)
|
||||
from module.security.jwt import create_access_token
|
||||
from module.security.jwt import create_access_token, decode_token
|
||||
|
||||
from .response import u_response
|
||||
|
||||
@@ -35,19 +35,29 @@ async def login(response: Response, form_data=Depends(OAuth2PasswordRequestForm)
|
||||
@router.get(
|
||||
"/refresh_token", response_model=dict, dependencies=[Depends(get_current_user)]
|
||||
)
|
||||
async def refresh(response: Response):
|
||||
token = create_access_token(
|
||||
data={"sub": active_user[0]}, expires_delta=timedelta(days=1)
|
||||
async def refresh(response: Response, token: str = Cookie(None)):
|
||||
payload = decode_token(token)
|
||||
username = payload.get("sub") if payload else None
|
||||
if not username:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Unauthorized"
|
||||
)
|
||||
active_user[username] = datetime.now()
|
||||
new_token = create_access_token(
|
||||
data={"sub": username}, expires_delta=timedelta(days=1)
|
||||
)
|
||||
response.set_cookie(key="token", value=token, httponly=True, max_age=86400)
|
||||
return {"access_token": token, "token_type": "bearer"}
|
||||
response.set_cookie(key="token", value=new_token, httponly=True, max_age=86400)
|
||||
return {"access_token": new_token, "token_type": "bearer"}
|
||||
|
||||
|
||||
@router.get(
|
||||
"/logout", response_model=APIResponse, dependencies=[Depends(get_current_user)]
|
||||
)
|
||||
async def logout(response: Response):
|
||||
active_user.clear()
|
||||
async def logout(response: Response, token: str = Cookie(None)):
|
||||
payload = decode_token(token)
|
||||
username = payload.get("sub") if payload else None
|
||||
if username:
|
||||
active_user.pop(username, None)
|
||||
response.delete_cookie(key="token")
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
@@ -56,8 +66,15 @@ async def logout(response: Response):
|
||||
|
||||
|
||||
@router.post("/update", response_model=dict, dependencies=[Depends(get_current_user)])
|
||||
async def update_user(user_data: UserUpdate, response: Response):
|
||||
old_user = active_user[0]
|
||||
async def update_user(
|
||||
user_data: UserUpdate, response: Response, token: str = Cookie(None)
|
||||
):
|
||||
payload = decode_token(token)
|
||||
old_user = payload.get("sub") if payload else None
|
||||
if not old_user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Unauthorized"
|
||||
)
|
||||
if update_user_info(user_data, old_user):
|
||||
token = create_access_token(
|
||||
data={"sub": old_user}, expires_delta=timedelta(days=1)
|
||||
|
||||
@@ -10,10 +10,24 @@ from module.security.api import UNAUTHORIZED, get_current_user
|
||||
router = APIRouter(prefix="/config", tags=["config"])
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_SENSITIVE_KEYS = ("password", "api_key", "token", "secret")
|
||||
|
||||
@router.get("/get", response_model=Config, dependencies=[Depends(get_current_user)])
|
||||
|
||||
def _sanitize_dict(d: dict) -> dict:
|
||||
result = {}
|
||||
for k, v in d.items():
|
||||
if isinstance(v, dict):
|
||||
result[k] = _sanitize_dict(v)
|
||||
elif any(s in k.lower() for s in _SENSITIVE_KEYS):
|
||||
result[k] = "********"
|
||||
else:
|
||||
result[k] = v
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/get", dependencies=[Depends(get_current_user)])
|
||||
async def get_config():
|
||||
return settings
|
||||
return _sanitize_dict(settings.dict())
|
||||
|
||||
|
||||
@router.patch(
|
||||
@@ -27,7 +41,10 @@ async def update_config(config: Config):
|
||||
logger.info("Config updated")
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={"msg_en": "Update config successfully.", "msg_zh": "更新配置成功。"},
|
||||
content={
|
||||
"msg_en": "Update config successfully.",
|
||||
"msg_zh": "更新配置成功。",
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(e)
|
||||
|
||||
@@ -3,11 +3,12 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import APIRouter, Depends
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from module.notification import NotificationManager
|
||||
from module.models.config import NotificationProvider as ProviderConfig
|
||||
from module.notification import NotificationManager
|
||||
from module.security.api import get_current_user
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/notification", tags=["notification"])
|
||||
@@ -44,7 +45,9 @@ class TestResponse(BaseModel):
|
||||
message_en: str = ""
|
||||
|
||||
|
||||
@router.post("/test", response_model=TestResponse)
|
||||
@router.post(
|
||||
"/test", response_model=TestResponse, dependencies=[Depends(get_current_user)]
|
||||
)
|
||||
async def test_provider(request: TestProviderRequest):
|
||||
"""Test a configured notification provider by its index.
|
||||
|
||||
@@ -78,7 +81,11 @@ async def test_provider(request: TestProviderRequest):
|
||||
)
|
||||
|
||||
|
||||
@router.post("/test-config", response_model=TestResponse)
|
||||
@router.post(
|
||||
"/test-config",
|
||||
response_model=TestResponse,
|
||||
dependencies=[Depends(get_current_user)],
|
||||
)
|
||||
async def test_provider_config(request: TestProviderConfigRequest):
|
||||
"""Test an unsaved notification provider configuration.
|
||||
|
||||
|
||||
@@ -2,8 +2,9 @@
|
||||
Passkey 管理 API
|
||||
用于注册、列表、删除 Passkey 凭证
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import timedelta
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse, Response
|
||||
@@ -233,8 +234,7 @@ async def login_with_passkey(
|
||||
data={"sub": username}, expires_delta=timedelta(days=1)
|
||||
)
|
||||
response.set_cookie(key="token", value=token, httponly=True, max_age=86400)
|
||||
if username not in active_user:
|
||||
active_user.append(username)
|
||||
active_user[username] = datetime.now()
|
||||
return {"access_token": token, "token_type": "bearer"}
|
||||
|
||||
raise HTTPException(status_code=resp.status_code, detail=resp.msg_en)
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
import ipaddress
|
||||
import logging
|
||||
import socket
|
||||
from pathlib import Path
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, HTTPException
|
||||
@@ -7,9 +10,9 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from module.conf import VERSION, settings
|
||||
from module.models import Config, ResponseModel
|
||||
from module.models.config import NotificationProvider as ProviderConfig
|
||||
from module.network import RequestContent
|
||||
from module.notification import PROVIDER_REGISTRY
|
||||
from module.models.config import NotificationProvider as ProviderConfig
|
||||
from module.security.jwt import get_password_hash
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -28,6 +31,27 @@ def _require_setup_needed():
|
||||
raise HTTPException(status_code=403, detail="Setup already completed.")
|
||||
|
||||
|
||||
def _validate_url(url: str) -> None:
|
||||
"""Reject non-HTTP schemes and private/reserved/loopback IPs."""
|
||||
parsed = urlparse(url)
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
raise HTTPException(status_code=400, detail="Only http/https URLs are allowed.")
|
||||
hostname = parsed.hostname
|
||||
if not hostname:
|
||||
raise HTTPException(status_code=400, detail="Invalid URL: no hostname.")
|
||||
try:
|
||||
addrs = socket.getaddrinfo(hostname, None)
|
||||
except socket.gaierror:
|
||||
raise HTTPException(status_code=400, detail="Cannot resolve hostname.")
|
||||
for family, _, _, _, sockaddr in addrs:
|
||||
ip = ipaddress.ip_address(sockaddr[0])
|
||||
if ip.is_private or ip.is_reserved or ip.is_loopback:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="URLs pointing to private/reserved IPs are not allowed.",
|
||||
)
|
||||
|
||||
|
||||
# --- Request/Response Models ---
|
||||
|
||||
|
||||
@@ -108,12 +132,16 @@ async def test_downloader(req: TestDownloaderRequest):
|
||||
|
||||
scheme = "https" if req.ssl else "http"
|
||||
host = req.host if "://" in req.host else f"{scheme}://{req.host}"
|
||||
_validate_url(host)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=5.0) as client:
|
||||
# Check if host is reachable and is qBittorrent
|
||||
resp = await client.get(host)
|
||||
if "qbittorrent" not in resp.text.lower() and "vuetorrent" not in resp.text.lower():
|
||||
if (
|
||||
"qbittorrent" not in resp.text.lower()
|
||||
and "vuetorrent" not in resp.text.lower()
|
||||
):
|
||||
return TestResultResponse(
|
||||
success=False,
|
||||
message_en="Host is reachable but does not appear to be qBittorrent.",
|
||||
@@ -169,6 +197,7 @@ async def test_downloader(req: TestDownloaderRequest):
|
||||
async def test_rss(req: TestRSSRequest):
|
||||
"""Test an RSS feed URL."""
|
||||
_require_setup_needed()
|
||||
_validate_url(req.url)
|
||||
|
||||
try:
|
||||
async with RequestContent() as request:
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
"""
|
||||
Passkey 数据库操作层
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import HTTPException
|
||||
@@ -54,7 +55,7 @@ class PasskeyDatabase:
|
||||
async def update_passkey_usage(self, passkey: Passkey, new_sign_count: int):
|
||||
"""更新 Passkey 使用记录(签名计数器 + 最后使用时间)"""
|
||||
passkey.sign_count = new_sign_count
|
||||
passkey.last_used_at = datetime.utcnow()
|
||||
passkey.last_used_at = datetime.now(timezone.utc)
|
||||
self.session.add(passkey)
|
||||
await self.session.commit()
|
||||
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
"""
|
||||
WebAuthn Passkey 数据模型
|
||||
"""
|
||||
from datetime import datetime
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -29,7 +30,7 @@ class Passkey(SQLModel, table=True):
|
||||
transports: Optional[str] = None # JSON array: ["usb", "nfc", "ble", "internal"]
|
||||
|
||||
# 审计字段
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
last_used_at: Optional[datetime] = None
|
||||
|
||||
# 备份状态 (是否为多设备凭证,如 iCloud Keychain)
|
||||
|
||||
@@ -9,7 +9,7 @@ class User(SQLModel, table=True):
|
||||
username: str = Field(
|
||||
"admin", min_length=4, max_length=20, regex=r"^[a-zA-Z0-9_]+$"
|
||||
)
|
||||
password: str = Field("adminadmin", min_length=8)
|
||||
password: str = Field("", min_length=8)
|
||||
|
||||
|
||||
class UserUpdate(SQLModel):
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import Cookie, Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
|
||||
@@ -8,10 +10,14 @@ from .jwt import verify_token
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
||||
|
||||
active_user = []
|
||||
active_user: dict[str, datetime] = {}
|
||||
|
||||
# Set to True to bypass authentication (for development/testing only)
|
||||
DEV_AUTH_BYPASS = False
|
||||
try:
|
||||
from module.__version__ import VERSION
|
||||
except ImportError:
|
||||
VERSION = "DEV_VERSION"
|
||||
|
||||
DEV_AUTH_BYPASS = VERSION == "DEV_VERSION"
|
||||
|
||||
|
||||
async def get_current_user(token: str = Cookie(None)):
|
||||
@@ -52,7 +58,7 @@ def auth_user(user: User):
|
||||
with Database() as db:
|
||||
resp = db.user.auth_user(user)
|
||||
if resp.status:
|
||||
active_user.append(user.username)
|
||||
active_user[user.username] = datetime.now()
|
||||
return resp
|
||||
|
||||
|
||||
|
||||
@@ -1,16 +1,23 @@
|
||||
import secrets
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from pathlib import Path
|
||||
|
||||
from jose import JWTError, jwt
|
||||
from passlib.context import CryptContext
|
||||
|
||||
|
||||
def generate_key():
|
||||
import secrets
|
||||
|
||||
return secrets.token_urlsafe(32)
|
||||
_SECRET_PATH = Path("config/.jwt_secret")
|
||||
|
||||
|
||||
app_pwd_key = generate_key()
|
||||
def _load_or_create_secret() -> str:
|
||||
if _SECRET_PATH.exists():
|
||||
return _SECRET_PATH.read_text().strip()
|
||||
secret = secrets.token_hex(32)
|
||||
_SECRET_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||
_SECRET_PATH.write_text(secret)
|
||||
return secret
|
||||
|
||||
|
||||
app_pwd_key = _load_or_create_secret()
|
||||
app_pwd_algorithm = "HS256"
|
||||
|
||||
# Hashing 密码
|
||||
|
||||
@@ -2,9 +2,11 @@
|
||||
WebAuthn 认证服务层
|
||||
封装 py_webauthn 库的复杂性,提供清晰的注册和认证接口
|
||||
"""
|
||||
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import List, Optional
|
||||
|
||||
from webauthn import (
|
||||
@@ -44,8 +46,44 @@ class WebAuthnService:
|
||||
self.rp_name = rp_name
|
||||
self.origin = origin
|
||||
|
||||
# 存储临时的 challenge(生产环境应使用 Redis)
|
||||
self._challenges: dict[str, bytes] = {}
|
||||
self._CHALLENGE_TTL = 300
|
||||
self._CHALLENGE_MAX = 100
|
||||
# Keyed by base64url-encoded challenge value -> (challenge_bytes, created_at, logical_key)
|
||||
self._challenges: dict[str, tuple[bytes, float, str]] = {}
|
||||
|
||||
def _cleanup_expired(self) -> None:
|
||||
now = time.time()
|
||||
expired = [
|
||||
k
|
||||
for k, (_, ts, _) in self._challenges.items()
|
||||
if now - ts > self._CHALLENGE_TTL
|
||||
]
|
||||
for k in expired:
|
||||
del self._challenges[k]
|
||||
|
||||
def _store_challenge(self, logical_key: str, challenge: bytes) -> None:
|
||||
self._cleanup_expired()
|
||||
if len(self._challenges) >= self._CHALLENGE_MAX:
|
||||
oldest = min(self._challenges, key=lambda k: self._challenges[k][1])
|
||||
del self._challenges[oldest]
|
||||
b64key = self.base64url_encode(challenge)
|
||||
self._challenges[b64key] = (challenge, time.time(), logical_key)
|
||||
|
||||
def _pop_challenge_by_key(self, logical_key: str) -> bytes | None:
|
||||
self._cleanup_expired()
|
||||
for b64key, (challenge, _, lk) in list(self._challenges.items()):
|
||||
if lk == logical_key:
|
||||
del self._challenges[b64key]
|
||||
return challenge
|
||||
return None
|
||||
|
||||
def _pop_challenge_by_value(self, challenge: bytes) -> bytes | None:
|
||||
self._cleanup_expired()
|
||||
b64key = self.base64url_encode(challenge)
|
||||
entry = self._challenges.pop(b64key, None)
|
||||
if entry:
|
||||
return entry[0]
|
||||
return None
|
||||
|
||||
# ============ 注册流程 ============
|
||||
|
||||
@@ -90,9 +128,7 @@ class WebAuthnService:
|
||||
],
|
||||
)
|
||||
|
||||
# 存储 challenge 用于后续验证
|
||||
challenge_key = f"reg_{username}"
|
||||
self._challenges[challenge_key] = options.challenge
|
||||
self._store_challenge(f"reg_{username}", options.challenge)
|
||||
logger.debug("Generated registration challenge for %s", username)
|
||||
|
||||
return json.loads(options_to_json(options))
|
||||
@@ -114,8 +150,7 @@ class WebAuthnService:
|
||||
Raises:
|
||||
ValueError: 验证失败
|
||||
"""
|
||||
challenge_key = f"reg_{username}"
|
||||
expected_challenge = self._challenges.get(challenge_key)
|
||||
expected_challenge = self._pop_challenge_by_key(f"reg_{username}")
|
||||
if not expected_challenge:
|
||||
raise ValueError("Challenge not found or expired")
|
||||
|
||||
@@ -150,9 +185,6 @@ class WebAuthnService:
|
||||
except Exception as e:
|
||||
logger.error(f"Registration verification failed: {e}")
|
||||
raise ValueError(f"Invalid registration response: {str(e)}")
|
||||
finally:
|
||||
# 清理使用过的 challenge(无论成功或失败都清理,防止重放攻击)
|
||||
self._challenges.pop(challenge_key, None)
|
||||
|
||||
# ============ 认证流程 ============
|
||||
|
||||
@@ -184,9 +216,7 @@ class WebAuthnService:
|
||||
user_verification=UserVerificationRequirement.PREFERRED,
|
||||
)
|
||||
|
||||
# 存储 challenge
|
||||
challenge_key = f"auth_{username}"
|
||||
self._challenges[challenge_key] = options.challenge
|
||||
self._store_challenge(f"auth_{username}", options.challenge)
|
||||
logger.debug("Generated authentication challenge for %s", username)
|
||||
|
||||
return json.loads(options_to_json(options))
|
||||
@@ -204,9 +234,10 @@ class WebAuthnService:
|
||||
user_verification=UserVerificationRequirement.PREFERRED,
|
||||
)
|
||||
|
||||
# Store challenge with a unique key for discoverable auth
|
||||
challenge_key = f"auth_discoverable_{self.base64url_encode(options.challenge)[:16]}"
|
||||
self._challenges[challenge_key] = options.challenge
|
||||
self._store_challenge(
|
||||
f"auth_discoverable_{self.base64url_encode(options.challenge)[:16]}",
|
||||
options.challenge,
|
||||
)
|
||||
logger.debug("Generated discoverable authentication challenge")
|
||||
|
||||
return json.loads(options_to_json(options))
|
||||
@@ -228,13 +259,11 @@ class WebAuthnService:
|
||||
Raises:
|
||||
ValueError: 验证失败
|
||||
"""
|
||||
challenge_key = f"auth_{username}"
|
||||
expected_challenge = self._challenges.get(challenge_key)
|
||||
expected_challenge = self._pop_challenge_by_key(f"auth_{username}")
|
||||
if not expected_challenge:
|
||||
raise ValueError("Challenge not found or expired")
|
||||
|
||||
try:
|
||||
# 解码 public key
|
||||
credential_public_key = base64.b64decode(passkey.public_key)
|
||||
|
||||
verification = verify_authentication_response(
|
||||
@@ -252,9 +281,6 @@ class WebAuthnService:
|
||||
except Exception as e:
|
||||
logger.error(f"Authentication verification failed: {e}")
|
||||
raise ValueError(f"Invalid authentication response: {str(e)}")
|
||||
finally:
|
||||
# 清理 challenge(无论成功或失败都清理,防止重放攻击)
|
||||
self._challenges.pop(challenge_key, None)
|
||||
|
||||
def verify_discoverable_authentication(
|
||||
self, credential: dict, passkey: Passkey
|
||||
@@ -272,13 +298,12 @@ class WebAuthnService:
|
||||
Raises:
|
||||
ValueError: 验证失败
|
||||
"""
|
||||
# Find the challenge by checking all discoverable challenges
|
||||
# Try all discoverable challenges to find the matching one
|
||||
expected_challenge = None
|
||||
challenge_key = None
|
||||
for key, challenge in list(self._challenges.items()):
|
||||
if key.startswith("auth_discoverable_"):
|
||||
for b64key, (challenge, _, lk) in list(self._challenges.items()):
|
||||
if lk.startswith("auth_discoverable_"):
|
||||
expected_challenge = challenge
|
||||
challenge_key = key
|
||||
del self._challenges[b64key]
|
||||
break
|
||||
|
||||
if not expected_challenge:
|
||||
@@ -302,9 +327,6 @@ class WebAuthnService:
|
||||
except Exception as e:
|
||||
logger.error(f"Discoverable authentication verification failed: {e}")
|
||||
raise ValueError(f"Invalid authentication response: {str(e)}")
|
||||
finally:
|
||||
if challenge_key:
|
||||
self._challenges.pop(challenge_key, None)
|
||||
|
||||
# ============ 辅助方法 ============
|
||||
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
"""Tests for Auth API endpoints."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from module.api import v1
|
||||
from module.models import ResponseModel
|
||||
from module.security.api import get_current_user, active_user
|
||||
|
||||
from module.security.api import active_user, get_current_user
|
||||
from module.security.jwt import create_access_token
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
@@ -115,7 +116,9 @@ class TestLogin:
|
||||
class TestRefreshToken:
|
||||
def test_refresh_token_success(self, authed_client):
|
||||
"""GET /auth/refresh_token returns new token."""
|
||||
with patch("module.api.auth.active_user", ["testuser"]):
|
||||
token = create_access_token(data={"sub": "testuser"})
|
||||
authed_client.cookies.set("token", token)
|
||||
with patch("module.api.auth.active_user", {"testuser": datetime.now()}):
|
||||
response = authed_client.get("/api/v1/auth/refresh_token")
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -132,7 +135,9 @@ class TestRefreshToken:
|
||||
class TestLogout:
|
||||
def test_logout_success(self, authed_client):
|
||||
"""GET /auth/logout clears session and returns success."""
|
||||
with patch("module.api.auth.active_user", ["testuser"]):
|
||||
token = create_access_token(data={"sub": "testuser"})
|
||||
authed_client.cookies.set("token", token)
|
||||
with patch("module.api.auth.active_user", {"testuser": datetime.now()}):
|
||||
response = authed_client.get("/api/v1/auth/logout")
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -148,7 +153,9 @@ class TestLogout:
|
||||
class TestUpdateCredentials:
|
||||
def test_update_success(self, authed_client):
|
||||
"""POST /auth/update with valid data updates credentials."""
|
||||
with patch("module.api.auth.active_user", ["testuser"]):
|
||||
token = create_access_token(data={"sub": "testuser"})
|
||||
authed_client.cookies.set("token", token)
|
||||
with patch("module.api.auth.active_user", {"testuser": datetime.now()}):
|
||||
with patch("module.api.auth.update_user_info", return_value=True):
|
||||
response = authed_client.post(
|
||||
"/api/v1/auth/update",
|
||||
@@ -162,7 +169,9 @@ class TestUpdateCredentials:
|
||||
|
||||
def test_update_failure(self, authed_client):
|
||||
"""POST /auth/update with invalid old password fails."""
|
||||
with patch("module.api.auth.active_user", ["testuser"]):
|
||||
token = create_access_token(data={"sub": "testuser"})
|
||||
authed_client.cookies.set("token", token)
|
||||
with patch("module.api.auth.active_user", {"testuser": datetime.now()}):
|
||||
with patch("module.api.auth.update_user_info", return_value=False):
|
||||
# When update_user_info returns False, the endpoint implicitly
|
||||
# returns None which causes an error
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
"""Tests for Passkey (WebAuthn) API endpoints."""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime
|
||||
from unittest.mock import patch, MagicMock, AsyncMock
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
@@ -11,10 +11,8 @@ from module.api import v1
|
||||
from module.models import ResponseModel
|
||||
from module.models.passkey import Passkey
|
||||
from module.security.api import get_current_user
|
||||
|
||||
from test.factories import make_passkey
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -342,7 +340,7 @@ class TestAuthVerify:
|
||||
with patch(
|
||||
"module.api.passkey.PasskeyAuthStrategy", return_value=mock_strategy
|
||||
):
|
||||
with patch("module.api.passkey.active_user", []):
|
||||
with patch("module.api.passkey.active_user", {}):
|
||||
response = unauthed_client.post(
|
||||
"/api/v1/passkey/auth/verify",
|
||||
json={
|
||||
|
||||
@@ -1,20 +1,19 @@
|
||||
"""Tests for authentication: JWT tokens, password hashing, login flow."""
|
||||
|
||||
import pytest
|
||||
from datetime import timedelta
|
||||
from unittest.mock import patch, MagicMock
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from jose import JWTError
|
||||
|
||||
from module.security.jwt import (
|
||||
create_access_token,
|
||||
decode_token,
|
||||
verify_token,
|
||||
verify_password,
|
||||
get_password_hash,
|
||||
verify_password,
|
||||
verify_token,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# JWT Token Creation
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -155,9 +154,10 @@ class TestGetCurrentUser:
|
||||
@patch("module.security.api.DEV_AUTH_BYPASS", False)
|
||||
async def test_no_cookie_raises_401(self):
|
||||
"""get_current_user raises 401 when no token cookie."""
|
||||
from module.security.api import get_current_user
|
||||
from fastapi import HTTPException
|
||||
|
||||
from module.security.api import get_current_user
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_current_user(token=None)
|
||||
assert exc_info.value.status_code == 401
|
||||
@@ -165,9 +165,10 @@ class TestGetCurrentUser:
|
||||
@patch("module.security.api.DEV_AUTH_BYPASS", False)
|
||||
async def test_invalid_token_raises_401(self):
|
||||
"""get_current_user raises 401 for invalid token."""
|
||||
from module.security.api import get_current_user
|
||||
from fastapi import HTTPException
|
||||
|
||||
from module.security.api import get_current_user
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_current_user(token="invalid.jwt.token")
|
||||
assert exc_info.value.status_code == 401
|
||||
@@ -175,9 +176,10 @@ class TestGetCurrentUser:
|
||||
@patch("module.security.api.DEV_AUTH_BYPASS", False)
|
||||
async def test_valid_token_user_not_active(self):
|
||||
"""get_current_user raises 401 when user not in active_user list."""
|
||||
from module.security.api import get_current_user, active_user
|
||||
from fastapi import HTTPException
|
||||
|
||||
from module.security.api import active_user, get_current_user
|
||||
|
||||
token = create_access_token(
|
||||
data={"sub": "ghost_user"}, expires_delta=timedelta(hours=1)
|
||||
)
|
||||
@@ -190,13 +192,15 @@ class TestGetCurrentUser:
|
||||
@patch("module.security.api.DEV_AUTH_BYPASS", False)
|
||||
async def test_valid_token_active_user_succeeds(self):
|
||||
"""get_current_user returns username for valid token + active user."""
|
||||
from module.security.api import get_current_user, active_user
|
||||
from datetime import datetime
|
||||
|
||||
from module.security.api import active_user, get_current_user
|
||||
|
||||
token = create_access_token(
|
||||
data={"sub": "active_user"}, expires_delta=timedelta(hours=1)
|
||||
)
|
||||
active_user.clear()
|
||||
active_user.append("active_user")
|
||||
active_user["active_user"] = datetime.now()
|
||||
|
||||
result = await get_current_user(token=token)
|
||||
assert result == "active_user"
|
||||
|
||||
@@ -59,7 +59,9 @@ class TestSetupStatus:
|
||||
patch("module.api.setup.SENTINEL_PATH") as mock_sentinel,
|
||||
patch("module.api.setup.settings") as mock_settings,
|
||||
patch("module.api.setup.Config") as mock_config,
|
||||
patch("module.api.setup.VERSION", "3.2.0"), # Non-dev version to test config check
|
||||
patch(
|
||||
"module.api.setup.VERSION", "3.2.0"
|
||||
), # Non-dev version to test config check
|
||||
):
|
||||
mock_sentinel.exists.return_value = False
|
||||
mock_settings.dict.return_value = {"test": "changed"}
|
||||
@@ -125,69 +127,72 @@ class TestTestDownloader:
|
||||
def test_connection_timeout(self, client, mock_first_run):
|
||||
import httpx
|
||||
|
||||
with patch("module.api.setup.httpx.AsyncClient") as mock_client_cls:
|
||||
mock_instance = AsyncMock()
|
||||
mock_instance.get.side_effect = httpx.TimeoutException("timeout")
|
||||
mock_client_cls.return_value.__aenter__ = AsyncMock(
|
||||
return_value=mock_instance
|
||||
)
|
||||
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||
with patch("module.api.setup._validate_url"):
|
||||
with patch("module.api.setup.httpx.AsyncClient") as mock_client_cls:
|
||||
mock_instance = AsyncMock()
|
||||
mock_instance.get.side_effect = httpx.TimeoutException("timeout")
|
||||
mock_client_cls.return_value.__aenter__ = AsyncMock(
|
||||
return_value=mock_instance
|
||||
)
|
||||
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/setup/test-downloader",
|
||||
json={
|
||||
"type": "qbittorrent",
|
||||
"host": "localhost:8080",
|
||||
"username": "admin",
|
||||
"password": "admin",
|
||||
"ssl": False,
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is False
|
||||
response = client.post(
|
||||
"/api/v1/setup/test-downloader",
|
||||
json={
|
||||
"type": "qbittorrent",
|
||||
"host": "localhost:8080",
|
||||
"username": "admin",
|
||||
"password": "admin",
|
||||
"ssl": False,
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is False
|
||||
|
||||
def test_connection_refused(self, client, mock_first_run):
|
||||
import httpx
|
||||
|
||||
with patch("module.api.setup.httpx.AsyncClient") as mock_client_cls:
|
||||
mock_instance = AsyncMock()
|
||||
mock_instance.get.side_effect = httpx.ConnectError("refused")
|
||||
mock_client_cls.return_value.__aenter__ = AsyncMock(
|
||||
return_value=mock_instance
|
||||
)
|
||||
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||
with patch("module.api.setup._validate_url"):
|
||||
with patch("module.api.setup.httpx.AsyncClient") as mock_client_cls:
|
||||
mock_instance = AsyncMock()
|
||||
mock_instance.get.side_effect = httpx.ConnectError("refused")
|
||||
mock_client_cls.return_value.__aenter__ = AsyncMock(
|
||||
return_value=mock_instance
|
||||
)
|
||||
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/setup/test-downloader",
|
||||
json={
|
||||
"type": "qbittorrent",
|
||||
"host": "localhost:8080",
|
||||
"username": "admin",
|
||||
"password": "admin",
|
||||
"ssl": False,
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is False
|
||||
response = client.post(
|
||||
"/api/v1/setup/test-downloader",
|
||||
json={
|
||||
"type": "qbittorrent",
|
||||
"host": "localhost:8080",
|
||||
"username": "admin",
|
||||
"password": "admin",
|
||||
"ssl": False,
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is False
|
||||
|
||||
|
||||
class TestTestRSS:
|
||||
def test_invalid_url(self, client, mock_first_run):
|
||||
with patch("module.api.setup.RequestContent") as mock_rc:
|
||||
mock_instance = AsyncMock()
|
||||
mock_instance.get_xml = AsyncMock(return_value=None)
|
||||
mock_rc.return_value.__aenter__ = AsyncMock(return_value=mock_instance)
|
||||
mock_rc.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||
with patch("module.api.setup._validate_url"):
|
||||
with patch("module.api.setup.RequestContent") as mock_rc:
|
||||
mock_instance = AsyncMock()
|
||||
mock_instance.get_xml = AsyncMock(return_value=None)
|
||||
mock_rc.return_value.__aenter__ = AsyncMock(return_value=mock_instance)
|
||||
mock_rc.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/setup/test-rss",
|
||||
json={"url": "https://invalid.example.com/rss"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is False
|
||||
response = client.post(
|
||||
"/api/v1/setup/test-rss",
|
||||
json={"url": "https://invalid.example.com/rss"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is False
|
||||
|
||||
|
||||
class TestRequestValidation:
|
||||
|
||||
Reference in New Issue
Block a user