diff --git a/app/api/endpoints/history.py b/app/api/endpoints/history.py index 3d5d0e8d..3cc59d3f 100644 --- a/app/api/endpoints/history.py +++ b/app/api/endpoints/history.py @@ -2,51 +2,52 @@ from typing import List, Any, Optional import jieba from fastapi import APIRouter, Depends +from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session from app import schemas from app.chain.storage import StorageChain from app.core.event import eventmanager from app.core.security import verify_token -from app.db import get_db +from app.db import get_async_db, get_db from app.db.models import User from app.db.models.downloadhistory import DownloadHistory from app.db.models.transferhistory import TransferHistory -from app.db.user_oper import get_current_active_superuser +from app.db.user_oper import get_current_active_superuser_async, get_current_active_superuser from app.schemas.types import EventType, MediaType router = APIRouter() @router.get("/download", summary="查询下载历史记录", response_model=List[schemas.DownloadHistory]) -def download_history(page: Optional[int] = 1, - count: Optional[int] = 30, - db: Session = Depends(get_db), - _: schemas.TokenPayload = Depends(verify_token)) -> Any: +async def download_history(page: Optional[int] = 1, + count: Optional[int] = 30, + db: AsyncSession = Depends(get_async_db), + _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 查询下载历史记录 """ - return DownloadHistory.list_by_page(db, page, count) + return await DownloadHistory.async_list_by_page(db, page, count) @router.delete("/download", summary="删除下载历史记录", response_model=schemas.Response) -def delete_download_history(history_in: schemas.DownloadHistory, - db: Session = Depends(get_db), - _: schemas.TokenPayload = Depends(verify_token)) -> Any: +async def delete_download_history(history_in: schemas.DownloadHistory, + db: AsyncSession = Depends(get_async_db), + _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 删除下载历史记录 """ - DownloadHistory.delete(db, history_in.id) + await DownloadHistory.async_delete(db, history_in.id) return schemas.Response(success=True) @router.get("/transfer", summary="查询整理记录", response_model=schemas.Response) -def transfer_history(title: Optional[str] = None, - page: Optional[int] = 1, - count: Optional[int] = 30, - status: Optional[bool] = None, - db: Session = Depends(get_db), - _: schemas.TokenPayload = Depends(verify_token)) -> Any: +async def transfer_history(title: Optional[str] = None, + page: Optional[int] = 1, + count: Optional[int] = 30, + status: Optional[bool] = None, + db: AsyncSession = Depends(get_async_db), + _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 查询整理记录 """ @@ -60,12 +61,12 @@ def transfer_history(title: Optional[str] = None, if title: words = jieba.cut(title, HMM=False) title = "%".join(words) - total = TransferHistory.count_by_title(db, title=title, status=status) - result = TransferHistory.list_by_title(db, title=title, page=page, - count=count, status=status) + total = await TransferHistory.async_count_by_title(db, title=title, status=status) + result = await TransferHistory.async_list_by_title(db, title=title, page=page, + count=count, status=status) else: - result = TransferHistory.list_by_page(db, page=page, count=count, status=status) - total = TransferHistory.count(db, status=status) + result = await TransferHistory.async_list_by_page(db, page=page, count=count, status=status) + total = await TransferHistory.async_count(db, status=status) return schemas.Response(success=True, data={ @@ -111,10 +112,10 @@ def delete_transfer_history(history_in: schemas.TransferHistory, @router.get("/empty/transfer", summary="清空整理记录", response_model=schemas.Response) -def delete_transfer_history(db: Session = Depends(get_db), - _: User = Depends(get_current_active_superuser)) -> Any: +async def empty_transfer_history(db: AsyncSession = Depends(get_async_db), + _: User = Depends(get_current_active_superuser_async)) -> Any: """ 清空整理记录 """ - TransferHistory.truncate(db) + await TransferHistory.async_truncate(db) return schemas.Response(success=True) diff --git a/app/db/models/downloadhistory.py b/app/db/models/downloadhistory.py index 7bfe079a..07ec984b 100644 --- a/app/db/models/downloadhistory.py +++ b/app/db/models/downloadhistory.py @@ -1,10 +1,11 @@ import time from typing import Optional -from sqlalchemy import Column, Integer, String, Sequence, JSON +from sqlalchemy import Column, Integer, String, Sequence, JSON, select from sqlalchemy.orm import Session +from sqlalchemy.ext.asyncio import AsyncSession -from app.db import db_query, db_update, Base +from app.db import db_query, db_update, Base, async_db_query class DownloadHistory(Base): @@ -76,6 +77,14 @@ class DownloadHistory(Base): def list_by_page(db: Session, page: Optional[int] = 1, count: Optional[int] = 30): return db.query(DownloadHistory).offset((page - 1) * count).limit(count).all() + @classmethod + @async_db_query + async def async_list_by_page(cls, db: AsyncSession, page: Optional[int] = 1, count: Optional[int] = 30): + result = await db.execute( + select(cls).offset((page - 1) * count).limit(count) + ) + return result.scalars().all() + @staticmethod @db_query def get_by_path(db: Session, path: str): diff --git a/app/db/models/transferhistory.py b/app/db/models/transferhistory.py index a0b7661f..fecbdca8 100644 --- a/app/db/models/transferhistory.py +++ b/app/db/models/transferhistory.py @@ -1,10 +1,11 @@ import time from typing import Optional -from sqlalchemy import Column, Integer, String, Sequence, Boolean, func, or_, JSON +from sqlalchemy import Column, Integer, String, Sequence, Boolean, func, or_, JSON, select +from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session -from app.db import db_query, db_update, Base +from app.db import db_query, db_update, Base, async_db_query class TransferHistory(Base): @@ -77,6 +78,30 @@ class TransferHistory(Base): TransferHistory.date.desc() ).offset((page - 1) * count).limit(count).all() + @classmethod + @async_db_query + async def async_list_by_title(cls, db: AsyncSession, title: str, page: Optional[int] = 1, count: Optional[int] = 30, + status: bool = None): + if status is not None: + result = await db.execute( + select(cls).filter( + cls.status == status + ).order_by( + cls.date.desc() + ).offset((page - 1) * count).limit(count) + ) + else: + result = await db.execute( + select(cls).filter(or_( + cls.title.like(f'%{title}%'), + cls.src.like(f'%{title}%'), + cls.dest.like(f'%{title}%'), + )).order_by( + cls.date.desc() + ).offset((page - 1) * count).limit(count) + ) + return result.scalars().all() + @staticmethod @db_query def list_by_page(db: Session, page: Optional[int] = 1, count: Optional[int] = 30, status: bool = None): @@ -91,6 +116,26 @@ class TransferHistory(Base): TransferHistory.date.desc() ).offset((page - 1) * count).limit(count).all() + @classmethod + @async_db_query + async def async_list_by_page(cls, db: AsyncSession, page: Optional[int] = 1, count: Optional[int] = 30, + status: bool = None): + if status is not None: + result = await db.execute( + select(cls).filter( + cls.status == status + ).order_by( + cls.date.desc() + ).offset((page - 1) * count).limit(count) + ) + else: + result = await db.execute( + select(cls).order_by( + cls.date.desc() + ).offset((page - 1) * count).limit(count) + ) + return result.scalars().all() + @staticmethod @db_query def get_by_hash(db: Session, download_hash: str): @@ -135,6 +180,19 @@ class TransferHistory(Base): else: return db.query(func.count(TransferHistory.id)).first()[0] + @classmethod + @async_db_query + async def async_count(cls, db: AsyncSession, status: bool = None): + if status is not None: + result = await db.execute( + select(func.count(cls.id)).filter(cls.status == status) + ) + else: + result = await db.execute( + select(func.count(cls.id)) + ) + return result.scalar() + @staticmethod @db_query def count_by_title(db: Session, title: str, status: bool = None): @@ -147,6 +205,23 @@ class TransferHistory(Base): TransferHistory.dest.like(f'%{title}%') )).first()[0] + @classmethod + @async_db_query + async def async_count_by_title(cls, db: AsyncSession, title: str, status: bool = None): + if status is not None: + result = await db.execute( + select(func.count(cls.id)).filter(cls.status == status) + ) + else: + result = await db.execute( + select(func.count(cls.id)).filter(or_( + cls.title.like(f'%{title}%'), + cls.src.like(f'%{title}%'), + cls.dest.like(f'%{title}%') + )) + ) + return result.scalar() + @staticmethod @db_query def list_by(db: Session, mtype: Optional[str] = None, title: Optional[str] = None, year: Optional[str] = None, diff --git a/app/db/user_oper.py b/app/db/user_oper.py index 705b6153..fc816b64 100644 --- a/app/db/user_oper.py +++ b/app/db/user_oper.py @@ -1,11 +1,12 @@ from typing import Optional, List from fastapi import Depends, HTTPException +from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session from app import schemas from app.core.security import verify_token -from app.db import DbOper, get_db +from app.db import DbOper, get_db, get_async_db from app.db.models.user import User @@ -22,6 +23,19 @@ def get_current_user( return user +async def get_current_user_async( + db: AsyncSession = Depends(get_async_db), + token_data: schemas.TokenPayload = Depends(verify_token) +) -> User: + """ + 异步获取当前用户 + """ + user = await User.async_get(db, rid=token_data.sub) + if not user: + raise HTTPException(status_code=403, detail="用户不存在") + return user + + def get_current_active_user( current_user: User = Depends(get_current_user), ) -> User: @@ -33,6 +47,17 @@ def get_current_active_user( return current_user +async def get_current_active_user_async( + current_user: User = Depends(get_current_user_async), +) -> User: + """ + 异步获取当前激活用户 + """ + if not current_user.is_active: + raise HTTPException(status_code=403, detail="用户未激活") + return current_user + + def get_current_active_superuser( current_user: User = Depends(get_current_user), ) -> User: @@ -46,6 +71,19 @@ def get_current_active_superuser( return current_user +async def get_current_active_superuser_async( + current_user: User = Depends(get_current_user_async), +) -> User: + """ + 异步获取当前激活超级管理员 + """ + if not current_user.is_superuser: + raise HTTPException( + status_code=400, detail="用户权限不足" + ) + return current_user + + class UserOper(DbOper): """ 用户管理