Merge pull request #5060 from wikrin/cached

This commit is contained in:
jxxghp
2025-10-19 22:37:33 +08:00
committed by GitHub
4 changed files with 80 additions and 44 deletions

View File

@@ -11,7 +11,7 @@ from fastapi.concurrency import run_in_threadpool
from qbittorrentapi import TorrentFilesList
from transmission_rpc import File
from app.core.cache import FileCache, AsyncFileCache
from app.core.cache import FileCache, AsyncFileCache, fresh, async_fresh
from app.core.config import settings
from app.core.context import Context, MediaInfo, TorrentInfo
from app.core.event import EventManager
@@ -358,9 +358,10 @@ class ChainBase(metaclass=ABCMeta):
if tmdbid:
doubanid = None
bangumiid = None
return self.run_module("recognize_media", meta=meta, mtype=mtype,
tmdbid=tmdbid, doubanid=doubanid, bangumiid=bangumiid,
episode_group=episode_group, cache=cache)
with fresh(not cache):
return self.run_module("recognize_media", meta=meta, mtype=mtype,
tmdbid=tmdbid, doubanid=doubanid, bangumiid=bangumiid,
episode_group=episode_group, cache=cache)
async def async_recognize_media(self, meta: MetaBase = None,
mtype: Optional[MediaType] = None,
@@ -391,9 +392,10 @@ class ChainBase(metaclass=ABCMeta):
if tmdbid:
doubanid = None
bangumiid = None
return await self.async_run_module("async_recognize_media", meta=meta, mtype=mtype,
tmdbid=tmdbid, doubanid=doubanid, bangumiid=bangumiid,
episode_group=episode_group, cache=cache)
async with async_fresh(not cache):
return await self.async_run_module("async_recognize_media", meta=meta, mtype=mtype,
tmdbid=tmdbid, doubanid=doubanid, bangumiid=bangumiid,
episode_group=episode_group, cache=cache)
def match_doubaninfo(self, name: str, imdbid: Optional[str] = None,
mtype: Optional[MediaType] = None, year: Optional[str] = None, season: Optional[int] = None,

View File

@@ -290,7 +290,7 @@ class DownloadChain(ChainBase):
# 登记下载记录
downloadhis = DownloadHistoryOper()
downloadhis.add(
path=str(download_path),
path=download_path.as_posix(),
type=_media.type.value,
title=_media.title,
year=_media.year,
@@ -331,8 +331,8 @@ class DownloadChain(ChainBase):
files_to_add.append({
"download_hash": _hash,
"downloader": _downloader,
"fullpath": str(_save_path / file),
"savepath": str(_save_path),
"fullpath": (_save_path / file).as_posix(),
"savepath": _save_path.as_posix(),
"filepath": file,
"torrentname": _meta.org_string,
})

View File

@@ -1,8 +1,10 @@
import contextvars
import inspect
import shutil
import tempfile
import threading
from abc import ABC, abstractmethod
from contextlib import contextmanager, asynccontextmanager
from functools import wraps
from pathlib import Path
from typing import Any, Dict, Optional, Generator, AsyncGenerator, Tuple, Literal, Union
@@ -27,6 +29,9 @@ DEFAULT_CACHE_TTL = 365 * 24 * 60 * 60
lock = threading.Lock()
# 上下文变量来控制缓存行为
_fresh = contextvars.ContextVar('fresh', default=False)
class CacheBackend(ABC):
"""
@@ -1010,6 +1015,45 @@ class AsyncFileBackend(AsyncCacheBackend):
pass
@contextmanager
def fresh(fresh: bool = False):
"""
是否获取新数据(不使用缓存的值)
Usage:
with fresh():
result = some_cached_function()
"""
token = _fresh.set(fresh)
try:
yield
finally:
_fresh.reset(token)
@asynccontextmanager
async def async_fresh(fresh: bool = False):
"""
是否获取新数据(不使用缓存的值)
Usage:
async with async_fresh():
result = await some_async_cached_function()
"""
token = _fresh.set(fresh)
try:
yield
finally:
_fresh.reset(token)
def is_fresh() -> bool:
"""
是否获取新数据
"""
try:
return _fresh.get()
except LookupError:
return False
def FileCache(base: Path = settings.TEMP_PATH, ttl: Optional[int] = None) -> CacheBackend:
"""
获取文件缓存后端实例Redis或文件系统ttl仅在Redis环境中有效
@@ -1084,16 +1128,6 @@ def cached(region: Optional[str] = None, maxsize: Optional[int] = 1024, ttl: Opt
"""
def decorator(func):
# 检查是否为异步函数
is_async = inspect.iscoroutinefunction(func)
# 根据函数类型选择对应的缓存后端没有ttl时默认是 LRU 缓存,否则是 TTL 缓存
if is_async:
# 异步函数使用异步缓存后端
cache_backend = AsyncCache(cache_type="ttl" if ttl else "lru", maxsize=maxsize, ttl=ttl)
else:
# 同步函数使用同步缓存后端
cache_backend = Cache(cache_type="ttl" if ttl else "lru", maxsize=maxsize, ttl=ttl)
def should_cache(value: Any) -> bool:
"""
@@ -1169,16 +1203,20 @@ def cached(region: Optional[str] = None, maxsize: Optional[int] = 1024, ttl: Opt
is_async = inspect.iscoroutinefunction(func)
if is_async:
# 异步函数使用异步缓存后端
cache_backend = AsyncCache(cache_type="ttl" if ttl else "lru", maxsize=maxsize, ttl=ttl)
# 异步函数的缓存装饰器
@wraps(func)
async def async_wrapper(*args, **kwargs):
# 获取缓存键
cache_key = __get_cache_key(args, kwargs)
# 尝试获取缓存
cached_value = await cache_backend.get(cache_key, region=cache_region)
if should_cache(cached_value) and await async_is_valid_cache_value(cache_key, cached_value,
cache_region):
return cached_value
if not is_fresh():
# 尝试获取缓存
cached_value = await cache_backend.get(cache_key, region=cache_region)
if should_cache(cached_value) and await async_is_valid_cache_value(cache_key, cached_value,
cache_region):
return cached_value
# 执行异步函数并缓存结果
result = await func(*args, **kwargs)
# 判断是否需要缓存
@@ -1198,15 +1236,19 @@ def cached(region: Optional[str] = None, maxsize: Optional[int] = 1024, ttl: Opt
async_wrapper.cache_clear = cache_clear
return async_wrapper
else:
# 同步函数使用同步缓存后端
cache_backend = Cache(cache_type="ttl" if ttl else "lru", maxsize=maxsize, ttl=ttl)
# 同步函数的缓存装饰器
@wraps(func)
def wrapper(*args, **kwargs):
# 获取缓存键
cache_key = __get_cache_key(args, kwargs)
# 尝试获取缓存
cached_value = cache_backend.get(cache_key, region=cache_region)
if should_cache(cached_value) and is_valid_cache_value(cache_key, cached_value, cache_region):
return cached_value
if not is_fresh():
# 尝试获取缓存
cached_value = cache_backend.get(cache_key, region=cache_region)
if should_cache(cached_value) and is_valid_cache_value(cache_key, cached_value, cache_region):
return cached_value
# 执行函数并缓存结果
result = func(*args, **kwargs)
# 判断是否需要缓存

View File

@@ -18,14 +18,13 @@ logger = logging.getLogger(__name__)
class TMDb(object):
def __init__(self, obj_cached=True, session=None, language=None):
def __init__(self, session=None, language=None):
self._api_key = settings.TMDB_API_KEY
self._language = language or settings.TMDB_LOCALE or "en-US"
self._session_id = None
self._session = session
self._wait_on_rate_limit = True
self._debug_enabled = False
self._cache_enabled = obj_cached
self._proxies = settings.PROXY
self._domain = settings.TMDB_API_DOMAIN
self._page = None
@@ -41,7 +40,6 @@ class TMDb(object):
self._remaining = 40
self._reset = None
self._timeout = 15
self.obj_cached = obj_cached
self.__clear_async_cache__ = False
@@ -119,14 +117,6 @@ class TMDb(object):
def debug(self, debug):
self._debug_enabled = bool(debug)
@property
def cache(self):
return self._cache_enabled
@cache.setter
def cache(self, cache):
self._cache_enabled = bool(cache)
@cached(maxsize=settings.CONF.tmdb, ttl=settings.CONF.meta, skip_none=True)
def cached_request(self, method, url, data, json,
_ts=datetime.strftime(datetime.now(), '%Y%m%d')):
@@ -224,8 +214,9 @@ class TMDb(object):
self._validate_api_key()
url = self._build_url(action, params)
if self.cache and self.obj_cached and call_cached and method != "POST":
req = self.cached_request(method, url, data, json)
if call_cached and method != "POST":
req = self.cached_request(method, url, data, json,
_ts=datetime.strftime(datetime.now(), '%Y%m%d'))
else:
req = self.request(method, url, data, json)
@@ -253,8 +244,9 @@ class TMDb(object):
self._validate_api_key()
url = self._build_url(action, params)
if self.cache and self.obj_cached and call_cached and method != "POST":
req = await self.async_cached_request(method, url, data, json)
if call_cached and method != "POST":
req = await self.async_cached_request(method, url, data, json,
_ts=datetime.strftime(datetime.now(), '%Y%m%d'))
else:
req = await self.async_request(method, url, data, json)