mirror of
https://github.com/EstrellaXD/Auto_Bangumi.git
synced 2026-07-05 19:47:32 +08:00
test: add mock test to database.
This commit is contained in:
@@ -2,13 +2,13 @@ from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from module.manager import TorrentManager
|
||||
from module.models import BangumiData
|
||||
from module.models import Bangumi
|
||||
from module.security import get_current_user
|
||||
|
||||
router = APIRouter(prefix="/bangumi", tags=["bangumi"])
|
||||
|
||||
|
||||
@router.get("/getAll", response_model=list[BangumiData])
|
||||
@router.get("/getAll", response_model=list[Bangumi])
|
||||
async def get_all_data(current_user=Depends(get_current_user)):
|
||||
if not current_user:
|
||||
raise HTTPException(
|
||||
@@ -18,7 +18,7 @@ async def get_all_data(current_user=Depends(get_current_user)):
|
||||
return torrent.search_all()
|
||||
|
||||
|
||||
@router.get("/getData/{bangumi_id}", response_model=BangumiData)
|
||||
@router.get("/getData/{bangumi_id}", response_model=Bangumi)
|
||||
async def get_data(bangumi_id: str, current_user=Depends(get_current_user)):
|
||||
if not current_user:
|
||||
raise HTTPException(
|
||||
@@ -29,7 +29,7 @@ async def get_data(bangumi_id: str, current_user=Depends(get_current_user)):
|
||||
|
||||
|
||||
@router.post("/updateRule")
|
||||
async def update_rule(data: BangumiData, current_user=Depends(get_current_user)):
|
||||
async def update_rule(data: Bangumi, current_user=Depends(get_current_user)):
|
||||
if not current_user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token"
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
|
||||
from module.manager import SeasonCollector
|
||||
from module.models import BangumiData
|
||||
from module.models import Bangumi
|
||||
from module.models.api import RssLink
|
||||
from module.rss import analyser
|
||||
from module.security import get_current_user
|
||||
@@ -23,9 +23,7 @@ async def analysis(link: RssLink, current_user=Depends(get_current_user)):
|
||||
|
||||
|
||||
@router.post("/collection")
|
||||
async def download_collection(
|
||||
data: BangumiData, current_user=Depends(get_current_user)
|
||||
):
|
||||
async def download_collection(data: Bangumi, current_user=Depends(get_current_user)):
|
||||
if not current_user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token"
|
||||
@@ -41,7 +39,7 @@ async def download_collection(
|
||||
|
||||
|
||||
@router.post("/subscribe")
|
||||
async def subscribe(data: BangumiData, current_user=Depends(get_current_user)):
|
||||
async def subscribe(data: Bangumi, current_user=Depends(get_current_user)):
|
||||
if not current_user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token"
|
||||
|
||||
@@ -1,21 +1,18 @@
|
||||
import logging
|
||||
|
||||
from module.models import Bangumi, BangumiUpdate
|
||||
from sqlmodel import Session, select, delete, SQLModel
|
||||
from module.database.engine import engine
|
||||
from sqlmodel import Session, select, delete, SQLModel, or_, and_
|
||||
from typing import Optional
|
||||
from sqlalchemy.exc import IntegrityError, NoResultFound
|
||||
|
||||
from .engine import engine
|
||||
from module.models import Bangumi
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BangumiDatabase(Session):
|
||||
def __init__(self):
|
||||
super().__init__(engine)
|
||||
|
||||
@staticmethod
|
||||
def update_table():
|
||||
SQLModel.metadata.create_all(engine)
|
||||
def __init__(self, _engine=engine):
|
||||
super().__init__(_engine)
|
||||
|
||||
def insert_one(self, data: Bangumi):
|
||||
self.add(data)
|
||||
@@ -26,7 +23,7 @@ class BangumiDatabase(Session):
|
||||
self.add_all(data)
|
||||
logger.debug(f"[Database] Insert {len(data)} bangumi into database.")
|
||||
|
||||
def update_one(self, data: BangumiUpdate) -> bool:
|
||||
def update_one(self, data: Bangumi) -> bool:
|
||||
db_data = self.get(Bangumi, data.id)
|
||||
if not db_data:
|
||||
return False
|
||||
@@ -39,7 +36,7 @@ class BangumiDatabase(Session):
|
||||
logger.debug(f"[Database] Update {data.official_title}")
|
||||
return True
|
||||
|
||||
def update_list(self, datas: list[BangumiUpdate]):
|
||||
def update_list(self, datas: list[Bangumi]):
|
||||
for data in datas:
|
||||
self.update_one(data)
|
||||
|
||||
@@ -48,7 +45,7 @@ class BangumiDatabase(Session):
|
||||
statement = select(Bangumi).where(Bangumi.title_raw == title_raw)
|
||||
bangumi = self.exec(statement).first()
|
||||
bangumi.rss_link = rss_set
|
||||
bangumi.added = 0
|
||||
bangumi.added = False
|
||||
self.add(bangumi)
|
||||
self.commit()
|
||||
self.refresh(bangumi)
|
||||
@@ -93,14 +90,8 @@ class BangumiDatabase(Session):
|
||||
return self.exec(statement).first()
|
||||
|
||||
def match_poster(self, bangumi_name: str) -> str:
|
||||
# condition = {"official_title": bangumi_name}
|
||||
statement = select(Bangumi).where(Bangumi.official_title == bangumi_name)
|
||||
# keys = ["poster_link"]
|
||||
# data = self.select.one(
|
||||
# keys=keys,
|
||||
# conditions=condition,
|
||||
# combine_operator="INSTR",
|
||||
# )
|
||||
# Use like to match
|
||||
statement = select(Bangumi).where(Bangumi.title_raw.like(f"%{bangumi_name}%"))
|
||||
data = self.exec(statement).first()
|
||||
if data:
|
||||
return data.poster_link
|
||||
@@ -108,9 +99,6 @@ class BangumiDatabase(Session):
|
||||
return ""
|
||||
|
||||
def match_list(self, torrent_list: list, rss_link: str) -> list:
|
||||
# Match title_raw in database
|
||||
# keys = ["title_raw", "rss_link", "poster_link"]
|
||||
# match_datas = self.select.column(keys)
|
||||
match_datas = self.search_all()
|
||||
if not match_datas:
|
||||
return torrent_list
|
||||
@@ -122,13 +110,9 @@ class BangumiDatabase(Session):
|
||||
if match_data.title_raw in torrent.name:
|
||||
if rss_link not in match_data.rss_link:
|
||||
match_data.rss_link += f",{rss_link}"
|
||||
self.update_rss(
|
||||
match_data.title_raw, match_data.rss_link
|
||||
)
|
||||
self.update_rss(match_data.title_raw, match_data.rss_link)
|
||||
if not match_data.poster_link:
|
||||
self.update_poster(
|
||||
match_data.title_raw, torrent.poster_link
|
||||
)
|
||||
self.update_poster(match_data.title_raw, torrent.poster_link)
|
||||
torrent_list.pop(i)
|
||||
break
|
||||
else:
|
||||
@@ -143,14 +127,23 @@ class BangumiDatabase(Session):
|
||||
|
||||
def not_added(self) -> list[Bangumi]:
|
||||
conditions = select(Bangumi).where(
|
||||
Bangumi.added == 0 or
|
||||
Bangumi.rule_name is None or
|
||||
Bangumi.save_path is None
|
||||
or_(
|
||||
Bangumi.added == 0, Bangumi.rule_name is None, Bangumi.save_path is None
|
||||
)
|
||||
)
|
||||
datas = self.exec(conditions).all()
|
||||
# dict_data = self.select.many(conditions=conditions, combine_operator="OR")
|
||||
return datas
|
||||
|
||||
def disable_rule(self, _id: int):
|
||||
statement = select(Bangumi).where(Bangumi.id == _id)
|
||||
bangumi = self.exec(statement).first()
|
||||
bangumi.deleted = True
|
||||
self.add(bangumi)
|
||||
self.commit()
|
||||
self.refresh(bangumi)
|
||||
logger.debug(f"[Database] Disable rule {bangumi.title_raw}.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
with BangumiDatabase() as db:
|
||||
|
||||
@@ -4,4 +4,4 @@ from module.conf import DATA_PATH
|
||||
|
||||
engine = create_engine(DATA_PATH)
|
||||
|
||||
db_session = Session(engine)
|
||||
db_session = Session(engine)
|
||||
|
||||
@@ -22,17 +22,6 @@ class UserDatabase(Session):
|
||||
self.add(User())
|
||||
self.commit()
|
||||
|
||||
|
||||
# @staticmethod
|
||||
# def __data_to_db(data: User) -> dict:
|
||||
# db_data = data.dict()
|
||||
# db_data["password"] = get_password_hash(db_data["password"])
|
||||
# return db_data
|
||||
#
|
||||
# @staticmethod
|
||||
# def __db_to_data(db_data: dict) -> User:
|
||||
# return User(**db_data)
|
||||
|
||||
def get_user(self, username):
|
||||
statement = select(User).where(User.username == username)
|
||||
result = self.exec(statement).first()
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
|
||||
from module.conf import settings
|
||||
from module.models import BangumiData
|
||||
from module.models import Bangumi
|
||||
|
||||
from .path import TorrentPath
|
||||
|
||||
@@ -68,7 +68,7 @@ class DownloadClient(TorrentPath):
|
||||
prefs = self.client.get_app_prefs()
|
||||
settings.downloader.path = self._join_path(prefs["save_path"], "Bangumi")
|
||||
|
||||
def set_rule(self, data: BangumiData):
|
||||
def set_rule(self, data: Bangumi):
|
||||
data.rule_name = self._rule_name(data)
|
||||
data.save_path = self._gen_save_path(data)
|
||||
rule = {
|
||||
@@ -92,7 +92,7 @@ class DownloadClient(TorrentPath):
|
||||
f"[Downloader] Add {data.official_title} Season {data.season} to auto download rules."
|
||||
)
|
||||
|
||||
def set_rules(self, bangumi_info: list[BangumiData]):
|
||||
def set_rules(self, bangumi_info: list[Bangumi]):
|
||||
logger.debug("[Downloader] Start adding rules.")
|
||||
for info in bangumi_info:
|
||||
self.set_rule(info)
|
||||
|
||||
@@ -4,8 +4,7 @@ import re
|
||||
from pathlib import Path
|
||||
|
||||
from module.conf import settings
|
||||
from module.models import BangumiData
|
||||
|
||||
from module.models import Bangumi
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -50,7 +49,7 @@ class TorrentPath:
|
||||
return self._file_depth(file_path) <= 2
|
||||
|
||||
@staticmethod
|
||||
def _gen_save_path(data: BangumiData):
|
||||
def _gen_save_path(data: Bangumi):
|
||||
folder = (
|
||||
f"{data.official_title} ({data.year})" if data.year else data.official_title
|
||||
)
|
||||
@@ -58,7 +57,7 @@ class TorrentPath:
|
||||
return str(save_path)
|
||||
|
||||
@staticmethod
|
||||
def _rule_name(data: BangumiData):
|
||||
def _rule_name(data: Bangumi):
|
||||
rule_name = (
|
||||
f"[{data.group_name}] {data.official_title} S{data.season}"
|
||||
if settings.bangumi_manage.group_tag
|
||||
|
||||
@@ -2,14 +2,14 @@ import logging
|
||||
|
||||
from module.database import BangumiDatabase
|
||||
from module.downloader import DownloadClient
|
||||
from module.models import BangumiData
|
||||
from module.models import Bangumi
|
||||
from module.searcher import SearchTorrent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SeasonCollector(DownloadClient):
|
||||
def add_season_torrents(self, data: BangumiData, torrents, torrent_files=None):
|
||||
def add_season_torrents(self, data: Bangumi, torrents, torrent_files=None):
|
||||
if torrent_files:
|
||||
download_info = {
|
||||
"torrent_files": torrent_files,
|
||||
@@ -23,7 +23,7 @@ class SeasonCollector(DownloadClient):
|
||||
}
|
||||
return self.add_torrent(download_info)
|
||||
|
||||
def collect_season(self, data: BangumiData, link: str = None, proxy: bool = False):
|
||||
def collect_season(self, data: Bangumi, link: str = None, proxy: bool = False):
|
||||
logger.info(f"Start collecting {data.official_title} Season {data.season}...")
|
||||
with SearchTorrent() as st:
|
||||
if not link:
|
||||
@@ -39,7 +39,7 @@ class SeasonCollector(DownloadClient):
|
||||
data=data, torrents=torrents, torrent_files=torrent_files
|
||||
)
|
||||
|
||||
def subscribe_season(self, data: BangumiData):
|
||||
def subscribe_season(self, data: Bangumi):
|
||||
with BangumiDatabase() as db:
|
||||
data.added = True
|
||||
data.eps_collect = True
|
||||
|
||||
@@ -4,21 +4,21 @@ from fastapi.responses import JSONResponse
|
||||
|
||||
from module.database import BangumiDatabase
|
||||
from module.downloader import DownloadClient
|
||||
from module.models import BangumiData
|
||||
from module.models import Bangumi
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TorrentManager(BangumiDatabase):
|
||||
@staticmethod
|
||||
def __match_torrents_list(data: BangumiData) -> list:
|
||||
def __match_torrents_list(data: Bangumi) -> list:
|
||||
with DownloadClient() as client:
|
||||
torrents = client.get_torrent_info(status_filter=None)
|
||||
return [
|
||||
torrent.hash for torrent in torrents if torrent.save_path == data.save_path
|
||||
]
|
||||
|
||||
def delete_torrents(self, data: BangumiData, client: DownloadClient):
|
||||
def delete_torrents(self, data: Bangumi, client: DownloadClient):
|
||||
hash_list = self.__match_torrents_list(data)
|
||||
if hash_list:
|
||||
client.delete_torrent(hash_list)
|
||||
@@ -29,7 +29,7 @@ class TorrentManager(BangumiDatabase):
|
||||
|
||||
def delete_rule(self, _id: int | str, file: bool = False):
|
||||
data = self.search_id(int(_id))
|
||||
if isinstance(data, BangumiData):
|
||||
if isinstance(data, Bangumi):
|
||||
with DownloadClient() as client:
|
||||
client.remove_rule(data.rule_name)
|
||||
client.remove_rss_feed(data.official_title)
|
||||
@@ -54,7 +54,7 @@ class TorrentManager(BangumiDatabase):
|
||||
|
||||
def disable_rule(self, _id: str | int, file: bool = False):
|
||||
data = self.search_id(int(_id))
|
||||
if isinstance(data, BangumiData):
|
||||
if isinstance(data, Bangumi):
|
||||
with DownloadClient() as client:
|
||||
client.remove_rule(data.rule_name)
|
||||
data.deleted = True
|
||||
@@ -81,7 +81,7 @@ class TorrentManager(BangumiDatabase):
|
||||
|
||||
def enable_rule(self, _id: str | int):
|
||||
data = self.search_id(int(_id))
|
||||
if isinstance(data, BangumiData):
|
||||
if isinstance(data, Bangumi):
|
||||
data.deleted = False
|
||||
self.update_one(data)
|
||||
with DownloadClient() as client:
|
||||
@@ -98,7 +98,7 @@ class TorrentManager(BangumiDatabase):
|
||||
status_code=406, content={"msg": f"Can't find bangumi id {_id}"}
|
||||
)
|
||||
|
||||
def update_rule(self, data: BangumiData):
|
||||
def update_rule(self, data: Bangumi):
|
||||
old_data = self.search_id(data.id)
|
||||
if not old_data:
|
||||
logger.error(f"[Manager] Can't find data with {data.id}")
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from .bangumi import *
|
||||
from .bangumi import Bangumi, Episode
|
||||
from .config import Config
|
||||
from .rss import RSSTorrents
|
||||
from .torrent import EpisodeFile, SubtitleFile, TorrentBase
|
||||
|
||||
@@ -7,7 +7,9 @@ from typing import Optional
|
||||
|
||||
class Bangumi(SQLModel, table=True):
|
||||
id: int = Field(default=None, primary_key=True)
|
||||
official_title: str = Field(default="official_title", alias="official_title", title="番剧中文名")
|
||||
official_title: str = Field(
|
||||
default="official_title", alias="official_title", title="番剧中文名"
|
||||
)
|
||||
year: Optional[str] = Field(alias="year", title="番剧年份")
|
||||
title_raw: str = Field(default="title_raw", alias="title_raw", title="番剧原名")
|
||||
season: int = Field(default=1, alias="season", title="番剧季度")
|
||||
@@ -28,7 +30,9 @@ class Bangumi(SQLModel, table=True):
|
||||
|
||||
|
||||
class BangumiUpdate(SQLModel):
|
||||
official_title: str = Field(default="official_title", alias="official_title", title="番剧中文名")
|
||||
official_title: str = Field(
|
||||
default="official_title", alias="official_title", title="番剧中文名"
|
||||
)
|
||||
year: Optional[str] = Field(alias="year", title="番剧年份")
|
||||
season: int = Field(default=1, alias="season", title="番剧季度")
|
||||
season_raw: Optional[str] = Field(alias="season_raw", title="番剧季度原名")
|
||||
|
||||
@@ -14,7 +14,9 @@ class Downloader(BaseModel):
|
||||
type: str = Field("qbittorrent", description="Downloader type")
|
||||
host: str = Field("172.17.0.1:8080", description="Downloader host")
|
||||
username_: str = Field("admin", alias="username", description="Downloader username")
|
||||
password_: str = Field("adminadmin", alias="password", description="Downloader password")
|
||||
password_: str = Field(
|
||||
"adminadmin", alias="password", description="Downloader password"
|
||||
)
|
||||
path: str = Field("/downloads/Bangumi", description="Downloader path")
|
||||
ssl: bool = Field(False, description="Downloader ssl")
|
||||
|
||||
@@ -26,6 +28,7 @@ class Downloader(BaseModel):
|
||||
def password(self):
|
||||
return expandvars(self.password_)
|
||||
|
||||
|
||||
class RSSParser(BaseModel):
|
||||
enable: bool = Field(True, description="Enable RSS parser")
|
||||
type: str = Field("mikan", description="RSS parser type")
|
||||
@@ -39,6 +42,7 @@ class RSSParser(BaseModel):
|
||||
def token(self):
|
||||
return expandvars(self.token_)
|
||||
|
||||
|
||||
class BangumiManage(BaseModel):
|
||||
enable: bool = Field(True, description="Enable bangumi manage")
|
||||
eps_complete: bool = Field(False, description="Enable eps complete")
|
||||
@@ -82,6 +86,7 @@ class Notification(BaseModel):
|
||||
def chat_id(self):
|
||||
return expandvars(self.chat_id_)
|
||||
|
||||
|
||||
class Config(BaseModel):
|
||||
program: Program = Program()
|
||||
downloader: Downloader = Downloader()
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
|
||||
from module.conf import settings
|
||||
from module.models import BangumiData
|
||||
from module.models import Bangumi
|
||||
|
||||
from .analyser import raw_parser, tmdb_parser, torrent_parser
|
||||
|
||||
@@ -39,7 +39,7 @@ class TitleParser:
|
||||
return official_title, tmdb_season, year
|
||||
|
||||
@staticmethod
|
||||
def raw_parser(raw: str, rss_link: str) -> BangumiData | None:
|
||||
def raw_parser(raw: str, rss_link: str) -> Bangumi | None:
|
||||
language = settings.rss_parser.language
|
||||
try:
|
||||
episode = raw_parser(raw)
|
||||
@@ -60,7 +60,7 @@ class TitleParser:
|
||||
else:
|
||||
official_title = title_raw
|
||||
_season = episode.season
|
||||
data = BangumiData(
|
||||
data = Bangumi(
|
||||
official_title=official_title,
|
||||
title_raw=title_raw,
|
||||
season=_season,
|
||||
|
||||
@@ -3,7 +3,7 @@ import re
|
||||
|
||||
from module.conf import settings
|
||||
from module.database import BangumiDatabase
|
||||
from module.models import BangumiData
|
||||
from module.models import Bangumi
|
||||
from module.network import RequestContent, TorrentInfo
|
||||
from module.parser import TitleParser
|
||||
|
||||
@@ -16,7 +16,7 @@ class RSSAnalyser:
|
||||
with BangumiDatabase() as db:
|
||||
db.update_table()
|
||||
|
||||
def official_title_parser(self, data: BangumiData, mikan_title: str):
|
||||
def official_title_parser(self, data: Bangumi, mikan_title: str):
|
||||
if settings.rss_parser.parser_type == "mikan":
|
||||
data.official_title = mikan_title if mikan_title else data.official_title
|
||||
elif settings.rss_parser.parser_type == "tmdb":
|
||||
@@ -63,7 +63,7 @@ class RSSAnalyser:
|
||||
|
||||
def torrent_to_data(
|
||||
self, torrent: TorrentInfo, rss_link: str | None = None
|
||||
) -> BangumiData:
|
||||
) -> Bangumi:
|
||||
data = self._title_analyser.raw_parser(raw=torrent.name, rss_link=rss_link)
|
||||
if data:
|
||||
try:
|
||||
@@ -79,7 +79,7 @@ class RSSAnalyser:
|
||||
|
||||
def rss_to_data(
|
||||
self, rss_link: str, database: BangumiDatabase, full_parse: bool = True
|
||||
) -> list[BangumiData]:
|
||||
) -> list[Bangumi]:
|
||||
rss_torrents = self.get_rss_torrents(rss_link, full_parse)
|
||||
torrents_to_add = database.match_list(rss_torrents, rss_link)
|
||||
if not torrents_to_add:
|
||||
@@ -92,7 +92,7 @@ class RSSAnalyser:
|
||||
else:
|
||||
return []
|
||||
|
||||
def link_to_data(self, link: str) -> BangumiData:
|
||||
def link_to_data(self, link: str) -> Bangumi:
|
||||
torrents = self.get_rss_torrents(link, False)
|
||||
for torrent in torrents:
|
||||
data = self.torrent_to_data(torrent, link)
|
||||
|
||||
@@ -3,7 +3,7 @@ import logging
|
||||
from module.conf import settings
|
||||
from module.database import BangumiDatabase
|
||||
from module.downloader import DownloadClient
|
||||
from module.models import BangumiData
|
||||
from module.models import Bangumi
|
||||
from module.network import RequestContent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -14,7 +14,7 @@ def matched(torrent_title: str):
|
||||
return db.match_torrent(torrent_title)
|
||||
|
||||
|
||||
def save_path(data: BangumiData):
|
||||
def save_path(data: Bangumi):
|
||||
folder = (
|
||||
f"{data.official_title}({data.year})" if data.year else f"{data.official_title}"
|
||||
)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import re
|
||||
|
||||
from module.database import RSSDatabase
|
||||
from module.models import BangumiData, RSSTorrents
|
||||
from module.models import Bangumi, RSSTorrents
|
||||
from module.network import RequestContent, TorrentInfo
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ class RSSPoller(RSSDatabase):
|
||||
return req.get_torrents(rss_link)
|
||||
|
||||
@staticmethod
|
||||
def filter_torrent(data: BangumiData, torrent: TorrentInfo) -> bool:
|
||||
def filter_torrent(data: Bangumi, torrent: TorrentInfo) -> bool:
|
||||
if data.title_raw in torrent.name:
|
||||
_filter = "|".join(data.filter)
|
||||
if not re.search(_filter, torrent.name):
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from module.models import BangumiData, TorrentBase
|
||||
from module.models import Bangumi, TorrentBase
|
||||
from module.network import RequestContent
|
||||
from module.searcher.plugin import search_url
|
||||
|
||||
@@ -30,7 +30,7 @@ class SearchTorrent(RequestContent):
|
||||
|
||||
return [TorrentBase(**d) for d in to_dict()]
|
||||
|
||||
def search_season(self, data: BangumiData):
|
||||
def search_season(self, data: Bangumi):
|
||||
keywords = [getattr(data, key) for key in SEARCH_KEY if getattr(data, key)]
|
||||
torrents = self.search_torrents(keywords)
|
||||
return [torrent for torrent in torrents if data.title_raw in torrent.name]
|
||||
|
||||
@@ -2,7 +2,7 @@ import os
|
||||
|
||||
from module.conf import LEGACY_DATA_PATH
|
||||
from module.database import BangumiDatabase
|
||||
from module.models import BangumiData
|
||||
from module.models import Bangumi
|
||||
from module.utils import json_config
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ def data_migration():
|
||||
rss_link = old_data["rss_link"]
|
||||
new_data = []
|
||||
for info in infos:
|
||||
new_data.append(BangumiData(**info, rss_link=[rss_link]))
|
||||
new_data.append(Bangumi(**info, rss_link=[rss_link]))
|
||||
with BangumiDatabase() as database:
|
||||
database.update_table()
|
||||
database.insert_list(new_data)
|
||||
|
||||
@@ -1,11 +1,17 @@
|
||||
from sqlmodel import create_engine, SQLModel, Session
|
||||
from sqlmodel.pool import StaticPool
|
||||
|
||||
from module.database import BangumiDatabase
|
||||
from module.models import BangumiData
|
||||
from module.models import Bangumi
|
||||
|
||||
|
||||
def test_database():
|
||||
TEST_PATH = "test/test.db"
|
||||
test_data = BangumiData(
|
||||
id=1,
|
||||
# sqlite mock engine
|
||||
engine = create_engine(
|
||||
"sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool
|
||||
)
|
||||
SQLModel.metadata.create_all(engine)
|
||||
test_data = Bangumi(
|
||||
official_title="test",
|
||||
year="2021",
|
||||
title_raw="test",
|
||||
@@ -17,18 +23,15 @@ def test_database():
|
||||
subtitle="test",
|
||||
eps_collect=False,
|
||||
offset=0,
|
||||
filter=["720p", "\\d+-\\d+"],
|
||||
rss_link=["test"],
|
||||
filter="720p,\\d+-\\d+",
|
||||
rss_link="test",
|
||||
poster_link="/test/test.jpg",
|
||||
added=False,
|
||||
rule_name=None,
|
||||
save_path=None,
|
||||
deleted=False,
|
||||
)
|
||||
with BangumiDatabase(database=TEST_PATH) as database:
|
||||
# create table
|
||||
database.update_table()
|
||||
with BangumiDatabase(database=TEST_PATH) as database:
|
||||
with BangumiDatabase(engine) as database:
|
||||
# insert
|
||||
database.insert_one(test_data)
|
||||
assert database.search_id(1) == test_data
|
||||
@@ -44,8 +47,3 @@ def test_database():
|
||||
# delete
|
||||
database.delete_one(1)
|
||||
assert database.search_id(1) is None
|
||||
|
||||
# Delete test database
|
||||
import os
|
||||
|
||||
os.remove(TEST_PATH)
|
||||
|
||||
Reference in New Issue
Block a user