Refactor new program class

重构程序启动逻辑。
增加数据迁移后的海报搜索功能。
修复 AB 自身作为代理的情况无法启动的情况。
This commit is contained in:
EstrellaXD
2023-05-18 16:43:51 +08:00
parent 710cb40e41
commit 306aef0e39
14 changed files with 93 additions and 102 deletions

View File

@@ -2,33 +2,26 @@ import signal
import logging
import os
from contextlib import asynccontextmanager
from fastapi.exceptions import HTTPException
from fastapi import FastAPI
from module.core import Program, check_status, check_rss, check_downloader
from module.core import Program
logger = logging.getLogger(__name__)
program = Program()
router = FastAPI()
@asynccontextmanager
async def lifespan(router: FastAPI):
logger.info("Starting program...")
@router.on_event("startup")
async def startup():
program.startup()
yield
program.stop()
router = FastAPI(lifespan=lifespan)
@router.on_event("shutdown")
def shutdown():
async def shutdown():
program.stop()
logger.info("Stopping program...")
@router.get("/api/v1/restart", tags=["program"])
@@ -78,15 +71,9 @@ async def shutdown_program():
# Check status
@router.get("/api/v1/check/downloader", tags=["check"])
async def check_downloader_status():
return check_downloader()
return program.check_downloader()
@router.get("/api/v1/check/rss", tags=["check"])
async def check_rss_status():
return check_rss()
@router.get("/api/v1/check", tags=["check"])
async def check_all():
return check_status()
return program.check_analyser()

View File

@@ -47,7 +47,8 @@ class Checker:
@staticmethod
def check_first_run() -> bool:
if os.path.exists(DATA_PATH):
token_exist = False if settings.rss_parser.token in ["", "token"] else True
if token_exist:
return False
else:
return True

View File

@@ -29,10 +29,7 @@ class Settings(Config):
self.load()
self.save()
else:
# load from env
load_dotenv(".env")
self.__load_from_env()
self.save()
self.init()
def load(self):
with open(CONFIG_PATH, "r", encoding="utf-8") as f:
@@ -44,11 +41,14 @@ class Settings(Config):
def save(self, config_dict: dict | None = None):
if not config_dict:
config_dict = self.dict()
if not os.path.exists("config"):
os.makedirs("config")
with open(CONFIG_PATH, "w", encoding="utf-8") as f:
json.dump(config_dict, f, indent=4)
def init(self):
load_dotenv(".env")
self.__load_from_env()
self.save()
@property
def rss_link(self) -> str:
if "://" not in self.rss_parser.custom_url:

View File

@@ -1,2 +1 @@
from .program import Program
from .check import check_status, check_rss, check_downloader
from .program import Program

View File

@@ -1,45 +0,0 @@
import logging
from module.conf import settings
from module.downloader import DownloadClient
from module.network import RequestContent
logger = logging.getLogger(__name__)
def check_status() -> bool:
if settings.rss_parser.token in ["", "token"]:
logger.warning("Please set RSS token")
return False
if check_downloader():
logger.debug("All check passed")
return True
return False
def check_downloader():
with DownloadClient() as client:
if client.authed:
logger.debug("Downloader is running")
return True
else:
logger.warning("Can't connect to downloader")
return False
def check_rss():
rss_link = settings.rss_link()
with RequestContent() as req:
try:
torrents = req.get_torrents(rss_link)
except Exception as e:
logger.warning("Failed to get torrents from RSS")
logger.warning(e)
return False
if not torrents:
logger.warning("No torrents in RSS")
logger.warning("Please check your RSS link")
return False
else:
logger.debug("RSS is running")
return True

View File

@@ -3,6 +3,7 @@ import logging
from .sub_thread import RenameThread, RSSThread
from module.conf import settings, VERSION
from module.update import data_migration
logger = logging.getLogger(__name__)
@@ -21,6 +22,14 @@ class Program(RenameThread, RSSThread):
def startup(self):
self.__start_info()
if self.first_run:
logger.info("First run detected, please configure the program in webui.")
return {"status": "First run detected."}
if self.legacy_data:
logger.info(
"Legacy data detected, starting data migration, please wait patiently."
)
data_migration()
self.start()
def start(self):

View File

@@ -1,7 +1,9 @@
import os.path
import threading
import asyncio
from module.checker import Checker
from module.conf import DATA_PATH
class ProgramStatus(Checker):
@@ -44,3 +46,7 @@ class ProgramStatus(Checker):
@property
def first_run(self):
return self.check_first_run()
@property
def legacy_data(self):
return os.path.exists("data/data.json")

View File

@@ -10,6 +10,7 @@ class BangumiDatabase(DataConnector):
def __init__(self):
super().__init__()
self.__table_name = "bangumi"
self.update_table()
def update_table(self):
db_data = self.__data_to_db(BangumiData())
@@ -55,16 +56,11 @@ class BangumiDatabase(DataConnector):
def update_one(self, data: BangumiData) -> bool:
db_data = self.__data_to_db(data)
update_columns = ", ".join([f"{key} = :{key}" for key in db_data.keys() if key != "id"])
self._cursor.execute(f"UPDATE bangumi SET {update_columns} WHERE id = :id", db_data)
self._conn.commit()
return self._cursor.rowcount == 1
return self._update(db_data=db_data, table_name=self.__table_name)
def update_list(self, data: list[BangumiData]):
db_data = [self.__data_to_db(x) for x in data]
update_columns = ", ".join([f"{key} = :{key}" for key in db_data[0].keys() if key != "id"])
self._cursor.executemany(f"UPDATE bangumi SET {update_columns} WHERE id = :id", db_data)
self._conn.commit()
data_list = [self.__data_to_db(x) for x in data]
self._update_list(data_list=data_list, table_name=self.__table_name)
def update_rss(self, title_raw, rss_set: str):
# Update rss and added
@@ -79,6 +75,18 @@ class BangumiDatabase(DataConnector):
self._conn.commit()
logger.debug(f"Update {title_raw} rss_link to {rss_set}.")
def update_poster(self, title_raw, poster_link: str):
self._cursor.execute(
"""
UPDATE bangumi
SET poster_link = :poster_link
WHERE title_raw = :title_raw
""",
{"poster_link": poster_link, "title_raw": title_raw},
)
self._conn.commit()
logger.debug(f"Update {title_raw} poster_link to {poster_link}.")
def delete_one(self, _id: int) -> bool:
self._cursor.execute(
"""
@@ -130,26 +138,28 @@ class BangumiDatabase(DataConnector):
return self.__db_to_data(dict_data)
def match_poster(self, bangumi_name: str) -> str:
# Find title_raw which in torrent_name
self._cursor.execute(
"""
SELECT official_title, poster_link FROM bangumi
SELECT official_title, poster_link
FROM bangumi
WHERE INSTR(:bangumi_name, official_title) > 0
"""
,
{"bangumi_name": bangumi_name},
)
data = self._cursor.fetchall()
data = self._cursor.fetchone()
if not data:
return ""
for official_title, poster_link in data:
if official_title in bangumi_name:
if poster_link:
return poster_link
return ""
official_title, poster_link = data
if not poster_link:
return ""
return poster_link
def match_list(self, torrent_list: list, rss_link: str) -> list:
# Match title_raw in database
self._cursor.execute(
"""
SELECT title_raw, rss_link FROM bangumi
SELECT title_raw, rss_link, poster_link FROM bangumi
"""
)
data = self._cursor.fetchall()
@@ -159,11 +169,13 @@ class BangumiDatabase(DataConnector):
i = 0
while i < len(torrent_list):
torrent = torrent_list[i]
for title_raw, rss_set in data:
for title_raw, rss_set, poster_link in data:
if title_raw in torrent.name:
if rss_link not in rss_set:
rss_set += "," + rss_link
self.update_rss(title_raw, rss_set)
if not poster_link:
self.update_poster(title_raw, torrent.poster_link)
torrent_list.pop(i)
break
else:
@@ -197,9 +209,3 @@ class BangumiDatabase(DataConnector):
if data is None:
return 1
return data[0] + 1
if __name__ == '__main__':
title = "[SweetSub&LoliHouse] Heavenly Delusion - 06 [WebRip 1080p HEVC-10bit AAC ASSx2].mkv"
with BangumiDatabase() as db:
print(db.match_poster(title))

View File

@@ -40,6 +40,35 @@ class DataConnector:
self._cursor.executemany(f"INSERT INTO {table_name} ({columns}) VALUES ({values})", data_list)
self._conn.commit()
def _select(self, keys: list[str], table_name: str, condition: str = None) -> dict:
if condition is None:
self._cursor.execute(f"SELECT {', '.join(keys)} FROM {table_name}")
else:
self._cursor.execute(f"SELECT {', '.join(keys)} FROM {table_name} WHERE {condition}")
return dict(zip(keys, self._cursor.fetchone()))
def _update(self, table_name: str, db_data: dict):
_id = db_data.get("id")
if _id is None:
raise ValueError("No _id in db_data.")
set_sql = ", ".join([f"{key} = :{key}" for key in db_data.keys()])
self._cursor.execute(f"UPDATE {table_name} SET {set_sql} WHERE id = {_id}", db_data)
self._conn.commit()
return self._cursor.rowcount == 1
def _update_list(self, table_name: str, data_list: list[dict]):
if len(data_list) == 0:
return
set_sql = ", ".join([f"{key} = :{key}" for key in data_list[0].keys() if key != "id"])
self._cursor.executemany(f"UPDATE {table_name} SET {set_sql} WHERE id = :id", data_list)
self._conn.commit()
def _update_section(self, table_name: str, location: dict, update_dict: dict):
set_sql = ", ".join([f"{key} = :{key}" for key in update_dict.keys()])
sql_loc = f"{location['key']} = {location['value']}"
self._cursor.execute(f"UPDATE {table_name} SET {set_sql} WHERE {sql_loc}", update_dict)
self._conn.commit()
def _delete_all(self, table_name: str):
self._cursor.execute(f"DELETE FROM {table_name}")
self._conn.commit()

View File

@@ -41,7 +41,7 @@ class QbDownloader:
break
except APIConnectionError:
logger.error(f"Cannot connect to qBittorrent Server")
logger.info(f"Please check the IP and port in qBittorrent Server")
logger.info(f"Please check the IP and port in WebUI settings")
time.sleep(30)
except Exception as e:
logger.error(f"Unknown error: {e}")

View File

@@ -16,9 +16,8 @@ class UserUpdate(UserBase):
class User(UserBase):
user_id: int
id: int = Field(..., alias="_id")
password: str = Field(..., min_length=8)
id: str = Field(..., alias="_id")
class UserInDB(UserBase):

View File

@@ -0,0 +1 @@
from .data_migration import data_migration

View File

@@ -18,4 +18,3 @@ def data_migration():
database.update_table()
database.insert_list(new_data)
os.remove("data/data.json")
return True