diff --git a/app/chain/workflow.py b/app/chain/workflow.py index f726ce44..ad4b0804 100644 --- a/app/chain/workflow.py +++ b/app/chain/workflow.py @@ -115,6 +115,11 @@ class WorkflowExecutor: # 工作流数据 self.workflow = workflow self.step_callback = step_callback + self.step_callback_extended = ( + len(inspect.signature(step_callback).parameters) > 2 + if step_callback + else False + ) self.actions = {action['id']: Action(**action) for action in workflow.actions} self.flows = [ActionFlow(**flow) for flow in workflow.flows] self.execution_config = getattr(workflow, "execution_config", None) or {} @@ -820,11 +825,10 @@ class WorkflowExecutor: """ if not self.step_callback: return - callback_params = inspect.signature(self.step_callback).parameters - if len(callback_params) <= 2: - self.step_callback(action, self.context) + if self.step_callback_extended: + self.step_callback(action, self.context, self.build_execution_state(), completed) return - self.step_callback(action, self.context, self.build_execution_state(), completed) + self.step_callback(action, self.context) @staticmethod def extract_context_outputs(context: ActionContext) -> dict: diff --git a/app/core/cache.py b/app/core/cache.py index 5e450939..ca57377c 100644 --- a/app/core/cache.py +++ b/app/core/cache.py @@ -1079,6 +1079,14 @@ def cached(region: Optional[str] = None, maxsize: Optional[int] = 1024, ttl: Opt """ def decorator(func): + # 函数签名在装饰后不会变化,预计算可避免每次缓存访问都重复反射。 + signature = inspect.signature(func) + parameter_names = list(signature.parameters.keys()) + cache_parameter_names = ( + parameter_names[1:] + if parameter_names and parameter_names[0] in ("self", "cls") + else parameter_names + ) def should_cache(value: Any) -> bool: """ @@ -1143,17 +1151,12 @@ def cached(region: Optional[str] = None, maxsize: Optional[int] = 1024, ttl: Opt :param kwargs: 关键字参数 :return: 缓存键 """ - signature = inspect.signature(func) # 绑定传入的参数并应用默认值 bound = signature.bind(*args, **kwargs) bound.apply_defaults() - # 忽略第一个参数,如果它是实例(self)或类(cls) - parameters = list(signature.parameters.keys()) - if parameters and parameters[0] in ("self", "cls"): - bound.arguments.pop(parameters[0], None) # 按照函数签名顺序提取参数值列表 keys = [ - bound.arguments[param] for param in signature.parameters if param in bound.arguments + bound.arguments[param] for param in cache_parameter_names if param in bound.arguments ] # 使用有序参数生成缓存键 return f"{func_name}_{hashkey(*keys)}" diff --git a/app/helper/doh.py b/app/helper/doh.py index f577b938..620be393 100644 --- a/app/helper/doh.py +++ b/app/helper/doh.py @@ -29,12 +29,12 @@ _doh_lock = Lock() _orig_getaddrinfo = socket.getaddrinfo -def enable_doh(enable: bool): +def enable_doh(enable: bool) -> None: """ 对 socket.getaddrinfo 进行补丁 """ - def _patched_getaddrinfo(host, *args, **kwargs): + def _patched_getaddrinfo(host: str, *args, **kwargs): """ socket.getaddrinfo的补丁版本。 """ @@ -42,9 +42,9 @@ def enable_doh(enable: bool): return _orig_getaddrinfo(host, *args, **kwargs) # 检查主机是否已解析 with _doh_lock: - ip = _doh_cache.get("host", None) + ip = _doh_cache.get(host, None) if ip is not None: - logger.info("已解析 [%s] 为 [%s] (缓存)", host, ip) + logger.info(f"已解析 [{host}] 为 [{ip}] (缓存)") return _orig_getaddrinfo(ip, *args, **kwargs) # 使用DoH解析主机 futures = [] @@ -53,7 +53,7 @@ def enable_doh(enable: bool): for future in concurrent.futures.as_completed(futures): ip = future.result() if ip is not None: - logger.info("已解析 [%s] 为 [%s]", host, ip) + logger.info(f"已解析 [{host}] 为 [{ip}]") with _doh_lock: _doh_cache[host] = ip host = ip @@ -73,16 +73,16 @@ class DohHelper(ConfigReloadMixin, metaclass=Singleton): """ CONFIG_WATCH = {"DOH_ENABLE", "DOH_DOMAINS", "DOH_RESOLVERS"} - def __init__(self): + def __init__(self) -> None: enable_doh(settings.DOH_ENABLE) - def on_config_changed(self): + def on_config_changed(self) -> None: with _doh_lock: # DOH配置有变动的情况下,清空缓存 _doh_cache.clear() enable_doh(settings.DOH_ENABLE) - def get_reload_name(self): + def get_reload_name(self) -> str: return 'DoH' def _doh_query(resolver: str, host: str) -> Optional[str]: @@ -121,11 +121,11 @@ def _doh_query(resolver: str, host: str) -> Optional[str]: b64message = base64.b64encode(message).decode("utf-8").rstrip("=") url = f"https://{resolver}/dns-query?dns={b64message}" headers = {"Content-Type": "application/dns-message"} - logger.debug("DoH请求: %s", url) + logger.debug(f"DoH请求: {url}") request = urllib.request.Request(url, headers=headers, method="GET") with urllib.request.urlopen(request, timeout=_doh_timeout) as response: - logger.debug("解析器(%s)响应: %s", resolver, response.status) + logger.debug(f"解析器({resolver})响应: {response.status}") if response.status != 200: return None resp_body = response.read() @@ -138,7 +138,7 @@ def _doh_query(resolver: str, host: str) -> Optional[str]: # 将rdata转换为IP地址 return socket.inet_ntoa(resp_body[first_rdata_start:first_rdata_end]) except Exception as e: - logger.error("解析器(%s)请求错误: %s", resolver, e) + logger.error(f"解析器({resolver})请求错误: {e}") return None @@ -148,17 +148,17 @@ def doh_query_json(resolver: str, host: str) -> Optional[str]: """ url = f"https://{resolver}/dns-query?name={host}&type=A" headers = {"Accept": "application/dns-json"} - logger.debug("DoH请求: %s", url) + logger.debug(f"DoH请求: {url}") try: request = urllib.request.Request(url, headers=headers, method="GET") with urllib.request.urlopen(request, timeout=_doh_timeout) as response: - logger.debug("解析器(%s)响应: %s", resolver, response.status) + logger.debug(f"解析器({resolver})响应: {response.status}") if response.status != 200: return None response_body = response.read().decode("utf-8") - logger.debug("<== body: %s", response_body) + logger.debug(f"<== body: {response_body}") answer = json.loads(response_body)["Answer"] return answer[0]["data"] except Exception as e: - logger.error("解析器(%s)请求错误: %s", resolver, e) + logger.error(f"解析器({resolver})请求错误: {e}") return None diff --git a/tests/test_doh_helper.py b/tests/test_doh_helper.py new file mode 100644 index 00000000..544483c9 --- /dev/null +++ b/tests/test_doh_helper.py @@ -0,0 +1,41 @@ +import socket + +from app.helper import doh + + +def test_enable_doh_reuses_cached_host_resolution(monkeypatch): + """ + 同一 DoH 域名第二次解析应命中缓存,避免重复请求远端解析器。 + """ + query_calls = [] + resolved_hosts = [] + + def fake_query(resolver: str, host: str) -> str: + query_calls.append((resolver, host)) + return "203.0.113.7" + + def fake_getaddrinfo(host: str, *args, **kwargs): + resolved_hosts.append(host) + return [(socket.AF_INET, socket.SOCK_STREAM, 0, "", (host, 0))] + + monkeypatch.setattr(doh.settings, "DOH_DOMAINS", "example.com") + monkeypatch.setattr(doh.settings, "DOH_RESOLVERS", "resolver.test") + monkeypatch.setattr(doh, "_doh_query", fake_query) + monkeypatch.setattr(doh, "_orig_getaddrinfo", fake_getaddrinfo) + + original_getaddrinfo = socket.getaddrinfo + with doh._doh_lock: + doh._doh_cache.clear() + + try: + doh.enable_doh(True) + + socket.getaddrinfo("example.com", None) + socket.getaddrinfo("example.com", None) + finally: + socket.getaddrinfo = original_getaddrinfo + with doh._doh_lock: + doh._doh_cache.clear() + + assert query_calls == [("resolver.test", "example.com")] + assert resolved_hosts == ["203.0.113.7", "203.0.113.7"]