fix: enhance caching mechanism and improve type hints in DoH and workflow modules

This commit is contained in:
jxxghp
2026-06-12 08:09:54 +08:00
parent d26225b998
commit 83cc7ea716
4 changed files with 73 additions and 25 deletions

View File

@@ -115,6 +115,11 @@ class WorkflowExecutor:
# 工作流数据
self.workflow = workflow
self.step_callback = step_callback
self.step_callback_extended = (
len(inspect.signature(step_callback).parameters) > 2
if step_callback
else False
)
self.actions = {action['id']: Action(**action) for action in workflow.actions}
self.flows = [ActionFlow(**flow) for flow in workflow.flows]
self.execution_config = getattr(workflow, "execution_config", None) or {}
@@ -820,11 +825,10 @@ class WorkflowExecutor:
"""
if not self.step_callback:
return
callback_params = inspect.signature(self.step_callback).parameters
if len(callback_params) <= 2:
self.step_callback(action, self.context)
if self.step_callback_extended:
self.step_callback(action, self.context, self.build_execution_state(), completed)
return
self.step_callback(action, self.context, self.build_execution_state(), completed)
self.step_callback(action, self.context)
@staticmethod
def extract_context_outputs(context: ActionContext) -> dict:

View File

@@ -1079,6 +1079,14 @@ def cached(region: Optional[str] = None, maxsize: Optional[int] = 1024, ttl: Opt
"""
def decorator(func):
# 函数签名在装饰后不会变化,预计算可避免每次缓存访问都重复反射。
signature = inspect.signature(func)
parameter_names = list(signature.parameters.keys())
cache_parameter_names = (
parameter_names[1:]
if parameter_names and parameter_names[0] in ("self", "cls")
else parameter_names
)
def should_cache(value: Any) -> bool:
"""
@@ -1143,17 +1151,12 @@ def cached(region: Optional[str] = None, maxsize: Optional[int] = 1024, ttl: Opt
:param kwargs: 关键字参数
:return: 缓存键
"""
signature = inspect.signature(func)
# 绑定传入的参数并应用默认值
bound = signature.bind(*args, **kwargs)
bound.apply_defaults()
# 忽略第一个参数,如果它是实例(self)或类(cls)
parameters = list(signature.parameters.keys())
if parameters and parameters[0] in ("self", "cls"):
bound.arguments.pop(parameters[0], None)
# 按照函数签名顺序提取参数值列表
keys = [
bound.arguments[param] for param in signature.parameters if param in bound.arguments
bound.arguments[param] for param in cache_parameter_names if param in bound.arguments
]
# 使用有序参数生成缓存键
return f"{func_name}_{hashkey(*keys)}"

View File

@@ -29,12 +29,12 @@ _doh_lock = Lock()
_orig_getaddrinfo = socket.getaddrinfo
def enable_doh(enable: bool):
def enable_doh(enable: bool) -> None:
"""
对 socket.getaddrinfo 进行补丁
"""
def _patched_getaddrinfo(host, *args, **kwargs):
def _patched_getaddrinfo(host: str, *args, **kwargs):
"""
socket.getaddrinfo的补丁版本。
"""
@@ -42,9 +42,9 @@ def enable_doh(enable: bool):
return _orig_getaddrinfo(host, *args, **kwargs)
# 检查主机是否已解析
with _doh_lock:
ip = _doh_cache.get("host", None)
ip = _doh_cache.get(host, None)
if ip is not None:
logger.info("已解析 [%s] 为 [%s] (缓存)", host, ip)
logger.info(f"已解析 [{host}] 为 [{ip}] (缓存)")
return _orig_getaddrinfo(ip, *args, **kwargs)
# 使用DoH解析主机
futures = []
@@ -53,7 +53,7 @@ def enable_doh(enable: bool):
for future in concurrent.futures.as_completed(futures):
ip = future.result()
if ip is not None:
logger.info("已解析 [%s] 为 [%s]", host, ip)
logger.info(f"已解析 [{host}] 为 [{ip}]")
with _doh_lock:
_doh_cache[host] = ip
host = ip
@@ -73,16 +73,16 @@ class DohHelper(ConfigReloadMixin, metaclass=Singleton):
"""
CONFIG_WATCH = {"DOH_ENABLE", "DOH_DOMAINS", "DOH_RESOLVERS"}
def __init__(self):
def __init__(self) -> None:
enable_doh(settings.DOH_ENABLE)
def on_config_changed(self):
def on_config_changed(self) -> None:
with _doh_lock:
# DOH配置有变动的情况下清空缓存
_doh_cache.clear()
enable_doh(settings.DOH_ENABLE)
def get_reload_name(self):
def get_reload_name(self) -> str:
return 'DoH'
def _doh_query(resolver: str, host: str) -> Optional[str]:
@@ -121,11 +121,11 @@ def _doh_query(resolver: str, host: str) -> Optional[str]:
b64message = base64.b64encode(message).decode("utf-8").rstrip("=")
url = f"https://{resolver}/dns-query?dns={b64message}"
headers = {"Content-Type": "application/dns-message"}
logger.debug("DoH请求: %s", url)
logger.debug(f"DoH请求: {url}")
request = urllib.request.Request(url, headers=headers, method="GET")
with urllib.request.urlopen(request, timeout=_doh_timeout) as response:
logger.debug("解析器(%s)响应: %s", resolver, response.status)
logger.debug(f"解析器({resolver})响应: {response.status}")
if response.status != 200:
return None
resp_body = response.read()
@@ -138,7 +138,7 @@ def _doh_query(resolver: str, host: str) -> Optional[str]:
# 将rdata转换为IP地址
return socket.inet_ntoa(resp_body[first_rdata_start:first_rdata_end])
except Exception as e:
logger.error("解析器(%s)请求错误: %s", resolver, e)
logger.error(f"解析器({resolver})请求错误: {e}")
return None
@@ -148,17 +148,17 @@ def doh_query_json(resolver: str, host: str) -> Optional[str]:
"""
url = f"https://{resolver}/dns-query?name={host}&type=A"
headers = {"Accept": "application/dns-json"}
logger.debug("DoH请求: %s", url)
logger.debug(f"DoH请求: {url}")
try:
request = urllib.request.Request(url, headers=headers, method="GET")
with urllib.request.urlopen(request, timeout=_doh_timeout) as response:
logger.debug("解析器(%s)响应: %s", resolver, response.status)
logger.debug(f"解析器({resolver})响应: {response.status}")
if response.status != 200:
return None
response_body = response.read().decode("utf-8")
logger.debug("<== body: %s", response_body)
logger.debug(f"<== body: {response_body}")
answer = json.loads(response_body)["Answer"]
return answer[0]["data"]
except Exception as e:
logger.error("解析器(%s)请求错误: %s", resolver, e)
logger.error(f"解析器({resolver})请求错误: {e}")
return None

41
tests/test_doh_helper.py Normal file
View File

@@ -0,0 +1,41 @@
import socket
from app.helper import doh
def test_enable_doh_reuses_cached_host_resolution(monkeypatch):
"""
同一 DoH 域名第二次解析应命中缓存,避免重复请求远端解析器。
"""
query_calls = []
resolved_hosts = []
def fake_query(resolver: str, host: str) -> str:
query_calls.append((resolver, host))
return "203.0.113.7"
def fake_getaddrinfo(host: str, *args, **kwargs):
resolved_hosts.append(host)
return [(socket.AF_INET, socket.SOCK_STREAM, 0, "", (host, 0))]
monkeypatch.setattr(doh.settings, "DOH_DOMAINS", "example.com")
monkeypatch.setattr(doh.settings, "DOH_RESOLVERS", "resolver.test")
monkeypatch.setattr(doh, "_doh_query", fake_query)
monkeypatch.setattr(doh, "_orig_getaddrinfo", fake_getaddrinfo)
original_getaddrinfo = socket.getaddrinfo
with doh._doh_lock:
doh._doh_cache.clear()
try:
doh.enable_doh(True)
socket.getaddrinfo("example.com", None)
socket.getaddrinfo("example.com", None)
finally:
socket.getaddrinfo = original_getaddrinfo
with doh._doh_lock:
doh._doh_cache.clear()
assert query_calls == [("resolver.test", "example.com")]
assert resolved_hosts == ["203.0.113.7", "203.0.113.7"]