From 6cac596d855cc5e9480ee32ea734c4d0949e3bdc Mon Sep 17 00:00:00 2001 From: EstrellaXD Date: Sun, 13 Aug 2023 18:13:25 +0800 Subject: [PATCH] fix: leak of response model. --- backend/src/module/api/bangumi.py | 28 +++++++++------ backend/src/module/api/config.py | 17 +++++++--- backend/src/module/api/download.py | 40 ++++++++++++++++------ backend/src/module/api/log.py | 16 ++++++--- backend/src/module/api/program.py | 3 +- backend/src/module/api/response.py | 1 - backend/src/module/api/rss.py | 49 +++++++++++++++++++++------ backend/src/module/api/search.py | 3 +- backend/src/module/database/rss.py | 11 ++++-- backend/src/module/manager/torrent.py | 12 +++---- backend/src/module/models/__init__.py | 2 +- backend/src/module/models/response.py | 5 +++ 12 files changed, 130 insertions(+), 57 deletions(-) diff --git a/backend/src/module/api/bangumi.py b/backend/src/module/api/bangumi.py index 64e2bfb3..464f001d 100644 --- a/backend/src/module/api/bangumi.py +++ b/backend/src/module/api/bangumi.py @@ -4,7 +4,7 @@ from fastapi.responses import JSONResponse from .response import u_response from module.manager import TorrentManager -from module.models import Bangumi, BangumiUpdate +from module.models import Bangumi, BangumiUpdate, APIResponse from module.security.api import get_current_user, UNAUTHORIZED router = APIRouter(prefix="/bangumi", tags=["bangumi"]) @@ -23,10 +23,11 @@ async def get_data(bangumi_id: str, current_user=Depends(get_current_user)): if not current_user: raise UNAUTHORIZED with TorrentManager() as manager: - return manager.search_one(bangumi_id) + resp = manager.search_one(bangumi_id) + return resp -@router.patch("/update/{bangumi_id}") +@router.patch("/update/{bangumi_id}", response_model=APIResponse) async def update_rule( bangumi_id: int, data: BangumiUpdate, current_user=Depends(get_current_user) ): @@ -37,7 +38,7 @@ async def update_rule( return u_response(resp) -@router.delete("/delete/{bangumi_id}") +@router.delete("/delete/{bangumi_id}", response_model=APIResponse) async def delete_rule( bangumi_id: str, file: bool = False, current_user=Depends(get_current_user) ): @@ -48,7 +49,7 @@ async def delete_rule( return u_response(resp) -@router.delete("/delete/many/") +@router.delete("/delete/many/", response_model=APIResponse) async def delete_many_rule( bangumi_id: list, file: bool = False, current_user=Depends(get_current_user) ): @@ -59,14 +60,15 @@ async def delete_many_rule( manager.delete_rule(i, file) -@router.delete("/disable/{bangumi_id}") +@router.delete("/disable/{bangumi_id}", response_model=APIResponse) async def disable_rule( bangumi_id: str, file: bool = False, current_user=Depends(get_current_user) ): if not current_user: raise UNAUTHORIZED with TorrentManager() as manager: - return manager.disable_rule(bangumi_id, file) + resp = manager.disable_rule(bangumi_id, file) + return u_response(resp) @router.delete("/disable/many/") @@ -80,18 +82,22 @@ async def disable_many_rule( manager.disable_rule(i, file) -@router.get("/enable/{bangumi_id}") +@router.get("/enable/{bangumi_id}", response_model=APIResponse) async def enable_rule(bangumi_id: str, current_user=Depends(get_current_user)): if not current_user: raise UNAUTHORIZED with TorrentManager() as manager: - return manager.enable_rule(bangumi_id) + resp = manager.enable_rule(bangumi_id) + return u_response(resp) -@router.get("/reset/all") +@router.get("/reset/all", response_model=APIResponse) async def reset_all(current_user=Depends(get_current_user)): if not current_user: raise UNAUTHORIZED with TorrentManager() as manager: manager.bangumi.delete_all() - return JSONResponse(status_code=200, content={"message": "OK"}) + return JSONResponse( + status_code=200, + content={"msg_en": "Reset all rules successfully.", "msg_zh": "重置所有规则成功。"}, + ) diff --git a/backend/src/module/api/config.py b/backend/src/module/api/config.py index 3ed0b0c9..5f9138b6 100644 --- a/backend/src/module/api/config.py +++ b/backend/src/module/api/config.py @@ -1,9 +1,10 @@ import logging -from fastapi import APIRouter, Depends, HTTPException, status +from fastapi import APIRouter, Depends +from fastapi.responses import JSONResponse from module.conf import settings -from module.models import Config +from module.models import Config, APIResponse from module.security.api import get_current_user, UNAUTHORIZED router = APIRouter(prefix="/config", tags=["config"]) @@ -17,7 +18,7 @@ async def get_config(current_user=Depends(get_current_user)): return settings -@router.patch("/update") +@router.patch("/update", response_model=APIResponse) async def update_config(config: Config, current_user=Depends(get_current_user)): if not current_user: raise UNAUTHORIZED @@ -26,7 +27,13 @@ async def update_config(config: Config, current_user=Depends(get_current_user)): settings.load() # update_rss() logger.info("Config updated") - return {"message": "Success"} + return JSONResponse( + status_code=200, + content={"msg_en": "Update config successfully.", "msg_zh": "更新配置成功。"} + ) except Exception as e: logger.warning(e) - return {"message": "Failed to update config"} + return JSONResponse( + status_code=406, + content={"msg_en": "Update config failed.", "msg_zh": "更新配置失败。"} + ) diff --git a/backend/src/module/api/download.py b/backend/src/module/api/download.py index 8e4b48a2..1ebf3581 100644 --- a/backend/src/module/api/download.py +++ b/backend/src/module/api/download.py @@ -1,8 +1,8 @@ from fastapi import APIRouter, Depends +from fastapi.responses import JSONResponse from module.manager import SeasonCollector -from module.models import Bangumi, RSSItem -from module.models.api import RssLink +from module.models import Bangumi, RSSItem, APIResponse from module.rss import RSSAnalyser from module.security.api import get_current_user, UNAUTHORIZED @@ -10,7 +10,7 @@ router = APIRouter(prefix="/download", tags=["download"]) analyser = RSSAnalyser() -@router.post("/analysis") +@router.post("/analysis", response_model=Bangumi) async def analysis(rss: RSSItem, current_user=Depends(get_current_user)): if not current_user: raise UNAUTHORIZED @@ -18,30 +18,48 @@ async def analysis(rss: RSSItem, current_user=Depends(get_current_user)): if data: return data else: - return {"status": "Failed to parse link"} + return JSONResponse( + status_code=406, + content={"msg_en": "Analysis failed.", "msg_zh": "解析失败。"}, + ) -@router.post("/collection") +@router.post("/collection", response_model=APIResponse) async def download_collection(data: Bangumi, current_user=Depends(get_current_user)): if not current_user: raise UNAUTHORIZED if data: with SeasonCollector() as collector: if collector.collect_season(data, data.rss_link[0]): - return {"status": "Success"} + return JSONResponse( + status_code=200, + content={"msg_en": "Add torrent successfully.", "msg_zh": "添加种子成功。"}, + ) else: - return {"status": "Failed to add torrent"} + return JSONResponse( + status_code=406, + content={"msg_en": "Add torrent failed.", "msg_zh": "添加种子失败。"}, + ) else: - return {"status": "Failed to parse link"} + return JSONResponse( + status_code=406, + content={"msg_en": "Add torrent failed.", "msg_zh": "添加种子失败。"}, + ) -@router.post("/subscribe") +@router.post("/subscribe", response_model=APIResponse) async def subscribe(data: Bangumi, current_user=Depends(get_current_user)): if not current_user: raise UNAUTHORIZED if data: with SeasonCollector() as collector: collector.subscribe_season(data) - return {"status": "Success"} + return JSONResponse( + status_code=200, + content={"msg_en": "Subscribe successfully.", "msg_zh": "订阅成功。"}, + ) else: - return {"status": "Failed to parse link"} + return JSONResponse( + status_code=406, + content={"msg_en": "Subscribe failed.", "msg_zh": "订阅失败。"}, + ) diff --git a/backend/src/module/api/log.py b/backend/src/module/api/log.py index 4c1b6e1d..520e316d 100644 --- a/backend/src/module/api/log.py +++ b/backend/src/module/api/log.py @@ -1,9 +1,9 @@ -import os - from fastapi import APIRouter, Depends, HTTPException, Response, status +from fastapi.responses import JSONResponse from module.conf import LOG_PATH from module.security.api import get_current_user, UNAUTHORIZED +from module.models import APIResponse router = APIRouter(prefix="/log", tags=["log"]) @@ -19,12 +19,18 @@ async def get_log(current_user=Depends(get_current_user)): return Response("Log file not found", status_code=404) -@router.get("/clear") +@router.get("/clear", response_model=APIResponse) async def clear_log(current_user=Depends(get_current_user)): if not current_user: raise UNAUTHORIZED if LOG_PATH.exists(): LOG_PATH.write_text("") - return {"status": "ok"} + return JSONResponse( + status_code=200, + content={"msg_en": "Log cleared successfully.", "msg_zh": "日志清除成功。"}, + ) else: - return Response("Log file not found", status_code=404) + return JSONResponse( + status_code=406, + content={"msg_en": "Log file not found.", "msg_zh": "日志文件未找到。"}, + ) diff --git a/backend/src/module/api/program.py b/backend/src/module/api/program.py index 4bed2f8f..7bd1cedd 100644 --- a/backend/src/module/api/program.py +++ b/backend/src/module/api/program.py @@ -3,6 +3,7 @@ import os import signal from fastapi import APIRouter, Depends, HTTPException +from fastapi.responses import JSONResponse from module.core import Program from module.conf import VERSION @@ -36,7 +37,7 @@ async def restart(current_user=Depends(get_current_user)): raise HTTPException(status_code=500, detail="Failed to restart program") -@router.get("/start") +@router.get("/start", response_model=JSONResponse) async def start(current_user=Depends(get_current_user)): if not current_user: raise UNAUTHORIZED diff --git a/backend/src/module/api/response.py b/backend/src/module/api/response.py index 00365a7e..b4705cc4 100644 --- a/backend/src/module/api/response.py +++ b/backend/src/module/api/response.py @@ -7,7 +7,6 @@ def u_response(response_model: ResponseModel): return JSONResponse( status_code=response_model.status_code, content={ - "status": response_model.status, "msg_en": response_model.msg_en, "msg_zh": response_model.msg_zh, }, diff --git a/backend/src/module/api/rss.py b/backend/src/module/api/rss.py index eaec955c..5ea96ddb 100644 --- a/backend/src/module/api/rss.py +++ b/backend/src/module/api/rss.py @@ -1,10 +1,9 @@ -from typing import Optional - from fastapi import APIRouter, Depends +from fastapi.responses import JSONResponse from .response import u_response -from module.models import RSSItem, RSSUpdate +from module.models import RSSItem, RSSUpdate, Torrent from module.rss import RSSEngine from module.security.api import get_current_user, UNAUTHORIZED from module.downloader import DownloadClient @@ -21,7 +20,7 @@ async def get_rss(current_user=Depends(get_current_user)): return engine.rss.search_all() -@router.post("/add") +@router.post("/add", response_model=JSONResponse) async def add_rss(rss: RSSItem, current_user=Depends(get_current_user)): if not current_user: raise UNAUTHORIZED @@ -30,15 +29,25 @@ async def add_rss(rss: RSSItem, current_user=Depends(get_current_user)): return u_response(result) -@router.delete("/delete/{rss_id}") +@router.delete("/delete/{rss_id}", response_model=JSONResponse) async def delete_rss(rss_id: int, current_user=Depends(get_current_user)): if not current_user: raise UNAUTHORIZED with RSSEngine() as engine: result = engine.rss.delete(rss_id) + if result: + return JSONResponse( + status_code=200, + content={"msg_en": "Delete RSS successfully.", "msg_zh": "删除 RSS 成功。"}, + ) + else: + return JSONResponse( + status_code=406, + content={"msg_en": "Delete RSS failed.", "msg_zh": "删除 RSS 失败。"}, + ) -@router.patch("/update/{rss_id}") +@router.patch("/update/{rss_id}", response_model=JSONResponse) async def update_rss( rss_id: int, data: RSSUpdate, current_user=Depends(get_current_user) ): @@ -46,25 +55,43 @@ async def update_rss( raise UNAUTHORIZED with RSSEngine() as engine: result = engine.rss.update(rss_id, data) + if result: + return JSONResponse( + status_code=200, + content={"msg_en": "Update RSS successfully.", "msg_zh": "更新 RSS 成功。"}, + ) + else: + return JSONResponse( + status_code=406, + content={"msg_en": "Update RSS failed.", "msg_zh": "更新 RSS 失败。"}, + ) -@router.get("/refresh/all") +@router.get("/refresh/all", response_model=JSONResponse) async def refresh_all(current_user=Depends(get_current_user)): if not current_user: raise UNAUTHORIZED with RSSEngine() as engine, DownloadClient() as client: - response = engine.refresh_rss(client) + engine.refresh_rss(client) + return JSONResponse( + status_code=200, + content={"msg_en": "Refresh all RSS successfully.", "msg_zh": "刷新 RSS 成功。"}, + ) -@router.get("/refresh/{rss_id}") +@router.get("/refresh/{rss_id}", response_model=JSONResponse) async def refresh_rss(rss_id: int, current_user=Depends(get_current_user)): if not current_user: raise UNAUTHORIZED with RSSEngine() as engine, DownloadClient() as client: - response = engine.refresh_rss(client, rss_id) + engine.refresh_rss(client, rss_id) + return JSONResponse( + status_code=200, + content={"msg_en": "Refresh RSS successfully.", "msg_zh": "刷新 RSS 成功。"}, + ) -@router.get("/torrent/{rss_id}") +@router.get("/torrent/{rss_id}", response_model=list[Torrent]) async def get_torrent(rss_id: int, current_user=Depends(get_current_user)): if not current_user: raise UNAUTHORIZED diff --git a/backend/src/module/api/search.py b/backend/src/module/api/search.py index 0daff511..beaa18fb 100644 --- a/backend/src/module/api/search.py +++ b/backend/src/module/api/search.py @@ -3,12 +3,13 @@ from fastapi.responses import StreamingResponse from module.searcher import SearchTorrent from module.security.api import get_current_user, UNAUTHORIZED +from module.models import Torrent router = APIRouter(prefix="/search", tags=["search"]) -@router.get("/") +@router.get("/", response_model=list[Torrent]) async def search_torrents( site: str = "mikan", keywords: str = Query(None), diff --git a/backend/src/module/database/rss.py b/backend/src/module/database/rss.py index c649bb16..97b4415b 100644 --- a/backend/src/module/database/rss.py +++ b/backend/src/module/database/rss.py @@ -58,10 +58,15 @@ class RSSDatabase: select(RSSItem).where(and_(RSSItem.aggregate, RSSItem.enabled)) ).all() - def delete(self, _id: int): + def delete(self, _id: int) -> bool: condition = delete(RSSItem).where(RSSItem.id == _id) - self.session.exec(condition) - self.session.commit() + try: + self.session.exec(condition) + self.session.commit() + return True + except Exception as e: + logger.error("Delete RSS Item failed.") + return False def delete_all(self): condition = delete(RSSItem) diff --git a/backend/src/module/manager/torrent.py b/backend/src/module/manager/torrent.py index 81411c88..ce2daaba 100644 --- a/backend/src/module/manager/torrent.py +++ b/backend/src/module/manager/torrent.py @@ -72,12 +72,7 @@ class TorrentManager(Database): self.bangumi.update(data) if file: torrent_message = self.delete_torrents(data, client) - return JSONResponse( - status_code=200, - content={ - "msg": f"Disable {data.official_title} rule. {torrent_message}" - }, - ) + return torrent_message logger.info(f"[Manager] Disable rule for {data.official_title}") return ResponseModel( status_code=200, @@ -148,6 +143,9 @@ class TorrentManager(Database): data = self.bangumi.search_id(int(_id)) if not data: logger.error(f"[Manager] Can't find data with {_id}") - return {"status": "error", "msg": f"Can't find data with {_id}"} + return JSONResponse( + status_code=406, + content={"msg_en": f"Can't find data with {_id}", "msg_zh": f"无法找到 id {_id} 的数据"}, + ) else: return data diff --git a/backend/src/module/models/__init__.py b/backend/src/module/models/__init__.py index a5578a5f..7a00b90d 100644 --- a/backend/src/module/models/__init__.py +++ b/backend/src/module/models/__init__.py @@ -3,4 +3,4 @@ from .config import Config from .rss import RSSItem, RSSUpdate from .torrent import EpisodeFile, SubtitleFile, Torrent, TorrentUpdate from .user import UserLogin, User, UserUpdate -from .response import ResponseModel +from .response import ResponseModel, APIResponse diff --git a/backend/src/module/models/response.py b/backend/src/module/models/response.py index 4ade25d6..73079631 100644 --- a/backend/src/module/models/response.py +++ b/backend/src/module/models/response.py @@ -6,3 +6,8 @@ class ResponseModel(BaseModel): status_code: int = Field(..., example=200) msg_en: str msg_zh: str + + +class APIResponse(BaseModel): + msg_en: str = Field(..., example="Success") + msg_zh: str = Field(..., example="成功") \ No newline at end of file