From 6bd6306dc653dc0067963305fd920b97c5f5aad4 Mon Sep 17 00:00:00 2001 From: EstrellaXD Date: Wed, 14 Jun 2023 20:39:19 +0800 Subject: [PATCH] feat: new orm module --- backend/src/module/database/bangumi.py | 212 ++++++++----------- backend/src/module/database/orm/connector.py | 19 +- backend/src/module/database/orm/delete.py | 11 +- backend/src/module/database/orm/insert.py | 34 +-- backend/src/module/database/orm/search.py | 48 ----- backend/src/module/database/orm/select.py | 90 ++++++++ backend/src/module/database/orm/update.py | 50 +++-- docs/wiki | 2 +- 8 files changed, 248 insertions(+), 218 deletions(-) delete mode 100644 backend/src/module/database/orm/search.py create mode 100644 backend/src/module/database/orm/select.py diff --git a/backend/src/module/database/bangumi.py b/backend/src/module/database/bangumi.py index ecd8126f..ad3a273a 100644 --- a/backend/src/module/database/bangumi.py +++ b/backend/src/module/database/bangumi.py @@ -1,5 +1,7 @@ import logging +from module.database.orm import Connector + from module.ab_decorator import locked from module.database.connector import DataConnector from module.models import BangumiData @@ -7,14 +9,12 @@ from module.models import BangumiData logger = logging.getLogger(__name__) -class BangumiDatabase(DataConnector): +class BangumiDatabase(Connector): def __init__(self): - super().__init__() - self.__table_name = "bangumi" + super().__init__(table_name="bangumi", data=self.__data_to_db(BangumiData())) def update_table(self): - db_data = self.__data_to_db(BangumiData()) - self._update_table(self.__table_name, db_data) + self.update.table() @staticmethod def __data_to_db(data: BangumiData) -> dict: @@ -36,120 +36,118 @@ class BangumiDatabase(DataConnector): db_data[key] = item.split(",") return BangumiData(**db_data) - def __fetch_data(self) -> list[BangumiData]: - values = self._cursor.fetchall() - if values is None: - return [] - keys = [x[0] for x in self._cursor.description] - dict_data = [dict(zip(keys, value)) for value in values] - return [self.__db_to_data(x) for x in dict_data] - - def insert(self, data: BangumiData): - if self.__check_exist(data): - self.update_one(data) - else: - db_data = self.__data_to_db(data) - db_data["id"] = self.gen_id() - self._insert(db_data=db_data, table_name=self.__table_name) - logger.debug(f"[Database] Insert {data.official_title} into database.") + def insert_one(self, data: BangumiData): + db_data = self.__data_to_db(data) + self.insert.one(db_data) + logger.debug(f"[Database] Insert {data.official_title} into database.") + # if self.__check_exist(data): + # self.update_one(data) + # else: + # db_data = self.__data_to_db(data) + # db_data["id"] = self.gen_id() + # self._insert(db_data=db_data, table_name=self.__table_name) + # logger.debug(f"[Database] Insert {data.official_title} into database.") def insert_list(self, data: list[BangumiData]): - _id = self.gen_id() - for i, item in enumerate(data): - item.id = _id + i data_list = [self.__data_to_db(x) for x in data] - self._insert_list(data_list=data_list, table_name=self.__table_name) + self.insert.many(data_list) + # _id = self.gen_id() + # for i, item in enumerate(data): + # item.id = _id + i + # data_list = [self.__data_to_db(x) for x in data] + # self._insert_list(data_list=data_list, table_name=self.__table_name) logger.debug(f"[Database] Insert {len(data)} bangumi into database.") def update_one(self, data: BangumiData) -> bool: db_data = self.__data_to_db(data) - return self._update(db_data=db_data, table_name=self.__table_name) + return self.update.one(db_data) def update_list(self, data: list[BangumiData]): data_list = [self.__data_to_db(x) for x in data] - self._update_list(data_list=data_list, table_name=self.__table_name) + self.update.many(data_list) - @locked def update_rss(self, title_raw, rss_set: str): # Update rss and added - self._cursor.execute( - """ - UPDATE bangumi - SET rss_link = :rss_link, added = 0 - WHERE title_raw = :title_raw - """, - {"rss_link": rss_set, "title_raw": title_raw}, - ) - self._conn.commit() + location = {"title_raw": title_raw} + set_value = {"rss_link": rss_set, "added": 0} + self.update.value(location, set_value) + # self._cursor.execute( + # """ + # UPDATE bangumi + # SET rss_link = :rss_link, added = 0 + # WHERE title_raw = :title_raw + # """, + # {"rss_link": rss_set, "title_raw": title_raw}, + # ) + # self._conn.commit() logger.debug(f"[Database] Update {title_raw} rss_link to {rss_set}.") def update_poster(self, title_raw, poster_link: str): - self._cursor.execute( - """ - UPDATE bangumi - SET poster_link = :poster_link - WHERE title_raw = :title_raw - """, - {"poster_link": poster_link, "title_raw": title_raw}, - ) - self._conn.commit() + location = {"title_raw": title_raw} + set_value = {"poster_link": poster_link} + self.update.value(location, set_value) + # self._cursor.execute( + # """ + # UPDATE bangumi + # SET poster_link = :poster_link + # WHERE title_raw = :title_raw + # """, + # {"poster_link": poster_link, "title_raw": title_raw}, + # ) + # self._conn.commit() logger.debug(f"[Database] Update {title_raw} poster_link to {poster_link}.") - def delete_one(self, _id: int) -> bool: - self._cursor.execute( - """ - DELETE FROM bangumi WHERE id = :id - """, - {"id": _id}, - ) - self._conn.commit() + def delete_one(self, _id: int): + self.delete.one(_id) + # self._cursor.execute( + # """ + # DELETE FROM bangumi WHERE id = :id + # """, + # {"id": _id}, + # ) + # self._conn.commit() logger.debug(f"[Database] Delete bangumi id: {_id}.") - return self._cursor.rowcount == 1 def delete_all(self): - self._delete_all(self.__table_name) + self.delete.all() def search_all(self) -> list[BangumiData]: - dict_data = self._search_datas(self.__table_name) - return [self.__db_to_data(x) for x in dict_data] + all_data = self.select.all() + return [self.__db_to_data(x) for x in all_data] def search_id(self, _id: int) -> BangumiData | None: - condition = {"id": _id} - dict_data = self._search_data(table_name=self.__table_name, condition=condition) + dict_data = self.select.one(_id) + # condition = {"id": _id} + # dict_data = self._search_data(table_name=self.__table_name, condition=condition) if dict_data is None: + logger.warning(f"[Database] Cannot find bangumi id: {_id}.") return None + logger.debug(f"[Database] Find bangumi id: {_id}.") return self.__db_to_data(dict_data) - def search_official_title(self, official_title: str) -> BangumiData | None: - dict_data = self._search_data( - table_name=self.__table_name, condition={"official_title": official_title} - ) - if dict_data is None: - return None - return self.__db_to_data(dict_data) + # def search_official_title(self, official_title: str) -> BangumiData | None: + # dict_data = self._search_data( + # table_name=self.__table_name, condition={"official_title": official_title} + # ) + # if dict_data is None: + # return None + # return self.__db_to_data(dict_data) def match_poster(self, bangumi_name: str) -> str: - condition = {"_custom_condition": "INSTR(:bangumi_name, official_title) > 0"} - keys = ["official_title", "poster_link"] - data = self._search_data( - table_name=self.__table_name, + condition = {"title_raw": bangumi_name} + keys = ["poster_link"] + data = self.select.one( keys=keys, condition=condition, ) if not data: return "" - official_title, poster_link = data - if not poster_link: - return "" - return poster_link + return data.get("poster_link") def match_list(self, torrent_list: list, rss_link: str) -> list: # Match title_raw in database keys = ["title_raw", "rss_link", "poster_link"] - match_datas = self._search_datas( - table_name=self.__table_name, - keys=keys, - ) + match_datas = self.select.column(keys) if not match_datas: return torrent_list # Match title @@ -160,9 +158,13 @@ class BangumiDatabase(DataConnector): if match_data.get("title_raw") in torrent.name: if rss_link not in match_data.get("rss_link"): match_data["rss_link"] += f",{rss_link}" - self.update_rss(match_data.get("title_raw"), match_data.get("rss_link")) + self.update_rss( + match_data.get("title_raw"), match_data.get("rss_link") + ) if not match_data.get("poster_link"): - self.update_poster(match_data.get("title_raw"), torrent.poster_link) + self.update_poster( + match_data.get("title_raw"), torrent.poster_link + ) torrent_list.pop(i) break else: @@ -172,50 +174,20 @@ class BangumiDatabase(DataConnector): def not_complete(self) -> list[BangumiData]: # Find eps_complete = False condition = {"eps_collect": 0} - dict_data = self._search_datas( - table_name=self.__table_name, - condition=condition, + dict_data = self.select.many( + conditions=condition, ) return [self.__db_to_data(x) for x in dict_data] def not_added(self) -> list[BangumiData]: - condition = {"added": 0, "rule_name": None, "save_path": None} - dict_data = self._search_datas( - table_name=self.__table_name, - condition=condition, - ) + conditions = {"added": 0, "rule_name": None, "save_path": None} + dict_data = self.select.many(conditions=conditions, combine_operator="OR") return [self.__db_to_data(x) for x in dict_data] - def gen_id(self) -> int: - self._cursor.execute( - """ - SELECT id FROM bangumi ORDER BY id DESC LIMIT 1 - """ - ) - data = self._cursor.fetchone() - if data is None: - return 1 - return data[0] + 1 - def __check_exist(self, data: BangumiData): - self._cursor.execute( - """ - SELECT * FROM bangumi WHERE official_title = :official_title - """, - {"official_title": data.official_title}, - ) - values = self._cursor.fetchone() - if values is None: - return False - return True - - def __check_list_exist(self, data_list: list[BangumiData]): - for data in data_list: - if self.__check_exist(data): - return True - return False - -if __name__ == '__main__': +if __name__ == "__main__": with BangumiDatabase() as db: - print(db.not_added()) - print(db.not_complete()) \ No newline at end of file + datas = db.not_added() + for data in datas: + print(data) + # print(db.not_complete()) diff --git a/backend/src/module/database/orm/connector.py b/backend/src/module/database/orm/connector.py index f55f4919..26224e71 100644 --- a/backend/src/module/database/orm/connector.py +++ b/backend/src/module/database/orm/connector.py @@ -9,13 +9,14 @@ from module.conf import DATA_PATH class Connector: - def __init__(self, database: str = DATA_PATH, table_name: str = None, data: dict = None): + def __init__(self, table_name: str, data: dict, database: str = DATA_PATH): self._conn = sqlite3.connect(database) self._cursor = self._conn.cursor() self.update = Update(self, table_name, data) self.insert = Insert(self, table_name, data) self.select = Select(self, table_name, data) self.delete = Delete(self, table_name, data) + self._columns = self.__get_columns(table_name) def __enter__(self): return self @@ -23,6 +24,10 @@ class Connector: def __exit__(self, exc_type, exc_value, traceback): self._conn.close() + def __get_columns(self, table_name: str) -> list[str]: + self._cursor.execute(f"PRAGMA table_info({table_name})") + return [x[1] for x in self._cursor.fetchall()] + def execute(self, sql: str, params: tuple = None): if params is None: self._cursor.execute(sql) @@ -34,11 +39,15 @@ class Connector: self._cursor.executemany(sql, params) self._conn.commit() - def fetchall(self): - return self._cursor.fetchall() + def fetchall(self) -> dict: + datas = self._cursor.fetchall() + for data in datas: + yield dict(zip(self._columns, data)) def fetchone(self): - return self._cursor.fetchone() + return dict(zip(self._columns, self._cursor.fetchone())) def fetchmany(self, size: int): - return self._cursor.fetchmany(size) \ No newline at end of file + datas = self._cursor.fetchmany(size) + for data in datas: + yield dict(zip(self._columns, data)) diff --git a/backend/src/module/database/orm/delete.py b/backend/src/module/database/orm/delete.py index cf7ce81b..aa5a705b 100644 --- a/backend/src/module/database/orm/delete.py +++ b/backend/src/module/database/orm/delete.py @@ -1,13 +1,11 @@ - - class Delete: - def __init__(self, connector: Connector, table_name: str, data: dict): - self.db = connector + def __init__(self, connector, table_name: str, data: dict): + self._connector = connector self._table_name = table_name self._data = data def one(self, _id: int) -> bool: - self.db.execute( + self._connector.execute( f""" DELETE FROM {self._table_name} WHERE id = :id @@ -17,10 +15,9 @@ class Delete: return True def all(self): - self.db.execute( + self._connector.execute( f""" DELETE FROM {self._table_name} """, ) return True - diff --git a/backend/src/module/database/orm/insert.py b/backend/src/module/database/orm/insert.py index 8f0986b4..f5d7c1c2 100644 --- a/backend/src/module/database/orm/insert.py +++ b/backend/src/module/database/orm/insert.py @@ -1,43 +1,31 @@ -from .connector import Connector - - class Insert: - def __init__(self, db: Connector, table_name: str, data: dict): - self.db = db + def __init__(self, connector, table_name: str, data: dict): + self._connector = connector self._table_name = table_name self._columns = data.items() def __gen_id(self) -> int: - self.db.execute(f"SELECT MAX(id) FROM {self._table_name}") - max_id = self.db.fetchone()[0] + self._connector.execute(f"SELECT MAX(id) FROM {self._table_name}") + max_id = self._connector.fetchone()[0] if max_id is None: return 1 return max_id + 1 - def one(self, data: dict) -> bool: + def one(self, data: dict): + if data["id"] is not None: + raise ValueError("id must be None") _id = self.__gen_id() data["id"] = _id columns = ", ".join(data.keys()) placeholders = ", ".join([f":{key}" for key in data.keys()]) - self.db.execute( + self._connector.execute( f""" INSERT INTO {self._table_name} ({columns}) VALUES ({placeholders}) """, data, ) - return True - - def list(self, data: list[dict]): - columns = ", ".join(data[0].keys()) - placeholders = ", ".join([f":{key}" for key in data[0].keys()]) - self.db.executemany( - f""" - INSERT INTO {self._table_name} ({columns}) - VALUES ({placeholders}) - """, - data, - ) - return True - + def many(self, data: list[dict]): + for item in data: + self.one(item) diff --git a/backend/src/module/database/orm/search.py b/backend/src/module/database/orm/search.py deleted file mode 100644 index f0e5d05d..00000000 --- a/backend/src/module/database/orm/search.py +++ /dev/null @@ -1,48 +0,0 @@ - -class Select: - def __init__(self, connector: Connector, table_name: str, data: dict): - self._connector = connector - self._table_name = table_name - self._data = data - - def id(self, _id: int): - self._connector.execute( - f""" - SELECT * FROM {self._table_name} - WHERE id = :id - """, - {"id": _id}, - ) - return self._connector.fetchone() - - def all(self): - self._connector.execute( - f""" - SELECT * FROM {self._table_name} - """, - ) - return self._connector.fetchall() - - def one(self, keys: list[str], values: list[str]): - columns = ", ".join(keys) - placeholders = ", ".join([f":{key}" for key in keys]) - self._connector.execute( - f""" - SELECT {columns} FROM {self._table_name} - WHERE {placeholders} - """, - dict(zip(keys, values)), - ) - return self._connector.fetchone() - - def list(self, keys: list[str], values: list[str]): - columns = ", ".join(keys) - placeholders = ", ".join([f":{key}" for key in keys]) - self._connector.execute( - f""" - SELECT {columns} FROM {self._table_name} - WHERE {placeholders} - """, - dict(zip(keys, values)), - ) - return self._connector.fetchall() \ No newline at end of file diff --git a/backend/src/module/database/orm/select.py b/backend/src/module/database/orm/select.py new file mode 100644 index 00000000..45e555ae --- /dev/null +++ b/backend/src/module/database/orm/select.py @@ -0,0 +1,90 @@ +class Select: + def __init__(self, connector, table_name: str, data: dict): + self._connector = connector + self._table_name = table_name + self._data = data + + def id(self, _id: int): + self._connector.execute( + f""" + SELECT * FROM {self._table_name} + WHERE id = :id + """, + {"id": _id}, + ) + return self._connector.fetchone() + + def all(self, limit: int = None): + if limit is None: + limit = 10000 + self._connector.execute( + f""" + SELECT * FROM {self._table_name} LIMIT {limit} + """, + ) + return self._connector.fetchall() + + def one( + self, + keys: list[str] | None = None, + conditions: dict = None, + combine_operator: str = "AND", + ): + if keys is None: + keys = ["*"] + columns = ", ".join(keys) + + self._connector.execute( + f""" + SELECT {columns} FROM {self._table_name} + WHERE {condition_sql} + """, + conditions, + ) + return self._connector.fetchone() + + def many( + self, + keys: list[str] | None = None, + conditions: dict = None, + combine_operator: str = "AND", + limit: int = None, + ): + if keys is None: + keys = ["*"] + if limit is None: + limit = 10000 + columns = ", ".join(keys) + condition_sql = self.__select_condition(conditions, combine_operator) + self._connector.execute( + f""" + SELECT {columns} FROM {self._table_name} + WHERE {condition_sql} + LIMIT {limit} + """, + conditions, + ) + return self._connector.fetchall() + + def column(self, keys: list[str]): + columns = ", ".join(keys) + self._connector.execute( + f""" + SELECT {columns} FROM {self._table_name} + """, + ) + return self._connector.fetchall() + + @staticmethod + def __select_condition(conditions: dict, combine_operator: str = "AND"): + if not conditions: + raise ValueError("No conditions provided.") + if combine_operator not in ["AND", "OR", "INSTR"]: + raise ValueError("Invalid combine_operator, must be 'AND' or 'OR'.") + if combine_operator == "INSTR": + condition_sql = f" {combine_operator} {' AND '.join([f'({key} = :{key})' for key in conditions.keys()])}" + else: + condition_sql = f" {combine_operator} ".join( + [f"{key} = :{key}" for key in conditions.keys()] + ) + return condition_sql diff --git a/backend/src/module/database/orm/update.py b/backend/src/module/database/orm/update.py index 2a90c9e3..39682774 100644 --- a/backend/src/module/database/orm/update.py +++ b/backend/src/module/database/orm/update.py @@ -1,12 +1,11 @@ import logging -from .connector import Connector logger = logging.getLogger(__name__) class Update: - def __init__(self, db: Connector, table_name: str, data: dict): - self.db = db + def __init__(self, connector, table_name: str, data: dict): + self._connector = connector self._table_name = table_name self._columns = data.items() @@ -18,10 +17,10 @@ class Update: ] ) create_table_sql = f"CREATE TABLE IF NOT EXISTS {self._table_name} ({columns});" - self.db.execute(create_table_sql) - self.db.execute(f"PRAGMA table_info({self._table_name})") + self._connector.execute(create_table_sql) + self._connector.execute(f"PRAGMA table_info({self._table_name})") existing_columns = { - column_info[1]: column_info for column_info in self.db.fetchall() + column_info[1]: column_info for column_info in self._connector.fetchall() } for key, value in self._columns: if key not in existing_columns: @@ -29,26 +28,49 @@ class Update: if value is None: value = "NULL" add_column_sql = f"ALTER TABLE {self._table_name} ADD COLUMN {key} {insert_column} DEFAULT {value};" - self.db.execute(add_column_sql) + self._connector.execute(add_column_sql) logger.debug(f"Create / Update table {self._table_name}.") def one(self, data: dict) -> bool: - _id = data.pop("id") + _id = data["id"] set_sql = ", ".join([f"{key} = :{key}" for key in data.keys()]) - self.db.execute( + self._connector.execute( f""" UPDATE {self._table_name} SET {set_sql} - WHERE id = {_id} + WHERE id = :id """, data, ) logger.debug(f"Update {_id} in {self._table_name}.") return True - def list(self, data: list[dict]): - for item in data: - self.one(item) + def many(self, data: list[dict]) -> bool: + columns = ", ".join(data[0].keys()) + self._connector.executemany( + f""" + UPDATE {self._table_name} + SET {columns} + WHERE id = :id + """, + data, + ) + logger.debug(f"Update {self._table_name}.") + return True + + def value(self, location: dict, set_value: dict) -> bool: + set_sql = ", ".join([f"{key} = :{key}" for key in set_value.keys()]) + params = {**location, **set_value} + self._connector.execute( + f""" + UPDATE {self._table_name} + SET {set_sql} + WHERE {location["key"]} = :{location["key"]} + """, + params, + ) + logger.debug(f"Update {self._table_name}.") + return True @staticmethod def __python_to_sqlite_type(value) -> str: @@ -65,4 +87,4 @@ class Update: elif value is None: return "TEXT" else: - raise ValueError(f"Unsupported data type: {type(value)}") \ No newline at end of file + raise ValueError(f"Unsupported data type: {type(value)}") diff --git a/docs/wiki b/docs/wiki index d0bb98f0..519e381e 160000 --- a/docs/wiki +++ b/docs/wiki @@ -1 +1 @@ -Subproject commit d0bb98f004fb292519dd56c42238ecb2f034eac9 +Subproject commit 519e381e8a1add62e76a39181ee61bad02816035