diff --git a/backend/src/module/api/bangumi.py b/backend/src/module/api/bangumi.py index 16c206f3..7254fe45 100644 --- a/backend/src/module/api/bangumi.py +++ b/backend/src/module/api/bangumi.py @@ -2,13 +2,13 @@ from fastapi import APIRouter, Depends, HTTPException, status from fastapi.responses import JSONResponse from module.manager import TorrentManager -from module.models import BangumiData +from module.models import Bangumi from module.security import get_current_user router = APIRouter(prefix="/bangumi", tags=["bangumi"]) -@router.get("/getAll", response_model=list[BangumiData]) +@router.get("/getAll", response_model=list[Bangumi]) async def get_all_data(current_user=Depends(get_current_user)): if not current_user: raise HTTPException( @@ -18,7 +18,7 @@ async def get_all_data(current_user=Depends(get_current_user)): return torrent.search_all() -@router.get("/getData/{bangumi_id}", response_model=BangumiData) +@router.get("/getData/{bangumi_id}", response_model=Bangumi) async def get_data(bangumi_id: str, current_user=Depends(get_current_user)): if not current_user: raise HTTPException( @@ -29,7 +29,7 @@ async def get_data(bangumi_id: str, current_user=Depends(get_current_user)): @router.post("/updateRule") -async def update_rule(data: BangumiData, current_user=Depends(get_current_user)): +async def update_rule(data: Bangumi, current_user=Depends(get_current_user)): if not current_user: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token" diff --git a/backend/src/module/api/download.py b/backend/src/module/api/download.py index 278b1bb5..dcd36f1e 100644 --- a/backend/src/module/api/download.py +++ b/backend/src/module/api/download.py @@ -1,7 +1,7 @@ from fastapi import APIRouter, Depends, HTTPException, status from module.manager import SeasonCollector -from module.models import BangumiData +from module.models import Bangumi from module.models.api import RssLink from module.rss import analyser from module.security import get_current_user @@ -23,9 +23,7 @@ async def analysis(link: RssLink, current_user=Depends(get_current_user)): @router.post("/collection") -async def download_collection( - data: BangumiData, current_user=Depends(get_current_user) -): +async def download_collection(data: Bangumi, current_user=Depends(get_current_user)): if not current_user: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token" @@ -41,7 +39,7 @@ async def download_collection( @router.post("/subscribe") -async def subscribe(data: BangumiData, current_user=Depends(get_current_user)): +async def subscribe(data: Bangumi, current_user=Depends(get_current_user)): if not current_user: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token" diff --git a/backend/src/module/database/bangumi.py b/backend/src/module/database/bangumi.py index 986ccee5..d750746e 100644 --- a/backend/src/module/database/bangumi.py +++ b/backend/src/module/database/bangumi.py @@ -1,21 +1,18 @@ import logging -from module.models import Bangumi, BangumiUpdate -from sqlmodel import Session, select, delete, SQLModel -from module.database.engine import engine +from sqlmodel import Session, select, delete, SQLModel, or_, and_ from typing import Optional from sqlalchemy.exc import IntegrityError, NoResultFound +from .engine import engine +from module.models import Bangumi + logger = logging.getLogger(__name__) class BangumiDatabase(Session): - def __init__(self): - super().__init__(engine) - - @staticmethod - def update_table(): - SQLModel.metadata.create_all(engine) + def __init__(self, _engine=engine): + super().__init__(_engine) def insert_one(self, data: Bangumi): self.add(data) @@ -26,7 +23,7 @@ class BangumiDatabase(Session): self.add_all(data) logger.debug(f"[Database] Insert {len(data)} bangumi into database.") - def update_one(self, data: BangumiUpdate) -> bool: + def update_one(self, data: Bangumi) -> bool: db_data = self.get(Bangumi, data.id) if not db_data: return False @@ -39,7 +36,7 @@ class BangumiDatabase(Session): logger.debug(f"[Database] Update {data.official_title}") return True - def update_list(self, datas: list[BangumiUpdate]): + def update_list(self, datas: list[Bangumi]): for data in datas: self.update_one(data) @@ -48,7 +45,7 @@ class BangumiDatabase(Session): statement = select(Bangumi).where(Bangumi.title_raw == title_raw) bangumi = self.exec(statement).first() bangumi.rss_link = rss_set - bangumi.added = 0 + bangumi.added = False self.add(bangumi) self.commit() self.refresh(bangumi) @@ -93,14 +90,8 @@ class BangumiDatabase(Session): return self.exec(statement).first() def match_poster(self, bangumi_name: str) -> str: - # condition = {"official_title": bangumi_name} - statement = select(Bangumi).where(Bangumi.official_title == bangumi_name) - # keys = ["poster_link"] - # data = self.select.one( - # keys=keys, - # conditions=condition, - # combine_operator="INSTR", - # ) + # Use like to match + statement = select(Bangumi).where(Bangumi.title_raw.like(f"%{bangumi_name}%")) data = self.exec(statement).first() if data: return data.poster_link @@ -108,9 +99,6 @@ class BangumiDatabase(Session): return "" def match_list(self, torrent_list: list, rss_link: str) -> list: - # Match title_raw in database - # keys = ["title_raw", "rss_link", "poster_link"] - # match_datas = self.select.column(keys) match_datas = self.search_all() if not match_datas: return torrent_list @@ -122,13 +110,9 @@ class BangumiDatabase(Session): if match_data.title_raw in torrent.name: if rss_link not in match_data.rss_link: match_data.rss_link += f",{rss_link}" - self.update_rss( - match_data.title_raw, match_data.rss_link - ) + self.update_rss(match_data.title_raw, match_data.rss_link) if not match_data.poster_link: - self.update_poster( - match_data.title_raw, torrent.poster_link - ) + self.update_poster(match_data.title_raw, torrent.poster_link) torrent_list.pop(i) break else: @@ -143,14 +127,23 @@ class BangumiDatabase(Session): def not_added(self) -> list[Bangumi]: conditions = select(Bangumi).where( - Bangumi.added == 0 or - Bangumi.rule_name is None or - Bangumi.save_path is None + or_( + Bangumi.added == 0, Bangumi.rule_name is None, Bangumi.save_path is None ) + ) datas = self.exec(conditions).all() # dict_data = self.select.many(conditions=conditions, combine_operator="OR") return datas + def disable_rule(self, _id: int): + statement = select(Bangumi).where(Bangumi.id == _id) + bangumi = self.exec(statement).first() + bangumi.deleted = True + self.add(bangumi) + self.commit() + self.refresh(bangumi) + logger.debug(f"[Database] Disable rule {bangumi.title_raw}.") + if __name__ == "__main__": with BangumiDatabase() as db: diff --git a/backend/src/module/database/engine.py b/backend/src/module/database/engine.py index 369c53b3..94fa37b0 100644 --- a/backend/src/module/database/engine.py +++ b/backend/src/module/database/engine.py @@ -4,4 +4,4 @@ from module.conf import DATA_PATH engine = create_engine(DATA_PATH) -db_session = Session(engine) \ No newline at end of file +db_session = Session(engine) diff --git a/backend/src/module/database/user.py b/backend/src/module/database/user.py index c763346f..6fe77d2f 100644 --- a/backend/src/module/database/user.py +++ b/backend/src/module/database/user.py @@ -22,17 +22,6 @@ class UserDatabase(Session): self.add(User()) self.commit() - - # @staticmethod - # def __data_to_db(data: User) -> dict: - # db_data = data.dict() - # db_data["password"] = get_password_hash(db_data["password"]) - # return db_data - # - # @staticmethod - # def __db_to_data(db_data: dict) -> User: - # return User(**db_data) - def get_user(self, username): statement = select(User).where(User.username == username) result = self.exec(statement).first() diff --git a/backend/src/module/downloader/download_client.py b/backend/src/module/downloader/download_client.py index 675e180a..105f97d3 100644 --- a/backend/src/module/downloader/download_client.py +++ b/backend/src/module/downloader/download_client.py @@ -1,7 +1,7 @@ import logging from module.conf import settings -from module.models import BangumiData +from module.models import Bangumi from .path import TorrentPath @@ -68,7 +68,7 @@ class DownloadClient(TorrentPath): prefs = self.client.get_app_prefs() settings.downloader.path = self._join_path(prefs["save_path"], "Bangumi") - def set_rule(self, data: BangumiData): + def set_rule(self, data: Bangumi): data.rule_name = self._rule_name(data) data.save_path = self._gen_save_path(data) rule = { @@ -92,7 +92,7 @@ class DownloadClient(TorrentPath): f"[Downloader] Add {data.official_title} Season {data.season} to auto download rules." ) - def set_rules(self, bangumi_info: list[BangumiData]): + def set_rules(self, bangumi_info: list[Bangumi]): logger.debug("[Downloader] Start adding rules.") for info in bangumi_info: self.set_rule(info) diff --git a/backend/src/module/downloader/path.py b/backend/src/module/downloader/path.py index f1099191..8f6fcc16 100644 --- a/backend/src/module/downloader/path.py +++ b/backend/src/module/downloader/path.py @@ -4,8 +4,7 @@ import re from pathlib import Path from module.conf import settings -from module.models import BangumiData - +from module.models import Bangumi logger = logging.getLogger(__name__) @@ -50,7 +49,7 @@ class TorrentPath: return self._file_depth(file_path) <= 2 @staticmethod - def _gen_save_path(data: BangumiData): + def _gen_save_path(data: Bangumi): folder = ( f"{data.official_title} ({data.year})" if data.year else data.official_title ) @@ -58,7 +57,7 @@ class TorrentPath: return str(save_path) @staticmethod - def _rule_name(data: BangumiData): + def _rule_name(data: Bangumi): rule_name = ( f"[{data.group_name}] {data.official_title} S{data.season}" if settings.bangumi_manage.group_tag diff --git a/backend/src/module/manager/collector.py b/backend/src/module/manager/collector.py index 979aa5cb..b0bb9e15 100644 --- a/backend/src/module/manager/collector.py +++ b/backend/src/module/manager/collector.py @@ -2,14 +2,14 @@ import logging from module.database import BangumiDatabase from module.downloader import DownloadClient -from module.models import BangumiData +from module.models import Bangumi from module.searcher import SearchTorrent logger = logging.getLogger(__name__) class SeasonCollector(DownloadClient): - def add_season_torrents(self, data: BangumiData, torrents, torrent_files=None): + def add_season_torrents(self, data: Bangumi, torrents, torrent_files=None): if torrent_files: download_info = { "torrent_files": torrent_files, @@ -23,7 +23,7 @@ class SeasonCollector(DownloadClient): } return self.add_torrent(download_info) - def collect_season(self, data: BangumiData, link: str = None, proxy: bool = False): + def collect_season(self, data: Bangumi, link: str = None, proxy: bool = False): logger.info(f"Start collecting {data.official_title} Season {data.season}...") with SearchTorrent() as st: if not link: @@ -39,7 +39,7 @@ class SeasonCollector(DownloadClient): data=data, torrents=torrents, torrent_files=torrent_files ) - def subscribe_season(self, data: BangumiData): + def subscribe_season(self, data: Bangumi): with BangumiDatabase() as db: data.added = True data.eps_collect = True diff --git a/backend/src/module/manager/torrent.py b/backend/src/module/manager/torrent.py index f3a69edb..3385e11e 100644 --- a/backend/src/module/manager/torrent.py +++ b/backend/src/module/manager/torrent.py @@ -4,21 +4,21 @@ from fastapi.responses import JSONResponse from module.database import BangumiDatabase from module.downloader import DownloadClient -from module.models import BangumiData +from module.models import Bangumi logger = logging.getLogger(__name__) class TorrentManager(BangumiDatabase): @staticmethod - def __match_torrents_list(data: BangumiData) -> list: + def __match_torrents_list(data: Bangumi) -> list: with DownloadClient() as client: torrents = client.get_torrent_info(status_filter=None) return [ torrent.hash for torrent in torrents if torrent.save_path == data.save_path ] - def delete_torrents(self, data: BangumiData, client: DownloadClient): + def delete_torrents(self, data: Bangumi, client: DownloadClient): hash_list = self.__match_torrents_list(data) if hash_list: client.delete_torrent(hash_list) @@ -29,7 +29,7 @@ class TorrentManager(BangumiDatabase): def delete_rule(self, _id: int | str, file: bool = False): data = self.search_id(int(_id)) - if isinstance(data, BangumiData): + if isinstance(data, Bangumi): with DownloadClient() as client: client.remove_rule(data.rule_name) client.remove_rss_feed(data.official_title) @@ -54,7 +54,7 @@ class TorrentManager(BangumiDatabase): def disable_rule(self, _id: str | int, file: bool = False): data = self.search_id(int(_id)) - if isinstance(data, BangumiData): + if isinstance(data, Bangumi): with DownloadClient() as client: client.remove_rule(data.rule_name) data.deleted = True @@ -81,7 +81,7 @@ class TorrentManager(BangumiDatabase): def enable_rule(self, _id: str | int): data = self.search_id(int(_id)) - if isinstance(data, BangumiData): + if isinstance(data, Bangumi): data.deleted = False self.update_one(data) with DownloadClient() as client: @@ -98,7 +98,7 @@ class TorrentManager(BangumiDatabase): status_code=406, content={"msg": f"Can't find bangumi id {_id}"} ) - def update_rule(self, data: BangumiData): + def update_rule(self, data: Bangumi): old_data = self.search_id(data.id) if not old_data: logger.error(f"[Manager] Can't find data with {data.id}") diff --git a/backend/src/module/models/__init__.py b/backend/src/module/models/__init__.py index a73f18ed..0b10feaf 100644 --- a/backend/src/module/models/__init__.py +++ b/backend/src/module/models/__init__.py @@ -1,4 +1,4 @@ -from .bangumi import * +from .bangumi import Bangumi, Episode from .config import Config from .rss import RSSTorrents from .torrent import EpisodeFile, SubtitleFile, TorrentBase diff --git a/backend/src/module/models/bangumi.py b/backend/src/module/models/bangumi.py index 02a74ce9..b7484887 100644 --- a/backend/src/module/models/bangumi.py +++ b/backend/src/module/models/bangumi.py @@ -7,7 +7,9 @@ from typing import Optional class Bangumi(SQLModel, table=True): id: int = Field(default=None, primary_key=True) - official_title: str = Field(default="official_title", alias="official_title", title="番剧中文名") + official_title: str = Field( + default="official_title", alias="official_title", title="番剧中文名" + ) year: Optional[str] = Field(alias="year", title="番剧年份") title_raw: str = Field(default="title_raw", alias="title_raw", title="番剧原名") season: int = Field(default=1, alias="season", title="番剧季度") @@ -28,7 +30,9 @@ class Bangumi(SQLModel, table=True): class BangumiUpdate(SQLModel): - official_title: str = Field(default="official_title", alias="official_title", title="番剧中文名") + official_title: str = Field( + default="official_title", alias="official_title", title="番剧中文名" + ) year: Optional[str] = Field(alias="year", title="番剧年份") season: int = Field(default=1, alias="season", title="番剧季度") season_raw: Optional[str] = Field(alias="season_raw", title="番剧季度原名") diff --git a/backend/src/module/models/config.py b/backend/src/module/models/config.py index a2e4d2a6..76ec036f 100644 --- a/backend/src/module/models/config.py +++ b/backend/src/module/models/config.py @@ -14,7 +14,9 @@ class Downloader(BaseModel): type: str = Field("qbittorrent", description="Downloader type") host: str = Field("172.17.0.1:8080", description="Downloader host") username_: str = Field("admin", alias="username", description="Downloader username") - password_: str = Field("adminadmin", alias="password", description="Downloader password") + password_: str = Field( + "adminadmin", alias="password", description="Downloader password" + ) path: str = Field("/downloads/Bangumi", description="Downloader path") ssl: bool = Field(False, description="Downloader ssl") @@ -26,6 +28,7 @@ class Downloader(BaseModel): def password(self): return expandvars(self.password_) + class RSSParser(BaseModel): enable: bool = Field(True, description="Enable RSS parser") type: str = Field("mikan", description="RSS parser type") @@ -39,6 +42,7 @@ class RSSParser(BaseModel): def token(self): return expandvars(self.token_) + class BangumiManage(BaseModel): enable: bool = Field(True, description="Enable bangumi manage") eps_complete: bool = Field(False, description="Enable eps complete") @@ -82,6 +86,7 @@ class Notification(BaseModel): def chat_id(self): return expandvars(self.chat_id_) + class Config(BaseModel): program: Program = Program() downloader: Downloader = Downloader() diff --git a/backend/src/module/parser/title_parser.py b/backend/src/module/parser/title_parser.py index 33ada84b..4d1e7f00 100644 --- a/backend/src/module/parser/title_parser.py +++ b/backend/src/module/parser/title_parser.py @@ -1,7 +1,7 @@ import logging from module.conf import settings -from module.models import BangumiData +from module.models import Bangumi from .analyser import raw_parser, tmdb_parser, torrent_parser @@ -39,7 +39,7 @@ class TitleParser: return official_title, tmdb_season, year @staticmethod - def raw_parser(raw: str, rss_link: str) -> BangumiData | None: + def raw_parser(raw: str, rss_link: str) -> Bangumi | None: language = settings.rss_parser.language try: episode = raw_parser(raw) @@ -60,7 +60,7 @@ class TitleParser: else: official_title = title_raw _season = episode.season - data = BangumiData( + data = Bangumi( official_title=official_title, title_raw=title_raw, season=_season, diff --git a/backend/src/module/rss/analyser.py b/backend/src/module/rss/analyser.py index ccd398be..00bde7a6 100644 --- a/backend/src/module/rss/analyser.py +++ b/backend/src/module/rss/analyser.py @@ -3,7 +3,7 @@ import re from module.conf import settings from module.database import BangumiDatabase -from module.models import BangumiData +from module.models import Bangumi from module.network import RequestContent, TorrentInfo from module.parser import TitleParser @@ -16,7 +16,7 @@ class RSSAnalyser: with BangumiDatabase() as db: db.update_table() - def official_title_parser(self, data: BangumiData, mikan_title: str): + def official_title_parser(self, data: Bangumi, mikan_title: str): if settings.rss_parser.parser_type == "mikan": data.official_title = mikan_title if mikan_title else data.official_title elif settings.rss_parser.parser_type == "tmdb": @@ -63,7 +63,7 @@ class RSSAnalyser: def torrent_to_data( self, torrent: TorrentInfo, rss_link: str | None = None - ) -> BangumiData: + ) -> Bangumi: data = self._title_analyser.raw_parser(raw=torrent.name, rss_link=rss_link) if data: try: @@ -79,7 +79,7 @@ class RSSAnalyser: def rss_to_data( self, rss_link: str, database: BangumiDatabase, full_parse: bool = True - ) -> list[BangumiData]: + ) -> list[Bangumi]: rss_torrents = self.get_rss_torrents(rss_link, full_parse) torrents_to_add = database.match_list(rss_torrents, rss_link) if not torrents_to_add: @@ -92,7 +92,7 @@ class RSSAnalyser: else: return [] - def link_to_data(self, link: str) -> BangumiData: + def link_to_data(self, link: str) -> Bangumi: torrents = self.get_rss_torrents(link, False) for torrent in torrents: data = self.torrent_to_data(torrent, link) diff --git a/backend/src/module/rss/filter.py b/backend/src/module/rss/filter.py index e70d4602..d4717e38 100644 --- a/backend/src/module/rss/filter.py +++ b/backend/src/module/rss/filter.py @@ -3,7 +3,7 @@ import logging from module.conf import settings from module.database import BangumiDatabase from module.downloader import DownloadClient -from module.models import BangumiData +from module.models import Bangumi from module.network import RequestContent logger = logging.getLogger(__name__) @@ -14,7 +14,7 @@ def matched(torrent_title: str): return db.match_torrent(torrent_title) -def save_path(data: BangumiData): +def save_path(data: Bangumi): folder = ( f"{data.official_title}({data.year})" if data.year else f"{data.official_title}" ) diff --git a/backend/src/module/rss/poller.py b/backend/src/module/rss/poller.py index 90a81cb0..3efaab2f 100644 --- a/backend/src/module/rss/poller.py +++ b/backend/src/module/rss/poller.py @@ -1,7 +1,7 @@ import re from module.database import RSSDatabase -from module.models import BangumiData, RSSTorrents +from module.models import Bangumi, RSSTorrents from module.network import RequestContent, TorrentInfo @@ -11,7 +11,7 @@ class RSSPoller(RSSDatabase): return req.get_torrents(rss_link) @staticmethod - def filter_torrent(data: BangumiData, torrent: TorrentInfo) -> bool: + def filter_torrent(data: Bangumi, torrent: TorrentInfo) -> bool: if data.title_raw in torrent.name: _filter = "|".join(data.filter) if not re.search(_filter, torrent.name): diff --git a/backend/src/module/searcher/searcher.py b/backend/src/module/searcher/searcher.py index a52e379e..dead9466 100644 --- a/backend/src/module/searcher/searcher.py +++ b/backend/src/module/searcher/searcher.py @@ -1,4 +1,4 @@ -from module.models import BangumiData, TorrentBase +from module.models import Bangumi, TorrentBase from module.network import RequestContent from module.searcher.plugin import search_url @@ -30,7 +30,7 @@ class SearchTorrent(RequestContent): return [TorrentBase(**d) for d in to_dict()] - def search_season(self, data: BangumiData): + def search_season(self, data: Bangumi): keywords = [getattr(data, key) for key in SEARCH_KEY if getattr(data, key)] torrents = self.search_torrents(keywords) return [torrent for torrent in torrents if data.title_raw in torrent.name] diff --git a/backend/src/module/update/data_migration.py b/backend/src/module/update/data_migration.py index b9ddf8db..f9739f98 100644 --- a/backend/src/module/update/data_migration.py +++ b/backend/src/module/update/data_migration.py @@ -2,7 +2,7 @@ import os from module.conf import LEGACY_DATA_PATH from module.database import BangumiDatabase -from module.models import BangumiData +from module.models import Bangumi from module.utils import json_config @@ -14,7 +14,7 @@ def data_migration(): rss_link = old_data["rss_link"] new_data = [] for info in infos: - new_data.append(BangumiData(**info, rss_link=[rss_link])) + new_data.append(Bangumi(**info, rss_link=[rss_link])) with BangumiDatabase() as database: database.update_table() database.insert_list(new_data) diff --git a/backend/src/test/test_database.py b/backend/src/test/test_database.py index 80493bf4..b7088ef1 100644 --- a/backend/src/test/test_database.py +++ b/backend/src/test/test_database.py @@ -1,11 +1,17 @@ +from sqlmodel import create_engine, SQLModel, Session +from sqlmodel.pool import StaticPool + from module.database import BangumiDatabase -from module.models import BangumiData +from module.models import Bangumi def test_database(): - TEST_PATH = "test/test.db" - test_data = BangumiData( - id=1, + # sqlite mock engine + engine = create_engine( + "sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool + ) + SQLModel.metadata.create_all(engine) + test_data = Bangumi( official_title="test", year="2021", title_raw="test", @@ -17,18 +23,15 @@ def test_database(): subtitle="test", eps_collect=False, offset=0, - filter=["720p", "\\d+-\\d+"], - rss_link=["test"], + filter="720p,\\d+-\\d+", + rss_link="test", poster_link="/test/test.jpg", added=False, rule_name=None, save_path=None, deleted=False, ) - with BangumiDatabase(database=TEST_PATH) as database: - # create table - database.update_table() - with BangumiDatabase(database=TEST_PATH) as database: + with BangumiDatabase(engine) as database: # insert database.insert_one(test_data) assert database.search_id(1) == test_data @@ -44,8 +47,3 @@ def test_database(): # delete database.delete_one(1) assert database.search_id(1) is None - - # Delete test database - import os - - os.remove(TEST_PATH)