From ee4d6d0db3a3c2f5ed65679ca21c3a2ecaea52b8 Mon Sep 17 00:00:00 2001 From: jxxghp Date: Thu, 31 Jul 2025 09:55:47 +0800 Subject: [PATCH] fix cache --- app/core/cache.py | 80 +++++++++++++++++++++++++++++++++-------------- 1 file changed, 56 insertions(+), 24 deletions(-) diff --git a/app/core/cache.py b/app/core/cache.py index e2a1d0d4..1cccb63c 100644 --- a/app/core/cache.py +++ b/app/core/cache.py @@ -529,33 +529,65 @@ def cached(region: Optional[str] = None, maxsize: Optional[int] = 512, ttl: Opti # 获取缓存区 cache_region = region if region is not None else f"{func.__module__}.{func.__name__}" - @wraps(func) - def wrapper(*args, **kwargs): - # 获取缓存键 - cache_key = cache_backend.get_cache_key(func, args, kwargs) - # 尝试获取缓存 - cached_value = cache_backend.get(cache_key, region=cache_region) - if should_cache(cached_value) and is_valid_cache_value(cache_key, cached_value, cache_region): - return cached_value - # 执行函数并缓存结果 - result = func(*args, **kwargs) - # 判断是否需要缓存 - if not should_cache(result): + # 检查是否为异步函数 + is_async = inspect.iscoroutinefunction(func) + + if is_async: + @wraps(func) + async def async_wrapper(*args, **kwargs): + # 获取缓存键 + cache_key = cache_backend.get_cache_key(func, args, kwargs) + # 尝试获取缓存 + cached_value = cache_backend.get(cache_key, region=cache_region) + if should_cache(cached_value) and is_valid_cache_value(cache_key, cached_value, cache_region): + return cached_value + # 执行异步函数并缓存结果 + result = await func(*args, **kwargs) + # 判断是否需要缓存 + if not should_cache(result): + return result + # 设置缓存(如果有传入的 maxsize 和 ttl,则覆盖默认值) + cache_backend.set(cache_key, result, ttl=ttl, maxsize=maxsize, region=cache_region) return result - # 设置缓存(如果有传入的 maxsize 和 ttl,则覆盖默认值) - cache_backend.set(cache_key, result, ttl=ttl, maxsize=maxsize, region=cache_region) - return result - def cache_clear(): - """ - 清理缓存区 - """ - # 清理缓存区 - cache_backend.clear(region=cache_region) + def cache_clear(): + """ + 清理缓存区 + """ + # 清理缓存区 + cache_backend.clear(region=cache_region) - wrapper.cache_region = cache_region - wrapper.cache_clear = cache_clear - return wrapper + async_wrapper.cache_region = cache_region + async_wrapper.cache_clear = cache_clear + return async_wrapper + else: + @wraps(func) + def wrapper(*args, **kwargs): + # 获取缓存键 + cache_key = cache_backend.get_cache_key(func, args, kwargs) + # 尝试获取缓存 + cached_value = cache_backend.get(cache_key, region=cache_region) + if should_cache(cached_value) and is_valid_cache_value(cache_key, cached_value, cache_region): + return cached_value + # 执行函数并缓存结果 + result = func(*args, **kwargs) + # 判断是否需要缓存 + if not should_cache(result): + return result + # 设置缓存(如果有传入的 maxsize 和 ttl,则覆盖默认值) + cache_backend.set(cache_key, result, ttl=ttl, maxsize=maxsize, region=cache_region) + return result + + def cache_clear(): + """ + 清理缓存区 + """ + # 清理缓存区 + cache_backend.clear(region=cache_region) + + wrapper.cache_region = cache_region + wrapper.cache_clear = cache_clear + return wrapper return decorator