From d2bf733a3efb74e87799e34fa0dbbdb986411d15 Mon Sep 17 00:00:00 2001 From: Estrella Pan Date: Sun, 25 Jan 2026 10:13:17 +0100 Subject: [PATCH] feat(database): auto-fill NULL values with model defaults during migration When migrating from older versions, new columns may have NULL values. This adds a generic mechanism that scans all table models and fills NULL values based on field defaults defined in SQLModel, improving data consistency for upgraded databases. Co-Authored-By: Claude Opus 4.5 --- backend/src/module/database/combine.py | 130 ++++++++++++++++++++++--- 1 file changed, 119 insertions(+), 11 deletions(-) diff --git a/backend/src/module/database/combine.py b/backend/src/module/database/combine.py index e9e4632a..0ea186a5 100644 --- a/backend/src/module/database/combine.py +++ b/backend/src/module/database/combine.py @@ -1,10 +1,15 @@ import logging +from typing import Any, get_args, get_origin +from pydantic.fields import FieldInfo +from pydantic_core import PydanticUndefined from sqlalchemy import inspect, text from sqlmodel import Session, SQLModel from module.models import Bangumi, User from module.models.passkey import Passkey +from module.models.rss import RSSItem +from module.models.torrent import Torrent from .bangumi import BangumiDatabase from .engine import engine as e @@ -14,6 +19,9 @@ from .user import UserDatabase logger = logging.getLogger(__name__) +# 所有需要进行空值填充的表模型 +TABLE_MODELS: list[type[SQLModel]] = [Bangumi, RSSItem, Torrent, User, Passkey] + # Increment this when adding new migrations to MIGRATIONS list. CURRENT_SCHEMA_VERSION = 4 @@ -81,12 +89,14 @@ class Database(Session): def _ensure_schema_version_table(self): """Create the schema_version table if it doesn't exist.""" with self.engine.connect() as conn: - conn.execute(text( - "CREATE TABLE IF NOT EXISTS schema_version (" - " id INTEGER PRIMARY KEY," - " version INTEGER NOT NULL" - ")" - )) + conn.execute( + text( + "CREATE TABLE IF NOT EXISTS schema_version (" + " id INTEGER PRIMARY KEY," + " version INTEGER NOT NULL" + ")" + ) + ) conn.commit() def _get_schema_version(self) -> int: @@ -95,16 +105,21 @@ class Database(Session): if "schema_version" not in inspector.get_table_names(): return 0 with self.engine.connect() as conn: - result = conn.execute(text("SELECT version FROM schema_version WHERE id = 1")) + result = conn.execute( + text("SELECT version FROM schema_version WHERE id = 1") + ) row = result.fetchone() return row[0] if row else 0 def _set_schema_version(self, version: int): """Update the schema version in the database.""" with self.engine.connect() as conn: - conn.execute(text( - "INSERT OR REPLACE INTO schema_version (id, version) VALUES (1, :version)" - ), {"version": version}) + conn.execute( + text( + "INSERT OR REPLACE INTO schema_version (id, version) VALUES (1, :version)" + ), + {"version": version}, + ) conn.commit() def run_migrations(self): @@ -141,9 +156,102 @@ class Database(Session): conn.commit() logger.info(f"[Database] Migration v{version}: {description}") else: - logger.debug(f"[Database] Migration v{version} skipped (already applied): {description}") + logger.debug( + f"[Database] Migration v{version} skipped (already applied): {description}" + ) self._set_schema_version(CURRENT_SCHEMA_VERSION) logger.info(f"[Database] Schema version is now {CURRENT_SCHEMA_VERSION}.") + self._fill_null_with_defaults() + + def _get_field_default(self, field_info: FieldInfo) -> tuple[bool, Any]: + """ + 获取字段的默认值。 + + 返回: + (has_default, default_value) - 是否有可用的默认值,以及默认值 + """ + # 跳过 default_factory(如 datetime.utcnow),不适合批量填充 + if field_info.default_factory is not None: + return False, None + # 跳过没有默认值的字段(PydanticUndefined) + if field_info.default is PydanticUndefined: + return False, None + return True, field_info.default + + def _is_optional_field(self, model: type[SQLModel], field_name: str) -> bool: + """检查字段是否为 Optional 类型""" + hints = model.__annotations__.get(field_name) + if hints is None: + return False + origin = get_origin(hints) + # Optional[X] 等同于 Union[X, None] + if origin is not None: + args = get_args(hints) + return type(None) in args + return False + + def _fill_null_with_defaults(self): + """ + 根据模型定义的默认值,自动填充数据库中的 NULL 值。 + + 规则: + - 跳过主键字段 + - 跳过 Optional 字段且默认值为 None 的情况 + - 跳过使用 default_factory 的字段 + - 只填充有明确非 None 默认值的字段 + """ + inspector = inspect(self.engine) + tables = inspector.get_table_names() + + for model in TABLE_MODELS: + table_name = model.__tablename__ + if table_name not in tables: + continue + + db_columns = {col["name"] for col in inspector.get_columns(table_name)} + fields_to_fill: list[tuple[str, Any]] = [] + + for field_name, field_info in model.model_fields.items(): + # 跳过主键 + if field_info.is_required() and field_name == "id": + continue + # 跳过数据库中不存在的列 + if field_name not in db_columns: + continue + + has_default, default_value = self._get_field_default(field_info) + if not has_default: + continue + # 如果是 Optional 且默认值为 None,跳过 + if default_value is None and self._is_optional_field(model, field_name): + continue + + fields_to_fill.append((field_name, default_value)) + + if not fields_to_fill: + continue + + with self.engine.connect() as conn: + for field_name, default_value in fields_to_fill: + # 转换 Python 值为 SQL 值 + if isinstance(default_value, bool): + sql_value = 1 if default_value else 0 + else: + sql_value = default_value + + result = conn.execute( + text( + f"UPDATE {table_name} SET {field_name} = :val " + f"WHERE {field_name} IS NULL" + ), + {"val": sql_value}, + ) + if result.rowcount > 0: + logger.info( + f"[Database] Filled {result.rowcount} NULL values " + f"in {table_name}.{field_name} with default: {default_value}" + ) + conn.commit() def drop_table(self): SQLModel.metadata.drop_all(self.engine)