From 2483223ee6e90497f2804980ff395c9595c037e9 Mon Sep 17 00:00:00 2001 From: Idris Yandarov <32651311+r142f@users.noreply.github.com> Date: Mon, 26 Jan 2026 06:36:36 +0300 Subject: [PATCH 01/19] feat: add nearest DC detection with TCP race --- tests/aio/test_nearest_dc.py | 145 ++++++++++++++++++++++++++++++ tests/test_nearest_dc.py | 132 ++++++++++++++++++++++++++++ ydb/aio/nearest_dc.py | 164 ++++++++++++++++++++++++++++++++++ ydb/aio/pool.py | 26 +++++- ydb/driver.py | 4 + ydb/nearest_dc.py | 166 +++++++++++++++++++++++++++++++++++ ydb/pool.py | 26 +++++- 7 files changed, 659 insertions(+), 4 deletions(-) create mode 100644 tests/aio/test_nearest_dc.py create mode 100644 tests/test_nearest_dc.py create mode 100644 ydb/aio/nearest_dc.py create mode 100644 ydb/nearest_dc.py diff --git a/tests/aio/test_nearest_dc.py b/tests/aio/test_nearest_dc.py new file mode 100644 index 00000000..bd252cff --- /dev/null +++ b/tests/aio/test_nearest_dc.py @@ -0,0 +1,145 @@ +import asyncio +import pytest +from ydb.aio import nearest_dc + + +class MockEndpoint: + def __init__(self, address, port, location): + self.address = address + self.port = port + self.endpoint = f"{address}:{port}" + self.location = location + + +class MockWriter: + def __init__(self): + self.closed = False + + def close(self): + self.closed = True + + async def wait_closed(self): + await asyncio.sleep(0) + + +@pytest.mark.asyncio +async def test_check_fastest_endpoint_empty(): + assert await nearest_dc._check_fastest_endpoint([]) is None + + +@pytest.mark.asyncio +async def test_check_fastest_endpoint_all_fail(monkeypatch): + async def fake_open_connection(host, port): + raise OSError("connect failed") + + monkeypatch.setattr(nearest_dc.asyncio, "open_connection", fake_open_connection) + + endpoints = [ + MockEndpoint("a", 1, "dc1"), + MockEndpoint("b", 1, "dc2"), + ] + assert await nearest_dc._check_fastest_endpoint(endpoints, timeout=0.05) is None + + +@pytest.mark.asyncio +async def test_check_fastest_endpoint_fastest_wins(monkeypatch): + async def fake_open_connection(host, port): + if host == "slow": + await asyncio.sleep(0.05) + return None, MockWriter() + + monkeypatch.setattr(nearest_dc.asyncio, "open_connection", fake_open_connection) + + endpoints = [ + MockEndpoint("slow", 1, "dc_slow"), + MockEndpoint("fast", 1, "dc_fast"), + ] + winner = await nearest_dc._check_fastest_endpoint(endpoints, timeout=0.2) + assert winner is not None + assert winner.location == "dc_fast" + + +@pytest.mark.asyncio +async def test_check_fastest_endpoint_respects_main_timeout(monkeypatch): + async def fake_open_connection(host, port): + await asyncio.sleep(0.2) + return None, MockWriter() + + monkeypatch.setattr(nearest_dc.asyncio, "open_connection", fake_open_connection) + + endpoints = [ + MockEndpoint("hang1", 1, "dc1"), + MockEndpoint("hang2", 1, "dc2"), + ] + + winner = await nearest_dc._check_fastest_endpoint(endpoints, timeout=0.05) + + assert winner is None + + +@pytest.mark.asyncio +async def test_detect_local_dc_empty_endpoints(): + with pytest.raises(ValueError, match="Empty endpoints"): + await nearest_dc.detect_local_dc([]) + + +@pytest.mark.asyncio +async def test_detect_local_dc_single_location_returns_immediately(monkeypatch): + async def fail_if_called(*args, **kwargs): + raise AssertionError("open_connection should not be called for single location") + + monkeypatch.setattr(nearest_dc.asyncio, "open_connection", fail_if_called) + + endpoints = [ + MockEndpoint("h1", 1, "dc1"), + MockEndpoint("h2", 1, "dc1"), + ] + assert await nearest_dc.detect_local_dc(endpoints) == "dc1" + + +@pytest.mark.asyncio +async def test_detect_local_dc_fallback_to_first_location_when_all_fail(monkeypatch): + async def fake_open_connection(host, port): + raise OSError("connect failed") + + monkeypatch.setattr(nearest_dc.asyncio, "open_connection", fake_open_connection) + + endpoints = [ + MockEndpoint("bad1", 9999, "dc1"), + MockEndpoint("bad2", 9999, "dc2"), + ] + assert await nearest_dc.detect_local_dc(endpoints, timeout=0.05) == "dc1" + + +@pytest.mark.asyncio +async def test_detect_local_dc_returns_location_of_fastest(monkeypatch): + async def fake_open_connection(host, port): + if host == "dc1_host": + await asyncio.sleep(0.05) + return None, MockWriter() + + monkeypatch.setattr(nearest_dc.asyncio, "open_connection", fake_open_connection) + + endpoints = [ + MockEndpoint("dc1_host", 1, "dc1"), + MockEndpoint("dc2_host", 1, "dc2"), + ] + assert await nearest_dc.detect_local_dc(endpoints, max_per_location=5, timeout=0.2) == "dc2" + + +@pytest.mark.asyncio +async def test_detect_local_dc_respects_max_per_location(monkeypatch): + calls = [] + + async def fake_open_connection(host, port): + calls.append((host, port)) + raise OSError("connect failed") + + monkeypatch.setattr(nearest_dc.asyncio, "open_connection", fake_open_connection) + + endpoints = [MockEndpoint(f"dc1_{i}", 1, "dc1") for i in range(5)] + [ + MockEndpoint(f"dc2_{i}", 1, "dc2") for i in range(5) + ] + await nearest_dc.detect_local_dc(endpoints, max_per_location=2, timeout=0.2) + + assert len(calls) == 4 diff --git a/tests/test_nearest_dc.py b/tests/test_nearest_dc.py new file mode 100644 index 00000000..97c53d68 --- /dev/null +++ b/tests/test_nearest_dc.py @@ -0,0 +1,132 @@ +import time +import pytest +from ydb import nearest_dc + + +class MockEndpoint: + def __init__(self, address, port, location): + self.address = address + self.port = port + self.endpoint = f"{address}:{port}" + self.location = location + + +class DummySock: + def close(self): + pass + + +def test_check_fastest_endpoint_empty(): + assert nearest_dc._check_fastest_endpoint([]) is None + + +def test_check_fastest_endpoint_all_fail(monkeypatch): + def fake_create_connection(addr_port, timeout=None): + raise OSError("connect failed") + + monkeypatch.setattr(nearest_dc.socket, "create_connection", fake_create_connection) + + endpoints = [ + MockEndpoint("a", 1, "dc1"), + MockEndpoint("b", 1, "dc2"), + ] + assert nearest_dc._check_fastest_endpoint(endpoints, timeout=0.05) is None + + +def test_check_fastest_endpoint_fastest_wins(monkeypatch): + def fake_create_connection(addr_port, timeout=None): + host, _ = addr_port + if host == "slow": + time.sleep(0.05) + return DummySock() + + monkeypatch.setattr(nearest_dc.socket, "create_connection", fake_create_connection) + + endpoints = [ + MockEndpoint("slow", 1, "dc_slow"), + MockEndpoint("fast", 1, "dc_fast"), + ] + winner = nearest_dc._check_fastest_endpoint(endpoints, timeout=0.2) + assert winner is not None + assert winner.location == "dc_fast" + + +def test_check_fastest_endpoint_respects_main_timeout(monkeypatch): + def fake_create_connection(addr_port, timeout=None): + time.sleep(0.2) + return DummySock() + + monkeypatch.setattr(nearest_dc.socket, "create_connection", fake_create_connection) + + endpoints = [ + MockEndpoint("hang1", 1, "dc1"), + MockEndpoint("hang2", 1, "dc2"), + ] + + winner = nearest_dc._check_fastest_endpoint(endpoints, timeout=0.05) + + assert winner is None + + +def test_detect_local_dc_empty_endpoints(): + with pytest.raises(ValueError, match="Empty endpoints"): + nearest_dc.detect_local_dc([]) + + +def test_detect_local_dc_single_location_returns_immediately(monkeypatch): + def fail_if_called(*args, **kwargs): + raise AssertionError("create_connection should not be called for single location") + + monkeypatch.setattr(nearest_dc.socket, "create_connection", fail_if_called) + + endpoints = [ + MockEndpoint("h1", 1, "dc1"), + MockEndpoint("h2", 1, "dc1"), + ] + assert nearest_dc.detect_local_dc(endpoints) == "dc1" + + +def test_detect_local_dc_fallback_to_first_location_when_all_fail(monkeypatch): + def fake_create_connection(addr_port, timeout=None): + raise OSError("connect failed") + + monkeypatch.setattr(nearest_dc.socket, "create_connection", fake_create_connection) + + endpoints = [ + MockEndpoint("bad1", 9999, "dc1"), + MockEndpoint("bad2", 9999, "dc2"), + ] + assert nearest_dc.detect_local_dc(endpoints, timeout=0.05) == "dc1" + + +def test_detect_local_dc_returns_location_of_fastest(monkeypatch): + def fake_create_connection(addr_port, timeout=None): + host, _ = addr_port + if host == "dc1_host": + time.sleep(0.05) + return DummySock() + + monkeypatch.setattr(nearest_dc.socket, "create_connection", fake_create_connection) + + endpoints = [ + MockEndpoint("dc1_host", 1, "dc1"), + MockEndpoint("dc2_host", 1, "dc2"), + ] + assert nearest_dc.detect_local_dc(endpoints, max_per_location=5, timeout=0.2) == "dc2" + + +def test_detect_local_dc_respects_max_per_location(monkeypatch): + calls = [] + + def fake_create_connection(addr_port, timeout=None): + calls.append(addr_port) + raise OSError("connect failed") + + monkeypatch.setattr(nearest_dc.socket, "create_connection", fake_create_connection) + + endpoints = [MockEndpoint(f"dc1_{i}", 1, "dc1") for i in range(5)] + [ + MockEndpoint(f"dc2_{i}", 1, "dc2") for i in range(5) + ] + nearest_dc.detect_local_dc(endpoints, max_per_location=2, timeout=0.2) + + assert len(calls) == 4 diff --git a/ydb/aio/nearest_dc.py b/ydb/aio/nearest_dc.py new file mode 100644 index 00000000..3d2cd001 --- /dev/null +++ b/ydb/aio/nearest_dc.py @@ -0,0 +1,164 @@ +# -*- coding: utf-8 -*- +import asyncio +import logging +import random +import time +from typing import List, Dict, Optional + +from .. import resolver + + +logger = logging.getLogger(__name__) + + +async def _check_fastest_endpoint( + endpoints: List[resolver.EndpointInfo], timeout: float = 5.0 +) -> Optional[resolver.EndpointInfo]: + """ + Perform async TCP race: connect to all endpoints concurrently and return the fastest one. + + This function starts async TCP connections to all provided endpoints concurrently using + asyncio tasks and returns the first one that successfully connects. Other connection + attempts are cancelled once a winner is found. + + :param endpoints: List of resolver.EndpointInfo objects + :param timeout: Maximum time to wait for any connection (seconds) + :return: Fastest endpoint that connected successfully, or None if all failed or timeout + """ + if not endpoints: + return None + + deadline = time.monotonic() + timeout + + async def try_connect(endpoint): + remaining = deadline - time.monotonic() + if remaining <= 0: + return None + + try: + _, writer = await asyncio.wait_for( + asyncio.open_connection(endpoint.address, endpoint.port), + timeout=remaining, + ) + writer.close() + await writer.wait_closed() + return endpoint + except (OSError, asyncio.TimeoutError) as e: + logger.debug("Failed to connect to %s: %s", endpoint.endpoint, e) + return None + + tasks = [asyncio.create_task(try_connect(endpoint)) for endpoint in endpoints] + try: + for task in asyncio.as_completed(tasks, timeout=timeout): + endpoint = await task + if endpoint is not None: + return endpoint + return None + except asyncio.TimeoutError: + logger.debug("TCP race timeout after %.2fs, no endpoint connected in time", timeout) + return None + finally: + for t in tasks: + if not t.done(): + t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + + +def _split_endpoints_by_location( + endpoints: List[resolver.EndpointInfo], +) -> Dict[str, List[resolver.EndpointInfo]]: + """ + Group endpoints by their location. + + :param endpoints: List of resolver.EndpointInfo objects + :return: Dictionary mapping location -> list of resolver.EndpointInfo + """ + result = {} + for endpoint in endpoints: + location = endpoint.location + if location not in result: + result[location] = [] + result[location].append(endpoint) + return result + + +def _get_random_endpoints(endpoints: List[resolver.EndpointInfo], count: int) -> List[resolver.EndpointInfo]: + """ + Get random sample of endpoints. + + :param endpoints: List of resolver.EndpointInfo objects + :param count: Maximum number of endpoints to return + :return: Random sample of resolver.EndpointInfo + """ + if len(endpoints) <= count: + return endpoints + + endpoints_copy = list(endpoints) + random.shuffle(endpoints_copy) + return endpoints_copy[:count] + + +async def detect_local_dc( + endpoints: List[resolver.EndpointInfo], max_per_location: int = 3, timeout: float = 5.0 +) -> str: + """ + Detect nearest datacenter by performing async TCP race between endpoints. + + This function groups endpoints by location, selects random samples from each location, + and performs parallel TCP connections to find the fastest one. The location of the + fastest endpoint is considered the nearest datacenter. + + Algorithm: + 1. Group endpoints by location + 2. If only one location exists, return it immediately + 3. Select up to max_per_location random endpoints from each location + 4. Perform TCP race: connect to all selected endpoints simultaneously + 5. Return the location of the first endpoint that connects successfully + + :param endpoints: List of resolver.EndpointInfo objects from discovery + :param max_per_location: Maximum number of endpoints to test per location (default: 3) + :param timeout: TCP connection timeout in seconds (default: 5.0) + :return: Location string of the nearest datacenter + :raises ValueError: If endpoints list is empty or detection fails + """ + if not endpoints: + raise ValueError("Empty endpoints list for local DC detection") + + endpoints_by_location = _split_endpoints_by_location(endpoints) + + logger.debug( + "Detecting local DC from %d endpoints across %d locations", + len(endpoints), + len(endpoints_by_location), + ) + + if len(endpoints_by_location) == 1: + location = list(endpoints_by_location.keys())[0] + logger.info("Only one location found: %s", location) + return location + + endpoints_to_test = [] + for location, location_endpoints in endpoints_by_location.items(): + sample = _get_random_endpoints(location_endpoints, max_per_location) + endpoints_to_test.extend(sample) + logger.debug( + "Selected %d/%d endpoints from location '%s' for testing", + len(sample), + len(location_endpoints), + location, + ) + + fastest_endpoint = await _check_fastest_endpoint(endpoints_to_test, timeout=timeout) + + if fastest_endpoint is None: + fallback_location = endpoints[0].location + logger.warning( + "Failed to detect local DC via TCP race, falling back to first endpoint location: %s", + fallback_location, + ) + return fallback_location + + detected_location = fastest_endpoint.location + logger.info("Detected local DC: %s", detected_location) + + return detected_location diff --git a/ydb/aio/pool.py b/ydb/aio/pool.py index 0e96602c..751fe2e7 100644 --- a/ydb/aio/pool.py +++ b/ydb/aio/pool.py @@ -10,7 +10,7 @@ from .connection import Connection, EndpointKey -from . import resolver +from . import nearest_dc, resolver if TYPE_CHECKING: from ydb.driver import DriverConfig @@ -145,6 +145,28 @@ async def execute_discovery(self) -> bool: if cached_endpoint.endpoint not in resolved_endpoints: self._cache.make_outdated(cached_endpoint) + local_dc = resolve_details.self_location + + # Detect local DC using TCP latency if enabled + if self._driver_config.detect_local_dc: + try: + detected_location = await nearest_dc.detect_local_dc( + resolve_details.endpoints, max_per_location=3, timeout=self._ready_timeout + ) + if detected_location: + local_dc = detected_location + self.logger.info( + "Detected local DC via TCP latency: %s (server reported: %s)", + local_dc, + resolve_details.self_location, + ) + except Exception as e: + self.logger.warning( + "Failed to detect local DC via TCP latency, using server location: %s. Error: %s", + resolve_details.self_location, + e, + ) + for resolved_endpoint in resolve_details.endpoints: if self._ssl_required and not resolved_endpoint.ssl: continue @@ -152,7 +174,7 @@ async def execute_discovery(self) -> bool: if not self._ssl_required and resolved_endpoint.ssl: continue - preferred = resolve_details.self_location == resolved_endpoint.location + preferred = local_dc == resolved_endpoint.location for ( endpoint, diff --git a/ydb/driver.py b/ydb/driver.py index 555ca485..a16e759b 100644 --- a/ydb/driver.py +++ b/ydb/driver.py @@ -110,6 +110,7 @@ class DriverConfig(object): "discovery_request_timeout", "compression", "disable_discovery", + "detect_local_dc", "_additional_sdk_headers", ) @@ -136,6 +137,7 @@ def __init__( discovery_request_timeout: int = 10, compression: Optional[grpc.Compression] = None, disable_discovery: bool = False, + detect_local_dc=False, *, _additional_sdk_headers: Tuple[str, ...] = (), ) -> None: @@ -159,6 +161,7 @@ def __init__( :param grpc_lb_policy_name: A load balancing policy to be used for discovery channel construction. Default value is `round_round` :param discovery_request_timeout: A default timeout to complete the discovery. The default value is 10 seconds. :param disable_discovery: If True, endpoint discovery is disabled and only the start endpoint is used for all requests. + :param detect_local_dc: If True, detect nearest datacenter using TCP latency measurement instead of using server-provided self_location. :param _additional_sdk_headers: Reserved for SDK integrations (e.g. dbapi, sqlalchemy). Do not use in application code. """ @@ -188,6 +191,7 @@ def __init__( self.discovery_request_timeout = discovery_request_timeout self.compression = compression self.disable_discovery = disable_discovery + self.detect_local_dc = detect_local_dc self._additional_sdk_headers = _additional_sdk_headers def set_database(self, database: str) -> "DriverConfig": diff --git a/ydb/nearest_dc.py b/ydb/nearest_dc.py new file mode 100644 index 00000000..f2840683 --- /dev/null +++ b/ydb/nearest_dc.py @@ -0,0 +1,166 @@ +# -*- coding: utf-8 -*- +import socket +import threading +import logging +import random +import time +from typing import List, Dict, Optional + +from . import resolver + + +logger = logging.getLogger(__name__) + + +def _check_fastest_endpoint( + endpoints: List[resolver.EndpointInfo], timeout: float = 5.0 +) -> Optional[resolver.EndpointInfo]: + """ + Perform TCP race: connect to all endpoints simultaneously and return the fastest one. + + This function starts TCP connections to all provided endpoints in parallel + and returns the first one that successfully connects. Other connection attempts + will continue until their socket timeout expires (they cannot be interrupted). + + :param endpoints: List of resolver.EndpointInfo objects + :param timeout: Maximum time to wait for any connection (seconds) + :return: Fastest endpoint that connected successfully, or None if all failed + """ + if not endpoints: + return None + + result = {"endpoint": None, "lock": threading.Lock()} + stop_event = threading.Event() + deadline = time.monotonic() + timeout + + def try_connect(endpoint: resolver.EndpointInfo): + """Try to connect to endpoint and report if successful.""" + remaining = deadline - time.monotonic() + if remaining <= 0 or stop_event.is_set(): + return + + try: + sock = socket.create_connection((endpoint.address, endpoint.port), timeout=remaining) + + try: + with result["lock"]: + if result["endpoint"] is None: + result["endpoint"] = endpoint + stop_event.set() + logger.debug("TCP race winner: %s (location: %s)", endpoint.endpoint, endpoint.location) + finally: + sock.close() + + except Exception as e: + logger.warning("Unexpected error connecting to %s: %s", endpoint.endpoint, e) + + threads: List[threading.Thread] = [] + for ep in endpoints: + thread = threading.Thread(target=try_connect, args=(ep,), daemon=True) + thread.start() + threads.append(thread) + + for thread in threads: + remaining = deadline - time.monotonic() + if remaining <= 0 or stop_event.is_set(): + break + + thread.join(timeout=remaining) + + return result["endpoint"] + + +def _split_endpoints_by_location(endpoints: List[resolver.EndpointInfo]) -> Dict[str, List[resolver.EndpointInfo]]: + """ + Group endpoints by their location. + + :param endpoints: List of resolver.EndpointInfo objects + :return: Dictionary mapping location -> list of resolver.EndpointInfo + """ + result = {} + for endpoint in endpoints: + location = endpoint.location + if location not in result: + result[location] = [] + result[location].append(endpoint) + return result + + +def _get_random_endpoints(endpoints: List[resolver.EndpointInfo], count: int) -> List[resolver.EndpointInfo]: + """ + Get random sample of endpoints. + + :param endpoints: List of resolver.EndpointInfo objects + :param count: Maximum number of endpoints to return + :return: Random sample of resolver.EndpointInfo + """ + if len(endpoints) <= count: + return endpoints + + endpoints_copy = list(endpoints) + random.shuffle(endpoints_copy) + return endpoints_copy[:count] + + +def detect_local_dc(endpoints: List[resolver.EndpointInfo], max_per_location: int = 3, timeout: float = 5.0) -> str: + """ + Detect nearest datacenter by performing TCP race between endpoints. + + This function groups endpoints by location, selects random samples from each location, + and performs parallel TCP connections to find the fastest one. The location of the + fastest endpoint is considered the nearest datacenter. + + Algorithm: + 1. Group endpoints by location + 2. If only one location exists, return it immediately + 3. Select up to max_per_location random endpoints from each location + 4. Perform TCP race: connect to all selected endpoints simultaneously + 5. Return the location of the first endpoint that connects successfully + + :param endpoints: List of resolver.EndpointInfo objects from discovery + :param max_per_location: Maximum number of endpoints to test per location (default: 3) + :param timeout: TCP connection timeout in seconds (default: 5.0) + :return: Location string of the nearest datacenter + :raises ValueError: If endpoints list is empty or detection fails + """ + if not endpoints: + raise ValueError("Empty endpoints list for local DC detection") + + endpoints_by_location = _split_endpoints_by_location(endpoints) + + logger.debug( + "Detecting local DC from %d endpoints across %d locations", + len(endpoints), + len(endpoints_by_location), + ) + + if len(endpoints_by_location) == 1: + location = list(endpoints_by_location.keys())[0] + logger.info("Only one location found: %s", location) + return location + + endpoints_to_test = [] + for location, location_endpoints in endpoints_by_location.items(): + sample = _get_random_endpoints(location_endpoints, max_per_location) + endpoints_to_test.extend(sample) + logger.debug( + "Selected %d/%d endpoints from location '%s' for testing", + len(sample), + len(location_endpoints), + location, + ) + + fastest_endpoint = _check_fastest_endpoint(endpoints_to_test, timeout=timeout) + + if fastest_endpoint is None: + fallback_location = endpoints[0].location + logger.warning( + "Failed to detect local DC via TCP race, falling back to first endpoint location: %s", + fallback_location, + ) + return fallback_location + + detected_location = fastest_endpoint.location + logger.info("Detected local DC: %s", detected_location) + + return detected_location diff --git a/ydb/pool.py b/ydb/pool.py index 1d1374e6..1c011f22 100644 --- a/ydb/pool.py +++ b/ydb/pool.py @@ -9,7 +9,7 @@ import random from typing import Any, Callable, ContextManager, List, Optional, Set, Tuple, TYPE_CHECKING -from . import connection as connection_impl, issues, resolver, _utilities, tracing +from . import connection as connection_impl, issues, nearest_dc, resolver, _utilities, tracing from abc import abstractmethod from .connection import Connection, EndpointKey @@ -232,6 +232,28 @@ def execute_discovery(self) -> bool: if cached_endpoint.endpoint not in resolved_endpoints: self._cache.make_outdated(cached_endpoint) + local_dc = resolve_details.self_location + + # Detect local DC using TCP latency if enabled + if self._driver_config.detect_local_dc: + try: + detected_location = nearest_dc.detect_local_dc( + resolve_details.endpoints, max_per_location=3, timeout=self._ready_timeout + ) + if detected_location: + local_dc = detected_location + self.logger.info( + "Detected local DC via TCP latency: %s (server reported: %s)", + local_dc, + resolve_details.self_location, + ) + except Exception as e: + self.logger.warning( + "Failed to detect local DC via TCP latency, using server location: %s. Error: %s", + resolve_details.self_location, + e, + ) + for resolved_endpoint in resolve_details.endpoints: if self._ssl_required and not resolved_endpoint.ssl: continue @@ -239,7 +261,7 @@ def execute_discovery(self) -> bool: if not self._ssl_required and resolved_endpoint.ssl: continue - preferred = resolve_details.self_location == resolved_endpoint.location + preferred = local_dc == resolved_endpoint.location for ( endpoint, From 8469cead203856c1cb235cef7bf19b56fef41058 Mon Sep 17 00:00:00 2001 From: Idris Yandarov <32651311+r142f@users.noreply.github.com> Date: Thu, 5 Feb 2026 07:04:33 +0300 Subject: [PATCH 02/19] fix: fix linting issues --- ydb/aio/nearest_dc.py | 4 ++-- ydb/nearest_dc.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/ydb/aio/nearest_dc.py b/ydb/aio/nearest_dc.py index 3d2cd001..e819718c 100644 --- a/ydb/aio/nearest_dc.py +++ b/ydb/aio/nearest_dc.py @@ -3,7 +3,7 @@ import logging import random import time -from typing import List, Dict, Optional +from typing import Dict, List, Optional from .. import resolver @@ -73,7 +73,7 @@ def _split_endpoints_by_location( :param endpoints: List of resolver.EndpointInfo objects :return: Dictionary mapping location -> list of resolver.EndpointInfo """ - result = {} + result: Dict[str, List[resolver.EndpointInfo]] = {} for endpoint in endpoints: location = endpoint.location if location not in result: diff --git a/ydb/nearest_dc.py b/ydb/nearest_dc.py index f2840683..12b8f555 100644 --- a/ydb/nearest_dc.py +++ b/ydb/nearest_dc.py @@ -4,7 +4,7 @@ import logging import random import time -from typing import List, Dict, Optional +from typing import Any, Dict, List, Optional from . import resolver @@ -29,7 +29,7 @@ def _check_fastest_endpoint( if not endpoints: return None - result = {"endpoint": None, "lock": threading.Lock()} + result: Dict[str, Any] = {"endpoint": None, "lock": threading.Lock()} stop_event = threading.Event() deadline = time.monotonic() + timeout @@ -77,7 +77,7 @@ def _split_endpoints_by_location(endpoints: List[resolver.EndpointInfo]) -> Dict :param endpoints: List of resolver.EndpointInfo objects :return: Dictionary mapping location -> list of resolver.EndpointInfo """ - result = {} + result: Dict[str, List[resolver.EndpointInfo]] = {} for endpoint in endpoints: location = endpoint.location if location not in result: From 8c65143d77e7fd2902a07b97daa0c472e81c3fe2 Mon Sep 17 00:00:00 2001 From: Idris Yandarov <32651311+r142f@users.noreply.github.com> Date: Sat, 7 Feb 2026 12:06:26 +0300 Subject: [PATCH 03/19] fix: fixing flaws --- tests/aio/test_nearest_dc.py | 4 ++-- tests/test_nearest_dc.py | 4 ++-- ydb/aio/nearest_dc.py | 21 ++++++++++----------- ydb/aio/pool.py | 5 +++++ ydb/driver.py | 2 +- ydb/nearest_dc.py | 28 ++++++++++++---------------- ydb/pool.py | 5 +++++ 7 files changed, 37 insertions(+), 32 deletions(-) diff --git a/tests/aio/test_nearest_dc.py b/tests/aio/test_nearest_dc.py index bd252cff..be9b1f08 100644 --- a/tests/aio/test_nearest_dc.py +++ b/tests/aio/test_nearest_dc.py @@ -98,7 +98,7 @@ async def fail_if_called(*args, **kwargs): @pytest.mark.asyncio -async def test_detect_local_dc_fallback_to_first_location_when_all_fail(monkeypatch): +async def test_detect_local_dc_returns_none_when_all_fail(monkeypatch): async def fake_open_connection(host, port): raise OSError("connect failed") @@ -108,7 +108,7 @@ async def fake_open_connection(host, port): MockEndpoint("bad1", 9999, "dc1"), MockEndpoint("bad2", 9999, "dc2"), ] - assert await nearest_dc.detect_local_dc(endpoints, timeout=0.05) == "dc1" + assert await nearest_dc.detect_local_dc(endpoints, timeout=0.05) is None @pytest.mark.asyncio diff --git a/tests/test_nearest_dc.py b/tests/test_nearest_dc.py index 97c53d68..c626fb7f 100644 --- a/tests/test_nearest_dc.py +++ b/tests/test_nearest_dc.py @@ -86,7 +86,7 @@ def fail_if_called(*args, **kwargs): assert nearest_dc.detect_local_dc(endpoints) == "dc1" -def test_detect_local_dc_fallback_to_first_location_when_all_fail(monkeypatch): +def test_detect_local_dc_returns_none_when_all_fail(monkeypatch): def fake_create_connection(addr_port, timeout=None): raise OSError("connect failed") @@ -96,7 +96,7 @@ def fake_create_connection(addr_port, timeout=None): MockEndpoint("bad1", 9999, "dc1"), MockEndpoint("bad2", 9999, "dc2"), ] - assert nearest_dc.detect_local_dc(endpoints, timeout=0.05) == "dc1" + assert nearest_dc.detect_local_dc(endpoints, timeout=0.05) is None def test_detect_local_dc_returns_location_of_fastest(monkeypatch): diff --git a/ydb/aio/nearest_dc.py b/ydb/aio/nearest_dc.py index e819718c..ae50d52d 100644 --- a/ydb/aio/nearest_dc.py +++ b/ydb/aio/nearest_dc.py @@ -43,8 +43,10 @@ async def try_connect(endpoint): writer.close() await writer.wait_closed() return endpoint - except (OSError, asyncio.TimeoutError) as e: - logger.debug("Failed to connect to %s: %s", endpoint.endpoint, e) + except (OSError, asyncio.TimeoutError): + return None + except Exception as e: + logger.debug("Unexpected error connecting to %s: %s", endpoint.endpoint, e) return None tasks = [asyncio.create_task(try_connect(endpoint)) for endpoint in endpoints] @@ -100,7 +102,7 @@ def _get_random_endpoints(endpoints: List[resolver.EndpointInfo], count: int) -> async def detect_local_dc( endpoints: List[resolver.EndpointInfo], max_per_location: int = 3, timeout: float = 5.0 -) -> str: +) -> Optional[str]: """ Detect nearest datacenter by performing async TCP race between endpoints. @@ -114,12 +116,13 @@ async def detect_local_dc( 3. Select up to max_per_location random endpoints from each location 4. Perform TCP race: connect to all selected endpoints simultaneously 5. Return the location of the first endpoint that connects successfully + 6. If all connections fail, return None :param endpoints: List of resolver.EndpointInfo objects from discovery :param max_per_location: Maximum number of endpoints to test per location (default: 3) :param timeout: TCP connection timeout in seconds (default: 5.0) - :return: Location string of the nearest datacenter - :raises ValueError: If endpoints list is empty or detection fails + :return: Location string of the nearest datacenter, or None if detection failed + :raises ValueError: If endpoints list is empty """ if not endpoints: raise ValueError("Empty endpoints list for local DC detection") @@ -151,12 +154,8 @@ async def detect_local_dc( fastest_endpoint = await _check_fastest_endpoint(endpoints_to_test, timeout=timeout) if fastest_endpoint is None: - fallback_location = endpoints[0].location - logger.warning( - "Failed to detect local DC via TCP race, falling back to first endpoint location: %s", - fallback_location, - ) - return fallback_location + logger.warning("Failed to detect local DC via TCP race: no endpoint connected in time") + return None detected_location = fastest_endpoint.location logger.info("Detected local DC: %s", detected_location) diff --git a/ydb/aio/pool.py b/ydb/aio/pool.py index 751fe2e7..1e607077 100644 --- a/ydb/aio/pool.py +++ b/ydb/aio/pool.py @@ -160,6 +160,11 @@ async def execute_discovery(self) -> bool: local_dc, resolve_details.self_location, ) + else: + self.logger.warning( + "Failed to detect local DC via TCP latency, using server location: %s", + resolve_details.self_location, + ) except Exception as e: self.logger.warning( "Failed to detect local DC via TCP latency, using server location: %s. Error: %s", diff --git a/ydb/driver.py b/ydb/driver.py index a16e759b..467b1158 100644 --- a/ydb/driver.py +++ b/ydb/driver.py @@ -137,7 +137,7 @@ def __init__( discovery_request_timeout: int = 10, compression: Optional[grpc.Compression] = None, disable_discovery: bool = False, - detect_local_dc=False, + detect_local_dc: bool = False, *, _additional_sdk_headers: Tuple[str, ...] = (), ) -> None: diff --git a/ydb/nearest_dc.py b/ydb/nearest_dc.py index 12b8f555..a107e1ad 100644 --- a/ydb/nearest_dc.py +++ b/ydb/nearest_dc.py @@ -51,8 +51,10 @@ def try_connect(endpoint: resolver.EndpointInfo): finally: sock.close() + except (OSError, socket.timeout): + pass except Exception as e: - logger.warning("Unexpected error connecting to %s: %s", endpoint.endpoint, e) + logger.debug("Unexpected error connecting to %s: %s", endpoint.endpoint, e) threads: List[threading.Thread] = [] for ep in endpoints: @@ -60,12 +62,7 @@ def try_connect(endpoint: resolver.EndpointInfo): thread.start() threads.append(thread) - for thread in threads: - remaining = deadline - time.monotonic() - if remaining <= 0 or stop_event.is_set(): - break - - thread.join(timeout=remaining) + stop_event.wait(timeout=max(0.0, deadline - time.monotonic())) return result["endpoint"] @@ -102,7 +99,9 @@ def _get_random_endpoints(endpoints: List[resolver.EndpointInfo], count: int) -> return endpoints_copy[:count] -def detect_local_dc(endpoints: List[resolver.EndpointInfo], max_per_location: int = 3, timeout: float = 5.0) -> str: +def detect_local_dc( + endpoints: List[resolver.EndpointInfo], max_per_location: int = 3, timeout: float = 5.0 +) -> Optional[str]: """ Detect nearest datacenter by performing TCP race between endpoints. @@ -116,12 +115,13 @@ def detect_local_dc(endpoints: List[resolver.EndpointInfo], max_per_location: in 3. Select up to max_per_location random endpoints from each location 4. Perform TCP race: connect to all selected endpoints simultaneously 5. Return the location of the first endpoint that connects successfully + 6. If all connections fail, return None :param endpoints: List of resolver.EndpointInfo objects from discovery :param max_per_location: Maximum number of endpoints to test per location (default: 3) :param timeout: TCP connection timeout in seconds (default: 5.0) - :return: Location string of the nearest datacenter - :raises ValueError: If endpoints list is empty or detection fails + :return: Location string of the nearest datacenter, or None if detection failed + :raises ValueError: If endpoints list is empty """ if not endpoints: raise ValueError("Empty endpoints list for local DC detection") @@ -153,12 +153,8 @@ def detect_local_dc(endpoints: List[resolver.EndpointInfo], max_per_location: in fastest_endpoint = _check_fastest_endpoint(endpoints_to_test, timeout=timeout) if fastest_endpoint is None: - fallback_location = endpoints[0].location - logger.warning( - "Failed to detect local DC via TCP race, falling back to first endpoint location: %s", - fallback_location, - ) - return fallback_location + logger.warning("Failed to detect local DC via TCP race: no endpoint connected in time") + return None detected_location = fastest_endpoint.location logger.info("Detected local DC: %s", detected_location) diff --git a/ydb/pool.py b/ydb/pool.py index 1c011f22..70599931 100644 --- a/ydb/pool.py +++ b/ydb/pool.py @@ -247,6 +247,11 @@ def execute_discovery(self) -> bool: local_dc, resolve_details.self_location, ) + else: + self.logger.warning( + "Failed to detect local DC via TCP latency, using server location: %s", + resolve_details.self_location, + ) except Exception as e: self.logger.warning( "Failed to detect local DC via TCP latency, using server location: %s. Error: %s", From e571db0fd8bd2039fed4513e43685e289f6ced3b Mon Sep 17 00:00:00 2001 From: Idris Yandarov <32651311+r142f@users.noreply.github.com> Date: Tue, 10 Feb 2026 05:25:14 +0300 Subject: [PATCH 04/19] fix: fixing flaws --- ydb/aio/nearest_dc.py | 6 +++--- ydb/driver.py | 4 +++- ydb/nearest_dc.py | 7 ++++--- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/ydb/aio/nearest_dc.py b/ydb/aio/nearest_dc.py index ae50d52d..e9699f64 100644 --- a/ydb/aio/nearest_dc.py +++ b/ydb/aio/nearest_dc.py @@ -137,7 +137,7 @@ async def detect_local_dc( if len(endpoints_by_location) == 1: location = list(endpoints_by_location.keys())[0] - logger.info("Only one location found: %s", location) + logger.debug("Only one location found: %s", location) return location endpoints_to_test = [] @@ -154,10 +154,10 @@ async def detect_local_dc( fastest_endpoint = await _check_fastest_endpoint(endpoints_to_test, timeout=timeout) if fastest_endpoint is None: - logger.warning("Failed to detect local DC via TCP race: no endpoint connected in time") + logger.debug("Failed to detect local DC via TCP race: no endpoint connected in time") return None detected_location = fastest_endpoint.location - logger.info("Detected local DC: %s", detected_location) + logger.debug("Detected local DC: %s", detected_location) return detected_location diff --git a/ydb/driver.py b/ydb/driver.py index 467b1158..7c123823 100644 --- a/ydb/driver.py +++ b/ydb/driver.py @@ -161,7 +161,9 @@ def __init__( :param grpc_lb_policy_name: A load balancing policy to be used for discovery channel construction. Default value is `round_round` :param discovery_request_timeout: A default timeout to complete the discovery. The default value is 10 seconds. :param disable_discovery: If True, endpoint discovery is disabled and only the start endpoint is used for all requests. - :param detect_local_dc: If True, detect nearest datacenter using TCP latency measurement instead of using server-provided self_location. + :param detect_local_dc: If True, detect nearest datacenter using TCP latency measurement instead of using\ + server-provided self_location. **Note**: This option only affects endpoint selection when use_all_nodes=False.\ + When use_all_nodes=True (default), all endpoints are used regardless of detected location. :param _additional_sdk_headers: Reserved for SDK integrations (e.g. dbapi, sqlalchemy). Do not use in application code. """ diff --git a/ydb/nearest_dc.py b/ydb/nearest_dc.py index a107e1ad..5e0f4ec5 100644 --- a/ydb/nearest_dc.py +++ b/ydb/nearest_dc.py @@ -52,6 +52,7 @@ def try_connect(endpoint: resolver.EndpointInfo): sock.close() except (OSError, socket.timeout): + # Ignore expected connection errors; endpoints that fail simply lose the TCP race. pass except Exception as e: logger.debug("Unexpected error connecting to %s: %s", endpoint.endpoint, e) @@ -136,7 +137,7 @@ def detect_local_dc( if len(endpoints_by_location) == 1: location = list(endpoints_by_location.keys())[0] - logger.info("Only one location found: %s", location) + logger.debug("Only one location found: %s", location) return location endpoints_to_test = [] @@ -153,10 +154,10 @@ def detect_local_dc( fastest_endpoint = _check_fastest_endpoint(endpoints_to_test, timeout=timeout) if fastest_endpoint is None: - logger.warning("Failed to detect local DC via TCP race: no endpoint connected in time") + logger.debug("Failed to detect local DC via TCP race: no endpoint connected in time") return None detected_location = fastest_endpoint.location - logger.info("Detected local DC: %s", detected_location) + logger.debug("Detected local DC: %s", detected_location) return detected_location From d04f03bc4d313f239528ddc060d21c56e85fc3d9 Mon Sep 17 00:00:00 2001 From: Idris Yandarov <32651311+r142f@users.noreply.github.com> Date: Wed, 11 Feb 2026 06:39:15 +0300 Subject: [PATCH 05/19] fix: fixing flaws --- ydb/aio/nearest_dc.py | 5 +-- ydb/aio/pool.py | 1 + ydb/nearest_dc.py | 71 ++++++++++++++++++++++++++----------------- ydb/pool.py | 1 + 4 files changed, 46 insertions(+), 32 deletions(-) diff --git a/ydb/aio/nearest_dc.py b/ydb/aio/nearest_dc.py index e9699f64..27a33b33 100644 --- a/ydb/aio/nearest_dc.py +++ b/ydb/aio/nearest_dc.py @@ -94,10 +94,7 @@ def _get_random_endpoints(endpoints: List[resolver.EndpointInfo], count: int) -> """ if len(endpoints) <= count: return endpoints - - endpoints_copy = list(endpoints) - random.shuffle(endpoints_copy) - return endpoints_copy[:count] + return random.sample(endpoints, count) async def detect_local_dc( diff --git a/ydb/aio/pool.py b/ydb/aio/pool.py index 1e607077..81b7d35c 100644 --- a/ydb/aio/pool.py +++ b/ydb/aio/pool.py @@ -170,6 +170,7 @@ async def execute_discovery(self) -> bool: "Failed to detect local DC via TCP latency, using server location: %s. Error: %s", resolve_details.self_location, e, + exc_info=True, ) for resolved_endpoint in resolve_details.endpoints: diff --git a/ydb/nearest_dc.py b/ydb/nearest_dc.py index 5e0f4ec5..d4f7f7d7 100644 --- a/ydb/nearest_dc.py +++ b/ydb/nearest_dc.py @@ -1,26 +1,37 @@ # -*- coding: utf-8 -*- +import atexit +import concurrent.futures import socket import threading import logging import random import time -from typing import Any, Dict, List, Optional +from typing import Dict, List, Optional from . import resolver logger = logging.getLogger(__name__) +# Module-level thread pool for TCP race (reused across discovery cycles) +_TCP_RACE_MAX_WORKERS = 15 +_TCP_RACE_EXECUTOR = concurrent.futures.ThreadPoolExecutor( + max_workers=_TCP_RACE_MAX_WORKERS, + thread_name_prefix="ydb-tcp-race", +) + +# Ensure executor is shut down on process exit +atexit.register(lambda: _TCP_RACE_EXECUTOR.shutdown(wait=False, cancel_futures=True)) + def _check_fastest_endpoint( endpoints: List[resolver.EndpointInfo], timeout: float = 5.0 ) -> Optional[resolver.EndpointInfo]: """ - Perform TCP race: connect to all endpoints simultaneously and return the fastest one. + Perform TCP race using a bounded thread pool and return the fastest endpoint. - This function starts TCP connections to all provided endpoints in parallel - and returns the first one that successfully connects. Other connection attempts - will continue until their socket timeout expires (they cannot be interrupted). + Uses a module-level ThreadPoolExecutor to avoid creating new threads on every + discovery cycle. Returns immediately when the first endpoint connects successfully. :param endpoints: List of resolver.EndpointInfo objects :param timeout: Maximum time to wait for any connection (seconds) @@ -29,43 +40,50 @@ def _check_fastest_endpoint( if not endpoints: return None - result: Dict[str, Any] = {"endpoint": None, "lock": threading.Lock()} + endpoints = _get_random_endpoints(endpoints, _TCP_RACE_MAX_WORKERS) + stop_event = threading.Event() + winner_lock = threading.Lock() deadline = time.monotonic() + timeout - def try_connect(endpoint: resolver.EndpointInfo): - """Try to connect to endpoint and report if successful.""" + def try_connect(endpoint: resolver.EndpointInfo) -> Optional[resolver.EndpointInfo]: + """Try to connect to endpoint and return it if successful.""" remaining = deadline - time.monotonic() if remaining <= 0 or stop_event.is_set(): - return + return None try: sock = socket.create_connection((endpoint.address, endpoint.port), timeout=remaining) - try: - with result["lock"]: - if result["endpoint"] is None: - result["endpoint"] = endpoint - stop_event.set() - logger.debug("TCP race winner: %s (location: %s)", endpoint.endpoint, endpoint.location) + with winner_lock: + if stop_event.is_set(): + return None + stop_event.set() + return endpoint finally: sock.close() - except (OSError, socket.timeout): # Ignore expected connection errors; endpoints that fail simply lose the TCP race. - pass + return None except Exception as e: logger.debug("Unexpected error connecting to %s: %s", endpoint.endpoint, e) + return None - threads: List[threading.Thread] = [] - for ep in endpoints: - thread = threading.Thread(target=try_connect, args=(ep,), daemon=True) - thread.start() - threads.append(thread) + futures: List[concurrent.futures.Future] = [_TCP_RACE_EXECUTOR.submit(try_connect, ep) for ep in endpoints] - stop_event.wait(timeout=max(0.0, deadline - time.monotonic())) + try: + for fut in concurrent.futures.as_completed(futures, timeout=timeout): + result = fut.result() + if result is not None: + return result + except concurrent.futures.TimeoutError: + # Overall timeout expired + pass + finally: + for f in futures: + f.cancel() - return result["endpoint"] + return None def _split_endpoints_by_location(endpoints: List[resolver.EndpointInfo]) -> Dict[str, List[resolver.EndpointInfo]]: @@ -94,10 +112,7 @@ def _get_random_endpoints(endpoints: List[resolver.EndpointInfo], count: int) -> """ if len(endpoints) <= count: return endpoints - - endpoints_copy = list(endpoints) - random.shuffle(endpoints_copy) - return endpoints_copy[:count] + return random.sample(endpoints, count) def detect_local_dc( diff --git a/ydb/pool.py b/ydb/pool.py index 70599931..a0c9e673 100644 --- a/ydb/pool.py +++ b/ydb/pool.py @@ -257,6 +257,7 @@ def execute_discovery(self) -> bool: "Failed to detect local DC via TCP latency, using server location: %s. Error: %s", resolve_details.self_location, e, + exc_info=True, ) for resolved_endpoint in resolve_details.endpoints: From 7e51850198123bb382cfcccf3047570519f5584d Mon Sep 17 00:00:00 2001 From: Idris Yandarov <32651311+r142f@users.noreply.github.com> Date: Wed, 11 Feb 2026 06:55:39 +0300 Subject: [PATCH 06/19] fix: add tests --- tests/aio/test_discovery_detect_local_dc.py | 106 ++++++++++++++++++++ tests/test_discovery_detect_local_dc.py | 90 +++++++++++++++++ 2 files changed, 196 insertions(+) create mode 100644 tests/aio/test_discovery_detect_local_dc.py create mode 100644 tests/test_discovery_detect_local_dc.py diff --git a/tests/aio/test_discovery_detect_local_dc.py b/tests/aio/test_discovery_detect_local_dc.py new file mode 100644 index 00000000..1e7faf5a --- /dev/null +++ b/tests/aio/test_discovery_detect_local_dc.py @@ -0,0 +1,106 @@ +# -*- coding: utf-8 -*- +import pytest +from unittest.mock import MagicMock, patch, AsyncMock +from ydb import driver, connection +from ydb.aio import pool, nearest_dc + + +class MockEndpointInfo: + def __init__(self, address, port, location): + self.address = address + self.port = port + self.endpoint = f"{address}:{port}" + self.location = location + self.ssl = False + self.node_id = 1 + + def endpoints_with_options(self): + yield (self.endpoint, connection.EndpointOptions(ssl_target_name_override=None, node_id=self.node_id)) + + +class MockDiscoveryResult: + def __init__(self, self_location, endpoints): + self.self_location = self_location + self.endpoints = endpoints + + +@pytest.mark.asyncio +async def test_detect_local_dc_overrides_server_location(): + """Test that detected location overrides server's self_location for preferred endpoints.""" + # Server reports dc1, but we detect dc2 as nearest + endpoints = [ + MockEndpointInfo("dc1-host", 2135, "dc1"), + MockEndpointInfo("dc2-host", 2135, "dc2"), + ] + mock_result = MockDiscoveryResult(self_location="dc1", endpoints=endpoints) + + mock_resolver = MagicMock() + mock_resolver.resolve = AsyncMock(return_value=mock_result) + + preferred = [] + + def mock_init(self, endpoint, driver_config, endpoint_options=None): + self.endpoint = endpoint + self.node_id = 1 + + with patch.object(nearest_dc, "detect_local_dc", AsyncMock(return_value="dc2")): + with patch("ydb.aio.connection.Connection.__init__", mock_init): + with patch("ydb.aio.connection.Connection.connection_ready", AsyncMock()): + with patch("ydb.aio.connection.Connection.close", AsyncMock()): + with patch("ydb.aio.connection.Connection.add_cleanup_callback", lambda *a: None): + config = driver.DriverConfig( + endpoint="grpc://test:2135", database="/local", detect_local_dc=True + ) + discovery = pool.Discovery(store=pool.ConnectionsCache(), driver_config=config) + discovery._resolver = mock_resolver + + original_add = discovery._cache.add + discovery._cache.add = lambda conn, pref=False: ( + preferred.append(conn.endpoint) if pref else None, + original_add(conn, pref), + )[1] + + await discovery.execute_discovery() + + assert any("dc2" in ep for ep in preferred), "dc2 should be preferred (detected)" + assert not any("dc1" in ep for ep in preferred), "dc1 should not be preferred" + + +@pytest.mark.asyncio +async def test_detect_local_dc_failure_fallback(): + """Test that detection failure falls back to server's self_location.""" + endpoints = [ + MockEndpointInfo("dc1-host", 2135, "dc1"), + MockEndpointInfo("dc2-host", 2135, "dc2"), + ] + mock_result = MockDiscoveryResult(self_location="dc1", endpoints=endpoints) + + mock_resolver = MagicMock() + mock_resolver.resolve = AsyncMock(return_value=mock_result) + + preferred = [] + + def mock_init(self, endpoint, driver_config, endpoint_options=None): + self.endpoint = endpoint + self.node_id = 1 + + with patch.object(nearest_dc, "detect_local_dc", AsyncMock(return_value=None)): + with patch("ydb.aio.connection.Connection.__init__", mock_init): + with patch("ydb.aio.connection.Connection.connection_ready", AsyncMock()): + with patch("ydb.aio.connection.Connection.close", AsyncMock()): + with patch("ydb.aio.connection.Connection.add_cleanup_callback", lambda *a: None): + config = driver.DriverConfig( + endpoint="grpc://test:2135", database="/local", detect_local_dc=True + ) + discovery = pool.Discovery(store=pool.ConnectionsCache(), driver_config=config) + discovery._resolver = mock_resolver + + original_add = discovery._cache.add + discovery._cache.add = lambda conn, pref=False: ( + preferred.append(conn.endpoint) if pref else None, + original_add(conn, pref), + )[1] + + await discovery.execute_discovery() + + assert any("dc1" in ep for ep in preferred), "dc1 should be preferred (server fallback)" diff --git a/tests/test_discovery_detect_local_dc.py b/tests/test_discovery_detect_local_dc.py new file mode 100644 index 00000000..291db3ac --- /dev/null +++ b/tests/test_discovery_detect_local_dc.py @@ -0,0 +1,90 @@ +# -*- coding: utf-8 -*- +from unittest.mock import Mock, MagicMock, patch +from ydb import driver, pool, nearest_dc, connection + + +class MockEndpointInfo: + def __init__(self, address, port, location): + self.address = address + self.port = port + self.endpoint = f"{address}:{port}" + self.location = location + self.ssl = False + self.node_id = 1 + + def endpoints_with_options(self): + yield (self.endpoint, connection.EndpointOptions(ssl_target_name_override=None, node_id=self.node_id)) + + +class MockDiscoveryResult: + def __init__(self, self_location, endpoints): + self.self_location = self_location + self.endpoints = endpoints + + +def test_detect_local_dc_overrides_server_location(): + """Test that detected location overrides server's self_location for preferred endpoints.""" + # Server reports dc1, but we detect dc2 as nearest + endpoints = [ + MockEndpointInfo("dc1-host", 2135, "dc1"), + MockEndpointInfo("dc2-host", 2135, "dc2"), + ] + mock_result = MockDiscoveryResult(self_location="dc1", endpoints=endpoints) + + mock_resolver = MagicMock() + mock_resolver.context_resolve.return_value.__enter__.return_value = mock_result + mock_resolver.context_resolve.return_value.__exit__.return_value = None + + preferred = [] + + with patch.object(nearest_dc, "detect_local_dc", Mock(return_value="dc2")): + with patch( + "ydb.connection.Connection.ready_factory", lambda *args, **kw: MagicMock(endpoint=args[0], node_id=1) + ): + config = driver.DriverConfig(endpoint="grpc://test:2135", database="/local", detect_local_dc=True) + discovery = pool.Discovery(store=pool.ConnectionsCache(), driver_config=config) + discovery._resolver = mock_resolver + + original_add = discovery._cache.add + discovery._cache.add = lambda conn, pref=False: ( + preferred.append(conn.endpoint) if pref else None, + original_add(conn, pref), + )[1] + + discovery.execute_discovery() + + assert any("dc2" in ep for ep in preferred), "dc2 should be preferred (detected)" + assert not any("dc1" in ep for ep in preferred), "dc1 should not be preferred" + + +def test_detect_local_dc_failure_fallback(): + """Test that detection failure falls back to server's self_location.""" + endpoints = [ + MockEndpointInfo("dc1-host", 2135, "dc1"), + MockEndpointInfo("dc2-host", 2135, "dc2"), + ] + mock_result = MockDiscoveryResult(self_location="dc1", endpoints=endpoints) + + mock_resolver = MagicMock() + mock_resolver.context_resolve.return_value.__enter__.return_value = mock_result + mock_resolver.context_resolve.return_value.__exit__.return_value = None + + preferred = [] + + with patch.object(nearest_dc, "detect_local_dc", Mock(return_value=None)): + with patch( + "ydb.connection.Connection.ready_factory", lambda *args, **kw: MagicMock(endpoint=args[0], node_id=1) + ): + config = driver.DriverConfig(endpoint="grpc://test:2135", database="/local", detect_local_dc=True) + discovery = pool.Discovery(store=pool.ConnectionsCache(), driver_config=config) + discovery._resolver = mock_resolver + + original_add = discovery._cache.add + discovery._cache.add = lambda conn, pref=False: ( + preferred.append(conn.endpoint) if pref else None, + original_add(conn, pref), + )[1] + + discovery.execute_discovery() + + assert any("dc1" in ep for ep in preferred), "dc1 should be preferred (server fallback)" From 206f3549670f8e7558f000ed2a32ebaf78c79d75 Mon Sep 17 00:00:00 2001 From: Idris Yandarov <32651311+r142f@users.noreply.github.com> Date: Wed, 11 Feb 2026 12:05:41 +0300 Subject: [PATCH 07/19] fix: fixing flaws --- tests/aio/test_discovery_detect_local_dc.py | 12 ++++++++---- tests/test_discovery_detect_local_dc.py | 12 ++++++++---- ydb/nearest_dc.py | 21 +++++++++++++++++---- 3 files changed, 33 insertions(+), 12 deletions(-) diff --git a/tests/aio/test_discovery_detect_local_dc.py b/tests/aio/test_discovery_detect_local_dc.py index 1e7faf5a..93e8ddb1 100644 --- a/tests/aio/test_discovery_detect_local_dc.py +++ b/tests/aio/test_discovery_detect_local_dc.py @@ -49,9 +49,11 @@ def mock_init(self, endpoint, driver_config, endpoint_options=None): with patch("ydb.aio.connection.Connection.close", AsyncMock()): with patch("ydb.aio.connection.Connection.add_cleanup_callback", lambda *a: None): config = driver.DriverConfig( - endpoint="grpc://test:2135", database="/local", detect_local_dc=True + endpoint="grpc://test:2135", database="/local", detect_local_dc=True, use_all_nodes=False + ) + discovery = pool.Discovery( + store=pool.ConnectionsCache(config.use_all_nodes), driver_config=config ) - discovery = pool.Discovery(store=pool.ConnectionsCache(), driver_config=config) discovery._resolver = mock_resolver original_add = discovery._cache.add @@ -90,9 +92,11 @@ def mock_init(self, endpoint, driver_config, endpoint_options=None): with patch("ydb.aio.connection.Connection.close", AsyncMock()): with patch("ydb.aio.connection.Connection.add_cleanup_callback", lambda *a: None): config = driver.DriverConfig( - endpoint="grpc://test:2135", database="/local", detect_local_dc=True + endpoint="grpc://test:2135", database="/local", detect_local_dc=True, use_all_nodes=False + ) + discovery = pool.Discovery( + store=pool.ConnectionsCache(config.use_all_nodes), driver_config=config ) - discovery = pool.Discovery(store=pool.ConnectionsCache(), driver_config=config) discovery._resolver = mock_resolver original_add = discovery._cache.add diff --git a/tests/test_discovery_detect_local_dc.py b/tests/test_discovery_detect_local_dc.py index 291db3ac..9be8efa1 100644 --- a/tests/test_discovery_detect_local_dc.py +++ b/tests/test_discovery_detect_local_dc.py @@ -41,8 +41,10 @@ def test_detect_local_dc_overrides_server_location(): with patch( "ydb.connection.Connection.ready_factory", lambda *args, **kw: MagicMock(endpoint=args[0], node_id=1) ): - config = driver.DriverConfig(endpoint="grpc://test:2135", database="/local", detect_local_dc=True) - discovery = pool.Discovery(store=pool.ConnectionsCache(), driver_config=config) + config = driver.DriverConfig( + endpoint="grpc://test:2135", database="/local", detect_local_dc=True, use_all_nodes=False + ) + discovery = pool.Discovery(store=pool.ConnectionsCache(config.use_all_nodes), driver_config=config) discovery._resolver = mock_resolver original_add = discovery._cache.add @@ -75,8 +77,10 @@ def test_detect_local_dc_failure_fallback(): with patch( "ydb.connection.Connection.ready_factory", lambda *args, **kw: MagicMock(endpoint=args[0], node_id=1) ): - config = driver.DriverConfig(endpoint="grpc://test:2135", database="/local", detect_local_dc=True) - discovery = pool.Discovery(store=pool.ConnectionsCache(), driver_config=config) + config = driver.DriverConfig( + endpoint="grpc://test:2135", database="/local", detect_local_dc=True, use_all_nodes=False + ) + discovery = pool.Discovery(store=pool.ConnectionsCache(config.use_all_nodes), driver_config=config) discovery._resolver = mock_resolver original_add = discovery._cache.add diff --git a/ydb/nearest_dc.py b/ydb/nearest_dc.py index d4f7f7d7..611a15dd 100644 --- a/ydb/nearest_dc.py +++ b/ydb/nearest_dc.py @@ -2,6 +2,7 @@ import atexit import concurrent.futures import socket +import sys import threading import logging import random @@ -14,14 +15,21 @@ logger = logging.getLogger(__name__) # Module-level thread pool for TCP race (reused across discovery cycles) -_TCP_RACE_MAX_WORKERS = 15 +_TCP_RACE_MAX_WORKERS = 30 _TCP_RACE_EXECUTOR = concurrent.futures.ThreadPoolExecutor( max_workers=_TCP_RACE_MAX_WORKERS, thread_name_prefix="ydb-tcp-race", ) -# Ensure executor is shut down on process exit -atexit.register(lambda: _TCP_RACE_EXECUTOR.shutdown(wait=False, cancel_futures=True)) + +def _shutdown_executor(): + if sys.version_info >= (3, 9): + _TCP_RACE_EXECUTOR.shutdown(wait=False, cancel_futures=True) + else: + _TCP_RACE_EXECUTOR.shutdown(wait=False) + + +atexit.register(_shutdown_executor) def _check_fastest_endpoint( @@ -33,6 +41,9 @@ def _check_fastest_endpoint( Uses a module-level ThreadPoolExecutor to avoid creating new threads on every discovery cycle. Returns immediately when the first endpoint connects successfully. + If there are more endpoints than the thread pool size, takes one random endpoint + per location to ensure fair representation of all locations in the race. + :param endpoints: List of resolver.EndpointInfo objects :param timeout: Maximum time to wait for any connection (seconds) :return: Fastest endpoint that connected successfully, or None if all failed @@ -40,7 +51,9 @@ def _check_fastest_endpoint( if not endpoints: return None - endpoints = _get_random_endpoints(endpoints, _TCP_RACE_MAX_WORKERS) + if len(endpoints) > _TCP_RACE_MAX_WORKERS: + endpoints_by_location = _split_endpoints_by_location(endpoints) + endpoints = [random.choice(location_eps) for location_eps in endpoints_by_location.values()] stop_event = threading.Event() winner_lock = threading.Lock() From b36a6c036a30ed5c231f2bae7fcd6fcc432eaea1 Mon Sep 17 00:00:00 2001 From: Idris Yandarov <32651311+r142f@users.noreply.github.com> Date: Wed, 11 Feb 2026 12:39:55 +0300 Subject: [PATCH 08/19] fix: fixing flaws --- tests/aio/test_nearest_dc.py | 14 ++++++++++++++ tests/test_nearest_dc.py | 12 ++++++++++++ ydb/aio/nearest_dc.py | 27 +++++++++++++++++++++------ ydb/nearest_dc.py | 16 ++++++++++++---- 4 files changed, 59 insertions(+), 10 deletions(-) diff --git a/tests/aio/test_nearest_dc.py b/tests/aio/test_nearest_dc.py index be9b1f08..4bf0b458 100644 --- a/tests/aio/test_nearest_dc.py +++ b/tests/aio/test_nearest_dc.py @@ -143,3 +143,17 @@ async def fake_open_connection(host, port): await nearest_dc.detect_local_dc(endpoints, max_per_location=2, timeout=0.2) assert len(calls) == 4 + + +@pytest.mark.asyncio +async def test_detect_local_dc_validates_max_per_location(): + endpoints = [MockEndpoint("h1", 1, "dc1")] + with pytest.raises(ValueError, match="max_per_location must be >= 1"): + await nearest_dc.detect_local_dc(endpoints, max_per_location=0) + + +@pytest.mark.asyncio +async def test_detect_local_dc_validates_timeout(): + endpoints = [MockEndpoint("h1", 1, "dc1")] + with pytest.raises(ValueError, match="timeout must be > 0"): + await nearest_dc.detect_local_dc(endpoints, timeout=0) diff --git a/tests/test_nearest_dc.py b/tests/test_nearest_dc.py index c626fb7f..6f7f0c43 100644 --- a/tests/test_nearest_dc.py +++ b/tests/test_nearest_dc.py @@ -130,3 +130,15 @@ def fake_create_connection(addr_port, timeout=None): nearest_dc.detect_local_dc(endpoints, max_per_location=2, timeout=0.2) assert len(calls) == 4 + + +def test_detect_local_dc_validates_max_per_location(): + endpoints = [MockEndpoint("h1", 1, "dc1")] + with pytest.raises(ValueError, match="max_per_location must be >= 1"): + nearest_dc.detect_local_dc(endpoints, max_per_location=0) + + +def test_detect_local_dc_validates_timeout(): + endpoints = [MockEndpoint("h1", 1, "dc1")] + with pytest.raises(ValueError, match="timeout must be > 0"): + nearest_dc.detect_local_dc(endpoints, timeout=0) diff --git a/ydb/aio/nearest_dc.py b/ydb/aio/nearest_dc.py index 27a33b33..0802bc5c 100644 --- a/ydb/aio/nearest_dc.py +++ b/ydb/aio/nearest_dc.py @@ -111,18 +111,23 @@ async def detect_local_dc( 1. Group endpoints by location 2. If only one location exists, return it immediately 3. Select up to max_per_location random endpoints from each location - 4. Perform TCP race: connect to all selected endpoints simultaneously - 5. Return the location of the first endpoint that connects successfully - 6. If all connections fail, return None + 4. If too many endpoints, reduce to one per location and cap at limit + 5. Perform TCP race: connect to all selected endpoints simultaneously + 6. Return the location of the first endpoint that connects successfully + 7. If all connections fail, return None :param endpoints: List of resolver.EndpointInfo objects from discovery - :param max_per_location: Maximum number of endpoints to test per location (default: 3) - :param timeout: TCP connection timeout in seconds (default: 5.0) + :param max_per_location: Maximum number of endpoints to test per location (default: 3, must be >= 1) + :param timeout: TCP connection timeout in seconds (default: 5.0, must be > 0) :return: Location string of the nearest datacenter, or None if detection failed - :raises ValueError: If endpoints list is empty + :raises ValueError: If endpoints list is empty, max_per_location < 1, or timeout <= 0 """ if not endpoints: raise ValueError("Empty endpoints list for local DC detection") + if max_per_location < 1: + raise ValueError(f"max_per_location must be >= 1, got {max_per_location}") + if timeout <= 0: + raise ValueError(f"timeout must be > 0, got {timeout}") endpoints_by_location = _split_endpoints_by_location(endpoints) @@ -137,6 +142,8 @@ async def detect_local_dc( logger.debug("Only one location found: %s", location) return location + _MAX_CONCURRENT_TASKS = 99 + endpoints_to_test = [] for location, location_endpoints in endpoints_by_location.items(): sample = _get_random_endpoints(location_endpoints, max_per_location) @@ -148,6 +155,14 @@ async def detect_local_dc( location, ) + if len(endpoints_to_test) > _MAX_CONCURRENT_TASKS: + endpoints_to_test = [random.choice(location_eps) for location_eps in endpoints_by_location.values()] + + if len(endpoints_to_test) > _MAX_CONCURRENT_TASKS: + endpoints_to_test = random.sample(endpoints_to_test, _MAX_CONCURRENT_TASKS) + + logger.debug("Capped endpoints to %d to limit concurrent tasks", len(endpoints_to_test)) + fastest_endpoint = await _check_fastest_endpoint(endpoints_to_test, timeout=timeout) if fastest_endpoint is None: diff --git a/ydb/nearest_dc.py b/ydb/nearest_dc.py index 611a15dd..306149d4 100644 --- a/ydb/nearest_dc.py +++ b/ydb/nearest_dc.py @@ -42,7 +42,8 @@ def _check_fastest_endpoint( discovery cycle. Returns immediately when the first endpoint connects successfully. If there are more endpoints than the thread pool size, takes one random endpoint - per location to ensure fair representation of all locations in the race. + per location to ensure fair representation of all locations in the race. If there + are still too many locations, randomly samples them to stay within the limit. :param endpoints: List of resolver.EndpointInfo objects :param timeout: Maximum time to wait for any connection (seconds) @@ -55,6 +56,9 @@ def _check_fastest_endpoint( endpoints_by_location = _split_endpoints_by_location(endpoints) endpoints = [random.choice(location_eps) for location_eps in endpoints_by_location.values()] + if len(endpoints) > _TCP_RACE_MAX_WORKERS: + endpoints = random.sample(endpoints, _TCP_RACE_MAX_WORKERS) + stop_event = threading.Event() winner_lock = threading.Lock() deadline = time.monotonic() + timeout @@ -147,13 +151,17 @@ def detect_local_dc( 6. If all connections fail, return None :param endpoints: List of resolver.EndpointInfo objects from discovery - :param max_per_location: Maximum number of endpoints to test per location (default: 3) - :param timeout: TCP connection timeout in seconds (default: 5.0) + :param max_per_location: Maximum number of endpoints to test per location (default: 3, must be >= 1) + :param timeout: TCP connection timeout in seconds (default: 5.0, must be > 0) :return: Location string of the nearest datacenter, or None if detection failed - :raises ValueError: If endpoints list is empty + :raises ValueError: If endpoints list is empty, max_per_location < 1, or timeout <= 0 """ if not endpoints: raise ValueError("Empty endpoints list for local DC detection") + if max_per_location < 1: + raise ValueError(f"max_per_location must be >= 1, got {max_per_location}") + if timeout <= 0: + raise ValueError(f"timeout must be > 0, got {timeout}") endpoints_by_location = _split_endpoints_by_location(endpoints) From fb5083dc831c5baf0d6ad6a88dc97714d56ad52e Mon Sep 17 00:00:00 2001 From: Idris Yandarov <32651311+r142f@users.noreply.github.com> Date: Wed, 11 Feb 2026 13:16:47 +0300 Subject: [PATCH 09/19] fix: fixing flaws --- ydb/nearest_dc.py | 45 ++++++++++++++++++++++++++++++++++----------- 1 file changed, 34 insertions(+), 11 deletions(-) diff --git a/ydb/nearest_dc.py b/ydb/nearest_dc.py index 306149d4..ce3cea24 100644 --- a/ydb/nearest_dc.py +++ b/ydb/nearest_dc.py @@ -16,20 +16,42 @@ # Module-level thread pool for TCP race (reused across discovery cycles) _TCP_RACE_MAX_WORKERS = 30 -_TCP_RACE_EXECUTOR = concurrent.futures.ThreadPoolExecutor( - max_workers=_TCP_RACE_MAX_WORKERS, - thread_name_prefix="ydb-tcp-race", -) +_TCP_RACE_EXECUTOR: Optional[concurrent.futures.ThreadPoolExecutor] = None +_EXECUTOR_LOCK = threading.Lock() +_ATEXIT_REGISTERED = False -def _shutdown_executor(): - if sys.version_info >= (3, 9): - _TCP_RACE_EXECUTOR.shutdown(wait=False, cancel_futures=True) - else: - _TCP_RACE_EXECUTOR.shutdown(wait=False) +def _get_executor() -> concurrent.futures.ThreadPoolExecutor: + """ + Lazily create and return the thread pool executor. + + The executor is created on first use to avoid import-time side effects. + The atexit hook is registered only when the executor is actually created. + """ + global _TCP_RACE_EXECUTOR, _ATEXIT_REGISTERED + if _TCP_RACE_EXECUTOR is None: + with _EXECUTOR_LOCK: + if _TCP_RACE_EXECUTOR is None: + _TCP_RACE_EXECUTOR = concurrent.futures.ThreadPoolExecutor( + max_workers=_TCP_RACE_MAX_WORKERS, + thread_name_prefix="ydb-tcp-race", + ) -atexit.register(_shutdown_executor) + if not _ATEXIT_REGISTERED: + atexit.register(_shutdown_executor) + _ATEXIT_REGISTERED = True + + return _TCP_RACE_EXECUTOR + + +def _shutdown_executor(): + """Shutdown the executor if it was created.""" + if _TCP_RACE_EXECUTOR is not None: + if sys.version_info >= (3, 9): + _TCP_RACE_EXECUTOR.shutdown(wait=False, cancel_futures=True) + else: + _TCP_RACE_EXECUTOR.shutdown(wait=False) def _check_fastest_endpoint( @@ -86,7 +108,8 @@ def try_connect(endpoint: resolver.EndpointInfo) -> Optional[resolver.EndpointIn logger.debug("Unexpected error connecting to %s: %s", endpoint.endpoint, e) return None - futures: List[concurrent.futures.Future] = [_TCP_RACE_EXECUTOR.submit(try_connect, ep) for ep in endpoints] + executor = _get_executor() + futures: List[concurrent.futures.Future] = [executor.submit(try_connect, ep) for ep in endpoints] try: for fut in concurrent.futures.as_completed(futures, timeout=timeout): From 9e8cafbfe5dfcb9dc457d7318185ffc29ba2a8d6 Mon Sep 17 00:00:00 2001 From: Idris Yandarov <32651311+r142f@users.noreply.github.com> Date: Wed, 11 Feb 2026 14:12:19 +0300 Subject: [PATCH 10/19] fix: fixing flaws --- ydb/aio/nearest_dc.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/ydb/aio/nearest_dc.py b/ydb/aio/nearest_dc.py index 0802bc5c..a5c3cda6 100644 --- a/ydb/aio/nearest_dc.py +++ b/ydb/aio/nearest_dc.py @@ -63,7 +63,6 @@ async def try_connect(endpoint): for t in tasks: if not t.done(): t.cancel() - await asyncio.gather(*tasks, return_exceptions=True) def _split_endpoints_by_location( @@ -142,7 +141,7 @@ async def detect_local_dc( logger.debug("Only one location found: %s", location) return location - _MAX_CONCURRENT_TASKS = 99 + _MAX_CONCURRENT_TASKS = 30 endpoints_to_test = [] for location, location_endpoints in endpoints_by_location.items(): From 5688e6a07825275234d71813957d4b4259ae535b Mon Sep 17 00:00:00 2001 From: Idris Yandarov <32651311+r142f@users.noreply.github.com> Date: Wed, 11 Feb 2026 15:17:28 +0300 Subject: [PATCH 11/19] fix: fixing flaws --- tests/aio/test_nearest_dc.py | 4 +++- tests/test_nearest_dc.py | 4 +++- ydb/aio/nearest_dc.py | 9 ++++++++- ydb/nearest_dc.py | 9 ++++++++- 4 files changed, 22 insertions(+), 4 deletions(-) diff --git a/tests/aio/test_nearest_dc.py b/tests/aio/test_nearest_dc.py index 4bf0b458..8ce6d06b 100644 --- a/tests/aio/test_nearest_dc.py +++ b/tests/aio/test_nearest_dc.py @@ -4,11 +4,13 @@ class MockEndpoint: - def __init__(self, address, port, location): + def __init__(self, address, port, location, ipv4_addrs=(), ipv6_addrs=()): self.address = address self.port = port self.endpoint = f"{address}:{port}" self.location = location + self.ipv4_addrs = ipv4_addrs + self.ipv6_addrs = ipv6_addrs class MockWriter: diff --git a/tests/test_nearest_dc.py b/tests/test_nearest_dc.py index 6f7f0c43..cc76e10d 100644 --- a/tests/test_nearest_dc.py +++ b/tests/test_nearest_dc.py @@ -4,11 +4,13 @@ class MockEndpoint: - def __init__(self, address, port, location): + def __init__(self, address, port, location, ipv4_addrs=(), ipv6_addrs=()): self.address = address self.port = port self.endpoint = f"{address}:{port}" self.location = location + self.ipv4_addrs = ipv4_addrs + self.ipv6_addrs = ipv6_addrs class DummySock: diff --git a/ydb/aio/nearest_dc.py b/ydb/aio/nearest_dc.py index a5c3cda6..c17e656d 100644 --- a/ydb/aio/nearest_dc.py +++ b/ydb/aio/nearest_dc.py @@ -35,9 +35,16 @@ async def try_connect(endpoint): if remaining <= 0: return None + if endpoint.ipv6_addrs: + target_host = endpoint.ipv6_addrs[0] + elif endpoint.ipv4_addrs: + target_host = endpoint.ipv4_addrs[0] + else: + target_host = endpoint.address + try: _, writer = await asyncio.wait_for( - asyncio.open_connection(endpoint.address, endpoint.port), + asyncio.open_connection(target_host, endpoint.port), timeout=remaining, ) writer.close() diff --git a/ydb/nearest_dc.py b/ydb/nearest_dc.py index ce3cea24..6b82adeb 100644 --- a/ydb/nearest_dc.py +++ b/ydb/nearest_dc.py @@ -91,8 +91,15 @@ def try_connect(endpoint: resolver.EndpointInfo) -> Optional[resolver.EndpointIn if remaining <= 0 or stop_event.is_set(): return None + if endpoint.ipv6_addrs: + target_host = endpoint.ipv6_addrs[0] + elif endpoint.ipv4_addrs: + target_host = endpoint.ipv4_addrs[0] + else: + target_host = endpoint.address + try: - sock = socket.create_connection((endpoint.address, endpoint.port), timeout=remaining) + sock = socket.create_connection((target_host, endpoint.port), timeout=remaining) try: with winner_lock: if stop_event.is_set(): From 33826e7d3a7a103628f3ae59d2fe47fd37d4571c Mon Sep 17 00:00:00 2001 From: Idris Yandarov <32651311+r142f@users.noreply.github.com> Date: Wed, 11 Feb 2026 17:17:37 +0300 Subject: [PATCH 12/19] fix: fixing flaws --- tests/aio/test_discovery_detect_local_dc.py | 33 ++++++++++++++ tests/test_discovery_detect_local_dc.py | 26 +++++++++++ ydb/aio/pool.py | 49 +++++++++++++-------- ydb/pool.py | 49 +++++++++++++-------- 4 files changed, 121 insertions(+), 36 deletions(-) diff --git a/tests/aio/test_discovery_detect_local_dc.py b/tests/aio/test_discovery_detect_local_dc.py index 93e8ddb1..4f042396 100644 --- a/tests/aio/test_discovery_detect_local_dc.py +++ b/tests/aio/test_discovery_detect_local_dc.py @@ -108,3 +108,36 @@ def mock_init(self, endpoint, driver_config, endpoint_options=None): await discovery.execute_discovery() assert any("dc1" in ep for ep in preferred), "dc1 should be preferred (server fallback)" + + +@pytest.mark.asyncio +async def test_detect_local_dc_skipped_when_use_all_nodes_true(): + """Test that detect_local_dc is NOT called when use_all_nodes=True.""" + endpoints = [ + MockEndpointInfo("dc1-host", 2135, "dc1"), + MockEndpointInfo("dc2-host", 2135, "dc2"), + ] + mock_result = MockDiscoveryResult(self_location="dc1", endpoints=endpoints) + + mock_resolver = MagicMock() + mock_resolver.resolve = AsyncMock(return_value=mock_result) + + def mock_init(self, endpoint, driver_config, endpoint_options=None): + self.endpoint = endpoint + self.node_id = 1 + + with patch.object(nearest_dc, "detect_local_dc", AsyncMock(return_value="dc2")) as detect_mock: + with patch("ydb.aio.connection.Connection.__init__", mock_init): + with patch("ydb.aio.connection.Connection.connection_ready", AsyncMock()): + with patch("ydb.aio.connection.Connection.close", AsyncMock()): + with patch("ydb.aio.connection.Connection.add_cleanup_callback", lambda *a: None): + config = driver.DriverConfig( + endpoint="grpc://test:2135", database="/local", detect_local_dc=True, use_all_nodes=True + ) + discovery = pool.Discovery( + store=pool.ConnectionsCache(config.use_all_nodes), driver_config=config + ) + discovery._resolver = mock_resolver + await discovery.execute_discovery() + + assert detect_mock.call_count == 0, "detect_local_dc should NOT be called when use_all_nodes=True" diff --git a/tests/test_discovery_detect_local_dc.py b/tests/test_discovery_detect_local_dc.py index 9be8efa1..4123dd3f 100644 --- a/tests/test_discovery_detect_local_dc.py +++ b/tests/test_discovery_detect_local_dc.py @@ -92,3 +92,29 @@ def test_detect_local_dc_failure_fallback(): discovery.execute_discovery() assert any("dc1" in ep for ep in preferred), "dc1 should be preferred (server fallback)" + + +def test_detect_local_dc_skipped_when_use_all_nodes_true(): + """Test that detect_local_dc is NOT called when use_all_nodes=True.""" + endpoints = [ + MockEndpointInfo("dc1-host", 2135, "dc1"), + MockEndpointInfo("dc2-host", 2135, "dc2"), + ] + mock_result = MockDiscoveryResult(self_location="dc1", endpoints=endpoints) + + mock_resolver = MagicMock() + mock_resolver.context_resolve.return_value.__enter__.return_value = mock_result + mock_resolver.context_resolve.return_value.__exit__.return_value = None + + with patch.object(nearest_dc, "detect_local_dc", Mock(return_value="dc2")) as detect_mock: + with patch( + "ydb.connection.Connection.ready_factory", lambda *args, **kw: MagicMock(endpoint=args[0], node_id=1) + ): + config = driver.DriverConfig( + endpoint="grpc://test:2135", database="/local", detect_local_dc=True, use_all_nodes=True + ) + discovery = pool.Discovery(store=pool.ConnectionsCache(config.use_all_nodes), driver_config=config) + discovery._resolver = mock_resolver + discovery.execute_discovery() + + assert detect_mock.call_count == 0, "detect_local_dc should NOT be called when use_all_nodes=True" diff --git a/ydb/aio/pool.py b/ydb/aio/pool.py index 81b7d35c..15a00051 100644 --- a/ydb/aio/pool.py +++ b/ydb/aio/pool.py @@ -147,30 +147,43 @@ async def execute_discovery(self) -> bool: local_dc = resolve_details.self_location - # Detect local DC using TCP latency if enabled - if self._driver_config.detect_local_dc: - try: - detected_location = await nearest_dc.detect_local_dc( - resolve_details.endpoints, max_per_location=3, timeout=self._ready_timeout - ) - if detected_location: - local_dc = detected_location - self.logger.info( - "Detected local DC via TCP latency: %s (server reported: %s)", - local_dc, - resolve_details.self_location, + # Detect local DC using TCP latency if enabled and preferred is meaningful + if self._driver_config.detect_local_dc and not self._driver_config.use_all_nodes: + # Use only endpoints that match the SSL requirements for detection + ssl_filtered_endpoints = [ + endpoint + for endpoint in resolve_details.endpoints + if (self._ssl_required and endpoint.ssl) or (not self._ssl_required and not endpoint.ssl) + ] + + if ssl_filtered_endpoints: + try: + detected_location = await nearest_dc.detect_local_dc( + ssl_filtered_endpoints, max_per_location=3, timeout=self._ready_timeout ) - else: + if detected_location: + local_dc = detected_location + self.logger.info( + "Detected local DC via TCP latency: %s (server reported: %s)", + local_dc, + resolve_details.self_location, + ) + else: + self.logger.warning( + "Failed to detect local DC via TCP latency, using server location: %s", + resolve_details.self_location, + ) + except Exception as e: self.logger.warning( - "Failed to detect local DC via TCP latency, using server location: %s", + "Failed to detect local DC via TCP latency, using server location: %s. Error: %s", resolve_details.self_location, + e, + exc_info=True, ) - except Exception as e: + else: self.logger.warning( - "Failed to detect local DC via TCP latency, using server location: %s. Error: %s", + "No SSL-compatible endpoints for local DC detection, using server location: %s", resolve_details.self_location, - e, - exc_info=True, ) for resolved_endpoint in resolve_details.endpoints: diff --git a/ydb/pool.py b/ydb/pool.py index a0c9e673..a45a931c 100644 --- a/ydb/pool.py +++ b/ydb/pool.py @@ -234,30 +234,43 @@ def execute_discovery(self) -> bool: local_dc = resolve_details.self_location - # Detect local DC using TCP latency if enabled - if self._driver_config.detect_local_dc: - try: - detected_location = nearest_dc.detect_local_dc( - resolve_details.endpoints, max_per_location=3, timeout=self._ready_timeout - ) - if detected_location: - local_dc = detected_location - self.logger.info( - "Detected local DC via TCP latency: %s (server reported: %s)", - local_dc, - resolve_details.self_location, + # Detect local DC using TCP latency if enabled and preferred is meaningful + if self._driver_config.detect_local_dc and not self._driver_config.use_all_nodes: + # Use only endpoints that match the SSL requirements for detection + ssl_filtered_endpoints = [ + endpoint + for endpoint in resolve_details.endpoints + if (self._ssl_required and endpoint.ssl) or (not self._ssl_required and not endpoint.ssl) + ] + + if ssl_filtered_endpoints: + try: + detected_location = nearest_dc.detect_local_dc( + ssl_filtered_endpoints, max_per_location=3, timeout=self._ready_timeout ) - else: + if detected_location: + local_dc = detected_location + self.logger.info( + "Detected local DC via TCP latency: %s (server reported: %s)", + local_dc, + resolve_details.self_location, + ) + else: + self.logger.warning( + "Failed to detect local DC via TCP latency, using server location: %s", + resolve_details.self_location, + ) + except Exception as e: self.logger.warning( - "Failed to detect local DC via TCP latency, using server location: %s", + "Failed to detect local DC via TCP latency, using server location: %s. Error: %s", resolve_details.self_location, + e, + exc_info=True, ) - except Exception as e: + else: self.logger.warning( - "Failed to detect local DC via TCP latency, using server location: %s. Error: %s", + "No SSL-compatible endpoints for local DC detection, using server location: %s", resolve_details.self_location, - e, - exc_info=True, ) for resolved_endpoint in resolve_details.endpoints: From 102ada34c045fc026a4d903a5812e60fe325e2ba Mon Sep 17 00:00:00 2001 From: Idris Yandarov <32651311+r142f@users.noreply.github.com> Date: Sun, 29 Mar 2026 15:25:11 +0300 Subject: [PATCH 13/19] fix: move nearest_dc to _utilities --- tests/aio/test_discovery_detect_local_dc.py | 8 +- tests/aio/test_nearest_dc.py | 38 ++-- tests/test_discovery_detect_local_dc.py | 8 +- tests/test_nearest_dc.py | 38 ++-- ydb/_utilities.py | 232 ++++++++++++++++++++ ydb/aio/_utilities.py | 184 ++++++++++++++++ ydb/aio/nearest_dc.py | 181 --------------- ydb/aio/pool.py | 4 +- ydb/nearest_dc.py | 229 ------------------- ydb/pool.py | 4 +- 10 files changed, 466 insertions(+), 460 deletions(-) delete mode 100644 ydb/aio/nearest_dc.py delete mode 100644 ydb/nearest_dc.py diff --git a/tests/aio/test_discovery_detect_local_dc.py b/tests/aio/test_discovery_detect_local_dc.py index 4f042396..ad534291 100644 --- a/tests/aio/test_discovery_detect_local_dc.py +++ b/tests/aio/test_discovery_detect_local_dc.py @@ -2,7 +2,7 @@ import pytest from unittest.mock import MagicMock, patch, AsyncMock from ydb import driver, connection -from ydb.aio import pool, nearest_dc +from ydb.aio import pool, _utilities class MockEndpointInfo: @@ -43,7 +43,7 @@ def mock_init(self, endpoint, driver_config, endpoint_options=None): self.endpoint = endpoint self.node_id = 1 - with patch.object(nearest_dc, "detect_local_dc", AsyncMock(return_value="dc2")): + with patch.object(_utilities, "detect_local_dc", AsyncMock(return_value="dc2")): with patch("ydb.aio.connection.Connection.__init__", mock_init): with patch("ydb.aio.connection.Connection.connection_ready", AsyncMock()): with patch("ydb.aio.connection.Connection.close", AsyncMock()): @@ -86,7 +86,7 @@ def mock_init(self, endpoint, driver_config, endpoint_options=None): self.endpoint = endpoint self.node_id = 1 - with patch.object(nearest_dc, "detect_local_dc", AsyncMock(return_value=None)): + with patch.object(_utilities, "detect_local_dc", AsyncMock(return_value=None)): with patch("ydb.aio.connection.Connection.__init__", mock_init): with patch("ydb.aio.connection.Connection.connection_ready", AsyncMock()): with patch("ydb.aio.connection.Connection.close", AsyncMock()): @@ -126,7 +126,7 @@ def mock_init(self, endpoint, driver_config, endpoint_options=None): self.endpoint = endpoint self.node_id = 1 - with patch.object(nearest_dc, "detect_local_dc", AsyncMock(return_value="dc2")) as detect_mock: + with patch.object(_utilities, "detect_local_dc", AsyncMock(return_value="dc2")) as detect_mock: with patch("ydb.aio.connection.Connection.__init__", mock_init): with patch("ydb.aio.connection.Connection.connection_ready", AsyncMock()): with patch("ydb.aio.connection.Connection.close", AsyncMock()): diff --git a/tests/aio/test_nearest_dc.py b/tests/aio/test_nearest_dc.py index 8ce6d06b..71c96819 100644 --- a/tests/aio/test_nearest_dc.py +++ b/tests/aio/test_nearest_dc.py @@ -1,6 +1,6 @@ import asyncio import pytest -from ydb.aio import nearest_dc +from ydb.aio import _utilities class MockEndpoint: @@ -26,7 +26,7 @@ async def wait_closed(self): @pytest.mark.asyncio async def test_check_fastest_endpoint_empty(): - assert await nearest_dc._check_fastest_endpoint([]) is None + assert await _utilities._check_fastest_endpoint([]) is None @pytest.mark.asyncio @@ -34,13 +34,13 @@ async def test_check_fastest_endpoint_all_fail(monkeypatch): async def fake_open_connection(host, port): raise OSError("connect failed") - monkeypatch.setattr(nearest_dc.asyncio, "open_connection", fake_open_connection) + monkeypatch.setattr(_utilities.asyncio, "open_connection", fake_open_connection) endpoints = [ MockEndpoint("a", 1, "dc1"), MockEndpoint("b", 1, "dc2"), ] - assert await nearest_dc._check_fastest_endpoint(endpoints, timeout=0.05) is None + assert await _utilities._check_fastest_endpoint(endpoints, timeout=0.05) is None @pytest.mark.asyncio @@ -50,13 +50,13 @@ async def fake_open_connection(host, port): await asyncio.sleep(0.05) return None, MockWriter() - monkeypatch.setattr(nearest_dc.asyncio, "open_connection", fake_open_connection) + monkeypatch.setattr(_utilities.asyncio, "open_connection", fake_open_connection) endpoints = [ MockEndpoint("slow", 1, "dc_slow"), MockEndpoint("fast", 1, "dc_fast"), ] - winner = await nearest_dc._check_fastest_endpoint(endpoints, timeout=0.2) + winner = await _utilities._check_fastest_endpoint(endpoints, timeout=0.2) assert winner is not None assert winner.location == "dc_fast" @@ -67,14 +67,14 @@ async def fake_open_connection(host, port): await asyncio.sleep(0.2) return None, MockWriter() - monkeypatch.setattr(nearest_dc.asyncio, "open_connection", fake_open_connection) + monkeypatch.setattr(_utilities.asyncio, "open_connection", fake_open_connection) endpoints = [ MockEndpoint("hang1", 1, "dc1"), MockEndpoint("hang2", 1, "dc2"), ] - winner = await nearest_dc._check_fastest_endpoint(endpoints, timeout=0.05) + winner = await _utilities._check_fastest_endpoint(endpoints, timeout=0.05) assert winner is None @@ -82,7 +82,7 @@ async def fake_open_connection(host, port): @pytest.mark.asyncio async def test_detect_local_dc_empty_endpoints(): with pytest.raises(ValueError, match="Empty endpoints"): - await nearest_dc.detect_local_dc([]) + await _utilities.detect_local_dc([]) @pytest.mark.asyncio @@ -90,13 +90,13 @@ async def test_detect_local_dc_single_location_returns_immediately(monkeypatch): async def fail_if_called(*args, **kwargs): raise AssertionError("open_connection should not be called for single location") - monkeypatch.setattr(nearest_dc.asyncio, "open_connection", fail_if_called) + monkeypatch.setattr(_utilities.asyncio, "open_connection", fail_if_called) endpoints = [ MockEndpoint("h1", 1, "dc1"), MockEndpoint("h2", 1, "dc1"), ] - assert await nearest_dc.detect_local_dc(endpoints) == "dc1" + assert await _utilities.detect_local_dc(endpoints) == "dc1" @pytest.mark.asyncio @@ -104,13 +104,13 @@ async def test_detect_local_dc_returns_none_when_all_fail(monkeypatch): async def fake_open_connection(host, port): raise OSError("connect failed") - monkeypatch.setattr(nearest_dc.asyncio, "open_connection", fake_open_connection) + monkeypatch.setattr(_utilities.asyncio, "open_connection", fake_open_connection) endpoints = [ MockEndpoint("bad1", 9999, "dc1"), MockEndpoint("bad2", 9999, "dc2"), ] - assert await nearest_dc.detect_local_dc(endpoints, timeout=0.05) is None + assert await _utilities.detect_local_dc(endpoints, timeout=0.05) is None @pytest.mark.asyncio @@ -120,13 +120,13 @@ async def fake_open_connection(host, port): await asyncio.sleep(0.05) return None, MockWriter() - monkeypatch.setattr(nearest_dc.asyncio, "open_connection", fake_open_connection) + monkeypatch.setattr(_utilities.asyncio, "open_connection", fake_open_connection) endpoints = [ MockEndpoint("dc1_host", 1, "dc1"), MockEndpoint("dc2_host", 1, "dc2"), ] - assert await nearest_dc.detect_local_dc(endpoints, max_per_location=5, timeout=0.2) == "dc2" + assert await _utilities.detect_local_dc(endpoints, max_per_location=5, timeout=0.2) == "dc2" @pytest.mark.asyncio @@ -137,12 +137,12 @@ async def fake_open_connection(host, port): calls.append((host, port)) raise OSError("connect failed") - monkeypatch.setattr(nearest_dc.asyncio, "open_connection", fake_open_connection) + monkeypatch.setattr(_utilities.asyncio, "open_connection", fake_open_connection) endpoints = [MockEndpoint(f"dc1_{i}", 1, "dc1") for i in range(5)] + [ MockEndpoint(f"dc2_{i}", 1, "dc2") for i in range(5) ] - await nearest_dc.detect_local_dc(endpoints, max_per_location=2, timeout=0.2) + await _utilities.detect_local_dc(endpoints, max_per_location=2, timeout=0.2) assert len(calls) == 4 @@ -151,11 +151,11 @@ async def fake_open_connection(host, port): async def test_detect_local_dc_validates_max_per_location(): endpoints = [MockEndpoint("h1", 1, "dc1")] with pytest.raises(ValueError, match="max_per_location must be >= 1"): - await nearest_dc.detect_local_dc(endpoints, max_per_location=0) + await _utilities.detect_local_dc(endpoints, max_per_location=0) @pytest.mark.asyncio async def test_detect_local_dc_validates_timeout(): endpoints = [MockEndpoint("h1", 1, "dc1")] with pytest.raises(ValueError, match="timeout must be > 0"): - await nearest_dc.detect_local_dc(endpoints, timeout=0) + await _utilities.detect_local_dc(endpoints, timeout=0) diff --git a/tests/test_discovery_detect_local_dc.py b/tests/test_discovery_detect_local_dc.py index 4123dd3f..874ffba7 100644 --- a/tests/test_discovery_detect_local_dc.py +++ b/tests/test_discovery_detect_local_dc.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- from unittest.mock import Mock, MagicMock, patch -from ydb import driver, pool, nearest_dc, connection +from ydb import driver, pool, _utilities, connection class MockEndpointInfo: @@ -37,7 +37,7 @@ def test_detect_local_dc_overrides_server_location(): preferred = [] - with patch.object(nearest_dc, "detect_local_dc", Mock(return_value="dc2")): + with patch.object(_utilities, "detect_local_dc", Mock(return_value="dc2")): with patch( "ydb.connection.Connection.ready_factory", lambda *args, **kw: MagicMock(endpoint=args[0], node_id=1) ): @@ -73,7 +73,7 @@ def test_detect_local_dc_failure_fallback(): preferred = [] - with patch.object(nearest_dc, "detect_local_dc", Mock(return_value=None)): + with patch.object(_utilities, "detect_local_dc", Mock(return_value=None)): with patch( "ydb.connection.Connection.ready_factory", lambda *args, **kw: MagicMock(endpoint=args[0], node_id=1) ): @@ -106,7 +106,7 @@ def test_detect_local_dc_skipped_when_use_all_nodes_true(): mock_resolver.context_resolve.return_value.__enter__.return_value = mock_result mock_resolver.context_resolve.return_value.__exit__.return_value = None - with patch.object(nearest_dc, "detect_local_dc", Mock(return_value="dc2")) as detect_mock: + with patch.object(_utilities, "detect_local_dc", Mock(return_value="dc2")) as detect_mock: with patch( "ydb.connection.Connection.ready_factory", lambda *args, **kw: MagicMock(endpoint=args[0], node_id=1) ): diff --git a/tests/test_nearest_dc.py b/tests/test_nearest_dc.py index cc76e10d..87f4572e 100644 --- a/tests/test_nearest_dc.py +++ b/tests/test_nearest_dc.py @@ -1,6 +1,6 @@ import time import pytest -from ydb import nearest_dc +from ydb import _utilities class MockEndpoint: @@ -19,20 +19,20 @@ def close(self): def test_check_fastest_endpoint_empty(): - assert nearest_dc._check_fastest_endpoint([]) is None + assert _utilities._check_fastest_endpoint([]) is None def test_check_fastest_endpoint_all_fail(monkeypatch): def fake_create_connection(addr_port, timeout=None): raise OSError("connect failed") - monkeypatch.setattr(nearest_dc.socket, "create_connection", fake_create_connection) + monkeypatch.setattr(_utilities.socket, "create_connection", fake_create_connection) endpoints = [ MockEndpoint("a", 1, "dc1"), MockEndpoint("b", 1, "dc2"), ] - assert nearest_dc._check_fastest_endpoint(endpoints, timeout=0.05) is None + assert _utilities._check_fastest_endpoint(endpoints, timeout=0.05) is None def test_check_fastest_endpoint_fastest_wins(monkeypatch): @@ -42,13 +42,13 @@ def fake_create_connection(addr_port, timeout=None): time.sleep(0.05) return DummySock() - monkeypatch.setattr(nearest_dc.socket, "create_connection", fake_create_connection) + monkeypatch.setattr(_utilities.socket, "create_connection", fake_create_connection) endpoints = [ MockEndpoint("slow", 1, "dc_slow"), MockEndpoint("fast", 1, "dc_fast"), ] - winner = nearest_dc._check_fastest_endpoint(endpoints, timeout=0.2) + winner = _utilities._check_fastest_endpoint(endpoints, timeout=0.2) assert winner is not None assert winner.location == "dc_fast" @@ -58,47 +58,47 @@ def fake_create_connection(addr_port, timeout=None): time.sleep(0.2) return DummySock() - monkeypatch.setattr(nearest_dc.socket, "create_connection", fake_create_connection) + monkeypatch.setattr(_utilities.socket, "create_connection", fake_create_connection) endpoints = [ MockEndpoint("hang1", 1, "dc1"), MockEndpoint("hang2", 1, "dc2"), ] - winner = nearest_dc._check_fastest_endpoint(endpoints, timeout=0.05) + winner = _utilities._check_fastest_endpoint(endpoints, timeout=0.05) assert winner is None def test_detect_local_dc_empty_endpoints(): with pytest.raises(ValueError, match="Empty endpoints"): - nearest_dc.detect_local_dc([]) + _utilities.detect_local_dc([]) def test_detect_local_dc_single_location_returns_immediately(monkeypatch): def fail_if_called(*args, **kwargs): raise AssertionError("create_connection should not be called for single location") - monkeypatch.setattr(nearest_dc.socket, "create_connection", fail_if_called) + monkeypatch.setattr(_utilities.socket, "create_connection", fail_if_called) endpoints = [ MockEndpoint("h1", 1, "dc1"), MockEndpoint("h2", 1, "dc1"), ] - assert nearest_dc.detect_local_dc(endpoints) == "dc1" + assert _utilities.detect_local_dc(endpoints) == "dc1" def test_detect_local_dc_returns_none_when_all_fail(monkeypatch): def fake_create_connection(addr_port, timeout=None): raise OSError("connect failed") - monkeypatch.setattr(nearest_dc.socket, "create_connection", fake_create_connection) + monkeypatch.setattr(_utilities.socket, "create_connection", fake_create_connection) endpoints = [ MockEndpoint("bad1", 9999, "dc1"), MockEndpoint("bad2", 9999, "dc2"), ] - assert nearest_dc.detect_local_dc(endpoints, timeout=0.05) is None + assert _utilities.detect_local_dc(endpoints, timeout=0.05) is None def test_detect_local_dc_returns_location_of_fastest(monkeypatch): @@ -108,13 +108,13 @@ def fake_create_connection(addr_port, timeout=None): time.sleep(0.05) return DummySock() - monkeypatch.setattr(nearest_dc.socket, "create_connection", fake_create_connection) + monkeypatch.setattr(_utilities.socket, "create_connection", fake_create_connection) endpoints = [ MockEndpoint("dc1_host", 1, "dc1"), MockEndpoint("dc2_host", 1, "dc2"), ] - assert nearest_dc.detect_local_dc(endpoints, max_per_location=5, timeout=0.2) == "dc2" + assert _utilities.detect_local_dc(endpoints, max_per_location=5, timeout=0.2) == "dc2" def test_detect_local_dc_respects_max_per_location(monkeypatch): @@ -124,12 +124,12 @@ def fake_create_connection(addr_port, timeout=None): calls.append(addr_port) raise OSError("connect failed") - monkeypatch.setattr(nearest_dc.socket, "create_connection", fake_create_connection) + monkeypatch.setattr(_utilities.socket, "create_connection", fake_create_connection) endpoints = [MockEndpoint(f"dc1_{i}", 1, "dc1") for i in range(5)] + [ MockEndpoint(f"dc2_{i}", 1, "dc2") for i in range(5) ] - nearest_dc.detect_local_dc(endpoints, max_per_location=2, timeout=0.2) + _utilities.detect_local_dc(endpoints, max_per_location=2, timeout=0.2) assert len(calls) == 4 @@ -137,10 +137,10 @@ def fake_create_connection(addr_port, timeout=None): def test_detect_local_dc_validates_max_per_location(): endpoints = [MockEndpoint("h1", 1, "dc1")] with pytest.raises(ValueError, match="max_per_location must be >= 1"): - nearest_dc.detect_local_dc(endpoints, max_per_location=0) + _utilities.detect_local_dc(endpoints, max_per_location=0) def test_detect_local_dc_validates_timeout(): endpoints = [MockEndpoint("h1", 1, "dc1")] with pytest.raises(ValueError, match="timeout must be > 0"): - nearest_dc.detect_local_dc(endpoints, timeout=0) + _utilities.detect_local_dc(endpoints, timeout=0) diff --git a/ydb/_utilities.py b/ydb/_utilities.py index 2e04ba12..df152987 100644 --- a/ydb/_utilities.py +++ b/ydb/_utilities.py @@ -1,4 +1,6 @@ # -*- coding: utf-8 -*- +import atexit +import concurrent.futures import importlib.util import threading import codecs @@ -6,11 +8,20 @@ import functools import hashlib import collections +import socket +import sys +import logging +import random +import time import urllib.parse +from typing import Dict, List, Optional, TYPE_CHECKING from . import ydb_version import typing +if TYPE_CHECKING: + from . import resolver + interceptor: typing.Any try: from . import interceptor @@ -208,3 +219,224 @@ def get_first_response(waiter): thread.start() return waiter.result(timeout=timeout) + + +# ============================================================================ +# Nearest DC detection utilities +# ============================================================================ + +logger = logging.getLogger(__name__) + +# Module-level thread pool for TCP race (reused across discovery cycles) +_TCP_RACE_MAX_WORKERS = 30 +_TCP_RACE_EXECUTOR: Optional[concurrent.futures.ThreadPoolExecutor] = None +_EXECUTOR_LOCK = threading.Lock() +_ATEXIT_REGISTERED = False + + +def _get_executor() -> concurrent.futures.ThreadPoolExecutor: + """ + Lazily create and return the thread pool executor. + + The executor is created on first use to avoid import-time side effects. + The atexit hook is registered only when the executor is actually created. + """ + global _TCP_RACE_EXECUTOR, _ATEXIT_REGISTERED + + if _TCP_RACE_EXECUTOR is None: + with _EXECUTOR_LOCK: + if _TCP_RACE_EXECUTOR is None: + _TCP_RACE_EXECUTOR = concurrent.futures.ThreadPoolExecutor( + max_workers=_TCP_RACE_MAX_WORKERS, + thread_name_prefix="ydb-tcp-race", + ) + + if not _ATEXIT_REGISTERED: + atexit.register(_shutdown_executor) + _ATEXIT_REGISTERED = True + + return _TCP_RACE_EXECUTOR + + +def _shutdown_executor(): + """Shutdown the executor if it was created.""" + if _TCP_RACE_EXECUTOR is not None: + if sys.version_info >= (3, 9): + _TCP_RACE_EXECUTOR.shutdown(wait=False, cancel_futures=True) + else: + _TCP_RACE_EXECUTOR.shutdown(wait=False) + + +def _check_fastest_endpoint( + endpoints: List["resolver.EndpointInfo"], timeout: float = 5.0 +) -> Optional["resolver.EndpointInfo"]: + """ + Perform TCP race using a bounded thread pool and return the fastest endpoint. + + Uses a module-level ThreadPoolExecutor to avoid creating new threads on every + discovery cycle. Returns immediately when the first endpoint connects successfully. + + If there are more endpoints than the thread pool size, takes one random endpoint + per location to ensure fair representation of all locations in the race. If there + are still too many locations, randomly samples them to stay within the limit. + + :param endpoints: List of resolver.EndpointInfo objects + :param timeout: Maximum time to wait for any connection (seconds) + :return: Fastest endpoint that connected successfully, or None if all failed + """ + if not endpoints: + return None + + if len(endpoints) > _TCP_RACE_MAX_WORKERS: + endpoints_by_location = _split_endpoints_by_location(endpoints) + endpoints = [random.choice(location_eps) for location_eps in endpoints_by_location.values()] + + if len(endpoints) > _TCP_RACE_MAX_WORKERS: + endpoints = random.sample(endpoints, _TCP_RACE_MAX_WORKERS) + + stop_event = threading.Event() + winner_lock = threading.Lock() + deadline = time.monotonic() + timeout + + def try_connect(endpoint: "resolver.EndpointInfo") -> Optional["resolver.EndpointInfo"]: + """Try to connect to endpoint and return it if successful.""" + remaining = deadline - time.monotonic() + if remaining <= 0 or stop_event.is_set(): + return None + + if endpoint.ipv6_addrs: + target_host = endpoint.ipv6_addrs[0] + elif endpoint.ipv4_addrs: + target_host = endpoint.ipv4_addrs[0] + else: + target_host = endpoint.address + + try: + sock = socket.create_connection((target_host, endpoint.port), timeout=remaining) + try: + with winner_lock: + if stop_event.is_set(): + return None + stop_event.set() + return endpoint + finally: + sock.close() + except (OSError, socket.timeout): + # Ignore expected connection errors; endpoints that fail simply lose the TCP race. + return None + except Exception as e: + logger.debug("Unexpected error connecting to %s: %s", endpoint.endpoint, e) + return None + + executor = _get_executor() + futures_list: List[concurrent.futures.Future] = [executor.submit(try_connect, ep) for ep in endpoints] + + try: + for fut in concurrent.futures.as_completed(futures_list, timeout=timeout): + result = fut.result() + if result is not None: + return result + except concurrent.futures.TimeoutError: + # Overall timeout expired + pass + finally: + for f in futures_list: + f.cancel() + + return None + + +def _split_endpoints_by_location(endpoints: List["resolver.EndpointInfo"]) -> Dict[str, List["resolver.EndpointInfo"]]: + """ + Group endpoints by their location. + + :param endpoints: List of resolver.EndpointInfo objects + :return: Dictionary mapping location -> list of resolver.EndpointInfo + """ + result: Dict[str, List["resolver.EndpointInfo"]] = {} + for endpoint in endpoints: + location = endpoint.location + if location not in result: + result[location] = [] + result[location].append(endpoint) + return result + + +def _get_random_endpoints(endpoints: List["resolver.EndpointInfo"], count: int) -> List["resolver.EndpointInfo"]: + """ + Get random sample of endpoints. + + :param endpoints: List of resolver.EndpointInfo objects + :param count: Maximum number of endpoints to return + :return: Random sample of resolver.EndpointInfo + """ + if len(endpoints) <= count: + return endpoints + return random.sample(endpoints, count) + + +def detect_local_dc( + endpoints: List["resolver.EndpointInfo"], max_per_location: int = 3, timeout: float = 5.0 +) -> Optional[str]: + """ + Detect nearest datacenter by performing TCP race between endpoints. + + This function groups endpoints by location, selects random samples from each location, + and performs parallel TCP connections to find the fastest one. The location of the + fastest endpoint is considered the nearest datacenter. + + Algorithm: + 1. Group endpoints by location + 2. If only one location exists, return it immediately + 3. Select up to max_per_location random endpoints from each location + 4. Perform TCP race: connect to all selected endpoints simultaneously + 5. Return the location of the first endpoint that connects successfully + 6. If all connections fail, return None + + :param endpoints: List of resolver.EndpointInfo objects from discovery + :param max_per_location: Maximum number of endpoints to test per location (default: 3, must be >= 1) + :param timeout: TCP connection timeout in seconds (default: 5.0, must be > 0) + :return: Location string of the nearest datacenter, or None if detection failed + :raises ValueError: If endpoints list is empty, max_per_location < 1, or timeout <= 0 + """ + if not endpoints: + raise ValueError("Empty endpoints list for local DC detection") + if max_per_location < 1: + raise ValueError(f"max_per_location must be >= 1, got {max_per_location}") + if timeout <= 0: + raise ValueError(f"timeout must be > 0, got {timeout}") + + endpoints_by_location = _split_endpoints_by_location(endpoints) + + logger.debug( + "Detecting local DC from %d endpoints across %d locations", + len(endpoints), + len(endpoints_by_location), + ) + + if len(endpoints_by_location) == 1: + location = list(endpoints_by_location.keys())[0] + logger.debug("Only one location found: %s", location) + return location + + endpoints_to_test = [] + for location, location_endpoints in endpoints_by_location.items(): + sample = _get_random_endpoints(location_endpoints, max_per_location) + endpoints_to_test.extend(sample) + logger.debug( + "Selected %d/%d endpoints from location '%s' for testing", + len(sample), + len(location_endpoints), + location, + ) + + fastest_endpoint = _check_fastest_endpoint(endpoints_to_test, timeout=timeout) + + if fastest_endpoint is None: + logger.debug("Failed to detect local DC via TCP race: no endpoint connected in time") + return None + + detected_location = fastest_endpoint.location + logger.debug("Detected local DC: %s", detected_location) + + return detected_location diff --git a/ydb/aio/_utilities.py b/ydb/aio/_utilities.py index 062545d8..32708dea 100644 --- a/ydb/aio/_utilities.py +++ b/ydb/aio/_utilities.py @@ -1,4 +1,13 @@ import asyncio +import logging +import random +import time +from typing import Dict, List, Optional + +from .. import resolver + + +logger = logging.getLogger(__name__) class AsyncResponseIterator(object): @@ -35,3 +44,178 @@ async def get_first_response(): return await stream.next() return await asyncio.wait_for(get_first_response(), timeout) + + +# ============================================================================ +# Nearest DC detection utilities +# ============================================================================ + + +async def _check_fastest_endpoint( + endpoints: List[resolver.EndpointInfo], timeout: float = 5.0 +) -> Optional[resolver.EndpointInfo]: + """ + Perform async TCP race: connect to all endpoints concurrently and return the fastest one. + + This function starts async TCP connections to all provided endpoints concurrently using + asyncio tasks and returns the first one that successfully connects. Other connection + attempts are cancelled once a winner is found. + + :param endpoints: List of resolver.EndpointInfo objects + :param timeout: Maximum time to wait for any connection (seconds) + :return: Fastest endpoint that connected successfully, or None if all failed or timeout + """ + if not endpoints: + return None + + deadline = time.monotonic() + timeout + + async def try_connect(endpoint): + remaining = deadline - time.monotonic() + if remaining <= 0: + return None + + if endpoint.ipv6_addrs: + target_host = endpoint.ipv6_addrs[0] + elif endpoint.ipv4_addrs: + target_host = endpoint.ipv4_addrs[0] + else: + target_host = endpoint.address + + try: + _, writer = await asyncio.wait_for( + asyncio.open_connection(target_host, endpoint.port), + timeout=remaining, + ) + writer.close() + await writer.wait_closed() + return endpoint + except (OSError, asyncio.TimeoutError): + return None + except Exception as e: + logger.debug("Unexpected error connecting to %s: %s", endpoint.endpoint, e) + return None + + tasks = [asyncio.create_task(try_connect(endpoint)) for endpoint in endpoints] + try: + for task in asyncio.as_completed(tasks, timeout=timeout): + endpoint = await task + if endpoint is not None: + return endpoint + return None + except asyncio.TimeoutError: + logger.debug("TCP race timeout after %.2fs, no endpoint connected in time", timeout) + return None + finally: + for t in tasks: + if not t.done(): + t.cancel() + + +def _split_endpoints_by_location( + endpoints: List[resolver.EndpointInfo], +) -> Dict[str, List[resolver.EndpointInfo]]: + """ + Group endpoints by their location. + + :param endpoints: List of resolver.EndpointInfo objects + :return: Dictionary mapping location -> list of resolver.EndpointInfo + """ + result: Dict[str, List[resolver.EndpointInfo]] = {} + for endpoint in endpoints: + location = endpoint.location + if location not in result: + result[location] = [] + result[location].append(endpoint) + return result + + +def _get_random_endpoints(endpoints: List[resolver.EndpointInfo], count: int) -> List[resolver.EndpointInfo]: + """ + Get random sample of endpoints. + + :param endpoints: List of resolver.EndpointInfo objects + :param count: Maximum number of endpoints to return + :return: Random sample of resolver.EndpointInfo + """ + if len(endpoints) <= count: + return endpoints + return random.sample(endpoints, count) + + +async def detect_local_dc( + endpoints: List[resolver.EndpointInfo], max_per_location: int = 3, timeout: float = 5.0 +) -> Optional[str]: + """ + Detect nearest datacenter by performing async TCP race between endpoints. + + This function groups endpoints by location, selects random samples from each location, + and performs parallel TCP connections to find the fastest one. The location of the + fastest endpoint is considered the nearest datacenter. + + Algorithm: + 1. Group endpoints by location + 2. If only one location exists, return it immediately + 3. Select up to max_per_location random endpoints from each location + 4. If too many endpoints, reduce to one per location and cap at limit + 5. Perform TCP race: connect to all selected endpoints simultaneously + 6. Return the location of the first endpoint that connects successfully + 7. If all connections fail, return None + + :param endpoints: List of resolver.EndpointInfo objects from discovery + :param max_per_location: Maximum number of endpoints to test per location (default: 3, must be >= 1) + :param timeout: TCP connection timeout in seconds (default: 5.0, must be > 0) + :return: Location string of the nearest datacenter, or None if detection failed + :raises ValueError: If endpoints list is empty, max_per_location < 1, or timeout <= 0 + """ + if not endpoints: + raise ValueError("Empty endpoints list for local DC detection") + if max_per_location < 1: + raise ValueError(f"max_per_location must be >= 1, got {max_per_location}") + if timeout <= 0: + raise ValueError(f"timeout must be > 0, got {timeout}") + + endpoints_by_location = _split_endpoints_by_location(endpoints) + + logger.debug( + "Detecting local DC from %d endpoints across %d locations", + len(endpoints), + len(endpoints_by_location), + ) + + if len(endpoints_by_location) == 1: + location = list(endpoints_by_location.keys())[0] + logger.debug("Only one location found: %s", location) + return location + + _MAX_CONCURRENT_TASKS = 30 + + endpoints_to_test = [] + for location, location_endpoints in endpoints_by_location.items(): + sample = _get_random_endpoints(location_endpoints, max_per_location) + endpoints_to_test.extend(sample) + logger.debug( + "Selected %d/%d endpoints from location '%s' for testing", + len(sample), + len(location_endpoints), + location, + ) + + if len(endpoints_to_test) > _MAX_CONCURRENT_TASKS: + endpoints_to_test = [random.choice(location_eps) for location_eps in endpoints_by_location.values()] + + if len(endpoints_to_test) > _MAX_CONCURRENT_TASKS: + endpoints_to_test = random.sample(endpoints_to_test, _MAX_CONCURRENT_TASKS) + + logger.debug("Capped endpoints to %d to limit concurrent tasks", len(endpoints_to_test)) + + fastest_endpoint = await _check_fastest_endpoint(endpoints_to_test, timeout=timeout) + + if fastest_endpoint is None: + logger.debug("Failed to detect local DC via TCP race: no endpoint connected in time") + return None + + detected_location = fastest_endpoint.location + logger.debug("Detected local DC: %s", detected_location) + + return detected_location diff --git a/ydb/aio/nearest_dc.py b/ydb/aio/nearest_dc.py deleted file mode 100644 index c17e656d..00000000 --- a/ydb/aio/nearest_dc.py +++ /dev/null @@ -1,181 +0,0 @@ -# -*- coding: utf-8 -*- -import asyncio -import logging -import random -import time -from typing import Dict, List, Optional - -from .. import resolver - - -logger = logging.getLogger(__name__) - - -async def _check_fastest_endpoint( - endpoints: List[resolver.EndpointInfo], timeout: float = 5.0 -) -> Optional[resolver.EndpointInfo]: - """ - Perform async TCP race: connect to all endpoints concurrently and return the fastest one. - - This function starts async TCP connections to all provided endpoints concurrently using - asyncio tasks and returns the first one that successfully connects. Other connection - attempts are cancelled once a winner is found. - - :param endpoints: List of resolver.EndpointInfo objects - :param timeout: Maximum time to wait for any connection (seconds) - :return: Fastest endpoint that connected successfully, or None if all failed or timeout - """ - if not endpoints: - return None - - deadline = time.monotonic() + timeout - - async def try_connect(endpoint): - remaining = deadline - time.monotonic() - if remaining <= 0: - return None - - if endpoint.ipv6_addrs: - target_host = endpoint.ipv6_addrs[0] - elif endpoint.ipv4_addrs: - target_host = endpoint.ipv4_addrs[0] - else: - target_host = endpoint.address - - try: - _, writer = await asyncio.wait_for( - asyncio.open_connection(target_host, endpoint.port), - timeout=remaining, - ) - writer.close() - await writer.wait_closed() - return endpoint - except (OSError, asyncio.TimeoutError): - return None - except Exception as e: - logger.debug("Unexpected error connecting to %s: %s", endpoint.endpoint, e) - return None - - tasks = [asyncio.create_task(try_connect(endpoint)) for endpoint in endpoints] - try: - for task in asyncio.as_completed(tasks, timeout=timeout): - endpoint = await task - if endpoint is not None: - return endpoint - return None - except asyncio.TimeoutError: - logger.debug("TCP race timeout after %.2fs, no endpoint connected in time", timeout) - return None - finally: - for t in tasks: - if not t.done(): - t.cancel() - - -def _split_endpoints_by_location( - endpoints: List[resolver.EndpointInfo], -) -> Dict[str, List[resolver.EndpointInfo]]: - """ - Group endpoints by their location. - - :param endpoints: List of resolver.EndpointInfo objects - :return: Dictionary mapping location -> list of resolver.EndpointInfo - """ - result: Dict[str, List[resolver.EndpointInfo]] = {} - for endpoint in endpoints: - location = endpoint.location - if location not in result: - result[location] = [] - result[location].append(endpoint) - return result - - -def _get_random_endpoints(endpoints: List[resolver.EndpointInfo], count: int) -> List[resolver.EndpointInfo]: - """ - Get random sample of endpoints. - - :param endpoints: List of resolver.EndpointInfo objects - :param count: Maximum number of endpoints to return - :return: Random sample of resolver.EndpointInfo - """ - if len(endpoints) <= count: - return endpoints - return random.sample(endpoints, count) - - -async def detect_local_dc( - endpoints: List[resolver.EndpointInfo], max_per_location: int = 3, timeout: float = 5.0 -) -> Optional[str]: - """ - Detect nearest datacenter by performing async TCP race between endpoints. - - This function groups endpoints by location, selects random samples from each location, - and performs parallel TCP connections to find the fastest one. The location of the - fastest endpoint is considered the nearest datacenter. - - Algorithm: - 1. Group endpoints by location - 2. If only one location exists, return it immediately - 3. Select up to max_per_location random endpoints from each location - 4. If too many endpoints, reduce to one per location and cap at limit - 5. Perform TCP race: connect to all selected endpoints simultaneously - 6. Return the location of the first endpoint that connects successfully - 7. If all connections fail, return None - - :param endpoints: List of resolver.EndpointInfo objects from discovery - :param max_per_location: Maximum number of endpoints to test per location (default: 3, must be >= 1) - :param timeout: TCP connection timeout in seconds (default: 5.0, must be > 0) - :return: Location string of the nearest datacenter, or None if detection failed - :raises ValueError: If endpoints list is empty, max_per_location < 1, or timeout <= 0 - """ - if not endpoints: - raise ValueError("Empty endpoints list for local DC detection") - if max_per_location < 1: - raise ValueError(f"max_per_location must be >= 1, got {max_per_location}") - if timeout <= 0: - raise ValueError(f"timeout must be > 0, got {timeout}") - - endpoints_by_location = _split_endpoints_by_location(endpoints) - - logger.debug( - "Detecting local DC from %d endpoints across %d locations", - len(endpoints), - len(endpoints_by_location), - ) - - if len(endpoints_by_location) == 1: - location = list(endpoints_by_location.keys())[0] - logger.debug("Only one location found: %s", location) - return location - - _MAX_CONCURRENT_TASKS = 30 - - endpoints_to_test = [] - for location, location_endpoints in endpoints_by_location.items(): - sample = _get_random_endpoints(location_endpoints, max_per_location) - endpoints_to_test.extend(sample) - logger.debug( - "Selected %d/%d endpoints from location '%s' for testing", - len(sample), - len(location_endpoints), - location, - ) - - if len(endpoints_to_test) > _MAX_CONCURRENT_TASKS: - endpoints_to_test = [random.choice(location_eps) for location_eps in endpoints_by_location.values()] - - if len(endpoints_to_test) > _MAX_CONCURRENT_TASKS: - endpoints_to_test = random.sample(endpoints_to_test, _MAX_CONCURRENT_TASKS) - - logger.debug("Capped endpoints to %d to limit concurrent tasks", len(endpoints_to_test)) - - fastest_endpoint = await _check_fastest_endpoint(endpoints_to_test, timeout=timeout) - - if fastest_endpoint is None: - logger.debug("Failed to detect local DC via TCP race: no endpoint connected in time") - return None - - detected_location = fastest_endpoint.location - logger.debug("Detected local DC: %s", detected_location) - - return detected_location diff --git a/ydb/aio/pool.py b/ydb/aio/pool.py index 15a00051..fe709133 100644 --- a/ydb/aio/pool.py +++ b/ydb/aio/pool.py @@ -10,7 +10,7 @@ from .connection import Connection, EndpointKey -from . import nearest_dc, resolver +from . import resolver, _utilities if TYPE_CHECKING: from ydb.driver import DriverConfig @@ -158,7 +158,7 @@ async def execute_discovery(self) -> bool: if ssl_filtered_endpoints: try: - detected_location = await nearest_dc.detect_local_dc( + detected_location = await _utilities.detect_local_dc( ssl_filtered_endpoints, max_per_location=3, timeout=self._ready_timeout ) if detected_location: diff --git a/ydb/nearest_dc.py b/ydb/nearest_dc.py deleted file mode 100644 index 6b82adeb..00000000 --- a/ydb/nearest_dc.py +++ /dev/null @@ -1,229 +0,0 @@ -# -*- coding: utf-8 -*- -import atexit -import concurrent.futures -import socket -import sys -import threading -import logging -import random -import time -from typing import Dict, List, Optional - -from . import resolver - - -logger = logging.getLogger(__name__) - -# Module-level thread pool for TCP race (reused across discovery cycles) -_TCP_RACE_MAX_WORKERS = 30 -_TCP_RACE_EXECUTOR: Optional[concurrent.futures.ThreadPoolExecutor] = None -_EXECUTOR_LOCK = threading.Lock() -_ATEXIT_REGISTERED = False - - -def _get_executor() -> concurrent.futures.ThreadPoolExecutor: - """ - Lazily create and return the thread pool executor. - - The executor is created on first use to avoid import-time side effects. - The atexit hook is registered only when the executor is actually created. - """ - global _TCP_RACE_EXECUTOR, _ATEXIT_REGISTERED - - if _TCP_RACE_EXECUTOR is None: - with _EXECUTOR_LOCK: - if _TCP_RACE_EXECUTOR is None: - _TCP_RACE_EXECUTOR = concurrent.futures.ThreadPoolExecutor( - max_workers=_TCP_RACE_MAX_WORKERS, - thread_name_prefix="ydb-tcp-race", - ) - - if not _ATEXIT_REGISTERED: - atexit.register(_shutdown_executor) - _ATEXIT_REGISTERED = True - - return _TCP_RACE_EXECUTOR - - -def _shutdown_executor(): - """Shutdown the executor if it was created.""" - if _TCP_RACE_EXECUTOR is not None: - if sys.version_info >= (3, 9): - _TCP_RACE_EXECUTOR.shutdown(wait=False, cancel_futures=True) - else: - _TCP_RACE_EXECUTOR.shutdown(wait=False) - - -def _check_fastest_endpoint( - endpoints: List[resolver.EndpointInfo], timeout: float = 5.0 -) -> Optional[resolver.EndpointInfo]: - """ - Perform TCP race using a bounded thread pool and return the fastest endpoint. - - Uses a module-level ThreadPoolExecutor to avoid creating new threads on every - discovery cycle. Returns immediately when the first endpoint connects successfully. - - If there are more endpoints than the thread pool size, takes one random endpoint - per location to ensure fair representation of all locations in the race. If there - are still too many locations, randomly samples them to stay within the limit. - - :param endpoints: List of resolver.EndpointInfo objects - :param timeout: Maximum time to wait for any connection (seconds) - :return: Fastest endpoint that connected successfully, or None if all failed - """ - if not endpoints: - return None - - if len(endpoints) > _TCP_RACE_MAX_WORKERS: - endpoints_by_location = _split_endpoints_by_location(endpoints) - endpoints = [random.choice(location_eps) for location_eps in endpoints_by_location.values()] - - if len(endpoints) > _TCP_RACE_MAX_WORKERS: - endpoints = random.sample(endpoints, _TCP_RACE_MAX_WORKERS) - - stop_event = threading.Event() - winner_lock = threading.Lock() - deadline = time.monotonic() + timeout - - def try_connect(endpoint: resolver.EndpointInfo) -> Optional[resolver.EndpointInfo]: - """Try to connect to endpoint and return it if successful.""" - remaining = deadline - time.monotonic() - if remaining <= 0 or stop_event.is_set(): - return None - - if endpoint.ipv6_addrs: - target_host = endpoint.ipv6_addrs[0] - elif endpoint.ipv4_addrs: - target_host = endpoint.ipv4_addrs[0] - else: - target_host = endpoint.address - - try: - sock = socket.create_connection((target_host, endpoint.port), timeout=remaining) - try: - with winner_lock: - if stop_event.is_set(): - return None - stop_event.set() - return endpoint - finally: - sock.close() - except (OSError, socket.timeout): - # Ignore expected connection errors; endpoints that fail simply lose the TCP race. - return None - except Exception as e: - logger.debug("Unexpected error connecting to %s: %s", endpoint.endpoint, e) - return None - - executor = _get_executor() - futures: List[concurrent.futures.Future] = [executor.submit(try_connect, ep) for ep in endpoints] - - try: - for fut in concurrent.futures.as_completed(futures, timeout=timeout): - result = fut.result() - if result is not None: - return result - except concurrent.futures.TimeoutError: - # Overall timeout expired - pass - finally: - for f in futures: - f.cancel() - - return None - - -def _split_endpoints_by_location(endpoints: List[resolver.EndpointInfo]) -> Dict[str, List[resolver.EndpointInfo]]: - """ - Group endpoints by their location. - - :param endpoints: List of resolver.EndpointInfo objects - :return: Dictionary mapping location -> list of resolver.EndpointInfo - """ - result: Dict[str, List[resolver.EndpointInfo]] = {} - for endpoint in endpoints: - location = endpoint.location - if location not in result: - result[location] = [] - result[location].append(endpoint) - return result - - -def _get_random_endpoints(endpoints: List[resolver.EndpointInfo], count: int) -> List[resolver.EndpointInfo]: - """ - Get random sample of endpoints. - - :param endpoints: List of resolver.EndpointInfo objects - :param count: Maximum number of endpoints to return - :return: Random sample of resolver.EndpointInfo - """ - if len(endpoints) <= count: - return endpoints - return random.sample(endpoints, count) - - -def detect_local_dc( - endpoints: List[resolver.EndpointInfo], max_per_location: int = 3, timeout: float = 5.0 -) -> Optional[str]: - """ - Detect nearest datacenter by performing TCP race between endpoints. - - This function groups endpoints by location, selects random samples from each location, - and performs parallel TCP connections to find the fastest one. The location of the - fastest endpoint is considered the nearest datacenter. - - Algorithm: - 1. Group endpoints by location - 2. If only one location exists, return it immediately - 3. Select up to max_per_location random endpoints from each location - 4. Perform TCP race: connect to all selected endpoints simultaneously - 5. Return the location of the first endpoint that connects successfully - 6. If all connections fail, return None - - :param endpoints: List of resolver.EndpointInfo objects from discovery - :param max_per_location: Maximum number of endpoints to test per location (default: 3, must be >= 1) - :param timeout: TCP connection timeout in seconds (default: 5.0, must be > 0) - :return: Location string of the nearest datacenter, or None if detection failed - :raises ValueError: If endpoints list is empty, max_per_location < 1, or timeout <= 0 - """ - if not endpoints: - raise ValueError("Empty endpoints list for local DC detection") - if max_per_location < 1: - raise ValueError(f"max_per_location must be >= 1, got {max_per_location}") - if timeout <= 0: - raise ValueError(f"timeout must be > 0, got {timeout}") - - endpoints_by_location = _split_endpoints_by_location(endpoints) - - logger.debug( - "Detecting local DC from %d endpoints across %d locations", - len(endpoints), - len(endpoints_by_location), - ) - - if len(endpoints_by_location) == 1: - location = list(endpoints_by_location.keys())[0] - logger.debug("Only one location found: %s", location) - return location - - endpoints_to_test = [] - for location, location_endpoints in endpoints_by_location.items(): - sample = _get_random_endpoints(location_endpoints, max_per_location) - endpoints_to_test.extend(sample) - logger.debug( - "Selected %d/%d endpoints from location '%s' for testing", - len(sample), - len(location_endpoints), - location, - ) - - fastest_endpoint = _check_fastest_endpoint(endpoints_to_test, timeout=timeout) - - if fastest_endpoint is None: - logger.debug("Failed to detect local DC via TCP race: no endpoint connected in time") - return None - - detected_location = fastest_endpoint.location - logger.debug("Detected local DC: %s", detected_location) - - return detected_location diff --git a/ydb/pool.py b/ydb/pool.py index a45a931c..2901c573 100644 --- a/ydb/pool.py +++ b/ydb/pool.py @@ -9,7 +9,7 @@ import random from typing import Any, Callable, ContextManager, List, Optional, Set, Tuple, TYPE_CHECKING -from . import connection as connection_impl, issues, nearest_dc, resolver, _utilities, tracing +from . import connection as connection_impl, issues, resolver, _utilities, tracing from abc import abstractmethod from .connection import Connection, EndpointKey @@ -245,7 +245,7 @@ def execute_discovery(self) -> bool: if ssl_filtered_endpoints: try: - detected_location = nearest_dc.detect_local_dc( + detected_location = _utilities.detect_local_dc( ssl_filtered_endpoints, max_per_location=3, timeout=self._ready_timeout ) if detected_location: From 5265ba139426f1fb2f93553d460cd041f45e9372 Mon Sep 17 00:00:00 2001 From: Idris Yandarov <32651311+r142f@users.noreply.github.com> Date: Mon, 30 Mar 2026 13:38:59 +0300 Subject: [PATCH 14/19] fix: fixing flaw --- ydb/aio/_utilities.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ydb/aio/_utilities.py b/ydb/aio/_utilities.py index 32708dea..e3e6699e 100644 --- a/ydb/aio/_utilities.py +++ b/ydb/aio/_utilities.py @@ -110,6 +110,7 @@ async def try_connect(endpoint): for t in tasks: if not t.done(): t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) def _split_endpoints_by_location( From e182284f089dfefaa2c42e874a067a6f37072f8a Mon Sep 17 00:00:00 2001 From: Idris Yandarov <32651311+r142f@users.noreply.github.com> Date: Mon, 30 Mar 2026 22:14:52 +0300 Subject: [PATCH 15/19] fix: commit to rerun checks --- ydb/aio/_utilities.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ydb/aio/_utilities.py b/ydb/aio/_utilities.py index e3e6699e..3d6d25b3 100644 --- a/ydb/aio/_utilities.py +++ b/ydb/aio/_utilities.py @@ -110,6 +110,7 @@ async def try_connect(endpoint): for t in tasks: if not t.done(): t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) From a00a511ceb2d3457637a8f657c6152f98bad7f89 Mon Sep 17 00:00:00 2001 From: Idris Yandarov <32651311+r142f@users.noreply.github.com> Date: Mon, 30 Mar 2026 22:33:58 +0300 Subject: [PATCH 16/19] fix: commit to rerun checks --- ydb/aio/_utilities.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ydb/aio/_utilities.py b/ydb/aio/_utilities.py index 3d6d25b3..e3e6699e 100644 --- a/ydb/aio/_utilities.py +++ b/ydb/aio/_utilities.py @@ -110,7 +110,6 @@ async def try_connect(endpoint): for t in tasks: if not t.done(): t.cancel() - await asyncio.gather(*tasks, return_exceptions=True) From db49c4dfd6ddad455c3e295a377ee6b6538fca9e Mon Sep 17 00:00:00 2001 From: Idris Yandarov <32651311+r142f@users.noreply.github.com> Date: Mon, 30 Mar 2026 22:46:07 +0300 Subject: [PATCH 17/19] fix: commit to rerun checks --- ydb/aio/_utilities.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ydb/aio/_utilities.py b/ydb/aio/_utilities.py index e3e6699e..3d6d25b3 100644 --- a/ydb/aio/_utilities.py +++ b/ydb/aio/_utilities.py @@ -110,6 +110,7 @@ async def try_connect(endpoint): for t in tasks: if not t.done(): t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) From 9d73ff8dc008587134b9b62aa8af67e8f3c98fa0 Mon Sep 17 00:00:00 2001 From: Idris Yandarov <32651311+r142f@users.noreply.github.com> Date: Mon, 30 Mar 2026 23:06:37 +0300 Subject: [PATCH 18/19] fix: commit to rerun checks --- ydb/aio/_utilities.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/ydb/aio/_utilities.py b/ydb/aio/_utilities.py index 3d6d25b3..32708dea 100644 --- a/ydb/aio/_utilities.py +++ b/ydb/aio/_utilities.py @@ -111,8 +111,6 @@ async def try_connect(endpoint): if not t.done(): t.cancel() - await asyncio.gather(*tasks, return_exceptions=True) - def _split_endpoints_by_location( endpoints: List[resolver.EndpointInfo], From f7feaf386d114dbc5cd4f3a7ef8c94d8fd58678f Mon Sep 17 00:00:00 2001 From: Idris Yandarov <32651311+r142f@users.noreply.github.com> Date: Mon, 30 Mar 2026 23:15:04 +0300 Subject: [PATCH 19/19] fix: commit to rerun checks --- ydb/aio/_utilities.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ydb/aio/_utilities.py b/ydb/aio/_utilities.py index 32708dea..e3e6699e 100644 --- a/ydb/aio/_utilities.py +++ b/ydb/aio/_utilities.py @@ -110,6 +110,7 @@ async def try_connect(endpoint): for t in tasks: if not t.done(): t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) def _split_endpoints_by_location(