feat(chain): add auth event to ChainEventType

This commit is contained in:
InfinityPacer
2024-10-18 20:03:05 +08:00
parent 5ef4fc04d5
commit 1a8e1844b4
8 changed files with 195 additions and 43 deletions

View File

@@ -1,7 +1,8 @@
import secrets
from datetime import timedelta
from typing import Any, List
from fastapi import APIRouter, Depends, HTTPException, Form
from fastapi import APIRouter, Depends, Form, HTTPException
from fastapi.security import OAuth2PasswordRequestForm
from sqlalchemy.orm import Session
@@ -45,16 +46,16 @@ async def login_access_token(
else:
# 如果找不到用户并开启了辅助认证
logger.warn(f"登录用户 {form_data.username} 本地不存在,尝试辅助认证 ...")
token = UserChain().user_authenticate(form_data.username, form_data.password)
if not token:
success = UserChain().user_authenticate(form_data.username, form_data.password)
if not success:
logger.warn(f"用户 {form_data.username} 登录失败!")
raise HTTPException(status_code=401, detail="用户名、密码、二次校验码不正确")
else:
logger.info(f"用户 {form_data.username} 辅助认证成功,用户信息: {token}以普通用户登录...")
logger.info(f"用户 {form_data.username} 辅助认证成功,以普通用户登录...")
# 加入用户信息表
logger.info(f"创建用户: {form_data.username}")
user = User(name=form_data.username, is_active=True,
is_superuser=False, hashed_password=get_password_hash(token))
is_superuser=False, hashed_password=get_password_hash(secrets.token_urlsafe(16)))
user.create(db)
else:
# 用户存在,但认证失败

View File

@@ -1,15 +1,80 @@
from typing import Optional
from app.chain import ChainBase
from app.log import logger
from app.schemas.event import AuthPassedInterceptData, AuthVerificationData
from app.schemas.types import ChainEventType
class UserChain(ChainBase):
"""
用户链
"""
def user_authenticate(self, name, password) -> Optional[str]:
def user_authenticate(self, name: str, password: str) -> bool:
"""
辅助完成用户认证
辅助完成用户认证
:param name: 用户名
:param password: 密码
:return: token
:return: 认证成功时返回 True否则返回 False
"""
return self.run_module("user_authenticate", name=name, password=password)
logger.debug(f"开始对用户 {name} 通过系统预置渠道进行辅助认证")
auth_data = AuthVerificationData(name=name, password=password)
# 尝试通过默认的认证模块认证
try:
result = self.run_module("user_authenticate", auth_data=auth_data)
if result:
return self._process_auth_success(name, result)
except Exception as e:
logger.error(f"认证模块运行出错:{e}")
return False
# 如果预置的认证未通过,则触发 AuthVerification 事件
logger.debug(f"用户 {name} 未通过系统预置渠道认证,触发认证事件")
event = self.eventmanager.send_event(
etype=ChainEventType.AuthVerification,
data=auth_data
)
if not event:
return False
if event and event.event_data:
try:
return self._process_auth_success(name, event.event_data)
except Exception as e:
logger.error(f"AuthVerificationData 数据验证失败:{e}")
return False
# 认证失败
logger.warning(f"用户 {name} 辅助认证失败")
return False
def _process_auth_success(self, name: str, data: AuthVerificationData) -> bool:
"""
处理认证成功后的逻辑,记录日志并处理拦截事件。
:param name: 用户名
:param data: 认证返回的数据,包含 token、channel 和 service
:return: 成功返回 True若被拦截返回 False
"""
token, channel, service = data.token, data.channel, data.service
if token and channel and service:
# 匿名化 token
anonymized_token = f"{token[:len(token) // 2]}****"
logger.info(f"用户 {name} 通过渠道 {channel},服务: {service} 认证成功token: {anonymized_token}")
# 触发认证通过的拦截事件
intercept_event = self.eventmanager.send_event(
etype=ChainEventType.AuthPassedIntercept,
data=AuthPassedInterceptData(name=name, channel=channel, service=service, token=token)
)
if intercept_event and intercept_event.event_data:
intercept_data: AuthPassedInterceptData = intercept_event.event_data
if intercept_data.cancel:
logger.info(
f"认证被拦截,用户: {name},渠道: {channel},服务: {service},拦截源: {intercept_data.source}")
return False
return True
logger.warning(f"用户 {name} 未通过辅助认证")
return False

View File

@@ -70,7 +70,7 @@ class ConfigModel(BaseModel):
CONFIG_DIR: Optional[str] = None
# 超级管理员
SUPERUSER: str = "admin"
# 辅助认证,允许通过媒体服务器认证并创建用户
# 辅助认证,允许通过外部服务(如媒体服务器/插件等)认证并创建用户
AUXILIARY_AUTH_ENABLE: bool = False
# API密钥需要更换
API_TOKEN: Optional[str] = None

View File

@@ -13,6 +13,7 @@ from typing import Callable, Dict, List, Optional, Union
from app.helper.message import MessageHelper
from app.helper.thread import ThreadHelper
from app.log import logger
from app.schemas.event import ChainEventData
from app.schemas.types import ChainEventType, EventType
from app.utils.limit import ExponentialBackoffRateLimiter
from app.utils.singleton import Singleton
@@ -28,7 +29,8 @@ class Event:
事件类,封装事件的基本信息
"""
def __init__(self, event_type: Union[EventType, ChainEventType], event_data: Optional[Dict] = None,
def __init__(self, event_type: Union[EventType, ChainEventType],
event_data: Optional[Union[Dict, ChainEventData]] = None,
priority: int = DEFAULT_EVENT_PRIORITY):
"""
:param event_type: 事件的类型,支持 EventType 或 ChainEventType
@@ -314,13 +316,13 @@ class EventManager(metaclass=Singleton):
return True
def __trigger_chain_event(self, event: Event) -> Event:
def __trigger_chain_event(self, event: Event) -> Optional[Event]:
"""
触发链式事件,按顺序调用订阅的处理器,并记录处理耗时
"""
logger.debug(f"Triggering synchronous chain event: {event}")
self.__dispatch_chain_event(event)
return event
dispatch = self.__dispatch_chain_event(event)
return event if dispatch else None
def __trigger_broadcast_event(self, event: Event):
"""
@@ -330,23 +332,25 @@ class EventManager(metaclass=Singleton):
logger.debug(f"Triggering broadcast event: {event}")
self.__event_queue.put((event.priority, event))
def __dispatch_chain_event(self, event: Event):
def __dispatch_chain_event(self, event: Event) -> bool:
"""
同步方式调度链式事件,按优先级顺序逐个调用事件处理器,并记录每个处理器的处理时间
:param event: 要调度的事件对象
"""
handlers = self.__chain_subscribers.get(event.event_type, {})
if not handlers:
return
self.__log_event_lifecycle(event, "started")
logger.debug(f"No handlers found for chain event: {event}")
return False
self.__log_event_lifecycle(event, "Started")
for handler_id, (priority, handler) in handlers.items():
start_time = time.time()
self.__safe_invoke_handler(handler, event)
logger.debug(
f"Handler {self.__get_handler_identifier(handler)} (Priority: {priority}) ,"
f" completed in {time.time() - start_time:.3f}s"
f"{self.__get_handler_identifier(handler)} (Priority: {priority}), "
f"completed in {time.time() - start_time:.3f}s for event: {event}"
)
self.__log_event_lifecycle(event, "completed")
self.__log_event_lifecycle(event, "Completed")
return True
def __dispatch_broadcast_event(self, event: Event):
"""
@@ -355,6 +359,7 @@ class EventManager(metaclass=Singleton):
"""
handlers = self.__broadcast_subscribers.get(event.event_type, {})
if not handlers:
logger.debug(f"No handlers found for broadcast event: {event}")
return
for handler_id, handler in handlers.items():
self.__executor.submit(self.__safe_invoke_handler, handler, event)

View File

@@ -1,10 +1,11 @@
from typing import Optional, Tuple, Union, Any, List, Generator
from typing import Any, Generator, List, Optional, Tuple, Union
from app import schemas
from app.core.context import MediaInfo
from app.log import logger
from app.modules import _ModuleBase, _MediaServerBase
from app.modules import _MediaServerBase, _ModuleBase
from app.modules.emby.emby import Emby
from app.schemas.event import AuthVerificationData
from app.schemas.types import MediaType, ModuleType
@@ -57,18 +58,22 @@ class EmbyModule(_ModuleBase, _MediaServerBase[Emby]):
logger.info(f"Emby服务器 {name} 连接断开,尝试重连 ...")
server.reconnect()
def user_authenticate(self, name: str, password: str) -> Optional[str]:
def user_authenticate(self, auth_data: AuthVerificationData) -> Optional[AuthVerificationData]:
"""
使用Emby用户辅助完成用户认证
:param name: 用户名
:param password: 密码
:return: token or None
:param auth_data: 认证数据
:return: 认证数据
"""
# Emby认证
for server in self.get_instances().values():
result = server.authenticate(name, password)
if result:
return result
if not auth_data:
return None
for name, server in self.get_instances().items():
token = server.authenticate(auth_data.name, auth_data.password)
if token:
auth_data.channel = self.get_name()
auth_data.service = name
auth_data.token = token
return auth_data
return None
def webhook_parser(self, body: Any, form: Any, args: Any) -> Optional[schemas.WebhookEventInfo]:

View File

@@ -1,10 +1,11 @@
from typing import Optional, Tuple, Union, Any, List, Generator
from typing import Any, Generator, List, Optional, Tuple, Union
from app import schemas
from app.core.context import MediaInfo
from app.log import logger
from app.modules import _ModuleBase, _MediaServerBase
from app.modules import _MediaServerBase, _ModuleBase
from app.modules.jellyfin.jellyfin import Jellyfin
from app.schemas.event import AuthVerificationData
from app.schemas.types import MediaType, ModuleType
@@ -57,18 +58,22 @@ class JellyfinModule(_ModuleBase, _MediaServerBase[Jellyfin]):
return False, f"无法连接Jellyfin服务器{name}"
return True, ""
def user_authenticate(self, name: str, password: str) -> Optional[str]:
def user_authenticate(self, auth_data: AuthVerificationData) -> Optional[AuthVerificationData]:
"""
使用Emby用户辅助完成用户认证
:param name: 用户名
:param password: 密码
:return: Token or None
使用Jellyfin用户辅助完成用户认证
:param auth_data: 认证数据
:return: 认证数据
"""
# Jellyfin认证
for server in self.get_instances().values():
result = server.authenticate(name, password)
if result:
return result
if not auth_data:
return None
for name, server in self.get_instances().items():
token = server.authenticate(auth_data.name, auth_data.password)
if token:
auth_data.channel = self.get_name()
auth_data.service = name
auth_data.token = token
return auth_data
return None
def webhook_parser(self, body: Any, form: Any, args: Any) -> Optional[schemas.WebhookEventInfo]:

67
app/schemas/event.py Normal file
View File

@@ -0,0 +1,67 @@
from typing import Optional
from pydantic import BaseModel, Field
class BaseEventData(BaseModel):
"""
事件数据的基类,所有具体事件数据类应继承自此类
"""
pass
class ChainEventData(BaseEventData):
"""
链式事件数据的基类,所有具体事件数据类应继承自此类
"""
pass
class AuthVerificationData(ChainEventData):
"""
AuthVerification 事件的数据模型
Attributes:
# 输入参数
name (str): 用户名
password (str): 用户密码
# 输出参数
token (str): 认证令牌
channel (str): 认证渠道
service (str): 服务名称
"""
# 输入参数
name: str = Field(..., description="用户名")
password: str = Field(..., description="用户密码")
# 输出参数
token: Optional[str] = Field(None, description="认证令牌")
channel: Optional[str] = Field(None, description="认证渠道")
service: Optional[str] = Field(None, description="服务名称")
class AuthPassedInterceptData(ChainEventData):
"""
AuthPassedIntercept 事件的数据模型。
Attributes:
# 输入参数
name (str): 用户名
channel (str): 认证渠道
service (str): 服务名称
token (str): 认证令牌
# 输出参数
source (str): 拦截源,默认值为 "未知拦截源"
cancel (bool): 是否取消认证,默认值为 False
"""
# 输入参数
name: str = Field(..., description="用户名")
channel: str = Field(..., description="认证渠道")
service: str = Field(..., description="服务名称")
token: Optional[str] = Field(None, description="认证令牌")
# 输出参数
source: str = Field("未知拦截源", description="拦截源")
cancel: bool = Field(False, description="是否取消认证")

View File

@@ -62,6 +62,10 @@ class EventType(Enum):
class ChainEventType(Enum):
# 名称识别请求
NameRecognize = "name.recognize"
# 认证验证请求
AuthVerification = "auth.verification"
# 认证通过拦截
AuthPassedIntercept = "auth.passed.intercept"
# 系统配置Key字典