From 589df384c99fac1dcd8b2519c99eaecbd92fe3f0 Mon Sep 17 00:00:00 2001 From: EstrellaXD Date: Sat, 20 May 2023 17:31:42 +0800 Subject: [PATCH] feat: enable auth in all api --- src/module/api/auth.py | 46 ++++++--------------------------- src/module/api/bangumi.py | 27 ++++++++++++++----- src/module/api/config.py | 11 ++++++-- src/module/api/download.py | 15 ++++++++--- src/module/api/log.py | 11 +++++--- src/module/api/program.py | 33 ++++++++++++++++------- src/module/database/__init__.py | 1 + src/module/security/__init__.py | 2 ++ src/module/security/api.py | 46 +++++++++++++++++++++++++++++++++ 9 files changed, 130 insertions(+), 62 deletions(-) create mode 100644 src/module/security/api.py diff --git a/src/module/api/auth.py b/src/module/api/auth.py index 607a0497..aaa8db45 100644 --- a/src/module/api/auth.py +++ b/src/module/api/auth.py @@ -1,45 +1,19 @@ from fastapi import Depends, HTTPException, status -from fastapi.security import OAuth2PasswordRequestForm, OAuth2PasswordBearer +from fastapi.security import OAuth2PasswordRequestForm -from module.database.user import AuthDB -from module.security.jwt import create_access_token, decode_token +from module.security import create_access_token, get_current_user, update_user_info, auth_user from module.models.user import User from .program import router -oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login") - - -async def get_current_user(token: str = Depends(oauth2_scheme)): - if not token: - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token") - payload = decode_token(token) - if not payload: - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token") - username = payload.get("sub") - with AuthDB() as user_db: - user = user_db.get_user(username) - if not user: - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid username") - return user - - -async def get_token_data(token: str = Depends(oauth2_scheme)): - payload = decode_token(token) - if not payload: - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token") - return payload - @router.post("/api/v1/auth/login", response_model=dict, tags=["auth"]) async def login(form_data: OAuth2PasswordRequestForm = Depends()): username = form_data.username password = form_data.password - with AuthDB() as db: - if not db.auth_user(username, password): - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid username or password") - token = create_access_token({"sub": username}) - return {"access_token": token, "token_type": "bearer", "expire": 86400} + auth_user(username, password) + token = create_access_token({"sub": username}) + return {"access_token": token, "token_type": "bearer", "expire": 86400} @router.get("/api/v1/auth/refresh_token", response_model=dict, tags=["auth"]) @@ -62,12 +36,8 @@ async def logout( @router.post("/api/v1/auth/update", response_model=dict, tags=["auth"]) -async def update_user(data: User, current_user: User = Depends(get_current_user)): +async def update_user(user_data: User, current_user: User = Depends(get_current_user)): if not current_user: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token") - try: - with AuthDB() as db: - db.update_user(current_user.username, data) - return {"message": "update success"} - except Exception as e: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) + return update_user_info(user_data, current_user) + diff --git a/src/module/api/bangumi.py b/src/module/api/bangumi.py index 3fb79763..85068877 100644 --- a/src/module/api/bangumi.py +++ b/src/module/api/bangumi.py @@ -1,42 +1,57 @@ +from fastapi import Depends, HTTPException, status + from .log import router from module.models import BangumiData from module.database import BangumiDatabase from module.manager import TorrentManager +from module.security import get_current_user @router.get("/api/v1/bangumi/getAll", tags=["bangumi"], response_model=list[BangumiData]) -async def get_all_data(): +async def get_all_data(current_user=Depends(get_current_user)): + if not current_user: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token") with TorrentManager() as torrent: return torrent.search_all() @router.get("/api/v1/bangumi/getData/{bangumi_id}", tags=["bangumi"], response_model=BangumiData) -async def get_data(bangumi_id: str): +async def get_data(bangumi_id: str, current_user=Depends(get_current_user)): + if not current_user: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token") with TorrentManager() as torrent: return torrent.search_data(bangumi_id) @router.post("/api/v1/bangumi/updateData", tags=["bangumi"]) -async def update_data(data: BangumiData): +async def update_data(data: BangumiData, current_user=Depends(get_current_user)): + if not current_user: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token") with TorrentManager() as torrent: return torrent.update_rule(data) @router.delete("/api/v1/bangumi/deleteData/{bangumi_id}", tags=["bangumi"]) -async def delete_data(bangumi_id: str): +async def delete_data(bangumi_id: str, current_user=Depends(get_current_user)): + if not current_user: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token") with TorrentManager() as torrent: return torrent.delete_data(bangumi_id) @router.delete("/api/v1/bangumi/deleteRule/{bangumi_id}", tags=["bangumi"]) -async def delete_rule(bangumi_id: str, file: bool = False): +async def delete_rule(bangumi_id: str, file: bool = False, current_user=Depends(get_current_user)): + if not current_user: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token") with TorrentManager() as torrent: return torrent.delete_rule(bangumi_id, file) @router.get("/api/v1/bangumi/resetAll", tags=["bangumi"]) -async def reset_all(): +async def reset_all(current_user=Depends(get_current_user)): + if not current_user: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token") with BangumiDatabase() as database: database.delete_all() return {"status": "ok"} diff --git a/src/module/api/config.py b/src/module/api/config.py index 7400d288..a82bca5a 100644 --- a/src/module/api/config.py +++ b/src/module/api/config.py @@ -1,20 +1,27 @@ import logging +from fastapi import Depends, HTTPException, status + from .bangumi import router from module.conf import settings from module.models import Config +from module.security import get_current_user logger = logging.getLogger(__name__) @router.get("/api/v1/getConfig", tags=["config"], response_model=Config) -async def get_config(): +async def get_config(current_user=Depends(get_current_user)): + if not current_user: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token") return settings @router.post("/api/v1/updateConfig", tags=["config"]) -async def update_config(config: Config): +async def update_config(config: Config, current_user=Depends(get_current_user)): + if not current_user: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token") try: settings.save(config_dict=config.dict()) settings.load() diff --git a/src/module/api/download.py b/src/module/api/download.py index 6d90aeac..b8396ea5 100644 --- a/src/module/api/download.py +++ b/src/module/api/download.py @@ -1,9 +1,12 @@ +from fastapi import Depends, HTTPException, status + from .config import router from module.models.api import * from module.models import BangumiData from module.manager import SeasonCollector from module.rss import analyser +from module.security import get_current_user def link_process(link): @@ -11,7 +14,9 @@ def link_process(link): @router.post("/api/v1/download/analysis", tags=["download"]) -async def analysis(link: RssLink): +async def analysis(link: RssLink, current_user=Depends(get_current_user)): + if not current_user: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token") data = link_process(link.rss_link) if data: return data[0] @@ -20,7 +25,9 @@ async def analysis(link: RssLink): @router.post("/api/v1/download/collection", tags=["download"]) -async def download_collection(data: BangumiData): +async def download_collection(data: BangumiData, current_user=Depends(get_current_user)): + if not current_user: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token") if data: with SeasonCollector() as collector: collector.collect_season(data, data.rss_link[0]) @@ -30,7 +37,9 @@ async def download_collection(data: BangumiData): @router.post("/api/v1/download/subscribe", tags=["download"]) -async def subscribe(data: BangumiData): +async def subscribe(data: BangumiData, current_user=Depends(get_current_user)): + if not current_user: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token") if data: with SeasonCollector() as collector: collector.subscribe_season(data) diff --git a/src/module/api/log.py b/src/module/api/log.py index 645f5695..cf26cb8c 100644 --- a/src/module/api/log.py +++ b/src/module/api/log.py @@ -1,13 +1,16 @@ import os -from fastapi import Response +from fastapi import Response, HTTPException, Depends, status from .auth import router from module.conf import LOG_PATH +from module.security import get_current_user @router.get("/api/v1/log", tags=["log"]) -async def get_log(): +async def get_log(current_user=Depends(get_current_user)): + if not current_user: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token") if os.path.isfile(LOG_PATH): with open(LOG_PATH, "r") as f: return Response(f.read(), media_type="text/plain") @@ -16,7 +19,9 @@ async def get_log(): @router.get("/api/v1/log/clear", tags=["log"]) -async def clear_log(): +async def clear_log(current_user=Depends(get_current_user)): + if not current_user: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token") if os.path.isfile(LOG_PATH): with open(LOG_PATH, "w") as f: f.write("") diff --git a/src/module/api/program.py b/src/module/api/program.py index e8db5e8a..73dcfa0a 100644 --- a/src/module/api/program.py +++ b/src/module/api/program.py @@ -2,12 +2,11 @@ import signal import logging import os -from fastapi.exceptions import HTTPException - - +from fastapi import HTTPException, status, Depends from fastapi import FastAPI from module.core import Program +from module.security import get_current_user logger = logging.getLogger(__name__) program = Program() @@ -25,7 +24,9 @@ async def shutdown(): @router.get("/api/v1/restart", tags=["program"]) -async def restart(): +async def restart(current_user=Depends(get_current_user)): + if not current_user: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token") try: program.restart() return {"status": "ok"} @@ -36,7 +37,9 @@ async def restart(): @router.get("/api/v1/start", tags=["program"]) -async def start(): +async def start(current_user=Depends(get_current_user)): + if not current_user: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token") try: program.start() return {"status": "ok"} @@ -47,13 +50,17 @@ async def start(): @router.get("/api/v1/stop", tags=["program"]) -async def stop(): +async def stop(current_user=Depends(get_current_user)): + if not current_user: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token") program.stop() return {"status": "ok"} @router.get("/api/v1/status", tags=["program"]) -async def status(): +async def status(current_user=Depends(get_current_user)): + if not current_user: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token") if not program.is_running: return {"status": "stop"} else: @@ -61,7 +68,9 @@ async def status(): @router.get("/api/v1/shutdown", tags=["program"]) -async def shutdown_program(): +async def shutdown_program(current_user=Depends(get_current_user)): + if not current_user: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token") program.stop() logger.info("Shutting down program...") os.kill(os.getpid(), signal.SIGINT) @@ -70,10 +79,14 @@ async def shutdown_program(): # Check status @router.get("/api/v1/check/downloader", tags=["check"]) -async def check_downloader_status(): +async def check_downloader_status(current_user=Depends(get_current_user)): + if not current_user: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token") return program.check_downloader() @router.get("/api/v1/check/rss", tags=["check"]) -async def check_rss_status(): +async def check_rss_status(current_user=Depends(get_current_user)): + if not current_user: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token") return program.check_analyser() diff --git a/src/module/database/__init__.py b/src/module/database/__init__.py index c5d59b08..7a0fd814 100644 --- a/src/module/database/__init__.py +++ b/src/module/database/__init__.py @@ -1 +1,2 @@ from .bangumi import BangumiDatabase +from .user import AuthDB diff --git a/src/module/security/__init__.py b/src/module/security/__init__.py index e69de29b..28f5d20e 100644 --- a/src/module/security/__init__.py +++ b/src/module/security/__init__.py @@ -0,0 +1,2 @@ +from .jwt import create_access_token +from .api import get_current_user, get_token_data, auth_user, update_user_info diff --git a/src/module/security/api.py b/src/module/security/api.py new file mode 100644 index 00000000..395d683d --- /dev/null +++ b/src/module/security/api.py @@ -0,0 +1,46 @@ +from fastapi import Depends, HTTPException, status +from fastapi.security import OAuth2PasswordBearer + +from .jwt import decode_token + +from module.database import AuthDB +from module.models.user import User + + +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login") + + +async def get_current_user(token: str = Depends(oauth2_scheme)): + if not token: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token") + payload = decode_token(token) + if not payload: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token") + username = payload.get("sub") + with AuthDB() as user_db: + user = user_db.get_user(username) + if not user: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid username") + return user + + +async def get_token_data(token: str = Depends(oauth2_scheme)): + payload = decode_token(token) + if not payload: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token") + return payload + + +def update_user_info(user_data: User, current_user): + try: + with AuthDB() as db: + db.update_user(current_user.username, user_data) + return {"message": "update success"} + except Exception as e: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) + + +def auth_user(username, password): + with AuthDB() as db: + if not db.auth_user(username, password): + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid username or password") \ No newline at end of file