diff --git a/backend/src/module/database/bangumi.py b/backend/src/module/database/bangumi.py index dd25edb5..9209e717 100644 --- a/backend/src/module/database/bangumi.py +++ b/backend/src/module/database/bangumi.py @@ -1,58 +1,59 @@ import logging from module.database.orm import Connector -from module.models import BangumiData +from module.models import Bangumi, BangumiUpdate from module.conf import DATA_PATH +from sqlmodel import Session, select, delete, SQLModel +from module.database.engine import engine +from typing import Optional +from sqlalchemy.exc import IntegrityError, NoResultFound logger = logging.getLogger(__name__) -class BangumiDatabase(Connector): - def __init__(self, database: str = DATA_PATH): - super().__init__( - table_name="bangumi", - data=self.__data_to_db(BangumiData()), - database=database, - ) - - def update_table(self): - self.update.table() +class BangumiDatabase(Session): + def __init__(self): + super().__init__(engine) + # table_name="bangumi", + # data=self.__data_to_db(BangumiData()), + # database=database, + # ) @staticmethod - def __data_to_db(data: BangumiData) -> dict: - db_data = data.dict() - for key, value in db_data.items(): - if isinstance(value, bool): - db_data[key] = int(value) - elif isinstance(value, list): - db_data[key] = ",".join(value) - return db_data + def update_table(): + SQLModel.metadata.create_all(engine) - @staticmethod - def __db_to_data(db_data: dict) -> BangumiData: - for key, item in db_data.items(): - if isinstance(item, int): - if key not in ["id", "offset", "season", "year"]: - db_data[key] = bool(item) - elif key in ["filter", "rss_link"]: - db_data[key] = item.split(",") - return BangumiData(**db_data) + # @staticmethod + # def __data_to_db(data: BangumiData) -> Bangumi: + # db_data = data.dict() + # for key, value in db_data.items(): + # if isinstance(value, bool): + # db_data[key] = int(value) + # elif isinstance(value, list): + # db_data[key] = ",".join(value) + # return db_data + # + # @staticmethod + # def __db_to_data(db_data: Bangumi) -> BangumiData: + # for key, item in db_data.items(): + # if isinstance(item, int): + # if key not in ["id", "offset", "season", "year"]: + # db_data[key] = bool(item) + # elif key in ["filter", "rss_link"]: + # db_data[key] = item.split(",") + # return BangumiData(**db_data) - def insert_one(self, data: BangumiData): - db_data = self.__data_to_db(data) - self.insert.one(db_data) + def insert_one(self, data: Bangumi): + self.add(data) + self.commit() + # db_data = self.__data_to_db(data) + # self.insert.one(db_data) logger.debug(f"[Database] Insert {data.official_title} into database.") - # if self.__check_exist(data): - # self.update_one(data) - # else: - # db_data = self.__data_to_db(data) - # db_data["id"] = self.gen_id() - # self._insert(db_data=db_data, table_name=self.__table_name) - # logger.debug(f"[Database] Insert {data.official_title} into database.") - def insert_list(self, data: list[BangumiData]): - data_list = [self.__data_to_db(x) for x in data] - self.insert.many(data_list) + def insert_list(self, data: list[Bangumi]): + self.add_all(data) + # data_list = [self.__data_to_db(x) for x in data] + # self.insert.many(data_list) # _id = self.gen_id() # for i, item in enumerate(data): # item.id = _id + i @@ -60,62 +61,92 @@ class BangumiDatabase(Connector): # self._insert_list(data_list=data_list, table_name=self.__table_name) logger.debug(f"[Database] Insert {len(data)} bangumi into database.") - def update_one(self, data: BangumiData) -> bool: - db_data = self.__data_to_db(data) - return self.update.one(db_data) + def update_one(self, data: BangumiUpdate) -> bool: + db_data = self.get(Bangumi, data.id) + if not db_data: + return False + bangumi_data = data.dict(exclude_unset=True) + for key, value in bangumi_data.items(): + setattr(db_data, key, value) + self.add(db_data) + self.commit() + self.refresh(db_data) + logger.debug(f"[Database] Update {data.official_title}") + return True - def update_list(self, data: list[BangumiData]): - data_list = [self.__data_to_db(x) for x in data] - self.update.many(data_list) + def update_list(self, datas: list[BangumiUpdate]): + for data in datas: + self.update_one(data) def update_rss(self, title_raw, rss_set: str): # Update rss and added - location = {"title_raw": title_raw} - set_value = {"rss_link": rss_set, "added": 0} - self.update.value(location, set_value) + statement = select(Bangumi).where(Bangumi.title_raw == title_raw) + bangumi = self.exec(statement).first() + bangumi.rss_link = rss_set + bangumi.added = 0 + self.add(bangumi) + self.commit() + self.refresh(bangumi) + # location = {"title_raw": title_raw} + # set_value = {"rss_link": rss_set, "added": 0} + # self.update.value(location, set_value) logger.debug(f"[Database] Update {title_raw} rss_link to {rss_set}.") def update_poster(self, title_raw, poster_link: str): - location = {"title_raw": title_raw} - set_value = {"poster_link": poster_link} - self.update.value(location, set_value) + statement = select(Bangumi).where(Bangumi.title_raw == title_raw) + bangumi = self.exec(statement).first() + bangumi.poster_link = poster_link + self.add(bangumi) + self.commit() + self.refresh(bangumi) logger.debug(f"[Database] Update {title_raw} poster_link to {poster_link}.") def delete_one(self, _id: int): - self.delete.one(_id) + statement = select(Bangumi).where(Bangumi.id == _id) + bangumi = self.exec(statement).first() + self.delete(bangumi) + self.commit() logger.debug(f"[Database] Delete bangumi id: {_id}.") def delete_all(self): - self.delete.all() + statement = delete(Bangumi) + self.exec(statement) + self.commit() - def search_all(self) -> list[BangumiData]: - all_data = self.select.all() - return [self.__db_to_data(x) for x in all_data] + def search_all(self) -> list[Bangumi]: + statement = select(Bangumi) + return self.exec(statement).all() - def search_id(self, _id: int) -> BangumiData | None: - dict_data = self.select.one(conditions={"id": _id}) - if dict_data is None: + def search_id(self, _id: int) -> Optional[Bangumi]: + statement = select(Bangumi).where(Bangumi.id == _id) + bangumi = self.exec(statement).first() + if bangumi is None: logger.warning(f"[Database] Cannot find bangumi id: {_id}.") return None - logger.debug(f"[Database] Find bangumi id: {_id}.") - return self.__db_to_data(dict_data) + else: + logger.debug(f"[Database] Find bangumi id: {_id}.") + return self.exec(statement).first() def match_poster(self, bangumi_name: str) -> str: - condition = {"official_title": bangumi_name} - keys = ["poster_link"] - data = self.select.one( - keys=keys, - conditions=condition, - combine_operator="INSTR", - ) - if not data: + # condition = {"official_title": bangumi_name} + statement = select(Bangumi).where(Bangumi.official_title == bangumi_name) + # keys = ["poster_link"] + # data = self.select.one( + # keys=keys, + # conditions=condition, + # combine_operator="INSTR", + # ) + data = self.exec(statement).first() + if data: + return data.poster_link + else: return "" - return data.get("poster_link") def match_list(self, torrent_list: list, rss_link: str) -> list: # Match title_raw in database - keys = ["title_raw", "rss_link", "poster_link"] - match_datas = self.select.column(keys) + # keys = ["title_raw", "rss_link", "poster_link"] + # match_datas = self.select.column(keys) + match_datas = self.search_all() if not match_datas: return torrent_list # Match title @@ -123,15 +154,15 @@ class BangumiDatabase(Connector): while i < len(torrent_list): torrent = torrent_list[i] for match_data in match_datas: - if match_data.get("title_raw") in torrent.name: - if rss_link not in match_data.get("rss_link"): - match_data["rss_link"] += f",{rss_link}" + if match_data.title_raw in torrent.name: + if rss_link not in match_data.rss_link: + match_data.rss_link += f",{rss_link}" self.update_rss( - match_data.get("title_raw"), match_data.get("rss_link") + match_data.title_raw, match_data.rss_link ) - if not match_data.get("poster_link"): + if not match_data.poster_link: self.update_poster( - match_data.get("title_raw"), torrent.poster_link + match_data.title_raw, torrent.poster_link ) torrent_list.pop(i) break @@ -139,20 +170,23 @@ class BangumiDatabase(Connector): i += 1 return torrent_list - def not_complete(self) -> list[BangumiData]: + def not_complete(self) -> list[Bangumi]: # Find eps_complete = False - condition = {"eps_collect": 0} - dict_data = self.select.many( - conditions=condition, - ) - return [self.__db_to_data(x) for x in dict_data] + condition = select(Bangumi).where(Bangumi.eps_collect == 0) + datas = self.exec(condition).all() + return datas - def not_added(self) -> list[BangumiData]: - conditions = {"added": 0, "rule_name": None, "save_path": None} - dict_data = self.select.many(conditions=conditions, combine_operator="OR") - return [self.__db_to_data(x) for x in dict_data] + def not_added(self) -> list[Bangumi]: + conditions = select(Bangumi).where( + Bangumi.added == 0 or + Bangumi.rule_name is None or + Bangumi.save_path is None + ) + datas = self.exec(conditions).all() + # dict_data = self.select.many(conditions=conditions, combine_operator="OR") + return datas if __name__ == "__main__": with BangumiDatabase() as db: - print(db.match_poster("久保")) + print(db.not_complete()) diff --git a/backend/src/module/database/engine.py b/backend/src/module/database/engine.py new file mode 100644 index 00000000..e4bd2f5d --- /dev/null +++ b/backend/src/module/database/engine.py @@ -0,0 +1,6 @@ +from sqlmodel import create_engine, Session + + +engine = create_engine("sqlite:///data/data.db") + +db_session = Session(engine) \ No newline at end of file diff --git a/backend/src/module/models/bangumi.py b/backend/src/module/models/bangumi.py index b8af3670..02a74ce9 100644 --- a/backend/src/module/models/bangumi.py +++ b/backend/src/module/models/bangumi.py @@ -1,27 +1,46 @@ from dataclasses import dataclass -from pydantic import BaseModel, Field +from pydantic import BaseModel +from sqlmodel import SQLModel, Field +from typing import Optional -class BangumiData(BaseModel): - id: int = Field(0, alias="id", title="番剧ID") - official_title: str = Field("official_title", alias="official_title", title="番剧中文名") - year: str | None = Field(None, alias="year", title="番剧年份") - title_raw: str = Field("title_raw", alias="title_raw", title="番剧原名") - season: int = Field(1, alias="season", title="番剧季度") - season_raw: str | None = Field(None, alias="season_raw", title="番剧季度原名") - group_name: str | None = Field(None, alias="group_name", title="字幕组") - dpi: str | None = Field(None, alias="dpi", title="分辨率") - source: str | None = Field(None, alias="source", title="来源") - subtitle: str | None = Field(None, alias="subtitle", title="字幕") - eps_collect: bool = Field(False, alias="eps_collect", title="是否已收集") - offset: int = Field(0, alias="offset", title="番剧偏移量") - filter: list[str] = Field(["720", "\\d+-\\d+"], alias="filter", title="番剧过滤器") - rss_link: list[str] = Field([], alias="rss_link", title="番剧RSS链接") - poster_link: str | None = Field(None, alias="poster_link", title="番剧海报链接") - added: bool = Field(False, alias="added", title="是否已添加") - rule_name: str | None = Field(None, alias="rule_name", title="番剧规则名") - save_path: str | None = Field(None, alias="save_path", title="番剧保存路径") +class Bangumi(SQLModel, table=True): + id: int = Field(default=None, primary_key=True) + official_title: str = Field(default="official_title", alias="official_title", title="番剧中文名") + year: Optional[str] = Field(alias="year", title="番剧年份") + title_raw: str = Field(default="title_raw", alias="title_raw", title="番剧原名") + season: int = Field(default=1, alias="season", title="番剧季度") + season_raw: Optional[str] = Field(alias="season_raw", title="番剧季度原名") + group_name: Optional[str] = Field(alias="group_name", title="字幕组") + dpi: Optional[str] = Field(alias="dpi", title="分辨率") + source: Optional[str] = Field(alias="source", title="来源") + subtitle: Optional[str] = Field(alias="subtitle", title="字幕") + eps_collect: bool = Field(default=False, alias="eps_collect", title="是否已收集") + offset: int = Field(default=0, alias="offset", title="番剧偏移量") + filter: str = Field(default="720, \\d+-\\d+", alias="filter", title="番剧过滤器") + rss_link: str = Field(default="", alias="rss_link", title="番剧RSS链接") + poster_link: Optional[str] = Field(alias="poster_link", title="番剧海报链接") + added: bool = Field(default=False, alias="added", title="是否已添加") + rule_name: Optional[str] = Field(alias="rule_name", title="番剧规则名") + save_path: Optional[str] = Field(alias="save_path", title="番剧保存路径") + deleted: bool = Field(False, alias="deleted", title="是否已删除") + + +class BangumiUpdate(SQLModel): + official_title: str = Field(default="official_title", alias="official_title", title="番剧中文名") + year: Optional[str] = Field(alias="year", title="番剧年份") + season: int = Field(default=1, alias="season", title="番剧季度") + season_raw: Optional[str] = Field(alias="season_raw", title="番剧季度原名") + group_name: Optional[str] = Field(alias="group_name", title="字幕组") + dpi: Optional[str] = Field(alias="dpi", title="分辨率") + source: Optional[str] = Field(alias="source", title="来源") + subtitle: Optional[str] = Field(alias="subtitle", title="字幕") + eps_collect: bool = Field(default=False, alias="eps_collect", title="是否已收集") + offset: int = Field(default=0, alias="offset", title="番剧偏移量") + filter: str = Field(default="720, \\d+-\\d+", alias="filter", title="番剧过滤器") + rss_link: str = Field(default="", alias="rss_link", title="番剧RSS链接") + added: bool = Field(default=False, alias="added", title="是否已添加") deleted: bool = Field(False, alias="deleted", title="是否已删除") @@ -29,14 +48,14 @@ class Notification(BaseModel): official_title: str = Field(..., alias="official_title", title="番剧名") season: int = Field(..., alias="season", title="番剧季度") episode: int = Field(..., alias="episode", title="番剧集数") - poster_path: str | None = Field(None, alias="poster_path", title="番剧海报路径") + poster_path: Optional[str] = Field(None, alias="poster_path", title="番剧海报路径") @dataclass class Episode: - title_en: str | None - title_zh: str | None - title_jp: str | None + title_en: Optional[str] + title_zh: Optional[str] + title_jp: Optional[str] season: int season_raw: str episode: int