mirror of
https://github.com/EstrellaXD/Auto_Bangumi.git
synced 2026-04-14 02:20:53 +08:00
Refactor new program class
重构程序启动逻辑。 增加数据迁移后的海报搜索功能。 修复 AB 自身作为代理的情况无法启动的情况。
This commit is contained in:
@@ -2,33 +2,26 @@ import signal
|
||||
import logging
|
||||
import os
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi.exceptions import HTTPException
|
||||
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
from module.core import Program, check_status, check_rss, check_downloader
|
||||
from module.core import Program
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
program = Program()
|
||||
router = FastAPI()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(router: FastAPI):
|
||||
logger.info("Starting program...")
|
||||
@router.on_event("startup")
|
||||
async def startup():
|
||||
program.startup()
|
||||
yield
|
||||
program.stop()
|
||||
|
||||
|
||||
router = FastAPI(lifespan=lifespan)
|
||||
|
||||
|
||||
@router.on_event("shutdown")
|
||||
def shutdown():
|
||||
async def shutdown():
|
||||
program.stop()
|
||||
logger.info("Stopping program...")
|
||||
|
||||
|
||||
@router.get("/api/v1/restart", tags=["program"])
|
||||
@@ -78,15 +71,9 @@ async def shutdown_program():
|
||||
# Check status
|
||||
@router.get("/api/v1/check/downloader", tags=["check"])
|
||||
async def check_downloader_status():
|
||||
return check_downloader()
|
||||
return program.check_downloader()
|
||||
|
||||
|
||||
@router.get("/api/v1/check/rss", tags=["check"])
|
||||
async def check_rss_status():
|
||||
return check_rss()
|
||||
|
||||
|
||||
@router.get("/api/v1/check", tags=["check"])
|
||||
async def check_all():
|
||||
return check_status()
|
||||
|
||||
return program.check_analyser()
|
||||
|
||||
@@ -47,7 +47,8 @@ class Checker:
|
||||
|
||||
@staticmethod
|
||||
def check_first_run() -> bool:
|
||||
if os.path.exists(DATA_PATH):
|
||||
token_exist = False if settings.rss_parser.token in ["", "token"] else True
|
||||
if token_exist:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
@@ -29,10 +29,7 @@ class Settings(Config):
|
||||
self.load()
|
||||
self.save()
|
||||
else:
|
||||
# load from env
|
||||
load_dotenv(".env")
|
||||
self.__load_from_env()
|
||||
self.save()
|
||||
self.init()
|
||||
|
||||
def load(self):
|
||||
with open(CONFIG_PATH, "r", encoding="utf-8") as f:
|
||||
@@ -44,11 +41,14 @@ class Settings(Config):
|
||||
def save(self, config_dict: dict | None = None):
|
||||
if not config_dict:
|
||||
config_dict = self.dict()
|
||||
if not os.path.exists("config"):
|
||||
os.makedirs("config")
|
||||
with open(CONFIG_PATH, "w", encoding="utf-8") as f:
|
||||
json.dump(config_dict, f, indent=4)
|
||||
|
||||
def init(self):
|
||||
load_dotenv(".env")
|
||||
self.__load_from_env()
|
||||
self.save()
|
||||
|
||||
@property
|
||||
def rss_link(self) -> str:
|
||||
if "://" not in self.rss_parser.custom_url:
|
||||
|
||||
@@ -1,2 +1 @@
|
||||
from .program import Program
|
||||
from .check import check_status, check_rss, check_downloader
|
||||
from .program import Program
|
||||
@@ -1,45 +0,0 @@
|
||||
import logging
|
||||
|
||||
from module.conf import settings
|
||||
from module.downloader import DownloadClient
|
||||
from module.network import RequestContent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def check_status() -> bool:
|
||||
if settings.rss_parser.token in ["", "token"]:
|
||||
logger.warning("Please set RSS token")
|
||||
return False
|
||||
if check_downloader():
|
||||
logger.debug("All check passed")
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def check_downloader():
|
||||
with DownloadClient() as client:
|
||||
if client.authed:
|
||||
logger.debug("Downloader is running")
|
||||
return True
|
||||
else:
|
||||
logger.warning("Can't connect to downloader")
|
||||
return False
|
||||
|
||||
|
||||
def check_rss():
|
||||
rss_link = settings.rss_link()
|
||||
with RequestContent() as req:
|
||||
try:
|
||||
torrents = req.get_torrents(rss_link)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to get torrents from RSS")
|
||||
logger.warning(e)
|
||||
return False
|
||||
if not torrents:
|
||||
logger.warning("No torrents in RSS")
|
||||
logger.warning("Please check your RSS link")
|
||||
return False
|
||||
else:
|
||||
logger.debug("RSS is running")
|
||||
return True
|
||||
@@ -3,6 +3,7 @@ import logging
|
||||
from .sub_thread import RenameThread, RSSThread
|
||||
|
||||
from module.conf import settings, VERSION
|
||||
from module.update import data_migration
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -21,6 +22,14 @@ class Program(RenameThread, RSSThread):
|
||||
|
||||
def startup(self):
|
||||
self.__start_info()
|
||||
if self.first_run:
|
||||
logger.info("First run detected, please configure the program in webui.")
|
||||
return {"status": "First run detected."}
|
||||
if self.legacy_data:
|
||||
logger.info(
|
||||
"Legacy data detected, starting data migration, please wait patiently."
|
||||
)
|
||||
data_migration()
|
||||
self.start()
|
||||
|
||||
def start(self):
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import os.path
|
||||
import threading
|
||||
import asyncio
|
||||
|
||||
from module.checker import Checker
|
||||
from module.conf import DATA_PATH
|
||||
|
||||
|
||||
class ProgramStatus(Checker):
|
||||
@@ -44,3 +46,7 @@ class ProgramStatus(Checker):
|
||||
@property
|
||||
def first_run(self):
|
||||
return self.check_first_run()
|
||||
|
||||
@property
|
||||
def legacy_data(self):
|
||||
return os.path.exists("data/data.json")
|
||||
|
||||
@@ -10,6 +10,7 @@ class BangumiDatabase(DataConnector):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.__table_name = "bangumi"
|
||||
self.update_table()
|
||||
|
||||
def update_table(self):
|
||||
db_data = self.__data_to_db(BangumiData())
|
||||
@@ -55,16 +56,11 @@ class BangumiDatabase(DataConnector):
|
||||
|
||||
def update_one(self, data: BangumiData) -> bool:
|
||||
db_data = self.__data_to_db(data)
|
||||
update_columns = ", ".join([f"{key} = :{key}" for key in db_data.keys() if key != "id"])
|
||||
self._cursor.execute(f"UPDATE bangumi SET {update_columns} WHERE id = :id", db_data)
|
||||
self._conn.commit()
|
||||
return self._cursor.rowcount == 1
|
||||
return self._update(db_data=db_data, table_name=self.__table_name)
|
||||
|
||||
def update_list(self, data: list[BangumiData]):
|
||||
db_data = [self.__data_to_db(x) for x in data]
|
||||
update_columns = ", ".join([f"{key} = :{key}" for key in db_data[0].keys() if key != "id"])
|
||||
self._cursor.executemany(f"UPDATE bangumi SET {update_columns} WHERE id = :id", db_data)
|
||||
self._conn.commit()
|
||||
data_list = [self.__data_to_db(x) for x in data]
|
||||
self._update_list(data_list=data_list, table_name=self.__table_name)
|
||||
|
||||
def update_rss(self, title_raw, rss_set: str):
|
||||
# Update rss and added
|
||||
@@ -79,6 +75,18 @@ class BangumiDatabase(DataConnector):
|
||||
self._conn.commit()
|
||||
logger.debug(f"Update {title_raw} rss_link to {rss_set}.")
|
||||
|
||||
def update_poster(self, title_raw, poster_link: str):
|
||||
self._cursor.execute(
|
||||
"""
|
||||
UPDATE bangumi
|
||||
SET poster_link = :poster_link
|
||||
WHERE title_raw = :title_raw
|
||||
""",
|
||||
{"poster_link": poster_link, "title_raw": title_raw},
|
||||
)
|
||||
self._conn.commit()
|
||||
logger.debug(f"Update {title_raw} poster_link to {poster_link}.")
|
||||
|
||||
def delete_one(self, _id: int) -> bool:
|
||||
self._cursor.execute(
|
||||
"""
|
||||
@@ -130,26 +138,28 @@ class BangumiDatabase(DataConnector):
|
||||
return self.__db_to_data(dict_data)
|
||||
|
||||
def match_poster(self, bangumi_name: str) -> str:
|
||||
# Find title_raw which in torrent_name
|
||||
self._cursor.execute(
|
||||
"""
|
||||
SELECT official_title, poster_link FROM bangumi
|
||||
SELECT official_title, poster_link
|
||||
FROM bangumi
|
||||
WHERE INSTR(:bangumi_name, official_title) > 0
|
||||
"""
|
||||
,
|
||||
{"bangumi_name": bangumi_name},
|
||||
)
|
||||
data = self._cursor.fetchall()
|
||||
data = self._cursor.fetchone()
|
||||
if not data:
|
||||
return ""
|
||||
for official_title, poster_link in data:
|
||||
if official_title in bangumi_name:
|
||||
if poster_link:
|
||||
return poster_link
|
||||
return ""
|
||||
official_title, poster_link = data
|
||||
if not poster_link:
|
||||
return ""
|
||||
return poster_link
|
||||
|
||||
def match_list(self, torrent_list: list, rss_link: str) -> list:
|
||||
# Match title_raw in database
|
||||
self._cursor.execute(
|
||||
"""
|
||||
SELECT title_raw, rss_link FROM bangumi
|
||||
SELECT title_raw, rss_link, poster_link FROM bangumi
|
||||
"""
|
||||
)
|
||||
data = self._cursor.fetchall()
|
||||
@@ -159,11 +169,13 @@ class BangumiDatabase(DataConnector):
|
||||
i = 0
|
||||
while i < len(torrent_list):
|
||||
torrent = torrent_list[i]
|
||||
for title_raw, rss_set in data:
|
||||
for title_raw, rss_set, poster_link in data:
|
||||
if title_raw in torrent.name:
|
||||
if rss_link not in rss_set:
|
||||
rss_set += "," + rss_link
|
||||
self.update_rss(title_raw, rss_set)
|
||||
if not poster_link:
|
||||
self.update_poster(title_raw, torrent.poster_link)
|
||||
torrent_list.pop(i)
|
||||
break
|
||||
else:
|
||||
@@ -197,9 +209,3 @@ class BangumiDatabase(DataConnector):
|
||||
if data is None:
|
||||
return 1
|
||||
return data[0] + 1
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
title = "[SweetSub&LoliHouse] Heavenly Delusion - 06 [WebRip 1080p HEVC-10bit AAC ASSx2].mkv"
|
||||
with BangumiDatabase() as db:
|
||||
print(db.match_poster(title))
|
||||
@@ -40,6 +40,35 @@ class DataConnector:
|
||||
self._cursor.executemany(f"INSERT INTO {table_name} ({columns}) VALUES ({values})", data_list)
|
||||
self._conn.commit()
|
||||
|
||||
def _select(self, keys: list[str], table_name: str, condition: str = None) -> dict:
|
||||
if condition is None:
|
||||
self._cursor.execute(f"SELECT {', '.join(keys)} FROM {table_name}")
|
||||
else:
|
||||
self._cursor.execute(f"SELECT {', '.join(keys)} FROM {table_name} WHERE {condition}")
|
||||
return dict(zip(keys, self._cursor.fetchone()))
|
||||
|
||||
def _update(self, table_name: str, db_data: dict):
|
||||
_id = db_data.get("id")
|
||||
if _id is None:
|
||||
raise ValueError("No _id in db_data.")
|
||||
set_sql = ", ".join([f"{key} = :{key}" for key in db_data.keys()])
|
||||
self._cursor.execute(f"UPDATE {table_name} SET {set_sql} WHERE id = {_id}", db_data)
|
||||
self._conn.commit()
|
||||
return self._cursor.rowcount == 1
|
||||
|
||||
def _update_list(self, table_name: str, data_list: list[dict]):
|
||||
if len(data_list) == 0:
|
||||
return
|
||||
set_sql = ", ".join([f"{key} = :{key}" for key in data_list[0].keys() if key != "id"])
|
||||
self._cursor.executemany(f"UPDATE {table_name} SET {set_sql} WHERE id = :id", data_list)
|
||||
self._conn.commit()
|
||||
|
||||
def _update_section(self, table_name: str, location: dict, update_dict: dict):
|
||||
set_sql = ", ".join([f"{key} = :{key}" for key in update_dict.keys()])
|
||||
sql_loc = f"{location['key']} = {location['value']}"
|
||||
self._cursor.execute(f"UPDATE {table_name} SET {set_sql} WHERE {sql_loc}", update_dict)
|
||||
self._conn.commit()
|
||||
|
||||
def _delete_all(self, table_name: str):
|
||||
self._cursor.execute(f"DELETE FROM {table_name}")
|
||||
self._conn.commit()
|
||||
|
||||
@@ -41,7 +41,7 @@ class QbDownloader:
|
||||
break
|
||||
except APIConnectionError:
|
||||
logger.error(f"Cannot connect to qBittorrent Server")
|
||||
logger.info(f"Please check the IP and port in qBittorrent Server")
|
||||
logger.info(f"Please check the IP and port in WebUI settings")
|
||||
time.sleep(30)
|
||||
except Exception as e:
|
||||
logger.error(f"Unknown error: {e}")
|
||||
|
||||
@@ -16,9 +16,8 @@ class UserUpdate(UserBase):
|
||||
|
||||
|
||||
class User(UserBase):
|
||||
user_id: int
|
||||
id: int = Field(..., alias="_id")
|
||||
password: str = Field(..., min_length=8)
|
||||
id: str = Field(..., alias="_id")
|
||||
|
||||
|
||||
class UserInDB(UserBase):
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
from .data_migration import data_migration
|
||||
|
||||
@@ -18,4 +18,3 @@ def data_migration():
|
||||
database.update_table()
|
||||
database.insert_list(new_data)
|
||||
os.remove("data/data.json")
|
||||
return True
|
||||
|
||||
Reference in New Issue
Block a user