From 1a8e1844b4c5bad4fb18824542daa48c3c856b06 Mon Sep 17 00:00:00 2001 From: InfinityPacer <160988576+InfinityPacer@users.noreply.github.com> Date: Fri, 18 Oct 2024 20:03:05 +0800 Subject: [PATCH] feat(chain): add auth event to ChainEventType --- app/api/endpoints/login.py | 11 ++--- app/chain/user.py | 77 +++++++++++++++++++++++++++++--- app/core/config.py | 2 +- app/core/event.py | 25 ++++++----- app/modules/emby/__init__.py | 25 ++++++----- app/modules/jellyfin/__init__.py | 27 ++++++----- app/schemas/event.py | 67 +++++++++++++++++++++++++++ app/schemas/types.py | 4 ++ 8 files changed, 195 insertions(+), 43 deletions(-) create mode 100644 app/schemas/event.py diff --git a/app/api/endpoints/login.py b/app/api/endpoints/login.py index 7d88f8fd..8df058a4 100644 --- a/app/api/endpoints/login.py +++ b/app/api/endpoints/login.py @@ -1,7 +1,8 @@ +import secrets from datetime import timedelta from typing import Any, List -from fastapi import APIRouter, Depends, HTTPException, Form +from fastapi import APIRouter, Depends, Form, HTTPException from fastapi.security import OAuth2PasswordRequestForm from sqlalchemy.orm import Session @@ -45,16 +46,16 @@ async def login_access_token( else: # 如果找不到用户并开启了辅助认证 logger.warn(f"登录用户 {form_data.username} 本地不存在,尝试辅助认证 ...") - token = UserChain().user_authenticate(form_data.username, form_data.password) - if not token: + success = UserChain().user_authenticate(form_data.username, form_data.password) + if not success: logger.warn(f"用户 {form_data.username} 登录失败!") raise HTTPException(status_code=401, detail="用户名、密码、二次校验码不正确") else: - logger.info(f"用户 {form_data.username} 辅助认证成功,用户信息: {token},以普通用户登录...") + logger.info(f"用户 {form_data.username} 辅助认证成功,以普通用户登录...") # 加入用户信息表 logger.info(f"创建用户: {form_data.username}") user = User(name=form_data.username, is_active=True, - is_superuser=False, hashed_password=get_password_hash(token)) + is_superuser=False, hashed_password=get_password_hash(secrets.token_urlsafe(16))) user.create(db) else: # 用户存在,但认证失败 diff --git a/app/chain/user.py b/app/chain/user.py index 6ba155f5..8c1c193b 100644 --- a/app/chain/user.py +++ b/app/chain/user.py @@ -1,15 +1,80 @@ -from typing import Optional - from app.chain import ChainBase +from app.log import logger +from app.schemas.event import AuthPassedInterceptData, AuthVerificationData +from app.schemas.types import ChainEventType class UserChain(ChainBase): + """ + 用户链 + """ - def user_authenticate(self, name, password) -> Optional[str]: + def user_authenticate(self, name: str, password: str) -> bool: """ - 辅助完成用户认证 + 辅助完成用户认证。 + :param name: 用户名 :param password: 密码 - :return: token + :return: 认证成功时返回 True,否则返回 False """ - return self.run_module("user_authenticate", name=name, password=password) + logger.debug(f"开始对用户 {name} 通过系统预置渠道进行辅助认证") + auth_data = AuthVerificationData(name=name, password=password) + # 尝试通过默认的认证模块认证 + try: + result = self.run_module("user_authenticate", auth_data=auth_data) + if result: + return self._process_auth_success(name, result) + except Exception as e: + logger.error(f"认证模块运行出错:{e}") + return False + + # 如果预置的认证未通过,则触发 AuthVerification 事件 + logger.debug(f"用户 {name} 未通过系统预置渠道认证,触发认证事件") + event = self.eventmanager.send_event( + etype=ChainEventType.AuthVerification, + data=auth_data + ) + if not event: + return False + if event and event.event_data: + try: + return self._process_auth_success(name, event.event_data) + except Exception as e: + logger.error(f"AuthVerificationData 数据验证失败:{e}") + return False + + # 认证失败 + logger.warning(f"用户 {name} 辅助认证失败") + return False + + def _process_auth_success(self, name: str, data: AuthVerificationData) -> bool: + """ + 处理认证成功后的逻辑,记录日志并处理拦截事件。 + + :param name: 用户名 + :param data: 认证返回的数据,包含 token、channel 和 service + :return: 成功返回 True,若被拦截返回 False + """ + token, channel, service = data.token, data.channel, data.service + if token and channel and service: + # 匿名化 token + anonymized_token = f"{token[:len(token) // 2]}****" + logger.info(f"用户 {name} 通过渠道 {channel},服务: {service} 认证成功,token: {anonymized_token}") + + # 触发认证通过的拦截事件 + intercept_event = self.eventmanager.send_event( + etype=ChainEventType.AuthPassedIntercept, + data=AuthPassedInterceptData(name=name, channel=channel, service=service, token=token) + ) + + if intercept_event and intercept_event.event_data: + intercept_data: AuthPassedInterceptData = intercept_event.event_data + if intercept_data.cancel: + logger.info( + f"认证被拦截,用户: {name},渠道: {channel},服务: {service},拦截源: {intercept_data.source}") + return False + + return True + + logger.warning(f"用户 {name} 未通过辅助认证") + return False diff --git a/app/core/config.py b/app/core/config.py index 3be5f160..dd014abe 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -70,7 +70,7 @@ class ConfigModel(BaseModel): CONFIG_DIR: Optional[str] = None # 超级管理员 SUPERUSER: str = "admin" - # 辅助认证,允许通过媒体服务器认证并创建用户 + # 辅助认证,允许通过外部服务(如媒体服务器/插件等)认证并创建用户 AUXILIARY_AUTH_ENABLE: bool = False # API密钥,需要更换 API_TOKEN: Optional[str] = None diff --git a/app/core/event.py b/app/core/event.py index f04d2c2a..69839070 100644 --- a/app/core/event.py +++ b/app/core/event.py @@ -13,6 +13,7 @@ from typing import Callable, Dict, List, Optional, Union from app.helper.message import MessageHelper from app.helper.thread import ThreadHelper from app.log import logger +from app.schemas.event import ChainEventData from app.schemas.types import ChainEventType, EventType from app.utils.limit import ExponentialBackoffRateLimiter from app.utils.singleton import Singleton @@ -28,7 +29,8 @@ class Event: 事件类,封装事件的基本信息 """ - def __init__(self, event_type: Union[EventType, ChainEventType], event_data: Optional[Dict] = None, + def __init__(self, event_type: Union[EventType, ChainEventType], + event_data: Optional[Union[Dict, ChainEventData]] = None, priority: int = DEFAULT_EVENT_PRIORITY): """ :param event_type: 事件的类型,支持 EventType 或 ChainEventType @@ -314,13 +316,13 @@ class EventManager(metaclass=Singleton): return True - def __trigger_chain_event(self, event: Event) -> Event: + def __trigger_chain_event(self, event: Event) -> Optional[Event]: """ 触发链式事件,按顺序调用订阅的处理器,并记录处理耗时 """ logger.debug(f"Triggering synchronous chain event: {event}") - self.__dispatch_chain_event(event) - return event + dispatch = self.__dispatch_chain_event(event) + return event if dispatch else None def __trigger_broadcast_event(self, event: Event): """ @@ -330,23 +332,25 @@ class EventManager(metaclass=Singleton): logger.debug(f"Triggering broadcast event: {event}") self.__event_queue.put((event.priority, event)) - def __dispatch_chain_event(self, event: Event): + def __dispatch_chain_event(self, event: Event) -> bool: """ 同步方式调度链式事件,按优先级顺序逐个调用事件处理器,并记录每个处理器的处理时间 :param event: 要调度的事件对象 """ handlers = self.__chain_subscribers.get(event.event_type, {}) if not handlers: - return - self.__log_event_lifecycle(event, "started") + logger.debug(f"No handlers found for chain event: {event}") + return False + self.__log_event_lifecycle(event, "Started") for handler_id, (priority, handler) in handlers.items(): start_time = time.time() self.__safe_invoke_handler(handler, event) logger.debug( - f"Handler {self.__get_handler_identifier(handler)} (Priority: {priority}) ," - f" completed in {time.time() - start_time:.3f}s" + f"{self.__get_handler_identifier(handler)} (Priority: {priority}), " + f"completed in {time.time() - start_time:.3f}s for event: {event}" ) - self.__log_event_lifecycle(event, "completed") + self.__log_event_lifecycle(event, "Completed") + return True def __dispatch_broadcast_event(self, event: Event): """ @@ -355,6 +359,7 @@ class EventManager(metaclass=Singleton): """ handlers = self.__broadcast_subscribers.get(event.event_type, {}) if not handlers: + logger.debug(f"No handlers found for broadcast event: {event}") return for handler_id, handler in handlers.items(): self.__executor.submit(self.__safe_invoke_handler, handler, event) diff --git a/app/modules/emby/__init__.py b/app/modules/emby/__init__.py index 98cc1487..a23d8289 100644 --- a/app/modules/emby/__init__.py +++ b/app/modules/emby/__init__.py @@ -1,10 +1,11 @@ -from typing import Optional, Tuple, Union, Any, List, Generator +from typing import Any, Generator, List, Optional, Tuple, Union from app import schemas from app.core.context import MediaInfo from app.log import logger -from app.modules import _ModuleBase, _MediaServerBase +from app.modules import _MediaServerBase, _ModuleBase from app.modules.emby.emby import Emby +from app.schemas.event import AuthVerificationData from app.schemas.types import MediaType, ModuleType @@ -57,18 +58,22 @@ class EmbyModule(_ModuleBase, _MediaServerBase[Emby]): logger.info(f"Emby服务器 {name} 连接断开,尝试重连 ...") server.reconnect() - def user_authenticate(self, name: str, password: str) -> Optional[str]: + def user_authenticate(self, auth_data: AuthVerificationData) -> Optional[AuthVerificationData]: """ 使用Emby用户辅助完成用户认证 - :param name: 用户名 - :param password: 密码 - :return: token or None + :param auth_data: 认证数据 + :return: 认证数据 """ # Emby认证 - for server in self.get_instances().values(): - result = server.authenticate(name, password) - if result: - return result + if not auth_data: + return None + for name, server in self.get_instances().items(): + token = server.authenticate(auth_data.name, auth_data.password) + if token: + auth_data.channel = self.get_name() + auth_data.service = name + auth_data.token = token + return auth_data return None def webhook_parser(self, body: Any, form: Any, args: Any) -> Optional[schemas.WebhookEventInfo]: diff --git a/app/modules/jellyfin/__init__.py b/app/modules/jellyfin/__init__.py index e59e45d7..a48a0c71 100644 --- a/app/modules/jellyfin/__init__.py +++ b/app/modules/jellyfin/__init__.py @@ -1,10 +1,11 @@ -from typing import Optional, Tuple, Union, Any, List, Generator +from typing import Any, Generator, List, Optional, Tuple, Union from app import schemas from app.core.context import MediaInfo from app.log import logger -from app.modules import _ModuleBase, _MediaServerBase +from app.modules import _MediaServerBase, _ModuleBase from app.modules.jellyfin.jellyfin import Jellyfin +from app.schemas.event import AuthVerificationData from app.schemas.types import MediaType, ModuleType @@ -57,18 +58,22 @@ class JellyfinModule(_ModuleBase, _MediaServerBase[Jellyfin]): return False, f"无法连接Jellyfin服务器:{name}" return True, "" - def user_authenticate(self, name: str, password: str) -> Optional[str]: + def user_authenticate(self, auth_data: AuthVerificationData) -> Optional[AuthVerificationData]: """ - 使用Emby用户辅助完成用户认证 - :param name: 用户名 - :param password: 密码 - :return: Token or None + 使用Jellyfin用户辅助完成用户认证 + :param auth_data: 认证数据 + :return: 认证数据 """ # Jellyfin认证 - for server in self.get_instances().values(): - result = server.authenticate(name, password) - if result: - return result + if not auth_data: + return None + for name, server in self.get_instances().items(): + token = server.authenticate(auth_data.name, auth_data.password) + if token: + auth_data.channel = self.get_name() + auth_data.service = name + auth_data.token = token + return auth_data return None def webhook_parser(self, body: Any, form: Any, args: Any) -> Optional[schemas.WebhookEventInfo]: diff --git a/app/schemas/event.py b/app/schemas/event.py new file mode 100644 index 00000000..48368905 --- /dev/null +++ b/app/schemas/event.py @@ -0,0 +1,67 @@ +from typing import Optional + +from pydantic import BaseModel, Field + + +class BaseEventData(BaseModel): + """ + 事件数据的基类,所有具体事件数据类应继承自此类 + """ + pass + + +class ChainEventData(BaseEventData): + """ + 链式事件数据的基类,所有具体事件数据类应继承自此类 + """ + pass + + +class AuthVerificationData(ChainEventData): + """ + AuthVerification 事件的数据模型 + + Attributes: + # 输入参数 + name (str): 用户名 + password (str): 用户密码 + + # 输出参数 + token (str): 认证令牌 + channel (str): 认证渠道 + service (str): 服务名称 + """ + # 输入参数 + name: str = Field(..., description="用户名") + password: str = Field(..., description="用户密码") + + # 输出参数 + token: Optional[str] = Field(None, description="认证令牌") + channel: Optional[str] = Field(None, description="认证渠道") + service: Optional[str] = Field(None, description="服务名称") + + +class AuthPassedInterceptData(ChainEventData): + """ + AuthPassedIntercept 事件的数据模型。 + + Attributes: + # 输入参数 + name (str): 用户名 + channel (str): 认证渠道 + service (str): 服务名称 + token (str): 认证令牌 + + # 输出参数 + source (str): 拦截源,默认值为 "未知拦截源" + cancel (bool): 是否取消认证,默认值为 False + """ + # 输入参数 + name: str = Field(..., description="用户名") + channel: str = Field(..., description="认证渠道") + service: str = Field(..., description="服务名称") + token: Optional[str] = Field(None, description="认证令牌") + + # 输出参数 + source: str = Field("未知拦截源", description="拦截源") + cancel: bool = Field(False, description="是否取消认证") diff --git a/app/schemas/types.py b/app/schemas/types.py index 5cb99091..5fe66c3a 100644 --- a/app/schemas/types.py +++ b/app/schemas/types.py @@ -62,6 +62,10 @@ class EventType(Enum): class ChainEventType(Enum): # 名称识别请求 NameRecognize = "name.recognize" + # 认证验证请求 + AuthVerification = "auth.verification" + # 认证通过拦截 + AuthPassedIntercept = "auth.passed.intercept" # 系统配置Key字典