diff --git a/docs/cli.md b/docs/cli.md index 903ad31e..0ceec341 100644 --- a/docs/cli.md +++ b/docs/cli.md @@ -206,7 +206,12 @@ moviepilot setup --config-dir /path/to/moviepilot-config `--wizard` 会进入交互式初始化向导,支持配置: - `API_TOKEN` +- 数据库类型 + 默认 `SQLite` + 可切换为 `PostgreSQL`,并填写主机、端口、数据库名、用户名、密码 - 默认下载目录与媒体库目录 +- AI Agent + 可按需启用,并配置 `LLM_PROVIDER`、`LLM_MODEL`、`LLM_API_KEY`、`LLM_BASE_URL` - 下载器 - 媒体服务器 - 消息通知渠道 diff --git a/scripts/local_setup.py b/scripts/local_setup.py index 9bcf7c7f..7d2bc2a1 100644 --- a/scripts/local_setup.py +++ b/scripts/local_setup.py @@ -40,6 +40,20 @@ DEFAULT_NODE_VERSION = "20.12.1" FRONTEND_LATEST_API = "https://api.github.com/repos/jxxghp/MoviePilot-Frontend/releases/latest" FRONTEND_TAG_API = "https://api.github.com/repos/jxxghp/MoviePilot-Frontend/releases/tags/{tag}" RESOURCES_MAIN_ZIP = "https://github.com/jxxghp/MoviePilot-Resources/archive/refs/heads/main.zip" +LLM_PROVIDER_DEFAULTS = { + "deepseek": { + "model": "deepseek-chat", + "base_url": "https://api.deepseek.com", + }, + "openai": { + "model": "gpt-4o-mini", + "base_url": "https://api.openai.com/v1", + }, + "google": { + "model": "gemini-2.5-flash", + "base_url": "", + }, +} RUNTIME_PACKAGE = { "name": "moviepilot-frontend-runtime", "private": True, @@ -218,6 +232,14 @@ def _load_env_lines() -> list[str]: return ENV_FILE.read_text(encoding="utf-8").splitlines(keepends=True) +def _serialize_env_value(value: Any) -> str: + if isinstance(value, Path): + value = str(value) + if value is None: + return '""' + return json.dumps(value, ensure_ascii=False) + + def read_env_value(key: str) -> Optional[str]: for line in _load_env_lines(): stripped = line.strip() @@ -232,7 +254,7 @@ def read_env_value(key: str) -> Optional[str]: def write_env_value(key: str, value: str) -> None: ensure_local_dirs() lines = _load_env_lines() - new_line = f"{key}={json.dumps(str(value), ensure_ascii=False)}\n" + new_line = f"{key}={_serialize_env_value(value)}\n" for index, line in enumerate(lines): stripped = line.strip() @@ -250,6 +272,11 @@ def write_env_value(key: str, value: str) -> None: ENV_FILE.write_text("".join(lines), encoding="utf-8") +def write_env_values(values: dict[str, Any]) -> None: + for key, value in values.items(): + write_env_value(key, value) + + def ensure_api_token(force_token: bool = False, token: Optional[str] = None) -> str: ensure_local_dirs() current_token = read_env_value("API_TOKEN") or "" @@ -546,6 +573,30 @@ def _normalize_choice(value: str) -> str: return value.strip().lower().replace("_", "").replace("-", "") +def _env_default(key: str, default: str = "") -> str: + value = read_env_value(key) + if value is None or value == "": + return default + return value + + +def _env_bool(key: str, default: bool) -> bool: + value = read_env_value(key) + if value is None or value == "": + return default + return value.strip().lower() in {"1", "true", "yes", "y", "on"} + + +def _env_int(key: str, default: int) -> int: + value = read_env_value(key) + if value is None or value == "": + return default + try: + return int(value) + except (TypeError, ValueError): + return default + + def _prompt_text( label: str, *, @@ -568,6 +619,29 @@ def _prompt_text( print("请输入有效内容,或使用回车接受默认值。") +def _prompt_secret_text( + label: str, + *, + current_value: Optional[str] = None, + allow_empty: bool = False, + required: bool = False, +) -> str: + while True: + suffix = " [留空保持现有值]" if current_value not in (None, "") else "" + prompt = f"{label}{suffix}: " + value = getpass.getpass(prompt).strip() + + if value: + return value + if current_value is not None and current_value != "": + return current_value + if allow_empty and not required: + return "" + if not required: + return "" + print("请输入有效内容。") + + def _prompt_yes_no(label: str, default: bool = True) -> bool: suffix = "Y/n" if default else "y/N" while True: @@ -652,6 +726,55 @@ def _collect_directory_config() -> dict[str, Any]: } +def _collect_database_config() -> dict[str, Any]: + print_step("数据库配置") + current_db_type = _env_default("DB_TYPE", "sqlite").lower() + if current_db_type not in {"sqlite", "postgresql"}: + current_db_type = "sqlite" + + db_type = _prompt_choice( + "选择数据库类型", + { + "sqlite": "SQLite", + "postgresql": "PostgreSQL", + }, + default=current_db_type, + ) + + config: dict[str, Any] = { + "DB_TYPE": db_type, + } + if db_type == "sqlite": + return config + + config.update( + { + "DB_POSTGRESQL_HOST": _prompt_text( + "PostgreSQL 主机地址", + default=_env_default("DB_POSTGRESQL_HOST", "localhost"), + ), + "DB_POSTGRESQL_PORT": _prompt_text( + "PostgreSQL 端口", + default=str(_env_int("DB_POSTGRESQL_PORT", 5432)), + ), + "DB_POSTGRESQL_DATABASE": _prompt_text( + "PostgreSQL 数据库名(需已创建)", + default=_env_default("DB_POSTGRESQL_DATABASE", "moviepilot"), + ), + "DB_POSTGRESQL_USERNAME": _prompt_text( + "PostgreSQL 用户名", + default=_env_default("DB_POSTGRESQL_USERNAME", "moviepilot"), + ), + "DB_POSTGRESQL_PASSWORD": _prompt_secret_text( + "PostgreSQL 密码", + current_value=read_env_value("DB_POSTGRESQL_PASSWORD"), + allow_empty=True, + ), + } + ) + return config + + def _collect_downloader_config() -> Optional[dict[str, Any]]: print_step("下载器配置") downloader_type = _prompt_choice( @@ -808,6 +931,73 @@ def _collect_notification_config() -> Optional[dict[str, Any]]: } +def _collect_agent_config() -> dict[str, Any]: + print_step("AI Agent 配置") + enabled = _prompt_yes_no( + "是否启用 AI 智能体", + default=_env_bool("AI_AGENT_ENABLE", False), + ) + if not enabled: + return { + "AI_AGENT_ENABLE": False, + "AI_AGENT_GLOBAL": False, + } + + current_provider = _env_default("LLM_PROVIDER", "deepseek").lower() + if current_provider not in LLM_PROVIDER_DEFAULTS: + current_provider = "deepseek" + + provider = _prompt_choice( + "选择 LLM 提供商", + { + "deepseek": "DeepSeek", + "openai": "OpenAI", + "google": "Google", + }, + default=current_provider, + ) + defaults = LLM_PROVIDER_DEFAULTS[provider] + current_model = _env_default("LLM_MODEL", defaults["model"]) + current_base_url = _env_default("LLM_BASE_URL", defaults["base_url"]) + + config: dict[str, Any] = { + "AI_AGENT_ENABLE": True, + "AI_AGENT_GLOBAL": _prompt_yes_no( + "是否启用全局 AI 智能体", + default=_env_bool("AI_AGENT_GLOBAL", False), + ), + "LLM_PROVIDER": provider, + "LLM_MODEL": _prompt_text( + "LLM 模型名称", + default=current_model, + ), + "LLM_API_KEY": _prompt_secret_text( + "LLM API Key", + current_value=read_env_value("LLM_API_KEY"), + required=True, + ), + "LLM_SUPPORT_IMAGE_INPUT": _prompt_yes_no( + "是否启用图片输入支持", + default=_env_bool("LLM_SUPPORT_IMAGE_INPUT", True), + ), + } + + if provider == "google": + config["LLM_BASE_URL"] = _prompt_text( + "自定义 Google API Base URL(可选)", + default=current_base_url, + allow_empty=True, + ) + else: + config["LLM_BASE_URL"] = _prompt_text( + "LLM Base URL", + default=current_base_url, + allow_empty=True, + ) + + return config + + def run_setup_wizard(force_token: bool) -> dict[str, Any]: if not _is_interactive(): raise RuntimeError("交互式向导需要在终端中运行,请直接执行 moviepilot setup --wizard 或 moviepilot init --wizard") @@ -841,6 +1031,10 @@ def run_setup_wizard(force_token: bool) -> dict[str, Any]: return { "api_token": api_token, + "env_settings": { + **_collect_database_config(), + **_collect_agent_config(), + }, "directories": [_collect_directory_config()], "downloader": _collect_downloader_config(), "mediaserver": _collect_media_server_config(), @@ -1006,6 +1200,10 @@ def init_local( else: ensure_api_token(force_token=force_token) + if wizard_payload and wizard_payload.get("env_settings"): + write_env_values(wizard_payload["env_settings"]) + print_step(f"已写入环境配置到 {ENV_FILE}") + if skip_resources: if resources_ready: print_step("资源文件已完成同步")