fix: api bugs, collect bug.

This commit is contained in:
EstrellaXD
2023-08-10 19:59:01 +08:00
parent e503196146
commit 0b053c9312
6 changed files with 61 additions and 22 deletions

View File

@@ -7,6 +7,7 @@ from .download import router as download_router
from .log import router as log_router
from .program import router as program_router
from .rss import router as rss_router
from .search import router as search_router
__all__ = "v1"
@@ -19,3 +20,4 @@ v1.include_router(download_router)
v1.include_router(bangumi_router)
v1.include_router(config_router)
v1.include_router(rss_router)
v1.include_router(search_router)

View File

@@ -15,8 +15,7 @@ async def get_all_data(current_user=Depends(get_current_user)):
if not current_user:
raise UNAUTHORIZED
with TorrentManager() as manager:
resp = manager.bangumi.search_all()
return u_response(resp)
return manager.bangumi.search_all()
@router.get("/get/{bangumi_id}", response_model=Bangumi)
@@ -24,8 +23,7 @@ async def get_data(bangumi_id: str, current_user=Depends(get_current_user)):
if not current_user:
raise UNAUTHORIZED
with TorrentManager() as manager:
resp = manager.search_one(bangumi_id)
return u_response(resp)
return manager.search_one(bangumi_id)
@router.patch("/update/{bangumi_id}")

View File

@@ -6,18 +6,18 @@ from module.conf import settings
from module.models import Config
from module.security.api import get_current_user, UNAUTHORIZED
router = APIRouter(tags=["config"])
router = APIRouter(prefix="/config", tags=["config"])
logger = logging.getLogger(__name__)
@router.get("/getConfig", response_model=Config)
@router.get("/get", response_model=Config)
async def get_config(current_user=Depends(get_current_user)):
if not current_user:
raise UNAUTHORIZED
return settings.dict()
@router.post("/updateConfig")
@router.patch("/update")
async def update_config(config: Config, current_user=Depends(get_current_user)):
if not current_user:
raise UNAUTHORIZED

View File

@@ -0,0 +1,26 @@
from fastapi import APIRouter, Query, Depends
from fastapi.responses import StreamingResponse
from module.searcher import SearchTorrent
from module.security.api import get_current_user, UNAUTHORIZED
router = APIRouter(prefix="/search", tags=["search"])
@router.get("/")
async def search_torrents(
site: str = "mikan",
keywords: str = Query(None),
current_user=Depends(get_current_user),
):
if not current_user:
raise UNAUTHORIZED
if not keywords:
return []
keywords = keywords.split(" ")
with SearchTorrent() as st:
return StreamingResponse(
content=st.analyse_keyword(keywords=keywords, site=site),
media_type="application/json",
)

View File

@@ -15,6 +15,7 @@ class RequestContent(RequestURL):
self,
_url: str,
_filter: str = "|".join(settings.rss_parser.filter),
limit: int = 100,
retry: int = 3,
) -> list[Torrent]:
try:
@@ -28,6 +29,8 @@ class RequestContent(RequestURL):
torrents.append(
Torrent(name=_title, url=torrent_url, homepage=homepage)
)
if len(torrents) >= limit:
break
return torrents
except ConnectionError:
return []

View File

@@ -1,5 +1,9 @@
import json
from module.models import Bangumi, Torrent
from module.network import RequestContent
from module.rss import RSSAnalyser
from module.searcher.plugin import search_url
SEARCH_KEY = [
@@ -12,23 +16,27 @@ SEARCH_KEY = [
]
class SearchTorrent(RequestContent):
class SearchTorrent(RequestContent, RSSAnalyser):
def search_torrents(
self, keywords: list[str], site: str = "mikan"
self, keywords: list[str], site: str = "mikan", limit: int = 5
) -> list[Torrent]:
url = search_url(site, keywords)
# TorrentInfo to TorrentBase
torrents = self.get_torrents(url)
torrents = self.get_torrents(url, limit=limit)
return torrents
def to_dict():
for torrent in torrents:
yield {
"name": torrent.name,
"torrent_link": torrent.url,
"homepage": torrent.homepage,
}
return [Torrent(**d) for d in to_dict()]
def analyse_keyword(self, keywords: list[str], site: str = "mikan"):
bangumis = []
torrents = self.search_torrents(keywords, site)
# Generate a list of json
yield "["
for idx, torrent in enumerate(torrents):
bangumi = self.torrent_to_data(torrent)
if bangumi:
yield json.dumps(bangumi.dict())
if idx != len(torrents) - 1:
yield ","
yield "]"
# Analyse bangumis
def search_season(self, data: Bangumi):
keywords = [getattr(data, key) for key in SEARCH_KEY if getattr(data, key)]
@@ -38,5 +46,7 @@ class SearchTorrent(RequestContent):
if __name__ == "__main__":
with SearchTorrent() as st:
for t in st.search_torrents(["魔法科高校の劣等生"]):
print(t)
keywords = ["无职转生", "第二季"]
bangumis = st.analyse_keyword(keywords)
for bangumi in bangumis:
print(bangumi)