From 9ae187a05d61e02c2f3f887e8c9f72e484f30905 Mon Sep 17 00:00:00 2001 From: estrella Date: Sun, 30 Jul 2023 21:27:03 +0800 Subject: [PATCH 1/6] feat: update hand-made orm to sqlmodel (BangumiDatabase) --- backend/src/module/database/bangumi.py | 216 ++++++++++++++----------- backend/src/module/database/engine.py | 6 + backend/src/module/models/bangumi.py | 67 +++++--- 3 files changed, 174 insertions(+), 115 deletions(-) create mode 100644 backend/src/module/database/engine.py diff --git a/backend/src/module/database/bangumi.py b/backend/src/module/database/bangumi.py index dd25edb5..9209e717 100644 --- a/backend/src/module/database/bangumi.py +++ b/backend/src/module/database/bangumi.py @@ -1,58 +1,59 @@ import logging from module.database.orm import Connector -from module.models import BangumiData +from module.models import Bangumi, BangumiUpdate from module.conf import DATA_PATH +from sqlmodel import Session, select, delete, SQLModel +from module.database.engine import engine +from typing import Optional +from sqlalchemy.exc import IntegrityError, NoResultFound logger = logging.getLogger(__name__) -class BangumiDatabase(Connector): - def __init__(self, database: str = DATA_PATH): - super().__init__( - table_name="bangumi", - data=self.__data_to_db(BangumiData()), - database=database, - ) - - def update_table(self): - self.update.table() +class BangumiDatabase(Session): + def __init__(self): + super().__init__(engine) + # table_name="bangumi", + # data=self.__data_to_db(BangumiData()), + # database=database, + # ) @staticmethod - def __data_to_db(data: BangumiData) -> dict: - db_data = data.dict() - for key, value in db_data.items(): - if isinstance(value, bool): - db_data[key] = int(value) - elif isinstance(value, list): - db_data[key] = ",".join(value) - return db_data + def update_table(): + SQLModel.metadata.create_all(engine) - @staticmethod - def __db_to_data(db_data: dict) -> BangumiData: - for key, item in db_data.items(): - if isinstance(item, int): - if key not in ["id", "offset", "season", "year"]: - db_data[key] = bool(item) - elif key in ["filter", "rss_link"]: - db_data[key] = item.split(",") - return BangumiData(**db_data) + # @staticmethod + # def __data_to_db(data: BangumiData) -> Bangumi: + # db_data = data.dict() + # for key, value in db_data.items(): + # if isinstance(value, bool): + # db_data[key] = int(value) + # elif isinstance(value, list): + # db_data[key] = ",".join(value) + # return db_data + # + # @staticmethod + # def __db_to_data(db_data: Bangumi) -> BangumiData: + # for key, item in db_data.items(): + # if isinstance(item, int): + # if key not in ["id", "offset", "season", "year"]: + # db_data[key] = bool(item) + # elif key in ["filter", "rss_link"]: + # db_data[key] = item.split(",") + # return BangumiData(**db_data) - def insert_one(self, data: BangumiData): - db_data = self.__data_to_db(data) - self.insert.one(db_data) + def insert_one(self, data: Bangumi): + self.add(data) + self.commit() + # db_data = self.__data_to_db(data) + # self.insert.one(db_data) logger.debug(f"[Database] Insert {data.official_title} into database.") - # if self.__check_exist(data): - # self.update_one(data) - # else: - # db_data = self.__data_to_db(data) - # db_data["id"] = self.gen_id() - # self._insert(db_data=db_data, table_name=self.__table_name) - # logger.debug(f"[Database] Insert {data.official_title} into database.") - def insert_list(self, data: list[BangumiData]): - data_list = [self.__data_to_db(x) for x in data] - self.insert.many(data_list) + def insert_list(self, data: list[Bangumi]): + self.add_all(data) + # data_list = [self.__data_to_db(x) for x in data] + # self.insert.many(data_list) # _id = self.gen_id() # for i, item in enumerate(data): # item.id = _id + i @@ -60,62 +61,92 @@ class BangumiDatabase(Connector): # self._insert_list(data_list=data_list, table_name=self.__table_name) logger.debug(f"[Database] Insert {len(data)} bangumi into database.") - def update_one(self, data: BangumiData) -> bool: - db_data = self.__data_to_db(data) - return self.update.one(db_data) + def update_one(self, data: BangumiUpdate) -> bool: + db_data = self.get(Bangumi, data.id) + if not db_data: + return False + bangumi_data = data.dict(exclude_unset=True) + for key, value in bangumi_data.items(): + setattr(db_data, key, value) + self.add(db_data) + self.commit() + self.refresh(db_data) + logger.debug(f"[Database] Update {data.official_title}") + return True - def update_list(self, data: list[BangumiData]): - data_list = [self.__data_to_db(x) for x in data] - self.update.many(data_list) + def update_list(self, datas: list[BangumiUpdate]): + for data in datas: + self.update_one(data) def update_rss(self, title_raw, rss_set: str): # Update rss and added - location = {"title_raw": title_raw} - set_value = {"rss_link": rss_set, "added": 0} - self.update.value(location, set_value) + statement = select(Bangumi).where(Bangumi.title_raw == title_raw) + bangumi = self.exec(statement).first() + bangumi.rss_link = rss_set + bangumi.added = 0 + self.add(bangumi) + self.commit() + self.refresh(bangumi) + # location = {"title_raw": title_raw} + # set_value = {"rss_link": rss_set, "added": 0} + # self.update.value(location, set_value) logger.debug(f"[Database] Update {title_raw} rss_link to {rss_set}.") def update_poster(self, title_raw, poster_link: str): - location = {"title_raw": title_raw} - set_value = {"poster_link": poster_link} - self.update.value(location, set_value) + statement = select(Bangumi).where(Bangumi.title_raw == title_raw) + bangumi = self.exec(statement).first() + bangumi.poster_link = poster_link + self.add(bangumi) + self.commit() + self.refresh(bangumi) logger.debug(f"[Database] Update {title_raw} poster_link to {poster_link}.") def delete_one(self, _id: int): - self.delete.one(_id) + statement = select(Bangumi).where(Bangumi.id == _id) + bangumi = self.exec(statement).first() + self.delete(bangumi) + self.commit() logger.debug(f"[Database] Delete bangumi id: {_id}.") def delete_all(self): - self.delete.all() + statement = delete(Bangumi) + self.exec(statement) + self.commit() - def search_all(self) -> list[BangumiData]: - all_data = self.select.all() - return [self.__db_to_data(x) for x in all_data] + def search_all(self) -> list[Bangumi]: + statement = select(Bangumi) + return self.exec(statement).all() - def search_id(self, _id: int) -> BangumiData | None: - dict_data = self.select.one(conditions={"id": _id}) - if dict_data is None: + def search_id(self, _id: int) -> Optional[Bangumi]: + statement = select(Bangumi).where(Bangumi.id == _id) + bangumi = self.exec(statement).first() + if bangumi is None: logger.warning(f"[Database] Cannot find bangumi id: {_id}.") return None - logger.debug(f"[Database] Find bangumi id: {_id}.") - return self.__db_to_data(dict_data) + else: + logger.debug(f"[Database] Find bangumi id: {_id}.") + return self.exec(statement).first() def match_poster(self, bangumi_name: str) -> str: - condition = {"official_title": bangumi_name} - keys = ["poster_link"] - data = self.select.one( - keys=keys, - conditions=condition, - combine_operator="INSTR", - ) - if not data: + # condition = {"official_title": bangumi_name} + statement = select(Bangumi).where(Bangumi.official_title == bangumi_name) + # keys = ["poster_link"] + # data = self.select.one( + # keys=keys, + # conditions=condition, + # combine_operator="INSTR", + # ) + data = self.exec(statement).first() + if data: + return data.poster_link + else: return "" - return data.get("poster_link") def match_list(self, torrent_list: list, rss_link: str) -> list: # Match title_raw in database - keys = ["title_raw", "rss_link", "poster_link"] - match_datas = self.select.column(keys) + # keys = ["title_raw", "rss_link", "poster_link"] + # match_datas = self.select.column(keys) + match_datas = self.search_all() if not match_datas: return torrent_list # Match title @@ -123,15 +154,15 @@ class BangumiDatabase(Connector): while i < len(torrent_list): torrent = torrent_list[i] for match_data in match_datas: - if match_data.get("title_raw") in torrent.name: - if rss_link not in match_data.get("rss_link"): - match_data["rss_link"] += f",{rss_link}" + if match_data.title_raw in torrent.name: + if rss_link not in match_data.rss_link: + match_data.rss_link += f",{rss_link}" self.update_rss( - match_data.get("title_raw"), match_data.get("rss_link") + match_data.title_raw, match_data.rss_link ) - if not match_data.get("poster_link"): + if not match_data.poster_link: self.update_poster( - match_data.get("title_raw"), torrent.poster_link + match_data.title_raw, torrent.poster_link ) torrent_list.pop(i) break @@ -139,20 +170,23 @@ class BangumiDatabase(Connector): i += 1 return torrent_list - def not_complete(self) -> list[BangumiData]: + def not_complete(self) -> list[Bangumi]: # Find eps_complete = False - condition = {"eps_collect": 0} - dict_data = self.select.many( - conditions=condition, - ) - return [self.__db_to_data(x) for x in dict_data] + condition = select(Bangumi).where(Bangumi.eps_collect == 0) + datas = self.exec(condition).all() + return datas - def not_added(self) -> list[BangumiData]: - conditions = {"added": 0, "rule_name": None, "save_path": None} - dict_data = self.select.many(conditions=conditions, combine_operator="OR") - return [self.__db_to_data(x) for x in dict_data] + def not_added(self) -> list[Bangumi]: + conditions = select(Bangumi).where( + Bangumi.added == 0 or + Bangumi.rule_name is None or + Bangumi.save_path is None + ) + datas = self.exec(conditions).all() + # dict_data = self.select.many(conditions=conditions, combine_operator="OR") + return datas if __name__ == "__main__": with BangumiDatabase() as db: - print(db.match_poster("久保")) + print(db.not_complete()) diff --git a/backend/src/module/database/engine.py b/backend/src/module/database/engine.py new file mode 100644 index 00000000..e4bd2f5d --- /dev/null +++ b/backend/src/module/database/engine.py @@ -0,0 +1,6 @@ +from sqlmodel import create_engine, Session + + +engine = create_engine("sqlite:///data/data.db") + +db_session = Session(engine) \ No newline at end of file diff --git a/backend/src/module/models/bangumi.py b/backend/src/module/models/bangumi.py index b8af3670..02a74ce9 100644 --- a/backend/src/module/models/bangumi.py +++ b/backend/src/module/models/bangumi.py @@ -1,27 +1,46 @@ from dataclasses import dataclass -from pydantic import BaseModel, Field +from pydantic import BaseModel +from sqlmodel import SQLModel, Field +from typing import Optional -class BangumiData(BaseModel): - id: int = Field(0, alias="id", title="番剧ID") - official_title: str = Field("official_title", alias="official_title", title="番剧中文名") - year: str | None = Field(None, alias="year", title="番剧年份") - title_raw: str = Field("title_raw", alias="title_raw", title="番剧原名") - season: int = Field(1, alias="season", title="番剧季度") - season_raw: str | None = Field(None, alias="season_raw", title="番剧季度原名") - group_name: str | None = Field(None, alias="group_name", title="字幕组") - dpi: str | None = Field(None, alias="dpi", title="分辨率") - source: str | None = Field(None, alias="source", title="来源") - subtitle: str | None = Field(None, alias="subtitle", title="字幕") - eps_collect: bool = Field(False, alias="eps_collect", title="是否已收集") - offset: int = Field(0, alias="offset", title="番剧偏移量") - filter: list[str] = Field(["720", "\\d+-\\d+"], alias="filter", title="番剧过滤器") - rss_link: list[str] = Field([], alias="rss_link", title="番剧RSS链接") - poster_link: str | None = Field(None, alias="poster_link", title="番剧海报链接") - added: bool = Field(False, alias="added", title="是否已添加") - rule_name: str | None = Field(None, alias="rule_name", title="番剧规则名") - save_path: str | None = Field(None, alias="save_path", title="番剧保存路径") +class Bangumi(SQLModel, table=True): + id: int = Field(default=None, primary_key=True) + official_title: str = Field(default="official_title", alias="official_title", title="番剧中文名") + year: Optional[str] = Field(alias="year", title="番剧年份") + title_raw: str = Field(default="title_raw", alias="title_raw", title="番剧原名") + season: int = Field(default=1, alias="season", title="番剧季度") + season_raw: Optional[str] = Field(alias="season_raw", title="番剧季度原名") + group_name: Optional[str] = Field(alias="group_name", title="字幕组") + dpi: Optional[str] = Field(alias="dpi", title="分辨率") + source: Optional[str] = Field(alias="source", title="来源") + subtitle: Optional[str] = Field(alias="subtitle", title="字幕") + eps_collect: bool = Field(default=False, alias="eps_collect", title="是否已收集") + offset: int = Field(default=0, alias="offset", title="番剧偏移量") + filter: str = Field(default="720, \\d+-\\d+", alias="filter", title="番剧过滤器") + rss_link: str = Field(default="", alias="rss_link", title="番剧RSS链接") + poster_link: Optional[str] = Field(alias="poster_link", title="番剧海报链接") + added: bool = Field(default=False, alias="added", title="是否已添加") + rule_name: Optional[str] = Field(alias="rule_name", title="番剧规则名") + save_path: Optional[str] = Field(alias="save_path", title="番剧保存路径") + deleted: bool = Field(False, alias="deleted", title="是否已删除") + + +class BangumiUpdate(SQLModel): + official_title: str = Field(default="official_title", alias="official_title", title="番剧中文名") + year: Optional[str] = Field(alias="year", title="番剧年份") + season: int = Field(default=1, alias="season", title="番剧季度") + season_raw: Optional[str] = Field(alias="season_raw", title="番剧季度原名") + group_name: Optional[str] = Field(alias="group_name", title="字幕组") + dpi: Optional[str] = Field(alias="dpi", title="分辨率") + source: Optional[str] = Field(alias="source", title="来源") + subtitle: Optional[str] = Field(alias="subtitle", title="字幕") + eps_collect: bool = Field(default=False, alias="eps_collect", title="是否已收集") + offset: int = Field(default=0, alias="offset", title="番剧偏移量") + filter: str = Field(default="720, \\d+-\\d+", alias="filter", title="番剧过滤器") + rss_link: str = Field(default="", alias="rss_link", title="番剧RSS链接") + added: bool = Field(default=False, alias="added", title="是否已添加") deleted: bool = Field(False, alias="deleted", title="是否已删除") @@ -29,14 +48,14 @@ class Notification(BaseModel): official_title: str = Field(..., alias="official_title", title="番剧名") season: int = Field(..., alias="season", title="番剧季度") episode: int = Field(..., alias="episode", title="番剧集数") - poster_path: str | None = Field(None, alias="poster_path", title="番剧海报路径") + poster_path: Optional[str] = Field(None, alias="poster_path", title="番剧海报路径") @dataclass class Episode: - title_en: str | None - title_zh: str | None - title_jp: str | None + title_en: Optional[str] + title_zh: Optional[str] + title_jp: Optional[str] season: int season_raw: str episode: int From 3ca9a9737fcf4266fe1f1a71cf22f0a271e2505c Mon Sep 17 00:00:00 2001 From: estrella Date: Sun, 30 Jul 2023 21:33:29 +0800 Subject: [PATCH 2/6] feat: update hand-made orm to sqlmodel (UserDatabase) --- backend/src/module/database/user.py | 79 +++++++++++++---------------- backend/src/module/models/user.py | 14 ++++- 2 files changed, 48 insertions(+), 45 deletions(-) diff --git a/backend/src/module/database/user.py b/backend/src/module/database/user.py index 69dd9dea..382ada60 100644 --- a/backend/src/module/database/user.py +++ b/backend/src/module/database/user.py @@ -2,69 +2,62 @@ import logging from fastapi import HTTPException -from module.database.connector import DataConnector -from module.models.user import User +from module.models.user import User, UserUpdate, UserLogin from module.security.jwt import get_password_hash, verify_password +from module.database.engine import engine +from sqlmodel import Session, select, SQLModel logger = logging.getLogger(__name__) -class AuthDB(DataConnector): +class AuthDB(Session): def __init__(self): super().__init__() - self.__table_name = "user" - if not self._table_exists(self.__table_name): - self.__update_table() - - def __update_table(self): - db_data = self.__data_to_db(User()) - self._update_table(self.__table_name, db_data) - self._insert(self.__table_name, db_data) + self.__update_table() @staticmethod - def __data_to_db(data: User) -> dict: - db_data = data.dict() - db_data["password"] = get_password_hash(db_data["password"]) - return db_data + def __update_table(): + SQLModel.metadata.create_all(engine) - @staticmethod - def __db_to_data(db_data: dict) -> User: - return User(**db_data) + # @staticmethod + # def __data_to_db(data: User) -> dict: + # db_data = data.dict() + # db_data["password"] = get_password_hash(db_data["password"]) + # return db_data + # + # @staticmethod + # def __db_to_data(db_data: dict) -> User: + # return User(**db_data) def get_user(self, username): - self._cursor.execute( - f"SELECT * FROM {self.__table_name} WHERE username=?", (username,) - ) - result = self._cursor.fetchone() + statement = select(User).where(User.username == username) + result = self.exec(statement).first() if not result: - return None - db_data = dict(zip([x[0] for x in self._cursor.description], result)) - return self.__db_to_data(db_data) + raise HTTPException(status_code=404, detail="User not found") + return result - def auth_user(self, username, password) -> bool: - self._cursor.execute( - f"SELECT username, password FROM {self.__table_name} WHERE username=?", - (username,), - ) - result = self._cursor.fetchone() + def auth_user(self, user: UserLogin) -> bool: + statement = select(User).where(User.username == user.username) + result = self.exec(statement).first() if not result: raise HTTPException(status_code=401, detail="User not found") - if not verify_password(password, result[1]): + if not verify_password(user.password, result.password): raise HTTPException(status_code=401, detail="Password error") return True - def update_user(self, username, update_user: User): + def update_user(self, username, update_user: UserUpdate): # Update username and password - new_username = update_user.username - new_password = update_user.password - self._cursor.execute( - f""" - UPDATE {self.__table_name} - SET username = '{new_username}', password = '{get_password_hash(new_password)}' - WHERE username = '{username}' - """ - ) - self._conn.commit() + statement = select(User).where(User.username == username) + result = self.exec(statement).first() + if not result: + raise HTTPException(status_code=404, detail="User not found") + if update_user.username: + result.username = update_user.username + if update_user.password: + result.password = get_password_hash(update_user.password) + self.add(result) + self.commit() + return result if __name__ == "__main__": diff --git a/backend/src/module/models/user.py b/backend/src/module/models/user.py index 36512642..662622cd 100644 --- a/backend/src/module/models/user.py +++ b/backend/src/module/models/user.py @@ -1,14 +1,24 @@ from pydantic import BaseModel, Field +from typing import Optional +from sqlmodel import SQLModel, Field -class User(BaseModel): +class User(SQLModel, table=True): + id: int = Field(default=None, primary_key=True) username: str = Field( "admin", min_length=4, max_length=20, regex=r"^[a-zA-Z0-9_]+$" ) password: str = Field("adminadmin", min_length=8) -class UserLogin(BaseModel): +class UserUpdate(SQLModel): + username: Optional[str] = Field( + None, min_length=4, max_length=20, regex=r"^[a-zA-Z0-9_]+$" + ) + password: Optional[str] = Field(None, min_length=8) + + +class UserLogin(SQLModel): username: str password: str = Field(..., min_length=8) From d24acc60d5d5dd805e90da8933bff20938af1bd4 Mon Sep 17 00:00:00 2001 From: estrella Date: Sun, 30 Jul 2023 21:49:34 +0800 Subject: [PATCH 3/6] fix: Userdatabase bugs --- backend/src/module/conf/__init__.py | 2 +- backend/src/module/database/bangumi.py | 35 -------------------------- backend/src/module/database/engine.py | 3 ++- backend/src/module/database/user.py | 20 +++++++++------ backend/src/module/models/user.py | 2 +- backend/src/module/security/api.py | 8 +++--- 6 files changed, 20 insertions(+), 50 deletions(-) diff --git a/backend/src/module/conf/__init__.py b/backend/src/module/conf/__init__.py index 00dbc6fe..b6bc7662 100644 --- a/backend/src/module/conf/__init__.py +++ b/backend/src/module/conf/__init__.py @@ -4,7 +4,7 @@ from .config import VERSION, settings from .log import LOG_PATH, setup_logger TMDB_API = "32b19d6a05b512190a056fa4e747cbbc" -DATA_PATH = Path("data/data.db") +DATA_PATH = "sqlite:///data/data.db" LEGACY_DATA_PATH = Path("data/data.json") PLATFORM = "Windows" if "\\" in settings.downloader.path else "Unix" diff --git a/backend/src/module/database/bangumi.py b/backend/src/module/database/bangumi.py index 9209e717..986ccee5 100644 --- a/backend/src/module/database/bangumi.py +++ b/backend/src/module/database/bangumi.py @@ -1,8 +1,6 @@ import logging -from module.database.orm import Connector from module.models import Bangumi, BangumiUpdate -from module.conf import DATA_PATH from sqlmodel import Session, select, delete, SQLModel from module.database.engine import engine from typing import Optional @@ -14,51 +12,18 @@ logger = logging.getLogger(__name__) class BangumiDatabase(Session): def __init__(self): super().__init__(engine) - # table_name="bangumi", - # data=self.__data_to_db(BangumiData()), - # database=database, - # ) @staticmethod def update_table(): SQLModel.metadata.create_all(engine) - # @staticmethod - # def __data_to_db(data: BangumiData) -> Bangumi: - # db_data = data.dict() - # for key, value in db_data.items(): - # if isinstance(value, bool): - # db_data[key] = int(value) - # elif isinstance(value, list): - # db_data[key] = ",".join(value) - # return db_data - # - # @staticmethod - # def __db_to_data(db_data: Bangumi) -> BangumiData: - # for key, item in db_data.items(): - # if isinstance(item, int): - # if key not in ["id", "offset", "season", "year"]: - # db_data[key] = bool(item) - # elif key in ["filter", "rss_link"]: - # db_data[key] = item.split(",") - # return BangumiData(**db_data) - def insert_one(self, data: Bangumi): self.add(data) self.commit() - # db_data = self.__data_to_db(data) - # self.insert.one(db_data) logger.debug(f"[Database] Insert {data.official_title} into database.") def insert_list(self, data: list[Bangumi]): self.add_all(data) - # data_list = [self.__data_to_db(x) for x in data] - # self.insert.many(data_list) - # _id = self.gen_id() - # for i, item in enumerate(data): - # item.id = _id + i - # data_list = [self.__data_to_db(x) for x in data] - # self._insert_list(data_list=data_list, table_name=self.__table_name) logger.debug(f"[Database] Insert {len(data)} bangumi into database.") def update_one(self, data: BangumiUpdate) -> bool: diff --git a/backend/src/module/database/engine.py b/backend/src/module/database/engine.py index e4bd2f5d..369c53b3 100644 --- a/backend/src/module/database/engine.py +++ b/backend/src/module/database/engine.py @@ -1,6 +1,7 @@ from sqlmodel import create_engine, Session +from module.conf import DATA_PATH -engine = create_engine("sqlite:///data/data.db") +engine = create_engine(DATA_PATH) db_session = Session(engine) \ No newline at end of file diff --git a/backend/src/module/database/user.py b/backend/src/module/database/user.py index 382ada60..c763346f 100644 --- a/backend/src/module/database/user.py +++ b/backend/src/module/database/user.py @@ -6,18 +6,22 @@ from module.models.user import User, UserUpdate, UserLogin from module.security.jwt import get_password_hash, verify_password from module.database.engine import engine from sqlmodel import Session, select, SQLModel +from sqlalchemy.exc import UnboundExecutionError, OperationalError logger = logging.getLogger(__name__) -class AuthDB(Session): +class UserDatabase(Session): def __init__(self): - super().__init__() - self.__update_table() + super().__init__(engine) + statement = select(User) + try: + self.exec(statement) + except OperationalError: + SQLModel.metadata.create_all(engine) + self.add(User()) + self.commit() - @staticmethod - def __update_table(): - SQLModel.metadata.create_all(engine) # @staticmethod # def __data_to_db(data: User) -> dict: @@ -61,6 +65,6 @@ class AuthDB(Session): if __name__ == "__main__": - with AuthDB() as db: + with UserDatabase() as db: # db.update_user(UserLogin(username="admin", password="adminadmin"), User(username="admin", password="cica1234")) - db.update_user("admin", User(username="estrella", password="cica1234")) + db.update_user("admin", UserUpdate(username="estrella", password="cica1234")) diff --git a/backend/src/module/models/user.py b/backend/src/module/models/user.py index 662622cd..3e5cef29 100644 --- a/backend/src/module/models/user.py +++ b/backend/src/module/models/user.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel, Field +from pydantic import BaseModel from typing import Optional from sqlmodel import SQLModel, Field diff --git a/backend/src/module/security/api.py b/backend/src/module/security/api.py index 368abb0e..b41dcafa 100644 --- a/backend/src/module/security/api.py +++ b/backend/src/module/security/api.py @@ -1,7 +1,7 @@ from fastapi import Depends, HTTPException, status from fastapi.security import OAuth2PasswordBearer -from module.database.user import AuthDB +from module.database.user import UserDatabase from module.models.user import User from .jwt import verify_token @@ -20,7 +20,7 @@ async def get_current_user(token: str = Depends(oauth2_scheme)): status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token" ) username = payload.get("sub") - with AuthDB() as user_db: + with UserDatabase as user_db: user = user_db.get_user(username) if not user: raise HTTPException( @@ -40,7 +40,7 @@ async def get_token_data(token: str = Depends(oauth2_scheme)): def update_user_info(user_data: User, current_user): try: - with AuthDB() as db: + with UserDatabase as db: db.update_user(current_user.username, user_data) return True except Exception as e: @@ -48,5 +48,5 @@ def update_user_info(user_data: User, current_user): def auth_user(username, password): - with AuthDB() as db: + with UserDatabase() as db: db.auth_user(username, password) From 9b99e6c59107ccd6d72644cd2aa7699d79559842 Mon Sep 17 00:00:00 2001 From: estrella Date: Mon, 31 Jul 2023 16:00:54 +0800 Subject: [PATCH 4/6] test: add mock test to database. --- backend/src/module/api/bangumi.py | 8 +-- backend/src/module/api/download.py | 8 +-- backend/src/module/database/bangumi.py | 57 ++++++++----------- backend/src/module/database/engine.py | 2 +- backend/src/module/database/user.py | 11 ---- .../src/module/downloader/download_client.py | 6 +- backend/src/module/downloader/path.py | 7 +-- backend/src/module/manager/collector.py | 8 +-- backend/src/module/manager/torrent.py | 14 ++--- backend/src/module/models/__init__.py | 2 +- backend/src/module/models/bangumi.py | 8 ++- backend/src/module/models/config.py | 7 ++- backend/src/module/parser/title_parser.py | 6 +- backend/src/module/rss/analyser.py | 10 ++-- backend/src/module/rss/filter.py | 4 +- backend/src/module/rss/poller.py | 4 +- backend/src/module/searcher/searcher.py | 4 +- backend/src/module/update/data_migration.py | 4 +- backend/src/test/test_database.py | 28 +++++---- 19 files changed, 92 insertions(+), 106 deletions(-) diff --git a/backend/src/module/api/bangumi.py b/backend/src/module/api/bangumi.py index 16c206f3..7254fe45 100644 --- a/backend/src/module/api/bangumi.py +++ b/backend/src/module/api/bangumi.py @@ -2,13 +2,13 @@ from fastapi import APIRouter, Depends, HTTPException, status from fastapi.responses import JSONResponse from module.manager import TorrentManager -from module.models import BangumiData +from module.models import Bangumi from module.security import get_current_user router = APIRouter(prefix="/bangumi", tags=["bangumi"]) -@router.get("/getAll", response_model=list[BangumiData]) +@router.get("/getAll", response_model=list[Bangumi]) async def get_all_data(current_user=Depends(get_current_user)): if not current_user: raise HTTPException( @@ -18,7 +18,7 @@ async def get_all_data(current_user=Depends(get_current_user)): return torrent.search_all() -@router.get("/getData/{bangumi_id}", response_model=BangumiData) +@router.get("/getData/{bangumi_id}", response_model=Bangumi) async def get_data(bangumi_id: str, current_user=Depends(get_current_user)): if not current_user: raise HTTPException( @@ -29,7 +29,7 @@ async def get_data(bangumi_id: str, current_user=Depends(get_current_user)): @router.post("/updateRule") -async def update_rule(data: BangumiData, current_user=Depends(get_current_user)): +async def update_rule(data: Bangumi, current_user=Depends(get_current_user)): if not current_user: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token" diff --git a/backend/src/module/api/download.py b/backend/src/module/api/download.py index 278b1bb5..dcd36f1e 100644 --- a/backend/src/module/api/download.py +++ b/backend/src/module/api/download.py @@ -1,7 +1,7 @@ from fastapi import APIRouter, Depends, HTTPException, status from module.manager import SeasonCollector -from module.models import BangumiData +from module.models import Bangumi from module.models.api import RssLink from module.rss import analyser from module.security import get_current_user @@ -23,9 +23,7 @@ async def analysis(link: RssLink, current_user=Depends(get_current_user)): @router.post("/collection") -async def download_collection( - data: BangumiData, current_user=Depends(get_current_user) -): +async def download_collection(data: Bangumi, current_user=Depends(get_current_user)): if not current_user: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token" @@ -41,7 +39,7 @@ async def download_collection( @router.post("/subscribe") -async def subscribe(data: BangumiData, current_user=Depends(get_current_user)): +async def subscribe(data: Bangumi, current_user=Depends(get_current_user)): if not current_user: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token" diff --git a/backend/src/module/database/bangumi.py b/backend/src/module/database/bangumi.py index 986ccee5..d750746e 100644 --- a/backend/src/module/database/bangumi.py +++ b/backend/src/module/database/bangumi.py @@ -1,21 +1,18 @@ import logging -from module.models import Bangumi, BangumiUpdate -from sqlmodel import Session, select, delete, SQLModel -from module.database.engine import engine +from sqlmodel import Session, select, delete, SQLModel, or_, and_ from typing import Optional from sqlalchemy.exc import IntegrityError, NoResultFound +from .engine import engine +from module.models import Bangumi + logger = logging.getLogger(__name__) class BangumiDatabase(Session): - def __init__(self): - super().__init__(engine) - - @staticmethod - def update_table(): - SQLModel.metadata.create_all(engine) + def __init__(self, _engine=engine): + super().__init__(_engine) def insert_one(self, data: Bangumi): self.add(data) @@ -26,7 +23,7 @@ class BangumiDatabase(Session): self.add_all(data) logger.debug(f"[Database] Insert {len(data)} bangumi into database.") - def update_one(self, data: BangumiUpdate) -> bool: + def update_one(self, data: Bangumi) -> bool: db_data = self.get(Bangumi, data.id) if not db_data: return False @@ -39,7 +36,7 @@ class BangumiDatabase(Session): logger.debug(f"[Database] Update {data.official_title}") return True - def update_list(self, datas: list[BangumiUpdate]): + def update_list(self, datas: list[Bangumi]): for data in datas: self.update_one(data) @@ -48,7 +45,7 @@ class BangumiDatabase(Session): statement = select(Bangumi).where(Bangumi.title_raw == title_raw) bangumi = self.exec(statement).first() bangumi.rss_link = rss_set - bangumi.added = 0 + bangumi.added = False self.add(bangumi) self.commit() self.refresh(bangumi) @@ -93,14 +90,8 @@ class BangumiDatabase(Session): return self.exec(statement).first() def match_poster(self, bangumi_name: str) -> str: - # condition = {"official_title": bangumi_name} - statement = select(Bangumi).where(Bangumi.official_title == bangumi_name) - # keys = ["poster_link"] - # data = self.select.one( - # keys=keys, - # conditions=condition, - # combine_operator="INSTR", - # ) + # Use like to match + statement = select(Bangumi).where(Bangumi.title_raw.like(f"%{bangumi_name}%")) data = self.exec(statement).first() if data: return data.poster_link @@ -108,9 +99,6 @@ class BangumiDatabase(Session): return "" def match_list(self, torrent_list: list, rss_link: str) -> list: - # Match title_raw in database - # keys = ["title_raw", "rss_link", "poster_link"] - # match_datas = self.select.column(keys) match_datas = self.search_all() if not match_datas: return torrent_list @@ -122,13 +110,9 @@ class BangumiDatabase(Session): if match_data.title_raw in torrent.name: if rss_link not in match_data.rss_link: match_data.rss_link += f",{rss_link}" - self.update_rss( - match_data.title_raw, match_data.rss_link - ) + self.update_rss(match_data.title_raw, match_data.rss_link) if not match_data.poster_link: - self.update_poster( - match_data.title_raw, torrent.poster_link - ) + self.update_poster(match_data.title_raw, torrent.poster_link) torrent_list.pop(i) break else: @@ -143,14 +127,23 @@ class BangumiDatabase(Session): def not_added(self) -> list[Bangumi]: conditions = select(Bangumi).where( - Bangumi.added == 0 or - Bangumi.rule_name is None or - Bangumi.save_path is None + or_( + Bangumi.added == 0, Bangumi.rule_name is None, Bangumi.save_path is None ) + ) datas = self.exec(conditions).all() # dict_data = self.select.many(conditions=conditions, combine_operator="OR") return datas + def disable_rule(self, _id: int): + statement = select(Bangumi).where(Bangumi.id == _id) + bangumi = self.exec(statement).first() + bangumi.deleted = True + self.add(bangumi) + self.commit() + self.refresh(bangumi) + logger.debug(f"[Database] Disable rule {bangumi.title_raw}.") + if __name__ == "__main__": with BangumiDatabase() as db: diff --git a/backend/src/module/database/engine.py b/backend/src/module/database/engine.py index 369c53b3..94fa37b0 100644 --- a/backend/src/module/database/engine.py +++ b/backend/src/module/database/engine.py @@ -4,4 +4,4 @@ from module.conf import DATA_PATH engine = create_engine(DATA_PATH) -db_session = Session(engine) \ No newline at end of file +db_session = Session(engine) diff --git a/backend/src/module/database/user.py b/backend/src/module/database/user.py index c763346f..6fe77d2f 100644 --- a/backend/src/module/database/user.py +++ b/backend/src/module/database/user.py @@ -22,17 +22,6 @@ class UserDatabase(Session): self.add(User()) self.commit() - - # @staticmethod - # def __data_to_db(data: User) -> dict: - # db_data = data.dict() - # db_data["password"] = get_password_hash(db_data["password"]) - # return db_data - # - # @staticmethod - # def __db_to_data(db_data: dict) -> User: - # return User(**db_data) - def get_user(self, username): statement = select(User).where(User.username == username) result = self.exec(statement).first() diff --git a/backend/src/module/downloader/download_client.py b/backend/src/module/downloader/download_client.py index 675e180a..105f97d3 100644 --- a/backend/src/module/downloader/download_client.py +++ b/backend/src/module/downloader/download_client.py @@ -1,7 +1,7 @@ import logging from module.conf import settings -from module.models import BangumiData +from module.models import Bangumi from .path import TorrentPath @@ -68,7 +68,7 @@ class DownloadClient(TorrentPath): prefs = self.client.get_app_prefs() settings.downloader.path = self._join_path(prefs["save_path"], "Bangumi") - def set_rule(self, data: BangumiData): + def set_rule(self, data: Bangumi): data.rule_name = self._rule_name(data) data.save_path = self._gen_save_path(data) rule = { @@ -92,7 +92,7 @@ class DownloadClient(TorrentPath): f"[Downloader] Add {data.official_title} Season {data.season} to auto download rules." ) - def set_rules(self, bangumi_info: list[BangumiData]): + def set_rules(self, bangumi_info: list[Bangumi]): logger.debug("[Downloader] Start adding rules.") for info in bangumi_info: self.set_rule(info) diff --git a/backend/src/module/downloader/path.py b/backend/src/module/downloader/path.py index f1099191..8f6fcc16 100644 --- a/backend/src/module/downloader/path.py +++ b/backend/src/module/downloader/path.py @@ -4,8 +4,7 @@ import re from pathlib import Path from module.conf import settings -from module.models import BangumiData - +from module.models import Bangumi logger = logging.getLogger(__name__) @@ -50,7 +49,7 @@ class TorrentPath: return self._file_depth(file_path) <= 2 @staticmethod - def _gen_save_path(data: BangumiData): + def _gen_save_path(data: Bangumi): folder = ( f"{data.official_title} ({data.year})" if data.year else data.official_title ) @@ -58,7 +57,7 @@ class TorrentPath: return str(save_path) @staticmethod - def _rule_name(data: BangumiData): + def _rule_name(data: Bangumi): rule_name = ( f"[{data.group_name}] {data.official_title} S{data.season}" if settings.bangumi_manage.group_tag diff --git a/backend/src/module/manager/collector.py b/backend/src/module/manager/collector.py index 979aa5cb..b0bb9e15 100644 --- a/backend/src/module/manager/collector.py +++ b/backend/src/module/manager/collector.py @@ -2,14 +2,14 @@ import logging from module.database import BangumiDatabase from module.downloader import DownloadClient -from module.models import BangumiData +from module.models import Bangumi from module.searcher import SearchTorrent logger = logging.getLogger(__name__) class SeasonCollector(DownloadClient): - def add_season_torrents(self, data: BangumiData, torrents, torrent_files=None): + def add_season_torrents(self, data: Bangumi, torrents, torrent_files=None): if torrent_files: download_info = { "torrent_files": torrent_files, @@ -23,7 +23,7 @@ class SeasonCollector(DownloadClient): } return self.add_torrent(download_info) - def collect_season(self, data: BangumiData, link: str = None, proxy: bool = False): + def collect_season(self, data: Bangumi, link: str = None, proxy: bool = False): logger.info(f"Start collecting {data.official_title} Season {data.season}...") with SearchTorrent() as st: if not link: @@ -39,7 +39,7 @@ class SeasonCollector(DownloadClient): data=data, torrents=torrents, torrent_files=torrent_files ) - def subscribe_season(self, data: BangumiData): + def subscribe_season(self, data: Bangumi): with BangumiDatabase() as db: data.added = True data.eps_collect = True diff --git a/backend/src/module/manager/torrent.py b/backend/src/module/manager/torrent.py index f3a69edb..3385e11e 100644 --- a/backend/src/module/manager/torrent.py +++ b/backend/src/module/manager/torrent.py @@ -4,21 +4,21 @@ from fastapi.responses import JSONResponse from module.database import BangumiDatabase from module.downloader import DownloadClient -from module.models import BangumiData +from module.models import Bangumi logger = logging.getLogger(__name__) class TorrentManager(BangumiDatabase): @staticmethod - def __match_torrents_list(data: BangumiData) -> list: + def __match_torrents_list(data: Bangumi) -> list: with DownloadClient() as client: torrents = client.get_torrent_info(status_filter=None) return [ torrent.hash for torrent in torrents if torrent.save_path == data.save_path ] - def delete_torrents(self, data: BangumiData, client: DownloadClient): + def delete_torrents(self, data: Bangumi, client: DownloadClient): hash_list = self.__match_torrents_list(data) if hash_list: client.delete_torrent(hash_list) @@ -29,7 +29,7 @@ class TorrentManager(BangumiDatabase): def delete_rule(self, _id: int | str, file: bool = False): data = self.search_id(int(_id)) - if isinstance(data, BangumiData): + if isinstance(data, Bangumi): with DownloadClient() as client: client.remove_rule(data.rule_name) client.remove_rss_feed(data.official_title) @@ -54,7 +54,7 @@ class TorrentManager(BangumiDatabase): def disable_rule(self, _id: str | int, file: bool = False): data = self.search_id(int(_id)) - if isinstance(data, BangumiData): + if isinstance(data, Bangumi): with DownloadClient() as client: client.remove_rule(data.rule_name) data.deleted = True @@ -81,7 +81,7 @@ class TorrentManager(BangumiDatabase): def enable_rule(self, _id: str | int): data = self.search_id(int(_id)) - if isinstance(data, BangumiData): + if isinstance(data, Bangumi): data.deleted = False self.update_one(data) with DownloadClient() as client: @@ -98,7 +98,7 @@ class TorrentManager(BangumiDatabase): status_code=406, content={"msg": f"Can't find bangumi id {_id}"} ) - def update_rule(self, data: BangumiData): + def update_rule(self, data: Bangumi): old_data = self.search_id(data.id) if not old_data: logger.error(f"[Manager] Can't find data with {data.id}") diff --git a/backend/src/module/models/__init__.py b/backend/src/module/models/__init__.py index a73f18ed..0b10feaf 100644 --- a/backend/src/module/models/__init__.py +++ b/backend/src/module/models/__init__.py @@ -1,4 +1,4 @@ -from .bangumi import * +from .bangumi import Bangumi, Episode from .config import Config from .rss import RSSTorrents from .torrent import EpisodeFile, SubtitleFile, TorrentBase diff --git a/backend/src/module/models/bangumi.py b/backend/src/module/models/bangumi.py index 02a74ce9..b7484887 100644 --- a/backend/src/module/models/bangumi.py +++ b/backend/src/module/models/bangumi.py @@ -7,7 +7,9 @@ from typing import Optional class Bangumi(SQLModel, table=True): id: int = Field(default=None, primary_key=True) - official_title: str = Field(default="official_title", alias="official_title", title="番剧中文名") + official_title: str = Field( + default="official_title", alias="official_title", title="番剧中文名" + ) year: Optional[str] = Field(alias="year", title="番剧年份") title_raw: str = Field(default="title_raw", alias="title_raw", title="番剧原名") season: int = Field(default=1, alias="season", title="番剧季度") @@ -28,7 +30,9 @@ class Bangumi(SQLModel, table=True): class BangumiUpdate(SQLModel): - official_title: str = Field(default="official_title", alias="official_title", title="番剧中文名") + official_title: str = Field( + default="official_title", alias="official_title", title="番剧中文名" + ) year: Optional[str] = Field(alias="year", title="番剧年份") season: int = Field(default=1, alias="season", title="番剧季度") season_raw: Optional[str] = Field(alias="season_raw", title="番剧季度原名") diff --git a/backend/src/module/models/config.py b/backend/src/module/models/config.py index a2e4d2a6..76ec036f 100644 --- a/backend/src/module/models/config.py +++ b/backend/src/module/models/config.py @@ -14,7 +14,9 @@ class Downloader(BaseModel): type: str = Field("qbittorrent", description="Downloader type") host: str = Field("172.17.0.1:8080", description="Downloader host") username_: str = Field("admin", alias="username", description="Downloader username") - password_: str = Field("adminadmin", alias="password", description="Downloader password") + password_: str = Field( + "adminadmin", alias="password", description="Downloader password" + ) path: str = Field("/downloads/Bangumi", description="Downloader path") ssl: bool = Field(False, description="Downloader ssl") @@ -26,6 +28,7 @@ class Downloader(BaseModel): def password(self): return expandvars(self.password_) + class RSSParser(BaseModel): enable: bool = Field(True, description="Enable RSS parser") type: str = Field("mikan", description="RSS parser type") @@ -39,6 +42,7 @@ class RSSParser(BaseModel): def token(self): return expandvars(self.token_) + class BangumiManage(BaseModel): enable: bool = Field(True, description="Enable bangumi manage") eps_complete: bool = Field(False, description="Enable eps complete") @@ -82,6 +86,7 @@ class Notification(BaseModel): def chat_id(self): return expandvars(self.chat_id_) + class Config(BaseModel): program: Program = Program() downloader: Downloader = Downloader() diff --git a/backend/src/module/parser/title_parser.py b/backend/src/module/parser/title_parser.py index 33ada84b..4d1e7f00 100644 --- a/backend/src/module/parser/title_parser.py +++ b/backend/src/module/parser/title_parser.py @@ -1,7 +1,7 @@ import logging from module.conf import settings -from module.models import BangumiData +from module.models import Bangumi from .analyser import raw_parser, tmdb_parser, torrent_parser @@ -39,7 +39,7 @@ class TitleParser: return official_title, tmdb_season, year @staticmethod - def raw_parser(raw: str, rss_link: str) -> BangumiData | None: + def raw_parser(raw: str, rss_link: str) -> Bangumi | None: language = settings.rss_parser.language try: episode = raw_parser(raw) @@ -60,7 +60,7 @@ class TitleParser: else: official_title = title_raw _season = episode.season - data = BangumiData( + data = Bangumi( official_title=official_title, title_raw=title_raw, season=_season, diff --git a/backend/src/module/rss/analyser.py b/backend/src/module/rss/analyser.py index ccd398be..00bde7a6 100644 --- a/backend/src/module/rss/analyser.py +++ b/backend/src/module/rss/analyser.py @@ -3,7 +3,7 @@ import re from module.conf import settings from module.database import BangumiDatabase -from module.models import BangumiData +from module.models import Bangumi from module.network import RequestContent, TorrentInfo from module.parser import TitleParser @@ -16,7 +16,7 @@ class RSSAnalyser: with BangumiDatabase() as db: db.update_table() - def official_title_parser(self, data: BangumiData, mikan_title: str): + def official_title_parser(self, data: Bangumi, mikan_title: str): if settings.rss_parser.parser_type == "mikan": data.official_title = mikan_title if mikan_title else data.official_title elif settings.rss_parser.parser_type == "tmdb": @@ -63,7 +63,7 @@ class RSSAnalyser: def torrent_to_data( self, torrent: TorrentInfo, rss_link: str | None = None - ) -> BangumiData: + ) -> Bangumi: data = self._title_analyser.raw_parser(raw=torrent.name, rss_link=rss_link) if data: try: @@ -79,7 +79,7 @@ class RSSAnalyser: def rss_to_data( self, rss_link: str, database: BangumiDatabase, full_parse: bool = True - ) -> list[BangumiData]: + ) -> list[Bangumi]: rss_torrents = self.get_rss_torrents(rss_link, full_parse) torrents_to_add = database.match_list(rss_torrents, rss_link) if not torrents_to_add: @@ -92,7 +92,7 @@ class RSSAnalyser: else: return [] - def link_to_data(self, link: str) -> BangumiData: + def link_to_data(self, link: str) -> Bangumi: torrents = self.get_rss_torrents(link, False) for torrent in torrents: data = self.torrent_to_data(torrent, link) diff --git a/backend/src/module/rss/filter.py b/backend/src/module/rss/filter.py index e70d4602..d4717e38 100644 --- a/backend/src/module/rss/filter.py +++ b/backend/src/module/rss/filter.py @@ -3,7 +3,7 @@ import logging from module.conf import settings from module.database import BangumiDatabase from module.downloader import DownloadClient -from module.models import BangumiData +from module.models import Bangumi from module.network import RequestContent logger = logging.getLogger(__name__) @@ -14,7 +14,7 @@ def matched(torrent_title: str): return db.match_torrent(torrent_title) -def save_path(data: BangumiData): +def save_path(data: Bangumi): folder = ( f"{data.official_title}({data.year})" if data.year else f"{data.official_title}" ) diff --git a/backend/src/module/rss/poller.py b/backend/src/module/rss/poller.py index 90a81cb0..3efaab2f 100644 --- a/backend/src/module/rss/poller.py +++ b/backend/src/module/rss/poller.py @@ -1,7 +1,7 @@ import re from module.database import RSSDatabase -from module.models import BangumiData, RSSTorrents +from module.models import Bangumi, RSSTorrents from module.network import RequestContent, TorrentInfo @@ -11,7 +11,7 @@ class RSSPoller(RSSDatabase): return req.get_torrents(rss_link) @staticmethod - def filter_torrent(data: BangumiData, torrent: TorrentInfo) -> bool: + def filter_torrent(data: Bangumi, torrent: TorrentInfo) -> bool: if data.title_raw in torrent.name: _filter = "|".join(data.filter) if not re.search(_filter, torrent.name): diff --git a/backend/src/module/searcher/searcher.py b/backend/src/module/searcher/searcher.py index a52e379e..dead9466 100644 --- a/backend/src/module/searcher/searcher.py +++ b/backend/src/module/searcher/searcher.py @@ -1,4 +1,4 @@ -from module.models import BangumiData, TorrentBase +from module.models import Bangumi, TorrentBase from module.network import RequestContent from module.searcher.plugin import search_url @@ -30,7 +30,7 @@ class SearchTorrent(RequestContent): return [TorrentBase(**d) for d in to_dict()] - def search_season(self, data: BangumiData): + def search_season(self, data: Bangumi): keywords = [getattr(data, key) for key in SEARCH_KEY if getattr(data, key)] torrents = self.search_torrents(keywords) return [torrent for torrent in torrents if data.title_raw in torrent.name] diff --git a/backend/src/module/update/data_migration.py b/backend/src/module/update/data_migration.py index b9ddf8db..f9739f98 100644 --- a/backend/src/module/update/data_migration.py +++ b/backend/src/module/update/data_migration.py @@ -2,7 +2,7 @@ import os from module.conf import LEGACY_DATA_PATH from module.database import BangumiDatabase -from module.models import BangumiData +from module.models import Bangumi from module.utils import json_config @@ -14,7 +14,7 @@ def data_migration(): rss_link = old_data["rss_link"] new_data = [] for info in infos: - new_data.append(BangumiData(**info, rss_link=[rss_link])) + new_data.append(Bangumi(**info, rss_link=[rss_link])) with BangumiDatabase() as database: database.update_table() database.insert_list(new_data) diff --git a/backend/src/test/test_database.py b/backend/src/test/test_database.py index 80493bf4..b7088ef1 100644 --- a/backend/src/test/test_database.py +++ b/backend/src/test/test_database.py @@ -1,11 +1,17 @@ +from sqlmodel import create_engine, SQLModel, Session +from sqlmodel.pool import StaticPool + from module.database import BangumiDatabase -from module.models import BangumiData +from module.models import Bangumi def test_database(): - TEST_PATH = "test/test.db" - test_data = BangumiData( - id=1, + # sqlite mock engine + engine = create_engine( + "sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool + ) + SQLModel.metadata.create_all(engine) + test_data = Bangumi( official_title="test", year="2021", title_raw="test", @@ -17,18 +23,15 @@ def test_database(): subtitle="test", eps_collect=False, offset=0, - filter=["720p", "\\d+-\\d+"], - rss_link=["test"], + filter="720p,\\d+-\\d+", + rss_link="test", poster_link="/test/test.jpg", added=False, rule_name=None, save_path=None, deleted=False, ) - with BangumiDatabase(database=TEST_PATH) as database: - # create table - database.update_table() - with BangumiDatabase(database=TEST_PATH) as database: + with BangumiDatabase(engine) as database: # insert database.insert_one(test_data) assert database.search_id(1) == test_data @@ -44,8 +47,3 @@ def test_database(): # delete database.delete_one(1) assert database.search_id(1) is None - - # Delete test database - import os - - os.remove(TEST_PATH) From 05dc4ba7ac0c4860c4ed9e7bd19f03f9a4a1a8ef Mon Sep 17 00:00:00 2001 From: estrella Date: Mon, 31 Jul 2023 16:02:17 +0800 Subject: [PATCH 5/6] fix: fix requirements.txt --- backend/requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/backend/requirements.txt b/backend/requirements.txt index 53166a42..783fe560 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -23,3 +23,4 @@ python-jose==3.3.0 passlib==1.7.4 bcrypt==4.0.1 python-multipart==0.0.6 +sqlmodel From c91c79912b855513e1736456adc9b14dc34f1b88 Mon Sep 17 00:00:00 2001 From: estrella Date: Mon, 31 Jul 2023 16:17:10 +0800 Subject: [PATCH 6/6] chore: remove unnecessary code. fix data search. --- backend/src/module/database/bangumi.py | 10 +- backend/src/module/database/connector.py | 174 ------------------- backend/src/module/database/orm/__init__.py | 1 - backend/src/module/database/orm/connector.py | 71 -------- backend/src/module/database/orm/delete.py | 23 --- backend/src/module/database/orm/insert.py | 33 ---- backend/src/module/database/orm/select.py | 96 ---------- backend/src/module/database/orm/update.py | 98 ----------- backend/src/test/test_database.py | 6 +- 9 files changed, 6 insertions(+), 506 deletions(-) delete mode 100644 backend/src/module/database/connector.py delete mode 100644 backend/src/module/database/orm/__init__.py delete mode 100644 backend/src/module/database/orm/connector.py delete mode 100644 backend/src/module/database/orm/delete.py delete mode 100644 backend/src/module/database/orm/insert.py delete mode 100644 backend/src/module/database/orm/select.py delete mode 100644 backend/src/module/database/orm/update.py diff --git a/backend/src/module/database/bangumi.py b/backend/src/module/database/bangumi.py index d750746e..9ab0b092 100644 --- a/backend/src/module/database/bangumi.py +++ b/backend/src/module/database/bangumi.py @@ -1,8 +1,8 @@ import logging -from sqlmodel import Session, select, delete, SQLModel, or_, and_ +from sqlmodel import Session, select, delete, or_ +from sqlalchemy.sql import func from typing import Optional -from sqlalchemy.exc import IntegrityError, NoResultFound from .engine import engine from module.models import Bangumi @@ -49,9 +49,6 @@ class BangumiDatabase(Session): self.add(bangumi) self.commit() self.refresh(bangumi) - # location = {"title_raw": title_raw} - # set_value = {"rss_link": rss_set, "added": 0} - # self.update.value(location, set_value) logger.debug(f"[Database] Update {title_raw} rss_link to {rss_set}.") def update_poster(self, title_raw, poster_link: str): @@ -91,7 +88,7 @@ class BangumiDatabase(Session): def match_poster(self, bangumi_name: str) -> str: # Use like to match - statement = select(Bangumi).where(Bangumi.title_raw.like(f"%{bangumi_name}%")) + statement = select(Bangumi).where(func.instr(bangumi_name, Bangumi.title_raw) > 0) data = self.exec(statement).first() if data: return data.poster_link @@ -132,7 +129,6 @@ class BangumiDatabase(Session): ) ) datas = self.exec(conditions).all() - # dict_data = self.select.many(conditions=conditions, combine_operator="OR") return datas def disable_rule(self, _id: int): diff --git a/backend/src/module/database/connector.py b/backend/src/module/database/connector.py deleted file mode 100644 index 506bb7b1..00000000 --- a/backend/src/module/database/connector.py +++ /dev/null @@ -1,174 +0,0 @@ -import logging -import os -import sqlite3 - -from module.conf import DATA_PATH - -logger = logging.getLogger(__name__) - - -class DataConnector: - def __init__(self): - # Create folder if not exists - DATA_PATH.parent.mkdir(parents=True, exist_ok=True) - - self._conn = sqlite3.connect(DATA_PATH) - self._cursor = self._conn.cursor() - - def _update_table(self, table_name: str, db_data: dict): - columns = ", ".join( - [ - f"{key} {self.__python_to_sqlite_type(value)}" - for key, value in db_data.items() - ] - ) - create_table_sql = f"CREATE TABLE IF NOT EXISTS {table_name} ({columns});" - self._cursor.execute(create_table_sql) - self._cursor.execute(f"PRAGMA table_info({table_name})") - existing_columns = { - column_info[1]: column_info for column_info in self._cursor.fetchall() - } - for key, value in db_data.items(): - if key not in existing_columns: - insert_column = self.__python_to_sqlite_type(value) - if value is None: - value = "NULL" - add_column_sql = f"ALTER TABLE {table_name} ADD COLUMN {key} {insert_column} DEFAULT {value};" - self._cursor.execute(add_column_sql) - self._conn.commit() - logger.debug(f"Create / Update table {table_name}.") - - def _insert(self, table_name: str, db_data: dict): - columns = ", ".join(db_data.keys()) - values = ", ".join([f":{key}" for key in db_data.keys()]) - self._cursor.execute( - f"INSERT INTO {table_name} ({columns}) VALUES ({values})", db_data - ) - self._conn.commit() - - def _insert_list(self, table_name: str, data_list: list[dict]): - columns = ", ".join(data_list[0].keys()) - values = ", ".join([f":{key}" for key in data_list[0].keys()]) - self._cursor.executemany( - f"INSERT INTO {table_name} ({columns}) VALUES ({values})", data_list - ) - self._conn.commit() - - def _select(self, keys: list[str], table_name: str, condition: str = None) -> dict: - if condition is None: - self._cursor.execute(f"SELECT {', '.join(keys)} FROM {table_name}") - else: - self._cursor.execute( - f"SELECT {', '.join(keys)} FROM {table_name} WHERE {condition}" - ) - return dict(zip(keys, self._cursor.fetchone())) - - def _update(self, table_name: str, db_data: dict): - _id = db_data.get("id") - if _id is None: - raise ValueError("No _id in db_data.") - set_sql = ", ".join([f"{key} = :{key}" for key in db_data.keys()]) - self._cursor.execute( - f"UPDATE {table_name} SET {set_sql} WHERE id = {_id}", db_data - ) - self._conn.commit() - return self._cursor.rowcount == 1 - - def _update_list(self, table_name: str, data_list: list[dict]): - if len(data_list) == 0: - return - set_sql = ", ".join( - [f"{key} = :{key}" for key in data_list[0].keys() if key != "id"] - ) - self._cursor.executemany( - f"UPDATE {table_name} SET {set_sql} WHERE id = :id", data_list - ) - self._conn.commit() - - def _update_section(self, table_name: str, location: dict, update_dict: dict): - set_sql = ", ".join([f"{key} = :{key}" for key in update_dict.keys()]) - sql_loc = f"{location['key']} = {location['value']}" - self._cursor.execute( - f"UPDATE {table_name} SET {set_sql} WHERE {sql_loc}", update_dict - ) - self._conn.commit() - - def _delete_all(self, table_name: str): - self._cursor.execute(f"DELETE FROM {table_name}") - self._conn.commit() - - def _delete(self, table_name: str, condition: dict): - condition_sql = " AND ".join([f"{key} = :{key}" for key in condition.keys()]) - self._cursor.execute( - f"DELETE FROM {table_name} WHERE {condition_sql}", condition - ) - self._conn.commit() - - def _search( - self, table_name: str, keys: list[str] | None = None, condition: dict = None - ): - if keys is None: - select_sql = "*" - else: - select_sql = ", ".join(keys) - if condition is None: - self._cursor.execute(f"SELECT {select_sql} FROM {table_name}") - else: - custom_condition = condition.pop("_custom_condition", None) - condition_sql = " AND ".join( - [f"{key} = :{key}" for key in condition.keys()] - ) + (f" AND {custom_condition}" if custom_condition else "") - self._cursor.execute( - f"SELECT {select_sql} FROM {table_name} WHERE {condition_sql}", - condition, - ) - - def _search_data( - self, table_name: str, keys: list[str] | None = None, condition: dict = None - ) -> dict: - if keys is None: - keys = self.__get_table_columns(table_name) - self._search(table_name, keys, condition) - return dict(zip(keys, self._cursor.fetchone())) - - def _search_datas( - self, table_name: str, keys: list[str] | None = None, condition: dict = None - ) -> list[dict]: - if keys is None: - keys = self.__get_table_columns(table_name) - self._search(table_name, keys, condition) - return [dict(zip(keys, row)) for row in self._cursor.fetchall()] - - def _table_exists(self, table_name: str) -> bool: - self._cursor.execute( - "SELECT name FROM sqlite_master WHERE type='table' AND name=?;", - (table_name,), - ) - return len(self._cursor.fetchall()) == 1 - - def __get_table_columns(self, table_name: str) -> list[str]: - self._cursor.execute(f"PRAGMA table_info({table_name})") - return [column_info[1] for column_info in self._cursor.fetchall()] - - @staticmethod - def __python_to_sqlite_type(value) -> str: - if isinstance(value, int): - return "INTEGER NOT NULL" - elif isinstance(value, float): - return "REAL NOT NULL" - elif isinstance(value, str): - return "TEXT NOT NULL" - elif isinstance(value, bool): - return "INTEGER NOT NULL" - elif isinstance(value, list): - return "TEXT NOT NULL" - elif value is None: - return "TEXT" - else: - raise ValueError(f"Unsupported data type: {type(value)}") - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self._conn.close() diff --git a/backend/src/module/database/orm/__init__.py b/backend/src/module/database/orm/__init__.py deleted file mode 100644 index 4b56580f..00000000 --- a/backend/src/module/database/orm/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .connector import Connector diff --git a/backend/src/module/database/orm/connector.py b/backend/src/module/database/orm/connector.py deleted file mode 100644 index 7f10106a..00000000 --- a/backend/src/module/database/orm/connector.py +++ /dev/null @@ -1,71 +0,0 @@ -from os import PathLike -from pathlib import Path -import sqlite3 - -from .delete import Delete -from .insert import Insert -from .select import Select -from .update import Update - -from module.conf import DATA_PATH - - -class Connector: - def __init__( - self, table_name: str, data: dict, database: PathLike[str] | Path = DATA_PATH - ): - # Create folder if not exists - if isinstance(database, (PathLike, str)): - database = Path(database) - database.parent.mkdir(parents=True, exist_ok=True) - - self._conn = sqlite3.connect(database) - self._cursor = self._conn.cursor() - self.update = Update(self, table_name, data) - self.insert = Insert(self, table_name, data) - self.select = Select(self, table_name, data) - self.delete = Delete(self, table_name, data) - self._columns = self.__get_columns(table_name) - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_value, traceback): - self._conn.close() - - def __get_columns(self, table_name: str) -> list[str]: - self._cursor.execute(f"PRAGMA table_info({table_name})") - return [x[1] for x in self._cursor.fetchall()] - - def execute(self, sql: str, params: tuple = None): - if params is None: - self._cursor.execute(sql) - else: - self._cursor.execute(sql, params) - self._conn.commit() - - def executemany(self, sql: str, params: list[tuple]): - self._cursor.executemany(sql, params) - self._conn.commit() - - def fetchall(self, keys: str = None) -> list[dict]: - datas = self._cursor.fetchall() - if keys: - return [dict(zip(keys, data)) for data in datas] - return [dict(zip(self._columns, data)) for data in datas] - - def fetchone(self, keys: list[str] = None) -> dict: - data = self._cursor.fetchone() - if data: - if keys: - return dict(zip(keys, data)) - return dict(zip(self._columns, data)) - - def fetchmany(self, keys: list[str], size: int) -> list[dict]: - datas = self._cursor.fetchmany(size) - if keys: - return [dict(zip(keys, data)) for data in datas] - return [dict(zip(self._columns, data)) for data in datas] - - def fetch(self): - return self._cursor.fetchall() diff --git a/backend/src/module/database/orm/delete.py b/backend/src/module/database/orm/delete.py deleted file mode 100644 index aa5a705b..00000000 --- a/backend/src/module/database/orm/delete.py +++ /dev/null @@ -1,23 +0,0 @@ -class Delete: - def __init__(self, connector, table_name: str, data: dict): - self._connector = connector - self._table_name = table_name - self._data = data - - def one(self, _id: int) -> bool: - self._connector.execute( - f""" - DELETE FROM {self._table_name} - WHERE id = :id - """, - {"id": _id}, - ) - return True - - def all(self): - self._connector.execute( - f""" - DELETE FROM {self._table_name} - """, - ) - return True diff --git a/backend/src/module/database/orm/insert.py b/backend/src/module/database/orm/insert.py deleted file mode 100644 index d78c6ca7..00000000 --- a/backend/src/module/database/orm/insert.py +++ /dev/null @@ -1,33 +0,0 @@ -class Insert: - def __init__(self, connector, table_name: str, data: dict): - self._connector = connector - self._table_name = table_name - self._columns = data.items() - - def __gen_id(self) -> int: - self._connector.execute( - f""" - SELECT MAX(id) FROM {self._table_name} - """, - ) - max_id = self._connector.fetchone(keys=["id"]).get("id") - if max_id is None: - return 1 - return max_id + 1 - - def one(self, data: dict): - _id = self.__gen_id() - data["id"] = _id - columns = ", ".join(data.keys()) - placeholders = ", ".join([f":{key}" for key in data.keys()]) - self._connector.execute( - f""" - INSERT INTO {self._table_name} ({columns}) - VALUES ({placeholders}) - """, - data, - ) - - def many(self, data: list[dict]): - for item in data: - self.one(item) diff --git a/backend/src/module/database/orm/select.py b/backend/src/module/database/orm/select.py deleted file mode 100644 index 198c2ff5..00000000 --- a/backend/src/module/database/orm/select.py +++ /dev/null @@ -1,96 +0,0 @@ -class Select: - def __init__(self, connector, table_name: str, data: dict): - self._connector = connector - self._table_name = table_name - self._data = data - - def id(self, _id: int): - self._connector.execute( - f""" - SELECT * FROM {self._table_name} - WHERE id = :id - """, - {"id": _id}, - ) - return self._connector.fetchone() - - def all(self, limit: int = None): - if limit is None: - limit = 10000 - self._connector.execute( - f""" - SELECT * FROM {self._table_name} LIMIT {limit} - """, - ) - return self._connector.fetchall() - - def one( - self, - keys: list[str] | None = None, - conditions: dict = None, - combine_operator: str = "AND", - ): - if keys is None: - columns = "*" - else: - columns = ", ".join(keys) - condition_sql = self.__select_condition(conditions, combine_operator) - self._connector.execute( - f""" - SELECT {columns} FROM {self._table_name} - WHERE {condition_sql} - """, - conditions, - ) - return self._connector.fetchone(keys) - - def many( - self, - keys: list[str] | None = None, - conditions: dict = None, - combine_operator: str = "AND", - limit: int = None, - ): - if keys is None: - columns = "*" - else: - columns = ", ".join(keys) - if limit is None: - limit = 10000 - condition_sql = self.__select_condition(conditions, combine_operator) - self._connector.execute( - f""" - SELECT {columns} FROM {self._table_name} - WHERE {condition_sql} - LIMIT {limit} - """, - conditions, - ) - return self._connector.fetchall(keys) - - def column(self, keys: list[str]): - columns = ", ".join(keys) - self._connector.execute( - f""" - SELECT {columns} FROM {self._table_name} - """, - ) - return self._connector.fetchall(keys) - - @staticmethod - def __select_condition(conditions: dict, combine_operator: str = "AND"): - if not conditions: - raise ValueError("No conditions provided.") - if combine_operator not in ["AND", "OR", "INSTR"]: - raise ValueError( - "Invalid combine_operator, must be 'AND' or 'OR' or 'INSTR'." - ) - if combine_operator == "INSTR": - condition_sql = f" AND ".join( - [f"INSTR({key}, :{key})" for key in conditions.keys()] - ) - else: - condition_sql = f" {combine_operator} ".join( - [f"{key} = :{key}" for key in conditions.keys()] - ) - return condition_sql diff --git a/backend/src/module/database/orm/update.py b/backend/src/module/database/orm/update.py deleted file mode 100644 index 7b022418..00000000 --- a/backend/src/module/database/orm/update.py +++ /dev/null @@ -1,98 +0,0 @@ -import logging - -logger = logging.getLogger(__name__) - - -class Update: - def __init__(self, connector, table_name: str, data: dict): - self._connector = connector - self._table_name = table_name - self._example_data = data - - def __table_exists(self) -> bool: - self._connector.execute( - f""" - SELECT name FROM sqlite_master - WHERE type='table' AND name='{self._table_name}' - """ - ) - return self._connector.fetch() is not None - - def table(self): - columns = ", ".join( - [ - f"{key} {self.__python_to_sqlite_type(value)}" - for key, value in self._example_data.items() - ] - ) - create_table_sql = f"CREATE TABLE IF NOT EXISTS {self._table_name} ({columns});" - self._connector.execute(create_table_sql) - logger.debug(f"Create table {self._table_name}.") - self._connector.execute(f"PRAGMA table_info({self._table_name})") - existing_columns = [x[1] for x in self._connector.fetch()] - for key, value in self._example_data.items(): - if key not in existing_columns: - insert_column = self.__python_to_sqlite_type(value) - if value is None: - value = "NULL" - add_column_sql = f"ALTER TABLE {self._table_name} ADD COLUMN {key} {insert_column} DEFAULT {value};" - self._connector.execute(add_column_sql) - logger.debug(f"Update table {self._table_name}.") - - def one(self, data: dict) -> bool: - _id = data["id"] - set_sql = ", ".join([f"{key} = :{key}" for key in data.keys()]) - self._connector.execute( - f""" - UPDATE {self._table_name} - SET {set_sql} - WHERE id = :id - """, - data, - ) - logger.debug(f"Update {_id} in {self._table_name}.") - return True - - def many(self, data: list[dict]) -> bool: - columns = ", ".join([f"{key} = :{key}" for key in data[0].keys()]) - self._connector.executemany( - f""" - UPDATE {self._table_name} - SET {columns} - WHERE id = :id - """, - data, - ) - logger.debug(f"Update {self._table_name}.") - return True - - def value(self, location: dict, set_value: dict) -> bool: - set_sql = ", ".join([f"{key} = :{key}" for key in set_value.keys()]) - params = {**location, **set_value} - self._connector.execute( - f""" - UPDATE {self._table_name} - SET {set_sql} - WHERE {location["key"]} = :{location["key"]} - """, - params, - ) - logger.debug(f"Update {self._table_name}.") - return True - - @staticmethod - def __python_to_sqlite_type(value) -> str: - if isinstance(value, int): - return "INTEGER NOT NULL" - elif isinstance(value, float): - return "REAL NOT NULL" - elif isinstance(value, str): - return "TEXT NOT NULL" - elif isinstance(value, bool): - return "INTEGER NOT NULL" - elif isinstance(value, list): - return "TEXT NOT NULL" - elif value is None: - return "TEXT" - else: - raise ValueError(f"Unsupported data type: {type(value)}") diff --git a/backend/src/test/test_database.py b/backend/src/test/test_database.py index b7088ef1..91f58f14 100644 --- a/backend/src/test/test_database.py +++ b/backend/src/test/test_database.py @@ -1,11 +1,11 @@ -from sqlmodel import create_engine, SQLModel, Session +from sqlmodel import create_engine, SQLModel from sqlmodel.pool import StaticPool from module.database import BangumiDatabase from module.models import Bangumi -def test_database(): +def test_bangumi_database(): # sqlite mock engine engine = create_engine( "sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool @@ -42,7 +42,7 @@ def test_database(): assert database.search_id(1) == test_data # search poster - assert database.match_poster("test") == "/test/test.jpg" + assert database.match_poster("test2 (2021)") == "/test/test.jpg" # delete database.delete_one(1)