feat(auth): enhance auxiliary authentication

This commit is contained in:
InfinityPacer
2024-10-19 03:16:04 +08:00
parent 386ff672a7
commit d8e7c7e6d7
9 changed files with 294 additions and 167 deletions

View File

@@ -1,21 +1,15 @@
import secrets
from datetime import timedelta
from typing import Any, List
from fastapi import APIRouter, Depends, Form, HTTPException
from fastapi.security import OAuth2PasswordRequestForm
from sqlalchemy.orm import Session
from app import schemas
from app.chain.tmdb import TmdbChain
from app.chain.user import UserChain
from app.core import security
from app.core.config import settings
from app.core.security import get_password_hash
from app.db import get_db
from app.db.models.user import User
from app.helper.sites import SitesHelper
from app.log import logger
from app.utils.web import WebUtils
router = APIRouter()
@@ -23,60 +17,32 @@ router = APIRouter()
@router.post("/access-token", summary="获取token", response_model=schemas.Token)
async def login_access_token(
db: Session = Depends(get_db),
form_data: OAuth2PasswordRequestForm = Depends(),
otp_password: str = Form(None)
) -> Any:
"""
获取认证Token
"""
# 检查数据库
success, user = User.authenticate(
db=db,
name=form_data.username,
password=form_data.password,
otp_password=otp_password
)
success, user_or_message = UserChain().user_authenticate(username=form_data.username,
password=form_data.password,
mfa_code=otp_password)
if not success:
# 认证不成功
if not user:
if not settings.AUXILIARY_AUTH_ENABLE:
logger.warn(f"用户 {form_data.username} 登录失败!")
raise HTTPException(status_code=401, detail="用户名、密码或二次校验码不正确")
else:
# 如果找不到用户并开启了辅助认证
logger.warn(f"登录用户 {form_data.username} 本地不存在,尝试辅助认证 ...")
success = UserChain().user_authenticate(form_data.username, form_data.password)
if not success:
logger.warn(f"用户 {form_data.username} 登录失败!")
raise HTTPException(status_code=401, detail="用户名、密码、二次校验码不正确")
else:
logger.info(f"用户 {form_data.username} 辅助认证成功,以普通用户登录...")
# 加入用户信息表
logger.info(f"创建用户: {form_data.username}")
user = User(name=form_data.username, is_active=True,
is_superuser=False, hashed_password=get_password_hash(secrets.token_urlsafe(16)))
user.create(db)
else:
# 用户存在,但认证失败
logger.warn(f"用户 {user.name} 登录失败!")
raise HTTPException(status_code=401, detail="用户名、密码或二次校验码不正确")
elif user and not user.is_active:
raise HTTPException(status_code=403, detail="用户未启用")
logger.info(f"用户 {user.name} 登录成功!")
raise HTTPException(status_code=401, detail=user_or_message)
level = SitesHelper().auth_level
return schemas.Token(
access_token=security.create_access_token(
userid=user.id,
username=user.name,
super_user=user.is_superuser,
userid=user_or_message.id,
username=user_or_message.name,
super_user=user_or_message.is_superuser,
expires_delta=timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES),
level=level
),
token_type="bearer",
super_user=user.is_superuser,
user_name=user.name,
avatar=user.avatar,
super_user=user_or_message.is_superuser,
user_name=user_or_message.name,
avatar=user_or_message.avatar,
level=level
)

View File

@@ -1,80 +1,224 @@
import secrets
from typing import Optional, Tuple, Union
from app.chain import ChainBase
from app.core.config import settings
from app.core.security import get_password_hash, verify_password
from app.db.models.user import User
from app.db.user_oper import UserOper
from app.log import logger
from app.schemas.event import AuthPassedInterceptData, AuthVerificationData
from app.schemas.event import AuthCredentials, AuthInterceptCredentials
from app.schemas.types import ChainEventType
from app.utils.otp import OtpUtils
from app.utils.singleton import Singleton
PASSWORD_INVALID_CREDENTIALS_MESSAGE = "用户名或密码或二次校验码不正确"
class UserChain(ChainBase):
class UserChain(ChainBase, metaclass=Singleton):
"""
用户链
用户链,处理多种认证协议
"""
def user_authenticate(self, name: str, password: str) -> bool:
"""
辅助完成用户认证。
def __init__(self):
super().__init__()
self.user_oper = UserOper()
:param name: 用户名
:param password: 密码
:return: 认证成功时返回 True否则返回 False
def user_authenticate(
self,
username: Optional[str] = None,
password: Optional[str] = None,
mfa_code: Optional[str] = None,
code: Optional[str] = None,
grant_type: str = "password"
) -> Union[Tuple[bool, Optional[str]], Tuple[bool, Optional[User]]]:
"""
logger.debug(f"开始对用户 {name} 通过系统预置渠道进行辅助认证")
auth_data = AuthVerificationData(name=name, password=password)
# 尝试通过默认的认证模块认证
try:
result = self.run_module("user_authenticate", auth_data=auth_data)
if result:
return self._process_auth_success(name, result)
except Exception as e:
logger.error(f"认证模块运行出错:{e}")
return False
认证用户,根据不同的 grant_type 处理不同的认证流程
# 如果预置的认证未通过,则触发 AuthVerification 事件
logger.debug(f"用户 {name} 未通过系统预置渠道认证,触发认证事件")
event = self.eventmanager.send_event(
etype=ChainEventType.AuthVerification,
data=auth_data
:param username: 用户名,适用于 "password" grant_type
:param password: 用户密码,适用于 "password" grant_type
:param mfa_code: 一次性密码,适用于 "password" grant_type
:param code: 授权码,适用于 "authorization_code" grant_type
:param grant_type: 认证类型,如 "password", "authorization_code", "client_credentials"
:return:
- 对于成功的认证,返回 (True, User)
- 对于失败的认证,返回 (False, "错误信息")
"""
credentials = AuthCredentials(
username=username,
password=password,
mfa_code=mfa_code,
code=code,
grant_type=grant_type
)
if not event:
logger.debug(f"开始使用 {grant_type} 认证,对用户 {username} 进行身份校验")
if credentials.grant_type == "password":
# Password 认证
success, user_or_message = self.password_authenticate(credentials=credentials)
if success:
# 如果用户启用了二次验证码,则进一步验证
if not self._verify_mfa(user_or_message, credentials.mfa_code):
return False, PASSWORD_INVALID_CREDENTIALS_MESSAGE
logger.info(f"用户 {username} 通过密码认证成功")
return True, user_or_message
else:
# 用户不存在或密码错误,考虑辅助认证
if settings.AUXILIARY_AUTH_ENABLE:
# 检查是否因为用户被禁用
user = self.user_oper.get_by_name(name=username)
if user and not user.is_active:
logger.info(f"用户 {username} 已被禁用,跳过后续辅助认证")
return False, PASSWORD_INVALID_CREDENTIALS_MESSAGE
logger.warning("密码认证失败,尝试通过外部服务进行辅助认证 ...")
aux_success, aux_user_or_message = self.auxiliary_authenticate(credentials=credentials)
if aux_success:
# 辅助认证成功后再验证二次验证码
if not self._verify_mfa(aux_user_or_message, credentials.mfa_code):
return False, PASSWORD_INVALID_CREDENTIALS_MESSAGE
return True, aux_user_or_message
else:
return False, PASSWORD_INVALID_CREDENTIALS_MESSAGE
else:
logger.debug(f"辅助认证未启用,用户 {username} 认证失败")
return False, PASSWORD_INVALID_CREDENTIALS_MESSAGE
else:
# 处理其他认证类型的分支
if settings.AUXILIARY_AUTH_ENABLE:
aux_success, aux_user_or_message = self.auxiliary_authenticate(credentials=credentials)
if aux_success:
logger.info(f"用户 {username} 辅助认证成功")
return True, aux_user_or_message
else:
logger.warning(f"用户 {username} 辅助认证失败")
return False, "认证失败"
else:
logger.debug(f"辅助认证未启用,认证类型 {grant_type} 未实现")
return False, "未实现的认证类型"
def password_authenticate(self, credentials: AuthCredentials) -> Tuple[bool, Union[User, str]]:
"""
密码认证
:param credentials: 认证凭证,包含用户名、密码以及可选的 MFA 认证码
:return:
- 成功时返回 (True, User),其中 User 是认证通过的用户对象
- 失败时返回 (False, "错误信息")
"""
if not credentials or credentials.grant_type != "password":
logger.debug("密码认证失败,认证类型不匹配")
return False, PASSWORD_INVALID_CREDENTIALS_MESSAGE
user = self.user_oper.get_by_name(name=credentials.username)
if not user:
logger.debug(f"密码认证失败,用户 {credentials.username} 不存在")
return False, PASSWORD_INVALID_CREDENTIALS_MESSAGE
if not user.is_active:
logger.debug(f"密码认证失败,用户 {credentials.username} 已被禁用")
return False, PASSWORD_INVALID_CREDENTIALS_MESSAGE
if not verify_password(credentials.password, str(user.hashed_password)):
logger.debug(f"密码认证失败,用户 {credentials.username} 的密码验证不通过")
return False, PASSWORD_INVALID_CREDENTIALS_MESSAGE
return True, user
def auxiliary_authenticate(self, credentials: AuthCredentials) -> Tuple[bool, Union[User, str]]:
"""
辅助用户认证
:param credentials: 认证凭证,包含必要的认证信息
:return:
- 成功时返回 (True, User),其中 User 是认证通过的用户对象
- 失败时返回 (False, "错误信息")
"""
if not credentials:
return False, "认证凭证无效"
logger.debug(f"尝试通过系统模块进行辅助认证,用户: {credentials.username}")
result = self.run_module("user_authenticate", credentials=credentials)
if not result:
logger.debug(f"通过系统模块辅助认证失败,尝试触发 {ChainEventType.AuthVerification} 事件")
event = self.eventmanager.send_event(etype=ChainEventType.AuthVerification, data=credentials)
if not event or not event.event_data:
logger.error(f"{credentials.grant_type} 辅助认证失败,未返回有效数据")
return False, f"{credentials.grant_type} 辅助认证事件失败或无效"
credentials = event.event_data # 使用事件返回的认证数据
else:
logger.info(f"通过系统模块辅助认证成功,用户: {credentials.username}")
credentials = result # 使用模块认证返回的认证数据
# 处理认证成功的逻辑
success = self._process_auth_success(username=credentials.username, credentials=credentials)
if success:
logger.info(f"用户 {credentials.username} 辅助认证通过")
return True, self.user_oper.get_by_name(credentials.username)
else:
logger.warning(f"用户 {credentials.username} 辅助认证未通过")
return False, "用户名或密码或二次校验码不正确"
@staticmethod
def _verify_mfa(user: User, mfa_code: Optional[str]) -> bool:
"""
验证 MFA二次验证码
:param user: 用户对象
:param mfa_code: 二次验证码
:return: 如果验证成功返回 True否则返回 False
"""
if not user.is_otp:
return True
if not mfa_code:
logger.debug(f"用户 {user.name} 缺少 MFA 认证码")
return False
if event and event.event_data:
try:
return self._process_auth_success(name, event.event_data)
except Exception as e:
logger.error(f"AuthVerificationData 数据验证失败:{e}")
if not OtpUtils.check(str(user.otp_secret), mfa_code):
logger.debug(f"用户 {user.name} 的 MFA 认证失败")
return False
return True
def _process_auth_success(self, username: str, credentials: AuthCredentials) -> bool:
"""
处理辅助认证成功的逻辑,返回用户对象或创建新用户
:param username: 用户名
:param credentials: 认证凭证,包含 token、channel、service 等信息
:return:
- 如果认证成功并且用户存在或已创建,返回 User 对象
- 如果认证被拦截或失败,返回 None
"""
token, channel, service = credentials.token, credentials.channel, credentials.service
if not all([token, channel, service]):
logger.debug(f"用户 {username} 未通过 {credentials.grant_type} 认证,必要信息不足")
return False
anonymized_token = f"{token[:len(token) // 2]}********"
logger.info(
f"认证类型:{credentials.grant_type},用户:{username},渠道:{channel}"
f"服务:{service} 认证成功token{anonymized_token}")
# 触发认证通过的拦截事件
intercept_event = self.eventmanager.send_event(
etype=ChainEventType.AuthPassedIntercept,
data=AuthInterceptCredentials(username=username, channel=channel, service=service, token=token)
)
if intercept_event and intercept_event.event_data:
intercept_data: AuthInterceptCredentials = intercept_event.event_data
if intercept_data.cancel:
logger.warning(
f"认证被拦截,用户:{username},渠道:{channel},服务:{service},拦截源:{intercept_data.source}")
return False
# 认证失败
logger.warning(f"用户 {name} 辅助认证失败")
return False
def _process_auth_success(self, name: str, data: AuthVerificationData) -> bool:
"""
处理认证成功后的逻辑,记录日志并处理拦截事件。
:param name: 用户名
:param data: 认证返回的数据,包含 token、channel 和 service
:return: 成功返回 True若被拦截返回 False
"""
token, channel, service = data.token, data.channel, data.service
if token and channel and service:
# 匿名化 token
anonymized_token = f"{token[:len(token) // 2]}****"
logger.info(f"用户 {name} 通过渠道 {channel},服务: {service} 认证成功token: {anonymized_token}")
# 触发认证通过的拦截事件
intercept_event = self.eventmanager.send_event(
etype=ChainEventType.AuthPassedIntercept,
data=AuthPassedInterceptData(name=name, channel=channel, service=service, token=token)
)
if intercept_event and intercept_event.event_data:
intercept_data: AuthPassedInterceptData = intercept_event.event_data
if intercept_data.cancel:
logger.info(
f"认证被拦截,用户: {name},渠道: {channel},服务: {service},拦截源: {intercept_data.source}")
return False
# 检查用户是否存在,如果不存在则创建新用户
user = self.user_oper.get_by_name(name=username)
if user:
return True
logger.warning(f"用户 {name} 未通过辅助认证")
return False
logger.info(f"用户 {username} 不存在,已通过 {credentials.grant_type} 认证并已创建普通用户")
self.user_oper.add(name=username, is_active=True, is_superuser=False,
hashed_password=get_password_hash(secrets.token_urlsafe(16)))
return True

View File

@@ -70,7 +70,7 @@ class ConfigModel(BaseModel):
CONFIG_DIR: Optional[str] = None
# 超级管理员
SUPERUSER: str = "admin"
# 辅助认证,允许通过外部服务(如媒体服务器/插件等)认证并创建用户
# 辅助认证,允许通过外部服务进行认证、单点登录以及自动创建用户
AUXILIARY_AUTH_ENABLE: bool = False
# API密钥需要更换
API_TOKEN: Optional[str] = None

View File

@@ -1,11 +1,7 @@
from typing import Tuple, Any
from sqlalchemy import Boolean, Column, Integer, String, Sequence, JSON
from sqlalchemy import Boolean, Column, Integer, JSON, Sequence, String
from sqlalchemy.orm import Session
from app.core.security import verify_password
from app.db import db_query, db_update, Base
from app.utils.otp import OtpUtils
from app.db import Base, db_query, db_update
class User(Base):
@@ -35,20 +31,6 @@ class User(Base):
# 用户个性化设置 json
settings = Column(JSON, default=dict)
@staticmethod
@db_query
def authenticate(db: Session, name: str, password: str,
otp_password: str) -> Tuple[bool, Any]:
user = db.query(User).filter(User.name == name).first()
if not user:
return False, None
if not verify_password(password, str(user.hashed_password)):
return False, user
if user.is_otp:
if not otp_password or not OtpUtils.check(str(user.otp_secret), otp_password):
return False, user
return True, user
@staticmethod
@db_query
def get_by_name(db: Session, name: str):

View File

@@ -5,8 +5,7 @@ from sqlalchemy.orm import Session
from app import schemas
from app.core.security import verify_token
from app.db import DbOper
from app.db import get_db
from app.db import DbOper, get_db
from app.db.models.user import User
@@ -52,6 +51,19 @@ class UserOper(DbOper):
用户管理
"""
def add(self, **kwargs):
"""
新增用户
"""
user = User(**kwargs)
user.create(self._db)
def get_by_name(self, name: str) -> User:
"""
根据用户名获取用户
"""
return User.get_by_name(self._db, name)
def get_permissions(self, name: str) -> dict:
"""
获取用户权限

View File

@@ -5,7 +5,7 @@ from app.core.context import MediaInfo
from app.log import logger
from app.modules import _MediaServerBase, _ModuleBase
from app.modules.emby.emby import Emby
from app.schemas.event import AuthVerificationData
from app.schemas.event import AuthCredentials
from app.schemas.types import MediaType, ModuleType
@@ -58,22 +58,22 @@ class EmbyModule(_ModuleBase, _MediaServerBase[Emby]):
logger.info(f"Emby服务器 {name} 连接断开,尝试重连 ...")
server.reconnect()
def user_authenticate(self, auth_data: AuthVerificationData) -> Optional[AuthVerificationData]:
def user_authenticate(self, credentials: AuthCredentials) -> Optional[AuthCredentials]:
"""
使用Emby用户辅助完成用户认证
:param auth_data: 认证数据
:param credentials: 认证数据
:return: 认证数据
"""
# Emby认证
if not auth_data:
if not credentials or credentials.grant_type != "password":
return None
for name, server in self.get_instances().items():
token = server.authenticate(auth_data.name, auth_data.password)
token = server.authenticate(credentials.username, credentials.password)
if token:
auth_data.channel = self.get_name()
auth_data.service = name
auth_data.token = token
return auth_data
credentials.channel = self.get_name()
credentials.service = name
credentials.token = token
return credentials
return None
def webhook_parser(self, body: Any, form: Any, args: Any) -> Optional[schemas.WebhookEventInfo]:

View File

@@ -5,7 +5,7 @@ from app.core.context import MediaInfo
from app.log import logger
from app.modules import _MediaServerBase, _ModuleBase
from app.modules.jellyfin.jellyfin import Jellyfin
from app.schemas.event import AuthVerificationData
from app.schemas.event import AuthCredentials
from app.schemas.types import MediaType, ModuleType
@@ -58,22 +58,22 @@ class JellyfinModule(_ModuleBase, _MediaServerBase[Jellyfin]):
return False, f"无法连接Jellyfin服务器{name}"
return True, ""
def user_authenticate(self, auth_data: AuthVerificationData) -> Optional[AuthVerificationData]:
def user_authenticate(self, credentials: AuthCredentials) -> Optional[AuthCredentials]:
"""
使用Jellyfin用户辅助完成用户认证
:param auth_data: 认证数据
:param credentials: 认证数据
:return: 认证数据
"""
# Jellyfin认证
if not auth_data:
if not credentials or credentials.grant_type != "password":
return None
for name, server in self.get_instances().items():
token = server.authenticate(auth_data.name, auth_data.password)
token = server.authenticate(credentials.username, credentials.password)
if token:
auth_data.channel = self.get_name()
auth_data.service = name
auth_data.token = token
return auth_data
credentials.channel = self.get_name()
credentials.service = name
credentials.token = token
return credentials
return None
def webhook_parser(self, body: Any, form: Any, args: Any) -> Optional[schemas.WebhookEventInfo]:

View File

@@ -1,6 +1,6 @@
from typing import Optional
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, root_validator
class BaseEventData(BaseModel):
@@ -17,37 +17,60 @@ class ChainEventData(BaseEventData):
pass
class AuthVerificationData(ChainEventData):
class AuthCredentials(ChainEventData):
"""
AuthVerification 事件的数据模型
Attributes:
# 输入参数
name (str): 用户
password (str): 用户密码
# 输出参数
token (str): 认证令牌
channel (str): 认证渠道
service (str): 服务名称
username (Optional[str]): 用户名,适用于 "password" grant_type
password (Optional[str]): 用户密码,适用于 "password" grant_type
mfa_code (Optional[str]): 一次性密码,目前仅适用于 "password" 认证类型
code (Optional[str]): 授权码,适用于 "authorization_code" grant_type
grant_type (str): 认证类型,如 "password", "authorization_code", "client_credentials"
# scope (List[str]): 权限范围,如 ["read", "write"]
token (Optional[str]): 认证令牌
channel (Optional[str]): 认证渠道
service (Optional[str]): 服务名称
"""
# 输入参数
name: str = Field(..., description="用户名")
password: str = Field(..., description="用户密码")
username: Optional[str] = Field(None, description="用户名,适用于 'password' 认证类型")
password: Optional[str] = Field(None, description="用户密码,适用于 'password' 认证类型")
mfa_code: Optional[str] = Field(None, description="一次性密码,目前仅适用于 'password' 认证类型")
code: Optional[str] = Field(None, description="授权码,适用于 'authorization_code' 认证类型")
grant_type: str = Field(..., description="认证类型,如 'password', 'authorization_code', 'client_credentials'")
# scope: List[str] = Field(default_factory=list, description="权限范围,如 ['read', 'write']")
# 输出参数
# grant_type 为 authorization_code 时,输出参数包括 username、token、channel、service
token: Optional[str] = Field(None, description="认证令牌")
channel: Optional[str] = Field(None, description="认证渠道")
service: Optional[str] = Field(None, description="服务名称")
@root_validator(pre=True)
def check_fields_based_on_grant_type(cls, values):
grant_type = values.get("grant_type")
if not grant_type:
values["grant_type"] = "password"
grant_type = "password"
class AuthPassedInterceptData(ChainEventData):
if grant_type == "password":
if not values.get("username") or not values.get("password"):
raise ValueError("username and password are required for grant_type 'password'")
elif grant_type == "authorization_code":
if not values.get("code"):
raise ValueError("code is required for grant_type 'authorization_code'")
return values
class AuthInterceptCredentials(ChainEventData):
"""
AuthPassedIntercept 事件的数据模型
AuthPassedIntercept 事件的数据模型
Attributes:
# 输入参数
name (str): 用户名
username (str): 用户名
channel (str): 认证渠道
service (str): 服务名称
token (str): 认证令牌
@@ -57,7 +80,7 @@ class AuthPassedInterceptData(ChainEventData):
cancel (bool): 是否取消认证,默认值为 False
"""
# 输入参数
name: str = Field(..., description="用户名")
username: str = Field(..., description="用户名")
channel: str = Field(..., description="认证渠道")
service: str = Field(..., description="服务名称")
token: Optional[str] = Field(None, description="认证令牌")

View File

@@ -17,7 +17,7 @@ DB_MAX_OVERFLOW=5
DB_TIMEOUT=60
# 【*】超级管理员,设置后一但重启将固化到数据库中,修改将无效(初始化超级管理员密码仅会生成一次,请在日志中查看并自行登录系统修改)
SUPERUSER=admin
# 辅助认证,允许通过媒体服务器认证并创建用户
# 辅助认证,允许通过外部服务进行认证、单点登录以及自动创建用户
AUXILIARY_AUTH_ENABLE=false
# 大内存模式,开启后会增加缓存数量,但会占用更多内存
BIG_MEMORY_MODE=false