mirror of
https://github.com/EstrellaXD/Auto_Bangumi.git
synced 2026-07-05 03:27:57 +08:00
refactor: database and rss engine
This commit is contained in:
@@ -4,35 +4,34 @@ from sqlmodel import Session, select, delete, or_
|
||||
from sqlalchemy.sql import func
|
||||
from typing import Optional
|
||||
|
||||
from .engine import engine
|
||||
from module.models import Bangumi
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BangumiDatabase(Session):
|
||||
def __init__(self, _engine=engine):
|
||||
super().__init__(_engine)
|
||||
class BangumiDatabase:
|
||||
def __init__(self, session: Session):
|
||||
self.session = session
|
||||
|
||||
def insert_one(self, data: Bangumi):
|
||||
self.add(data)
|
||||
self.commit()
|
||||
self.session.add(data)
|
||||
self.session.commit()
|
||||
logger.debug(f"[Database] Insert {data.official_title} into database.")
|
||||
|
||||
def insert_list(self, data: list[Bangumi]):
|
||||
self.add_all(data)
|
||||
self.session.add_all(data)
|
||||
logger.debug(f"[Database] Insert {len(data)} bangumi into database.")
|
||||
|
||||
def update_one(self, data: Bangumi) -> bool:
|
||||
db_data = self.get(Bangumi, data.id)
|
||||
db_data = self.session.get(Bangumi, data.id)
|
||||
if not db_data:
|
||||
return False
|
||||
bangumi_data = data.dict(exclude_unset=True)
|
||||
for key, value in bangumi_data.items():
|
||||
setattr(db_data, key, value)
|
||||
self.add(db_data)
|
||||
self.commit()
|
||||
self.refresh(db_data)
|
||||
self.session.add(db_data)
|
||||
self.session.commit()
|
||||
self.session.refresh(db_data)
|
||||
logger.debug(f"[Database] Update {data.official_title}")
|
||||
return True
|
||||
|
||||
@@ -43,53 +42,53 @@ class BangumiDatabase(Session):
|
||||
def update_rss(self, title_raw, rss_set: str):
|
||||
# Update rss and added
|
||||
statement = select(Bangumi).where(Bangumi.title_raw == title_raw)
|
||||
bangumi = self.exec(statement).first()
|
||||
bangumi = self.session.exec(statement).first()
|
||||
bangumi.rss_link = rss_set
|
||||
bangumi.added = False
|
||||
self.add(bangumi)
|
||||
self.commit()
|
||||
self.refresh(bangumi)
|
||||
self.session.add(bangumi)
|
||||
self.session.commit()
|
||||
self.session.refresh(bangumi)
|
||||
logger.debug(f"[Database] Update {title_raw} rss_link to {rss_set}.")
|
||||
|
||||
def update_poster(self, title_raw, poster_link: str):
|
||||
statement = select(Bangumi).where(Bangumi.title_raw == title_raw)
|
||||
bangumi = self.exec(statement).first()
|
||||
bangumi = self.session.exec(statement).first()
|
||||
bangumi.poster_link = poster_link
|
||||
self.add(bangumi)
|
||||
self.commit()
|
||||
self.refresh(bangumi)
|
||||
self.session.add(bangumi)
|
||||
self.session.commit()
|
||||
self.session.refresh(bangumi)
|
||||
logger.debug(f"[Database] Update {title_raw} poster_link to {poster_link}.")
|
||||
|
||||
def delete_one(self, _id: int):
|
||||
statement = select(Bangumi).where(Bangumi.id == _id)
|
||||
bangumi = self.exec(statement).first()
|
||||
self.delete(bangumi)
|
||||
self.commit()
|
||||
bangumi = self.session.exec(statement).first()
|
||||
self.session.delete(bangumi)
|
||||
self.session.commit()
|
||||
logger.debug(f"[Database] Delete bangumi id: {_id}.")
|
||||
|
||||
def delete_all(self):
|
||||
statement = delete(Bangumi)
|
||||
self.exec(statement)
|
||||
self.commit()
|
||||
self.session.exec(statement)
|
||||
self.session.commit()
|
||||
|
||||
def search_all(self) -> list[Bangumi]:
|
||||
statement = select(Bangumi)
|
||||
return self.exec(statement).all()
|
||||
return self.session.exec(statement).all()
|
||||
|
||||
def search_id(self, _id: int) -> Optional[Bangumi]:
|
||||
statement = select(Bangumi).where(Bangumi.id == _id)
|
||||
bangumi = self.exec(statement).first()
|
||||
bangumi = self.session.exec(statement).first()
|
||||
if bangumi is None:
|
||||
logger.warning(f"[Database] Cannot find bangumi id: {_id}.")
|
||||
return None
|
||||
else:
|
||||
logger.debug(f"[Database] Find bangumi id: {_id}.")
|
||||
return self.exec(statement).first()
|
||||
return self.session.exec(statement).first()
|
||||
|
||||
def match_poster(self, bangumi_name: str) -> str:
|
||||
# Use like to match
|
||||
statement = select(Bangumi).where(func.instr(bangumi_name, Bangumi.title_raw) > 0)
|
||||
data = self.exec(statement).first()
|
||||
data = self.session.exec(statement).first()
|
||||
if data:
|
||||
return data.poster_link
|
||||
else:
|
||||
@@ -119,7 +118,7 @@ class BangumiDatabase(Session):
|
||||
def not_complete(self) -> list[Bangumi]:
|
||||
# Find eps_complete = False
|
||||
condition = select(Bangumi).where(Bangumi.eps_collect == 0)
|
||||
datas = self.exec(condition).all()
|
||||
datas = self.session.exec(condition).all()
|
||||
return datas
|
||||
|
||||
def not_added(self) -> list[Bangumi]:
|
||||
@@ -128,19 +127,20 @@ class BangumiDatabase(Session):
|
||||
Bangumi.added == 0, Bangumi.rule_name is None, Bangumi.save_path is None
|
||||
)
|
||||
)
|
||||
datas = self.exec(conditions).all()
|
||||
datas = self.session.exec(conditions).all()
|
||||
return datas
|
||||
|
||||
def disable_rule(self, _id: int):
|
||||
statement = select(Bangumi).where(Bangumi.id == _id)
|
||||
bangumi = self.exec(statement).first()
|
||||
bangumi = self.session.exec(statement).first()
|
||||
bangumi.deleted = True
|
||||
self.add(bangumi)
|
||||
self.commit()
|
||||
self.refresh(bangumi)
|
||||
self.session.add(bangumi)
|
||||
self.session.commit()
|
||||
self.session.refresh(bangumi)
|
||||
logger.debug(f"[Database] Disable rule {bangumi.title_raw}.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
with BangumiDatabase() as db:
|
||||
print(db.not_complete())
|
||||
def search_rss(self, rss_link: str) -> list[Bangumi]:
|
||||
statement = select(Bangumi).where(
|
||||
func.instr(rss_link, Bangumi.rss_link) > 0
|
||||
)
|
||||
return self.session.exec(statement).all()
|
||||
18
backend/src/module/database/combine.py
Normal file
18
backend/src/module/database/combine.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from sqlmodel import Session,SQLModel
|
||||
|
||||
from .engine import engine
|
||||
from .rss import RSSDatabase
|
||||
from .torrent import TorrentDatabase
|
||||
from .bangumi import BangumiDatabase
|
||||
|
||||
|
||||
class Database(Session):
|
||||
def __init__(self, _engine=engine):
|
||||
super().__init__(_engine)
|
||||
self.rss = RSSDatabase(self)
|
||||
self.torrent = TorrentDatabase(self)
|
||||
self.bangumi = BangumiDatabase(self)
|
||||
|
||||
@staticmethod
|
||||
def create_table():
|
||||
SQLModel.metadata.create_all(engine)
|
||||
@@ -1,44 +1,33 @@
|
||||
from .orm import Connector
|
||||
import logging
|
||||
|
||||
from sqlmodel import Session, select, delete
|
||||
|
||||
from .engine import engine
|
||||
from module.models import RSSItem
|
||||
from module.conf import DATA_PATH
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RSSDatabase(Connector):
|
||||
def __init__(self, database: str = DATA_PATH):
|
||||
super().__init__(
|
||||
table_name="RSSItem",
|
||||
data=RSSItem().dict(),
|
||||
database=database
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def __data_to_db(data: RSSItem) -> dict:
|
||||
db_data = data.dict()
|
||||
for key, value in db_data.items():
|
||||
if isinstance(value, bool):
|
||||
db_data[key] = int(value)
|
||||
elif isinstance(value, list):
|
||||
db_data[key] = ",".join(value)
|
||||
return db_data
|
||||
|
||||
@staticmethod
|
||||
def __db_to_data(db_data: dict) -> RSSItem:
|
||||
for key, item in db_data.items():
|
||||
if isinstance(item, int):
|
||||
db_data[key] = bool(item)
|
||||
return RSSItem(**db_data)
|
||||
|
||||
def update_table(self):
|
||||
self.update.table()
|
||||
class RSSDatabase:
|
||||
def __init__(self, session: Session):
|
||||
self.session = session
|
||||
|
||||
def insert_one(self, data: RSSItem):
|
||||
dict_data = self.__data_to_db(data)
|
||||
self.insert.one(data=dict_data)
|
||||
self.session.add(data)
|
||||
self.session.commit()
|
||||
self.session.refresh(data)
|
||||
|
||||
def get_all(self) -> list[RSSItem]:
|
||||
dict_datas = self.select.all()
|
||||
return [self.__db_to_data(x) for x in dict_datas]
|
||||
def search_all(self) -> list[RSSItem]:
|
||||
return self.session.exec(select(RSSItem)).all()
|
||||
|
||||
def delete_one(self, _id: int):
|
||||
self.delete.one(_id)
|
||||
condition = delete(RSSItem).where(RSSItem.id == _id)
|
||||
self.session.exec(condition)
|
||||
self.session.commit()
|
||||
|
||||
def delete_all(self):
|
||||
condition = delete(RSSItem)
|
||||
self.session.exec(condition)
|
||||
self.session.commit()
|
||||
|
||||
|
||||
|
||||
@@ -1,51 +1,54 @@
|
||||
import logging
|
||||
|
||||
from module.database.orm import Connector
|
||||
from module.models import TorrentData
|
||||
from module.conf import DATA_PATH
|
||||
from sqlmodel import Session, select
|
||||
|
||||
from module.models import Torrent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TorrentDatabase(Connector):
|
||||
def __init__(self, database: str = DATA_PATH):
|
||||
super().__init__(
|
||||
table_name="torrent", data=TorrentData().dict(), database=database
|
||||
)
|
||||
class TorrentDatabase:
|
||||
def __init__(self, session: Session):
|
||||
self.session = session
|
||||
|
||||
def update_table(self):
|
||||
self.update.table()
|
||||
def insert_one(self, data: Torrent):
|
||||
self.session.add(data)
|
||||
self.session.commit()
|
||||
self.session.refresh(data)
|
||||
logger.debug(f"Insert {data.name} in database.")
|
||||
|
||||
def __data_to_db(self, data: TorrentData) -> dict:
|
||||
db_data = data.dict()
|
||||
for key, value in db_data.items():
|
||||
if isinstance(value, bool):
|
||||
db_data[key] = int(value)
|
||||
elif isinstance(value, list):
|
||||
db_data[key] = ",".join(value)
|
||||
return db_data
|
||||
def insert_many(self, datas: list[Torrent]):
|
||||
self.session.add_all(datas)
|
||||
self.session.commit()
|
||||
logger.debug(f"Insert {len(datas)} torrents in database.")
|
||||
|
||||
def __db_to_data(self, db_data: dict) -> TorrentData:
|
||||
for key, item in db_data.items():
|
||||
if isinstance(item, int):
|
||||
db_data[key] = bool(item)
|
||||
elif key in ["filter", "rss_link"]:
|
||||
db_data[key] = item.split(",")
|
||||
return TorrentData(**db_data)
|
||||
def update_one_sys(self, data: Torrent):
|
||||
self.session.add(data)
|
||||
self.session.commit()
|
||||
self.session.refresh(data)
|
||||
logger.debug(f"Update {data.name} in database.")
|
||||
|
||||
def insert_many(self, data_list: list[TorrentData]):
|
||||
dict_datas = [self.__data_to_db(data) for data in data_list]
|
||||
self.insert.many(dict_datas)
|
||||
def update_many_sys(self, datas: list[Torrent]):
|
||||
self.session.add_all(datas)
|
||||
self.session.commit()
|
||||
|
||||
def get_all(self) -> list[TorrentData]:
|
||||
dict_datas = self.select.all()
|
||||
return [self.__db_to_data(data) for data in dict_datas]
|
||||
def update_one_user(self, data: Torrent):
|
||||
self.session.add(data)
|
||||
self.session.commit()
|
||||
self.session.refresh(data)
|
||||
logger.debug(f"Update {data.name} in database.")
|
||||
|
||||
def get_torrent_name(self) -> list[str]:
|
||||
dict_data = self.select.all()
|
||||
return [data["name"] for data in dict_data]
|
||||
def search_one(self, _id: int) -> Torrent:
|
||||
return self.session.exec(select(Torrent).where(Torrent.id == _id)).first()
|
||||
|
||||
def search_all(self) -> list[Torrent]:
|
||||
return self.session.exec(select(Torrent)).all()
|
||||
|
||||
if __name__ == "__main__":
|
||||
with TorrentDatabase() as db:
|
||||
db.update_table()
|
||||
def check_new(self, torrents_list: list[Torrent]) -> list[Torrent]:
|
||||
new_torrents = []
|
||||
for torrent in torrents_list:
|
||||
statement = select(Torrent).where(Torrent.name == torrent.name)
|
||||
db_torrent = self.session.exec(statement).first()
|
||||
if not db_torrent:
|
||||
new_torrents.append(torrent)
|
||||
return new_torrents
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from .bangumi import Bangumi, Episode
|
||||
from .bangumi import Bangumi, Episode, BangumiUpdate
|
||||
from .config import Config
|
||||
from .rss import RSSTorrents
|
||||
from .torrent import EpisodeFile, SubtitleFile, TorrentBase
|
||||
from .rss import RSSItem, RSSUpdate
|
||||
from .torrent import EpisodeFile, SubtitleFile, Torrent, TorrentUpdate
|
||||
from .user import UserLogin
|
||||
|
||||
@@ -34,6 +34,7 @@ class BangumiUpdate(SQLModel):
|
||||
default="official_title", alias="official_title", title="番剧中文名"
|
||||
)
|
||||
year: Optional[str] = Field(alias="year", title="番剧年份")
|
||||
title_raw: str = Field(default="title_raw", alias="title_raw", title="番剧原名")
|
||||
season: int = Field(default=1, alias="season", title="番剧季度")
|
||||
season_raw: Optional[str] = Field(alias="season_raw", title="番剧季度原名")
|
||||
group_name: Optional[str] = Field(alias="group_name", title="字幕组")
|
||||
@@ -44,7 +45,10 @@ class BangumiUpdate(SQLModel):
|
||||
offset: int = Field(default=0, alias="offset", title="番剧偏移量")
|
||||
filter: str = Field(default="720, \\d+-\\d+", alias="filter", title="番剧过滤器")
|
||||
rss_link: str = Field(default="", alias="rss_link", title="番剧RSS链接")
|
||||
poster_link: Optional[str] = Field(alias="poster_link", title="番剧海报链接")
|
||||
added: bool = Field(default=False, alias="added", title="是否已添加")
|
||||
rule_name: Optional[str] = Field(alias="rule_name", title="番剧规则名")
|
||||
save_path: Optional[str] = Field(alias="save_path", title="番剧保存路径")
|
||||
deleted: bool = Field(False, alias="deleted", title="是否已删除")
|
||||
|
||||
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
from os.path import expandvars
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# Sub config
|
||||
|
||||
|
||||
class Program(BaseModel):
|
||||
rss_time: int = Field(7200, description="Sleep time")
|
||||
|
||||
@@ -1,17 +1,21 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlmodel import SQLModel, Field
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class RSSItem(BaseModel):
|
||||
id: int = Field(0, alias="id", title="id")
|
||||
class RSSItem(SQLModel, table=True):
|
||||
id: int = Field(default=None, primary_key=True, alias="id")
|
||||
item_path: str = Field("example path", alias="item_path")
|
||||
url: str = Field("https://mikanani.me", alias="url")
|
||||
combine: bool = Field(True, alias="combine")
|
||||
enabled: bool = Field(True, alias="enabled")
|
||||
|
||||
|
||||
class TorrentData(BaseModel):
|
||||
id: int = Field(0, alias="id")
|
||||
rss_id: int = Field(0, alias="rss_id")
|
||||
name: str = Field("", alias="name")
|
||||
url: str = Field("https://example.com/torrent", alias="url")
|
||||
save_path: str = Field("path/to/save", alias="save_path")
|
||||
class RSSUpdate(SQLModel):
|
||||
item_path: Optional[str] = Field("example path", alias="item_path")
|
||||
url: Optional[str] = Field("https://mikanani.me", alias="url")
|
||||
combine: Optional[bool] = Field(True, alias="combine")
|
||||
enabled: Optional[bool] = Field(True, alias="enabled")
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,16 +1,19 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel
|
||||
from sqlmodel import SQLModel, Field
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class TorrentBase(BaseModel):
|
||||
name: str = Field(...)
|
||||
torrent_link: str = Field(...)
|
||||
homepage: str | None = Field(None)
|
||||
class Torrent(SQLModel, table=True):
|
||||
id: int = Field(default=None, primary_key=True, alias="id")
|
||||
refer_id: Optional[int] = Field(None, alias="refer_id")
|
||||
name: str = Field("", alias="name")
|
||||
url: str = Field("https://example.com/torrent", alias="url")
|
||||
homepage: Optional[str] = Field(None, alias="homepage")
|
||||
downloaded: bool = Field(False, alias="downloaded")
|
||||
|
||||
|
||||
class FileSet(BaseModel):
|
||||
media_path: str = Field(...)
|
||||
sc_subtitle: str | None = Field(None)
|
||||
tc_subtitle: str | None = Field(None)
|
||||
class TorrentUpdate(SQLModel):
|
||||
downloaded: bool = Field(False, alias="downloaded")
|
||||
|
||||
|
||||
class EpisodeFile(BaseModel):
|
||||
|
||||
@@ -37,7 +37,6 @@ class TorrentInfo:
|
||||
|
||||
|
||||
class RequestContent(RequestURL):
|
||||
# Mikanani RSS
|
||||
def get_torrents(
|
||||
self,
|
||||
_url: str,
|
||||
@@ -103,4 +102,6 @@ class RequestContent(RequestURL):
|
||||
|
||||
if __name__ == '__main__':
|
||||
with RequestContent() as req:
|
||||
req.get_xml("https://mikanani.me/RSS/Classic")
|
||||
ts = req.get_torrents("https://mikanani.me/RSS/Classic")
|
||||
for t in ts:
|
||||
print(t)
|
||||
|
||||
@@ -1,20 +1,35 @@
|
||||
import re
|
||||
|
||||
from module.database import RSSDatabase, BangumiDatabase, TorrentDatabase
|
||||
from module.models import BangumiData, RSSItem, TorrentData
|
||||
from module.models import Bangumi, RSSItem, Torrent
|
||||
from module.network import RequestContent, TorrentInfo
|
||||
|
||||
from module.database.combine import Database
|
||||
|
||||
class RSSEngine(RequestContent):
|
||||
@staticmethod
|
||||
def _get_rss_items() -> list[RSSItem]:
|
||||
with RSSDatabase() as db:
|
||||
return db.get_all()
|
||||
|
||||
@staticmethod
|
||||
def _get_bangumi_data(rss_link: str) -> list[BangumiData]:
|
||||
with BangumiDatabase() as db:
|
||||
return db.get_rss(rss_link)
|
||||
class RSSEngine(Database):
|
||||
def _get_rss_items(self) -> list[RSSItem]:
|
||||
return self.rss.search_all()
|
||||
|
||||
def _get_bangumi_data(self, rss_link: str) -> list[Bangumi]:
|
||||
return self.bangumi.search_rss(rss_link)
|
||||
|
||||
def get_torrent(self, rss_link: str) -> list[Torrent]:
|
||||
with RequestContent() as req:
|
||||
torrent_infos = req.get_torrents(rss_link)
|
||||
torrents: list[Torrent] = []
|
||||
for torrent_info in torrent_infos:
|
||||
torrents.append(
|
||||
Torrent(
|
||||
name=torrent_info.name,
|
||||
url=torrent_info.torrent_link,
|
||||
homepage=torrent_info.homepage,
|
||||
)
|
||||
)
|
||||
return torrents
|
||||
|
||||
def check_new_torrents(self, torrents_list: list[list[Torrent]]) -> list[Torrent]:
|
||||
return self.torrent.check_new(torrents_list)
|
||||
|
||||
def add_rss(self, rss_link: str, name: str, combine: bool):
|
||||
if not name:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from sqlmodel import create_engine, SQLModel
|
||||
from sqlmodel.pool import StaticPool
|
||||
|
||||
from module.database import BangumiDatabase
|
||||
from module.database.combine import Database
|
||||
from module.models import Bangumi
|
||||
|
||||
|
||||
@@ -31,19 +31,19 @@ def test_bangumi_database():
|
||||
save_path=None,
|
||||
deleted=False,
|
||||
)
|
||||
with BangumiDatabase(engine) as database:
|
||||
with Database(engine) as db:
|
||||
# insert
|
||||
database.insert_one(test_data)
|
||||
assert database.search_id(1) == test_data
|
||||
db.bangumi.insert_one(test_data)
|
||||
assert db.bangumi.search_id(1) == test_data
|
||||
|
||||
# update
|
||||
test_data.official_title = "test2"
|
||||
database.update_one(test_data)
|
||||
assert database.search_id(1) == test_data
|
||||
db.bangumi.update_one(test_data)
|
||||
assert db.bangumi.search_id(1) == test_data
|
||||
|
||||
# search poster
|
||||
assert database.match_poster("test2 (2021)") == "/test/test.jpg"
|
||||
assert db.bangumi.match_poster("test2 (2021)") == "/test/test.jpg"
|
||||
|
||||
# delete
|
||||
database.delete_one(1)
|
||||
assert database.search_id(1) is None
|
||||
db.bangumi.delete_one(1)
|
||||
assert db.bangumi.search_id(1) is None
|
||||
|
||||
Reference in New Issue
Block a user