diff --git a/backend/src/module/database/__init__.py b/backend/src/module/database/__init__.py index cd90685d..87d8425a 100644 --- a/backend/src/module/database/__init__.py +++ b/backend/src/module/database/__init__.py @@ -1,3 +1,2 @@ -from .bangumi import BangumiDatabase -from .rss import RSSDatabase -from .torrent import TorrentDatabase \ No newline at end of file +from .combine import Database +from .engine import engine diff --git a/backend/src/module/database/bangumi.py b/backend/src/module/database/bangumi.py index a432dc94..29273102 100644 --- a/backend/src/module/database/bangumi.py +++ b/backend/src/module/database/bangumi.py @@ -13,16 +13,16 @@ class BangumiDatabase: def __init__(self, session: Session): self.session = session - def insert_one(self, data: Bangumi): + def add(self, data: Bangumi): self.session.add(data) self.session.commit() logger.debug(f"[Database] Insert {data.official_title} into database.") - def insert_list(self, data: list[Bangumi]): + def add_all(self, data: list[Bangumi]): self.session.add_all(data) logger.debug(f"[Database] Insert {len(data)} bangumi into database.") - def update_one(self, data: Bangumi) -> bool: + def update(self, data: Bangumi) -> bool: db_data = self.session.get(Bangumi, data.id) if not db_data: return False @@ -35,9 +35,10 @@ class BangumiDatabase: logger.debug(f"[Database] Update {data.official_title}") return True - def update_list(self, datas: list[Bangumi]): - for data in datas: - self.update_one(data) + def update_all(self, datas: list[Bangumi]): + self.session.add_all(datas) + self.session.commit() + logger.debug(f"[Database] Update {len(datas)} bangumi.") def update_rss(self, title_raw, rss_set: str): # Update rss and added @@ -119,13 +120,16 @@ class BangumiDatabase: def match_torrent(self, torrent_name: str) -> Optional[Bangumi]: statement = select(Bangumi).where( - and_(func.instr(torrent_name, Bangumi.title_raw) > 0, not Bangumi.deleted) + and_( + func.instr(torrent_name, Bangumi.title_raw) > 0, + Bangumi.deleted == False, + ) ) return self.session.exec(statement).first() def not_complete(self) -> list[Bangumi]: # Find eps_complete = False - condition = select(Bangumi).where(Bangumi.eps_collect == 0) + condition = select(Bangumi).where(Bangumi.eps_collect == False) datas = self.session.exec(condition).all() return datas diff --git a/backend/src/module/database/combine.py b/backend/src/module/database/combine.py index 9066f378..e6dce3a7 100644 --- a/backend/src/module/database/combine.py +++ b/backend/src/module/database/combine.py @@ -1,18 +1,17 @@ from sqlmodel import Session, SQLModel -from .engine import engine from .rss import RSSDatabase from .torrent import TorrentDatabase from .bangumi import BangumiDatabase class Database(Session): - def __init__(self, _engine=engine): - super().__init__(_engine) + def __init__(self, engine): + self.engine = engine + super().__init__(engine) self.rss = RSSDatabase(self) self.torrent = TorrentDatabase(self) self.bangumi = BangumiDatabase(self) - @staticmethod - def create_table(): - SQLModel.metadata.create_all(engine) + def create_table(self): + SQLModel.metadata.create_all(self.engine) diff --git a/backend/src/module/database/torrent.py b/backend/src/module/database/torrent.py index 2c32dfca..3ccafb14 100644 --- a/backend/src/module/database/torrent.py +++ b/backend/src/module/database/torrent.py @@ -22,13 +22,13 @@ class TorrentDatabase: self.session.commit() logger.debug(f"Insert {len(datas)} torrents in database.") - def update_one_sys(self, data: Torrent): + def update(self, data: Torrent): self.session.add(data) self.session.commit() self.session.refresh(data) logger.debug(f"Update {data.name} in database.") - def update_many_sys(self, datas: list[Torrent]): + def update_all(self, datas: list[Torrent]): self.session.add_all(datas) self.session.commit() @@ -38,7 +38,7 @@ class TorrentDatabase: self.session.refresh(data) logger.debug(f"Update {data.name} in database.") - def search_one(self, _id: int) -> Torrent: + def search(self, _id: int) -> Torrent: return self.session.exec(select(Torrent).where(Torrent.id == _id)).first() def search_all(self) -> list[Torrent]: diff --git a/backend/src/module/rss/analyser.py b/backend/src/module/rss/analyser.py index 00bde7a6..9746547d 100644 --- a/backend/src/module/rss/analyser.py +++ b/backend/src/module/rss/analyser.py @@ -2,8 +2,8 @@ import logging import re from module.conf import settings -from module.database import BangumiDatabase from module.models import Bangumi +from module.database import Database from module.network import RequestContent, TorrentInfo from module.parser import TitleParser @@ -13,8 +13,6 @@ logger = logging.getLogger(__name__) class RSSAnalyser: def __init__(self): self._title_analyser = TitleParser() - with BangumiDatabase() as db: - db.update_table() def official_title_parser(self, data: Bangumi, mikan_title: str): if settings.rss_parser.parser_type == "mikan": @@ -78,10 +76,10 @@ class RSSAnalyser: return data def rss_to_data( - self, rss_link: str, database: BangumiDatabase, full_parse: bool = True + self, rss_link: str, database: Database, full_parse: bool = True ) -> list[Bangumi]: rss_torrents = self.get_rss_torrents(rss_link, full_parse) - torrents_to_add = database.match_list(rss_torrents, rss_link) + torrents_to_add = database.bangumi.match_list(rss_torrents, rss_link) if not torrents_to_add: logger.debug("[RSS] No new title has been found.") return [] diff --git a/backend/src/module/rss/engine.py b/backend/src/module/rss/engine.py index 1c775ce7..5b608798 100644 --- a/backend/src/module/rss/engine.py +++ b/backend/src/module/rss/engine.py @@ -2,13 +2,18 @@ import re import logging from module.models import Bangumi, RSSItem, Torrent -from module.network import RequestContent, TorrentInfo +from module.network import RequestContent from module.downloader import DownloadClient -from module.database.combine import Database +from module.database import Database, engine + +logger = logging.getLogger(__name__) class RSSEngine(Database): + def __init__(self, _engine=engine): + super().__init__(_engine) + @staticmethod def _get_torrents(rss_link: str) -> list[Torrent]: with RequestContent() as req: diff --git a/backend/src/test/test_database.py b/backend/src/test/test_database.py index 4f3e983b..cfe9289e 100644 --- a/backend/src/test/test_database.py +++ b/backend/src/test/test_database.py @@ -2,15 +2,16 @@ from sqlmodel import create_engine, SQLModel from sqlmodel.pool import StaticPool from module.database.combine import Database -from module.models import Bangumi +from module.models import Bangumi, Torrent, RSSItem + + +# sqlite mock engine +engine = create_engine( + "sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool +) def test_bangumi_database(): - # sqlite mock engine - engine = create_engine( - "sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool - ) - SQLModel.metadata.create_all(engine) test_data = Bangumi( official_title="test", year="2021", @@ -32,18 +33,46 @@ def test_bangumi_database(): deleted=False, ) with Database(engine) as db: + db.create_table() # insert - db.bangumi.insert_one(test_data) + db.bangumi.add(test_data) assert db.bangumi.search_id(1) == test_data # update test_data.official_title = "test2" - db.bangumi.update_one(test_data) + db.bangumi.update(test_data) assert db.bangumi.search_id(1) == test_data # search poster assert db.bangumi.match_poster("test2 (2021)") == "/test/test.jpg" + # match torrent + result = db.bangumi.match_torrent("[Sub Group]test S02 01 [720p].mkv") + assert result.official_title == "test2" + # delete db.bangumi.delete_one(1) assert db.bangumi.search_id(1) is None + + +def test_torrent_database(): + test_data = Torrent( + name="[Sub Group]test S02 01 [720p].mkv", + url="https://test.com/test.mkv", + ) + with Database(engine) as db: + # insert + db.torrent.add(test_data) + assert db.torrent.search(1) == test_data + + # update + test_data.downloaded = True + db.torrent.update(test_data) + assert db.torrent.search(1) == test_data + + +def test_rss_database(): + rss_url = "https://test.com/test.xml" + + with Database(engine) as db: + db.rss.add(RSSItem(url=rss_url)) diff --git a/backend/src/test/test_rss_engine.py b/backend/src/test/test_rss_engine.py new file mode 100644 index 00000000..e2e8f90f --- /dev/null +++ b/backend/src/test/test_rss_engine.py @@ -0,0 +1,22 @@ +from module.rss.engine import RSSEngine + +from .test_database import engine as e + + +def test_rss_engine(): + with RSSEngine(e) as engine: + rss_link = "https://mikanani.me/RSS/Bangumi?bangumiId=2353&subgroupid=552" + + engine.add_rss(rss_link, combine=False) + + result = engine.rss.search_active() + + assert result[1].item_path == "Mikan Project - 无职转生~到了异世界就拿出真本事~" + + new_torrents = engine.pull_rss(result[1]) + torrent = new_torrents[0] + + assert torrent.name == "[Lilith-Raws] 无职转生,到了异世界就拿出真本事 / Mushoku Tensei - 11 [Baha][WEB-DL][1080p][AVC AAC][CHT][MP4]" + + +