mirror of
https://github.com/EstrellaXD/Auto_Bangumi.git
synced 2026-04-09 05:29:51 +08:00
feat(database): auto-fill NULL values with model defaults during migration
When migrating from older versions, new columns may have NULL values. This adds a generic mechanism that scans all table models and fills NULL values based on field defaults defined in SQLModel, improving data consistency for upgraded databases. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -1,10 +1,15 @@
|
||||
import logging
|
||||
from typing import Any, get_args, get_origin
|
||||
|
||||
from pydantic.fields import FieldInfo
|
||||
from pydantic_core import PydanticUndefined
|
||||
from sqlalchemy import inspect, text
|
||||
from sqlmodel import Session, SQLModel
|
||||
|
||||
from module.models import Bangumi, User
|
||||
from module.models.passkey import Passkey
|
||||
from module.models.rss import RSSItem
|
||||
from module.models.torrent import Torrent
|
||||
|
||||
from .bangumi import BangumiDatabase
|
||||
from .engine import engine as e
|
||||
@@ -14,6 +19,9 @@ from .user import UserDatabase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 所有需要进行空值填充的表模型
|
||||
TABLE_MODELS: list[type[SQLModel]] = [Bangumi, RSSItem, Torrent, User, Passkey]
|
||||
|
||||
# Increment this when adding new migrations to MIGRATIONS list.
|
||||
CURRENT_SCHEMA_VERSION = 4
|
||||
|
||||
@@ -81,12 +89,14 @@ class Database(Session):
|
||||
def _ensure_schema_version_table(self):
|
||||
"""Create the schema_version table if it doesn't exist."""
|
||||
with self.engine.connect() as conn:
|
||||
conn.execute(text(
|
||||
"CREATE TABLE IF NOT EXISTS schema_version ("
|
||||
" id INTEGER PRIMARY KEY,"
|
||||
" version INTEGER NOT NULL"
|
||||
")"
|
||||
))
|
||||
conn.execute(
|
||||
text(
|
||||
"CREATE TABLE IF NOT EXISTS schema_version ("
|
||||
" id INTEGER PRIMARY KEY,"
|
||||
" version INTEGER NOT NULL"
|
||||
")"
|
||||
)
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def _get_schema_version(self) -> int:
|
||||
@@ -95,16 +105,21 @@ class Database(Session):
|
||||
if "schema_version" not in inspector.get_table_names():
|
||||
return 0
|
||||
with self.engine.connect() as conn:
|
||||
result = conn.execute(text("SELECT version FROM schema_version WHERE id = 1"))
|
||||
result = conn.execute(
|
||||
text("SELECT version FROM schema_version WHERE id = 1")
|
||||
)
|
||||
row = result.fetchone()
|
||||
return row[0] if row else 0
|
||||
|
||||
def _set_schema_version(self, version: int):
|
||||
"""Update the schema version in the database."""
|
||||
with self.engine.connect() as conn:
|
||||
conn.execute(text(
|
||||
"INSERT OR REPLACE INTO schema_version (id, version) VALUES (1, :version)"
|
||||
), {"version": version})
|
||||
conn.execute(
|
||||
text(
|
||||
"INSERT OR REPLACE INTO schema_version (id, version) VALUES (1, :version)"
|
||||
),
|
||||
{"version": version},
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def run_migrations(self):
|
||||
@@ -141,9 +156,102 @@ class Database(Session):
|
||||
conn.commit()
|
||||
logger.info(f"[Database] Migration v{version}: {description}")
|
||||
else:
|
||||
logger.debug(f"[Database] Migration v{version} skipped (already applied): {description}")
|
||||
logger.debug(
|
||||
f"[Database] Migration v{version} skipped (already applied): {description}"
|
||||
)
|
||||
self._set_schema_version(CURRENT_SCHEMA_VERSION)
|
||||
logger.info(f"[Database] Schema version is now {CURRENT_SCHEMA_VERSION}.")
|
||||
self._fill_null_with_defaults()
|
||||
|
||||
def _get_field_default(self, field_info: FieldInfo) -> tuple[bool, Any]:
|
||||
"""
|
||||
获取字段的默认值。
|
||||
|
||||
返回:
|
||||
(has_default, default_value) - 是否有可用的默认值,以及默认值
|
||||
"""
|
||||
# 跳过 default_factory(如 datetime.utcnow),不适合批量填充
|
||||
if field_info.default_factory is not None:
|
||||
return False, None
|
||||
# 跳过没有默认值的字段(PydanticUndefined)
|
||||
if field_info.default is PydanticUndefined:
|
||||
return False, None
|
||||
return True, field_info.default
|
||||
|
||||
def _is_optional_field(self, model: type[SQLModel], field_name: str) -> bool:
|
||||
"""检查字段是否为 Optional 类型"""
|
||||
hints = model.__annotations__.get(field_name)
|
||||
if hints is None:
|
||||
return False
|
||||
origin = get_origin(hints)
|
||||
# Optional[X] 等同于 Union[X, None]
|
||||
if origin is not None:
|
||||
args = get_args(hints)
|
||||
return type(None) in args
|
||||
return False
|
||||
|
||||
def _fill_null_with_defaults(self):
|
||||
"""
|
||||
根据模型定义的默认值,自动填充数据库中的 NULL 值。
|
||||
|
||||
规则:
|
||||
- 跳过主键字段
|
||||
- 跳过 Optional 字段且默认值为 None 的情况
|
||||
- 跳过使用 default_factory 的字段
|
||||
- 只填充有明确非 None 默认值的字段
|
||||
"""
|
||||
inspector = inspect(self.engine)
|
||||
tables = inspector.get_table_names()
|
||||
|
||||
for model in TABLE_MODELS:
|
||||
table_name = model.__tablename__
|
||||
if table_name not in tables:
|
||||
continue
|
||||
|
||||
db_columns = {col["name"] for col in inspector.get_columns(table_name)}
|
||||
fields_to_fill: list[tuple[str, Any]] = []
|
||||
|
||||
for field_name, field_info in model.model_fields.items():
|
||||
# 跳过主键
|
||||
if field_info.is_required() and field_name == "id":
|
||||
continue
|
||||
# 跳过数据库中不存在的列
|
||||
if field_name not in db_columns:
|
||||
continue
|
||||
|
||||
has_default, default_value = self._get_field_default(field_info)
|
||||
if not has_default:
|
||||
continue
|
||||
# 如果是 Optional 且默认值为 None,跳过
|
||||
if default_value is None and self._is_optional_field(model, field_name):
|
||||
continue
|
||||
|
||||
fields_to_fill.append((field_name, default_value))
|
||||
|
||||
if not fields_to_fill:
|
||||
continue
|
||||
|
||||
with self.engine.connect() as conn:
|
||||
for field_name, default_value in fields_to_fill:
|
||||
# 转换 Python 值为 SQL 值
|
||||
if isinstance(default_value, bool):
|
||||
sql_value = 1 if default_value else 0
|
||||
else:
|
||||
sql_value = default_value
|
||||
|
||||
result = conn.execute(
|
||||
text(
|
||||
f"UPDATE {table_name} SET {field_name} = :val "
|
||||
f"WHERE {field_name} IS NULL"
|
||||
),
|
||||
{"val": sql_value},
|
||||
)
|
||||
if result.rowcount > 0:
|
||||
logger.info(
|
||||
f"[Database] Filled {result.rowcount} NULL values "
|
||||
f"in {table_name}.{field_name} with default: {default_value}"
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def drop_table(self):
|
||||
SQLModel.metadata.drop_all(self.engine)
|
||||
|
||||
Reference in New Issue
Block a user