Merge pull request #411 from EstrellaXD/orm-change

feat: Change orm from hand-made code to SQLModel
This commit is contained in:
Estrella Pan
2023-07-31 16:29:39 +08:00
committed by GitHub
30 changed files with 264 additions and 745 deletions

View File

@@ -23,3 +23,4 @@ python-jose==3.3.0
passlib==1.7.4
bcrypt==4.0.1
python-multipart==0.0.6
sqlmodel

View File

@@ -2,13 +2,13 @@ from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.responses import JSONResponse
from module.manager import TorrentManager
from module.models import BangumiData
from module.models import Bangumi
from module.security import get_current_user
router = APIRouter(prefix="/bangumi", tags=["bangumi"])
@router.get("/getAll", response_model=list[BangumiData])
@router.get("/getAll", response_model=list[Bangumi])
async def get_all_data(current_user=Depends(get_current_user)):
if not current_user:
raise HTTPException(
@@ -18,7 +18,7 @@ async def get_all_data(current_user=Depends(get_current_user)):
return torrent.search_all()
@router.get("/getData/{bangumi_id}", response_model=BangumiData)
@router.get("/getData/{bangumi_id}", response_model=Bangumi)
async def get_data(bangumi_id: str, current_user=Depends(get_current_user)):
if not current_user:
raise HTTPException(
@@ -29,7 +29,7 @@ async def get_data(bangumi_id: str, current_user=Depends(get_current_user)):
@router.post("/updateRule")
async def update_rule(data: BangumiData, current_user=Depends(get_current_user)):
async def update_rule(data: Bangumi, current_user=Depends(get_current_user)):
if not current_user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token"

View File

@@ -1,7 +1,7 @@
from fastapi import APIRouter, Depends, HTTPException, status
from module.manager import SeasonCollector
from module.models import BangumiData
from module.models import Bangumi
from module.models.api import RssLink
from module.rss import analyser
from module.security import get_current_user
@@ -23,9 +23,7 @@ async def analysis(link: RssLink, current_user=Depends(get_current_user)):
@router.post("/collection")
async def download_collection(
data: BangumiData, current_user=Depends(get_current_user)
):
async def download_collection(data: Bangumi, current_user=Depends(get_current_user)):
if not current_user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token"
@@ -41,7 +39,7 @@ async def download_collection(
@router.post("/subscribe")
async def subscribe(data: BangumiData, current_user=Depends(get_current_user)):
async def subscribe(data: Bangumi, current_user=Depends(get_current_user)):
if not current_user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token"

View File

@@ -4,7 +4,7 @@ from .config import VERSION, settings
from .log import LOG_PATH, setup_logger
TMDB_API = "32b19d6a05b512190a056fa4e747cbbc"
DATA_PATH = Path("data/data.db")
DATA_PATH = "sqlite:///data/data.db"
LEGACY_DATA_PATH = Path("data/data.json")
PLATFORM = "Windows" if "\\" in settings.downloader.path else "Unix"

View File

@@ -1,121 +1,102 @@
import logging
from module.database.orm import Connector
from module.models import BangumiData
from module.conf import DATA_PATH
from sqlmodel import Session, select, delete, or_
from sqlalchemy.sql import func
from typing import Optional
from .engine import engine
from module.models import Bangumi
logger = logging.getLogger(__name__)
class BangumiDatabase(Connector):
def __init__(self, database: str = DATA_PATH):
super().__init__(
table_name="bangumi",
data=self.__data_to_db(BangumiData()),
database=database,
)
class BangumiDatabase(Session):
def __init__(self, _engine=engine):
super().__init__(_engine)
def update_table(self):
self.update.table()
@staticmethod
def __data_to_db(data: BangumiData) -> dict:
db_data = data.dict()
for key, value in db_data.items():
if isinstance(value, bool):
db_data[key] = int(value)
elif isinstance(value, list):
db_data[key] = ",".join(value)
return db_data
@staticmethod
def __db_to_data(db_data: dict) -> BangumiData:
for key, item in db_data.items():
if isinstance(item, int):
if key not in ["id", "offset", "season", "year"]:
db_data[key] = bool(item)
elif key in ["filter", "rss_link"]:
db_data[key] = item.split(",")
return BangumiData(**db_data)
def insert_one(self, data: BangumiData):
db_data = self.__data_to_db(data)
self.insert.one(db_data)
def insert_one(self, data: Bangumi):
self.add(data)
self.commit()
logger.debug(f"[Database] Insert {data.official_title} into database.")
# if self.__check_exist(data):
# self.update_one(data)
# else:
# db_data = self.__data_to_db(data)
# db_data["id"] = self.gen_id()
# self._insert(db_data=db_data, table_name=self.__table_name)
# logger.debug(f"[Database] Insert {data.official_title} into database.")
def insert_list(self, data: list[BangumiData]):
data_list = [self.__data_to_db(x) for x in data]
self.insert.many(data_list)
# _id = self.gen_id()
# for i, item in enumerate(data):
# item.id = _id + i
# data_list = [self.__data_to_db(x) for x in data]
# self._insert_list(data_list=data_list, table_name=self.__table_name)
def insert_list(self, data: list[Bangumi]):
self.add_all(data)
logger.debug(f"[Database] Insert {len(data)} bangumi into database.")
def update_one(self, data: BangumiData) -> bool:
db_data = self.__data_to_db(data)
return self.update.one(db_data)
def update_one(self, data: Bangumi) -> bool:
db_data = self.get(Bangumi, data.id)
if not db_data:
return False
bangumi_data = data.dict(exclude_unset=True)
for key, value in bangumi_data.items():
setattr(db_data, key, value)
self.add(db_data)
self.commit()
self.refresh(db_data)
logger.debug(f"[Database] Update {data.official_title}")
return True
def update_list(self, data: list[BangumiData]):
data_list = [self.__data_to_db(x) for x in data]
self.update.many(data_list)
def update_list(self, datas: list[Bangumi]):
for data in datas:
self.update_one(data)
def update_rss(self, title_raw, rss_set: str):
# Update rss and added
location = {"title_raw": title_raw}
set_value = {"rss_link": rss_set, "added": 0}
self.update.value(location, set_value)
statement = select(Bangumi).where(Bangumi.title_raw == title_raw)
bangumi = self.exec(statement).first()
bangumi.rss_link = rss_set
bangumi.added = False
self.add(bangumi)
self.commit()
self.refresh(bangumi)
logger.debug(f"[Database] Update {title_raw} rss_link to {rss_set}.")
def update_poster(self, title_raw, poster_link: str):
location = {"title_raw": title_raw}
set_value = {"poster_link": poster_link}
self.update.value(location, set_value)
statement = select(Bangumi).where(Bangumi.title_raw == title_raw)
bangumi = self.exec(statement).first()
bangumi.poster_link = poster_link
self.add(bangumi)
self.commit()
self.refresh(bangumi)
logger.debug(f"[Database] Update {title_raw} poster_link to {poster_link}.")
def delete_one(self, _id: int):
self.delete.one(_id)
statement = select(Bangumi).where(Bangumi.id == _id)
bangumi = self.exec(statement).first()
self.delete(bangumi)
self.commit()
logger.debug(f"[Database] Delete bangumi id: {_id}.")
def delete_all(self):
self.delete.all()
statement = delete(Bangumi)
self.exec(statement)
self.commit()
def search_all(self) -> list[BangumiData]:
all_data = self.select.all()
return [self.__db_to_data(x) for x in all_data]
def search_all(self) -> list[Bangumi]:
statement = select(Bangumi)
return self.exec(statement).all()
def search_id(self, _id: int) -> BangumiData | None:
dict_data = self.select.one(conditions={"id": _id})
if dict_data is None:
def search_id(self, _id: int) -> Optional[Bangumi]:
statement = select(Bangumi).where(Bangumi.id == _id)
bangumi = self.exec(statement).first()
if bangumi is None:
logger.warning(f"[Database] Cannot find bangumi id: {_id}.")
return None
logger.debug(f"[Database] Find bangumi id: {_id}.")
return self.__db_to_data(dict_data)
else:
logger.debug(f"[Database] Find bangumi id: {_id}.")
return self.exec(statement).first()
def match_poster(self, bangumi_name: str) -> str:
condition = {"official_title": bangumi_name}
keys = ["poster_link"]
data = self.select.one(
keys=keys,
conditions=condition,
combine_operator="INSTR",
)
if not data:
# Use like to match
statement = select(Bangumi).where(func.instr(bangumi_name, Bangumi.title_raw) > 0)
data = self.exec(statement).first()
if data:
return data.poster_link
else:
return ""
return data.get("poster_link")
def match_list(self, torrent_list: list, rss_link: str) -> list:
# Match title_raw in database
keys = ["title_raw", "rss_link", "poster_link"]
match_datas = self.select.column(keys)
match_datas = self.search_all()
if not match_datas:
return torrent_list
# Match title
@@ -123,36 +104,43 @@ class BangumiDatabase(Connector):
while i < len(torrent_list):
torrent = torrent_list[i]
for match_data in match_datas:
if match_data.get("title_raw") in torrent.name:
if rss_link not in match_data.get("rss_link"):
match_data["rss_link"] += f",{rss_link}"
self.update_rss(
match_data.get("title_raw"), match_data.get("rss_link")
)
if not match_data.get("poster_link"):
self.update_poster(
match_data.get("title_raw"), torrent.poster_link
)
if match_data.title_raw in torrent.name:
if rss_link not in match_data.rss_link:
match_data.rss_link += f",{rss_link}"
self.update_rss(match_data.title_raw, match_data.rss_link)
if not match_data.poster_link:
self.update_poster(match_data.title_raw, torrent.poster_link)
torrent_list.pop(i)
break
else:
i += 1
return torrent_list
def not_complete(self) -> list[BangumiData]:
def not_complete(self) -> list[Bangumi]:
# Find eps_complete = False
condition = {"eps_collect": 0}
dict_data = self.select.many(
conditions=condition,
)
return [self.__db_to_data(x) for x in dict_data]
condition = select(Bangumi).where(Bangumi.eps_collect == 0)
datas = self.exec(condition).all()
return datas
def not_added(self) -> list[BangumiData]:
conditions = {"added": 0, "rule_name": None, "save_path": None}
dict_data = self.select.many(conditions=conditions, combine_operator="OR")
return [self.__db_to_data(x) for x in dict_data]
def not_added(self) -> list[Bangumi]:
conditions = select(Bangumi).where(
or_(
Bangumi.added == 0, Bangumi.rule_name is None, Bangumi.save_path is None
)
)
datas = self.exec(conditions).all()
return datas
def disable_rule(self, _id: int):
statement = select(Bangumi).where(Bangumi.id == _id)
bangumi = self.exec(statement).first()
bangumi.deleted = True
self.add(bangumi)
self.commit()
self.refresh(bangumi)
logger.debug(f"[Database] Disable rule {bangumi.title_raw}.")
if __name__ == "__main__":
with BangumiDatabase() as db:
print(db.match_poster("久保"))
print(db.not_complete())

View File

@@ -1,174 +0,0 @@
import logging
import os
import sqlite3
from module.conf import DATA_PATH
logger = logging.getLogger(__name__)
class DataConnector:
def __init__(self):
# Create folder if not exists
DATA_PATH.parent.mkdir(parents=True, exist_ok=True)
self._conn = sqlite3.connect(DATA_PATH)
self._cursor = self._conn.cursor()
def _update_table(self, table_name: str, db_data: dict):
columns = ", ".join(
[
f"{key} {self.__python_to_sqlite_type(value)}"
for key, value in db_data.items()
]
)
create_table_sql = f"CREATE TABLE IF NOT EXISTS {table_name} ({columns});"
self._cursor.execute(create_table_sql)
self._cursor.execute(f"PRAGMA table_info({table_name})")
existing_columns = {
column_info[1]: column_info for column_info in self._cursor.fetchall()
}
for key, value in db_data.items():
if key not in existing_columns:
insert_column = self.__python_to_sqlite_type(value)
if value is None:
value = "NULL"
add_column_sql = f"ALTER TABLE {table_name} ADD COLUMN {key} {insert_column} DEFAULT {value};"
self._cursor.execute(add_column_sql)
self._conn.commit()
logger.debug(f"Create / Update table {table_name}.")
def _insert(self, table_name: str, db_data: dict):
columns = ", ".join(db_data.keys())
values = ", ".join([f":{key}" for key in db_data.keys()])
self._cursor.execute(
f"INSERT INTO {table_name} ({columns}) VALUES ({values})", db_data
)
self._conn.commit()
def _insert_list(self, table_name: str, data_list: list[dict]):
columns = ", ".join(data_list[0].keys())
values = ", ".join([f":{key}" for key in data_list[0].keys()])
self._cursor.executemany(
f"INSERT INTO {table_name} ({columns}) VALUES ({values})", data_list
)
self._conn.commit()
def _select(self, keys: list[str], table_name: str, condition: str = None) -> dict:
if condition is None:
self._cursor.execute(f"SELECT {', '.join(keys)} FROM {table_name}")
else:
self._cursor.execute(
f"SELECT {', '.join(keys)} FROM {table_name} WHERE {condition}"
)
return dict(zip(keys, self._cursor.fetchone()))
def _update(self, table_name: str, db_data: dict):
_id = db_data.get("id")
if _id is None:
raise ValueError("No _id in db_data.")
set_sql = ", ".join([f"{key} = :{key}" for key in db_data.keys()])
self._cursor.execute(
f"UPDATE {table_name} SET {set_sql} WHERE id = {_id}", db_data
)
self._conn.commit()
return self._cursor.rowcount == 1
def _update_list(self, table_name: str, data_list: list[dict]):
if len(data_list) == 0:
return
set_sql = ", ".join(
[f"{key} = :{key}" for key in data_list[0].keys() if key != "id"]
)
self._cursor.executemany(
f"UPDATE {table_name} SET {set_sql} WHERE id = :id", data_list
)
self._conn.commit()
def _update_section(self, table_name: str, location: dict, update_dict: dict):
set_sql = ", ".join([f"{key} = :{key}" for key in update_dict.keys()])
sql_loc = f"{location['key']} = {location['value']}"
self._cursor.execute(
f"UPDATE {table_name} SET {set_sql} WHERE {sql_loc}", update_dict
)
self._conn.commit()
def _delete_all(self, table_name: str):
self._cursor.execute(f"DELETE FROM {table_name}")
self._conn.commit()
def _delete(self, table_name: str, condition: dict):
condition_sql = " AND ".join([f"{key} = :{key}" for key in condition.keys()])
self._cursor.execute(
f"DELETE FROM {table_name} WHERE {condition_sql}", condition
)
self._conn.commit()
def _search(
self, table_name: str, keys: list[str] | None = None, condition: dict = None
):
if keys is None:
select_sql = "*"
else:
select_sql = ", ".join(keys)
if condition is None:
self._cursor.execute(f"SELECT {select_sql} FROM {table_name}")
else:
custom_condition = condition.pop("_custom_condition", None)
condition_sql = " AND ".join(
[f"{key} = :{key}" for key in condition.keys()]
) + (f" AND {custom_condition}" if custom_condition else "")
self._cursor.execute(
f"SELECT {select_sql} FROM {table_name} WHERE {condition_sql}",
condition,
)
def _search_data(
self, table_name: str, keys: list[str] | None = None, condition: dict = None
) -> dict:
if keys is None:
keys = self.__get_table_columns(table_name)
self._search(table_name, keys, condition)
return dict(zip(keys, self._cursor.fetchone()))
def _search_datas(
self, table_name: str, keys: list[str] | None = None, condition: dict = None
) -> list[dict]:
if keys is None:
keys = self.__get_table_columns(table_name)
self._search(table_name, keys, condition)
return [dict(zip(keys, row)) for row in self._cursor.fetchall()]
def _table_exists(self, table_name: str) -> bool:
self._cursor.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name=?;",
(table_name,),
)
return len(self._cursor.fetchall()) == 1
def __get_table_columns(self, table_name: str) -> list[str]:
self._cursor.execute(f"PRAGMA table_info({table_name})")
return [column_info[1] for column_info in self._cursor.fetchall()]
@staticmethod
def __python_to_sqlite_type(value) -> str:
if isinstance(value, int):
return "INTEGER NOT NULL"
elif isinstance(value, float):
return "REAL NOT NULL"
elif isinstance(value, str):
return "TEXT NOT NULL"
elif isinstance(value, bool):
return "INTEGER NOT NULL"
elif isinstance(value, list):
return "TEXT NOT NULL"
elif value is None:
return "TEXT"
else:
raise ValueError(f"Unsupported data type: {type(value)}")
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self._conn.close()

View File

@@ -0,0 +1,7 @@
from sqlmodel import create_engine, Session
from module.conf import DATA_PATH
engine = create_engine(DATA_PATH)
db_session = Session(engine)

View File

@@ -1 +0,0 @@
from .connector import Connector

View File

@@ -1,71 +0,0 @@
from os import PathLike
from pathlib import Path
import sqlite3
from .delete import Delete
from .insert import Insert
from .select import Select
from .update import Update
from module.conf import DATA_PATH
class Connector:
def __init__(
self, table_name: str, data: dict, database: PathLike[str] | Path = DATA_PATH
):
# Create folder if not exists
if isinstance(database, (PathLike, str)):
database = Path(database)
database.parent.mkdir(parents=True, exist_ok=True)
self._conn = sqlite3.connect(database)
self._cursor = self._conn.cursor()
self.update = Update(self, table_name, data)
self.insert = Insert(self, table_name, data)
self.select = Select(self, table_name, data)
self.delete = Delete(self, table_name, data)
self._columns = self.__get_columns(table_name)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
self._conn.close()
def __get_columns(self, table_name: str) -> list[str]:
self._cursor.execute(f"PRAGMA table_info({table_name})")
return [x[1] for x in self._cursor.fetchall()]
def execute(self, sql: str, params: tuple = None):
if params is None:
self._cursor.execute(sql)
else:
self._cursor.execute(sql, params)
self._conn.commit()
def executemany(self, sql: str, params: list[tuple]):
self._cursor.executemany(sql, params)
self._conn.commit()
def fetchall(self, keys: str = None) -> list[dict]:
datas = self._cursor.fetchall()
if keys:
return [dict(zip(keys, data)) for data in datas]
return [dict(zip(self._columns, data)) for data in datas]
def fetchone(self, keys: list[str] = None) -> dict:
data = self._cursor.fetchone()
if data:
if keys:
return dict(zip(keys, data))
return dict(zip(self._columns, data))
def fetchmany(self, keys: list[str], size: int) -> list[dict]:
datas = self._cursor.fetchmany(size)
if keys:
return [dict(zip(keys, data)) for data in datas]
return [dict(zip(self._columns, data)) for data in datas]
def fetch(self):
return self._cursor.fetchall()

View File

@@ -1,23 +0,0 @@
class Delete:
def __init__(self, connector, table_name: str, data: dict):
self._connector = connector
self._table_name = table_name
self._data = data
def one(self, _id: int) -> bool:
self._connector.execute(
f"""
DELETE FROM {self._table_name}
WHERE id = :id
""",
{"id": _id},
)
return True
def all(self):
self._connector.execute(
f"""
DELETE FROM {self._table_name}
""",
)
return True

View File

@@ -1,33 +0,0 @@
class Insert:
def __init__(self, connector, table_name: str, data: dict):
self._connector = connector
self._table_name = table_name
self._columns = data.items()
def __gen_id(self) -> int:
self._connector.execute(
f"""
SELECT MAX(id) FROM {self._table_name}
""",
)
max_id = self._connector.fetchone(keys=["id"]).get("id")
if max_id is None:
return 1
return max_id + 1
def one(self, data: dict):
_id = self.__gen_id()
data["id"] = _id
columns = ", ".join(data.keys())
placeholders = ", ".join([f":{key}" for key in data.keys()])
self._connector.execute(
f"""
INSERT INTO {self._table_name} ({columns})
VALUES ({placeholders})
""",
data,
)
def many(self, data: list[dict]):
for item in data:
self.one(item)

View File

@@ -1,96 +0,0 @@
class Select:
def __init__(self, connector, table_name: str, data: dict):
self._connector = connector
self._table_name = table_name
self._data = data
def id(self, _id: int):
self._connector.execute(
f"""
SELECT * FROM {self._table_name}
WHERE id = :id
""",
{"id": _id},
)
return self._connector.fetchone()
def all(self, limit: int = None):
if limit is None:
limit = 10000
self._connector.execute(
f"""
SELECT * FROM {self._table_name} LIMIT {limit}
""",
)
return self._connector.fetchall()
def one(
self,
keys: list[str] | None = None,
conditions: dict = None,
combine_operator: str = "AND",
):
if keys is None:
columns = "*"
else:
columns = ", ".join(keys)
condition_sql = self.__select_condition(conditions, combine_operator)
self._connector.execute(
f"""
SELECT {columns} FROM {self._table_name}
WHERE {condition_sql}
""",
conditions,
)
return self._connector.fetchone(keys)
def many(
self,
keys: list[str] | None = None,
conditions: dict = None,
combine_operator: str = "AND",
limit: int = None,
):
if keys is None:
columns = "*"
else:
columns = ", ".join(keys)
if limit is None:
limit = 10000
condition_sql = self.__select_condition(conditions, combine_operator)
self._connector.execute(
f"""
SELECT {columns} FROM {self._table_name}
WHERE {condition_sql}
LIMIT {limit}
""",
conditions,
)
return self._connector.fetchall(keys)
def column(self, keys: list[str]):
columns = ", ".join(keys)
self._connector.execute(
f"""
SELECT {columns} FROM {self._table_name}
""",
)
return self._connector.fetchall(keys)
@staticmethod
def __select_condition(conditions: dict, combine_operator: str = "AND"):
if not conditions:
raise ValueError("No conditions provided.")
if combine_operator not in ["AND", "OR", "INSTR"]:
raise ValueError(
"Invalid combine_operator, must be 'AND' or 'OR' or 'INSTR'."
)
if combine_operator == "INSTR":
condition_sql = f" AND ".join(
[f"INSTR({key}, :{key})" for key in conditions.keys()]
)
else:
condition_sql = f" {combine_operator} ".join(
[f"{key} = :{key}" for key in conditions.keys()]
)
return condition_sql

View File

@@ -1,98 +0,0 @@
import logging
logger = logging.getLogger(__name__)
class Update:
def __init__(self, connector, table_name: str, data: dict):
self._connector = connector
self._table_name = table_name
self._example_data = data
def __table_exists(self) -> bool:
self._connector.execute(
f"""
SELECT name FROM sqlite_master
WHERE type='table' AND name='{self._table_name}'
"""
)
return self._connector.fetch() is not None
def table(self):
columns = ", ".join(
[
f"{key} {self.__python_to_sqlite_type(value)}"
for key, value in self._example_data.items()
]
)
create_table_sql = f"CREATE TABLE IF NOT EXISTS {self._table_name} ({columns});"
self._connector.execute(create_table_sql)
logger.debug(f"Create table {self._table_name}.")
self._connector.execute(f"PRAGMA table_info({self._table_name})")
existing_columns = [x[1] for x in self._connector.fetch()]
for key, value in self._example_data.items():
if key not in existing_columns:
insert_column = self.__python_to_sqlite_type(value)
if value is None:
value = "NULL"
add_column_sql = f"ALTER TABLE {self._table_name} ADD COLUMN {key} {insert_column} DEFAULT {value};"
self._connector.execute(add_column_sql)
logger.debug(f"Update table {self._table_name}.")
def one(self, data: dict) -> bool:
_id = data["id"]
set_sql = ", ".join([f"{key} = :{key}" for key in data.keys()])
self._connector.execute(
f"""
UPDATE {self._table_name}
SET {set_sql}
WHERE id = :id
""",
data,
)
logger.debug(f"Update {_id} in {self._table_name}.")
return True
def many(self, data: list[dict]) -> bool:
columns = ", ".join([f"{key} = :{key}" for key in data[0].keys()])
self._connector.executemany(
f"""
UPDATE {self._table_name}
SET {columns}
WHERE id = :id
""",
data,
)
logger.debug(f"Update {self._table_name}.")
return True
def value(self, location: dict, set_value: dict) -> bool:
set_sql = ", ".join([f"{key} = :{key}" for key in set_value.keys()])
params = {**location, **set_value}
self._connector.execute(
f"""
UPDATE {self._table_name}
SET {set_sql}
WHERE {location["key"]} = :{location["key"]}
""",
params,
)
logger.debug(f"Update {self._table_name}.")
return True
@staticmethod
def __python_to_sqlite_type(value) -> str:
if isinstance(value, int):
return "INTEGER NOT NULL"
elif isinstance(value, float):
return "REAL NOT NULL"
elif isinstance(value, str):
return "TEXT NOT NULL"
elif isinstance(value, bool):
return "INTEGER NOT NULL"
elif isinstance(value, list):
return "TEXT NOT NULL"
elif value is None:
return "TEXT"
else:
raise ValueError(f"Unsupported data type: {type(value)}")

View File

@@ -2,72 +2,58 @@ import logging
from fastapi import HTTPException
from module.database.connector import DataConnector
from module.models.user import User
from module.models.user import User, UserUpdate, UserLogin
from module.security.jwt import get_password_hash, verify_password
from module.database.engine import engine
from sqlmodel import Session, select, SQLModel
from sqlalchemy.exc import UnboundExecutionError, OperationalError
logger = logging.getLogger(__name__)
class AuthDB(DataConnector):
class UserDatabase(Session):
def __init__(self):
super().__init__()
self.__table_name = "user"
if not self._table_exists(self.__table_name):
self.__update_table()
def __update_table(self):
db_data = self.__data_to_db(User())
self._update_table(self.__table_name, db_data)
self._insert(self.__table_name, db_data)
@staticmethod
def __data_to_db(data: User) -> dict:
db_data = data.dict()
db_data["password"] = get_password_hash(db_data["password"])
return db_data
@staticmethod
def __db_to_data(db_data: dict) -> User:
return User(**db_data)
super().__init__(engine)
statement = select(User)
try:
self.exec(statement)
except OperationalError:
SQLModel.metadata.create_all(engine)
self.add(User())
self.commit()
def get_user(self, username):
self._cursor.execute(
f"SELECT * FROM {self.__table_name} WHERE username=?", (username,)
)
result = self._cursor.fetchone()
statement = select(User).where(User.username == username)
result = self.exec(statement).first()
if not result:
return None
db_data = dict(zip([x[0] for x in self._cursor.description], result))
return self.__db_to_data(db_data)
raise HTTPException(status_code=404, detail="User not found")
return result
def auth_user(self, username, password) -> bool:
self._cursor.execute(
f"SELECT username, password FROM {self.__table_name} WHERE username=?",
(username,),
)
result = self._cursor.fetchone()
def auth_user(self, user: UserLogin) -> bool:
statement = select(User).where(User.username == user.username)
result = self.exec(statement).first()
if not result:
raise HTTPException(status_code=401, detail="User not found")
if not verify_password(password, result[1]):
if not verify_password(user.password, result.password):
raise HTTPException(status_code=401, detail="Password error")
return True
def update_user(self, username, update_user: User):
def update_user(self, username, update_user: UserUpdate):
# Update username and password
new_username = update_user.username
new_password = update_user.password
self._cursor.execute(
f"""
UPDATE {self.__table_name}
SET username = '{new_username}', password = '{get_password_hash(new_password)}'
WHERE username = '{username}'
"""
)
self._conn.commit()
statement = select(User).where(User.username == username)
result = self.exec(statement).first()
if not result:
raise HTTPException(status_code=404, detail="User not found")
if update_user.username:
result.username = update_user.username
if update_user.password:
result.password = get_password_hash(update_user.password)
self.add(result)
self.commit()
return result
if __name__ == "__main__":
with AuthDB() as db:
with UserDatabase() as db:
# db.update_user(UserLogin(username="admin", password="adminadmin"), User(username="admin", password="cica1234"))
db.update_user("admin", User(username="estrella", password="cica1234"))
db.update_user("admin", UserUpdate(username="estrella", password="cica1234"))

View File

@@ -1,7 +1,7 @@
import logging
from module.conf import settings
from module.models import BangumiData
from module.models import Bangumi
from .path import TorrentPath
@@ -68,7 +68,7 @@ class DownloadClient(TorrentPath):
prefs = self.client.get_app_prefs()
settings.downloader.path = self._join_path(prefs["save_path"], "Bangumi")
def set_rule(self, data: BangumiData):
def set_rule(self, data: Bangumi):
data.rule_name = self._rule_name(data)
data.save_path = self._gen_save_path(data)
rule = {
@@ -92,7 +92,7 @@ class DownloadClient(TorrentPath):
f"[Downloader] Add {data.official_title} Season {data.season} to auto download rules."
)
def set_rules(self, bangumi_info: list[BangumiData]):
def set_rules(self, bangumi_info: list[Bangumi]):
logger.debug("[Downloader] Start adding rules.")
for info in bangumi_info:
self.set_rule(info)

View File

@@ -4,8 +4,7 @@ import re
from pathlib import Path
from module.conf import settings
from module.models import BangumiData
from module.models import Bangumi
logger = logging.getLogger(__name__)
@@ -50,7 +49,7 @@ class TorrentPath:
return self._file_depth(file_path) <= 2
@staticmethod
def _gen_save_path(data: BangumiData):
def _gen_save_path(data: Bangumi):
folder = (
f"{data.official_title} ({data.year})" if data.year else data.official_title
)
@@ -58,7 +57,7 @@ class TorrentPath:
return str(save_path)
@staticmethod
def _rule_name(data: BangumiData):
def _rule_name(data: Bangumi):
rule_name = (
f"[{data.group_name}] {data.official_title} S{data.season}"
if settings.bangumi_manage.group_tag

View File

@@ -2,14 +2,14 @@ import logging
from module.database import BangumiDatabase
from module.downloader import DownloadClient
from module.models import BangumiData
from module.models import Bangumi
from module.searcher import SearchTorrent
logger = logging.getLogger(__name__)
class SeasonCollector(DownloadClient):
def add_season_torrents(self, data: BangumiData, torrents, torrent_files=None):
def add_season_torrents(self, data: Bangumi, torrents, torrent_files=None):
if torrent_files:
download_info = {
"torrent_files": torrent_files,
@@ -23,7 +23,7 @@ class SeasonCollector(DownloadClient):
}
return self.add_torrent(download_info)
def collect_season(self, data: BangumiData, link: str = None, proxy: bool = False):
def collect_season(self, data: Bangumi, link: str = None, proxy: bool = False):
logger.info(f"Start collecting {data.official_title} Season {data.season}...")
with SearchTorrent() as st:
if not link:
@@ -39,7 +39,7 @@ class SeasonCollector(DownloadClient):
data=data, torrents=torrents, torrent_files=torrent_files
)
def subscribe_season(self, data: BangumiData):
def subscribe_season(self, data: Bangumi):
with BangumiDatabase() as db:
data.added = True
data.eps_collect = True

View File

@@ -4,21 +4,21 @@ from fastapi.responses import JSONResponse
from module.database import BangumiDatabase
from module.downloader import DownloadClient
from module.models import BangumiData
from module.models import Bangumi
logger = logging.getLogger(__name__)
class TorrentManager(BangumiDatabase):
@staticmethod
def __match_torrents_list(data: BangumiData) -> list:
def __match_torrents_list(data: Bangumi) -> list:
with DownloadClient() as client:
torrents = client.get_torrent_info(status_filter=None)
return [
torrent.hash for torrent in torrents if torrent.save_path == data.save_path
]
def delete_torrents(self, data: BangumiData, client: DownloadClient):
def delete_torrents(self, data: Bangumi, client: DownloadClient):
hash_list = self.__match_torrents_list(data)
if hash_list:
client.delete_torrent(hash_list)
@@ -29,7 +29,7 @@ class TorrentManager(BangumiDatabase):
def delete_rule(self, _id: int | str, file: bool = False):
data = self.search_id(int(_id))
if isinstance(data, BangumiData):
if isinstance(data, Bangumi):
with DownloadClient() as client:
client.remove_rule(data.rule_name)
client.remove_rss_feed(data.official_title)
@@ -54,7 +54,7 @@ class TorrentManager(BangumiDatabase):
def disable_rule(self, _id: str | int, file: bool = False):
data = self.search_id(int(_id))
if isinstance(data, BangumiData):
if isinstance(data, Bangumi):
with DownloadClient() as client:
client.remove_rule(data.rule_name)
data.deleted = True
@@ -81,7 +81,7 @@ class TorrentManager(BangumiDatabase):
def enable_rule(self, _id: str | int):
data = self.search_id(int(_id))
if isinstance(data, BangumiData):
if isinstance(data, Bangumi):
data.deleted = False
self.update_one(data)
with DownloadClient() as client:
@@ -98,7 +98,7 @@ class TorrentManager(BangumiDatabase):
status_code=406, content={"msg": f"Can't find bangumi id {_id}"}
)
def update_rule(self, data: BangumiData):
def update_rule(self, data: Bangumi):
old_data = self.search_id(data.id)
if not old_data:
logger.error(f"[Manager] Can't find data with {data.id}")

View File

@@ -1,4 +1,4 @@
from .bangumi import *
from .bangumi import Bangumi, Episode
from .config import Config
from .rss import RSSTorrents
from .torrent import EpisodeFile, SubtitleFile, TorrentBase

View File

@@ -1,27 +1,50 @@
from dataclasses import dataclass
from pydantic import BaseModel, Field
from pydantic import BaseModel
from sqlmodel import SQLModel, Field
from typing import Optional
class BangumiData(BaseModel):
id: int = Field(0, alias="id", title="番剧ID")
official_title: str = Field("official_title", alias="official_title", title="番剧中文名")
year: str | None = Field(None, alias="year", title="番剧年份")
title_raw: str = Field("title_raw", alias="title_raw", title="番剧原名")
season: int = Field(1, alias="season", title="番剧季度")
season_raw: str | None = Field(None, alias="season_raw", title="番剧季度原名")
group_name: str | None = Field(None, alias="group_name", title="字幕组")
dpi: str | None = Field(None, alias="dpi", title="分辨率")
source: str | None = Field(None, alias="source", title="来源")
subtitle: str | None = Field(None, alias="subtitle", title="字幕")
eps_collect: bool = Field(False, alias="eps_collect", title="是否已收集")
offset: int = Field(0, alias="offset", title="番剧偏移量")
filter: list[str] = Field(["720", "\\d+-\\d+"], alias="filter", title="番剧过滤器")
rss_link: list[str] = Field([], alias="rss_link", title="番剧RSS链接")
poster_link: str | None = Field(None, alias="poster_link", title="番剧海报链接")
added: bool = Field(False, alias="added", title="是否已添加")
rule_name: str | None = Field(None, alias="rule_name", title="番剧规则名")
save_path: str | None = Field(None, alias="save_path", title="番剧保存路径")
class Bangumi(SQLModel, table=True):
id: int = Field(default=None, primary_key=True)
official_title: str = Field(
default="official_title", alias="official_title", title="番剧中文名"
)
year: Optional[str] = Field(alias="year", title="番剧年份")
title_raw: str = Field(default="title_raw", alias="title_raw", title="番剧原名")
season: int = Field(default=1, alias="season", title="番剧季度")
season_raw: Optional[str] = Field(alias="season_raw", title="番剧季度原名")
group_name: Optional[str] = Field(alias="group_name", title="字幕组")
dpi: Optional[str] = Field(alias="dpi", title="分辨率")
source: Optional[str] = Field(alias="source", title="来源")
subtitle: Optional[str] = Field(alias="subtitle", title="字幕")
eps_collect: bool = Field(default=False, alias="eps_collect", title="是否已收集")
offset: int = Field(default=0, alias="offset", title="番剧偏移量")
filter: str = Field(default="720, \\d+-\\d+", alias="filter", title="番剧过滤器")
rss_link: str = Field(default="", alias="rss_link", title="番剧RSS链接")
poster_link: Optional[str] = Field(alias="poster_link", title="番剧海报链接")
added: bool = Field(default=False, alias="added", title="是否已添加")
rule_name: Optional[str] = Field(alias="rule_name", title="番剧规则名")
save_path: Optional[str] = Field(alias="save_path", title="番剧保存路径")
deleted: bool = Field(False, alias="deleted", title="是否已删除")
class BangumiUpdate(SQLModel):
official_title: str = Field(
default="official_title", alias="official_title", title="番剧中文名"
)
year: Optional[str] = Field(alias="year", title="番剧年份")
season: int = Field(default=1, alias="season", title="番剧季度")
season_raw: Optional[str] = Field(alias="season_raw", title="番剧季度原名")
group_name: Optional[str] = Field(alias="group_name", title="字幕组")
dpi: Optional[str] = Field(alias="dpi", title="分辨率")
source: Optional[str] = Field(alias="source", title="来源")
subtitle: Optional[str] = Field(alias="subtitle", title="字幕")
eps_collect: bool = Field(default=False, alias="eps_collect", title="是否已收集")
offset: int = Field(default=0, alias="offset", title="番剧偏移量")
filter: str = Field(default="720, \\d+-\\d+", alias="filter", title="番剧过滤器")
rss_link: str = Field(default="", alias="rss_link", title="番剧RSS链接")
added: bool = Field(default=False, alias="added", title="是否已添加")
deleted: bool = Field(False, alias="deleted", title="是否已删除")
@@ -29,14 +52,14 @@ class Notification(BaseModel):
official_title: str = Field(..., alias="official_title", title="番剧名")
season: int = Field(..., alias="season", title="番剧季度")
episode: int = Field(..., alias="episode", title="番剧集数")
poster_path: str | None = Field(None, alias="poster_path", title="番剧海报路径")
poster_path: Optional[str] = Field(None, alias="poster_path", title="番剧海报路径")
@dataclass
class Episode:
title_en: str | None
title_zh: str | None
title_jp: str | None
title_en: Optional[str]
title_zh: Optional[str]
title_jp: Optional[str]
season: int
season_raw: str
episode: int

View File

@@ -14,7 +14,9 @@ class Downloader(BaseModel):
type: str = Field("qbittorrent", description="Downloader type")
host: str = Field("172.17.0.1:8080", description="Downloader host")
username_: str = Field("admin", alias="username", description="Downloader username")
password_: str = Field("adminadmin", alias="password", description="Downloader password")
password_: str = Field(
"adminadmin", alias="password", description="Downloader password"
)
path: str = Field("/downloads/Bangumi", description="Downloader path")
ssl: bool = Field(False, description="Downloader ssl")
@@ -26,6 +28,7 @@ class Downloader(BaseModel):
def password(self):
return expandvars(self.password_)
class RSSParser(BaseModel):
enable: bool = Field(True, description="Enable RSS parser")
type: str = Field("mikan", description="RSS parser type")
@@ -39,6 +42,7 @@ class RSSParser(BaseModel):
def token(self):
return expandvars(self.token_)
class BangumiManage(BaseModel):
enable: bool = Field(True, description="Enable bangumi manage")
eps_complete: bool = Field(False, description="Enable eps complete")
@@ -82,6 +86,7 @@ class Notification(BaseModel):
def chat_id(self):
return expandvars(self.chat_id_)
class Config(BaseModel):
program: Program = Program()
downloader: Downloader = Downloader()

View File

@@ -1,14 +1,24 @@
from pydantic import BaseModel, Field
from pydantic import BaseModel
from typing import Optional
from sqlmodel import SQLModel, Field
class User(BaseModel):
class User(SQLModel, table=True):
id: int = Field(default=None, primary_key=True)
username: str = Field(
"admin", min_length=4, max_length=20, regex=r"^[a-zA-Z0-9_]+$"
)
password: str = Field("adminadmin", min_length=8)
class UserLogin(BaseModel):
class UserUpdate(SQLModel):
username: Optional[str] = Field(
None, min_length=4, max_length=20, regex=r"^[a-zA-Z0-9_]+$"
)
password: Optional[str] = Field(None, min_length=8)
class UserLogin(SQLModel):
username: str
password: str = Field(..., min_length=8)

View File

@@ -1,7 +1,7 @@
import logging
from module.conf import settings
from module.models import BangumiData
from module.models import Bangumi
from .analyser import raw_parser, tmdb_parser, torrent_parser
@@ -39,7 +39,7 @@ class TitleParser:
return official_title, tmdb_season, year
@staticmethod
def raw_parser(raw: str, rss_link: str) -> BangumiData | None:
def raw_parser(raw: str, rss_link: str) -> Bangumi | None:
language = settings.rss_parser.language
try:
episode = raw_parser(raw)
@@ -60,7 +60,7 @@ class TitleParser:
else:
official_title = title_raw
_season = episode.season
data = BangumiData(
data = Bangumi(
official_title=official_title,
title_raw=title_raw,
season=_season,

View File

@@ -3,7 +3,7 @@ import re
from module.conf import settings
from module.database import BangumiDatabase
from module.models import BangumiData
from module.models import Bangumi
from module.network import RequestContent, TorrentInfo
from module.parser import TitleParser
@@ -16,7 +16,7 @@ class RSSAnalyser:
with BangumiDatabase() as db:
db.update_table()
def official_title_parser(self, data: BangumiData, mikan_title: str):
def official_title_parser(self, data: Bangumi, mikan_title: str):
if settings.rss_parser.parser_type == "mikan":
data.official_title = mikan_title if mikan_title else data.official_title
elif settings.rss_parser.parser_type == "tmdb":
@@ -63,7 +63,7 @@ class RSSAnalyser:
def torrent_to_data(
self, torrent: TorrentInfo, rss_link: str | None = None
) -> BangumiData:
) -> Bangumi:
data = self._title_analyser.raw_parser(raw=torrent.name, rss_link=rss_link)
if data:
try:
@@ -79,7 +79,7 @@ class RSSAnalyser:
def rss_to_data(
self, rss_link: str, database: BangumiDatabase, full_parse: bool = True
) -> list[BangumiData]:
) -> list[Bangumi]:
rss_torrents = self.get_rss_torrents(rss_link, full_parse)
torrents_to_add = database.match_list(rss_torrents, rss_link)
if not torrents_to_add:
@@ -92,7 +92,7 @@ class RSSAnalyser:
else:
return []
def link_to_data(self, link: str) -> BangumiData:
def link_to_data(self, link: str) -> Bangumi:
torrents = self.get_rss_torrents(link, False)
for torrent in torrents:
data = self.torrent_to_data(torrent, link)

View File

@@ -3,7 +3,7 @@ import logging
from module.conf import settings
from module.database import BangumiDatabase
from module.downloader import DownloadClient
from module.models import BangumiData
from module.models import Bangumi
from module.network import RequestContent
logger = logging.getLogger(__name__)
@@ -14,7 +14,7 @@ def matched(torrent_title: str):
return db.match_torrent(torrent_title)
def save_path(data: BangumiData):
def save_path(data: Bangumi):
folder = (
f"{data.official_title}({data.year})" if data.year else f"{data.official_title}"
)

View File

@@ -1,7 +1,7 @@
import re
from module.database import RSSDatabase
from module.models import BangumiData, RSSTorrents
from module.models import Bangumi, RSSTorrents
from module.network import RequestContent, TorrentInfo
@@ -11,7 +11,7 @@ class RSSPoller(RSSDatabase):
return req.get_torrents(rss_link)
@staticmethod
def filter_torrent(data: BangumiData, torrent: TorrentInfo) -> bool:
def filter_torrent(data: Bangumi, torrent: TorrentInfo) -> bool:
if data.title_raw in torrent.name:
_filter = "|".join(data.filter)
if not re.search(_filter, torrent.name):

View File

@@ -1,4 +1,4 @@
from module.models import BangumiData, TorrentBase
from module.models import Bangumi, TorrentBase
from module.network import RequestContent
from module.searcher.plugin import search_url
@@ -30,7 +30,7 @@ class SearchTorrent(RequestContent):
return [TorrentBase(**d) for d in to_dict()]
def search_season(self, data: BangumiData):
def search_season(self, data: Bangumi):
keywords = [getattr(data, key) for key in SEARCH_KEY if getattr(data, key)]
torrents = self.search_torrents(keywords)
return [torrent for torrent in torrents if data.title_raw in torrent.name]

View File

@@ -1,7 +1,7 @@
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from module.database.user import AuthDB
from module.database.user import UserDatabase
from module.models.user import User
from .jwt import verify_token
@@ -20,7 +20,7 @@ async def get_current_user(token: str = Depends(oauth2_scheme)):
status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token"
)
username = payload.get("sub")
with AuthDB() as user_db:
with UserDatabase as user_db:
user = user_db.get_user(username)
if not user:
raise HTTPException(
@@ -40,7 +40,7 @@ async def get_token_data(token: str = Depends(oauth2_scheme)):
def update_user_info(user_data: User, current_user):
try:
with AuthDB() as db:
with UserDatabase as db:
db.update_user(current_user.username, user_data)
return True
except Exception as e:
@@ -48,5 +48,5 @@ def update_user_info(user_data: User, current_user):
def auth_user(username, password):
with AuthDB() as db:
with UserDatabase() as db:
db.auth_user(username, password)

View File

@@ -2,7 +2,7 @@ import os
from module.conf import LEGACY_DATA_PATH
from module.database import BangumiDatabase
from module.models import BangumiData
from module.models import Bangumi
from module.utils import json_config
@@ -14,7 +14,7 @@ def data_migration():
rss_link = old_data["rss_link"]
new_data = []
for info in infos:
new_data.append(BangumiData(**info, rss_link=[rss_link]))
new_data.append(Bangumi(**info, rss_link=[rss_link]))
with BangumiDatabase() as database:
database.update_table()
database.insert_list(new_data)

View File

@@ -1,11 +1,17 @@
from sqlmodel import create_engine, SQLModel
from sqlmodel.pool import StaticPool
from module.database import BangumiDatabase
from module.models import BangumiData
from module.models import Bangumi
def test_database():
TEST_PATH = "test/test.db"
test_data = BangumiData(
id=1,
def test_bangumi_database():
# sqlite mock engine
engine = create_engine(
"sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool
)
SQLModel.metadata.create_all(engine)
test_data = Bangumi(
official_title="test",
year="2021",
title_raw="test",
@@ -17,18 +23,15 @@ def test_database():
subtitle="test",
eps_collect=False,
offset=0,
filter=["720p", "\\d+-\\d+"],
rss_link=["test"],
filter="720p,\\d+-\\d+",
rss_link="test",
poster_link="/test/test.jpg",
added=False,
rule_name=None,
save_path=None,
deleted=False,
)
with BangumiDatabase(database=TEST_PATH) as database:
# create table
database.update_table()
with BangumiDatabase(database=TEST_PATH) as database:
with BangumiDatabase(engine) as database:
# insert
database.insert_one(test_data)
assert database.search_id(1) == test_data
@@ -39,13 +42,8 @@ def test_database():
assert database.search_id(1) == test_data
# search poster
assert database.match_poster("test") == "/test/test.jpg"
assert database.match_poster("test2 (2021)") == "/test/test.jpg"
# delete
database.delete_one(1)
assert database.search_id(1) is None
# Delete test database
import os
os.remove(TEST_PATH)