diff --git a/app/api/endpoints/system.py b/app/api/endpoints/system.py index 44085269..5b63a225 100644 --- a/app/api/endpoints/system.py +++ b/app/api/endpoints/system.py @@ -2,13 +2,14 @@ import asyncio import io import json import re +import tempfile from collections import deque from datetime import datetime from typing import Optional, Union, Annotated -import aiofiles import pillow_avif # noqa 用于自动注册AVIF支持 from PIL import Image +from aiopath import AsyncPath from fastapi import APIRouter, Body, Depends, HTTPException, Header, Request, Response from fastapi.responses import StreamingResponse @@ -63,24 +64,27 @@ async def fetch_image( raise HTTPException(status_code=404, detail="Unsafe URL") # 后续观察系统性能表现,如果发现磁盘缓存和HTTP缓存无法满足高并发情况下的响应速度需求,可以考虑重新引入内存缓存 - cache_path = None + cache_path: Optional[AsyncPath] = None if use_disk_cache: # 生成缓存路径 + base_path = AsyncPath(settings.CACHE_PATH) sanitized_path = SecurityUtils.sanitize_url_path(url) - cache_path = settings.CACHE_PATH / "images" / sanitized_path + cache_path = base_path / "images" / sanitized_path # 没有文件类型,则添加后缀,在恶意文件类型和实际需求下的折衷选择 if not cache_path.suffix: cache_path = cache_path.with_suffix(".jpg") # 确保缓存路径和文件类型合法 - if not SecurityUtils.is_safe_path(settings.CACHE_PATH, cache_path, settings.SECURITY_IMAGE_SUFFIXES): + if not await SecurityUtils.async_is_safe_path(base_path=base_path, + user_path=cache_path, + allowed_suffixes=settings.SECURITY_IMAGE_SUFFIXES): raise HTTPException(status_code=400, detail="Invalid cache path or file type") # 目前暂不考虑磁盘缓存文件是否过期,后续通过缓存清理机制处理 - if cache_path.exists(): + if cache_path and await cache_path.exists(): try: - async with aiofiles.open(cache_path, 'rb') as f: + async with cache_path.open(cache_path, 'rb') as f: content = await f.read() etag = HashUtils.md5(content) headers = RequestUtils.generate_cache_headers(etag, max_age=86400 * 7) @@ -115,11 +119,12 @@ async def fetch_image( # 如果需要使用磁盘缓存,则保存到磁盘 if use_disk_cache and cache_path: try: - if not cache_path.parent.exists(): - cache_path.parent.mkdir(parents=True, exist_ok=True) - # 使用异步文件操作写入缓存 - async with aiofiles.open(cache_path, 'wb') as f: - await f.write(content) + if not await cache_path.parent.exists(): + await cache_path.parent.mkdir(parents=True, exist_ok=True) + with tempfile.NamedTemporaryFile(dir=cache_path.parent, delete=False) as tmp_file: + tmp_file.write(content) + temp_path = AsyncPath(tmp_file.name) + await temp_path.replace(cache_path) except Exception as e: logger.debug(f"Failed to write cache file {cache_path}: {e}") @@ -357,12 +362,12 @@ async def get_logging(request: Request, length: Optional[int] = 50, logfile: Opt length = -1 时, 返回text/plain 否则 返回格式SSE """ - log_path = settings.LOG_PATH / logfile + log_path = AsyncPath(settings.LOG_PATH) / logfile if not SecurityUtils.is_safe_path(settings.LOG_PATH, log_path, allowed_suffixes={".log"}): raise HTTPException(status_code=404, detail="Not Found") - if not log_path.exists() or not log_path.is_file(): + if not await log_path.exists() or not await log_path.is_file(): raise HTTPException(status_code=404, detail="Not Found") async def log_generator(): @@ -370,7 +375,7 @@ async def get_logging(request: Request, length: Optional[int] = 50, logfile: Opt # 使用固定大小的双向队列来限制内存使用 lines_queue = deque(maxlen=max(length, 50)) # 使用 aiofiles 异步读取文件 - async with aiofiles.open(log_path, mode="r", encoding="utf-8") as f: + async with log_path.open(mode="r", encoding="utf-8") as f: # 逐行读取文件,将每一行存入队列 file_content = await f.read() for line in file_content.splitlines(): @@ -393,10 +398,10 @@ async def get_logging(request: Request, length: Optional[int] = 50, logfile: Opt # 根据length参数返回不同的响应 if length == -1: # 返回全部日志作为文本响应 - if not log_path.exists(): + if not await log_path.exists(): return Response(content="日志文件不存在!", media_type="text/plain") # 使用 aiofiles 异步读取文件 - async with aiofiles.open(log_path, mode="r", encoding="utf-8") as file: + async with log_path.open(mode="r", encoding="utf-8") as file: text = await file.read() # 倒序输出 text = "\n".join(text.split("\n")[::-1]) diff --git a/app/utils/http.py b/app/utils/http.py index 75161976..2eb49e0b 100644 --- a/app/utils/http.py +++ b/app/utils/http.py @@ -168,7 +168,9 @@ class RequestUtils: try: return req_method(method, url, **kwargs) except requests.exceptions.RequestException as e: - logger.debug(f"请求失败: {e}") + # 获取更详细的错误信息 + error_msg = str(e) if str(e) else f"未知网络错误 (URL: {url}, Method: {method.upper()})" + logger.debug(f"请求失败: {error_msg}") if raise_exception: raise return None @@ -603,18 +605,18 @@ class AsyncRequestUtils: """ if not proxies: return None - + # 如果已经是字符串格式,直接返回 if isinstance(proxies, str): return proxies - + # 如果是字典格式,提取http或https代理 if isinstance(proxies, dict): # 优先使用https代理,如果没有则使用http代理 proxy_url = proxies.get("https") or proxies.get("http") if proxy_url: return proxy_url - + return None @asynccontextmanager @@ -669,7 +671,9 @@ class AsyncRequestUtils: try: return await client.request(method, url, **kwargs) except httpx.RequestError as e: - logger.debug(f"异步请求失败: {e}") + # 获取更详细的错误信息 + error_msg = str(e) if str(e) else f"未知网络错误 (URL: {url}, Method: {method.upper()})" + logger.debug(f"异步请求失败: {error_msg}") if raise_exception: raise return None diff --git a/app/utils/security.py b/app/utils/security.py index 8190d928..93e8f640 100644 --- a/app/utils/security.py +++ b/app/utils/security.py @@ -3,6 +3,8 @@ from pathlib import Path from typing import List, Optional, Set, Union from urllib.parse import quote, urlparse +from aiopath import AsyncPath + from app.log import logger @@ -39,6 +41,37 @@ class SecurityUtils: logger.debug(f"Error occurred while validating paths: {e}") return False + @staticmethod + async def async_is_safe_path(base_path: AsyncPath, user_path: AsyncPath, + allowed_suffixes: Optional[Union[Set[str], List[str]]] = None) -> bool: + """ + 异步验证用户提供的路径是否在基准目录内,并检查文件类型是否合法,防止目录遍历攻击 + + :param base_path: 基准目录,允许访问的根目录 + :param user_path: 用户提供的路径,需检查其是否位于基准目录内 + :param allowed_suffixes: 允许的文件后缀名集合,用于验证文件类型 + :return: 如果用户路径安全且位于基准目录内,且文件类型合法,返回 True;否则返回 False + :raises Exception: 如果解析路径时发生错误,则捕获并记录异常 + """ + try: + # resolve() 将相对路径转换为绝对路径,并处理符号链接和'..' + base_path_resolved = await base_path.resolve() + user_path_resolved = await user_path.resolve() + + # 检查用户路径是否在基准目录或基准目录的子目录内 + if base_path_resolved != user_path_resolved and base_path_resolved not in user_path_resolved.parents: + return False + + if allowed_suffixes is not None: + allowed_suffixes = set(allowed_suffixes) + if user_path.suffix.lower() not in allowed_suffixes: + return False + + return True + except Exception as e: + logger.debug(f"Error occurred while validating paths: {e}") + return False + @staticmethod def is_safe_url(url: str, allowed_domains: Union[Set[str], List[str]], strict: bool = False) -> bool: """ diff --git a/requirements.in b/requirements.in index 34944860..d1b7766a 100644 --- a/requirements.in +++ b/requirements.in @@ -60,6 +60,8 @@ Pinyin2Hanzi~=0.1.1 pywebpush~=2.0.3 python-cookietools==0.0.4 aiofiles~=24.1.0 +aiopath~=0.7.7 +asynctempfile~=0.5.0 aiosqlite~=0.21.0 jieba~=0.42.1 rsa~=4.9