From 32382dc7cb37c2f932478f59aa4b9c8b98722e64 Mon Sep 17 00:00:00 2001 From: EstrellaXD Date: Wed, 14 Jun 2023 21:17:42 +0800 Subject: [PATCH] fix: orm bugs. --- backend/src/module/database/orm/connector.py | 22 +++++++++++++------- backend/src/module/database/orm/insert.py | 10 +++++---- backend/src/module/database/orm/select.py | 13 ++++++------ backend/src/module/database/orm/update.py | 12 +++++------ 4 files changed, 32 insertions(+), 25 deletions(-) diff --git a/backend/src/module/database/orm/connector.py b/backend/src/module/database/orm/connector.py index 26224e71..d30088a9 100644 --- a/backend/src/module/database/orm/connector.py +++ b/backend/src/module/database/orm/connector.py @@ -39,15 +39,21 @@ class Connector: self._cursor.executemany(sql, params) self._conn.commit() - def fetchall(self) -> dict: + def fetchall(self, keys: str = None) -> list[dict]: datas = self._cursor.fetchall() - for data in datas: - yield dict(zip(self._columns, data)) + if keys: + return [dict(zip(keys, data)) for data in datas] + return [dict(zip(self._columns, data)) for data in datas] - def fetchone(self): - return dict(zip(self._columns, self._cursor.fetchone())) + def fetchone(self, keys: list[str] = None) -> dict: + data = self._cursor.fetchone() + if data: + if keys: + return dict(zip(keys, data)) + return dict(zip(self._columns, data)) - def fetchmany(self, size: int): + def fetchmany(self, keys: list[str], size: int) -> list[dict]: datas = self._cursor.fetchmany(size) - for data in datas: - yield dict(zip(self._columns, data)) + if keys: + return [dict(zip(keys, data)) for data in datas] + return [dict(zip(self._columns, data)) for data in datas] diff --git a/backend/src/module/database/orm/insert.py b/backend/src/module/database/orm/insert.py index f5d7c1c2..d78c6ca7 100644 --- a/backend/src/module/database/orm/insert.py +++ b/backend/src/module/database/orm/insert.py @@ -5,15 +5,17 @@ class Insert: self._columns = data.items() def __gen_id(self) -> int: - self._connector.execute(f"SELECT MAX(id) FROM {self._table_name}") - max_id = self._connector.fetchone()[0] + self._connector.execute( + f""" + SELECT MAX(id) FROM {self._table_name} + """, + ) + max_id = self._connector.fetchone(keys=["id"]).get("id") if max_id is None: return 1 return max_id + 1 def one(self, data: dict): - if data["id"] is not None: - raise ValueError("id must be None") _id = self.__gen_id() data["id"] = _id columns = ", ".join(data.keys()) diff --git a/backend/src/module/database/orm/select.py b/backend/src/module/database/orm/select.py index 45e555ae..21da0079 100644 --- a/backend/src/module/database/orm/select.py +++ b/backend/src/module/database/orm/select.py @@ -33,7 +33,7 @@ class Select: if keys is None: keys = ["*"] columns = ", ".join(keys) - + condition_sql = self.__select_condition(conditions, combine_operator) self._connector.execute( f""" SELECT {columns} FROM {self._table_name} @@ -41,7 +41,7 @@ class Select: """, conditions, ) - return self._connector.fetchone() + return self._connector.fetchone(keys) def many( self, @@ -51,10 +51,11 @@ class Select: limit: int = None, ): if keys is None: - keys = ["*"] + columns = "*" + else: + columns = ", ".join(keys) if limit is None: limit = 10000 - columns = ", ".join(keys) condition_sql = self.__select_condition(conditions, combine_operator) self._connector.execute( f""" @@ -64,7 +65,7 @@ class Select: """, conditions, ) - return self._connector.fetchall() + return self._connector.fetchall(keys) def column(self, keys: list[str]): columns = ", ".join(keys) @@ -73,7 +74,7 @@ class Select: SELECT {columns} FROM {self._table_name} """, ) - return self._connector.fetchall() + return self._connector.fetchall(keys) @staticmethod def __select_condition(conditions: dict, combine_operator: str = "AND"): diff --git a/backend/src/module/database/orm/update.py b/backend/src/module/database/orm/update.py index 39682774..96dc8fa6 100644 --- a/backend/src/module/database/orm/update.py +++ b/backend/src/module/database/orm/update.py @@ -7,22 +7,20 @@ class Update: def __init__(self, connector, table_name: str, data: dict): self._connector = connector self._table_name = table_name - self._columns = data.items() + self._example_data = data def table(self): columns = ", ".join( [ f"{key} {self.__python_to_sqlite_type(value)}" - for key, value in self._columns + for key, value in self._example_data.items() ] ) create_table_sql = f"CREATE TABLE IF NOT EXISTS {self._table_name} ({columns});" self._connector.execute(create_table_sql) self._connector.execute(f"PRAGMA table_info({self._table_name})") - existing_columns = { - column_info[1]: column_info for column_info in self._connector.fetchall() - } - for key, value in self._columns: + existing_columns = self._connector._columns + for key, value in self._example_data.items(): if key not in existing_columns: insert_column = self.__python_to_sqlite_type(value) if value is None: @@ -46,7 +44,7 @@ class Update: return True def many(self, data: list[dict]) -> bool: - columns = ", ".join(data[0].keys()) + columns = ", ".join([f"{key} = :{key}" for key in data[0].keys()]) self._connector.executemany( f""" UPDATE {self._table_name}