diff --git a/app/api/endpoints/download.py b/app/api/endpoints/download.py index d159f288..b7a1ade8 100644 --- a/app/api/endpoints/download.py +++ b/app/api/endpoints/download.py @@ -102,7 +102,8 @@ def start( @router.get("/stop/{hashString}", summary="暂停任务", response_model=schemas.Response) -def stop(hashString: str) -> Any: +def stop(hashString: str, + _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 暂停下载任务 """ @@ -111,7 +112,8 @@ def stop(hashString: str) -> Any: @router.delete("/{hashString}", summary="删除下载任务", response_model=schemas.Response) -def delete(hashString: str) -> Any: +def delete(hashString: str, + _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 删除下载任务 """ diff --git a/app/api/endpoints/mediaserver.py b/app/api/endpoints/mediaserver.py index ef30e3fc..0b41cb21 100644 --- a/app/api/endpoints/mediaserver.py +++ b/app/api/endpoints/mediaserver.py @@ -6,38 +6,38 @@ from sqlalchemy.orm import Session from app import schemas from app.chain.download import DownloadChain from app.chain.mediaserver import MediaServerChain -from app.core.config import settings from app.core.context import MediaInfo from app.core.metainfo import MetaInfo from app.core.security import verify_token from app.db import get_db from app.db.mediaserver_oper import MediaServerOper from app.db.models import MediaServerItem +from app.helper.mediaserver import MediaServerHelper from app.schemas import MediaType, NotExistMediaInfo router = APIRouter() -@router.get("/play/{itemid}", summary="在线播放") -def play_item(itemid: str) -> schemas.Response: +@router.get("/play/{itemid:path}", summary="在线播放") +def play_item(itemid: str, _: schemas.TokenPayload = Depends(verify_token)) -> schemas.Response: """ 获取媒体服务器播放页面地址 """ if not itemid: - return schemas.Response(success=False, msg="参数错误") - if not settings.MEDIASERVER: - return schemas.Response(success=False, msg="未配置媒体服务器") - # 查找一个不为空的值 - mediaserver = next((server for server in settings.MEDIASERVER.split(",") if server), None) - if not mediaserver: - return schemas.Response(success=False, msg="未配置媒体服务器") - play_url = MediaServerChain().get_play_url(server=mediaserver, item_id=itemid) - # 重定向到play_url - if not play_url: - return schemas.Response(success=False, msg="未找到播放地址") - return schemas.Response(success=True, data={ - "url": play_url - }) + return schemas.Response(success=False, message="参数错误") + configs = MediaServerHelper().get_configs() + if not configs: + return schemas.Response(success=False, message="未配置媒体服务器") + media_chain = MediaServerChain() + for name in configs.keys(): + item = media_chain.iteminfo(server=name, item_id=itemid) + if item: + play_url = media_chain.get_play_url(server=name, item_id=itemid) + if play_url: + return schemas.Response(success=True, data={ + "url": play_url + }) + return schemas.Response(success=False, message="未找到播放地址") @router.get("/exists", summary="查询本地是否存在(数据库)", response_model=schemas.Response) diff --git a/app/api/endpoints/message.py b/app/api/endpoints/message.py index 239b052f..48121d3f 100644 --- a/app/api/endpoints/message.py +++ b/app/api/endpoints/message.py @@ -9,7 +9,7 @@ from starlette.responses import PlainTextResponse from app import schemas from app.chain.message import MessageChain from app.core.config import settings, global_vars -from app.core.security import verify_token +from app.core.security import verify_token, verify_apitoken from app.db import get_db from app.db.models import User from app.db.models.message import Message @@ -30,7 +30,8 @@ def start_message_chain(body: Any, form: Any, args: Any): @router.post("/", summary="接收用户消息", response_model=schemas.Response) -async def user_message(background_tasks: BackgroundTasks, request: Request): +async def user_message(background_tasks: BackgroundTasks, request: Request, + _: schemas.TokenPayload = Depends(verify_apitoken)): """ 用户消息响应,配置请求中需要添加参数:token=API_TOKEN&source=消息配置名 """ @@ -102,18 +103,17 @@ def wechat_verify(echostr: str, msg_signature: str, timestamp: Union[str, int], return "未找到对应的消息配置" -def vocechat_verify(token: str) -> Any: +def vocechat_verify() -> Any: """ VoceChat验证响应 """ - if token == settings.API_TOKEN: - return {"status": "OK"} - return {"status": "API_TOKEN ERROR"} + return {"status": "OK"} @router.get("/", summary="回调请求验证") def incoming_verify(token: str = None, echostr: str = None, msg_signature: str = None, - timestamp: Union[str, int] = None, nonce: str = None, source: str = None) -> Any: + timestamp: Union[str, int] = None, nonce: str = None, source: str = None, + _: schemas.TokenPayload = Depends(verify_apitoken)) -> Any: """ 微信/VoceChat等验证响应 """ @@ -121,7 +121,7 @@ def incoming_verify(token: str = None, echostr: str = None, msg_signature: str = f"msg_signature={msg_signature}, timestamp={timestamp}, nonce={nonce}") if echostr and msg_signature and timestamp and nonce: return wechat_verify(echostr, msg_signature, timestamp, nonce, source) - return vocechat_verify(token) + return vocechat_verify() @router.post("/webpush/subscribe", summary="客户端webpush通知订阅", response_model=schemas.Response) diff --git a/app/api/endpoints/search.py b/app/api/endpoints/search.py index dea642bc..345f0b10 100644 --- a/app/api/endpoints/search.py +++ b/app/api/endpoints/search.py @@ -25,7 +25,8 @@ def search_latest(_: schemas.TokenPayload = Depends(verify_token)) -> Any: def search_by_id(mediaid: str, mtype: str = None, area: str = "title", - season: str = None) -> Any: + season: str = None, + _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 根据TMDBID/豆瓣ID精确搜索站点资源 tmdb:/douban:/bangumi: """ @@ -89,7 +90,8 @@ def search_by_id(mediaid: str, @router.get("/title", summary="模糊搜索资源", response_model=schemas.Response) def search_by_title(keyword: str = None, page: int = 0, - site: int = None) -> Any: + site: int = None, + _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 根据名称模糊搜索站点资源,支持分页,关键词为空是返回首页资源 """ diff --git a/app/api/endpoints/system.py b/app/api/endpoints/system.py index 7704c788..3b3e923d 100644 --- a/app/api/endpoints/system.py +++ b/app/api/endpoints/system.py @@ -12,7 +12,7 @@ from app.chain.search import SearchChain from app.chain.system import SystemChain from app.core.config import settings, global_vars from app.core.module import ModuleManager -from app.core.security import verify_token +from app.core.security import verify_token, verify_uri_token from app.db.models import User from app.db.systemconfig_oper import SystemConfigOper from app.db.user_oper import get_current_active_superuser @@ -137,16 +137,10 @@ def set_env_setting(env: dict, @router.get("/progress/{process_type}", summary="实时进度") -def get_progress(process_type: str, token: str): +def get_progress(process_type: str, _: schemas.TokenPayload = Depends(verify_token)): """ 实时获取处理进度,返回格式为SSE """ - if not token or not verify_token(token): - raise HTTPException( - status_code=403, - detail="认证失败!", - ) - progress = ProgressHelper() def event_generator(): @@ -192,16 +186,10 @@ def set_setting(key: str, value: Union[list, dict, bool, int, str] = None, @router.get("/message", summary="实时消息") -def get_message(token: str, role: str = "system"): +def get_message(role: str = "system", _: schemas.TokenPayload = Depends(verify_uri_token)): """ 实时获取系统消息,返回格式为SSE """ - if not token or not verify_token(token): - raise HTTPException( - status_code=403, - detail="认证失败!", - ) - message = MessageHelper() def event_generator(): @@ -216,18 +204,12 @@ def get_message(token: str, role: str = "system"): @router.get("/logging", summary="实时日志") -def get_logging(token: str, length: int = 50, logfile: str = "moviepilot.log"): +def get_logging(length: int = 50, logfile: str = "moviepilot.log", _: schemas.TokenPayload = Depends(verify_uri_token)): """ 实时获取系统日志 length = -1 时, 返回text/plain 否则 返回格式SSE """ - if not token or not verify_token(token): - raise HTTPException( - status_code=403, - detail="认证失败!", - ) - log_path = settings.LOG_PATH / logfile def log_generator(): diff --git a/app/api/endpoints/user.py b/app/api/endpoints/user.py index 56172be3..ae295eaa 100644 --- a/app/api/endpoints/user.py +++ b/app/api/endpoints/user.py @@ -139,7 +139,7 @@ def otp_disable( def otp_enable(userid: str, db: Session = Depends(get_db)) -> Any: user: User = User.get_by_name(db, userid) if not user: - return schemas.Response(success=False, message="用户不存在") + return schemas.Response(success=False) return schemas.Response(success=user.is_otp) diff --git a/app/core/security.py b/app/core/security.py index 1f96190e..8c73c2f7 100644 --- a/app/core/security.py +++ b/app/core/security.py @@ -27,6 +27,9 @@ oauth2_scheme = OAuth2PasswordBearer( tokenUrl=f"{settings.API_V1_STR}/login/access-token" ) +# JWT TOKEN 通过 QUERY 认证 +jwt_token_query = APIKeyQuery(name="token", auto_error=False, scheme_name="jwt_token_query") + # API TOKEN 通过 QUERY 认证 api_token_query = APIKeyQuery(name="token", auto_error=False, scheme_name="api_token_query") @@ -75,13 +78,13 @@ def create_access_token( return encoded_jwt -def verify_token(token: str = Security(oauth2_scheme)) -> schemas.TokenPayload: - """ - 使用 JWT Token 进行身份认证并解析 Token 的内容 - :param token: JWT 令牌,从请求的 Authorization 头部获取 - :return: 包含用户身份信息的 Token 负载数据 - :raises HTTPException: 如果令牌无效或解码失败,抛出 403 错误 +def __verify_token(token: str) -> schemas.TokenPayload: """ + 使用 JWT Token 进行身份认证并解析 Token 的内容 + :param token: JWT 令牌 + :return: 包含用户身份信息的 Token 负载数据 + :raises HTTPException: 如果令牌无效或解码失败,抛出 403 错误 + """ try: payload = jwt.decode( token, settings.SECRET_KEY, algorithms=[ALGORITHM] @@ -94,6 +97,26 @@ def verify_token(token: str = Security(oauth2_scheme)) -> schemas.TokenPayload: ) +def verify_token(token: str = Security(oauth2_scheme)) -> schemas.TokenPayload: + """ + 使用 JWT Token 进行身份认证并解析 Token 的内容 + :param token: JWT 令牌,从请求的 Authorization 头部获取 + :return: 包含用户身份信息的 Token 负载数据 + :raises HTTPException: 如果令牌无效或解码失败,抛出 403 错误 + """ + return __verify_token(token) + + +def verify_uri_token(token: str = Security(jwt_token_query)) -> schemas.TokenPayload: + """ + 使用 JWT Token 进行身份认证并解析 Token 的内容 + :param token: JWT 令牌,从 URL 中的 `token` 查询参数获取 + :return: 包含用户身份信息的 Token 负载数据 + :raises HTTPException: 如果令牌无效或解码失败,抛出 403 错误 + """ + return __verify_token(token) + + def __get_api_token( token_query: Annotated[str | None, Security(api_token_query)] = None ) -> str: @@ -153,15 +176,6 @@ def verify_apikey(apikey: str = Security(__get_api_key)) -> str: return __verify_key(apikey, settings.API_TOKEN, "API_KEY") -def verify_uri_token(token: str = Security(__get_api_token)) -> str: - """ - 使用 API Token 进行身份认证 - :param token: API Token,从 URL 查询参数中获取 - :return: 返回校验通过的 API Token - """ - return __verify_key(token, settings.API_TOKEN, "API_TOKEN") - - def verify_password(plain_password: str, hashed_password: str) -> bool: return pwd_context.verify(plain_password, hashed_password) diff --git a/app/modules/slack/__init__.py b/app/modules/slack/__init__.py index f86d08ff..34476f85 100644 --- a/app/modules/slack/__init__.py +++ b/app/modules/slack/__init__.py @@ -171,10 +171,6 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]): client_config = self.get_config(source) if not client_config: return None - # 校验token - token = args.get("token") - if not token or token != settings.API_TOKEN: - return None try: msg_json: dict = json.loads(body) except Exception as err: diff --git a/app/modules/telegram/__init__.py b/app/modules/telegram/__init__.py index af444c15..ad19647c 100644 --- a/app/modules/telegram/__init__.py +++ b/app/modules/telegram/__init__.py @@ -85,10 +85,6 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]): if not client_config: return None client: Telegram = self.get_instance(source) - # 校验token - token = args.get("token") - if not token or token != settings.API_TOKEN: - return None try: message: dict = json.loads(body) except Exception as err: diff --git a/app/modules/vocechat/__init__.py b/app/modules/vocechat/__init__.py index 661bb580..479cd5f1 100644 --- a/app/modules/vocechat/__init__.py +++ b/app/modules/vocechat/__init__.py @@ -81,11 +81,6 @@ class VoceChatModule(_ModuleBase, _MessageBase[VoceChat]): # 非新消息 return None logger.debug(f"收到VoceChat请求:{msg_body}") - # token校验 - token = args.get("token") - if not token or token != settings.API_TOKEN: - logger.warn(f"VoceChat请求token校验失败:{token}") - return None # 文本内容 content = msg_body.get("detail", {}).get("content") # 用户ID