mirror of
https://github.com/EstrellaXD/Auto_Bangumi.git
synced 2026-04-14 10:30:35 +08:00
fix: old data support problem.
This commit is contained in:
@@ -4,12 +4,12 @@ from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
|
||||
from module.models.user import User
|
||||
from module.security import (
|
||||
from module.security.api import (
|
||||
auth_user,
|
||||
create_access_token,
|
||||
get_current_user,
|
||||
update_user_info,
|
||||
)
|
||||
from module.security.jwt import create_access_token
|
||||
|
||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from fastapi.responses import JSONResponse
|
||||
|
||||
from module.manager import TorrentManager
|
||||
from module.models import Bangumi
|
||||
from module.security import get_current_user
|
||||
from module.security.api import get_current_user
|
||||
|
||||
router = APIRouter(prefix="/bangumi", tags=["bangumi"])
|
||||
|
||||
@@ -14,8 +14,8 @@ async def get_all_data(current_user=Depends(get_current_user)):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token"
|
||||
)
|
||||
with TorrentManager() as torrent:
|
||||
return torrent.search_all()
|
||||
with TorrentManager() as manager:
|
||||
return manager.bangumi.search_all()
|
||||
|
||||
|
||||
@router.get("/getData/{bangumi_id}", response_model=Bangumi)
|
||||
@@ -24,8 +24,8 @@ async def get_data(bangumi_id: str, current_user=Depends(get_current_user)):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token"
|
||||
)
|
||||
with TorrentManager() as torrent:
|
||||
return torrent.search_one(bangumi_id)
|
||||
with TorrentManager() as manager:
|
||||
return manager.search_one(bangumi_id)
|
||||
|
||||
|
||||
@router.post("/updateRule")
|
||||
@@ -34,8 +34,8 @@ async def update_rule(data: Bangumi, current_user=Depends(get_current_user)):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token"
|
||||
)
|
||||
with TorrentManager() as torrent:
|
||||
return torrent.update_rule(data)
|
||||
with TorrentManager() as manager:
|
||||
return manager.update_rule(data)
|
||||
|
||||
|
||||
@router.delete("/deleteRule/{bangumi_id}")
|
||||
@@ -46,8 +46,8 @@ async def delete_rule(
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token"
|
||||
)
|
||||
with TorrentManager() as torrent:
|
||||
return torrent.delete_rule(bangumi_id, file)
|
||||
with TorrentManager() as manager:
|
||||
return manager.delete_rule(bangumi_id, file)
|
||||
|
||||
|
||||
@router.delete("/disableRule/{bangumi_id}")
|
||||
@@ -58,8 +58,8 @@ async def disable_rule(
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token"
|
||||
)
|
||||
with TorrentManager() as torrent:
|
||||
return torrent.disable_rule(bangumi_id, file)
|
||||
with TorrentManager() as manager:
|
||||
return manager.disable_rule(bangumi_id, file)
|
||||
|
||||
|
||||
@router.get("/enableRule/{bangumi_id}")
|
||||
@@ -68,8 +68,8 @@ async def enable_rule(bangumi_id: str, current_user=Depends(get_current_user)):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token"
|
||||
)
|
||||
with TorrentManager() as torrent:
|
||||
return torrent.enable_rule(bangumi_id)
|
||||
with TorrentManager() as manager:
|
||||
return manager.enable_rule(bangumi_id)
|
||||
|
||||
|
||||
@router.get("/resetAll")
|
||||
@@ -78,6 +78,6 @@ async def reset_all(current_user=Depends(get_current_user)):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token"
|
||||
)
|
||||
with TorrentManager() as torrent:
|
||||
torrent.delete_all()
|
||||
with TorrentManager() as manager:
|
||||
manager.bangumi.delete_all()
|
||||
return JSONResponse(status_code=200, content={"message": "OK"})
|
||||
|
||||
@@ -4,7 +4,7 @@ from fastapi import APIRouter, Depends, HTTPException, status
|
||||
|
||||
from module.conf import settings
|
||||
from module.models import Config
|
||||
from module.security import get_current_user
|
||||
from module.security.api import get_current_user
|
||||
|
||||
router = APIRouter(tags=["config"])
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -4,7 +4,7 @@ from module.manager import SeasonCollector
|
||||
from module.models import Bangumi
|
||||
from module.models.api import RssLink
|
||||
from module.rss import analyser
|
||||
from module.security import get_current_user
|
||||
from module.security.api import get_current_user
|
||||
|
||||
router = APIRouter(prefix="/download", tags=["download"])
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ import os
|
||||
from fastapi import APIRouter, Depends, HTTPException, Response, status
|
||||
|
||||
from module.conf import LOG_PATH
|
||||
from module.security import get_current_user
|
||||
from module.security.api import get_current_user
|
||||
|
||||
router = APIRouter(prefix="/log", tags=["log"])
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ import sys
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
|
||||
from module.core import Program
|
||||
from module.security import get_current_user
|
||||
from module.security.api import get_current_user
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
program = Program()
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
|
||||
from module.conf import VERSION, settings
|
||||
from module.update import data_migration
|
||||
from module.update import data_migration, start_up
|
||||
|
||||
from .sub_thread import RenameThread, RSSThread
|
||||
|
||||
@@ -32,6 +32,7 @@ class Program(RenameThread, RSSThread):
|
||||
|
||||
def startup(self):
|
||||
self.__start_info()
|
||||
start_up(self.first_run)
|
||||
if self.first_run:
|
||||
logger.info("First run detected, please configure the program in webui.")
|
||||
return {"status": "First run detected."}
|
||||
|
||||
@@ -3,6 +3,7 @@ from sqlmodel import Session, SQLModel
|
||||
from .rss import RSSDatabase
|
||||
from .torrent import TorrentDatabase
|
||||
from .bangumi import BangumiDatabase
|
||||
from .user import UserDatabase
|
||||
from .engine import engine as e
|
||||
|
||||
|
||||
@@ -13,6 +14,18 @@ class Database(Session):
|
||||
self.rss = RSSDatabase(self)
|
||||
self.torrent = TorrentDatabase(self)
|
||||
self.bangumi = BangumiDatabase(self)
|
||||
self.user = UserDatabase(self)
|
||||
|
||||
def create_table(self):
|
||||
SQLModel.metadata.create_all(self.engine)
|
||||
|
||||
def drop_table(self):
|
||||
SQLModel.metadata.drop_all(self.engine)
|
||||
|
||||
def migrate(self):
|
||||
# Run migration online
|
||||
from alembic import command
|
||||
from alembic.config import Config
|
||||
|
||||
alembic_cfg = Config("alembic.ini")
|
||||
command.upgrade(alembic_cfg, "head")
|
||||
|
||||
@@ -30,6 +30,10 @@ class RSSDatabase:
|
||||
self.session.commit()
|
||||
self.session.refresh(data)
|
||||
|
||||
# TODO: Check if this is needed
|
||||
def search_id(self, _id: int) -> RSSItem:
|
||||
return self.session.get(RSSItem, _id)
|
||||
|
||||
def search_all(self) -> list[RSSItem]:
|
||||
return self.session.exec(select(RSSItem)).all()
|
||||
|
||||
|
||||
@@ -4,34 +4,25 @@ from fastapi import HTTPException
|
||||
|
||||
from module.models.user import User, UserUpdate, UserLogin
|
||||
from module.security.jwt import get_password_hash, verify_password
|
||||
from module.database.engine import engine
|
||||
from sqlmodel import Session, select, SQLModel
|
||||
from sqlalchemy.exc import UnboundExecutionError, OperationalError
|
||||
from sqlmodel import Session, select
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UserDatabase(Session):
|
||||
def __init__(self):
|
||||
super().__init__(engine)
|
||||
statement = select(User)
|
||||
try:
|
||||
self.exec(statement)
|
||||
except OperationalError:
|
||||
SQLModel.metadata.create_all(engine)
|
||||
self.add(User())
|
||||
self.commit()
|
||||
class UserDatabase:
|
||||
def __init__(self, session: Session):
|
||||
self.session = session
|
||||
|
||||
def get_user(self, username):
|
||||
statement = select(User).where(User.username == username)
|
||||
result = self.exec(statement).first()
|
||||
result = self.session.exec(statement).first()
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
return result
|
||||
|
||||
def auth_user(self, user: UserLogin) -> bool:
|
||||
statement = select(User).where(User.username == user.username)
|
||||
result = self.exec(statement).first()
|
||||
result = self.session.exec(statement).first()
|
||||
if not result:
|
||||
raise HTTPException(status_code=401, detail="User not found")
|
||||
if not verify_password(user.password, result.password):
|
||||
@@ -41,19 +32,55 @@ class UserDatabase(Session):
|
||||
def update_user(self, username, update_user: UserUpdate):
|
||||
# Update username and password
|
||||
statement = select(User).where(User.username == username)
|
||||
result = self.exec(statement).first()
|
||||
result = self.session.exec(statement).first()
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
if update_user.username:
|
||||
result.username = update_user.username
|
||||
if update_user.password:
|
||||
result.password = get_password_hash(update_user.password)
|
||||
self.add(result)
|
||||
self.commit()
|
||||
self.session.add(result)
|
||||
self.session.commit()
|
||||
return result
|
||||
|
||||
def merge_old_user(self):
|
||||
# get old data
|
||||
statement = """
|
||||
SELECT * FROM user
|
||||
"""
|
||||
result = self.session.exec(statement).first()
|
||||
if not result:
|
||||
return
|
||||
# add new data
|
||||
user = User(username=result.username, password=result.password)
|
||||
# Drop old table
|
||||
statement = """
|
||||
DROP TABLE user
|
||||
"""
|
||||
self.session.exec(statement)
|
||||
# Create new table
|
||||
statement = """
|
||||
CREATE TABLE user (
|
||||
id INTEGER NOT NULL PRIMARY KEY,
|
||||
username VARCHAR NOT NULL,
|
||||
password VARCHAR NOT NULL
|
||||
)
|
||||
"""
|
||||
self.session.exec(statement)
|
||||
self.session.add(user)
|
||||
self.session.commit()
|
||||
|
||||
if __name__ == "__main__":
|
||||
with UserDatabase() as db:
|
||||
# db.update_user(UserLogin(username="admin", password="adminadmin"), User(username="admin", password="cica1234"))
|
||||
db.update_user("admin", UserUpdate(username="estrella", password="cica1234"))
|
||||
def add_default_user(self):
|
||||
# Check if user exists
|
||||
statement = select(User)
|
||||
try:
|
||||
result = self.session.exec(statement).all()
|
||||
except Exception as e:
|
||||
self.merge_old_user()
|
||||
result = self.session.exec(statement).all()
|
||||
if len(result) != 0:
|
||||
return
|
||||
# Add default user
|
||||
user = User(username="admin", password=get_password_hash("adminadmin"))
|
||||
self.session.add(user)
|
||||
self.session.commit()
|
||||
|
||||
@@ -2,4 +2,4 @@ from .bangumi import Bangumi, Episode, BangumiUpdate, Notification
|
||||
from .config import Config
|
||||
from .rss import RSSItem, RSSUpdate
|
||||
from .torrent import EpisodeFile, SubtitleFile, Torrent, TorrentUpdate
|
||||
from .user import UserLogin
|
||||
from .user import UserLogin, User
|
||||
|
||||
@@ -5,7 +5,8 @@ from typing import Optional
|
||||
|
||||
class Torrent(SQLModel, table=True):
|
||||
id: int = Field(default=None, primary_key=True, alias="id")
|
||||
refer_id: Optional[int] = Field(None, alias="refer_id")
|
||||
bangumi_id: Optional[int] = Field(None, alias="refer_id", foreign_key="bangumi.id")
|
||||
rss_id: Optional[int] = Field(None, alias="rss_id", foreign_key="rssitem.id")
|
||||
name: str = Field("", alias="name")
|
||||
url: str = Field("https://example.com/torrent", alias="url")
|
||||
homepage: Optional[str] = Field(None, alias="homepage")
|
||||
|
||||
@@ -17,9 +17,12 @@ class RSSEngine(Database):
|
||||
super().__init__(_engine)
|
||||
|
||||
@staticmethod
|
||||
def _get_torrents(rss_link: str) -> list[Torrent]:
|
||||
def _get_torrents(rss: RSSItem) -> list[Torrent]:
|
||||
with RequestContent() as req:
|
||||
torrents = req.get_torrents(rss_link)
|
||||
torrents = req.get_torrents(rss.url)
|
||||
# Add RSS ID
|
||||
for torrent in torrents:
|
||||
torrent.rss_id = rss.id
|
||||
return torrents
|
||||
|
||||
def get_combine_rss(self) -> list[RSSItem]:
|
||||
@@ -33,7 +36,7 @@ class RSSEngine(Database):
|
||||
self.rss.add(rss_data)
|
||||
|
||||
def pull_rss(self, rss_item: RSSItem) -> list[Torrent]:
|
||||
torrents = self._get_torrents(rss_item.url)
|
||||
torrents = self._get_torrents(rss_item)
|
||||
new_torrents = self.torrent.check_new(torrents)
|
||||
return new_torrents
|
||||
|
||||
@@ -42,7 +45,7 @@ class RSSEngine(Database):
|
||||
if matched:
|
||||
_filter = matched.filter.replace(",", "|")
|
||||
if not re.search(_filter, torrent.name, re.IGNORECASE):
|
||||
torrent.refer_id = matched.id
|
||||
torrent.bangumi_id = matched.id
|
||||
torrent.save_path = matched.save_path
|
||||
return matched
|
||||
return None
|
||||
|
||||
@@ -1,2 +0,0 @@
|
||||
from .api import auth_user, get_current_user, get_token_data, update_user_info
|
||||
from .jwt import create_access_token
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
|
||||
from module.database.user import UserDatabase
|
||||
from module.database import Database
|
||||
from module.models.user import User
|
||||
|
||||
from .jwt import verify_token
|
||||
@@ -20,8 +20,8 @@ async def get_current_user(token: str = Depends(oauth2_scheme)):
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token"
|
||||
)
|
||||
username = payload.get("sub")
|
||||
with UserDatabase as user_db:
|
||||
user = user_db.get_user(username)
|
||||
with Database() as db:
|
||||
user = db.user.get_user(username)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid username"
|
||||
@@ -40,13 +40,13 @@ async def get_token_data(token: str = Depends(oauth2_scheme)):
|
||||
|
||||
def update_user_info(user_data: User, current_user):
|
||||
try:
|
||||
with UserDatabase as db:
|
||||
db.update_user(current_user.username, user_data)
|
||||
with Database() as db:
|
||||
db.user.update_user(current_user.username, user_data)
|
||||
return True
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
|
||||
|
||||
def auth_user(username, password):
|
||||
with UserDatabase() as db:
|
||||
db.auth_user(username, password)
|
||||
with Database() as db:
|
||||
db.user.auth_user(username, password)
|
||||
|
||||
@@ -1 +1,2 @@
|
||||
from .data_migration import data_migration
|
||||
from .startup import start_up
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
import os
|
||||
|
||||
from module.conf import LEGACY_DATA_PATH
|
||||
from module.database import Database
|
||||
from module.rss import RSSEngine
|
||||
from module.models import Bangumi
|
||||
from module.utils import json_config
|
||||
|
||||
@@ -15,8 +13,9 @@ def data_migration():
|
||||
new_data = []
|
||||
for info in infos:
|
||||
new_data.append(Bangumi(**info, rss_link=[rss_link]))
|
||||
with Database() as db:
|
||||
db.create_table()
|
||||
db.bangumi.add_all(new_data)
|
||||
|
||||
with RSSEngine() as engine:
|
||||
engine.create_table()
|
||||
engine.bangumi.add_all(new_data)
|
||||
engine.user.add_default_user()
|
||||
engine.add_rss(rss_link)
|
||||
LEGACY_DATA_PATH.unlink(missing_ok=True)
|
||||
|
||||
23
backend/src/module/update/startup.py
Normal file
23
backend/src/module/update/startup.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import logging
|
||||
|
||||
from module.rss import RSSEngine
|
||||
from module.conf import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def start_up(first_run):
|
||||
with RSSEngine() as engine:
|
||||
engine.create_table()
|
||||
engine.user.add_default_user()
|
||||
if not first_run:
|
||||
main_rss = engine.rss.search_id(1)
|
||||
if not main_rss:
|
||||
engine.add_rss(settings.rss_link, name="Mikan RSS", combine=True)
|
||||
elif main_rss.url != settings.rss_link:
|
||||
main_rss.url = settings.rss_link
|
||||
engine.rss.update(main_rss)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
start_up(False)
|
||||
Reference in New Issue
Block a user