mirror of
https://github.com/EstrellaXD/Auto_Bangumi.git
synced 2026-03-20 03:46:40 +08:00
feat(security): add security config UI and improve auth/MCP security
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
38
CHANGELOG.md
38
CHANGELOG.md
@@ -1,3 +1,41 @@
|
||||
# [Unreleased]
|
||||
|
||||
## Backend
|
||||
|
||||
### Added
|
||||
|
||||
- 新增 `Security` 配置模型,支持登录 IP 白名单、MCP IP 白名单和 Bearer Token 认证
|
||||
- 新增登录端点 IP 白名单检查中间件 (`check_login_ip`)
|
||||
- MCP 安全中间件升级为可配置模式:支持 CIDR 白名单 + Bearer Token 双重认证
|
||||
- 认证端点支持 `Authorization: Bearer` 令牌绕过 Cookie 登录
|
||||
- 配置 API `_sanitize_dict` 修复:仅对字符串值进行脱敏,避免误处理非字符串字段
|
||||
|
||||
- 新增番剧放送日手动设置 API (`PATCH /api/v1/bangumi/{id}/weekday`),支持锁定放送日防止日历刷新覆盖
|
||||
- 数据库迁移 v9:`bangumi` 表新增 `weekday_locked` 列
|
||||
|
||||
### Changed
|
||||
|
||||
- 重构认证模块:提取 `_issue_token` 公共方法,消除 3 处重复的 JWT 签发逻辑
|
||||
- `get_current_user` 简化为三级认证(DEV 绕过 → Bearer Token → Cookie JWT)
|
||||
- `LocalNetworkMiddleware` 重命名为 `McpAccessMiddleware`,从硬编码 RFC 1918 改为读取配置
|
||||
|
||||
### Tests
|
||||
|
||||
- 新增 101 个单元测试覆盖安全、认证、配置、下载器和 MockDownloader 模块
|
||||
|
||||
## Frontend
|
||||
|
||||
### Added
|
||||
|
||||
- 新增日历拖拽排列功能:可将「未知」番剧拖入星期列,自动设置放送日并锁定
|
||||
- 拖入后显示紫色图钉图标,鼠标悬停显示取消按钮
|
||||
- 锁定的番剧在日历刷新时不会被覆盖
|
||||
- 使用 vuedraggable 实现流畅拖拽动画
|
||||
- 新增安全设置组件 (`config-security.vue`),支持在 WebUI 中配置 IP 白名单和 Token
|
||||
- 前端 `Security` 类型定义和初始化配置
|
||||
|
||||
---
|
||||
|
||||
# [3.2.3] - 2026-02-23
|
||||
|
||||
## Backend
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "auto-bangumi"
|
||||
version = "3.2.3"
|
||||
version = "3.2.4"
|
||||
description = "AutoBangumi - Automated anime download manager"
|
||||
requires-python = ">=3.13"
|
||||
dependencies = [
|
||||
|
||||
@@ -9,6 +9,7 @@ from module.models.user import User, UserUpdate
|
||||
from module.security.api import (
|
||||
active_user,
|
||||
auth_user,
|
||||
check_login_ip,
|
||||
get_current_user,
|
||||
update_user_info,
|
||||
)
|
||||
@@ -18,17 +19,26 @@ from .response import u_response
|
||||
|
||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
|
||||
_TOKEN_EXPIRY_DAYS = 1
|
||||
_TOKEN_MAX_AGE = 86400
|
||||
|
||||
@router.post("/login", response_model=dict)
|
||||
|
||||
def _issue_token(username: str, response: Response) -> dict:
|
||||
"""Create a JWT, set it as an HttpOnly cookie, and return the bearer payload."""
|
||||
token = create_access_token(
|
||||
data={"sub": username}, expires_delta=timedelta(days=_TOKEN_EXPIRY_DAYS)
|
||||
)
|
||||
response.set_cookie(key="token", value=token, httponly=True, max_age=_TOKEN_MAX_AGE)
|
||||
return {"access_token": token, "token_type": "bearer"}
|
||||
|
||||
|
||||
@router.post("/login", response_model=dict, dependencies=[Depends(check_login_ip)])
|
||||
async def login(response: Response, form_data=Depends(OAuth2PasswordRequestForm)):
|
||||
"""Authenticate with username/password and issue a session token."""
|
||||
user = User(username=form_data.username, password=form_data.password)
|
||||
resp = auth_user(user)
|
||||
if resp.status:
|
||||
token = create_access_token(
|
||||
data={"sub": user.username}, expires_delta=timedelta(days=1)
|
||||
)
|
||||
response.set_cookie(key="token", value=token, httponly=True, max_age=86400)
|
||||
return {"access_token": token, "token_type": "bearer"}
|
||||
return _issue_token(user.username, response)
|
||||
return u_response(resp)
|
||||
|
||||
|
||||
@@ -36,6 +46,7 @@ async def login(response: Response, form_data=Depends(OAuth2PasswordRequestForm)
|
||||
"/refresh_token", response_model=dict, dependencies=[Depends(get_current_user)]
|
||||
)
|
||||
async def refresh(response: Response, token: str = Cookie(None)):
|
||||
"""Refresh the current session token and update the active-user timestamp."""
|
||||
payload = decode_token(token)
|
||||
username = payload.get("sub") if payload else None
|
||||
if not username:
|
||||
@@ -43,17 +54,14 @@ async def refresh(response: Response, token: str = Cookie(None)):
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Unauthorized"
|
||||
)
|
||||
active_user[username] = datetime.now()
|
||||
new_token = create_access_token(
|
||||
data={"sub": username}, expires_delta=timedelta(days=1)
|
||||
)
|
||||
response.set_cookie(key="token", value=new_token, httponly=True, max_age=86400)
|
||||
return {"access_token": new_token, "token_type": "bearer"}
|
||||
return _issue_token(username, response)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/logout", response_model=APIResponse, dependencies=[Depends(get_current_user)]
|
||||
)
|
||||
async def logout(response: Response, token: str = Cookie(None)):
|
||||
"""Invalidate the session and clear the token cookie."""
|
||||
payload = decode_token(token)
|
||||
username = payload.get("sub") if payload else None
|
||||
if username:
|
||||
@@ -69,6 +77,7 @@ async def logout(response: Response, token: str = Cookie(None)):
|
||||
async def update_user(
|
||||
user_data: UserUpdate, response: Response, token: str = Cookie(None)
|
||||
):
|
||||
"""Update credentials for the current user and re-issue a fresh token."""
|
||||
payload = decode_token(token)
|
||||
old_user = payload.get("sub") if payload else None
|
||||
if not old_user:
|
||||
@@ -76,17 +85,4 @@ async def update_user(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Unauthorized"
|
||||
)
|
||||
if update_user_info(user_data, old_user):
|
||||
token = create_access_token(
|
||||
data={"sub": old_user}, expires_delta=timedelta(days=1)
|
||||
)
|
||||
response.set_cookie(
|
||||
key="token",
|
||||
value=token,
|
||||
httponly=True,
|
||||
max_age=86400,
|
||||
)
|
||||
return {
|
||||
"access_token": token,
|
||||
"token_type": "bearer",
|
||||
"message": "update success",
|
||||
}
|
||||
return {**_issue_token(old_user, response), "message": "update success"}
|
||||
|
||||
@@ -14,11 +14,12 @@ _SENSITIVE_KEYS = ("password", "api_key", "token", "secret")
|
||||
|
||||
|
||||
def _sanitize_dict(d: dict) -> dict:
|
||||
"""Recursively mask string values whose keys contain sensitive keywords."""
|
||||
result = {}
|
||||
for k, v in d.items():
|
||||
if isinstance(v, dict):
|
||||
result[k] = _sanitize_dict(v)
|
||||
elif any(s in k.lower() for s in _SENSITIVE_KEYS):
|
||||
elif isinstance(v, str) and any(s in k.lower() for s in _SENSITIVE_KEYS):
|
||||
result[k] = "********"
|
||||
else:
|
||||
result[k] = v
|
||||
@@ -27,6 +28,7 @@ def _sanitize_dict(d: dict) -> dict:
|
||||
|
||||
@router.get("/get", dependencies=[Depends(get_current_user)])
|
||||
async def get_config():
|
||||
"""Return the current configuration with sensitive fields masked."""
|
||||
return _sanitize_dict(settings.dict())
|
||||
|
||||
|
||||
@@ -34,6 +36,7 @@ async def get_config():
|
||||
"/update", response_model=APIResponse, dependencies=[Depends(get_current_user)]
|
||||
)
|
||||
async def update_config(config: Config):
|
||||
"""Persist and reload configuration from the supplied payload."""
|
||||
try:
|
||||
settings.save(config_dict=config.dict())
|
||||
settings.load()
|
||||
|
||||
@@ -7,7 +7,7 @@ from dotenv import load_dotenv
|
||||
|
||||
from module.models.config import Config
|
||||
|
||||
from .const import ENV_TO_ATTR
|
||||
from .const import DEFAULT_SETTINGS, ENV_TO_ATTR
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
CONFIG_ROOT = Path("config")
|
||||
@@ -27,6 +27,15 @@ CONFIG_PATH = (
|
||||
|
||||
|
||||
class Settings(Config):
|
||||
"""Runtime configuration singleton.
|
||||
|
||||
On construction, loads from ``CONFIG_PATH`` if the file exists (and
|
||||
immediately re-saves to apply any migrations), otherwise bootstraps
|
||||
defaults from environment variables via ``init()``.
|
||||
|
||||
Use ``settings`` module-level instance rather than instantiating directly.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
if CONFIG_PATH.exists():
|
||||
@@ -36,6 +45,7 @@ class Settings(Config):
|
||||
self.init()
|
||||
|
||||
def load(self):
|
||||
"""Load and validate configuration from ``CONFIG_PATH``, applying migrations."""
|
||||
with open(CONFIG_PATH, "r", encoding="utf-8") as f:
|
||||
config = json.load(f)
|
||||
config = self._migrate_old_config(config)
|
||||
@@ -65,20 +75,27 @@ class Settings(Config):
|
||||
for key in ("type", "custom_url", "token", "enable_tmdb"):
|
||||
rss_parser.pop(key, None)
|
||||
|
||||
# Add security section if missing (preserves local-network MCP default)
|
||||
if "security" not in config:
|
||||
config["security"] = DEFAULT_SETTINGS["security"]
|
||||
|
||||
return config
|
||||
|
||||
def save(self, config_dict: dict | None = None):
|
||||
"""Write configuration to ``CONFIG_PATH``. Uses current state when no dict supplied."""
|
||||
if not config_dict:
|
||||
config_dict = self.model_dump()
|
||||
with open(CONFIG_PATH, "w", encoding="utf-8") as f:
|
||||
json.dump(config_dict, f, indent=4, ensure_ascii=False)
|
||||
|
||||
def init(self):
|
||||
"""Bootstrap a new config file from ``.env`` and environment variables."""
|
||||
load_dotenv(".env")
|
||||
self.__load_from_env()
|
||||
self.save()
|
||||
|
||||
def __load_from_env(self):
|
||||
"""Apply ``ENV_TO_ATTR`` mappings from the process environment to the config dict."""
|
||||
config_dict = self.model_dump()
|
||||
for key, section in ENV_TO_ATTR.items():
|
||||
for env, attr in section.items():
|
||||
@@ -97,12 +114,11 @@ class Settings(Config):
|
||||
logger.info("Config loaded from env")
|
||||
|
||||
@staticmethod
|
||||
def __val_from_env(env: str, attr: tuple):
|
||||
def __val_from_env(env: str, attr: tuple | str):
|
||||
"""Return the environment variable value, applying the converter when attr is a tuple."""
|
||||
if isinstance(attr, tuple):
|
||||
conv_func = attr[1]
|
||||
return conv_func(os.environ[env])
|
||||
else:
|
||||
return os.environ[env]
|
||||
return attr[1](os.environ[env])
|
||||
return os.environ[env]
|
||||
|
||||
@property
|
||||
def group_rules(self):
|
||||
|
||||
@@ -1,4 +1,8 @@
|
||||
# -*- encoding: utf-8 -*-
|
||||
# DEFAULT_SETTINGS: factory defaults written to config.json on first run.
|
||||
# ENV_TO_ATTR: maps AB_* environment variables to Config model attribute paths.
|
||||
# Values are either a string attr name, a (attr_name, converter) tuple, or a
|
||||
# list of such tuples when a single env var sets multiple attributes.
|
||||
DEFAULT_SETTINGS = {
|
||||
"program": {
|
||||
"rss_time": 900,
|
||||
@@ -46,6 +50,20 @@ DEFAULT_SETTINGS = {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"deployment_id": "",
|
||||
},
|
||||
"security": {
|
||||
"login_whitelist": [],
|
||||
"login_tokens": [],
|
||||
"mcp_whitelist": [
|
||||
"127.0.0.0/8",
|
||||
"10.0.0.0/8",
|
||||
"172.16.0.0/12",
|
||||
"192.168.0.0/16",
|
||||
"::1/128",
|
||||
"fe80::/10",
|
||||
"fc00::/7",
|
||||
],
|
||||
"mcp_tokens": [],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -99,8 +117,11 @@ ENV_TO_ATTR = {
|
||||
|
||||
|
||||
class BCOLORS:
|
||||
"""ANSI colour helpers for terminal output."""
|
||||
|
||||
@staticmethod
|
||||
def _(color: str, *args: str) -> str:
|
||||
"""Wrap *args* in the given ANSI colour code and reset at the end."""
|
||||
strings = [str(s) for s in args]
|
||||
return f"{color}{', '.join(strings)}{BCOLORS.ENDC}"
|
||||
|
||||
|
||||
@@ -29,10 +29,10 @@ class MockDownloader:
|
||||
"rss_processing_enabled": True,
|
||||
"rss_refresh_interval": 30,
|
||||
}
|
||||
logger.info("[MockDownloader] Initialized")
|
||||
logger.debug("[MockDownloader] Initialized")
|
||||
|
||||
async def auth(self, retry=3) -> bool:
|
||||
logger.info("[MockDownloader] Auth successful (mocked)")
|
||||
logger.debug("[MockDownloader] Auth successful (mocked)")
|
||||
return True
|
||||
|
||||
async def logout(self):
|
||||
|
||||
@@ -11,6 +11,13 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DownloadClient(TorrentPath):
|
||||
"""Unified async download client.
|
||||
|
||||
Wraps qBittorrent, Aria2, or MockDownloader behind a common interface.
|
||||
Intended to be used as an async context manager; authentication is
|
||||
performed on ``__aenter__`` and the session is closed on ``__aexit__``.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.client = self.__getClient()
|
||||
@@ -18,27 +25,28 @@ class DownloadClient(TorrentPath):
|
||||
|
||||
@staticmethod
|
||||
def __getClient():
|
||||
type = settings.downloader.type
|
||||
"""Instantiate the configured downloader client (qbittorrent | aria2 | mock)."""
|
||||
downloader_type = settings.downloader.type
|
||||
host = settings.downloader.host
|
||||
username = settings.downloader.username
|
||||
password = settings.downloader.password
|
||||
ssl = settings.downloader.ssl
|
||||
if type == "qbittorrent":
|
||||
if downloader_type == "qbittorrent":
|
||||
from .client.qb_downloader import QbDownloader
|
||||
|
||||
return QbDownloader(host, username, password, ssl)
|
||||
elif type == "aria2":
|
||||
elif downloader_type == "aria2":
|
||||
from .client.aria2_downloader import Aria2Downloader
|
||||
|
||||
return Aria2Downloader(host, username, password)
|
||||
elif type == "mock":
|
||||
elif downloader_type == "mock":
|
||||
from .client.mock_downloader import MockDownloader
|
||||
|
||||
logger.info("[Downloader] Using MockDownloader for local development")
|
||||
logger.debug("[Downloader] Using MockDownloader for local development")
|
||||
return MockDownloader()
|
||||
else:
|
||||
logger.error(f"[Downloader] Unsupported downloader type: {type}")
|
||||
raise Exception(f"Unsupported downloader type: {type}")
|
||||
logger.error("[Downloader] Unsupported downloader type: %s", downloader_type)
|
||||
raise Exception(f"Unsupported downloader type: {downloader_type}")
|
||||
|
||||
async def __aenter__(self):
|
||||
if not self.authed:
|
||||
@@ -65,6 +73,7 @@ class DownloadClient(TorrentPath):
|
||||
return await self.client.check_host()
|
||||
|
||||
async def init_downloader(self):
|
||||
"""Apply required qBittorrent RSS preferences and create the Bangumi category."""
|
||||
prefs = {
|
||||
"rss_auto_downloading_enabled": True,
|
||||
"rss_max_articles_per_feed": 500,
|
||||
@@ -84,6 +93,7 @@ class DownloadClient(TorrentPath):
|
||||
settings.downloader.path = self._join_path(prefs["save_path"], "Bangumi")
|
||||
|
||||
async def set_rule(self, data: Bangumi):
|
||||
"""Create or update a qBittorrent RSS auto-download rule for one bangumi entry."""
|
||||
data.rule_name = self._rule_name(data)
|
||||
data.save_path = self._gen_save_path(data)
|
||||
rule = {
|
||||
@@ -145,6 +155,12 @@ class DownloadClient(TorrentPath):
|
||||
await self.client.torrents_resume(hashes)
|
||||
|
||||
async def add_torrent(self, torrent: Torrent | list, bangumi: Bangumi) -> bool:
|
||||
"""Download a torrent (or list of torrents) for the given bangumi entry.
|
||||
|
||||
Handles both magnet links and .torrent file URLs, fetching file bytes
|
||||
when necessary. Tags each torrent with ``ab:<bangumi_id>`` for later
|
||||
episode-offset lookup during rename.
|
||||
"""
|
||||
if not bangumi.save_path:
|
||||
bangumi.save_path = self._gen_save_path(bangumi)
|
||||
async with RequestContent() as req:
|
||||
|
||||
@@ -1,48 +1,68 @@
|
||||
"""MCP access control: restricts connections to local network addresses only."""
|
||||
"""MCP access control: configurable IP whitelist and bearer token authentication."""
|
||||
|
||||
import ipaddress
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
from module.conf import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# RFC 1918 private ranges + loopback + IPv6 equivalents
|
||||
_ALLOWED_NETWORKS = [
|
||||
ipaddress.ip_network("127.0.0.0/8"),
|
||||
ipaddress.ip_network("10.0.0.0/8"),
|
||||
ipaddress.ip_network("172.16.0.0/12"),
|
||||
ipaddress.ip_network("192.168.0.0/16"),
|
||||
ipaddress.ip_network("::1/128"),
|
||||
ipaddress.ip_network("fe80::/10"),
|
||||
ipaddress.ip_network("fc00::/7"),
|
||||
]
|
||||
|
||||
@lru_cache(maxsize=128)
|
||||
def _parse_network(cidr: str) -> ipaddress.IPv4Network | ipaddress.IPv6Network | None:
|
||||
try:
|
||||
return ipaddress.ip_network(cidr, strict=False)
|
||||
except ValueError:
|
||||
logger.warning("[MCP] Invalid CIDR in whitelist: %s", cidr)
|
||||
return None
|
||||
|
||||
|
||||
def _is_local(host: str) -> bool:
|
||||
"""Return True if *host* is a loopback or RFC 1918 private address."""
|
||||
def _is_allowed(host: str, whitelist: list[str]) -> bool:
|
||||
"""Return True if *host* falls within any CIDR range in *whitelist*."""
|
||||
try:
|
||||
addr = ipaddress.ip_address(host)
|
||||
except ValueError:
|
||||
return False
|
||||
return any(addr in net for net in _ALLOWED_NETWORKS)
|
||||
for cidr in whitelist:
|
||||
net = _parse_network(cidr)
|
||||
if net and addr in net:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class LocalNetworkMiddleware(BaseHTTPMiddleware):
|
||||
"""Starlette middleware that blocks requests from non-local IP addresses.
|
||||
def clear_network_cache():
|
||||
"""Clear the parsed network cache (call after config reload)."""
|
||||
_parse_network.cache_clear()
|
||||
|
||||
Returns HTTP 403 for any client outside loopback, RFC 1918, or IPv6
|
||||
link-local/unique-local ranges.
|
||||
|
||||
class McpAccessMiddleware(BaseHTTPMiddleware):
|
||||
"""Configurable access control for MCP endpoint.
|
||||
|
||||
Checks client IP against ``settings.security.mcp_whitelist`` CIDR ranges,
|
||||
and ``Authorization`` header against ``settings.security.mcp_tokens``.
|
||||
If the whitelist is empty and no tokens are configured, all access is denied.
|
||||
"""
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
# Check bearer token first
|
||||
auth_header = request.headers.get("authorization", "")
|
||||
if auth_header.startswith("Bearer "):
|
||||
token = auth_header[7:]
|
||||
if token and token in settings.security.mcp_tokens:
|
||||
return await call_next(request)
|
||||
|
||||
# Check IP whitelist
|
||||
client_host = request.client.host if request.client else None
|
||||
if not client_host or not _is_local(client_host):
|
||||
logger.warning("[MCP] Rejected non-local connection from %s", client_host)
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content={"error": "MCP access is restricted to local network"},
|
||||
)
|
||||
return await call_next(request)
|
||||
if client_host and _is_allowed(client_host, settings.security.mcp_whitelist):
|
||||
return await call_next(request)
|
||||
|
||||
logger.warning("[MCP] Rejected connection from %s", client_host)
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content={"error": "MCP access denied"},
|
||||
)
|
||||
|
||||
@@ -18,7 +18,7 @@ from starlette.requests import Request
|
||||
from starlette.routing import Mount, Route
|
||||
|
||||
from .resources import RESOURCE_TEMPLATES, RESOURCES, handle_resource
|
||||
from .security import LocalNetworkMiddleware
|
||||
from .security import McpAccessMiddleware
|
||||
from .tools import TOOLS, handle_tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -73,8 +73,8 @@ def create_mcp_starlette_app() -> Starlette:
|
||||
- ``GET /sse`` - SSE stream for MCP clients
|
||||
- ``POST /messages/`` - client-to-server message posting
|
||||
|
||||
``LocalNetworkMiddleware`` is applied so the endpoint is only reachable
|
||||
from loopback and RFC 1918 addresses.
|
||||
``McpAccessMiddleware`` is applied to enforce configurable IP whitelist
|
||||
and bearer token access control.
|
||||
"""
|
||||
app = Starlette(
|
||||
routes=[
|
||||
@@ -82,5 +82,5 @@ def create_mcp_starlette_app() -> Starlette:
|
||||
Mount("/messages", app=sse.handle_post_message),
|
||||
],
|
||||
)
|
||||
app.add_middleware(LocalNetworkMiddleware)
|
||||
app.add_middleware(McpAccessMiddleware)
|
||||
return app
|
||||
|
||||
@@ -4,13 +4,27 @@ from typing import Literal, Optional
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
|
||||
def _expand(value: str | None) -> str:
|
||||
"""Expand shell environment variables in *value*, returning empty string for None."""
|
||||
return expandvars(value) if value else ""
|
||||
|
||||
|
||||
class Program(BaseModel):
|
||||
"""Scheduler timing and WebUI port settings."""
|
||||
|
||||
rss_time: int = Field(900, description="Sleep time")
|
||||
rename_time: int = Field(60, description="Rename times in one loop")
|
||||
webui_port: int = Field(7892, description="WebUI port")
|
||||
|
||||
|
||||
class Downloader(BaseModel):
|
||||
"""Download client connection settings.
|
||||
|
||||
Credential fields (``host``, ``username``, ``password``) are stored with a
|
||||
trailing underscore and exposed via properties that expand ``$VAR``
|
||||
environment variable references at access time.
|
||||
"""
|
||||
|
||||
type: str = Field("qbittorrent", description="Downloader type")
|
||||
host_: str = Field("172.17.0.1:8080", alias="host", description="Downloader host")
|
||||
username_: str = Field("admin", alias="username", description="Downloader username")
|
||||
@@ -22,24 +36,28 @@ class Downloader(BaseModel):
|
||||
|
||||
@property
|
||||
def host(self):
|
||||
return expandvars(self.host_)
|
||||
return _expand(self.host_)
|
||||
|
||||
@property
|
||||
def username(self):
|
||||
return expandvars(self.username_)
|
||||
return _expand(self.username_)
|
||||
|
||||
@property
|
||||
def password(self):
|
||||
return expandvars(self.password_)
|
||||
return _expand(self.password_)
|
||||
|
||||
|
||||
class RSSParser(BaseModel):
|
||||
"""RSS feed parsing settings."""
|
||||
|
||||
enable: bool = Field(True, description="Enable RSS parser")
|
||||
filter: list[str] = Field(["720", r"\d+-\d"], description="Filter")
|
||||
language: str = "zh"
|
||||
|
||||
|
||||
class BangumiManage(BaseModel):
|
||||
"""File organisation and renaming settings."""
|
||||
|
||||
enable: bool = Field(True, description="Enable bangumi manage")
|
||||
eps_complete: bool = Field(False, description="Enable eps complete")
|
||||
rename_method: str = Field("pn", description="Rename method")
|
||||
@@ -48,10 +66,14 @@ class BangumiManage(BaseModel):
|
||||
|
||||
|
||||
class Log(BaseModel):
|
||||
"""Logging verbosity settings."""
|
||||
|
||||
debug_enable: bool = Field(False, description="Enable debug")
|
||||
|
||||
|
||||
class Proxy(BaseModel):
|
||||
"""HTTP/SOCKS proxy settings. Credentials support ``$VAR`` expansion."""
|
||||
|
||||
enable: bool = Field(False, description="Enable proxy")
|
||||
type: str = Field("http", description="Proxy type")
|
||||
host: str = Field("", description="Proxy host")
|
||||
@@ -61,11 +83,11 @@ class Proxy(BaseModel):
|
||||
|
||||
@property
|
||||
def username(self):
|
||||
return expandvars(self.username_)
|
||||
return _expand(self.username_)
|
||||
|
||||
@property
|
||||
def password(self):
|
||||
return expandvars(self.password_)
|
||||
return _expand(self.password_)
|
||||
|
||||
|
||||
class NotificationProvider(BaseModel):
|
||||
@@ -103,35 +125,35 @@ class NotificationProvider(BaseModel):
|
||||
|
||||
@property
|
||||
def token(self) -> str:
|
||||
return expandvars(self.token_) if self.token_ else ""
|
||||
return _expand(self.token_)
|
||||
|
||||
@property
|
||||
def chat_id(self) -> str:
|
||||
return expandvars(self.chat_id_) if self.chat_id_ else ""
|
||||
return _expand(self.chat_id_)
|
||||
|
||||
@property
|
||||
def webhook_url(self) -> str:
|
||||
return expandvars(self.webhook_url_) if self.webhook_url_ else ""
|
||||
return _expand(self.webhook_url_)
|
||||
|
||||
@property
|
||||
def server_url(self) -> str:
|
||||
return expandvars(self.server_url_) if self.server_url_ else ""
|
||||
return _expand(self.server_url_)
|
||||
|
||||
@property
|
||||
def device_key(self) -> str:
|
||||
return expandvars(self.device_key_) if self.device_key_ else ""
|
||||
return _expand(self.device_key_)
|
||||
|
||||
@property
|
||||
def user_key(self) -> str:
|
||||
return expandvars(self.user_key_) if self.user_key_ else ""
|
||||
return _expand(self.user_key_)
|
||||
|
||||
@property
|
||||
def api_token(self) -> str:
|
||||
return expandvars(self.api_token_) if self.api_token_ else ""
|
||||
return _expand(self.api_token_)
|
||||
|
||||
@property
|
||||
def url(self) -> str:
|
||||
return expandvars(self.url_) if self.url_ else ""
|
||||
return _expand(self.url_)
|
||||
|
||||
|
||||
class Notification(BaseModel):
|
||||
@@ -149,11 +171,11 @@ class Notification(BaseModel):
|
||||
|
||||
@property
|
||||
def token(self) -> str:
|
||||
return expandvars(self.token_) if self.token_ else ""
|
||||
return _expand(self.token_)
|
||||
|
||||
@property
|
||||
def chat_id(self) -> str:
|
||||
return expandvars(self.chat_id_) if self.chat_id_ else ""
|
||||
return _expand(self.chat_id_)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def migrate_legacy_config(self) -> "Notification":
|
||||
@@ -197,7 +219,35 @@ class ExperimentalOpenAI(BaseModel):
|
||||
return value
|
||||
|
||||
|
||||
class Security(BaseModel):
|
||||
"""Access control configuration for the login endpoint and MCP server.
|
||||
|
||||
Both ``login_whitelist`` and ``mcp_whitelist`` accept IPv4/IPv6 CIDR ranges.
|
||||
An empty ``login_whitelist`` allows all IPs; an empty ``mcp_whitelist``
|
||||
denies all IP-based access (tokens still work).
|
||||
"""
|
||||
|
||||
login_whitelist: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="IP/CIDR whitelist for login access. Empty = allow all.",
|
||||
)
|
||||
login_tokens: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="API bearer tokens that bypass login authentication.",
|
||||
)
|
||||
mcp_whitelist: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="IP/CIDR whitelist for MCP access. Empty = deny all.",
|
||||
)
|
||||
mcp_tokens: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="API bearer tokens for MCP access.",
|
||||
)
|
||||
|
||||
|
||||
class Config(BaseModel):
|
||||
"""Root configuration model composed of all subsection models."""
|
||||
|
||||
program: Program = Program()
|
||||
downloader: Downloader = Downloader()
|
||||
rss_parser: RSSParser = RSSParser()
|
||||
@@ -206,6 +256,7 @@ class Config(BaseModel):
|
||||
proxy: Proxy = Proxy()
|
||||
notification: Notification = Notification()
|
||||
experimental_openai: ExperimentalOpenAI = ExperimentalOpenAI()
|
||||
security: Security = Security()
|
||||
|
||||
def model_dump(self, *args, by_alias=True, **kwargs):
|
||||
return super().model_dump(*args, by_alias=by_alias, **kwargs)
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import Cookie, Depends, HTTPException, status
|
||||
from fastapi import Cookie, Depends, HTTPException, Request, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
|
||||
from module.conf import settings
|
||||
from module.database import Database
|
||||
from module.mcp.security import _is_allowed
|
||||
from module.models.user import User, UserUpdate
|
||||
|
||||
from .jwt import verify_token
|
||||
@@ -20,23 +22,49 @@ except ImportError:
|
||||
DEV_AUTH_BYPASS = VERSION == "DEV_VERSION"
|
||||
|
||||
|
||||
async def get_current_user(token: str = Cookie(None)):
|
||||
def check_login_ip(request: Request):
|
||||
"""Dependency that enforces login IP whitelist.
|
||||
|
||||
If ``settings.security.login_whitelist`` is empty, all IPs are allowed.
|
||||
"""
|
||||
whitelist = settings.security.login_whitelist
|
||||
if not whitelist:
|
||||
return
|
||||
client_host = request.client.host if request.client else None
|
||||
if not client_host or not _is_allowed(client_host, whitelist):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="IP not in login whitelist",
|
||||
)
|
||||
|
||||
|
||||
async def get_current_user(request: Request, token: str = Cookie(None)):
|
||||
"""FastAPI dependency that validates the current session.
|
||||
|
||||
Accepts authentication via (in order of precedence):
|
||||
1. DEV_AUTH_BYPASS when running as DEV_VERSION.
|
||||
2. ``Authorization: Bearer <token>`` header matching ``login_tokens``.
|
||||
3. HttpOnly ``token`` cookie containing a valid JWT with an active session.
|
||||
"""
|
||||
if DEV_AUTH_BYPASS:
|
||||
return "dev_user"
|
||||
# Check bearer token bypass
|
||||
auth_header = request.headers.get("authorization", "")
|
||||
if auth_header.startswith("Bearer "):
|
||||
api_token = auth_header[7:]
|
||||
if api_token and api_token in settings.security.login_tokens:
|
||||
return "api_token_user"
|
||||
if not token:
|
||||
raise UNAUTHORIZED
|
||||
payload = verify_token(token)
|
||||
if not payload:
|
||||
raise UNAUTHORIZED
|
||||
username = payload.get("sub")
|
||||
if not username:
|
||||
raise UNAUTHORIZED
|
||||
if username not in active_user:
|
||||
username = payload.get("sub") if payload else None
|
||||
if not username or username not in active_user:
|
||||
raise UNAUTHORIZED
|
||||
return username
|
||||
|
||||
|
||||
async def get_token_data(token: str = Depends(oauth2_scheme)):
|
||||
"""FastAPI dependency that decodes and returns the OAuth2 bearer token payload."""
|
||||
payload = verify_token(token)
|
||||
if not payload:
|
||||
raise HTTPException(
|
||||
@@ -46,6 +74,7 @@ async def get_token_data(token: str = Depends(oauth2_scheme)):
|
||||
|
||||
|
||||
def update_user_info(user_data: UserUpdate, current_user):
|
||||
"""Persist updated credentials for *current_user* to the database."""
|
||||
try:
|
||||
with Database() as db:
|
||||
db.user.update_user(current_user, user_data)
|
||||
@@ -55,6 +84,7 @@ def update_user_info(user_data: UserUpdate, current_user):
|
||||
|
||||
|
||||
def auth_user(user: User):
|
||||
"""Verify credentials and register the user in ``active_user`` on success."""
|
||||
with Database() as db:
|
||||
resp = db.user.auth_user(user)
|
||||
if resp.status:
|
||||
|
||||
@@ -185,3 +185,98 @@ class TestUpdateCredentials:
|
||||
except Exception:
|
||||
# Expected - endpoint doesn't handle failure case properly
|
||||
pass
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Refresh token: cookie-based username resolution
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRefreshTokenCookieBehavior:
|
||||
def test_refresh_with_no_cookie_raises_401(self, authed_client):
|
||||
"""GET /refresh_token with missing token cookie raises 401."""
|
||||
# Override auth to allow route but provide no cookie token
|
||||
with patch("module.api.auth.decode_token", return_value=None):
|
||||
response = authed_client.get("/api/v1/auth/refresh_token")
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_refresh_with_valid_cookie_updates_active_user(self, authed_client):
|
||||
"""GET /refresh_token updates the active user timestamp."""
|
||||
token = create_access_token(data={"sub": "testuser"})
|
||||
authed_client.cookies.set("token", token)
|
||||
active_users: dict = {}
|
||||
with patch("module.api.auth.active_user", active_users):
|
||||
response = authed_client.get("/api/v1/auth/refresh_token")
|
||||
assert response.status_code == 200
|
||||
assert "testuser" in active_users
|
||||
|
||||
def test_refresh_returns_new_token(self, authed_client):
|
||||
"""GET /refresh_token issues a valid JWT with bearer type."""
|
||||
token = create_access_token(data={"sub": "testuser"})
|
||||
authed_client.cookies.set("token", token)
|
||||
with patch("module.api.auth.active_user", {}):
|
||||
response = authed_client.get("/api/v1/auth/refresh_token")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["token_type"] == "bearer"
|
||||
assert isinstance(data["access_token"], str)
|
||||
assert len(data["access_token"]) > 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Logout: per-user removal
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLogoutCookieBehavior:
|
||||
def test_logout_removes_only_current_user(self, authed_client):
|
||||
"""GET /logout removes the current user from active_user, not others."""
|
||||
token = create_access_token(data={"sub": "testuser"})
|
||||
authed_client.cookies.set("token", token)
|
||||
active_users = {
|
||||
"testuser": datetime.now(),
|
||||
"otheruser": datetime.now(),
|
||||
}
|
||||
with patch("module.api.auth.active_user", active_users):
|
||||
response = authed_client.get("/api/v1/auth/logout")
|
||||
assert response.status_code == 200
|
||||
assert "testuser" not in active_users
|
||||
assert "otheruser" in active_users
|
||||
|
||||
def test_logout_with_no_cookie_still_succeeds(self, authed_client):
|
||||
"""GET /logout with no cookie clears nothing but returns success."""
|
||||
with patch("module.api.auth.decode_token", return_value=None):
|
||||
with patch("module.api.auth.active_user", {}):
|
||||
response = authed_client.get("/api/v1/auth/logout")
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Update: cookie-based user resolution
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestUpdateCookieBehavior:
|
||||
def test_update_with_no_cookie_raises_401(self, authed_client):
|
||||
"""POST /auth/update with no cookie raises 401."""
|
||||
with patch("module.api.auth.decode_token", return_value=None):
|
||||
response = authed_client.post(
|
||||
"/api/v1/auth/update",
|
||||
json={"old_password": "old", "new_password": "new"},
|
||||
)
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_update_with_valid_cookie_succeeds(self, authed_client):
|
||||
"""POST /auth/update resolves username from cookie and issues new token."""
|
||||
token = create_access_token(data={"sub": "testuser"})
|
||||
authed_client.cookies.set("token", token)
|
||||
with patch("module.api.auth.active_user", {"testuser": datetime.now()}):
|
||||
with patch("module.api.auth.update_user_info", return_value=True):
|
||||
response = authed_client.post(
|
||||
"/api/v1/auth/update",
|
||||
json={"old_password": "oldpass", "new_password": "newpass"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "access_token" in data
|
||||
assert data["message"] == "update success"
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Tests for Config API endpoints."""
|
||||
"""Tests for Config API endpoints and config sanitization."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
@@ -7,6 +7,7 @@ from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from module.api import v1
|
||||
from module.api.config import _sanitize_dict
|
||||
from module.models.config import Config
|
||||
from module.security.api import get_current_user
|
||||
|
||||
@@ -263,3 +264,99 @@ class TestUpdateConfig:
|
||||
response = authed_client.patch("/api/v1/config/update", json=invalid_data)
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _sanitize_dict unit tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSanitizeDict:
|
||||
def test_masks_password_key(self):
|
||||
"""Keys containing 'password' are masked."""
|
||||
result = _sanitize_dict({"password": "secret"})
|
||||
assert result["password"] == "********"
|
||||
|
||||
def test_masks_api_key(self):
|
||||
"""Keys containing 'api_key' are masked."""
|
||||
result = _sanitize_dict({"api_key": "sk-abc123"})
|
||||
assert result["api_key"] == "********"
|
||||
|
||||
def test_masks_token_key(self):
|
||||
"""Keys containing 'token' are masked."""
|
||||
result = _sanitize_dict({"token": "bearer-xyz"})
|
||||
assert result["token"] == "********"
|
||||
|
||||
def test_masks_secret_key(self):
|
||||
"""Keys containing 'secret' are masked."""
|
||||
result = _sanitize_dict({"my_secret": "topsecret"})
|
||||
assert result["my_secret"] == "********"
|
||||
|
||||
def test_case_insensitive_key_matching(self):
|
||||
"""Sensitive key matching is case-insensitive."""
|
||||
result = _sanitize_dict({"API_KEY": "abc"})
|
||||
assert result["API_KEY"] == "********"
|
||||
|
||||
def test_non_sensitive_keys_pass_through(self):
|
||||
"""Non-sensitive keys are returned unchanged."""
|
||||
result = _sanitize_dict({"host": "localhost", "port": 8080, "enable": True})
|
||||
assert result["host"] == "localhost"
|
||||
assert result["port"] == 8080
|
||||
assert result["enable"] is True
|
||||
|
||||
def test_nested_dict_recursed(self):
|
||||
"""Nested dicts are processed recursively."""
|
||||
result = _sanitize_dict({
|
||||
"downloader": {
|
||||
"host": "localhost",
|
||||
"password": "secret",
|
||||
}
|
||||
})
|
||||
assert result["downloader"]["host"] == "localhost"
|
||||
assert result["downloader"]["password"] == "********"
|
||||
|
||||
def test_deeply_nested_dict(self):
|
||||
"""Deeply nested sensitive keys are masked."""
|
||||
result = _sanitize_dict({
|
||||
"level1": {
|
||||
"level2": {
|
||||
"api_key": "deep-secret"
|
||||
}
|
||||
}
|
||||
})
|
||||
assert result["level1"]["level2"]["api_key"] == "********"
|
||||
|
||||
def test_non_string_value_not_masked(self):
|
||||
"""Non-string values with sensitive-looking keys are NOT masked."""
|
||||
result = _sanitize_dict({"password": 12345})
|
||||
# Only string values are masked; integers pass through
|
||||
assert result["password"] == 12345
|
||||
|
||||
def test_empty_dict(self):
|
||||
"""Empty dict returns empty dict."""
|
||||
assert _sanitize_dict({}) == {}
|
||||
|
||||
def test_mixed_sensitive_and_plain(self):
|
||||
"""Mix of sensitive and plain keys handled correctly."""
|
||||
result = _sanitize_dict({
|
||||
"username": "admin",
|
||||
"password": "secret",
|
||||
"host": "10.0.0.1",
|
||||
"token": "jwt-abc",
|
||||
})
|
||||
assert result["username"] == "admin"
|
||||
assert result["host"] == "10.0.0.1"
|
||||
assert result["password"] == "********"
|
||||
assert result["token"] == "********"
|
||||
|
||||
def test_get_config_masks_sensitive_fields(self, authed_client):
|
||||
"""GET /config/get response masks password and api_key fields."""
|
||||
test_config = Config()
|
||||
with patch("module.api.config.settings", test_config):
|
||||
response = authed_client.get("/api/v1/config/get")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
# Downloader password should be masked
|
||||
assert data["downloader"]["password"] == "********"
|
||||
# OpenAI api_key should be masked (it's an empty string but still masked)
|
||||
assert data["experimental_openai"]["api_key"] == "********"
|
||||
|
||||
@@ -151,6 +151,15 @@ class TestPasswordHashing:
|
||||
|
||||
|
||||
class TestGetCurrentUser:
|
||||
@staticmethod
|
||||
def _mock_request(authorization=""):
|
||||
"""Create a mock Request with the given Authorization header."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
request = MagicMock()
|
||||
request.headers = {"authorization": authorization}
|
||||
return request
|
||||
|
||||
@patch("module.security.api.DEV_AUTH_BYPASS", False)
|
||||
async def test_no_cookie_raises_401(self):
|
||||
"""get_current_user raises 401 when no token cookie."""
|
||||
@@ -159,7 +168,7 @@ class TestGetCurrentUser:
|
||||
from module.security.api import get_current_user
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_current_user(token=None)
|
||||
await get_current_user(request=self._mock_request(), token=None)
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
@patch("module.security.api.DEV_AUTH_BYPASS", False)
|
||||
@@ -170,7 +179,7 @@ class TestGetCurrentUser:
|
||||
from module.security.api import get_current_user
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_current_user(token="invalid.jwt.token")
|
||||
await get_current_user(request=self._mock_request(), token="invalid.jwt.token")
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
@patch("module.security.api.DEV_AUTH_BYPASS", False)
|
||||
@@ -186,7 +195,7 @@ class TestGetCurrentUser:
|
||||
active_user.clear()
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_current_user(token=token)
|
||||
await get_current_user(request=self._mock_request(), token=token)
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
@patch("module.security.api.DEV_AUTH_BYPASS", False)
|
||||
@@ -202,8 +211,113 @@ class TestGetCurrentUser:
|
||||
active_user.clear()
|
||||
active_user["active_user"] = datetime.now()
|
||||
|
||||
result = await get_current_user(token=token)
|
||||
result = await get_current_user(request=self._mock_request(), token=token)
|
||||
assert result == "active_user"
|
||||
|
||||
# Cleanup
|
||||
active_user.clear()
|
||||
|
||||
@patch("module.security.api.DEV_AUTH_BYPASS", True)
|
||||
async def test_dev_bypass_skips_auth(self):
|
||||
"""When DEV_AUTH_BYPASS is True, get_current_user returns 'dev_user' unconditionally."""
|
||||
from module.security.api import get_current_user
|
||||
|
||||
result = await get_current_user(request=self._mock_request(), token=None)
|
||||
assert result == "dev_user"
|
||||
|
||||
@patch("module.security.api.DEV_AUTH_BYPASS", False)
|
||||
async def test_bearer_token_bypass_valid(self):
|
||||
"""A valid login_token in Authorization header returns 'api_token_user'."""
|
||||
from module.security.api import get_current_user
|
||||
|
||||
mock_request = self._mock_request(authorization="Bearer valid-api-token")
|
||||
mock_security = type("S", (), {"login_tokens": ["valid-api-token"]})()
|
||||
mock_settings = type("Settings", (), {"security": mock_security})()
|
||||
|
||||
with patch("module.security.api.settings", mock_settings):
|
||||
result = await get_current_user(request=mock_request, token=None)
|
||||
assert result == "api_token_user"
|
||||
|
||||
@patch("module.security.api.DEV_AUTH_BYPASS", False)
|
||||
async def test_bearer_token_bypass_invalid(self):
|
||||
"""An invalid login_token still falls through to cookie check."""
|
||||
from fastapi import HTTPException
|
||||
|
||||
from module.security.api import get_current_user
|
||||
|
||||
mock_request = self._mock_request(authorization="Bearer wrong-token")
|
||||
mock_security = type("S", (), {"login_tokens": ["correct-token"]})()
|
||||
mock_settings = type("Settings", (), {"security": mock_security})()
|
||||
|
||||
with patch("module.security.api.settings", mock_settings):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_current_user(request=mock_request, token=None)
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# check_login_ip
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCheckLoginIp:
|
||||
@staticmethod
|
||||
def _make_request(host: str | None):
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
request = MagicMock()
|
||||
if host is None:
|
||||
request.client = None
|
||||
else:
|
||||
request.client = MagicMock()
|
||||
request.client.host = host
|
||||
return request
|
||||
|
||||
def test_empty_whitelist_allows_all(self):
|
||||
"""When login_whitelist is empty, all IPs pass."""
|
||||
from module.security.api import check_login_ip
|
||||
|
||||
mock_security = type("S", (), {"login_whitelist": []})()
|
||||
mock_settings = type("Settings", (), {"security": mock_security})()
|
||||
|
||||
with patch("module.security.api.settings", mock_settings):
|
||||
# Should not raise
|
||||
check_login_ip(request=self._make_request("8.8.8.8"))
|
||||
|
||||
def test_allowed_ip_passes(self):
|
||||
"""IP in whitelist does not raise."""
|
||||
from module.security.api import check_login_ip
|
||||
|
||||
mock_security = type("S", (), {"login_whitelist": ["192.168.0.0/16"]})()
|
||||
mock_settings = type("Settings", (), {"security": mock_security})()
|
||||
|
||||
with patch("module.security.api.settings", mock_settings):
|
||||
check_login_ip(request=self._make_request("192.168.1.100"))
|
||||
|
||||
def test_blocked_ip_raises_403(self):
|
||||
"""IP outside whitelist raises 403."""
|
||||
from fastapi import HTTPException
|
||||
|
||||
from module.security.api import check_login_ip
|
||||
|
||||
mock_security = type("S", (), {"login_whitelist": ["192.168.0.0/16"]})()
|
||||
mock_settings = type("Settings", (), {"security": mock_security})()
|
||||
|
||||
with patch("module.security.api.settings", mock_settings):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
check_login_ip(request=self._make_request("8.8.8.8"))
|
||||
assert exc_info.value.status_code == 403
|
||||
|
||||
def test_no_client_raises_403_when_whitelist_set(self):
|
||||
"""Missing client info raises 403 when whitelist is non-empty."""
|
||||
from fastapi import HTTPException
|
||||
|
||||
from module.security.api import check_login_ip
|
||||
|
||||
mock_security = type("S", (), {"login_whitelist": ["192.168.0.0/16"]})()
|
||||
mock_settings = type("Settings", (), {"security": mock_security})()
|
||||
|
||||
with patch("module.security.api.settings", mock_settings):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
check_login_ip(request=self._make_request(None))
|
||||
assert exc_info.value.status_code == 403
|
||||
|
||||
@@ -8,14 +8,16 @@ from unittest.mock import patch
|
||||
|
||||
from module.models.config import (
|
||||
Config,
|
||||
Program,
|
||||
Downloader,
|
||||
RSSParser,
|
||||
BangumiManage,
|
||||
Proxy,
|
||||
Notification as NotificationConfig,
|
||||
NotificationProvider,
|
||||
Program,
|
||||
Proxy,
|
||||
RSSParser,
|
||||
Security,
|
||||
)
|
||||
from module.conf.config import Settings
|
||||
from module.conf.const import BCOLORS, DEFAULT_SETTINGS
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -228,3 +230,281 @@ class TestEnvOverrides:
|
||||
s.init()
|
||||
|
||||
assert "192.168.1.100:9090" in s.downloader.host
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Security model
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSecurityModel:
|
||||
def test_security_defaults(self):
|
||||
"""Security has empty whitelists and token lists by default."""
|
||||
sec = Security()
|
||||
assert sec.login_whitelist == []
|
||||
assert sec.login_tokens == []
|
||||
assert sec.mcp_whitelist == []
|
||||
assert sec.mcp_tokens == []
|
||||
|
||||
def test_security_in_config(self):
|
||||
"""Config includes a Security section with correct defaults."""
|
||||
config = Config()
|
||||
assert hasattr(config, "security")
|
||||
assert isinstance(config.security, Security)
|
||||
assert config.security.login_whitelist == []
|
||||
|
||||
def test_security_populated(self):
|
||||
"""Security fields accept lists of CIDRs and tokens."""
|
||||
sec = Security(
|
||||
login_whitelist=["192.168.0.0/16"],
|
||||
login_tokens=["token-abc"],
|
||||
mcp_whitelist=["10.0.0.0/8"],
|
||||
mcp_tokens=["mcp-secret"],
|
||||
)
|
||||
assert "192.168.0.0/16" in sec.login_whitelist
|
||||
assert "token-abc" in sec.login_tokens
|
||||
assert "10.0.0.0/8" in sec.mcp_whitelist
|
||||
assert "mcp-secret" in sec.mcp_tokens
|
||||
|
||||
def test_security_roundtrip_serialization(self):
|
||||
"""Security serializes and deserializes correctly."""
|
||||
original = Security(
|
||||
login_whitelist=["127.0.0.0/8"],
|
||||
mcp_tokens=["tok1"],
|
||||
)
|
||||
data = original.model_dump()
|
||||
restored = Security.model_validate(data)
|
||||
assert restored.login_whitelist == ["127.0.0.0/8"]
|
||||
assert restored.mcp_tokens == ["tok1"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# NotificationProvider model
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestNotificationProvider:
|
||||
def test_minimal_provider(self):
|
||||
"""NotificationProvider requires only type."""
|
||||
p = NotificationProvider(type="telegram")
|
||||
assert p.type == "telegram"
|
||||
assert p.enabled is True
|
||||
|
||||
def test_telegram_provider_fields(self):
|
||||
"""Telegram provider stores token and chat_id."""
|
||||
p = NotificationProvider(type="telegram", token="bot123", chat_id="-100456")
|
||||
assert p.token == "bot123"
|
||||
assert p.chat_id == "-100456"
|
||||
|
||||
def test_discord_provider_fields(self):
|
||||
"""Discord provider stores webhook_url."""
|
||||
p = NotificationProvider(
|
||||
type="discord", webhook_url="https://discord.com/api/webhooks/123/abc"
|
||||
)
|
||||
assert p.webhook_url == "https://discord.com/api/webhooks/123/abc"
|
||||
|
||||
def test_bark_provider_fields(self):
|
||||
"""Bark provider stores server_url and device_key."""
|
||||
p = NotificationProvider(
|
||||
type="bark", server_url="https://api.day.app", device_key="mykey"
|
||||
)
|
||||
assert p.server_url == "https://api.day.app"
|
||||
assert p.device_key == "mykey"
|
||||
|
||||
def test_pushover_provider_fields(self):
|
||||
"""Pushover provider stores user_key and api_token."""
|
||||
p = NotificationProvider(type="pushover", user_key="uk1", api_token="at1")
|
||||
assert p.user_key == "uk1"
|
||||
assert p.api_token == "at1"
|
||||
|
||||
def test_url_field_property(self):
|
||||
"""Webhook provider stores url."""
|
||||
p = NotificationProvider(type="webhook", url="https://example.com/hook")
|
||||
assert p.url == "https://example.com/hook"
|
||||
|
||||
def test_optional_fields_default_empty_string(self):
|
||||
"""Unset optional properties return empty string, not None."""
|
||||
p = NotificationProvider(type="telegram")
|
||||
assert p.token == ""
|
||||
assert p.chat_id == ""
|
||||
assert p.webhook_url == ""
|
||||
|
||||
def test_provider_can_be_disabled(self):
|
||||
"""Provider can be disabled without removing it."""
|
||||
p = NotificationProvider(type="telegram", enabled=False)
|
||||
assert p.enabled is False
|
||||
|
||||
def test_env_var_expansion_in_token(self, monkeypatch):
|
||||
"""Token field expands shell environment variables."""
|
||||
monkeypatch.setenv("TEST_BOT_TOKEN", "real-token-value")
|
||||
p = NotificationProvider(type="telegram", token="$TEST_BOT_TOKEN")
|
||||
assert p.token == "real-token-value"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Notification model - legacy migration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestNotificationLegacyMigration:
|
||||
def test_new_format_no_migration(self):
|
||||
"""New format with providers list is not touched."""
|
||||
n = NotificationConfig(
|
||||
enable=True,
|
||||
providers=[NotificationProvider(type="telegram", token="tok")],
|
||||
)
|
||||
assert len(n.providers) == 1
|
||||
assert n.providers[0].type == "telegram"
|
||||
|
||||
def test_old_format_migrates_to_provider(self):
|
||||
"""Old single-provider fields (type, token, chat_id) migrate to providers list."""
|
||||
n = NotificationConfig(
|
||||
enable=True,
|
||||
type="telegram",
|
||||
token="bot_token",
|
||||
chat_id="-100123",
|
||||
)
|
||||
assert len(n.providers) == 1
|
||||
provider = n.providers[0]
|
||||
assert provider.type == "telegram"
|
||||
assert provider.enabled is True
|
||||
|
||||
def test_old_format_no_migration_when_providers_already_set(self):
|
||||
"""When providers already exist, legacy fields do not create additional providers."""
|
||||
n = NotificationConfig(
|
||||
enable=True,
|
||||
type="telegram",
|
||||
token="unused",
|
||||
providers=[NotificationProvider(type="discord", webhook_url="https://d.co")],
|
||||
)
|
||||
assert len(n.providers) == 1
|
||||
assert n.providers[0].type == "discord"
|
||||
|
||||
def test_notification_empty_providers_by_default(self):
|
||||
"""Default Notification has no providers."""
|
||||
n = NotificationConfig()
|
||||
assert n.providers == []
|
||||
assert n.enable is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Downloader env-var expansion
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDownloaderEnvExpansion:
|
||||
def test_host_expands_env_var(self, monkeypatch):
|
||||
"""Downloader.host expands $VAR references."""
|
||||
monkeypatch.setenv("QB_HOST", "192.168.5.10:8080")
|
||||
d = Downloader(host="$QB_HOST")
|
||||
assert d.host == "192.168.5.10:8080"
|
||||
|
||||
def test_username_expands_env_var(self, monkeypatch):
|
||||
"""Downloader.username expands $VAR references."""
|
||||
monkeypatch.setenv("QB_USER", "myuser")
|
||||
d = Downloader(username="$QB_USER")
|
||||
assert d.username == "myuser"
|
||||
|
||||
def test_password_expands_env_var(self, monkeypatch):
|
||||
"""Downloader.password expands $VAR references."""
|
||||
monkeypatch.setenv("QB_PASS", "s3cret")
|
||||
d = Downloader(password="$QB_PASS")
|
||||
assert d.password == "s3cret"
|
||||
|
||||
def test_literal_host_not_expanded(self):
|
||||
"""Literal host strings without $ are returned as-is."""
|
||||
d = Downloader(host="localhost:8080")
|
||||
assert d.host == "localhost:8080"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DEFAULT_SETTINGS structure
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDefaultSettings:
|
||||
def test_security_section_present(self):
|
||||
"""DEFAULT_SETTINGS contains a security section."""
|
||||
assert "security" in DEFAULT_SETTINGS
|
||||
|
||||
def test_security_default_mcp_whitelist(self):
|
||||
"""Default MCP whitelist contains private network ranges."""
|
||||
mcp_wl = DEFAULT_SETTINGS["security"]["mcp_whitelist"]
|
||||
assert "127.0.0.0/8" in mcp_wl
|
||||
assert "192.168.0.0/16" in mcp_wl
|
||||
assert "10.0.0.0/8" in mcp_wl
|
||||
|
||||
def test_security_default_tokens_empty(self):
|
||||
"""Default security token lists are empty."""
|
||||
assert DEFAULT_SETTINGS["security"]["login_tokens"] == []
|
||||
assert DEFAULT_SETTINGS["security"]["mcp_tokens"] == []
|
||||
|
||||
def test_notification_uses_providers_format(self):
|
||||
"""DEFAULT_SETTINGS notification uses new providers format."""
|
||||
notif = DEFAULT_SETTINGS["notification"]
|
||||
assert "providers" in notif
|
||||
assert notif["providers"] == []
|
||||
assert "type" not in notif
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# BCOLORS utility
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBCOLORS:
|
||||
def test_wrap_single_string(self):
|
||||
"""BCOLORS._() wraps a string with color codes and reset."""
|
||||
result = BCOLORS._(BCOLORS.OKGREEN, "hello")
|
||||
assert "hello" in result
|
||||
assert BCOLORS.OKGREEN in result
|
||||
assert BCOLORS.ENDC in result
|
||||
|
||||
def test_wrap_multiple_strings(self):
|
||||
"""BCOLORS._() joins multiple args with commas."""
|
||||
result = BCOLORS._(BCOLORS.WARNING, "foo", "bar")
|
||||
assert "foo" in result
|
||||
assert "bar" in result
|
||||
|
||||
def test_wrap_non_string_arg(self):
|
||||
"""BCOLORS._() converts non-string args to str."""
|
||||
result = BCOLORS._(BCOLORS.FAIL, 42)
|
||||
assert "42" in result
|
||||
|
||||
def test_all_color_constants_are_strings(self):
|
||||
"""All BCOLORS constants are strings."""
|
||||
for attr in ["HEADER", "OKBLUE", "OKCYAN", "OKGREEN", "WARNING", "FAIL", "ENDC"]:
|
||||
assert isinstance(getattr(BCOLORS, attr), str)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Migration: security section injection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMigrateSecuritySection:
|
||||
def test_adds_security_when_missing(self):
|
||||
"""_migrate_old_config injects a default security section when absent."""
|
||||
old_config = {
|
||||
"program": {},
|
||||
"rss_parser": {},
|
||||
}
|
||||
result = Settings._migrate_old_config(old_config)
|
||||
assert "security" in result
|
||||
assert "mcp_whitelist" in result["security"]
|
||||
|
||||
def test_preserves_existing_security_section(self):
|
||||
"""_migrate_old_config does not overwrite an existing security section."""
|
||||
existing_config = {
|
||||
"program": {},
|
||||
"rss_parser": {},
|
||||
"security": {
|
||||
"login_whitelist": ["10.0.0.0/8"],
|
||||
"login_tokens": ["mytoken"],
|
||||
"mcp_whitelist": [],
|
||||
"mcp_tokens": [],
|
||||
},
|
||||
}
|
||||
result = Settings._migrate_old_config(existing_config)
|
||||
assert result["security"]["login_tokens"] == ["mytoken"]
|
||||
assert result["security"]["login_whitelist"] == ["10.0.0.0/8"]
|
||||
|
||||
@@ -293,7 +293,68 @@ class TestClientDelegation:
|
||||
result = await download_client.rename_torrent_file("hash1", "old.mkv", "new.mkv")
|
||||
assert result is False
|
||||
|
||||
async def test_rename_torrent_file_passes_verify_flag(
|
||||
self, download_client, mock_qb_client
|
||||
):
|
||||
"""rename_torrent_file forwards the verify kwarg to the underlying client."""
|
||||
mock_qb_client.torrents_rename_file.return_value = True
|
||||
await download_client.rename_torrent_file(
|
||||
"hash1", "old.mkv", "new.mkv", verify=False
|
||||
)
|
||||
call_kwargs = mock_qb_client.torrents_rename_file.call_args[1]
|
||||
assert call_kwargs["verify"] is False
|
||||
|
||||
async def test_delete_torrent(self, download_client, mock_qb_client):
|
||||
"""delete_torrent delegates to client.torrents_delete."""
|
||||
await download_client.delete_torrent("hash1", delete_files=True)
|
||||
mock_qb_client.torrents_delete.assert_called_once_with("hash1", delete_files=True)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# add_tag
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAddTag:
|
||||
async def test_add_tag_delegates_to_client(self, download_client, mock_qb_client):
|
||||
"""add_tag delegates to client.add_tag."""
|
||||
mock_qb_client.add_tag = AsyncMock(return_value=None)
|
||||
download_client.client = mock_qb_client
|
||||
await download_client.add_tag("deadbeef12345678", "ab:42")
|
||||
mock_qb_client.add_tag.assert_called_once_with("deadbeef12345678", "ab:42")
|
||||
|
||||
async def test_add_tag_short_hash_no_error(self, download_client, mock_qb_client):
|
||||
"""add_tag with a hash shorter than 8 chars does not crash the slice."""
|
||||
mock_qb_client.add_tag = AsyncMock(return_value=None)
|
||||
download_client.client = mock_qb_client
|
||||
# Should not raise even for short hashes
|
||||
await download_client.add_tag("abc", "ab:1")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Context manager: ConnectionError on failed auth
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestContextManagerAuth:
|
||||
async def test_aenter_raises_on_auth_failure(self, download_client, mock_qb_client):
|
||||
"""__aenter__ raises ConnectionError when auth fails."""
|
||||
mock_qb_client.auth.return_value = False
|
||||
download_client.authed = False
|
||||
with pytest.raises(ConnectionError, match="authentication failed"):
|
||||
await download_client.__aenter__()
|
||||
|
||||
async def test_aenter_succeeds_when_auth_passes(self, download_client, mock_qb_client):
|
||||
"""__aenter__ returns self when auth succeeds."""
|
||||
mock_qb_client.auth.return_value = True
|
||||
download_client.authed = False
|
||||
result = await download_client.__aenter__()
|
||||
assert result is download_client
|
||||
assert download_client.authed is True
|
||||
|
||||
async def test_aexit_calls_logout_when_authed(self, download_client, mock_qb_client):
|
||||
"""__aexit__ calls logout and resets authed when session was active."""
|
||||
download_client.authed = True
|
||||
await download_client.__aexit__(None, None, None)
|
||||
mock_qb_client.logout.assert_called_once()
|
||||
assert download_client.authed is False
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Tests for module.mcp.security - LocalNetworkMiddleware and _is_local()."""
|
||||
"""Tests for module.mcp.security - McpAccessMiddleware, _is_allowed(), and clear_network_cache()."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from starlette.applications import Starlette
|
||||
@@ -8,248 +8,223 @@ from starlette.responses import PlainTextResponse
|
||||
from starlette.routing import Route
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
from module.mcp.security import LocalNetworkMiddleware, _is_local
|
||||
from module.mcp.security import McpAccessMiddleware, _is_allowed, clear_network_cache
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _is_local() unit tests
|
||||
# _is_allowed() unit tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIsLocal:
|
||||
"""Verify _is_local() correctly classifies IP addresses."""
|
||||
class TestIsAllowed:
|
||||
"""Verify _is_allowed() checks IPs against a given whitelist."""
|
||||
|
||||
# --- loopback ---
|
||||
def setup_method(self):
|
||||
clear_network_cache()
|
||||
|
||||
def test_ipv4_loopback_127_0_0_1(self):
|
||||
"""127.0.0.1 is the canonical loopback address."""
|
||||
assert _is_local("127.0.0.1") is True
|
||||
LOCAL_WHITELIST = [
|
||||
"127.0.0.0/8",
|
||||
"10.0.0.0/8",
|
||||
"172.16.0.0/12",
|
||||
"192.168.0.0/16",
|
||||
"::1/128",
|
||||
"fe80::/10",
|
||||
"fc00::/7",
|
||||
]
|
||||
|
||||
def test_ipv4_loopback_127_0_0_2(self):
|
||||
"""127.0.0.2 is within 127.0.0.0/8 and therefore local."""
|
||||
assert _is_local("127.0.0.2") is True
|
||||
# --- allowed IPs ---
|
||||
|
||||
def test_ipv4_loopback_127_255_255_255(self):
|
||||
"""Top of 127.0.0.0/8 range is still local."""
|
||||
assert _is_local("127.255.255.255") is True
|
||||
def test_ipv4_loopback_allowed(self):
|
||||
assert _is_allowed("127.0.0.1", self.LOCAL_WHITELIST) is True
|
||||
|
||||
# --- RFC 1918 class-A (10.0.0.0/8) ---
|
||||
def test_ipv4_loopback_range(self):
|
||||
assert _is_allowed("127.255.255.255", self.LOCAL_WHITELIST) is True
|
||||
|
||||
def test_ipv4_10_network_start(self):
|
||||
"""10.0.0.1 is in 10.0.0.0/8 private range."""
|
||||
assert _is_local("10.0.0.1") is True
|
||||
def test_ipv4_10_network(self):
|
||||
assert _is_allowed("10.0.0.1", self.LOCAL_WHITELIST) is True
|
||||
|
||||
def test_ipv4_10_network_mid(self):
|
||||
"""10.10.20.30 is inside 10.0.0.0/8."""
|
||||
assert _is_local("10.10.20.30") is True
|
||||
def test_ipv4_172_16_network(self):
|
||||
assert _is_allowed("172.16.0.1", self.LOCAL_WHITELIST) is True
|
||||
|
||||
def test_ipv4_10_network_end(self):
|
||||
"""10.255.255.254 is the last usable address in 10.0.0.0/8."""
|
||||
assert _is_local("10.255.255.254") is True
|
||||
|
||||
# --- RFC 1918 class-B (172.16.0.0/12) ---
|
||||
|
||||
def test_ipv4_172_16_start(self):
|
||||
"""172.16.0.1 is the first address in 172.16.0.0/12."""
|
||||
assert _is_local("172.16.0.1") is True
|
||||
|
||||
def test_ipv4_172_31_end(self):
|
||||
"""172.31.255.254 is at the top of the 172.16.0.0/12 range."""
|
||||
assert _is_local("172.31.255.254") is True
|
||||
|
||||
def test_ipv4_172_15_not_local(self):
|
||||
"""172.15.255.255 is just outside 172.16.0.0/12 (below the range)."""
|
||||
assert _is_local("172.15.255.255") is False
|
||||
|
||||
def test_ipv4_172_32_not_local(self):
|
||||
"""172.32.0.0 is just above 172.16.0.0/12 (outside the range)."""
|
||||
assert _is_local("172.32.0.0") is False
|
||||
|
||||
# --- RFC 1918 class-C (192.168.0.0/16) ---
|
||||
|
||||
def test_ipv4_192_168_start(self):
|
||||
"""192.168.0.1 is a typical home-router address."""
|
||||
assert _is_local("192.168.0.1") is True
|
||||
|
||||
def test_ipv4_192_168_end(self):
|
||||
"""192.168.255.254 is at the top of 192.168.0.0/16."""
|
||||
assert _is_local("192.168.255.254") is True
|
||||
|
||||
# --- Public IPv4 ---
|
||||
|
||||
def test_public_ipv4_google_dns(self):
|
||||
"""8.8.8.8 (Google DNS) is a public address."""
|
||||
assert _is_local("8.8.8.8") is False
|
||||
|
||||
def test_public_ipv4_cloudflare_dns(self):
|
||||
"""1.1.1.1 (Cloudflare) is a public address."""
|
||||
assert _is_local("1.1.1.1") is False
|
||||
|
||||
def test_public_ipv4_broadcast_like(self):
|
||||
"""203.0.113.1 (TEST-NET-3, RFC 5737) is not a private address."""
|
||||
assert _is_local("203.0.113.1") is False
|
||||
|
||||
# --- IPv6 loopback ---
|
||||
def test_ipv4_192_168_network(self):
|
||||
assert _is_allowed("192.168.1.100", self.LOCAL_WHITELIST) is True
|
||||
|
||||
def test_ipv6_loopback(self):
|
||||
"""::1 is the IPv6 loopback address."""
|
||||
assert _is_local("::1") is True
|
||||
|
||||
# --- IPv6 link-local (fe80::/10) ---
|
||||
assert _is_allowed("::1", self.LOCAL_WHITELIST) is True
|
||||
|
||||
def test_ipv6_link_local(self):
|
||||
"""fe80::1 is an IPv6 link-local address."""
|
||||
assert _is_local("fe80::1") is True
|
||||
assert _is_allowed("fe80::1", self.LOCAL_WHITELIST) is True
|
||||
|
||||
def test_ipv6_link_local_full(self):
|
||||
"""fe80::aabb:ccdd is also link-local."""
|
||||
assert _is_local("fe80::aabb:ccdd") is True
|
||||
def test_ipv6_ula(self):
|
||||
assert _is_allowed("fd00::1", self.LOCAL_WHITELIST) is True
|
||||
|
||||
# --- IPv6 ULA (fc00::/7) ---
|
||||
# --- denied IPs ---
|
||||
|
||||
def test_ipv6_ula_fc(self):
|
||||
"""fc00::1 is within the ULA range fc00::/7."""
|
||||
assert _is_local("fc00::1") is True
|
||||
def test_public_ipv4_denied(self):
|
||||
assert _is_allowed("8.8.8.8", self.LOCAL_WHITELIST) is False
|
||||
|
||||
def test_ipv6_ula_fd(self):
|
||||
"""fd00::1 is within the ULA range fc00::/7 (fd prefix)."""
|
||||
assert _is_local("fd00::1") is True
|
||||
def test_public_ipv6_denied(self):
|
||||
assert _is_allowed("2001:4860:4860::8888", self.LOCAL_WHITELIST) is False
|
||||
|
||||
# --- Public IPv6 ---
|
||||
def test_172_outside_range(self):
|
||||
assert _is_allowed("172.32.0.0", self.LOCAL_WHITELIST) is False
|
||||
|
||||
def test_public_ipv6_google(self):
|
||||
"""2001:4860:4860::8888 (Google IPv6 DNS) is a public address."""
|
||||
assert _is_local("2001:4860:4860::8888") is False
|
||||
# --- empty whitelist ---
|
||||
|
||||
def test_public_ipv6_documentation(self):
|
||||
"""2001:db8::1 (documentation prefix, RFC 3849) is public."""
|
||||
assert _is_local("2001:db8::1") is False
|
||||
def test_empty_whitelist_denies_all(self):
|
||||
assert _is_allowed("127.0.0.1", []) is False
|
||||
|
||||
# --- Invalid inputs ---
|
||||
# --- invalid inputs ---
|
||||
|
||||
def test_invalid_hostname_returns_false(self):
|
||||
"""A hostname string is not parseable as an IP and must return False."""
|
||||
assert _is_local("localhost") is False
|
||||
def test_invalid_hostname(self):
|
||||
assert _is_allowed("localhost", self.LOCAL_WHITELIST) is False
|
||||
|
||||
def test_invalid_string_returns_false(self):
|
||||
"""A random non-IP string returns False without raising."""
|
||||
assert _is_local("not-an-ip") is False
|
||||
def test_empty_string(self):
|
||||
assert _is_allowed("", self.LOCAL_WHITELIST) is False
|
||||
|
||||
def test_empty_string_returns_false(self):
|
||||
"""An empty string is not a valid IP address."""
|
||||
assert _is_local("") is False
|
||||
def test_malformed_ipv4(self):
|
||||
assert _is_allowed("256.0.0.1", self.LOCAL_WHITELIST) is False
|
||||
|
||||
def test_malformed_ipv4_returns_false(self):
|
||||
"""A string that looks like IPv4 but is malformed returns False."""
|
||||
assert _is_local("256.0.0.1") is False
|
||||
# --- single IP whitelist ---
|
||||
|
||||
def test_partial_ipv4_returns_false(self):
|
||||
"""An incomplete IPv4 address is not valid."""
|
||||
assert _is_local("192.168") is False
|
||||
def test_single_ip_whitelist(self):
|
||||
assert _is_allowed("203.0.113.5", ["203.0.113.5/32"]) is True
|
||||
assert _is_allowed("203.0.113.6", ["203.0.113.5/32"]) is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LocalNetworkMiddleware integration tests
|
||||
# McpAccessMiddleware integration tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_app_with_middleware() -> Starlette:
|
||||
"""Build a minimal Starlette app with LocalNetworkMiddleware applied."""
|
||||
def _make_mcp_settings(mcp_whitelist=None, mcp_tokens=None):
|
||||
"""Create a mock settings.security object."""
|
||||
|
||||
class MockSecurity:
|
||||
def __init__(self):
|
||||
self.mcp_whitelist = mcp_whitelist if mcp_whitelist is not None else []
|
||||
self.mcp_tokens = mcp_tokens if mcp_tokens is not None else []
|
||||
|
||||
class MockSettings:
|
||||
def __init__(self):
|
||||
self.security = MockSecurity()
|
||||
|
||||
return MockSettings()
|
||||
|
||||
|
||||
def _make_app() -> Starlette:
|
||||
"""Build a minimal Starlette app with McpAccessMiddleware applied."""
|
||||
|
||||
async def homepage(request):
|
||||
return PlainTextResponse("ok")
|
||||
|
||||
app = Starlette(routes=[Route("/", homepage)])
|
||||
app.add_middleware(LocalNetworkMiddleware)
|
||||
app.add_middleware(McpAccessMiddleware)
|
||||
return app
|
||||
|
||||
|
||||
class TestLocalNetworkMiddleware:
|
||||
"""Verify LocalNetworkMiddleware allows or denies requests by client IP."""
|
||||
def _patch_client_ip(app, ip):
|
||||
"""Return a modified app that overrides the client IP in ASGI scope."""
|
||||
original_build = app.build_middleware_stack
|
||||
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
return _make_app_with_middleware()
|
||||
async def patched_app(scope, receive, send):
|
||||
if scope["type"] == "http":
|
||||
scope["client"] = (ip, 12345) if ip is not None else None
|
||||
await original_build()(scope, receive, send)
|
||||
|
||||
def test_local_ipv4_loopback_allowed(self, app):
|
||||
"""Requests from 127.0.0.1 are allowed through."""
|
||||
# Starlette's TestClient identifies itself as "testclient", not a real IP.
|
||||
# Patch the scope so the middleware sees an actual loopback address.
|
||||
original_build = app.build_middleware_stack
|
||||
app.build_middleware_stack = lambda: patched_app
|
||||
return app
|
||||
|
||||
async def patched_app(scope, receive, send):
|
||||
if scope["type"] == "http":
|
||||
scope["client"] = ("127.0.0.1", 12345)
|
||||
await original_build()(scope, receive, send)
|
||||
|
||||
app.build_middleware_stack = lambda: patched_app # type: ignore[method-assign]
|
||||
class TestMcpAccessMiddleware:
|
||||
"""Verify McpAccessMiddleware allows/denies requests by IP and token."""
|
||||
|
||||
client = TestClient(app, raise_server_exceptions=False)
|
||||
response = client.get("/")
|
||||
def setup_method(self):
|
||||
clear_network_cache()
|
||||
|
||||
def test_allowed_ip_passes(self):
|
||||
mock_settings = _make_mcp_settings(mcp_whitelist=["127.0.0.0/8"])
|
||||
app = _patch_client_ip(_make_app(), "127.0.0.1")
|
||||
with patch("module.mcp.security.settings", mock_settings):
|
||||
client = TestClient(app, raise_server_exceptions=False)
|
||||
response = client.get("/")
|
||||
assert response.status_code == 200
|
||||
assert response.text == "ok"
|
||||
|
||||
def test_non_local_ip_blocked(self, app):
|
||||
"""Requests from a public IP are rejected with 403."""
|
||||
# Patch the ASGI scope to simulate a public client
|
||||
original_build = app.build_middleware_stack
|
||||
|
||||
async def patched_app(scope, receive, send):
|
||||
if scope["type"] == "http":
|
||||
scope["client"] = ("8.8.8.8", 12345)
|
||||
await original_build()(scope, receive, send)
|
||||
|
||||
app.build_middleware_stack = lambda: patched_app # type: ignore[method-assign]
|
||||
|
||||
client = TestClient(app, raise_server_exceptions=False)
|
||||
response = client.get("/")
|
||||
def test_denied_ip_blocked(self):
|
||||
mock_settings = _make_mcp_settings(mcp_whitelist=["127.0.0.0/8"])
|
||||
app = _patch_client_ip(_make_app(), "8.8.8.8")
|
||||
with patch("module.mcp.security.settings", mock_settings):
|
||||
client = TestClient(app, raise_server_exceptions=False)
|
||||
response = client.get("/")
|
||||
assert response.status_code == 403
|
||||
assert "MCP access is restricted to local network" in response.text
|
||||
assert "MCP access denied" in response.text
|
||||
|
||||
def test_missing_client_blocked(self, app):
|
||||
"""Requests with no client information are rejected with 403."""
|
||||
original_build = app.build_middleware_stack
|
||||
|
||||
async def patched_app(scope, receive, send):
|
||||
if scope["type"] == "http":
|
||||
scope["client"] = None
|
||||
await original_build()(scope, receive, send)
|
||||
|
||||
app.build_middleware_stack = lambda: patched_app # type: ignore[method-assign]
|
||||
|
||||
client = TestClient(app, raise_server_exceptions=False)
|
||||
response = client.get("/")
|
||||
def test_empty_whitelist_denies_all(self):
|
||||
mock_settings = _make_mcp_settings(mcp_whitelist=[], mcp_tokens=[])
|
||||
app = _patch_client_ip(_make_app(), "127.0.0.1")
|
||||
with patch("module.mcp.security.settings", mock_settings):
|
||||
client = TestClient(app, raise_server_exceptions=False)
|
||||
response = client.get("/")
|
||||
assert response.status_code == 403
|
||||
|
||||
def test_blocked_response_is_json(self, app):
|
||||
"""The 403 error body is valid JSON with an 'error' key."""
|
||||
def test_missing_client_blocked(self):
|
||||
mock_settings = _make_mcp_settings(mcp_whitelist=["127.0.0.0/8"])
|
||||
app = _patch_client_ip(_make_app(), None)
|
||||
with patch("module.mcp.security.settings", mock_settings):
|
||||
client = TestClient(app, raise_server_exceptions=False)
|
||||
response = client.get("/")
|
||||
assert response.status_code == 403
|
||||
|
||||
def test_bearer_token_bypasses_ip(self):
|
||||
mock_settings = _make_mcp_settings(
|
||||
mcp_whitelist=[], mcp_tokens=["secret-token-123"]
|
||||
)
|
||||
app = _patch_client_ip(_make_app(), "8.8.8.8")
|
||||
with patch("module.mcp.security.settings", mock_settings):
|
||||
client = TestClient(app, raise_server_exceptions=False)
|
||||
response = client.get(
|
||||
"/", headers={"Authorization": "Bearer secret-token-123"}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_invalid_bearer_token_denied(self):
|
||||
mock_settings = _make_mcp_settings(
|
||||
mcp_whitelist=[], mcp_tokens=["secret-token-123"]
|
||||
)
|
||||
app = _patch_client_ip(_make_app(), "8.8.8.8")
|
||||
with patch("module.mcp.security.settings", mock_settings):
|
||||
client = TestClient(app, raise_server_exceptions=False)
|
||||
response = client.get(
|
||||
"/", headers={"Authorization": "Bearer wrong-token"}
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
def test_private_network_with_default_whitelist(self):
|
||||
default_whitelist = [
|
||||
"127.0.0.0/8",
|
||||
"10.0.0.0/8",
|
||||
"172.16.0.0/12",
|
||||
"192.168.0.0/16",
|
||||
"::1/128",
|
||||
"fe80::/10",
|
||||
"fc00::/7",
|
||||
]
|
||||
mock_settings = _make_mcp_settings(mcp_whitelist=default_whitelist)
|
||||
app = _patch_client_ip(_make_app(), "192.168.1.100")
|
||||
with patch("module.mcp.security.settings", mock_settings):
|
||||
client = TestClient(app, raise_server_exceptions=False)
|
||||
response = client.get("/")
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_blocked_response_is_json(self):
|
||||
import json
|
||||
|
||||
original_build = app.build_middleware_stack
|
||||
|
||||
async def patched_app(scope, receive, send):
|
||||
if scope["type"] == "http":
|
||||
scope["client"] = ("1.2.3.4", 9999)
|
||||
await original_build()(scope, receive, send)
|
||||
|
||||
app.build_middleware_stack = lambda: patched_app # type: ignore[method-assign]
|
||||
|
||||
client = TestClient(app, raise_server_exceptions=False)
|
||||
response = client.get("/")
|
||||
mock_settings = _make_mcp_settings(mcp_whitelist=["127.0.0.0/8"])
|
||||
app = _patch_client_ip(_make_app(), "1.2.3.4")
|
||||
with patch("module.mcp.security.settings", mock_settings):
|
||||
client = TestClient(app, raise_server_exceptions=False)
|
||||
response = client.get("/")
|
||||
assert response.status_code == 403
|
||||
body = json.loads(response.text)
|
||||
assert "error" in body
|
||||
|
||||
def test_private_192_168_allowed(self, app):
|
||||
"""Requests from a 192.168.x.x address pass through."""
|
||||
original_build = app.build_middleware_stack
|
||||
|
||||
async def patched_app(scope, receive, send):
|
||||
if scope["type"] == "http":
|
||||
scope["client"] = ("192.168.1.100", 54321)
|
||||
await original_build()(scope, receive, send)
|
||||
|
||||
app.build_middleware_stack = lambda: patched_app # type: ignore[method-assign]
|
||||
|
||||
client = TestClient(app, raise_server_exceptions=False)
|
||||
response = client.get("/")
|
||||
assert response.status_code == 200
|
||||
|
||||
429
backend/src/test/test_mock_downloader.py
Normal file
429
backend/src/test/test_mock_downloader.py
Normal file
@@ -0,0 +1,429 @@
|
||||
"""Tests for MockDownloader - state management and API contract."""
|
||||
|
||||
import pytest
|
||||
|
||||
from module.downloader.client.mock_downloader import MockDownloader
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dl() -> MockDownloader:
|
||||
"""Fresh MockDownloader for each test."""
|
||||
return MockDownloader()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Initialization
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMockDownloaderInit:
|
||||
def test_initial_state_is_empty(self, mock_dl):
|
||||
"""MockDownloader starts with no torrents, rules, or feeds."""
|
||||
state = mock_dl.get_state()
|
||||
assert state["torrents"] == {}
|
||||
assert state["rules"] == {}
|
||||
assert state["feeds"] == {}
|
||||
|
||||
def test_initial_categories(self, mock_dl):
|
||||
"""Default categories include Bangumi and BangumiCollection."""
|
||||
state = mock_dl.get_state()
|
||||
assert "Bangumi" in state["categories"]
|
||||
assert "BangumiCollection" in state["categories"]
|
||||
|
||||
def test_initial_prefs(self, mock_dl):
|
||||
"""Default prefs are populated."""
|
||||
# Access private attribute directly to confirm defaults
|
||||
assert mock_dl._prefs["rss_auto_downloading_enabled"] is True
|
||||
assert mock_dl._prefs["rss_max_articles_per_feed"] == 500
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Auth / connection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMockDownloaderAuth:
|
||||
async def test_auth_returns_true(self, mock_dl):
|
||||
result = await mock_dl.auth()
|
||||
assert result is True
|
||||
|
||||
async def test_auth_retry_param_accepted(self, mock_dl):
|
||||
result = await mock_dl.auth(retry=5)
|
||||
assert result is True
|
||||
|
||||
async def test_logout_does_not_raise(self, mock_dl):
|
||||
await mock_dl.logout()
|
||||
|
||||
async def test_check_host_returns_true(self, mock_dl):
|
||||
result = await mock_dl.check_host()
|
||||
assert result is True
|
||||
|
||||
async def test_check_connection_returns_version_string(self, mock_dl):
|
||||
result = await mock_dl.check_connection()
|
||||
assert "mock" in result.lower()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Prefs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMockDownloaderPrefs:
|
||||
async def test_prefs_init_updates_prefs(self, mock_dl):
|
||||
"""prefs_init merges given prefs into the internal store."""
|
||||
await mock_dl.prefs_init({"rss_refresh_interval": 60, "custom_key": "val"})
|
||||
assert mock_dl._prefs["rss_refresh_interval"] == 60
|
||||
assert mock_dl._prefs["custom_key"] == "val"
|
||||
|
||||
async def test_get_app_prefs_returns_dict(self, mock_dl):
|
||||
result = await mock_dl.get_app_prefs()
|
||||
assert isinstance(result, dict)
|
||||
assert "save_path" in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Categories
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMockDownloaderCategories:
|
||||
async def test_add_category_persists(self, mock_dl):
|
||||
await mock_dl.add_category("NewCategory")
|
||||
state = mock_dl.get_state()
|
||||
assert "NewCategory" in state["categories"]
|
||||
|
||||
async def test_add_duplicate_category_no_error(self, mock_dl):
|
||||
await mock_dl.add_category("Bangumi")
|
||||
state = mock_dl.get_state()
|
||||
# Still only one entry for Bangumi (set semantics)
|
||||
assert state["categories"].count("Bangumi") == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Torrent management
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMockDownloaderAddTorrents:
|
||||
async def test_add_torrent_url_returns_true(self, mock_dl):
|
||||
result = await mock_dl.add_torrents(
|
||||
torrent_urls="magnet:?xt=urn:btih:abc",
|
||||
torrent_files=None,
|
||||
save_path="/downloads/Bangumi",
|
||||
category="Bangumi",
|
||||
)
|
||||
assert result is True
|
||||
|
||||
async def test_add_torrent_stores_in_state(self, mock_dl):
|
||||
await mock_dl.add_torrents(
|
||||
torrent_urls="magnet:?xt=urn:btih:abc",
|
||||
torrent_files=None,
|
||||
save_path="/downloads/Bangumi",
|
||||
category="Bangumi",
|
||||
)
|
||||
state = mock_dl.get_state()
|
||||
assert len(state["torrents"]) == 1
|
||||
|
||||
async def test_add_torrent_with_tag_stored(self, mock_dl):
|
||||
await mock_dl.add_torrents(
|
||||
torrent_urls="magnet:?xt=urn:btih:abc",
|
||||
torrent_files=None,
|
||||
save_path="/downloads/Bangumi",
|
||||
category="Bangumi",
|
||||
tags="ab:42",
|
||||
)
|
||||
state = mock_dl.get_state()
|
||||
torrent = list(state["torrents"].values())[0]
|
||||
assert torrent["tags"] == "ab:42"
|
||||
|
||||
async def test_add_torrent_with_file_bytes(self, mock_dl):
|
||||
result = await mock_dl.add_torrents(
|
||||
torrent_urls=None,
|
||||
torrent_files=b"\x00\x01\x02",
|
||||
save_path="/downloads",
|
||||
category="Bangumi",
|
||||
)
|
||||
assert result is True
|
||||
|
||||
async def test_two_different_torrents_stored_separately(self, mock_dl):
|
||||
await mock_dl.add_torrents(
|
||||
torrent_urls="magnet:?xt=urn:btih:aaa",
|
||||
torrent_files=None,
|
||||
save_path="/dl",
|
||||
category="Bangumi",
|
||||
)
|
||||
await mock_dl.add_torrents(
|
||||
torrent_urls="magnet:?xt=urn:btih:bbb",
|
||||
torrent_files=None,
|
||||
save_path="/dl",
|
||||
category="Bangumi",
|
||||
)
|
||||
state = mock_dl.get_state()
|
||||
assert len(state["torrents"]) == 2
|
||||
|
||||
|
||||
class TestMockDownloaderTorrentsInfo:
|
||||
async def test_returns_all_when_no_filter(self, mock_dl):
|
||||
mock_dl.add_mock_torrent("Anime A", category="Bangumi")
|
||||
mock_dl.add_mock_torrent("Anime B", category="Bangumi")
|
||||
result = await mock_dl.torrents_info(status_filter=None, category=None)
|
||||
assert len(result) == 2
|
||||
|
||||
async def test_filters_by_category(self, mock_dl):
|
||||
mock_dl.add_mock_torrent("Anime A", category="Bangumi")
|
||||
mock_dl.add_mock_torrent("Movie", category="BangumiCollection")
|
||||
result = await mock_dl.torrents_info(status_filter=None, category="Bangumi")
|
||||
assert len(result) == 1
|
||||
assert result[0]["name"] == "Anime A"
|
||||
|
||||
async def test_filters_by_tag(self, mock_dl):
|
||||
h1 = mock_dl.add_mock_torrent("Anime A", category="Bangumi")
|
||||
mock_dl.add_mock_torrent("Anime B", category="Bangumi")
|
||||
# Manually set the tag on first torrent
|
||||
mock_dl._torrents[h1]["tags"] = ["ab:1"]
|
||||
result = await mock_dl.torrents_info(
|
||||
status_filter=None, category=None, tag="ab:1"
|
||||
)
|
||||
assert len(result) == 1
|
||||
assert result[0]["name"] == "Anime A"
|
||||
|
||||
async def test_empty_store_returns_empty_list(self, mock_dl):
|
||||
result = await mock_dl.torrents_info(status_filter=None, category="Bangumi")
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestMockDownloaderTorrentsFiles:
|
||||
async def test_returns_files_for_known_hash(self, mock_dl):
|
||||
files = [{"name": "ep01.mkv", "size": 500_000_000}]
|
||||
h = mock_dl.add_mock_torrent("Anime", files=files)
|
||||
result = await mock_dl.torrents_files(torrent_hash=h)
|
||||
assert len(result) == 1
|
||||
assert result[0]["name"] == "ep01.mkv"
|
||||
|
||||
async def test_returns_empty_list_for_unknown_hash(self, mock_dl):
|
||||
result = await mock_dl.torrents_files(torrent_hash="nonexistent")
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestMockDownloaderDelete:
|
||||
async def test_delete_single_torrent(self, mock_dl):
|
||||
h = mock_dl.add_mock_torrent("Anime")
|
||||
await mock_dl.torrents_delete(hash=h)
|
||||
assert h not in mock_dl._torrents
|
||||
|
||||
async def test_delete_multiple_torrents_pipe_separated(self, mock_dl):
|
||||
h1 = mock_dl.add_mock_torrent("Anime A")
|
||||
h2 = mock_dl.add_mock_torrent("Anime B")
|
||||
await mock_dl.torrents_delete(hash=f"{h1}|{h2}")
|
||||
assert h1 not in mock_dl._torrents
|
||||
assert h2 not in mock_dl._torrents
|
||||
|
||||
async def test_delete_nonexistent_hash_no_error(self, mock_dl):
|
||||
await mock_dl.torrents_delete(hash="deadbeef")
|
||||
|
||||
|
||||
class TestMockDownloaderPauseResume:
|
||||
async def test_pause_sets_state(self, mock_dl):
|
||||
h = mock_dl.add_mock_torrent("Anime", state="downloading")
|
||||
await mock_dl.torrents_pause(hashes=h)
|
||||
assert mock_dl._torrents[h]["state"] == "paused"
|
||||
|
||||
async def test_resume_sets_state(self, mock_dl):
|
||||
h = mock_dl.add_mock_torrent("Anime", state="paused")
|
||||
await mock_dl.torrents_resume(hashes=h)
|
||||
assert mock_dl._torrents[h]["state"] == "downloading"
|
||||
|
||||
async def test_pause_multiple_pipe_separated(self, mock_dl):
|
||||
h1 = mock_dl.add_mock_torrent("Anime A", state="downloading")
|
||||
h2 = mock_dl.add_mock_torrent("Anime B", state="downloading")
|
||||
await mock_dl.torrents_pause(hashes=f"{h1}|{h2}")
|
||||
assert mock_dl._torrents[h1]["state"] == "paused"
|
||||
assert mock_dl._torrents[h2]["state"] == "paused"
|
||||
|
||||
async def test_pause_unknown_hash_no_error(self, mock_dl):
|
||||
await mock_dl.torrents_pause(hashes="deadbeef")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Rename
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMockDownloaderRename:
|
||||
async def test_rename_returns_true(self, mock_dl):
|
||||
result = await mock_dl.torrents_rename_file(
|
||||
torrent_hash="hash1",
|
||||
old_path="old.mkv",
|
||||
new_path="new.mkv",
|
||||
)
|
||||
assert result is True
|
||||
|
||||
async def test_rename_with_verify_flag(self, mock_dl):
|
||||
result = await mock_dl.torrents_rename_file(
|
||||
torrent_hash="hash1",
|
||||
old_path="old.mkv",
|
||||
new_path="new.mkv",
|
||||
verify=False,
|
||||
)
|
||||
assert result is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RSS feed management
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMockDownloaderRssFeeds:
|
||||
async def test_add_feed_stored(self, mock_dl):
|
||||
await mock_dl.rss_add_feed(url="https://mikan.me/RSS/test", item_path="Mikan")
|
||||
feeds = await mock_dl.rss_get_feeds()
|
||||
assert "Mikan" in feeds
|
||||
assert feeds["Mikan"]["url"] == "https://mikan.me/RSS/test"
|
||||
|
||||
async def test_remove_feed(self, mock_dl):
|
||||
await mock_dl.rss_add_feed(url="https://example.com", item_path="Feed1")
|
||||
await mock_dl.rss_remove_item(item_path="Feed1")
|
||||
feeds = await mock_dl.rss_get_feeds()
|
||||
assert "Feed1" not in feeds
|
||||
|
||||
async def test_remove_nonexistent_feed_no_error(self, mock_dl):
|
||||
await mock_dl.rss_remove_item(item_path="nonexistent")
|
||||
|
||||
async def test_get_feeds_initially_empty(self, mock_dl):
|
||||
feeds = await mock_dl.rss_get_feeds()
|
||||
assert feeds == {}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Rules
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMockDownloaderRules:
|
||||
async def test_set_rule_stored(self, mock_dl):
|
||||
rule_def = {"enable": True, "mustContain": "Anime"}
|
||||
await mock_dl.rss_set_rule("rule1", rule_def)
|
||||
rules = await mock_dl.get_download_rule()
|
||||
assert "rule1" in rules
|
||||
assert rules["rule1"]["mustContain"] == "Anime"
|
||||
|
||||
async def test_remove_rule(self, mock_dl):
|
||||
await mock_dl.rss_set_rule("rule1", {"enable": True})
|
||||
await mock_dl.remove_rule("rule1")
|
||||
rules = await mock_dl.get_download_rule()
|
||||
assert "rule1" not in rules
|
||||
|
||||
async def test_remove_nonexistent_rule_no_error(self, mock_dl):
|
||||
await mock_dl.remove_rule("nonexistent")
|
||||
|
||||
async def test_get_download_rule_initially_empty(self, mock_dl):
|
||||
rules = await mock_dl.get_download_rule()
|
||||
assert rules == {}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Move / path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMockDownloaderMovePath:
|
||||
async def test_move_torrent_updates_save_path(self, mock_dl):
|
||||
h = mock_dl.add_mock_torrent("Anime", save_path="/old/path")
|
||||
await mock_dl.move_torrent(hashes=h, new_location="/new/path")
|
||||
assert mock_dl._torrents[h]["save_path"] == "/new/path"
|
||||
|
||||
async def test_move_multiple_pipe_separated(self, mock_dl):
|
||||
h1 = mock_dl.add_mock_torrent("Anime A", save_path="/old")
|
||||
h2 = mock_dl.add_mock_torrent("Anime B", save_path="/old")
|
||||
await mock_dl.move_torrent(hashes=f"{h1}|{h2}", new_location="/new")
|
||||
assert mock_dl._torrents[h1]["save_path"] == "/new"
|
||||
assert mock_dl._torrents[h2]["save_path"] == "/new"
|
||||
|
||||
async def test_get_torrent_path_known_hash(self, mock_dl):
|
||||
h = mock_dl.add_mock_torrent("Anime", save_path="/downloads/Bangumi")
|
||||
path = await mock_dl.get_torrent_path(h)
|
||||
assert path == "/downloads/Bangumi"
|
||||
|
||||
async def test_get_torrent_path_unknown_hash_returns_default(self, mock_dl):
|
||||
path = await mock_dl.get_torrent_path("nonexistent")
|
||||
assert path == "/tmp/mock-downloads"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Category assignment
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMockDownloaderSetCategory:
|
||||
async def test_set_category_updates_torrent(self, mock_dl):
|
||||
h = mock_dl.add_mock_torrent("Anime", category="Bangumi")
|
||||
await mock_dl.set_category(h, "BangumiCollection")
|
||||
assert mock_dl._torrents[h]["category"] == "BangumiCollection"
|
||||
|
||||
async def test_set_category_unknown_hash_no_error(self, mock_dl):
|
||||
await mock_dl.set_category("deadbeef", "Bangumi")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tags
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMockDownloaderTags:
|
||||
async def test_add_tag_appends(self, mock_dl):
|
||||
h = mock_dl.add_mock_torrent("Anime")
|
||||
await mock_dl.add_tag(h, "ab:1")
|
||||
assert "ab:1" in mock_dl._torrents[h]["tags"]
|
||||
|
||||
async def test_add_tag_no_duplicates(self, mock_dl):
|
||||
h = mock_dl.add_mock_torrent("Anime")
|
||||
await mock_dl.add_tag(h, "ab:1")
|
||||
await mock_dl.add_tag(h, "ab:1")
|
||||
assert mock_dl._torrents[h]["tags"].count("ab:1") == 1
|
||||
|
||||
async def test_add_tag_unknown_hash_no_error(self, mock_dl):
|
||||
await mock_dl.add_tag("deadbeef", "ab:1")
|
||||
|
||||
async def test_multiple_tags_on_same_torrent(self, mock_dl):
|
||||
h = mock_dl.add_mock_torrent("Anime")
|
||||
await mock_dl.add_tag(h, "ab:1")
|
||||
await mock_dl.add_tag(h, "group:sub")
|
||||
assert "ab:1" in mock_dl._torrents[h]["tags"]
|
||||
assert "group:sub" in mock_dl._torrents[h]["tags"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# add_mock_torrent helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAddMockTorrentHelper:
|
||||
def test_generates_hash_from_name(self, mock_dl):
|
||||
h = mock_dl.add_mock_torrent("Anime")
|
||||
assert h is not None
|
||||
assert len(h) == 40 # SHA1 hex digest
|
||||
|
||||
def test_explicit_hash_used(self, mock_dl):
|
||||
h = mock_dl.add_mock_torrent("Anime", hash="cafebabe" + "0" * 32)
|
||||
assert h == "cafebabe" + "0" * 32
|
||||
|
||||
def test_torrent_state_is_completed_by_default(self, mock_dl):
|
||||
h = mock_dl.add_mock_torrent("Anime")
|
||||
assert mock_dl._torrents[h]["state"] == "completed"
|
||||
assert mock_dl._torrents[h]["progress"] == 1.0
|
||||
|
||||
def test_torrent_state_custom(self, mock_dl):
|
||||
h = mock_dl.add_mock_torrent("Anime", state="downloading")
|
||||
assert mock_dl._torrents[h]["state"] == "downloading"
|
||||
assert mock_dl._torrents[h]["progress"] == 0.5
|
||||
|
||||
def test_default_file_is_mkv(self, mock_dl):
|
||||
h = mock_dl.add_mock_torrent("My Anime")
|
||||
files = mock_dl._torrents[h]["files"]
|
||||
assert len(files) == 1
|
||||
assert files[0]["name"].endswith(".mkv")
|
||||
|
||||
def test_custom_files_stored(self, mock_dl):
|
||||
custom_files = [{"name": "ep01.mkv"}, {"name": "ep02.mkv"}]
|
||||
h = mock_dl.add_mock_torrent("Anime", files=custom_files)
|
||||
assert len(mock_dl._torrents[h]["files"]) == 2
|
||||
2
backend/uv.lock
generated
2
backend/uv.lock
generated
@@ -61,7 +61,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "auto-bangumi"
|
||||
version = "3.2.3b5"
|
||||
version = "3.2.3"
|
||||
source = { virtual = "." }
|
||||
dependencies = [
|
||||
{ name = "aiosqlite" },
|
||||
|
||||
68
webui/src/components/setting/config-security.vue
Normal file
68
webui/src/components/setting/config-security.vue
Normal file
@@ -0,0 +1,68 @@
|
||||
<script lang="ts" setup>
|
||||
import type { Security } from '#/config';
|
||||
import type { SettingItem } from '#/components';
|
||||
|
||||
const { t } = useMyI18n();
|
||||
const { getSettingGroup } = useConfigStore();
|
||||
|
||||
const security = getSettingGroup('security');
|
||||
|
||||
const items: SettingItem<Security>[] = [
|
||||
{
|
||||
configKey: 'login_whitelist',
|
||||
label: () => t('config.security_set.login_whitelist'),
|
||||
type: 'dynamic-tags',
|
||||
prop: {
|
||||
placeholder: '192.168.0.0/16',
|
||||
},
|
||||
},
|
||||
{
|
||||
configKey: 'login_tokens',
|
||||
label: () => t('config.security_set.login_tokens'),
|
||||
type: 'dynamic-tags',
|
||||
prop: {
|
||||
placeholder: 'your-api-token',
|
||||
},
|
||||
bottomLine: true,
|
||||
},
|
||||
{
|
||||
configKey: 'mcp_whitelist',
|
||||
label: () => t('config.security_set.mcp_whitelist'),
|
||||
type: 'dynamic-tags',
|
||||
prop: {
|
||||
placeholder: '127.0.0.0/8',
|
||||
},
|
||||
},
|
||||
{
|
||||
configKey: 'mcp_tokens',
|
||||
label: () => t('config.security_set.mcp_tokens'),
|
||||
type: 'dynamic-tags',
|
||||
prop: {
|
||||
placeholder: 'your-mcp-token',
|
||||
},
|
||||
},
|
||||
];
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<ab-fold-panel :title="$t('config.security_set.title')">
|
||||
<p class="hint-text">{{ $t('config.security_set.hint') }}</p>
|
||||
<div space-y-8>
|
||||
<ab-setting
|
||||
v-for="i in items"
|
||||
:key="i.configKey"
|
||||
v-bind="i"
|
||||
v-model:data="security[i.configKey]"
|
||||
></ab-setting>
|
||||
</div>
|
||||
</ab-fold-panel>
|
||||
</template>
|
||||
|
||||
<style lang="scss" scoped>
|
||||
.hint-text {
|
||||
font-size: 12px;
|
||||
color: var(--color-text-secondary);
|
||||
margin-bottom: 12px;
|
||||
line-height: 1.5;
|
||||
}
|
||||
</style>
|
||||
@@ -49,6 +49,7 @@ onActivated(() => {
|
||||
<config-player></config-player>
|
||||
<config-openai></config-openai>
|
||||
<config-passkey></config-passkey>
|
||||
<config-security></config-security>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
||||
@@ -2,10 +2,6 @@ import type { TupleToUnion } from './utils';
|
||||
|
||||
/** 下载方式 */
|
||||
export type DownloaderType = ['qbittorrent'];
|
||||
/** rss parser 源 */
|
||||
export type RssParserType = ['mikan'];
|
||||
/** rss parser 方法 */
|
||||
export type RssParserMethodType = ['tmdb', 'mikan', 'parser'];
|
||||
/** rss parser 语言 */
|
||||
export type RssParserLang = ['zh', 'en', 'jp'];
|
||||
/** 重命名方式 */
|
||||
@@ -44,12 +40,8 @@ export interface Downloader {
|
||||
}
|
||||
export interface RssParser {
|
||||
enable: boolean;
|
||||
type: TupleToUnion<RssParserType>;
|
||||
token: string;
|
||||
custom_url: string;
|
||||
filter: Array<string>;
|
||||
language: TupleToUnion<RssParserLang>;
|
||||
parser_type: TupleToUnion<RssParserMethodType>;
|
||||
}
|
||||
export interface BangumiManage {
|
||||
enable: boolean;
|
||||
@@ -105,6 +97,17 @@ export interface ExperimentalOpenAI {
|
||||
deployment_id?: string;
|
||||
}
|
||||
|
||||
/** Access control for the login endpoint and MCP server.
|
||||
* Whitelist entries are IPv4/IPv6 CIDR strings (e.g. "192.168.0.0/16").
|
||||
* An empty login_whitelist allows all IPs; an empty mcp_whitelist denies all IP-based MCP access.
|
||||
*/
|
||||
export interface Security {
|
||||
login_whitelist: string[];
|
||||
login_tokens: string[];
|
||||
mcp_whitelist: string[];
|
||||
mcp_tokens: string[];
|
||||
}
|
||||
|
||||
export interface Config {
|
||||
program: Program;
|
||||
downloader: Downloader;
|
||||
@@ -114,6 +117,7 @@ export interface Config {
|
||||
proxy: Proxy;
|
||||
notification: Notification;
|
||||
experimental_openai: ExperimentalOpenAI;
|
||||
security: Security;
|
||||
}
|
||||
|
||||
export const initConfig: Config = {
|
||||
@@ -132,12 +136,8 @@ export const initConfig: Config = {
|
||||
},
|
||||
rss_parser: {
|
||||
enable: true,
|
||||
type: 'mikan',
|
||||
token: '',
|
||||
custom_url: '',
|
||||
filter: [],
|
||||
language: 'zh',
|
||||
parser_type: 'parser',
|
||||
},
|
||||
bangumi_manage: {
|
||||
enable: true,
|
||||
@@ -171,4 +171,10 @@ export const initConfig: Config = {
|
||||
api_version: '2020-05-03',
|
||||
deployment_id: '',
|
||||
},
|
||||
security: {
|
||||
login_whitelist: [],
|
||||
login_tokens: [],
|
||||
mcp_whitelist: [],
|
||||
mcp_tokens: [],
|
||||
},
|
||||
};
|
||||
|
||||
1
webui/types/dts/auto-imports.d.ts
vendored
1
webui/types/dts/auto-imports.d.ts
vendored
@@ -40,6 +40,7 @@ declare global {
|
||||
const getCurrentInstance: typeof import('vue')['getCurrentInstance']
|
||||
const getCurrentScope: typeof import('vue')['getCurrentScope']
|
||||
const h: typeof import('vue')['h']
|
||||
const i18n: typeof import('../../src/hooks/useMyI18n')['i18n']
|
||||
const inject: typeof import('vue')['inject']
|
||||
const isProxy: typeof import('vue')['isProxy']
|
||||
const isReactive: typeof import('vue')['isReactive']
|
||||
|
||||
1
webui/types/dts/components.d.ts
vendored
1
webui/types/dts/components.d.ts
vendored
@@ -55,6 +55,7 @@ declare module '@vue/runtime-core' {
|
||||
ConfigPlayer: typeof import('./../../src/components/setting/config-player.vue')['default']
|
||||
ConfigProxy: typeof import('./../../src/components/setting/config-proxy.vue')['default']
|
||||
ConfigSearchProvider: typeof import('./../../src/components/setting/config-search-provider.vue')['default']
|
||||
ConfigSecurity: typeof import('./../../src/components/setting/config-security.vue')['default']
|
||||
MediaQuery: typeof import('./../../src/components/media-query.vue')['default']
|
||||
RouterLink: typeof import('vue-router')['RouterLink']
|
||||
RouterView: typeof import('vue-router')['RouterView']
|
||||
|
||||
Reference in New Issue
Block a user