mirror of
https://github.com/EstrellaXD/Auto_Bangumi.git
synced 2026-04-15 02:50:42 +08:00
Merge pull request #411 from EstrellaXD/orm-change
feat: Change orm from hand-made code to SQLModel
This commit is contained in:
@@ -23,3 +23,4 @@ python-jose==3.3.0
|
||||
passlib==1.7.4
|
||||
bcrypt==4.0.1
|
||||
python-multipart==0.0.6
|
||||
sqlmodel
|
||||
|
||||
@@ -2,13 +2,13 @@ from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from module.manager import TorrentManager
|
||||
from module.models import BangumiData
|
||||
from module.models import Bangumi
|
||||
from module.security import get_current_user
|
||||
|
||||
router = APIRouter(prefix="/bangumi", tags=["bangumi"])
|
||||
|
||||
|
||||
@router.get("/getAll", response_model=list[BangumiData])
|
||||
@router.get("/getAll", response_model=list[Bangumi])
|
||||
async def get_all_data(current_user=Depends(get_current_user)):
|
||||
if not current_user:
|
||||
raise HTTPException(
|
||||
@@ -18,7 +18,7 @@ async def get_all_data(current_user=Depends(get_current_user)):
|
||||
return torrent.search_all()
|
||||
|
||||
|
||||
@router.get("/getData/{bangumi_id}", response_model=BangumiData)
|
||||
@router.get("/getData/{bangumi_id}", response_model=Bangumi)
|
||||
async def get_data(bangumi_id: str, current_user=Depends(get_current_user)):
|
||||
if not current_user:
|
||||
raise HTTPException(
|
||||
@@ -29,7 +29,7 @@ async def get_data(bangumi_id: str, current_user=Depends(get_current_user)):
|
||||
|
||||
|
||||
@router.post("/updateRule")
|
||||
async def update_rule(data: BangumiData, current_user=Depends(get_current_user)):
|
||||
async def update_rule(data: Bangumi, current_user=Depends(get_current_user)):
|
||||
if not current_user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token"
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
|
||||
from module.manager import SeasonCollector
|
||||
from module.models import BangumiData
|
||||
from module.models import Bangumi
|
||||
from module.models.api import RssLink
|
||||
from module.rss import analyser
|
||||
from module.security import get_current_user
|
||||
@@ -23,9 +23,7 @@ async def analysis(link: RssLink, current_user=Depends(get_current_user)):
|
||||
|
||||
|
||||
@router.post("/collection")
|
||||
async def download_collection(
|
||||
data: BangumiData, current_user=Depends(get_current_user)
|
||||
):
|
||||
async def download_collection(data: Bangumi, current_user=Depends(get_current_user)):
|
||||
if not current_user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token"
|
||||
@@ -41,7 +39,7 @@ async def download_collection(
|
||||
|
||||
|
||||
@router.post("/subscribe")
|
||||
async def subscribe(data: BangumiData, current_user=Depends(get_current_user)):
|
||||
async def subscribe(data: Bangumi, current_user=Depends(get_current_user)):
|
||||
if not current_user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token"
|
||||
|
||||
@@ -4,7 +4,7 @@ from .config import VERSION, settings
|
||||
from .log import LOG_PATH, setup_logger
|
||||
|
||||
TMDB_API = "32b19d6a05b512190a056fa4e747cbbc"
|
||||
DATA_PATH = Path("data/data.db")
|
||||
DATA_PATH = "sqlite:///data/data.db"
|
||||
LEGACY_DATA_PATH = Path("data/data.json")
|
||||
|
||||
PLATFORM = "Windows" if "\\" in settings.downloader.path else "Unix"
|
||||
|
||||
@@ -1,121 +1,102 @@
|
||||
import logging
|
||||
|
||||
from module.database.orm import Connector
|
||||
from module.models import BangumiData
|
||||
from module.conf import DATA_PATH
|
||||
from sqlmodel import Session, select, delete, or_
|
||||
from sqlalchemy.sql import func
|
||||
from typing import Optional
|
||||
|
||||
from .engine import engine
|
||||
from module.models import Bangumi
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BangumiDatabase(Connector):
|
||||
def __init__(self, database: str = DATA_PATH):
|
||||
super().__init__(
|
||||
table_name="bangumi",
|
||||
data=self.__data_to_db(BangumiData()),
|
||||
database=database,
|
||||
)
|
||||
class BangumiDatabase(Session):
|
||||
def __init__(self, _engine=engine):
|
||||
super().__init__(_engine)
|
||||
|
||||
def update_table(self):
|
||||
self.update.table()
|
||||
|
||||
@staticmethod
|
||||
def __data_to_db(data: BangumiData) -> dict:
|
||||
db_data = data.dict()
|
||||
for key, value in db_data.items():
|
||||
if isinstance(value, bool):
|
||||
db_data[key] = int(value)
|
||||
elif isinstance(value, list):
|
||||
db_data[key] = ",".join(value)
|
||||
return db_data
|
||||
|
||||
@staticmethod
|
||||
def __db_to_data(db_data: dict) -> BangumiData:
|
||||
for key, item in db_data.items():
|
||||
if isinstance(item, int):
|
||||
if key not in ["id", "offset", "season", "year"]:
|
||||
db_data[key] = bool(item)
|
||||
elif key in ["filter", "rss_link"]:
|
||||
db_data[key] = item.split(",")
|
||||
return BangumiData(**db_data)
|
||||
|
||||
def insert_one(self, data: BangumiData):
|
||||
db_data = self.__data_to_db(data)
|
||||
self.insert.one(db_data)
|
||||
def insert_one(self, data: Bangumi):
|
||||
self.add(data)
|
||||
self.commit()
|
||||
logger.debug(f"[Database] Insert {data.official_title} into database.")
|
||||
# if self.__check_exist(data):
|
||||
# self.update_one(data)
|
||||
# else:
|
||||
# db_data = self.__data_to_db(data)
|
||||
# db_data["id"] = self.gen_id()
|
||||
# self._insert(db_data=db_data, table_name=self.__table_name)
|
||||
# logger.debug(f"[Database] Insert {data.official_title} into database.")
|
||||
|
||||
def insert_list(self, data: list[BangumiData]):
|
||||
data_list = [self.__data_to_db(x) for x in data]
|
||||
self.insert.many(data_list)
|
||||
# _id = self.gen_id()
|
||||
# for i, item in enumerate(data):
|
||||
# item.id = _id + i
|
||||
# data_list = [self.__data_to_db(x) for x in data]
|
||||
# self._insert_list(data_list=data_list, table_name=self.__table_name)
|
||||
def insert_list(self, data: list[Bangumi]):
|
||||
self.add_all(data)
|
||||
logger.debug(f"[Database] Insert {len(data)} bangumi into database.")
|
||||
|
||||
def update_one(self, data: BangumiData) -> bool:
|
||||
db_data = self.__data_to_db(data)
|
||||
return self.update.one(db_data)
|
||||
def update_one(self, data: Bangumi) -> bool:
|
||||
db_data = self.get(Bangumi, data.id)
|
||||
if not db_data:
|
||||
return False
|
||||
bangumi_data = data.dict(exclude_unset=True)
|
||||
for key, value in bangumi_data.items():
|
||||
setattr(db_data, key, value)
|
||||
self.add(db_data)
|
||||
self.commit()
|
||||
self.refresh(db_data)
|
||||
logger.debug(f"[Database] Update {data.official_title}")
|
||||
return True
|
||||
|
||||
def update_list(self, data: list[BangumiData]):
|
||||
data_list = [self.__data_to_db(x) for x in data]
|
||||
self.update.many(data_list)
|
||||
def update_list(self, datas: list[Bangumi]):
|
||||
for data in datas:
|
||||
self.update_one(data)
|
||||
|
||||
def update_rss(self, title_raw, rss_set: str):
|
||||
# Update rss and added
|
||||
location = {"title_raw": title_raw}
|
||||
set_value = {"rss_link": rss_set, "added": 0}
|
||||
self.update.value(location, set_value)
|
||||
statement = select(Bangumi).where(Bangumi.title_raw == title_raw)
|
||||
bangumi = self.exec(statement).first()
|
||||
bangumi.rss_link = rss_set
|
||||
bangumi.added = False
|
||||
self.add(bangumi)
|
||||
self.commit()
|
||||
self.refresh(bangumi)
|
||||
logger.debug(f"[Database] Update {title_raw} rss_link to {rss_set}.")
|
||||
|
||||
def update_poster(self, title_raw, poster_link: str):
|
||||
location = {"title_raw": title_raw}
|
||||
set_value = {"poster_link": poster_link}
|
||||
self.update.value(location, set_value)
|
||||
statement = select(Bangumi).where(Bangumi.title_raw == title_raw)
|
||||
bangumi = self.exec(statement).first()
|
||||
bangumi.poster_link = poster_link
|
||||
self.add(bangumi)
|
||||
self.commit()
|
||||
self.refresh(bangumi)
|
||||
logger.debug(f"[Database] Update {title_raw} poster_link to {poster_link}.")
|
||||
|
||||
def delete_one(self, _id: int):
|
||||
self.delete.one(_id)
|
||||
statement = select(Bangumi).where(Bangumi.id == _id)
|
||||
bangumi = self.exec(statement).first()
|
||||
self.delete(bangumi)
|
||||
self.commit()
|
||||
logger.debug(f"[Database] Delete bangumi id: {_id}.")
|
||||
|
||||
def delete_all(self):
|
||||
self.delete.all()
|
||||
statement = delete(Bangumi)
|
||||
self.exec(statement)
|
||||
self.commit()
|
||||
|
||||
def search_all(self) -> list[BangumiData]:
|
||||
all_data = self.select.all()
|
||||
return [self.__db_to_data(x) for x in all_data]
|
||||
def search_all(self) -> list[Bangumi]:
|
||||
statement = select(Bangumi)
|
||||
return self.exec(statement).all()
|
||||
|
||||
def search_id(self, _id: int) -> BangumiData | None:
|
||||
dict_data = self.select.one(conditions={"id": _id})
|
||||
if dict_data is None:
|
||||
def search_id(self, _id: int) -> Optional[Bangumi]:
|
||||
statement = select(Bangumi).where(Bangumi.id == _id)
|
||||
bangumi = self.exec(statement).first()
|
||||
if bangumi is None:
|
||||
logger.warning(f"[Database] Cannot find bangumi id: {_id}.")
|
||||
return None
|
||||
logger.debug(f"[Database] Find bangumi id: {_id}.")
|
||||
return self.__db_to_data(dict_data)
|
||||
else:
|
||||
logger.debug(f"[Database] Find bangumi id: {_id}.")
|
||||
return self.exec(statement).first()
|
||||
|
||||
def match_poster(self, bangumi_name: str) -> str:
|
||||
condition = {"official_title": bangumi_name}
|
||||
keys = ["poster_link"]
|
||||
data = self.select.one(
|
||||
keys=keys,
|
||||
conditions=condition,
|
||||
combine_operator="INSTR",
|
||||
)
|
||||
if not data:
|
||||
# Use like to match
|
||||
statement = select(Bangumi).where(func.instr(bangumi_name, Bangumi.title_raw) > 0)
|
||||
data = self.exec(statement).first()
|
||||
if data:
|
||||
return data.poster_link
|
||||
else:
|
||||
return ""
|
||||
return data.get("poster_link")
|
||||
|
||||
def match_list(self, torrent_list: list, rss_link: str) -> list:
|
||||
# Match title_raw in database
|
||||
keys = ["title_raw", "rss_link", "poster_link"]
|
||||
match_datas = self.select.column(keys)
|
||||
match_datas = self.search_all()
|
||||
if not match_datas:
|
||||
return torrent_list
|
||||
# Match title
|
||||
@@ -123,36 +104,43 @@ class BangumiDatabase(Connector):
|
||||
while i < len(torrent_list):
|
||||
torrent = torrent_list[i]
|
||||
for match_data in match_datas:
|
||||
if match_data.get("title_raw") in torrent.name:
|
||||
if rss_link not in match_data.get("rss_link"):
|
||||
match_data["rss_link"] += f",{rss_link}"
|
||||
self.update_rss(
|
||||
match_data.get("title_raw"), match_data.get("rss_link")
|
||||
)
|
||||
if not match_data.get("poster_link"):
|
||||
self.update_poster(
|
||||
match_data.get("title_raw"), torrent.poster_link
|
||||
)
|
||||
if match_data.title_raw in torrent.name:
|
||||
if rss_link not in match_data.rss_link:
|
||||
match_data.rss_link += f",{rss_link}"
|
||||
self.update_rss(match_data.title_raw, match_data.rss_link)
|
||||
if not match_data.poster_link:
|
||||
self.update_poster(match_data.title_raw, torrent.poster_link)
|
||||
torrent_list.pop(i)
|
||||
break
|
||||
else:
|
||||
i += 1
|
||||
return torrent_list
|
||||
|
||||
def not_complete(self) -> list[BangumiData]:
|
||||
def not_complete(self) -> list[Bangumi]:
|
||||
# Find eps_complete = False
|
||||
condition = {"eps_collect": 0}
|
||||
dict_data = self.select.many(
|
||||
conditions=condition,
|
||||
)
|
||||
return [self.__db_to_data(x) for x in dict_data]
|
||||
condition = select(Bangumi).where(Bangumi.eps_collect == 0)
|
||||
datas = self.exec(condition).all()
|
||||
return datas
|
||||
|
||||
def not_added(self) -> list[BangumiData]:
|
||||
conditions = {"added": 0, "rule_name": None, "save_path": None}
|
||||
dict_data = self.select.many(conditions=conditions, combine_operator="OR")
|
||||
return [self.__db_to_data(x) for x in dict_data]
|
||||
def not_added(self) -> list[Bangumi]:
|
||||
conditions = select(Bangumi).where(
|
||||
or_(
|
||||
Bangumi.added == 0, Bangumi.rule_name is None, Bangumi.save_path is None
|
||||
)
|
||||
)
|
||||
datas = self.exec(conditions).all()
|
||||
return datas
|
||||
|
||||
def disable_rule(self, _id: int):
|
||||
statement = select(Bangumi).where(Bangumi.id == _id)
|
||||
bangumi = self.exec(statement).first()
|
||||
bangumi.deleted = True
|
||||
self.add(bangumi)
|
||||
self.commit()
|
||||
self.refresh(bangumi)
|
||||
logger.debug(f"[Database] Disable rule {bangumi.title_raw}.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
with BangumiDatabase() as db:
|
||||
print(db.match_poster("久保"))
|
||||
print(db.not_complete())
|
||||
|
||||
@@ -1,174 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
import sqlite3
|
||||
|
||||
from module.conf import DATA_PATH
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DataConnector:
|
||||
def __init__(self):
|
||||
# Create folder if not exists
|
||||
DATA_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self._conn = sqlite3.connect(DATA_PATH)
|
||||
self._cursor = self._conn.cursor()
|
||||
|
||||
def _update_table(self, table_name: str, db_data: dict):
|
||||
columns = ", ".join(
|
||||
[
|
||||
f"{key} {self.__python_to_sqlite_type(value)}"
|
||||
for key, value in db_data.items()
|
||||
]
|
||||
)
|
||||
create_table_sql = f"CREATE TABLE IF NOT EXISTS {table_name} ({columns});"
|
||||
self._cursor.execute(create_table_sql)
|
||||
self._cursor.execute(f"PRAGMA table_info({table_name})")
|
||||
existing_columns = {
|
||||
column_info[1]: column_info for column_info in self._cursor.fetchall()
|
||||
}
|
||||
for key, value in db_data.items():
|
||||
if key not in existing_columns:
|
||||
insert_column = self.__python_to_sqlite_type(value)
|
||||
if value is None:
|
||||
value = "NULL"
|
||||
add_column_sql = f"ALTER TABLE {table_name} ADD COLUMN {key} {insert_column} DEFAULT {value};"
|
||||
self._cursor.execute(add_column_sql)
|
||||
self._conn.commit()
|
||||
logger.debug(f"Create / Update table {table_name}.")
|
||||
|
||||
def _insert(self, table_name: str, db_data: dict):
|
||||
columns = ", ".join(db_data.keys())
|
||||
values = ", ".join([f":{key}" for key in db_data.keys()])
|
||||
self._cursor.execute(
|
||||
f"INSERT INTO {table_name} ({columns}) VALUES ({values})", db_data
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
def _insert_list(self, table_name: str, data_list: list[dict]):
|
||||
columns = ", ".join(data_list[0].keys())
|
||||
values = ", ".join([f":{key}" for key in data_list[0].keys()])
|
||||
self._cursor.executemany(
|
||||
f"INSERT INTO {table_name} ({columns}) VALUES ({values})", data_list
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
def _select(self, keys: list[str], table_name: str, condition: str = None) -> dict:
|
||||
if condition is None:
|
||||
self._cursor.execute(f"SELECT {', '.join(keys)} FROM {table_name}")
|
||||
else:
|
||||
self._cursor.execute(
|
||||
f"SELECT {', '.join(keys)} FROM {table_name} WHERE {condition}"
|
||||
)
|
||||
return dict(zip(keys, self._cursor.fetchone()))
|
||||
|
||||
def _update(self, table_name: str, db_data: dict):
|
||||
_id = db_data.get("id")
|
||||
if _id is None:
|
||||
raise ValueError("No _id in db_data.")
|
||||
set_sql = ", ".join([f"{key} = :{key}" for key in db_data.keys()])
|
||||
self._cursor.execute(
|
||||
f"UPDATE {table_name} SET {set_sql} WHERE id = {_id}", db_data
|
||||
)
|
||||
self._conn.commit()
|
||||
return self._cursor.rowcount == 1
|
||||
|
||||
def _update_list(self, table_name: str, data_list: list[dict]):
|
||||
if len(data_list) == 0:
|
||||
return
|
||||
set_sql = ", ".join(
|
||||
[f"{key} = :{key}" for key in data_list[0].keys() if key != "id"]
|
||||
)
|
||||
self._cursor.executemany(
|
||||
f"UPDATE {table_name} SET {set_sql} WHERE id = :id", data_list
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
def _update_section(self, table_name: str, location: dict, update_dict: dict):
|
||||
set_sql = ", ".join([f"{key} = :{key}" for key in update_dict.keys()])
|
||||
sql_loc = f"{location['key']} = {location['value']}"
|
||||
self._cursor.execute(
|
||||
f"UPDATE {table_name} SET {set_sql} WHERE {sql_loc}", update_dict
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
def _delete_all(self, table_name: str):
|
||||
self._cursor.execute(f"DELETE FROM {table_name}")
|
||||
self._conn.commit()
|
||||
|
||||
def _delete(self, table_name: str, condition: dict):
|
||||
condition_sql = " AND ".join([f"{key} = :{key}" for key in condition.keys()])
|
||||
self._cursor.execute(
|
||||
f"DELETE FROM {table_name} WHERE {condition_sql}", condition
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
def _search(
|
||||
self, table_name: str, keys: list[str] | None = None, condition: dict = None
|
||||
):
|
||||
if keys is None:
|
||||
select_sql = "*"
|
||||
else:
|
||||
select_sql = ", ".join(keys)
|
||||
if condition is None:
|
||||
self._cursor.execute(f"SELECT {select_sql} FROM {table_name}")
|
||||
else:
|
||||
custom_condition = condition.pop("_custom_condition", None)
|
||||
condition_sql = " AND ".join(
|
||||
[f"{key} = :{key}" for key in condition.keys()]
|
||||
) + (f" AND {custom_condition}" if custom_condition else "")
|
||||
self._cursor.execute(
|
||||
f"SELECT {select_sql} FROM {table_name} WHERE {condition_sql}",
|
||||
condition,
|
||||
)
|
||||
|
||||
def _search_data(
|
||||
self, table_name: str, keys: list[str] | None = None, condition: dict = None
|
||||
) -> dict:
|
||||
if keys is None:
|
||||
keys = self.__get_table_columns(table_name)
|
||||
self._search(table_name, keys, condition)
|
||||
return dict(zip(keys, self._cursor.fetchone()))
|
||||
|
||||
def _search_datas(
|
||||
self, table_name: str, keys: list[str] | None = None, condition: dict = None
|
||||
) -> list[dict]:
|
||||
if keys is None:
|
||||
keys = self.__get_table_columns(table_name)
|
||||
self._search(table_name, keys, condition)
|
||||
return [dict(zip(keys, row)) for row in self._cursor.fetchall()]
|
||||
|
||||
def _table_exists(self, table_name: str) -> bool:
|
||||
self._cursor.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name=?;",
|
||||
(table_name,),
|
||||
)
|
||||
return len(self._cursor.fetchall()) == 1
|
||||
|
||||
def __get_table_columns(self, table_name: str) -> list[str]:
|
||||
self._cursor.execute(f"PRAGMA table_info({table_name})")
|
||||
return [column_info[1] for column_info in self._cursor.fetchall()]
|
||||
|
||||
@staticmethod
|
||||
def __python_to_sqlite_type(value) -> str:
|
||||
if isinstance(value, int):
|
||||
return "INTEGER NOT NULL"
|
||||
elif isinstance(value, float):
|
||||
return "REAL NOT NULL"
|
||||
elif isinstance(value, str):
|
||||
return "TEXT NOT NULL"
|
||||
elif isinstance(value, bool):
|
||||
return "INTEGER NOT NULL"
|
||||
elif isinstance(value, list):
|
||||
return "TEXT NOT NULL"
|
||||
elif value is None:
|
||||
return "TEXT"
|
||||
else:
|
||||
raise ValueError(f"Unsupported data type: {type(value)}")
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self._conn.close()
|
||||
7
backend/src/module/database/engine.py
Normal file
7
backend/src/module/database/engine.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from sqlmodel import create_engine, Session
|
||||
from module.conf import DATA_PATH
|
||||
|
||||
|
||||
engine = create_engine(DATA_PATH)
|
||||
|
||||
db_session = Session(engine)
|
||||
@@ -1 +0,0 @@
|
||||
from .connector import Connector
|
||||
@@ -1,71 +0,0 @@
|
||||
from os import PathLike
|
||||
from pathlib import Path
|
||||
import sqlite3
|
||||
|
||||
from .delete import Delete
|
||||
from .insert import Insert
|
||||
from .select import Select
|
||||
from .update import Update
|
||||
|
||||
from module.conf import DATA_PATH
|
||||
|
||||
|
||||
class Connector:
|
||||
def __init__(
|
||||
self, table_name: str, data: dict, database: PathLike[str] | Path = DATA_PATH
|
||||
):
|
||||
# Create folder if not exists
|
||||
if isinstance(database, (PathLike, str)):
|
||||
database = Path(database)
|
||||
database.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self._conn = sqlite3.connect(database)
|
||||
self._cursor = self._conn.cursor()
|
||||
self.update = Update(self, table_name, data)
|
||||
self.insert = Insert(self, table_name, data)
|
||||
self.select = Select(self, table_name, data)
|
||||
self.delete = Delete(self, table_name, data)
|
||||
self._columns = self.__get_columns(table_name)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self._conn.close()
|
||||
|
||||
def __get_columns(self, table_name: str) -> list[str]:
|
||||
self._cursor.execute(f"PRAGMA table_info({table_name})")
|
||||
return [x[1] for x in self._cursor.fetchall()]
|
||||
|
||||
def execute(self, sql: str, params: tuple = None):
|
||||
if params is None:
|
||||
self._cursor.execute(sql)
|
||||
else:
|
||||
self._cursor.execute(sql, params)
|
||||
self._conn.commit()
|
||||
|
||||
def executemany(self, sql: str, params: list[tuple]):
|
||||
self._cursor.executemany(sql, params)
|
||||
self._conn.commit()
|
||||
|
||||
def fetchall(self, keys: str = None) -> list[dict]:
|
||||
datas = self._cursor.fetchall()
|
||||
if keys:
|
||||
return [dict(zip(keys, data)) for data in datas]
|
||||
return [dict(zip(self._columns, data)) for data in datas]
|
||||
|
||||
def fetchone(self, keys: list[str] = None) -> dict:
|
||||
data = self._cursor.fetchone()
|
||||
if data:
|
||||
if keys:
|
||||
return dict(zip(keys, data))
|
||||
return dict(zip(self._columns, data))
|
||||
|
||||
def fetchmany(self, keys: list[str], size: int) -> list[dict]:
|
||||
datas = self._cursor.fetchmany(size)
|
||||
if keys:
|
||||
return [dict(zip(keys, data)) for data in datas]
|
||||
return [dict(zip(self._columns, data)) for data in datas]
|
||||
|
||||
def fetch(self):
|
||||
return self._cursor.fetchall()
|
||||
@@ -1,23 +0,0 @@
|
||||
class Delete:
|
||||
def __init__(self, connector, table_name: str, data: dict):
|
||||
self._connector = connector
|
||||
self._table_name = table_name
|
||||
self._data = data
|
||||
|
||||
def one(self, _id: int) -> bool:
|
||||
self._connector.execute(
|
||||
f"""
|
||||
DELETE FROM {self._table_name}
|
||||
WHERE id = :id
|
||||
""",
|
||||
{"id": _id},
|
||||
)
|
||||
return True
|
||||
|
||||
def all(self):
|
||||
self._connector.execute(
|
||||
f"""
|
||||
DELETE FROM {self._table_name}
|
||||
""",
|
||||
)
|
||||
return True
|
||||
@@ -1,33 +0,0 @@
|
||||
class Insert:
|
||||
def __init__(self, connector, table_name: str, data: dict):
|
||||
self._connector = connector
|
||||
self._table_name = table_name
|
||||
self._columns = data.items()
|
||||
|
||||
def __gen_id(self) -> int:
|
||||
self._connector.execute(
|
||||
f"""
|
||||
SELECT MAX(id) FROM {self._table_name}
|
||||
""",
|
||||
)
|
||||
max_id = self._connector.fetchone(keys=["id"]).get("id")
|
||||
if max_id is None:
|
||||
return 1
|
||||
return max_id + 1
|
||||
|
||||
def one(self, data: dict):
|
||||
_id = self.__gen_id()
|
||||
data["id"] = _id
|
||||
columns = ", ".join(data.keys())
|
||||
placeholders = ", ".join([f":{key}" for key in data.keys()])
|
||||
self._connector.execute(
|
||||
f"""
|
||||
INSERT INTO {self._table_name} ({columns})
|
||||
VALUES ({placeholders})
|
||||
""",
|
||||
data,
|
||||
)
|
||||
|
||||
def many(self, data: list[dict]):
|
||||
for item in data:
|
||||
self.one(item)
|
||||
@@ -1,96 +0,0 @@
|
||||
class Select:
|
||||
def __init__(self, connector, table_name: str, data: dict):
|
||||
self._connector = connector
|
||||
self._table_name = table_name
|
||||
self._data = data
|
||||
|
||||
def id(self, _id: int):
|
||||
self._connector.execute(
|
||||
f"""
|
||||
SELECT * FROM {self._table_name}
|
||||
WHERE id = :id
|
||||
""",
|
||||
{"id": _id},
|
||||
)
|
||||
return self._connector.fetchone()
|
||||
|
||||
def all(self, limit: int = None):
|
||||
if limit is None:
|
||||
limit = 10000
|
||||
self._connector.execute(
|
||||
f"""
|
||||
SELECT * FROM {self._table_name} LIMIT {limit}
|
||||
""",
|
||||
)
|
||||
return self._connector.fetchall()
|
||||
|
||||
def one(
|
||||
self,
|
||||
keys: list[str] | None = None,
|
||||
conditions: dict = None,
|
||||
combine_operator: str = "AND",
|
||||
):
|
||||
if keys is None:
|
||||
columns = "*"
|
||||
else:
|
||||
columns = ", ".join(keys)
|
||||
condition_sql = self.__select_condition(conditions, combine_operator)
|
||||
self._connector.execute(
|
||||
f"""
|
||||
SELECT {columns} FROM {self._table_name}
|
||||
WHERE {condition_sql}
|
||||
""",
|
||||
conditions,
|
||||
)
|
||||
return self._connector.fetchone(keys)
|
||||
|
||||
def many(
|
||||
self,
|
||||
keys: list[str] | None = None,
|
||||
conditions: dict = None,
|
||||
combine_operator: str = "AND",
|
||||
limit: int = None,
|
||||
):
|
||||
if keys is None:
|
||||
columns = "*"
|
||||
else:
|
||||
columns = ", ".join(keys)
|
||||
if limit is None:
|
||||
limit = 10000
|
||||
condition_sql = self.__select_condition(conditions, combine_operator)
|
||||
self._connector.execute(
|
||||
f"""
|
||||
SELECT {columns} FROM {self._table_name}
|
||||
WHERE {condition_sql}
|
||||
LIMIT {limit}
|
||||
""",
|
||||
conditions,
|
||||
)
|
||||
return self._connector.fetchall(keys)
|
||||
|
||||
def column(self, keys: list[str]):
|
||||
columns = ", ".join(keys)
|
||||
self._connector.execute(
|
||||
f"""
|
||||
SELECT {columns} FROM {self._table_name}
|
||||
""",
|
||||
)
|
||||
return self._connector.fetchall(keys)
|
||||
|
||||
@staticmethod
|
||||
def __select_condition(conditions: dict, combine_operator: str = "AND"):
|
||||
if not conditions:
|
||||
raise ValueError("No conditions provided.")
|
||||
if combine_operator not in ["AND", "OR", "INSTR"]:
|
||||
raise ValueError(
|
||||
"Invalid combine_operator, must be 'AND' or 'OR' or 'INSTR'."
|
||||
)
|
||||
if combine_operator == "INSTR":
|
||||
condition_sql = f" AND ".join(
|
||||
[f"INSTR({key}, :{key})" for key in conditions.keys()]
|
||||
)
|
||||
else:
|
||||
condition_sql = f" {combine_operator} ".join(
|
||||
[f"{key} = :{key}" for key in conditions.keys()]
|
||||
)
|
||||
return condition_sql
|
||||
@@ -1,98 +0,0 @@
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Update:
|
||||
def __init__(self, connector, table_name: str, data: dict):
|
||||
self._connector = connector
|
||||
self._table_name = table_name
|
||||
self._example_data = data
|
||||
|
||||
def __table_exists(self) -> bool:
|
||||
self._connector.execute(
|
||||
f"""
|
||||
SELECT name FROM sqlite_master
|
||||
WHERE type='table' AND name='{self._table_name}'
|
||||
"""
|
||||
)
|
||||
return self._connector.fetch() is not None
|
||||
|
||||
def table(self):
|
||||
columns = ", ".join(
|
||||
[
|
||||
f"{key} {self.__python_to_sqlite_type(value)}"
|
||||
for key, value in self._example_data.items()
|
||||
]
|
||||
)
|
||||
create_table_sql = f"CREATE TABLE IF NOT EXISTS {self._table_name} ({columns});"
|
||||
self._connector.execute(create_table_sql)
|
||||
logger.debug(f"Create table {self._table_name}.")
|
||||
self._connector.execute(f"PRAGMA table_info({self._table_name})")
|
||||
existing_columns = [x[1] for x in self._connector.fetch()]
|
||||
for key, value in self._example_data.items():
|
||||
if key not in existing_columns:
|
||||
insert_column = self.__python_to_sqlite_type(value)
|
||||
if value is None:
|
||||
value = "NULL"
|
||||
add_column_sql = f"ALTER TABLE {self._table_name} ADD COLUMN {key} {insert_column} DEFAULT {value};"
|
||||
self._connector.execute(add_column_sql)
|
||||
logger.debug(f"Update table {self._table_name}.")
|
||||
|
||||
def one(self, data: dict) -> bool:
|
||||
_id = data["id"]
|
||||
set_sql = ", ".join([f"{key} = :{key}" for key in data.keys()])
|
||||
self._connector.execute(
|
||||
f"""
|
||||
UPDATE {self._table_name}
|
||||
SET {set_sql}
|
||||
WHERE id = :id
|
||||
""",
|
||||
data,
|
||||
)
|
||||
logger.debug(f"Update {_id} in {self._table_name}.")
|
||||
return True
|
||||
|
||||
def many(self, data: list[dict]) -> bool:
|
||||
columns = ", ".join([f"{key} = :{key}" for key in data[0].keys()])
|
||||
self._connector.executemany(
|
||||
f"""
|
||||
UPDATE {self._table_name}
|
||||
SET {columns}
|
||||
WHERE id = :id
|
||||
""",
|
||||
data,
|
||||
)
|
||||
logger.debug(f"Update {self._table_name}.")
|
||||
return True
|
||||
|
||||
def value(self, location: dict, set_value: dict) -> bool:
|
||||
set_sql = ", ".join([f"{key} = :{key}" for key in set_value.keys()])
|
||||
params = {**location, **set_value}
|
||||
self._connector.execute(
|
||||
f"""
|
||||
UPDATE {self._table_name}
|
||||
SET {set_sql}
|
||||
WHERE {location["key"]} = :{location["key"]}
|
||||
""",
|
||||
params,
|
||||
)
|
||||
logger.debug(f"Update {self._table_name}.")
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def __python_to_sqlite_type(value) -> str:
|
||||
if isinstance(value, int):
|
||||
return "INTEGER NOT NULL"
|
||||
elif isinstance(value, float):
|
||||
return "REAL NOT NULL"
|
||||
elif isinstance(value, str):
|
||||
return "TEXT NOT NULL"
|
||||
elif isinstance(value, bool):
|
||||
return "INTEGER NOT NULL"
|
||||
elif isinstance(value, list):
|
||||
return "TEXT NOT NULL"
|
||||
elif value is None:
|
||||
return "TEXT"
|
||||
else:
|
||||
raise ValueError(f"Unsupported data type: {type(value)}")
|
||||
@@ -2,72 +2,58 @@ import logging
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from module.database.connector import DataConnector
|
||||
from module.models.user import User
|
||||
from module.models.user import User, UserUpdate, UserLogin
|
||||
from module.security.jwt import get_password_hash, verify_password
|
||||
from module.database.engine import engine
|
||||
from sqlmodel import Session, select, SQLModel
|
||||
from sqlalchemy.exc import UnboundExecutionError, OperationalError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AuthDB(DataConnector):
|
||||
class UserDatabase(Session):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.__table_name = "user"
|
||||
if not self._table_exists(self.__table_name):
|
||||
self.__update_table()
|
||||
|
||||
def __update_table(self):
|
||||
db_data = self.__data_to_db(User())
|
||||
self._update_table(self.__table_name, db_data)
|
||||
self._insert(self.__table_name, db_data)
|
||||
|
||||
@staticmethod
|
||||
def __data_to_db(data: User) -> dict:
|
||||
db_data = data.dict()
|
||||
db_data["password"] = get_password_hash(db_data["password"])
|
||||
return db_data
|
||||
|
||||
@staticmethod
|
||||
def __db_to_data(db_data: dict) -> User:
|
||||
return User(**db_data)
|
||||
super().__init__(engine)
|
||||
statement = select(User)
|
||||
try:
|
||||
self.exec(statement)
|
||||
except OperationalError:
|
||||
SQLModel.metadata.create_all(engine)
|
||||
self.add(User())
|
||||
self.commit()
|
||||
|
||||
def get_user(self, username):
|
||||
self._cursor.execute(
|
||||
f"SELECT * FROM {self.__table_name} WHERE username=?", (username,)
|
||||
)
|
||||
result = self._cursor.fetchone()
|
||||
statement = select(User).where(User.username == username)
|
||||
result = self.exec(statement).first()
|
||||
if not result:
|
||||
return None
|
||||
db_data = dict(zip([x[0] for x in self._cursor.description], result))
|
||||
return self.__db_to_data(db_data)
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
return result
|
||||
|
||||
def auth_user(self, username, password) -> bool:
|
||||
self._cursor.execute(
|
||||
f"SELECT username, password FROM {self.__table_name} WHERE username=?",
|
||||
(username,),
|
||||
)
|
||||
result = self._cursor.fetchone()
|
||||
def auth_user(self, user: UserLogin) -> bool:
|
||||
statement = select(User).where(User.username == user.username)
|
||||
result = self.exec(statement).first()
|
||||
if not result:
|
||||
raise HTTPException(status_code=401, detail="User not found")
|
||||
if not verify_password(password, result[1]):
|
||||
if not verify_password(user.password, result.password):
|
||||
raise HTTPException(status_code=401, detail="Password error")
|
||||
return True
|
||||
|
||||
def update_user(self, username, update_user: User):
|
||||
def update_user(self, username, update_user: UserUpdate):
|
||||
# Update username and password
|
||||
new_username = update_user.username
|
||||
new_password = update_user.password
|
||||
self._cursor.execute(
|
||||
f"""
|
||||
UPDATE {self.__table_name}
|
||||
SET username = '{new_username}', password = '{get_password_hash(new_password)}'
|
||||
WHERE username = '{username}'
|
||||
"""
|
||||
)
|
||||
self._conn.commit()
|
||||
statement = select(User).where(User.username == username)
|
||||
result = self.exec(statement).first()
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
if update_user.username:
|
||||
result.username = update_user.username
|
||||
if update_user.password:
|
||||
result.password = get_password_hash(update_user.password)
|
||||
self.add(result)
|
||||
self.commit()
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
with AuthDB() as db:
|
||||
with UserDatabase() as db:
|
||||
# db.update_user(UserLogin(username="admin", password="adminadmin"), User(username="admin", password="cica1234"))
|
||||
db.update_user("admin", User(username="estrella", password="cica1234"))
|
||||
db.update_user("admin", UserUpdate(username="estrella", password="cica1234"))
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
|
||||
from module.conf import settings
|
||||
from module.models import BangumiData
|
||||
from module.models import Bangumi
|
||||
|
||||
from .path import TorrentPath
|
||||
|
||||
@@ -68,7 +68,7 @@ class DownloadClient(TorrentPath):
|
||||
prefs = self.client.get_app_prefs()
|
||||
settings.downloader.path = self._join_path(prefs["save_path"], "Bangumi")
|
||||
|
||||
def set_rule(self, data: BangumiData):
|
||||
def set_rule(self, data: Bangumi):
|
||||
data.rule_name = self._rule_name(data)
|
||||
data.save_path = self._gen_save_path(data)
|
||||
rule = {
|
||||
@@ -92,7 +92,7 @@ class DownloadClient(TorrentPath):
|
||||
f"[Downloader] Add {data.official_title} Season {data.season} to auto download rules."
|
||||
)
|
||||
|
||||
def set_rules(self, bangumi_info: list[BangumiData]):
|
||||
def set_rules(self, bangumi_info: list[Bangumi]):
|
||||
logger.debug("[Downloader] Start adding rules.")
|
||||
for info in bangumi_info:
|
||||
self.set_rule(info)
|
||||
|
||||
@@ -4,8 +4,7 @@ import re
|
||||
from pathlib import Path
|
||||
|
||||
from module.conf import settings
|
||||
from module.models import BangumiData
|
||||
|
||||
from module.models import Bangumi
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -50,7 +49,7 @@ class TorrentPath:
|
||||
return self._file_depth(file_path) <= 2
|
||||
|
||||
@staticmethod
|
||||
def _gen_save_path(data: BangumiData):
|
||||
def _gen_save_path(data: Bangumi):
|
||||
folder = (
|
||||
f"{data.official_title} ({data.year})" if data.year else data.official_title
|
||||
)
|
||||
@@ -58,7 +57,7 @@ class TorrentPath:
|
||||
return str(save_path)
|
||||
|
||||
@staticmethod
|
||||
def _rule_name(data: BangumiData):
|
||||
def _rule_name(data: Bangumi):
|
||||
rule_name = (
|
||||
f"[{data.group_name}] {data.official_title} S{data.season}"
|
||||
if settings.bangumi_manage.group_tag
|
||||
|
||||
@@ -2,14 +2,14 @@ import logging
|
||||
|
||||
from module.database import BangumiDatabase
|
||||
from module.downloader import DownloadClient
|
||||
from module.models import BangumiData
|
||||
from module.models import Bangumi
|
||||
from module.searcher import SearchTorrent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SeasonCollector(DownloadClient):
|
||||
def add_season_torrents(self, data: BangumiData, torrents, torrent_files=None):
|
||||
def add_season_torrents(self, data: Bangumi, torrents, torrent_files=None):
|
||||
if torrent_files:
|
||||
download_info = {
|
||||
"torrent_files": torrent_files,
|
||||
@@ -23,7 +23,7 @@ class SeasonCollector(DownloadClient):
|
||||
}
|
||||
return self.add_torrent(download_info)
|
||||
|
||||
def collect_season(self, data: BangumiData, link: str = None, proxy: bool = False):
|
||||
def collect_season(self, data: Bangumi, link: str = None, proxy: bool = False):
|
||||
logger.info(f"Start collecting {data.official_title} Season {data.season}...")
|
||||
with SearchTorrent() as st:
|
||||
if not link:
|
||||
@@ -39,7 +39,7 @@ class SeasonCollector(DownloadClient):
|
||||
data=data, torrents=torrents, torrent_files=torrent_files
|
||||
)
|
||||
|
||||
def subscribe_season(self, data: BangumiData):
|
||||
def subscribe_season(self, data: Bangumi):
|
||||
with BangumiDatabase() as db:
|
||||
data.added = True
|
||||
data.eps_collect = True
|
||||
|
||||
@@ -4,21 +4,21 @@ from fastapi.responses import JSONResponse
|
||||
|
||||
from module.database import BangumiDatabase
|
||||
from module.downloader import DownloadClient
|
||||
from module.models import BangumiData
|
||||
from module.models import Bangumi
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TorrentManager(BangumiDatabase):
|
||||
@staticmethod
|
||||
def __match_torrents_list(data: BangumiData) -> list:
|
||||
def __match_torrents_list(data: Bangumi) -> list:
|
||||
with DownloadClient() as client:
|
||||
torrents = client.get_torrent_info(status_filter=None)
|
||||
return [
|
||||
torrent.hash for torrent in torrents if torrent.save_path == data.save_path
|
||||
]
|
||||
|
||||
def delete_torrents(self, data: BangumiData, client: DownloadClient):
|
||||
def delete_torrents(self, data: Bangumi, client: DownloadClient):
|
||||
hash_list = self.__match_torrents_list(data)
|
||||
if hash_list:
|
||||
client.delete_torrent(hash_list)
|
||||
@@ -29,7 +29,7 @@ class TorrentManager(BangumiDatabase):
|
||||
|
||||
def delete_rule(self, _id: int | str, file: bool = False):
|
||||
data = self.search_id(int(_id))
|
||||
if isinstance(data, BangumiData):
|
||||
if isinstance(data, Bangumi):
|
||||
with DownloadClient() as client:
|
||||
client.remove_rule(data.rule_name)
|
||||
client.remove_rss_feed(data.official_title)
|
||||
@@ -54,7 +54,7 @@ class TorrentManager(BangumiDatabase):
|
||||
|
||||
def disable_rule(self, _id: str | int, file: bool = False):
|
||||
data = self.search_id(int(_id))
|
||||
if isinstance(data, BangumiData):
|
||||
if isinstance(data, Bangumi):
|
||||
with DownloadClient() as client:
|
||||
client.remove_rule(data.rule_name)
|
||||
data.deleted = True
|
||||
@@ -81,7 +81,7 @@ class TorrentManager(BangumiDatabase):
|
||||
|
||||
def enable_rule(self, _id: str | int):
|
||||
data = self.search_id(int(_id))
|
||||
if isinstance(data, BangumiData):
|
||||
if isinstance(data, Bangumi):
|
||||
data.deleted = False
|
||||
self.update_one(data)
|
||||
with DownloadClient() as client:
|
||||
@@ -98,7 +98,7 @@ class TorrentManager(BangumiDatabase):
|
||||
status_code=406, content={"msg": f"Can't find bangumi id {_id}"}
|
||||
)
|
||||
|
||||
def update_rule(self, data: BangumiData):
|
||||
def update_rule(self, data: Bangumi):
|
||||
old_data = self.search_id(data.id)
|
||||
if not old_data:
|
||||
logger.error(f"[Manager] Can't find data with {data.id}")
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from .bangumi import *
|
||||
from .bangumi import Bangumi, Episode
|
||||
from .config import Config
|
||||
from .rss import RSSTorrents
|
||||
from .torrent import EpisodeFile, SubtitleFile, TorrentBase
|
||||
|
||||
@@ -1,27 +1,50 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel
|
||||
from sqlmodel import SQLModel, Field
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class BangumiData(BaseModel):
|
||||
id: int = Field(0, alias="id", title="番剧ID")
|
||||
official_title: str = Field("official_title", alias="official_title", title="番剧中文名")
|
||||
year: str | None = Field(None, alias="year", title="番剧年份")
|
||||
title_raw: str = Field("title_raw", alias="title_raw", title="番剧原名")
|
||||
season: int = Field(1, alias="season", title="番剧季度")
|
||||
season_raw: str | None = Field(None, alias="season_raw", title="番剧季度原名")
|
||||
group_name: str | None = Field(None, alias="group_name", title="字幕组")
|
||||
dpi: str | None = Field(None, alias="dpi", title="分辨率")
|
||||
source: str | None = Field(None, alias="source", title="来源")
|
||||
subtitle: str | None = Field(None, alias="subtitle", title="字幕")
|
||||
eps_collect: bool = Field(False, alias="eps_collect", title="是否已收集")
|
||||
offset: int = Field(0, alias="offset", title="番剧偏移量")
|
||||
filter: list[str] = Field(["720", "\\d+-\\d+"], alias="filter", title="番剧过滤器")
|
||||
rss_link: list[str] = Field([], alias="rss_link", title="番剧RSS链接")
|
||||
poster_link: str | None = Field(None, alias="poster_link", title="番剧海报链接")
|
||||
added: bool = Field(False, alias="added", title="是否已添加")
|
||||
rule_name: str | None = Field(None, alias="rule_name", title="番剧规则名")
|
||||
save_path: str | None = Field(None, alias="save_path", title="番剧保存路径")
|
||||
class Bangumi(SQLModel, table=True):
|
||||
id: int = Field(default=None, primary_key=True)
|
||||
official_title: str = Field(
|
||||
default="official_title", alias="official_title", title="番剧中文名"
|
||||
)
|
||||
year: Optional[str] = Field(alias="year", title="番剧年份")
|
||||
title_raw: str = Field(default="title_raw", alias="title_raw", title="番剧原名")
|
||||
season: int = Field(default=1, alias="season", title="番剧季度")
|
||||
season_raw: Optional[str] = Field(alias="season_raw", title="番剧季度原名")
|
||||
group_name: Optional[str] = Field(alias="group_name", title="字幕组")
|
||||
dpi: Optional[str] = Field(alias="dpi", title="分辨率")
|
||||
source: Optional[str] = Field(alias="source", title="来源")
|
||||
subtitle: Optional[str] = Field(alias="subtitle", title="字幕")
|
||||
eps_collect: bool = Field(default=False, alias="eps_collect", title="是否已收集")
|
||||
offset: int = Field(default=0, alias="offset", title="番剧偏移量")
|
||||
filter: str = Field(default="720, \\d+-\\d+", alias="filter", title="番剧过滤器")
|
||||
rss_link: str = Field(default="", alias="rss_link", title="番剧RSS链接")
|
||||
poster_link: Optional[str] = Field(alias="poster_link", title="番剧海报链接")
|
||||
added: bool = Field(default=False, alias="added", title="是否已添加")
|
||||
rule_name: Optional[str] = Field(alias="rule_name", title="番剧规则名")
|
||||
save_path: Optional[str] = Field(alias="save_path", title="番剧保存路径")
|
||||
deleted: bool = Field(False, alias="deleted", title="是否已删除")
|
||||
|
||||
|
||||
class BangumiUpdate(SQLModel):
|
||||
official_title: str = Field(
|
||||
default="official_title", alias="official_title", title="番剧中文名"
|
||||
)
|
||||
year: Optional[str] = Field(alias="year", title="番剧年份")
|
||||
season: int = Field(default=1, alias="season", title="番剧季度")
|
||||
season_raw: Optional[str] = Field(alias="season_raw", title="番剧季度原名")
|
||||
group_name: Optional[str] = Field(alias="group_name", title="字幕组")
|
||||
dpi: Optional[str] = Field(alias="dpi", title="分辨率")
|
||||
source: Optional[str] = Field(alias="source", title="来源")
|
||||
subtitle: Optional[str] = Field(alias="subtitle", title="字幕")
|
||||
eps_collect: bool = Field(default=False, alias="eps_collect", title="是否已收集")
|
||||
offset: int = Field(default=0, alias="offset", title="番剧偏移量")
|
||||
filter: str = Field(default="720, \\d+-\\d+", alias="filter", title="番剧过滤器")
|
||||
rss_link: str = Field(default="", alias="rss_link", title="番剧RSS链接")
|
||||
added: bool = Field(default=False, alias="added", title="是否已添加")
|
||||
deleted: bool = Field(False, alias="deleted", title="是否已删除")
|
||||
|
||||
|
||||
@@ -29,14 +52,14 @@ class Notification(BaseModel):
|
||||
official_title: str = Field(..., alias="official_title", title="番剧名")
|
||||
season: int = Field(..., alias="season", title="番剧季度")
|
||||
episode: int = Field(..., alias="episode", title="番剧集数")
|
||||
poster_path: str | None = Field(None, alias="poster_path", title="番剧海报路径")
|
||||
poster_path: Optional[str] = Field(None, alias="poster_path", title="番剧海报路径")
|
||||
|
||||
|
||||
@dataclass
|
||||
class Episode:
|
||||
title_en: str | None
|
||||
title_zh: str | None
|
||||
title_jp: str | None
|
||||
title_en: Optional[str]
|
||||
title_zh: Optional[str]
|
||||
title_jp: Optional[str]
|
||||
season: int
|
||||
season_raw: str
|
||||
episode: int
|
||||
|
||||
@@ -14,7 +14,9 @@ class Downloader(BaseModel):
|
||||
type: str = Field("qbittorrent", description="Downloader type")
|
||||
host: str = Field("172.17.0.1:8080", description="Downloader host")
|
||||
username_: str = Field("admin", alias="username", description="Downloader username")
|
||||
password_: str = Field("adminadmin", alias="password", description="Downloader password")
|
||||
password_: str = Field(
|
||||
"adminadmin", alias="password", description="Downloader password"
|
||||
)
|
||||
path: str = Field("/downloads/Bangumi", description="Downloader path")
|
||||
ssl: bool = Field(False, description="Downloader ssl")
|
||||
|
||||
@@ -26,6 +28,7 @@ class Downloader(BaseModel):
|
||||
def password(self):
|
||||
return expandvars(self.password_)
|
||||
|
||||
|
||||
class RSSParser(BaseModel):
|
||||
enable: bool = Field(True, description="Enable RSS parser")
|
||||
type: str = Field("mikan", description="RSS parser type")
|
||||
@@ -39,6 +42,7 @@ class RSSParser(BaseModel):
|
||||
def token(self):
|
||||
return expandvars(self.token_)
|
||||
|
||||
|
||||
class BangumiManage(BaseModel):
|
||||
enable: bool = Field(True, description="Enable bangumi manage")
|
||||
eps_complete: bool = Field(False, description="Enable eps complete")
|
||||
@@ -82,6 +86,7 @@ class Notification(BaseModel):
|
||||
def chat_id(self):
|
||||
return expandvars(self.chat_id_)
|
||||
|
||||
|
||||
class Config(BaseModel):
|
||||
program: Program = Program()
|
||||
downloader: Downloader = Downloader()
|
||||
|
||||
@@ -1,14 +1,24 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
from sqlmodel import SQLModel, Field
|
||||
|
||||
|
||||
class User(BaseModel):
|
||||
class User(SQLModel, table=True):
|
||||
id: int = Field(default=None, primary_key=True)
|
||||
username: str = Field(
|
||||
"admin", min_length=4, max_length=20, regex=r"^[a-zA-Z0-9_]+$"
|
||||
)
|
||||
password: str = Field("adminadmin", min_length=8)
|
||||
|
||||
|
||||
class UserLogin(BaseModel):
|
||||
class UserUpdate(SQLModel):
|
||||
username: Optional[str] = Field(
|
||||
None, min_length=4, max_length=20, regex=r"^[a-zA-Z0-9_]+$"
|
||||
)
|
||||
password: Optional[str] = Field(None, min_length=8)
|
||||
|
||||
|
||||
class UserLogin(SQLModel):
|
||||
username: str
|
||||
password: str = Field(..., min_length=8)
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
|
||||
from module.conf import settings
|
||||
from module.models import BangumiData
|
||||
from module.models import Bangumi
|
||||
|
||||
from .analyser import raw_parser, tmdb_parser, torrent_parser
|
||||
|
||||
@@ -39,7 +39,7 @@ class TitleParser:
|
||||
return official_title, tmdb_season, year
|
||||
|
||||
@staticmethod
|
||||
def raw_parser(raw: str, rss_link: str) -> BangumiData | None:
|
||||
def raw_parser(raw: str, rss_link: str) -> Bangumi | None:
|
||||
language = settings.rss_parser.language
|
||||
try:
|
||||
episode = raw_parser(raw)
|
||||
@@ -60,7 +60,7 @@ class TitleParser:
|
||||
else:
|
||||
official_title = title_raw
|
||||
_season = episode.season
|
||||
data = BangumiData(
|
||||
data = Bangumi(
|
||||
official_title=official_title,
|
||||
title_raw=title_raw,
|
||||
season=_season,
|
||||
|
||||
@@ -3,7 +3,7 @@ import re
|
||||
|
||||
from module.conf import settings
|
||||
from module.database import BangumiDatabase
|
||||
from module.models import BangumiData
|
||||
from module.models import Bangumi
|
||||
from module.network import RequestContent, TorrentInfo
|
||||
from module.parser import TitleParser
|
||||
|
||||
@@ -16,7 +16,7 @@ class RSSAnalyser:
|
||||
with BangumiDatabase() as db:
|
||||
db.update_table()
|
||||
|
||||
def official_title_parser(self, data: BangumiData, mikan_title: str):
|
||||
def official_title_parser(self, data: Bangumi, mikan_title: str):
|
||||
if settings.rss_parser.parser_type == "mikan":
|
||||
data.official_title = mikan_title if mikan_title else data.official_title
|
||||
elif settings.rss_parser.parser_type == "tmdb":
|
||||
@@ -63,7 +63,7 @@ class RSSAnalyser:
|
||||
|
||||
def torrent_to_data(
|
||||
self, torrent: TorrentInfo, rss_link: str | None = None
|
||||
) -> BangumiData:
|
||||
) -> Bangumi:
|
||||
data = self._title_analyser.raw_parser(raw=torrent.name, rss_link=rss_link)
|
||||
if data:
|
||||
try:
|
||||
@@ -79,7 +79,7 @@ class RSSAnalyser:
|
||||
|
||||
def rss_to_data(
|
||||
self, rss_link: str, database: BangumiDatabase, full_parse: bool = True
|
||||
) -> list[BangumiData]:
|
||||
) -> list[Bangumi]:
|
||||
rss_torrents = self.get_rss_torrents(rss_link, full_parse)
|
||||
torrents_to_add = database.match_list(rss_torrents, rss_link)
|
||||
if not torrents_to_add:
|
||||
@@ -92,7 +92,7 @@ class RSSAnalyser:
|
||||
else:
|
||||
return []
|
||||
|
||||
def link_to_data(self, link: str) -> BangumiData:
|
||||
def link_to_data(self, link: str) -> Bangumi:
|
||||
torrents = self.get_rss_torrents(link, False)
|
||||
for torrent in torrents:
|
||||
data = self.torrent_to_data(torrent, link)
|
||||
|
||||
@@ -3,7 +3,7 @@ import logging
|
||||
from module.conf import settings
|
||||
from module.database import BangumiDatabase
|
||||
from module.downloader import DownloadClient
|
||||
from module.models import BangumiData
|
||||
from module.models import Bangumi
|
||||
from module.network import RequestContent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -14,7 +14,7 @@ def matched(torrent_title: str):
|
||||
return db.match_torrent(torrent_title)
|
||||
|
||||
|
||||
def save_path(data: BangumiData):
|
||||
def save_path(data: Bangumi):
|
||||
folder = (
|
||||
f"{data.official_title}({data.year})" if data.year else f"{data.official_title}"
|
||||
)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import re
|
||||
|
||||
from module.database import RSSDatabase
|
||||
from module.models import BangumiData, RSSTorrents
|
||||
from module.models import Bangumi, RSSTorrents
|
||||
from module.network import RequestContent, TorrentInfo
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ class RSSPoller(RSSDatabase):
|
||||
return req.get_torrents(rss_link)
|
||||
|
||||
@staticmethod
|
||||
def filter_torrent(data: BangumiData, torrent: TorrentInfo) -> bool:
|
||||
def filter_torrent(data: Bangumi, torrent: TorrentInfo) -> bool:
|
||||
if data.title_raw in torrent.name:
|
||||
_filter = "|".join(data.filter)
|
||||
if not re.search(_filter, torrent.name):
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from module.models import BangumiData, TorrentBase
|
||||
from module.models import Bangumi, TorrentBase
|
||||
from module.network import RequestContent
|
||||
from module.searcher.plugin import search_url
|
||||
|
||||
@@ -30,7 +30,7 @@ class SearchTorrent(RequestContent):
|
||||
|
||||
return [TorrentBase(**d) for d in to_dict()]
|
||||
|
||||
def search_season(self, data: BangumiData):
|
||||
def search_season(self, data: Bangumi):
|
||||
keywords = [getattr(data, key) for key in SEARCH_KEY if getattr(data, key)]
|
||||
torrents = self.search_torrents(keywords)
|
||||
return [torrent for torrent in torrents if data.title_raw in torrent.name]
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
|
||||
from module.database.user import AuthDB
|
||||
from module.database.user import UserDatabase
|
||||
from module.models.user import User
|
||||
|
||||
from .jwt import verify_token
|
||||
@@ -20,7 +20,7 @@ async def get_current_user(token: str = Depends(oauth2_scheme)):
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token"
|
||||
)
|
||||
username = payload.get("sub")
|
||||
with AuthDB() as user_db:
|
||||
with UserDatabase as user_db:
|
||||
user = user_db.get_user(username)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
@@ -40,7 +40,7 @@ async def get_token_data(token: str = Depends(oauth2_scheme)):
|
||||
|
||||
def update_user_info(user_data: User, current_user):
|
||||
try:
|
||||
with AuthDB() as db:
|
||||
with UserDatabase as db:
|
||||
db.update_user(current_user.username, user_data)
|
||||
return True
|
||||
except Exception as e:
|
||||
@@ -48,5 +48,5 @@ def update_user_info(user_data: User, current_user):
|
||||
|
||||
|
||||
def auth_user(username, password):
|
||||
with AuthDB() as db:
|
||||
with UserDatabase() as db:
|
||||
db.auth_user(username, password)
|
||||
|
||||
@@ -2,7 +2,7 @@ import os
|
||||
|
||||
from module.conf import LEGACY_DATA_PATH
|
||||
from module.database import BangumiDatabase
|
||||
from module.models import BangumiData
|
||||
from module.models import Bangumi
|
||||
from module.utils import json_config
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ def data_migration():
|
||||
rss_link = old_data["rss_link"]
|
||||
new_data = []
|
||||
for info in infos:
|
||||
new_data.append(BangumiData(**info, rss_link=[rss_link]))
|
||||
new_data.append(Bangumi(**info, rss_link=[rss_link]))
|
||||
with BangumiDatabase() as database:
|
||||
database.update_table()
|
||||
database.insert_list(new_data)
|
||||
|
||||
@@ -1,11 +1,17 @@
|
||||
from sqlmodel import create_engine, SQLModel
|
||||
from sqlmodel.pool import StaticPool
|
||||
|
||||
from module.database import BangumiDatabase
|
||||
from module.models import BangumiData
|
||||
from module.models import Bangumi
|
||||
|
||||
|
||||
def test_database():
|
||||
TEST_PATH = "test/test.db"
|
||||
test_data = BangumiData(
|
||||
id=1,
|
||||
def test_bangumi_database():
|
||||
# sqlite mock engine
|
||||
engine = create_engine(
|
||||
"sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool
|
||||
)
|
||||
SQLModel.metadata.create_all(engine)
|
||||
test_data = Bangumi(
|
||||
official_title="test",
|
||||
year="2021",
|
||||
title_raw="test",
|
||||
@@ -17,18 +23,15 @@ def test_database():
|
||||
subtitle="test",
|
||||
eps_collect=False,
|
||||
offset=0,
|
||||
filter=["720p", "\\d+-\\d+"],
|
||||
rss_link=["test"],
|
||||
filter="720p,\\d+-\\d+",
|
||||
rss_link="test",
|
||||
poster_link="/test/test.jpg",
|
||||
added=False,
|
||||
rule_name=None,
|
||||
save_path=None,
|
||||
deleted=False,
|
||||
)
|
||||
with BangumiDatabase(database=TEST_PATH) as database:
|
||||
# create table
|
||||
database.update_table()
|
||||
with BangumiDatabase(database=TEST_PATH) as database:
|
||||
with BangumiDatabase(engine) as database:
|
||||
# insert
|
||||
database.insert_one(test_data)
|
||||
assert database.search_id(1) == test_data
|
||||
@@ -39,13 +42,8 @@ def test_database():
|
||||
assert database.search_id(1) == test_data
|
||||
|
||||
# search poster
|
||||
assert database.match_poster("test") == "/test/test.jpg"
|
||||
assert database.match_poster("test2 (2021)") == "/test/test.jpg"
|
||||
|
||||
# delete
|
||||
database.delete_one(1)
|
||||
assert database.search_id(1) is None
|
||||
|
||||
# Delete test database
|
||||
import os
|
||||
|
||||
os.remove(TEST_PATH)
|
||||
|
||||
Reference in New Issue
Block a user