mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-06-16 06:58:08 +08:00
fix: enhance caching mechanism and improve type hints in DoH and workflow modules
This commit is contained in:
@@ -115,6 +115,11 @@ class WorkflowExecutor:
|
||||
# 工作流数据
|
||||
self.workflow = workflow
|
||||
self.step_callback = step_callback
|
||||
self.step_callback_extended = (
|
||||
len(inspect.signature(step_callback).parameters) > 2
|
||||
if step_callback
|
||||
else False
|
||||
)
|
||||
self.actions = {action['id']: Action(**action) for action in workflow.actions}
|
||||
self.flows = [ActionFlow(**flow) for flow in workflow.flows]
|
||||
self.execution_config = getattr(workflow, "execution_config", None) or {}
|
||||
@@ -820,11 +825,10 @@ class WorkflowExecutor:
|
||||
"""
|
||||
if not self.step_callback:
|
||||
return
|
||||
callback_params = inspect.signature(self.step_callback).parameters
|
||||
if len(callback_params) <= 2:
|
||||
self.step_callback(action, self.context)
|
||||
if self.step_callback_extended:
|
||||
self.step_callback(action, self.context, self.build_execution_state(), completed)
|
||||
return
|
||||
self.step_callback(action, self.context, self.build_execution_state(), completed)
|
||||
self.step_callback(action, self.context)
|
||||
|
||||
@staticmethod
|
||||
def extract_context_outputs(context: ActionContext) -> dict:
|
||||
|
||||
@@ -1079,6 +1079,14 @@ def cached(region: Optional[str] = None, maxsize: Optional[int] = 1024, ttl: Opt
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
# 函数签名在装饰后不会变化,预计算可避免每次缓存访问都重复反射。
|
||||
signature = inspect.signature(func)
|
||||
parameter_names = list(signature.parameters.keys())
|
||||
cache_parameter_names = (
|
||||
parameter_names[1:]
|
||||
if parameter_names and parameter_names[0] in ("self", "cls")
|
||||
else parameter_names
|
||||
)
|
||||
|
||||
def should_cache(value: Any) -> bool:
|
||||
"""
|
||||
@@ -1143,17 +1151,12 @@ def cached(region: Optional[str] = None, maxsize: Optional[int] = 1024, ttl: Opt
|
||||
:param kwargs: 关键字参数
|
||||
:return: 缓存键
|
||||
"""
|
||||
signature = inspect.signature(func)
|
||||
# 绑定传入的参数并应用默认值
|
||||
bound = signature.bind(*args, **kwargs)
|
||||
bound.apply_defaults()
|
||||
# 忽略第一个参数,如果它是实例(self)或类(cls)
|
||||
parameters = list(signature.parameters.keys())
|
||||
if parameters and parameters[0] in ("self", "cls"):
|
||||
bound.arguments.pop(parameters[0], None)
|
||||
# 按照函数签名顺序提取参数值列表
|
||||
keys = [
|
||||
bound.arguments[param] for param in signature.parameters if param in bound.arguments
|
||||
bound.arguments[param] for param in cache_parameter_names if param in bound.arguments
|
||||
]
|
||||
# 使用有序参数生成缓存键
|
||||
return f"{func_name}_{hashkey(*keys)}"
|
||||
|
||||
@@ -29,12 +29,12 @@ _doh_lock = Lock()
|
||||
_orig_getaddrinfo = socket.getaddrinfo
|
||||
|
||||
|
||||
def enable_doh(enable: bool):
|
||||
def enable_doh(enable: bool) -> None:
|
||||
"""
|
||||
对 socket.getaddrinfo 进行补丁
|
||||
"""
|
||||
|
||||
def _patched_getaddrinfo(host, *args, **kwargs):
|
||||
def _patched_getaddrinfo(host: str, *args, **kwargs):
|
||||
"""
|
||||
socket.getaddrinfo的补丁版本。
|
||||
"""
|
||||
@@ -42,9 +42,9 @@ def enable_doh(enable: bool):
|
||||
return _orig_getaddrinfo(host, *args, **kwargs)
|
||||
# 检查主机是否已解析
|
||||
with _doh_lock:
|
||||
ip = _doh_cache.get("host", None)
|
||||
ip = _doh_cache.get(host, None)
|
||||
if ip is not None:
|
||||
logger.info("已解析 [%s] 为 [%s] (缓存)", host, ip)
|
||||
logger.info(f"已解析 [{host}] 为 [{ip}] (缓存)")
|
||||
return _orig_getaddrinfo(ip, *args, **kwargs)
|
||||
# 使用DoH解析主机
|
||||
futures = []
|
||||
@@ -53,7 +53,7 @@ def enable_doh(enable: bool):
|
||||
for future in concurrent.futures.as_completed(futures):
|
||||
ip = future.result()
|
||||
if ip is not None:
|
||||
logger.info("已解析 [%s] 为 [%s]", host, ip)
|
||||
logger.info(f"已解析 [{host}] 为 [{ip}]")
|
||||
with _doh_lock:
|
||||
_doh_cache[host] = ip
|
||||
host = ip
|
||||
@@ -73,16 +73,16 @@ class DohHelper(ConfigReloadMixin, metaclass=Singleton):
|
||||
"""
|
||||
CONFIG_WATCH = {"DOH_ENABLE", "DOH_DOMAINS", "DOH_RESOLVERS"}
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
enable_doh(settings.DOH_ENABLE)
|
||||
|
||||
def on_config_changed(self):
|
||||
def on_config_changed(self) -> None:
|
||||
with _doh_lock:
|
||||
# DOH配置有变动的情况下,清空缓存
|
||||
_doh_cache.clear()
|
||||
enable_doh(settings.DOH_ENABLE)
|
||||
|
||||
def get_reload_name(self):
|
||||
def get_reload_name(self) -> str:
|
||||
return 'DoH'
|
||||
|
||||
def _doh_query(resolver: str, host: str) -> Optional[str]:
|
||||
@@ -121,11 +121,11 @@ def _doh_query(resolver: str, host: str) -> Optional[str]:
|
||||
b64message = base64.b64encode(message).decode("utf-8").rstrip("=")
|
||||
url = f"https://{resolver}/dns-query?dns={b64message}"
|
||||
headers = {"Content-Type": "application/dns-message"}
|
||||
logger.debug("DoH请求: %s", url)
|
||||
logger.debug(f"DoH请求: {url}")
|
||||
|
||||
request = urllib.request.Request(url, headers=headers, method="GET")
|
||||
with urllib.request.urlopen(request, timeout=_doh_timeout) as response:
|
||||
logger.debug("解析器(%s)响应: %s", resolver, response.status)
|
||||
logger.debug(f"解析器({resolver})响应: {response.status}")
|
||||
if response.status != 200:
|
||||
return None
|
||||
resp_body = response.read()
|
||||
@@ -138,7 +138,7 @@ def _doh_query(resolver: str, host: str) -> Optional[str]:
|
||||
# 将rdata转换为IP地址
|
||||
return socket.inet_ntoa(resp_body[first_rdata_start:first_rdata_end])
|
||||
except Exception as e:
|
||||
logger.error("解析器(%s)请求错误: %s", resolver, e)
|
||||
logger.error(f"解析器({resolver})请求错误: {e}")
|
||||
return None
|
||||
|
||||
|
||||
@@ -148,17 +148,17 @@ def doh_query_json(resolver: str, host: str) -> Optional[str]:
|
||||
"""
|
||||
url = f"https://{resolver}/dns-query?name={host}&type=A"
|
||||
headers = {"Accept": "application/dns-json"}
|
||||
logger.debug("DoH请求: %s", url)
|
||||
logger.debug(f"DoH请求: {url}")
|
||||
try:
|
||||
request = urllib.request.Request(url, headers=headers, method="GET")
|
||||
with urllib.request.urlopen(request, timeout=_doh_timeout) as response:
|
||||
logger.debug("解析器(%s)响应: %s", resolver, response.status)
|
||||
logger.debug(f"解析器({resolver})响应: {response.status}")
|
||||
if response.status != 200:
|
||||
return None
|
||||
response_body = response.read().decode("utf-8")
|
||||
logger.debug("<== body: %s", response_body)
|
||||
logger.debug(f"<== body: {response_body}")
|
||||
answer = json.loads(response_body)["Answer"]
|
||||
return answer[0]["data"]
|
||||
except Exception as e:
|
||||
logger.error("解析器(%s)请求错误: %s", resolver, e)
|
||||
logger.error(f"解析器({resolver})请求错误: {e}")
|
||||
return None
|
||||
|
||||
41
tests/test_doh_helper.py
Normal file
41
tests/test_doh_helper.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import socket
|
||||
|
||||
from app.helper import doh
|
||||
|
||||
|
||||
def test_enable_doh_reuses_cached_host_resolution(monkeypatch):
|
||||
"""
|
||||
同一 DoH 域名第二次解析应命中缓存,避免重复请求远端解析器。
|
||||
"""
|
||||
query_calls = []
|
||||
resolved_hosts = []
|
||||
|
||||
def fake_query(resolver: str, host: str) -> str:
|
||||
query_calls.append((resolver, host))
|
||||
return "203.0.113.7"
|
||||
|
||||
def fake_getaddrinfo(host: str, *args, **kwargs):
|
||||
resolved_hosts.append(host)
|
||||
return [(socket.AF_INET, socket.SOCK_STREAM, 0, "", (host, 0))]
|
||||
|
||||
monkeypatch.setattr(doh.settings, "DOH_DOMAINS", "example.com")
|
||||
monkeypatch.setattr(doh.settings, "DOH_RESOLVERS", "resolver.test")
|
||||
monkeypatch.setattr(doh, "_doh_query", fake_query)
|
||||
monkeypatch.setattr(doh, "_orig_getaddrinfo", fake_getaddrinfo)
|
||||
|
||||
original_getaddrinfo = socket.getaddrinfo
|
||||
with doh._doh_lock:
|
||||
doh._doh_cache.clear()
|
||||
|
||||
try:
|
||||
doh.enable_doh(True)
|
||||
|
||||
socket.getaddrinfo("example.com", None)
|
||||
socket.getaddrinfo("example.com", None)
|
||||
finally:
|
||||
socket.getaddrinfo = original_getaddrinfo
|
||||
with doh._doh_lock:
|
||||
doh._doh_cache.clear()
|
||||
|
||||
assert query_calls == [("resolver.test", "example.com")]
|
||||
assert resolved_hosts == ["203.0.113.7", "203.0.113.7"]
|
||||
Reference in New Issue
Block a user