add aiopath

This commit is contained in:
jxxghp
2025-07-30 19:49:59 +08:00
parent 49647e3bb5
commit c8749b3c9c
4 changed files with 65 additions and 21 deletions

View File

@@ -2,13 +2,14 @@ import asyncio
import io import io
import json import json
import re import re
import tempfile
from collections import deque from collections import deque
from datetime import datetime from datetime import datetime
from typing import Optional, Union, Annotated from typing import Optional, Union, Annotated
import aiofiles
import pillow_avif # noqa 用于自动注册AVIF支持 import pillow_avif # noqa 用于自动注册AVIF支持
from PIL import Image from PIL import Image
from aiopath import AsyncPath
from fastapi import APIRouter, Body, Depends, HTTPException, Header, Request, Response from fastapi import APIRouter, Body, Depends, HTTPException, Header, Request, Response
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
@@ -63,24 +64,27 @@ async def fetch_image(
raise HTTPException(status_code=404, detail="Unsafe URL") raise HTTPException(status_code=404, detail="Unsafe URL")
# 后续观察系统性能表现如果发现磁盘缓存和HTTP缓存无法满足高并发情况下的响应速度需求可以考虑重新引入内存缓存 # 后续观察系统性能表现如果发现磁盘缓存和HTTP缓存无法满足高并发情况下的响应速度需求可以考虑重新引入内存缓存
cache_path = None cache_path: Optional[AsyncPath] = None
if use_disk_cache: if use_disk_cache:
# 生成缓存路径 # 生成缓存路径
base_path = AsyncPath(settings.CACHE_PATH)
sanitized_path = SecurityUtils.sanitize_url_path(url) sanitized_path = SecurityUtils.sanitize_url_path(url)
cache_path = settings.CACHE_PATH / "images" / sanitized_path cache_path = base_path / "images" / sanitized_path
# 没有文件类型,则添加后缀,在恶意文件类型和实际需求下的折衷选择 # 没有文件类型,则添加后缀,在恶意文件类型和实际需求下的折衷选择
if not cache_path.suffix: if not cache_path.suffix:
cache_path = cache_path.with_suffix(".jpg") cache_path = cache_path.with_suffix(".jpg")
# 确保缓存路径和文件类型合法 # 确保缓存路径和文件类型合法
if not SecurityUtils.is_safe_path(settings.CACHE_PATH, cache_path, settings.SECURITY_IMAGE_SUFFIXES): if not await SecurityUtils.async_is_safe_path(base_path=base_path,
user_path=cache_path,
allowed_suffixes=settings.SECURITY_IMAGE_SUFFIXES):
raise HTTPException(status_code=400, detail="Invalid cache path or file type") raise HTTPException(status_code=400, detail="Invalid cache path or file type")
# 目前暂不考虑磁盘缓存文件是否过期,后续通过缓存清理机制处理 # 目前暂不考虑磁盘缓存文件是否过期,后续通过缓存清理机制处理
if cache_path.exists(): if cache_path and await cache_path.exists():
try: try:
async with aiofiles.open(cache_path, 'rb') as f: async with cache_path.open(cache_path, 'rb') as f:
content = await f.read() content = await f.read()
etag = HashUtils.md5(content) etag = HashUtils.md5(content)
headers = RequestUtils.generate_cache_headers(etag, max_age=86400 * 7) headers = RequestUtils.generate_cache_headers(etag, max_age=86400 * 7)
@@ -115,11 +119,12 @@ async def fetch_image(
# 如果需要使用磁盘缓存,则保存到磁盘 # 如果需要使用磁盘缓存,则保存到磁盘
if use_disk_cache and cache_path: if use_disk_cache and cache_path:
try: try:
if not cache_path.parent.exists(): if not await cache_path.parent.exists():
cache_path.parent.mkdir(parents=True, exist_ok=True) await cache_path.parent.mkdir(parents=True, exist_ok=True)
# 使用异步文件操作写入缓存 with tempfile.NamedTemporaryFile(dir=cache_path.parent, delete=False) as tmp_file:
async with aiofiles.open(cache_path, 'wb') as f: tmp_file.write(content)
await f.write(content) temp_path = AsyncPath(tmp_file.name)
await temp_path.replace(cache_path)
except Exception as e: except Exception as e:
logger.debug(f"Failed to write cache file {cache_path}: {e}") logger.debug(f"Failed to write cache file {cache_path}: {e}")
@@ -357,12 +362,12 @@ async def get_logging(request: Request, length: Optional[int] = 50, logfile: Opt
length = -1 时, 返回text/plain length = -1 时, 返回text/plain
否则 返回格式SSE 否则 返回格式SSE
""" """
log_path = settings.LOG_PATH / logfile log_path = AsyncPath(settings.LOG_PATH) / logfile
if not SecurityUtils.is_safe_path(settings.LOG_PATH, log_path, allowed_suffixes={".log"}): if not SecurityUtils.is_safe_path(settings.LOG_PATH, log_path, allowed_suffixes={".log"}):
raise HTTPException(status_code=404, detail="Not Found") raise HTTPException(status_code=404, detail="Not Found")
if not log_path.exists() or not log_path.is_file(): if not await log_path.exists() or not await log_path.is_file():
raise HTTPException(status_code=404, detail="Not Found") raise HTTPException(status_code=404, detail="Not Found")
async def log_generator(): async def log_generator():
@@ -370,7 +375,7 @@ async def get_logging(request: Request, length: Optional[int] = 50, logfile: Opt
# 使用固定大小的双向队列来限制内存使用 # 使用固定大小的双向队列来限制内存使用
lines_queue = deque(maxlen=max(length, 50)) lines_queue = deque(maxlen=max(length, 50))
# 使用 aiofiles 异步读取文件 # 使用 aiofiles 异步读取文件
async with aiofiles.open(log_path, mode="r", encoding="utf-8") as f: async with log_path.open(mode="r", encoding="utf-8") as f:
# 逐行读取文件,将每一行存入队列 # 逐行读取文件,将每一行存入队列
file_content = await f.read() file_content = await f.read()
for line in file_content.splitlines(): for line in file_content.splitlines():
@@ -393,10 +398,10 @@ async def get_logging(request: Request, length: Optional[int] = 50, logfile: Opt
# 根据length参数返回不同的响应 # 根据length参数返回不同的响应
if length == -1: if length == -1:
# 返回全部日志作为文本响应 # 返回全部日志作为文本响应
if not log_path.exists(): if not await log_path.exists():
return Response(content="日志文件不存在!", media_type="text/plain") return Response(content="日志文件不存在!", media_type="text/plain")
# 使用 aiofiles 异步读取文件 # 使用 aiofiles 异步读取文件
async with aiofiles.open(log_path, mode="r", encoding="utf-8") as file: async with log_path.open(mode="r", encoding="utf-8") as file:
text = await file.read() text = await file.read()
# 倒序输出 # 倒序输出
text = "\n".join(text.split("\n")[::-1]) text = "\n".join(text.split("\n")[::-1])

View File

@@ -168,7 +168,9 @@ class RequestUtils:
try: try:
return req_method(method, url, **kwargs) return req_method(method, url, **kwargs)
except requests.exceptions.RequestException as e: except requests.exceptions.RequestException as e:
logger.debug(f"请求失败: {e}") # 获取更详细的错误信息
error_msg = str(e) if str(e) else f"未知网络错误 (URL: {url}, Method: {method.upper()})"
logger.debug(f"请求失败: {error_msg}")
if raise_exception: if raise_exception:
raise raise
return None return None
@@ -669,7 +671,9 @@ class AsyncRequestUtils:
try: try:
return await client.request(method, url, **kwargs) return await client.request(method, url, **kwargs)
except httpx.RequestError as e: except httpx.RequestError as e:
logger.debug(f"异步请求失败: {e}") # 获取更详细的错误信息
error_msg = str(e) if str(e) else f"未知网络错误 (URL: {url}, Method: {method.upper()})"
logger.debug(f"异步请求失败: {error_msg}")
if raise_exception: if raise_exception:
raise raise
return None return None

View File

@@ -3,6 +3,8 @@ from pathlib import Path
from typing import List, Optional, Set, Union from typing import List, Optional, Set, Union
from urllib.parse import quote, urlparse from urllib.parse import quote, urlparse
from aiopath import AsyncPath
from app.log import logger from app.log import logger
@@ -39,6 +41,37 @@ class SecurityUtils:
logger.debug(f"Error occurred while validating paths: {e}") logger.debug(f"Error occurred while validating paths: {e}")
return False return False
@staticmethod
async def async_is_safe_path(base_path: AsyncPath, user_path: AsyncPath,
allowed_suffixes: Optional[Union[Set[str], List[str]]] = None) -> bool:
"""
异步验证用户提供的路径是否在基准目录内,并检查文件类型是否合法,防止目录遍历攻击
:param base_path: 基准目录,允许访问的根目录
:param user_path: 用户提供的路径,需检查其是否位于基准目录内
:param allowed_suffixes: 允许的文件后缀名集合,用于验证文件类型
:return: 如果用户路径安全且位于基准目录内,且文件类型合法,返回 True否则返回 False
:raises Exception: 如果解析路径时发生错误,则捕获并记录异常
"""
try:
# resolve() 将相对路径转换为绝对路径,并处理符号链接和'..'
base_path_resolved = await base_path.resolve()
user_path_resolved = await user_path.resolve()
# 检查用户路径是否在基准目录或基准目录的子目录内
if base_path_resolved != user_path_resolved and base_path_resolved not in user_path_resolved.parents:
return False
if allowed_suffixes is not None:
allowed_suffixes = set(allowed_suffixes)
if user_path.suffix.lower() not in allowed_suffixes:
return False
return True
except Exception as e:
logger.debug(f"Error occurred while validating paths: {e}")
return False
@staticmethod @staticmethod
def is_safe_url(url: str, allowed_domains: Union[Set[str], List[str]], strict: bool = False) -> bool: def is_safe_url(url: str, allowed_domains: Union[Set[str], List[str]], strict: bool = False) -> bool:
""" """

View File

@@ -60,6 +60,8 @@ Pinyin2Hanzi~=0.1.1
pywebpush~=2.0.3 pywebpush~=2.0.3
python-cookietools==0.0.4 python-cookietools==0.0.4
aiofiles~=24.1.0 aiofiles~=24.1.0
aiopath~=0.7.7
asynctempfile~=0.5.0
aiosqlite~=0.21.0 aiosqlite~=0.21.0
jieba~=0.42.1 jieba~=0.42.1
rsa~=4.9 rsa~=4.9