mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-03-20 03:57:30 +08:00
119 lines
3.7 KiB
Python
119 lines
3.7 KiB
Python
"""2.2.0
|
||
|
||
Revision ID: 5b3355c964bb
|
||
Revises: d58298a0879f
|
||
Create Date: 2025-08-19 12:27:08.451371
|
||
|
||
"""
|
||
import sqlalchemy as sa
|
||
from alembic import op
|
||
|
||
from app.log import logger
|
||
from app.core.config import settings
|
||
|
||
# revision identifiers, used by Alembic.
|
||
revision = '5b3355c964bb'
|
||
down_revision = 'd58298a0879f'
|
||
branch_labels = None
|
||
depends_on = None
|
||
|
||
|
||
def upgrade() -> None:
|
||
# ### commands auto generated by Alembic - please adjust! ###
|
||
if settings.DB_TYPE.lower() == "postgresql":
|
||
# 将SQLite的Sequence转换为PostgreSQL的Identity
|
||
fix_postgresql_sequences()
|
||
# ### end Alembic commands ###
|
||
|
||
|
||
def fix_postgresql_sequences():
|
||
"""
|
||
修复PostgreSQL数据库中的序列问题
|
||
将SQLite迁移过来的Sequence转换为PostgreSQL的Identity
|
||
"""
|
||
connection = op.get_bind()
|
||
|
||
# 获取所有表名
|
||
result = connection.execute(sa.text("""
|
||
SELECT table_name
|
||
FROM information_schema.tables
|
||
WHERE table_schema = 'public'
|
||
AND table_type = 'BASE TABLE'
|
||
"""))
|
||
tables = [row[0] for row in result.fetchall()]
|
||
|
||
logger.info(f"发现 {len(tables)} 个表需要检查序列")
|
||
|
||
for table_name in tables:
|
||
fix_table_sequence(connection, table_name)
|
||
|
||
|
||
def fix_table_sequence(connection, table_name):
|
||
"""
|
||
修复单个表的序列
|
||
"""
|
||
try:
|
||
# 跳过alembic_version表,它没有id列
|
||
if table_name == 'alembic_version':
|
||
logger.debug(f"跳过表 {table_name},这是Alembic版本表")
|
||
return
|
||
|
||
# 检查表是否有id列
|
||
result = connection.execute(sa.text(f"""
|
||
SELECT is_identity, column_default
|
||
FROM information_schema.columns
|
||
WHERE table_name = '{table_name}'
|
||
AND column_name = 'id'
|
||
"""))
|
||
|
||
id_column = result.fetchone()
|
||
if not id_column:
|
||
logger.debug(f"表 {table_name} 没有id列,跳过")
|
||
return
|
||
|
||
is_identity, column_default = id_column
|
||
|
||
# 检查是否已经是Identity类型
|
||
if is_identity == 'YES' or (column_default and 'GENERATED BY DEFAULT AS IDENTITY' in column_default):
|
||
logger.debug(f"表 {table_name} 的id列已经是Identity类型,跳过")
|
||
return
|
||
|
||
# 检查是否有序列
|
||
logger.info(f"表 {table_name} 存在序列,需要修复")
|
||
convert_to_identity(connection, table_name)
|
||
|
||
except Exception as e:
|
||
logger.error(f"修复表 {table_name} 序列时出错: {e}")
|
||
# 回滚当前事务,避免影响后续操作
|
||
connection.rollback()
|
||
|
||
|
||
def convert_to_identity(connection, table_name):
|
||
"""
|
||
将序列转换为Identity,保持原有约束不变
|
||
"""
|
||
try:
|
||
# 获取当前序列的最大值
|
||
result = connection.execute(sa.text(f"""
|
||
SELECT COALESCE(MAX(id), 0) + 1 as next_value
|
||
FROM "{table_name}"
|
||
"""))
|
||
next_value = result.fetchone()[0]
|
||
|
||
# 直接修改列属性,添加Identity,保持其他约束不变
|
||
# 这种方式不会删除主键约束和索引
|
||
connection.execute(sa.text(f"""
|
||
ALTER TABLE "{table_name}"
|
||
ALTER COLUMN id ADD GENERATED BY DEFAULT AS IDENTITY (START WITH {next_value})
|
||
"""))
|
||
|
||
logger.info(f"表 {table_name} 序列已转换为Identity,起始值为 {next_value}")
|
||
|
||
except Exception as e:
|
||
# 如果是已经存在的Identity错误,则忽略
|
||
if "already an identity column" in str(e):
|
||
logger.warn(f"表 {table_name} 的id列已经是Identity类型,忽略此错误: {e}")
|
||
return
|
||
logger.error(f"转换表 {table_name} 序列时出错: {e}")
|
||
raise
|