Files
MoviePilot/app/db/user_oper.py

150 lines
3.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from typing import Optional, List
from fastapi import Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from app import schemas
from app.core.security import verify_token
from app.db import DbOper, get_db, get_async_db
from app.db.models.user import User
def get_current_user(
db: Session = Depends(get_db),
token_data: schemas.TokenPayload = Depends(verify_token)
) -> User:
"""
获取当前用户
"""
user = User.get(db, rid=token_data.sub)
if not user:
raise HTTPException(status_code=403, detail="用户不存在")
return user
async def get_current_user_async(
db: AsyncSession = Depends(get_async_db),
token_data: schemas.TokenPayload = Depends(verify_token)
) -> User:
"""
异步获取当前用户
"""
user = await User.async_get(db, rid=token_data.sub)
if not user:
raise HTTPException(status_code=403, detail="用户不存在")
return user
def get_current_active_user(
current_user: User = Depends(get_current_user),
) -> User:
"""
获取当前激活用户
"""
if not current_user.is_active:
raise HTTPException(status_code=403, detail="用户未激活")
return current_user
async def get_current_active_user_async(
current_user: User = Depends(get_current_user_async),
) -> User:
"""
异步获取当前激活用户
"""
if not current_user.is_active:
raise HTTPException(status_code=403, detail="用户未激活")
return current_user
def get_current_active_superuser(
current_user: User = Depends(get_current_user),
) -> User:
"""
获取当前激活超级管理员
"""
if not current_user.is_superuser:
raise HTTPException(
status_code=400, detail="用户权限不足"
)
return current_user
async def get_current_active_superuser_async(
current_user: User = Depends(get_current_user_async),
) -> User:
"""
异步获取当前激活超级管理员
"""
if not current_user.is_superuser:
raise HTTPException(
status_code=400, detail="用户权限不足"
)
return current_user
class UserOper(DbOper):
"""
用户管理
"""
def list(self) -> List[User]:
"""
获取用户列表
"""
return User.list(self._db)
def add(self, **kwargs):
"""
新增用户
"""
user = User(**kwargs)
user.create(self._db)
def get_by_name(self, name: str) -> User:
"""
根据用户名获取用户
"""
return User.get_by_name(self._db, name)
def get_permissions(self, name: str) -> dict:
"""
获取用户权限
"""
user = User.get_by_name(self._db, name)
if user:
return user.permissions or {}
return {}
def get_settings(self, name: str) -> Optional[dict]:
"""
获取用户个性化设置返回None表示用户不存在
"""
user = User.get_by_name(self._db, name)
if user:
return user.settings or {}
return None
def get_setting(self, name: str, key: str) -> Optional[str]:
"""
获取用户个性化设置
"""
settings = self.get_settings(name)
if settings:
return settings.get(key)
return None
def get_name(self, **kwargs) -> Optional[str]:
"""
根据绑定账号获取用户名称
"""
users = self.list()
for user in users:
user_setting = user.settings
if user_setting:
for k, v in kwargs.items():
if user_setting.get(k) == str(v):
return user.name
return None