From 472a5093e956623f12d5ce1ed7600fb7fea609e4 Mon Sep 17 00:00:00 2001 From: estrella Date: Thu, 3 Aug 2023 14:16:52 +0800 Subject: [PATCH] refactor: database and rss engine --- backend/src/module/database/bangumi.py | 76 +++++++++--------- backend/src/module/database/combine.py | 18 +++++ backend/src/module/database/rss.py | 59 ++++++-------- backend/src/module/database/torrent.py | 77 ++++++++++--------- backend/src/module/models/__init__.py | 6 +- backend/src/module/models/bangumi.py | 4 + backend/src/module/models/config.py | 2 - backend/src/module/models/rss.py | 22 +++--- backend/src/module/models/torrent.py | 21 ++--- .../src/module/network/request_contents.py | 5 +- backend/src/module/rss/engine.py | 35 ++++++--- backend/src/test/test_database.py | 18 ++--- 12 files changed, 189 insertions(+), 154 deletions(-) create mode 100644 backend/src/module/database/combine.py diff --git a/backend/src/module/database/bangumi.py b/backend/src/module/database/bangumi.py index 9ab0b092..64de1994 100644 --- a/backend/src/module/database/bangumi.py +++ b/backend/src/module/database/bangumi.py @@ -4,35 +4,34 @@ from sqlmodel import Session, select, delete, or_ from sqlalchemy.sql import func from typing import Optional -from .engine import engine from module.models import Bangumi logger = logging.getLogger(__name__) -class BangumiDatabase(Session): - def __init__(self, _engine=engine): - super().__init__(_engine) +class BangumiDatabase: + def __init__(self, session: Session): + self.session = session def insert_one(self, data: Bangumi): - self.add(data) - self.commit() + self.session.add(data) + self.session.commit() logger.debug(f"[Database] Insert {data.official_title} into database.") def insert_list(self, data: list[Bangumi]): - self.add_all(data) + self.session.add_all(data) logger.debug(f"[Database] Insert {len(data)} bangumi into database.") def update_one(self, data: Bangumi) -> bool: - db_data = self.get(Bangumi, data.id) + db_data = self.session.get(Bangumi, data.id) if not db_data: return False bangumi_data = data.dict(exclude_unset=True) for key, value in bangumi_data.items(): setattr(db_data, key, value) - self.add(db_data) - self.commit() - self.refresh(db_data) + self.session.add(db_data) + self.session.commit() + self.session.refresh(db_data) logger.debug(f"[Database] Update {data.official_title}") return True @@ -43,53 +42,53 @@ class BangumiDatabase(Session): def update_rss(self, title_raw, rss_set: str): # Update rss and added statement = select(Bangumi).where(Bangumi.title_raw == title_raw) - bangumi = self.exec(statement).first() + bangumi = self.session.exec(statement).first() bangumi.rss_link = rss_set bangumi.added = False - self.add(bangumi) - self.commit() - self.refresh(bangumi) + self.session.add(bangumi) + self.session.commit() + self.session.refresh(bangumi) logger.debug(f"[Database] Update {title_raw} rss_link to {rss_set}.") def update_poster(self, title_raw, poster_link: str): statement = select(Bangumi).where(Bangumi.title_raw == title_raw) - bangumi = self.exec(statement).first() + bangumi = self.session.exec(statement).first() bangumi.poster_link = poster_link - self.add(bangumi) - self.commit() - self.refresh(bangumi) + self.session.add(bangumi) + self.session.commit() + self.session.refresh(bangumi) logger.debug(f"[Database] Update {title_raw} poster_link to {poster_link}.") def delete_one(self, _id: int): statement = select(Bangumi).where(Bangumi.id == _id) - bangumi = self.exec(statement).first() - self.delete(bangumi) - self.commit() + bangumi = self.session.exec(statement).first() + self.session.delete(bangumi) + self.session.commit() logger.debug(f"[Database] Delete bangumi id: {_id}.") def delete_all(self): statement = delete(Bangumi) - self.exec(statement) - self.commit() + self.session.exec(statement) + self.session.commit() def search_all(self) -> list[Bangumi]: statement = select(Bangumi) - return self.exec(statement).all() + return self.session.exec(statement).all() def search_id(self, _id: int) -> Optional[Bangumi]: statement = select(Bangumi).where(Bangumi.id == _id) - bangumi = self.exec(statement).first() + bangumi = self.session.exec(statement).first() if bangumi is None: logger.warning(f"[Database] Cannot find bangumi id: {_id}.") return None else: logger.debug(f"[Database] Find bangumi id: {_id}.") - return self.exec(statement).first() + return self.session.exec(statement).first() def match_poster(self, bangumi_name: str) -> str: # Use like to match statement = select(Bangumi).where(func.instr(bangumi_name, Bangumi.title_raw) > 0) - data = self.exec(statement).first() + data = self.session.exec(statement).first() if data: return data.poster_link else: @@ -119,7 +118,7 @@ class BangumiDatabase(Session): def not_complete(self) -> list[Bangumi]: # Find eps_complete = False condition = select(Bangumi).where(Bangumi.eps_collect == 0) - datas = self.exec(condition).all() + datas = self.session.exec(condition).all() return datas def not_added(self) -> list[Bangumi]: @@ -128,19 +127,20 @@ class BangumiDatabase(Session): Bangumi.added == 0, Bangumi.rule_name is None, Bangumi.save_path is None ) ) - datas = self.exec(conditions).all() + datas = self.session.exec(conditions).all() return datas def disable_rule(self, _id: int): statement = select(Bangumi).where(Bangumi.id == _id) - bangumi = self.exec(statement).first() + bangumi = self.session.exec(statement).first() bangumi.deleted = True - self.add(bangumi) - self.commit() - self.refresh(bangumi) + self.session.add(bangumi) + self.session.commit() + self.session.refresh(bangumi) logger.debug(f"[Database] Disable rule {bangumi.title_raw}.") - -if __name__ == "__main__": - with BangumiDatabase() as db: - print(db.not_complete()) + def search_rss(self, rss_link: str) -> list[Bangumi]: + statement = select(Bangumi).where( + func.instr(rss_link, Bangumi.rss_link) > 0 + ) + return self.session.exec(statement).all() \ No newline at end of file diff --git a/backend/src/module/database/combine.py b/backend/src/module/database/combine.py new file mode 100644 index 00000000..20c2e32e --- /dev/null +++ b/backend/src/module/database/combine.py @@ -0,0 +1,18 @@ +from sqlmodel import Session,SQLModel + +from .engine import engine +from .rss import RSSDatabase +from .torrent import TorrentDatabase +from .bangumi import BangumiDatabase + + +class Database(Session): + def __init__(self, _engine=engine): + super().__init__(_engine) + self.rss = RSSDatabase(self) + self.torrent = TorrentDatabase(self) + self.bangumi = BangumiDatabase(self) + + @staticmethod + def create_table(): + SQLModel.metadata.create_all(engine) \ No newline at end of file diff --git a/backend/src/module/database/rss.py b/backend/src/module/database/rss.py index 8c86ef66..72b08a0f 100644 --- a/backend/src/module/database/rss.py +++ b/backend/src/module/database/rss.py @@ -1,44 +1,33 @@ -from .orm import Connector +import logging +from sqlmodel import Session, select, delete + +from .engine import engine from module.models import RSSItem -from module.conf import DATA_PATH + +logger = logging.getLogger(__name__) -class RSSDatabase(Connector): - def __init__(self, database: str = DATA_PATH): - super().__init__( - table_name="RSSItem", - data=RSSItem().dict(), - database=database - ) - - @staticmethod - def __data_to_db(data: RSSItem) -> dict: - db_data = data.dict() - for key, value in db_data.items(): - if isinstance(value, bool): - db_data[key] = int(value) - elif isinstance(value, list): - db_data[key] = ",".join(value) - return db_data - - @staticmethod - def __db_to_data(db_data: dict) -> RSSItem: - for key, item in db_data.items(): - if isinstance(item, int): - db_data[key] = bool(item) - return RSSItem(**db_data) - - def update_table(self): - self.update.table() +class RSSDatabase: + def __init__(self, session: Session): + self.session = session def insert_one(self, data: RSSItem): - dict_data = self.__data_to_db(data) - self.insert.one(data=dict_data) + self.session.add(data) + self.session.commit() + self.session.refresh(data) - def get_all(self) -> list[RSSItem]: - dict_datas = self.select.all() - return [self.__db_to_data(x) for x in dict_datas] + def search_all(self) -> list[RSSItem]: + return self.session.exec(select(RSSItem)).all() def delete_one(self, _id: int): - self.delete.one(_id) + condition = delete(RSSItem).where(RSSItem.id == _id) + self.session.exec(condition) + self.session.commit() + + def delete_all(self): + condition = delete(RSSItem) + self.session.exec(condition) + self.session.commit() + + diff --git a/backend/src/module/database/torrent.py b/backend/src/module/database/torrent.py index dd16823a..e6fc2d02 100644 --- a/backend/src/module/database/torrent.py +++ b/backend/src/module/database/torrent.py @@ -1,51 +1,54 @@ import logging -from module.database.orm import Connector -from module.models import TorrentData -from module.conf import DATA_PATH +from sqlmodel import Session, select + +from module.models import Torrent logger = logging.getLogger(__name__) -class TorrentDatabase(Connector): - def __init__(self, database: str = DATA_PATH): - super().__init__( - table_name="torrent", data=TorrentData().dict(), database=database - ) +class TorrentDatabase: + def __init__(self, session: Session): + self.session = session - def update_table(self): - self.update.table() + def insert_one(self, data: Torrent): + self.session.add(data) + self.session.commit() + self.session.refresh(data) + logger.debug(f"Insert {data.name} in database.") - def __data_to_db(self, data: TorrentData) -> dict: - db_data = data.dict() - for key, value in db_data.items(): - if isinstance(value, bool): - db_data[key] = int(value) - elif isinstance(value, list): - db_data[key] = ",".join(value) - return db_data + def insert_many(self, datas: list[Torrent]): + self.session.add_all(datas) + self.session.commit() + logger.debug(f"Insert {len(datas)} torrents in database.") - def __db_to_data(self, db_data: dict) -> TorrentData: - for key, item in db_data.items(): - if isinstance(item, int): - db_data[key] = bool(item) - elif key in ["filter", "rss_link"]: - db_data[key] = item.split(",") - return TorrentData(**db_data) + def update_one_sys(self, data: Torrent): + self.session.add(data) + self.session.commit() + self.session.refresh(data) + logger.debug(f"Update {data.name} in database.") - def insert_many(self, data_list: list[TorrentData]): - dict_datas = [self.__data_to_db(data) for data in data_list] - self.insert.many(dict_datas) + def update_many_sys(self, datas: list[Torrent]): + self.session.add_all(datas) + self.session.commit() - def get_all(self) -> list[TorrentData]: - dict_datas = self.select.all() - return [self.__db_to_data(data) for data in dict_datas] + def update_one_user(self, data: Torrent): + self.session.add(data) + self.session.commit() + self.session.refresh(data) + logger.debug(f"Update {data.name} in database.") - def get_torrent_name(self) -> list[str]: - dict_data = self.select.all() - return [data["name"] for data in dict_data] + def search_one(self, _id: int) -> Torrent: + return self.session.exec(select(Torrent).where(Torrent.id == _id)).first() + def search_all(self) -> list[Torrent]: + return self.session.exec(select(Torrent)).all() -if __name__ == "__main__": - with TorrentDatabase() as db: - db.update_table() + def check_new(self, torrents_list: list[Torrent]) -> list[Torrent]: + new_torrents = [] + for torrent in torrents_list: + statement = select(Torrent).where(Torrent.name == torrent.name) + db_torrent = self.session.exec(statement).first() + if not db_torrent: + new_torrents.append(torrent) + return new_torrents diff --git a/backend/src/module/models/__init__.py b/backend/src/module/models/__init__.py index 0b10feaf..587d548f 100644 --- a/backend/src/module/models/__init__.py +++ b/backend/src/module/models/__init__.py @@ -1,5 +1,5 @@ -from .bangumi import Bangumi, Episode +from .bangumi import Bangumi, Episode, BangumiUpdate from .config import Config -from .rss import RSSTorrents -from .torrent import EpisodeFile, SubtitleFile, TorrentBase +from .rss import RSSItem, RSSUpdate +from .torrent import EpisodeFile, SubtitleFile, Torrent, TorrentUpdate from .user import UserLogin diff --git a/backend/src/module/models/bangumi.py b/backend/src/module/models/bangumi.py index b7484887..2613ff38 100644 --- a/backend/src/module/models/bangumi.py +++ b/backend/src/module/models/bangumi.py @@ -34,6 +34,7 @@ class BangumiUpdate(SQLModel): default="official_title", alias="official_title", title="番剧中文名" ) year: Optional[str] = Field(alias="year", title="番剧年份") + title_raw: str = Field(default="title_raw", alias="title_raw", title="番剧原名") season: int = Field(default=1, alias="season", title="番剧季度") season_raw: Optional[str] = Field(alias="season_raw", title="番剧季度原名") group_name: Optional[str] = Field(alias="group_name", title="字幕组") @@ -44,7 +45,10 @@ class BangumiUpdate(SQLModel): offset: int = Field(default=0, alias="offset", title="番剧偏移量") filter: str = Field(default="720, \\d+-\\d+", alias="filter", title="番剧过滤器") rss_link: str = Field(default="", alias="rss_link", title="番剧RSS链接") + poster_link: Optional[str] = Field(alias="poster_link", title="番剧海报链接") added: bool = Field(default=False, alias="added", title="是否已添加") + rule_name: Optional[str] = Field(alias="rule_name", title="番剧规则名") + save_path: Optional[str] = Field(alias="save_path", title="番剧保存路径") deleted: bool = Field(False, alias="deleted", title="是否已删除") diff --git a/backend/src/module/models/config.py b/backend/src/module/models/config.py index 76ec036f..a1169ad2 100644 --- a/backend/src/module/models/config.py +++ b/backend/src/module/models/config.py @@ -1,8 +1,6 @@ from os.path import expandvars from pydantic import BaseModel, Field -# Sub config - class Program(BaseModel): rss_time: int = Field(7200, description="Sleep time") diff --git a/backend/src/module/models/rss.py b/backend/src/module/models/rss.py index f9fe9aed..aa98f028 100644 --- a/backend/src/module/models/rss.py +++ b/backend/src/module/models/rss.py @@ -1,17 +1,21 @@ -from pydantic import BaseModel, Field +from sqlmodel import SQLModel, Field +from typing import Optional -class RSSItem(BaseModel): - id: int = Field(0, alias="id", title="id") +class RSSItem(SQLModel, table=True): + id: int = Field(default=None, primary_key=True, alias="id") item_path: str = Field("example path", alias="item_path") url: str = Field("https://mikanani.me", alias="url") combine: bool = Field(True, alias="combine") enabled: bool = Field(True, alias="enabled") -class TorrentData(BaseModel): - id: int = Field(0, alias="id") - rss_id: int = Field(0, alias="rss_id") - name: str = Field("", alias="name") - url: str = Field("https://example.com/torrent", alias="url") - save_path: str = Field("path/to/save", alias="save_path") +class RSSUpdate(SQLModel): + item_path: Optional[str] = Field("example path", alias="item_path") + url: Optional[str] = Field("https://mikanani.me", alias="url") + combine: Optional[bool] = Field(True, alias="combine") + enabled: Optional[bool] = Field(True, alias="enabled") + + + + diff --git a/backend/src/module/models/torrent.py b/backend/src/module/models/torrent.py index 892d66d6..e44f6104 100644 --- a/backend/src/module/models/torrent.py +++ b/backend/src/module/models/torrent.py @@ -1,16 +1,19 @@ -from pydantic import BaseModel, Field +from pydantic import BaseModel +from sqlmodel import SQLModel, Field +from typing import Optional -class TorrentBase(BaseModel): - name: str = Field(...) - torrent_link: str = Field(...) - homepage: str | None = Field(None) +class Torrent(SQLModel, table=True): + id: int = Field(default=None, primary_key=True, alias="id") + refer_id: Optional[int] = Field(None, alias="refer_id") + name: str = Field("", alias="name") + url: str = Field("https://example.com/torrent", alias="url") + homepage: Optional[str] = Field(None, alias="homepage") + downloaded: bool = Field(False, alias="downloaded") -class FileSet(BaseModel): - media_path: str = Field(...) - sc_subtitle: str | None = Field(None) - tc_subtitle: str | None = Field(None) +class TorrentUpdate(SQLModel): + downloaded: bool = Field(False, alias="downloaded") class EpisodeFile(BaseModel): diff --git a/backend/src/module/network/request_contents.py b/backend/src/module/network/request_contents.py index 7f00a1d1..c42018f0 100644 --- a/backend/src/module/network/request_contents.py +++ b/backend/src/module/network/request_contents.py @@ -37,7 +37,6 @@ class TorrentInfo: class RequestContent(RequestURL): - # Mikanani RSS def get_torrents( self, _url: str, @@ -103,4 +102,6 @@ class RequestContent(RequestURL): if __name__ == '__main__': with RequestContent() as req: - req.get_xml("https://mikanani.me/RSS/Classic") + ts = req.get_torrents("https://mikanani.me/RSS/Classic") + for t in ts: + print(t) diff --git a/backend/src/module/rss/engine.py b/backend/src/module/rss/engine.py index 5fb7ff97..da68c701 100644 --- a/backend/src/module/rss/engine.py +++ b/backend/src/module/rss/engine.py @@ -1,20 +1,35 @@ import re from module.database import RSSDatabase, BangumiDatabase, TorrentDatabase -from module.models import BangumiData, RSSItem, TorrentData +from module.models import Bangumi, RSSItem, Torrent from module.network import RequestContent, TorrentInfo +from module.database.combine import Database -class RSSEngine(RequestContent): - @staticmethod - def _get_rss_items() -> list[RSSItem]: - with RSSDatabase() as db: - return db.get_all() - @staticmethod - def _get_bangumi_data(rss_link: str) -> list[BangumiData]: - with BangumiDatabase() as db: - return db.get_rss(rss_link) +class RSSEngine(Database): + def _get_rss_items(self) -> list[RSSItem]: + return self.rss.search_all() + + def _get_bangumi_data(self, rss_link: str) -> list[Bangumi]: + return self.bangumi.search_rss(rss_link) + + def get_torrent(self, rss_link: str) -> list[Torrent]: + with RequestContent() as req: + torrent_infos = req.get_torrents(rss_link) + torrents: list[Torrent] = [] + for torrent_info in torrent_infos: + torrents.append( + Torrent( + name=torrent_info.name, + url=torrent_info.torrent_link, + homepage=torrent_info.homepage, + ) + ) + return torrents + + def check_new_torrents(self, torrents_list: list[list[Torrent]]) -> list[Torrent]: + return self.torrent.check_new(torrents_list) def add_rss(self, rss_link: str, name: str, combine: bool): if not name: diff --git a/backend/src/test/test_database.py b/backend/src/test/test_database.py index 91f58f14..4f3e983b 100644 --- a/backend/src/test/test_database.py +++ b/backend/src/test/test_database.py @@ -1,7 +1,7 @@ from sqlmodel import create_engine, SQLModel from sqlmodel.pool import StaticPool -from module.database import BangumiDatabase +from module.database.combine import Database from module.models import Bangumi @@ -31,19 +31,19 @@ def test_bangumi_database(): save_path=None, deleted=False, ) - with BangumiDatabase(engine) as database: + with Database(engine) as db: # insert - database.insert_one(test_data) - assert database.search_id(1) == test_data + db.bangumi.insert_one(test_data) + assert db.bangumi.search_id(1) == test_data # update test_data.official_title = "test2" - database.update_one(test_data) - assert database.search_id(1) == test_data + db.bangumi.update_one(test_data) + assert db.bangumi.search_id(1) == test_data # search poster - assert database.match_poster("test2 (2021)") == "/test/test.jpg" + assert db.bangumi.match_poster("test2 (2021)") == "/test/test.jpg" # delete - database.delete_one(1) - assert database.search_id(1) is None + db.bangumi.delete_one(1) + assert db.bangumi.search_id(1) is None