mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-03-20 03:57:30 +08:00
add aiopath
This commit is contained in:
@@ -2,13 +2,14 @@ import asyncio
|
||||
import io
|
||||
import json
|
||||
import re
|
||||
import tempfile
|
||||
from collections import deque
|
||||
from datetime import datetime
|
||||
from typing import Optional, Union, Annotated
|
||||
|
||||
import aiofiles
|
||||
import pillow_avif # noqa 用于自动注册AVIF支持
|
||||
from PIL import Image
|
||||
from aiopath import AsyncPath
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Header, Request, Response
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
@@ -63,24 +64,27 @@ async def fetch_image(
|
||||
raise HTTPException(status_code=404, detail="Unsafe URL")
|
||||
|
||||
# 后续观察系统性能表现,如果发现磁盘缓存和HTTP缓存无法满足高并发情况下的响应速度需求,可以考虑重新引入内存缓存
|
||||
cache_path = None
|
||||
cache_path: Optional[AsyncPath] = None
|
||||
if use_disk_cache:
|
||||
# 生成缓存路径
|
||||
base_path = AsyncPath(settings.CACHE_PATH)
|
||||
sanitized_path = SecurityUtils.sanitize_url_path(url)
|
||||
cache_path = settings.CACHE_PATH / "images" / sanitized_path
|
||||
cache_path = base_path / "images" / sanitized_path
|
||||
|
||||
# 没有文件类型,则添加后缀,在恶意文件类型和实际需求下的折衷选择
|
||||
if not cache_path.suffix:
|
||||
cache_path = cache_path.with_suffix(".jpg")
|
||||
|
||||
# 确保缓存路径和文件类型合法
|
||||
if not SecurityUtils.is_safe_path(settings.CACHE_PATH, cache_path, settings.SECURITY_IMAGE_SUFFIXES):
|
||||
if not await SecurityUtils.async_is_safe_path(base_path=base_path,
|
||||
user_path=cache_path,
|
||||
allowed_suffixes=settings.SECURITY_IMAGE_SUFFIXES):
|
||||
raise HTTPException(status_code=400, detail="Invalid cache path or file type")
|
||||
|
||||
# 目前暂不考虑磁盘缓存文件是否过期,后续通过缓存清理机制处理
|
||||
if cache_path.exists():
|
||||
if cache_path and await cache_path.exists():
|
||||
try:
|
||||
async with aiofiles.open(cache_path, 'rb') as f:
|
||||
async with cache_path.open(cache_path, 'rb') as f:
|
||||
content = await f.read()
|
||||
etag = HashUtils.md5(content)
|
||||
headers = RequestUtils.generate_cache_headers(etag, max_age=86400 * 7)
|
||||
@@ -115,11 +119,12 @@ async def fetch_image(
|
||||
# 如果需要使用磁盘缓存,则保存到磁盘
|
||||
if use_disk_cache and cache_path:
|
||||
try:
|
||||
if not cache_path.parent.exists():
|
||||
cache_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
# 使用异步文件操作写入缓存
|
||||
async with aiofiles.open(cache_path, 'wb') as f:
|
||||
await f.write(content)
|
||||
if not await cache_path.parent.exists():
|
||||
await cache_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with tempfile.NamedTemporaryFile(dir=cache_path.parent, delete=False) as tmp_file:
|
||||
tmp_file.write(content)
|
||||
temp_path = AsyncPath(tmp_file.name)
|
||||
await temp_path.replace(cache_path)
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to write cache file {cache_path}: {e}")
|
||||
|
||||
@@ -357,12 +362,12 @@ async def get_logging(request: Request, length: Optional[int] = 50, logfile: Opt
|
||||
length = -1 时, 返回text/plain
|
||||
否则 返回格式SSE
|
||||
"""
|
||||
log_path = settings.LOG_PATH / logfile
|
||||
log_path = AsyncPath(settings.LOG_PATH) / logfile
|
||||
|
||||
if not SecurityUtils.is_safe_path(settings.LOG_PATH, log_path, allowed_suffixes={".log"}):
|
||||
raise HTTPException(status_code=404, detail="Not Found")
|
||||
|
||||
if not log_path.exists() or not log_path.is_file():
|
||||
if not await log_path.exists() or not await log_path.is_file():
|
||||
raise HTTPException(status_code=404, detail="Not Found")
|
||||
|
||||
async def log_generator():
|
||||
@@ -370,7 +375,7 @@ async def get_logging(request: Request, length: Optional[int] = 50, logfile: Opt
|
||||
# 使用固定大小的双向队列来限制内存使用
|
||||
lines_queue = deque(maxlen=max(length, 50))
|
||||
# 使用 aiofiles 异步读取文件
|
||||
async with aiofiles.open(log_path, mode="r", encoding="utf-8") as f:
|
||||
async with log_path.open(mode="r", encoding="utf-8") as f:
|
||||
# 逐行读取文件,将每一行存入队列
|
||||
file_content = await f.read()
|
||||
for line in file_content.splitlines():
|
||||
@@ -393,10 +398,10 @@ async def get_logging(request: Request, length: Optional[int] = 50, logfile: Opt
|
||||
# 根据length参数返回不同的响应
|
||||
if length == -1:
|
||||
# 返回全部日志作为文本响应
|
||||
if not log_path.exists():
|
||||
if not await log_path.exists():
|
||||
return Response(content="日志文件不存在!", media_type="text/plain")
|
||||
# 使用 aiofiles 异步读取文件
|
||||
async with aiofiles.open(log_path, mode="r", encoding="utf-8") as file:
|
||||
async with log_path.open(mode="r", encoding="utf-8") as file:
|
||||
text = await file.read()
|
||||
# 倒序输出
|
||||
text = "\n".join(text.split("\n")[::-1])
|
||||
|
||||
@@ -168,7 +168,9 @@ class RequestUtils:
|
||||
try:
|
||||
return req_method(method, url, **kwargs)
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.debug(f"请求失败: {e}")
|
||||
# 获取更详细的错误信息
|
||||
error_msg = str(e) if str(e) else f"未知网络错误 (URL: {url}, Method: {method.upper()})"
|
||||
logger.debug(f"请求失败: {error_msg}")
|
||||
if raise_exception:
|
||||
raise
|
||||
return None
|
||||
@@ -669,7 +671,9 @@ class AsyncRequestUtils:
|
||||
try:
|
||||
return await client.request(method, url, **kwargs)
|
||||
except httpx.RequestError as e:
|
||||
logger.debug(f"异步请求失败: {e}")
|
||||
# 获取更详细的错误信息
|
||||
error_msg = str(e) if str(e) else f"未知网络错误 (URL: {url}, Method: {method.upper()})"
|
||||
logger.debug(f"异步请求失败: {error_msg}")
|
||||
if raise_exception:
|
||||
raise
|
||||
return None
|
||||
|
||||
@@ -3,6 +3,8 @@ from pathlib import Path
|
||||
from typing import List, Optional, Set, Union
|
||||
from urllib.parse import quote, urlparse
|
||||
|
||||
from aiopath import AsyncPath
|
||||
|
||||
from app.log import logger
|
||||
|
||||
|
||||
@@ -39,6 +41,37 @@ class SecurityUtils:
|
||||
logger.debug(f"Error occurred while validating paths: {e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def async_is_safe_path(base_path: AsyncPath, user_path: AsyncPath,
|
||||
allowed_suffixes: Optional[Union[Set[str], List[str]]] = None) -> bool:
|
||||
"""
|
||||
异步验证用户提供的路径是否在基准目录内,并检查文件类型是否合法,防止目录遍历攻击
|
||||
|
||||
:param base_path: 基准目录,允许访问的根目录
|
||||
:param user_path: 用户提供的路径,需检查其是否位于基准目录内
|
||||
:param allowed_suffixes: 允许的文件后缀名集合,用于验证文件类型
|
||||
:return: 如果用户路径安全且位于基准目录内,且文件类型合法,返回 True;否则返回 False
|
||||
:raises Exception: 如果解析路径时发生错误,则捕获并记录异常
|
||||
"""
|
||||
try:
|
||||
# resolve() 将相对路径转换为绝对路径,并处理符号链接和'..'
|
||||
base_path_resolved = await base_path.resolve()
|
||||
user_path_resolved = await user_path.resolve()
|
||||
|
||||
# 检查用户路径是否在基准目录或基准目录的子目录内
|
||||
if base_path_resolved != user_path_resolved and base_path_resolved not in user_path_resolved.parents:
|
||||
return False
|
||||
|
||||
if allowed_suffixes is not None:
|
||||
allowed_suffixes = set(allowed_suffixes)
|
||||
if user_path.suffix.lower() not in allowed_suffixes:
|
||||
return False
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.debug(f"Error occurred while validating paths: {e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def is_safe_url(url: str, allowed_domains: Union[Set[str], List[str]], strict: bool = False) -> bool:
|
||||
"""
|
||||
|
||||
@@ -60,6 +60,8 @@ Pinyin2Hanzi~=0.1.1
|
||||
pywebpush~=2.0.3
|
||||
python-cookietools==0.0.4
|
||||
aiofiles~=24.1.0
|
||||
aiopath~=0.7.7
|
||||
asynctempfile~=0.5.0
|
||||
aiosqlite~=0.21.0
|
||||
jieba~=0.42.1
|
||||
rsa~=4.9
|
||||
|
||||
Reference in New Issue
Block a user