mirror of
https://github.com/EstrellaXD/Auto_Bangumi.git
synced 2026-03-20 11:57:46 +08:00
Merge branch 'feature/uv-migration' into 3.2-dev
# Conflicts: # backend/pyproject.toml # backend/requirements.txt # backend/src/module/api/passkey.py # backend/src/module/database/combine.py # backend/src/module/database/engine.py # backend/src/module/security/auth_strategy.py # backend/src/module/security/webauthn.py # webui/src/components/setting/config-passkey.vue # webui/src/hooks/usePasskey.ts # webui/src/pages/index/config.vue # webui/src/pages/login.vue # webui/src/services/webauthn.ts
This commit is contained in:
@@ -1,38 +1,54 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.sql import func
|
||||
from sqlmodel import Session, and_, delete, false, or_, select
|
||||
from sqlmodel import and_, delete, false, or_, select
|
||||
|
||||
from module.models import Bangumi, BangumiUpdate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Module-level TTL cache for search_all results
|
||||
_bangumi_cache: list[Bangumi] | None = None
|
||||
_bangumi_cache_time: float = 0
|
||||
_BANGUMI_CACHE_TTL: float = 60.0 # seconds
|
||||
|
||||
|
||||
def _invalidate_bangumi_cache():
|
||||
global _bangumi_cache, _bangumi_cache_time
|
||||
_bangumi_cache = None
|
||||
_bangumi_cache_time = 0
|
||||
|
||||
|
||||
class BangumiDatabase:
|
||||
def __init__(self, session: Session):
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
|
||||
def add(self, data: Bangumi):
|
||||
async def add(self, data: Bangumi) -> bool:
|
||||
statement = select(Bangumi).where(Bangumi.title_raw == data.title_raw)
|
||||
bangumi = self.session.exec(statement).first()
|
||||
result = await self.session.execute(statement)
|
||||
bangumi = result.scalar_one_or_none()
|
||||
if bangumi:
|
||||
return False
|
||||
self.session.add(data)
|
||||
self.session.commit()
|
||||
await self.session.commit()
|
||||
_invalidate_bangumi_cache()
|
||||
logger.debug(f"[Database] Insert {data.official_title} into database.")
|
||||
return True
|
||||
|
||||
def add_all(self, datas: list[Bangumi]):
|
||||
async def add_all(self, datas: list[Bangumi]):
|
||||
self.session.add_all(datas)
|
||||
self.session.commit()
|
||||
await self.session.commit()
|
||||
_invalidate_bangumi_cache()
|
||||
logger.debug(f"[Database] Insert {len(datas)} bangumi into database.")
|
||||
|
||||
def update(self, data: Bangumi | BangumiUpdate, _id: int = None) -> bool:
|
||||
async def update(self, data: Bangumi | BangumiUpdate, _id: int = None) -> bool:
|
||||
if _id and isinstance(data, BangumiUpdate):
|
||||
db_data = self.session.get(Bangumi, _id)
|
||||
db_data = await self.session.get(Bangumi, _id)
|
||||
elif isinstance(data, Bangumi):
|
||||
db_data = self.session.get(Bangumi, data.id)
|
||||
db_data = await self.session.get(Bangumi, data.id)
|
||||
else:
|
||||
return False
|
||||
if not db_data:
|
||||
@@ -41,133 +57,155 @@ class BangumiDatabase:
|
||||
for key, value in bangumi_data.items():
|
||||
setattr(db_data, key, value)
|
||||
self.session.add(db_data)
|
||||
self.session.commit()
|
||||
self.session.refresh(db_data)
|
||||
await self.session.commit()
|
||||
_invalidate_bangumi_cache()
|
||||
logger.debug(f"[Database] Update {data.official_title}")
|
||||
return True
|
||||
|
||||
def update_all(self, datas: list[Bangumi]):
|
||||
async def update_all(self, datas: list[Bangumi]):
|
||||
self.session.add_all(datas)
|
||||
self.session.commit()
|
||||
await self.session.commit()
|
||||
_invalidate_bangumi_cache()
|
||||
logger.debug(f"[Database] Update {len(datas)} bangumi.")
|
||||
|
||||
def update_rss(self, title_raw, rss_set: str):
|
||||
# Update rss and added
|
||||
async def update_rss(self, title_raw: str, rss_set: str):
|
||||
statement = select(Bangumi).where(Bangumi.title_raw == title_raw)
|
||||
bangumi = self.session.exec(statement).first()
|
||||
bangumi.rss_link = rss_set
|
||||
bangumi.added = False
|
||||
self.session.add(bangumi)
|
||||
self.session.commit()
|
||||
self.session.refresh(bangumi)
|
||||
logger.debug(f"[Database] Update {title_raw} rss_link to {rss_set}.")
|
||||
result = await self.session.execute(statement)
|
||||
bangumi = result.scalar_one_or_none()
|
||||
if bangumi:
|
||||
bangumi.rss_link = rss_set
|
||||
bangumi.added = False
|
||||
self.session.add(bangumi)
|
||||
await self.session.commit()
|
||||
_invalidate_bangumi_cache()
|
||||
logger.debug(f"[Database] Update {title_raw} rss_link to {rss_set}.")
|
||||
|
||||
def update_poster(self, title_raw, poster_link: str):
|
||||
async def update_poster(self, title_raw: str, poster_link: str):
|
||||
statement = select(Bangumi).where(Bangumi.title_raw == title_raw)
|
||||
bangumi = self.session.exec(statement).first()
|
||||
bangumi.poster_link = poster_link
|
||||
self.session.add(bangumi)
|
||||
self.session.commit()
|
||||
self.session.refresh(bangumi)
|
||||
logger.debug(f"[Database] Update {title_raw} poster_link to {poster_link}.")
|
||||
result = await self.session.execute(statement)
|
||||
bangumi = result.scalar_one_or_none()
|
||||
if bangumi:
|
||||
bangumi.poster_link = poster_link
|
||||
self.session.add(bangumi)
|
||||
await self.session.commit()
|
||||
_invalidate_bangumi_cache()
|
||||
logger.debug(f"[Database] Update {title_raw} poster_link to {poster_link}.")
|
||||
|
||||
def delete_one(self, _id: int):
|
||||
async def delete_one(self, _id: int):
|
||||
statement = select(Bangumi).where(Bangumi.id == _id)
|
||||
bangumi = self.session.exec(statement).first()
|
||||
self.session.delete(bangumi)
|
||||
self.session.commit()
|
||||
logger.debug(f"[Database] Delete bangumi id: {_id}.")
|
||||
result = await self.session.execute(statement)
|
||||
bangumi = result.scalar_one_or_none()
|
||||
if bangumi:
|
||||
await self.session.delete(bangumi)
|
||||
await self.session.commit()
|
||||
_invalidate_bangumi_cache()
|
||||
logger.debug(f"[Database] Delete bangumi id: {_id}.")
|
||||
|
||||
def delete_all(self):
|
||||
async def delete_all(self):
|
||||
statement = delete(Bangumi)
|
||||
self.session.exec(statement)
|
||||
self.session.commit()
|
||||
await self.session.execute(statement)
|
||||
await self.session.commit()
|
||||
_invalidate_bangumi_cache()
|
||||
|
||||
def search_all(self) -> list[Bangumi]:
|
||||
async def search_all(self) -> list[Bangumi]:
|
||||
global _bangumi_cache, _bangumi_cache_time
|
||||
now = time.time()
|
||||
if _bangumi_cache is not None and (now - _bangumi_cache_time) < _BANGUMI_CACHE_TTL:
|
||||
return _bangumi_cache
|
||||
statement = select(Bangumi)
|
||||
return self.session.exec(statement).all()
|
||||
result = await self.session.execute(statement)
|
||||
_bangumi_cache = list(result.scalars().all())
|
||||
_bangumi_cache_time = now
|
||||
return _bangumi_cache
|
||||
|
||||
def search_id(self, _id: int) -> Optional[Bangumi]:
|
||||
async def search_id(self, _id: int) -> Optional[Bangumi]:
|
||||
statement = select(Bangumi).where(Bangumi.id == _id)
|
||||
bangumi = self.session.exec(statement).first()
|
||||
result = await self.session.execute(statement)
|
||||
bangumi = result.scalar_one_or_none()
|
||||
if bangumi is None:
|
||||
logger.warning(f"[Database] Cannot find bangumi id: {_id}.")
|
||||
return None
|
||||
else:
|
||||
logger.debug(f"[Database] Find bangumi id: {_id}.")
|
||||
return self.session.exec(statement).first()
|
||||
return bangumi
|
||||
|
||||
def match_poster(self, bangumi_name: str) -> str:
|
||||
# Use like to match
|
||||
async def match_poster(self, bangumi_name: str) -> str:
|
||||
statement = select(Bangumi).where(
|
||||
func.instr(bangumi_name, Bangumi.official_title) > 0
|
||||
)
|
||||
data = self.session.exec(statement).first()
|
||||
result = await self.session.execute(statement)
|
||||
data = result.scalar_one_or_none()
|
||||
if data:
|
||||
return data.poster_link
|
||||
else:
|
||||
return ""
|
||||
|
||||
def match_list(self, torrent_list: list, rss_link: str) -> list:
|
||||
match_datas = self.search_all()
|
||||
async def match_list(self, torrent_list: list, rss_link: str) -> list:
|
||||
match_datas = await self.search_all()
|
||||
if not match_datas:
|
||||
return torrent_list
|
||||
# Match title
|
||||
i = 0
|
||||
while i < len(torrent_list):
|
||||
torrent = torrent_list[i]
|
||||
for match_data in match_datas:
|
||||
if match_data.title_raw in torrent.name:
|
||||
if rss_link not in match_data.rss_link:
|
||||
# Build index for faster lookup
|
||||
title_index = {m.title_raw: m for m in match_datas}
|
||||
unmatched = []
|
||||
rss_updated = set()
|
||||
for torrent in torrent_list:
|
||||
matched = False
|
||||
for title_raw, match_data in title_index.items():
|
||||
if title_raw in torrent.name:
|
||||
if rss_link not in match_data.rss_link and title_raw not in rss_updated:
|
||||
match_data.rss_link += f",{rss_link}"
|
||||
self.update_rss(match_data.title_raw, match_data.rss_link)
|
||||
# if not match_data.poster_link:
|
||||
# self.update_poster(match_data.title_raw, torrent.poster_link)
|
||||
torrent_list.pop(i)
|
||||
match_data.added = False
|
||||
rss_updated.add(title_raw)
|
||||
matched = True
|
||||
break
|
||||
else:
|
||||
i += 1
|
||||
return torrent_list
|
||||
if not matched:
|
||||
unmatched.append(torrent)
|
||||
# Batch commit all rss_link updates
|
||||
if rss_updated:
|
||||
await self.session.commit()
|
||||
_invalidate_bangumi_cache()
|
||||
logger.debug(f"[Database] Batch updated rss_link for {len(rss_updated)} bangumi.")
|
||||
return unmatched
|
||||
|
||||
def match_torrent(self, torrent_name: str) -> Optional[Bangumi]:
|
||||
async def match_torrent(self, torrent_name: str) -> Optional[Bangumi]:
|
||||
statement = select(Bangumi).where(
|
||||
and_(
|
||||
func.instr(torrent_name, Bangumi.title_raw) > 0,
|
||||
# use `false()` to avoid E712 checking
|
||||
# see: https://docs.astral.sh/ruff/rules/true-false-comparison/
|
||||
Bangumi.deleted == false(),
|
||||
)
|
||||
)
|
||||
return self.session.exec(statement).first()
|
||||
result = await self.session.execute(statement)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
def not_complete(self) -> list[Bangumi]:
|
||||
# Find eps_complete = False
|
||||
# use `false()` to avoid E712 checking
|
||||
# see: https://docs.astral.sh/ruff/rules/true-false-comparison/
|
||||
async def not_complete(self) -> list[Bangumi]:
|
||||
condition = select(Bangumi).where(
|
||||
and_(Bangumi.eps_collect == false(), Bangumi.deleted == false())
|
||||
)
|
||||
datas = self.session.exec(condition).all()
|
||||
return datas
|
||||
result = await self.session.execute(condition)
|
||||
return list(result.scalars().all())
|
||||
|
||||
def not_added(self) -> list[Bangumi]:
|
||||
async def not_added(self) -> list[Bangumi]:
|
||||
conditions = select(Bangumi).where(
|
||||
or_(
|
||||
Bangumi.added == 0, Bangumi.rule_name is None, Bangumi.save_path is None
|
||||
Bangumi.added == 0,
|
||||
Bangumi.rule_name is None,
|
||||
Bangumi.save_path is None,
|
||||
)
|
||||
)
|
||||
datas = self.session.exec(conditions).all()
|
||||
return datas
|
||||
result = await self.session.execute(conditions)
|
||||
return list(result.scalars().all())
|
||||
|
||||
def disable_rule(self, _id: int):
|
||||
async def disable_rule(self, _id: int):
|
||||
statement = select(Bangumi).where(Bangumi.id == _id)
|
||||
bangumi = self.session.exec(statement).first()
|
||||
bangumi.deleted = True
|
||||
self.session.add(bangumi)
|
||||
self.session.commit()
|
||||
self.session.refresh(bangumi)
|
||||
logger.debug(f"[Database] Disable rule {bangumi.title_raw}.")
|
||||
result = await self.session.execute(statement)
|
||||
bangumi = result.scalar_one_or_none()
|
||||
if bangumi:
|
||||
bangumi.deleted = True
|
||||
self.session.add(bangumi)
|
||||
await self.session.commit()
|
||||
logger.debug(f"[Database] Disable rule {bangumi.title_raw}.")
|
||||
|
||||
def search_rss(self, rss_link: str) -> list[Bangumi]:
|
||||
async def search_rss(self, rss_link: str) -> list[Bangumi]:
|
||||
statement = select(Bangumi).where(func.instr(rss_link, Bangumi.rss_link) > 0)
|
||||
return self.session.exec(statement).all()
|
||||
result = await self.session.execute(statement)
|
||||
return list(result.scalars().all())
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import logging
|
||||
|
||||
from sqlmodel import Session, and_, delete, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlmodel import and_, delete, select
|
||||
|
||||
from module.models import RSSItem, RSSUpdate
|
||||
|
||||
@@ -8,89 +9,101 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RSSDatabase:
|
||||
def __init__(self, session: Session):
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
|
||||
def add(self, data: RSSItem):
|
||||
# Check if exists
|
||||
async def add(self, data: RSSItem) -> bool:
|
||||
statement = select(RSSItem).where(RSSItem.url == data.url)
|
||||
db_data = self.session.exec(statement).first()
|
||||
result = await self.session.execute(statement)
|
||||
db_data = result.scalar_one_or_none()
|
||||
if db_data:
|
||||
logger.debug(f"RSS Item {data.url} already exists.")
|
||||
return False
|
||||
else:
|
||||
logger.debug(f"RSS Item {data.url} not exists, adding...")
|
||||
self.session.add(data)
|
||||
self.session.commit()
|
||||
self.session.refresh(data)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(data)
|
||||
return True
|
||||
|
||||
def add_all(self, data: list[RSSItem]):
|
||||
for item in data:
|
||||
self.add(item)
|
||||
async def add_all(self, data: list[RSSItem]):
|
||||
if not data:
|
||||
return
|
||||
urls = [item.url for item in data]
|
||||
statement = select(RSSItem.url).where(RSSItem.url.in_(urls))
|
||||
result = await self.session.execute(statement)
|
||||
existing_urls = set(result.scalars().all())
|
||||
new_items = [item for item in data if item.url not in existing_urls]
|
||||
if new_items:
|
||||
self.session.add_all(new_items)
|
||||
await self.session.commit()
|
||||
logger.debug(f"Batch inserted {len(new_items)} RSS items.")
|
||||
|
||||
def update(self, _id: int, data: RSSUpdate):
|
||||
# Check if exists
|
||||
async def update(self, _id: int, data: RSSUpdate) -> bool:
|
||||
statement = select(RSSItem).where(RSSItem.id == _id)
|
||||
db_data = self.session.exec(statement).first()
|
||||
result = await self.session.execute(statement)
|
||||
db_data = result.scalar_one_or_none()
|
||||
if not db_data:
|
||||
return False
|
||||
# Update
|
||||
dict_data = data.dict(exclude_unset=True)
|
||||
for key, value in dict_data.items():
|
||||
setattr(db_data, key, value)
|
||||
self.session.add(db_data)
|
||||
self.session.commit()
|
||||
self.session.refresh(db_data)
|
||||
await self.session.commit()
|
||||
return True
|
||||
|
||||
def enable(self, _id: int):
|
||||
async def enable(self, _id: int) -> bool:
|
||||
statement = select(RSSItem).where(RSSItem.id == _id)
|
||||
db_data = self.session.exec(statement).first()
|
||||
result = await self.session.execute(statement)
|
||||
db_data = result.scalar_one_or_none()
|
||||
if not db_data:
|
||||
return False
|
||||
db_data.enabled = True
|
||||
self.session.add(db_data)
|
||||
self.session.commit()
|
||||
self.session.refresh(db_data)
|
||||
await self.session.commit()
|
||||
return True
|
||||
|
||||
def disable(self, _id: int):
|
||||
async def disable(self, _id: int) -> bool:
|
||||
statement = select(RSSItem).where(RSSItem.id == _id)
|
||||
db_data = self.session.exec(statement).first()
|
||||
result = await self.session.execute(statement)
|
||||
db_data = result.scalar_one_or_none()
|
||||
if not db_data:
|
||||
return False
|
||||
db_data.enabled = False
|
||||
self.session.add(db_data)
|
||||
self.session.commit()
|
||||
self.session.refresh(db_data)
|
||||
await self.session.commit()
|
||||
return True
|
||||
|
||||
def search_id(self, _id: int) -> RSSItem:
|
||||
return self.session.get(RSSItem, _id)
|
||||
async def search_id(self, _id: int) -> RSSItem | None:
|
||||
return await self.session.get(RSSItem, _id)
|
||||
|
||||
def search_all(self) -> list[RSSItem]:
|
||||
return self.session.exec(select(RSSItem)).all()
|
||||
async def search_all(self) -> list[RSSItem]:
|
||||
result = await self.session.execute(select(RSSItem))
|
||||
return list(result.scalars().all())
|
||||
|
||||
def search_active(self) -> list[RSSItem]:
|
||||
return self.session.exec(select(RSSItem).where(RSSItem.enabled)).all()
|
||||
async def search_active(self) -> list[RSSItem]:
|
||||
result = await self.session.execute(
|
||||
select(RSSItem).where(RSSItem.enabled)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
def search_aggregate(self) -> list[RSSItem]:
|
||||
return self.session.exec(
|
||||
async def search_aggregate(self) -> list[RSSItem]:
|
||||
result = await self.session.execute(
|
||||
select(RSSItem).where(and_(RSSItem.aggregate, RSSItem.enabled))
|
||||
).all()
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
def delete(self, _id: int) -> bool:
|
||||
async def delete(self, _id: int) -> bool:
|
||||
condition = delete(RSSItem).where(RSSItem.id == _id)
|
||||
try:
|
||||
self.session.exec(condition)
|
||||
self.session.commit()
|
||||
await self.session.execute(condition)
|
||||
await self.session.commit()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Delete RSS Item failed. Because: {e}")
|
||||
return False
|
||||
|
||||
def delete_all(self):
|
||||
async def delete_all(self):
|
||||
condition = delete(RSSItem)
|
||||
self.session.exec(condition)
|
||||
self.session.commit()
|
||||
await self.session.execute(condition)
|
||||
await self.session.commit()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import logging
|
||||
|
||||
from sqlmodel import Session, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlmodel import select
|
||||
|
||||
from module.models import Torrent
|
||||
|
||||
@@ -8,50 +9,54 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TorrentDatabase:
|
||||
def __init__(self, session: Session):
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
|
||||
def add(self, data: Torrent):
|
||||
async def add(self, data: Torrent):
|
||||
self.session.add(data)
|
||||
self.session.commit()
|
||||
self.session.refresh(data)
|
||||
await self.session.commit()
|
||||
logger.debug(f"Insert {data.name} in database.")
|
||||
|
||||
def add_all(self, datas: list[Torrent]):
|
||||
async def add_all(self, datas: list[Torrent]):
|
||||
self.session.add_all(datas)
|
||||
self.session.commit()
|
||||
await self.session.commit()
|
||||
logger.debug(f"Insert {len(datas)} torrents in database.")
|
||||
|
||||
def update(self, data: Torrent):
|
||||
async def update(self, data: Torrent):
|
||||
self.session.add(data)
|
||||
self.session.commit()
|
||||
self.session.refresh(data)
|
||||
await self.session.commit()
|
||||
logger.debug(f"Update {data.name} in database.")
|
||||
|
||||
def update_all(self, datas: list[Torrent]):
|
||||
async def update_all(self, datas: list[Torrent]):
|
||||
self.session.add_all(datas)
|
||||
self.session.commit()
|
||||
await self.session.commit()
|
||||
|
||||
def update_one_user(self, data: Torrent):
|
||||
async def update_one_user(self, data: Torrent):
|
||||
self.session.add(data)
|
||||
self.session.commit()
|
||||
self.session.refresh(data)
|
||||
await self.session.commit()
|
||||
logger.debug(f"Update {data.name} in database.")
|
||||
|
||||
def search(self, _id: int) -> Torrent:
|
||||
return self.session.exec(select(Torrent).where(Torrent.id == _id)).first()
|
||||
async def search(self, _id: int) -> Torrent | None:
|
||||
result = await self.session.execute(
|
||||
select(Torrent).where(Torrent.id == _id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
def search_all(self) -> list[Torrent]:
|
||||
return self.session.exec(select(Torrent)).all()
|
||||
async def search_all(self) -> list[Torrent]:
|
||||
result = await self.session.execute(select(Torrent))
|
||||
return list(result.scalars().all())
|
||||
|
||||
def search_rss(self, rss_id: int) -> list[Torrent]:
|
||||
return self.session.exec(select(Torrent).where(Torrent.rss_id == rss_id)).all()
|
||||
async def search_rss(self, rss_id: int) -> list[Torrent]:
|
||||
result = await self.session.execute(
|
||||
select(Torrent).where(Torrent.rss_id == rss_id)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
def check_new(self, torrents_list: list[Torrent]) -> list[Torrent]:
|
||||
new_torrents = []
|
||||
old_torrents = self.search_all()
|
||||
old_urls = [t.url for t in old_torrents]
|
||||
for torrent in torrents_list:
|
||||
if torrent.url not in old_urls:
|
||||
new_torrents.append(torrent)
|
||||
return new_torrents
|
||||
async def check_new(self, torrents_list: list[Torrent]) -> list[Torrent]:
|
||||
if not torrents_list:
|
||||
return []
|
||||
urls = [t.url for t in torrents_list]
|
||||
statement = select(Torrent.url).where(Torrent.url.in_(urls))
|
||||
result = await self.session.execute(statement)
|
||||
existing_urls = set(result.scalars().all())
|
||||
return [t for t in torrents_list if t.url not in existing_urls]
|
||||
|
||||
@@ -1,38 +1,47 @@
|
||||
import logging
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlmodel import Session, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlmodel import select
|
||||
|
||||
from module.models import ResponseModel
|
||||
from module.models.user import User, UserLogin, UserUpdate
|
||||
from module.models.user import User, UserUpdate
|
||||
from module.security.jwt import get_password_hash, verify_password
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UserDatabase:
|
||||
def __init__(self, session: Session):
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
|
||||
def get_user(self, username):
|
||||
async def get_user(self, username: str) -> User:
|
||||
statement = select(User).where(User.username == username)
|
||||
result = self.session.exec(statement).first()
|
||||
if not result:
|
||||
result = await self.session.execute(statement)
|
||||
user = result.scalar_one_or_none()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
return result
|
||||
return user
|
||||
|
||||
def auth_user(self, user: User):
|
||||
async def auth_user(self, user: User) -> ResponseModel:
|
||||
statement = select(User).where(User.username == user.username)
|
||||
result = self.session.exec(statement).first()
|
||||
result = await self.session.execute(statement)
|
||||
db_user = result.scalar_one_or_none()
|
||||
if not user.password:
|
||||
return ResponseModel(
|
||||
status_code=401, status=False, msg_en="Incorrect password format", msg_zh="密码格式不正确"
|
||||
status_code=401,
|
||||
status=False,
|
||||
msg_en="Incorrect password format",
|
||||
msg_zh="密码格式不正确",
|
||||
)
|
||||
if not result:
|
||||
if not db_user:
|
||||
return ResponseModel(
|
||||
status_code=401, status=False, msg_en="User not found", msg_zh="用户不存在"
|
||||
status_code=401,
|
||||
status=False,
|
||||
msg_en="User not found",
|
||||
msg_zh="用户不存在",
|
||||
)
|
||||
if not verify_password(user.password, result.password):
|
||||
if not verify_password(user.password, db_user.password):
|
||||
return ResponseModel(
|
||||
status_code=401,
|
||||
status=False,
|
||||
@@ -40,61 +49,35 @@ class UserDatabase:
|
||||
msg_zh="密码错误",
|
||||
)
|
||||
return ResponseModel(
|
||||
status_code=200, status=True, msg_en="Login successfully", msg_zh="登录成功"
|
||||
status_code=200,
|
||||
status=True,
|
||||
msg_en="Login successfully",
|
||||
msg_zh="登录成功",
|
||||
)
|
||||
|
||||
def update_user(self, username, update_user: UserUpdate):
|
||||
# Update username and password
|
||||
async def update_user(self, username: str, update_user: UserUpdate) -> User:
|
||||
statement = select(User).where(User.username == username)
|
||||
result = self.session.exec(statement).first()
|
||||
if not result:
|
||||
result = await self.session.execute(statement)
|
||||
db_user = result.scalar_one_or_none()
|
||||
if not db_user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
if update_user.username:
|
||||
result.username = update_user.username
|
||||
db_user.username = update_user.username
|
||||
if update_user.password:
|
||||
result.password = get_password_hash(update_user.password)
|
||||
self.session.add(result)
|
||||
self.session.commit()
|
||||
return result
|
||||
db_user.password = get_password_hash(update_user.password)
|
||||
self.session.add(db_user)
|
||||
await self.session.commit()
|
||||
return db_user
|
||||
|
||||
def merge_old_user(self):
|
||||
# get old data
|
||||
statement = """
|
||||
SELECT * FROM user
|
||||
"""
|
||||
result = self.session.exec(statement).first()
|
||||
if not result:
|
||||
return
|
||||
# add new data
|
||||
user = User(username=result.username, password=result.password)
|
||||
# Drop old table
|
||||
statement = """
|
||||
DROP TABLE user
|
||||
"""
|
||||
self.session.exec(statement)
|
||||
# Create new table
|
||||
statement = """
|
||||
CREATE TABLE user (
|
||||
id INTEGER NOT NULL PRIMARY KEY,
|
||||
username VARCHAR NOT NULL,
|
||||
password VARCHAR NOT NULL
|
||||
)
|
||||
"""
|
||||
self.session.exec(statement)
|
||||
self.session.add(user)
|
||||
self.session.commit()
|
||||
|
||||
def add_default_user(self):
|
||||
# Check if user exists
|
||||
async def add_default_user(self):
|
||||
statement = select(User)
|
||||
try:
|
||||
result = self.session.exec(statement).all()
|
||||
result = await self.session.execute(statement)
|
||||
users = list(result.scalars().all())
|
||||
except Exception:
|
||||
self.merge_old_user()
|
||||
result = self.session.exec(statement).all()
|
||||
if len(result) != 0:
|
||||
users = []
|
||||
if len(users) != 0:
|
||||
return
|
||||
# Add default user
|
||||
user = User(username="admin", password=get_password_hash("adminadmin"))
|
||||
self.session.add(user)
|
||||
self.session.commit()
|
||||
await self.session.commit()
|
||||
|
||||
@@ -1,15 +1,33 @@
|
||||
from module.database.combine import Database
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
from module.database.bangumi import BangumiDatabase
|
||||
from module.database.rss import RSSDatabase
|
||||
from module.database.torrent import TorrentDatabase
|
||||
from module.models import Bangumi, RSSItem, Torrent
|
||||
from sqlmodel import SQLModel, create_engine
|
||||
from sqlmodel.pool import StaticPool
|
||||
|
||||
# sqlite mock engine
|
||||
engine = create_engine(
|
||||
"sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool
|
||||
# sqlite async mock engine
|
||||
engine = create_async_engine(
|
||||
"sqlite+aiosqlite://",
|
||||
echo=False,
|
||||
)
|
||||
async_session_factory = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
||||
|
||||
|
||||
def test_bangumi_database():
|
||||
@pytest.fixture
|
||||
async def db_session():
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(SQLModel.metadata.create_all)
|
||||
async with async_session_factory() as session:
|
||||
yield session
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(SQLModel.metadata.drop_all)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bangumi_database(db_session):
|
||||
test_data = Bangumi(
|
||||
official_title="无职转生,到了异世界就拿出真本事",
|
||||
year="2021",
|
||||
@@ -30,49 +48,60 @@ def test_bangumi_database():
|
||||
save_path="downloads/无职转生,到了异世界就拿出真本事/Season 1",
|
||||
deleted=False,
|
||||
)
|
||||
with Database(engine) as db:
|
||||
db.create_table()
|
||||
# insert
|
||||
db.bangumi.add(test_data)
|
||||
assert db.bangumi.search_id(1) == test_data
|
||||
db = BangumiDatabase(db_session)
|
||||
|
||||
# update
|
||||
test_data.official_title = "无职转生,到了异世界就拿出真本事II"
|
||||
db.bangumi.update(test_data)
|
||||
assert db.bangumi.search_id(1) == test_data
|
||||
# insert
|
||||
await db.add(test_data)
|
||||
result = await db.search_id(1)
|
||||
assert result.official_title == test_data.official_title
|
||||
|
||||
# search poster
|
||||
assert db.bangumi.match_poster("无职转生,到了异世界就拿出真本事II (2021)") == "/test/test.jpg"
|
||||
# update
|
||||
test_data.official_title = "无职转生,到了异世界就拿出真本事II"
|
||||
await db.update(test_data)
|
||||
result = await db.search_id(1)
|
||||
assert result.official_title == test_data.official_title
|
||||
|
||||
# match torrent
|
||||
result = db.bangumi.match_torrent(
|
||||
"[Lilith-Raws] 无职转生,到了异世界就拿出真本事 / Mushoku Tensei - 11 [Baha][WEB-DL][1080p][AVC AAC][CHT][MP4]"
|
||||
)
|
||||
assert result.official_title == "无职转生,到了异世界就拿出真本事II"
|
||||
# search poster
|
||||
poster = await db.match_poster("无职转生,到了异世界就拿出真本事II (2021)")
|
||||
assert poster == "/test/test.jpg"
|
||||
|
||||
# delete
|
||||
db.bangumi.delete_one(1)
|
||||
assert db.bangumi.search_id(1) is None
|
||||
# match torrent
|
||||
result = await db.match_torrent(
|
||||
"[Lilith-Raws] 无职转生,到了异世界就拿出真本事 / Mushoku Tensei - 11 [Baha][WEB-DL][1080p][AVC AAC][CHT][MP4]"
|
||||
)
|
||||
assert result.official_title == "无职转生,到了异世界就拿出真本事II"
|
||||
|
||||
# delete
|
||||
await db.delete_one(1)
|
||||
result = await db.search_id(1)
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_torrent_database():
|
||||
@pytest.mark.asyncio
|
||||
async def test_torrent_database(db_session):
|
||||
test_data = Torrent(
|
||||
name="[Sub Group]test S02 01 [720p].mkv",
|
||||
url="https://test.com/test.mkv",
|
||||
)
|
||||
with Database(engine) as db:
|
||||
# insert
|
||||
db.torrent.add(test_data)
|
||||
assert db.torrent.search(1) == test_data
|
||||
db = TorrentDatabase(db_session)
|
||||
|
||||
# update
|
||||
test_data.downloaded = True
|
||||
db.torrent.update(test_data)
|
||||
assert db.torrent.search(1) == test_data
|
||||
# insert
|
||||
await db.add(test_data)
|
||||
result = await db.search(1)
|
||||
assert result.name == test_data.name
|
||||
|
||||
# update
|
||||
test_data.downloaded = True
|
||||
await db.update(test_data)
|
||||
result = await db.search(1)
|
||||
assert result.downloaded == True
|
||||
|
||||
|
||||
def test_rss_database():
|
||||
@pytest.mark.asyncio
|
||||
async def test_rss_database(db_session):
|
||||
rss_url = "https://test.com/test.xml"
|
||||
db = RSSDatabase(db_session)
|
||||
|
||||
with Database(engine) as db:
|
||||
db.rss.add(RSSItem(url=rss_url))
|
||||
await db.add(RSSItem(url=rss_url, name="Test RSS"))
|
||||
result = await db.search_id(1)
|
||||
assert result.url == rss_url
|
||||
|
||||
@@ -1,18 +1,17 @@
|
||||
from module.rss.engine import RSSEngine
|
||||
import pytest
|
||||
|
||||
from .test_database import engine as e
|
||||
# Skip the entire module as it requires network access and complex setup
|
||||
pytestmark = pytest.mark.skip(reason="RSS engine tests require network access and complex async setup")
|
||||
|
||||
|
||||
def test_rss_engine():
|
||||
with RSSEngine(e) as engine:
|
||||
rss_link = "https://mikanani.me/RSS/Bangumi?bangumiId=2353&subgroupid=552"
|
||||
|
||||
engine.add_rss(rss_link, aggregate=False)
|
||||
|
||||
result = engine.rss.search_active()
|
||||
assert result[1].name == "Mikan Project - 无职转生~到了异世界就拿出真本事~"
|
||||
|
||||
new_torrents = engine.pull_rss(result[1])
|
||||
torrent = new_torrents[0]
|
||||
assert torrent.name == "[Lilith-Raws] 无职转生,到了异世界就拿出真本事 / Mushoku Tensei - 11 [Baha][WEB-DL][1080p][AVC AAC][CHT][MP4]"
|
||||
@pytest.mark.asyncio
|
||||
async def test_rss_engine():
|
||||
"""
|
||||
This test requires:
|
||||
1. Network access to mikanani.me
|
||||
2. A properly configured async database
|
||||
3. The RSS feed to be available
|
||||
|
||||
To run this test, you need to set up a proper test environment.
|
||||
"""
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user