diff --git a/app/helper/__init__.py b/app/helper/__init__.py index 99af9db9..ff2e48a2 100644 --- a/app/helper/__init__.py +++ b/app/helper/__init__.py @@ -1,2 +1 @@ -from .doh import doh_query_json from .cloudflare import under_challenge diff --git a/app/helper/doh.py b/app/helper/doh.py index e730d70b..112cec80 100644 --- a/app/helper/doh.py +++ b/app/helper/doh.py @@ -10,10 +10,14 @@ import socket import struct import urllib import urllib.request +from threading import Lock from typing import Dict, Optional from app.core.config import settings +from app.core.event import Event, eventmanager from app.log import logger +from app.schemas import ConfigChangeEventData +from app.schemas.types import EventType # 定义一个全局线程池执行器 _executor = concurrent.futures.ThreadPoolExecutor() @@ -21,11 +25,15 @@ _executor = concurrent.futures.ThreadPoolExecutor() # 定义默认的DoH配置 _doh_timeout = 5 _doh_cache: Dict[str, str] = {} +_doh_lock = Lock() +# 保存原始的 socket.getaddrinfo 方法 +_orig_getaddrinfo = socket.getaddrinfo -# 对 socket.getaddrinfo 进行补丁 -if settings.DOH_ENABLE: - # 保存原始的 socket.getaddrinfo 方法 - _orig_getaddrinfo = socket.getaddrinfo + +def enable_doh(enable: bool): + """ + 对 socket.getaddrinfo 进行补丁 + """ def _patched_getaddrinfo(host, *args, **kwargs): """ @@ -34,8 +42,9 @@ if settings.DOH_ENABLE: if host not in settings.DOH_DOMAINS.split(","): return _orig_getaddrinfo(host, *args, **kwargs) # 检查主机是否已解析 - if host in _doh_cache: - ip = _doh_cache[host] + with _doh_lock: + ip = _doh_cache.get("host", None) + if ip is not None: logger.info("已解析 [%s] 为 [%s] (缓存)", host, ip) return _orig_getaddrinfo(ip, *args, **kwargs) # 使用DoH解析主机 @@ -46,13 +55,34 @@ if settings.DOH_ENABLE: ip = future.result() if ip is not None: logger.info("已解析 [%s] 为 [%s]", host, ip) - _doh_cache[host] = ip + with _doh_lock: + _doh_cache[host] = ip host = ip break return _orig_getaddrinfo(host, *args, **kwargs) - # 替换 socket.getaddrinfo 方法 - socket.getaddrinfo = _patched_getaddrinfo + if enable: + # 替换 socket.getaddrinfo 方法 + socket.getaddrinfo = _patched_getaddrinfo + else: + socket.getaddrinfo = _orig_getaddrinfo + +class DohHelper: + def __init__(self): + enable_doh(settings.DOH_ENABLE) + + @eventmanager.register(EventType.ConfigChanged) + @staticmethod + def handle_config_changed(event: Event): + if not event: + return + event_data: ConfigChangeEventData = event.event_data + if event_data.key not in ["DOH_ENABLE", "DOH_DOMAINS", "DOH_RESOLVERS"]: + return + with _doh_lock: + # DOH配置有变动的情况下,清空缓存 + _doh_cache.clear() + enable_doh(settings.DOH_ENABLE) def _doh_query(resolver: str, host: str) -> Optional[str]: diff --git a/app/startup/modules_initializer.py b/app/startup/modules_initializer.py index 9c459e8d..9f6301e5 100644 --- a/app/startup/modules_initializer.py +++ b/app/startup/modules_initializer.py @@ -19,6 +19,7 @@ except ImportError as e: from app.core.event import EventManager from app.helper.thread import ThreadHelper from app.helper.display import DisplayHelper +from app.helper.doh import DohHelper from app.helper.resource import ResourceHelper from app.helper.message import MessageHelper from app.schemas import Notification, NotificationType @@ -132,6 +133,8 @@ def init_modules(): """ # 虚拟显示 DisplayHelper() + # DoH + DohHelper() # 站点管理 SitesHelper() # 资源包检测