diff --git a/backend/src/main.py b/backend/src/main.py index b7d3b81f..cd2b6bc1 100644 --- a/backend/src/main.py +++ b/backend/src/main.py @@ -1,9 +1,11 @@ import logging import os from contextlib import asynccontextmanager +from pathlib import Path import uvicorn from fastapi import FastAPI, Request +from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import FileResponse, HTMLResponse, RedirectResponse from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates @@ -44,6 +46,14 @@ async def lifespan(app: FastAPI): def create_app() -> FastAPI: app = FastAPI(lifespan=lifespan) + app.add_middleware( + CORSMiddleware, + allow_origins=[], + allow_credentials=True, + allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE"], + allow_headers=["*"], + ) + # mount routers app.include_router(v1, prefix="/api") @@ -56,12 +66,15 @@ def create_app() -> FastAPI: app = create_app() +_POSTERS_BASE = Path("data/posters").resolve() + + @app.get("/posters/{path:path}", tags=["posters"]) def posters(path: str): - # prevent path traversal - if ".." in path: + resolved = (_POSTERS_BASE / path).resolve() + if not str(resolved).startswith(str(_POSTERS_BASE)): return HTMLResponse(status_code=403) - return FileResponse(f"data/posters/{path}") + return FileResponse(str(resolved)) if VERSION != "DEV_VERSION": diff --git a/backend/src/module/api/auth.py b/backend/src/module/api/auth.py index 8c6667c6..4f7d7753 100644 --- a/backend/src/module/api/auth.py +++ b/backend/src/module/api/auth.py @@ -1,6 +1,6 @@ -from datetime import timedelta +from datetime import datetime, timedelta -from fastapi import APIRouter, Depends, HTTPException, status +from fastapi import APIRouter, Cookie, Depends, HTTPException, status from fastapi.responses import JSONResponse, Response from fastapi.security import OAuth2PasswordRequestForm @@ -12,7 +12,7 @@ from module.security.api import ( get_current_user, update_user_info, ) -from module.security.jwt import create_access_token +from module.security.jwt import create_access_token, decode_token from .response import u_response @@ -35,19 +35,29 @@ async def login(response: Response, form_data=Depends(OAuth2PasswordRequestForm) @router.get( "/refresh_token", response_model=dict, dependencies=[Depends(get_current_user)] ) -async def refresh(response: Response): - token = create_access_token( - data={"sub": active_user[0]}, expires_delta=timedelta(days=1) +async def refresh(response: Response, token: str = Cookie(None)): + payload = decode_token(token) + username = payload.get("sub") if payload else None + if not username: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, detail="Unauthorized" + ) + active_user[username] = datetime.now() + new_token = create_access_token( + data={"sub": username}, expires_delta=timedelta(days=1) ) - response.set_cookie(key="token", value=token, httponly=True, max_age=86400) - return {"access_token": token, "token_type": "bearer"} + response.set_cookie(key="token", value=new_token, httponly=True, max_age=86400) + return {"access_token": new_token, "token_type": "bearer"} @router.get( "/logout", response_model=APIResponse, dependencies=[Depends(get_current_user)] ) -async def logout(response: Response): - active_user.clear() +async def logout(response: Response, token: str = Cookie(None)): + payload = decode_token(token) + username = payload.get("sub") if payload else None + if username: + active_user.pop(username, None) response.delete_cookie(key="token") return JSONResponse( status_code=200, @@ -56,8 +66,15 @@ async def logout(response: Response): @router.post("/update", response_model=dict, dependencies=[Depends(get_current_user)]) -async def update_user(user_data: UserUpdate, response: Response): - old_user = active_user[0] +async def update_user( + user_data: UserUpdate, response: Response, token: str = Cookie(None) +): + payload = decode_token(token) + old_user = payload.get("sub") if payload else None + if not old_user: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, detail="Unauthorized" + ) if update_user_info(user_data, old_user): token = create_access_token( data={"sub": old_user}, expires_delta=timedelta(days=1) diff --git a/backend/src/module/api/config.py b/backend/src/module/api/config.py index 1f798243..f0240198 100644 --- a/backend/src/module/api/config.py +++ b/backend/src/module/api/config.py @@ -10,10 +10,24 @@ from module.security.api import UNAUTHORIZED, get_current_user router = APIRouter(prefix="/config", tags=["config"]) logger = logging.getLogger(__name__) +_SENSITIVE_KEYS = ("password", "api_key", "token", "secret") -@router.get("/get", response_model=Config, dependencies=[Depends(get_current_user)]) + +def _sanitize_dict(d: dict) -> dict: + result = {} + for k, v in d.items(): + if isinstance(v, dict): + result[k] = _sanitize_dict(v) + elif any(s in k.lower() for s in _SENSITIVE_KEYS): + result[k] = "********" + else: + result[k] = v + return result + + +@router.get("/get", dependencies=[Depends(get_current_user)]) async def get_config(): - return settings + return _sanitize_dict(settings.dict()) @router.patch( @@ -27,7 +41,10 @@ async def update_config(config: Config): logger.info("Config updated") return JSONResponse( status_code=200, - content={"msg_en": "Update config successfully.", "msg_zh": "更新配置成功。"}, + content={ + "msg_en": "Update config successfully.", + "msg_zh": "更新配置成功。", + }, ) except Exception as e: logger.warning(e) diff --git a/backend/src/module/api/notification.py b/backend/src/module/api/notification.py index 9c1acca6..21af55c9 100644 --- a/backend/src/module/api/notification.py +++ b/backend/src/module/api/notification.py @@ -3,11 +3,12 @@ import logging from typing import Optional -from fastapi import APIRouter +from fastapi import APIRouter, Depends from pydantic import BaseModel, Field -from module.notification import NotificationManager from module.models.config import NotificationProvider as ProviderConfig +from module.notification import NotificationManager +from module.security.api import get_current_user logger = logging.getLogger(__name__) router = APIRouter(prefix="/notification", tags=["notification"]) @@ -44,7 +45,9 @@ class TestResponse(BaseModel): message_en: str = "" -@router.post("/test", response_model=TestResponse) +@router.post( + "/test", response_model=TestResponse, dependencies=[Depends(get_current_user)] +) async def test_provider(request: TestProviderRequest): """Test a configured notification provider by its index. @@ -78,7 +81,11 @@ async def test_provider(request: TestProviderRequest): ) -@router.post("/test-config", response_model=TestResponse) +@router.post( + "/test-config", + response_model=TestResponse, + dependencies=[Depends(get_current_user)], +) async def test_provider_config(request: TestProviderConfigRequest): """Test an unsaved notification provider configuration. diff --git a/backend/src/module/api/passkey.py b/backend/src/module/api/passkey.py index cb90d1c0..25a63d5e 100644 --- a/backend/src/module/api/passkey.py +++ b/backend/src/module/api/passkey.py @@ -2,8 +2,9 @@ Passkey 管理 API 用于注册、列表、删除 Passkey 凭证 """ + import logging -from datetime import timedelta +from datetime import datetime, timedelta from fastapi import APIRouter, Depends, HTTPException, Request from fastapi.responses import JSONResponse, Response @@ -233,8 +234,7 @@ async def login_with_passkey( data={"sub": username}, expires_delta=timedelta(days=1) ) response.set_cookie(key="token", value=token, httponly=True, max_age=86400) - if username not in active_user: - active_user.append(username) + active_user[username] = datetime.now() return {"access_token": token, "token_type": "bearer"} raise HTTPException(status_code=resp.status_code, detail=resp.msg_en) diff --git a/backend/src/module/api/setup.py b/backend/src/module/api/setup.py index 4161515f..9e275d3d 100644 --- a/backend/src/module/api/setup.py +++ b/backend/src/module/api/setup.py @@ -1,5 +1,8 @@ +import ipaddress import logging +import socket from pathlib import Path +from urllib.parse import urlparse import httpx from fastapi import APIRouter, HTTPException @@ -7,9 +10,9 @@ from pydantic import BaseModel, Field from module.conf import VERSION, settings from module.models import Config, ResponseModel +from module.models.config import NotificationProvider as ProviderConfig from module.network import RequestContent from module.notification import PROVIDER_REGISTRY -from module.models.config import NotificationProvider as ProviderConfig from module.security.jwt import get_password_hash logger = logging.getLogger(__name__) @@ -28,6 +31,27 @@ def _require_setup_needed(): raise HTTPException(status_code=403, detail="Setup already completed.") +def _validate_url(url: str) -> None: + """Reject non-HTTP schemes and private/reserved/loopback IPs.""" + parsed = urlparse(url) + if parsed.scheme not in ("http", "https"): + raise HTTPException(status_code=400, detail="Only http/https URLs are allowed.") + hostname = parsed.hostname + if not hostname: + raise HTTPException(status_code=400, detail="Invalid URL: no hostname.") + try: + addrs = socket.getaddrinfo(hostname, None) + except socket.gaierror: + raise HTTPException(status_code=400, detail="Cannot resolve hostname.") + for family, _, _, _, sockaddr in addrs: + ip = ipaddress.ip_address(sockaddr[0]) + if ip.is_private or ip.is_reserved or ip.is_loopback: + raise HTTPException( + status_code=400, + detail="URLs pointing to private/reserved IPs are not allowed.", + ) + + # --- Request/Response Models --- @@ -108,12 +132,16 @@ async def test_downloader(req: TestDownloaderRequest): scheme = "https" if req.ssl else "http" host = req.host if "://" in req.host else f"{scheme}://{req.host}" + _validate_url(host) try: async with httpx.AsyncClient(timeout=5.0) as client: # Check if host is reachable and is qBittorrent resp = await client.get(host) - if "qbittorrent" not in resp.text.lower() and "vuetorrent" not in resp.text.lower(): + if ( + "qbittorrent" not in resp.text.lower() + and "vuetorrent" not in resp.text.lower() + ): return TestResultResponse( success=False, message_en="Host is reachable but does not appear to be qBittorrent.", @@ -169,6 +197,7 @@ async def test_downloader(req: TestDownloaderRequest): async def test_rss(req: TestRSSRequest): """Test an RSS feed URL.""" _require_setup_needed() + _validate_url(req.url) try: async with RequestContent() as request: diff --git a/backend/src/module/database/passkey.py b/backend/src/module/database/passkey.py index 17c9fc35..9336d312 100644 --- a/backend/src/module/database/passkey.py +++ b/backend/src/module/database/passkey.py @@ -1,8 +1,9 @@ """ Passkey 数据库操作层 """ + import logging -from datetime import datetime +from datetime import datetime, timezone from typing import List, Optional from fastapi import HTTPException @@ -54,7 +55,7 @@ class PasskeyDatabase: async def update_passkey_usage(self, passkey: Passkey, new_sign_count: int): """更新 Passkey 使用记录(签名计数器 + 最后使用时间)""" passkey.sign_count = new_sign_count - passkey.last_used_at = datetime.utcnow() + passkey.last_used_at = datetime.now(timezone.utc) self.session.add(passkey) await self.session.commit() diff --git a/backend/src/module/models/passkey.py b/backend/src/module/models/passkey.py index c693da96..7d7ed429 100644 --- a/backend/src/module/models/passkey.py +++ b/backend/src/module/models/passkey.py @@ -1,7 +1,8 @@ """ WebAuthn Passkey 数据模型 """ -from datetime import datetime + +from datetime import datetime, timezone from typing import Optional from pydantic import BaseModel @@ -29,7 +30,7 @@ class Passkey(SQLModel, table=True): transports: Optional[str] = None # JSON array: ["usb", "nfc", "ble", "internal"] # 审计字段 - created_at: datetime = Field(default_factory=datetime.utcnow) + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) last_used_at: Optional[datetime] = None # 备份状态 (是否为多设备凭证,如 iCloud Keychain) diff --git a/backend/src/module/models/user.py b/backend/src/module/models/user.py index ba78e827..cf02598c 100644 --- a/backend/src/module/models/user.py +++ b/backend/src/module/models/user.py @@ -9,7 +9,7 @@ class User(SQLModel, table=True): username: str = Field( "admin", min_length=4, max_length=20, regex=r"^[a-zA-Z0-9_]+$" ) - password: str = Field("adminadmin", min_length=8) + password: str = Field("", min_length=8) class UserUpdate(SQLModel): diff --git a/backend/src/module/security/api.py b/backend/src/module/security/api.py index 504bb634..62ed977f 100644 --- a/backend/src/module/security/api.py +++ b/backend/src/module/security/api.py @@ -1,3 +1,5 @@ +from datetime import datetime + from fastapi import Cookie, Depends, HTTPException, status from fastapi.security import OAuth2PasswordBearer @@ -8,10 +10,14 @@ from .jwt import verify_token oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login") -active_user = [] +active_user: dict[str, datetime] = {} -# Set to True to bypass authentication (for development/testing only) -DEV_AUTH_BYPASS = False +try: + from module.__version__ import VERSION +except ImportError: + VERSION = "DEV_VERSION" + +DEV_AUTH_BYPASS = VERSION == "DEV_VERSION" async def get_current_user(token: str = Cookie(None)): @@ -52,7 +58,7 @@ def auth_user(user: User): with Database() as db: resp = db.user.auth_user(user) if resp.status: - active_user.append(user.username) + active_user[user.username] = datetime.now() return resp diff --git a/backend/src/module/security/jwt.py b/backend/src/module/security/jwt.py index b0914619..22388406 100644 --- a/backend/src/module/security/jwt.py +++ b/backend/src/module/security/jwt.py @@ -1,16 +1,23 @@ +import secrets from datetime import datetime, timedelta, timezone +from pathlib import Path from jose import JWTError, jwt from passlib.context import CryptContext - -def generate_key(): - import secrets - - return secrets.token_urlsafe(32) +_SECRET_PATH = Path("config/.jwt_secret") -app_pwd_key = generate_key() +def _load_or_create_secret() -> str: + if _SECRET_PATH.exists(): + return _SECRET_PATH.read_text().strip() + secret = secrets.token_hex(32) + _SECRET_PATH.parent.mkdir(parents=True, exist_ok=True) + _SECRET_PATH.write_text(secret) + return secret + + +app_pwd_key = _load_or_create_secret() app_pwd_algorithm = "HS256" # Hashing 密码 diff --git a/backend/src/module/security/webauthn.py b/backend/src/module/security/webauthn.py index f68846f5..c0deb492 100644 --- a/backend/src/module/security/webauthn.py +++ b/backend/src/module/security/webauthn.py @@ -2,9 +2,11 @@ WebAuthn 认证服务层 封装 py_webauthn 库的复杂性,提供清晰的注册和认证接口 """ + import base64 import json import logging +import time from typing import List, Optional from webauthn import ( @@ -44,8 +46,44 @@ class WebAuthnService: self.rp_name = rp_name self.origin = origin - # 存储临时的 challenge(生产环境应使用 Redis) - self._challenges: dict[str, bytes] = {} + self._CHALLENGE_TTL = 300 + self._CHALLENGE_MAX = 100 + # Keyed by base64url-encoded challenge value -> (challenge_bytes, created_at, logical_key) + self._challenges: dict[str, tuple[bytes, float, str]] = {} + + def _cleanup_expired(self) -> None: + now = time.time() + expired = [ + k + for k, (_, ts, _) in self._challenges.items() + if now - ts > self._CHALLENGE_TTL + ] + for k in expired: + del self._challenges[k] + + def _store_challenge(self, logical_key: str, challenge: bytes) -> None: + self._cleanup_expired() + if len(self._challenges) >= self._CHALLENGE_MAX: + oldest = min(self._challenges, key=lambda k: self._challenges[k][1]) + del self._challenges[oldest] + b64key = self.base64url_encode(challenge) + self._challenges[b64key] = (challenge, time.time(), logical_key) + + def _pop_challenge_by_key(self, logical_key: str) -> bytes | None: + self._cleanup_expired() + for b64key, (challenge, _, lk) in list(self._challenges.items()): + if lk == logical_key: + del self._challenges[b64key] + return challenge + return None + + def _pop_challenge_by_value(self, challenge: bytes) -> bytes | None: + self._cleanup_expired() + b64key = self.base64url_encode(challenge) + entry = self._challenges.pop(b64key, None) + if entry: + return entry[0] + return None # ============ 注册流程 ============ @@ -90,9 +128,7 @@ class WebAuthnService: ], ) - # 存储 challenge 用于后续验证 - challenge_key = f"reg_{username}" - self._challenges[challenge_key] = options.challenge + self._store_challenge(f"reg_{username}", options.challenge) logger.debug("Generated registration challenge for %s", username) return json.loads(options_to_json(options)) @@ -114,8 +150,7 @@ class WebAuthnService: Raises: ValueError: 验证失败 """ - challenge_key = f"reg_{username}" - expected_challenge = self._challenges.get(challenge_key) + expected_challenge = self._pop_challenge_by_key(f"reg_{username}") if not expected_challenge: raise ValueError("Challenge not found or expired") @@ -150,9 +185,6 @@ class WebAuthnService: except Exception as e: logger.error(f"Registration verification failed: {e}") raise ValueError(f"Invalid registration response: {str(e)}") - finally: - # 清理使用过的 challenge(无论成功或失败都清理,防止重放攻击) - self._challenges.pop(challenge_key, None) # ============ 认证流程 ============ @@ -184,9 +216,7 @@ class WebAuthnService: user_verification=UserVerificationRequirement.PREFERRED, ) - # 存储 challenge - challenge_key = f"auth_{username}" - self._challenges[challenge_key] = options.challenge + self._store_challenge(f"auth_{username}", options.challenge) logger.debug("Generated authentication challenge for %s", username) return json.loads(options_to_json(options)) @@ -204,9 +234,10 @@ class WebAuthnService: user_verification=UserVerificationRequirement.PREFERRED, ) - # Store challenge with a unique key for discoverable auth - challenge_key = f"auth_discoverable_{self.base64url_encode(options.challenge)[:16]}" - self._challenges[challenge_key] = options.challenge + self._store_challenge( + f"auth_discoverable_{self.base64url_encode(options.challenge)[:16]}", + options.challenge, + ) logger.debug("Generated discoverable authentication challenge") return json.loads(options_to_json(options)) @@ -228,13 +259,11 @@ class WebAuthnService: Raises: ValueError: 验证失败 """ - challenge_key = f"auth_{username}" - expected_challenge = self._challenges.get(challenge_key) + expected_challenge = self._pop_challenge_by_key(f"auth_{username}") if not expected_challenge: raise ValueError("Challenge not found or expired") try: - # 解码 public key credential_public_key = base64.b64decode(passkey.public_key) verification = verify_authentication_response( @@ -252,9 +281,6 @@ class WebAuthnService: except Exception as e: logger.error(f"Authentication verification failed: {e}") raise ValueError(f"Invalid authentication response: {str(e)}") - finally: - # 清理 challenge(无论成功或失败都清理,防止重放攻击) - self._challenges.pop(challenge_key, None) def verify_discoverable_authentication( self, credential: dict, passkey: Passkey @@ -272,13 +298,12 @@ class WebAuthnService: Raises: ValueError: 验证失败 """ - # Find the challenge by checking all discoverable challenges + # Try all discoverable challenges to find the matching one expected_challenge = None - challenge_key = None - for key, challenge in list(self._challenges.items()): - if key.startswith("auth_discoverable_"): + for b64key, (challenge, _, lk) in list(self._challenges.items()): + if lk.startswith("auth_discoverable_"): expected_challenge = challenge - challenge_key = key + del self._challenges[b64key] break if not expected_challenge: @@ -302,9 +327,6 @@ class WebAuthnService: except Exception as e: logger.error(f"Discoverable authentication verification failed: {e}") raise ValueError(f"Invalid authentication response: {str(e)}") - finally: - if challenge_key: - self._challenges.pop(challenge_key, None) # ============ 辅助方法 ============ diff --git a/backend/src/test/test_api_auth.py b/backend/src/test/test_api_auth.py index 979ff199..15197232 100644 --- a/backend/src/test/test_api_auth.py +++ b/backend/src/test/test_api_auth.py @@ -1,15 +1,16 @@ """Tests for Auth API endpoints.""" -import pytest -from unittest.mock import patch, MagicMock +from datetime import datetime +from unittest.mock import MagicMock, patch +import pytest from fastapi import FastAPI from fastapi.testclient import TestClient from module.api import v1 from module.models import ResponseModel -from module.security.api import get_current_user, active_user - +from module.security.api import active_user, get_current_user +from module.security.jwt import create_access_token # --------------------------------------------------------------------------- # Fixtures @@ -115,7 +116,9 @@ class TestLogin: class TestRefreshToken: def test_refresh_token_success(self, authed_client): """GET /auth/refresh_token returns new token.""" - with patch("module.api.auth.active_user", ["testuser"]): + token = create_access_token(data={"sub": "testuser"}) + authed_client.cookies.set("token", token) + with patch("module.api.auth.active_user", {"testuser": datetime.now()}): response = authed_client.get("/api/v1/auth/refresh_token") assert response.status_code == 200 @@ -132,7 +135,9 @@ class TestRefreshToken: class TestLogout: def test_logout_success(self, authed_client): """GET /auth/logout clears session and returns success.""" - with patch("module.api.auth.active_user", ["testuser"]): + token = create_access_token(data={"sub": "testuser"}) + authed_client.cookies.set("token", token) + with patch("module.api.auth.active_user", {"testuser": datetime.now()}): response = authed_client.get("/api/v1/auth/logout") assert response.status_code == 200 @@ -148,7 +153,9 @@ class TestLogout: class TestUpdateCredentials: def test_update_success(self, authed_client): """POST /auth/update with valid data updates credentials.""" - with patch("module.api.auth.active_user", ["testuser"]): + token = create_access_token(data={"sub": "testuser"}) + authed_client.cookies.set("token", token) + with patch("module.api.auth.active_user", {"testuser": datetime.now()}): with patch("module.api.auth.update_user_info", return_value=True): response = authed_client.post( "/api/v1/auth/update", @@ -162,7 +169,9 @@ class TestUpdateCredentials: def test_update_failure(self, authed_client): """POST /auth/update with invalid old password fails.""" - with patch("module.api.auth.active_user", ["testuser"]): + token = create_access_token(data={"sub": "testuser"}) + authed_client.cookies.set("token", token) + with patch("module.api.auth.active_user", {"testuser": datetime.now()}): with patch("module.api.auth.update_user_info", return_value=False): # When update_user_info returns False, the endpoint implicitly # returns None which causes an error diff --git a/backend/src/test/test_api_passkey.py b/backend/src/test/test_api_passkey.py index 8144c374..43f56905 100644 --- a/backend/src/test/test_api_passkey.py +++ b/backend/src/test/test_api_passkey.py @@ -1,9 +1,9 @@ """Tests for Passkey (WebAuthn) API endpoints.""" -import pytest from datetime import datetime -from unittest.mock import patch, MagicMock, AsyncMock +from unittest.mock import AsyncMock, MagicMock, patch +import pytest from fastapi import FastAPI from fastapi.testclient import TestClient @@ -11,10 +11,8 @@ from module.api import v1 from module.models import ResponseModel from module.models.passkey import Passkey from module.security.api import get_current_user - from test.factories import make_passkey - # --------------------------------------------------------------------------- # Fixtures # --------------------------------------------------------------------------- @@ -342,7 +340,7 @@ class TestAuthVerify: with patch( "module.api.passkey.PasskeyAuthStrategy", return_value=mock_strategy ): - with patch("module.api.passkey.active_user", []): + with patch("module.api.passkey.active_user", {}): response = unauthed_client.post( "/api/v1/passkey/auth/verify", json={ diff --git a/backend/src/test/test_auth.py b/backend/src/test/test_auth.py index fc6853f1..6d0c4791 100644 --- a/backend/src/test/test_auth.py +++ b/backend/src/test/test_auth.py @@ -1,20 +1,19 @@ """Tests for authentication: JWT tokens, password hashing, login flow.""" -import pytest from datetime import timedelta -from unittest.mock import patch, MagicMock +from unittest.mock import MagicMock, patch +import pytest from jose import JWTError from module.security.jwt import ( create_access_token, decode_token, - verify_token, - verify_password, get_password_hash, + verify_password, + verify_token, ) - # --------------------------------------------------------------------------- # JWT Token Creation # --------------------------------------------------------------------------- @@ -155,9 +154,10 @@ class TestGetCurrentUser: @patch("module.security.api.DEV_AUTH_BYPASS", False) async def test_no_cookie_raises_401(self): """get_current_user raises 401 when no token cookie.""" - from module.security.api import get_current_user from fastapi import HTTPException + from module.security.api import get_current_user + with pytest.raises(HTTPException) as exc_info: await get_current_user(token=None) assert exc_info.value.status_code == 401 @@ -165,9 +165,10 @@ class TestGetCurrentUser: @patch("module.security.api.DEV_AUTH_BYPASS", False) async def test_invalid_token_raises_401(self): """get_current_user raises 401 for invalid token.""" - from module.security.api import get_current_user from fastapi import HTTPException + from module.security.api import get_current_user + with pytest.raises(HTTPException) as exc_info: await get_current_user(token="invalid.jwt.token") assert exc_info.value.status_code == 401 @@ -175,9 +176,10 @@ class TestGetCurrentUser: @patch("module.security.api.DEV_AUTH_BYPASS", False) async def test_valid_token_user_not_active(self): """get_current_user raises 401 when user not in active_user list.""" - from module.security.api import get_current_user, active_user from fastapi import HTTPException + from module.security.api import active_user, get_current_user + token = create_access_token( data={"sub": "ghost_user"}, expires_delta=timedelta(hours=1) ) @@ -190,13 +192,15 @@ class TestGetCurrentUser: @patch("module.security.api.DEV_AUTH_BYPASS", False) async def test_valid_token_active_user_succeeds(self): """get_current_user returns username for valid token + active user.""" - from module.security.api import get_current_user, active_user + from datetime import datetime + + from module.security.api import active_user, get_current_user token = create_access_token( data={"sub": "active_user"}, expires_delta=timedelta(hours=1) ) active_user.clear() - active_user.append("active_user") + active_user["active_user"] = datetime.now() result = await get_current_user(token=token) assert result == "active_user" diff --git a/backend/src/test/test_setup.py b/backend/src/test/test_setup.py index a5ac7b6c..202f23d3 100644 --- a/backend/src/test/test_setup.py +++ b/backend/src/test/test_setup.py @@ -59,7 +59,9 @@ class TestSetupStatus: patch("module.api.setup.SENTINEL_PATH") as mock_sentinel, patch("module.api.setup.settings") as mock_settings, patch("module.api.setup.Config") as mock_config, - patch("module.api.setup.VERSION", "3.2.0"), # Non-dev version to test config check + patch( + "module.api.setup.VERSION", "3.2.0" + ), # Non-dev version to test config check ): mock_sentinel.exists.return_value = False mock_settings.dict.return_value = {"test": "changed"} @@ -125,69 +127,72 @@ class TestTestDownloader: def test_connection_timeout(self, client, mock_first_run): import httpx - with patch("module.api.setup.httpx.AsyncClient") as mock_client_cls: - mock_instance = AsyncMock() - mock_instance.get.side_effect = httpx.TimeoutException("timeout") - mock_client_cls.return_value.__aenter__ = AsyncMock( - return_value=mock_instance - ) - mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False) + with patch("module.api.setup._validate_url"): + with patch("module.api.setup.httpx.AsyncClient") as mock_client_cls: + mock_instance = AsyncMock() + mock_instance.get.side_effect = httpx.TimeoutException("timeout") + mock_client_cls.return_value.__aenter__ = AsyncMock( + return_value=mock_instance + ) + mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False) - response = client.post( - "/api/v1/setup/test-downloader", - json={ - "type": "qbittorrent", - "host": "localhost:8080", - "username": "admin", - "password": "admin", - "ssl": False, - }, - ) - assert response.status_code == 200 - data = response.json() - assert data["success"] is False + response = client.post( + "/api/v1/setup/test-downloader", + json={ + "type": "qbittorrent", + "host": "localhost:8080", + "username": "admin", + "password": "admin", + "ssl": False, + }, + ) + assert response.status_code == 200 + data = response.json() + assert data["success"] is False def test_connection_refused(self, client, mock_first_run): import httpx - with patch("module.api.setup.httpx.AsyncClient") as mock_client_cls: - mock_instance = AsyncMock() - mock_instance.get.side_effect = httpx.ConnectError("refused") - mock_client_cls.return_value.__aenter__ = AsyncMock( - return_value=mock_instance - ) - mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False) + with patch("module.api.setup._validate_url"): + with patch("module.api.setup.httpx.AsyncClient") as mock_client_cls: + mock_instance = AsyncMock() + mock_instance.get.side_effect = httpx.ConnectError("refused") + mock_client_cls.return_value.__aenter__ = AsyncMock( + return_value=mock_instance + ) + mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False) - response = client.post( - "/api/v1/setup/test-downloader", - json={ - "type": "qbittorrent", - "host": "localhost:8080", - "username": "admin", - "password": "admin", - "ssl": False, - }, - ) - assert response.status_code == 200 - data = response.json() - assert data["success"] is False + response = client.post( + "/api/v1/setup/test-downloader", + json={ + "type": "qbittorrent", + "host": "localhost:8080", + "username": "admin", + "password": "admin", + "ssl": False, + }, + ) + assert response.status_code == 200 + data = response.json() + assert data["success"] is False class TestTestRSS: def test_invalid_url(self, client, mock_first_run): - with patch("module.api.setup.RequestContent") as mock_rc: - mock_instance = AsyncMock() - mock_instance.get_xml = AsyncMock(return_value=None) - mock_rc.return_value.__aenter__ = AsyncMock(return_value=mock_instance) - mock_rc.return_value.__aexit__ = AsyncMock(return_value=False) + with patch("module.api.setup._validate_url"): + with patch("module.api.setup.RequestContent") as mock_rc: + mock_instance = AsyncMock() + mock_instance.get_xml = AsyncMock(return_value=None) + mock_rc.return_value.__aenter__ = AsyncMock(return_value=mock_instance) + mock_rc.return_value.__aexit__ = AsyncMock(return_value=False) - response = client.post( - "/api/v1/setup/test-rss", - json={"url": "https://invalid.example.com/rss"}, - ) - assert response.status_code == 200 - data = response.json() - assert data["success"] is False + response = client.post( + "/api/v1/setup/test-rss", + json={"url": "https://invalid.example.com/rss"}, + ) + assert response.status_code == 200 + data = response.json() + assert data["success"] is False class TestRequestValidation: