Merge pull request #2818 from InfinityPacer/feature/security

This commit is contained in:
jxxghp
2024-10-08 20:36:49 +08:00
committed by GitHub
10 changed files with 67 additions and 80 deletions

View File

@@ -102,7 +102,8 @@ def start(
@router.get("/stop/{hashString}", summary="暂停任务", response_model=schemas.Response)
def stop(hashString: str) -> Any:
def stop(hashString: str,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
暂停下载任务
"""
@@ -111,7 +112,8 @@ def stop(hashString: str) -> Any:
@router.delete("/{hashString}", summary="删除下载任务", response_model=schemas.Response)
def delete(hashString: str) -> Any:
def delete(hashString: str,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
删除下载任务
"""

View File

@@ -6,38 +6,38 @@ from sqlalchemy.orm import Session
from app import schemas
from app.chain.download import DownloadChain
from app.chain.mediaserver import MediaServerChain
from app.core.config import settings
from app.core.context import MediaInfo
from app.core.metainfo import MetaInfo
from app.core.security import verify_token
from app.db import get_db
from app.db.mediaserver_oper import MediaServerOper
from app.db.models import MediaServerItem
from app.helper.mediaserver import MediaServerHelper
from app.schemas import MediaType, NotExistMediaInfo
router = APIRouter()
@router.get("/play/{itemid}", summary="在线播放")
def play_item(itemid: str) -> schemas.Response:
@router.get("/play/{itemid:path}", summary="在线播放")
def play_item(itemid: str, _: schemas.TokenPayload = Depends(verify_token)) -> schemas.Response:
"""
获取媒体服务器播放页面地址
"""
if not itemid:
return schemas.Response(success=False, msg="参数错误")
if not settings.MEDIASERVER:
return schemas.Response(success=False, msg="未配置媒体服务器")
# 查找一个不为空的值
mediaserver = next((server for server in settings.MEDIASERVER.split(",") if server), None)
if not mediaserver:
return schemas.Response(success=False, msg="未配置媒体服务器")
play_url = MediaServerChain().get_play_url(server=mediaserver, item_id=itemid)
# 重定向到play_url
if not play_url:
return schemas.Response(success=False, msg="未找到播放地址")
return schemas.Response(success=True, data={
"url": play_url
})
return schemas.Response(success=False, message="参数错误")
configs = MediaServerHelper().get_configs()
if not configs:
return schemas.Response(success=False, message="未配置媒体服务器")
media_chain = MediaServerChain()
for name in configs.keys():
item = media_chain.iteminfo(server=name, item_id=itemid)
if item:
play_url = media_chain.get_play_url(server=name, item_id=itemid)
if play_url:
return schemas.Response(success=True, data={
"url": play_url
})
return schemas.Response(success=False, message="未找到播放地址")
@router.get("/exists", summary="查询本地是否存在(数据库)", response_model=schemas.Response)

View File

@@ -9,7 +9,7 @@ from starlette.responses import PlainTextResponse
from app import schemas
from app.chain.message import MessageChain
from app.core.config import settings, global_vars
from app.core.security import verify_token
from app.core.security import verify_token, verify_apitoken
from app.db import get_db
from app.db.models import User
from app.db.models.message import Message
@@ -30,7 +30,8 @@ def start_message_chain(body: Any, form: Any, args: Any):
@router.post("/", summary="接收用户消息", response_model=schemas.Response)
async def user_message(background_tasks: BackgroundTasks, request: Request):
async def user_message(background_tasks: BackgroundTasks, request: Request,
_: schemas.TokenPayload = Depends(verify_apitoken)):
"""
用户消息响应配置请求中需要添加参数token=API_TOKEN&source=消息配置名
"""
@@ -102,18 +103,17 @@ def wechat_verify(echostr: str, msg_signature: str, timestamp: Union[str, int],
return "未找到对应的消息配置"
def vocechat_verify(token: str) -> Any:
def vocechat_verify() -> Any:
"""
VoceChat验证响应
"""
if token == settings.API_TOKEN:
return {"status": "OK"}
return {"status": "API_TOKEN ERROR"}
return {"status": "OK"}
@router.get("/", summary="回调请求验证")
def incoming_verify(token: str = None, echostr: str = None, msg_signature: str = None,
timestamp: Union[str, int] = None, nonce: str = None, source: str = None) -> Any:
timestamp: Union[str, int] = None, nonce: str = None, source: str = None,
_: schemas.TokenPayload = Depends(verify_apitoken)) -> Any:
"""
微信/VoceChat等验证响应
"""
@@ -121,7 +121,7 @@ def incoming_verify(token: str = None, echostr: str = None, msg_signature: str =
f"msg_signature={msg_signature}, timestamp={timestamp}, nonce={nonce}")
if echostr and msg_signature and timestamp and nonce:
return wechat_verify(echostr, msg_signature, timestamp, nonce, source)
return vocechat_verify(token)
return vocechat_verify()
@router.post("/webpush/subscribe", summary="客户端webpush通知订阅", response_model=schemas.Response)

View File

@@ -25,7 +25,8 @@ def search_latest(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
def search_by_id(mediaid: str,
mtype: str = None,
area: str = "title",
season: str = None) -> Any:
season: str = None,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
根据TMDBID/豆瓣ID精确搜索站点资源 tmdb:/douban:/bangumi:
"""
@@ -89,7 +90,8 @@ def search_by_id(mediaid: str,
@router.get("/title", summary="模糊搜索资源", response_model=schemas.Response)
def search_by_title(keyword: str = None,
page: int = 0,
site: int = None) -> Any:
site: int = None,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
根据名称模糊搜索站点资源,支持分页,关键词为空是返回首页资源
"""

View File

@@ -12,7 +12,7 @@ from app.chain.search import SearchChain
from app.chain.system import SystemChain
from app.core.config import settings, global_vars
from app.core.module import ModuleManager
from app.core.security import verify_token
from app.core.security import verify_token, verify_uri_token
from app.db.models import User
from app.db.systemconfig_oper import SystemConfigOper
from app.db.user_oper import get_current_active_superuser
@@ -137,16 +137,10 @@ def set_env_setting(env: dict,
@router.get("/progress/{process_type}", summary="实时进度")
def get_progress(process_type: str, token: str):
def get_progress(process_type: str, _: schemas.TokenPayload = Depends(verify_token)):
"""
实时获取处理进度返回格式为SSE
"""
if not token or not verify_token(token):
raise HTTPException(
status_code=403,
detail="认证失败!",
)
progress = ProgressHelper()
def event_generator():
@@ -192,16 +186,10 @@ def set_setting(key: str, value: Union[list, dict, bool, int, str] = None,
@router.get("/message", summary="实时消息")
def get_message(token: str, role: str = "system"):
def get_message(role: str = "system", _: schemas.TokenPayload = Depends(verify_uri_token)):
"""
实时获取系统消息返回格式为SSE
"""
if not token or not verify_token(token):
raise HTTPException(
status_code=403,
detail="认证失败!",
)
message = MessageHelper()
def event_generator():
@@ -216,18 +204,12 @@ def get_message(token: str, role: str = "system"):
@router.get("/logging", summary="实时日志")
def get_logging(token: str, length: int = 50, logfile: str = "moviepilot.log"):
def get_logging(length: int = 50, logfile: str = "moviepilot.log", _: schemas.TokenPayload = Depends(verify_uri_token)):
"""
实时获取系统日志
length = -1 时, 返回text/plain
否则 返回格式SSE
"""
if not token or not verify_token(token):
raise HTTPException(
status_code=403,
detail="认证失败!",
)
log_path = settings.LOG_PATH / logfile
def log_generator():

View File

@@ -139,7 +139,7 @@ def otp_disable(
def otp_enable(userid: str, db: Session = Depends(get_db)) -> Any:
user: User = User.get_by_name(db, userid)
if not user:
return schemas.Response(success=False, message="用户不存在")
return schemas.Response(success=False)
return schemas.Response(success=user.is_otp)

View File

@@ -27,6 +27,9 @@ oauth2_scheme = OAuth2PasswordBearer(
tokenUrl=f"{settings.API_V1_STR}/login/access-token"
)
# JWT TOKEN 通过 QUERY 认证
jwt_token_query = APIKeyQuery(name="token", auto_error=False, scheme_name="jwt_token_query")
# API TOKEN 通过 QUERY 认证
api_token_query = APIKeyQuery(name="token", auto_error=False, scheme_name="api_token_query")
@@ -75,13 +78,13 @@ def create_access_token(
return encoded_jwt
def verify_token(token: str = Security(oauth2_scheme)) -> schemas.TokenPayload:
"""
使用 JWT Token 进行身份认证并解析 Token 的内容
:param token: JWT 令牌,从请求的 Authorization 头部获取
:return: 包含用户身份信息的 Token 负载数据
:raises HTTPException: 如果令牌无效或解码失败,抛出 403 错误
def __verify_token(token: str) -> schemas.TokenPayload:
"""
使用 JWT Token 进行身份认证并解析 Token 的内容
:param token: JWT 令牌
:return: 包含用户身份信息的 Token 负载数据
:raises HTTPException: 如果令牌无效或解码失败,抛出 403 错误
"""
try:
payload = jwt.decode(
token, settings.SECRET_KEY, algorithms=[ALGORITHM]
@@ -94,6 +97,26 @@ def verify_token(token: str = Security(oauth2_scheme)) -> schemas.TokenPayload:
)
def verify_token(token: str = Security(oauth2_scheme)) -> schemas.TokenPayload:
"""
使用 JWT Token 进行身份认证并解析 Token 的内容
:param token: JWT 令牌,从请求的 Authorization 头部获取
:return: 包含用户身份信息的 Token 负载数据
:raises HTTPException: 如果令牌无效或解码失败,抛出 403 错误
"""
return __verify_token(token)
def verify_uri_token(token: str = Security(jwt_token_query)) -> schemas.TokenPayload:
"""
使用 JWT Token 进行身份认证并解析 Token 的内容
:param token: JWT 令牌,从 URL 中的 `token` 查询参数获取
:return: 包含用户身份信息的 Token 负载数据
:raises HTTPException: 如果令牌无效或解码失败,抛出 403 错误
"""
return __verify_token(token)
def __get_api_token(
token_query: Annotated[str | None, Security(api_token_query)] = None
) -> str:
@@ -153,15 +176,6 @@ def verify_apikey(apikey: str = Security(__get_api_key)) -> str:
return __verify_key(apikey, settings.API_TOKEN, "API_KEY")
def verify_uri_token(token: str = Security(__get_api_token)) -> str:
"""
使用 API Token 进行身份认证
:param token: API Token从 URL 查询参数中获取
:return: 返回校验通过的 API Token
"""
return __verify_key(token, settings.API_TOKEN, "API_TOKEN")
def verify_password(plain_password: str, hashed_password: str) -> bool:
return pwd_context.verify(plain_password, hashed_password)

View File

@@ -171,10 +171,6 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]):
client_config = self.get_config(source)
if not client_config:
return None
# 校验token
token = args.get("token")
if not token or token != settings.API_TOKEN:
return None
try:
msg_json: dict = json.loads(body)
except Exception as err:

View File

@@ -85,10 +85,6 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
if not client_config:
return None
client: Telegram = self.get_instance(source)
# 校验token
token = args.get("token")
if not token or token != settings.API_TOKEN:
return None
try:
message: dict = json.loads(body)
except Exception as err:

View File

@@ -81,11 +81,6 @@ class VoceChatModule(_ModuleBase, _MessageBase[VoceChat]):
# 非新消息
return None
logger.debug(f"收到VoceChat请求{msg_body}")
# token校验
token = args.get("token")
if not token or token != settings.API_TOKEN:
logger.warn(f"VoceChat请求token校验失败{token}")
return None
# 文本内容
content = msg_body.get("detail", {}).get("content")
# 用户ID