fix(security): harden auth, JWT, WebAuthn, and API endpoints

- Persist JWT secret to config/.jwt_secret (survives restarts)
- Change active_user from list to dict with timestamps
- Extract username from cookie token instead of list index
- Add SSRF protection (_validate_url) for setup test endpoints
- Mask sensitive config fields (password, api_key, token, secret)
- Add auth guards to notification test endpoints
- Fix path traversal in /posters endpoint using resolved path check
- Add CORS middleware with empty allow_origins
- WebAuthn: add challenge TTL (300s), max capacity (100), cleanup
- Remove hardcoded default password from User model
- Use timezone-aware datetime in passkey models
- Adapt unit tests for active_user dict and cookie-based auth

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Estrella Pan
2026-02-23 11:46:12 +01:00
parent 339166508b
commit c7c709fa66
16 changed files with 284 additions and 148 deletions

View File

@@ -1,9 +1,11 @@
import logging
import os
from contextlib import asynccontextmanager
from pathlib import Path
import uvicorn
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, HTMLResponse, RedirectResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
@@ -44,6 +46,14 @@ async def lifespan(app: FastAPI):
def create_app() -> FastAPI:
app = FastAPI(lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=[],
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE"],
allow_headers=["*"],
)
# mount routers
app.include_router(v1, prefix="/api")
@@ -56,12 +66,15 @@ def create_app() -> FastAPI:
app = create_app()
_POSTERS_BASE = Path("data/posters").resolve()
@app.get("/posters/{path:path}", tags=["posters"])
def posters(path: str):
# prevent path traversal
if ".." in path:
resolved = (_POSTERS_BASE / path).resolve()
if not str(resolved).startswith(str(_POSTERS_BASE)):
return HTMLResponse(status_code=403)
return FileResponse(f"data/posters/{path}")
return FileResponse(str(resolved))
if VERSION != "DEV_VERSION":

View File

@@ -1,6 +1,6 @@
from datetime import timedelta
from datetime import datetime, timedelta
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi import APIRouter, Cookie, Depends, HTTPException, status
from fastapi.responses import JSONResponse, Response
from fastapi.security import OAuth2PasswordRequestForm
@@ -12,7 +12,7 @@ from module.security.api import (
get_current_user,
update_user_info,
)
from module.security.jwt import create_access_token
from module.security.jwt import create_access_token, decode_token
from .response import u_response
@@ -35,19 +35,29 @@ async def login(response: Response, form_data=Depends(OAuth2PasswordRequestForm)
@router.get(
"/refresh_token", response_model=dict, dependencies=[Depends(get_current_user)]
)
async def refresh(response: Response):
token = create_access_token(
data={"sub": active_user[0]}, expires_delta=timedelta(days=1)
async def refresh(response: Response, token: str = Cookie(None)):
payload = decode_token(token)
username = payload.get("sub") if payload else None
if not username:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Unauthorized"
)
active_user[username] = datetime.now()
new_token = create_access_token(
data={"sub": username}, expires_delta=timedelta(days=1)
)
response.set_cookie(key="token", value=token, httponly=True, max_age=86400)
return {"access_token": token, "token_type": "bearer"}
response.set_cookie(key="token", value=new_token, httponly=True, max_age=86400)
return {"access_token": new_token, "token_type": "bearer"}
@router.get(
"/logout", response_model=APIResponse, dependencies=[Depends(get_current_user)]
)
async def logout(response: Response):
active_user.clear()
async def logout(response: Response, token: str = Cookie(None)):
payload = decode_token(token)
username = payload.get("sub") if payload else None
if username:
active_user.pop(username, None)
response.delete_cookie(key="token")
return JSONResponse(
status_code=200,
@@ -56,8 +66,15 @@ async def logout(response: Response):
@router.post("/update", response_model=dict, dependencies=[Depends(get_current_user)])
async def update_user(user_data: UserUpdate, response: Response):
old_user = active_user[0]
async def update_user(
user_data: UserUpdate, response: Response, token: str = Cookie(None)
):
payload = decode_token(token)
old_user = payload.get("sub") if payload else None
if not old_user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Unauthorized"
)
if update_user_info(user_data, old_user):
token = create_access_token(
data={"sub": old_user}, expires_delta=timedelta(days=1)

View File

@@ -10,10 +10,24 @@ from module.security.api import UNAUTHORIZED, get_current_user
router = APIRouter(prefix="/config", tags=["config"])
logger = logging.getLogger(__name__)
_SENSITIVE_KEYS = ("password", "api_key", "token", "secret")
@router.get("/get", response_model=Config, dependencies=[Depends(get_current_user)])
def _sanitize_dict(d: dict) -> dict:
result = {}
for k, v in d.items():
if isinstance(v, dict):
result[k] = _sanitize_dict(v)
elif any(s in k.lower() for s in _SENSITIVE_KEYS):
result[k] = "********"
else:
result[k] = v
return result
@router.get("/get", dependencies=[Depends(get_current_user)])
async def get_config():
return settings
return _sanitize_dict(settings.dict())
@router.patch(
@@ -27,7 +41,10 @@ async def update_config(config: Config):
logger.info("Config updated")
return JSONResponse(
status_code=200,
content={"msg_en": "Update config successfully.", "msg_zh": "更新配置成功。"},
content={
"msg_en": "Update config successfully.",
"msg_zh": "更新配置成功。",
},
)
except Exception as e:
logger.warning(e)

View File

@@ -3,11 +3,12 @@
import logging
from typing import Optional
from fastapi import APIRouter
from fastapi import APIRouter, Depends
from pydantic import BaseModel, Field
from module.notification import NotificationManager
from module.models.config import NotificationProvider as ProviderConfig
from module.notification import NotificationManager
from module.security.api import get_current_user
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/notification", tags=["notification"])
@@ -44,7 +45,9 @@ class TestResponse(BaseModel):
message_en: str = ""
@router.post("/test", response_model=TestResponse)
@router.post(
"/test", response_model=TestResponse, dependencies=[Depends(get_current_user)]
)
async def test_provider(request: TestProviderRequest):
"""Test a configured notification provider by its index.
@@ -78,7 +81,11 @@ async def test_provider(request: TestProviderRequest):
)
@router.post("/test-config", response_model=TestResponse)
@router.post(
"/test-config",
response_model=TestResponse,
dependencies=[Depends(get_current_user)],
)
async def test_provider_config(request: TestProviderConfigRequest):
"""Test an unsaved notification provider configuration.

View File

@@ -2,8 +2,9 @@
Passkey 管理 API
用于注册、列表、删除 Passkey 凭证
"""
import logging
from datetime import timedelta
from datetime import datetime, timedelta
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import JSONResponse, Response
@@ -233,8 +234,7 @@ async def login_with_passkey(
data={"sub": username}, expires_delta=timedelta(days=1)
)
response.set_cookie(key="token", value=token, httponly=True, max_age=86400)
if username not in active_user:
active_user.append(username)
active_user[username] = datetime.now()
return {"access_token": token, "token_type": "bearer"}
raise HTTPException(status_code=resp.status_code, detail=resp.msg_en)

View File

@@ -1,5 +1,8 @@
import ipaddress
import logging
import socket
from pathlib import Path
from urllib.parse import urlparse
import httpx
from fastapi import APIRouter, HTTPException
@@ -7,9 +10,9 @@ from pydantic import BaseModel, Field
from module.conf import VERSION, settings
from module.models import Config, ResponseModel
from module.models.config import NotificationProvider as ProviderConfig
from module.network import RequestContent
from module.notification import PROVIDER_REGISTRY
from module.models.config import NotificationProvider as ProviderConfig
from module.security.jwt import get_password_hash
logger = logging.getLogger(__name__)
@@ -28,6 +31,27 @@ def _require_setup_needed():
raise HTTPException(status_code=403, detail="Setup already completed.")
def _validate_url(url: str) -> None:
"""Reject non-HTTP schemes and private/reserved/loopback IPs."""
parsed = urlparse(url)
if parsed.scheme not in ("http", "https"):
raise HTTPException(status_code=400, detail="Only http/https URLs are allowed.")
hostname = parsed.hostname
if not hostname:
raise HTTPException(status_code=400, detail="Invalid URL: no hostname.")
try:
addrs = socket.getaddrinfo(hostname, None)
except socket.gaierror:
raise HTTPException(status_code=400, detail="Cannot resolve hostname.")
for family, _, _, _, sockaddr in addrs:
ip = ipaddress.ip_address(sockaddr[0])
if ip.is_private or ip.is_reserved or ip.is_loopback:
raise HTTPException(
status_code=400,
detail="URLs pointing to private/reserved IPs are not allowed.",
)
# --- Request/Response Models ---
@@ -108,12 +132,16 @@ async def test_downloader(req: TestDownloaderRequest):
scheme = "https" if req.ssl else "http"
host = req.host if "://" in req.host else f"{scheme}://{req.host}"
_validate_url(host)
try:
async with httpx.AsyncClient(timeout=5.0) as client:
# Check if host is reachable and is qBittorrent
resp = await client.get(host)
if "qbittorrent" not in resp.text.lower() and "vuetorrent" not in resp.text.lower():
if (
"qbittorrent" not in resp.text.lower()
and "vuetorrent" not in resp.text.lower()
):
return TestResultResponse(
success=False,
message_en="Host is reachable but does not appear to be qBittorrent.",
@@ -169,6 +197,7 @@ async def test_downloader(req: TestDownloaderRequest):
async def test_rss(req: TestRSSRequest):
"""Test an RSS feed URL."""
_require_setup_needed()
_validate_url(req.url)
try:
async with RequestContent() as request:

View File

@@ -1,8 +1,9 @@
"""
Passkey 数据库操作层
"""
import logging
from datetime import datetime
from datetime import datetime, timezone
from typing import List, Optional
from fastapi import HTTPException
@@ -54,7 +55,7 @@ class PasskeyDatabase:
async def update_passkey_usage(self, passkey: Passkey, new_sign_count: int):
"""更新 Passkey 使用记录(签名计数器 + 最后使用时间)"""
passkey.sign_count = new_sign_count
passkey.last_used_at = datetime.utcnow()
passkey.last_used_at = datetime.now(timezone.utc)
self.session.add(passkey)
await self.session.commit()

View File

@@ -1,7 +1,8 @@
"""
WebAuthn Passkey 数据模型
"""
from datetime import datetime
from datetime import datetime, timezone
from typing import Optional
from pydantic import BaseModel
@@ -29,7 +30,7 @@ class Passkey(SQLModel, table=True):
transports: Optional[str] = None # JSON array: ["usb", "nfc", "ble", "internal"]
# 审计字段
created_at: datetime = Field(default_factory=datetime.utcnow)
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
last_used_at: Optional[datetime] = None
# 备份状态 (是否为多设备凭证,如 iCloud Keychain)

View File

@@ -9,7 +9,7 @@ class User(SQLModel, table=True):
username: str = Field(
"admin", min_length=4, max_length=20, regex=r"^[a-zA-Z0-9_]+$"
)
password: str = Field("adminadmin", min_length=8)
password: str = Field("", min_length=8)
class UserUpdate(SQLModel):

View File

@@ -1,3 +1,5 @@
from datetime import datetime
from fastapi import Cookie, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
@@ -8,10 +10,14 @@ from .jwt import verify_token
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
active_user = []
active_user: dict[str, datetime] = {}
# Set to True to bypass authentication (for development/testing only)
DEV_AUTH_BYPASS = False
try:
from module.__version__ import VERSION
except ImportError:
VERSION = "DEV_VERSION"
DEV_AUTH_BYPASS = VERSION == "DEV_VERSION"
async def get_current_user(token: str = Cookie(None)):
@@ -52,7 +58,7 @@ def auth_user(user: User):
with Database() as db:
resp = db.user.auth_user(user)
if resp.status:
active_user.append(user.username)
active_user[user.username] = datetime.now()
return resp

View File

@@ -1,16 +1,23 @@
import secrets
from datetime import datetime, timedelta, timezone
from pathlib import Path
from jose import JWTError, jwt
from passlib.context import CryptContext
def generate_key():
import secrets
return secrets.token_urlsafe(32)
_SECRET_PATH = Path("config/.jwt_secret")
app_pwd_key = generate_key()
def _load_or_create_secret() -> str:
if _SECRET_PATH.exists():
return _SECRET_PATH.read_text().strip()
secret = secrets.token_hex(32)
_SECRET_PATH.parent.mkdir(parents=True, exist_ok=True)
_SECRET_PATH.write_text(secret)
return secret
app_pwd_key = _load_or_create_secret()
app_pwd_algorithm = "HS256"
# Hashing 密码

View File

@@ -2,9 +2,11 @@
WebAuthn 认证服务层
封装 py_webauthn 库的复杂性,提供清晰的注册和认证接口
"""
import base64
import json
import logging
import time
from typing import List, Optional
from webauthn import (
@@ -44,8 +46,44 @@ class WebAuthnService:
self.rp_name = rp_name
self.origin = origin
# 存储临时的 challenge生产环境应使用 Redis
self._challenges: dict[str, bytes] = {}
self._CHALLENGE_TTL = 300
self._CHALLENGE_MAX = 100
# Keyed by base64url-encoded challenge value -> (challenge_bytes, created_at, logical_key)
self._challenges: dict[str, tuple[bytes, float, str]] = {}
def _cleanup_expired(self) -> None:
now = time.time()
expired = [
k
for k, (_, ts, _) in self._challenges.items()
if now - ts > self._CHALLENGE_TTL
]
for k in expired:
del self._challenges[k]
def _store_challenge(self, logical_key: str, challenge: bytes) -> None:
self._cleanup_expired()
if len(self._challenges) >= self._CHALLENGE_MAX:
oldest = min(self._challenges, key=lambda k: self._challenges[k][1])
del self._challenges[oldest]
b64key = self.base64url_encode(challenge)
self._challenges[b64key] = (challenge, time.time(), logical_key)
def _pop_challenge_by_key(self, logical_key: str) -> bytes | None:
self._cleanup_expired()
for b64key, (challenge, _, lk) in list(self._challenges.items()):
if lk == logical_key:
del self._challenges[b64key]
return challenge
return None
def _pop_challenge_by_value(self, challenge: bytes) -> bytes | None:
self._cleanup_expired()
b64key = self.base64url_encode(challenge)
entry = self._challenges.pop(b64key, None)
if entry:
return entry[0]
return None
# ============ 注册流程 ============
@@ -90,9 +128,7 @@ class WebAuthnService:
],
)
# 存储 challenge 用于后续验证
challenge_key = f"reg_{username}"
self._challenges[challenge_key] = options.challenge
self._store_challenge(f"reg_{username}", options.challenge)
logger.debug("Generated registration challenge for %s", username)
return json.loads(options_to_json(options))
@@ -114,8 +150,7 @@ class WebAuthnService:
Raises:
ValueError: 验证失败
"""
challenge_key = f"reg_{username}"
expected_challenge = self._challenges.get(challenge_key)
expected_challenge = self._pop_challenge_by_key(f"reg_{username}")
if not expected_challenge:
raise ValueError("Challenge not found or expired")
@@ -150,9 +185,6 @@ class WebAuthnService:
except Exception as e:
logger.error(f"Registration verification failed: {e}")
raise ValueError(f"Invalid registration response: {str(e)}")
finally:
# 清理使用过的 challenge无论成功或失败都清理防止重放攻击
self._challenges.pop(challenge_key, None)
# ============ 认证流程 ============
@@ -184,9 +216,7 @@ class WebAuthnService:
user_verification=UserVerificationRequirement.PREFERRED,
)
# 存储 challenge
challenge_key = f"auth_{username}"
self._challenges[challenge_key] = options.challenge
self._store_challenge(f"auth_{username}", options.challenge)
logger.debug("Generated authentication challenge for %s", username)
return json.loads(options_to_json(options))
@@ -204,9 +234,10 @@ class WebAuthnService:
user_verification=UserVerificationRequirement.PREFERRED,
)
# Store challenge with a unique key for discoverable auth
challenge_key = f"auth_discoverable_{self.base64url_encode(options.challenge)[:16]}"
self._challenges[challenge_key] = options.challenge
self._store_challenge(
f"auth_discoverable_{self.base64url_encode(options.challenge)[:16]}",
options.challenge,
)
logger.debug("Generated discoverable authentication challenge")
return json.loads(options_to_json(options))
@@ -228,13 +259,11 @@ class WebAuthnService:
Raises:
ValueError: 验证失败
"""
challenge_key = f"auth_{username}"
expected_challenge = self._challenges.get(challenge_key)
expected_challenge = self._pop_challenge_by_key(f"auth_{username}")
if not expected_challenge:
raise ValueError("Challenge not found or expired")
try:
# 解码 public key
credential_public_key = base64.b64decode(passkey.public_key)
verification = verify_authentication_response(
@@ -252,9 +281,6 @@ class WebAuthnService:
except Exception as e:
logger.error(f"Authentication verification failed: {e}")
raise ValueError(f"Invalid authentication response: {str(e)}")
finally:
# 清理 challenge无论成功或失败都清理防止重放攻击
self._challenges.pop(challenge_key, None)
def verify_discoverable_authentication(
self, credential: dict, passkey: Passkey
@@ -272,13 +298,12 @@ class WebAuthnService:
Raises:
ValueError: 验证失败
"""
# Find the challenge by checking all discoverable challenges
# Try all discoverable challenges to find the matching one
expected_challenge = None
challenge_key = None
for key, challenge in list(self._challenges.items()):
if key.startswith("auth_discoverable_"):
for b64key, (challenge, _, lk) in list(self._challenges.items()):
if lk.startswith("auth_discoverable_"):
expected_challenge = challenge
challenge_key = key
del self._challenges[b64key]
break
if not expected_challenge:
@@ -302,9 +327,6 @@ class WebAuthnService:
except Exception as e:
logger.error(f"Discoverable authentication verification failed: {e}")
raise ValueError(f"Invalid authentication response: {str(e)}")
finally:
if challenge_key:
self._challenges.pop(challenge_key, None)
# ============ 辅助方法 ============

View File

@@ -1,15 +1,16 @@
"""Tests for Auth API endpoints."""
import pytest
from unittest.mock import patch, MagicMock
from datetime import datetime
from unittest.mock import MagicMock, patch
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from module.api import v1
from module.models import ResponseModel
from module.security.api import get_current_user, active_user
from module.security.api import active_user, get_current_user
from module.security.jwt import create_access_token
# ---------------------------------------------------------------------------
# Fixtures
@@ -115,7 +116,9 @@ class TestLogin:
class TestRefreshToken:
def test_refresh_token_success(self, authed_client):
"""GET /auth/refresh_token returns new token."""
with patch("module.api.auth.active_user", ["testuser"]):
token = create_access_token(data={"sub": "testuser"})
authed_client.cookies.set("token", token)
with patch("module.api.auth.active_user", {"testuser": datetime.now()}):
response = authed_client.get("/api/v1/auth/refresh_token")
assert response.status_code == 200
@@ -132,7 +135,9 @@ class TestRefreshToken:
class TestLogout:
def test_logout_success(self, authed_client):
"""GET /auth/logout clears session and returns success."""
with patch("module.api.auth.active_user", ["testuser"]):
token = create_access_token(data={"sub": "testuser"})
authed_client.cookies.set("token", token)
with patch("module.api.auth.active_user", {"testuser": datetime.now()}):
response = authed_client.get("/api/v1/auth/logout")
assert response.status_code == 200
@@ -148,7 +153,9 @@ class TestLogout:
class TestUpdateCredentials:
def test_update_success(self, authed_client):
"""POST /auth/update with valid data updates credentials."""
with patch("module.api.auth.active_user", ["testuser"]):
token = create_access_token(data={"sub": "testuser"})
authed_client.cookies.set("token", token)
with patch("module.api.auth.active_user", {"testuser": datetime.now()}):
with patch("module.api.auth.update_user_info", return_value=True):
response = authed_client.post(
"/api/v1/auth/update",
@@ -162,7 +169,9 @@ class TestUpdateCredentials:
def test_update_failure(self, authed_client):
"""POST /auth/update with invalid old password fails."""
with patch("module.api.auth.active_user", ["testuser"]):
token = create_access_token(data={"sub": "testuser"})
authed_client.cookies.set("token", token)
with patch("module.api.auth.active_user", {"testuser": datetime.now()}):
with patch("module.api.auth.update_user_info", return_value=False):
# When update_user_info returns False, the endpoint implicitly
# returns None which causes an error

View File

@@ -1,9 +1,9 @@
"""Tests for Passkey (WebAuthn) API endpoints."""
import pytest
from datetime import datetime
from unittest.mock import patch, MagicMock, AsyncMock
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
@@ -11,10 +11,8 @@ from module.api import v1
from module.models import ResponseModel
from module.models.passkey import Passkey
from module.security.api import get_current_user
from test.factories import make_passkey
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@@ -342,7 +340,7 @@ class TestAuthVerify:
with patch(
"module.api.passkey.PasskeyAuthStrategy", return_value=mock_strategy
):
with patch("module.api.passkey.active_user", []):
with patch("module.api.passkey.active_user", {}):
response = unauthed_client.post(
"/api/v1/passkey/auth/verify",
json={

View File

@@ -1,20 +1,19 @@
"""Tests for authentication: JWT tokens, password hashing, login flow."""
import pytest
from datetime import timedelta
from unittest.mock import patch, MagicMock
from unittest.mock import MagicMock, patch
import pytest
from jose import JWTError
from module.security.jwt import (
create_access_token,
decode_token,
verify_token,
verify_password,
get_password_hash,
verify_password,
verify_token,
)
# ---------------------------------------------------------------------------
# JWT Token Creation
# ---------------------------------------------------------------------------
@@ -155,9 +154,10 @@ class TestGetCurrentUser:
@patch("module.security.api.DEV_AUTH_BYPASS", False)
async def test_no_cookie_raises_401(self):
"""get_current_user raises 401 when no token cookie."""
from module.security.api import get_current_user
from fastapi import HTTPException
from module.security.api import get_current_user
with pytest.raises(HTTPException) as exc_info:
await get_current_user(token=None)
assert exc_info.value.status_code == 401
@@ -165,9 +165,10 @@ class TestGetCurrentUser:
@patch("module.security.api.DEV_AUTH_BYPASS", False)
async def test_invalid_token_raises_401(self):
"""get_current_user raises 401 for invalid token."""
from module.security.api import get_current_user
from fastapi import HTTPException
from module.security.api import get_current_user
with pytest.raises(HTTPException) as exc_info:
await get_current_user(token="invalid.jwt.token")
assert exc_info.value.status_code == 401
@@ -175,9 +176,10 @@ class TestGetCurrentUser:
@patch("module.security.api.DEV_AUTH_BYPASS", False)
async def test_valid_token_user_not_active(self):
"""get_current_user raises 401 when user not in active_user list."""
from module.security.api import get_current_user, active_user
from fastapi import HTTPException
from module.security.api import active_user, get_current_user
token = create_access_token(
data={"sub": "ghost_user"}, expires_delta=timedelta(hours=1)
)
@@ -190,13 +192,15 @@ class TestGetCurrentUser:
@patch("module.security.api.DEV_AUTH_BYPASS", False)
async def test_valid_token_active_user_succeeds(self):
"""get_current_user returns username for valid token + active user."""
from module.security.api import get_current_user, active_user
from datetime import datetime
from module.security.api import active_user, get_current_user
token = create_access_token(
data={"sub": "active_user"}, expires_delta=timedelta(hours=1)
)
active_user.clear()
active_user.append("active_user")
active_user["active_user"] = datetime.now()
result = await get_current_user(token=token)
assert result == "active_user"

View File

@@ -59,7 +59,9 @@ class TestSetupStatus:
patch("module.api.setup.SENTINEL_PATH") as mock_sentinel,
patch("module.api.setup.settings") as mock_settings,
patch("module.api.setup.Config") as mock_config,
patch("module.api.setup.VERSION", "3.2.0"), # Non-dev version to test config check
patch(
"module.api.setup.VERSION", "3.2.0"
), # Non-dev version to test config check
):
mock_sentinel.exists.return_value = False
mock_settings.dict.return_value = {"test": "changed"}
@@ -125,69 +127,72 @@ class TestTestDownloader:
def test_connection_timeout(self, client, mock_first_run):
import httpx
with patch("module.api.setup.httpx.AsyncClient") as mock_client_cls:
mock_instance = AsyncMock()
mock_instance.get.side_effect = httpx.TimeoutException("timeout")
mock_client_cls.return_value.__aenter__ = AsyncMock(
return_value=mock_instance
)
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
with patch("module.api.setup._validate_url"):
with patch("module.api.setup.httpx.AsyncClient") as mock_client_cls:
mock_instance = AsyncMock()
mock_instance.get.side_effect = httpx.TimeoutException("timeout")
mock_client_cls.return_value.__aenter__ = AsyncMock(
return_value=mock_instance
)
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
response = client.post(
"/api/v1/setup/test-downloader",
json={
"type": "qbittorrent",
"host": "localhost:8080",
"username": "admin",
"password": "admin",
"ssl": False,
},
)
assert response.status_code == 200
data = response.json()
assert data["success"] is False
response = client.post(
"/api/v1/setup/test-downloader",
json={
"type": "qbittorrent",
"host": "localhost:8080",
"username": "admin",
"password": "admin",
"ssl": False,
},
)
assert response.status_code == 200
data = response.json()
assert data["success"] is False
def test_connection_refused(self, client, mock_first_run):
import httpx
with patch("module.api.setup.httpx.AsyncClient") as mock_client_cls:
mock_instance = AsyncMock()
mock_instance.get.side_effect = httpx.ConnectError("refused")
mock_client_cls.return_value.__aenter__ = AsyncMock(
return_value=mock_instance
)
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
with patch("module.api.setup._validate_url"):
with patch("module.api.setup.httpx.AsyncClient") as mock_client_cls:
mock_instance = AsyncMock()
mock_instance.get.side_effect = httpx.ConnectError("refused")
mock_client_cls.return_value.__aenter__ = AsyncMock(
return_value=mock_instance
)
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
response = client.post(
"/api/v1/setup/test-downloader",
json={
"type": "qbittorrent",
"host": "localhost:8080",
"username": "admin",
"password": "admin",
"ssl": False,
},
)
assert response.status_code == 200
data = response.json()
assert data["success"] is False
response = client.post(
"/api/v1/setup/test-downloader",
json={
"type": "qbittorrent",
"host": "localhost:8080",
"username": "admin",
"password": "admin",
"ssl": False,
},
)
assert response.status_code == 200
data = response.json()
assert data["success"] is False
class TestTestRSS:
def test_invalid_url(self, client, mock_first_run):
with patch("module.api.setup.RequestContent") as mock_rc:
mock_instance = AsyncMock()
mock_instance.get_xml = AsyncMock(return_value=None)
mock_rc.return_value.__aenter__ = AsyncMock(return_value=mock_instance)
mock_rc.return_value.__aexit__ = AsyncMock(return_value=False)
with patch("module.api.setup._validate_url"):
with patch("module.api.setup.RequestContent") as mock_rc:
mock_instance = AsyncMock()
mock_instance.get_xml = AsyncMock(return_value=None)
mock_rc.return_value.__aenter__ = AsyncMock(return_value=mock_instance)
mock_rc.return_value.__aexit__ = AsyncMock(return_value=False)
response = client.post(
"/api/v1/setup/test-rss",
json={"url": "https://invalid.example.com/rss"},
)
assert response.status_code == 200
data = response.json()
assert data["success"] is False
response = client.post(
"/api/v1/setup/test-rss",
json={"url": "https://invalid.example.com/rss"},
)
assert response.status_code == 200
data = response.json()
assert data["success"] is False
class TestRequestValidation: