mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-04-14 10:10:20 +08:00
feat(chain): add auth event to ChainEventType
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
import secrets
|
||||
from datetime import timedelta
|
||||
from typing import Any, List
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Form
|
||||
from fastapi import APIRouter, Depends, Form, HTTPException
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -45,16 +46,16 @@ async def login_access_token(
|
||||
else:
|
||||
# 如果找不到用户并开启了辅助认证
|
||||
logger.warn(f"登录用户 {form_data.username} 本地不存在,尝试辅助认证 ...")
|
||||
token = UserChain().user_authenticate(form_data.username, form_data.password)
|
||||
if not token:
|
||||
success = UserChain().user_authenticate(form_data.username, form_data.password)
|
||||
if not success:
|
||||
logger.warn(f"用户 {form_data.username} 登录失败!")
|
||||
raise HTTPException(status_code=401, detail="用户名、密码、二次校验码不正确")
|
||||
else:
|
||||
logger.info(f"用户 {form_data.username} 辅助认证成功,用户信息: {token},以普通用户登录...")
|
||||
logger.info(f"用户 {form_data.username} 辅助认证成功,以普通用户登录...")
|
||||
# 加入用户信息表
|
||||
logger.info(f"创建用户: {form_data.username}")
|
||||
user = User(name=form_data.username, is_active=True,
|
||||
is_superuser=False, hashed_password=get_password_hash(token))
|
||||
is_superuser=False, hashed_password=get_password_hash(secrets.token_urlsafe(16)))
|
||||
user.create(db)
|
||||
else:
|
||||
# 用户存在,但认证失败
|
||||
|
||||
@@ -1,15 +1,80 @@
|
||||
from typing import Optional
|
||||
|
||||
from app.chain import ChainBase
|
||||
from app.log import logger
|
||||
from app.schemas.event import AuthPassedInterceptData, AuthVerificationData
|
||||
from app.schemas.types import ChainEventType
|
||||
|
||||
|
||||
class UserChain(ChainBase):
|
||||
"""
|
||||
用户链
|
||||
"""
|
||||
|
||||
def user_authenticate(self, name, password) -> Optional[str]:
|
||||
def user_authenticate(self, name: str, password: str) -> bool:
|
||||
"""
|
||||
辅助完成用户认证
|
||||
辅助完成用户认证。
|
||||
|
||||
:param name: 用户名
|
||||
:param password: 密码
|
||||
:return: token
|
||||
:return: 认证成功时返回 True,否则返回 False
|
||||
"""
|
||||
return self.run_module("user_authenticate", name=name, password=password)
|
||||
logger.debug(f"开始对用户 {name} 通过系统预置渠道进行辅助认证")
|
||||
auth_data = AuthVerificationData(name=name, password=password)
|
||||
# 尝试通过默认的认证模块认证
|
||||
try:
|
||||
result = self.run_module("user_authenticate", auth_data=auth_data)
|
||||
if result:
|
||||
return self._process_auth_success(name, result)
|
||||
except Exception as e:
|
||||
logger.error(f"认证模块运行出错:{e}")
|
||||
return False
|
||||
|
||||
# 如果预置的认证未通过,则触发 AuthVerification 事件
|
||||
logger.debug(f"用户 {name} 未通过系统预置渠道认证,触发认证事件")
|
||||
event = self.eventmanager.send_event(
|
||||
etype=ChainEventType.AuthVerification,
|
||||
data=auth_data
|
||||
)
|
||||
if not event:
|
||||
return False
|
||||
if event and event.event_data:
|
||||
try:
|
||||
return self._process_auth_success(name, event.event_data)
|
||||
except Exception as e:
|
||||
logger.error(f"AuthVerificationData 数据验证失败:{e}")
|
||||
return False
|
||||
|
||||
# 认证失败
|
||||
logger.warning(f"用户 {name} 辅助认证失败")
|
||||
return False
|
||||
|
||||
def _process_auth_success(self, name: str, data: AuthVerificationData) -> bool:
|
||||
"""
|
||||
处理认证成功后的逻辑,记录日志并处理拦截事件。
|
||||
|
||||
:param name: 用户名
|
||||
:param data: 认证返回的数据,包含 token、channel 和 service
|
||||
:return: 成功返回 True,若被拦截返回 False
|
||||
"""
|
||||
token, channel, service = data.token, data.channel, data.service
|
||||
if token and channel and service:
|
||||
# 匿名化 token
|
||||
anonymized_token = f"{token[:len(token) // 2]}****"
|
||||
logger.info(f"用户 {name} 通过渠道 {channel},服务: {service} 认证成功,token: {anonymized_token}")
|
||||
|
||||
# 触发认证通过的拦截事件
|
||||
intercept_event = self.eventmanager.send_event(
|
||||
etype=ChainEventType.AuthPassedIntercept,
|
||||
data=AuthPassedInterceptData(name=name, channel=channel, service=service, token=token)
|
||||
)
|
||||
|
||||
if intercept_event and intercept_event.event_data:
|
||||
intercept_data: AuthPassedInterceptData = intercept_event.event_data
|
||||
if intercept_data.cancel:
|
||||
logger.info(
|
||||
f"认证被拦截,用户: {name},渠道: {channel},服务: {service},拦截源: {intercept_data.source}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
logger.warning(f"用户 {name} 未通过辅助认证")
|
||||
return False
|
||||
|
||||
@@ -70,7 +70,7 @@ class ConfigModel(BaseModel):
|
||||
CONFIG_DIR: Optional[str] = None
|
||||
# 超级管理员
|
||||
SUPERUSER: str = "admin"
|
||||
# 辅助认证,允许通过媒体服务器认证并创建用户
|
||||
# 辅助认证,允许通过外部服务(如媒体服务器/插件等)认证并创建用户
|
||||
AUXILIARY_AUTH_ENABLE: bool = False
|
||||
# API密钥,需要更换
|
||||
API_TOKEN: Optional[str] = None
|
||||
|
||||
@@ -13,6 +13,7 @@ from typing import Callable, Dict, List, Optional, Union
|
||||
from app.helper.message import MessageHelper
|
||||
from app.helper.thread import ThreadHelper
|
||||
from app.log import logger
|
||||
from app.schemas.event import ChainEventData
|
||||
from app.schemas.types import ChainEventType, EventType
|
||||
from app.utils.limit import ExponentialBackoffRateLimiter
|
||||
from app.utils.singleton import Singleton
|
||||
@@ -28,7 +29,8 @@ class Event:
|
||||
事件类,封装事件的基本信息
|
||||
"""
|
||||
|
||||
def __init__(self, event_type: Union[EventType, ChainEventType], event_data: Optional[Dict] = None,
|
||||
def __init__(self, event_type: Union[EventType, ChainEventType],
|
||||
event_data: Optional[Union[Dict, ChainEventData]] = None,
|
||||
priority: int = DEFAULT_EVENT_PRIORITY):
|
||||
"""
|
||||
:param event_type: 事件的类型,支持 EventType 或 ChainEventType
|
||||
@@ -314,13 +316,13 @@ class EventManager(metaclass=Singleton):
|
||||
|
||||
return True
|
||||
|
||||
def __trigger_chain_event(self, event: Event) -> Event:
|
||||
def __trigger_chain_event(self, event: Event) -> Optional[Event]:
|
||||
"""
|
||||
触发链式事件,按顺序调用订阅的处理器,并记录处理耗时
|
||||
"""
|
||||
logger.debug(f"Triggering synchronous chain event: {event}")
|
||||
self.__dispatch_chain_event(event)
|
||||
return event
|
||||
dispatch = self.__dispatch_chain_event(event)
|
||||
return event if dispatch else None
|
||||
|
||||
def __trigger_broadcast_event(self, event: Event):
|
||||
"""
|
||||
@@ -330,23 +332,25 @@ class EventManager(metaclass=Singleton):
|
||||
logger.debug(f"Triggering broadcast event: {event}")
|
||||
self.__event_queue.put((event.priority, event))
|
||||
|
||||
def __dispatch_chain_event(self, event: Event):
|
||||
def __dispatch_chain_event(self, event: Event) -> bool:
|
||||
"""
|
||||
同步方式调度链式事件,按优先级顺序逐个调用事件处理器,并记录每个处理器的处理时间
|
||||
:param event: 要调度的事件对象
|
||||
"""
|
||||
handlers = self.__chain_subscribers.get(event.event_type, {})
|
||||
if not handlers:
|
||||
return
|
||||
self.__log_event_lifecycle(event, "started")
|
||||
logger.debug(f"No handlers found for chain event: {event}")
|
||||
return False
|
||||
self.__log_event_lifecycle(event, "Started")
|
||||
for handler_id, (priority, handler) in handlers.items():
|
||||
start_time = time.time()
|
||||
self.__safe_invoke_handler(handler, event)
|
||||
logger.debug(
|
||||
f"Handler {self.__get_handler_identifier(handler)} (Priority: {priority}) ,"
|
||||
f" completed in {time.time() - start_time:.3f}s"
|
||||
f"{self.__get_handler_identifier(handler)} (Priority: {priority}), "
|
||||
f"completed in {time.time() - start_time:.3f}s for event: {event}"
|
||||
)
|
||||
self.__log_event_lifecycle(event, "completed")
|
||||
self.__log_event_lifecycle(event, "Completed")
|
||||
return True
|
||||
|
||||
def __dispatch_broadcast_event(self, event: Event):
|
||||
"""
|
||||
@@ -355,6 +359,7 @@ class EventManager(metaclass=Singleton):
|
||||
"""
|
||||
handlers = self.__broadcast_subscribers.get(event.event_type, {})
|
||||
if not handlers:
|
||||
logger.debug(f"No handlers found for broadcast event: {event}")
|
||||
return
|
||||
for handler_id, handler in handlers.items():
|
||||
self.__executor.submit(self.__safe_invoke_handler, handler, event)
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
from typing import Optional, Tuple, Union, Any, List, Generator
|
||||
from typing import Any, Generator, List, Optional, Tuple, Union
|
||||
|
||||
from app import schemas
|
||||
from app.core.context import MediaInfo
|
||||
from app.log import logger
|
||||
from app.modules import _ModuleBase, _MediaServerBase
|
||||
from app.modules import _MediaServerBase, _ModuleBase
|
||||
from app.modules.emby.emby import Emby
|
||||
from app.schemas.event import AuthVerificationData
|
||||
from app.schemas.types import MediaType, ModuleType
|
||||
|
||||
|
||||
@@ -57,18 +58,22 @@ class EmbyModule(_ModuleBase, _MediaServerBase[Emby]):
|
||||
logger.info(f"Emby服务器 {name} 连接断开,尝试重连 ...")
|
||||
server.reconnect()
|
||||
|
||||
def user_authenticate(self, name: str, password: str) -> Optional[str]:
|
||||
def user_authenticate(self, auth_data: AuthVerificationData) -> Optional[AuthVerificationData]:
|
||||
"""
|
||||
使用Emby用户辅助完成用户认证
|
||||
:param name: 用户名
|
||||
:param password: 密码
|
||||
:return: token or None
|
||||
:param auth_data: 认证数据
|
||||
:return: 认证数据
|
||||
"""
|
||||
# Emby认证
|
||||
for server in self.get_instances().values():
|
||||
result = server.authenticate(name, password)
|
||||
if result:
|
||||
return result
|
||||
if not auth_data:
|
||||
return None
|
||||
for name, server in self.get_instances().items():
|
||||
token = server.authenticate(auth_data.name, auth_data.password)
|
||||
if token:
|
||||
auth_data.channel = self.get_name()
|
||||
auth_data.service = name
|
||||
auth_data.token = token
|
||||
return auth_data
|
||||
return None
|
||||
|
||||
def webhook_parser(self, body: Any, form: Any, args: Any) -> Optional[schemas.WebhookEventInfo]:
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
from typing import Optional, Tuple, Union, Any, List, Generator
|
||||
from typing import Any, Generator, List, Optional, Tuple, Union
|
||||
|
||||
from app import schemas
|
||||
from app.core.context import MediaInfo
|
||||
from app.log import logger
|
||||
from app.modules import _ModuleBase, _MediaServerBase
|
||||
from app.modules import _MediaServerBase, _ModuleBase
|
||||
from app.modules.jellyfin.jellyfin import Jellyfin
|
||||
from app.schemas.event import AuthVerificationData
|
||||
from app.schemas.types import MediaType, ModuleType
|
||||
|
||||
|
||||
@@ -57,18 +58,22 @@ class JellyfinModule(_ModuleBase, _MediaServerBase[Jellyfin]):
|
||||
return False, f"无法连接Jellyfin服务器:{name}"
|
||||
return True, ""
|
||||
|
||||
def user_authenticate(self, name: str, password: str) -> Optional[str]:
|
||||
def user_authenticate(self, auth_data: AuthVerificationData) -> Optional[AuthVerificationData]:
|
||||
"""
|
||||
使用Emby用户辅助完成用户认证
|
||||
:param name: 用户名
|
||||
:param password: 密码
|
||||
:return: Token or None
|
||||
使用Jellyfin用户辅助完成用户认证
|
||||
:param auth_data: 认证数据
|
||||
:return: 认证数据
|
||||
"""
|
||||
# Jellyfin认证
|
||||
for server in self.get_instances().values():
|
||||
result = server.authenticate(name, password)
|
||||
if result:
|
||||
return result
|
||||
if not auth_data:
|
||||
return None
|
||||
for name, server in self.get_instances().items():
|
||||
token = server.authenticate(auth_data.name, auth_data.password)
|
||||
if token:
|
||||
auth_data.channel = self.get_name()
|
||||
auth_data.service = name
|
||||
auth_data.token = token
|
||||
return auth_data
|
||||
return None
|
||||
|
||||
def webhook_parser(self, body: Any, form: Any, args: Any) -> Optional[schemas.WebhookEventInfo]:
|
||||
|
||||
67
app/schemas/event.py
Normal file
67
app/schemas/event.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class BaseEventData(BaseModel):
|
||||
"""
|
||||
事件数据的基类,所有具体事件数据类应继承自此类
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class ChainEventData(BaseEventData):
|
||||
"""
|
||||
链式事件数据的基类,所有具体事件数据类应继承自此类
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class AuthVerificationData(ChainEventData):
|
||||
"""
|
||||
AuthVerification 事件的数据模型
|
||||
|
||||
Attributes:
|
||||
# 输入参数
|
||||
name (str): 用户名
|
||||
password (str): 用户密码
|
||||
|
||||
# 输出参数
|
||||
token (str): 认证令牌
|
||||
channel (str): 认证渠道
|
||||
service (str): 服务名称
|
||||
"""
|
||||
# 输入参数
|
||||
name: str = Field(..., description="用户名")
|
||||
password: str = Field(..., description="用户密码")
|
||||
|
||||
# 输出参数
|
||||
token: Optional[str] = Field(None, description="认证令牌")
|
||||
channel: Optional[str] = Field(None, description="认证渠道")
|
||||
service: Optional[str] = Field(None, description="服务名称")
|
||||
|
||||
|
||||
class AuthPassedInterceptData(ChainEventData):
|
||||
"""
|
||||
AuthPassedIntercept 事件的数据模型。
|
||||
|
||||
Attributes:
|
||||
# 输入参数
|
||||
name (str): 用户名
|
||||
channel (str): 认证渠道
|
||||
service (str): 服务名称
|
||||
token (str): 认证令牌
|
||||
|
||||
# 输出参数
|
||||
source (str): 拦截源,默认值为 "未知拦截源"
|
||||
cancel (bool): 是否取消认证,默认值为 False
|
||||
"""
|
||||
# 输入参数
|
||||
name: str = Field(..., description="用户名")
|
||||
channel: str = Field(..., description="认证渠道")
|
||||
service: str = Field(..., description="服务名称")
|
||||
token: Optional[str] = Field(None, description="认证令牌")
|
||||
|
||||
# 输出参数
|
||||
source: str = Field("未知拦截源", description="拦截源")
|
||||
cancel: bool = Field(False, description="是否取消认证")
|
||||
@@ -62,6 +62,10 @@ class EventType(Enum):
|
||||
class ChainEventType(Enum):
|
||||
# 名称识别请求
|
||||
NameRecognize = "name.recognize"
|
||||
# 认证验证请求
|
||||
AuthVerification = "auth.verification"
|
||||
# 认证通过拦截
|
||||
AuthPassedIntercept = "auth.passed.intercept"
|
||||
|
||||
|
||||
# 系统配置Key字典
|
||||
|
||||
Reference in New Issue
Block a user