mirror of
https://github.com/EstrellaXD/Auto_Bangumi.git
synced 2026-03-20 03:46:40 +08:00
feat(security): add security config UI and improve auth/MCP security
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
38
CHANGELOG.md
38
CHANGELOG.md
@@ -1,3 +1,41 @@
|
|||||||
|
# [Unreleased]
|
||||||
|
|
||||||
|
## Backend
|
||||||
|
|
||||||
|
### Added
|
||||||
|
|
||||||
|
- 新增 `Security` 配置模型,支持登录 IP 白名单、MCP IP 白名单和 Bearer Token 认证
|
||||||
|
- 新增登录端点 IP 白名单检查中间件 (`check_login_ip`)
|
||||||
|
- MCP 安全中间件升级为可配置模式:支持 CIDR 白名单 + Bearer Token 双重认证
|
||||||
|
- 认证端点支持 `Authorization: Bearer` 令牌绕过 Cookie 登录
|
||||||
|
- 配置 API `_sanitize_dict` 修复:仅对字符串值进行脱敏,避免误处理非字符串字段
|
||||||
|
|
||||||
|
- 新增番剧放送日手动设置 API (`PATCH /api/v1/bangumi/{id}/weekday`),支持锁定放送日防止日历刷新覆盖
|
||||||
|
- 数据库迁移 v9:`bangumi` 表新增 `weekday_locked` 列
|
||||||
|
|
||||||
|
### Changed
|
||||||
|
|
||||||
|
- 重构认证模块:提取 `_issue_token` 公共方法,消除 3 处重复的 JWT 签发逻辑
|
||||||
|
- `get_current_user` 简化为三级认证(DEV 绕过 → Bearer Token → Cookie JWT)
|
||||||
|
- `LocalNetworkMiddleware` 重命名为 `McpAccessMiddleware`,从硬编码 RFC 1918 改为读取配置
|
||||||
|
|
||||||
|
### Tests
|
||||||
|
|
||||||
|
- 新增 101 个单元测试覆盖安全、认证、配置、下载器和 MockDownloader 模块
|
||||||
|
|
||||||
|
## Frontend
|
||||||
|
|
||||||
|
### Added
|
||||||
|
|
||||||
|
- 新增日历拖拽排列功能:可将「未知」番剧拖入星期列,自动设置放送日并锁定
|
||||||
|
- 拖入后显示紫色图钉图标,鼠标悬停显示取消按钮
|
||||||
|
- 锁定的番剧在日历刷新时不会被覆盖
|
||||||
|
- 使用 vuedraggable 实现流畅拖拽动画
|
||||||
|
- 新增安全设置组件 (`config-security.vue`),支持在 WebUI 中配置 IP 白名单和 Token
|
||||||
|
- 前端 `Security` 类型定义和初始化配置
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
# [3.2.3] - 2026-02-23
|
# [3.2.3] - 2026-02-23
|
||||||
|
|
||||||
## Backend
|
## Backend
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "auto-bangumi"
|
name = "auto-bangumi"
|
||||||
version = "3.2.3"
|
version = "3.2.4"
|
||||||
description = "AutoBangumi - Automated anime download manager"
|
description = "AutoBangumi - Automated anime download manager"
|
||||||
requires-python = ">=3.13"
|
requires-python = ">=3.13"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from module.models.user import User, UserUpdate
|
|||||||
from module.security.api import (
|
from module.security.api import (
|
||||||
active_user,
|
active_user,
|
||||||
auth_user,
|
auth_user,
|
||||||
|
check_login_ip,
|
||||||
get_current_user,
|
get_current_user,
|
||||||
update_user_info,
|
update_user_info,
|
||||||
)
|
)
|
||||||
@@ -18,17 +19,26 @@ from .response import u_response
|
|||||||
|
|
||||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||||
|
|
||||||
|
_TOKEN_EXPIRY_DAYS = 1
|
||||||
|
_TOKEN_MAX_AGE = 86400
|
||||||
|
|
||||||
@router.post("/login", response_model=dict)
|
|
||||||
|
def _issue_token(username: str, response: Response) -> dict:
|
||||||
|
"""Create a JWT, set it as an HttpOnly cookie, and return the bearer payload."""
|
||||||
|
token = create_access_token(
|
||||||
|
data={"sub": username}, expires_delta=timedelta(days=_TOKEN_EXPIRY_DAYS)
|
||||||
|
)
|
||||||
|
response.set_cookie(key="token", value=token, httponly=True, max_age=_TOKEN_MAX_AGE)
|
||||||
|
return {"access_token": token, "token_type": "bearer"}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/login", response_model=dict, dependencies=[Depends(check_login_ip)])
|
||||||
async def login(response: Response, form_data=Depends(OAuth2PasswordRequestForm)):
|
async def login(response: Response, form_data=Depends(OAuth2PasswordRequestForm)):
|
||||||
|
"""Authenticate with username/password and issue a session token."""
|
||||||
user = User(username=form_data.username, password=form_data.password)
|
user = User(username=form_data.username, password=form_data.password)
|
||||||
resp = auth_user(user)
|
resp = auth_user(user)
|
||||||
if resp.status:
|
if resp.status:
|
||||||
token = create_access_token(
|
return _issue_token(user.username, response)
|
||||||
data={"sub": user.username}, expires_delta=timedelta(days=1)
|
|
||||||
)
|
|
||||||
response.set_cookie(key="token", value=token, httponly=True, max_age=86400)
|
|
||||||
return {"access_token": token, "token_type": "bearer"}
|
|
||||||
return u_response(resp)
|
return u_response(resp)
|
||||||
|
|
||||||
|
|
||||||
@@ -36,6 +46,7 @@ async def login(response: Response, form_data=Depends(OAuth2PasswordRequestForm)
|
|||||||
"/refresh_token", response_model=dict, dependencies=[Depends(get_current_user)]
|
"/refresh_token", response_model=dict, dependencies=[Depends(get_current_user)]
|
||||||
)
|
)
|
||||||
async def refresh(response: Response, token: str = Cookie(None)):
|
async def refresh(response: Response, token: str = Cookie(None)):
|
||||||
|
"""Refresh the current session token and update the active-user timestamp."""
|
||||||
payload = decode_token(token)
|
payload = decode_token(token)
|
||||||
username = payload.get("sub") if payload else None
|
username = payload.get("sub") if payload else None
|
||||||
if not username:
|
if not username:
|
||||||
@@ -43,17 +54,14 @@ async def refresh(response: Response, token: str = Cookie(None)):
|
|||||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Unauthorized"
|
status_code=status.HTTP_401_UNAUTHORIZED, detail="Unauthorized"
|
||||||
)
|
)
|
||||||
active_user[username] = datetime.now()
|
active_user[username] = datetime.now()
|
||||||
new_token = create_access_token(
|
return _issue_token(username, response)
|
||||||
data={"sub": username}, expires_delta=timedelta(days=1)
|
|
||||||
)
|
|
||||||
response.set_cookie(key="token", value=new_token, httponly=True, max_age=86400)
|
|
||||||
return {"access_token": new_token, "token_type": "bearer"}
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/logout", response_model=APIResponse, dependencies=[Depends(get_current_user)]
|
"/logout", response_model=APIResponse, dependencies=[Depends(get_current_user)]
|
||||||
)
|
)
|
||||||
async def logout(response: Response, token: str = Cookie(None)):
|
async def logout(response: Response, token: str = Cookie(None)):
|
||||||
|
"""Invalidate the session and clear the token cookie."""
|
||||||
payload = decode_token(token)
|
payload = decode_token(token)
|
||||||
username = payload.get("sub") if payload else None
|
username = payload.get("sub") if payload else None
|
||||||
if username:
|
if username:
|
||||||
@@ -69,6 +77,7 @@ async def logout(response: Response, token: str = Cookie(None)):
|
|||||||
async def update_user(
|
async def update_user(
|
||||||
user_data: UserUpdate, response: Response, token: str = Cookie(None)
|
user_data: UserUpdate, response: Response, token: str = Cookie(None)
|
||||||
):
|
):
|
||||||
|
"""Update credentials for the current user and re-issue a fresh token."""
|
||||||
payload = decode_token(token)
|
payload = decode_token(token)
|
||||||
old_user = payload.get("sub") if payload else None
|
old_user = payload.get("sub") if payload else None
|
||||||
if not old_user:
|
if not old_user:
|
||||||
@@ -76,17 +85,4 @@ async def update_user(
|
|||||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Unauthorized"
|
status_code=status.HTTP_401_UNAUTHORIZED, detail="Unauthorized"
|
||||||
)
|
)
|
||||||
if update_user_info(user_data, old_user):
|
if update_user_info(user_data, old_user):
|
||||||
token = create_access_token(
|
return {**_issue_token(old_user, response), "message": "update success"}
|
||||||
data={"sub": old_user}, expires_delta=timedelta(days=1)
|
|
||||||
)
|
|
||||||
response.set_cookie(
|
|
||||||
key="token",
|
|
||||||
value=token,
|
|
||||||
httponly=True,
|
|
||||||
max_age=86400,
|
|
||||||
)
|
|
||||||
return {
|
|
||||||
"access_token": token,
|
|
||||||
"token_type": "bearer",
|
|
||||||
"message": "update success",
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -14,11 +14,12 @@ _SENSITIVE_KEYS = ("password", "api_key", "token", "secret")
|
|||||||
|
|
||||||
|
|
||||||
def _sanitize_dict(d: dict) -> dict:
|
def _sanitize_dict(d: dict) -> dict:
|
||||||
|
"""Recursively mask string values whose keys contain sensitive keywords."""
|
||||||
result = {}
|
result = {}
|
||||||
for k, v in d.items():
|
for k, v in d.items():
|
||||||
if isinstance(v, dict):
|
if isinstance(v, dict):
|
||||||
result[k] = _sanitize_dict(v)
|
result[k] = _sanitize_dict(v)
|
||||||
elif any(s in k.lower() for s in _SENSITIVE_KEYS):
|
elif isinstance(v, str) and any(s in k.lower() for s in _SENSITIVE_KEYS):
|
||||||
result[k] = "********"
|
result[k] = "********"
|
||||||
else:
|
else:
|
||||||
result[k] = v
|
result[k] = v
|
||||||
@@ -27,6 +28,7 @@ def _sanitize_dict(d: dict) -> dict:
|
|||||||
|
|
||||||
@router.get("/get", dependencies=[Depends(get_current_user)])
|
@router.get("/get", dependencies=[Depends(get_current_user)])
|
||||||
async def get_config():
|
async def get_config():
|
||||||
|
"""Return the current configuration with sensitive fields masked."""
|
||||||
return _sanitize_dict(settings.dict())
|
return _sanitize_dict(settings.dict())
|
||||||
|
|
||||||
|
|
||||||
@@ -34,6 +36,7 @@ async def get_config():
|
|||||||
"/update", response_model=APIResponse, dependencies=[Depends(get_current_user)]
|
"/update", response_model=APIResponse, dependencies=[Depends(get_current_user)]
|
||||||
)
|
)
|
||||||
async def update_config(config: Config):
|
async def update_config(config: Config):
|
||||||
|
"""Persist and reload configuration from the supplied payload."""
|
||||||
try:
|
try:
|
||||||
settings.save(config_dict=config.dict())
|
settings.save(config_dict=config.dict())
|
||||||
settings.load()
|
settings.load()
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from dotenv import load_dotenv
|
|||||||
|
|
||||||
from module.models.config import Config
|
from module.models.config import Config
|
||||||
|
|
||||||
from .const import ENV_TO_ATTR
|
from .const import DEFAULT_SETTINGS, ENV_TO_ATTR
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
CONFIG_ROOT = Path("config")
|
CONFIG_ROOT = Path("config")
|
||||||
@@ -27,6 +27,15 @@ CONFIG_PATH = (
|
|||||||
|
|
||||||
|
|
||||||
class Settings(Config):
|
class Settings(Config):
|
||||||
|
"""Runtime configuration singleton.
|
||||||
|
|
||||||
|
On construction, loads from ``CONFIG_PATH`` if the file exists (and
|
||||||
|
immediately re-saves to apply any migrations), otherwise bootstraps
|
||||||
|
defaults from environment variables via ``init()``.
|
||||||
|
|
||||||
|
Use ``settings`` module-level instance rather than instantiating directly.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if CONFIG_PATH.exists():
|
if CONFIG_PATH.exists():
|
||||||
@@ -36,6 +45,7 @@ class Settings(Config):
|
|||||||
self.init()
|
self.init()
|
||||||
|
|
||||||
def load(self):
|
def load(self):
|
||||||
|
"""Load and validate configuration from ``CONFIG_PATH``, applying migrations."""
|
||||||
with open(CONFIG_PATH, "r", encoding="utf-8") as f:
|
with open(CONFIG_PATH, "r", encoding="utf-8") as f:
|
||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
config = self._migrate_old_config(config)
|
config = self._migrate_old_config(config)
|
||||||
@@ -65,20 +75,27 @@ class Settings(Config):
|
|||||||
for key in ("type", "custom_url", "token", "enable_tmdb"):
|
for key in ("type", "custom_url", "token", "enable_tmdb"):
|
||||||
rss_parser.pop(key, None)
|
rss_parser.pop(key, None)
|
||||||
|
|
||||||
|
# Add security section if missing (preserves local-network MCP default)
|
||||||
|
if "security" not in config:
|
||||||
|
config["security"] = DEFAULT_SETTINGS["security"]
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
def save(self, config_dict: dict | None = None):
|
def save(self, config_dict: dict | None = None):
|
||||||
|
"""Write configuration to ``CONFIG_PATH``. Uses current state when no dict supplied."""
|
||||||
if not config_dict:
|
if not config_dict:
|
||||||
config_dict = self.model_dump()
|
config_dict = self.model_dump()
|
||||||
with open(CONFIG_PATH, "w", encoding="utf-8") as f:
|
with open(CONFIG_PATH, "w", encoding="utf-8") as f:
|
||||||
json.dump(config_dict, f, indent=4, ensure_ascii=False)
|
json.dump(config_dict, f, indent=4, ensure_ascii=False)
|
||||||
|
|
||||||
def init(self):
|
def init(self):
|
||||||
|
"""Bootstrap a new config file from ``.env`` and environment variables."""
|
||||||
load_dotenv(".env")
|
load_dotenv(".env")
|
||||||
self.__load_from_env()
|
self.__load_from_env()
|
||||||
self.save()
|
self.save()
|
||||||
|
|
||||||
def __load_from_env(self):
|
def __load_from_env(self):
|
||||||
|
"""Apply ``ENV_TO_ATTR`` mappings from the process environment to the config dict."""
|
||||||
config_dict = self.model_dump()
|
config_dict = self.model_dump()
|
||||||
for key, section in ENV_TO_ATTR.items():
|
for key, section in ENV_TO_ATTR.items():
|
||||||
for env, attr in section.items():
|
for env, attr in section.items():
|
||||||
@@ -97,12 +114,11 @@ class Settings(Config):
|
|||||||
logger.info("Config loaded from env")
|
logger.info("Config loaded from env")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def __val_from_env(env: str, attr: tuple):
|
def __val_from_env(env: str, attr: tuple | str):
|
||||||
|
"""Return the environment variable value, applying the converter when attr is a tuple."""
|
||||||
if isinstance(attr, tuple):
|
if isinstance(attr, tuple):
|
||||||
conv_func = attr[1]
|
return attr[1](os.environ[env])
|
||||||
return conv_func(os.environ[env])
|
return os.environ[env]
|
||||||
else:
|
|
||||||
return os.environ[env]
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def group_rules(self):
|
def group_rules(self):
|
||||||
|
|||||||
@@ -1,4 +1,8 @@
|
|||||||
# -*- encoding: utf-8 -*-
|
# -*- encoding: utf-8 -*-
|
||||||
|
# DEFAULT_SETTINGS: factory defaults written to config.json on first run.
|
||||||
|
# ENV_TO_ATTR: maps AB_* environment variables to Config model attribute paths.
|
||||||
|
# Values are either a string attr name, a (attr_name, converter) tuple, or a
|
||||||
|
# list of such tuples when a single env var sets multiple attributes.
|
||||||
DEFAULT_SETTINGS = {
|
DEFAULT_SETTINGS = {
|
||||||
"program": {
|
"program": {
|
||||||
"rss_time": 900,
|
"rss_time": 900,
|
||||||
@@ -46,6 +50,20 @@ DEFAULT_SETTINGS = {
|
|||||||
"model": "gpt-3.5-turbo",
|
"model": "gpt-3.5-turbo",
|
||||||
"deployment_id": "",
|
"deployment_id": "",
|
||||||
},
|
},
|
||||||
|
"security": {
|
||||||
|
"login_whitelist": [],
|
||||||
|
"login_tokens": [],
|
||||||
|
"mcp_whitelist": [
|
||||||
|
"127.0.0.0/8",
|
||||||
|
"10.0.0.0/8",
|
||||||
|
"172.16.0.0/12",
|
||||||
|
"192.168.0.0/16",
|
||||||
|
"::1/128",
|
||||||
|
"fe80::/10",
|
||||||
|
"fc00::/7",
|
||||||
|
],
|
||||||
|
"mcp_tokens": [],
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -99,8 +117,11 @@ ENV_TO_ATTR = {
|
|||||||
|
|
||||||
|
|
||||||
class BCOLORS:
|
class BCOLORS:
|
||||||
|
"""ANSI colour helpers for terminal output."""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _(color: str, *args: str) -> str:
|
def _(color: str, *args: str) -> str:
|
||||||
|
"""Wrap *args* in the given ANSI colour code and reset at the end."""
|
||||||
strings = [str(s) for s in args]
|
strings = [str(s) for s in args]
|
||||||
return f"{color}{', '.join(strings)}{BCOLORS.ENDC}"
|
return f"{color}{', '.join(strings)}{BCOLORS.ENDC}"
|
||||||
|
|
||||||
|
|||||||
@@ -29,10 +29,10 @@ class MockDownloader:
|
|||||||
"rss_processing_enabled": True,
|
"rss_processing_enabled": True,
|
||||||
"rss_refresh_interval": 30,
|
"rss_refresh_interval": 30,
|
||||||
}
|
}
|
||||||
logger.info("[MockDownloader] Initialized")
|
logger.debug("[MockDownloader] Initialized")
|
||||||
|
|
||||||
async def auth(self, retry=3) -> bool:
|
async def auth(self, retry=3) -> bool:
|
||||||
logger.info("[MockDownloader] Auth successful (mocked)")
|
logger.debug("[MockDownloader] Auth successful (mocked)")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def logout(self):
|
async def logout(self):
|
||||||
|
|||||||
@@ -11,6 +11,13 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class DownloadClient(TorrentPath):
|
class DownloadClient(TorrentPath):
|
||||||
|
"""Unified async download client.
|
||||||
|
|
||||||
|
Wraps qBittorrent, Aria2, or MockDownloader behind a common interface.
|
||||||
|
Intended to be used as an async context manager; authentication is
|
||||||
|
performed on ``__aenter__`` and the session is closed on ``__aexit__``.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.client = self.__getClient()
|
self.client = self.__getClient()
|
||||||
@@ -18,27 +25,28 @@ class DownloadClient(TorrentPath):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def __getClient():
|
def __getClient():
|
||||||
type = settings.downloader.type
|
"""Instantiate the configured downloader client (qbittorrent | aria2 | mock)."""
|
||||||
|
downloader_type = settings.downloader.type
|
||||||
host = settings.downloader.host
|
host = settings.downloader.host
|
||||||
username = settings.downloader.username
|
username = settings.downloader.username
|
||||||
password = settings.downloader.password
|
password = settings.downloader.password
|
||||||
ssl = settings.downloader.ssl
|
ssl = settings.downloader.ssl
|
||||||
if type == "qbittorrent":
|
if downloader_type == "qbittorrent":
|
||||||
from .client.qb_downloader import QbDownloader
|
from .client.qb_downloader import QbDownloader
|
||||||
|
|
||||||
return QbDownloader(host, username, password, ssl)
|
return QbDownloader(host, username, password, ssl)
|
||||||
elif type == "aria2":
|
elif downloader_type == "aria2":
|
||||||
from .client.aria2_downloader import Aria2Downloader
|
from .client.aria2_downloader import Aria2Downloader
|
||||||
|
|
||||||
return Aria2Downloader(host, username, password)
|
return Aria2Downloader(host, username, password)
|
||||||
elif type == "mock":
|
elif downloader_type == "mock":
|
||||||
from .client.mock_downloader import MockDownloader
|
from .client.mock_downloader import MockDownloader
|
||||||
|
|
||||||
logger.info("[Downloader] Using MockDownloader for local development")
|
logger.debug("[Downloader] Using MockDownloader for local development")
|
||||||
return MockDownloader()
|
return MockDownloader()
|
||||||
else:
|
else:
|
||||||
logger.error(f"[Downloader] Unsupported downloader type: {type}")
|
logger.error("[Downloader] Unsupported downloader type: %s", downloader_type)
|
||||||
raise Exception(f"Unsupported downloader type: {type}")
|
raise Exception(f"Unsupported downloader type: {downloader_type}")
|
||||||
|
|
||||||
async def __aenter__(self):
|
async def __aenter__(self):
|
||||||
if not self.authed:
|
if not self.authed:
|
||||||
@@ -65,6 +73,7 @@ class DownloadClient(TorrentPath):
|
|||||||
return await self.client.check_host()
|
return await self.client.check_host()
|
||||||
|
|
||||||
async def init_downloader(self):
|
async def init_downloader(self):
|
||||||
|
"""Apply required qBittorrent RSS preferences and create the Bangumi category."""
|
||||||
prefs = {
|
prefs = {
|
||||||
"rss_auto_downloading_enabled": True,
|
"rss_auto_downloading_enabled": True,
|
||||||
"rss_max_articles_per_feed": 500,
|
"rss_max_articles_per_feed": 500,
|
||||||
@@ -84,6 +93,7 @@ class DownloadClient(TorrentPath):
|
|||||||
settings.downloader.path = self._join_path(prefs["save_path"], "Bangumi")
|
settings.downloader.path = self._join_path(prefs["save_path"], "Bangumi")
|
||||||
|
|
||||||
async def set_rule(self, data: Bangumi):
|
async def set_rule(self, data: Bangumi):
|
||||||
|
"""Create or update a qBittorrent RSS auto-download rule for one bangumi entry."""
|
||||||
data.rule_name = self._rule_name(data)
|
data.rule_name = self._rule_name(data)
|
||||||
data.save_path = self._gen_save_path(data)
|
data.save_path = self._gen_save_path(data)
|
||||||
rule = {
|
rule = {
|
||||||
@@ -145,6 +155,12 @@ class DownloadClient(TorrentPath):
|
|||||||
await self.client.torrents_resume(hashes)
|
await self.client.torrents_resume(hashes)
|
||||||
|
|
||||||
async def add_torrent(self, torrent: Torrent | list, bangumi: Bangumi) -> bool:
|
async def add_torrent(self, torrent: Torrent | list, bangumi: Bangumi) -> bool:
|
||||||
|
"""Download a torrent (or list of torrents) for the given bangumi entry.
|
||||||
|
|
||||||
|
Handles both magnet links and .torrent file URLs, fetching file bytes
|
||||||
|
when necessary. Tags each torrent with ``ab:<bangumi_id>`` for later
|
||||||
|
episode-offset lookup during rename.
|
||||||
|
"""
|
||||||
if not bangumi.save_path:
|
if not bangumi.save_path:
|
||||||
bangumi.save_path = self._gen_save_path(bangumi)
|
bangumi.save_path = self._gen_save_path(bangumi)
|
||||||
async with RequestContent() as req:
|
async with RequestContent() as req:
|
||||||
|
|||||||
@@ -1,48 +1,68 @@
|
|||||||
"""MCP access control: restricts connections to local network addresses only."""
|
"""MCP access control: configurable IP whitelist and bearer token authentication."""
|
||||||
|
|
||||||
import ipaddress
|
import ipaddress
|
||||||
import logging
|
import logging
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
from starlette.middleware.base import BaseHTTPMiddleware
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
from starlette.responses import JSONResponse
|
from starlette.responses import JSONResponse
|
||||||
|
|
||||||
|
from module.conf import settings
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# RFC 1918 private ranges + loopback + IPv6 equivalents
|
|
||||||
_ALLOWED_NETWORKS = [
|
@lru_cache(maxsize=128)
|
||||||
ipaddress.ip_network("127.0.0.0/8"),
|
def _parse_network(cidr: str) -> ipaddress.IPv4Network | ipaddress.IPv6Network | None:
|
||||||
ipaddress.ip_network("10.0.0.0/8"),
|
try:
|
||||||
ipaddress.ip_network("172.16.0.0/12"),
|
return ipaddress.ip_network(cidr, strict=False)
|
||||||
ipaddress.ip_network("192.168.0.0/16"),
|
except ValueError:
|
||||||
ipaddress.ip_network("::1/128"),
|
logger.warning("[MCP] Invalid CIDR in whitelist: %s", cidr)
|
||||||
ipaddress.ip_network("fe80::/10"),
|
return None
|
||||||
ipaddress.ip_network("fc00::/7"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def _is_local(host: str) -> bool:
|
def _is_allowed(host: str, whitelist: list[str]) -> bool:
|
||||||
"""Return True if *host* is a loopback or RFC 1918 private address."""
|
"""Return True if *host* falls within any CIDR range in *whitelist*."""
|
||||||
try:
|
try:
|
||||||
addr = ipaddress.ip_address(host)
|
addr = ipaddress.ip_address(host)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
return False
|
return False
|
||||||
return any(addr in net for net in _ALLOWED_NETWORKS)
|
for cidr in whitelist:
|
||||||
|
net = _parse_network(cidr)
|
||||||
|
if net and addr in net:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class LocalNetworkMiddleware(BaseHTTPMiddleware):
|
def clear_network_cache():
|
||||||
"""Starlette middleware that blocks requests from non-local IP addresses.
|
"""Clear the parsed network cache (call after config reload)."""
|
||||||
|
_parse_network.cache_clear()
|
||||||
|
|
||||||
Returns HTTP 403 for any client outside loopback, RFC 1918, or IPv6
|
|
||||||
link-local/unique-local ranges.
|
class McpAccessMiddleware(BaseHTTPMiddleware):
|
||||||
|
"""Configurable access control for MCP endpoint.
|
||||||
|
|
||||||
|
Checks client IP against ``settings.security.mcp_whitelist`` CIDR ranges,
|
||||||
|
and ``Authorization`` header against ``settings.security.mcp_tokens``.
|
||||||
|
If the whitelist is empty and no tokens are configured, all access is denied.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def dispatch(self, request: Request, call_next):
|
async def dispatch(self, request: Request, call_next):
|
||||||
|
# Check bearer token first
|
||||||
|
auth_header = request.headers.get("authorization", "")
|
||||||
|
if auth_header.startswith("Bearer "):
|
||||||
|
token = auth_header[7:]
|
||||||
|
if token and token in settings.security.mcp_tokens:
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
# Check IP whitelist
|
||||||
client_host = request.client.host if request.client else None
|
client_host = request.client.host if request.client else None
|
||||||
if not client_host or not _is_local(client_host):
|
if client_host and _is_allowed(client_host, settings.security.mcp_whitelist):
|
||||||
logger.warning("[MCP] Rejected non-local connection from %s", client_host)
|
return await call_next(request)
|
||||||
return JSONResponse(
|
|
||||||
status_code=403,
|
logger.warning("[MCP] Rejected connection from %s", client_host)
|
||||||
content={"error": "MCP access is restricted to local network"},
|
return JSONResponse(
|
||||||
)
|
status_code=403,
|
||||||
return await call_next(request)
|
content={"error": "MCP access denied"},
|
||||||
|
)
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ from starlette.requests import Request
|
|||||||
from starlette.routing import Mount, Route
|
from starlette.routing import Mount, Route
|
||||||
|
|
||||||
from .resources import RESOURCE_TEMPLATES, RESOURCES, handle_resource
|
from .resources import RESOURCE_TEMPLATES, RESOURCES, handle_resource
|
||||||
from .security import LocalNetworkMiddleware
|
from .security import McpAccessMiddleware
|
||||||
from .tools import TOOLS, handle_tool
|
from .tools import TOOLS, handle_tool
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -73,8 +73,8 @@ def create_mcp_starlette_app() -> Starlette:
|
|||||||
- ``GET /sse`` - SSE stream for MCP clients
|
- ``GET /sse`` - SSE stream for MCP clients
|
||||||
- ``POST /messages/`` - client-to-server message posting
|
- ``POST /messages/`` - client-to-server message posting
|
||||||
|
|
||||||
``LocalNetworkMiddleware`` is applied so the endpoint is only reachable
|
``McpAccessMiddleware`` is applied to enforce configurable IP whitelist
|
||||||
from loopback and RFC 1918 addresses.
|
and bearer token access control.
|
||||||
"""
|
"""
|
||||||
app = Starlette(
|
app = Starlette(
|
||||||
routes=[
|
routes=[
|
||||||
@@ -82,5 +82,5 @@ def create_mcp_starlette_app() -> Starlette:
|
|||||||
Mount("/messages", app=sse.handle_post_message),
|
Mount("/messages", app=sse.handle_post_message),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
app.add_middleware(LocalNetworkMiddleware)
|
app.add_middleware(McpAccessMiddleware)
|
||||||
return app
|
return app
|
||||||
|
|||||||
@@ -4,13 +4,27 @@ from typing import Literal, Optional
|
|||||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||||
|
|
||||||
|
|
||||||
|
def _expand(value: str | None) -> str:
|
||||||
|
"""Expand shell environment variables in *value*, returning empty string for None."""
|
||||||
|
return expandvars(value) if value else ""
|
||||||
|
|
||||||
|
|
||||||
class Program(BaseModel):
|
class Program(BaseModel):
|
||||||
|
"""Scheduler timing and WebUI port settings."""
|
||||||
|
|
||||||
rss_time: int = Field(900, description="Sleep time")
|
rss_time: int = Field(900, description="Sleep time")
|
||||||
rename_time: int = Field(60, description="Rename times in one loop")
|
rename_time: int = Field(60, description="Rename times in one loop")
|
||||||
webui_port: int = Field(7892, description="WebUI port")
|
webui_port: int = Field(7892, description="WebUI port")
|
||||||
|
|
||||||
|
|
||||||
class Downloader(BaseModel):
|
class Downloader(BaseModel):
|
||||||
|
"""Download client connection settings.
|
||||||
|
|
||||||
|
Credential fields (``host``, ``username``, ``password``) are stored with a
|
||||||
|
trailing underscore and exposed via properties that expand ``$VAR``
|
||||||
|
environment variable references at access time.
|
||||||
|
"""
|
||||||
|
|
||||||
type: str = Field("qbittorrent", description="Downloader type")
|
type: str = Field("qbittorrent", description="Downloader type")
|
||||||
host_: str = Field("172.17.0.1:8080", alias="host", description="Downloader host")
|
host_: str = Field("172.17.0.1:8080", alias="host", description="Downloader host")
|
||||||
username_: str = Field("admin", alias="username", description="Downloader username")
|
username_: str = Field("admin", alias="username", description="Downloader username")
|
||||||
@@ -22,24 +36,28 @@ class Downloader(BaseModel):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def host(self):
|
def host(self):
|
||||||
return expandvars(self.host_)
|
return _expand(self.host_)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def username(self):
|
def username(self):
|
||||||
return expandvars(self.username_)
|
return _expand(self.username_)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def password(self):
|
def password(self):
|
||||||
return expandvars(self.password_)
|
return _expand(self.password_)
|
||||||
|
|
||||||
|
|
||||||
class RSSParser(BaseModel):
|
class RSSParser(BaseModel):
|
||||||
|
"""RSS feed parsing settings."""
|
||||||
|
|
||||||
enable: bool = Field(True, description="Enable RSS parser")
|
enable: bool = Field(True, description="Enable RSS parser")
|
||||||
filter: list[str] = Field(["720", r"\d+-\d"], description="Filter")
|
filter: list[str] = Field(["720", r"\d+-\d"], description="Filter")
|
||||||
language: str = "zh"
|
language: str = "zh"
|
||||||
|
|
||||||
|
|
||||||
class BangumiManage(BaseModel):
|
class BangumiManage(BaseModel):
|
||||||
|
"""File organisation and renaming settings."""
|
||||||
|
|
||||||
enable: bool = Field(True, description="Enable bangumi manage")
|
enable: bool = Field(True, description="Enable bangumi manage")
|
||||||
eps_complete: bool = Field(False, description="Enable eps complete")
|
eps_complete: bool = Field(False, description="Enable eps complete")
|
||||||
rename_method: str = Field("pn", description="Rename method")
|
rename_method: str = Field("pn", description="Rename method")
|
||||||
@@ -48,10 +66,14 @@ class BangumiManage(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class Log(BaseModel):
|
class Log(BaseModel):
|
||||||
|
"""Logging verbosity settings."""
|
||||||
|
|
||||||
debug_enable: bool = Field(False, description="Enable debug")
|
debug_enable: bool = Field(False, description="Enable debug")
|
||||||
|
|
||||||
|
|
||||||
class Proxy(BaseModel):
|
class Proxy(BaseModel):
|
||||||
|
"""HTTP/SOCKS proxy settings. Credentials support ``$VAR`` expansion."""
|
||||||
|
|
||||||
enable: bool = Field(False, description="Enable proxy")
|
enable: bool = Field(False, description="Enable proxy")
|
||||||
type: str = Field("http", description="Proxy type")
|
type: str = Field("http", description="Proxy type")
|
||||||
host: str = Field("", description="Proxy host")
|
host: str = Field("", description="Proxy host")
|
||||||
@@ -61,11 +83,11 @@ class Proxy(BaseModel):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def username(self):
|
def username(self):
|
||||||
return expandvars(self.username_)
|
return _expand(self.username_)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def password(self):
|
def password(self):
|
||||||
return expandvars(self.password_)
|
return _expand(self.password_)
|
||||||
|
|
||||||
|
|
||||||
class NotificationProvider(BaseModel):
|
class NotificationProvider(BaseModel):
|
||||||
@@ -103,35 +125,35 @@ class NotificationProvider(BaseModel):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def token(self) -> str:
|
def token(self) -> str:
|
||||||
return expandvars(self.token_) if self.token_ else ""
|
return _expand(self.token_)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def chat_id(self) -> str:
|
def chat_id(self) -> str:
|
||||||
return expandvars(self.chat_id_) if self.chat_id_ else ""
|
return _expand(self.chat_id_)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def webhook_url(self) -> str:
|
def webhook_url(self) -> str:
|
||||||
return expandvars(self.webhook_url_) if self.webhook_url_ else ""
|
return _expand(self.webhook_url_)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def server_url(self) -> str:
|
def server_url(self) -> str:
|
||||||
return expandvars(self.server_url_) if self.server_url_ else ""
|
return _expand(self.server_url_)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def device_key(self) -> str:
|
def device_key(self) -> str:
|
||||||
return expandvars(self.device_key_) if self.device_key_ else ""
|
return _expand(self.device_key_)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def user_key(self) -> str:
|
def user_key(self) -> str:
|
||||||
return expandvars(self.user_key_) if self.user_key_ else ""
|
return _expand(self.user_key_)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def api_token(self) -> str:
|
def api_token(self) -> str:
|
||||||
return expandvars(self.api_token_) if self.api_token_ else ""
|
return _expand(self.api_token_)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def url(self) -> str:
|
def url(self) -> str:
|
||||||
return expandvars(self.url_) if self.url_ else ""
|
return _expand(self.url_)
|
||||||
|
|
||||||
|
|
||||||
class Notification(BaseModel):
|
class Notification(BaseModel):
|
||||||
@@ -149,11 +171,11 @@ class Notification(BaseModel):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def token(self) -> str:
|
def token(self) -> str:
|
||||||
return expandvars(self.token_) if self.token_ else ""
|
return _expand(self.token_)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def chat_id(self) -> str:
|
def chat_id(self) -> str:
|
||||||
return expandvars(self.chat_id_) if self.chat_id_ else ""
|
return _expand(self.chat_id_)
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def migrate_legacy_config(self) -> "Notification":
|
def migrate_legacy_config(self) -> "Notification":
|
||||||
@@ -197,7 +219,35 @@ class ExperimentalOpenAI(BaseModel):
|
|||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
class Security(BaseModel):
|
||||||
|
"""Access control configuration for the login endpoint and MCP server.
|
||||||
|
|
||||||
|
Both ``login_whitelist`` and ``mcp_whitelist`` accept IPv4/IPv6 CIDR ranges.
|
||||||
|
An empty ``login_whitelist`` allows all IPs; an empty ``mcp_whitelist``
|
||||||
|
denies all IP-based access (tokens still work).
|
||||||
|
"""
|
||||||
|
|
||||||
|
login_whitelist: list[str] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="IP/CIDR whitelist for login access. Empty = allow all.",
|
||||||
|
)
|
||||||
|
login_tokens: list[str] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="API bearer tokens that bypass login authentication.",
|
||||||
|
)
|
||||||
|
mcp_whitelist: list[str] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="IP/CIDR whitelist for MCP access. Empty = deny all.",
|
||||||
|
)
|
||||||
|
mcp_tokens: list[str] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="API bearer tokens for MCP access.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Config(BaseModel):
|
class Config(BaseModel):
|
||||||
|
"""Root configuration model composed of all subsection models."""
|
||||||
|
|
||||||
program: Program = Program()
|
program: Program = Program()
|
||||||
downloader: Downloader = Downloader()
|
downloader: Downloader = Downloader()
|
||||||
rss_parser: RSSParser = RSSParser()
|
rss_parser: RSSParser = RSSParser()
|
||||||
@@ -206,6 +256,7 @@ class Config(BaseModel):
|
|||||||
proxy: Proxy = Proxy()
|
proxy: Proxy = Proxy()
|
||||||
notification: Notification = Notification()
|
notification: Notification = Notification()
|
||||||
experimental_openai: ExperimentalOpenAI = ExperimentalOpenAI()
|
experimental_openai: ExperimentalOpenAI = ExperimentalOpenAI()
|
||||||
|
security: Security = Security()
|
||||||
|
|
||||||
def model_dump(self, *args, by_alias=True, **kwargs):
|
def model_dump(self, *args, by_alias=True, **kwargs):
|
||||||
return super().model_dump(*args, by_alias=by_alias, **kwargs)
|
return super().model_dump(*args, by_alias=by_alias, **kwargs)
|
||||||
|
|||||||
@@ -1,9 +1,11 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from fastapi import Cookie, Depends, HTTPException, status
|
from fastapi import Cookie, Depends, HTTPException, Request, status
|
||||||
from fastapi.security import OAuth2PasswordBearer
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
|
|
||||||
|
from module.conf import settings
|
||||||
from module.database import Database
|
from module.database import Database
|
||||||
|
from module.mcp.security import _is_allowed
|
||||||
from module.models.user import User, UserUpdate
|
from module.models.user import User, UserUpdate
|
||||||
|
|
||||||
from .jwt import verify_token
|
from .jwt import verify_token
|
||||||
@@ -20,23 +22,49 @@ except ImportError:
|
|||||||
DEV_AUTH_BYPASS = VERSION == "DEV_VERSION"
|
DEV_AUTH_BYPASS = VERSION == "DEV_VERSION"
|
||||||
|
|
||||||
|
|
||||||
async def get_current_user(token: str = Cookie(None)):
|
def check_login_ip(request: Request):
|
||||||
|
"""Dependency that enforces login IP whitelist.
|
||||||
|
|
||||||
|
If ``settings.security.login_whitelist`` is empty, all IPs are allowed.
|
||||||
|
"""
|
||||||
|
whitelist = settings.security.login_whitelist
|
||||||
|
if not whitelist:
|
||||||
|
return
|
||||||
|
client_host = request.client.host if request.client else None
|
||||||
|
if not client_host or not _is_allowed(client_host, whitelist):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="IP not in login whitelist",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_current_user(request: Request, token: str = Cookie(None)):
|
||||||
|
"""FastAPI dependency that validates the current session.
|
||||||
|
|
||||||
|
Accepts authentication via (in order of precedence):
|
||||||
|
1. DEV_AUTH_BYPASS when running as DEV_VERSION.
|
||||||
|
2. ``Authorization: Bearer <token>`` header matching ``login_tokens``.
|
||||||
|
3. HttpOnly ``token`` cookie containing a valid JWT with an active session.
|
||||||
|
"""
|
||||||
if DEV_AUTH_BYPASS:
|
if DEV_AUTH_BYPASS:
|
||||||
return "dev_user"
|
return "dev_user"
|
||||||
|
# Check bearer token bypass
|
||||||
|
auth_header = request.headers.get("authorization", "")
|
||||||
|
if auth_header.startswith("Bearer "):
|
||||||
|
api_token = auth_header[7:]
|
||||||
|
if api_token and api_token in settings.security.login_tokens:
|
||||||
|
return "api_token_user"
|
||||||
if not token:
|
if not token:
|
||||||
raise UNAUTHORIZED
|
raise UNAUTHORIZED
|
||||||
payload = verify_token(token)
|
payload = verify_token(token)
|
||||||
if not payload:
|
username = payload.get("sub") if payload else None
|
||||||
raise UNAUTHORIZED
|
if not username or username not in active_user:
|
||||||
username = payload.get("sub")
|
|
||||||
if not username:
|
|
||||||
raise UNAUTHORIZED
|
|
||||||
if username not in active_user:
|
|
||||||
raise UNAUTHORIZED
|
raise UNAUTHORIZED
|
||||||
return username
|
return username
|
||||||
|
|
||||||
|
|
||||||
async def get_token_data(token: str = Depends(oauth2_scheme)):
|
async def get_token_data(token: str = Depends(oauth2_scheme)):
|
||||||
|
"""FastAPI dependency that decodes and returns the OAuth2 bearer token payload."""
|
||||||
payload = verify_token(token)
|
payload = verify_token(token)
|
||||||
if not payload:
|
if not payload:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
@@ -46,6 +74,7 @@ async def get_token_data(token: str = Depends(oauth2_scheme)):
|
|||||||
|
|
||||||
|
|
||||||
def update_user_info(user_data: UserUpdate, current_user):
|
def update_user_info(user_data: UserUpdate, current_user):
|
||||||
|
"""Persist updated credentials for *current_user* to the database."""
|
||||||
try:
|
try:
|
||||||
with Database() as db:
|
with Database() as db:
|
||||||
db.user.update_user(current_user, user_data)
|
db.user.update_user(current_user, user_data)
|
||||||
@@ -55,6 +84,7 @@ def update_user_info(user_data: UserUpdate, current_user):
|
|||||||
|
|
||||||
|
|
||||||
def auth_user(user: User):
|
def auth_user(user: User):
|
||||||
|
"""Verify credentials and register the user in ``active_user`` on success."""
|
||||||
with Database() as db:
|
with Database() as db:
|
||||||
resp = db.user.auth_user(user)
|
resp = db.user.auth_user(user)
|
||||||
if resp.status:
|
if resp.status:
|
||||||
|
|||||||
@@ -185,3 +185,98 @@ class TestUpdateCredentials:
|
|||||||
except Exception:
|
except Exception:
|
||||||
# Expected - endpoint doesn't handle failure case properly
|
# Expected - endpoint doesn't handle failure case properly
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Refresh token: cookie-based username resolution
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestRefreshTokenCookieBehavior:
|
||||||
|
def test_refresh_with_no_cookie_raises_401(self, authed_client):
|
||||||
|
"""GET /refresh_token with missing token cookie raises 401."""
|
||||||
|
# Override auth to allow route but provide no cookie token
|
||||||
|
with patch("module.api.auth.decode_token", return_value=None):
|
||||||
|
response = authed_client.get("/api/v1/auth/refresh_token")
|
||||||
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
def test_refresh_with_valid_cookie_updates_active_user(self, authed_client):
|
||||||
|
"""GET /refresh_token updates the active user timestamp."""
|
||||||
|
token = create_access_token(data={"sub": "testuser"})
|
||||||
|
authed_client.cookies.set("token", token)
|
||||||
|
active_users: dict = {}
|
||||||
|
with patch("module.api.auth.active_user", active_users):
|
||||||
|
response = authed_client.get("/api/v1/auth/refresh_token")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert "testuser" in active_users
|
||||||
|
|
||||||
|
def test_refresh_returns_new_token(self, authed_client):
|
||||||
|
"""GET /refresh_token issues a valid JWT with bearer type."""
|
||||||
|
token = create_access_token(data={"sub": "testuser"})
|
||||||
|
authed_client.cookies.set("token", token)
|
||||||
|
with patch("module.api.auth.active_user", {}):
|
||||||
|
response = authed_client.get("/api/v1/auth/refresh_token")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["token_type"] == "bearer"
|
||||||
|
assert isinstance(data["access_token"], str)
|
||||||
|
assert len(data["access_token"]) > 0
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Logout: per-user removal
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestLogoutCookieBehavior:
|
||||||
|
def test_logout_removes_only_current_user(self, authed_client):
|
||||||
|
"""GET /logout removes the current user from active_user, not others."""
|
||||||
|
token = create_access_token(data={"sub": "testuser"})
|
||||||
|
authed_client.cookies.set("token", token)
|
||||||
|
active_users = {
|
||||||
|
"testuser": datetime.now(),
|
||||||
|
"otheruser": datetime.now(),
|
||||||
|
}
|
||||||
|
with patch("module.api.auth.active_user", active_users):
|
||||||
|
response = authed_client.get("/api/v1/auth/logout")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert "testuser" not in active_users
|
||||||
|
assert "otheruser" in active_users
|
||||||
|
|
||||||
|
def test_logout_with_no_cookie_still_succeeds(self, authed_client):
|
||||||
|
"""GET /logout with no cookie clears nothing but returns success."""
|
||||||
|
with patch("module.api.auth.decode_token", return_value=None):
|
||||||
|
with patch("module.api.auth.active_user", {}):
|
||||||
|
response = authed_client.get("/api/v1/auth/logout")
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Update: cookie-based user resolution
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestUpdateCookieBehavior:
|
||||||
|
def test_update_with_no_cookie_raises_401(self, authed_client):
|
||||||
|
"""POST /auth/update with no cookie raises 401."""
|
||||||
|
with patch("module.api.auth.decode_token", return_value=None):
|
||||||
|
response = authed_client.post(
|
||||||
|
"/api/v1/auth/update",
|
||||||
|
json={"old_password": "old", "new_password": "new"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
def test_update_with_valid_cookie_succeeds(self, authed_client):
|
||||||
|
"""POST /auth/update resolves username from cookie and issues new token."""
|
||||||
|
token = create_access_token(data={"sub": "testuser"})
|
||||||
|
authed_client.cookies.set("token", token)
|
||||||
|
with patch("module.api.auth.active_user", {"testuser": datetime.now()}):
|
||||||
|
with patch("module.api.auth.update_user_info", return_value=True):
|
||||||
|
response = authed_client.post(
|
||||||
|
"/api/v1/auth/update",
|
||||||
|
json={"old_password": "oldpass", "new_password": "newpass"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert "access_token" in data
|
||||||
|
assert data["message"] == "update success"
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
"""Tests for Config API endpoints."""
|
"""Tests for Config API endpoints and config sanitization."""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import patch, MagicMock
|
from unittest.mock import patch, MagicMock
|
||||||
@@ -7,6 +7,7 @@ from fastapi import FastAPI
|
|||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
from module.api import v1
|
from module.api import v1
|
||||||
|
from module.api.config import _sanitize_dict
|
||||||
from module.models.config import Config
|
from module.models.config import Config
|
||||||
from module.security.api import get_current_user
|
from module.security.api import get_current_user
|
||||||
|
|
||||||
@@ -263,3 +264,99 @@ class TestUpdateConfig:
|
|||||||
response = authed_client.patch("/api/v1/config/update", json=invalid_data)
|
response = authed_client.patch("/api/v1/config/update", json=invalid_data)
|
||||||
|
|
||||||
assert response.status_code == 422
|
assert response.status_code == 422
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _sanitize_dict unit tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestSanitizeDict:
|
||||||
|
def test_masks_password_key(self):
|
||||||
|
"""Keys containing 'password' are masked."""
|
||||||
|
result = _sanitize_dict({"password": "secret"})
|
||||||
|
assert result["password"] == "********"
|
||||||
|
|
||||||
|
def test_masks_api_key(self):
|
||||||
|
"""Keys containing 'api_key' are masked."""
|
||||||
|
result = _sanitize_dict({"api_key": "sk-abc123"})
|
||||||
|
assert result["api_key"] == "********"
|
||||||
|
|
||||||
|
def test_masks_token_key(self):
|
||||||
|
"""Keys containing 'token' are masked."""
|
||||||
|
result = _sanitize_dict({"token": "bearer-xyz"})
|
||||||
|
assert result["token"] == "********"
|
||||||
|
|
||||||
|
def test_masks_secret_key(self):
|
||||||
|
"""Keys containing 'secret' are masked."""
|
||||||
|
result = _sanitize_dict({"my_secret": "topsecret"})
|
||||||
|
assert result["my_secret"] == "********"
|
||||||
|
|
||||||
|
def test_case_insensitive_key_matching(self):
|
||||||
|
"""Sensitive key matching is case-insensitive."""
|
||||||
|
result = _sanitize_dict({"API_KEY": "abc"})
|
||||||
|
assert result["API_KEY"] == "********"
|
||||||
|
|
||||||
|
def test_non_sensitive_keys_pass_through(self):
|
||||||
|
"""Non-sensitive keys are returned unchanged."""
|
||||||
|
result = _sanitize_dict({"host": "localhost", "port": 8080, "enable": True})
|
||||||
|
assert result["host"] == "localhost"
|
||||||
|
assert result["port"] == 8080
|
||||||
|
assert result["enable"] is True
|
||||||
|
|
||||||
|
def test_nested_dict_recursed(self):
|
||||||
|
"""Nested dicts are processed recursively."""
|
||||||
|
result = _sanitize_dict({
|
||||||
|
"downloader": {
|
||||||
|
"host": "localhost",
|
||||||
|
"password": "secret",
|
||||||
|
}
|
||||||
|
})
|
||||||
|
assert result["downloader"]["host"] == "localhost"
|
||||||
|
assert result["downloader"]["password"] == "********"
|
||||||
|
|
||||||
|
def test_deeply_nested_dict(self):
|
||||||
|
"""Deeply nested sensitive keys are masked."""
|
||||||
|
result = _sanitize_dict({
|
||||||
|
"level1": {
|
||||||
|
"level2": {
|
||||||
|
"api_key": "deep-secret"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
assert result["level1"]["level2"]["api_key"] == "********"
|
||||||
|
|
||||||
|
def test_non_string_value_not_masked(self):
|
||||||
|
"""Non-string values with sensitive-looking keys are NOT masked."""
|
||||||
|
result = _sanitize_dict({"password": 12345})
|
||||||
|
# Only string values are masked; integers pass through
|
||||||
|
assert result["password"] == 12345
|
||||||
|
|
||||||
|
def test_empty_dict(self):
|
||||||
|
"""Empty dict returns empty dict."""
|
||||||
|
assert _sanitize_dict({}) == {}
|
||||||
|
|
||||||
|
def test_mixed_sensitive_and_plain(self):
|
||||||
|
"""Mix of sensitive and plain keys handled correctly."""
|
||||||
|
result = _sanitize_dict({
|
||||||
|
"username": "admin",
|
||||||
|
"password": "secret",
|
||||||
|
"host": "10.0.0.1",
|
||||||
|
"token": "jwt-abc",
|
||||||
|
})
|
||||||
|
assert result["username"] == "admin"
|
||||||
|
assert result["host"] == "10.0.0.1"
|
||||||
|
assert result["password"] == "********"
|
||||||
|
assert result["token"] == "********"
|
||||||
|
|
||||||
|
def test_get_config_masks_sensitive_fields(self, authed_client):
|
||||||
|
"""GET /config/get response masks password and api_key fields."""
|
||||||
|
test_config = Config()
|
||||||
|
with patch("module.api.config.settings", test_config):
|
||||||
|
response = authed_client.get("/api/v1/config/get")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
# Downloader password should be masked
|
||||||
|
assert data["downloader"]["password"] == "********"
|
||||||
|
# OpenAI api_key should be masked (it's an empty string but still masked)
|
||||||
|
assert data["experimental_openai"]["api_key"] == "********"
|
||||||
|
|||||||
@@ -151,6 +151,15 @@ class TestPasswordHashing:
|
|||||||
|
|
||||||
|
|
||||||
class TestGetCurrentUser:
|
class TestGetCurrentUser:
|
||||||
|
@staticmethod
|
||||||
|
def _mock_request(authorization=""):
|
||||||
|
"""Create a mock Request with the given Authorization header."""
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
request = MagicMock()
|
||||||
|
request.headers = {"authorization": authorization}
|
||||||
|
return request
|
||||||
|
|
||||||
@patch("module.security.api.DEV_AUTH_BYPASS", False)
|
@patch("module.security.api.DEV_AUTH_BYPASS", False)
|
||||||
async def test_no_cookie_raises_401(self):
|
async def test_no_cookie_raises_401(self):
|
||||||
"""get_current_user raises 401 when no token cookie."""
|
"""get_current_user raises 401 when no token cookie."""
|
||||||
@@ -159,7 +168,7 @@ class TestGetCurrentUser:
|
|||||||
from module.security.api import get_current_user
|
from module.security.api import get_current_user
|
||||||
|
|
||||||
with pytest.raises(HTTPException) as exc_info:
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
await get_current_user(token=None)
|
await get_current_user(request=self._mock_request(), token=None)
|
||||||
assert exc_info.value.status_code == 401
|
assert exc_info.value.status_code == 401
|
||||||
|
|
||||||
@patch("module.security.api.DEV_AUTH_BYPASS", False)
|
@patch("module.security.api.DEV_AUTH_BYPASS", False)
|
||||||
@@ -170,7 +179,7 @@ class TestGetCurrentUser:
|
|||||||
from module.security.api import get_current_user
|
from module.security.api import get_current_user
|
||||||
|
|
||||||
with pytest.raises(HTTPException) as exc_info:
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
await get_current_user(token="invalid.jwt.token")
|
await get_current_user(request=self._mock_request(), token="invalid.jwt.token")
|
||||||
assert exc_info.value.status_code == 401
|
assert exc_info.value.status_code == 401
|
||||||
|
|
||||||
@patch("module.security.api.DEV_AUTH_BYPASS", False)
|
@patch("module.security.api.DEV_AUTH_BYPASS", False)
|
||||||
@@ -186,7 +195,7 @@ class TestGetCurrentUser:
|
|||||||
active_user.clear()
|
active_user.clear()
|
||||||
|
|
||||||
with pytest.raises(HTTPException) as exc_info:
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
await get_current_user(token=token)
|
await get_current_user(request=self._mock_request(), token=token)
|
||||||
assert exc_info.value.status_code == 401
|
assert exc_info.value.status_code == 401
|
||||||
|
|
||||||
@patch("module.security.api.DEV_AUTH_BYPASS", False)
|
@patch("module.security.api.DEV_AUTH_BYPASS", False)
|
||||||
@@ -202,8 +211,113 @@ class TestGetCurrentUser:
|
|||||||
active_user.clear()
|
active_user.clear()
|
||||||
active_user["active_user"] = datetime.now()
|
active_user["active_user"] = datetime.now()
|
||||||
|
|
||||||
result = await get_current_user(token=token)
|
result = await get_current_user(request=self._mock_request(), token=token)
|
||||||
assert result == "active_user"
|
assert result == "active_user"
|
||||||
|
|
||||||
# Cleanup
|
# Cleanup
|
||||||
active_user.clear()
|
active_user.clear()
|
||||||
|
|
||||||
|
@patch("module.security.api.DEV_AUTH_BYPASS", True)
|
||||||
|
async def test_dev_bypass_skips_auth(self):
|
||||||
|
"""When DEV_AUTH_BYPASS is True, get_current_user returns 'dev_user' unconditionally."""
|
||||||
|
from module.security.api import get_current_user
|
||||||
|
|
||||||
|
result = await get_current_user(request=self._mock_request(), token=None)
|
||||||
|
assert result == "dev_user"
|
||||||
|
|
||||||
|
@patch("module.security.api.DEV_AUTH_BYPASS", False)
|
||||||
|
async def test_bearer_token_bypass_valid(self):
|
||||||
|
"""A valid login_token in Authorization header returns 'api_token_user'."""
|
||||||
|
from module.security.api import get_current_user
|
||||||
|
|
||||||
|
mock_request = self._mock_request(authorization="Bearer valid-api-token")
|
||||||
|
mock_security = type("S", (), {"login_tokens": ["valid-api-token"]})()
|
||||||
|
mock_settings = type("Settings", (), {"security": mock_security})()
|
||||||
|
|
||||||
|
with patch("module.security.api.settings", mock_settings):
|
||||||
|
result = await get_current_user(request=mock_request, token=None)
|
||||||
|
assert result == "api_token_user"
|
||||||
|
|
||||||
|
@patch("module.security.api.DEV_AUTH_BYPASS", False)
|
||||||
|
async def test_bearer_token_bypass_invalid(self):
|
||||||
|
"""An invalid login_token still falls through to cookie check."""
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
from module.security.api import get_current_user
|
||||||
|
|
||||||
|
mock_request = self._mock_request(authorization="Bearer wrong-token")
|
||||||
|
mock_security = type("S", (), {"login_tokens": ["correct-token"]})()
|
||||||
|
mock_settings = type("Settings", (), {"security": mock_security})()
|
||||||
|
|
||||||
|
with patch("module.security.api.settings", mock_settings):
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
await get_current_user(request=mock_request, token=None)
|
||||||
|
assert exc_info.value.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# check_login_ip
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestCheckLoginIp:
|
||||||
|
@staticmethod
|
||||||
|
def _make_request(host: str | None):
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
request = MagicMock()
|
||||||
|
if host is None:
|
||||||
|
request.client = None
|
||||||
|
else:
|
||||||
|
request.client = MagicMock()
|
||||||
|
request.client.host = host
|
||||||
|
return request
|
||||||
|
|
||||||
|
def test_empty_whitelist_allows_all(self):
|
||||||
|
"""When login_whitelist is empty, all IPs pass."""
|
||||||
|
from module.security.api import check_login_ip
|
||||||
|
|
||||||
|
mock_security = type("S", (), {"login_whitelist": []})()
|
||||||
|
mock_settings = type("Settings", (), {"security": mock_security})()
|
||||||
|
|
||||||
|
with patch("module.security.api.settings", mock_settings):
|
||||||
|
# Should not raise
|
||||||
|
check_login_ip(request=self._make_request("8.8.8.8"))
|
||||||
|
|
||||||
|
def test_allowed_ip_passes(self):
|
||||||
|
"""IP in whitelist does not raise."""
|
||||||
|
from module.security.api import check_login_ip
|
||||||
|
|
||||||
|
mock_security = type("S", (), {"login_whitelist": ["192.168.0.0/16"]})()
|
||||||
|
mock_settings = type("Settings", (), {"security": mock_security})()
|
||||||
|
|
||||||
|
with patch("module.security.api.settings", mock_settings):
|
||||||
|
check_login_ip(request=self._make_request("192.168.1.100"))
|
||||||
|
|
||||||
|
def test_blocked_ip_raises_403(self):
|
||||||
|
"""IP outside whitelist raises 403."""
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
from module.security.api import check_login_ip
|
||||||
|
|
||||||
|
mock_security = type("S", (), {"login_whitelist": ["192.168.0.0/16"]})()
|
||||||
|
mock_settings = type("Settings", (), {"security": mock_security})()
|
||||||
|
|
||||||
|
with patch("module.security.api.settings", mock_settings):
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
check_login_ip(request=self._make_request("8.8.8.8"))
|
||||||
|
assert exc_info.value.status_code == 403
|
||||||
|
|
||||||
|
def test_no_client_raises_403_when_whitelist_set(self):
|
||||||
|
"""Missing client info raises 403 when whitelist is non-empty."""
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
from module.security.api import check_login_ip
|
||||||
|
|
||||||
|
mock_security = type("S", (), {"login_whitelist": ["192.168.0.0/16"]})()
|
||||||
|
mock_settings = type("Settings", (), {"security": mock_security})()
|
||||||
|
|
||||||
|
with patch("module.security.api.settings", mock_settings):
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
check_login_ip(request=self._make_request(None))
|
||||||
|
assert exc_info.value.status_code == 403
|
||||||
|
|||||||
@@ -8,14 +8,16 @@ from unittest.mock import patch
|
|||||||
|
|
||||||
from module.models.config import (
|
from module.models.config import (
|
||||||
Config,
|
Config,
|
||||||
Program,
|
|
||||||
Downloader,
|
Downloader,
|
||||||
RSSParser,
|
|
||||||
BangumiManage,
|
|
||||||
Proxy,
|
|
||||||
Notification as NotificationConfig,
|
Notification as NotificationConfig,
|
||||||
|
NotificationProvider,
|
||||||
|
Program,
|
||||||
|
Proxy,
|
||||||
|
RSSParser,
|
||||||
|
Security,
|
||||||
)
|
)
|
||||||
from module.conf.config import Settings
|
from module.conf.config import Settings
|
||||||
|
from module.conf.const import BCOLORS, DEFAULT_SETTINGS
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -228,3 +230,281 @@ class TestEnvOverrides:
|
|||||||
s.init()
|
s.init()
|
||||||
|
|
||||||
assert "192.168.1.100:9090" in s.downloader.host
|
assert "192.168.1.100:9090" in s.downloader.host
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Security model
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestSecurityModel:
|
||||||
|
def test_security_defaults(self):
|
||||||
|
"""Security has empty whitelists and token lists by default."""
|
||||||
|
sec = Security()
|
||||||
|
assert sec.login_whitelist == []
|
||||||
|
assert sec.login_tokens == []
|
||||||
|
assert sec.mcp_whitelist == []
|
||||||
|
assert sec.mcp_tokens == []
|
||||||
|
|
||||||
|
def test_security_in_config(self):
|
||||||
|
"""Config includes a Security section with correct defaults."""
|
||||||
|
config = Config()
|
||||||
|
assert hasattr(config, "security")
|
||||||
|
assert isinstance(config.security, Security)
|
||||||
|
assert config.security.login_whitelist == []
|
||||||
|
|
||||||
|
def test_security_populated(self):
|
||||||
|
"""Security fields accept lists of CIDRs and tokens."""
|
||||||
|
sec = Security(
|
||||||
|
login_whitelist=["192.168.0.0/16"],
|
||||||
|
login_tokens=["token-abc"],
|
||||||
|
mcp_whitelist=["10.0.0.0/8"],
|
||||||
|
mcp_tokens=["mcp-secret"],
|
||||||
|
)
|
||||||
|
assert "192.168.0.0/16" in sec.login_whitelist
|
||||||
|
assert "token-abc" in sec.login_tokens
|
||||||
|
assert "10.0.0.0/8" in sec.mcp_whitelist
|
||||||
|
assert "mcp-secret" in sec.mcp_tokens
|
||||||
|
|
||||||
|
def test_security_roundtrip_serialization(self):
|
||||||
|
"""Security serializes and deserializes correctly."""
|
||||||
|
original = Security(
|
||||||
|
login_whitelist=["127.0.0.0/8"],
|
||||||
|
mcp_tokens=["tok1"],
|
||||||
|
)
|
||||||
|
data = original.model_dump()
|
||||||
|
restored = Security.model_validate(data)
|
||||||
|
assert restored.login_whitelist == ["127.0.0.0/8"]
|
||||||
|
assert restored.mcp_tokens == ["tok1"]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# NotificationProvider model
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestNotificationProvider:
|
||||||
|
def test_minimal_provider(self):
|
||||||
|
"""NotificationProvider requires only type."""
|
||||||
|
p = NotificationProvider(type="telegram")
|
||||||
|
assert p.type == "telegram"
|
||||||
|
assert p.enabled is True
|
||||||
|
|
||||||
|
def test_telegram_provider_fields(self):
|
||||||
|
"""Telegram provider stores token and chat_id."""
|
||||||
|
p = NotificationProvider(type="telegram", token="bot123", chat_id="-100456")
|
||||||
|
assert p.token == "bot123"
|
||||||
|
assert p.chat_id == "-100456"
|
||||||
|
|
||||||
|
def test_discord_provider_fields(self):
|
||||||
|
"""Discord provider stores webhook_url."""
|
||||||
|
p = NotificationProvider(
|
||||||
|
type="discord", webhook_url="https://discord.com/api/webhooks/123/abc"
|
||||||
|
)
|
||||||
|
assert p.webhook_url == "https://discord.com/api/webhooks/123/abc"
|
||||||
|
|
||||||
|
def test_bark_provider_fields(self):
|
||||||
|
"""Bark provider stores server_url and device_key."""
|
||||||
|
p = NotificationProvider(
|
||||||
|
type="bark", server_url="https://api.day.app", device_key="mykey"
|
||||||
|
)
|
||||||
|
assert p.server_url == "https://api.day.app"
|
||||||
|
assert p.device_key == "mykey"
|
||||||
|
|
||||||
|
def test_pushover_provider_fields(self):
|
||||||
|
"""Pushover provider stores user_key and api_token."""
|
||||||
|
p = NotificationProvider(type="pushover", user_key="uk1", api_token="at1")
|
||||||
|
assert p.user_key == "uk1"
|
||||||
|
assert p.api_token == "at1"
|
||||||
|
|
||||||
|
def test_url_field_property(self):
|
||||||
|
"""Webhook provider stores url."""
|
||||||
|
p = NotificationProvider(type="webhook", url="https://example.com/hook")
|
||||||
|
assert p.url == "https://example.com/hook"
|
||||||
|
|
||||||
|
def test_optional_fields_default_empty_string(self):
|
||||||
|
"""Unset optional properties return empty string, not None."""
|
||||||
|
p = NotificationProvider(type="telegram")
|
||||||
|
assert p.token == ""
|
||||||
|
assert p.chat_id == ""
|
||||||
|
assert p.webhook_url == ""
|
||||||
|
|
||||||
|
def test_provider_can_be_disabled(self):
|
||||||
|
"""Provider can be disabled without removing it."""
|
||||||
|
p = NotificationProvider(type="telegram", enabled=False)
|
||||||
|
assert p.enabled is False
|
||||||
|
|
||||||
|
def test_env_var_expansion_in_token(self, monkeypatch):
|
||||||
|
"""Token field expands shell environment variables."""
|
||||||
|
monkeypatch.setenv("TEST_BOT_TOKEN", "real-token-value")
|
||||||
|
p = NotificationProvider(type="telegram", token="$TEST_BOT_TOKEN")
|
||||||
|
assert p.token == "real-token-value"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Notification model - legacy migration
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestNotificationLegacyMigration:
|
||||||
|
def test_new_format_no_migration(self):
|
||||||
|
"""New format with providers list is not touched."""
|
||||||
|
n = NotificationConfig(
|
||||||
|
enable=True,
|
||||||
|
providers=[NotificationProvider(type="telegram", token="tok")],
|
||||||
|
)
|
||||||
|
assert len(n.providers) == 1
|
||||||
|
assert n.providers[0].type == "telegram"
|
||||||
|
|
||||||
|
def test_old_format_migrates_to_provider(self):
|
||||||
|
"""Old single-provider fields (type, token, chat_id) migrate to providers list."""
|
||||||
|
n = NotificationConfig(
|
||||||
|
enable=True,
|
||||||
|
type="telegram",
|
||||||
|
token="bot_token",
|
||||||
|
chat_id="-100123",
|
||||||
|
)
|
||||||
|
assert len(n.providers) == 1
|
||||||
|
provider = n.providers[0]
|
||||||
|
assert provider.type == "telegram"
|
||||||
|
assert provider.enabled is True
|
||||||
|
|
||||||
|
def test_old_format_no_migration_when_providers_already_set(self):
|
||||||
|
"""When providers already exist, legacy fields do not create additional providers."""
|
||||||
|
n = NotificationConfig(
|
||||||
|
enable=True,
|
||||||
|
type="telegram",
|
||||||
|
token="unused",
|
||||||
|
providers=[NotificationProvider(type="discord", webhook_url="https://d.co")],
|
||||||
|
)
|
||||||
|
assert len(n.providers) == 1
|
||||||
|
assert n.providers[0].type == "discord"
|
||||||
|
|
||||||
|
def test_notification_empty_providers_by_default(self):
|
||||||
|
"""Default Notification has no providers."""
|
||||||
|
n = NotificationConfig()
|
||||||
|
assert n.providers == []
|
||||||
|
assert n.enable is False
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Downloader env-var expansion
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestDownloaderEnvExpansion:
|
||||||
|
def test_host_expands_env_var(self, monkeypatch):
|
||||||
|
"""Downloader.host expands $VAR references."""
|
||||||
|
monkeypatch.setenv("QB_HOST", "192.168.5.10:8080")
|
||||||
|
d = Downloader(host="$QB_HOST")
|
||||||
|
assert d.host == "192.168.5.10:8080"
|
||||||
|
|
||||||
|
def test_username_expands_env_var(self, monkeypatch):
|
||||||
|
"""Downloader.username expands $VAR references."""
|
||||||
|
monkeypatch.setenv("QB_USER", "myuser")
|
||||||
|
d = Downloader(username="$QB_USER")
|
||||||
|
assert d.username == "myuser"
|
||||||
|
|
||||||
|
def test_password_expands_env_var(self, monkeypatch):
|
||||||
|
"""Downloader.password expands $VAR references."""
|
||||||
|
monkeypatch.setenv("QB_PASS", "s3cret")
|
||||||
|
d = Downloader(password="$QB_PASS")
|
||||||
|
assert d.password == "s3cret"
|
||||||
|
|
||||||
|
def test_literal_host_not_expanded(self):
|
||||||
|
"""Literal host strings without $ are returned as-is."""
|
||||||
|
d = Downloader(host="localhost:8080")
|
||||||
|
assert d.host == "localhost:8080"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# DEFAULT_SETTINGS structure
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestDefaultSettings:
|
||||||
|
def test_security_section_present(self):
|
||||||
|
"""DEFAULT_SETTINGS contains a security section."""
|
||||||
|
assert "security" in DEFAULT_SETTINGS
|
||||||
|
|
||||||
|
def test_security_default_mcp_whitelist(self):
|
||||||
|
"""Default MCP whitelist contains private network ranges."""
|
||||||
|
mcp_wl = DEFAULT_SETTINGS["security"]["mcp_whitelist"]
|
||||||
|
assert "127.0.0.0/8" in mcp_wl
|
||||||
|
assert "192.168.0.0/16" in mcp_wl
|
||||||
|
assert "10.0.0.0/8" in mcp_wl
|
||||||
|
|
||||||
|
def test_security_default_tokens_empty(self):
|
||||||
|
"""Default security token lists are empty."""
|
||||||
|
assert DEFAULT_SETTINGS["security"]["login_tokens"] == []
|
||||||
|
assert DEFAULT_SETTINGS["security"]["mcp_tokens"] == []
|
||||||
|
|
||||||
|
def test_notification_uses_providers_format(self):
|
||||||
|
"""DEFAULT_SETTINGS notification uses new providers format."""
|
||||||
|
notif = DEFAULT_SETTINGS["notification"]
|
||||||
|
assert "providers" in notif
|
||||||
|
assert notif["providers"] == []
|
||||||
|
assert "type" not in notif
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# BCOLORS utility
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestBCOLORS:
|
||||||
|
def test_wrap_single_string(self):
|
||||||
|
"""BCOLORS._() wraps a string with color codes and reset."""
|
||||||
|
result = BCOLORS._(BCOLORS.OKGREEN, "hello")
|
||||||
|
assert "hello" in result
|
||||||
|
assert BCOLORS.OKGREEN in result
|
||||||
|
assert BCOLORS.ENDC in result
|
||||||
|
|
||||||
|
def test_wrap_multiple_strings(self):
|
||||||
|
"""BCOLORS._() joins multiple args with commas."""
|
||||||
|
result = BCOLORS._(BCOLORS.WARNING, "foo", "bar")
|
||||||
|
assert "foo" in result
|
||||||
|
assert "bar" in result
|
||||||
|
|
||||||
|
def test_wrap_non_string_arg(self):
|
||||||
|
"""BCOLORS._() converts non-string args to str."""
|
||||||
|
result = BCOLORS._(BCOLORS.FAIL, 42)
|
||||||
|
assert "42" in result
|
||||||
|
|
||||||
|
def test_all_color_constants_are_strings(self):
|
||||||
|
"""All BCOLORS constants are strings."""
|
||||||
|
for attr in ["HEADER", "OKBLUE", "OKCYAN", "OKGREEN", "WARNING", "FAIL", "ENDC"]:
|
||||||
|
assert isinstance(getattr(BCOLORS, attr), str)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Migration: security section injection
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestMigrateSecuritySection:
|
||||||
|
def test_adds_security_when_missing(self):
|
||||||
|
"""_migrate_old_config injects a default security section when absent."""
|
||||||
|
old_config = {
|
||||||
|
"program": {},
|
||||||
|
"rss_parser": {},
|
||||||
|
}
|
||||||
|
result = Settings._migrate_old_config(old_config)
|
||||||
|
assert "security" in result
|
||||||
|
assert "mcp_whitelist" in result["security"]
|
||||||
|
|
||||||
|
def test_preserves_existing_security_section(self):
|
||||||
|
"""_migrate_old_config does not overwrite an existing security section."""
|
||||||
|
existing_config = {
|
||||||
|
"program": {},
|
||||||
|
"rss_parser": {},
|
||||||
|
"security": {
|
||||||
|
"login_whitelist": ["10.0.0.0/8"],
|
||||||
|
"login_tokens": ["mytoken"],
|
||||||
|
"mcp_whitelist": [],
|
||||||
|
"mcp_tokens": [],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
result = Settings._migrate_old_config(existing_config)
|
||||||
|
assert result["security"]["login_tokens"] == ["mytoken"]
|
||||||
|
assert result["security"]["login_whitelist"] == ["10.0.0.0/8"]
|
||||||
|
|||||||
@@ -293,7 +293,68 @@ class TestClientDelegation:
|
|||||||
result = await download_client.rename_torrent_file("hash1", "old.mkv", "new.mkv")
|
result = await download_client.rename_torrent_file("hash1", "old.mkv", "new.mkv")
|
||||||
assert result is False
|
assert result is False
|
||||||
|
|
||||||
|
async def test_rename_torrent_file_passes_verify_flag(
|
||||||
|
self, download_client, mock_qb_client
|
||||||
|
):
|
||||||
|
"""rename_torrent_file forwards the verify kwarg to the underlying client."""
|
||||||
|
mock_qb_client.torrents_rename_file.return_value = True
|
||||||
|
await download_client.rename_torrent_file(
|
||||||
|
"hash1", "old.mkv", "new.mkv", verify=False
|
||||||
|
)
|
||||||
|
call_kwargs = mock_qb_client.torrents_rename_file.call_args[1]
|
||||||
|
assert call_kwargs["verify"] is False
|
||||||
|
|
||||||
async def test_delete_torrent(self, download_client, mock_qb_client):
|
async def test_delete_torrent(self, download_client, mock_qb_client):
|
||||||
"""delete_torrent delegates to client.torrents_delete."""
|
"""delete_torrent delegates to client.torrents_delete."""
|
||||||
await download_client.delete_torrent("hash1", delete_files=True)
|
await download_client.delete_torrent("hash1", delete_files=True)
|
||||||
mock_qb_client.torrents_delete.assert_called_once_with("hash1", delete_files=True)
|
mock_qb_client.torrents_delete.assert_called_once_with("hash1", delete_files=True)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# add_tag
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestAddTag:
|
||||||
|
async def test_add_tag_delegates_to_client(self, download_client, mock_qb_client):
|
||||||
|
"""add_tag delegates to client.add_tag."""
|
||||||
|
mock_qb_client.add_tag = AsyncMock(return_value=None)
|
||||||
|
download_client.client = mock_qb_client
|
||||||
|
await download_client.add_tag("deadbeef12345678", "ab:42")
|
||||||
|
mock_qb_client.add_tag.assert_called_once_with("deadbeef12345678", "ab:42")
|
||||||
|
|
||||||
|
async def test_add_tag_short_hash_no_error(self, download_client, mock_qb_client):
|
||||||
|
"""add_tag with a hash shorter than 8 chars does not crash the slice."""
|
||||||
|
mock_qb_client.add_tag = AsyncMock(return_value=None)
|
||||||
|
download_client.client = mock_qb_client
|
||||||
|
# Should not raise even for short hashes
|
||||||
|
await download_client.add_tag("abc", "ab:1")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Context manager: ConnectionError on failed auth
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestContextManagerAuth:
|
||||||
|
async def test_aenter_raises_on_auth_failure(self, download_client, mock_qb_client):
|
||||||
|
"""__aenter__ raises ConnectionError when auth fails."""
|
||||||
|
mock_qb_client.auth.return_value = False
|
||||||
|
download_client.authed = False
|
||||||
|
with pytest.raises(ConnectionError, match="authentication failed"):
|
||||||
|
await download_client.__aenter__()
|
||||||
|
|
||||||
|
async def test_aenter_succeeds_when_auth_passes(self, download_client, mock_qb_client):
|
||||||
|
"""__aenter__ returns self when auth succeeds."""
|
||||||
|
mock_qb_client.auth.return_value = True
|
||||||
|
download_client.authed = False
|
||||||
|
result = await download_client.__aenter__()
|
||||||
|
assert result is download_client
|
||||||
|
assert download_client.authed is True
|
||||||
|
|
||||||
|
async def test_aexit_calls_logout_when_authed(self, download_client, mock_qb_client):
|
||||||
|
"""__aexit__ calls logout and resets authed when session was active."""
|
||||||
|
download_client.authed = True
|
||||||
|
await download_client.__aexit__(None, None, None)
|
||||||
|
mock_qb_client.logout.assert_called_once()
|
||||||
|
assert download_client.authed is False
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
"""Tests for module.mcp.security - LocalNetworkMiddleware and _is_local()."""
|
"""Tests for module.mcp.security - McpAccessMiddleware, _is_allowed(), and clear_network_cache()."""
|
||||||
|
|
||||||
from unittest.mock import AsyncMock, MagicMock
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from starlette.applications import Starlette
|
from starlette.applications import Starlette
|
||||||
@@ -8,248 +8,223 @@ from starlette.responses import PlainTextResponse
|
|||||||
from starlette.routing import Route
|
from starlette.routing import Route
|
||||||
from starlette.testclient import TestClient
|
from starlette.testclient import TestClient
|
||||||
|
|
||||||
from module.mcp.security import LocalNetworkMiddleware, _is_local
|
from module.mcp.security import McpAccessMiddleware, _is_allowed, clear_network_cache
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# _is_local() unit tests
|
# _is_allowed() unit tests
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class TestIsLocal:
|
class TestIsAllowed:
|
||||||
"""Verify _is_local() correctly classifies IP addresses."""
|
"""Verify _is_allowed() checks IPs against a given whitelist."""
|
||||||
|
|
||||||
# --- loopback ---
|
def setup_method(self):
|
||||||
|
clear_network_cache()
|
||||||
|
|
||||||
def test_ipv4_loopback_127_0_0_1(self):
|
LOCAL_WHITELIST = [
|
||||||
"""127.0.0.1 is the canonical loopback address."""
|
"127.0.0.0/8",
|
||||||
assert _is_local("127.0.0.1") is True
|
"10.0.0.0/8",
|
||||||
|
"172.16.0.0/12",
|
||||||
|
"192.168.0.0/16",
|
||||||
|
"::1/128",
|
||||||
|
"fe80::/10",
|
||||||
|
"fc00::/7",
|
||||||
|
]
|
||||||
|
|
||||||
def test_ipv4_loopback_127_0_0_2(self):
|
# --- allowed IPs ---
|
||||||
"""127.0.0.2 is within 127.0.0.0/8 and therefore local."""
|
|
||||||
assert _is_local("127.0.0.2") is True
|
|
||||||
|
|
||||||
def test_ipv4_loopback_127_255_255_255(self):
|
def test_ipv4_loopback_allowed(self):
|
||||||
"""Top of 127.0.0.0/8 range is still local."""
|
assert _is_allowed("127.0.0.1", self.LOCAL_WHITELIST) is True
|
||||||
assert _is_local("127.255.255.255") is True
|
|
||||||
|
|
||||||
# --- RFC 1918 class-A (10.0.0.0/8) ---
|
def test_ipv4_loopback_range(self):
|
||||||
|
assert _is_allowed("127.255.255.255", self.LOCAL_WHITELIST) is True
|
||||||
|
|
||||||
def test_ipv4_10_network_start(self):
|
def test_ipv4_10_network(self):
|
||||||
"""10.0.0.1 is in 10.0.0.0/8 private range."""
|
assert _is_allowed("10.0.0.1", self.LOCAL_WHITELIST) is True
|
||||||
assert _is_local("10.0.0.1") is True
|
|
||||||
|
|
||||||
def test_ipv4_10_network_mid(self):
|
def test_ipv4_172_16_network(self):
|
||||||
"""10.10.20.30 is inside 10.0.0.0/8."""
|
assert _is_allowed("172.16.0.1", self.LOCAL_WHITELIST) is True
|
||||||
assert _is_local("10.10.20.30") is True
|
|
||||||
|
|
||||||
def test_ipv4_10_network_end(self):
|
def test_ipv4_192_168_network(self):
|
||||||
"""10.255.255.254 is the last usable address in 10.0.0.0/8."""
|
assert _is_allowed("192.168.1.100", self.LOCAL_WHITELIST) is True
|
||||||
assert _is_local("10.255.255.254") is True
|
|
||||||
|
|
||||||
# --- RFC 1918 class-B (172.16.0.0/12) ---
|
|
||||||
|
|
||||||
def test_ipv4_172_16_start(self):
|
|
||||||
"""172.16.0.1 is the first address in 172.16.0.0/12."""
|
|
||||||
assert _is_local("172.16.0.1") is True
|
|
||||||
|
|
||||||
def test_ipv4_172_31_end(self):
|
|
||||||
"""172.31.255.254 is at the top of the 172.16.0.0/12 range."""
|
|
||||||
assert _is_local("172.31.255.254") is True
|
|
||||||
|
|
||||||
def test_ipv4_172_15_not_local(self):
|
|
||||||
"""172.15.255.255 is just outside 172.16.0.0/12 (below the range)."""
|
|
||||||
assert _is_local("172.15.255.255") is False
|
|
||||||
|
|
||||||
def test_ipv4_172_32_not_local(self):
|
|
||||||
"""172.32.0.0 is just above 172.16.0.0/12 (outside the range)."""
|
|
||||||
assert _is_local("172.32.0.0") is False
|
|
||||||
|
|
||||||
# --- RFC 1918 class-C (192.168.0.0/16) ---
|
|
||||||
|
|
||||||
def test_ipv4_192_168_start(self):
|
|
||||||
"""192.168.0.1 is a typical home-router address."""
|
|
||||||
assert _is_local("192.168.0.1") is True
|
|
||||||
|
|
||||||
def test_ipv4_192_168_end(self):
|
|
||||||
"""192.168.255.254 is at the top of 192.168.0.0/16."""
|
|
||||||
assert _is_local("192.168.255.254") is True
|
|
||||||
|
|
||||||
# --- Public IPv4 ---
|
|
||||||
|
|
||||||
def test_public_ipv4_google_dns(self):
|
|
||||||
"""8.8.8.8 (Google DNS) is a public address."""
|
|
||||||
assert _is_local("8.8.8.8") is False
|
|
||||||
|
|
||||||
def test_public_ipv4_cloudflare_dns(self):
|
|
||||||
"""1.1.1.1 (Cloudflare) is a public address."""
|
|
||||||
assert _is_local("1.1.1.1") is False
|
|
||||||
|
|
||||||
def test_public_ipv4_broadcast_like(self):
|
|
||||||
"""203.0.113.1 (TEST-NET-3, RFC 5737) is not a private address."""
|
|
||||||
assert _is_local("203.0.113.1") is False
|
|
||||||
|
|
||||||
# --- IPv6 loopback ---
|
|
||||||
|
|
||||||
def test_ipv6_loopback(self):
|
def test_ipv6_loopback(self):
|
||||||
"""::1 is the IPv6 loopback address."""
|
assert _is_allowed("::1", self.LOCAL_WHITELIST) is True
|
||||||
assert _is_local("::1") is True
|
|
||||||
|
|
||||||
# --- IPv6 link-local (fe80::/10) ---
|
|
||||||
|
|
||||||
def test_ipv6_link_local(self):
|
def test_ipv6_link_local(self):
|
||||||
"""fe80::1 is an IPv6 link-local address."""
|
assert _is_allowed("fe80::1", self.LOCAL_WHITELIST) is True
|
||||||
assert _is_local("fe80::1") is True
|
|
||||||
|
|
||||||
def test_ipv6_link_local_full(self):
|
def test_ipv6_ula(self):
|
||||||
"""fe80::aabb:ccdd is also link-local."""
|
assert _is_allowed("fd00::1", self.LOCAL_WHITELIST) is True
|
||||||
assert _is_local("fe80::aabb:ccdd") is True
|
|
||||||
|
|
||||||
# --- IPv6 ULA (fc00::/7) ---
|
# --- denied IPs ---
|
||||||
|
|
||||||
def test_ipv6_ula_fc(self):
|
def test_public_ipv4_denied(self):
|
||||||
"""fc00::1 is within the ULA range fc00::/7."""
|
assert _is_allowed("8.8.8.8", self.LOCAL_WHITELIST) is False
|
||||||
assert _is_local("fc00::1") is True
|
|
||||||
|
|
||||||
def test_ipv6_ula_fd(self):
|
def test_public_ipv6_denied(self):
|
||||||
"""fd00::1 is within the ULA range fc00::/7 (fd prefix)."""
|
assert _is_allowed("2001:4860:4860::8888", self.LOCAL_WHITELIST) is False
|
||||||
assert _is_local("fd00::1") is True
|
|
||||||
|
|
||||||
# --- Public IPv6 ---
|
def test_172_outside_range(self):
|
||||||
|
assert _is_allowed("172.32.0.0", self.LOCAL_WHITELIST) is False
|
||||||
|
|
||||||
def test_public_ipv6_google(self):
|
# --- empty whitelist ---
|
||||||
"""2001:4860:4860::8888 (Google IPv6 DNS) is a public address."""
|
|
||||||
assert _is_local("2001:4860:4860::8888") is False
|
|
||||||
|
|
||||||
def test_public_ipv6_documentation(self):
|
def test_empty_whitelist_denies_all(self):
|
||||||
"""2001:db8::1 (documentation prefix, RFC 3849) is public."""
|
assert _is_allowed("127.0.0.1", []) is False
|
||||||
assert _is_local("2001:db8::1") is False
|
|
||||||
|
|
||||||
# --- Invalid inputs ---
|
# --- invalid inputs ---
|
||||||
|
|
||||||
def test_invalid_hostname_returns_false(self):
|
def test_invalid_hostname(self):
|
||||||
"""A hostname string is not parseable as an IP and must return False."""
|
assert _is_allowed("localhost", self.LOCAL_WHITELIST) is False
|
||||||
assert _is_local("localhost") is False
|
|
||||||
|
|
||||||
def test_invalid_string_returns_false(self):
|
def test_empty_string(self):
|
||||||
"""A random non-IP string returns False without raising."""
|
assert _is_allowed("", self.LOCAL_WHITELIST) is False
|
||||||
assert _is_local("not-an-ip") is False
|
|
||||||
|
|
||||||
def test_empty_string_returns_false(self):
|
def test_malformed_ipv4(self):
|
||||||
"""An empty string is not a valid IP address."""
|
assert _is_allowed("256.0.0.1", self.LOCAL_WHITELIST) is False
|
||||||
assert _is_local("") is False
|
|
||||||
|
|
||||||
def test_malformed_ipv4_returns_false(self):
|
# --- single IP whitelist ---
|
||||||
"""A string that looks like IPv4 but is malformed returns False."""
|
|
||||||
assert _is_local("256.0.0.1") is False
|
|
||||||
|
|
||||||
def test_partial_ipv4_returns_false(self):
|
def test_single_ip_whitelist(self):
|
||||||
"""An incomplete IPv4 address is not valid."""
|
assert _is_allowed("203.0.113.5", ["203.0.113.5/32"]) is True
|
||||||
assert _is_local("192.168") is False
|
assert _is_allowed("203.0.113.6", ["203.0.113.5/32"]) is False
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# LocalNetworkMiddleware integration tests
|
# McpAccessMiddleware integration tests
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def _make_app_with_middleware() -> Starlette:
|
def _make_mcp_settings(mcp_whitelist=None, mcp_tokens=None):
|
||||||
"""Build a minimal Starlette app with LocalNetworkMiddleware applied."""
|
"""Create a mock settings.security object."""
|
||||||
|
|
||||||
|
class MockSecurity:
|
||||||
|
def __init__(self):
|
||||||
|
self.mcp_whitelist = mcp_whitelist if mcp_whitelist is not None else []
|
||||||
|
self.mcp_tokens = mcp_tokens if mcp_tokens is not None else []
|
||||||
|
|
||||||
|
class MockSettings:
|
||||||
|
def __init__(self):
|
||||||
|
self.security = MockSecurity()
|
||||||
|
|
||||||
|
return MockSettings()
|
||||||
|
|
||||||
|
|
||||||
|
def _make_app() -> Starlette:
|
||||||
|
"""Build a minimal Starlette app with McpAccessMiddleware applied."""
|
||||||
|
|
||||||
async def homepage(request):
|
async def homepage(request):
|
||||||
return PlainTextResponse("ok")
|
return PlainTextResponse("ok")
|
||||||
|
|
||||||
app = Starlette(routes=[Route("/", homepage)])
|
app = Starlette(routes=[Route("/", homepage)])
|
||||||
app.add_middleware(LocalNetworkMiddleware)
|
app.add_middleware(McpAccessMiddleware)
|
||||||
return app
|
return app
|
||||||
|
|
||||||
|
|
||||||
class TestLocalNetworkMiddleware:
|
def _patch_client_ip(app, ip):
|
||||||
"""Verify LocalNetworkMiddleware allows or denies requests by client IP."""
|
"""Return a modified app that overrides the client IP in ASGI scope."""
|
||||||
|
original_build = app.build_middleware_stack
|
||||||
|
|
||||||
@pytest.fixture
|
async def patched_app(scope, receive, send):
|
||||||
def app(self):
|
if scope["type"] == "http":
|
||||||
return _make_app_with_middleware()
|
scope["client"] = (ip, 12345) if ip is not None else None
|
||||||
|
await original_build()(scope, receive, send)
|
||||||
|
|
||||||
def test_local_ipv4_loopback_allowed(self, app):
|
app.build_middleware_stack = lambda: patched_app
|
||||||
"""Requests from 127.0.0.1 are allowed through."""
|
return app
|
||||||
# Starlette's TestClient identifies itself as "testclient", not a real IP.
|
|
||||||
# Patch the scope so the middleware sees an actual loopback address.
|
|
||||||
original_build = app.build_middleware_stack
|
|
||||||
|
|
||||||
async def patched_app(scope, receive, send):
|
|
||||||
if scope["type"] == "http":
|
|
||||||
scope["client"] = ("127.0.0.1", 12345)
|
|
||||||
await original_build()(scope, receive, send)
|
|
||||||
|
|
||||||
app.build_middleware_stack = lambda: patched_app # type: ignore[method-assign]
|
class TestMcpAccessMiddleware:
|
||||||
|
"""Verify McpAccessMiddleware allows/denies requests by IP and token."""
|
||||||
|
|
||||||
client = TestClient(app, raise_server_exceptions=False)
|
def setup_method(self):
|
||||||
response = client.get("/")
|
clear_network_cache()
|
||||||
|
|
||||||
|
def test_allowed_ip_passes(self):
|
||||||
|
mock_settings = _make_mcp_settings(mcp_whitelist=["127.0.0.0/8"])
|
||||||
|
app = _patch_client_ip(_make_app(), "127.0.0.1")
|
||||||
|
with patch("module.mcp.security.settings", mock_settings):
|
||||||
|
client = TestClient(app, raise_server_exceptions=False)
|
||||||
|
response = client.get("/")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.text == "ok"
|
assert response.text == "ok"
|
||||||
|
|
||||||
def test_non_local_ip_blocked(self, app):
|
def test_denied_ip_blocked(self):
|
||||||
"""Requests from a public IP are rejected with 403."""
|
mock_settings = _make_mcp_settings(mcp_whitelist=["127.0.0.0/8"])
|
||||||
# Patch the ASGI scope to simulate a public client
|
app = _patch_client_ip(_make_app(), "8.8.8.8")
|
||||||
original_build = app.build_middleware_stack
|
with patch("module.mcp.security.settings", mock_settings):
|
||||||
|
client = TestClient(app, raise_server_exceptions=False)
|
||||||
async def patched_app(scope, receive, send):
|
response = client.get("/")
|
||||||
if scope["type"] == "http":
|
|
||||||
scope["client"] = ("8.8.8.8", 12345)
|
|
||||||
await original_build()(scope, receive, send)
|
|
||||||
|
|
||||||
app.build_middleware_stack = lambda: patched_app # type: ignore[method-assign]
|
|
||||||
|
|
||||||
client = TestClient(app, raise_server_exceptions=False)
|
|
||||||
response = client.get("/")
|
|
||||||
assert response.status_code == 403
|
assert response.status_code == 403
|
||||||
assert "MCP access is restricted to local network" in response.text
|
assert "MCP access denied" in response.text
|
||||||
|
|
||||||
def test_missing_client_blocked(self, app):
|
def test_empty_whitelist_denies_all(self):
|
||||||
"""Requests with no client information are rejected with 403."""
|
mock_settings = _make_mcp_settings(mcp_whitelist=[], mcp_tokens=[])
|
||||||
original_build = app.build_middleware_stack
|
app = _patch_client_ip(_make_app(), "127.0.0.1")
|
||||||
|
with patch("module.mcp.security.settings", mock_settings):
|
||||||
async def patched_app(scope, receive, send):
|
client = TestClient(app, raise_server_exceptions=False)
|
||||||
if scope["type"] == "http":
|
response = client.get("/")
|
||||||
scope["client"] = None
|
|
||||||
await original_build()(scope, receive, send)
|
|
||||||
|
|
||||||
app.build_middleware_stack = lambda: patched_app # type: ignore[method-assign]
|
|
||||||
|
|
||||||
client = TestClient(app, raise_server_exceptions=False)
|
|
||||||
response = client.get("/")
|
|
||||||
assert response.status_code == 403
|
assert response.status_code == 403
|
||||||
|
|
||||||
def test_blocked_response_is_json(self, app):
|
def test_missing_client_blocked(self):
|
||||||
"""The 403 error body is valid JSON with an 'error' key."""
|
mock_settings = _make_mcp_settings(mcp_whitelist=["127.0.0.0/8"])
|
||||||
|
app = _patch_client_ip(_make_app(), None)
|
||||||
|
with patch("module.mcp.security.settings", mock_settings):
|
||||||
|
client = TestClient(app, raise_server_exceptions=False)
|
||||||
|
response = client.get("/")
|
||||||
|
assert response.status_code == 403
|
||||||
|
|
||||||
|
def test_bearer_token_bypasses_ip(self):
|
||||||
|
mock_settings = _make_mcp_settings(
|
||||||
|
mcp_whitelist=[], mcp_tokens=["secret-token-123"]
|
||||||
|
)
|
||||||
|
app = _patch_client_ip(_make_app(), "8.8.8.8")
|
||||||
|
with patch("module.mcp.security.settings", mock_settings):
|
||||||
|
client = TestClient(app, raise_server_exceptions=False)
|
||||||
|
response = client.get(
|
||||||
|
"/", headers={"Authorization": "Bearer secret-token-123"}
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
def test_invalid_bearer_token_denied(self):
|
||||||
|
mock_settings = _make_mcp_settings(
|
||||||
|
mcp_whitelist=[], mcp_tokens=["secret-token-123"]
|
||||||
|
)
|
||||||
|
app = _patch_client_ip(_make_app(), "8.8.8.8")
|
||||||
|
with patch("module.mcp.security.settings", mock_settings):
|
||||||
|
client = TestClient(app, raise_server_exceptions=False)
|
||||||
|
response = client.get(
|
||||||
|
"/", headers={"Authorization": "Bearer wrong-token"}
|
||||||
|
)
|
||||||
|
assert response.status_code == 403
|
||||||
|
|
||||||
|
def test_private_network_with_default_whitelist(self):
|
||||||
|
default_whitelist = [
|
||||||
|
"127.0.0.0/8",
|
||||||
|
"10.0.0.0/8",
|
||||||
|
"172.16.0.0/12",
|
||||||
|
"192.168.0.0/16",
|
||||||
|
"::1/128",
|
||||||
|
"fe80::/10",
|
||||||
|
"fc00::/7",
|
||||||
|
]
|
||||||
|
mock_settings = _make_mcp_settings(mcp_whitelist=default_whitelist)
|
||||||
|
app = _patch_client_ip(_make_app(), "192.168.1.100")
|
||||||
|
with patch("module.mcp.security.settings", mock_settings):
|
||||||
|
client = TestClient(app, raise_server_exceptions=False)
|
||||||
|
response = client.get("/")
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
def test_blocked_response_is_json(self):
|
||||||
import json
|
import json
|
||||||
|
|
||||||
original_build = app.build_middleware_stack
|
mock_settings = _make_mcp_settings(mcp_whitelist=["127.0.0.0/8"])
|
||||||
|
app = _patch_client_ip(_make_app(), "1.2.3.4")
|
||||||
async def patched_app(scope, receive, send):
|
with patch("module.mcp.security.settings", mock_settings):
|
||||||
if scope["type"] == "http":
|
client = TestClient(app, raise_server_exceptions=False)
|
||||||
scope["client"] = ("1.2.3.4", 9999)
|
response = client.get("/")
|
||||||
await original_build()(scope, receive, send)
|
|
||||||
|
|
||||||
app.build_middleware_stack = lambda: patched_app # type: ignore[method-assign]
|
|
||||||
|
|
||||||
client = TestClient(app, raise_server_exceptions=False)
|
|
||||||
response = client.get("/")
|
|
||||||
assert response.status_code == 403
|
assert response.status_code == 403
|
||||||
body = json.loads(response.text)
|
body = json.loads(response.text)
|
||||||
assert "error" in body
|
assert "error" in body
|
||||||
|
|
||||||
def test_private_192_168_allowed(self, app):
|
|
||||||
"""Requests from a 192.168.x.x address pass through."""
|
|
||||||
original_build = app.build_middleware_stack
|
|
||||||
|
|
||||||
async def patched_app(scope, receive, send):
|
|
||||||
if scope["type"] == "http":
|
|
||||||
scope["client"] = ("192.168.1.100", 54321)
|
|
||||||
await original_build()(scope, receive, send)
|
|
||||||
|
|
||||||
app.build_middleware_stack = lambda: patched_app # type: ignore[method-assign]
|
|
||||||
|
|
||||||
client = TestClient(app, raise_server_exceptions=False)
|
|
||||||
response = client.get("/")
|
|
||||||
assert response.status_code == 200
|
|
||||||
|
|||||||
429
backend/src/test/test_mock_downloader.py
Normal file
429
backend/src/test/test_mock_downloader.py
Normal file
@@ -0,0 +1,429 @@
|
|||||||
|
"""Tests for MockDownloader - state management and API contract."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from module.downloader.client.mock_downloader import MockDownloader
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_dl() -> MockDownloader:
|
||||||
|
"""Fresh MockDownloader for each test."""
|
||||||
|
return MockDownloader()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Initialization
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestMockDownloaderInit:
|
||||||
|
def test_initial_state_is_empty(self, mock_dl):
|
||||||
|
"""MockDownloader starts with no torrents, rules, or feeds."""
|
||||||
|
state = mock_dl.get_state()
|
||||||
|
assert state["torrents"] == {}
|
||||||
|
assert state["rules"] == {}
|
||||||
|
assert state["feeds"] == {}
|
||||||
|
|
||||||
|
def test_initial_categories(self, mock_dl):
|
||||||
|
"""Default categories include Bangumi and BangumiCollection."""
|
||||||
|
state = mock_dl.get_state()
|
||||||
|
assert "Bangumi" in state["categories"]
|
||||||
|
assert "BangumiCollection" in state["categories"]
|
||||||
|
|
||||||
|
def test_initial_prefs(self, mock_dl):
|
||||||
|
"""Default prefs are populated."""
|
||||||
|
# Access private attribute directly to confirm defaults
|
||||||
|
assert mock_dl._prefs["rss_auto_downloading_enabled"] is True
|
||||||
|
assert mock_dl._prefs["rss_max_articles_per_feed"] == 500
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Auth / connection
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestMockDownloaderAuth:
|
||||||
|
async def test_auth_returns_true(self, mock_dl):
|
||||||
|
result = await mock_dl.auth()
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
async def test_auth_retry_param_accepted(self, mock_dl):
|
||||||
|
result = await mock_dl.auth(retry=5)
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
async def test_logout_does_not_raise(self, mock_dl):
|
||||||
|
await mock_dl.logout()
|
||||||
|
|
||||||
|
async def test_check_host_returns_true(self, mock_dl):
|
||||||
|
result = await mock_dl.check_host()
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
async def test_check_connection_returns_version_string(self, mock_dl):
|
||||||
|
result = await mock_dl.check_connection()
|
||||||
|
assert "mock" in result.lower()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Prefs
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestMockDownloaderPrefs:
|
||||||
|
async def test_prefs_init_updates_prefs(self, mock_dl):
|
||||||
|
"""prefs_init merges given prefs into the internal store."""
|
||||||
|
await mock_dl.prefs_init({"rss_refresh_interval": 60, "custom_key": "val"})
|
||||||
|
assert mock_dl._prefs["rss_refresh_interval"] == 60
|
||||||
|
assert mock_dl._prefs["custom_key"] == "val"
|
||||||
|
|
||||||
|
async def test_get_app_prefs_returns_dict(self, mock_dl):
|
||||||
|
result = await mock_dl.get_app_prefs()
|
||||||
|
assert isinstance(result, dict)
|
||||||
|
assert "save_path" in result
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Categories
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestMockDownloaderCategories:
|
||||||
|
async def test_add_category_persists(self, mock_dl):
|
||||||
|
await mock_dl.add_category("NewCategory")
|
||||||
|
state = mock_dl.get_state()
|
||||||
|
assert "NewCategory" in state["categories"]
|
||||||
|
|
||||||
|
async def test_add_duplicate_category_no_error(self, mock_dl):
|
||||||
|
await mock_dl.add_category("Bangumi")
|
||||||
|
state = mock_dl.get_state()
|
||||||
|
# Still only one entry for Bangumi (set semantics)
|
||||||
|
assert state["categories"].count("Bangumi") == 1
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Torrent management
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestMockDownloaderAddTorrents:
|
||||||
|
async def test_add_torrent_url_returns_true(self, mock_dl):
|
||||||
|
result = await mock_dl.add_torrents(
|
||||||
|
torrent_urls="magnet:?xt=urn:btih:abc",
|
||||||
|
torrent_files=None,
|
||||||
|
save_path="/downloads/Bangumi",
|
||||||
|
category="Bangumi",
|
||||||
|
)
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
async def test_add_torrent_stores_in_state(self, mock_dl):
|
||||||
|
await mock_dl.add_torrents(
|
||||||
|
torrent_urls="magnet:?xt=urn:btih:abc",
|
||||||
|
torrent_files=None,
|
||||||
|
save_path="/downloads/Bangumi",
|
||||||
|
category="Bangumi",
|
||||||
|
)
|
||||||
|
state = mock_dl.get_state()
|
||||||
|
assert len(state["torrents"]) == 1
|
||||||
|
|
||||||
|
async def test_add_torrent_with_tag_stored(self, mock_dl):
|
||||||
|
await mock_dl.add_torrents(
|
||||||
|
torrent_urls="magnet:?xt=urn:btih:abc",
|
||||||
|
torrent_files=None,
|
||||||
|
save_path="/downloads/Bangumi",
|
||||||
|
category="Bangumi",
|
||||||
|
tags="ab:42",
|
||||||
|
)
|
||||||
|
state = mock_dl.get_state()
|
||||||
|
torrent = list(state["torrents"].values())[0]
|
||||||
|
assert torrent["tags"] == "ab:42"
|
||||||
|
|
||||||
|
async def test_add_torrent_with_file_bytes(self, mock_dl):
|
||||||
|
result = await mock_dl.add_torrents(
|
||||||
|
torrent_urls=None,
|
||||||
|
torrent_files=b"\x00\x01\x02",
|
||||||
|
save_path="/downloads",
|
||||||
|
category="Bangumi",
|
||||||
|
)
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
async def test_two_different_torrents_stored_separately(self, mock_dl):
|
||||||
|
await mock_dl.add_torrents(
|
||||||
|
torrent_urls="magnet:?xt=urn:btih:aaa",
|
||||||
|
torrent_files=None,
|
||||||
|
save_path="/dl",
|
||||||
|
category="Bangumi",
|
||||||
|
)
|
||||||
|
await mock_dl.add_torrents(
|
||||||
|
torrent_urls="magnet:?xt=urn:btih:bbb",
|
||||||
|
torrent_files=None,
|
||||||
|
save_path="/dl",
|
||||||
|
category="Bangumi",
|
||||||
|
)
|
||||||
|
state = mock_dl.get_state()
|
||||||
|
assert len(state["torrents"]) == 2
|
||||||
|
|
||||||
|
|
||||||
|
class TestMockDownloaderTorrentsInfo:
|
||||||
|
async def test_returns_all_when_no_filter(self, mock_dl):
|
||||||
|
mock_dl.add_mock_torrent("Anime A", category="Bangumi")
|
||||||
|
mock_dl.add_mock_torrent("Anime B", category="Bangumi")
|
||||||
|
result = await mock_dl.torrents_info(status_filter=None, category=None)
|
||||||
|
assert len(result) == 2
|
||||||
|
|
||||||
|
async def test_filters_by_category(self, mock_dl):
|
||||||
|
mock_dl.add_mock_torrent("Anime A", category="Bangumi")
|
||||||
|
mock_dl.add_mock_torrent("Movie", category="BangumiCollection")
|
||||||
|
result = await mock_dl.torrents_info(status_filter=None, category="Bangumi")
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0]["name"] == "Anime A"
|
||||||
|
|
||||||
|
async def test_filters_by_tag(self, mock_dl):
|
||||||
|
h1 = mock_dl.add_mock_torrent("Anime A", category="Bangumi")
|
||||||
|
mock_dl.add_mock_torrent("Anime B", category="Bangumi")
|
||||||
|
# Manually set the tag on first torrent
|
||||||
|
mock_dl._torrents[h1]["tags"] = ["ab:1"]
|
||||||
|
result = await mock_dl.torrents_info(
|
||||||
|
status_filter=None, category=None, tag="ab:1"
|
||||||
|
)
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0]["name"] == "Anime A"
|
||||||
|
|
||||||
|
async def test_empty_store_returns_empty_list(self, mock_dl):
|
||||||
|
result = await mock_dl.torrents_info(status_filter=None, category="Bangumi")
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
|
||||||
|
class TestMockDownloaderTorrentsFiles:
|
||||||
|
async def test_returns_files_for_known_hash(self, mock_dl):
|
||||||
|
files = [{"name": "ep01.mkv", "size": 500_000_000}]
|
||||||
|
h = mock_dl.add_mock_torrent("Anime", files=files)
|
||||||
|
result = await mock_dl.torrents_files(torrent_hash=h)
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0]["name"] == "ep01.mkv"
|
||||||
|
|
||||||
|
async def test_returns_empty_list_for_unknown_hash(self, mock_dl):
|
||||||
|
result = await mock_dl.torrents_files(torrent_hash="nonexistent")
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
|
||||||
|
class TestMockDownloaderDelete:
|
||||||
|
async def test_delete_single_torrent(self, mock_dl):
|
||||||
|
h = mock_dl.add_mock_torrent("Anime")
|
||||||
|
await mock_dl.torrents_delete(hash=h)
|
||||||
|
assert h not in mock_dl._torrents
|
||||||
|
|
||||||
|
async def test_delete_multiple_torrents_pipe_separated(self, mock_dl):
|
||||||
|
h1 = mock_dl.add_mock_torrent("Anime A")
|
||||||
|
h2 = mock_dl.add_mock_torrent("Anime B")
|
||||||
|
await mock_dl.torrents_delete(hash=f"{h1}|{h2}")
|
||||||
|
assert h1 not in mock_dl._torrents
|
||||||
|
assert h2 not in mock_dl._torrents
|
||||||
|
|
||||||
|
async def test_delete_nonexistent_hash_no_error(self, mock_dl):
|
||||||
|
await mock_dl.torrents_delete(hash="deadbeef")
|
||||||
|
|
||||||
|
|
||||||
|
class TestMockDownloaderPauseResume:
|
||||||
|
async def test_pause_sets_state(self, mock_dl):
|
||||||
|
h = mock_dl.add_mock_torrent("Anime", state="downloading")
|
||||||
|
await mock_dl.torrents_pause(hashes=h)
|
||||||
|
assert mock_dl._torrents[h]["state"] == "paused"
|
||||||
|
|
||||||
|
async def test_resume_sets_state(self, mock_dl):
|
||||||
|
h = mock_dl.add_mock_torrent("Anime", state="paused")
|
||||||
|
await mock_dl.torrents_resume(hashes=h)
|
||||||
|
assert mock_dl._torrents[h]["state"] == "downloading"
|
||||||
|
|
||||||
|
async def test_pause_multiple_pipe_separated(self, mock_dl):
|
||||||
|
h1 = mock_dl.add_mock_torrent("Anime A", state="downloading")
|
||||||
|
h2 = mock_dl.add_mock_torrent("Anime B", state="downloading")
|
||||||
|
await mock_dl.torrents_pause(hashes=f"{h1}|{h2}")
|
||||||
|
assert mock_dl._torrents[h1]["state"] == "paused"
|
||||||
|
assert mock_dl._torrents[h2]["state"] == "paused"
|
||||||
|
|
||||||
|
async def test_pause_unknown_hash_no_error(self, mock_dl):
|
||||||
|
await mock_dl.torrents_pause(hashes="deadbeef")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Rename
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestMockDownloaderRename:
|
||||||
|
async def test_rename_returns_true(self, mock_dl):
|
||||||
|
result = await mock_dl.torrents_rename_file(
|
||||||
|
torrent_hash="hash1",
|
||||||
|
old_path="old.mkv",
|
||||||
|
new_path="new.mkv",
|
||||||
|
)
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
async def test_rename_with_verify_flag(self, mock_dl):
|
||||||
|
result = await mock_dl.torrents_rename_file(
|
||||||
|
torrent_hash="hash1",
|
||||||
|
old_path="old.mkv",
|
||||||
|
new_path="new.mkv",
|
||||||
|
verify=False,
|
||||||
|
)
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# RSS feed management
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestMockDownloaderRssFeeds:
|
||||||
|
async def test_add_feed_stored(self, mock_dl):
|
||||||
|
await mock_dl.rss_add_feed(url="https://mikan.me/RSS/test", item_path="Mikan")
|
||||||
|
feeds = await mock_dl.rss_get_feeds()
|
||||||
|
assert "Mikan" in feeds
|
||||||
|
assert feeds["Mikan"]["url"] == "https://mikan.me/RSS/test"
|
||||||
|
|
||||||
|
async def test_remove_feed(self, mock_dl):
|
||||||
|
await mock_dl.rss_add_feed(url="https://example.com", item_path="Feed1")
|
||||||
|
await mock_dl.rss_remove_item(item_path="Feed1")
|
||||||
|
feeds = await mock_dl.rss_get_feeds()
|
||||||
|
assert "Feed1" not in feeds
|
||||||
|
|
||||||
|
async def test_remove_nonexistent_feed_no_error(self, mock_dl):
|
||||||
|
await mock_dl.rss_remove_item(item_path="nonexistent")
|
||||||
|
|
||||||
|
async def test_get_feeds_initially_empty(self, mock_dl):
|
||||||
|
feeds = await mock_dl.rss_get_feeds()
|
||||||
|
assert feeds == {}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Rules
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestMockDownloaderRules:
|
||||||
|
async def test_set_rule_stored(self, mock_dl):
|
||||||
|
rule_def = {"enable": True, "mustContain": "Anime"}
|
||||||
|
await mock_dl.rss_set_rule("rule1", rule_def)
|
||||||
|
rules = await mock_dl.get_download_rule()
|
||||||
|
assert "rule1" in rules
|
||||||
|
assert rules["rule1"]["mustContain"] == "Anime"
|
||||||
|
|
||||||
|
async def test_remove_rule(self, mock_dl):
|
||||||
|
await mock_dl.rss_set_rule("rule1", {"enable": True})
|
||||||
|
await mock_dl.remove_rule("rule1")
|
||||||
|
rules = await mock_dl.get_download_rule()
|
||||||
|
assert "rule1" not in rules
|
||||||
|
|
||||||
|
async def test_remove_nonexistent_rule_no_error(self, mock_dl):
|
||||||
|
await mock_dl.remove_rule("nonexistent")
|
||||||
|
|
||||||
|
async def test_get_download_rule_initially_empty(self, mock_dl):
|
||||||
|
rules = await mock_dl.get_download_rule()
|
||||||
|
assert rules == {}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Move / path
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestMockDownloaderMovePath:
|
||||||
|
async def test_move_torrent_updates_save_path(self, mock_dl):
|
||||||
|
h = mock_dl.add_mock_torrent("Anime", save_path="/old/path")
|
||||||
|
await mock_dl.move_torrent(hashes=h, new_location="/new/path")
|
||||||
|
assert mock_dl._torrents[h]["save_path"] == "/new/path"
|
||||||
|
|
||||||
|
async def test_move_multiple_pipe_separated(self, mock_dl):
|
||||||
|
h1 = mock_dl.add_mock_torrent("Anime A", save_path="/old")
|
||||||
|
h2 = mock_dl.add_mock_torrent("Anime B", save_path="/old")
|
||||||
|
await mock_dl.move_torrent(hashes=f"{h1}|{h2}", new_location="/new")
|
||||||
|
assert mock_dl._torrents[h1]["save_path"] == "/new"
|
||||||
|
assert mock_dl._torrents[h2]["save_path"] == "/new"
|
||||||
|
|
||||||
|
async def test_get_torrent_path_known_hash(self, mock_dl):
|
||||||
|
h = mock_dl.add_mock_torrent("Anime", save_path="/downloads/Bangumi")
|
||||||
|
path = await mock_dl.get_torrent_path(h)
|
||||||
|
assert path == "/downloads/Bangumi"
|
||||||
|
|
||||||
|
async def test_get_torrent_path_unknown_hash_returns_default(self, mock_dl):
|
||||||
|
path = await mock_dl.get_torrent_path("nonexistent")
|
||||||
|
assert path == "/tmp/mock-downloads"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Category assignment
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestMockDownloaderSetCategory:
|
||||||
|
async def test_set_category_updates_torrent(self, mock_dl):
|
||||||
|
h = mock_dl.add_mock_torrent("Anime", category="Bangumi")
|
||||||
|
await mock_dl.set_category(h, "BangumiCollection")
|
||||||
|
assert mock_dl._torrents[h]["category"] == "BangumiCollection"
|
||||||
|
|
||||||
|
async def test_set_category_unknown_hash_no_error(self, mock_dl):
|
||||||
|
await mock_dl.set_category("deadbeef", "Bangumi")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tags
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestMockDownloaderTags:
|
||||||
|
async def test_add_tag_appends(self, mock_dl):
|
||||||
|
h = mock_dl.add_mock_torrent("Anime")
|
||||||
|
await mock_dl.add_tag(h, "ab:1")
|
||||||
|
assert "ab:1" in mock_dl._torrents[h]["tags"]
|
||||||
|
|
||||||
|
async def test_add_tag_no_duplicates(self, mock_dl):
|
||||||
|
h = mock_dl.add_mock_torrent("Anime")
|
||||||
|
await mock_dl.add_tag(h, "ab:1")
|
||||||
|
await mock_dl.add_tag(h, "ab:1")
|
||||||
|
assert mock_dl._torrents[h]["tags"].count("ab:1") == 1
|
||||||
|
|
||||||
|
async def test_add_tag_unknown_hash_no_error(self, mock_dl):
|
||||||
|
await mock_dl.add_tag("deadbeef", "ab:1")
|
||||||
|
|
||||||
|
async def test_multiple_tags_on_same_torrent(self, mock_dl):
|
||||||
|
h = mock_dl.add_mock_torrent("Anime")
|
||||||
|
await mock_dl.add_tag(h, "ab:1")
|
||||||
|
await mock_dl.add_tag(h, "group:sub")
|
||||||
|
assert "ab:1" in mock_dl._torrents[h]["tags"]
|
||||||
|
assert "group:sub" in mock_dl._torrents[h]["tags"]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# add_mock_torrent helper
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestAddMockTorrentHelper:
|
||||||
|
def test_generates_hash_from_name(self, mock_dl):
|
||||||
|
h = mock_dl.add_mock_torrent("Anime")
|
||||||
|
assert h is not None
|
||||||
|
assert len(h) == 40 # SHA1 hex digest
|
||||||
|
|
||||||
|
def test_explicit_hash_used(self, mock_dl):
|
||||||
|
h = mock_dl.add_mock_torrent("Anime", hash="cafebabe" + "0" * 32)
|
||||||
|
assert h == "cafebabe" + "0" * 32
|
||||||
|
|
||||||
|
def test_torrent_state_is_completed_by_default(self, mock_dl):
|
||||||
|
h = mock_dl.add_mock_torrent("Anime")
|
||||||
|
assert mock_dl._torrents[h]["state"] == "completed"
|
||||||
|
assert mock_dl._torrents[h]["progress"] == 1.0
|
||||||
|
|
||||||
|
def test_torrent_state_custom(self, mock_dl):
|
||||||
|
h = mock_dl.add_mock_torrent("Anime", state="downloading")
|
||||||
|
assert mock_dl._torrents[h]["state"] == "downloading"
|
||||||
|
assert mock_dl._torrents[h]["progress"] == 0.5
|
||||||
|
|
||||||
|
def test_default_file_is_mkv(self, mock_dl):
|
||||||
|
h = mock_dl.add_mock_torrent("My Anime")
|
||||||
|
files = mock_dl._torrents[h]["files"]
|
||||||
|
assert len(files) == 1
|
||||||
|
assert files[0]["name"].endswith(".mkv")
|
||||||
|
|
||||||
|
def test_custom_files_stored(self, mock_dl):
|
||||||
|
custom_files = [{"name": "ep01.mkv"}, {"name": "ep02.mkv"}]
|
||||||
|
h = mock_dl.add_mock_torrent("Anime", files=custom_files)
|
||||||
|
assert len(mock_dl._torrents[h]["files"]) == 2
|
||||||
2
backend/uv.lock
generated
2
backend/uv.lock
generated
@@ -61,7 +61,7 @@ wheels = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "auto-bangumi"
|
name = "auto-bangumi"
|
||||||
version = "3.2.3b5"
|
version = "3.2.3"
|
||||||
source = { virtual = "." }
|
source = { virtual = "." }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "aiosqlite" },
|
{ name = "aiosqlite" },
|
||||||
|
|||||||
68
webui/src/components/setting/config-security.vue
Normal file
68
webui/src/components/setting/config-security.vue
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
<script lang="ts" setup>
|
||||||
|
import type { Security } from '#/config';
|
||||||
|
import type { SettingItem } from '#/components';
|
||||||
|
|
||||||
|
const { t } = useMyI18n();
|
||||||
|
const { getSettingGroup } = useConfigStore();
|
||||||
|
|
||||||
|
const security = getSettingGroup('security');
|
||||||
|
|
||||||
|
const items: SettingItem<Security>[] = [
|
||||||
|
{
|
||||||
|
configKey: 'login_whitelist',
|
||||||
|
label: () => t('config.security_set.login_whitelist'),
|
||||||
|
type: 'dynamic-tags',
|
||||||
|
prop: {
|
||||||
|
placeholder: '192.168.0.0/16',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
configKey: 'login_tokens',
|
||||||
|
label: () => t('config.security_set.login_tokens'),
|
||||||
|
type: 'dynamic-tags',
|
||||||
|
prop: {
|
||||||
|
placeholder: 'your-api-token',
|
||||||
|
},
|
||||||
|
bottomLine: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
configKey: 'mcp_whitelist',
|
||||||
|
label: () => t('config.security_set.mcp_whitelist'),
|
||||||
|
type: 'dynamic-tags',
|
||||||
|
prop: {
|
||||||
|
placeholder: '127.0.0.0/8',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
configKey: 'mcp_tokens',
|
||||||
|
label: () => t('config.security_set.mcp_tokens'),
|
||||||
|
type: 'dynamic-tags',
|
||||||
|
prop: {
|
||||||
|
placeholder: 'your-mcp-token',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
];
|
||||||
|
</script>
|
||||||
|
|
||||||
|
<template>
|
||||||
|
<ab-fold-panel :title="$t('config.security_set.title')">
|
||||||
|
<p class="hint-text">{{ $t('config.security_set.hint') }}</p>
|
||||||
|
<div space-y-8>
|
||||||
|
<ab-setting
|
||||||
|
v-for="i in items"
|
||||||
|
:key="i.configKey"
|
||||||
|
v-bind="i"
|
||||||
|
v-model:data="security[i.configKey]"
|
||||||
|
></ab-setting>
|
||||||
|
</div>
|
||||||
|
</ab-fold-panel>
|
||||||
|
</template>
|
||||||
|
|
||||||
|
<style lang="scss" scoped>
|
||||||
|
.hint-text {
|
||||||
|
font-size: 12px;
|
||||||
|
color: var(--color-text-secondary);
|
||||||
|
margin-bottom: 12px;
|
||||||
|
line-height: 1.5;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
@@ -49,6 +49,7 @@ onActivated(() => {
|
|||||||
<config-player></config-player>
|
<config-player></config-player>
|
||||||
<config-openai></config-openai>
|
<config-openai></config-openai>
|
||||||
<config-passkey></config-passkey>
|
<config-passkey></config-passkey>
|
||||||
|
<config-security></config-security>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
|||||||
@@ -2,10 +2,6 @@ import type { TupleToUnion } from './utils';
|
|||||||
|
|
||||||
/** 下载方式 */
|
/** 下载方式 */
|
||||||
export type DownloaderType = ['qbittorrent'];
|
export type DownloaderType = ['qbittorrent'];
|
||||||
/** rss parser 源 */
|
|
||||||
export type RssParserType = ['mikan'];
|
|
||||||
/** rss parser 方法 */
|
|
||||||
export type RssParserMethodType = ['tmdb', 'mikan', 'parser'];
|
|
||||||
/** rss parser 语言 */
|
/** rss parser 语言 */
|
||||||
export type RssParserLang = ['zh', 'en', 'jp'];
|
export type RssParserLang = ['zh', 'en', 'jp'];
|
||||||
/** 重命名方式 */
|
/** 重命名方式 */
|
||||||
@@ -44,12 +40,8 @@ export interface Downloader {
|
|||||||
}
|
}
|
||||||
export interface RssParser {
|
export interface RssParser {
|
||||||
enable: boolean;
|
enable: boolean;
|
||||||
type: TupleToUnion<RssParserType>;
|
|
||||||
token: string;
|
|
||||||
custom_url: string;
|
|
||||||
filter: Array<string>;
|
filter: Array<string>;
|
||||||
language: TupleToUnion<RssParserLang>;
|
language: TupleToUnion<RssParserLang>;
|
||||||
parser_type: TupleToUnion<RssParserMethodType>;
|
|
||||||
}
|
}
|
||||||
export interface BangumiManage {
|
export interface BangumiManage {
|
||||||
enable: boolean;
|
enable: boolean;
|
||||||
@@ -105,6 +97,17 @@ export interface ExperimentalOpenAI {
|
|||||||
deployment_id?: string;
|
deployment_id?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Access control for the login endpoint and MCP server.
|
||||||
|
* Whitelist entries are IPv4/IPv6 CIDR strings (e.g. "192.168.0.0/16").
|
||||||
|
* An empty login_whitelist allows all IPs; an empty mcp_whitelist denies all IP-based MCP access.
|
||||||
|
*/
|
||||||
|
export interface Security {
|
||||||
|
login_whitelist: string[];
|
||||||
|
login_tokens: string[];
|
||||||
|
mcp_whitelist: string[];
|
||||||
|
mcp_tokens: string[];
|
||||||
|
}
|
||||||
|
|
||||||
export interface Config {
|
export interface Config {
|
||||||
program: Program;
|
program: Program;
|
||||||
downloader: Downloader;
|
downloader: Downloader;
|
||||||
@@ -114,6 +117,7 @@ export interface Config {
|
|||||||
proxy: Proxy;
|
proxy: Proxy;
|
||||||
notification: Notification;
|
notification: Notification;
|
||||||
experimental_openai: ExperimentalOpenAI;
|
experimental_openai: ExperimentalOpenAI;
|
||||||
|
security: Security;
|
||||||
}
|
}
|
||||||
|
|
||||||
export const initConfig: Config = {
|
export const initConfig: Config = {
|
||||||
@@ -132,12 +136,8 @@ export const initConfig: Config = {
|
|||||||
},
|
},
|
||||||
rss_parser: {
|
rss_parser: {
|
||||||
enable: true,
|
enable: true,
|
||||||
type: 'mikan',
|
|
||||||
token: '',
|
|
||||||
custom_url: '',
|
|
||||||
filter: [],
|
filter: [],
|
||||||
language: 'zh',
|
language: 'zh',
|
||||||
parser_type: 'parser',
|
|
||||||
},
|
},
|
||||||
bangumi_manage: {
|
bangumi_manage: {
|
||||||
enable: true,
|
enable: true,
|
||||||
@@ -171,4 +171,10 @@ export const initConfig: Config = {
|
|||||||
api_version: '2020-05-03',
|
api_version: '2020-05-03',
|
||||||
deployment_id: '',
|
deployment_id: '',
|
||||||
},
|
},
|
||||||
|
security: {
|
||||||
|
login_whitelist: [],
|
||||||
|
login_tokens: [],
|
||||||
|
mcp_whitelist: [],
|
||||||
|
mcp_tokens: [],
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|||||||
1
webui/types/dts/auto-imports.d.ts
vendored
1
webui/types/dts/auto-imports.d.ts
vendored
@@ -40,6 +40,7 @@ declare global {
|
|||||||
const getCurrentInstance: typeof import('vue')['getCurrentInstance']
|
const getCurrentInstance: typeof import('vue')['getCurrentInstance']
|
||||||
const getCurrentScope: typeof import('vue')['getCurrentScope']
|
const getCurrentScope: typeof import('vue')['getCurrentScope']
|
||||||
const h: typeof import('vue')['h']
|
const h: typeof import('vue')['h']
|
||||||
|
const i18n: typeof import('../../src/hooks/useMyI18n')['i18n']
|
||||||
const inject: typeof import('vue')['inject']
|
const inject: typeof import('vue')['inject']
|
||||||
const isProxy: typeof import('vue')['isProxy']
|
const isProxy: typeof import('vue')['isProxy']
|
||||||
const isReactive: typeof import('vue')['isReactive']
|
const isReactive: typeof import('vue')['isReactive']
|
||||||
|
|||||||
1
webui/types/dts/components.d.ts
vendored
1
webui/types/dts/components.d.ts
vendored
@@ -55,6 +55,7 @@ declare module '@vue/runtime-core' {
|
|||||||
ConfigPlayer: typeof import('./../../src/components/setting/config-player.vue')['default']
|
ConfigPlayer: typeof import('./../../src/components/setting/config-player.vue')['default']
|
||||||
ConfigProxy: typeof import('./../../src/components/setting/config-proxy.vue')['default']
|
ConfigProxy: typeof import('./../../src/components/setting/config-proxy.vue')['default']
|
||||||
ConfigSearchProvider: typeof import('./../../src/components/setting/config-search-provider.vue')['default']
|
ConfigSearchProvider: typeof import('./../../src/components/setting/config-search-provider.vue')['default']
|
||||||
|
ConfigSecurity: typeof import('./../../src/components/setting/config-security.vue')['default']
|
||||||
MediaQuery: typeof import('./../../src/components/media-query.vue')['default']
|
MediaQuery: typeof import('./../../src/components/media-query.vue')['default']
|
||||||
RouterLink: typeof import('vue-router')['RouterLink']
|
RouterLink: typeof import('vue-router')['RouterLink']
|
||||||
RouterView: typeof import('vue-router')['RouterView']
|
RouterView: typeof import('vue-router')['RouterView']
|
||||||
|
|||||||
Reference in New Issue
Block a user