feat(security): add security config UI and improve auth/MCP security

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Estrella Pan
2026-02-23 17:18:23 +01:00
parent a0bf878b7e
commit b57d3c49ae
25 changed files with 1621 additions and 302 deletions

View File

@@ -1,3 +1,41 @@
# [Unreleased]
## Backend
### Added
- 新增 `Security` 配置模型,支持登录 IP 白名单、MCP IP 白名单和 Bearer Token 认证
- 新增登录端点 IP 白名单检查中间件 (`check_login_ip`)
- MCP 安全中间件升级为可配置模式:支持 CIDR 白名单 + Bearer Token 双重认证
- 认证端点支持 `Authorization: Bearer` 令牌绕过 Cookie 登录
- 配置 API `_sanitize_dict` 修复:仅对字符串值进行脱敏,避免误处理非字符串字段
- 新增番剧放送日手动设置 API (`PATCH /api/v1/bangumi/{id}/weekday`),支持锁定放送日防止日历刷新覆盖
- 数据库迁移 v9`bangumi` 表新增 `weekday_locked`
### Changed
- 重构认证模块:提取 `_issue_token` 公共方法,消除 3 处重复的 JWT 签发逻辑
- `get_current_user` 简化为三级认证DEV 绕过 → Bearer Token → Cookie JWT
- `LocalNetworkMiddleware` 重命名为 `McpAccessMiddleware`,从硬编码 RFC 1918 改为读取配置
### Tests
- 新增 101 个单元测试覆盖安全、认证、配置、下载器和 MockDownloader 模块
## Frontend
### Added
- 新增日历拖拽排列功能:可将「未知」番剧拖入星期列,自动设置放送日并锁定
- 拖入后显示紫色图钉图标,鼠标悬停显示取消按钮
- 锁定的番剧在日历刷新时不会被覆盖
- 使用 vuedraggable 实现流畅拖拽动画
- 新增安全设置组件 (`config-security.vue`),支持在 WebUI 中配置 IP 白名单和 Token
- 前端 `Security` 类型定义和初始化配置
---
# [3.2.3] - 2026-02-23
## Backend

View File

@@ -1,6 +1,6 @@
[project]
name = "auto-bangumi"
version = "3.2.3"
version = "3.2.4"
description = "AutoBangumi - Automated anime download manager"
requires-python = ">=3.13"
dependencies = [

View File

@@ -9,6 +9,7 @@ from module.models.user import User, UserUpdate
from module.security.api import (
active_user,
auth_user,
check_login_ip,
get_current_user,
update_user_info,
)
@@ -18,17 +19,26 @@ from .response import u_response
router = APIRouter(prefix="/auth", tags=["auth"])
_TOKEN_EXPIRY_DAYS = 1
_TOKEN_MAX_AGE = 86400
@router.post("/login", response_model=dict)
def _issue_token(username: str, response: Response) -> dict:
"""Create a JWT, set it as an HttpOnly cookie, and return the bearer payload."""
token = create_access_token(
data={"sub": username}, expires_delta=timedelta(days=_TOKEN_EXPIRY_DAYS)
)
response.set_cookie(key="token", value=token, httponly=True, max_age=_TOKEN_MAX_AGE)
return {"access_token": token, "token_type": "bearer"}
@router.post("/login", response_model=dict, dependencies=[Depends(check_login_ip)])
async def login(response: Response, form_data=Depends(OAuth2PasswordRequestForm)):
"""Authenticate with username/password and issue a session token."""
user = User(username=form_data.username, password=form_data.password)
resp = auth_user(user)
if resp.status:
token = create_access_token(
data={"sub": user.username}, expires_delta=timedelta(days=1)
)
response.set_cookie(key="token", value=token, httponly=True, max_age=86400)
return {"access_token": token, "token_type": "bearer"}
return _issue_token(user.username, response)
return u_response(resp)
@@ -36,6 +46,7 @@ async def login(response: Response, form_data=Depends(OAuth2PasswordRequestForm)
"/refresh_token", response_model=dict, dependencies=[Depends(get_current_user)]
)
async def refresh(response: Response, token: str = Cookie(None)):
"""Refresh the current session token and update the active-user timestamp."""
payload = decode_token(token)
username = payload.get("sub") if payload else None
if not username:
@@ -43,17 +54,14 @@ async def refresh(response: Response, token: str = Cookie(None)):
status_code=status.HTTP_401_UNAUTHORIZED, detail="Unauthorized"
)
active_user[username] = datetime.now()
new_token = create_access_token(
data={"sub": username}, expires_delta=timedelta(days=1)
)
response.set_cookie(key="token", value=new_token, httponly=True, max_age=86400)
return {"access_token": new_token, "token_type": "bearer"}
return _issue_token(username, response)
@router.get(
"/logout", response_model=APIResponse, dependencies=[Depends(get_current_user)]
)
async def logout(response: Response, token: str = Cookie(None)):
"""Invalidate the session and clear the token cookie."""
payload = decode_token(token)
username = payload.get("sub") if payload else None
if username:
@@ -69,6 +77,7 @@ async def logout(response: Response, token: str = Cookie(None)):
async def update_user(
user_data: UserUpdate, response: Response, token: str = Cookie(None)
):
"""Update credentials for the current user and re-issue a fresh token."""
payload = decode_token(token)
old_user = payload.get("sub") if payload else None
if not old_user:
@@ -76,17 +85,4 @@ async def update_user(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Unauthorized"
)
if update_user_info(user_data, old_user):
token = create_access_token(
data={"sub": old_user}, expires_delta=timedelta(days=1)
)
response.set_cookie(
key="token",
value=token,
httponly=True,
max_age=86400,
)
return {
"access_token": token,
"token_type": "bearer",
"message": "update success",
}
return {**_issue_token(old_user, response), "message": "update success"}

View File

@@ -14,11 +14,12 @@ _SENSITIVE_KEYS = ("password", "api_key", "token", "secret")
def _sanitize_dict(d: dict) -> dict:
"""Recursively mask string values whose keys contain sensitive keywords."""
result = {}
for k, v in d.items():
if isinstance(v, dict):
result[k] = _sanitize_dict(v)
elif any(s in k.lower() for s in _SENSITIVE_KEYS):
elif isinstance(v, str) and any(s in k.lower() for s in _SENSITIVE_KEYS):
result[k] = "********"
else:
result[k] = v
@@ -27,6 +28,7 @@ def _sanitize_dict(d: dict) -> dict:
@router.get("/get", dependencies=[Depends(get_current_user)])
async def get_config():
"""Return the current configuration with sensitive fields masked."""
return _sanitize_dict(settings.dict())
@@ -34,6 +36,7 @@ async def get_config():
"/update", response_model=APIResponse, dependencies=[Depends(get_current_user)]
)
async def update_config(config: Config):
"""Persist and reload configuration from the supplied payload."""
try:
settings.save(config_dict=config.dict())
settings.load()

View File

@@ -7,7 +7,7 @@ from dotenv import load_dotenv
from module.models.config import Config
from .const import ENV_TO_ATTR
from .const import DEFAULT_SETTINGS, ENV_TO_ATTR
logger = logging.getLogger(__name__)
CONFIG_ROOT = Path("config")
@@ -27,6 +27,15 @@ CONFIG_PATH = (
class Settings(Config):
"""Runtime configuration singleton.
On construction, loads from ``CONFIG_PATH`` if the file exists (and
immediately re-saves to apply any migrations), otherwise bootstraps
defaults from environment variables via ``init()``.
Use ``settings`` module-level instance rather than instantiating directly.
"""
def __init__(self):
super().__init__()
if CONFIG_PATH.exists():
@@ -36,6 +45,7 @@ class Settings(Config):
self.init()
def load(self):
"""Load and validate configuration from ``CONFIG_PATH``, applying migrations."""
with open(CONFIG_PATH, "r", encoding="utf-8") as f:
config = json.load(f)
config = self._migrate_old_config(config)
@@ -65,20 +75,27 @@ class Settings(Config):
for key in ("type", "custom_url", "token", "enable_tmdb"):
rss_parser.pop(key, None)
# Add security section if missing (preserves local-network MCP default)
if "security" not in config:
config["security"] = DEFAULT_SETTINGS["security"]
return config
def save(self, config_dict: dict | None = None):
"""Write configuration to ``CONFIG_PATH``. Uses current state when no dict supplied."""
if not config_dict:
config_dict = self.model_dump()
with open(CONFIG_PATH, "w", encoding="utf-8") as f:
json.dump(config_dict, f, indent=4, ensure_ascii=False)
def init(self):
"""Bootstrap a new config file from ``.env`` and environment variables."""
load_dotenv(".env")
self.__load_from_env()
self.save()
def __load_from_env(self):
"""Apply ``ENV_TO_ATTR`` mappings from the process environment to the config dict."""
config_dict = self.model_dump()
for key, section in ENV_TO_ATTR.items():
for env, attr in section.items():
@@ -97,12 +114,11 @@ class Settings(Config):
logger.info("Config loaded from env")
@staticmethod
def __val_from_env(env: str, attr: tuple):
def __val_from_env(env: str, attr: tuple | str):
"""Return the environment variable value, applying the converter when attr is a tuple."""
if isinstance(attr, tuple):
conv_func = attr[1]
return conv_func(os.environ[env])
else:
return os.environ[env]
return attr[1](os.environ[env])
return os.environ[env]
@property
def group_rules(self):

View File

@@ -1,4 +1,8 @@
# -*- encoding: utf-8 -*-
# DEFAULT_SETTINGS: factory defaults written to config.json on first run.
# ENV_TO_ATTR: maps AB_* environment variables to Config model attribute paths.
# Values are either a string attr name, a (attr_name, converter) tuple, or a
# list of such tuples when a single env var sets multiple attributes.
DEFAULT_SETTINGS = {
"program": {
"rss_time": 900,
@@ -46,6 +50,20 @@ DEFAULT_SETTINGS = {
"model": "gpt-3.5-turbo",
"deployment_id": "",
},
"security": {
"login_whitelist": [],
"login_tokens": [],
"mcp_whitelist": [
"127.0.0.0/8",
"10.0.0.0/8",
"172.16.0.0/12",
"192.168.0.0/16",
"::1/128",
"fe80::/10",
"fc00::/7",
],
"mcp_tokens": [],
},
}
@@ -99,8 +117,11 @@ ENV_TO_ATTR = {
class BCOLORS:
"""ANSI colour helpers for terminal output."""
@staticmethod
def _(color: str, *args: str) -> str:
"""Wrap *args* in the given ANSI colour code and reset at the end."""
strings = [str(s) for s in args]
return f"{color}{', '.join(strings)}{BCOLORS.ENDC}"

View File

@@ -29,10 +29,10 @@ class MockDownloader:
"rss_processing_enabled": True,
"rss_refresh_interval": 30,
}
logger.info("[MockDownloader] Initialized")
logger.debug("[MockDownloader] Initialized")
async def auth(self, retry=3) -> bool:
logger.info("[MockDownloader] Auth successful (mocked)")
logger.debug("[MockDownloader] Auth successful (mocked)")
return True
async def logout(self):

View File

@@ -11,6 +11,13 @@ logger = logging.getLogger(__name__)
class DownloadClient(TorrentPath):
"""Unified async download client.
Wraps qBittorrent, Aria2, or MockDownloader behind a common interface.
Intended to be used as an async context manager; authentication is
performed on ``__aenter__`` and the session is closed on ``__aexit__``.
"""
def __init__(self):
super().__init__()
self.client = self.__getClient()
@@ -18,27 +25,28 @@ class DownloadClient(TorrentPath):
@staticmethod
def __getClient():
type = settings.downloader.type
"""Instantiate the configured downloader client (qbittorrent | aria2 | mock)."""
downloader_type = settings.downloader.type
host = settings.downloader.host
username = settings.downloader.username
password = settings.downloader.password
ssl = settings.downloader.ssl
if type == "qbittorrent":
if downloader_type == "qbittorrent":
from .client.qb_downloader import QbDownloader
return QbDownloader(host, username, password, ssl)
elif type == "aria2":
elif downloader_type == "aria2":
from .client.aria2_downloader import Aria2Downloader
return Aria2Downloader(host, username, password)
elif type == "mock":
elif downloader_type == "mock":
from .client.mock_downloader import MockDownloader
logger.info("[Downloader] Using MockDownloader for local development")
logger.debug("[Downloader] Using MockDownloader for local development")
return MockDownloader()
else:
logger.error(f"[Downloader] Unsupported downloader type: {type}")
raise Exception(f"Unsupported downloader type: {type}")
logger.error("[Downloader] Unsupported downloader type: %s", downloader_type)
raise Exception(f"Unsupported downloader type: {downloader_type}")
async def __aenter__(self):
if not self.authed:
@@ -65,6 +73,7 @@ class DownloadClient(TorrentPath):
return await self.client.check_host()
async def init_downloader(self):
"""Apply required qBittorrent RSS preferences and create the Bangumi category."""
prefs = {
"rss_auto_downloading_enabled": True,
"rss_max_articles_per_feed": 500,
@@ -84,6 +93,7 @@ class DownloadClient(TorrentPath):
settings.downloader.path = self._join_path(prefs["save_path"], "Bangumi")
async def set_rule(self, data: Bangumi):
"""Create or update a qBittorrent RSS auto-download rule for one bangumi entry."""
data.rule_name = self._rule_name(data)
data.save_path = self._gen_save_path(data)
rule = {
@@ -145,6 +155,12 @@ class DownloadClient(TorrentPath):
await self.client.torrents_resume(hashes)
async def add_torrent(self, torrent: Torrent | list, bangumi: Bangumi) -> bool:
"""Download a torrent (or list of torrents) for the given bangumi entry.
Handles both magnet links and .torrent file URLs, fetching file bytes
when necessary. Tags each torrent with ``ab:<bangumi_id>`` for later
episode-offset lookup during rename.
"""
if not bangumi.save_path:
bangumi.save_path = self._gen_save_path(bangumi)
async with RequestContent() as req:

View File

@@ -1,48 +1,68 @@
"""MCP access control: restricts connections to local network addresses only."""
"""MCP access control: configurable IP whitelist and bearer token authentication."""
import ipaddress
import logging
from functools import lru_cache
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse
from module.conf import settings
logger = logging.getLogger(__name__)
# RFC 1918 private ranges + loopback + IPv6 equivalents
_ALLOWED_NETWORKS = [
ipaddress.ip_network("127.0.0.0/8"),
ipaddress.ip_network("10.0.0.0/8"),
ipaddress.ip_network("172.16.0.0/12"),
ipaddress.ip_network("192.168.0.0/16"),
ipaddress.ip_network("::1/128"),
ipaddress.ip_network("fe80::/10"),
ipaddress.ip_network("fc00::/7"),
]
@lru_cache(maxsize=128)
def _parse_network(cidr: str) -> ipaddress.IPv4Network | ipaddress.IPv6Network | None:
try:
return ipaddress.ip_network(cidr, strict=False)
except ValueError:
logger.warning("[MCP] Invalid CIDR in whitelist: %s", cidr)
return None
def _is_local(host: str) -> bool:
"""Return True if *host* is a loopback or RFC 1918 private address."""
def _is_allowed(host: str, whitelist: list[str]) -> bool:
"""Return True if *host* falls within any CIDR range in *whitelist*."""
try:
addr = ipaddress.ip_address(host)
except ValueError:
return False
return any(addr in net for net in _ALLOWED_NETWORKS)
for cidr in whitelist:
net = _parse_network(cidr)
if net and addr in net:
return True
return False
class LocalNetworkMiddleware(BaseHTTPMiddleware):
"""Starlette middleware that blocks requests from non-local IP addresses.
def clear_network_cache():
"""Clear the parsed network cache (call after config reload)."""
_parse_network.cache_clear()
Returns HTTP 403 for any client outside loopback, RFC 1918, or IPv6
link-local/unique-local ranges.
class McpAccessMiddleware(BaseHTTPMiddleware):
"""Configurable access control for MCP endpoint.
Checks client IP against ``settings.security.mcp_whitelist`` CIDR ranges,
and ``Authorization`` header against ``settings.security.mcp_tokens``.
If the whitelist is empty and no tokens are configured, all access is denied.
"""
async def dispatch(self, request: Request, call_next):
# Check bearer token first
auth_header = request.headers.get("authorization", "")
if auth_header.startswith("Bearer "):
token = auth_header[7:]
if token and token in settings.security.mcp_tokens:
return await call_next(request)
# Check IP whitelist
client_host = request.client.host if request.client else None
if not client_host or not _is_local(client_host):
logger.warning("[MCP] Rejected non-local connection from %s", client_host)
return JSONResponse(
status_code=403,
content={"error": "MCP access is restricted to local network"},
)
return await call_next(request)
if client_host and _is_allowed(client_host, settings.security.mcp_whitelist):
return await call_next(request)
logger.warning("[MCP] Rejected connection from %s", client_host)
return JSONResponse(
status_code=403,
content={"error": "MCP access denied"},
)

View File

@@ -18,7 +18,7 @@ from starlette.requests import Request
from starlette.routing import Mount, Route
from .resources import RESOURCE_TEMPLATES, RESOURCES, handle_resource
from .security import LocalNetworkMiddleware
from .security import McpAccessMiddleware
from .tools import TOOLS, handle_tool
logger = logging.getLogger(__name__)
@@ -73,8 +73,8 @@ def create_mcp_starlette_app() -> Starlette:
- ``GET /sse`` - SSE stream for MCP clients
- ``POST /messages/`` - client-to-server message posting
``LocalNetworkMiddleware`` is applied so the endpoint is only reachable
from loopback and RFC 1918 addresses.
``McpAccessMiddleware`` is applied to enforce configurable IP whitelist
and bearer token access control.
"""
app = Starlette(
routes=[
@@ -82,5 +82,5 @@ def create_mcp_starlette_app() -> Starlette:
Mount("/messages", app=sse.handle_post_message),
],
)
app.add_middleware(LocalNetworkMiddleware)
app.add_middleware(McpAccessMiddleware)
return app

View File

@@ -4,13 +4,27 @@ from typing import Literal, Optional
from pydantic import BaseModel, Field, field_validator, model_validator
def _expand(value: str | None) -> str:
"""Expand shell environment variables in *value*, returning empty string for None."""
return expandvars(value) if value else ""
class Program(BaseModel):
"""Scheduler timing and WebUI port settings."""
rss_time: int = Field(900, description="Sleep time")
rename_time: int = Field(60, description="Rename times in one loop")
webui_port: int = Field(7892, description="WebUI port")
class Downloader(BaseModel):
"""Download client connection settings.
Credential fields (``host``, ``username``, ``password``) are stored with a
trailing underscore and exposed via properties that expand ``$VAR``
environment variable references at access time.
"""
type: str = Field("qbittorrent", description="Downloader type")
host_: str = Field("172.17.0.1:8080", alias="host", description="Downloader host")
username_: str = Field("admin", alias="username", description="Downloader username")
@@ -22,24 +36,28 @@ class Downloader(BaseModel):
@property
def host(self):
return expandvars(self.host_)
return _expand(self.host_)
@property
def username(self):
return expandvars(self.username_)
return _expand(self.username_)
@property
def password(self):
return expandvars(self.password_)
return _expand(self.password_)
class RSSParser(BaseModel):
"""RSS feed parsing settings."""
enable: bool = Field(True, description="Enable RSS parser")
filter: list[str] = Field(["720", r"\d+-\d"], description="Filter")
language: str = "zh"
class BangumiManage(BaseModel):
"""File organisation and renaming settings."""
enable: bool = Field(True, description="Enable bangumi manage")
eps_complete: bool = Field(False, description="Enable eps complete")
rename_method: str = Field("pn", description="Rename method")
@@ -48,10 +66,14 @@ class BangumiManage(BaseModel):
class Log(BaseModel):
"""Logging verbosity settings."""
debug_enable: bool = Field(False, description="Enable debug")
class Proxy(BaseModel):
"""HTTP/SOCKS proxy settings. Credentials support ``$VAR`` expansion."""
enable: bool = Field(False, description="Enable proxy")
type: str = Field("http", description="Proxy type")
host: str = Field("", description="Proxy host")
@@ -61,11 +83,11 @@ class Proxy(BaseModel):
@property
def username(self):
return expandvars(self.username_)
return _expand(self.username_)
@property
def password(self):
return expandvars(self.password_)
return _expand(self.password_)
class NotificationProvider(BaseModel):
@@ -103,35 +125,35 @@ class NotificationProvider(BaseModel):
@property
def token(self) -> str:
return expandvars(self.token_) if self.token_ else ""
return _expand(self.token_)
@property
def chat_id(self) -> str:
return expandvars(self.chat_id_) if self.chat_id_ else ""
return _expand(self.chat_id_)
@property
def webhook_url(self) -> str:
return expandvars(self.webhook_url_) if self.webhook_url_ else ""
return _expand(self.webhook_url_)
@property
def server_url(self) -> str:
return expandvars(self.server_url_) if self.server_url_ else ""
return _expand(self.server_url_)
@property
def device_key(self) -> str:
return expandvars(self.device_key_) if self.device_key_ else ""
return _expand(self.device_key_)
@property
def user_key(self) -> str:
return expandvars(self.user_key_) if self.user_key_ else ""
return _expand(self.user_key_)
@property
def api_token(self) -> str:
return expandvars(self.api_token_) if self.api_token_ else ""
return _expand(self.api_token_)
@property
def url(self) -> str:
return expandvars(self.url_) if self.url_ else ""
return _expand(self.url_)
class Notification(BaseModel):
@@ -149,11 +171,11 @@ class Notification(BaseModel):
@property
def token(self) -> str:
return expandvars(self.token_) if self.token_ else ""
return _expand(self.token_)
@property
def chat_id(self) -> str:
return expandvars(self.chat_id_) if self.chat_id_ else ""
return _expand(self.chat_id_)
@model_validator(mode="after")
def migrate_legacy_config(self) -> "Notification":
@@ -197,7 +219,35 @@ class ExperimentalOpenAI(BaseModel):
return value
class Security(BaseModel):
"""Access control configuration for the login endpoint and MCP server.
Both ``login_whitelist`` and ``mcp_whitelist`` accept IPv4/IPv6 CIDR ranges.
An empty ``login_whitelist`` allows all IPs; an empty ``mcp_whitelist``
denies all IP-based access (tokens still work).
"""
login_whitelist: list[str] = Field(
default_factory=list,
description="IP/CIDR whitelist for login access. Empty = allow all.",
)
login_tokens: list[str] = Field(
default_factory=list,
description="API bearer tokens that bypass login authentication.",
)
mcp_whitelist: list[str] = Field(
default_factory=list,
description="IP/CIDR whitelist for MCP access. Empty = deny all.",
)
mcp_tokens: list[str] = Field(
default_factory=list,
description="API bearer tokens for MCP access.",
)
class Config(BaseModel):
"""Root configuration model composed of all subsection models."""
program: Program = Program()
downloader: Downloader = Downloader()
rss_parser: RSSParser = RSSParser()
@@ -206,6 +256,7 @@ class Config(BaseModel):
proxy: Proxy = Proxy()
notification: Notification = Notification()
experimental_openai: ExperimentalOpenAI = ExperimentalOpenAI()
security: Security = Security()
def model_dump(self, *args, by_alias=True, **kwargs):
return super().model_dump(*args, by_alias=by_alias, **kwargs)

View File

@@ -1,9 +1,11 @@
from datetime import datetime
from fastapi import Cookie, Depends, HTTPException, status
from fastapi import Cookie, Depends, HTTPException, Request, status
from fastapi.security import OAuth2PasswordBearer
from module.conf import settings
from module.database import Database
from module.mcp.security import _is_allowed
from module.models.user import User, UserUpdate
from .jwt import verify_token
@@ -20,23 +22,49 @@ except ImportError:
DEV_AUTH_BYPASS = VERSION == "DEV_VERSION"
async def get_current_user(token: str = Cookie(None)):
def check_login_ip(request: Request):
"""Dependency that enforces login IP whitelist.
If ``settings.security.login_whitelist`` is empty, all IPs are allowed.
"""
whitelist = settings.security.login_whitelist
if not whitelist:
return
client_host = request.client.host if request.client else None
if not client_host or not _is_allowed(client_host, whitelist):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="IP not in login whitelist",
)
async def get_current_user(request: Request, token: str = Cookie(None)):
"""FastAPI dependency that validates the current session.
Accepts authentication via (in order of precedence):
1. DEV_AUTH_BYPASS when running as DEV_VERSION.
2. ``Authorization: Bearer <token>`` header matching ``login_tokens``.
3. HttpOnly ``token`` cookie containing a valid JWT with an active session.
"""
if DEV_AUTH_BYPASS:
return "dev_user"
# Check bearer token bypass
auth_header = request.headers.get("authorization", "")
if auth_header.startswith("Bearer "):
api_token = auth_header[7:]
if api_token and api_token in settings.security.login_tokens:
return "api_token_user"
if not token:
raise UNAUTHORIZED
payload = verify_token(token)
if not payload:
raise UNAUTHORIZED
username = payload.get("sub")
if not username:
raise UNAUTHORIZED
if username not in active_user:
username = payload.get("sub") if payload else None
if not username or username not in active_user:
raise UNAUTHORIZED
return username
async def get_token_data(token: str = Depends(oauth2_scheme)):
"""FastAPI dependency that decodes and returns the OAuth2 bearer token payload."""
payload = verify_token(token)
if not payload:
raise HTTPException(
@@ -46,6 +74,7 @@ async def get_token_data(token: str = Depends(oauth2_scheme)):
def update_user_info(user_data: UserUpdate, current_user):
"""Persist updated credentials for *current_user* to the database."""
try:
with Database() as db:
db.user.update_user(current_user, user_data)
@@ -55,6 +84,7 @@ def update_user_info(user_data: UserUpdate, current_user):
def auth_user(user: User):
"""Verify credentials and register the user in ``active_user`` on success."""
with Database() as db:
resp = db.user.auth_user(user)
if resp.status:

View File

@@ -185,3 +185,98 @@ class TestUpdateCredentials:
except Exception:
# Expected - endpoint doesn't handle failure case properly
pass
# ---------------------------------------------------------------------------
# Refresh token: cookie-based username resolution
# ---------------------------------------------------------------------------
class TestRefreshTokenCookieBehavior:
def test_refresh_with_no_cookie_raises_401(self, authed_client):
"""GET /refresh_token with missing token cookie raises 401."""
# Override auth to allow route but provide no cookie token
with patch("module.api.auth.decode_token", return_value=None):
response = authed_client.get("/api/v1/auth/refresh_token")
assert response.status_code == 401
def test_refresh_with_valid_cookie_updates_active_user(self, authed_client):
"""GET /refresh_token updates the active user timestamp."""
token = create_access_token(data={"sub": "testuser"})
authed_client.cookies.set("token", token)
active_users: dict = {}
with patch("module.api.auth.active_user", active_users):
response = authed_client.get("/api/v1/auth/refresh_token")
assert response.status_code == 200
assert "testuser" in active_users
def test_refresh_returns_new_token(self, authed_client):
"""GET /refresh_token issues a valid JWT with bearer type."""
token = create_access_token(data={"sub": "testuser"})
authed_client.cookies.set("token", token)
with patch("module.api.auth.active_user", {}):
response = authed_client.get("/api/v1/auth/refresh_token")
assert response.status_code == 200
data = response.json()
assert data["token_type"] == "bearer"
assert isinstance(data["access_token"], str)
assert len(data["access_token"]) > 0
# ---------------------------------------------------------------------------
# Logout: per-user removal
# ---------------------------------------------------------------------------
class TestLogoutCookieBehavior:
def test_logout_removes_only_current_user(self, authed_client):
"""GET /logout removes the current user from active_user, not others."""
token = create_access_token(data={"sub": "testuser"})
authed_client.cookies.set("token", token)
active_users = {
"testuser": datetime.now(),
"otheruser": datetime.now(),
}
with patch("module.api.auth.active_user", active_users):
response = authed_client.get("/api/v1/auth/logout")
assert response.status_code == 200
assert "testuser" not in active_users
assert "otheruser" in active_users
def test_logout_with_no_cookie_still_succeeds(self, authed_client):
"""GET /logout with no cookie clears nothing but returns success."""
with patch("module.api.auth.decode_token", return_value=None):
with patch("module.api.auth.active_user", {}):
response = authed_client.get("/api/v1/auth/logout")
assert response.status_code == 200
# ---------------------------------------------------------------------------
# Update: cookie-based user resolution
# ---------------------------------------------------------------------------
class TestUpdateCookieBehavior:
def test_update_with_no_cookie_raises_401(self, authed_client):
"""POST /auth/update with no cookie raises 401."""
with patch("module.api.auth.decode_token", return_value=None):
response = authed_client.post(
"/api/v1/auth/update",
json={"old_password": "old", "new_password": "new"},
)
assert response.status_code == 401
def test_update_with_valid_cookie_succeeds(self, authed_client):
"""POST /auth/update resolves username from cookie and issues new token."""
token = create_access_token(data={"sub": "testuser"})
authed_client.cookies.set("token", token)
with patch("module.api.auth.active_user", {"testuser": datetime.now()}):
with patch("module.api.auth.update_user_info", return_value=True):
response = authed_client.post(
"/api/v1/auth/update",
json={"old_password": "oldpass", "new_password": "newpass"},
)
assert response.status_code == 200
data = response.json()
assert "access_token" in data
assert data["message"] == "update success"

View File

@@ -1,4 +1,4 @@
"""Tests for Config API endpoints."""
"""Tests for Config API endpoints and config sanitization."""
import pytest
from unittest.mock import patch, MagicMock
@@ -7,6 +7,7 @@ from fastapi import FastAPI
from fastapi.testclient import TestClient
from module.api import v1
from module.api.config import _sanitize_dict
from module.models.config import Config
from module.security.api import get_current_user
@@ -263,3 +264,99 @@ class TestUpdateConfig:
response = authed_client.patch("/api/v1/config/update", json=invalid_data)
assert response.status_code == 422
# ---------------------------------------------------------------------------
# _sanitize_dict unit tests
# ---------------------------------------------------------------------------
class TestSanitizeDict:
def test_masks_password_key(self):
"""Keys containing 'password' are masked."""
result = _sanitize_dict({"password": "secret"})
assert result["password"] == "********"
def test_masks_api_key(self):
"""Keys containing 'api_key' are masked."""
result = _sanitize_dict({"api_key": "sk-abc123"})
assert result["api_key"] == "********"
def test_masks_token_key(self):
"""Keys containing 'token' are masked."""
result = _sanitize_dict({"token": "bearer-xyz"})
assert result["token"] == "********"
def test_masks_secret_key(self):
"""Keys containing 'secret' are masked."""
result = _sanitize_dict({"my_secret": "topsecret"})
assert result["my_secret"] == "********"
def test_case_insensitive_key_matching(self):
"""Sensitive key matching is case-insensitive."""
result = _sanitize_dict({"API_KEY": "abc"})
assert result["API_KEY"] == "********"
def test_non_sensitive_keys_pass_through(self):
"""Non-sensitive keys are returned unchanged."""
result = _sanitize_dict({"host": "localhost", "port": 8080, "enable": True})
assert result["host"] == "localhost"
assert result["port"] == 8080
assert result["enable"] is True
def test_nested_dict_recursed(self):
"""Nested dicts are processed recursively."""
result = _sanitize_dict({
"downloader": {
"host": "localhost",
"password": "secret",
}
})
assert result["downloader"]["host"] == "localhost"
assert result["downloader"]["password"] == "********"
def test_deeply_nested_dict(self):
"""Deeply nested sensitive keys are masked."""
result = _sanitize_dict({
"level1": {
"level2": {
"api_key": "deep-secret"
}
}
})
assert result["level1"]["level2"]["api_key"] == "********"
def test_non_string_value_not_masked(self):
"""Non-string values with sensitive-looking keys are NOT masked."""
result = _sanitize_dict({"password": 12345})
# Only string values are masked; integers pass through
assert result["password"] == 12345
def test_empty_dict(self):
"""Empty dict returns empty dict."""
assert _sanitize_dict({}) == {}
def test_mixed_sensitive_and_plain(self):
"""Mix of sensitive and plain keys handled correctly."""
result = _sanitize_dict({
"username": "admin",
"password": "secret",
"host": "10.0.0.1",
"token": "jwt-abc",
})
assert result["username"] == "admin"
assert result["host"] == "10.0.0.1"
assert result["password"] == "********"
assert result["token"] == "********"
def test_get_config_masks_sensitive_fields(self, authed_client):
"""GET /config/get response masks password and api_key fields."""
test_config = Config()
with patch("module.api.config.settings", test_config):
response = authed_client.get("/api/v1/config/get")
assert response.status_code == 200
data = response.json()
# Downloader password should be masked
assert data["downloader"]["password"] == "********"
# OpenAI api_key should be masked (it's an empty string but still masked)
assert data["experimental_openai"]["api_key"] == "********"

View File

@@ -151,6 +151,15 @@ class TestPasswordHashing:
class TestGetCurrentUser:
@staticmethod
def _mock_request(authorization=""):
"""Create a mock Request with the given Authorization header."""
from unittest.mock import MagicMock
request = MagicMock()
request.headers = {"authorization": authorization}
return request
@patch("module.security.api.DEV_AUTH_BYPASS", False)
async def test_no_cookie_raises_401(self):
"""get_current_user raises 401 when no token cookie."""
@@ -159,7 +168,7 @@ class TestGetCurrentUser:
from module.security.api import get_current_user
with pytest.raises(HTTPException) as exc_info:
await get_current_user(token=None)
await get_current_user(request=self._mock_request(), token=None)
assert exc_info.value.status_code == 401
@patch("module.security.api.DEV_AUTH_BYPASS", False)
@@ -170,7 +179,7 @@ class TestGetCurrentUser:
from module.security.api import get_current_user
with pytest.raises(HTTPException) as exc_info:
await get_current_user(token="invalid.jwt.token")
await get_current_user(request=self._mock_request(), token="invalid.jwt.token")
assert exc_info.value.status_code == 401
@patch("module.security.api.DEV_AUTH_BYPASS", False)
@@ -186,7 +195,7 @@ class TestGetCurrentUser:
active_user.clear()
with pytest.raises(HTTPException) as exc_info:
await get_current_user(token=token)
await get_current_user(request=self._mock_request(), token=token)
assert exc_info.value.status_code == 401
@patch("module.security.api.DEV_AUTH_BYPASS", False)
@@ -202,8 +211,113 @@ class TestGetCurrentUser:
active_user.clear()
active_user["active_user"] = datetime.now()
result = await get_current_user(token=token)
result = await get_current_user(request=self._mock_request(), token=token)
assert result == "active_user"
# Cleanup
active_user.clear()
@patch("module.security.api.DEV_AUTH_BYPASS", True)
async def test_dev_bypass_skips_auth(self):
"""When DEV_AUTH_BYPASS is True, get_current_user returns 'dev_user' unconditionally."""
from module.security.api import get_current_user
result = await get_current_user(request=self._mock_request(), token=None)
assert result == "dev_user"
@patch("module.security.api.DEV_AUTH_BYPASS", False)
async def test_bearer_token_bypass_valid(self):
"""A valid login_token in Authorization header returns 'api_token_user'."""
from module.security.api import get_current_user
mock_request = self._mock_request(authorization="Bearer valid-api-token")
mock_security = type("S", (), {"login_tokens": ["valid-api-token"]})()
mock_settings = type("Settings", (), {"security": mock_security})()
with patch("module.security.api.settings", mock_settings):
result = await get_current_user(request=mock_request, token=None)
assert result == "api_token_user"
@patch("module.security.api.DEV_AUTH_BYPASS", False)
async def test_bearer_token_bypass_invalid(self):
"""An invalid login_token still falls through to cookie check."""
from fastapi import HTTPException
from module.security.api import get_current_user
mock_request = self._mock_request(authorization="Bearer wrong-token")
mock_security = type("S", (), {"login_tokens": ["correct-token"]})()
mock_settings = type("Settings", (), {"security": mock_security})()
with patch("module.security.api.settings", mock_settings):
with pytest.raises(HTTPException) as exc_info:
await get_current_user(request=mock_request, token=None)
assert exc_info.value.status_code == 401
# ---------------------------------------------------------------------------
# check_login_ip
# ---------------------------------------------------------------------------
class TestCheckLoginIp:
@staticmethod
def _make_request(host: str | None):
from unittest.mock import MagicMock
request = MagicMock()
if host is None:
request.client = None
else:
request.client = MagicMock()
request.client.host = host
return request
def test_empty_whitelist_allows_all(self):
"""When login_whitelist is empty, all IPs pass."""
from module.security.api import check_login_ip
mock_security = type("S", (), {"login_whitelist": []})()
mock_settings = type("Settings", (), {"security": mock_security})()
with patch("module.security.api.settings", mock_settings):
# Should not raise
check_login_ip(request=self._make_request("8.8.8.8"))
def test_allowed_ip_passes(self):
"""IP in whitelist does not raise."""
from module.security.api import check_login_ip
mock_security = type("S", (), {"login_whitelist": ["192.168.0.0/16"]})()
mock_settings = type("Settings", (), {"security": mock_security})()
with patch("module.security.api.settings", mock_settings):
check_login_ip(request=self._make_request("192.168.1.100"))
def test_blocked_ip_raises_403(self):
"""IP outside whitelist raises 403."""
from fastapi import HTTPException
from module.security.api import check_login_ip
mock_security = type("S", (), {"login_whitelist": ["192.168.0.0/16"]})()
mock_settings = type("Settings", (), {"security": mock_security})()
with patch("module.security.api.settings", mock_settings):
with pytest.raises(HTTPException) as exc_info:
check_login_ip(request=self._make_request("8.8.8.8"))
assert exc_info.value.status_code == 403
def test_no_client_raises_403_when_whitelist_set(self):
"""Missing client info raises 403 when whitelist is non-empty."""
from fastapi import HTTPException
from module.security.api import check_login_ip
mock_security = type("S", (), {"login_whitelist": ["192.168.0.0/16"]})()
mock_settings = type("Settings", (), {"security": mock_security})()
with patch("module.security.api.settings", mock_settings):
with pytest.raises(HTTPException) as exc_info:
check_login_ip(request=self._make_request(None))
assert exc_info.value.status_code == 403

View File

@@ -8,14 +8,16 @@ from unittest.mock import patch
from module.models.config import (
Config,
Program,
Downloader,
RSSParser,
BangumiManage,
Proxy,
Notification as NotificationConfig,
NotificationProvider,
Program,
Proxy,
RSSParser,
Security,
)
from module.conf.config import Settings
from module.conf.const import BCOLORS, DEFAULT_SETTINGS
# ---------------------------------------------------------------------------
@@ -228,3 +230,281 @@ class TestEnvOverrides:
s.init()
assert "192.168.1.100:9090" in s.downloader.host
# ---------------------------------------------------------------------------
# Security model
# ---------------------------------------------------------------------------
class TestSecurityModel:
def test_security_defaults(self):
"""Security has empty whitelists and token lists by default."""
sec = Security()
assert sec.login_whitelist == []
assert sec.login_tokens == []
assert sec.mcp_whitelist == []
assert sec.mcp_tokens == []
def test_security_in_config(self):
"""Config includes a Security section with correct defaults."""
config = Config()
assert hasattr(config, "security")
assert isinstance(config.security, Security)
assert config.security.login_whitelist == []
def test_security_populated(self):
"""Security fields accept lists of CIDRs and tokens."""
sec = Security(
login_whitelist=["192.168.0.0/16"],
login_tokens=["token-abc"],
mcp_whitelist=["10.0.0.0/8"],
mcp_tokens=["mcp-secret"],
)
assert "192.168.0.0/16" in sec.login_whitelist
assert "token-abc" in sec.login_tokens
assert "10.0.0.0/8" in sec.mcp_whitelist
assert "mcp-secret" in sec.mcp_tokens
def test_security_roundtrip_serialization(self):
"""Security serializes and deserializes correctly."""
original = Security(
login_whitelist=["127.0.0.0/8"],
mcp_tokens=["tok1"],
)
data = original.model_dump()
restored = Security.model_validate(data)
assert restored.login_whitelist == ["127.0.0.0/8"]
assert restored.mcp_tokens == ["tok1"]
# ---------------------------------------------------------------------------
# NotificationProvider model
# ---------------------------------------------------------------------------
class TestNotificationProvider:
def test_minimal_provider(self):
"""NotificationProvider requires only type."""
p = NotificationProvider(type="telegram")
assert p.type == "telegram"
assert p.enabled is True
def test_telegram_provider_fields(self):
"""Telegram provider stores token and chat_id."""
p = NotificationProvider(type="telegram", token="bot123", chat_id="-100456")
assert p.token == "bot123"
assert p.chat_id == "-100456"
def test_discord_provider_fields(self):
"""Discord provider stores webhook_url."""
p = NotificationProvider(
type="discord", webhook_url="https://discord.com/api/webhooks/123/abc"
)
assert p.webhook_url == "https://discord.com/api/webhooks/123/abc"
def test_bark_provider_fields(self):
"""Bark provider stores server_url and device_key."""
p = NotificationProvider(
type="bark", server_url="https://api.day.app", device_key="mykey"
)
assert p.server_url == "https://api.day.app"
assert p.device_key == "mykey"
def test_pushover_provider_fields(self):
"""Pushover provider stores user_key and api_token."""
p = NotificationProvider(type="pushover", user_key="uk1", api_token="at1")
assert p.user_key == "uk1"
assert p.api_token == "at1"
def test_url_field_property(self):
"""Webhook provider stores url."""
p = NotificationProvider(type="webhook", url="https://example.com/hook")
assert p.url == "https://example.com/hook"
def test_optional_fields_default_empty_string(self):
"""Unset optional properties return empty string, not None."""
p = NotificationProvider(type="telegram")
assert p.token == ""
assert p.chat_id == ""
assert p.webhook_url == ""
def test_provider_can_be_disabled(self):
"""Provider can be disabled without removing it."""
p = NotificationProvider(type="telegram", enabled=False)
assert p.enabled is False
def test_env_var_expansion_in_token(self, monkeypatch):
"""Token field expands shell environment variables."""
monkeypatch.setenv("TEST_BOT_TOKEN", "real-token-value")
p = NotificationProvider(type="telegram", token="$TEST_BOT_TOKEN")
assert p.token == "real-token-value"
# ---------------------------------------------------------------------------
# Notification model - legacy migration
# ---------------------------------------------------------------------------
class TestNotificationLegacyMigration:
def test_new_format_no_migration(self):
"""New format with providers list is not touched."""
n = NotificationConfig(
enable=True,
providers=[NotificationProvider(type="telegram", token="tok")],
)
assert len(n.providers) == 1
assert n.providers[0].type == "telegram"
def test_old_format_migrates_to_provider(self):
"""Old single-provider fields (type, token, chat_id) migrate to providers list."""
n = NotificationConfig(
enable=True,
type="telegram",
token="bot_token",
chat_id="-100123",
)
assert len(n.providers) == 1
provider = n.providers[0]
assert provider.type == "telegram"
assert provider.enabled is True
def test_old_format_no_migration_when_providers_already_set(self):
"""When providers already exist, legacy fields do not create additional providers."""
n = NotificationConfig(
enable=True,
type="telegram",
token="unused",
providers=[NotificationProvider(type="discord", webhook_url="https://d.co")],
)
assert len(n.providers) == 1
assert n.providers[0].type == "discord"
def test_notification_empty_providers_by_default(self):
"""Default Notification has no providers."""
n = NotificationConfig()
assert n.providers == []
assert n.enable is False
# ---------------------------------------------------------------------------
# Downloader env-var expansion
# ---------------------------------------------------------------------------
class TestDownloaderEnvExpansion:
def test_host_expands_env_var(self, monkeypatch):
"""Downloader.host expands $VAR references."""
monkeypatch.setenv("QB_HOST", "192.168.5.10:8080")
d = Downloader(host="$QB_HOST")
assert d.host == "192.168.5.10:8080"
def test_username_expands_env_var(self, monkeypatch):
"""Downloader.username expands $VAR references."""
monkeypatch.setenv("QB_USER", "myuser")
d = Downloader(username="$QB_USER")
assert d.username == "myuser"
def test_password_expands_env_var(self, monkeypatch):
"""Downloader.password expands $VAR references."""
monkeypatch.setenv("QB_PASS", "s3cret")
d = Downloader(password="$QB_PASS")
assert d.password == "s3cret"
def test_literal_host_not_expanded(self):
"""Literal host strings without $ are returned as-is."""
d = Downloader(host="localhost:8080")
assert d.host == "localhost:8080"
# ---------------------------------------------------------------------------
# DEFAULT_SETTINGS structure
# ---------------------------------------------------------------------------
class TestDefaultSettings:
def test_security_section_present(self):
"""DEFAULT_SETTINGS contains a security section."""
assert "security" in DEFAULT_SETTINGS
def test_security_default_mcp_whitelist(self):
"""Default MCP whitelist contains private network ranges."""
mcp_wl = DEFAULT_SETTINGS["security"]["mcp_whitelist"]
assert "127.0.0.0/8" in mcp_wl
assert "192.168.0.0/16" in mcp_wl
assert "10.0.0.0/8" in mcp_wl
def test_security_default_tokens_empty(self):
"""Default security token lists are empty."""
assert DEFAULT_SETTINGS["security"]["login_tokens"] == []
assert DEFAULT_SETTINGS["security"]["mcp_tokens"] == []
def test_notification_uses_providers_format(self):
"""DEFAULT_SETTINGS notification uses new providers format."""
notif = DEFAULT_SETTINGS["notification"]
assert "providers" in notif
assert notif["providers"] == []
assert "type" not in notif
# ---------------------------------------------------------------------------
# BCOLORS utility
# ---------------------------------------------------------------------------
class TestBCOLORS:
def test_wrap_single_string(self):
"""BCOLORS._() wraps a string with color codes and reset."""
result = BCOLORS._(BCOLORS.OKGREEN, "hello")
assert "hello" in result
assert BCOLORS.OKGREEN in result
assert BCOLORS.ENDC in result
def test_wrap_multiple_strings(self):
"""BCOLORS._() joins multiple args with commas."""
result = BCOLORS._(BCOLORS.WARNING, "foo", "bar")
assert "foo" in result
assert "bar" in result
def test_wrap_non_string_arg(self):
"""BCOLORS._() converts non-string args to str."""
result = BCOLORS._(BCOLORS.FAIL, 42)
assert "42" in result
def test_all_color_constants_are_strings(self):
"""All BCOLORS constants are strings."""
for attr in ["HEADER", "OKBLUE", "OKCYAN", "OKGREEN", "WARNING", "FAIL", "ENDC"]:
assert isinstance(getattr(BCOLORS, attr), str)
# ---------------------------------------------------------------------------
# Migration: security section injection
# ---------------------------------------------------------------------------
class TestMigrateSecuritySection:
def test_adds_security_when_missing(self):
"""_migrate_old_config injects a default security section when absent."""
old_config = {
"program": {},
"rss_parser": {},
}
result = Settings._migrate_old_config(old_config)
assert "security" in result
assert "mcp_whitelist" in result["security"]
def test_preserves_existing_security_section(self):
"""_migrate_old_config does not overwrite an existing security section."""
existing_config = {
"program": {},
"rss_parser": {},
"security": {
"login_whitelist": ["10.0.0.0/8"],
"login_tokens": ["mytoken"],
"mcp_whitelist": [],
"mcp_tokens": [],
},
}
result = Settings._migrate_old_config(existing_config)
assert result["security"]["login_tokens"] == ["mytoken"]
assert result["security"]["login_whitelist"] == ["10.0.0.0/8"]

View File

@@ -293,7 +293,68 @@ class TestClientDelegation:
result = await download_client.rename_torrent_file("hash1", "old.mkv", "new.mkv")
assert result is False
async def test_rename_torrent_file_passes_verify_flag(
self, download_client, mock_qb_client
):
"""rename_torrent_file forwards the verify kwarg to the underlying client."""
mock_qb_client.torrents_rename_file.return_value = True
await download_client.rename_torrent_file(
"hash1", "old.mkv", "new.mkv", verify=False
)
call_kwargs = mock_qb_client.torrents_rename_file.call_args[1]
assert call_kwargs["verify"] is False
async def test_delete_torrent(self, download_client, mock_qb_client):
"""delete_torrent delegates to client.torrents_delete."""
await download_client.delete_torrent("hash1", delete_files=True)
mock_qb_client.torrents_delete.assert_called_once_with("hash1", delete_files=True)
# ---------------------------------------------------------------------------
# add_tag
# ---------------------------------------------------------------------------
class TestAddTag:
async def test_add_tag_delegates_to_client(self, download_client, mock_qb_client):
"""add_tag delegates to client.add_tag."""
mock_qb_client.add_tag = AsyncMock(return_value=None)
download_client.client = mock_qb_client
await download_client.add_tag("deadbeef12345678", "ab:42")
mock_qb_client.add_tag.assert_called_once_with("deadbeef12345678", "ab:42")
async def test_add_tag_short_hash_no_error(self, download_client, mock_qb_client):
"""add_tag with a hash shorter than 8 chars does not crash the slice."""
mock_qb_client.add_tag = AsyncMock(return_value=None)
download_client.client = mock_qb_client
# Should not raise even for short hashes
await download_client.add_tag("abc", "ab:1")
# ---------------------------------------------------------------------------
# Context manager: ConnectionError on failed auth
# ---------------------------------------------------------------------------
class TestContextManagerAuth:
async def test_aenter_raises_on_auth_failure(self, download_client, mock_qb_client):
"""__aenter__ raises ConnectionError when auth fails."""
mock_qb_client.auth.return_value = False
download_client.authed = False
with pytest.raises(ConnectionError, match="authentication failed"):
await download_client.__aenter__()
async def test_aenter_succeeds_when_auth_passes(self, download_client, mock_qb_client):
"""__aenter__ returns self when auth succeeds."""
mock_qb_client.auth.return_value = True
download_client.authed = False
result = await download_client.__aenter__()
assert result is download_client
assert download_client.authed is True
async def test_aexit_calls_logout_when_authed(self, download_client, mock_qb_client):
"""__aexit__ calls logout and resets authed when session was active."""
download_client.authed = True
await download_client.__aexit__(None, None, None)
mock_qb_client.logout.assert_called_once()
assert download_client.authed is False

View File

@@ -1,6 +1,6 @@
"""Tests for module.mcp.security - LocalNetworkMiddleware and _is_local()."""
"""Tests for module.mcp.security - McpAccessMiddleware, _is_allowed(), and clear_network_cache()."""
from unittest.mock import AsyncMock, MagicMock
from unittest.mock import patch
import pytest
from starlette.applications import Starlette
@@ -8,248 +8,223 @@ from starlette.responses import PlainTextResponse
from starlette.routing import Route
from starlette.testclient import TestClient
from module.mcp.security import LocalNetworkMiddleware, _is_local
from module.mcp.security import McpAccessMiddleware, _is_allowed, clear_network_cache
# ---------------------------------------------------------------------------
# _is_local() unit tests
# _is_allowed() unit tests
# ---------------------------------------------------------------------------
class TestIsLocal:
"""Verify _is_local() correctly classifies IP addresses."""
class TestIsAllowed:
"""Verify _is_allowed() checks IPs against a given whitelist."""
# --- loopback ---
def setup_method(self):
clear_network_cache()
def test_ipv4_loopback_127_0_0_1(self):
"""127.0.0.1 is the canonical loopback address."""
assert _is_local("127.0.0.1") is True
LOCAL_WHITELIST = [
"127.0.0.0/8",
"10.0.0.0/8",
"172.16.0.0/12",
"192.168.0.0/16",
"::1/128",
"fe80::/10",
"fc00::/7",
]
def test_ipv4_loopback_127_0_0_2(self):
"""127.0.0.2 is within 127.0.0.0/8 and therefore local."""
assert _is_local("127.0.0.2") is True
# --- allowed IPs ---
def test_ipv4_loopback_127_255_255_255(self):
"""Top of 127.0.0.0/8 range is still local."""
assert _is_local("127.255.255.255") is True
def test_ipv4_loopback_allowed(self):
assert _is_allowed("127.0.0.1", self.LOCAL_WHITELIST) is True
# --- RFC 1918 class-A (10.0.0.0/8) ---
def test_ipv4_loopback_range(self):
assert _is_allowed("127.255.255.255", self.LOCAL_WHITELIST) is True
def test_ipv4_10_network_start(self):
"""10.0.0.1 is in 10.0.0.0/8 private range."""
assert _is_local("10.0.0.1") is True
def test_ipv4_10_network(self):
assert _is_allowed("10.0.0.1", self.LOCAL_WHITELIST) is True
def test_ipv4_10_network_mid(self):
"""10.10.20.30 is inside 10.0.0.0/8."""
assert _is_local("10.10.20.30") is True
def test_ipv4_172_16_network(self):
assert _is_allowed("172.16.0.1", self.LOCAL_WHITELIST) is True
def test_ipv4_10_network_end(self):
"""10.255.255.254 is the last usable address in 10.0.0.0/8."""
assert _is_local("10.255.255.254") is True
# --- RFC 1918 class-B (172.16.0.0/12) ---
def test_ipv4_172_16_start(self):
"""172.16.0.1 is the first address in 172.16.0.0/12."""
assert _is_local("172.16.0.1") is True
def test_ipv4_172_31_end(self):
"""172.31.255.254 is at the top of the 172.16.0.0/12 range."""
assert _is_local("172.31.255.254") is True
def test_ipv4_172_15_not_local(self):
"""172.15.255.255 is just outside 172.16.0.0/12 (below the range)."""
assert _is_local("172.15.255.255") is False
def test_ipv4_172_32_not_local(self):
"""172.32.0.0 is just above 172.16.0.0/12 (outside the range)."""
assert _is_local("172.32.0.0") is False
# --- RFC 1918 class-C (192.168.0.0/16) ---
def test_ipv4_192_168_start(self):
"""192.168.0.1 is a typical home-router address."""
assert _is_local("192.168.0.1") is True
def test_ipv4_192_168_end(self):
"""192.168.255.254 is at the top of 192.168.0.0/16."""
assert _is_local("192.168.255.254") is True
# --- Public IPv4 ---
def test_public_ipv4_google_dns(self):
"""8.8.8.8 (Google DNS) is a public address."""
assert _is_local("8.8.8.8") is False
def test_public_ipv4_cloudflare_dns(self):
"""1.1.1.1 (Cloudflare) is a public address."""
assert _is_local("1.1.1.1") is False
def test_public_ipv4_broadcast_like(self):
"""203.0.113.1 (TEST-NET-3, RFC 5737) is not a private address."""
assert _is_local("203.0.113.1") is False
# --- IPv6 loopback ---
def test_ipv4_192_168_network(self):
assert _is_allowed("192.168.1.100", self.LOCAL_WHITELIST) is True
def test_ipv6_loopback(self):
"""::1 is the IPv6 loopback address."""
assert _is_local("::1") is True
# --- IPv6 link-local (fe80::/10) ---
assert _is_allowed("::1", self.LOCAL_WHITELIST) is True
def test_ipv6_link_local(self):
"""fe80::1 is an IPv6 link-local address."""
assert _is_local("fe80::1") is True
assert _is_allowed("fe80::1", self.LOCAL_WHITELIST) is True
def test_ipv6_link_local_full(self):
"""fe80::aabb:ccdd is also link-local."""
assert _is_local("fe80::aabb:ccdd") is True
def test_ipv6_ula(self):
assert _is_allowed("fd00::1", self.LOCAL_WHITELIST) is True
# --- IPv6 ULA (fc00::/7) ---
# --- denied IPs ---
def test_ipv6_ula_fc(self):
"""fc00::1 is within the ULA range fc00::/7."""
assert _is_local("fc00::1") is True
def test_public_ipv4_denied(self):
assert _is_allowed("8.8.8.8", self.LOCAL_WHITELIST) is False
def test_ipv6_ula_fd(self):
"""fd00::1 is within the ULA range fc00::/7 (fd prefix)."""
assert _is_local("fd00::1") is True
def test_public_ipv6_denied(self):
assert _is_allowed("2001:4860:4860::8888", self.LOCAL_WHITELIST) is False
# --- Public IPv6 ---
def test_172_outside_range(self):
assert _is_allowed("172.32.0.0", self.LOCAL_WHITELIST) is False
def test_public_ipv6_google(self):
"""2001:4860:4860::8888 (Google IPv6 DNS) is a public address."""
assert _is_local("2001:4860:4860::8888") is False
# --- empty whitelist ---
def test_public_ipv6_documentation(self):
"""2001:db8::1 (documentation prefix, RFC 3849) is public."""
assert _is_local("2001:db8::1") is False
def test_empty_whitelist_denies_all(self):
assert _is_allowed("127.0.0.1", []) is False
# --- Invalid inputs ---
# --- invalid inputs ---
def test_invalid_hostname_returns_false(self):
"""A hostname string is not parseable as an IP and must return False."""
assert _is_local("localhost") is False
def test_invalid_hostname(self):
assert _is_allowed("localhost", self.LOCAL_WHITELIST) is False
def test_invalid_string_returns_false(self):
"""A random non-IP string returns False without raising."""
assert _is_local("not-an-ip") is False
def test_empty_string(self):
assert _is_allowed("", self.LOCAL_WHITELIST) is False
def test_empty_string_returns_false(self):
"""An empty string is not a valid IP address."""
assert _is_local("") is False
def test_malformed_ipv4(self):
assert _is_allowed("256.0.0.1", self.LOCAL_WHITELIST) is False
def test_malformed_ipv4_returns_false(self):
"""A string that looks like IPv4 but is malformed returns False."""
assert _is_local("256.0.0.1") is False
# --- single IP whitelist ---
def test_partial_ipv4_returns_false(self):
"""An incomplete IPv4 address is not valid."""
assert _is_local("192.168") is False
def test_single_ip_whitelist(self):
assert _is_allowed("203.0.113.5", ["203.0.113.5/32"]) is True
assert _is_allowed("203.0.113.6", ["203.0.113.5/32"]) is False
# ---------------------------------------------------------------------------
# LocalNetworkMiddleware integration tests
# McpAccessMiddleware integration tests
# ---------------------------------------------------------------------------
def _make_app_with_middleware() -> Starlette:
"""Build a minimal Starlette app with LocalNetworkMiddleware applied."""
def _make_mcp_settings(mcp_whitelist=None, mcp_tokens=None):
"""Create a mock settings.security object."""
class MockSecurity:
def __init__(self):
self.mcp_whitelist = mcp_whitelist if mcp_whitelist is not None else []
self.mcp_tokens = mcp_tokens if mcp_tokens is not None else []
class MockSettings:
def __init__(self):
self.security = MockSecurity()
return MockSettings()
def _make_app() -> Starlette:
"""Build a minimal Starlette app with McpAccessMiddleware applied."""
async def homepage(request):
return PlainTextResponse("ok")
app = Starlette(routes=[Route("/", homepage)])
app.add_middleware(LocalNetworkMiddleware)
app.add_middleware(McpAccessMiddleware)
return app
class TestLocalNetworkMiddleware:
"""Verify LocalNetworkMiddleware allows or denies requests by client IP."""
def _patch_client_ip(app, ip):
"""Return a modified app that overrides the client IP in ASGI scope."""
original_build = app.build_middleware_stack
@pytest.fixture
def app(self):
return _make_app_with_middleware()
async def patched_app(scope, receive, send):
if scope["type"] == "http":
scope["client"] = (ip, 12345) if ip is not None else None
await original_build()(scope, receive, send)
def test_local_ipv4_loopback_allowed(self, app):
"""Requests from 127.0.0.1 are allowed through."""
# Starlette's TestClient identifies itself as "testclient", not a real IP.
# Patch the scope so the middleware sees an actual loopback address.
original_build = app.build_middleware_stack
app.build_middleware_stack = lambda: patched_app
return app
async def patched_app(scope, receive, send):
if scope["type"] == "http":
scope["client"] = ("127.0.0.1", 12345)
await original_build()(scope, receive, send)
app.build_middleware_stack = lambda: patched_app # type: ignore[method-assign]
class TestMcpAccessMiddleware:
"""Verify McpAccessMiddleware allows/denies requests by IP and token."""
client = TestClient(app, raise_server_exceptions=False)
response = client.get("/")
def setup_method(self):
clear_network_cache()
def test_allowed_ip_passes(self):
mock_settings = _make_mcp_settings(mcp_whitelist=["127.0.0.0/8"])
app = _patch_client_ip(_make_app(), "127.0.0.1")
with patch("module.mcp.security.settings", mock_settings):
client = TestClient(app, raise_server_exceptions=False)
response = client.get("/")
assert response.status_code == 200
assert response.text == "ok"
def test_non_local_ip_blocked(self, app):
"""Requests from a public IP are rejected with 403."""
# Patch the ASGI scope to simulate a public client
original_build = app.build_middleware_stack
async def patched_app(scope, receive, send):
if scope["type"] == "http":
scope["client"] = ("8.8.8.8", 12345)
await original_build()(scope, receive, send)
app.build_middleware_stack = lambda: patched_app # type: ignore[method-assign]
client = TestClient(app, raise_server_exceptions=False)
response = client.get("/")
def test_denied_ip_blocked(self):
mock_settings = _make_mcp_settings(mcp_whitelist=["127.0.0.0/8"])
app = _patch_client_ip(_make_app(), "8.8.8.8")
with patch("module.mcp.security.settings", mock_settings):
client = TestClient(app, raise_server_exceptions=False)
response = client.get("/")
assert response.status_code == 403
assert "MCP access is restricted to local network" in response.text
assert "MCP access denied" in response.text
def test_missing_client_blocked(self, app):
"""Requests with no client information are rejected with 403."""
original_build = app.build_middleware_stack
async def patched_app(scope, receive, send):
if scope["type"] == "http":
scope["client"] = None
await original_build()(scope, receive, send)
app.build_middleware_stack = lambda: patched_app # type: ignore[method-assign]
client = TestClient(app, raise_server_exceptions=False)
response = client.get("/")
def test_empty_whitelist_denies_all(self):
mock_settings = _make_mcp_settings(mcp_whitelist=[], mcp_tokens=[])
app = _patch_client_ip(_make_app(), "127.0.0.1")
with patch("module.mcp.security.settings", mock_settings):
client = TestClient(app, raise_server_exceptions=False)
response = client.get("/")
assert response.status_code == 403
def test_blocked_response_is_json(self, app):
"""The 403 error body is valid JSON with an 'error' key."""
def test_missing_client_blocked(self):
mock_settings = _make_mcp_settings(mcp_whitelist=["127.0.0.0/8"])
app = _patch_client_ip(_make_app(), None)
with patch("module.mcp.security.settings", mock_settings):
client = TestClient(app, raise_server_exceptions=False)
response = client.get("/")
assert response.status_code == 403
def test_bearer_token_bypasses_ip(self):
mock_settings = _make_mcp_settings(
mcp_whitelist=[], mcp_tokens=["secret-token-123"]
)
app = _patch_client_ip(_make_app(), "8.8.8.8")
with patch("module.mcp.security.settings", mock_settings):
client = TestClient(app, raise_server_exceptions=False)
response = client.get(
"/", headers={"Authorization": "Bearer secret-token-123"}
)
assert response.status_code == 200
def test_invalid_bearer_token_denied(self):
mock_settings = _make_mcp_settings(
mcp_whitelist=[], mcp_tokens=["secret-token-123"]
)
app = _patch_client_ip(_make_app(), "8.8.8.8")
with patch("module.mcp.security.settings", mock_settings):
client = TestClient(app, raise_server_exceptions=False)
response = client.get(
"/", headers={"Authorization": "Bearer wrong-token"}
)
assert response.status_code == 403
def test_private_network_with_default_whitelist(self):
default_whitelist = [
"127.0.0.0/8",
"10.0.0.0/8",
"172.16.0.0/12",
"192.168.0.0/16",
"::1/128",
"fe80::/10",
"fc00::/7",
]
mock_settings = _make_mcp_settings(mcp_whitelist=default_whitelist)
app = _patch_client_ip(_make_app(), "192.168.1.100")
with patch("module.mcp.security.settings", mock_settings):
client = TestClient(app, raise_server_exceptions=False)
response = client.get("/")
assert response.status_code == 200
def test_blocked_response_is_json(self):
import json
original_build = app.build_middleware_stack
async def patched_app(scope, receive, send):
if scope["type"] == "http":
scope["client"] = ("1.2.3.4", 9999)
await original_build()(scope, receive, send)
app.build_middleware_stack = lambda: patched_app # type: ignore[method-assign]
client = TestClient(app, raise_server_exceptions=False)
response = client.get("/")
mock_settings = _make_mcp_settings(mcp_whitelist=["127.0.0.0/8"])
app = _patch_client_ip(_make_app(), "1.2.3.4")
with patch("module.mcp.security.settings", mock_settings):
client = TestClient(app, raise_server_exceptions=False)
response = client.get("/")
assert response.status_code == 403
body = json.loads(response.text)
assert "error" in body
def test_private_192_168_allowed(self, app):
"""Requests from a 192.168.x.x address pass through."""
original_build = app.build_middleware_stack
async def patched_app(scope, receive, send):
if scope["type"] == "http":
scope["client"] = ("192.168.1.100", 54321)
await original_build()(scope, receive, send)
app.build_middleware_stack = lambda: patched_app # type: ignore[method-assign]
client = TestClient(app, raise_server_exceptions=False)
response = client.get("/")
assert response.status_code == 200

View File

@@ -0,0 +1,429 @@
"""Tests for MockDownloader - state management and API contract."""
import pytest
from module.downloader.client.mock_downloader import MockDownloader
@pytest.fixture
def mock_dl() -> MockDownloader:
"""Fresh MockDownloader for each test."""
return MockDownloader()
# ---------------------------------------------------------------------------
# Initialization
# ---------------------------------------------------------------------------
class TestMockDownloaderInit:
def test_initial_state_is_empty(self, mock_dl):
"""MockDownloader starts with no torrents, rules, or feeds."""
state = mock_dl.get_state()
assert state["torrents"] == {}
assert state["rules"] == {}
assert state["feeds"] == {}
def test_initial_categories(self, mock_dl):
"""Default categories include Bangumi and BangumiCollection."""
state = mock_dl.get_state()
assert "Bangumi" in state["categories"]
assert "BangumiCollection" in state["categories"]
def test_initial_prefs(self, mock_dl):
"""Default prefs are populated."""
# Access private attribute directly to confirm defaults
assert mock_dl._prefs["rss_auto_downloading_enabled"] is True
assert mock_dl._prefs["rss_max_articles_per_feed"] == 500
# ---------------------------------------------------------------------------
# Auth / connection
# ---------------------------------------------------------------------------
class TestMockDownloaderAuth:
async def test_auth_returns_true(self, mock_dl):
result = await mock_dl.auth()
assert result is True
async def test_auth_retry_param_accepted(self, mock_dl):
result = await mock_dl.auth(retry=5)
assert result is True
async def test_logout_does_not_raise(self, mock_dl):
await mock_dl.logout()
async def test_check_host_returns_true(self, mock_dl):
result = await mock_dl.check_host()
assert result is True
async def test_check_connection_returns_version_string(self, mock_dl):
result = await mock_dl.check_connection()
assert "mock" in result.lower()
# ---------------------------------------------------------------------------
# Prefs
# ---------------------------------------------------------------------------
class TestMockDownloaderPrefs:
async def test_prefs_init_updates_prefs(self, mock_dl):
"""prefs_init merges given prefs into the internal store."""
await mock_dl.prefs_init({"rss_refresh_interval": 60, "custom_key": "val"})
assert mock_dl._prefs["rss_refresh_interval"] == 60
assert mock_dl._prefs["custom_key"] == "val"
async def test_get_app_prefs_returns_dict(self, mock_dl):
result = await mock_dl.get_app_prefs()
assert isinstance(result, dict)
assert "save_path" in result
# ---------------------------------------------------------------------------
# Categories
# ---------------------------------------------------------------------------
class TestMockDownloaderCategories:
async def test_add_category_persists(self, mock_dl):
await mock_dl.add_category("NewCategory")
state = mock_dl.get_state()
assert "NewCategory" in state["categories"]
async def test_add_duplicate_category_no_error(self, mock_dl):
await mock_dl.add_category("Bangumi")
state = mock_dl.get_state()
# Still only one entry for Bangumi (set semantics)
assert state["categories"].count("Bangumi") == 1
# ---------------------------------------------------------------------------
# Torrent management
# ---------------------------------------------------------------------------
class TestMockDownloaderAddTorrents:
async def test_add_torrent_url_returns_true(self, mock_dl):
result = await mock_dl.add_torrents(
torrent_urls="magnet:?xt=urn:btih:abc",
torrent_files=None,
save_path="/downloads/Bangumi",
category="Bangumi",
)
assert result is True
async def test_add_torrent_stores_in_state(self, mock_dl):
await mock_dl.add_torrents(
torrent_urls="magnet:?xt=urn:btih:abc",
torrent_files=None,
save_path="/downloads/Bangumi",
category="Bangumi",
)
state = mock_dl.get_state()
assert len(state["torrents"]) == 1
async def test_add_torrent_with_tag_stored(self, mock_dl):
await mock_dl.add_torrents(
torrent_urls="magnet:?xt=urn:btih:abc",
torrent_files=None,
save_path="/downloads/Bangumi",
category="Bangumi",
tags="ab:42",
)
state = mock_dl.get_state()
torrent = list(state["torrents"].values())[0]
assert torrent["tags"] == "ab:42"
async def test_add_torrent_with_file_bytes(self, mock_dl):
result = await mock_dl.add_torrents(
torrent_urls=None,
torrent_files=b"\x00\x01\x02",
save_path="/downloads",
category="Bangumi",
)
assert result is True
async def test_two_different_torrents_stored_separately(self, mock_dl):
await mock_dl.add_torrents(
torrent_urls="magnet:?xt=urn:btih:aaa",
torrent_files=None,
save_path="/dl",
category="Bangumi",
)
await mock_dl.add_torrents(
torrent_urls="magnet:?xt=urn:btih:bbb",
torrent_files=None,
save_path="/dl",
category="Bangumi",
)
state = mock_dl.get_state()
assert len(state["torrents"]) == 2
class TestMockDownloaderTorrentsInfo:
async def test_returns_all_when_no_filter(self, mock_dl):
mock_dl.add_mock_torrent("Anime A", category="Bangumi")
mock_dl.add_mock_torrent("Anime B", category="Bangumi")
result = await mock_dl.torrents_info(status_filter=None, category=None)
assert len(result) == 2
async def test_filters_by_category(self, mock_dl):
mock_dl.add_mock_torrent("Anime A", category="Bangumi")
mock_dl.add_mock_torrent("Movie", category="BangumiCollection")
result = await mock_dl.torrents_info(status_filter=None, category="Bangumi")
assert len(result) == 1
assert result[0]["name"] == "Anime A"
async def test_filters_by_tag(self, mock_dl):
h1 = mock_dl.add_mock_torrent("Anime A", category="Bangumi")
mock_dl.add_mock_torrent("Anime B", category="Bangumi")
# Manually set the tag on first torrent
mock_dl._torrents[h1]["tags"] = ["ab:1"]
result = await mock_dl.torrents_info(
status_filter=None, category=None, tag="ab:1"
)
assert len(result) == 1
assert result[0]["name"] == "Anime A"
async def test_empty_store_returns_empty_list(self, mock_dl):
result = await mock_dl.torrents_info(status_filter=None, category="Bangumi")
assert result == []
class TestMockDownloaderTorrentsFiles:
async def test_returns_files_for_known_hash(self, mock_dl):
files = [{"name": "ep01.mkv", "size": 500_000_000}]
h = mock_dl.add_mock_torrent("Anime", files=files)
result = await mock_dl.torrents_files(torrent_hash=h)
assert len(result) == 1
assert result[0]["name"] == "ep01.mkv"
async def test_returns_empty_list_for_unknown_hash(self, mock_dl):
result = await mock_dl.torrents_files(torrent_hash="nonexistent")
assert result == []
class TestMockDownloaderDelete:
async def test_delete_single_torrent(self, mock_dl):
h = mock_dl.add_mock_torrent("Anime")
await mock_dl.torrents_delete(hash=h)
assert h not in mock_dl._torrents
async def test_delete_multiple_torrents_pipe_separated(self, mock_dl):
h1 = mock_dl.add_mock_torrent("Anime A")
h2 = mock_dl.add_mock_torrent("Anime B")
await mock_dl.torrents_delete(hash=f"{h1}|{h2}")
assert h1 not in mock_dl._torrents
assert h2 not in mock_dl._torrents
async def test_delete_nonexistent_hash_no_error(self, mock_dl):
await mock_dl.torrents_delete(hash="deadbeef")
class TestMockDownloaderPauseResume:
async def test_pause_sets_state(self, mock_dl):
h = mock_dl.add_mock_torrent("Anime", state="downloading")
await mock_dl.torrents_pause(hashes=h)
assert mock_dl._torrents[h]["state"] == "paused"
async def test_resume_sets_state(self, mock_dl):
h = mock_dl.add_mock_torrent("Anime", state="paused")
await mock_dl.torrents_resume(hashes=h)
assert mock_dl._torrents[h]["state"] == "downloading"
async def test_pause_multiple_pipe_separated(self, mock_dl):
h1 = mock_dl.add_mock_torrent("Anime A", state="downloading")
h2 = mock_dl.add_mock_torrent("Anime B", state="downloading")
await mock_dl.torrents_pause(hashes=f"{h1}|{h2}")
assert mock_dl._torrents[h1]["state"] == "paused"
assert mock_dl._torrents[h2]["state"] == "paused"
async def test_pause_unknown_hash_no_error(self, mock_dl):
await mock_dl.torrents_pause(hashes="deadbeef")
# ---------------------------------------------------------------------------
# Rename
# ---------------------------------------------------------------------------
class TestMockDownloaderRename:
async def test_rename_returns_true(self, mock_dl):
result = await mock_dl.torrents_rename_file(
torrent_hash="hash1",
old_path="old.mkv",
new_path="new.mkv",
)
assert result is True
async def test_rename_with_verify_flag(self, mock_dl):
result = await mock_dl.torrents_rename_file(
torrent_hash="hash1",
old_path="old.mkv",
new_path="new.mkv",
verify=False,
)
assert result is True
# ---------------------------------------------------------------------------
# RSS feed management
# ---------------------------------------------------------------------------
class TestMockDownloaderRssFeeds:
async def test_add_feed_stored(self, mock_dl):
await mock_dl.rss_add_feed(url="https://mikan.me/RSS/test", item_path="Mikan")
feeds = await mock_dl.rss_get_feeds()
assert "Mikan" in feeds
assert feeds["Mikan"]["url"] == "https://mikan.me/RSS/test"
async def test_remove_feed(self, mock_dl):
await mock_dl.rss_add_feed(url="https://example.com", item_path="Feed1")
await mock_dl.rss_remove_item(item_path="Feed1")
feeds = await mock_dl.rss_get_feeds()
assert "Feed1" not in feeds
async def test_remove_nonexistent_feed_no_error(self, mock_dl):
await mock_dl.rss_remove_item(item_path="nonexistent")
async def test_get_feeds_initially_empty(self, mock_dl):
feeds = await mock_dl.rss_get_feeds()
assert feeds == {}
# ---------------------------------------------------------------------------
# Rules
# ---------------------------------------------------------------------------
class TestMockDownloaderRules:
async def test_set_rule_stored(self, mock_dl):
rule_def = {"enable": True, "mustContain": "Anime"}
await mock_dl.rss_set_rule("rule1", rule_def)
rules = await mock_dl.get_download_rule()
assert "rule1" in rules
assert rules["rule1"]["mustContain"] == "Anime"
async def test_remove_rule(self, mock_dl):
await mock_dl.rss_set_rule("rule1", {"enable": True})
await mock_dl.remove_rule("rule1")
rules = await mock_dl.get_download_rule()
assert "rule1" not in rules
async def test_remove_nonexistent_rule_no_error(self, mock_dl):
await mock_dl.remove_rule("nonexistent")
async def test_get_download_rule_initially_empty(self, mock_dl):
rules = await mock_dl.get_download_rule()
assert rules == {}
# ---------------------------------------------------------------------------
# Move / path
# ---------------------------------------------------------------------------
class TestMockDownloaderMovePath:
async def test_move_torrent_updates_save_path(self, mock_dl):
h = mock_dl.add_mock_torrent("Anime", save_path="/old/path")
await mock_dl.move_torrent(hashes=h, new_location="/new/path")
assert mock_dl._torrents[h]["save_path"] == "/new/path"
async def test_move_multiple_pipe_separated(self, mock_dl):
h1 = mock_dl.add_mock_torrent("Anime A", save_path="/old")
h2 = mock_dl.add_mock_torrent("Anime B", save_path="/old")
await mock_dl.move_torrent(hashes=f"{h1}|{h2}", new_location="/new")
assert mock_dl._torrents[h1]["save_path"] == "/new"
assert mock_dl._torrents[h2]["save_path"] == "/new"
async def test_get_torrent_path_known_hash(self, mock_dl):
h = mock_dl.add_mock_torrent("Anime", save_path="/downloads/Bangumi")
path = await mock_dl.get_torrent_path(h)
assert path == "/downloads/Bangumi"
async def test_get_torrent_path_unknown_hash_returns_default(self, mock_dl):
path = await mock_dl.get_torrent_path("nonexistent")
assert path == "/tmp/mock-downloads"
# ---------------------------------------------------------------------------
# Category assignment
# ---------------------------------------------------------------------------
class TestMockDownloaderSetCategory:
async def test_set_category_updates_torrent(self, mock_dl):
h = mock_dl.add_mock_torrent("Anime", category="Bangumi")
await mock_dl.set_category(h, "BangumiCollection")
assert mock_dl._torrents[h]["category"] == "BangumiCollection"
async def test_set_category_unknown_hash_no_error(self, mock_dl):
await mock_dl.set_category("deadbeef", "Bangumi")
# ---------------------------------------------------------------------------
# Tags
# ---------------------------------------------------------------------------
class TestMockDownloaderTags:
async def test_add_tag_appends(self, mock_dl):
h = mock_dl.add_mock_torrent("Anime")
await mock_dl.add_tag(h, "ab:1")
assert "ab:1" in mock_dl._torrents[h]["tags"]
async def test_add_tag_no_duplicates(self, mock_dl):
h = mock_dl.add_mock_torrent("Anime")
await mock_dl.add_tag(h, "ab:1")
await mock_dl.add_tag(h, "ab:1")
assert mock_dl._torrents[h]["tags"].count("ab:1") == 1
async def test_add_tag_unknown_hash_no_error(self, mock_dl):
await mock_dl.add_tag("deadbeef", "ab:1")
async def test_multiple_tags_on_same_torrent(self, mock_dl):
h = mock_dl.add_mock_torrent("Anime")
await mock_dl.add_tag(h, "ab:1")
await mock_dl.add_tag(h, "group:sub")
assert "ab:1" in mock_dl._torrents[h]["tags"]
assert "group:sub" in mock_dl._torrents[h]["tags"]
# ---------------------------------------------------------------------------
# add_mock_torrent helper
# ---------------------------------------------------------------------------
class TestAddMockTorrentHelper:
def test_generates_hash_from_name(self, mock_dl):
h = mock_dl.add_mock_torrent("Anime")
assert h is not None
assert len(h) == 40 # SHA1 hex digest
def test_explicit_hash_used(self, mock_dl):
h = mock_dl.add_mock_torrent("Anime", hash="cafebabe" + "0" * 32)
assert h == "cafebabe" + "0" * 32
def test_torrent_state_is_completed_by_default(self, mock_dl):
h = mock_dl.add_mock_torrent("Anime")
assert mock_dl._torrents[h]["state"] == "completed"
assert mock_dl._torrents[h]["progress"] == 1.0
def test_torrent_state_custom(self, mock_dl):
h = mock_dl.add_mock_torrent("Anime", state="downloading")
assert mock_dl._torrents[h]["state"] == "downloading"
assert mock_dl._torrents[h]["progress"] == 0.5
def test_default_file_is_mkv(self, mock_dl):
h = mock_dl.add_mock_torrent("My Anime")
files = mock_dl._torrents[h]["files"]
assert len(files) == 1
assert files[0]["name"].endswith(".mkv")
def test_custom_files_stored(self, mock_dl):
custom_files = [{"name": "ep01.mkv"}, {"name": "ep02.mkv"}]
h = mock_dl.add_mock_torrent("Anime", files=custom_files)
assert len(mock_dl._torrents[h]["files"]) == 2

2
backend/uv.lock generated
View File

@@ -61,7 +61,7 @@ wheels = [
[[package]]
name = "auto-bangumi"
version = "3.2.3b5"
version = "3.2.3"
source = { virtual = "." }
dependencies = [
{ name = "aiosqlite" },

View File

@@ -0,0 +1,68 @@
<script lang="ts" setup>
import type { Security } from '#/config';
import type { SettingItem } from '#/components';
const { t } = useMyI18n();
const { getSettingGroup } = useConfigStore();
const security = getSettingGroup('security');
const items: SettingItem<Security>[] = [
{
configKey: 'login_whitelist',
label: () => t('config.security_set.login_whitelist'),
type: 'dynamic-tags',
prop: {
placeholder: '192.168.0.0/16',
},
},
{
configKey: 'login_tokens',
label: () => t('config.security_set.login_tokens'),
type: 'dynamic-tags',
prop: {
placeholder: 'your-api-token',
},
bottomLine: true,
},
{
configKey: 'mcp_whitelist',
label: () => t('config.security_set.mcp_whitelist'),
type: 'dynamic-tags',
prop: {
placeholder: '127.0.0.0/8',
},
},
{
configKey: 'mcp_tokens',
label: () => t('config.security_set.mcp_tokens'),
type: 'dynamic-tags',
prop: {
placeholder: 'your-mcp-token',
},
},
];
</script>
<template>
<ab-fold-panel :title="$t('config.security_set.title')">
<p class="hint-text">{{ $t('config.security_set.hint') }}</p>
<div space-y-8>
<ab-setting
v-for="i in items"
:key="i.configKey"
v-bind="i"
v-model:data="security[i.configKey]"
></ab-setting>
</div>
</ab-fold-panel>
</template>
<style lang="scss" scoped>
.hint-text {
font-size: 12px;
color: var(--color-text-secondary);
margin-bottom: 12px;
line-height: 1.5;
}
</style>

View File

@@ -49,6 +49,7 @@ onActivated(() => {
<config-player></config-player>
<config-openai></config-openai>
<config-passkey></config-passkey>
<config-security></config-security>
</div>
</div>

View File

@@ -2,10 +2,6 @@ import type { TupleToUnion } from './utils';
/** 下载方式 */
export type DownloaderType = ['qbittorrent'];
/** rss parser 源 */
export type RssParserType = ['mikan'];
/** rss parser 方法 */
export type RssParserMethodType = ['tmdb', 'mikan', 'parser'];
/** rss parser 语言 */
export type RssParserLang = ['zh', 'en', 'jp'];
/** 重命名方式 */
@@ -44,12 +40,8 @@ export interface Downloader {
}
export interface RssParser {
enable: boolean;
type: TupleToUnion<RssParserType>;
token: string;
custom_url: string;
filter: Array<string>;
language: TupleToUnion<RssParserLang>;
parser_type: TupleToUnion<RssParserMethodType>;
}
export interface BangumiManage {
enable: boolean;
@@ -105,6 +97,17 @@ export interface ExperimentalOpenAI {
deployment_id?: string;
}
/** Access control for the login endpoint and MCP server.
* Whitelist entries are IPv4/IPv6 CIDR strings (e.g. "192.168.0.0/16").
* An empty login_whitelist allows all IPs; an empty mcp_whitelist denies all IP-based MCP access.
*/
export interface Security {
login_whitelist: string[];
login_tokens: string[];
mcp_whitelist: string[];
mcp_tokens: string[];
}
export interface Config {
program: Program;
downloader: Downloader;
@@ -114,6 +117,7 @@ export interface Config {
proxy: Proxy;
notification: Notification;
experimental_openai: ExperimentalOpenAI;
security: Security;
}
export const initConfig: Config = {
@@ -132,12 +136,8 @@ export const initConfig: Config = {
},
rss_parser: {
enable: true,
type: 'mikan',
token: '',
custom_url: '',
filter: [],
language: 'zh',
parser_type: 'parser',
},
bangumi_manage: {
enable: true,
@@ -171,4 +171,10 @@ export const initConfig: Config = {
api_version: '2020-05-03',
deployment_id: '',
},
security: {
login_whitelist: [],
login_tokens: [],
mcp_whitelist: [],
mcp_tokens: [],
},
};

View File

@@ -40,6 +40,7 @@ declare global {
const getCurrentInstance: typeof import('vue')['getCurrentInstance']
const getCurrentScope: typeof import('vue')['getCurrentScope']
const h: typeof import('vue')['h']
const i18n: typeof import('../../src/hooks/useMyI18n')['i18n']
const inject: typeof import('vue')['inject']
const isProxy: typeof import('vue')['isProxy']
const isReactive: typeof import('vue')['isReactive']

View File

@@ -55,6 +55,7 @@ declare module '@vue/runtime-core' {
ConfigPlayer: typeof import('./../../src/components/setting/config-player.vue')['default']
ConfigProxy: typeof import('./../../src/components/setting/config-proxy.vue')['default']
ConfigSearchProvider: typeof import('./../../src/components/setting/config-search-provider.vue')['default']
ConfigSecurity: typeof import('./../../src/components/setting/config-security.vue')['default']
MediaQuery: typeof import('./../../src/components/media-query.vue')['default']
RouterLink: typeof import('vue-router')['RouterLink']
RouterView: typeof import('vue-router')['RouterView']