Merge pull request #5060 from wikrin/cached

This commit is contained in:
jxxghp
2025-10-19 22:37:33 +08:00
committed by GitHub
4 changed files with 80 additions and 44 deletions

View File

@@ -11,7 +11,7 @@ from fastapi.concurrency import run_in_threadpool
from qbittorrentapi import TorrentFilesList from qbittorrentapi import TorrentFilesList
from transmission_rpc import File from transmission_rpc import File
from app.core.cache import FileCache, AsyncFileCache from app.core.cache import FileCache, AsyncFileCache, fresh, async_fresh
from app.core.config import settings from app.core.config import settings
from app.core.context import Context, MediaInfo, TorrentInfo from app.core.context import Context, MediaInfo, TorrentInfo
from app.core.event import EventManager from app.core.event import EventManager
@@ -358,9 +358,10 @@ class ChainBase(metaclass=ABCMeta):
if tmdbid: if tmdbid:
doubanid = None doubanid = None
bangumiid = None bangumiid = None
return self.run_module("recognize_media", meta=meta, mtype=mtype, with fresh(not cache):
tmdbid=tmdbid, doubanid=doubanid, bangumiid=bangumiid, return self.run_module("recognize_media", meta=meta, mtype=mtype,
episode_group=episode_group, cache=cache) tmdbid=tmdbid, doubanid=doubanid, bangumiid=bangumiid,
episode_group=episode_group, cache=cache)
async def async_recognize_media(self, meta: MetaBase = None, async def async_recognize_media(self, meta: MetaBase = None,
mtype: Optional[MediaType] = None, mtype: Optional[MediaType] = None,
@@ -391,9 +392,10 @@ class ChainBase(metaclass=ABCMeta):
if tmdbid: if tmdbid:
doubanid = None doubanid = None
bangumiid = None bangumiid = None
return await self.async_run_module("async_recognize_media", meta=meta, mtype=mtype, async with async_fresh(not cache):
tmdbid=tmdbid, doubanid=doubanid, bangumiid=bangumiid, return await self.async_run_module("async_recognize_media", meta=meta, mtype=mtype,
episode_group=episode_group, cache=cache) tmdbid=tmdbid, doubanid=doubanid, bangumiid=bangumiid,
episode_group=episode_group, cache=cache)
def match_doubaninfo(self, name: str, imdbid: Optional[str] = None, def match_doubaninfo(self, name: str, imdbid: Optional[str] = None,
mtype: Optional[MediaType] = None, year: Optional[str] = None, season: Optional[int] = None, mtype: Optional[MediaType] = None, year: Optional[str] = None, season: Optional[int] = None,

View File

@@ -290,7 +290,7 @@ class DownloadChain(ChainBase):
# 登记下载记录 # 登记下载记录
downloadhis = DownloadHistoryOper() downloadhis = DownloadHistoryOper()
downloadhis.add( downloadhis.add(
path=str(download_path), path=download_path.as_posix(),
type=_media.type.value, type=_media.type.value,
title=_media.title, title=_media.title,
year=_media.year, year=_media.year,
@@ -331,8 +331,8 @@ class DownloadChain(ChainBase):
files_to_add.append({ files_to_add.append({
"download_hash": _hash, "download_hash": _hash,
"downloader": _downloader, "downloader": _downloader,
"fullpath": str(_save_path / file), "fullpath": (_save_path / file).as_posix(),
"savepath": str(_save_path), "savepath": _save_path.as_posix(),
"filepath": file, "filepath": file,
"torrentname": _meta.org_string, "torrentname": _meta.org_string,
}) })

View File

@@ -1,8 +1,10 @@
import contextvars
import inspect import inspect
import shutil import shutil
import tempfile import tempfile
import threading import threading
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import contextmanager, asynccontextmanager
from functools import wraps from functools import wraps
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Optional, Generator, AsyncGenerator, Tuple, Literal, Union from typing import Any, Dict, Optional, Generator, AsyncGenerator, Tuple, Literal, Union
@@ -27,6 +29,9 @@ DEFAULT_CACHE_TTL = 365 * 24 * 60 * 60
lock = threading.Lock() lock = threading.Lock()
# 上下文变量来控制缓存行为
_fresh = contextvars.ContextVar('fresh', default=False)
class CacheBackend(ABC): class CacheBackend(ABC):
""" """
@@ -1010,6 +1015,45 @@ class AsyncFileBackend(AsyncCacheBackend):
pass pass
@contextmanager
def fresh(fresh: bool = False):
"""
是否获取新数据(不使用缓存的值)
Usage:
with fresh():
result = some_cached_function()
"""
token = _fresh.set(fresh)
try:
yield
finally:
_fresh.reset(token)
@asynccontextmanager
async def async_fresh(fresh: bool = False):
"""
是否获取新数据(不使用缓存的值)
Usage:
async with async_fresh():
result = await some_async_cached_function()
"""
token = _fresh.set(fresh)
try:
yield
finally:
_fresh.reset(token)
def is_fresh() -> bool:
"""
是否获取新数据
"""
try:
return _fresh.get()
except LookupError:
return False
def FileCache(base: Path = settings.TEMP_PATH, ttl: Optional[int] = None) -> CacheBackend: def FileCache(base: Path = settings.TEMP_PATH, ttl: Optional[int] = None) -> CacheBackend:
""" """
获取文件缓存后端实例Redis或文件系统ttl仅在Redis环境中有效 获取文件缓存后端实例Redis或文件系统ttl仅在Redis环境中有效
@@ -1084,16 +1128,6 @@ def cached(region: Optional[str] = None, maxsize: Optional[int] = 1024, ttl: Opt
""" """
def decorator(func): def decorator(func):
# 检查是否为异步函数
is_async = inspect.iscoroutinefunction(func)
# 根据函数类型选择对应的缓存后端没有ttl时默认是 LRU 缓存,否则是 TTL 缓存
if is_async:
# 异步函数使用异步缓存后端
cache_backend = AsyncCache(cache_type="ttl" if ttl else "lru", maxsize=maxsize, ttl=ttl)
else:
# 同步函数使用同步缓存后端
cache_backend = Cache(cache_type="ttl" if ttl else "lru", maxsize=maxsize, ttl=ttl)
def should_cache(value: Any) -> bool: def should_cache(value: Any) -> bool:
""" """
@@ -1169,16 +1203,20 @@ def cached(region: Optional[str] = None, maxsize: Optional[int] = 1024, ttl: Opt
is_async = inspect.iscoroutinefunction(func) is_async = inspect.iscoroutinefunction(func)
if is_async: if is_async:
# 异步函数使用异步缓存后端
cache_backend = AsyncCache(cache_type="ttl" if ttl else "lru", maxsize=maxsize, ttl=ttl)
# 异步函数的缓存装饰器 # 异步函数的缓存装饰器
@wraps(func) @wraps(func)
async def async_wrapper(*args, **kwargs): async def async_wrapper(*args, **kwargs):
# 获取缓存键 # 获取缓存键
cache_key = __get_cache_key(args, kwargs) cache_key = __get_cache_key(args, kwargs)
# 尝试获取缓存
cached_value = await cache_backend.get(cache_key, region=cache_region) if not is_fresh():
if should_cache(cached_value) and await async_is_valid_cache_value(cache_key, cached_value, # 尝试获取缓存
cache_region): cached_value = await cache_backend.get(cache_key, region=cache_region)
return cached_value if should_cache(cached_value) and await async_is_valid_cache_value(cache_key, cached_value,
cache_region):
return cached_value
# 执行异步函数并缓存结果 # 执行异步函数并缓存结果
result = await func(*args, **kwargs) result = await func(*args, **kwargs)
# 判断是否需要缓存 # 判断是否需要缓存
@@ -1198,15 +1236,19 @@ def cached(region: Optional[str] = None, maxsize: Optional[int] = 1024, ttl: Opt
async_wrapper.cache_clear = cache_clear async_wrapper.cache_clear = cache_clear
return async_wrapper return async_wrapper
else: else:
# 同步函数使用同步缓存后端
cache_backend = Cache(cache_type="ttl" if ttl else "lru", maxsize=maxsize, ttl=ttl)
# 同步函数的缓存装饰器 # 同步函数的缓存装饰器
@wraps(func) @wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
# 获取缓存键 # 获取缓存键
cache_key = __get_cache_key(args, kwargs) cache_key = __get_cache_key(args, kwargs)
# 尝试获取缓存
cached_value = cache_backend.get(cache_key, region=cache_region) if not is_fresh():
if should_cache(cached_value) and is_valid_cache_value(cache_key, cached_value, cache_region): # 尝试获取缓存
return cached_value cached_value = cache_backend.get(cache_key, region=cache_region)
if should_cache(cached_value) and is_valid_cache_value(cache_key, cached_value, cache_region):
return cached_value
# 执行函数并缓存结果 # 执行函数并缓存结果
result = func(*args, **kwargs) result = func(*args, **kwargs)
# 判断是否需要缓存 # 判断是否需要缓存

View File

@@ -18,14 +18,13 @@ logger = logging.getLogger(__name__)
class TMDb(object): class TMDb(object):
def __init__(self, obj_cached=True, session=None, language=None): def __init__(self, session=None, language=None):
self._api_key = settings.TMDB_API_KEY self._api_key = settings.TMDB_API_KEY
self._language = language or settings.TMDB_LOCALE or "en-US" self._language = language or settings.TMDB_LOCALE or "en-US"
self._session_id = None self._session_id = None
self._session = session self._session = session
self._wait_on_rate_limit = True self._wait_on_rate_limit = True
self._debug_enabled = False self._debug_enabled = False
self._cache_enabled = obj_cached
self._proxies = settings.PROXY self._proxies = settings.PROXY
self._domain = settings.TMDB_API_DOMAIN self._domain = settings.TMDB_API_DOMAIN
self._page = None self._page = None
@@ -41,7 +40,6 @@ class TMDb(object):
self._remaining = 40 self._remaining = 40
self._reset = None self._reset = None
self._timeout = 15 self._timeout = 15
self.obj_cached = obj_cached
self.__clear_async_cache__ = False self.__clear_async_cache__ = False
@@ -119,14 +117,6 @@ class TMDb(object):
def debug(self, debug): def debug(self, debug):
self._debug_enabled = bool(debug) self._debug_enabled = bool(debug)
@property
def cache(self):
return self._cache_enabled
@cache.setter
def cache(self, cache):
self._cache_enabled = bool(cache)
@cached(maxsize=settings.CONF.tmdb, ttl=settings.CONF.meta, skip_none=True) @cached(maxsize=settings.CONF.tmdb, ttl=settings.CONF.meta, skip_none=True)
def cached_request(self, method, url, data, json, def cached_request(self, method, url, data, json,
_ts=datetime.strftime(datetime.now(), '%Y%m%d')): _ts=datetime.strftime(datetime.now(), '%Y%m%d')):
@@ -224,8 +214,9 @@ class TMDb(object):
self._validate_api_key() self._validate_api_key()
url = self._build_url(action, params) url = self._build_url(action, params)
if self.cache and self.obj_cached and call_cached and method != "POST": if call_cached and method != "POST":
req = self.cached_request(method, url, data, json) req = self.cached_request(method, url, data, json,
_ts=datetime.strftime(datetime.now(), '%Y%m%d'))
else: else:
req = self.request(method, url, data, json) req = self.request(method, url, data, json)
@@ -253,8 +244,9 @@ class TMDb(object):
self._validate_api_key() self._validate_api_key()
url = self._build_url(action, params) url = self._build_url(action, params)
if self.cache and self.obj_cached and call_cached and method != "POST": if call_cached and method != "POST":
req = await self.async_cached_request(method, url, data, json) req = await self.async_cached_request(method, url, data, json,
_ts=datetime.strftime(datetime.now(), '%Y%m%d'))
else: else:
req = await self.async_request(method, url, data, json) req = await self.async_request(method, url, data, json)