test: add mock test to database.

This commit is contained in:
estrella
2023-07-31 16:00:54 +08:00
parent d24acc60d5
commit 9b99e6c591
19 changed files with 92 additions and 106 deletions

View File

@@ -2,13 +2,13 @@ from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.responses import JSONResponse
from module.manager import TorrentManager
from module.models import BangumiData
from module.models import Bangumi
from module.security import get_current_user
router = APIRouter(prefix="/bangumi", tags=["bangumi"])
@router.get("/getAll", response_model=list[BangumiData])
@router.get("/getAll", response_model=list[Bangumi])
async def get_all_data(current_user=Depends(get_current_user)):
if not current_user:
raise HTTPException(
@@ -18,7 +18,7 @@ async def get_all_data(current_user=Depends(get_current_user)):
return torrent.search_all()
@router.get("/getData/{bangumi_id}", response_model=BangumiData)
@router.get("/getData/{bangumi_id}", response_model=Bangumi)
async def get_data(bangumi_id: str, current_user=Depends(get_current_user)):
if not current_user:
raise HTTPException(
@@ -29,7 +29,7 @@ async def get_data(bangumi_id: str, current_user=Depends(get_current_user)):
@router.post("/updateRule")
async def update_rule(data: BangumiData, current_user=Depends(get_current_user)):
async def update_rule(data: Bangumi, current_user=Depends(get_current_user)):
if not current_user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token"

View File

@@ -1,7 +1,7 @@
from fastapi import APIRouter, Depends, HTTPException, status
from module.manager import SeasonCollector
from module.models import BangumiData
from module.models import Bangumi
from module.models.api import RssLink
from module.rss import analyser
from module.security import get_current_user
@@ -23,9 +23,7 @@ async def analysis(link: RssLink, current_user=Depends(get_current_user)):
@router.post("/collection")
async def download_collection(
data: BangumiData, current_user=Depends(get_current_user)
):
async def download_collection(data: Bangumi, current_user=Depends(get_current_user)):
if not current_user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token"
@@ -41,7 +39,7 @@ async def download_collection(
@router.post("/subscribe")
async def subscribe(data: BangumiData, current_user=Depends(get_current_user)):
async def subscribe(data: Bangumi, current_user=Depends(get_current_user)):
if not current_user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token"

View File

@@ -1,21 +1,18 @@
import logging
from module.models import Bangumi, BangumiUpdate
from sqlmodel import Session, select, delete, SQLModel
from module.database.engine import engine
from sqlmodel import Session, select, delete, SQLModel, or_, and_
from typing import Optional
from sqlalchemy.exc import IntegrityError, NoResultFound
from .engine import engine
from module.models import Bangumi
logger = logging.getLogger(__name__)
class BangumiDatabase(Session):
def __init__(self):
super().__init__(engine)
@staticmethod
def update_table():
SQLModel.metadata.create_all(engine)
def __init__(self, _engine=engine):
super().__init__(_engine)
def insert_one(self, data: Bangumi):
self.add(data)
@@ -26,7 +23,7 @@ class BangumiDatabase(Session):
self.add_all(data)
logger.debug(f"[Database] Insert {len(data)} bangumi into database.")
def update_one(self, data: BangumiUpdate) -> bool:
def update_one(self, data: Bangumi) -> bool:
db_data = self.get(Bangumi, data.id)
if not db_data:
return False
@@ -39,7 +36,7 @@ class BangumiDatabase(Session):
logger.debug(f"[Database] Update {data.official_title}")
return True
def update_list(self, datas: list[BangumiUpdate]):
def update_list(self, datas: list[Bangumi]):
for data in datas:
self.update_one(data)
@@ -48,7 +45,7 @@ class BangumiDatabase(Session):
statement = select(Bangumi).where(Bangumi.title_raw == title_raw)
bangumi = self.exec(statement).first()
bangumi.rss_link = rss_set
bangumi.added = 0
bangumi.added = False
self.add(bangumi)
self.commit()
self.refresh(bangumi)
@@ -93,14 +90,8 @@ class BangumiDatabase(Session):
return self.exec(statement).first()
def match_poster(self, bangumi_name: str) -> str:
# condition = {"official_title": bangumi_name}
statement = select(Bangumi).where(Bangumi.official_title == bangumi_name)
# keys = ["poster_link"]
# data = self.select.one(
# keys=keys,
# conditions=condition,
# combine_operator="INSTR",
# )
# Use like to match
statement = select(Bangumi).where(Bangumi.title_raw.like(f"%{bangumi_name}%"))
data = self.exec(statement).first()
if data:
return data.poster_link
@@ -108,9 +99,6 @@ class BangumiDatabase(Session):
return ""
def match_list(self, torrent_list: list, rss_link: str) -> list:
# Match title_raw in database
# keys = ["title_raw", "rss_link", "poster_link"]
# match_datas = self.select.column(keys)
match_datas = self.search_all()
if not match_datas:
return torrent_list
@@ -122,13 +110,9 @@ class BangumiDatabase(Session):
if match_data.title_raw in torrent.name:
if rss_link not in match_data.rss_link:
match_data.rss_link += f",{rss_link}"
self.update_rss(
match_data.title_raw, match_data.rss_link
)
self.update_rss(match_data.title_raw, match_data.rss_link)
if not match_data.poster_link:
self.update_poster(
match_data.title_raw, torrent.poster_link
)
self.update_poster(match_data.title_raw, torrent.poster_link)
torrent_list.pop(i)
break
else:
@@ -143,14 +127,23 @@ class BangumiDatabase(Session):
def not_added(self) -> list[Bangumi]:
conditions = select(Bangumi).where(
Bangumi.added == 0 or
Bangumi.rule_name is None or
Bangumi.save_path is None
or_(
Bangumi.added == 0, Bangumi.rule_name is None, Bangumi.save_path is None
)
)
datas = self.exec(conditions).all()
# dict_data = self.select.many(conditions=conditions, combine_operator="OR")
return datas
def disable_rule(self, _id: int):
statement = select(Bangumi).where(Bangumi.id == _id)
bangumi = self.exec(statement).first()
bangumi.deleted = True
self.add(bangumi)
self.commit()
self.refresh(bangumi)
logger.debug(f"[Database] Disable rule {bangumi.title_raw}.")
if __name__ == "__main__":
with BangumiDatabase() as db:

View File

@@ -4,4 +4,4 @@ from module.conf import DATA_PATH
engine = create_engine(DATA_PATH)
db_session = Session(engine)
db_session = Session(engine)

View File

@@ -22,17 +22,6 @@ class UserDatabase(Session):
self.add(User())
self.commit()
# @staticmethod
# def __data_to_db(data: User) -> dict:
# db_data = data.dict()
# db_data["password"] = get_password_hash(db_data["password"])
# return db_data
#
# @staticmethod
# def __db_to_data(db_data: dict) -> User:
# return User(**db_data)
def get_user(self, username):
statement = select(User).where(User.username == username)
result = self.exec(statement).first()

View File

@@ -1,7 +1,7 @@
import logging
from module.conf import settings
from module.models import BangumiData
from module.models import Bangumi
from .path import TorrentPath
@@ -68,7 +68,7 @@ class DownloadClient(TorrentPath):
prefs = self.client.get_app_prefs()
settings.downloader.path = self._join_path(prefs["save_path"], "Bangumi")
def set_rule(self, data: BangumiData):
def set_rule(self, data: Bangumi):
data.rule_name = self._rule_name(data)
data.save_path = self._gen_save_path(data)
rule = {
@@ -92,7 +92,7 @@ class DownloadClient(TorrentPath):
f"[Downloader] Add {data.official_title} Season {data.season} to auto download rules."
)
def set_rules(self, bangumi_info: list[BangumiData]):
def set_rules(self, bangumi_info: list[Bangumi]):
logger.debug("[Downloader] Start adding rules.")
for info in bangumi_info:
self.set_rule(info)

View File

@@ -4,8 +4,7 @@ import re
from pathlib import Path
from module.conf import settings
from module.models import BangumiData
from module.models import Bangumi
logger = logging.getLogger(__name__)
@@ -50,7 +49,7 @@ class TorrentPath:
return self._file_depth(file_path) <= 2
@staticmethod
def _gen_save_path(data: BangumiData):
def _gen_save_path(data: Bangumi):
folder = (
f"{data.official_title} ({data.year})" if data.year else data.official_title
)
@@ -58,7 +57,7 @@ class TorrentPath:
return str(save_path)
@staticmethod
def _rule_name(data: BangumiData):
def _rule_name(data: Bangumi):
rule_name = (
f"[{data.group_name}] {data.official_title} S{data.season}"
if settings.bangumi_manage.group_tag

View File

@@ -2,14 +2,14 @@ import logging
from module.database import BangumiDatabase
from module.downloader import DownloadClient
from module.models import BangumiData
from module.models import Bangumi
from module.searcher import SearchTorrent
logger = logging.getLogger(__name__)
class SeasonCollector(DownloadClient):
def add_season_torrents(self, data: BangumiData, torrents, torrent_files=None):
def add_season_torrents(self, data: Bangumi, torrents, torrent_files=None):
if torrent_files:
download_info = {
"torrent_files": torrent_files,
@@ -23,7 +23,7 @@ class SeasonCollector(DownloadClient):
}
return self.add_torrent(download_info)
def collect_season(self, data: BangumiData, link: str = None, proxy: bool = False):
def collect_season(self, data: Bangumi, link: str = None, proxy: bool = False):
logger.info(f"Start collecting {data.official_title} Season {data.season}...")
with SearchTorrent() as st:
if not link:
@@ -39,7 +39,7 @@ class SeasonCollector(DownloadClient):
data=data, torrents=torrents, torrent_files=torrent_files
)
def subscribe_season(self, data: BangumiData):
def subscribe_season(self, data: Bangumi):
with BangumiDatabase() as db:
data.added = True
data.eps_collect = True

View File

@@ -4,21 +4,21 @@ from fastapi.responses import JSONResponse
from module.database import BangumiDatabase
from module.downloader import DownloadClient
from module.models import BangumiData
from module.models import Bangumi
logger = logging.getLogger(__name__)
class TorrentManager(BangumiDatabase):
@staticmethod
def __match_torrents_list(data: BangumiData) -> list:
def __match_torrents_list(data: Bangumi) -> list:
with DownloadClient() as client:
torrents = client.get_torrent_info(status_filter=None)
return [
torrent.hash for torrent in torrents if torrent.save_path == data.save_path
]
def delete_torrents(self, data: BangumiData, client: DownloadClient):
def delete_torrents(self, data: Bangumi, client: DownloadClient):
hash_list = self.__match_torrents_list(data)
if hash_list:
client.delete_torrent(hash_list)
@@ -29,7 +29,7 @@ class TorrentManager(BangumiDatabase):
def delete_rule(self, _id: int | str, file: bool = False):
data = self.search_id(int(_id))
if isinstance(data, BangumiData):
if isinstance(data, Bangumi):
with DownloadClient() as client:
client.remove_rule(data.rule_name)
client.remove_rss_feed(data.official_title)
@@ -54,7 +54,7 @@ class TorrentManager(BangumiDatabase):
def disable_rule(self, _id: str | int, file: bool = False):
data = self.search_id(int(_id))
if isinstance(data, BangumiData):
if isinstance(data, Bangumi):
with DownloadClient() as client:
client.remove_rule(data.rule_name)
data.deleted = True
@@ -81,7 +81,7 @@ class TorrentManager(BangumiDatabase):
def enable_rule(self, _id: str | int):
data = self.search_id(int(_id))
if isinstance(data, BangumiData):
if isinstance(data, Bangumi):
data.deleted = False
self.update_one(data)
with DownloadClient() as client:
@@ -98,7 +98,7 @@ class TorrentManager(BangumiDatabase):
status_code=406, content={"msg": f"Can't find bangumi id {_id}"}
)
def update_rule(self, data: BangumiData):
def update_rule(self, data: Bangumi):
old_data = self.search_id(data.id)
if not old_data:
logger.error(f"[Manager] Can't find data with {data.id}")

View File

@@ -1,4 +1,4 @@
from .bangumi import *
from .bangumi import Bangumi, Episode
from .config import Config
from .rss import RSSTorrents
from .torrent import EpisodeFile, SubtitleFile, TorrentBase

View File

@@ -7,7 +7,9 @@ from typing import Optional
class Bangumi(SQLModel, table=True):
id: int = Field(default=None, primary_key=True)
official_title: str = Field(default="official_title", alias="official_title", title="番剧中文名")
official_title: str = Field(
default="official_title", alias="official_title", title="番剧中文名"
)
year: Optional[str] = Field(alias="year", title="番剧年份")
title_raw: str = Field(default="title_raw", alias="title_raw", title="番剧原名")
season: int = Field(default=1, alias="season", title="番剧季度")
@@ -28,7 +30,9 @@ class Bangumi(SQLModel, table=True):
class BangumiUpdate(SQLModel):
official_title: str = Field(default="official_title", alias="official_title", title="番剧中文名")
official_title: str = Field(
default="official_title", alias="official_title", title="番剧中文名"
)
year: Optional[str] = Field(alias="year", title="番剧年份")
season: int = Field(default=1, alias="season", title="番剧季度")
season_raw: Optional[str] = Field(alias="season_raw", title="番剧季度原名")

View File

@@ -14,7 +14,9 @@ class Downloader(BaseModel):
type: str = Field("qbittorrent", description="Downloader type")
host: str = Field("172.17.0.1:8080", description="Downloader host")
username_: str = Field("admin", alias="username", description="Downloader username")
password_: str = Field("adminadmin", alias="password", description="Downloader password")
password_: str = Field(
"adminadmin", alias="password", description="Downloader password"
)
path: str = Field("/downloads/Bangumi", description="Downloader path")
ssl: bool = Field(False, description="Downloader ssl")
@@ -26,6 +28,7 @@ class Downloader(BaseModel):
def password(self):
return expandvars(self.password_)
class RSSParser(BaseModel):
enable: bool = Field(True, description="Enable RSS parser")
type: str = Field("mikan", description="RSS parser type")
@@ -39,6 +42,7 @@ class RSSParser(BaseModel):
def token(self):
return expandvars(self.token_)
class BangumiManage(BaseModel):
enable: bool = Field(True, description="Enable bangumi manage")
eps_complete: bool = Field(False, description="Enable eps complete")
@@ -82,6 +86,7 @@ class Notification(BaseModel):
def chat_id(self):
return expandvars(self.chat_id_)
class Config(BaseModel):
program: Program = Program()
downloader: Downloader = Downloader()

View File

@@ -1,7 +1,7 @@
import logging
from module.conf import settings
from module.models import BangumiData
from module.models import Bangumi
from .analyser import raw_parser, tmdb_parser, torrent_parser
@@ -39,7 +39,7 @@ class TitleParser:
return official_title, tmdb_season, year
@staticmethod
def raw_parser(raw: str, rss_link: str) -> BangumiData | None:
def raw_parser(raw: str, rss_link: str) -> Bangumi | None:
language = settings.rss_parser.language
try:
episode = raw_parser(raw)
@@ -60,7 +60,7 @@ class TitleParser:
else:
official_title = title_raw
_season = episode.season
data = BangumiData(
data = Bangumi(
official_title=official_title,
title_raw=title_raw,
season=_season,

View File

@@ -3,7 +3,7 @@ import re
from module.conf import settings
from module.database import BangumiDatabase
from module.models import BangumiData
from module.models import Bangumi
from module.network import RequestContent, TorrentInfo
from module.parser import TitleParser
@@ -16,7 +16,7 @@ class RSSAnalyser:
with BangumiDatabase() as db:
db.update_table()
def official_title_parser(self, data: BangumiData, mikan_title: str):
def official_title_parser(self, data: Bangumi, mikan_title: str):
if settings.rss_parser.parser_type == "mikan":
data.official_title = mikan_title if mikan_title else data.official_title
elif settings.rss_parser.parser_type == "tmdb":
@@ -63,7 +63,7 @@ class RSSAnalyser:
def torrent_to_data(
self, torrent: TorrentInfo, rss_link: str | None = None
) -> BangumiData:
) -> Bangumi:
data = self._title_analyser.raw_parser(raw=torrent.name, rss_link=rss_link)
if data:
try:
@@ -79,7 +79,7 @@ class RSSAnalyser:
def rss_to_data(
self, rss_link: str, database: BangumiDatabase, full_parse: bool = True
) -> list[BangumiData]:
) -> list[Bangumi]:
rss_torrents = self.get_rss_torrents(rss_link, full_parse)
torrents_to_add = database.match_list(rss_torrents, rss_link)
if not torrents_to_add:
@@ -92,7 +92,7 @@ class RSSAnalyser:
else:
return []
def link_to_data(self, link: str) -> BangumiData:
def link_to_data(self, link: str) -> Bangumi:
torrents = self.get_rss_torrents(link, False)
for torrent in torrents:
data = self.torrent_to_data(torrent, link)

View File

@@ -3,7 +3,7 @@ import logging
from module.conf import settings
from module.database import BangumiDatabase
from module.downloader import DownloadClient
from module.models import BangumiData
from module.models import Bangumi
from module.network import RequestContent
logger = logging.getLogger(__name__)
@@ -14,7 +14,7 @@ def matched(torrent_title: str):
return db.match_torrent(torrent_title)
def save_path(data: BangumiData):
def save_path(data: Bangumi):
folder = (
f"{data.official_title}({data.year})" if data.year else f"{data.official_title}"
)

View File

@@ -1,7 +1,7 @@
import re
from module.database import RSSDatabase
from module.models import BangumiData, RSSTorrents
from module.models import Bangumi, RSSTorrents
from module.network import RequestContent, TorrentInfo
@@ -11,7 +11,7 @@ class RSSPoller(RSSDatabase):
return req.get_torrents(rss_link)
@staticmethod
def filter_torrent(data: BangumiData, torrent: TorrentInfo) -> bool:
def filter_torrent(data: Bangumi, torrent: TorrentInfo) -> bool:
if data.title_raw in torrent.name:
_filter = "|".join(data.filter)
if not re.search(_filter, torrent.name):

View File

@@ -1,4 +1,4 @@
from module.models import BangumiData, TorrentBase
from module.models import Bangumi, TorrentBase
from module.network import RequestContent
from module.searcher.plugin import search_url
@@ -30,7 +30,7 @@ class SearchTorrent(RequestContent):
return [TorrentBase(**d) for d in to_dict()]
def search_season(self, data: BangumiData):
def search_season(self, data: Bangumi):
keywords = [getattr(data, key) for key in SEARCH_KEY if getattr(data, key)]
torrents = self.search_torrents(keywords)
return [torrent for torrent in torrents if data.title_raw in torrent.name]

View File

@@ -2,7 +2,7 @@ import os
from module.conf import LEGACY_DATA_PATH
from module.database import BangumiDatabase
from module.models import BangumiData
from module.models import Bangumi
from module.utils import json_config
@@ -14,7 +14,7 @@ def data_migration():
rss_link = old_data["rss_link"]
new_data = []
for info in infos:
new_data.append(BangumiData(**info, rss_link=[rss_link]))
new_data.append(Bangumi(**info, rss_link=[rss_link]))
with BangumiDatabase() as database:
database.update_table()
database.insert_list(new_data)

View File

@@ -1,11 +1,17 @@
from sqlmodel import create_engine, SQLModel, Session
from sqlmodel.pool import StaticPool
from module.database import BangumiDatabase
from module.models import BangumiData
from module.models import Bangumi
def test_database():
TEST_PATH = "test/test.db"
test_data = BangumiData(
id=1,
# sqlite mock engine
engine = create_engine(
"sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool
)
SQLModel.metadata.create_all(engine)
test_data = Bangumi(
official_title="test",
year="2021",
title_raw="test",
@@ -17,18 +23,15 @@ def test_database():
subtitle="test",
eps_collect=False,
offset=0,
filter=["720p", "\\d+-\\d+"],
rss_link=["test"],
filter="720p,\\d+-\\d+",
rss_link="test",
poster_link="/test/test.jpg",
added=False,
rule_name=None,
save_path=None,
deleted=False,
)
with BangumiDatabase(database=TEST_PATH) as database:
# create table
database.update_table()
with BangumiDatabase(database=TEST_PATH) as database:
with BangumiDatabase(engine) as database:
# insert
database.insert_one(test_data)
assert database.search_id(1) == test_data
@@ -44,8 +47,3 @@ def test_database():
# delete
database.delete_one(1)
assert database.search_id(1) is None
# Delete test database
import os
os.remove(TEST_PATH)