添加对 PostgreSQL 的支持

This commit is contained in:
jxxghp
2025-08-18 11:19:17 +08:00
parent e8058c8813
commit b6cf54d57f
23 changed files with 304 additions and 78 deletions

View File

@@ -82,10 +82,16 @@ class ConfigModel(BaseModel):
PORT: int = 3001
# 前端监听端口
NGINX_PORT: int = 3000
# 配置文件目录
CONFIG_DIR: Optional[str] = None
# 超级管理员
SUPERUSER: str = "admin"
# 是否调试模式
DEBUG: bool = False
# 是否开发模式
DEV: bool = False
# 数据库类型,支持 sqlite 和 postgresql默认使用 sqlite
DB_TYPE: str = "sqlite"
# 是否在控制台输出 SQL 语句,默认关闭
DB_ECHO: bool = False
# 数据库连接池类型QueuePool, NullPool
@@ -100,16 +106,26 @@ class ConfigModel(BaseModel):
DB_TIMEOUT: int = 60
# SQLite 是否启用 WAL 模式,默认开启
DB_WAL_ENABLE: bool = True
# PostgreSQL 主机地址
DB_POSTGRESQL_HOST: str = "localhost"
# PostgreSQL 端口
DB_POSTGRESQL_PORT: int = 5432
# PostgreSQL 数据库名
DB_POSTGRESQL_DATABASE: str = "moviepilot"
# PostgreSQL 用户名
DB_POSTGRESQL_USERNAME: str = "moviepilot"
# PostgreSQL 密码
DB_POSTGRESQL_PASSWORD: str = "moviepilot"
# PostgreSQL 连接池大小
DB_POSTGRESQL_POOL_SIZE: int = 20
# PostgreSQL 连接池溢出数量
DB_POSTGRESQL_MAX_OVERFLOW: int = 30
# 缓存类型,支持 cachetools 和 redis默认使用 cachetools
CACHE_BACKEND_TYPE: str = "cachetools"
# 缓存连接字符串,仅外部缓存(如 Redis、Memcached需要
CACHE_BACKEND_URL: Optional[str] = None
# Redis 缓存最大内存限制,未配置时,如开启大内存模式时为 "1024mb",未开启时为 "256mb"
CACHE_REDIS_MAXMEMORY: Optional[str] = None
# 配置文件目录
CONFIG_DIR: Optional[str] = None
# 超级管理员
SUPERUSER: str = "admin"
# 辅助认证,允许通过外部服务进行认证、单点登录以及自动创建用户
AUXILIARY_AUTH_ENABLE: bool = False
# API密钥需要更换
@@ -124,6 +140,10 @@ class ConfigModel(BaseModel):
SEARCH_SOURCE: str = "themoviedb,douban,bangumi"
# 媒体识别来源 themoviedb/douban
RECOGNIZE_SOURCE: str = "themoviedb"
# 元数据识别缓存过期时间(小时)
META_CACHE_EXPIRE: int = 0
# 电视剧动漫的分类genre_ids
ANIME_GENREIDS: List[int] = Field(default=[16])
# 刮削来源 themoviedb/douban
SCRAP_SOURCE: str = "themoviedb"
# 新增已入库媒体是否跟随TMDB信息变化
@@ -151,10 +171,6 @@ class ConfigModel(BaseModel):
U115_APP_ID: str = "100196807"
# Alipan AppId
ALIPAN_APP_ID: str = "ac1bf04dc9fd4d9aaabb65b4a668d403"
# 元数据识别缓存过期时间(小时)
META_CACHE_EXPIRE: int = 0
# 电视剧动漫的分类genre_ids
ANIME_GENREIDS: List[int] = Field(default=[16])
# 用户认证站点
AUTH_SITE: str = ""
# 重启自动升级

View File

@@ -1,19 +1,43 @@
import asyncio
from typing import Any, Generator, List, Optional, Self, Tuple, AsyncGenerator, Sequence, Union
from typing import Any, Generator, List, Optional, Self, Tuple, AsyncGenerator, Union
from sqlalchemy import NullPool, QueuePool, and_, create_engine, inspect, text, select, delete
from sqlalchemy import NullPool, QueuePool, and_, create_engine, inspect, text, select, delete, Column, Integer, \
Sequence
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
from sqlalchemy.orm import Session, as_declarative, declared_attr, scoped_session, sessionmaker
from app.core.config import settings
def get_id_column():
"""
根据数据库类型返回合适的ID列定义
"""
if settings.DB_TYPE.lower() == "postgresql":
# PostgreSQL使用SERIAL类型
return Column(Integer, primary_key=True, index=True)
else:
# SQLite使用Sequence
return Column(Integer, Sequence('id'), primary_key=True, index=True)
def _get_database_engine(is_async: bool = False):
"""
获取数据库连接参数并设置WAL模式
:param is_async: 是否创建异步引擎True - 异步引擎, False - 同步引擎
:return: 返回对应的数据库引擎
"""
# 根据数据库类型选择连接方式
if settings.DB_TYPE.lower() == "postgresql":
return _get_postgresql_engine(is_async)
else:
return _get_sqlite_engine(is_async)
def _get_sqlite_engine(is_async: bool = False):
"""
获取SQLite数据库引擎
"""
# 连接参数
_connect_args = {
"timeout": settings.DB_TIMEOUT,
@@ -52,7 +76,7 @@ def _get_database_engine(is_async: bool = False):
_journal_mode = "WAL" if settings.DB_WAL_ENABLE else "DELETE"
with engine.connect() as connection:
current_mode = connection.execute(text(f"PRAGMA journal_mode={_journal_mode};")).scalar()
print(f"Database journal mode set to: {current_mode}")
print(f"SQLite database journal mode set to: {current_mode}")
return engine
else:
@@ -78,12 +102,68 @@ def _get_database_engine(is_async: bool = False):
async with async_engine.connect() as _connection:
result = await _connection.execute(text(f"PRAGMA journal_mode={_journal_mode};"))
_current_mode = result.scalar()
print(f"Async database journal mode set to: {_current_mode}")
print(f"Async SQLite database journal mode set to: {_current_mode}")
try:
asyncio.run(set_async_wal_mode())
except Exception as e:
print(f"Failed to set async WAL mode: {e}")
print(f"Failed to set async SQLite WAL mode: {e}")
return async_engine
def _get_postgresql_engine(is_async: bool = False):
"""
获取PostgreSQL数据库引擎
"""
# 构建PostgreSQL连接URL
if settings.DB_POSTGRESQL_PASSWORD:
db_url = f"postgresql://{settings.DB_POSTGRESQL_USERNAME}:{settings.DB_POSTGRESQL_PASSWORD}@{settings.DB_POSTGRESQL_HOST}:{settings.DB_POSTGRESQL_PORT}/{settings.DB_POSTGRESQL_DATABASE}"
else:
db_url = f"postgresql://{settings.DB_POSTGRESQL_USERNAME}@{settings.DB_POSTGRESQL_HOST}:{settings.DB_POSTGRESQL_PORT}/{settings.DB_POSTGRESQL_DATABASE}"
# 创建同步引擎
if not is_async:
# 根据池类型设置 poolclass 和相关参数
_pool_class = NullPool if settings.DB_POOL_TYPE == "NullPool" else QueuePool
# 数据库参数
_db_kwargs = {
"url": db_url,
"pool_pre_ping": settings.DB_POOL_PRE_PING,
"echo": settings.DB_ECHO,
"poolclass": _pool_class,
"pool_recycle": settings.DB_POOL_RECYCLE,
}
# 当使用 QueuePool 时,添加 QueuePool 特有的参数
if _pool_class == QueuePool:
_db_kwargs.update({
"pool_size": settings.DB_POSTGRESQL_POOL_SIZE,
"pool_timeout": settings.DB_POOL_TIMEOUT,
"max_overflow": settings.DB_POSTGRESQL_MAX_OVERFLOW
})
# 创建数据库引擎
engine = create_engine(**_db_kwargs)
print(f"PostgreSQL database connected to {settings.DB_POSTGRESQL_HOST}:{settings.DB_POSTGRESQL_PORT}/{settings.DB_POSTGRESQL_DATABASE}")
return engine
else:
# 构建异步PostgreSQL连接URL
async_db_url = f"postgresql+asyncpg://{settings.DB_POSTGRESQL_USERNAME}:{settings.DB_POSTGRESQL_PASSWORD}@{settings.DB_POSTGRESQL_HOST}:{settings.DB_POSTGRESQL_PORT}/{settings.DB_POSTGRESQL_DATABASE}"
# 数据库参数,只能使用 NullPool
_db_kwargs = {
"url": async_db_url,
"pool_pre_ping": settings.DB_POOL_PRE_PING,
"echo": settings.DB_ECHO,
"poolclass": NullPool,
"pool_recycle": settings.DB_POOL_RECYCLE,
}
# 创建异步数据库引擎
async_engine = create_async_engine(**_db_kwargs)
print(f"Async PostgreSQL database connected to {settings.DB_POSTGRESQL_HOST}:{settings.DB_POSTGRESQL_PORT}/{settings.DB_POSTGRESQL_DATABASE}")
return async_engine

View File

@@ -18,12 +18,22 @@ def update_db():
"""
更新数据库
"""
db_location = settings.CONFIG_PATH / 'user.db'
script_location = settings.ROOT_PATH / 'database'
try:
alembic_cfg = Config()
alembic_cfg.set_main_option('script_location', str(script_location))
alembic_cfg.set_main_option('sqlalchemy.url', f"sqlite:///{db_location}")
# 根据数据库类型设置不同的URL
if settings.DB_TYPE.lower() == "postgresql":
if settings.DB_POSTGRESQL_PASSWORD:
db_url = f"postgresql://{settings.DB_POSTGRESQL_USERNAME}:{settings.DB_POSTGRESQL_PASSWORD}@{settings.DB_POSTGRESQL_HOST}:{settings.DB_POSTGRESQL_PORT}/{settings.DB_POSTGRESQL_DATABASE}"
else:
db_url = f"postgresql://{settings.DB_POSTGRESQL_USERNAME}@{settings.DB_POSTGRESQL_HOST}:{settings.DB_POSTGRESQL_PORT}/{settings.DB_POSTGRESQL_DATABASE}"
else:
db_location = settings.CONFIG_PATH / 'user.db'
db_url = f"sqlite:///{db_location}"
alembic_cfg.set_main_option('sqlalchemy.url', db_url)
upgrade(alembic_cfg, 'head')
except Exception as e:
logger.error(f'数据库更新失败:{str(e)}')

View File

@@ -1,18 +1,18 @@
import time
from typing import Optional
from sqlalchemy import Column, Integer, String, Sequence, JSON, select
from sqlalchemy import Column, Integer, String, JSON, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from app.db import db_query, db_update, Base, async_db_query
from app.db import db_query, db_update, get_id_column, Base, async_db_query
class DownloadHistory(Base):
"""
下载历史记录
"""
id = Column(Integer, Sequence('id'), primary_key=True, index=True)
id = get_id_column()
# 保存路径
path = Column(String, nullable=False, index=True)
# 类型 电影/电视剧
@@ -188,7 +188,7 @@ class DownloadFiles(Base):
"""
下载文件记录
"""
id = Column(Integer, Sequence('id'), primary_key=True, index=True)
id = get_id_column()
# 下载器
downloader = Column(String)
# 下载任务Hash

View File

@@ -1,19 +1,19 @@
from datetime import datetime
from typing import Optional
from sqlalchemy import Column, Integer, String, Sequence, JSON
from sqlalchemy import Column, Integer, String, JSON
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from app.db import db_query, db_update, async_db_query, Base
from app.db import db_query, db_update, get_id_column, async_db_query, Base
class MediaServerItem(Base):
"""
媒体服务器媒体条目表
"""
id = Column(Integer, Sequence('id'), primary_key=True, index=True)
id = get_id_column()
# 服务器类型
server = Column(String)
# 媒体库ID

View File

@@ -1,17 +1,17 @@
from typing import Optional
from sqlalchemy import Column, Integer, String, Sequence, JSON, select
from sqlalchemy import Column, Integer, String, JSON, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from app.db import db_query, Base, async_db_query
from app.db import db_query, Base, get_id_column, async_db_query
class Message(Base):
"""
消息表
"""
id = Column(Integer, Sequence('id'), primary_key=True, index=True)
id = get_id_column()
# 消息渠道
channel = Column(String)
# 消息来源

View File

@@ -1,14 +1,14 @@
from sqlalchemy import Column, Integer, String, Sequence, JSON
from sqlalchemy import Column, String, JSON
from sqlalchemy.orm import Session
from app.db import db_query, db_update, Base
from app.db import db_query, db_update, get_id_column, Base
class PluginData(Base):
"""
插件数据表
"""
id = Column(Integer, Sequence('id'), primary_key=True, index=True)
id = get_id_column()
plugin_id = Column(String, nullable=False, index=True)
key = Column(String, index=True, nullable=False)
value = Column(JSON)

View File

@@ -1,17 +1,17 @@
from datetime import datetime
from sqlalchemy import Boolean, Column, Integer, String, Sequence, JSON, select, delete
from sqlalchemy import Boolean, Column, Integer, String, JSON, select, delete
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from app.db import db_query, db_update, Base, async_db_query, async_db_update
from app.db import db_query, db_update, Base, async_db_query, async_db_update, get_id_column
class Site(Base):
"""
站点表
"""
id = Column(Integer, Sequence('id'), primary_key=True, index=True)
id = get_id_column()
# 站点名
name = Column(String, nullable=False)
# 域名Key

View File

@@ -1,15 +1,15 @@
from sqlalchemy import Column, Integer, String, Sequence, select
from sqlalchemy import Column, String, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from app.db import db_query, Base, async_db_query
from app.db import db_query, Base, get_id_column, async_db_query
class SiteIcon(Base):
"""
站点图标表
"""
id = Column(Integer, Sequence('id'), primary_key=True, index=True)
id = get_id_column()
# 站点名称
name = Column(String, nullable=False)
# 域名Key

View File

@@ -1,17 +1,17 @@
from datetime import datetime
from sqlalchemy import Column, Integer, String, Sequence, JSON, select
from sqlalchemy import Column, Integer, String, JSON, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from app.db import db_query, db_update, Base, async_db_query
from app.db import db_query, db_update, get_id_column, Base, async_db_query
class SiteStatistic(Base):
"""
站点统计表
"""
id = Column(Integer, Sequence('id'), primary_key=True, index=True)
id = get_id_column()
# 域名Key
domain = Column(String, index=True)
# 成功次数

View File

@@ -1,18 +1,18 @@
from datetime import datetime
from typing import Optional
from sqlalchemy import Column, Integer, String, Sequence, Float, JSON, func, or_, select
from sqlalchemy import Column, Integer, String, Float, JSON, func, or_, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from app.db import db_query, Base, async_db_query
from app.db import db_query, Base, get_id_column, async_db_query
class SiteUserData(Base):
"""
站点数据表
"""
id = Column(Integer, Sequence('id'), primary_key=True, index=True)
id = get_id_column()
# 站点域名
domain = Column(String, index=True)
# 站点名称

View File

@@ -1,18 +1,18 @@
import time
from typing import Optional
from sqlalchemy import Column, Integer, String, Sequence, Float, JSON, select
from sqlalchemy import Column, Integer, String, Float, JSON, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from app.db import db_query, db_update, Base, async_db_query, async_db_update
from app.db import db_query, db_update, get_id_column, Base, async_db_query, async_db_update
class Subscribe(Base):
"""
订阅表
"""
id = Column(Integer, Sequence('id'), primary_key=True, index=True)
id = get_id_column()
# 标题
name = Column(String, nullable=False, index=True)
# 年份

View File

@@ -1,17 +1,17 @@
from typing import Optional
from sqlalchemy import Column, Integer, String, Sequence, Float, JSON, select
from sqlalchemy import Column, Integer, String, Float, JSON, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from app.db import db_query, Base, async_db_query
from app.db import db_query, Base, get_id_column, async_db_query
class SubscribeHistory(Base):
"""
订阅历史表
"""
id = Column(Integer, Sequence('id'), primary_key=True, index=True)
id = get_id_column()
# 标题
name = Column(String, nullable=False, index=True)
# 年份

View File

@@ -1,15 +1,15 @@
from sqlalchemy import Column, Integer, String, Sequence, JSON, select
from sqlalchemy import Column, String, JSON, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from app.db import db_query, db_update, Base, async_db_query
from app.db import db_query, db_update, Base, async_db_query, get_id_column
class SystemConfig(Base):
"""
配置表
"""
id = Column(Integer, Sequence('id'), primary_key=True, index=True)
id = get_id_column()
# 主键
key = Column(String, index=True)
# 值

View File

@@ -1,18 +1,18 @@
import time
from typing import Optional
from sqlalchemy import Column, Integer, String, Sequence, Boolean, func, or_, JSON, select
from sqlalchemy import Column, Integer, String, Boolean, func, or_, JSON, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from app.db import db_query, db_update, Base, async_db_query
from app.db import db_query, db_update, get_id_column, Base, async_db_query
class TransferHistory(Base):
"""
整理记录
"""
id = Column(Integer, Sequence('id'), primary_key=True, index=True)
id = get_id_column()
# 源路径
src = Column(String, index=True)
# 源存储

View File

@@ -1,8 +1,8 @@
from sqlalchemy import Boolean, Column, Integer, JSON, Sequence, String, select
from sqlalchemy import Boolean, Column, JSON, String, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from app.db import Base, db_query, db_update, async_db_query, async_db_update
from app.db import Base, db_query, db_update, async_db_query, async_db_update, get_id_column
class User(Base):
@@ -10,7 +10,7 @@ class User(Base):
用户表
"""
# ID
id = Column(Integer, Sequence('id'), primary_key=True, index=True)
id = get_id_column()
# 用户名,唯一值
name = Column(String, index=True, nullable=False)
# 邮箱

View File

@@ -1,14 +1,14 @@
from sqlalchemy import Column, Integer, String, Sequence, UniqueConstraint, Index, JSON
from sqlalchemy import Column, String, UniqueConstraint, Index, JSON
from sqlalchemy.orm import Session
from app.db import db_query, db_update, Base
from app.db import db_query, db_update, get_id_column, Base
class UserConfig(Base):
"""
用户配置表
"""
id = Column(Integer, Sequence('id'), primary_key=True, index=True)
id = get_id_column()
# 用户名
username = Column(String, index=True)
# 配置键

View File

@@ -1,10 +1,10 @@
from datetime import datetime
from typing import Optional
from sqlalchemy import Column, Integer, JSON, Sequence, String, and_, or_, select
from sqlalchemy import Column, Integer, JSON, String, and_, or_, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import Base, db_query, db_update, async_db_query, async_db_update
from app.db import Base, db_query, get_id_column, db_update, async_db_query, async_db_update
class Workflow(Base):
@@ -12,7 +12,7 @@ class Workflow(Base):
工作流表
"""
# ID
id = Column(Integer, Sequence('id'), primary_key=True, index=True)
id = get_id_column()
# 名称
name = Column(String, index=True, nullable=False)
# 描述