Merge pull request #502 from EstrellaXD/3.1-dev

3.1.4
This commit is contained in:
Estrella Pan
2023-10-03 22:56:29 +08:00
committed by GitHub
58 changed files with 665 additions and 205 deletions

9
.gitignore vendored
View File

@@ -182,7 +182,6 @@ test.*
# webui
logs
*.log
npm-debug.log*
yarn-debug.log*
yarn-error.log*
@@ -202,8 +201,6 @@ dev-dist
!.vscode/tasks.json
!.vscode/launch.json
!.vscode/extensions.json
.idea
.DS_Store
*.suo
*.ntvs*
@@ -211,5 +208,11 @@ dev-dist
*.sln
*.sw?
# vitepress
/docs/.vitepress/cache/
# test file
test.*
test_*

View File

@@ -20,6 +20,7 @@ RUN set -ex && \
su-exec \
shadow \
tini \
openssl \
tzdata && \
python3 -m pip install --no-cache-dir --upgrade pip && \
sed -i '/bcrypt/d' requirements.txt && \

View File

@@ -4,5 +4,8 @@ repos:
hooks:
- id: black
language: python
# TODO: add ruff lint check before committing.
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.0.291
hooks:
- id: ruff

View File

@@ -1,4 +1,5 @@
-r requirements.txt
ruff
black
pre-commit
pre-commit
pytest

View File

@@ -26,3 +26,4 @@ python-multipart==0.0.6
sqlmodel==0.0.8
sse-starlette==1.6.5
semver==3.0.1
openai==0.28.1

View File

@@ -1,5 +1,5 @@
import os
import logging
import os
import uvicorn
from fastapi import FastAPI, Request
@@ -8,8 +8,7 @@ from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from module.api import v1
from module.api.proxy import router as proxy_router
from module.conf import settings, setup_logger, VERSION
from starlette.types import ASGIApp
from module.conf import VERSION, settings, setup_logger
setup_logger(reset=True)
logger = logging.getLogger(__name__)

View File

@@ -1,21 +1,21 @@
from datetime import timedelta
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordRequestForm
from fastapi.responses import JSONResponse, Response
from fastapi.security import OAuth2PasswordRequestForm
from .response import u_response
from module.models.user import User, UserUpdate
from module.models import APIResponse
from module.models.user import User, UserUpdate
from module.security.api import (
active_user,
auth_user,
get_current_user,
update_user_info,
active_user
)
from module.security.jwt import create_access_token
from .response import u_response
router = APIRouter(prefix="/auth", tags=["auth"])
@@ -31,7 +31,10 @@ async def login(response: Response, form_data=Depends(OAuth2PasswordRequestForm)
return {"access_token": token, "token_type": "bearer"}
return u_response(resp)
@router.get("/refresh_token", response_model=dict, dependencies=[Depends(get_current_user)])
@router.get(
"/refresh_token", response_model=dict, dependencies=[Depends(get_current_user)]
)
async def refresh(response: Response):
token = create_access_token(
data={"sub": active_user[0]}, expires_delta=timedelta(days=1)
@@ -40,7 +43,9 @@ async def refresh(response: Response):
return {"access_token": token, "token_type": "bearer"}
@router.get("/logout", response_model=APIResponse, dependencies=[Depends(get_current_user)])
@router.get(
"/logout", response_model=APIResponse, dependencies=[Depends(get_current_user)]
)
async def logout(response: Response):
active_user.clear()
response.delete_cookie(key="token")
@@ -51,16 +56,20 @@ async def logout(response: Response):
@router.post("/update", response_model=dict, dependencies=[Depends(get_current_user)])
async def update_user(
user_data: UserUpdate, response: Response
):
async def update_user(user_data: UserUpdate, response: Response):
old_user = active_user[0]
if update_user_info(user_data, old_user):
token = create_access_token(data={"sub": old_user}, expires_delta=timedelta(days=1))
token = create_access_token(
data={"sub": old_user}, expires_delta=timedelta(days=1)
)
response.set_cookie(
key="token",
value=token,
httponly=True,
max_age=86400,
)
return {"access_token": token, "token_type": "bearer", "message": "update success"}
return {
"access_token": token,
"token_type": "bearer",
"message": "update success",
}

View File

@@ -1,11 +1,11 @@
from fastapi import APIRouter, Depends
from fastapi.responses import JSONResponse
from .response import u_response
from module.manager import TorrentManager
from module.models import Bangumi, BangumiUpdate, APIResponse
from module.security.api import get_current_user, UNAUTHORIZED
from module.models import APIResponse, Bangumi, BangumiUpdate
from module.security.api import UNAUTHORIZED, get_current_user
from .response import u_response
router = APIRouter(prefix="/bangumi", tags=["bangumi"])
@@ -16,36 +16,55 @@ def str_to_list(data: Bangumi):
return data
@router.get("/get/all", response_model=list[Bangumi], dependencies=[Depends(get_current_user)])
@router.get(
"/get/all", response_model=list[Bangumi], dependencies=[Depends(get_current_user)]
)
async def get_all_data():
with TorrentManager() as manager:
return manager.bangumi.search_all()
@router.get("/get/{bangumi_id}", response_model=Bangumi, dependencies=[Depends(get_current_user)])
@router.get(
"/get/{bangumi_id}",
response_model=Bangumi,
dependencies=[Depends(get_current_user)],
)
async def get_data(bangumi_id: str):
with TorrentManager() as manager:
resp = manager.search_one(bangumi_id)
return resp
@router.patch("/update/{bangumi_id}", response_model=APIResponse, dependencies=[Depends(get_current_user)])
@router.patch(
"/update/{bangumi_id}",
response_model=APIResponse,
dependencies=[Depends(get_current_user)],
)
async def update_rule(
bangumi_id: int, data: BangumiUpdate,
bangumi_id: int,
data: BangumiUpdate,
):
with TorrentManager() as manager:
resp = manager.update_rule(bangumi_id, data)
return u_response(resp)
@router.delete(path="/delete/{bangumi_id}", response_model=APIResponse, dependencies=[Depends(get_current_user)])
@router.delete(
path="/delete/{bangumi_id}",
response_model=APIResponse,
dependencies=[Depends(get_current_user)],
)
async def delete_rule(bangumi_id: str, file: bool = False):
with TorrentManager() as manager:
resp = manager.delete_rule(bangumi_id, file)
return u_response(resp)
@router.delete(path="/delete/many/", response_model=APIResponse, dependencies=[Depends(get_current_user)])
@router.delete(
path="/delete/many/",
response_model=APIResponse,
dependencies=[Depends(get_current_user)],
)
async def delete_many_rule(bangumi_id: list, file: bool = False):
with TorrentManager() as manager:
for i in bangumi_id:
@@ -53,14 +72,22 @@ async def delete_many_rule(bangumi_id: list, file: bool = False):
return u_response(resp)
@router.delete(path="/disable/{bangumi_id}", response_model=APIResponse, dependencies=[Depends(get_current_user)])
@router.delete(
path="/disable/{bangumi_id}",
response_model=APIResponse,
dependencies=[Depends(get_current_user)],
)
async def disable_rule(bangumi_id: str, file: bool = False):
with TorrentManager() as manager:
resp = manager.disable_rule(bangumi_id, file)
return u_response(resp)
@router.delete(path="/disable/many/", response_model=APIResponse, dependencies=[Depends(get_current_user)])
@router.delete(
path="/disable/many/",
response_model=APIResponse,
dependencies=[Depends(get_current_user)],
)
async def disable_many_rule(bangumi_id: list, file: bool = False):
with TorrentManager() as manager:
for i in bangumi_id:
@@ -68,21 +95,31 @@ async def disable_many_rule(bangumi_id: list, file: bool = False):
return u_response(resp)
@router.get(path="/enable/{bangumi_id}", response_model=APIResponse, dependencies=[Depends(get_current_user)])
@router.get(
path="/enable/{bangumi_id}",
response_model=APIResponse,
dependencies=[Depends(get_current_user)],
)
async def enable_rule(bangumi_id: str):
with TorrentManager() as manager:
resp = manager.enable_rule(bangumi_id)
return u_response(resp)
@router.get(path="/refresh/poster/all", response_model=APIResponse, dependencies=[Depends(get_current_user)])
@router.get(
path="/refresh/poster/all",
response_model=APIResponse,
dependencies=[Depends(get_current_user)],
)
async def refresh_poster():
with TorrentManager() as manager:
resp = manager.refresh_poster()
return u_response(resp)
@router.get("/reset/all", response_model=APIResponse, dependencies=[Depends(get_current_user)])
@router.get(
"/reset/all", response_model=APIResponse, dependencies=[Depends(get_current_user)]
)
async def reset_all():
with TorrentManager() as manager:
manager.bangumi.delete_all()

View File

@@ -4,8 +4,8 @@ from fastapi import APIRouter, Depends
from fastapi.responses import JSONResponse
from module.conf import settings
from module.models import Config, APIResponse
from module.security.api import get_current_user, UNAUTHORIZED
from module.models import APIResponse, Config
from module.security.api import UNAUTHORIZED, get_current_user
router = APIRouter(prefix="/config", tags=["config"])
logger = logging.getLogger(__name__)
@@ -16,7 +16,9 @@ async def get_config():
return settings
@router.patch("/update", response_model=APIResponse, dependencies=[Depends(get_current_user)])
@router.patch(
"/update", response_model=APIResponse, dependencies=[Depends(get_current_user)]
)
async def update_config(config: Config):
try:
settings.save(config_dict=config.dict())
@@ -25,11 +27,11 @@ async def update_config(config: Config):
logger.info("Config updated")
return JSONResponse(
status_code=200,
content={"msg_en": "Update config successfully.", "msg_zh": "更新配置成功。"}
content={"msg_en": "Update config successfully.", "msg_zh": "更新配置成功。"},
)
except Exception as e:
logger.warning(e)
return JSONResponse(
status_code=406,
content={"msg_en": "Update config failed.", "msg_zh": "更新配置失败。"}
content={"msg_en": "Update config failed.", "msg_zh": "更新配置失败。"},
)

View File

@@ -2,8 +2,8 @@ from fastapi import APIRouter, Depends, HTTPException, Response, status
from fastapi.responses import JSONResponse
from module.conf import LOG_PATH
from module.security.api import get_current_user, UNAUTHORIZED
from module.models import APIResponse
from module.security.api import UNAUTHORIZED, get_current_user
router = APIRouter(prefix="/log", tags=["log"])
@@ -17,7 +17,9 @@ async def get_log():
return Response("Log file not found", status_code=404)
@router.get("/clear", response_model=APIResponse, dependencies=[Depends(get_current_user)])
@router.get(
"/clear", response_model=APIResponse, dependencies=[Depends(get_current_user)]
)
async def clear_log():
if LOG_PATH.exists():
LOG_PATH.write_text("")

View File

@@ -5,12 +5,12 @@ import signal
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import JSONResponse
from .response import u_response
from module.conf import VERSION
from module.core import Program
from module.models import APIResponse
from module.conf import VERSION
from module.security.api import get_current_user, UNAUTHORIZED
from module.security.api import UNAUTHORIZED, get_current_user
from .response import u_response
logger = logging.getLogger(__name__)
program = Program()
@@ -27,7 +27,9 @@ async def shutdown():
program.stop()
@router.get("/restart", response_model=APIResponse, dependencies=[Depends(get_current_user)])
@router.get(
"/restart", response_model=APIResponse, dependencies=[Depends(get_current_user)]
)
async def restart():
try:
resp = program.restart()
@@ -40,11 +42,13 @@ async def restart():
detail={
"msg_en": "Failed to restart program.",
"msg_zh": "重启程序失败。",
}
},
)
@router.get("/start", response_model=APIResponse, dependencies=[Depends(get_current_user)])
@router.get(
"/start", response_model=APIResponse, dependencies=[Depends(get_current_user)]
)
async def start():
try:
resp = program.start()
@@ -57,11 +61,13 @@ async def start():
detail={
"msg_en": "Failed to start program.",
"msg_zh": "启动程序失败。",
}
},
)
@router.get("/stop", response_model=APIResponse, dependencies=[Depends(get_current_user)])
@router.get(
"/stop", response_model=APIResponse, dependencies=[Depends(get_current_user)]
)
async def stop():
return u_response(program.stop())
@@ -82,7 +88,9 @@ async def program_status():
}
@router.get("/shutdown", response_model=APIResponse, dependencies=[Depends(get_current_user)])
@router.get(
"/shutdown", response_model=APIResponse, dependencies=[Depends(get_current_user)]
)
async def shutdown_program():
program.stop()
logger.info("Shutting down program...")
@@ -94,6 +102,11 @@ async def shutdown_program():
# Check status
@router.get("/check/downloader", tags=["check"], response_model=bool, dependencies=[Depends(get_current_user)])
@router.get(
"/check/downloader",
tags=["check"],
response_model=bool,
dependencies=[Depends(get_current_user)],
)
async def check_downloader_status():
return program.check_downloader()

View File

@@ -1,5 +1,5 @@
from fastapi.responses import JSONResponse
from fastapi.exceptions import HTTPException
from fastapi.responses import JSONResponse
from module.models.response import ResponseModel
@@ -11,4 +11,4 @@ def u_response(response_model: ResponseModel):
"msg_en": response_model.msg_en,
"msg_zh": response_model.msg_zh,
},
)
)

View File

@@ -1,39 +1,52 @@
from fastapi import APIRouter, Depends
from fastapi.responses import JSONResponse
from .response import u_response
from module.models import RSSItem, RSSUpdate, Torrent, APIResponse, Bangumi
from module.rss import RSSEngine, RSSAnalyser
from module.security.api import get_current_user, UNAUTHORIZED
from module.downloader import DownloadClient
from module.manager import SeasonCollector
from module.models import APIResponse, Bangumi, RSSItem, RSSUpdate, Torrent
from module.rss import RSSAnalyser, RSSEngine
from module.security.api import UNAUTHORIZED, get_current_user
from .response import u_response
router = APIRouter(prefix="/rss", tags=["rss"])
@router.get(path="", response_model=list[RSSItem], dependencies=[Depends(get_current_user)])
@router.get(
path="", response_model=list[RSSItem], dependencies=[Depends(get_current_user)]
)
async def get_rss():
with RSSEngine() as engine:
return engine.rss.search_all()
@router.post(path="/add", response_model=APIResponse, dependencies=[Depends(get_current_user)])
@router.post(
path="/add", response_model=APIResponse, dependencies=[Depends(get_current_user)]
)
async def add_rss(rss: RSSItem):
with RSSEngine() as engine:
result = engine.add_rss(rss.url, rss.name, rss.aggregate, rss.parser)
return u_response(result)
@router.post(path="/enable/many", response_model=APIResponse, dependencies=[Depends(get_current_user)])
async def enable_many_rss(rss_ids: list[int], ):
@router.post(
path="/enable/many",
response_model=APIResponse,
dependencies=[Depends(get_current_user)],
)
async def enable_many_rss(
rss_ids: list[int],
):
with RSSEngine() as engine:
result = engine.enable_list(rss_ids)
return u_response(result)
@router.delete(path="/delete/{rss_id}", response_model=APIResponse, dependencies=[Depends(get_current_user)])
@router.delete(
path="/delete/{rss_id}",
response_model=APIResponse,
dependencies=[Depends(get_current_user)],
)
async def delete_rss(rss_id: int):
with RSSEngine() as engine:
if engine.rss.delete(rss_id):
@@ -48,14 +61,24 @@ async def delete_rss(rss_id: int):
)
@router.post(path="/delete/many", response_model=APIResponse, dependencies=[Depends(get_current_user)])
async def delete_many_rss(rss_ids: list[int], ):
@router.post(
path="/delete/many",
response_model=APIResponse,
dependencies=[Depends(get_current_user)],
)
async def delete_many_rss(
rss_ids: list[int],
):
with RSSEngine() as engine:
result = engine.delete_list(rss_ids)
return u_response(result)
@router.patch(path="/disable/{rss_id}", response_model=APIResponse, dependencies=[Depends(get_current_user)])
@router.patch(
path="/disable/{rss_id}",
response_model=APIResponse,
dependencies=[Depends(get_current_user)],
)
async def disable_rss(rss_id: int):
with RSSEngine() as engine:
if engine.rss.disable(rss_id):
@@ -70,14 +93,22 @@ async def disable_rss(rss_id: int):
)
@router.post(path="/disable/many", response_model=APIResponse, dependencies=[Depends(get_current_user)])
@router.post(
path="/disable/many",
response_model=APIResponse,
dependencies=[Depends(get_current_user)],
)
async def disable_many_rss(rss_ids: list[int]):
with RSSEngine() as engine:
result = engine.disable_list(rss_ids)
return u_response(result)
@router.patch(path="/update/{rss_id}", response_model=APIResponse, dependencies=[Depends(get_current_user)])
@router.patch(
path="/update/{rss_id}",
response_model=APIResponse,
dependencies=[Depends(get_current_user)],
)
async def update_rss(
rss_id: int, data: RSSUpdate, current_user=Depends(get_current_user)
):
@@ -96,7 +127,11 @@ async def update_rss(
)
@router.get(path="/refresh/all", response_model=APIResponse, dependencies=[Depends(get_current_user)])
@router.get(
path="/refresh/all",
response_model=APIResponse,
dependencies=[Depends(get_current_user)],
)
async def refresh_all():
with RSSEngine() as engine, DownloadClient() as client:
engine.refresh_rss(client)
@@ -106,7 +141,11 @@ async def refresh_all():
)
@router.get(path="/refresh/{rss_id}", response_model=APIResponse, dependencies=[Depends(get_current_user)])
@router.get(
path="/refresh/{rss_id}",
response_model=APIResponse,
dependencies=[Depends(get_current_user)],
)
async def refresh_rss(rss_id: int):
with RSSEngine() as engine, DownloadClient() as client:
engine.refresh_rss(client, rss_id)
@@ -116,8 +155,14 @@ async def refresh_rss(rss_id: int):
)
@router.get(path="/torrent/{rss_id}", response_model=list[Torrent], dependencies=[Depends(get_current_user)])
async def get_torrent(rss_id: int, ):
@router.get(
path="/torrent/{rss_id}",
response_model=list[Torrent],
dependencies=[Depends(get_current_user)],
)
async def get_torrent(
rss_id: int,
):
with RSSEngine() as engine:
return engine.get_rss_torrents(rss_id)
@@ -126,7 +171,9 @@ async def get_torrent(rss_id: int, ):
analyser = RSSAnalyser()
@router.post("/analysis", response_model=Bangumi, dependencies=[Depends(get_current_user)])
@router.post(
"/analysis", response_model=Bangumi, dependencies=[Depends(get_current_user)]
)
async def analysis(rss: RSSItem):
data = analyser.link_to_data(rss)
if isinstance(data, Bangumi):
@@ -135,16 +182,19 @@ async def analysis(rss: RSSItem):
return u_response(data)
@router.post("/collect", response_model=APIResponse, dependencies=[Depends(get_current_user)])
@router.post(
"/collect", response_model=APIResponse, dependencies=[Depends(get_current_user)]
)
async def download_collection(data: Bangumi):
with SeasonCollector() as collector:
resp = collector.collect_season(data, data.rss_link)
return u_response(resp)
@router.post("/subscribe", response_model=APIResponse, dependencies=[Depends(get_current_user)])
@router.post(
"/subscribe", response_model=APIResponse, dependencies=[Depends(get_current_user)]
)
async def subscribe(data: Bangumi):
with SeasonCollector() as collector:
resp = collector.subscribe_season(data)
return u_response(resp)

View File

@@ -1,19 +1,17 @@
from fastapi import APIRouter, Query, Depends
from fastapi import APIRouter, Depends, Query
from sse_starlette.sse import EventSourceResponse
from module.searcher import SearchTorrent, SEARCH_CONFIG
from module.security.api import get_current_user, UNAUTHORIZED
from module.models import Bangumi
from module.searcher import SEARCH_CONFIG, SearchTorrent
from module.security.api import UNAUTHORIZED, get_current_user
router = APIRouter(prefix="/search", tags=["search"])
@router.get("/bangumi", response_model=list[Bangumi], dependencies=[Depends(get_current_user)])
async def search_torrents(
site: str = "mikan",
keywords: str = Query(None)
):
@router.get(
"/bangumi", response_model=list[Bangumi], dependencies=[Depends(get_current_user)]
)
async def search_torrents(site: str = "mikan", keywords: str = Query(None)):
"""
Server Send Event for per Bangumi item
"""
@@ -26,6 +24,8 @@ async def search_torrents(
)
@router.get("/provider", response_model=list[str], dependencies=[Depends(get_current_user)])
@router.get(
"/provider", response_model=list[str], dependencies=[Depends(get_current_user)]
)
async def search_provider():
return list(SEARCH_CONFIG.keys())

View File

@@ -1,8 +1,9 @@
import logging
import requests
from pathlib import Path
from module.conf import settings, VERSION
import requests
from module.conf import VERSION, settings
from module.downloader import DownloadClient
from module.models import Config
from module.update import version_check
@@ -50,7 +51,11 @@ class Checker:
@staticmethod
def check_downloader() -> bool:
try:
url = f"http://{settings.downloader.host}" if "://" not in settings.downloader.host else f"{settings.downloader.host}"
url = (
f"http://{settings.downloader.host}"
if "://" not in settings.downloader.host
else f"{settings.downloader.host}"
)
response = requests.get(url, timeout=2)
if settings.downloader.type in response.text.lower():
with DownloadClient() as client:
@@ -74,4 +79,3 @@ class Checker:
if __name__ == "__main__":
# print(Checker().check_downloader())
requests.get("http://162.200.20.1", timeout=2)

View File

@@ -1,10 +1,11 @@
from pathlib import Path
from module.utils import json_config
DEFAULT_PROVIDER = {
"mikan": "https://mikanani.me/RSS/Search?searchstr=%s",
"nyaa": "https://nyaa.si/?page=rss&q=%s&c=0_0&f=0",
"dmhy": "http://dmhy.org/topics/rss/rss.xml?keyword=%s"
"mikan": "https://mikanani.me/RSS/Search?searchstr=%s",
"nyaa": "https://nyaa.si/?page=rss&q=%s&c=0_0&f=0",
"dmhy": "http://dmhy.org/topics/rss/rss.xml?keyword=%s",
}
PROVIDER_PATH = Path("config/search_provider.json")
@@ -19,5 +20,3 @@ def load_provider():
SEARCH_CONFIG = load_provider()

View File

@@ -1,8 +1,8 @@
import logging
from module.conf import VERSION, settings
from module.update import data_migration, from_30_to_31, start_up, first_run
from module.models import ResponseModel
from module.update import data_migration, first_run, from_30_to_31, start_up
from .sub_thread import RenameThread, RSSThread

View File

@@ -1,9 +1,9 @@
import logging
from sqlmodel import Session, select, delete, or_, and_
from sqlalchemy.sql import func
from typing import Optional
from sqlalchemy.sql import func
from sqlmodel import Session, and_, delete, false, or_, select
from module.models import Bangumi, BangumiUpdate
logger = logging.getLogger(__name__)
@@ -14,9 +14,14 @@ class BangumiDatabase:
self.session = session
def add(self, data: Bangumi):
statement = select(Bangumi).where(Bangumi.title_raw == data.title_raw)
bangumi = self.session.exec(statement).first()
if bangumi:
return False
self.session.add(data)
self.session.commit()
logger.debug(f"[Database] Insert {data.official_title} into database.")
return True
def add_all(self, datas: list[Bangumi]):
self.session.add_all(datas)
@@ -128,14 +133,18 @@ class BangumiDatabase:
statement = select(Bangumi).where(
and_(
func.instr(torrent_name, Bangumi.title_raw) > 0,
Bangumi.deleted == False,
# use `false()` to avoid E712 checking
# see: https://docs.astral.sh/ruff/rules/true-false-comparison/
Bangumi.deleted == false(),
)
)
return self.session.exec(statement).first()
def not_complete(self) -> list[Bangumi]:
# Find eps_complete = False
condition = select(Bangumi).where(Bangumi.eps_collect == False)
# use `false()` to avoid E712 checking
# see: https://docs.astral.sh/ruff/rules/true-false-comparison/
condition = select(Bangumi).where(Bangumi.eps_collect == false())
datas = self.session.exec(condition).all()
return datas

View File

@@ -1,12 +1,12 @@
from sqlmodel import Session, SQLModel
from module.models import Bangumi, User
from .bangumi import BangumiDatabase
from .engine import engine as e
from .rss import RSSDatabase
from .torrent import TorrentDatabase
from .bangumi import BangumiDatabase
from .user import UserDatabase
from .engine import engine as e
from module.models import User, Bangumi
class Database(Session):
@@ -40,5 +40,3 @@ class Database(Session):
self.bangumi.add_all(readd_bangumi)
self.add(User(**user_data[0]))
self.commit()

View File

@@ -1,6 +1,6 @@
from sqlmodel import create_engine, Session
from module.conf import DATA_PATH
from sqlmodel import Session, create_engine
from module.conf import DATA_PATH
engine = create_engine(DATA_PATH)

View File

@@ -1,6 +1,6 @@
import logging
from sqlmodel import Session, select, delete, and_
from sqlmodel import Session, and_, delete, select
from module.models import RSSItem, RSSUpdate
@@ -66,7 +66,6 @@ class RSSDatabase:
self.session.refresh(db_data)
return True
def search_id(self, _id: int) -> RSSItem:
return self.session.get(RSSItem, _id)
@@ -88,7 +87,7 @@ class RSSDatabase:
self.session.commit()
return True
except Exception as e:
logger.error("Delete RSS Item failed.")
logger.error(f"Delete RSS Item failed. Because: {e}")
return False
def delete_all(self):

View File

@@ -1,12 +1,12 @@
import logging
from fastapi import HTTPException
from module.models.user import User, UserUpdate, UserLogin
from module.models import ResponseModel
from module.security.jwt import get_password_hash, verify_password
from sqlmodel import Session, select
from module.models import ResponseModel
from module.models.user import User, UserLogin, UserUpdate
from module.security.jwt import get_password_hash, verify_password
logger = logging.getLogger(__name__)
@@ -26,23 +26,17 @@ class UserDatabase:
result = self.session.exec(statement).first()
if not result:
return ResponseModel(
status_code=401,
status=False,
msg_en="User not found",
msg_zh="用户不存在"
status_code=401, status=False, msg_en="User not found", msg_zh="用户不存在"
)
if not verify_password(user.password, result.password):
return ResponseModel(
status_code=401,
status=False,
msg_en="Incorrect password",
msg_zh="密码错误"
msg_zh="密码错误",
)
return ResponseModel(
status_code=200,
status=True,
msg_en="Login successfully",
msg_zh="登录成功"
status_code=200, status=True, msg_en="Login successfully", msg_zh="登录成功"
)
def update_user(self, username, update_user: UserUpdate):
@@ -91,7 +85,7 @@ class UserDatabase:
statement = select(User)
try:
result = self.session.exec(statement).all()
except Exception as e:
except Exception:
self.merge_old_user()
result = self.session.exec(statement).all()
if len(result) != 0:

View File

@@ -1,8 +1,8 @@
import logging
from os import PathLike
import re
from os import PathLike
from module.conf import settings, PLATFORM
from module.conf import PLATFORM, settings
from module.models import Bangumi, BangumiUpdate
logger = logging.getLogger(__name__)
@@ -72,4 +72,3 @@ class TorrentPath:
@staticmethod
def _join_path(*args):
return str(Path(*args))

View File

@@ -2,8 +2,8 @@ import logging
from module.downloader import DownloadClient
from module.models import Bangumi, ResponseModel
from module.searcher import SearchTorrent
from module.rss import RSSEngine
from module.searcher import SearchTorrent
logger = logging.getLogger(__name__)
@@ -13,17 +13,21 @@ class SeasonCollector(DownloadClient):
logger.info(
f"Start collecting {bangumi.official_title} Season {bangumi.season}..."
)
with SearchTorrent() as st:
with SearchTorrent() as st, RSSEngine() as engine:
if not link:
torrents = st.search_season(bangumi)
else:
torrents = st.get_torrents(link, bangumi.filter.replace(",", "|"))
if self.add_torrent(torrents, bangumi):
logger.info(f"Collections of {bangumi.official_title} Season {bangumi.season} completed.")
logger.info(
f"Collections of {bangumi.official_title} Season {bangumi.season} completed."
)
for torrent in torrents:
torrent.downloaded = True
bangumi.eps_collect = True
with RSSEngine() as engine:
engine.bangumi.update(bangumi)
engine.torrent.add_all(torrents)
if engine.bangumi.update(bangumi):
engine.bangumi.add(bangumi)
engine.torrent.add_all(torrents)
return ResponseModel(
status=True,
status_code=200,
@@ -31,12 +35,14 @@ class SeasonCollector(DownloadClient):
msg_zh=f"收集 {bangumi.official_title}{bangumi.season} 季完成。",
)
else:
logger.warning(f"Collection of {bangumi.official_title} Season {bangumi.season} failed.")
logger.warning(
f"Already collected {bangumi.official_title} Season {bangumi.season}."
)
return ResponseModel(
status=False,
status_code=406,
msg_en=f"Collection of {bangumi.official_title} Season {bangumi.season} failed.",
msg_zh=f"收集 {bangumi.official_title}{bangumi.season} 季失败。",
msg_zh=f"收集 {bangumi.official_title}{bangumi.season} 季失败, 种子已经添加",
)
@staticmethod

View File

@@ -1,6 +1,5 @@
import logging
from module.database import Database
from module.downloader import DownloadClient
from module.models import Bangumi, BangumiUpdate, ResponseModel
@@ -117,6 +116,7 @@ class TorrentManager(Database):
path = client._gen_save_path(data)
if match_list:
client.move_torrent(match_list, path)
data.save_path = path
self.bangumi.update(data, bangumi_id)
return ResponseModel(
status_code=200,
@@ -142,8 +142,8 @@ class TorrentManager(Database):
return ResponseModel(
status_code=200,
status=True,
msg_en=f"Refresh poster link successfully.",
msg_zh=f"刷新海报链接成功。",
msg_en="Refresh poster link successfully.",
msg_zh="刷新海报链接成功。",
)
def search_all_bangumi(self):
@@ -165,6 +165,7 @@ class TorrentManager(Database):
else:
return data
if __name__ == '__main__':
if __name__ == "__main__":
with TorrentManager() as manager:
manager.refresh_poster()

View File

@@ -1,6 +1,6 @@
from .bangumi import Bangumi, Episode, BangumiUpdate, Notification
from .bangumi import Bangumi, BangumiUpdate, Episode, Notification
from .config import Config
from .response import APIResponse, ResponseModel
from .rss import RSSItem, RSSUpdate
from .torrent import EpisodeFile, SubtitleFile, Torrent, TorrentUpdate
from .user import UserLogin, User, UserUpdate
from .response import ResponseModel, APIResponse
from .user import User, UserLogin, UserUpdate

View File

@@ -1,8 +1,8 @@
from dataclasses import dataclass
from typing import Optional
from pydantic import BaseModel
from sqlmodel import SQLModel, Field
from typing import Optional
from sqlmodel import Field, SQLModel
class Bangumi(SQLModel, table=True):

View File

@@ -1,4 +1,5 @@
from os.path import expandvars
from pydantic import BaseModel, Field
@@ -81,6 +82,15 @@ class Notification(BaseModel):
return expandvars(self.chat_id_)
class ExperimentalOpenAI(BaseModel):
enable: bool = Field(False, description="Enable experimental OpenAI")
api_key: str = Field("", description="OpenAI api key")
api_base: str = Field(
"https://api.openai.com/v1", description="OpenAI api base url"
)
model: str = Field("gpt-3.5-turbo", description="OpenAI model")
class Config(BaseModel):
program: Program = Program()
downloader: Downloader = Downloader()
@@ -89,6 +99,7 @@ class Config(BaseModel):
log: Log = Log()
proxy: Proxy = Proxy()
notification: Notification = Notification()
experimental_openai: ExperimentalOpenAI = ExperimentalOpenAI()
def dict(self, *args, by_alias=True, **kwargs):
return super().dict(*args, by_alias=by_alias, **kwargs)

View File

@@ -1,6 +1,7 @@
from sqlmodel import SQLModel, Field
from typing import Optional
from sqlmodel import Field, SQLModel
class RSSItem(SQLModel, table=True):
id: int = Field(default=None, primary_key=True, alias="id")

View File

@@ -1,7 +1,8 @@
from pydantic import BaseModel
from sqlmodel import SQLModel, Field
from typing import Optional
from pydantic import BaseModel
from sqlmodel import Field, SQLModel
class Torrent(SQLModel, table=True):
id: int = Field(default=None, primary_key=True, alias="id")

View File

@@ -1,6 +1,7 @@
from pydantic import BaseModel
from typing import Optional
from sqlmodel import SQLModel, Field
from pydantic import BaseModel
from sqlmodel import Field, SQLModel
class User(SQLModel, table=True):

View File

@@ -1,5 +1,5 @@
import re
import logging
import re
import xml.etree.ElementTree
from module.conf import settings

View File

@@ -9,13 +9,13 @@ logger = logging.getLogger(__name__)
class TelegramNotification(RequestContent):
def __init__(self, token, chat_id):
super().__init__()
self.notification_url = f"https://api.telegram.org/bot{token}/sendMessage"
self.notification_url = f"https://api.telegram.org/bot{token}/sendPhoto"
self.chat_id = chat_id
@staticmethod
def gen_message(notify: Notification) -> str:
text = f"""
番剧名称:{notify.official_title}\n季度: 第{notify.season}\n更新集数: 第{notify.episode}\n{notify.poster_path}\n
番剧名称:{notify.official_title}\n季度: 第{notify.season}\n更新集数: 第{notify.episode}
"""
return text
@@ -23,7 +23,8 @@ class TelegramNotification(RequestContent):
text = self.gen_message(notify)
data = {
"chat_id": self.chat_id,
"text": text,
"caption": text,
"photo": notify.poster_path,
"disable_notification": True,
}
resp = self.post_data(self.notification_url, data)

View File

@@ -1,4 +1,5 @@
from .mikan_parser import mikan_parser
from .openai import OpenAIParser
from .raw_parser import raw_parser
from .tmdb_parser import tmdb_parser
from .torrent_parser import torrent_parser
from .mikan_parser import mikan_parser

View File

@@ -0,0 +1,129 @@
import asyncio
import json
import logging
import openai
logger = logging.getLogger(__name__)
DEFAULT_PROMPT = """\
You will now play the role of a super assistant.
Your task is to extract structured data from unstructured text content and output it in JSON format.
If you are unable to extract any information, please keep all fields and leave the field empty or default value like `''`, `None`.
But Do not fabricate data!
the python structured data type is:
```python
@dataclass
class Episode:
title_en: Optional[str]
title_zh: Optional[str]
title_jp: Optional[str]
season: int
season_raw: str
episode: int
sub: str
group: str
resolution: str
source: str
```
Example:
```
input: "【喵萌奶茶屋】★04月新番★[夏日重现/Summer Time Rendering][11][1080p][繁日双语][招募翻译]"
output: '{"group": "喵萌奶茶屋", "title_en": "Summer Time Rendering", "resolution": "1080p", "episode": 11, "season": 1, "title_zh": "夏日重现", "sub": "", "title_jp": "", "season_raw": "", "source": ""}'
input: "【幻樱字幕组】【4月新番】【古见同学有交流障碍症 第二季 Komi-san wa, Komyushou Desu. S02】【22】【GB_MP4】【1920X1080】"
output: '{"group": "幻樱字幕组", "title_en": "Komi-san wa, Komyushou Desu.", "resolution": "1920X1080", "episode": 22, "season": 2, "title_zh": "古见同学有交流障碍症", "sub": "", "title_jp": "", "season_raw": "", "source": ""}'
input: "[Lilith-Raws] 关于我在无意间被隔壁的天使变成废柴这件事 / Otonari no Tenshi-sama - 09 [Baha][WEB-DL][1080p][AVC AAC][CHT][MP4]"
output: '{"group": "Lilith-Raws", "title_en": "Otonari no Tenshi-sama", "resolution": "1080p", "episode": 9, "season": 1, "source": "WEB-DL", "title_zh": "关于我在无意间被隔壁的天使变成废柴这件事", "sub": "CHT", "title_jp": ""}'
```
"""
class OpenAIParser:
def __init__(
self,
api_key: str,
api_base: str = "https://api.openai.com/v1",
model: str = "gpt-3.5-turbo",
**kwargs,
) -> None:
"""OpenAIParser is a class to parse text with openai
Args:
api_key (str): the OpenAI api key
api_base (str):
the OpenAI api base url, you can use custom url here. \
Defaults to "https://api.openai.com/v1".
model (str):
the ChatGPT model parameter, you can get more details from \
https://platform.openai.com/docs/api-reference/chat/create. \
Defaults to "gpt-3.5-turbo".
kwargs (dict):
the OpenAI ChatGPT parameters, you can get more details from \
https://platform.openai.com/docs/api-reference/chat/create.
Raises:
ValueError: if api_key is not provided.
"""
if not api_key:
raise ValueError("API key is required.")
self._api_key = api_key
self.api_base = api_base
self.model = model
self.openai_kwargs = kwargs
def parse(
self, text: str, prompt: str | None = None, asdict: bool = True
) -> dict | str:
"""parse text with openai
Args:
text (str): the text to be parsed
prompt (str | None, optional):
the custom prompt. Built-in prompt will be used if no prompt is provided. \
Defaults to None.
asdict (bool, optional):
whether to return the result as dict or not. \
Defaults to True.
Returns:
dict | str: the parsed result.
"""
if not prompt:
prompt = DEFAULT_PROMPT
async def complete() -> str:
resp = await openai.ChatCompletion.acreate(
api_key=self._api_key,
api_base=self.api_base,
model=self.model,
messages=[
dict(role="system", content=prompt),
dict(role="user", content=text),
],
# set temperature to 0 to make results be more stable and reproducible.
temperature=0,
**self.openai_kwargs,
)
result = resp["choices"][0]["message"]["content"]
return result
loop = asyncio.get_event_loop()
result = loop.run_until_complete(complete())
if asdict:
try:
result = json.loads(result)
except json.JSONDecodeError:
logger.warning(f"Cannot parse result {result} as python dict.")
logger.debug(f"the parsed result is: {result}")
return result

View File

@@ -5,7 +5,6 @@ from dataclasses import dataclass
from module.conf import TMDB_API
from module.network import RequestContent
TMDB_URL = "https://api.themoviedb.org"
@@ -50,7 +49,9 @@ def get_season(seasons: list) -> tuple[int, str]:
[year, _, _] = date
now_year = time.localtime().tm_year
if int(year) <= now_year:
return int(re.findall(r"\d", season.get("season"))[0]), season.get("poster_path")
return int(re.findall(r"\d", season.get("season"))[0]), season.get(
"poster_path"
)
return len(ss), ss[-1].get("poster_path")
@@ -100,5 +101,5 @@ def tmdb_parser(title, language) -> TMDBInfo | None:
return None
if __name__ == '__main__':
print(tmdb_parser("魔法禁书目录", "zh"))
if __name__ == "__main__":
print(tmdb_parser("魔法禁书目录", "zh"))

View File

@@ -1,6 +1,6 @@
import logging
from pathlib import Path
import re
from pathlib import Path
from module.models import EpisodeFile, SubtitleFile

View File

@@ -1 +0,0 @@

View File

@@ -2,8 +2,14 @@ import logging
from module.conf import settings
from module.models import Bangumi
from .analyser import raw_parser, tmdb_parser, torrent_parser, mikan_parser
from module.models.bangumi import Episode
from module.parser.analyser import (
OpenAIParser,
mikan_parser,
raw_parser,
tmdb_parser,
torrent_parser,
)
logger = logging.getLogger(__name__)
@@ -43,14 +49,24 @@ class TitleParser:
logger.debug(f"TMDB Matched, official title is {tmdb_info.title}")
bangumi.poster_link = tmdb_info.poster_link
else:
logger.warning(f"Cannot match {bangumi.official_title} in TMDB. Use raw title instead.")
logger.warning(
f"Cannot match {bangumi.official_title} in TMDB. Use raw title instead."
)
logger.warning("Please change bangumi info manually.")
@staticmethod
def raw_parser(raw: str) -> Bangumi | None:
language = settings.rss_parser.language
try:
episode = raw_parser(raw)
# use OpenAI ChatGPT to parse raw title and get structured data
if settings.experimental_openai.enable:
kwargs = settings.experimental_openai.dict(exclude={"enable"})
gpt = OpenAIParser(**kwargs)
episode_dict = gpt.parse(raw, asdict=True)
episode = Episode(**episode_dict)
else:
episode = raw_parser(raw)
titles = {
"zh": episode.title_zh,
"en": episode.title_en,

View File

@@ -1,13 +1,13 @@
import logging
import re
from .engine import RSSEngine
from module.conf import settings
from module.models import Bangumi, Torrent, RSSItem, ResponseModel
from module.models import Bangumi, ResponseModel, RSSItem, Torrent
from module.network import RequestContent
from module.parser import TitleParser
from .engine import RSSEngine
logger = logging.getLogger(__name__)

View File

@@ -1,13 +1,11 @@
import re
import logging
import re
from typing import Optional
from module.models import Bangumi, RSSItem, Torrent, ResponseModel
from module.network import RequestContent
from module.downloader import DownloadClient
from module.database import Database, engine
from module.downloader import DownloadClient
from module.models import Bangumi, ResponseModel, RSSItem, Torrent
from module.network import RequestContent
logger = logging.getLogger(__name__)
@@ -33,7 +31,13 @@ class RSSEngine(Database):
else:
return []
def add_rss(self, rss_link: str, name: str | None = None, aggregate: bool = True, parser: str = "mikan"):
def add_rss(
self,
rss_link: str,
name: str | None = None,
aggregate: bool = True,
parser: str = "mikan",
):
if not name:
with RequestContent() as req:
name = req.get_rss_title(rss_link)
@@ -98,6 +102,8 @@ class RSSEngine(Database):
def match_torrent(self, torrent: Torrent) -> Optional[Bangumi]:
matched: Bangumi = self.bangumi.match_torrent(torrent.name)
if matched:
if matched.filter == "":
return matched
_filter = matched.filter.replace(",", "|")
if not re.search(_filter, torrent.name, re.IGNORECASE):
torrent.bangumi_id = matched.id
@@ -127,7 +133,9 @@ class RSSEngine(Database):
def download_bangumi(self, bangumi: Bangumi):
with RequestContent() as req:
torrents = req.get_torrents(bangumi.rss_link, bangumi.filter.replace(",", "|"))
torrents = req.get_torrents(
bangumi.rss_link, bangumi.filter.replace(",", "|")
)
if torrents:
with DownloadClient() as client:
client.add_torrent(torrents, bangumi)

View File

@@ -1,2 +1,2 @@
from .searcher import SearchTorrent
from .provider import SEARCH_CONFIG
from .searcher import SearchTorrent

View File

@@ -1,7 +1,7 @@
import re
from module.models import RSSItem
from module.conf import SEARCH_CONFIG
from module.models import RSSItem
def search_url(site: str, keywords: list[str]) -> RSSItem:
@@ -17,4 +17,4 @@ def search_url(site: str, keywords: list[str]) -> RSSItem:
)
return rss_item
else:
raise ValueError(f"Site {site} is not supported")
raise ValueError(f"Site {site} is not supported")

View File

@@ -1,7 +1,7 @@
import json
from typing import TypeAlias
from module.models import Bangumi, Torrent, RSSItem
from module.models import Bangumi, RSSItem, Torrent
from module.network import RequestContent
from module.rss import RSSAnalyser
@@ -20,13 +20,13 @@ BangumiJSON: TypeAlias = str
class SearchTorrent(RequestContent, RSSAnalyser):
def search_torrents(
self, rss_item: RSSItem
) -> list[Torrent]:
def search_torrents(self, rss_item: RSSItem) -> list[Torrent]:
torrents = self.get_torrents(rss_item.url)
return torrents
def analyse_keyword(self, keywords: list[str], site: str = "mikan", limit: int = 5) -> BangumiJSON:
def analyse_keyword(
self, keywords: list[str], site: str = "mikan", limit: int = 5
) -> BangumiJSON:
rss_item = search_url(site, keywords)
torrents = self.search_torrents(rss_item)
# yield for EventSourceResponse (Server Send)
@@ -39,7 +39,7 @@ class SearchTorrent(RequestContent, RSSAnalyser):
if bangumi and special_link not in exist_list:
bangumi.rss_link = special_link
exist_list.append(special_link)
yield json.dumps(bangumi.dict(), separators=(',', ':'))
yield json.dumps(bangumi.dict(), separators=(",", ":"))
@staticmethod
def special_url(data: Bangumi, site: str) -> RSSItem:
@@ -50,4 +50,4 @@ class SearchTorrent(RequestContent, RSSAnalyser):
def search_season(self, data: Bangumi, site: str = "mikan") -> list[Torrent]:
rss_item = self.special_url(data, site)
torrents = self.search_torrents(rss_item)
return [torrent for torrent in torrents if data.title_raw in torrent.name]
return [torrent for torrent in torrents if data.title_raw in torrent.name]

View File

@@ -1,4 +1,4 @@
from fastapi import Depends, HTTPException, status, Cookie
from fastapi import Cookie, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from module.database import Database

View File

@@ -1,4 +1,4 @@
from .data_migration import data_migration
from .startup import start_up, first_run
from .version_check import version_check
from .cross_version import from_30_to_31
from .data_migration import data_migration
from .startup import first_run, start_up
from .version_check import version_check

View File

@@ -1,4 +1,5 @@
import re
from urllib3.util import parse_url
from module.rss import RSSEngine
@@ -13,7 +14,9 @@ def from_30_to_31():
for bangumi in bangumis:
if bangumi.poster_link:
rss_link = bangumi.rss_link.split(",")[-1]
if rss_link not in rss_pool and not re.search(r"\d+.\d+.\d+.\d+", rss_link):
if rss_link not in rss_pool and not re.search(
r"\d+.\d+.\d+.\d+", rss_link
):
rss_pool.append(rss_link)
root_path = parse_url(rss_link).host
if "://" not in bangumi.poster_link:

View File

@@ -1,6 +1,6 @@
from module.conf import LEGACY_DATA_PATH
from module.rss import RSSEngine
from module.models import Bangumi
from module.rss import RSSEngine
from module.utils import json_config

View File

@@ -6,6 +6,8 @@ from module.conf import VERSION, VERSION_PATH
def version_check() -> bool:
if VERSION == "DEV_VERSION":
return True
if VERSION == "local":
return True
if not VERSION_PATH.exists():
with open(VERSION_PATH, "w") as f:
f.write(VERSION + "\n")

View File

@@ -1,9 +1,7 @@
from sqlmodel import create_engine, SQLModel
from sqlmodel.pool import StaticPool
from module.database.combine import Database
from module.models import Bangumi, Torrent, RSSItem
from module.models import Bangumi, RSSItem, Torrent
from sqlmodel import SQLModel, create_engine
from sqlmodel.pool import StaticPool
# sqlite mock engine
engine = create_engine(
@@ -47,7 +45,9 @@ def test_bangumi_database():
assert db.bangumi.match_poster("无职转生到了异世界就拿出真本事II (2021)") == "/test/test.jpg"
# match torrent
result = db.bangumi.match_torrent("[Lilith-Raws] 无职转生,到了异世界就拿出真本事 / Mushoku Tensei - 11 [Baha][WEB-DL][1080p][AVC AAC][CHT][MP4]")
result = db.bangumi.match_torrent(
"[Lilith-Raws] 无职转生,到了异世界就拿出真本事 / Mushoku Tensei - 11 [Baha][WEB-DL][1080p][AVC AAC][CHT][MP4]"
)
assert result.official_title == "无职转生到了异世界就拿出真本事II"
# delete

View File

@@ -0,0 +1,36 @@
import json
import os
from unittest import mock
from dotenv import load_dotenv
from module.parser.analyser.openai import OpenAIParser
load_dotenv()
class TestOpenAIParser:
@classmethod
def setup_class(cls):
api_key = os.getenv("OPENAI_API_KEY") or "testing!"
cls.parser = OpenAIParser(api_key=api_key)
def test_parse(self):
text = "[梦蓝字幕组]New Doraemon 哆啦A梦新番[747][2023.02.25][AVC][1080P][GB_JP][MP4]"
expected = {
"group": "梦蓝字幕组",
"title_en": "New Doraemon",
"resolution": "1080P",
"episode": 747,
"season": 1,
"title_zh": "哆啦A梦新番",
"sub": "GB_JP",
"title_jp": "",
"season_raw": "2023.02.25",
"source": "AVC",
}
with mock.patch("module.parser.analyser.OpenAIParser.parse") as mocker:
mocker.return_value = json.dumps(expected)
result = self.parser.parse(text=text, asdict=False)
assert json.loads(result) == expected

View File

@@ -0,0 +1,30 @@
import json
import os
import pytest
from module.conf import settings
from module.parser.title_parser import TitleParser
class TestTitleParser:
def test_parse_without_openai(self):
text = "[梦蓝字幕组]New Doraemon 哆啦A梦新番[747][2023.02.25][AVC][1080P][GB_JP][MP4]"
result = TitleParser.raw_parser(text)
assert result.group_name == "梦蓝字幕组"
assert result.title_raw == "New Doraemon"
assert result.dpi == "1080P"
assert result.season == 1
assert result.subtitle == "GB_JP"
@pytest.mark.skipif(
not settings.experimental_openai.enable,
reason="OpenAI is not enabled in settings",
)
def test_parse_with_openai(self):
text = "[梦蓝字幕组]New Doraemon 哆啦A梦新番[747][2023.02.25][AVC][1080P][GB_JP][MP4]"
result = TitleParser.raw_parser(text)
assert result.group_name == "梦蓝字幕组"
assert result.title_raw == "New Doraemon"
assert result.dpi == "1080P"
assert result.season == 1
assert result.subtitle == "GB_JP"

View File

@@ -0,0 +1,58 @@
<script lang="ts" setup>
import { Caution } from '@icon-park/vue-next';
import type { SettingItem } from '#/components';
import type { ExperimentalOpenAI } from '#/config';
const { t } = useMyI18n();
const { getSettingGroup } = useConfigStore();
const experimentalFeatures = getSettingGroup('experimental_openai');
const items: SettingItem<ExperimentalOpenAI>[] = [
{
configKey: 'enable',
label: () => t('config.experimental_openai_set.enable'),
type: 'switch',
},
{
configKey: 'api_key',
label: () => t('config.experimental_openai_set.api_key'),
type: 'input',
prop: {
type: 'password',
placeholder: 'e.g: sk-3Bl****w2E9kW',
},
},
{
configKey: 'api_base',
label: () => t('config.experimental_openai_set.api_base'),
type: 'input',
prop: {
type: 'url',
placeholder: 'OpenAI API Base URL',
},
},
{
configKey: 'model',
label: () => t('config.experimental_openai_set.model'),
type: 'select',
},
];
</script>
<template>
<ab-fold-panel :title="$t('config.experimental_openai_set.title')">
<div fx-cer gap-2 mb-4 p-2 bg-amber-300 rounded-4px>
<Caution />
<span>{{ $t('config.experimental_openai_set.warning') }}</span>
</div>
<div space-y-12px>
<ab-setting
v-for="i in items"
:key="i.configKey"
v-bind="i"
v-model:data="experimentalFeatures[i.configKey]"
></ab-setting>
</div>
</ab-fold-panel>
</template>

View File

@@ -72,7 +72,7 @@
"status": "Status",
"delete": "Delete",
"disable": "Disable",
"enable": "Enable",
"enable": "Enable"
},
"player": {
"hit": "Please set up the media player"
@@ -137,6 +137,14 @@
"username": "Username",
"password": "Password"
},
"experimental_openai_set": {
"title": "Experimental Setting",
"warning": "Warning: Experimental feature is not yet stable. Please use with caution.",
"enable": "Enable OpenAI",
"api_key": "OpenAI API Key",
"api_base": "OpenAI API Base URL",
"model": "OpenAI Model"
},
"media_player_set": {
"title": "Media Player Setting",
"type": "type",

View File

@@ -58,8 +58,8 @@
"apply": "应用",
"yes_btn": "是",
"no_btn": "否",
"enable_hit": "确定启用该规则?",
"delete_hit": "是否删除本地文件?",
"enable_hit": "确定启用该规则",
"delete_hit": "是否删除本地文件",
"enable_rule": "启用规则",
"edit_rule": "编辑规则"
}
@@ -137,6 +137,14 @@
"username": "用户名",
"password": "密码"
},
"experimental_openai_set": {
"title": "实验功能设置",
"warning": "警告:实验功能尚未稳定,请谨慎使用",
"enable": "启用 OpenAI",
"api_key": "OpenAI API Key",
"api_base": "OpenAI API Base URL",
"model": "OpenAI 模型"
},
"media_player_set": {
"title": "播放器设置",
"type": "类型",

View File

@@ -28,6 +28,8 @@ definePage({
<config-proxy></config-proxy>
<config-player></config-player>
<config-experimental></config-experimental>
</div>
</div>

View File

@@ -47,6 +47,12 @@ export interface Config {
token: string;
chat_id: string;
};
experimental_openai: {
enable: boolean;
api_key: string;
api_base: string;
model: string;
};
}
export const initConfig: Config = {
@@ -96,6 +102,12 @@ export const initConfig: Config = {
token: '',
chat_id: '',
},
experimental_openai: {
enable: false,
api_key: '',
api_base: 'https://api.openai.com/v1/',
model: 'gpt-3.5-turbo',
},
};
type getItem<T extends keyof Config> = Pick<Config, T>[T];
@@ -107,6 +119,7 @@ export type BangumiManage = getItem<'bangumi_manage'>;
export type Log = getItem<'log'>;
export type Proxy = getItem<'proxy'>;
export type Notification = getItem<'notification'>;
export type ExperimentalOpenAI = getItem<'experimental_openai'>;
/** 下载方式 */
export type DownloaderType = UnionToTuple<Downloader['type']>;

View File

@@ -34,6 +34,7 @@ declare module '@vue/runtime-core' {
AbTag: typeof import('./../../src/components/basic/ab-tag.vue')['default']
AbTopbar: typeof import('./../../src/components/layout/ab-topbar.vue')['default']
ConfigDownload: typeof import('./../../src/components/setting/config-download.vue')['default']
ConfigExperimental: typeof import('./../../src/components/setting/config-experimental.vue')['default']
ConfigManage: typeof import('./../../src/components/setting/config-manage.vue')['default']
ConfigNormal: typeof import('./../../src/components/setting/config-normal.vue')['default']
ConfigNotification: typeof import('./../../src/components/setting/config-notification.vue')['default']