Files

229 lines
7.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
MoviePilot 数据库操作脚本。
脚本从项目配置读取数据库连接参数,不要求 Agent 在提示词中接触数据库密码。
默认只允许查询语句;写操作必须显式传入 --write。
"""
from __future__ import annotations
import argparse
import json
import re
import sys
from pathlib import Path
from typing import Any, Optional
from sqlalchemy import create_engine, inspect, text
from sqlalchemy.engine import Engine
from sqlalchemy.exc import SQLAlchemyError
SCRIPT_PATH = Path(__file__).resolve()
PROJECT_ROOT = SCRIPT_PATH.parents[3]
WRITE_STATEMENT_RE = re.compile(
r"^\s*(insert|update|delete|drop|alter|truncate|create|replace)\b",
re.IGNORECASE,
)
WRITE_KEYWORD_RE = re.compile(
r"\b(insert|update|delete|drop|alter|truncate|create|replace)\b",
re.IGNORECASE,
)
SELECT_STATEMENT_RE = re.compile(r"^\s*(select|with|explain)\b", re.IGNORECASE)
def _ensure_project_import() -> None:
"""确保脚本可以从任意工作目录导入 MoviePilot 项目模块。"""
project_path = str(PROJECT_ROOT)
if project_path not in sys.path:
sys.path.insert(0, project_path)
def _load_settings() -> Any:
"""读取 MoviePilot 运行配置。"""
_ensure_project_import()
from app.core.config import settings # pylint: disable=import-outside-toplevel
return settings
def _build_engine() -> Engine:
"""根据 MoviePilot 配置创建同步数据库引擎。"""
settings = _load_settings()
if str(settings.DB_TYPE).lower() == "postgresql":
return create_engine(
settings.DB_POSTGRESQL_URL(),
pool_pre_ping=settings.DB_POOL_PRE_PING,
pool_recycle=settings.DB_POOL_RECYCLE,
)
return create_engine(
f"sqlite:///{settings.CONFIG_PATH}/user.db",
pool_pre_ping=settings.DB_POOL_PRE_PING,
pool_recycle=settings.DB_POOL_RECYCLE,
connect_args={"timeout": settings.DB_TIMEOUT},
)
def _normalize_sql(sql: str) -> str:
"""去除 SQL 首尾空白和末尾分号。"""
return sql.strip().rstrip(";")
def _contains_multiple_statements(sql: str) -> bool:
"""判断 SQL 是否包含多语句分隔符。"""
return ";" in sql
def _is_write_statement(sql: str) -> bool:
"""判断 SQL 是否为写操作或结构变更操作。"""
return bool(WRITE_STATEMENT_RE.match(sql))
def _is_supported_statement(sql: str, allow_write: bool) -> bool:
"""判断 SQL 是否在当前权限模式下允许执行。"""
if _contains_multiple_statements(sql):
return False
if allow_write:
return True
return bool(SELECT_STATEMENT_RE.match(sql)) and not WRITE_KEYWORD_RE.search(sql)
def _append_limit(sql: str, limit: int) -> str:
"""为普通 SELECT 查询追加默认 LIMIT避免输出过大。"""
if limit <= 0:
return sql
lowered = sql.lower()
if not lowered.lstrip().startswith("select"):
return sql
if re.search(r"\blimit\s+\d+\b", lowered):
return sql
return f"{sql} LIMIT {limit}"
def _row_to_dict(row: Any) -> dict[str, Any]:
"""将 SQLAlchemy 行对象转为普通字典。"""
return dict(row._mapping)
def _print_json(payload: Any) -> None:
"""输出 JSON 结果。"""
print(json.dumps(payload, ensure_ascii=False, indent=2, default=str))
def list_tables() -> int:
"""列出当前数据库中的数据表。"""
engine = _build_engine()
inspector = inspect(engine)
_print_json({"tables": sorted(inspector.get_table_names())})
return 0
def show_schema(table_name: str) -> int:
"""显示指定数据表的字段结构。"""
engine = _build_engine()
inspector = inspect(engine)
columns = [
{
"name": column.get("name"),
"type": str(column.get("type")),
"nullable": column.get("nullable"),
"default": column.get("default"),
"primary_key": bool(column.get("primary_key")),
}
for column in inspector.get_columns(table_name)
]
_print_json({"table": table_name, "columns": columns})
return 0
def run_query(sql: str, *, limit: int = 100, allow_write: bool = False) -> int:
"""
执行 SQL 语句并输出 JSON 结果。
:param sql: 要执行的 SQL 语句
:param limit: 查询语句默认追加的最大行数
:param allow_write: 是否允许写操作或结构变更操作
:return: 进程退出码
"""
normalized_sql = _normalize_sql(sql)
if not normalized_sql:
print("Error: SQL is empty", file=sys.stderr)
return 1
if not _is_supported_statement(normalized_sql, allow_write):
print("Error: write statements require --write", file=sys.stderr)
return 1
statement_sql = _append_limit(normalized_sql, limit)
engine = _build_engine()
try:
with engine.begin() as connection:
result = connection.execute(text(statement_sql))
if result.returns_rows:
rows = [_row_to_dict(row) for row in result.fetchall()]
_print_json({"rows": rows, "row_count": len(rows)})
else:
_print_json({"row_count": result.rowcount})
except SQLAlchemyError as err:
print(f"Error: {err}", file=sys.stderr)
return 1
return 0
def _read_sql(sql: Optional[str], sql_file: Optional[str]) -> str:
"""从参数或文件读取 SQL 文本。"""
if sql:
return sql
if sql_file:
return Path(sql_file).read_text(encoding="utf-8")
if not sys.stdin.isatty():
return sys.stdin.read()
return ""
def _build_parser() -> argparse.ArgumentParser:
"""构建命令行参数解析器。"""
parser = argparse.ArgumentParser(description="MoviePilot database operation helper")
subparsers = parser.add_subparsers(dest="command")
query_parser = subparsers.add_parser("query", help="execute a SQL statement")
query_parser.add_argument("sql", nargs="?", help="SQL statement")
query_parser.add_argument("--file", dest="sql_file", help="read SQL from file")
query_parser.add_argument("--limit", type=int, default=100, help="default SELECT row limit")
query_parser.add_argument("--write", action="store_true", help="allow write statements")
write_parser = subparsers.add_parser("write", help="execute a write statement")
write_parser.add_argument("sql", nargs="?", help="SQL statement")
write_parser.add_argument("--file", dest="sql_file", help="read SQL from file")
subparsers.add_parser("tables", help="list tables")
schema_parser = subparsers.add_parser("schema", help="show table schema")
schema_parser.add_argument("table_name", help="table name")
return parser
def main() -> int:
"""执行命令行入口。"""
parser = _build_parser()
args = parser.parse_args()
if args.command == "tables":
return list_tables()
if args.command == "schema":
return show_schema(args.table_name)
if args.command == "query":
sql = _read_sql(args.sql, args.sql_file)
return run_query(sql, limit=args.limit, allow_write=args.write)
if args.command == "write":
sql = _read_sql(args.sql, args.sql_file)
return run_query(sql, limit=0, allow_write=True)
parser.print_help()
return 0
if __name__ == "__main__":
sys.exit(main())