diff --git a/src/module/api/program.py b/src/module/api/program.py index a315fe9b..e8db5e8a 100644 --- a/src/module/api/program.py +++ b/src/module/api/program.py @@ -2,33 +2,26 @@ import signal import logging import os -from contextlib import asynccontextmanager from fastapi.exceptions import HTTPException from fastapi import FastAPI -from module.core import Program, check_status, check_rss, check_downloader +from module.core import Program logger = logging.getLogger(__name__) program = Program() +router = FastAPI() -@asynccontextmanager -async def lifespan(router: FastAPI): - logger.info("Starting program...") +@router.on_event("startup") +async def startup(): program.startup() - yield - program.stop() - - -router = FastAPI(lifespan=lifespan) @router.on_event("shutdown") -def shutdown(): +async def shutdown(): program.stop() - logger.info("Stopping program...") @router.get("/api/v1/restart", tags=["program"]) @@ -78,15 +71,9 @@ async def shutdown_program(): # Check status @router.get("/api/v1/check/downloader", tags=["check"]) async def check_downloader_status(): - return check_downloader() + return program.check_downloader() @router.get("/api/v1/check/rss", tags=["check"]) async def check_rss_status(): - return check_rss() - - -@router.get("/api/v1/check", tags=["check"]) -async def check_all(): - return check_status() - + return program.check_analyser() diff --git a/src/module/checker/checker.py b/src/module/checker/checker.py index c9ce5615..e246cadf 100644 --- a/src/module/checker/checker.py +++ b/src/module/checker/checker.py @@ -47,7 +47,8 @@ class Checker: @staticmethod def check_first_run() -> bool: - if os.path.exists(DATA_PATH): + token_exist = False if settings.rss_parser.token in ["", "token"] else True + if token_exist: return False else: return True diff --git a/src/module/conf/config.py b/src/module/conf/config.py index 7eba3d6e..24a3d39a 100644 --- a/src/module/conf/config.py +++ b/src/module/conf/config.py @@ -29,10 +29,7 @@ class Settings(Config): self.load() self.save() else: - # load from env - load_dotenv(".env") - self.__load_from_env() - self.save() + self.init() def load(self): with open(CONFIG_PATH, "r", encoding="utf-8") as f: @@ -44,11 +41,14 @@ class Settings(Config): def save(self, config_dict: dict | None = None): if not config_dict: config_dict = self.dict() - if not os.path.exists("config"): - os.makedirs("config") with open(CONFIG_PATH, "w", encoding="utf-8") as f: json.dump(config_dict, f, indent=4) + def init(self): + load_dotenv(".env") + self.__load_from_env() + self.save() + @property def rss_link(self) -> str: if "://" not in self.rss_parser.custom_url: diff --git a/src/module/core/__init__.py b/src/module/core/__init__.py index 8194e6bb..82d79d10 100644 --- a/src/module/core/__init__.py +++ b/src/module/core/__init__.py @@ -1,2 +1 @@ -from .program import Program -from .check import check_status, check_rss, check_downloader \ No newline at end of file +from .program import Program \ No newline at end of file diff --git a/src/module/core/check.py b/src/module/core/check.py deleted file mode 100644 index f158453d..00000000 --- a/src/module/core/check.py +++ /dev/null @@ -1,45 +0,0 @@ -import logging - -from module.conf import settings -from module.downloader import DownloadClient -from module.network import RequestContent - -logger = logging.getLogger(__name__) - - -def check_status() -> bool: - if settings.rss_parser.token in ["", "token"]: - logger.warning("Please set RSS token") - return False - if check_downloader(): - logger.debug("All check passed") - return True - return False - - -def check_downloader(): - with DownloadClient() as client: - if client.authed: - logger.debug("Downloader is running") - return True - else: - logger.warning("Can't connect to downloader") - return False - - -def check_rss(): - rss_link = settings.rss_link() - with RequestContent() as req: - try: - torrents = req.get_torrents(rss_link) - except Exception as e: - logger.warning("Failed to get torrents from RSS") - logger.warning(e) - return False - if not torrents: - logger.warning("No torrents in RSS") - logger.warning("Please check your RSS link") - return False - else: - logger.debug("RSS is running") - return True diff --git a/src/module/core/program.py b/src/module/core/program.py index 8b0ad351..190eb6c9 100644 --- a/src/module/core/program.py +++ b/src/module/core/program.py @@ -3,6 +3,7 @@ import logging from .sub_thread import RenameThread, RSSThread from module.conf import settings, VERSION +from module.update import data_migration logger = logging.getLogger(__name__) @@ -21,6 +22,14 @@ class Program(RenameThread, RSSThread): def startup(self): self.__start_info() + if self.first_run: + logger.info("First run detected, please configure the program in webui.") + return {"status": "First run detected."} + if self.legacy_data: + logger.info( + "Legacy data detected, starting data migration, please wait patiently." + ) + data_migration() self.start() def start(self): diff --git a/src/module/core/startup.py b/src/module/core/startup.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/module/core/status.py b/src/module/core/status.py index d957a5e9..f0a83c86 100644 --- a/src/module/core/status.py +++ b/src/module/core/status.py @@ -1,7 +1,9 @@ +import os.path import threading import asyncio from module.checker import Checker +from module.conf import DATA_PATH class ProgramStatus(Checker): @@ -44,3 +46,7 @@ class ProgramStatus(Checker): @property def first_run(self): return self.check_first_run() + + @property + def legacy_data(self): + return os.path.exists("data/data.json") diff --git a/src/module/database/bangumi.py b/src/module/database/bangumi.py index 34d174e7..3b669fc8 100644 --- a/src/module/database/bangumi.py +++ b/src/module/database/bangumi.py @@ -10,6 +10,7 @@ class BangumiDatabase(DataConnector): def __init__(self): super().__init__() self.__table_name = "bangumi" + self.update_table() def update_table(self): db_data = self.__data_to_db(BangumiData()) @@ -55,16 +56,11 @@ class BangumiDatabase(DataConnector): def update_one(self, data: BangumiData) -> bool: db_data = self.__data_to_db(data) - update_columns = ", ".join([f"{key} = :{key}" for key in db_data.keys() if key != "id"]) - self._cursor.execute(f"UPDATE bangumi SET {update_columns} WHERE id = :id", db_data) - self._conn.commit() - return self._cursor.rowcount == 1 + return self._update(db_data=db_data, table_name=self.__table_name) def update_list(self, data: list[BangumiData]): - db_data = [self.__data_to_db(x) for x in data] - update_columns = ", ".join([f"{key} = :{key}" for key in db_data[0].keys() if key != "id"]) - self._cursor.executemany(f"UPDATE bangumi SET {update_columns} WHERE id = :id", db_data) - self._conn.commit() + data_list = [self.__data_to_db(x) for x in data] + self._update_list(data_list=data_list, table_name=self.__table_name) def update_rss(self, title_raw, rss_set: str): # Update rss and added @@ -79,6 +75,18 @@ class BangumiDatabase(DataConnector): self._conn.commit() logger.debug(f"Update {title_raw} rss_link to {rss_set}.") + def update_poster(self, title_raw, poster_link: str): + self._cursor.execute( + """ + UPDATE bangumi + SET poster_link = :poster_link + WHERE title_raw = :title_raw + """, + {"poster_link": poster_link, "title_raw": title_raw}, + ) + self._conn.commit() + logger.debug(f"Update {title_raw} poster_link to {poster_link}.") + def delete_one(self, _id: int) -> bool: self._cursor.execute( """ @@ -130,26 +138,28 @@ class BangumiDatabase(DataConnector): return self.__db_to_data(dict_data) def match_poster(self, bangumi_name: str) -> str: - # Find title_raw which in torrent_name self._cursor.execute( """ - SELECT official_title, poster_link FROM bangumi + SELECT official_title, poster_link + FROM bangumi + WHERE INSTR(:bangumi_name, official_title) > 0 """ + , + {"bangumi_name": bangumi_name}, ) - data = self._cursor.fetchall() + data = self._cursor.fetchone() if not data: return "" - for official_title, poster_link in data: - if official_title in bangumi_name: - if poster_link: - return poster_link - return "" + official_title, poster_link = data + if not poster_link: + return "" + return poster_link def match_list(self, torrent_list: list, rss_link: str) -> list: # Match title_raw in database self._cursor.execute( """ - SELECT title_raw, rss_link FROM bangumi + SELECT title_raw, rss_link, poster_link FROM bangumi """ ) data = self._cursor.fetchall() @@ -159,11 +169,13 @@ class BangumiDatabase(DataConnector): i = 0 while i < len(torrent_list): torrent = torrent_list[i] - for title_raw, rss_set in data: + for title_raw, rss_set, poster_link in data: if title_raw in torrent.name: if rss_link not in rss_set: rss_set += "," + rss_link self.update_rss(title_raw, rss_set) + if not poster_link: + self.update_poster(title_raw, torrent.poster_link) torrent_list.pop(i) break else: @@ -197,9 +209,3 @@ class BangumiDatabase(DataConnector): if data is None: return 1 return data[0] + 1 - - -if __name__ == '__main__': - title = "[SweetSub&LoliHouse] Heavenly Delusion - 06 [WebRip 1080p HEVC-10bit AAC ASSx2].mkv" - with BangumiDatabase() as db: - print(db.match_poster(title)) \ No newline at end of file diff --git a/src/module/database/connector.py b/src/module/database/connector.py index e68d9d6e..fad7e347 100644 --- a/src/module/database/connector.py +++ b/src/module/database/connector.py @@ -40,6 +40,35 @@ class DataConnector: self._cursor.executemany(f"INSERT INTO {table_name} ({columns}) VALUES ({values})", data_list) self._conn.commit() + def _select(self, keys: list[str], table_name: str, condition: str = None) -> dict: + if condition is None: + self._cursor.execute(f"SELECT {', '.join(keys)} FROM {table_name}") + else: + self._cursor.execute(f"SELECT {', '.join(keys)} FROM {table_name} WHERE {condition}") + return dict(zip(keys, self._cursor.fetchone())) + + def _update(self, table_name: str, db_data: dict): + _id = db_data.get("id") + if _id is None: + raise ValueError("No _id in db_data.") + set_sql = ", ".join([f"{key} = :{key}" for key in db_data.keys()]) + self._cursor.execute(f"UPDATE {table_name} SET {set_sql} WHERE id = {_id}", db_data) + self._conn.commit() + return self._cursor.rowcount == 1 + + def _update_list(self, table_name: str, data_list: list[dict]): + if len(data_list) == 0: + return + set_sql = ", ".join([f"{key} = :{key}" for key in data_list[0].keys() if key != "id"]) + self._cursor.executemany(f"UPDATE {table_name} SET {set_sql} WHERE id = :id", data_list) + self._conn.commit() + + def _update_section(self, table_name: str, location: dict, update_dict: dict): + set_sql = ", ".join([f"{key} = :{key}" for key in update_dict.keys()]) + sql_loc = f"{location['key']} = {location['value']}" + self._cursor.execute(f"UPDATE {table_name} SET {set_sql} WHERE {sql_loc}", update_dict) + self._conn.commit() + def _delete_all(self, table_name: str): self._cursor.execute(f"DELETE FROM {table_name}") self._conn.commit() diff --git a/src/module/downloader/client/qb_downloader.py b/src/module/downloader/client/qb_downloader.py index 85a64fae..85cfc747 100644 --- a/src/module/downloader/client/qb_downloader.py +++ b/src/module/downloader/client/qb_downloader.py @@ -41,7 +41,7 @@ class QbDownloader: break except APIConnectionError: logger.error(f"Cannot connect to qBittorrent Server") - logger.info(f"Please check the IP and port in qBittorrent Server") + logger.info(f"Please check the IP and port in WebUI settings") time.sleep(30) except Exception as e: logger.error(f"Unknown error: {e}") diff --git a/src/module/models/user.py b/src/module/models/user.py index 4e976be3..253e7480 100644 --- a/src/module/models/user.py +++ b/src/module/models/user.py @@ -16,9 +16,8 @@ class UserUpdate(UserBase): class User(UserBase): - user_id: int + id: int = Field(..., alias="_id") password: str = Field(..., min_length=8) - id: str = Field(..., alias="_id") class UserInDB(UserBase): diff --git a/src/module/update/__init__.py b/src/module/update/__init__.py index e69de29b..0044a5be 100644 --- a/src/module/update/__init__.py +++ b/src/module/update/__init__.py @@ -0,0 +1 @@ +from .data_migration import data_migration diff --git a/src/module/update/data_migration.py b/src/module/update/data_migration.py index b0b7fecf..21dd72f6 100644 --- a/src/module/update/data_migration.py +++ b/src/module/update/data_migration.py @@ -18,4 +18,3 @@ def data_migration(): database.update_table() database.insert_list(new_data) os.remove("data/data.json") - return True