From 87813c853b36a6ba211d3598d93d03fe41cb9d99 Mon Sep 17 00:00:00 2001 From: InfinityPacer <160988576+InfinityPacer@users.noreply.github.com> Date: Sun, 20 Oct 2024 23:39:20 +0800 Subject: [PATCH] fix(config): improve env update logic --- app/core/config.py | 43 ++++++++++++++++++++++++++----------------- 1 file changed, 26 insertions(+), 17 deletions(-) diff --git a/app/core/config.py b/app/core/config.py index 86ceb897..88009064 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -228,17 +228,17 @@ class Settings(BaseSettings, ConfigModel): if not app_env_path.exists(): SystemUtils.copy(self.INNER_CONFIG_PATH / "app.env", app_env_path) - @validator("API_TOKEN", pre=True, always=True) - def validate_api_token(cls, value: Any, field): + @staticmethod + def validate_api_token(original_value: Any) -> Tuple[Any, bool]: + value = original_value.strip() if not value or len(value) < 16: new_token = secrets.token_urlsafe(16) if not value: logger.info(f"'API_TOKEN' 未设置,已随机生成新的【API_TOKEN】{new_token}") else: logger.warning(f"'API_TOKEN' 长度不足 16 个字符,存在安全隐患,已随机生成新的【API_TOKEN】{new_token}") - cls.update_env_config(field, original_value=value or "", converted_value=new_token) - return new_token - return value + return new_token, True + return value, value != original_value @staticmethod def generic_type_converter(value: Any, original_value: Any, expected_type: Type, default: Any, field_name: str, @@ -263,7 +263,7 @@ class Settings(BaseSettings, ConfigModel): "true": True, "yes": True, "1": True, "on": True } if value_clean in bool_map: - return bool_map[value_clean], value_clean != original_value.lower() + return bool_map[value_clean], value_clean != str(original_value).lower() elif isinstance(value, (int, float)): return bool(value), False return default, False @@ -307,8 +307,11 @@ class Settings(BaseSettings, ConfigModel): """ 通用校验器,尝试将配置值转换为期望的类型 """ - converted_value, needs_update = cls.generic_type_converter(value, value, field.type_, field.default, - field.name) + if field.name == "API_TOKEN": + converted_value, needs_update = cls.validate_api_token(value) + else: + converted_value, needs_update = cls.generic_type_converter(value, value, field.type_, field.default, + field.name) if needs_update: cls.update_env_config(field, value, converted_value) return converted_value @@ -318,21 +321,22 @@ class Settings(BaseSettings, ConfigModel): """ 更新 env 配置 """ - is_converted = original_value is not None and original_value != converted_value + message = None + is_converted = original_value is not None and str(original_value) != str(converted_value) if is_converted: - logger.warning(f"配置项 '{field.name}' 的值 '{original_value}' 无效,已替换为 '{converted_value}'") + message = f"配置项 '{field.name}' 的值 '{original_value}' 无效,已替换为 '{converted_value}'" + logger.warning(message) if field.name in os.environ: if is_converted: message = f"配置项 '{field.name}' 已在环境变量中设置,请手动更新以保持一致性" logger.warning(message) return False, message - return True, "" else: set_key(SystemUtils.get_env_path(), field.name, str(converted_value) if converted_value is not None else "") if is_converted: logger.info(f"配置项 '{field.name}' 已自动修正并写入到 'app.env' 文件") - return True, "" + return True, message def update_setting(self, key: str, value: Any) -> Tuple[bool, str]: """ @@ -343,11 +347,16 @@ class Settings(BaseSettings, ConfigModel): try: field = self.__fields__[key] - converted_value, _ = self.generic_type_converter(value, getattr(self, key), field.type_, - field.default, key) - # 如果没有抛出异常,则统一使用 converted_value 进行更新 - setattr(self, key, converted_value) - return self.update_env_config(field, value, converted_value) + if field.name == "API_TOKEN": + converted_value, needs_update = self.validate_api_token(value) + else: + converted_value, needs_update = self.generic_type_converter(value, getattr(self, key), field.type_, + field.default, key) + # 如果没有抛出异常且需要更新,则统一使用 converted_value 进行更新 + if needs_update: + setattr(self, key, converted_value) + return self.update_env_config(field, value, converted_value) + return True, "" except Exception as e: return False, str(e)