fix: leak of response model.

This commit is contained in:
EstrellaXD
2023-08-13 18:13:25 +08:00
parent 8300461155
commit 6cac596d85
12 changed files with 130 additions and 57 deletions

View File

@@ -4,7 +4,7 @@ from fastapi.responses import JSONResponse
from .response import u_response
from module.manager import TorrentManager
from module.models import Bangumi, BangumiUpdate
from module.models import Bangumi, BangumiUpdate, APIResponse
from module.security.api import get_current_user, UNAUTHORIZED
router = APIRouter(prefix="/bangumi", tags=["bangumi"])
@@ -23,10 +23,11 @@ async def get_data(bangumi_id: str, current_user=Depends(get_current_user)):
if not current_user:
raise UNAUTHORIZED
with TorrentManager() as manager:
return manager.search_one(bangumi_id)
resp = manager.search_one(bangumi_id)
return resp
@router.patch("/update/{bangumi_id}")
@router.patch("/update/{bangumi_id}", response_model=APIResponse)
async def update_rule(
bangumi_id: int, data: BangumiUpdate, current_user=Depends(get_current_user)
):
@@ -37,7 +38,7 @@ async def update_rule(
return u_response(resp)
@router.delete("/delete/{bangumi_id}")
@router.delete("/delete/{bangumi_id}", response_model=APIResponse)
async def delete_rule(
bangumi_id: str, file: bool = False, current_user=Depends(get_current_user)
):
@@ -48,7 +49,7 @@ async def delete_rule(
return u_response(resp)
@router.delete("/delete/many/")
@router.delete("/delete/many/", response_model=APIResponse)
async def delete_many_rule(
bangumi_id: list, file: bool = False, current_user=Depends(get_current_user)
):
@@ -59,14 +60,15 @@ async def delete_many_rule(
manager.delete_rule(i, file)
@router.delete("/disable/{bangumi_id}")
@router.delete("/disable/{bangumi_id}", response_model=APIResponse)
async def disable_rule(
bangumi_id: str, file: bool = False, current_user=Depends(get_current_user)
):
if not current_user:
raise UNAUTHORIZED
with TorrentManager() as manager:
return manager.disable_rule(bangumi_id, file)
resp = manager.disable_rule(bangumi_id, file)
return u_response(resp)
@router.delete("/disable/many/")
@@ -80,18 +82,22 @@ async def disable_many_rule(
manager.disable_rule(i, file)
@router.get("/enable/{bangumi_id}")
@router.get("/enable/{bangumi_id}", response_model=APIResponse)
async def enable_rule(bangumi_id: str, current_user=Depends(get_current_user)):
if not current_user:
raise UNAUTHORIZED
with TorrentManager() as manager:
return manager.enable_rule(bangumi_id)
resp = manager.enable_rule(bangumi_id)
return u_response(resp)
@router.get("/reset/all")
@router.get("/reset/all", response_model=APIResponse)
async def reset_all(current_user=Depends(get_current_user)):
if not current_user:
raise UNAUTHORIZED
with TorrentManager() as manager:
manager.bangumi.delete_all()
return JSONResponse(status_code=200, content={"message": "OK"})
return JSONResponse(
status_code=200,
content={"msg_en": "Reset all rules successfully.", "msg_zh": "重置所有规则成功。"},
)

View File

@@ -1,9 +1,10 @@
import logging
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi import APIRouter, Depends
from fastapi.responses import JSONResponse
from module.conf import settings
from module.models import Config
from module.models import Config, APIResponse
from module.security.api import get_current_user, UNAUTHORIZED
router = APIRouter(prefix="/config", tags=["config"])
@@ -17,7 +18,7 @@ async def get_config(current_user=Depends(get_current_user)):
return settings
@router.patch("/update")
@router.patch("/update", response_model=APIResponse)
async def update_config(config: Config, current_user=Depends(get_current_user)):
if not current_user:
raise UNAUTHORIZED
@@ -26,7 +27,13 @@ async def update_config(config: Config, current_user=Depends(get_current_user)):
settings.load()
# update_rss()
logger.info("Config updated")
return {"message": "Success"}
return JSONResponse(
status_code=200,
content={"msg_en": "Update config successfully.", "msg_zh": "更新配置成功。"}
)
except Exception as e:
logger.warning(e)
return {"message": "Failed to update config"}
return JSONResponse(
status_code=406,
content={"msg_en": "Update config failed.", "msg_zh": "更新配置失败。"}
)

View File

@@ -1,8 +1,8 @@
from fastapi import APIRouter, Depends
from fastapi.responses import JSONResponse
from module.manager import SeasonCollector
from module.models import Bangumi, RSSItem
from module.models.api import RssLink
from module.models import Bangumi, RSSItem, APIResponse
from module.rss import RSSAnalyser
from module.security.api import get_current_user, UNAUTHORIZED
@@ -10,7 +10,7 @@ router = APIRouter(prefix="/download", tags=["download"])
analyser = RSSAnalyser()
@router.post("/analysis")
@router.post("/analysis", response_model=Bangumi)
async def analysis(rss: RSSItem, current_user=Depends(get_current_user)):
if not current_user:
raise UNAUTHORIZED
@@ -18,30 +18,48 @@ async def analysis(rss: RSSItem, current_user=Depends(get_current_user)):
if data:
return data
else:
return {"status": "Failed to parse link"}
return JSONResponse(
status_code=406,
content={"msg_en": "Analysis failed.", "msg_zh": "解析失败。"},
)
@router.post("/collection")
@router.post("/collection", response_model=APIResponse)
async def download_collection(data: Bangumi, current_user=Depends(get_current_user)):
if not current_user:
raise UNAUTHORIZED
if data:
with SeasonCollector() as collector:
if collector.collect_season(data, data.rss_link[0]):
return {"status": "Success"}
return JSONResponse(
status_code=200,
content={"msg_en": "Add torrent successfully.", "msg_zh": "添加种子成功。"},
)
else:
return {"status": "Failed to add torrent"}
return JSONResponse(
status_code=406,
content={"msg_en": "Add torrent failed.", "msg_zh": "添加种子失败。"},
)
else:
return {"status": "Failed to parse link"}
return JSONResponse(
status_code=406,
content={"msg_en": "Add torrent failed.", "msg_zh": "添加种子失败。"},
)
@router.post("/subscribe")
@router.post("/subscribe", response_model=APIResponse)
async def subscribe(data: Bangumi, current_user=Depends(get_current_user)):
if not current_user:
raise UNAUTHORIZED
if data:
with SeasonCollector() as collector:
collector.subscribe_season(data)
return {"status": "Success"}
return JSONResponse(
status_code=200,
content={"msg_en": "Subscribe successfully.", "msg_zh": "订阅成功。"},
)
else:
return {"status": "Failed to parse link"}
return JSONResponse(
status_code=406,
content={"msg_en": "Subscribe failed.", "msg_zh": "订阅失败。"},
)

View File

@@ -1,9 +1,9 @@
import os
from fastapi import APIRouter, Depends, HTTPException, Response, status
from fastapi.responses import JSONResponse
from module.conf import LOG_PATH
from module.security.api import get_current_user, UNAUTHORIZED
from module.models import APIResponse
router = APIRouter(prefix="/log", tags=["log"])
@@ -19,12 +19,18 @@ async def get_log(current_user=Depends(get_current_user)):
return Response("Log file not found", status_code=404)
@router.get("/clear")
@router.get("/clear", response_model=APIResponse)
async def clear_log(current_user=Depends(get_current_user)):
if not current_user:
raise UNAUTHORIZED
if LOG_PATH.exists():
LOG_PATH.write_text("")
return {"status": "ok"}
return JSONResponse(
status_code=200,
content={"msg_en": "Log cleared successfully.", "msg_zh": "日志清除成功。"},
)
else:
return Response("Log file not found", status_code=404)
return JSONResponse(
status_code=406,
content={"msg_en": "Log file not found.", "msg_zh": "日志文件未找到。"},
)

View File

@@ -3,6 +3,7 @@ import os
import signal
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import JSONResponse
from module.core import Program
from module.conf import VERSION
@@ -36,7 +37,7 @@ async def restart(current_user=Depends(get_current_user)):
raise HTTPException(status_code=500, detail="Failed to restart program")
@router.get("/start")
@router.get("/start", response_model=JSONResponse)
async def start(current_user=Depends(get_current_user)):
if not current_user:
raise UNAUTHORIZED

View File

@@ -7,7 +7,6 @@ def u_response(response_model: ResponseModel):
return JSONResponse(
status_code=response_model.status_code,
content={
"status": response_model.status,
"msg_en": response_model.msg_en,
"msg_zh": response_model.msg_zh,
},

View File

@@ -1,10 +1,9 @@
from typing import Optional
from fastapi import APIRouter, Depends
from fastapi.responses import JSONResponse
from .response import u_response
from module.models import RSSItem, RSSUpdate
from module.models import RSSItem, RSSUpdate, Torrent
from module.rss import RSSEngine
from module.security.api import get_current_user, UNAUTHORIZED
from module.downloader import DownloadClient
@@ -21,7 +20,7 @@ async def get_rss(current_user=Depends(get_current_user)):
return engine.rss.search_all()
@router.post("/add")
@router.post("/add", response_model=JSONResponse)
async def add_rss(rss: RSSItem, current_user=Depends(get_current_user)):
if not current_user:
raise UNAUTHORIZED
@@ -30,15 +29,25 @@ async def add_rss(rss: RSSItem, current_user=Depends(get_current_user)):
return u_response(result)
@router.delete("/delete/{rss_id}")
@router.delete("/delete/{rss_id}", response_model=JSONResponse)
async def delete_rss(rss_id: int, current_user=Depends(get_current_user)):
if not current_user:
raise UNAUTHORIZED
with RSSEngine() as engine:
result = engine.rss.delete(rss_id)
if result:
return JSONResponse(
status_code=200,
content={"msg_en": "Delete RSS successfully.", "msg_zh": "删除 RSS 成功。"},
)
else:
return JSONResponse(
status_code=406,
content={"msg_en": "Delete RSS failed.", "msg_zh": "删除 RSS 失败。"},
)
@router.patch("/update/{rss_id}")
@router.patch("/update/{rss_id}", response_model=JSONResponse)
async def update_rss(
rss_id: int, data: RSSUpdate, current_user=Depends(get_current_user)
):
@@ -46,25 +55,43 @@ async def update_rss(
raise UNAUTHORIZED
with RSSEngine() as engine:
result = engine.rss.update(rss_id, data)
if result:
return JSONResponse(
status_code=200,
content={"msg_en": "Update RSS successfully.", "msg_zh": "更新 RSS 成功。"},
)
else:
return JSONResponse(
status_code=406,
content={"msg_en": "Update RSS failed.", "msg_zh": "更新 RSS 失败。"},
)
@router.get("/refresh/all")
@router.get("/refresh/all", response_model=JSONResponse)
async def refresh_all(current_user=Depends(get_current_user)):
if not current_user:
raise UNAUTHORIZED
with RSSEngine() as engine, DownloadClient() as client:
response = engine.refresh_rss(client)
engine.refresh_rss(client)
return JSONResponse(
status_code=200,
content={"msg_en": "Refresh all RSS successfully.", "msg_zh": "刷新 RSS 成功。"},
)
@router.get("/refresh/{rss_id}")
@router.get("/refresh/{rss_id}", response_model=JSONResponse)
async def refresh_rss(rss_id: int, current_user=Depends(get_current_user)):
if not current_user:
raise UNAUTHORIZED
with RSSEngine() as engine, DownloadClient() as client:
response = engine.refresh_rss(client, rss_id)
engine.refresh_rss(client, rss_id)
return JSONResponse(
status_code=200,
content={"msg_en": "Refresh RSS successfully.", "msg_zh": "刷新 RSS 成功。"},
)
@router.get("/torrent/{rss_id}")
@router.get("/torrent/{rss_id}", response_model=list[Torrent])
async def get_torrent(rss_id: int, current_user=Depends(get_current_user)):
if not current_user:
raise UNAUTHORIZED

View File

@@ -3,12 +3,13 @@ from fastapi.responses import StreamingResponse
from module.searcher import SearchTorrent
from module.security.api import get_current_user, UNAUTHORIZED
from module.models import Torrent
router = APIRouter(prefix="/search", tags=["search"])
@router.get("/")
@router.get("/", response_model=list[Torrent])
async def search_torrents(
site: str = "mikan",
keywords: str = Query(None),

View File

@@ -58,10 +58,15 @@ class RSSDatabase:
select(RSSItem).where(and_(RSSItem.aggregate, RSSItem.enabled))
).all()
def delete(self, _id: int):
def delete(self, _id: int) -> bool:
condition = delete(RSSItem).where(RSSItem.id == _id)
self.session.exec(condition)
self.session.commit()
try:
self.session.exec(condition)
self.session.commit()
return True
except Exception as e:
logger.error("Delete RSS Item failed.")
return False
def delete_all(self):
condition = delete(RSSItem)

View File

@@ -72,12 +72,7 @@ class TorrentManager(Database):
self.bangumi.update(data)
if file:
torrent_message = self.delete_torrents(data, client)
return JSONResponse(
status_code=200,
content={
"msg": f"Disable {data.official_title} rule. {torrent_message}"
},
)
return torrent_message
logger.info(f"[Manager] Disable rule for {data.official_title}")
return ResponseModel(
status_code=200,
@@ -148,6 +143,9 @@ class TorrentManager(Database):
data = self.bangumi.search_id(int(_id))
if not data:
logger.error(f"[Manager] Can't find data with {_id}")
return {"status": "error", "msg": f"Can't find data with {_id}"}
return JSONResponse(
status_code=406,
content={"msg_en": f"Can't find data with {_id}", "msg_zh": f"无法找到 id {_id} 的数据"},
)
else:
return data

View File

@@ -3,4 +3,4 @@ from .config import Config
from .rss import RSSItem, RSSUpdate
from .torrent import EpisodeFile, SubtitleFile, Torrent, TorrentUpdate
from .user import UserLogin, User, UserUpdate
from .response import ResponseModel
from .response import ResponseModel, APIResponse

View File

@@ -6,3 +6,8 @@ class ResponseModel(BaseModel):
status_code: int = Field(..., example=200)
msg_en: str
msg_zh: str
class APIResponse(BaseModel):
msg_en: str = Field(..., example="Success")
msg_zh: str = Field(..., example="成功")