diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml index 430d121d..14514ac3 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.yml +++ b/.github/ISSUE_TEMPLATE/bug_report.yml @@ -1,17 +1,31 @@ name: 问题反馈 description: File a bug report -title: "[错误报告]" +title: "[错误报告] 请在此处简单描述你的问题" labels: ["bug"] body: - type: markdown attributes: value: | - 描述问题前,请先更新到最新版本。 - 最新版本: [version](https://img.shields.io/docker/v/estrellaxd/auto_bangumi) - 如果更新到最新版本后仍然有问题,请先查阅 [FAQ](https://github.com/EstrellaXD/Auto_Bangumi/wiki/FAQ)。 - 确认非上述问题后,请详细描述你所遇到的问题,并附上相应信息。 - 如果问题已经列在 [FAQ](https://github.com/EstrellaXD/Auto_Bangumi/wiki/FAQ) 中,会直接关闭 issue。 - 解析器问题请转到[专用模板](https://github.com/EstrellaXD/Auto_Bangumi/issues/new?assignees=&labels=bug&template=parser_bug.yml&title=%5B解析器错误%5D),重命名问题请到[专用模板](https://github.com/EstrellaXD/Auto_Bangumi/issues/new?assignees=&labels=bug&template=rename_bug.yml&title=%5B重命名错误%5D) + 描述问题前,请先更新到最新版本。2.5 之前的版本升级请参考 [升级指南](https://github.com/EstrellaXD/Auto_Bangumi/wiki/2.6更新说明#如何从老版本更新的注意事项) + 请确认以下信息,如果你的问题可以直接在文档中找到,那么你的 issue 将会被直接关闭。 + 解析器问题请转到[专用模板](https://github.com/EstrellaXD/Auto_Bangumi/issues/new?assignees=&labels=bug&template=parser_bug.yml&title=%5B解析器错误%5D), + 重命名问题请到[专用模板](https://github.com/EstrellaXD/Auto_Bangumi/issues/new?assignees=&labels=bug&template=rename_bug.yml&title=%5B重命名错误%5D) + - type: checkboxes + id: ensure + attributes: + label: 确认 + description: 在提交 issue 之前,请确认你已经阅读并确认以下内容 + options: + - label: 我的版本是最新版本,我的版本号与 [version](https://github.com/EstrellaXD/Auto_Bangumi/releases/latest) 相同。 + required: true + - label: 我已经查阅了[排错流程](https://github.com/EstrellaXD/Auto_Bangumi/wiki/排错流程),确保提出的问题不在其中。 + required: true + - label: 我已经查阅了[已知问题](https://github.com/EstrellaXD/Auto_Bangumi/wiki/常见问题),并确认我的问题不在其中。 + required: true + - label: 我已经 [issue](https://github.com/EstrellaXD/Auto_Bangumi/issues) 中搜索过,确认我的问题没有被提出过。 + required: true + - label: 我已经修改标题,将标题中的 描述 替换为我遇到的问题。 + required: true - type: input id: version attributes: @@ -43,4 +57,4 @@ body: attributes: label: 发生问题时系统日志 description: 问题出现时,程序运行日志请复制到这里。 - render: shell + render: bash \ No newline at end of file diff --git a/README.md b/README.md index cbabfda7..cc968cd3 100644 --- a/README.md +++ b/README.md @@ -75,18 +75,14 @@ ***开发中的功能:*** -- Web UI #57 ✅ -- 文件统一整理,对单个规则或者文件微调文件夹可以自动调整所有对应的文件。 -- 通知功能,可以通过 IFTTT 等方式通知用户番剧更新进度。✅ -- 剧场版以及合集的支持。✅ -- 各类 API 接口。 +- 内置 RSS 推送更新器。 +- 搜索功能 ***计划开发的功能:*** - 对其他站点种子的解析归类。 - 本地化番剧订阅方式。 - Transmission & Aria2 的支持。 -- 更完善的 WebUI。 # 声明 diff --git a/src/module/ab_decorator/__init__.py b/src/module/ab_decorator/__init__.py index f469894c..f25a083b 100644 --- a/src/module/ab_decorator/__init__.py +++ b/src/module/ab_decorator/__init__.py @@ -1,7 +1,9 @@ import logging import time +import threading logger = logging.getLogger(__name__) +lock = threading.Lock() def qb_connect_failed_wait(func): @@ -30,3 +32,10 @@ def api_failed(func): logger.debug(e) return wrapper + + +def locked(func): + def wrapper(*args, **kwargs): + with lock: + return func(*args, **kwargs) + return wrapper \ No newline at end of file diff --git a/src/module/api/download.py b/src/module/api/download.py index ed3bc9df..6ce80569 100644 --- a/src/module/api/download.py +++ b/src/module/api/download.py @@ -32,8 +32,10 @@ async def download_collection( ) if data: with SeasonCollector() as collector: - collector.collect_season(data, data.rss_link[0]) - return {"status": "Success"} + if collector.collect_season(data, data.rss_link[0], proxy=True): + return {"status": "Success"} + else: + return {"status": "Failed to add torrent"} else: return {"status": "Failed to parse link"} diff --git a/src/module/conf/config.py b/src/module/conf/config.py index 24a3d39a..90503783 100644 --- a/src/module/conf/config.py +++ b/src/module/conf/config.py @@ -42,7 +42,7 @@ class Settings(Config): if not config_dict: config_dict = self.dict() with open(CONFIG_PATH, "w", encoding="utf-8") as f: - json.dump(config_dict, f, indent=4) + json.dump(config_dict, f, indent=4, ensure_ascii=False) def init(self): load_dotenv(".env") diff --git a/src/module/database/bangumi.py b/src/module/database/bangumi.py index 74e7fec5..f2129826 100644 --- a/src/module/database/bangumi.py +++ b/src/module/database/bangumi.py @@ -2,6 +2,7 @@ import logging from module.database.connector import DataConnector from module.models import BangumiData +from module.ab_decorator import locked logger = logging.getLogger(__name__) @@ -68,6 +69,7 @@ class BangumiDatabase(DataConnector): data_list = [self.__data_to_db(x) for x in data] self._update_list(data_list=data_list, table_name=self.__table_name) + @locked def update_rss(self, title_raw, rss_set: str): # Update rss and added self._cursor.execute( @@ -108,51 +110,59 @@ class BangumiDatabase(DataConnector): self._delete_all(self.__table_name) def search_all(self) -> list[BangumiData]: - self._cursor.execute( - """ - SELECT * FROM bangumi - """ - ) - return self.__fetch_data() + dict_data = self._search_datas(self.__table_name) + return [self.__db_to_data(x) for x in dict_data] def search_id(self, _id: int) -> BangumiData | None: - self._cursor.execute( - """ - SELECT * FROM bangumi WHERE id = :id - """, - {"id": _id}, - ) - values = self._cursor.fetchone() - if values is None: + condition = {"id": _id} + value = self._search_data(table_name=self.__table_name, condition=condition) + # self._cursor.execute( + # """ + # SELECT * FROM bangumi WHERE id = :id + # """, + # {"id": _id}, + # ) + # values = self._cursor.fetchone() + if value is None: return None keys = [x[0] for x in self._cursor.description] - dict_data = dict(zip(keys, values)) + dict_data = dict(zip(keys, value)) return self.__db_to_data(dict_data) def search_official_title(self, official_title: str) -> BangumiData | None: - self._cursor.execute( - """ - SELECT * FROM bangumi WHERE official_title = :official_title - """, - {"official_title": official_title}, + value = self._search_data( + table_name=self.__table_name, condition={"official_title": official_title} ) - values = self._cursor.fetchone() - if values is None: + # self._cursor.execute( + # """ + # SELECT * FROM bangumi WHERE official_title = :official_title + # """, + # {"official_title": official_title}, + # ) + # values = self._cursor.fetchone() + if value is None: return None keys = [x[0] for x in self._cursor.description] - dict_data = dict(zip(keys, values)) + dict_data = dict(zip(keys, value)) return self.__db_to_data(dict_data) def match_poster(self, bangumi_name: str) -> str: - self._cursor.execute( - """ - SELECT official_title, poster_link - FROM bangumi - WHERE INSTR(:bangumi_name, official_title) > 0 - """, - {"bangumi_name": bangumi_name}, + condition = f"INSTR({bangumi_name}, official_title) > 0" + keys = ["official_title", "poster_link"] + data = self._search_data( + table_name=self.__table_name, + keys=keys, + condition=condition, ) - data = self._cursor.fetchone() + # self._cursor.execute( + # """ + # SELECT official_title, poster_link + # FROM bangumi + # WHERE INSTR(:bangumi_name, official_title) > 0 + # """, + # {"bangumi_name": bangumi_name}, + # ) + # data = self._cursor.fetchone() if not data: return "" official_title, poster_link = data @@ -160,14 +170,20 @@ class BangumiDatabase(DataConnector): return "" return poster_link + @locked def match_list(self, torrent_list: list, rss_link: str) -> list: # Match title_raw in database - self._cursor.execute( - """ - SELECT title_raw, rss_link, poster_link FROM bangumi - """ + keys = ["title_raw", "rss_link", "poster_link"] + data = self._search_datas( + table_name=self.__table_name, + keys=keys, ) - data = self._cursor.fetchall() + # self._cursor.execute( + # """ + # SELECT title_raw, rss_link, poster_link FROM bangumi + # """ + # ) + # data = self._cursor.fetchall() if not data: return torrent_list # Match title @@ -189,6 +205,12 @@ class BangumiDatabase(DataConnector): def not_complete(self) -> list[BangumiData]: # Find eps_complete = False + condition = "eps_complete = 0" + data = self._search_datas( + table_name=self.__table_name, + condition=condition, + ) + self._cursor.execute( """ SELECT * FROM bangumi WHERE eps_collect = 0 diff --git a/src/module/database/connector.py b/src/module/database/connector.py index d74bb159..84a6ae56 100644 --- a/src/module/database/connector.py +++ b/src/module/database/connector.py @@ -2,7 +2,9 @@ import os import sqlite3 import logging + from module.conf import DATA_PATH +from module.ab_decorator import locked logger = logging.getLogger(__name__) @@ -15,6 +17,7 @@ class DataConnector: self._conn = sqlite3.connect(DATA_PATH) self._cursor = self._conn.cursor() + @locked def _update_table(self, table_name: str, db_data: dict): columns = ", ".join( [ @@ -38,6 +41,7 @@ class DataConnector: self._conn.commit() logger.debug(f"Create / Update table {table_name}.") + @locked def _insert(self, table_name: str, db_data: dict): columns = ", ".join(db_data.keys()) values = ", ".join([f":{key}" for key in db_data.keys()]) @@ -46,6 +50,7 @@ class DataConnector: ) self._conn.commit() + @locked def _insert_list(self, table_name: str, data_list: list[dict]): columns = ", ".join(data_list[0].keys()) values = ", ".join([f":{key}" for key in data_list[0].keys()]) @@ -54,6 +59,7 @@ class DataConnector: ) self._conn.commit() + @locked def _select(self, keys: list[str], table_name: str, condition: str = None) -> dict: if condition is None: self._cursor.execute(f"SELECT {', '.join(keys)} FROM {table_name}") @@ -63,6 +69,7 @@ class DataConnector: ) return dict(zip(keys, self._cursor.fetchone())) + @locked def _update(self, table_name: str, db_data: dict): _id = db_data.get("id") if _id is None: @@ -74,6 +81,7 @@ class DataConnector: self._conn.commit() return self._cursor.rowcount == 1 + @locked def _update_list(self, table_name: str, data_list: list[dict]): if len(data_list) == 0: return @@ -85,6 +93,7 @@ class DataConnector: ) self._conn.commit() + @locked def _update_section(self, table_name: str, location: dict, update_dict: dict): set_sql = ", ".join([f"{key} = :{key}" for key in update_dict.keys()]) sql_loc = f"{location['key']} = {location['value']}" @@ -93,10 +102,35 @@ class DataConnector: ) self._conn.commit() + @locked def _delete_all(self, table_name: str): self._cursor.execute(f"DELETE FROM {table_name}") self._conn.commit() + @locked + def _search_data(self, table_name: str, keys: list[str] | None, condition: str) -> dict: + if keys is None: + self._cursor.execute(f"SELECT * FROM {table_name} WHERE {condition}") + else: + self._cursor.execute( + f"SELECT {', '.join(keys)} FROM {table_name} WHERE {condition}" + ) + return dict(zip(keys, self._cursor.fetchone())) + + @locked + def _search_datas(self, table_name: str, keys: list[str] | None, condition: str = None) -> list[dict]: + if keys is None: + select_sql = "*" + else: + select_sql = ", ".join(keys) + if condition is None: + self._cursor.execute(f"SELECT {select_sql} FROM {table_name}") + else: + self._cursor.execute( + f"SELECT {select_sql} FROM {table_name} WHERE {condition}" + ) + return [dict(zip(keys, row)) for row in self._cursor.fetchall()] + def _table_exists(self, table_name: str) -> bool: self._cursor.execute( f"SELECT name FROM sqlite_master WHERE type='table' AND name=?;", diff --git a/src/module/database/user.py b/src/module/database/user.py index 22a97020..76dcade4 100644 --- a/src/module/database/user.py +++ b/src/module/database/user.py @@ -49,7 +49,7 @@ class AuthDB(DataConnector): ) result = self._cursor.fetchone() if not result: - raise HTTPException(status_code=404, detail="User not found") + raise HTTPException(status_code=401, detail="User not found") if not verify_password(password, result[1]): raise HTTPException(status_code=401, detail="Password error") return True diff --git a/src/module/downloader/client/qb_downloader.py b/src/module/downloader/client/qb_downloader.py index d3563185..8b22b813 100644 --- a/src/module/downloader/client/qb_downloader.py +++ b/src/module/downloader/client/qb_downloader.py @@ -80,14 +80,16 @@ class QbDownloader: def torrents_info(self, status_filter, category, tag=None): return self._client.torrents_info(status_filter=status_filter, category=category, tag=tag) - def torrents_add(self, urls, save_path, category): - return self._client.torrents_add( + def torrents_add(self, urls, save_path, category, torrent_files=None): + resp = self._client.torrents_add( is_paused=False, urls=urls, + torrent_files=torrent_files, save_path=save_path, category=category, use_auto_torrent_management=False ) + return resp == "Ok." def torrents_delete(self, hash): return self._client.torrents_delete(delete_files=True, torrent_hashes=hash) diff --git a/src/module/downloader/download_client.py b/src/module/downloader/download_client.py index 37c843f2..c30c1cd3 100644 --- a/src/module/downloader/download_client.py +++ b/src/module/downloader/download_client.py @@ -45,7 +45,7 @@ class DownloadClient(TorrentPath): def auth(self): self.authed = self.client.auth() if self.authed: - logger.info("[Downloader] Authed.") + logger.debug("[Downloader] Authed.") else: logger.error("[Downloader] Auth failed.") @@ -112,9 +112,17 @@ class DownloadClient(TorrentPath): logger.info(f"[Downloader] Remove torrents.") def add_torrent(self, torrent: dict): - self.client.torrents_add( - urls=torrent["url"], save_path=torrent["save_path"], category="Bangumi" - ) + if self.client.torrents_add( + urls=torrent.get("urls"), + torrent_files=torrent.get("torrent_files"), + save_path=torrent.get("save_path"), + category="Bangumi" + ): + logger.debug(f"[Downloader] Add torrent: {torrent.get('save_path')}") + return True + else: + logger.error(f"[Downloader] Add torrent failed: {torrent.get('save_path')}") + return False def move_torrent(self, hashes, location): self.client.move_torrent(hashes=hashes, new_location=location) diff --git a/src/module/downloader/path.py b/src/module/downloader/path.py index de4b9348..e0ed8043 100644 --- a/src/module/downloader/path.py +++ b/src/module/downloader/path.py @@ -13,8 +13,8 @@ logger = logging.getLogger(__name__) class TorrentPath: - def __init__(self, download_path: str = settings.downloader.path): - self.download_path = download_path + def __init__(self): + pass @staticmethod def check_files(info): @@ -29,10 +29,11 @@ class TorrentPath: subtitle_list.append(file_name) return media_list, subtitle_list - def _path_to_bangumi(self, save_path): + @staticmethod + def _path_to_bangumi(save_path): # Split save path and download path save_parts = save_path.split(path.sep) - download_parts = self.download_path.split(path.sep) + download_parts = settings.downloader.path.split(path.sep) # Get bangumi name and season bangumi_name = "" season = 1 @@ -50,11 +51,12 @@ class TorrentPath: def is_ep(self, file_path): return self._file_depth(file_path) <= 2 - def _gen_save_path(self, data: BangumiData): + @staticmethod + def _gen_save_path(data: BangumiData): folder = ( f"{data.official_title} ({data.year})" if data.year else data.official_title ) - save_path = path.join(self.download_path, folder, f"Season {data.season}") + save_path = path.join(settings.downloader.path, folder, f"Season {data.season}") return save_path @staticmethod diff --git a/src/module/manager/collector.py b/src/module/manager/collector.py index 1b341e3e..7350108a 100644 --- a/src/module/manager/collector.py +++ b/src/module/manager/collector.py @@ -11,23 +11,31 @@ logger = logging.getLogger(__name__) class SeasonCollector(DownloadClient): - def add_season_torrents(self, data: BangumiData, torrents): - for torrent in torrents: + def add_season_torrents(self, data: BangumiData, torrents, torrent_files=None): + if torrent_files: download_info = { - "url": torrent.torrent_link, + "torrent_files": torrent_files, "save_path": self._gen_save_path(data), } - self.add_torrent(download_info) + return self.add_torrent(download_info) + else: + download_info = { + "urls": [torrent.torrent_link for torrent in torrents], + "save_path": self._gen_save_path(data), + } + return self.add_torrent(download_info) - def collect_season(self, data: BangumiData, link: str = None): + def collect_season(self, data: BangumiData, link: str = None, proxy: bool = False): logger.info(f"Start collecting {data.official_title} Season {data.season}...") with SearchTorrent() as st: if not link: torrents = st.search_season(data) else: - torrents = st.get_torrents(link) - self.add_season_torrents(data, torrents) - logger.info("Completed!") + torrents = st.get_torrents(link, _filter="|".join(data.filter)) + torrent_files = None + if proxy: + torrent_files = [st.get_content(torrent.torrent_link) for torrent in torrents] + return self.add_season_torrents(data=data, torrents=torrents, torrent_files=torrent_files) def subscribe_season(self, data: BangumiData): with BangumiDatabase() as db: diff --git a/src/module/manager/renamer.py b/src/module/manager/renamer.py index d88de5fb..c6b25b5e 100644 --- a/src/module/manager/renamer.py +++ b/src/module/manager/renamer.py @@ -139,12 +139,12 @@ class Renamer(DownloadClient): if not renamed: logger.warning(f"[Renamer] {subtitle_path} rename failed") - def rename(self): + def rename(self) -> list[Notification]: # Get torrent info logger.debug("[Renamer] Start rename process.") rename_method = settings.bangumi_manage.rename_method torrents_info = self.get_torrent_info() - renamed_info = [] + renamed_info: list[Notification] = [] for info in torrents_info: media_list, subtitle_list = self.check_files(info) bangumi_name, season = self._path_to_bangumi(info.save_path) diff --git a/src/module/network/request_contents.py b/src/module/network/request_contents.py index b9b2b7af..0c2096b6 100644 --- a/src/module/network/request_contents.py +++ b/src/module/network/request_contents.py @@ -42,12 +42,11 @@ class RequestContent(RequestURL): _url: str, _filter: str = "|".join(settings.rss_parser.filter), retry: int = 3, - ) -> [TorrentInfo]: + ) -> list[TorrentInfo]: try: soup = self.get_xml(_url, retry) torrent_titles, torrent_urls, torrent_homepage = mikan_parser(soup) - - torrents = [] + torrents: list[TorrentInfo] = [] for _title, torrent_url, homepage in zip( torrent_titles, torrent_urls, torrent_homepage ): diff --git a/src/module/network/request_url.py b/src/module/network/request_url.py index 85dadb7e..16da3d4c 100644 --- a/src/module/network/request_url.py +++ b/src/module/network/request_url.py @@ -13,6 +13,7 @@ logger = logging.getLogger(__name__) class RequestURL: def __init__(self): self.header = {"user-agent": "Mozilla/5.0", "Accept": "application/xml"} + self._socks5_proxy = False def get_url(self, url, retry=3): try_time = 0 @@ -77,6 +78,7 @@ class RequestURL: "http": url, } elif settings.proxy.type == "socks5": + self._socks5_proxy = True socks.set_default_proxy( socks.SOCKS5, addr=settings.proxy.host, @@ -91,4 +93,8 @@ class RequestURL: return self def __exit__(self, exc_type, exc_val, exc_tb): + if self._socks5_proxy: + socks.set_default_proxy() + socket.socket = socks.socksocket + self._socks5_proxy = False self.session.close() diff --git a/src/module/notification/notification.py b/src/module/notification/notification.py index 4181ac02..22bde17a 100644 --- a/src/module/notification/notification.py +++ b/src/module/notification/notification.py @@ -10,20 +10,23 @@ from module.database import BangumiDatabase logger = logging.getLogger(__name__) -def getClient(type=settings.notification.type): +def getClient(type: str): if type.lower() == "telegram": return TelegramNotification elif type.lower() == "server-chan": return ServerChanNotification elif type.lower() == "bark": return BarkNotification + elif type.lower() == "wecom": + return WecomNotification else: return None -class PostNotification(getClient()): +class PostNotification: def __init__(self): - super().__init__( + Notifier = getClient(settings.notification.type) + self.notifier = Notifier( token=settings.notification.token, chat_id=settings.notification.chat_id ) @@ -46,12 +49,18 @@ class PostNotification(getClient()): def send_msg(self, notify: Notification) -> bool: text = self._gen_message(notify) try: - self.post_msg(text) + self.notifier.post_msg(text) logger.debug(f"Send notification: {notify.official_title}") except Exception as e: logger.warning(f"Failed to send notification: {e}") return False + def __enter__(self): + self.notifier.__enter__() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.notifier.__exit__(exc_type, exc_val, exc_tb) if __name__ == "__main__": info = Notification( diff --git a/src/module/notification/plugin/__init__.py b/src/module/notification/plugin/__init__.py index 106087f6..ece03d58 100644 --- a/src/module/notification/plugin/__init__.py +++ b/src/module/notification/plugin/__init__.py @@ -1,3 +1,4 @@ from .bark import BarkNotification from .server_chan import ServerChanNotification -from .telegram import TelegramNotification \ No newline at end of file +from .telegram import TelegramNotification +from .wecom import WecomNotification \ No newline at end of file diff --git a/src/module/notification/plugin/wecom.py b/src/module/notification/plugin/wecom.py new file mode 100644 index 00000000..c56e97e2 --- /dev/null +++ b/src/module/notification/plugin/wecom.py @@ -0,0 +1,35 @@ +import logging +from module.network import RequestContent + +logger = logging.getLogger(__name__) + + +class WecomNotification(RequestContent): + """企业微信推送 基于图文消息""" + + def __init__(self, token, chat_id, **kwargs): + super().__init__() + #Chat_id is used as noti_url in this push tunnel + self.notification_url = f"{chat_id}" + self.token = token + + def post_msg(self, text: str) -> bool: + ##Change message format to match Wecom push better + info = text.split(":") + print(info) + title = "【番剧更新】" + info[1].split("\n")[0].strip() + msg = info[2].split("\n")[0].strip()+" "+info[3].split("\n")[0].strip() + picurl = info[3].split("\n")[1].strip() + #Default pic to avoid blank in message. Resolution:1068*455 + if picurl == "": + picurl = "https://article.biliimg.com/bfs/article/d8bcd0408bf32594fd82f27de7d2c685829d1b2e.png" + data = { + "key":self.token, + "type": "news", + "title": title, + "msg": msg, + "picurl":picurl + } + resp = self.post_data(self.notification_url, data) + logger.debug(f"Wecom notification: {resp.status_code}") + return resp.status_code == 200 diff --git a/src/module/parser/analyser/torrent_parser.py b/src/module/parser/analyser/torrent_parser.py index fab6e485..ba2a1880 100644 --- a/src/module/parser/analyser/torrent_parser.py +++ b/src/module/parser/analyser/torrent_parser.py @@ -18,8 +18,8 @@ RULES = [ ] SUBTITLE_LANG = { - "zh-tw": ["TC", "CHT", "繁", "zh-tw"], - "zh": ["SC", "CHS", "简", "zh"], + "zh-tw": ["TC", "CHT", "cht", "繁", "zh-tw"], + "zh": ["SC", "CHS", "chs", "简", "zh"], } diff --git a/src/module/security/api.py b/src/module/security/api.py index 13069b72..14282171 100644 --- a/src/module/security/api.py +++ b/src/module/security/api.py @@ -6,7 +6,6 @@ from .jwt import verify_token from module.database.user import AuthDB from module.models.user import User - oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login") @@ -50,8 +49,4 @@ def update_user_info(user_data: User, current_user): def auth_user(username, password): with AuthDB() as db: - if not db.auth_user(username, password): - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="invalid username or password", - ) + db.auth_user(username, password)