feat(redis): add Redis support

This commit is contained in:
InfinityPacer
2025-01-19 02:50:28 +08:00
parent 9f22ce5cc0
commit cb5c06ee7e
3 changed files with 109 additions and 27 deletions

View File

@@ -1,5 +1,4 @@
import inspect
import os
from abc import ABC, abstractmethod
from functools import wraps
from typing import Any, Dict, Optional
@@ -8,6 +7,9 @@ import redis
from cachetools import TTLCache
from cachetools.keys import hashkey
from app.core.config import settings
from app.log import logger
# 默认缓存区
DEFAULT_CACHE_REGION = "DEFAULT"
@@ -60,6 +62,13 @@ class CacheBackend(ABC):
"""
pass
@abstractmethod
def close(self) -> None:
"""
关闭缓存连接
"""
pass
@staticmethod
def get_region(region: str = DEFAULT_CACHE_REGION):
"""
@@ -154,10 +163,18 @@ class CacheToolsBackend(CacheBackend):
region_cache = self.__get_region_cache(region)
if region_cache:
region_cache.clear()
logger.info(f"Cleared cache for region: {region}")
else:
# 清除所有区域的缓存
for region_cache in self._region_caches.values():
region_cache.clear()
logger.info("Cleared all cache")
def close(self) -> None:
"""
内存缓存不需要关闭资源
"""
pass
class RedisBackend(CacheBackend):
@@ -174,14 +191,28 @@ class RedisBackend(CacheBackend):
"""
self.redis_url = redis_url
self.ttl = ttl
self.client = redis.StrictRedis.from_url(redis_url)
try:
self.client = redis.Redis.from_url(
redis_url,
decode_responses=True,
socket_timeout=30,
socket_connect_timeout=5,
max_connections=100,
health_check_interval=60,
)
# 测试连接,确保 Redis 可用
self.client.ping()
logger.debug(f"Successfully connected to Redis")
except redis.RedisError as e:
logger.error(f"Failed to connect to Redis: {e}")
raise RuntimeError("Redis connection failed") from e
@staticmethod
def get_redis_key(region, key):
def get_redis_key(self, region: str, key: str) -> str:
"""
获取缓存 Key
"""
# 使用 region 作为缓存键的一部分
region = self.get_region(region)
return f"region:{region}:key:{key}"
def set(self, key: str, value: Any, ttl: int = None, region: str = DEFAULT_CACHE_REGION, **kwargs) -> None:
@@ -194,11 +225,15 @@ class RedisBackend(CacheBackend):
:param region: 缓存的区
:param kwargs: kwargs
"""
ttl = ttl or self.ttl
redis_key = self.get_redis_key(region, key)
self.client.setex(redis_key, ttl, value)
try:
ttl = ttl or self.ttl
redis_key = self.get_redis_key(region, key)
kwargs.pop("maxsize")
self.client.set(redis_key, value, ex=ttl, **kwargs)
except redis.RedisError as e:
logger.error(f"Failed to set key: {key} in region: {region}, error: {e}")
def get(self, key: str, region: str = DEFAULT_CACHE_REGION) -> Any:
def get(self, key: str, region: str = DEFAULT_CACHE_REGION) -> Optional[Any]:
"""
获取缓存的值
@@ -206,9 +241,13 @@ class RedisBackend(CacheBackend):
:param region: 缓存的区
:return: 返回缓存的值,如果缓存不存在返回 None
"""
redis_key = self.get_redis_key(region, key)
value = self.client.get(redis_key)
return value
try:
redis_key = self.get_redis_key(region, key)
value = self.client.get(redis_key)
return value
except redis.RedisError as e:
logger.error(f"Failed to get key: {key} in region: {region}, error: {e}")
return None
def delete(self, key: str, region: str = DEFAULT_CACHE_REGION) -> None:
"""
@@ -217,24 +256,36 @@ class RedisBackend(CacheBackend):
:param key: 缓存的键
:param region: 缓存的区
"""
redis_key = self.get_redis_key(region, key)
self.client.delete(redis_key)
try:
redis_key = self.get_redis_key(region, key)
self.client.delete(redis_key)
except redis.RedisError as e:
logger.error(f"Failed to delete key: {key} in region: {region}, error: {e}")
def clear(self, region: Optional[str] = None) -> None:
"""
清除 Redis 中指定区域的缓存或全部缓存
清除指定区域的缓存或全部缓存
:param region: 缓存的区
"""
if region:
# 清除指定区域的所有键
pattern = f"{region}:*"
keys = list(self.client.keys(pattern))
if keys:
self.client.delete(*keys)
else:
# 清除所有缓存
self.client.flushdb()
try:
if region:
pattern = f"region:{region}:key:*"
for key in self.client.scan_iter(pattern):
self.client.delete(key)
logger.info(f"Cleared Redis cache for region: {region}")
else:
self.client.flushdb()
logger.info("Cleared all Redis cache")
except redis.RedisError as e:
logger.error(f"Failed to clear cache, region: {region}, error: {e}")
def close(self) -> None:
"""
关闭 Redis 客户端的连接池
"""
if self.client:
self.client.close()
def get_cache_backend(maxsize: int = 1000, ttl: int = 1800) -> CacheBackend:
@@ -245,10 +296,23 @@ def get_cache_backend(maxsize: int = 1000, ttl: int = 1800) -> CacheBackend:
:param ttl: 缓存的默认存活时间,单位秒
:return: 返回缓存后端实例
"""
cache_type = os.getenv("CACHE_TYPE", "cachetools").lower()
cache_type = settings.CACHE_BACKEND_TYPE
logger.debug(f"Cache backend type from settings: {cache_type}")
if cache_type == "redis":
return RedisBackend(redis_url=os.getenv("REDIS_URL", "redis://localhost"))
redis_url = settings.CACHE_BACKEND_URL
if redis_url:
try:
logger.debug(f"Attempting to use RedisBackend with URL: {redis_url}, TTL: {ttl}")
return RedisBackend(redis_url=redis_url, ttl=ttl)
except RuntimeError:
logger.warning("Falling back to CacheToolsBackend due to Redis connection failure.")
else:
logger.debug("Cache backend type is redis, but no valid REDIS_URL found. "
"Falling back to CacheToolsBackend.")
# 如果不是 Redis回退到内存缓存
logger.debug(f"Using CacheToolsBackend with default maxsize: {maxsize}, TTL: {ttl}")
return CacheToolsBackend(maxsize=maxsize, ttl=ttl)
@@ -330,7 +394,6 @@ def cached(region: Optional[str] = None, maxsize: int = 1000, ttl: int = 1800,
"""
# 清理缓存区
cache_backend.clear(region=cache_region)
print(f"{cache_region} region cache is cleared")
wrapper.cache_region = cache_region
wrapper.cache_clear = cache_clear
@@ -341,3 +404,15 @@ def cached(region: Optional[str] = None, maxsize: int = 1000, ttl: int = 1800,
# 缓存后端实例
cache_backend = get_cache_backend()
def close_cache() -> None:
"""
关闭缓存后端连接并清理资源
"""
try:
if cache_backend:
cache_backend.close()
logger.info("Cache backend closed successfully.")
except Exception as e:
logger.info(f"Error while closing cache backend: {e}")

View File

@@ -71,6 +71,10 @@ class ConfigModel(BaseModel):
DB_TIMEOUT: int = 60
# SQLite 是否启用 WAL 模式,默认关闭
DB_WAL_ENABLE: bool = False
# 缓存类型,支持 cachetools 和 redis默认使用 cachetools
CACHE_BACKEND_TYPE: str = "cachetools"
# 缓存连接字符串,仅外部缓存(如 Redis、Memcached需要
CACHE_BACKEND_URL: Optional[str] = None
# 配置文件目录
CONFIG_DIR: Optional[str] = None
# 超级管理员
@@ -351,7 +355,7 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
return default, True
@validator('*', pre=True, always=True)
def generic_type_validator(cls, value: Any, field): # noqa
def generic_type_validator(cls, value: Any, field): # noqa
"""
通用校验器,尝试将配置值转换为期望的类型
"""

View File

@@ -2,6 +2,7 @@ import sys
from fastapi import FastAPI
from app.core.cache import close_cache
from app.core.config import global_vars, settings
from app.core.module import ModuleManager
from app.log import logger
@@ -129,6 +130,8 @@ def shutdown_modules(_: FastAPI):
Monitor().stop()
# 停止线程池
ThreadHelper().shutdown()
# 停止缓存连接
close_cache()
# 停止数据库连接
close_database()
# 停止前端服务