From b57d3c49ae76ccb3257058213cd2e932575dc10e Mon Sep 17 00:00:00 2001 From: Estrella Pan Date: Mon, 23 Feb 2026 17:18:23 +0100 Subject: [PATCH] feat(security): add security config UI and improve auth/MCP security Co-Authored-By: Claude Opus 4.6 --- CHANGELOG.md | 38 ++ backend/pyproject.toml | 2 +- backend/src/module/api/auth.py | 46 +- backend/src/module/api/config.py | 5 +- backend/src/module/conf/config.py | 28 +- backend/src/module/conf/const.py | 21 + .../downloader/client/mock_downloader.py | 4 +- .../src/module/downloader/download_client.py | 30 +- backend/src/module/mcp/security.py | 70 ++- backend/src/module/mcp/server.py | 8 +- backend/src/module/models/config.py | 81 +++- backend/src/module/security/api.py | 46 +- backend/src/test/test_api_auth.py | 95 ++++ backend/src/test/test_api_config.py | 99 +++- backend/src/test/test_auth.py | 122 ++++- backend/src/test/test_config.py | 288 +++++++++++- backend/src/test/test_download_client.py | 61 +++ backend/src/test/test_mcp_security.py | 347 +++++++------- backend/src/test/test_mock_downloader.py | 429 ++++++++++++++++++ backend/uv.lock | 2 +- .../components/setting/config-security.vue | 68 +++ webui/src/pages/index/config.vue | 1 + webui/types/config.ts | 30 +- webui/types/dts/auto-imports.d.ts | 1 + webui/types/dts/components.d.ts | 1 + 25 files changed, 1621 insertions(+), 302 deletions(-) create mode 100644 backend/src/test/test_mock_downloader.py create mode 100644 webui/src/components/setting/config-security.vue diff --git a/CHANGELOG.md b/CHANGELOG.md index 349ef388..0ac6f0af 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,41 @@ +# [Unreleased] + +## Backend + +### Added + +- 新增 `Security` 配置模型,支持登录 IP 白名单、MCP IP 白名单和 Bearer Token 认证 +- 新增登录端点 IP 白名单检查中间件 (`check_login_ip`) +- MCP 安全中间件升级为可配置模式:支持 CIDR 白名单 + Bearer Token 双重认证 +- 认证端点支持 `Authorization: Bearer` 令牌绕过 Cookie 登录 +- 配置 API `_sanitize_dict` 修复:仅对字符串值进行脱敏,避免误处理非字符串字段 + +- 新增番剧放送日手动设置 API (`PATCH /api/v1/bangumi/{id}/weekday`),支持锁定放送日防止日历刷新覆盖 +- 数据库迁移 v9:`bangumi` 表新增 `weekday_locked` 列 + +### Changed + +- 重构认证模块:提取 `_issue_token` 公共方法,消除 3 处重复的 JWT 签发逻辑 +- `get_current_user` 简化为三级认证(DEV 绕过 → Bearer Token → Cookie JWT) +- `LocalNetworkMiddleware` 重命名为 `McpAccessMiddleware`,从硬编码 RFC 1918 改为读取配置 + +### Tests + +- 新增 101 个单元测试覆盖安全、认证、配置、下载器和 MockDownloader 模块 + +## Frontend + +### Added + +- 新增日历拖拽排列功能:可将「未知」番剧拖入星期列,自动设置放送日并锁定 + - 拖入后显示紫色图钉图标,鼠标悬停显示取消按钮 + - 锁定的番剧在日历刷新时不会被覆盖 + - 使用 vuedraggable 实现流畅拖拽动画 +- 新增安全设置组件 (`config-security.vue`),支持在 WebUI 中配置 IP 白名单和 Token +- 前端 `Security` 类型定义和初始化配置 + +--- + # [3.2.3] - 2026-02-23 ## Backend diff --git a/backend/pyproject.toml b/backend/pyproject.toml index ecefe4ae..b9c5b4f5 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "auto-bangumi" -version = "3.2.3" +version = "3.2.4" description = "AutoBangumi - Automated anime download manager" requires-python = ">=3.13" dependencies = [ diff --git a/backend/src/module/api/auth.py b/backend/src/module/api/auth.py index 4f7d7753..6b362070 100644 --- a/backend/src/module/api/auth.py +++ b/backend/src/module/api/auth.py @@ -9,6 +9,7 @@ from module.models.user import User, UserUpdate from module.security.api import ( active_user, auth_user, + check_login_ip, get_current_user, update_user_info, ) @@ -18,17 +19,26 @@ from .response import u_response router = APIRouter(prefix="/auth", tags=["auth"]) +_TOKEN_EXPIRY_DAYS = 1 +_TOKEN_MAX_AGE = 86400 -@router.post("/login", response_model=dict) + +def _issue_token(username: str, response: Response) -> dict: + """Create a JWT, set it as an HttpOnly cookie, and return the bearer payload.""" + token = create_access_token( + data={"sub": username}, expires_delta=timedelta(days=_TOKEN_EXPIRY_DAYS) + ) + response.set_cookie(key="token", value=token, httponly=True, max_age=_TOKEN_MAX_AGE) + return {"access_token": token, "token_type": "bearer"} + + +@router.post("/login", response_model=dict, dependencies=[Depends(check_login_ip)]) async def login(response: Response, form_data=Depends(OAuth2PasswordRequestForm)): + """Authenticate with username/password and issue a session token.""" user = User(username=form_data.username, password=form_data.password) resp = auth_user(user) if resp.status: - token = create_access_token( - data={"sub": user.username}, expires_delta=timedelta(days=1) - ) - response.set_cookie(key="token", value=token, httponly=True, max_age=86400) - return {"access_token": token, "token_type": "bearer"} + return _issue_token(user.username, response) return u_response(resp) @@ -36,6 +46,7 @@ async def login(response: Response, form_data=Depends(OAuth2PasswordRequestForm) "/refresh_token", response_model=dict, dependencies=[Depends(get_current_user)] ) async def refresh(response: Response, token: str = Cookie(None)): + """Refresh the current session token and update the active-user timestamp.""" payload = decode_token(token) username = payload.get("sub") if payload else None if not username: @@ -43,17 +54,14 @@ async def refresh(response: Response, token: str = Cookie(None)): status_code=status.HTTP_401_UNAUTHORIZED, detail="Unauthorized" ) active_user[username] = datetime.now() - new_token = create_access_token( - data={"sub": username}, expires_delta=timedelta(days=1) - ) - response.set_cookie(key="token", value=new_token, httponly=True, max_age=86400) - return {"access_token": new_token, "token_type": "bearer"} + return _issue_token(username, response) @router.get( "/logout", response_model=APIResponse, dependencies=[Depends(get_current_user)] ) async def logout(response: Response, token: str = Cookie(None)): + """Invalidate the session and clear the token cookie.""" payload = decode_token(token) username = payload.get("sub") if payload else None if username: @@ -69,6 +77,7 @@ async def logout(response: Response, token: str = Cookie(None)): async def update_user( user_data: UserUpdate, response: Response, token: str = Cookie(None) ): + """Update credentials for the current user and re-issue a fresh token.""" payload = decode_token(token) old_user = payload.get("sub") if payload else None if not old_user: @@ -76,17 +85,4 @@ async def update_user( status_code=status.HTTP_401_UNAUTHORIZED, detail="Unauthorized" ) if update_user_info(user_data, old_user): - token = create_access_token( - data={"sub": old_user}, expires_delta=timedelta(days=1) - ) - response.set_cookie( - key="token", - value=token, - httponly=True, - max_age=86400, - ) - return { - "access_token": token, - "token_type": "bearer", - "message": "update success", - } + return {**_issue_token(old_user, response), "message": "update success"} diff --git a/backend/src/module/api/config.py b/backend/src/module/api/config.py index f0240198..a0fadf59 100644 --- a/backend/src/module/api/config.py +++ b/backend/src/module/api/config.py @@ -14,11 +14,12 @@ _SENSITIVE_KEYS = ("password", "api_key", "token", "secret") def _sanitize_dict(d: dict) -> dict: + """Recursively mask string values whose keys contain sensitive keywords.""" result = {} for k, v in d.items(): if isinstance(v, dict): result[k] = _sanitize_dict(v) - elif any(s in k.lower() for s in _SENSITIVE_KEYS): + elif isinstance(v, str) and any(s in k.lower() for s in _SENSITIVE_KEYS): result[k] = "********" else: result[k] = v @@ -27,6 +28,7 @@ def _sanitize_dict(d: dict) -> dict: @router.get("/get", dependencies=[Depends(get_current_user)]) async def get_config(): + """Return the current configuration with sensitive fields masked.""" return _sanitize_dict(settings.dict()) @@ -34,6 +36,7 @@ async def get_config(): "/update", response_model=APIResponse, dependencies=[Depends(get_current_user)] ) async def update_config(config: Config): + """Persist and reload configuration from the supplied payload.""" try: settings.save(config_dict=config.dict()) settings.load() diff --git a/backend/src/module/conf/config.py b/backend/src/module/conf/config.py index 1bb2f19e..fd6fa2da 100644 --- a/backend/src/module/conf/config.py +++ b/backend/src/module/conf/config.py @@ -7,7 +7,7 @@ from dotenv import load_dotenv from module.models.config import Config -from .const import ENV_TO_ATTR +from .const import DEFAULT_SETTINGS, ENV_TO_ATTR logger = logging.getLogger(__name__) CONFIG_ROOT = Path("config") @@ -27,6 +27,15 @@ CONFIG_PATH = ( class Settings(Config): + """Runtime configuration singleton. + + On construction, loads from ``CONFIG_PATH`` if the file exists (and + immediately re-saves to apply any migrations), otherwise bootstraps + defaults from environment variables via ``init()``. + + Use ``settings`` module-level instance rather than instantiating directly. + """ + def __init__(self): super().__init__() if CONFIG_PATH.exists(): @@ -36,6 +45,7 @@ class Settings(Config): self.init() def load(self): + """Load and validate configuration from ``CONFIG_PATH``, applying migrations.""" with open(CONFIG_PATH, "r", encoding="utf-8") as f: config = json.load(f) config = self._migrate_old_config(config) @@ -65,20 +75,27 @@ class Settings(Config): for key in ("type", "custom_url", "token", "enable_tmdb"): rss_parser.pop(key, None) + # Add security section if missing (preserves local-network MCP default) + if "security" not in config: + config["security"] = DEFAULT_SETTINGS["security"] + return config def save(self, config_dict: dict | None = None): + """Write configuration to ``CONFIG_PATH``. Uses current state when no dict supplied.""" if not config_dict: config_dict = self.model_dump() with open(CONFIG_PATH, "w", encoding="utf-8") as f: json.dump(config_dict, f, indent=4, ensure_ascii=False) def init(self): + """Bootstrap a new config file from ``.env`` and environment variables.""" load_dotenv(".env") self.__load_from_env() self.save() def __load_from_env(self): + """Apply ``ENV_TO_ATTR`` mappings from the process environment to the config dict.""" config_dict = self.model_dump() for key, section in ENV_TO_ATTR.items(): for env, attr in section.items(): @@ -97,12 +114,11 @@ class Settings(Config): logger.info("Config loaded from env") @staticmethod - def __val_from_env(env: str, attr: tuple): + def __val_from_env(env: str, attr: tuple | str): + """Return the environment variable value, applying the converter when attr is a tuple.""" if isinstance(attr, tuple): - conv_func = attr[1] - return conv_func(os.environ[env]) - else: - return os.environ[env] + return attr[1](os.environ[env]) + return os.environ[env] @property def group_rules(self): diff --git a/backend/src/module/conf/const.py b/backend/src/module/conf/const.py index c4383f71..69f25181 100644 --- a/backend/src/module/conf/const.py +++ b/backend/src/module/conf/const.py @@ -1,4 +1,8 @@ # -*- encoding: utf-8 -*- +# DEFAULT_SETTINGS: factory defaults written to config.json on first run. +# ENV_TO_ATTR: maps AB_* environment variables to Config model attribute paths. +# Values are either a string attr name, a (attr_name, converter) tuple, or a +# list of such tuples when a single env var sets multiple attributes. DEFAULT_SETTINGS = { "program": { "rss_time": 900, @@ -46,6 +50,20 @@ DEFAULT_SETTINGS = { "model": "gpt-3.5-turbo", "deployment_id": "", }, + "security": { + "login_whitelist": [], + "login_tokens": [], + "mcp_whitelist": [ + "127.0.0.0/8", + "10.0.0.0/8", + "172.16.0.0/12", + "192.168.0.0/16", + "::1/128", + "fe80::/10", + "fc00::/7", + ], + "mcp_tokens": [], + }, } @@ -99,8 +117,11 @@ ENV_TO_ATTR = { class BCOLORS: + """ANSI colour helpers for terminal output.""" + @staticmethod def _(color: str, *args: str) -> str: + """Wrap *args* in the given ANSI colour code and reset at the end.""" strings = [str(s) for s in args] return f"{color}{', '.join(strings)}{BCOLORS.ENDC}" diff --git a/backend/src/module/downloader/client/mock_downloader.py b/backend/src/module/downloader/client/mock_downloader.py index 117a92c2..e900242d 100644 --- a/backend/src/module/downloader/client/mock_downloader.py +++ b/backend/src/module/downloader/client/mock_downloader.py @@ -29,10 +29,10 @@ class MockDownloader: "rss_processing_enabled": True, "rss_refresh_interval": 30, } - logger.info("[MockDownloader] Initialized") + logger.debug("[MockDownloader] Initialized") async def auth(self, retry=3) -> bool: - logger.info("[MockDownloader] Auth successful (mocked)") + logger.debug("[MockDownloader] Auth successful (mocked)") return True async def logout(self): diff --git a/backend/src/module/downloader/download_client.py b/backend/src/module/downloader/download_client.py index c90ee2bf..a587eec9 100644 --- a/backend/src/module/downloader/download_client.py +++ b/backend/src/module/downloader/download_client.py @@ -11,6 +11,13 @@ logger = logging.getLogger(__name__) class DownloadClient(TorrentPath): + """Unified async download client. + + Wraps qBittorrent, Aria2, or MockDownloader behind a common interface. + Intended to be used as an async context manager; authentication is + performed on ``__aenter__`` and the session is closed on ``__aexit__``. + """ + def __init__(self): super().__init__() self.client = self.__getClient() @@ -18,27 +25,28 @@ class DownloadClient(TorrentPath): @staticmethod def __getClient(): - type = settings.downloader.type + """Instantiate the configured downloader client (qbittorrent | aria2 | mock).""" + downloader_type = settings.downloader.type host = settings.downloader.host username = settings.downloader.username password = settings.downloader.password ssl = settings.downloader.ssl - if type == "qbittorrent": + if downloader_type == "qbittorrent": from .client.qb_downloader import QbDownloader return QbDownloader(host, username, password, ssl) - elif type == "aria2": + elif downloader_type == "aria2": from .client.aria2_downloader import Aria2Downloader return Aria2Downloader(host, username, password) - elif type == "mock": + elif downloader_type == "mock": from .client.mock_downloader import MockDownloader - logger.info("[Downloader] Using MockDownloader for local development") + logger.debug("[Downloader] Using MockDownloader for local development") return MockDownloader() else: - logger.error(f"[Downloader] Unsupported downloader type: {type}") - raise Exception(f"Unsupported downloader type: {type}") + logger.error("[Downloader] Unsupported downloader type: %s", downloader_type) + raise Exception(f"Unsupported downloader type: {downloader_type}") async def __aenter__(self): if not self.authed: @@ -65,6 +73,7 @@ class DownloadClient(TorrentPath): return await self.client.check_host() async def init_downloader(self): + """Apply required qBittorrent RSS preferences and create the Bangumi category.""" prefs = { "rss_auto_downloading_enabled": True, "rss_max_articles_per_feed": 500, @@ -84,6 +93,7 @@ class DownloadClient(TorrentPath): settings.downloader.path = self._join_path(prefs["save_path"], "Bangumi") async def set_rule(self, data: Bangumi): + """Create or update a qBittorrent RSS auto-download rule for one bangumi entry.""" data.rule_name = self._rule_name(data) data.save_path = self._gen_save_path(data) rule = { @@ -145,6 +155,12 @@ class DownloadClient(TorrentPath): await self.client.torrents_resume(hashes) async def add_torrent(self, torrent: Torrent | list, bangumi: Bangumi) -> bool: + """Download a torrent (or list of torrents) for the given bangumi entry. + + Handles both magnet links and .torrent file URLs, fetching file bytes + when necessary. Tags each torrent with ``ab:`` for later + episode-offset lookup during rename. + """ if not bangumi.save_path: bangumi.save_path = self._gen_save_path(bangumi) async with RequestContent() as req: diff --git a/backend/src/module/mcp/security.py b/backend/src/module/mcp/security.py index 293e8f23..24b30227 100644 --- a/backend/src/module/mcp/security.py +++ b/backend/src/module/mcp/security.py @@ -1,48 +1,68 @@ -"""MCP access control: restricts connections to local network addresses only.""" +"""MCP access control: configurable IP whitelist and bearer token authentication.""" import ipaddress import logging +from functools import lru_cache from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request from starlette.responses import JSONResponse +from module.conf import settings + logger = logging.getLogger(__name__) -# RFC 1918 private ranges + loopback + IPv6 equivalents -_ALLOWED_NETWORKS = [ - ipaddress.ip_network("127.0.0.0/8"), - ipaddress.ip_network("10.0.0.0/8"), - ipaddress.ip_network("172.16.0.0/12"), - ipaddress.ip_network("192.168.0.0/16"), - ipaddress.ip_network("::1/128"), - ipaddress.ip_network("fe80::/10"), - ipaddress.ip_network("fc00::/7"), -] + +@lru_cache(maxsize=128) +def _parse_network(cidr: str) -> ipaddress.IPv4Network | ipaddress.IPv6Network | None: + try: + return ipaddress.ip_network(cidr, strict=False) + except ValueError: + logger.warning("[MCP] Invalid CIDR in whitelist: %s", cidr) + return None -def _is_local(host: str) -> bool: - """Return True if *host* is a loopback or RFC 1918 private address.""" +def _is_allowed(host: str, whitelist: list[str]) -> bool: + """Return True if *host* falls within any CIDR range in *whitelist*.""" try: addr = ipaddress.ip_address(host) except ValueError: return False - return any(addr in net for net in _ALLOWED_NETWORKS) + for cidr in whitelist: + net = _parse_network(cidr) + if net and addr in net: + return True + return False -class LocalNetworkMiddleware(BaseHTTPMiddleware): - """Starlette middleware that blocks requests from non-local IP addresses. +def clear_network_cache(): + """Clear the parsed network cache (call after config reload).""" + _parse_network.cache_clear() - Returns HTTP 403 for any client outside loopback, RFC 1918, or IPv6 - link-local/unique-local ranges. + +class McpAccessMiddleware(BaseHTTPMiddleware): + """Configurable access control for MCP endpoint. + + Checks client IP against ``settings.security.mcp_whitelist`` CIDR ranges, + and ``Authorization`` header against ``settings.security.mcp_tokens``. + If the whitelist is empty and no tokens are configured, all access is denied. """ async def dispatch(self, request: Request, call_next): + # Check bearer token first + auth_header = request.headers.get("authorization", "") + if auth_header.startswith("Bearer "): + token = auth_header[7:] + if token and token in settings.security.mcp_tokens: + return await call_next(request) + + # Check IP whitelist client_host = request.client.host if request.client else None - if not client_host or not _is_local(client_host): - logger.warning("[MCP] Rejected non-local connection from %s", client_host) - return JSONResponse( - status_code=403, - content={"error": "MCP access is restricted to local network"}, - ) - return await call_next(request) + if client_host and _is_allowed(client_host, settings.security.mcp_whitelist): + return await call_next(request) + + logger.warning("[MCP] Rejected connection from %s", client_host) + return JSONResponse( + status_code=403, + content={"error": "MCP access denied"}, + ) diff --git a/backend/src/module/mcp/server.py b/backend/src/module/mcp/server.py index b26910a3..5e250906 100644 --- a/backend/src/module/mcp/server.py +++ b/backend/src/module/mcp/server.py @@ -18,7 +18,7 @@ from starlette.requests import Request from starlette.routing import Mount, Route from .resources import RESOURCE_TEMPLATES, RESOURCES, handle_resource -from .security import LocalNetworkMiddleware +from .security import McpAccessMiddleware from .tools import TOOLS, handle_tool logger = logging.getLogger(__name__) @@ -73,8 +73,8 @@ def create_mcp_starlette_app() -> Starlette: - ``GET /sse`` - SSE stream for MCP clients - ``POST /messages/`` - client-to-server message posting - ``LocalNetworkMiddleware`` is applied so the endpoint is only reachable - from loopback and RFC 1918 addresses. + ``McpAccessMiddleware`` is applied to enforce configurable IP whitelist + and bearer token access control. """ app = Starlette( routes=[ @@ -82,5 +82,5 @@ def create_mcp_starlette_app() -> Starlette: Mount("/messages", app=sse.handle_post_message), ], ) - app.add_middleware(LocalNetworkMiddleware) + app.add_middleware(McpAccessMiddleware) return app diff --git a/backend/src/module/models/config.py b/backend/src/module/models/config.py index 63de5924..2cfe8b53 100644 --- a/backend/src/module/models/config.py +++ b/backend/src/module/models/config.py @@ -4,13 +4,27 @@ from typing import Literal, Optional from pydantic import BaseModel, Field, field_validator, model_validator +def _expand(value: str | None) -> str: + """Expand shell environment variables in *value*, returning empty string for None.""" + return expandvars(value) if value else "" + + class Program(BaseModel): + """Scheduler timing and WebUI port settings.""" + rss_time: int = Field(900, description="Sleep time") rename_time: int = Field(60, description="Rename times in one loop") webui_port: int = Field(7892, description="WebUI port") class Downloader(BaseModel): + """Download client connection settings. + + Credential fields (``host``, ``username``, ``password``) are stored with a + trailing underscore and exposed via properties that expand ``$VAR`` + environment variable references at access time. + """ + type: str = Field("qbittorrent", description="Downloader type") host_: str = Field("172.17.0.1:8080", alias="host", description="Downloader host") username_: str = Field("admin", alias="username", description="Downloader username") @@ -22,24 +36,28 @@ class Downloader(BaseModel): @property def host(self): - return expandvars(self.host_) + return _expand(self.host_) @property def username(self): - return expandvars(self.username_) + return _expand(self.username_) @property def password(self): - return expandvars(self.password_) + return _expand(self.password_) class RSSParser(BaseModel): + """RSS feed parsing settings.""" + enable: bool = Field(True, description="Enable RSS parser") filter: list[str] = Field(["720", r"\d+-\d"], description="Filter") language: str = "zh" class BangumiManage(BaseModel): + """File organisation and renaming settings.""" + enable: bool = Field(True, description="Enable bangumi manage") eps_complete: bool = Field(False, description="Enable eps complete") rename_method: str = Field("pn", description="Rename method") @@ -48,10 +66,14 @@ class BangumiManage(BaseModel): class Log(BaseModel): + """Logging verbosity settings.""" + debug_enable: bool = Field(False, description="Enable debug") class Proxy(BaseModel): + """HTTP/SOCKS proxy settings. Credentials support ``$VAR`` expansion.""" + enable: bool = Field(False, description="Enable proxy") type: str = Field("http", description="Proxy type") host: str = Field("", description="Proxy host") @@ -61,11 +83,11 @@ class Proxy(BaseModel): @property def username(self): - return expandvars(self.username_) + return _expand(self.username_) @property def password(self): - return expandvars(self.password_) + return _expand(self.password_) class NotificationProvider(BaseModel): @@ -103,35 +125,35 @@ class NotificationProvider(BaseModel): @property def token(self) -> str: - return expandvars(self.token_) if self.token_ else "" + return _expand(self.token_) @property def chat_id(self) -> str: - return expandvars(self.chat_id_) if self.chat_id_ else "" + return _expand(self.chat_id_) @property def webhook_url(self) -> str: - return expandvars(self.webhook_url_) if self.webhook_url_ else "" + return _expand(self.webhook_url_) @property def server_url(self) -> str: - return expandvars(self.server_url_) if self.server_url_ else "" + return _expand(self.server_url_) @property def device_key(self) -> str: - return expandvars(self.device_key_) if self.device_key_ else "" + return _expand(self.device_key_) @property def user_key(self) -> str: - return expandvars(self.user_key_) if self.user_key_ else "" + return _expand(self.user_key_) @property def api_token(self) -> str: - return expandvars(self.api_token_) if self.api_token_ else "" + return _expand(self.api_token_) @property def url(self) -> str: - return expandvars(self.url_) if self.url_ else "" + return _expand(self.url_) class Notification(BaseModel): @@ -149,11 +171,11 @@ class Notification(BaseModel): @property def token(self) -> str: - return expandvars(self.token_) if self.token_ else "" + return _expand(self.token_) @property def chat_id(self) -> str: - return expandvars(self.chat_id_) if self.chat_id_ else "" + return _expand(self.chat_id_) @model_validator(mode="after") def migrate_legacy_config(self) -> "Notification": @@ -197,7 +219,35 @@ class ExperimentalOpenAI(BaseModel): return value +class Security(BaseModel): + """Access control configuration for the login endpoint and MCP server. + + Both ``login_whitelist`` and ``mcp_whitelist`` accept IPv4/IPv6 CIDR ranges. + An empty ``login_whitelist`` allows all IPs; an empty ``mcp_whitelist`` + denies all IP-based access (tokens still work). + """ + + login_whitelist: list[str] = Field( + default_factory=list, + description="IP/CIDR whitelist for login access. Empty = allow all.", + ) + login_tokens: list[str] = Field( + default_factory=list, + description="API bearer tokens that bypass login authentication.", + ) + mcp_whitelist: list[str] = Field( + default_factory=list, + description="IP/CIDR whitelist for MCP access. Empty = deny all.", + ) + mcp_tokens: list[str] = Field( + default_factory=list, + description="API bearer tokens for MCP access.", + ) + + class Config(BaseModel): + """Root configuration model composed of all subsection models.""" + program: Program = Program() downloader: Downloader = Downloader() rss_parser: RSSParser = RSSParser() @@ -206,6 +256,7 @@ class Config(BaseModel): proxy: Proxy = Proxy() notification: Notification = Notification() experimental_openai: ExperimentalOpenAI = ExperimentalOpenAI() + security: Security = Security() def model_dump(self, *args, by_alias=True, **kwargs): return super().model_dump(*args, by_alias=by_alias, **kwargs) diff --git a/backend/src/module/security/api.py b/backend/src/module/security/api.py index 62ed977f..23b2b154 100644 --- a/backend/src/module/security/api.py +++ b/backend/src/module/security/api.py @@ -1,9 +1,11 @@ from datetime import datetime -from fastapi import Cookie, Depends, HTTPException, status +from fastapi import Cookie, Depends, HTTPException, Request, status from fastapi.security import OAuth2PasswordBearer +from module.conf import settings from module.database import Database +from module.mcp.security import _is_allowed from module.models.user import User, UserUpdate from .jwt import verify_token @@ -20,23 +22,49 @@ except ImportError: DEV_AUTH_BYPASS = VERSION == "DEV_VERSION" -async def get_current_user(token: str = Cookie(None)): +def check_login_ip(request: Request): + """Dependency that enforces login IP whitelist. + + If ``settings.security.login_whitelist`` is empty, all IPs are allowed. + """ + whitelist = settings.security.login_whitelist + if not whitelist: + return + client_host = request.client.host if request.client else None + if not client_host or not _is_allowed(client_host, whitelist): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="IP not in login whitelist", + ) + + +async def get_current_user(request: Request, token: str = Cookie(None)): + """FastAPI dependency that validates the current session. + + Accepts authentication via (in order of precedence): + 1. DEV_AUTH_BYPASS when running as DEV_VERSION. + 2. ``Authorization: Bearer `` header matching ``login_tokens``. + 3. HttpOnly ``token`` cookie containing a valid JWT with an active session. + """ if DEV_AUTH_BYPASS: return "dev_user" + # Check bearer token bypass + auth_header = request.headers.get("authorization", "") + if auth_header.startswith("Bearer "): + api_token = auth_header[7:] + if api_token and api_token in settings.security.login_tokens: + return "api_token_user" if not token: raise UNAUTHORIZED payload = verify_token(token) - if not payload: - raise UNAUTHORIZED - username = payload.get("sub") - if not username: - raise UNAUTHORIZED - if username not in active_user: + username = payload.get("sub") if payload else None + if not username or username not in active_user: raise UNAUTHORIZED return username async def get_token_data(token: str = Depends(oauth2_scheme)): + """FastAPI dependency that decodes and returns the OAuth2 bearer token payload.""" payload = verify_token(token) if not payload: raise HTTPException( @@ -46,6 +74,7 @@ async def get_token_data(token: str = Depends(oauth2_scheme)): def update_user_info(user_data: UserUpdate, current_user): + """Persist updated credentials for *current_user* to the database.""" try: with Database() as db: db.user.update_user(current_user, user_data) @@ -55,6 +84,7 @@ def update_user_info(user_data: UserUpdate, current_user): def auth_user(user: User): + """Verify credentials and register the user in ``active_user`` on success.""" with Database() as db: resp = db.user.auth_user(user) if resp.status: diff --git a/backend/src/test/test_api_auth.py b/backend/src/test/test_api_auth.py index 15197232..d03d01d9 100644 --- a/backend/src/test/test_api_auth.py +++ b/backend/src/test/test_api_auth.py @@ -185,3 +185,98 @@ class TestUpdateCredentials: except Exception: # Expected - endpoint doesn't handle failure case properly pass + + +# --------------------------------------------------------------------------- +# Refresh token: cookie-based username resolution +# --------------------------------------------------------------------------- + + +class TestRefreshTokenCookieBehavior: + def test_refresh_with_no_cookie_raises_401(self, authed_client): + """GET /refresh_token with missing token cookie raises 401.""" + # Override auth to allow route but provide no cookie token + with patch("module.api.auth.decode_token", return_value=None): + response = authed_client.get("/api/v1/auth/refresh_token") + assert response.status_code == 401 + + def test_refresh_with_valid_cookie_updates_active_user(self, authed_client): + """GET /refresh_token updates the active user timestamp.""" + token = create_access_token(data={"sub": "testuser"}) + authed_client.cookies.set("token", token) + active_users: dict = {} + with patch("module.api.auth.active_user", active_users): + response = authed_client.get("/api/v1/auth/refresh_token") + assert response.status_code == 200 + assert "testuser" in active_users + + def test_refresh_returns_new_token(self, authed_client): + """GET /refresh_token issues a valid JWT with bearer type.""" + token = create_access_token(data={"sub": "testuser"}) + authed_client.cookies.set("token", token) + with patch("module.api.auth.active_user", {}): + response = authed_client.get("/api/v1/auth/refresh_token") + assert response.status_code == 200 + data = response.json() + assert data["token_type"] == "bearer" + assert isinstance(data["access_token"], str) + assert len(data["access_token"]) > 0 + + +# --------------------------------------------------------------------------- +# Logout: per-user removal +# --------------------------------------------------------------------------- + + +class TestLogoutCookieBehavior: + def test_logout_removes_only_current_user(self, authed_client): + """GET /logout removes the current user from active_user, not others.""" + token = create_access_token(data={"sub": "testuser"}) + authed_client.cookies.set("token", token) + active_users = { + "testuser": datetime.now(), + "otheruser": datetime.now(), + } + with patch("module.api.auth.active_user", active_users): + response = authed_client.get("/api/v1/auth/logout") + assert response.status_code == 200 + assert "testuser" not in active_users + assert "otheruser" in active_users + + def test_logout_with_no_cookie_still_succeeds(self, authed_client): + """GET /logout with no cookie clears nothing but returns success.""" + with patch("module.api.auth.decode_token", return_value=None): + with patch("module.api.auth.active_user", {}): + response = authed_client.get("/api/v1/auth/logout") + assert response.status_code == 200 + + +# --------------------------------------------------------------------------- +# Update: cookie-based user resolution +# --------------------------------------------------------------------------- + + +class TestUpdateCookieBehavior: + def test_update_with_no_cookie_raises_401(self, authed_client): + """POST /auth/update with no cookie raises 401.""" + with patch("module.api.auth.decode_token", return_value=None): + response = authed_client.post( + "/api/v1/auth/update", + json={"old_password": "old", "new_password": "new"}, + ) + assert response.status_code == 401 + + def test_update_with_valid_cookie_succeeds(self, authed_client): + """POST /auth/update resolves username from cookie and issues new token.""" + token = create_access_token(data={"sub": "testuser"}) + authed_client.cookies.set("token", token) + with patch("module.api.auth.active_user", {"testuser": datetime.now()}): + with patch("module.api.auth.update_user_info", return_value=True): + response = authed_client.post( + "/api/v1/auth/update", + json={"old_password": "oldpass", "new_password": "newpass"}, + ) + assert response.status_code == 200 + data = response.json() + assert "access_token" in data + assert data["message"] == "update success" diff --git a/backend/src/test/test_api_config.py b/backend/src/test/test_api_config.py index bd67454f..36f16d50 100644 --- a/backend/src/test/test_api_config.py +++ b/backend/src/test/test_api_config.py @@ -1,4 +1,4 @@ -"""Tests for Config API endpoints.""" +"""Tests for Config API endpoints and config sanitization.""" import pytest from unittest.mock import patch, MagicMock @@ -7,6 +7,7 @@ from fastapi import FastAPI from fastapi.testclient import TestClient from module.api import v1 +from module.api.config import _sanitize_dict from module.models.config import Config from module.security.api import get_current_user @@ -263,3 +264,99 @@ class TestUpdateConfig: response = authed_client.patch("/api/v1/config/update", json=invalid_data) assert response.status_code == 422 + + +# --------------------------------------------------------------------------- +# _sanitize_dict unit tests +# --------------------------------------------------------------------------- + + +class TestSanitizeDict: + def test_masks_password_key(self): + """Keys containing 'password' are masked.""" + result = _sanitize_dict({"password": "secret"}) + assert result["password"] == "********" + + def test_masks_api_key(self): + """Keys containing 'api_key' are masked.""" + result = _sanitize_dict({"api_key": "sk-abc123"}) + assert result["api_key"] == "********" + + def test_masks_token_key(self): + """Keys containing 'token' are masked.""" + result = _sanitize_dict({"token": "bearer-xyz"}) + assert result["token"] == "********" + + def test_masks_secret_key(self): + """Keys containing 'secret' are masked.""" + result = _sanitize_dict({"my_secret": "topsecret"}) + assert result["my_secret"] == "********" + + def test_case_insensitive_key_matching(self): + """Sensitive key matching is case-insensitive.""" + result = _sanitize_dict({"API_KEY": "abc"}) + assert result["API_KEY"] == "********" + + def test_non_sensitive_keys_pass_through(self): + """Non-sensitive keys are returned unchanged.""" + result = _sanitize_dict({"host": "localhost", "port": 8080, "enable": True}) + assert result["host"] == "localhost" + assert result["port"] == 8080 + assert result["enable"] is True + + def test_nested_dict_recursed(self): + """Nested dicts are processed recursively.""" + result = _sanitize_dict({ + "downloader": { + "host": "localhost", + "password": "secret", + } + }) + assert result["downloader"]["host"] == "localhost" + assert result["downloader"]["password"] == "********" + + def test_deeply_nested_dict(self): + """Deeply nested sensitive keys are masked.""" + result = _sanitize_dict({ + "level1": { + "level2": { + "api_key": "deep-secret" + } + } + }) + assert result["level1"]["level2"]["api_key"] == "********" + + def test_non_string_value_not_masked(self): + """Non-string values with sensitive-looking keys are NOT masked.""" + result = _sanitize_dict({"password": 12345}) + # Only string values are masked; integers pass through + assert result["password"] == 12345 + + def test_empty_dict(self): + """Empty dict returns empty dict.""" + assert _sanitize_dict({}) == {} + + def test_mixed_sensitive_and_plain(self): + """Mix of sensitive and plain keys handled correctly.""" + result = _sanitize_dict({ + "username": "admin", + "password": "secret", + "host": "10.0.0.1", + "token": "jwt-abc", + }) + assert result["username"] == "admin" + assert result["host"] == "10.0.0.1" + assert result["password"] == "********" + assert result["token"] == "********" + + def test_get_config_masks_sensitive_fields(self, authed_client): + """GET /config/get response masks password and api_key fields.""" + test_config = Config() + with patch("module.api.config.settings", test_config): + response = authed_client.get("/api/v1/config/get") + assert response.status_code == 200 + data = response.json() + # Downloader password should be masked + assert data["downloader"]["password"] == "********" + # OpenAI api_key should be masked (it's an empty string but still masked) + assert data["experimental_openai"]["api_key"] == "********" diff --git a/backend/src/test/test_auth.py b/backend/src/test/test_auth.py index 6d0c4791..aab56cca 100644 --- a/backend/src/test/test_auth.py +++ b/backend/src/test/test_auth.py @@ -151,6 +151,15 @@ class TestPasswordHashing: class TestGetCurrentUser: + @staticmethod + def _mock_request(authorization=""): + """Create a mock Request with the given Authorization header.""" + from unittest.mock import MagicMock + + request = MagicMock() + request.headers = {"authorization": authorization} + return request + @patch("module.security.api.DEV_AUTH_BYPASS", False) async def test_no_cookie_raises_401(self): """get_current_user raises 401 when no token cookie.""" @@ -159,7 +168,7 @@ class TestGetCurrentUser: from module.security.api import get_current_user with pytest.raises(HTTPException) as exc_info: - await get_current_user(token=None) + await get_current_user(request=self._mock_request(), token=None) assert exc_info.value.status_code == 401 @patch("module.security.api.DEV_AUTH_BYPASS", False) @@ -170,7 +179,7 @@ class TestGetCurrentUser: from module.security.api import get_current_user with pytest.raises(HTTPException) as exc_info: - await get_current_user(token="invalid.jwt.token") + await get_current_user(request=self._mock_request(), token="invalid.jwt.token") assert exc_info.value.status_code == 401 @patch("module.security.api.DEV_AUTH_BYPASS", False) @@ -186,7 +195,7 @@ class TestGetCurrentUser: active_user.clear() with pytest.raises(HTTPException) as exc_info: - await get_current_user(token=token) + await get_current_user(request=self._mock_request(), token=token) assert exc_info.value.status_code == 401 @patch("module.security.api.DEV_AUTH_BYPASS", False) @@ -202,8 +211,113 @@ class TestGetCurrentUser: active_user.clear() active_user["active_user"] = datetime.now() - result = await get_current_user(token=token) + result = await get_current_user(request=self._mock_request(), token=token) assert result == "active_user" # Cleanup active_user.clear() + + @patch("module.security.api.DEV_AUTH_BYPASS", True) + async def test_dev_bypass_skips_auth(self): + """When DEV_AUTH_BYPASS is True, get_current_user returns 'dev_user' unconditionally.""" + from module.security.api import get_current_user + + result = await get_current_user(request=self._mock_request(), token=None) + assert result == "dev_user" + + @patch("module.security.api.DEV_AUTH_BYPASS", False) + async def test_bearer_token_bypass_valid(self): + """A valid login_token in Authorization header returns 'api_token_user'.""" + from module.security.api import get_current_user + + mock_request = self._mock_request(authorization="Bearer valid-api-token") + mock_security = type("S", (), {"login_tokens": ["valid-api-token"]})() + mock_settings = type("Settings", (), {"security": mock_security})() + + with patch("module.security.api.settings", mock_settings): + result = await get_current_user(request=mock_request, token=None) + assert result == "api_token_user" + + @patch("module.security.api.DEV_AUTH_BYPASS", False) + async def test_bearer_token_bypass_invalid(self): + """An invalid login_token still falls through to cookie check.""" + from fastapi import HTTPException + + from module.security.api import get_current_user + + mock_request = self._mock_request(authorization="Bearer wrong-token") + mock_security = type("S", (), {"login_tokens": ["correct-token"]})() + mock_settings = type("Settings", (), {"security": mock_security})() + + with patch("module.security.api.settings", mock_settings): + with pytest.raises(HTTPException) as exc_info: + await get_current_user(request=mock_request, token=None) + assert exc_info.value.status_code == 401 + + +# --------------------------------------------------------------------------- +# check_login_ip +# --------------------------------------------------------------------------- + + +class TestCheckLoginIp: + @staticmethod + def _make_request(host: str | None): + from unittest.mock import MagicMock + + request = MagicMock() + if host is None: + request.client = None + else: + request.client = MagicMock() + request.client.host = host + return request + + def test_empty_whitelist_allows_all(self): + """When login_whitelist is empty, all IPs pass.""" + from module.security.api import check_login_ip + + mock_security = type("S", (), {"login_whitelist": []})() + mock_settings = type("Settings", (), {"security": mock_security})() + + with patch("module.security.api.settings", mock_settings): + # Should not raise + check_login_ip(request=self._make_request("8.8.8.8")) + + def test_allowed_ip_passes(self): + """IP in whitelist does not raise.""" + from module.security.api import check_login_ip + + mock_security = type("S", (), {"login_whitelist": ["192.168.0.0/16"]})() + mock_settings = type("Settings", (), {"security": mock_security})() + + with patch("module.security.api.settings", mock_settings): + check_login_ip(request=self._make_request("192.168.1.100")) + + def test_blocked_ip_raises_403(self): + """IP outside whitelist raises 403.""" + from fastapi import HTTPException + + from module.security.api import check_login_ip + + mock_security = type("S", (), {"login_whitelist": ["192.168.0.0/16"]})() + mock_settings = type("Settings", (), {"security": mock_security})() + + with patch("module.security.api.settings", mock_settings): + with pytest.raises(HTTPException) as exc_info: + check_login_ip(request=self._make_request("8.8.8.8")) + assert exc_info.value.status_code == 403 + + def test_no_client_raises_403_when_whitelist_set(self): + """Missing client info raises 403 when whitelist is non-empty.""" + from fastapi import HTTPException + + from module.security.api import check_login_ip + + mock_security = type("S", (), {"login_whitelist": ["192.168.0.0/16"]})() + mock_settings = type("Settings", (), {"security": mock_security})() + + with patch("module.security.api.settings", mock_settings): + with pytest.raises(HTTPException) as exc_info: + check_login_ip(request=self._make_request(None)) + assert exc_info.value.status_code == 403 diff --git a/backend/src/test/test_config.py b/backend/src/test/test_config.py index 081474ab..e20f6176 100644 --- a/backend/src/test/test_config.py +++ b/backend/src/test/test_config.py @@ -8,14 +8,16 @@ from unittest.mock import patch from module.models.config import ( Config, - Program, Downloader, - RSSParser, - BangumiManage, - Proxy, Notification as NotificationConfig, + NotificationProvider, + Program, + Proxy, + RSSParser, + Security, ) from module.conf.config import Settings +from module.conf.const import BCOLORS, DEFAULT_SETTINGS # --------------------------------------------------------------------------- @@ -228,3 +230,281 @@ class TestEnvOverrides: s.init() assert "192.168.1.100:9090" in s.downloader.host + + +# --------------------------------------------------------------------------- +# Security model +# --------------------------------------------------------------------------- + + +class TestSecurityModel: + def test_security_defaults(self): + """Security has empty whitelists and token lists by default.""" + sec = Security() + assert sec.login_whitelist == [] + assert sec.login_tokens == [] + assert sec.mcp_whitelist == [] + assert sec.mcp_tokens == [] + + def test_security_in_config(self): + """Config includes a Security section with correct defaults.""" + config = Config() + assert hasattr(config, "security") + assert isinstance(config.security, Security) + assert config.security.login_whitelist == [] + + def test_security_populated(self): + """Security fields accept lists of CIDRs and tokens.""" + sec = Security( + login_whitelist=["192.168.0.0/16"], + login_tokens=["token-abc"], + mcp_whitelist=["10.0.0.0/8"], + mcp_tokens=["mcp-secret"], + ) + assert "192.168.0.0/16" in sec.login_whitelist + assert "token-abc" in sec.login_tokens + assert "10.0.0.0/8" in sec.mcp_whitelist + assert "mcp-secret" in sec.mcp_tokens + + def test_security_roundtrip_serialization(self): + """Security serializes and deserializes correctly.""" + original = Security( + login_whitelist=["127.0.0.0/8"], + mcp_tokens=["tok1"], + ) + data = original.model_dump() + restored = Security.model_validate(data) + assert restored.login_whitelist == ["127.0.0.0/8"] + assert restored.mcp_tokens == ["tok1"] + + +# --------------------------------------------------------------------------- +# NotificationProvider model +# --------------------------------------------------------------------------- + + +class TestNotificationProvider: + def test_minimal_provider(self): + """NotificationProvider requires only type.""" + p = NotificationProvider(type="telegram") + assert p.type == "telegram" + assert p.enabled is True + + def test_telegram_provider_fields(self): + """Telegram provider stores token and chat_id.""" + p = NotificationProvider(type="telegram", token="bot123", chat_id="-100456") + assert p.token == "bot123" + assert p.chat_id == "-100456" + + def test_discord_provider_fields(self): + """Discord provider stores webhook_url.""" + p = NotificationProvider( + type="discord", webhook_url="https://discord.com/api/webhooks/123/abc" + ) + assert p.webhook_url == "https://discord.com/api/webhooks/123/abc" + + def test_bark_provider_fields(self): + """Bark provider stores server_url and device_key.""" + p = NotificationProvider( + type="bark", server_url="https://api.day.app", device_key="mykey" + ) + assert p.server_url == "https://api.day.app" + assert p.device_key == "mykey" + + def test_pushover_provider_fields(self): + """Pushover provider stores user_key and api_token.""" + p = NotificationProvider(type="pushover", user_key="uk1", api_token="at1") + assert p.user_key == "uk1" + assert p.api_token == "at1" + + def test_url_field_property(self): + """Webhook provider stores url.""" + p = NotificationProvider(type="webhook", url="https://example.com/hook") + assert p.url == "https://example.com/hook" + + def test_optional_fields_default_empty_string(self): + """Unset optional properties return empty string, not None.""" + p = NotificationProvider(type="telegram") + assert p.token == "" + assert p.chat_id == "" + assert p.webhook_url == "" + + def test_provider_can_be_disabled(self): + """Provider can be disabled without removing it.""" + p = NotificationProvider(type="telegram", enabled=False) + assert p.enabled is False + + def test_env_var_expansion_in_token(self, monkeypatch): + """Token field expands shell environment variables.""" + monkeypatch.setenv("TEST_BOT_TOKEN", "real-token-value") + p = NotificationProvider(type="telegram", token="$TEST_BOT_TOKEN") + assert p.token == "real-token-value" + + +# --------------------------------------------------------------------------- +# Notification model - legacy migration +# --------------------------------------------------------------------------- + + +class TestNotificationLegacyMigration: + def test_new_format_no_migration(self): + """New format with providers list is not touched.""" + n = NotificationConfig( + enable=True, + providers=[NotificationProvider(type="telegram", token="tok")], + ) + assert len(n.providers) == 1 + assert n.providers[0].type == "telegram" + + def test_old_format_migrates_to_provider(self): + """Old single-provider fields (type, token, chat_id) migrate to providers list.""" + n = NotificationConfig( + enable=True, + type="telegram", + token="bot_token", + chat_id="-100123", + ) + assert len(n.providers) == 1 + provider = n.providers[0] + assert provider.type == "telegram" + assert provider.enabled is True + + def test_old_format_no_migration_when_providers_already_set(self): + """When providers already exist, legacy fields do not create additional providers.""" + n = NotificationConfig( + enable=True, + type="telegram", + token="unused", + providers=[NotificationProvider(type="discord", webhook_url="https://d.co")], + ) + assert len(n.providers) == 1 + assert n.providers[0].type == "discord" + + def test_notification_empty_providers_by_default(self): + """Default Notification has no providers.""" + n = NotificationConfig() + assert n.providers == [] + assert n.enable is False + + +# --------------------------------------------------------------------------- +# Downloader env-var expansion +# --------------------------------------------------------------------------- + + +class TestDownloaderEnvExpansion: + def test_host_expands_env_var(self, monkeypatch): + """Downloader.host expands $VAR references.""" + monkeypatch.setenv("QB_HOST", "192.168.5.10:8080") + d = Downloader(host="$QB_HOST") + assert d.host == "192.168.5.10:8080" + + def test_username_expands_env_var(self, monkeypatch): + """Downloader.username expands $VAR references.""" + monkeypatch.setenv("QB_USER", "myuser") + d = Downloader(username="$QB_USER") + assert d.username == "myuser" + + def test_password_expands_env_var(self, monkeypatch): + """Downloader.password expands $VAR references.""" + monkeypatch.setenv("QB_PASS", "s3cret") + d = Downloader(password="$QB_PASS") + assert d.password == "s3cret" + + def test_literal_host_not_expanded(self): + """Literal host strings without $ are returned as-is.""" + d = Downloader(host="localhost:8080") + assert d.host == "localhost:8080" + + +# --------------------------------------------------------------------------- +# DEFAULT_SETTINGS structure +# --------------------------------------------------------------------------- + + +class TestDefaultSettings: + def test_security_section_present(self): + """DEFAULT_SETTINGS contains a security section.""" + assert "security" in DEFAULT_SETTINGS + + def test_security_default_mcp_whitelist(self): + """Default MCP whitelist contains private network ranges.""" + mcp_wl = DEFAULT_SETTINGS["security"]["mcp_whitelist"] + assert "127.0.0.0/8" in mcp_wl + assert "192.168.0.0/16" in mcp_wl + assert "10.0.0.0/8" in mcp_wl + + def test_security_default_tokens_empty(self): + """Default security token lists are empty.""" + assert DEFAULT_SETTINGS["security"]["login_tokens"] == [] + assert DEFAULT_SETTINGS["security"]["mcp_tokens"] == [] + + def test_notification_uses_providers_format(self): + """DEFAULT_SETTINGS notification uses new providers format.""" + notif = DEFAULT_SETTINGS["notification"] + assert "providers" in notif + assert notif["providers"] == [] + assert "type" not in notif + + +# --------------------------------------------------------------------------- +# BCOLORS utility +# --------------------------------------------------------------------------- + + +class TestBCOLORS: + def test_wrap_single_string(self): + """BCOLORS._() wraps a string with color codes and reset.""" + result = BCOLORS._(BCOLORS.OKGREEN, "hello") + assert "hello" in result + assert BCOLORS.OKGREEN in result + assert BCOLORS.ENDC in result + + def test_wrap_multiple_strings(self): + """BCOLORS._() joins multiple args with commas.""" + result = BCOLORS._(BCOLORS.WARNING, "foo", "bar") + assert "foo" in result + assert "bar" in result + + def test_wrap_non_string_arg(self): + """BCOLORS._() converts non-string args to str.""" + result = BCOLORS._(BCOLORS.FAIL, 42) + assert "42" in result + + def test_all_color_constants_are_strings(self): + """All BCOLORS constants are strings.""" + for attr in ["HEADER", "OKBLUE", "OKCYAN", "OKGREEN", "WARNING", "FAIL", "ENDC"]: + assert isinstance(getattr(BCOLORS, attr), str) + + +# --------------------------------------------------------------------------- +# Migration: security section injection +# --------------------------------------------------------------------------- + + +class TestMigrateSecuritySection: + def test_adds_security_when_missing(self): + """_migrate_old_config injects a default security section when absent.""" + old_config = { + "program": {}, + "rss_parser": {}, + } + result = Settings._migrate_old_config(old_config) + assert "security" in result + assert "mcp_whitelist" in result["security"] + + def test_preserves_existing_security_section(self): + """_migrate_old_config does not overwrite an existing security section.""" + existing_config = { + "program": {}, + "rss_parser": {}, + "security": { + "login_whitelist": ["10.0.0.0/8"], + "login_tokens": ["mytoken"], + "mcp_whitelist": [], + "mcp_tokens": [], + }, + } + result = Settings._migrate_old_config(existing_config) + assert result["security"]["login_tokens"] == ["mytoken"] + assert result["security"]["login_whitelist"] == ["10.0.0.0/8"] diff --git a/backend/src/test/test_download_client.py b/backend/src/test/test_download_client.py index af00e068..2ce43af9 100644 --- a/backend/src/test/test_download_client.py +++ b/backend/src/test/test_download_client.py @@ -293,7 +293,68 @@ class TestClientDelegation: result = await download_client.rename_torrent_file("hash1", "old.mkv", "new.mkv") assert result is False + async def test_rename_torrent_file_passes_verify_flag( + self, download_client, mock_qb_client + ): + """rename_torrent_file forwards the verify kwarg to the underlying client.""" + mock_qb_client.torrents_rename_file.return_value = True + await download_client.rename_torrent_file( + "hash1", "old.mkv", "new.mkv", verify=False + ) + call_kwargs = mock_qb_client.torrents_rename_file.call_args[1] + assert call_kwargs["verify"] is False + async def test_delete_torrent(self, download_client, mock_qb_client): """delete_torrent delegates to client.torrents_delete.""" await download_client.delete_torrent("hash1", delete_files=True) mock_qb_client.torrents_delete.assert_called_once_with("hash1", delete_files=True) + + +# --------------------------------------------------------------------------- +# add_tag +# --------------------------------------------------------------------------- + + +class TestAddTag: + async def test_add_tag_delegates_to_client(self, download_client, mock_qb_client): + """add_tag delegates to client.add_tag.""" + mock_qb_client.add_tag = AsyncMock(return_value=None) + download_client.client = mock_qb_client + await download_client.add_tag("deadbeef12345678", "ab:42") + mock_qb_client.add_tag.assert_called_once_with("deadbeef12345678", "ab:42") + + async def test_add_tag_short_hash_no_error(self, download_client, mock_qb_client): + """add_tag with a hash shorter than 8 chars does not crash the slice.""" + mock_qb_client.add_tag = AsyncMock(return_value=None) + download_client.client = mock_qb_client + # Should not raise even for short hashes + await download_client.add_tag("abc", "ab:1") + + +# --------------------------------------------------------------------------- +# Context manager: ConnectionError on failed auth +# --------------------------------------------------------------------------- + + +class TestContextManagerAuth: + async def test_aenter_raises_on_auth_failure(self, download_client, mock_qb_client): + """__aenter__ raises ConnectionError when auth fails.""" + mock_qb_client.auth.return_value = False + download_client.authed = False + with pytest.raises(ConnectionError, match="authentication failed"): + await download_client.__aenter__() + + async def test_aenter_succeeds_when_auth_passes(self, download_client, mock_qb_client): + """__aenter__ returns self when auth succeeds.""" + mock_qb_client.auth.return_value = True + download_client.authed = False + result = await download_client.__aenter__() + assert result is download_client + assert download_client.authed is True + + async def test_aexit_calls_logout_when_authed(self, download_client, mock_qb_client): + """__aexit__ calls logout and resets authed when session was active.""" + download_client.authed = True + await download_client.__aexit__(None, None, None) + mock_qb_client.logout.assert_called_once() + assert download_client.authed is False diff --git a/backend/src/test/test_mcp_security.py b/backend/src/test/test_mcp_security.py index 7367fa77..6d96397d 100644 --- a/backend/src/test/test_mcp_security.py +++ b/backend/src/test/test_mcp_security.py @@ -1,6 +1,6 @@ -"""Tests for module.mcp.security - LocalNetworkMiddleware and _is_local().""" +"""Tests for module.mcp.security - McpAccessMiddleware, _is_allowed(), and clear_network_cache().""" -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import patch import pytest from starlette.applications import Starlette @@ -8,248 +8,223 @@ from starlette.responses import PlainTextResponse from starlette.routing import Route from starlette.testclient import TestClient -from module.mcp.security import LocalNetworkMiddleware, _is_local +from module.mcp.security import McpAccessMiddleware, _is_allowed, clear_network_cache + # --------------------------------------------------------------------------- -# _is_local() unit tests +# _is_allowed() unit tests # --------------------------------------------------------------------------- -class TestIsLocal: - """Verify _is_local() correctly classifies IP addresses.""" +class TestIsAllowed: + """Verify _is_allowed() checks IPs against a given whitelist.""" - # --- loopback --- + def setup_method(self): + clear_network_cache() - def test_ipv4_loopback_127_0_0_1(self): - """127.0.0.1 is the canonical loopback address.""" - assert _is_local("127.0.0.1") is True + LOCAL_WHITELIST = [ + "127.0.0.0/8", + "10.0.0.0/8", + "172.16.0.0/12", + "192.168.0.0/16", + "::1/128", + "fe80::/10", + "fc00::/7", + ] - def test_ipv4_loopback_127_0_0_2(self): - """127.0.0.2 is within 127.0.0.0/8 and therefore local.""" - assert _is_local("127.0.0.2") is True + # --- allowed IPs --- - def test_ipv4_loopback_127_255_255_255(self): - """Top of 127.0.0.0/8 range is still local.""" - assert _is_local("127.255.255.255") is True + def test_ipv4_loopback_allowed(self): + assert _is_allowed("127.0.0.1", self.LOCAL_WHITELIST) is True - # --- RFC 1918 class-A (10.0.0.0/8) --- + def test_ipv4_loopback_range(self): + assert _is_allowed("127.255.255.255", self.LOCAL_WHITELIST) is True - def test_ipv4_10_network_start(self): - """10.0.0.1 is in 10.0.0.0/8 private range.""" - assert _is_local("10.0.0.1") is True + def test_ipv4_10_network(self): + assert _is_allowed("10.0.0.1", self.LOCAL_WHITELIST) is True - def test_ipv4_10_network_mid(self): - """10.10.20.30 is inside 10.0.0.0/8.""" - assert _is_local("10.10.20.30") is True + def test_ipv4_172_16_network(self): + assert _is_allowed("172.16.0.1", self.LOCAL_WHITELIST) is True - def test_ipv4_10_network_end(self): - """10.255.255.254 is the last usable address in 10.0.0.0/8.""" - assert _is_local("10.255.255.254") is True - - # --- RFC 1918 class-B (172.16.0.0/12) --- - - def test_ipv4_172_16_start(self): - """172.16.0.1 is the first address in 172.16.0.0/12.""" - assert _is_local("172.16.0.1") is True - - def test_ipv4_172_31_end(self): - """172.31.255.254 is at the top of the 172.16.0.0/12 range.""" - assert _is_local("172.31.255.254") is True - - def test_ipv4_172_15_not_local(self): - """172.15.255.255 is just outside 172.16.0.0/12 (below the range).""" - assert _is_local("172.15.255.255") is False - - def test_ipv4_172_32_not_local(self): - """172.32.0.0 is just above 172.16.0.0/12 (outside the range).""" - assert _is_local("172.32.0.0") is False - - # --- RFC 1918 class-C (192.168.0.0/16) --- - - def test_ipv4_192_168_start(self): - """192.168.0.1 is a typical home-router address.""" - assert _is_local("192.168.0.1") is True - - def test_ipv4_192_168_end(self): - """192.168.255.254 is at the top of 192.168.0.0/16.""" - assert _is_local("192.168.255.254") is True - - # --- Public IPv4 --- - - def test_public_ipv4_google_dns(self): - """8.8.8.8 (Google DNS) is a public address.""" - assert _is_local("8.8.8.8") is False - - def test_public_ipv4_cloudflare_dns(self): - """1.1.1.1 (Cloudflare) is a public address.""" - assert _is_local("1.1.1.1") is False - - def test_public_ipv4_broadcast_like(self): - """203.0.113.1 (TEST-NET-3, RFC 5737) is not a private address.""" - assert _is_local("203.0.113.1") is False - - # --- IPv6 loopback --- + def test_ipv4_192_168_network(self): + assert _is_allowed("192.168.1.100", self.LOCAL_WHITELIST) is True def test_ipv6_loopback(self): - """::1 is the IPv6 loopback address.""" - assert _is_local("::1") is True - - # --- IPv6 link-local (fe80::/10) --- + assert _is_allowed("::1", self.LOCAL_WHITELIST) is True def test_ipv6_link_local(self): - """fe80::1 is an IPv6 link-local address.""" - assert _is_local("fe80::1") is True + assert _is_allowed("fe80::1", self.LOCAL_WHITELIST) is True - def test_ipv6_link_local_full(self): - """fe80::aabb:ccdd is also link-local.""" - assert _is_local("fe80::aabb:ccdd") is True + def test_ipv6_ula(self): + assert _is_allowed("fd00::1", self.LOCAL_WHITELIST) is True - # --- IPv6 ULA (fc00::/7) --- + # --- denied IPs --- - def test_ipv6_ula_fc(self): - """fc00::1 is within the ULA range fc00::/7.""" - assert _is_local("fc00::1") is True + def test_public_ipv4_denied(self): + assert _is_allowed("8.8.8.8", self.LOCAL_WHITELIST) is False - def test_ipv6_ula_fd(self): - """fd00::1 is within the ULA range fc00::/7 (fd prefix).""" - assert _is_local("fd00::1") is True + def test_public_ipv6_denied(self): + assert _is_allowed("2001:4860:4860::8888", self.LOCAL_WHITELIST) is False - # --- Public IPv6 --- + def test_172_outside_range(self): + assert _is_allowed("172.32.0.0", self.LOCAL_WHITELIST) is False - def test_public_ipv6_google(self): - """2001:4860:4860::8888 (Google IPv6 DNS) is a public address.""" - assert _is_local("2001:4860:4860::8888") is False + # --- empty whitelist --- - def test_public_ipv6_documentation(self): - """2001:db8::1 (documentation prefix, RFC 3849) is public.""" - assert _is_local("2001:db8::1") is False + def test_empty_whitelist_denies_all(self): + assert _is_allowed("127.0.0.1", []) is False - # --- Invalid inputs --- + # --- invalid inputs --- - def test_invalid_hostname_returns_false(self): - """A hostname string is not parseable as an IP and must return False.""" - assert _is_local("localhost") is False + def test_invalid_hostname(self): + assert _is_allowed("localhost", self.LOCAL_WHITELIST) is False - def test_invalid_string_returns_false(self): - """A random non-IP string returns False without raising.""" - assert _is_local("not-an-ip") is False + def test_empty_string(self): + assert _is_allowed("", self.LOCAL_WHITELIST) is False - def test_empty_string_returns_false(self): - """An empty string is not a valid IP address.""" - assert _is_local("") is False + def test_malformed_ipv4(self): + assert _is_allowed("256.0.0.1", self.LOCAL_WHITELIST) is False - def test_malformed_ipv4_returns_false(self): - """A string that looks like IPv4 but is malformed returns False.""" - assert _is_local("256.0.0.1") is False + # --- single IP whitelist --- - def test_partial_ipv4_returns_false(self): - """An incomplete IPv4 address is not valid.""" - assert _is_local("192.168") is False + def test_single_ip_whitelist(self): + assert _is_allowed("203.0.113.5", ["203.0.113.5/32"]) is True + assert _is_allowed("203.0.113.6", ["203.0.113.5/32"]) is False # --------------------------------------------------------------------------- -# LocalNetworkMiddleware integration tests +# McpAccessMiddleware integration tests # --------------------------------------------------------------------------- -def _make_app_with_middleware() -> Starlette: - """Build a minimal Starlette app with LocalNetworkMiddleware applied.""" +def _make_mcp_settings(mcp_whitelist=None, mcp_tokens=None): + """Create a mock settings.security object.""" + + class MockSecurity: + def __init__(self): + self.mcp_whitelist = mcp_whitelist if mcp_whitelist is not None else [] + self.mcp_tokens = mcp_tokens if mcp_tokens is not None else [] + + class MockSettings: + def __init__(self): + self.security = MockSecurity() + + return MockSettings() + + +def _make_app() -> Starlette: + """Build a minimal Starlette app with McpAccessMiddleware applied.""" async def homepage(request): return PlainTextResponse("ok") app = Starlette(routes=[Route("/", homepage)]) - app.add_middleware(LocalNetworkMiddleware) + app.add_middleware(McpAccessMiddleware) return app -class TestLocalNetworkMiddleware: - """Verify LocalNetworkMiddleware allows or denies requests by client IP.""" +def _patch_client_ip(app, ip): + """Return a modified app that overrides the client IP in ASGI scope.""" + original_build = app.build_middleware_stack - @pytest.fixture - def app(self): - return _make_app_with_middleware() + async def patched_app(scope, receive, send): + if scope["type"] == "http": + scope["client"] = (ip, 12345) if ip is not None else None + await original_build()(scope, receive, send) - def test_local_ipv4_loopback_allowed(self, app): - """Requests from 127.0.0.1 are allowed through.""" - # Starlette's TestClient identifies itself as "testclient", not a real IP. - # Patch the scope so the middleware sees an actual loopback address. - original_build = app.build_middleware_stack + app.build_middleware_stack = lambda: patched_app + return app - async def patched_app(scope, receive, send): - if scope["type"] == "http": - scope["client"] = ("127.0.0.1", 12345) - await original_build()(scope, receive, send) - app.build_middleware_stack = lambda: patched_app # type: ignore[method-assign] +class TestMcpAccessMiddleware: + """Verify McpAccessMiddleware allows/denies requests by IP and token.""" - client = TestClient(app, raise_server_exceptions=False) - response = client.get("/") + def setup_method(self): + clear_network_cache() + + def test_allowed_ip_passes(self): + mock_settings = _make_mcp_settings(mcp_whitelist=["127.0.0.0/8"]) + app = _patch_client_ip(_make_app(), "127.0.0.1") + with patch("module.mcp.security.settings", mock_settings): + client = TestClient(app, raise_server_exceptions=False) + response = client.get("/") assert response.status_code == 200 assert response.text == "ok" - def test_non_local_ip_blocked(self, app): - """Requests from a public IP are rejected with 403.""" - # Patch the ASGI scope to simulate a public client - original_build = app.build_middleware_stack - - async def patched_app(scope, receive, send): - if scope["type"] == "http": - scope["client"] = ("8.8.8.8", 12345) - await original_build()(scope, receive, send) - - app.build_middleware_stack = lambda: patched_app # type: ignore[method-assign] - - client = TestClient(app, raise_server_exceptions=False) - response = client.get("/") + def test_denied_ip_blocked(self): + mock_settings = _make_mcp_settings(mcp_whitelist=["127.0.0.0/8"]) + app = _patch_client_ip(_make_app(), "8.8.8.8") + with patch("module.mcp.security.settings", mock_settings): + client = TestClient(app, raise_server_exceptions=False) + response = client.get("/") assert response.status_code == 403 - assert "MCP access is restricted to local network" in response.text + assert "MCP access denied" in response.text - def test_missing_client_blocked(self, app): - """Requests with no client information are rejected with 403.""" - original_build = app.build_middleware_stack - - async def patched_app(scope, receive, send): - if scope["type"] == "http": - scope["client"] = None - await original_build()(scope, receive, send) - - app.build_middleware_stack = lambda: patched_app # type: ignore[method-assign] - - client = TestClient(app, raise_server_exceptions=False) - response = client.get("/") + def test_empty_whitelist_denies_all(self): + mock_settings = _make_mcp_settings(mcp_whitelist=[], mcp_tokens=[]) + app = _patch_client_ip(_make_app(), "127.0.0.1") + with patch("module.mcp.security.settings", mock_settings): + client = TestClient(app, raise_server_exceptions=False) + response = client.get("/") assert response.status_code == 403 - def test_blocked_response_is_json(self, app): - """The 403 error body is valid JSON with an 'error' key.""" + def test_missing_client_blocked(self): + mock_settings = _make_mcp_settings(mcp_whitelist=["127.0.0.0/8"]) + app = _patch_client_ip(_make_app(), None) + with patch("module.mcp.security.settings", mock_settings): + client = TestClient(app, raise_server_exceptions=False) + response = client.get("/") + assert response.status_code == 403 + + def test_bearer_token_bypasses_ip(self): + mock_settings = _make_mcp_settings( + mcp_whitelist=[], mcp_tokens=["secret-token-123"] + ) + app = _patch_client_ip(_make_app(), "8.8.8.8") + with patch("module.mcp.security.settings", mock_settings): + client = TestClient(app, raise_server_exceptions=False) + response = client.get( + "/", headers={"Authorization": "Bearer secret-token-123"} + ) + assert response.status_code == 200 + + def test_invalid_bearer_token_denied(self): + mock_settings = _make_mcp_settings( + mcp_whitelist=[], mcp_tokens=["secret-token-123"] + ) + app = _patch_client_ip(_make_app(), "8.8.8.8") + with patch("module.mcp.security.settings", mock_settings): + client = TestClient(app, raise_server_exceptions=False) + response = client.get( + "/", headers={"Authorization": "Bearer wrong-token"} + ) + assert response.status_code == 403 + + def test_private_network_with_default_whitelist(self): + default_whitelist = [ + "127.0.0.0/8", + "10.0.0.0/8", + "172.16.0.0/12", + "192.168.0.0/16", + "::1/128", + "fe80::/10", + "fc00::/7", + ] + mock_settings = _make_mcp_settings(mcp_whitelist=default_whitelist) + app = _patch_client_ip(_make_app(), "192.168.1.100") + with patch("module.mcp.security.settings", mock_settings): + client = TestClient(app, raise_server_exceptions=False) + response = client.get("/") + assert response.status_code == 200 + + def test_blocked_response_is_json(self): import json - original_build = app.build_middleware_stack - - async def patched_app(scope, receive, send): - if scope["type"] == "http": - scope["client"] = ("1.2.3.4", 9999) - await original_build()(scope, receive, send) - - app.build_middleware_stack = lambda: patched_app # type: ignore[method-assign] - - client = TestClient(app, raise_server_exceptions=False) - response = client.get("/") + mock_settings = _make_mcp_settings(mcp_whitelist=["127.0.0.0/8"]) + app = _patch_client_ip(_make_app(), "1.2.3.4") + with patch("module.mcp.security.settings", mock_settings): + client = TestClient(app, raise_server_exceptions=False) + response = client.get("/") assert response.status_code == 403 body = json.loads(response.text) assert "error" in body - - def test_private_192_168_allowed(self, app): - """Requests from a 192.168.x.x address pass through.""" - original_build = app.build_middleware_stack - - async def patched_app(scope, receive, send): - if scope["type"] == "http": - scope["client"] = ("192.168.1.100", 54321) - await original_build()(scope, receive, send) - - app.build_middleware_stack = lambda: patched_app # type: ignore[method-assign] - - client = TestClient(app, raise_server_exceptions=False) - response = client.get("/") - assert response.status_code == 200 diff --git a/backend/src/test/test_mock_downloader.py b/backend/src/test/test_mock_downloader.py new file mode 100644 index 00000000..aa21d22d --- /dev/null +++ b/backend/src/test/test_mock_downloader.py @@ -0,0 +1,429 @@ +"""Tests for MockDownloader - state management and API contract.""" + +import pytest + +from module.downloader.client.mock_downloader import MockDownloader + + +@pytest.fixture +def mock_dl() -> MockDownloader: + """Fresh MockDownloader for each test.""" + return MockDownloader() + + +# --------------------------------------------------------------------------- +# Initialization +# --------------------------------------------------------------------------- + + +class TestMockDownloaderInit: + def test_initial_state_is_empty(self, mock_dl): + """MockDownloader starts with no torrents, rules, or feeds.""" + state = mock_dl.get_state() + assert state["torrents"] == {} + assert state["rules"] == {} + assert state["feeds"] == {} + + def test_initial_categories(self, mock_dl): + """Default categories include Bangumi and BangumiCollection.""" + state = mock_dl.get_state() + assert "Bangumi" in state["categories"] + assert "BangumiCollection" in state["categories"] + + def test_initial_prefs(self, mock_dl): + """Default prefs are populated.""" + # Access private attribute directly to confirm defaults + assert mock_dl._prefs["rss_auto_downloading_enabled"] is True + assert mock_dl._prefs["rss_max_articles_per_feed"] == 500 + + +# --------------------------------------------------------------------------- +# Auth / connection +# --------------------------------------------------------------------------- + + +class TestMockDownloaderAuth: + async def test_auth_returns_true(self, mock_dl): + result = await mock_dl.auth() + assert result is True + + async def test_auth_retry_param_accepted(self, mock_dl): + result = await mock_dl.auth(retry=5) + assert result is True + + async def test_logout_does_not_raise(self, mock_dl): + await mock_dl.logout() + + async def test_check_host_returns_true(self, mock_dl): + result = await mock_dl.check_host() + assert result is True + + async def test_check_connection_returns_version_string(self, mock_dl): + result = await mock_dl.check_connection() + assert "mock" in result.lower() + + +# --------------------------------------------------------------------------- +# Prefs +# --------------------------------------------------------------------------- + + +class TestMockDownloaderPrefs: + async def test_prefs_init_updates_prefs(self, mock_dl): + """prefs_init merges given prefs into the internal store.""" + await mock_dl.prefs_init({"rss_refresh_interval": 60, "custom_key": "val"}) + assert mock_dl._prefs["rss_refresh_interval"] == 60 + assert mock_dl._prefs["custom_key"] == "val" + + async def test_get_app_prefs_returns_dict(self, mock_dl): + result = await mock_dl.get_app_prefs() + assert isinstance(result, dict) + assert "save_path" in result + + +# --------------------------------------------------------------------------- +# Categories +# --------------------------------------------------------------------------- + + +class TestMockDownloaderCategories: + async def test_add_category_persists(self, mock_dl): + await mock_dl.add_category("NewCategory") + state = mock_dl.get_state() + assert "NewCategory" in state["categories"] + + async def test_add_duplicate_category_no_error(self, mock_dl): + await mock_dl.add_category("Bangumi") + state = mock_dl.get_state() + # Still only one entry for Bangumi (set semantics) + assert state["categories"].count("Bangumi") == 1 + + +# --------------------------------------------------------------------------- +# Torrent management +# --------------------------------------------------------------------------- + + +class TestMockDownloaderAddTorrents: + async def test_add_torrent_url_returns_true(self, mock_dl): + result = await mock_dl.add_torrents( + torrent_urls="magnet:?xt=urn:btih:abc", + torrent_files=None, + save_path="/downloads/Bangumi", + category="Bangumi", + ) + assert result is True + + async def test_add_torrent_stores_in_state(self, mock_dl): + await mock_dl.add_torrents( + torrent_urls="magnet:?xt=urn:btih:abc", + torrent_files=None, + save_path="/downloads/Bangumi", + category="Bangumi", + ) + state = mock_dl.get_state() + assert len(state["torrents"]) == 1 + + async def test_add_torrent_with_tag_stored(self, mock_dl): + await mock_dl.add_torrents( + torrent_urls="magnet:?xt=urn:btih:abc", + torrent_files=None, + save_path="/downloads/Bangumi", + category="Bangumi", + tags="ab:42", + ) + state = mock_dl.get_state() + torrent = list(state["torrents"].values())[0] + assert torrent["tags"] == "ab:42" + + async def test_add_torrent_with_file_bytes(self, mock_dl): + result = await mock_dl.add_torrents( + torrent_urls=None, + torrent_files=b"\x00\x01\x02", + save_path="/downloads", + category="Bangumi", + ) + assert result is True + + async def test_two_different_torrents_stored_separately(self, mock_dl): + await mock_dl.add_torrents( + torrent_urls="magnet:?xt=urn:btih:aaa", + torrent_files=None, + save_path="/dl", + category="Bangumi", + ) + await mock_dl.add_torrents( + torrent_urls="magnet:?xt=urn:btih:bbb", + torrent_files=None, + save_path="/dl", + category="Bangumi", + ) + state = mock_dl.get_state() + assert len(state["torrents"]) == 2 + + +class TestMockDownloaderTorrentsInfo: + async def test_returns_all_when_no_filter(self, mock_dl): + mock_dl.add_mock_torrent("Anime A", category="Bangumi") + mock_dl.add_mock_torrent("Anime B", category="Bangumi") + result = await mock_dl.torrents_info(status_filter=None, category=None) + assert len(result) == 2 + + async def test_filters_by_category(self, mock_dl): + mock_dl.add_mock_torrent("Anime A", category="Bangumi") + mock_dl.add_mock_torrent("Movie", category="BangumiCollection") + result = await mock_dl.torrents_info(status_filter=None, category="Bangumi") + assert len(result) == 1 + assert result[0]["name"] == "Anime A" + + async def test_filters_by_tag(self, mock_dl): + h1 = mock_dl.add_mock_torrent("Anime A", category="Bangumi") + mock_dl.add_mock_torrent("Anime B", category="Bangumi") + # Manually set the tag on first torrent + mock_dl._torrents[h1]["tags"] = ["ab:1"] + result = await mock_dl.torrents_info( + status_filter=None, category=None, tag="ab:1" + ) + assert len(result) == 1 + assert result[0]["name"] == "Anime A" + + async def test_empty_store_returns_empty_list(self, mock_dl): + result = await mock_dl.torrents_info(status_filter=None, category="Bangumi") + assert result == [] + + +class TestMockDownloaderTorrentsFiles: + async def test_returns_files_for_known_hash(self, mock_dl): + files = [{"name": "ep01.mkv", "size": 500_000_000}] + h = mock_dl.add_mock_torrent("Anime", files=files) + result = await mock_dl.torrents_files(torrent_hash=h) + assert len(result) == 1 + assert result[0]["name"] == "ep01.mkv" + + async def test_returns_empty_list_for_unknown_hash(self, mock_dl): + result = await mock_dl.torrents_files(torrent_hash="nonexistent") + assert result == [] + + +class TestMockDownloaderDelete: + async def test_delete_single_torrent(self, mock_dl): + h = mock_dl.add_mock_torrent("Anime") + await mock_dl.torrents_delete(hash=h) + assert h not in mock_dl._torrents + + async def test_delete_multiple_torrents_pipe_separated(self, mock_dl): + h1 = mock_dl.add_mock_torrent("Anime A") + h2 = mock_dl.add_mock_torrent("Anime B") + await mock_dl.torrents_delete(hash=f"{h1}|{h2}") + assert h1 not in mock_dl._torrents + assert h2 not in mock_dl._torrents + + async def test_delete_nonexistent_hash_no_error(self, mock_dl): + await mock_dl.torrents_delete(hash="deadbeef") + + +class TestMockDownloaderPauseResume: + async def test_pause_sets_state(self, mock_dl): + h = mock_dl.add_mock_torrent("Anime", state="downloading") + await mock_dl.torrents_pause(hashes=h) + assert mock_dl._torrents[h]["state"] == "paused" + + async def test_resume_sets_state(self, mock_dl): + h = mock_dl.add_mock_torrent("Anime", state="paused") + await mock_dl.torrents_resume(hashes=h) + assert mock_dl._torrents[h]["state"] == "downloading" + + async def test_pause_multiple_pipe_separated(self, mock_dl): + h1 = mock_dl.add_mock_torrent("Anime A", state="downloading") + h2 = mock_dl.add_mock_torrent("Anime B", state="downloading") + await mock_dl.torrents_pause(hashes=f"{h1}|{h2}") + assert mock_dl._torrents[h1]["state"] == "paused" + assert mock_dl._torrents[h2]["state"] == "paused" + + async def test_pause_unknown_hash_no_error(self, mock_dl): + await mock_dl.torrents_pause(hashes="deadbeef") + + +# --------------------------------------------------------------------------- +# Rename +# --------------------------------------------------------------------------- + + +class TestMockDownloaderRename: + async def test_rename_returns_true(self, mock_dl): + result = await mock_dl.torrents_rename_file( + torrent_hash="hash1", + old_path="old.mkv", + new_path="new.mkv", + ) + assert result is True + + async def test_rename_with_verify_flag(self, mock_dl): + result = await mock_dl.torrents_rename_file( + torrent_hash="hash1", + old_path="old.mkv", + new_path="new.mkv", + verify=False, + ) + assert result is True + + +# --------------------------------------------------------------------------- +# RSS feed management +# --------------------------------------------------------------------------- + + +class TestMockDownloaderRssFeeds: + async def test_add_feed_stored(self, mock_dl): + await mock_dl.rss_add_feed(url="https://mikan.me/RSS/test", item_path="Mikan") + feeds = await mock_dl.rss_get_feeds() + assert "Mikan" in feeds + assert feeds["Mikan"]["url"] == "https://mikan.me/RSS/test" + + async def test_remove_feed(self, mock_dl): + await mock_dl.rss_add_feed(url="https://example.com", item_path="Feed1") + await mock_dl.rss_remove_item(item_path="Feed1") + feeds = await mock_dl.rss_get_feeds() + assert "Feed1" not in feeds + + async def test_remove_nonexistent_feed_no_error(self, mock_dl): + await mock_dl.rss_remove_item(item_path="nonexistent") + + async def test_get_feeds_initially_empty(self, mock_dl): + feeds = await mock_dl.rss_get_feeds() + assert feeds == {} + + +# --------------------------------------------------------------------------- +# Rules +# --------------------------------------------------------------------------- + + +class TestMockDownloaderRules: + async def test_set_rule_stored(self, mock_dl): + rule_def = {"enable": True, "mustContain": "Anime"} + await mock_dl.rss_set_rule("rule1", rule_def) + rules = await mock_dl.get_download_rule() + assert "rule1" in rules + assert rules["rule1"]["mustContain"] == "Anime" + + async def test_remove_rule(self, mock_dl): + await mock_dl.rss_set_rule("rule1", {"enable": True}) + await mock_dl.remove_rule("rule1") + rules = await mock_dl.get_download_rule() + assert "rule1" not in rules + + async def test_remove_nonexistent_rule_no_error(self, mock_dl): + await mock_dl.remove_rule("nonexistent") + + async def test_get_download_rule_initially_empty(self, mock_dl): + rules = await mock_dl.get_download_rule() + assert rules == {} + + +# --------------------------------------------------------------------------- +# Move / path +# --------------------------------------------------------------------------- + + +class TestMockDownloaderMovePath: + async def test_move_torrent_updates_save_path(self, mock_dl): + h = mock_dl.add_mock_torrent("Anime", save_path="/old/path") + await mock_dl.move_torrent(hashes=h, new_location="/new/path") + assert mock_dl._torrents[h]["save_path"] == "/new/path" + + async def test_move_multiple_pipe_separated(self, mock_dl): + h1 = mock_dl.add_mock_torrent("Anime A", save_path="/old") + h2 = mock_dl.add_mock_torrent("Anime B", save_path="/old") + await mock_dl.move_torrent(hashes=f"{h1}|{h2}", new_location="/new") + assert mock_dl._torrents[h1]["save_path"] == "/new" + assert mock_dl._torrents[h2]["save_path"] == "/new" + + async def test_get_torrent_path_known_hash(self, mock_dl): + h = mock_dl.add_mock_torrent("Anime", save_path="/downloads/Bangumi") + path = await mock_dl.get_torrent_path(h) + assert path == "/downloads/Bangumi" + + async def test_get_torrent_path_unknown_hash_returns_default(self, mock_dl): + path = await mock_dl.get_torrent_path("nonexistent") + assert path == "/tmp/mock-downloads" + + +# --------------------------------------------------------------------------- +# Category assignment +# --------------------------------------------------------------------------- + + +class TestMockDownloaderSetCategory: + async def test_set_category_updates_torrent(self, mock_dl): + h = mock_dl.add_mock_torrent("Anime", category="Bangumi") + await mock_dl.set_category(h, "BangumiCollection") + assert mock_dl._torrents[h]["category"] == "BangumiCollection" + + async def test_set_category_unknown_hash_no_error(self, mock_dl): + await mock_dl.set_category("deadbeef", "Bangumi") + + +# --------------------------------------------------------------------------- +# Tags +# --------------------------------------------------------------------------- + + +class TestMockDownloaderTags: + async def test_add_tag_appends(self, mock_dl): + h = mock_dl.add_mock_torrent("Anime") + await mock_dl.add_tag(h, "ab:1") + assert "ab:1" in mock_dl._torrents[h]["tags"] + + async def test_add_tag_no_duplicates(self, mock_dl): + h = mock_dl.add_mock_torrent("Anime") + await mock_dl.add_tag(h, "ab:1") + await mock_dl.add_tag(h, "ab:1") + assert mock_dl._torrents[h]["tags"].count("ab:1") == 1 + + async def test_add_tag_unknown_hash_no_error(self, mock_dl): + await mock_dl.add_tag("deadbeef", "ab:1") + + async def test_multiple_tags_on_same_torrent(self, mock_dl): + h = mock_dl.add_mock_torrent("Anime") + await mock_dl.add_tag(h, "ab:1") + await mock_dl.add_tag(h, "group:sub") + assert "ab:1" in mock_dl._torrents[h]["tags"] + assert "group:sub" in mock_dl._torrents[h]["tags"] + + +# --------------------------------------------------------------------------- +# add_mock_torrent helper +# --------------------------------------------------------------------------- + + +class TestAddMockTorrentHelper: + def test_generates_hash_from_name(self, mock_dl): + h = mock_dl.add_mock_torrent("Anime") + assert h is not None + assert len(h) == 40 # SHA1 hex digest + + def test_explicit_hash_used(self, mock_dl): + h = mock_dl.add_mock_torrent("Anime", hash="cafebabe" + "0" * 32) + assert h == "cafebabe" + "0" * 32 + + def test_torrent_state_is_completed_by_default(self, mock_dl): + h = mock_dl.add_mock_torrent("Anime") + assert mock_dl._torrents[h]["state"] == "completed" + assert mock_dl._torrents[h]["progress"] == 1.0 + + def test_torrent_state_custom(self, mock_dl): + h = mock_dl.add_mock_torrent("Anime", state="downloading") + assert mock_dl._torrents[h]["state"] == "downloading" + assert mock_dl._torrents[h]["progress"] == 0.5 + + def test_default_file_is_mkv(self, mock_dl): + h = mock_dl.add_mock_torrent("My Anime") + files = mock_dl._torrents[h]["files"] + assert len(files) == 1 + assert files[0]["name"].endswith(".mkv") + + def test_custom_files_stored(self, mock_dl): + custom_files = [{"name": "ep01.mkv"}, {"name": "ep02.mkv"}] + h = mock_dl.add_mock_torrent("Anime", files=custom_files) + assert len(mock_dl._torrents[h]["files"]) == 2 diff --git a/backend/uv.lock b/backend/uv.lock index d456bec0..ddea71e3 100644 --- a/backend/uv.lock +++ b/backend/uv.lock @@ -61,7 +61,7 @@ wheels = [ [[package]] name = "auto-bangumi" -version = "3.2.3b5" +version = "3.2.3" source = { virtual = "." } dependencies = [ { name = "aiosqlite" }, diff --git a/webui/src/components/setting/config-security.vue b/webui/src/components/setting/config-security.vue new file mode 100644 index 00000000..c02515d3 --- /dev/null +++ b/webui/src/components/setting/config-security.vue @@ -0,0 +1,68 @@ + + + + + diff --git a/webui/src/pages/index/config.vue b/webui/src/pages/index/config.vue index 09ffcb75..cb319baa 100644 --- a/webui/src/pages/index/config.vue +++ b/webui/src/pages/index/config.vue @@ -49,6 +49,7 @@ onActivated(() => { + diff --git a/webui/types/config.ts b/webui/types/config.ts index 6673e0a8..1bf42dc4 100644 --- a/webui/types/config.ts +++ b/webui/types/config.ts @@ -2,10 +2,6 @@ import type { TupleToUnion } from './utils'; /** 下载方式 */ export type DownloaderType = ['qbittorrent']; -/** rss parser 源 */ -export type RssParserType = ['mikan']; -/** rss parser 方法 */ -export type RssParserMethodType = ['tmdb', 'mikan', 'parser']; /** rss parser 语言 */ export type RssParserLang = ['zh', 'en', 'jp']; /** 重命名方式 */ @@ -44,12 +40,8 @@ export interface Downloader { } export interface RssParser { enable: boolean; - type: TupleToUnion; - token: string; - custom_url: string; filter: Array; language: TupleToUnion; - parser_type: TupleToUnion; } export interface BangumiManage { enable: boolean; @@ -105,6 +97,17 @@ export interface ExperimentalOpenAI { deployment_id?: string; } +/** Access control for the login endpoint and MCP server. + * Whitelist entries are IPv4/IPv6 CIDR strings (e.g. "192.168.0.0/16"). + * An empty login_whitelist allows all IPs; an empty mcp_whitelist denies all IP-based MCP access. + */ +export interface Security { + login_whitelist: string[]; + login_tokens: string[]; + mcp_whitelist: string[]; + mcp_tokens: string[]; +} + export interface Config { program: Program; downloader: Downloader; @@ -114,6 +117,7 @@ export interface Config { proxy: Proxy; notification: Notification; experimental_openai: ExperimentalOpenAI; + security: Security; } export const initConfig: Config = { @@ -132,12 +136,8 @@ export const initConfig: Config = { }, rss_parser: { enable: true, - type: 'mikan', - token: '', - custom_url: '', filter: [], language: 'zh', - parser_type: 'parser', }, bangumi_manage: { enable: true, @@ -171,4 +171,10 @@ export const initConfig: Config = { api_version: '2020-05-03', deployment_id: '', }, + security: { + login_whitelist: [], + login_tokens: [], + mcp_whitelist: [], + mcp_tokens: [], + }, }; diff --git a/webui/types/dts/auto-imports.d.ts b/webui/types/dts/auto-imports.d.ts index 434d5576..14d25398 100644 --- a/webui/types/dts/auto-imports.d.ts +++ b/webui/types/dts/auto-imports.d.ts @@ -40,6 +40,7 @@ declare global { const getCurrentInstance: typeof import('vue')['getCurrentInstance'] const getCurrentScope: typeof import('vue')['getCurrentScope'] const h: typeof import('vue')['h'] + const i18n: typeof import('../../src/hooks/useMyI18n')['i18n'] const inject: typeof import('vue')['inject'] const isProxy: typeof import('vue')['isProxy'] const isReactive: typeof import('vue')['isReactive'] diff --git a/webui/types/dts/components.d.ts b/webui/types/dts/components.d.ts index 467a1316..348e352c 100644 --- a/webui/types/dts/components.d.ts +++ b/webui/types/dts/components.d.ts @@ -55,6 +55,7 @@ declare module '@vue/runtime-core' { ConfigPlayer: typeof import('./../../src/components/setting/config-player.vue')['default'] ConfigProxy: typeof import('./../../src/components/setting/config-proxy.vue')['default'] ConfigSearchProvider: typeof import('./../../src/components/setting/config-search-provider.vue')['default'] + ConfigSecurity: typeof import('./../../src/components/setting/config-security.vue')['default'] MediaQuery: typeof import('./../../src/components/media-query.vue')['default'] RouterLink: typeof import('vue-router')['RouterLink'] RouterView: typeof import('vue-router')['RouterView']