fix: old data support problem.

This commit is contained in:
EstrellaXD
2023-08-07 20:14:45 +08:00
parent 84d5dbaceb
commit 1c4e8dc293
18 changed files with 136 additions and 66 deletions

View File

@@ -4,12 +4,12 @@ from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordRequestForm
from module.models.user import User
from module.security import (
from module.security.api import (
auth_user,
create_access_token,
get_current_user,
update_user_info,
)
from module.security.jwt import create_access_token
router = APIRouter(prefix="/auth", tags=["auth"])

View File

@@ -3,7 +3,7 @@ from fastapi.responses import JSONResponse
from module.manager import TorrentManager
from module.models import Bangumi
from module.security import get_current_user
from module.security.api import get_current_user
router = APIRouter(prefix="/bangumi", tags=["bangumi"])
@@ -14,8 +14,8 @@ async def get_all_data(current_user=Depends(get_current_user)):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token"
)
with TorrentManager() as torrent:
return torrent.search_all()
with TorrentManager() as manager:
return manager.bangumi.search_all()
@router.get("/getData/{bangumi_id}", response_model=Bangumi)
@@ -24,8 +24,8 @@ async def get_data(bangumi_id: str, current_user=Depends(get_current_user)):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token"
)
with TorrentManager() as torrent:
return torrent.search_one(bangumi_id)
with TorrentManager() as manager:
return manager.search_one(bangumi_id)
@router.post("/updateRule")
@@ -34,8 +34,8 @@ async def update_rule(data: Bangumi, current_user=Depends(get_current_user)):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token"
)
with TorrentManager() as torrent:
return torrent.update_rule(data)
with TorrentManager() as manager:
return manager.update_rule(data)
@router.delete("/deleteRule/{bangumi_id}")
@@ -46,8 +46,8 @@ async def delete_rule(
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token"
)
with TorrentManager() as torrent:
return torrent.delete_rule(bangumi_id, file)
with TorrentManager() as manager:
return manager.delete_rule(bangumi_id, file)
@router.delete("/disableRule/{bangumi_id}")
@@ -58,8 +58,8 @@ async def disable_rule(
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token"
)
with TorrentManager() as torrent:
return torrent.disable_rule(bangumi_id, file)
with TorrentManager() as manager:
return manager.disable_rule(bangumi_id, file)
@router.get("/enableRule/{bangumi_id}")
@@ -68,8 +68,8 @@ async def enable_rule(bangumi_id: str, current_user=Depends(get_current_user)):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token"
)
with TorrentManager() as torrent:
return torrent.enable_rule(bangumi_id)
with TorrentManager() as manager:
return manager.enable_rule(bangumi_id)
@router.get("/resetAll")
@@ -78,6 +78,6 @@ async def reset_all(current_user=Depends(get_current_user)):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token"
)
with TorrentManager() as torrent:
torrent.delete_all()
with TorrentManager() as manager:
manager.bangumi.delete_all()
return JSONResponse(status_code=200, content={"message": "OK"})

View File

@@ -4,7 +4,7 @@ from fastapi import APIRouter, Depends, HTTPException, status
from module.conf import settings
from module.models import Config
from module.security import get_current_user
from module.security.api import get_current_user
router = APIRouter(tags=["config"])
logger = logging.getLogger(__name__)

View File

@@ -4,7 +4,7 @@ from module.manager import SeasonCollector
from module.models import Bangumi
from module.models.api import RssLink
from module.rss import analyser
from module.security import get_current_user
from module.security.api import get_current_user
router = APIRouter(prefix="/download", tags=["download"])

View File

@@ -3,7 +3,7 @@ import os
from fastapi import APIRouter, Depends, HTTPException, Response, status
from module.conf import LOG_PATH
from module.security import get_current_user
from module.security.api import get_current_user
router = APIRouter(prefix="/log", tags=["log"])

View File

@@ -6,7 +6,7 @@ import sys
from fastapi import APIRouter, Depends, HTTPException, status
from module.core import Program
from module.security import get_current_user
from module.security.api import get_current_user
logger = logging.getLogger(__name__)
program = Program()

View File

@@ -1,7 +1,7 @@
import logging
from module.conf import VERSION, settings
from module.update import data_migration
from module.update import data_migration, start_up
from .sub_thread import RenameThread, RSSThread
@@ -32,6 +32,7 @@ class Program(RenameThread, RSSThread):
def startup(self):
self.__start_info()
start_up(self.first_run)
if self.first_run:
logger.info("First run detected, please configure the program in webui.")
return {"status": "First run detected."}

View File

@@ -3,6 +3,7 @@ from sqlmodel import Session, SQLModel
from .rss import RSSDatabase
from .torrent import TorrentDatabase
from .bangumi import BangumiDatabase
from .user import UserDatabase
from .engine import engine as e
@@ -13,6 +14,18 @@ class Database(Session):
self.rss = RSSDatabase(self)
self.torrent = TorrentDatabase(self)
self.bangumi = BangumiDatabase(self)
self.user = UserDatabase(self)
def create_table(self):
SQLModel.metadata.create_all(self.engine)
def drop_table(self):
SQLModel.metadata.drop_all(self.engine)
def migrate(self):
# Run migration online
from alembic import command
from alembic.config import Config
alembic_cfg = Config("alembic.ini")
command.upgrade(alembic_cfg, "head")

View File

@@ -30,6 +30,10 @@ class RSSDatabase:
self.session.commit()
self.session.refresh(data)
# TODO: Check if this is needed
def search_id(self, _id: int) -> RSSItem:
return self.session.get(RSSItem, _id)
def search_all(self) -> list[RSSItem]:
return self.session.exec(select(RSSItem)).all()

View File

@@ -4,34 +4,25 @@ from fastapi import HTTPException
from module.models.user import User, UserUpdate, UserLogin
from module.security.jwt import get_password_hash, verify_password
from module.database.engine import engine
from sqlmodel import Session, select, SQLModel
from sqlalchemy.exc import UnboundExecutionError, OperationalError
from sqlmodel import Session, select
logger = logging.getLogger(__name__)
class UserDatabase(Session):
def __init__(self):
super().__init__(engine)
statement = select(User)
try:
self.exec(statement)
except OperationalError:
SQLModel.metadata.create_all(engine)
self.add(User())
self.commit()
class UserDatabase:
def __init__(self, session: Session):
self.session = session
def get_user(self, username):
statement = select(User).where(User.username == username)
result = self.exec(statement).first()
result = self.session.exec(statement).first()
if not result:
raise HTTPException(status_code=404, detail="User not found")
return result
def auth_user(self, user: UserLogin) -> bool:
statement = select(User).where(User.username == user.username)
result = self.exec(statement).first()
result = self.session.exec(statement).first()
if not result:
raise HTTPException(status_code=401, detail="User not found")
if not verify_password(user.password, result.password):
@@ -41,19 +32,55 @@ class UserDatabase(Session):
def update_user(self, username, update_user: UserUpdate):
# Update username and password
statement = select(User).where(User.username == username)
result = self.exec(statement).first()
result = self.session.exec(statement).first()
if not result:
raise HTTPException(status_code=404, detail="User not found")
if update_user.username:
result.username = update_user.username
if update_user.password:
result.password = get_password_hash(update_user.password)
self.add(result)
self.commit()
self.session.add(result)
self.session.commit()
return result
def merge_old_user(self):
# get old data
statement = """
SELECT * FROM user
"""
result = self.session.exec(statement).first()
if not result:
return
# add new data
user = User(username=result.username, password=result.password)
# Drop old table
statement = """
DROP TABLE user
"""
self.session.exec(statement)
# Create new table
statement = """
CREATE TABLE user (
id INTEGER NOT NULL PRIMARY KEY,
username VARCHAR NOT NULL,
password VARCHAR NOT NULL
)
"""
self.session.exec(statement)
self.session.add(user)
self.session.commit()
if __name__ == "__main__":
with UserDatabase() as db:
# db.update_user(UserLogin(username="admin", password="adminadmin"), User(username="admin", password="cica1234"))
db.update_user("admin", UserUpdate(username="estrella", password="cica1234"))
def add_default_user(self):
# Check if user exists
statement = select(User)
try:
result = self.session.exec(statement).all()
except Exception as e:
self.merge_old_user()
result = self.session.exec(statement).all()
if len(result) != 0:
return
# Add default user
user = User(username="admin", password=get_password_hash("adminadmin"))
self.session.add(user)
self.session.commit()

View File

@@ -2,4 +2,4 @@ from .bangumi import Bangumi, Episode, BangumiUpdate, Notification
from .config import Config
from .rss import RSSItem, RSSUpdate
from .torrent import EpisodeFile, SubtitleFile, Torrent, TorrentUpdate
from .user import UserLogin
from .user import UserLogin, User

View File

@@ -5,7 +5,8 @@ from typing import Optional
class Torrent(SQLModel, table=True):
id: int = Field(default=None, primary_key=True, alias="id")
refer_id: Optional[int] = Field(None, alias="refer_id")
bangumi_id: Optional[int] = Field(None, alias="refer_id", foreign_key="bangumi.id")
rss_id: Optional[int] = Field(None, alias="rss_id", foreign_key="rssitem.id")
name: str = Field("", alias="name")
url: str = Field("https://example.com/torrent", alias="url")
homepage: Optional[str] = Field(None, alias="homepage")

View File

@@ -17,9 +17,12 @@ class RSSEngine(Database):
super().__init__(_engine)
@staticmethod
def _get_torrents(rss_link: str) -> list[Torrent]:
def _get_torrents(rss: RSSItem) -> list[Torrent]:
with RequestContent() as req:
torrents = req.get_torrents(rss_link)
torrents = req.get_torrents(rss.url)
# Add RSS ID
for torrent in torrents:
torrent.rss_id = rss.id
return torrents
def get_combine_rss(self) -> list[RSSItem]:
@@ -33,7 +36,7 @@ class RSSEngine(Database):
self.rss.add(rss_data)
def pull_rss(self, rss_item: RSSItem) -> list[Torrent]:
torrents = self._get_torrents(rss_item.url)
torrents = self._get_torrents(rss_item)
new_torrents = self.torrent.check_new(torrents)
return new_torrents
@@ -42,7 +45,7 @@ class RSSEngine(Database):
if matched:
_filter = matched.filter.replace(",", "|")
if not re.search(_filter, torrent.name, re.IGNORECASE):
torrent.refer_id = matched.id
torrent.bangumi_id = matched.id
torrent.save_path = matched.save_path
return matched
return None

View File

@@ -1,2 +0,0 @@
from .api import auth_user, get_current_user, get_token_data, update_user_info
from .jwt import create_access_token

View File

@@ -1,7 +1,7 @@
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from module.database.user import UserDatabase
from module.database import Database
from module.models.user import User
from .jwt import verify_token
@@ -20,8 +20,8 @@ async def get_current_user(token: str = Depends(oauth2_scheme)):
status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid token"
)
username = payload.get("sub")
with UserDatabase as user_db:
user = user_db.get_user(username)
with Database() as db:
user = db.user.get_user(username)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid username"
@@ -40,13 +40,13 @@ async def get_token_data(token: str = Depends(oauth2_scheme)):
def update_user_info(user_data: User, current_user):
try:
with UserDatabase as db:
db.update_user(current_user.username, user_data)
with Database() as db:
db.user.update_user(current_user.username, user_data)
return True
except Exception as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
def auth_user(username, password):
with UserDatabase() as db:
db.auth_user(username, password)
with Database() as db:
db.user.auth_user(username, password)

View File

@@ -1 +1,2 @@
from .data_migration import data_migration
from .startup import start_up

View File

@@ -1,7 +1,5 @@
import os
from module.conf import LEGACY_DATA_PATH
from module.database import Database
from module.rss import RSSEngine
from module.models import Bangumi
from module.utils import json_config
@@ -15,8 +13,9 @@ def data_migration():
new_data = []
for info in infos:
new_data.append(Bangumi(**info, rss_link=[rss_link]))
with Database() as db:
db.create_table()
db.bangumi.add_all(new_data)
with RSSEngine() as engine:
engine.create_table()
engine.bangumi.add_all(new_data)
engine.user.add_default_user()
engine.add_rss(rss_link)
LEGACY_DATA_PATH.unlink(missing_ok=True)

View File

@@ -0,0 +1,23 @@
import logging
from module.rss import RSSEngine
from module.conf import settings
logger = logging.getLogger(__name__)
def start_up(first_run):
with RSSEngine() as engine:
engine.create_table()
engine.user.add_default_user()
if not first_run:
main_rss = engine.rss.search_id(1)
if not main_rss:
engine.add_rss(settings.rss_link, name="Mikan RSS", combine=True)
elif main_rss.url != settings.rss_link:
main_rss.url = settings.rss_link
engine.rss.update(main_rss)
if __name__ == "__main__":
start_up(False)