refactor: database and rss engine

This commit is contained in:
estrella
2023-08-03 14:16:52 +08:00
parent f80c3bffb4
commit 472a5093e9
12 changed files with 189 additions and 154 deletions

View File

@@ -4,35 +4,34 @@ from sqlmodel import Session, select, delete, or_
from sqlalchemy.sql import func
from typing import Optional
from .engine import engine
from module.models import Bangumi
logger = logging.getLogger(__name__)
class BangumiDatabase(Session):
def __init__(self, _engine=engine):
super().__init__(_engine)
class BangumiDatabase:
def __init__(self, session: Session):
self.session = session
def insert_one(self, data: Bangumi):
self.add(data)
self.commit()
self.session.add(data)
self.session.commit()
logger.debug(f"[Database] Insert {data.official_title} into database.")
def insert_list(self, data: list[Bangumi]):
self.add_all(data)
self.session.add_all(data)
logger.debug(f"[Database] Insert {len(data)} bangumi into database.")
def update_one(self, data: Bangumi) -> bool:
db_data = self.get(Bangumi, data.id)
db_data = self.session.get(Bangumi, data.id)
if not db_data:
return False
bangumi_data = data.dict(exclude_unset=True)
for key, value in bangumi_data.items():
setattr(db_data, key, value)
self.add(db_data)
self.commit()
self.refresh(db_data)
self.session.add(db_data)
self.session.commit()
self.session.refresh(db_data)
logger.debug(f"[Database] Update {data.official_title}")
return True
@@ -43,53 +42,53 @@ class BangumiDatabase(Session):
def update_rss(self, title_raw, rss_set: str):
# Update rss and added
statement = select(Bangumi).where(Bangumi.title_raw == title_raw)
bangumi = self.exec(statement).first()
bangumi = self.session.exec(statement).first()
bangumi.rss_link = rss_set
bangumi.added = False
self.add(bangumi)
self.commit()
self.refresh(bangumi)
self.session.add(bangumi)
self.session.commit()
self.session.refresh(bangumi)
logger.debug(f"[Database] Update {title_raw} rss_link to {rss_set}.")
def update_poster(self, title_raw, poster_link: str):
statement = select(Bangumi).where(Bangumi.title_raw == title_raw)
bangumi = self.exec(statement).first()
bangumi = self.session.exec(statement).first()
bangumi.poster_link = poster_link
self.add(bangumi)
self.commit()
self.refresh(bangumi)
self.session.add(bangumi)
self.session.commit()
self.session.refresh(bangumi)
logger.debug(f"[Database] Update {title_raw} poster_link to {poster_link}.")
def delete_one(self, _id: int):
statement = select(Bangumi).where(Bangumi.id == _id)
bangumi = self.exec(statement).first()
self.delete(bangumi)
self.commit()
bangumi = self.session.exec(statement).first()
self.session.delete(bangumi)
self.session.commit()
logger.debug(f"[Database] Delete bangumi id: {_id}.")
def delete_all(self):
statement = delete(Bangumi)
self.exec(statement)
self.commit()
self.session.exec(statement)
self.session.commit()
def search_all(self) -> list[Bangumi]:
statement = select(Bangumi)
return self.exec(statement).all()
return self.session.exec(statement).all()
def search_id(self, _id: int) -> Optional[Bangumi]:
statement = select(Bangumi).where(Bangumi.id == _id)
bangumi = self.exec(statement).first()
bangumi = self.session.exec(statement).first()
if bangumi is None:
logger.warning(f"[Database] Cannot find bangumi id: {_id}.")
return None
else:
logger.debug(f"[Database] Find bangumi id: {_id}.")
return self.exec(statement).first()
return self.session.exec(statement).first()
def match_poster(self, bangumi_name: str) -> str:
# Use like to match
statement = select(Bangumi).where(func.instr(bangumi_name, Bangumi.title_raw) > 0)
data = self.exec(statement).first()
data = self.session.exec(statement).first()
if data:
return data.poster_link
else:
@@ -119,7 +118,7 @@ class BangumiDatabase(Session):
def not_complete(self) -> list[Bangumi]:
# Find eps_complete = False
condition = select(Bangumi).where(Bangumi.eps_collect == 0)
datas = self.exec(condition).all()
datas = self.session.exec(condition).all()
return datas
def not_added(self) -> list[Bangumi]:
@@ -128,19 +127,20 @@ class BangumiDatabase(Session):
Bangumi.added == 0, Bangumi.rule_name is None, Bangumi.save_path is None
)
)
datas = self.exec(conditions).all()
datas = self.session.exec(conditions).all()
return datas
def disable_rule(self, _id: int):
statement = select(Bangumi).where(Bangumi.id == _id)
bangumi = self.exec(statement).first()
bangumi = self.session.exec(statement).first()
bangumi.deleted = True
self.add(bangumi)
self.commit()
self.refresh(bangumi)
self.session.add(bangumi)
self.session.commit()
self.session.refresh(bangumi)
logger.debug(f"[Database] Disable rule {bangumi.title_raw}.")
if __name__ == "__main__":
with BangumiDatabase() as db:
print(db.not_complete())
def search_rss(self, rss_link: str) -> list[Bangumi]:
statement = select(Bangumi).where(
func.instr(rss_link, Bangumi.rss_link) > 0
)
return self.session.exec(statement).all()

View File

@@ -0,0 +1,18 @@
from sqlmodel import Session,SQLModel
from .engine import engine
from .rss import RSSDatabase
from .torrent import TorrentDatabase
from .bangumi import BangumiDatabase
class Database(Session):
def __init__(self, _engine=engine):
super().__init__(_engine)
self.rss = RSSDatabase(self)
self.torrent = TorrentDatabase(self)
self.bangumi = BangumiDatabase(self)
@staticmethod
def create_table():
SQLModel.metadata.create_all(engine)

View File

@@ -1,44 +1,33 @@
from .orm import Connector
import logging
from sqlmodel import Session, select, delete
from .engine import engine
from module.models import RSSItem
from module.conf import DATA_PATH
logger = logging.getLogger(__name__)
class RSSDatabase(Connector):
def __init__(self, database: str = DATA_PATH):
super().__init__(
table_name="RSSItem",
data=RSSItem().dict(),
database=database
)
@staticmethod
def __data_to_db(data: RSSItem) -> dict:
db_data = data.dict()
for key, value in db_data.items():
if isinstance(value, bool):
db_data[key] = int(value)
elif isinstance(value, list):
db_data[key] = ",".join(value)
return db_data
@staticmethod
def __db_to_data(db_data: dict) -> RSSItem:
for key, item in db_data.items():
if isinstance(item, int):
db_data[key] = bool(item)
return RSSItem(**db_data)
def update_table(self):
self.update.table()
class RSSDatabase:
def __init__(self, session: Session):
self.session = session
def insert_one(self, data: RSSItem):
dict_data = self.__data_to_db(data)
self.insert.one(data=dict_data)
self.session.add(data)
self.session.commit()
self.session.refresh(data)
def get_all(self) -> list[RSSItem]:
dict_datas = self.select.all()
return [self.__db_to_data(x) for x in dict_datas]
def search_all(self) -> list[RSSItem]:
return self.session.exec(select(RSSItem)).all()
def delete_one(self, _id: int):
self.delete.one(_id)
condition = delete(RSSItem).where(RSSItem.id == _id)
self.session.exec(condition)
self.session.commit()
def delete_all(self):
condition = delete(RSSItem)
self.session.exec(condition)
self.session.commit()

View File

@@ -1,51 +1,54 @@
import logging
from module.database.orm import Connector
from module.models import TorrentData
from module.conf import DATA_PATH
from sqlmodel import Session, select
from module.models import Torrent
logger = logging.getLogger(__name__)
class TorrentDatabase(Connector):
def __init__(self, database: str = DATA_PATH):
super().__init__(
table_name="torrent", data=TorrentData().dict(), database=database
)
class TorrentDatabase:
def __init__(self, session: Session):
self.session = session
def update_table(self):
self.update.table()
def insert_one(self, data: Torrent):
self.session.add(data)
self.session.commit()
self.session.refresh(data)
logger.debug(f"Insert {data.name} in database.")
def __data_to_db(self, data: TorrentData) -> dict:
db_data = data.dict()
for key, value in db_data.items():
if isinstance(value, bool):
db_data[key] = int(value)
elif isinstance(value, list):
db_data[key] = ",".join(value)
return db_data
def insert_many(self, datas: list[Torrent]):
self.session.add_all(datas)
self.session.commit()
logger.debug(f"Insert {len(datas)} torrents in database.")
def __db_to_data(self, db_data: dict) -> TorrentData:
for key, item in db_data.items():
if isinstance(item, int):
db_data[key] = bool(item)
elif key in ["filter", "rss_link"]:
db_data[key] = item.split(",")
return TorrentData(**db_data)
def update_one_sys(self, data: Torrent):
self.session.add(data)
self.session.commit()
self.session.refresh(data)
logger.debug(f"Update {data.name} in database.")
def insert_many(self, data_list: list[TorrentData]):
dict_datas = [self.__data_to_db(data) for data in data_list]
self.insert.many(dict_datas)
def update_many_sys(self, datas: list[Torrent]):
self.session.add_all(datas)
self.session.commit()
def get_all(self) -> list[TorrentData]:
dict_datas = self.select.all()
return [self.__db_to_data(data) for data in dict_datas]
def update_one_user(self, data: Torrent):
self.session.add(data)
self.session.commit()
self.session.refresh(data)
logger.debug(f"Update {data.name} in database.")
def get_torrent_name(self) -> list[str]:
dict_data = self.select.all()
return [data["name"] for data in dict_data]
def search_one(self, _id: int) -> Torrent:
return self.session.exec(select(Torrent).where(Torrent.id == _id)).first()
def search_all(self) -> list[Torrent]:
return self.session.exec(select(Torrent)).all()
if __name__ == "__main__":
with TorrentDatabase() as db:
db.update_table()
def check_new(self, torrents_list: list[Torrent]) -> list[Torrent]:
new_torrents = []
for torrent in torrents_list:
statement = select(Torrent).where(Torrent.name == torrent.name)
db_torrent = self.session.exec(statement).first()
if not db_torrent:
new_torrents.append(torrent)
return new_torrents

View File

@@ -1,5 +1,5 @@
from .bangumi import Bangumi, Episode
from .bangumi import Bangumi, Episode, BangumiUpdate
from .config import Config
from .rss import RSSTorrents
from .torrent import EpisodeFile, SubtitleFile, TorrentBase
from .rss import RSSItem, RSSUpdate
from .torrent import EpisodeFile, SubtitleFile, Torrent, TorrentUpdate
from .user import UserLogin

View File

@@ -34,6 +34,7 @@ class BangumiUpdate(SQLModel):
default="official_title", alias="official_title", title="番剧中文名"
)
year: Optional[str] = Field(alias="year", title="番剧年份")
title_raw: str = Field(default="title_raw", alias="title_raw", title="番剧原名")
season: int = Field(default=1, alias="season", title="番剧季度")
season_raw: Optional[str] = Field(alias="season_raw", title="番剧季度原名")
group_name: Optional[str] = Field(alias="group_name", title="字幕组")
@@ -44,7 +45,10 @@ class BangumiUpdate(SQLModel):
offset: int = Field(default=0, alias="offset", title="番剧偏移量")
filter: str = Field(default="720, \\d+-\\d+", alias="filter", title="番剧过滤器")
rss_link: str = Field(default="", alias="rss_link", title="番剧RSS链接")
poster_link: Optional[str] = Field(alias="poster_link", title="番剧海报链接")
added: bool = Field(default=False, alias="added", title="是否已添加")
rule_name: Optional[str] = Field(alias="rule_name", title="番剧规则名")
save_path: Optional[str] = Field(alias="save_path", title="番剧保存路径")
deleted: bool = Field(False, alias="deleted", title="是否已删除")

View File

@@ -1,8 +1,6 @@
from os.path import expandvars
from pydantic import BaseModel, Field
# Sub config
class Program(BaseModel):
rss_time: int = Field(7200, description="Sleep time")

View File

@@ -1,17 +1,21 @@
from pydantic import BaseModel, Field
from sqlmodel import SQLModel, Field
from typing import Optional
class RSSItem(BaseModel):
id: int = Field(0, alias="id", title="id")
class RSSItem(SQLModel, table=True):
id: int = Field(default=None, primary_key=True, alias="id")
item_path: str = Field("example path", alias="item_path")
url: str = Field("https://mikanani.me", alias="url")
combine: bool = Field(True, alias="combine")
enabled: bool = Field(True, alias="enabled")
class TorrentData(BaseModel):
id: int = Field(0, alias="id")
rss_id: int = Field(0, alias="rss_id")
name: str = Field("", alias="name")
url: str = Field("https://example.com/torrent", alias="url")
save_path: str = Field("path/to/save", alias="save_path")
class RSSUpdate(SQLModel):
item_path: Optional[str] = Field("example path", alias="item_path")
url: Optional[str] = Field("https://mikanani.me", alias="url")
combine: Optional[bool] = Field(True, alias="combine")
enabled: Optional[bool] = Field(True, alias="enabled")

View File

@@ -1,16 +1,19 @@
from pydantic import BaseModel, Field
from pydantic import BaseModel
from sqlmodel import SQLModel, Field
from typing import Optional
class TorrentBase(BaseModel):
name: str = Field(...)
torrent_link: str = Field(...)
homepage: str | None = Field(None)
class Torrent(SQLModel, table=True):
id: int = Field(default=None, primary_key=True, alias="id")
refer_id: Optional[int] = Field(None, alias="refer_id")
name: str = Field("", alias="name")
url: str = Field("https://example.com/torrent", alias="url")
homepage: Optional[str] = Field(None, alias="homepage")
downloaded: bool = Field(False, alias="downloaded")
class FileSet(BaseModel):
media_path: str = Field(...)
sc_subtitle: str | None = Field(None)
tc_subtitle: str | None = Field(None)
class TorrentUpdate(SQLModel):
downloaded: bool = Field(False, alias="downloaded")
class EpisodeFile(BaseModel):

View File

@@ -37,7 +37,6 @@ class TorrentInfo:
class RequestContent(RequestURL):
# Mikanani RSS
def get_torrents(
self,
_url: str,
@@ -103,4 +102,6 @@ class RequestContent(RequestURL):
if __name__ == '__main__':
with RequestContent() as req:
req.get_xml("https://mikanani.me/RSS/Classic")
ts = req.get_torrents("https://mikanani.me/RSS/Classic")
for t in ts:
print(t)

View File

@@ -1,20 +1,35 @@
import re
from module.database import RSSDatabase, BangumiDatabase, TorrentDatabase
from module.models import BangumiData, RSSItem, TorrentData
from module.models import Bangumi, RSSItem, Torrent
from module.network import RequestContent, TorrentInfo
from module.database.combine import Database
class RSSEngine(RequestContent):
@staticmethod
def _get_rss_items() -> list[RSSItem]:
with RSSDatabase() as db:
return db.get_all()
@staticmethod
def _get_bangumi_data(rss_link: str) -> list[BangumiData]:
with BangumiDatabase() as db:
return db.get_rss(rss_link)
class RSSEngine(Database):
def _get_rss_items(self) -> list[RSSItem]:
return self.rss.search_all()
def _get_bangumi_data(self, rss_link: str) -> list[Bangumi]:
return self.bangumi.search_rss(rss_link)
def get_torrent(self, rss_link: str) -> list[Torrent]:
with RequestContent() as req:
torrent_infos = req.get_torrents(rss_link)
torrents: list[Torrent] = []
for torrent_info in torrent_infos:
torrents.append(
Torrent(
name=torrent_info.name,
url=torrent_info.torrent_link,
homepage=torrent_info.homepage,
)
)
return torrents
def check_new_torrents(self, torrents_list: list[list[Torrent]]) -> list[Torrent]:
return self.torrent.check_new(torrents_list)
def add_rss(self, rss_link: str, name: str, combine: bool):
if not name:

View File

@@ -1,7 +1,7 @@
from sqlmodel import create_engine, SQLModel
from sqlmodel.pool import StaticPool
from module.database import BangumiDatabase
from module.database.combine import Database
from module.models import Bangumi
@@ -31,19 +31,19 @@ def test_bangumi_database():
save_path=None,
deleted=False,
)
with BangumiDatabase(engine) as database:
with Database(engine) as db:
# insert
database.insert_one(test_data)
assert database.search_id(1) == test_data
db.bangumi.insert_one(test_data)
assert db.bangumi.search_id(1) == test_data
# update
test_data.official_title = "test2"
database.update_one(test_data)
assert database.search_id(1) == test_data
db.bangumi.update_one(test_data)
assert db.bangumi.search_id(1) == test_data
# search poster
assert database.match_poster("test2 (2021)") == "/test/test.jpg"
assert db.bangumi.match_poster("test2 (2021)") == "/test/test.jpg"
# delete
database.delete_one(1)
assert database.search_id(1) is None
db.bangumi.delete_one(1)
assert db.bangumi.search_id(1) is None