feat(database): auto-fill NULL values with model defaults during migration

When migrating from older versions, new columns may have NULL values.
This adds a generic mechanism that scans all table models and fills
NULL values based on field defaults defined in SQLModel, improving
data consistency for upgraded databases.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Estrella Pan
2026-01-25 10:13:17 +01:00
parent 683e272b4d
commit d2bf733a3e

View File

@@ -1,10 +1,15 @@
import logging
from typing import Any, get_args, get_origin
from pydantic.fields import FieldInfo
from pydantic_core import PydanticUndefined
from sqlalchemy import inspect, text
from sqlmodel import Session, SQLModel
from module.models import Bangumi, User
from module.models.passkey import Passkey
from module.models.rss import RSSItem
from module.models.torrent import Torrent
from .bangumi import BangumiDatabase
from .engine import engine as e
@@ -14,6 +19,9 @@ from .user import UserDatabase
logger = logging.getLogger(__name__)
# 所有需要进行空值填充的表模型
TABLE_MODELS: list[type[SQLModel]] = [Bangumi, RSSItem, Torrent, User, Passkey]
# Increment this when adding new migrations to MIGRATIONS list.
CURRENT_SCHEMA_VERSION = 4
@@ -81,12 +89,14 @@ class Database(Session):
def _ensure_schema_version_table(self):
"""Create the schema_version table if it doesn't exist."""
with self.engine.connect() as conn:
conn.execute(text(
"CREATE TABLE IF NOT EXISTS schema_version ("
" id INTEGER PRIMARY KEY,"
" version INTEGER NOT NULL"
")"
))
conn.execute(
text(
"CREATE TABLE IF NOT EXISTS schema_version ("
" id INTEGER PRIMARY KEY,"
" version INTEGER NOT NULL"
")"
)
)
conn.commit()
def _get_schema_version(self) -> int:
@@ -95,16 +105,21 @@ class Database(Session):
if "schema_version" not in inspector.get_table_names():
return 0
with self.engine.connect() as conn:
result = conn.execute(text("SELECT version FROM schema_version WHERE id = 1"))
result = conn.execute(
text("SELECT version FROM schema_version WHERE id = 1")
)
row = result.fetchone()
return row[0] if row else 0
def _set_schema_version(self, version: int):
"""Update the schema version in the database."""
with self.engine.connect() as conn:
conn.execute(text(
"INSERT OR REPLACE INTO schema_version (id, version) VALUES (1, :version)"
), {"version": version})
conn.execute(
text(
"INSERT OR REPLACE INTO schema_version (id, version) VALUES (1, :version)"
),
{"version": version},
)
conn.commit()
def run_migrations(self):
@@ -141,9 +156,102 @@ class Database(Session):
conn.commit()
logger.info(f"[Database] Migration v{version}: {description}")
else:
logger.debug(f"[Database] Migration v{version} skipped (already applied): {description}")
logger.debug(
f"[Database] Migration v{version} skipped (already applied): {description}"
)
self._set_schema_version(CURRENT_SCHEMA_VERSION)
logger.info(f"[Database] Schema version is now {CURRENT_SCHEMA_VERSION}.")
self._fill_null_with_defaults()
def _get_field_default(self, field_info: FieldInfo) -> tuple[bool, Any]:
"""
获取字段的默认值。
返回:
(has_default, default_value) - 是否有可用的默认值,以及默认值
"""
# 跳过 default_factory如 datetime.utcnow不适合批量填充
if field_info.default_factory is not None:
return False, None
# 跳过没有默认值的字段PydanticUndefined
if field_info.default is PydanticUndefined:
return False, None
return True, field_info.default
def _is_optional_field(self, model: type[SQLModel], field_name: str) -> bool:
"""检查字段是否为 Optional 类型"""
hints = model.__annotations__.get(field_name)
if hints is None:
return False
origin = get_origin(hints)
# Optional[X] 等同于 Union[X, None]
if origin is not None:
args = get_args(hints)
return type(None) in args
return False
def _fill_null_with_defaults(self):
"""
根据模型定义的默认值,自动填充数据库中的 NULL 值。
规则:
- 跳过主键字段
- 跳过 Optional 字段且默认值为 None 的情况
- 跳过使用 default_factory 的字段
- 只填充有明确非 None 默认值的字段
"""
inspector = inspect(self.engine)
tables = inspector.get_table_names()
for model in TABLE_MODELS:
table_name = model.__tablename__
if table_name not in tables:
continue
db_columns = {col["name"] for col in inspector.get_columns(table_name)}
fields_to_fill: list[tuple[str, Any]] = []
for field_name, field_info in model.model_fields.items():
# 跳过主键
if field_info.is_required() and field_name == "id":
continue
# 跳过数据库中不存在的列
if field_name not in db_columns:
continue
has_default, default_value = self._get_field_default(field_info)
if not has_default:
continue
# 如果是 Optional 且默认值为 None跳过
if default_value is None and self._is_optional_field(model, field_name):
continue
fields_to_fill.append((field_name, default_value))
if not fields_to_fill:
continue
with self.engine.connect() as conn:
for field_name, default_value in fields_to_fill:
# 转换 Python 值为 SQL 值
if isinstance(default_value, bool):
sql_value = 1 if default_value else 0
else:
sql_value = default_value
result = conn.execute(
text(
f"UPDATE {table_name} SET {field_name} = :val "
f"WHERE {field_name} IS NULL"
),
{"val": sql_value},
)
if result.rowcount > 0:
logger.info(
f"[Database] Filled {result.rowcount} NULL values "
f"in {table_name}.{field_name} with default: {default_value}"
)
conn.commit()
def drop_table(self):
SQLModel.metadata.drop_all(self.engine)