From 1c4e8dc2932d464f24bba54e37a3152cc68fe1c4 Mon Sep 17 00:00:00 2001 From: EstrellaXD Date: Mon, 7 Aug 2023 20:14:45 +0800 Subject: [PATCH] fix: old data support problem. --- backend/src/module/api/auth.py | 4 +- backend/src/module/api/bangumi.py | 30 ++++----- backend/src/module/api/config.py | 2 +- backend/src/module/api/download.py | 2 +- backend/src/module/api/log.py | 2 +- backend/src/module/api/program.py | 2 +- backend/src/module/core/program.py | 3 +- backend/src/module/database/combine.py | 13 ++++ backend/src/module/database/rss.py | 4 ++ backend/src/module/database/user.py | 71 ++++++++++++++------- backend/src/module/models/__init__.py | 2 +- backend/src/module/models/torrent.py | 3 +- backend/src/module/rss/engine.py | 11 ++-- backend/src/module/security/__init__.py | 2 - backend/src/module/security/api.py | 14 ++-- backend/src/module/update/__init__.py | 1 + backend/src/module/update/data_migration.py | 13 ++-- backend/src/module/update/startup.py | 23 +++++++ 18 files changed, 136 insertions(+), 66 deletions(-) create mode 100644 backend/src/module/update/startup.py diff --git a/backend/src/module/api/auth.py b/backend/src/module/api/auth.py index a1a2fc10..d670a8c7 100644 --- a/backend/src/module/api/auth.py +++ b/backend/src/module/api/auth.py @@ -4,12 +4,12 @@ from fastapi import APIRouter, Depends, HTTPException, status from fastapi.security import OAuth2PasswordRequestForm from module.models.user import User -from module.security import ( +from module.security.api import ( auth_user, - create_access_token, get_current_user, update_user_info, ) +from module.security.jwt import create_access_token router = APIRouter(prefix="/auth", tags=["auth"]) diff --git a/backend/src/module/api/bangumi.py b/backend/src/module/api/bangumi.py index 7254fe45..dceea760 100644 --- a/backend/src/module/api/bangumi.py +++ b/backend/src/module/api/bangumi.py @@ -3,7 +3,7 @@ from fastapi.responses import JSONResponse from module.manager import TorrentManager from module.models import Bangumi -from module.security import get_current_user +from module.security.api import get_current_user router = APIRouter(prefix="/bangumi", tags=["bangumi"]) @@ -14,8 +14,8 @@ async def get_all_data(current_user=Depends(get_current_user)): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token" ) - with TorrentManager() as torrent: - return torrent.search_all() + with TorrentManager() as manager: + return manager.bangumi.search_all() @router.get("/getData/{bangumi_id}", response_model=Bangumi) @@ -24,8 +24,8 @@ async def get_data(bangumi_id: str, current_user=Depends(get_current_user)): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token" ) - with TorrentManager() as torrent: - return torrent.search_one(bangumi_id) + with TorrentManager() as manager: + return manager.search_one(bangumi_id) @router.post("/updateRule") @@ -34,8 +34,8 @@ async def update_rule(data: Bangumi, current_user=Depends(get_current_user)): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token" ) - with TorrentManager() as torrent: - return torrent.update_rule(data) + with TorrentManager() as manager: + return manager.update_rule(data) @router.delete("/deleteRule/{bangumi_id}") @@ -46,8 +46,8 @@ async def delete_rule( raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token" ) - with TorrentManager() as torrent: - return torrent.delete_rule(bangumi_id, file) + with TorrentManager() as manager: + return manager.delete_rule(bangumi_id, file) @router.delete("/disableRule/{bangumi_id}") @@ -58,8 +58,8 @@ async def disable_rule( raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token" ) - with TorrentManager() as torrent: - return torrent.disable_rule(bangumi_id, file) + with TorrentManager() as manager: + return manager.disable_rule(bangumi_id, file) @router.get("/enableRule/{bangumi_id}") @@ -68,8 +68,8 @@ async def enable_rule(bangumi_id: str, current_user=Depends(get_current_user)): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token" ) - with TorrentManager() as torrent: - return torrent.enable_rule(bangumi_id) + with TorrentManager() as manager: + return manager.enable_rule(bangumi_id) @router.get("/resetAll") @@ -78,6 +78,6 @@ async def reset_all(current_user=Depends(get_current_user)): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token" ) - with TorrentManager() as torrent: - torrent.delete_all() + with TorrentManager() as manager: + manager.bangumi.delete_all() return JSONResponse(status_code=200, content={"message": "OK"}) diff --git a/backend/src/module/api/config.py b/backend/src/module/api/config.py index 2c396919..25e4d2a8 100644 --- a/backend/src/module/api/config.py +++ b/backend/src/module/api/config.py @@ -4,7 +4,7 @@ from fastapi import APIRouter, Depends, HTTPException, status from module.conf import settings from module.models import Config -from module.security import get_current_user +from module.security.api import get_current_user router = APIRouter(tags=["config"]) logger = logging.getLogger(__name__) diff --git a/backend/src/module/api/download.py b/backend/src/module/api/download.py index dcd36f1e..30ac2518 100644 --- a/backend/src/module/api/download.py +++ b/backend/src/module/api/download.py @@ -4,7 +4,7 @@ from module.manager import SeasonCollector from module.models import Bangumi from module.models.api import RssLink from module.rss import analyser -from module.security import get_current_user +from module.security.api import get_current_user router = APIRouter(prefix="/download", tags=["download"]) diff --git a/backend/src/module/api/log.py b/backend/src/module/api/log.py index 88ec2c57..d530adba 100644 --- a/backend/src/module/api/log.py +++ b/backend/src/module/api/log.py @@ -3,7 +3,7 @@ import os from fastapi import APIRouter, Depends, HTTPException, Response, status from module.conf import LOG_PATH -from module.security import get_current_user +from module.security.api import get_current_user router = APIRouter(prefix="/log", tags=["log"]) diff --git a/backend/src/module/api/program.py b/backend/src/module/api/program.py index 83720a0e..b45f8978 100644 --- a/backend/src/module/api/program.py +++ b/backend/src/module/api/program.py @@ -6,7 +6,7 @@ import sys from fastapi import APIRouter, Depends, HTTPException, status from module.core import Program -from module.security import get_current_user +from module.security.api import get_current_user logger = logging.getLogger(__name__) program = Program() diff --git a/backend/src/module/core/program.py b/backend/src/module/core/program.py index 22a2915f..e926d0f1 100644 --- a/backend/src/module/core/program.py +++ b/backend/src/module/core/program.py @@ -1,7 +1,7 @@ import logging from module.conf import VERSION, settings -from module.update import data_migration +from module.update import data_migration, start_up from .sub_thread import RenameThread, RSSThread @@ -32,6 +32,7 @@ class Program(RenameThread, RSSThread): def startup(self): self.__start_info() + start_up(self.first_run) if self.first_run: logger.info("First run detected, please configure the program in webui.") return {"status": "First run detected."} diff --git a/backend/src/module/database/combine.py b/backend/src/module/database/combine.py index 25c717dd..805f165e 100644 --- a/backend/src/module/database/combine.py +++ b/backend/src/module/database/combine.py @@ -3,6 +3,7 @@ from sqlmodel import Session, SQLModel from .rss import RSSDatabase from .torrent import TorrentDatabase from .bangumi import BangumiDatabase +from .user import UserDatabase from .engine import engine as e @@ -13,6 +14,18 @@ class Database(Session): self.rss = RSSDatabase(self) self.torrent = TorrentDatabase(self) self.bangumi = BangumiDatabase(self) + self.user = UserDatabase(self) def create_table(self): SQLModel.metadata.create_all(self.engine) + + def drop_table(self): + SQLModel.metadata.drop_all(self.engine) + + def migrate(self): + # Run migration online + from alembic import command + from alembic.config import Config + + alembic_cfg = Config("alembic.ini") + command.upgrade(alembic_cfg, "head") diff --git a/backend/src/module/database/rss.py b/backend/src/module/database/rss.py index 3f182877..f213a2e6 100644 --- a/backend/src/module/database/rss.py +++ b/backend/src/module/database/rss.py @@ -30,6 +30,10 @@ class RSSDatabase: self.session.commit() self.session.refresh(data) + # TODO: Check if this is needed + def search_id(self, _id: int) -> RSSItem: + return self.session.get(RSSItem, _id) + def search_all(self) -> list[RSSItem]: return self.session.exec(select(RSSItem)).all() diff --git a/backend/src/module/database/user.py b/backend/src/module/database/user.py index 6fe77d2f..9664915a 100644 --- a/backend/src/module/database/user.py +++ b/backend/src/module/database/user.py @@ -4,34 +4,25 @@ from fastapi import HTTPException from module.models.user import User, UserUpdate, UserLogin from module.security.jwt import get_password_hash, verify_password -from module.database.engine import engine -from sqlmodel import Session, select, SQLModel -from sqlalchemy.exc import UnboundExecutionError, OperationalError +from sqlmodel import Session, select logger = logging.getLogger(__name__) -class UserDatabase(Session): - def __init__(self): - super().__init__(engine) - statement = select(User) - try: - self.exec(statement) - except OperationalError: - SQLModel.metadata.create_all(engine) - self.add(User()) - self.commit() +class UserDatabase: + def __init__(self, session: Session): + self.session = session def get_user(self, username): statement = select(User).where(User.username == username) - result = self.exec(statement).first() + result = self.session.exec(statement).first() if not result: raise HTTPException(status_code=404, detail="User not found") return result def auth_user(self, user: UserLogin) -> bool: statement = select(User).where(User.username == user.username) - result = self.exec(statement).first() + result = self.session.exec(statement).first() if not result: raise HTTPException(status_code=401, detail="User not found") if not verify_password(user.password, result.password): @@ -41,19 +32,55 @@ class UserDatabase(Session): def update_user(self, username, update_user: UserUpdate): # Update username and password statement = select(User).where(User.username == username) - result = self.exec(statement).first() + result = self.session.exec(statement).first() if not result: raise HTTPException(status_code=404, detail="User not found") if update_user.username: result.username = update_user.username if update_user.password: result.password = get_password_hash(update_user.password) - self.add(result) - self.commit() + self.session.add(result) + self.session.commit() return result + def merge_old_user(self): + # get old data + statement = """ + SELECT * FROM user + """ + result = self.session.exec(statement).first() + if not result: + return + # add new data + user = User(username=result.username, password=result.password) + # Drop old table + statement = """ + DROP TABLE user + """ + self.session.exec(statement) + # Create new table + statement = """ + CREATE TABLE user ( + id INTEGER NOT NULL PRIMARY KEY, + username VARCHAR NOT NULL, + password VARCHAR NOT NULL + ) + """ + self.session.exec(statement) + self.session.add(user) + self.session.commit() -if __name__ == "__main__": - with UserDatabase() as db: - # db.update_user(UserLogin(username="admin", password="adminadmin"), User(username="admin", password="cica1234")) - db.update_user("admin", UserUpdate(username="estrella", password="cica1234")) + def add_default_user(self): + # Check if user exists + statement = select(User) + try: + result = self.session.exec(statement).all() + except Exception as e: + self.merge_old_user() + result = self.session.exec(statement).all() + if len(result) != 0: + return + # Add default user + user = User(username="admin", password=get_password_hash("adminadmin")) + self.session.add(user) + self.session.commit() diff --git a/backend/src/module/models/__init__.py b/backend/src/module/models/__init__.py index 046930c8..dc391533 100644 --- a/backend/src/module/models/__init__.py +++ b/backend/src/module/models/__init__.py @@ -2,4 +2,4 @@ from .bangumi import Bangumi, Episode, BangumiUpdate, Notification from .config import Config from .rss import RSSItem, RSSUpdate from .torrent import EpisodeFile, SubtitleFile, Torrent, TorrentUpdate -from .user import UserLogin +from .user import UserLogin, User diff --git a/backend/src/module/models/torrent.py b/backend/src/module/models/torrent.py index 1e6198d7..0d5f68be 100644 --- a/backend/src/module/models/torrent.py +++ b/backend/src/module/models/torrent.py @@ -5,7 +5,8 @@ from typing import Optional class Torrent(SQLModel, table=True): id: int = Field(default=None, primary_key=True, alias="id") - refer_id: Optional[int] = Field(None, alias="refer_id") + bangumi_id: Optional[int] = Field(None, alias="refer_id", foreign_key="bangumi.id") + rss_id: Optional[int] = Field(None, alias="rss_id", foreign_key="rssitem.id") name: str = Field("", alias="name") url: str = Field("https://example.com/torrent", alias="url") homepage: Optional[str] = Field(None, alias="homepage") diff --git a/backend/src/module/rss/engine.py b/backend/src/module/rss/engine.py index c36153d3..802d0bbe 100644 --- a/backend/src/module/rss/engine.py +++ b/backend/src/module/rss/engine.py @@ -17,9 +17,12 @@ class RSSEngine(Database): super().__init__(_engine) @staticmethod - def _get_torrents(rss_link: str) -> list[Torrent]: + def _get_torrents(rss: RSSItem) -> list[Torrent]: with RequestContent() as req: - torrents = req.get_torrents(rss_link) + torrents = req.get_torrents(rss.url) + # Add RSS ID + for torrent in torrents: + torrent.rss_id = rss.id return torrents def get_combine_rss(self) -> list[RSSItem]: @@ -33,7 +36,7 @@ class RSSEngine(Database): self.rss.add(rss_data) def pull_rss(self, rss_item: RSSItem) -> list[Torrent]: - torrents = self._get_torrents(rss_item.url) + torrents = self._get_torrents(rss_item) new_torrents = self.torrent.check_new(torrents) return new_torrents @@ -42,7 +45,7 @@ class RSSEngine(Database): if matched: _filter = matched.filter.replace(",", "|") if not re.search(_filter, torrent.name, re.IGNORECASE): - torrent.refer_id = matched.id + torrent.bangumi_id = matched.id torrent.save_path = matched.save_path return matched return None diff --git a/backend/src/module/security/__init__.py b/backend/src/module/security/__init__.py index 7ce58e8c..e69de29b 100644 --- a/backend/src/module/security/__init__.py +++ b/backend/src/module/security/__init__.py @@ -1,2 +0,0 @@ -from .api import auth_user, get_current_user, get_token_data, update_user_info -from .jwt import create_access_token diff --git a/backend/src/module/security/api.py b/backend/src/module/security/api.py index b41dcafa..3fd59331 100644 --- a/backend/src/module/security/api.py +++ b/backend/src/module/security/api.py @@ -1,7 +1,7 @@ from fastapi import Depends, HTTPException, status from fastapi.security import OAuth2PasswordBearer -from module.database.user import UserDatabase +from module.database import Database from module.models.user import User from .jwt import verify_token @@ -20,8 +20,8 @@ async def get_current_user(token: str = Depends(oauth2_scheme)): status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token" ) username = payload.get("sub") - with UserDatabase as user_db: - user = user_db.get_user(username) + with Database() as db: + user = db.user.get_user(username) if not user: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid username" @@ -40,13 +40,13 @@ async def get_token_data(token: str = Depends(oauth2_scheme)): def update_user_info(user_data: User, current_user): try: - with UserDatabase as db: - db.update_user(current_user.username, user_data) + with Database() as db: + db.user.update_user(current_user.username, user_data) return True except Exception as e: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) def auth_user(username, password): - with UserDatabase() as db: - db.auth_user(username, password) + with Database() as db: + db.user.auth_user(username, password) diff --git a/backend/src/module/update/__init__.py b/backend/src/module/update/__init__.py index 0044a5be..2f4148c6 100644 --- a/backend/src/module/update/__init__.py +++ b/backend/src/module/update/__init__.py @@ -1 +1,2 @@ from .data_migration import data_migration +from .startup import start_up diff --git a/backend/src/module/update/data_migration.py b/backend/src/module/update/data_migration.py index 95ccc905..ed5a69b7 100644 --- a/backend/src/module/update/data_migration.py +++ b/backend/src/module/update/data_migration.py @@ -1,7 +1,5 @@ -import os - from module.conf import LEGACY_DATA_PATH -from module.database import Database +from module.rss import RSSEngine from module.models import Bangumi from module.utils import json_config @@ -15,8 +13,9 @@ def data_migration(): new_data = [] for info in infos: new_data.append(Bangumi(**info, rss_link=[rss_link])) - with Database() as db: - db.create_table() - db.bangumi.add_all(new_data) - + with RSSEngine() as engine: + engine.create_table() + engine.bangumi.add_all(new_data) + engine.user.add_default_user() + engine.add_rss(rss_link) LEGACY_DATA_PATH.unlink(missing_ok=True) diff --git a/backend/src/module/update/startup.py b/backend/src/module/update/startup.py new file mode 100644 index 00000000..888bf9ad --- /dev/null +++ b/backend/src/module/update/startup.py @@ -0,0 +1,23 @@ +import logging + +from module.rss import RSSEngine +from module.conf import settings + +logger = logging.getLogger(__name__) + + +def start_up(first_run): + with RSSEngine() as engine: + engine.create_table() + engine.user.add_default_user() + if not first_run: + main_rss = engine.rss.search_id(1) + if not main_rss: + engine.add_rss(settings.rss_link, name="Mikan RSS", combine=True) + elif main_rss.url != settings.rss_link: + main_rss.url = settings.rss_link + engine.rss.update(main_rss) + + +if __name__ == "__main__": + start_up(False)