diff --git a/sdk/cosmos/azure-cosmos/CHANGELOG.md b/sdk/cosmos/azure-cosmos/CHANGELOG.md index 6680ead07f26..343370c743b3 100644 --- a/sdk/cosmos/azure-cosmos/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos/CHANGELOG.md @@ -1,5 +1,10 @@ ## Release History +### 4.16.1 (2026-05-31) + +#### Bugs Fixed +* Fixed a bug in the sync and async `/pkranges` change-feed refresh where some containers could fail to build a complete routing map. See [PR 47245](https://github.com/Azure/azure-sdk-for-python/pull/47245). + ### 4.16.0 (2026-05-29) #### Features Added @@ -8,6 +13,7 @@ #### Breaking Changes * `CosmosItemPaged.get_response_headers()` and `CosmosAsyncItemPaged.get_response_headers()` now return a single `CaseInsensitiveDict` (the latest page) instead of `List[CaseInsensitiveDict]` (introduced in 4.16.0b1); `get_last_response_headers()` has been removed. This avoids unbounded memory growth on large queries. **Migration:** code that previously accessed `headers[i]['x-ms-request-charge']` should switch to `headers['x-ms-request-charge']` for the latest page, or pass `response_hook=` to the query method to receive per-page headers as they arrive. See [PR 47172](https://github.com/Azure/azure-sdk-for-python/pull/47172). +* `SELECT VALUE AVG(...)` queries spanning multiple physical partitions now raise `ValueError` instead of returning a mathematically incorrect merged value from client-side aggregation. **Migration:** rewrite cross-partition `AVG` queries as `SUM(...) / COUNT(...)` (both of which merge correctly across partitions), or scope the query to a single partition via `partition_key=`. See [PR 47105](https://github.com/Azure/azure-sdk-for-python/pull/47105). #### Bugs Fixed * Fixed bug where the `Content-Length` HTTP request header was computed from the character count of the request body instead of its UTF-8 byte count. See [PR 47008](https://github.com/Azure/azure-sdk-for-python/pull/47008) @@ -15,7 +21,6 @@ * Fixed bug where `CosmosClient` construction with AAD credentials would crash at startup if the semantic reranking inference endpoint environment variable was not set, even when semantic reranking was not being used. The inference service is now lazily initialized on first use. See [PR 46243](https://github.com/Azure/azure-sdk-for-python/pull/46243) * Fixed bug where region names in `preferred_locations` and `excluded_locations` (client-level and per-request) were not matched tolerantly for differences in case, whitespace, hyphens, and underscores. See [PR 46937](https://github.com/Azure/azure-sdk-for-python/pull/46937) * Fixed a bug in `query_items(feed_range=...)` where pagination could return incorrect results after a partition split caused the supplied feed range to overlap multiple physical partitions. See [PR 47105](https://github.com/Azure/azure-sdk-for-python/pull/47105) -* Fixed bug where `SELECT VALUE AVG(...)` queries spanning multiple physical partitions returned mathematically incorrect merged values from client-side aggregation. These queries now raise `ValueError`. See [PR 47105](https://github.com/Azure/azure-sdk-for-python/pull/47105) * Fixed bug where a `ValueError("Ranges overlap")` or an `AssertionError("code bug: returned overlapping ranges ... is empty")` from the partition key range cache could escape to the caller when the `/pkranges` response contained a transiently inconsistent snapshot (overlap or gap). See [PR 47091](https://github.com/Azure/azure-sdk-for-python/pull/47091) #### Other Changes diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/_routing_map_provider_common.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/_routing_map_provider_common.py index 1bc0286fe7b6..bb766097489d 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/_routing_map_provider_common.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/_routing_map_provider_common.py @@ -171,6 +171,7 @@ def _handle_transient_snapshot_retry_decision( ) raise CosmosHttpResponseError( status_code=http_constants.StatusCodes.SERVICE_UNAVAILABLE, + sub_status=http_constants.SubStatusCodes.ROUTING_MAP_SNAPSHOT_INCONSISTENT, message=( "Routing-map fetch for collection '{}' returned overlapping " "or gapped ranges on {} attempt(s)." @@ -295,6 +296,86 @@ def _resolve_endpoint(client: Any) -> str: return f"__unknown_{id(client)}__" + + +# --------------------------------------------------------------------------- +# /pkranges change-feed drain helpers (shared by sync + async providers) +# --------------------------------------------------------------------------- +# +# These helpers hoist the *pure decision logic* of the routing-map change-feed +# drain out of the sync and async providers so a future bug-fix lands in one +# place. The providers still own the I/O-shaped parts that genuinely differ: +# - sync uses ``ranges.extend(list(generator))`` +# - async uses ``async for item in generator: ...`` +# Everything else (per-page state transitions) lives here. + + +class _DrainPageDecision: + """Outcome of evaluating a single /pkranges drain page.""" + + CONTINUE = "continue" + STOP_DRAINED = "stop_drained" + + +def evaluate_drain_page( + *, + page_new_etag: Optional[str], + current_if_none_match: Optional[str], + new_etag: Optional[str], + seen_any_etag: bool, + status_code: Optional[int], +) -> Tuple[str, Optional[str], Optional[str], bool]: + """Decide whether to keep draining the /pkranges change feed. + + Pure function: no I/O. The sole termination signal is literal HTTP + ``304 Not Modified`` (matching Java, .NET v3, and Go). ``status_code`` + is required: production callers wire it via the + ``_internal_response_status_capture`` sidecar populated by + ``_synchronized_request`` / ``_asynchronous_request`` before any + return, so it is always a concrete int by the time we land here. + There is intentionally no secondary safety net (e.g. a page cap) + here -- peer SDKs (.NET v3, Java, Go) all rely solely on the 304 + termination predicate and we mirror that contract. + + :keyword page_new_etag: ETag header from the current page response, if any. + :paramtype page_new_etag: str or None + :keyword current_if_none_match: The ``If-None-Match`` we sent for this page. + :paramtype current_if_none_match: str or None + :keyword new_etag: Running accumulator for the final etag to publish. + :paramtype new_etag: str or None + :keyword bool seen_any_etag: Whether the service has ever surfaced an ETag + across the drain so far. + :keyword status_code: HTTP status code of the page response. Required at runtime; + ``None`` indicates the response-status sidecar was not wired by the caller and + raises ``RuntimeError``. Typed as ``Optional[int]`` so callers that read the + status from a sidecar list typed as ``List[Optional[int]]`` (whose first slot + is ``None`` until populated by ``_synchronized_request`` / + ``_asynchronous_request``) satisfy mypy without an extra cast. + :paramtype status_code: int or None + + :returns: ``(decision, new_etag, next_if_none_match, seen_any_etag)``. + ``next_if_none_match`` is only meaningful when ``decision == CONTINUE``. + :rtype: tuple + """ + if status_code is None: + raise RuntimeError( + "evaluate_drain_page invoked with status_code=None. The /pkranges " + "drain loop requires the _internal_response_status_capture sidecar " + "to be wired by the caller; this indicates a programming error in " + "the routing-map provider." + ) + + if page_new_etag: + seen_any_etag = True + new_etag = page_new_etag + + if status_code == http_constants.StatusCodes.NOT_MODIFIED: + return (_DrainPageDecision.STOP_DRAINED, new_etag, current_if_none_match, seen_any_etag) + + next_inm = page_new_etag if page_new_etag else current_if_none_match + return (_DrainPageDecision.CONTINUE, new_etag, next_inm, seen_any_etag) + + class _IncrementalMergeFailed(Exception): """Private exception type raised by :func:`process_fetched_ranges` when the incremental update cannot resolve all partition key ranges. diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/aio/routing_map_provider.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/aio/routing_map_provider.py index f25d21cbdf1e..668adbae90d2 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/aio/routing_map_provider.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/aio/routing_map_provider.py @@ -41,6 +41,8 @@ _OverlapDetected, _GapDetected, _handle_transient_snapshot_retry_decision, + _DrainPageDecision, + evaluate_drain_page, ) @@ -334,6 +336,7 @@ async def get_routing_map( return self._collection_routing_map_by_item.get(collection_id) + # pylint: disable=too-many-statements,too-many-locals async def _fetch_routing_map( self, collection_link: str, @@ -377,35 +380,84 @@ async def _fetch_routing_map( inconsistency_attempt_count = 0 while True: - request_kwargs = dict(kwargs) - response_headers: CaseInsensitiveDict = CaseInsensitiveDict() - request_kwargs['_internal_response_headers_capture'] = response_headers - - # Prepare sanitised options and headers for the PK-range fetch. + ranges: List[Dict[str, Any]] = [] + # Start the change-feed drain at the previous map's etag (if any). + # On subsequent drain pages we advance this with the etag returned + # for the previous page so the service returns "what's new since X" + # until it eventually responds with 304 / no new ranges, mirroring + # the .NET and Go SDK behaviour. + current_if_none_match = ( + current_previous_map.change_feed_etag if current_previous_map else None + ) + new_etag = current_if_none_match + # Track whether the service ever surfaced an ETag header during this + # drain attempt. If it never did, we want ``process_fetched_ranges`` + # to surface the "no ETag" observability warning rather than + # silently treating ``current_if_none_match`` as the fresh etag. + seen_any_etag = False + + # Hoist: ``prepare_fetch_options_and_headers`` is loop-invariant + # for this drain attempt -- ``change_feed_options`` depends only on + # ``feed_options`` and the headers it builds depend only on + # ``current_previous_map.change_feed_etag``, neither of which + # change inside the inner drain loop. Compute them once here; the + # only per-page mutation is the ``If-None-Match`` override below. + base_kwargs_for_headers: Dict[str, Any] = dict(kwargs) change_feed_options = prepare_fetch_options_and_headers( - current_previous_map, feed_options, request_kwargs + current_previous_map, feed_options, base_kwargs_for_headers ) + base_headers: Dict[str, Any] = base_kwargs_for_headers['headers'] - ranges: List[Dict[str, Any]] = [] - try: - pk_range_generator = self._document_client._ReadPartitionKeyRanges( - collection_link, - change_feed_options, - **request_kwargs - ) - async for item in pk_range_generator: - ranges.append(item) - - except CosmosHttpResponseError as e: - logger.error( # pylint: disable=do-not-log-exceptions-if-not-debug,do-not-log-raised-errors - "Failed to read partition key ranges for collection '%s': %s", collection_link, e) - raise + while True: + request_kwargs = dict(kwargs) + # Shallow-copy ``base_headers`` so the per-iter + # ``If-None-Match`` override does not bleed across iterations. + request_kwargs['headers'] = dict(base_headers) + response_headers: CaseInsensitiveDict = CaseInsensitiveDict() + request_kwargs['_internal_response_headers_capture'] = response_headers + # Sidecar list -- populated by _Request with the raw wire + # status. Lets us terminate on literal 304 (matching peer + # SDKs) instead of inferring it from an empty page. + status_capture: List[Optional[int]] = [None] + request_kwargs['_internal_response_status_capture'] = status_capture + + # Override If-None-Match with the running etag from the drain + # so each page advances. ``prepare_fetch_options_and_headers`` + # only sets it from ``current_previous_map.change_feed_etag`` + # which never advances during this drain. + drain_headers = request_kwargs['headers'] + if current_if_none_match: + drain_headers[http_constants.HttpHeaders.IfNoneMatch] = current_if_none_match + else: + drain_headers.pop(http_constants.HttpHeaders.IfNoneMatch, None) - new_etag = response_headers.get(http_constants.HttpHeaders.ETag) + try: + pk_range_generator = self._document_client._ReadPartitionKeyRanges( + collection_link, + change_feed_options, + **request_kwargs + ) + ranges.extend([item async for item in pk_range_generator]) + except CosmosHttpResponseError as e: + logger.error( # pylint: disable=do-not-log-exceptions-if-not-debug,do-not-log-raised-errors + "Failed to read partition key ranges for collection '%s': %s", + collection_link, e) + raise + + decision, new_etag, current_if_none_match, seen_any_etag = evaluate_drain_page( + page_new_etag=response_headers.get(http_constants.HttpHeaders.ETag), + current_if_none_match=current_if_none_match, + new_etag=new_etag, + seen_any_etag=seen_any_etag, + status_code=status_capture[0], + ) + if decision == _DrainPageDecision.STOP_DRAINED: + break try: + effective_new_etag = new_etag if seen_any_etag else None return process_fetched_ranges( - ranges, current_previous_map, collection_id, collection_link, new_etag + ranges, current_previous_map, collection_id, collection_link, effective_new_etag ) except _IncrementalMergeFailed: if current_previous_map is not None and incomplete_attempt_count < _INCOMPLETE_ROUTING_MAP_MAX_RETRIES: diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_map_provider.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_map_provider.py index 297bbdec5504..c2cca7bb2ec1 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_map_provider.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_map_provider.py @@ -41,6 +41,8 @@ _OverlapDetected, _GapDetected, _handle_transient_snapshot_retry_decision, + _DrainPageDecision, + evaluate_drain_page, ) if TYPE_CHECKING: @@ -300,7 +302,7 @@ def get_routing_map( return self._collection_routing_map_by_item.get(collection_id) - # pylint: disable=too-many-statements + # pylint: disable=too-many-statements,too-many-locals def _fetch_routing_map( self, collection_link: str, @@ -345,34 +347,87 @@ def _fetch_routing_map( inconsistency_attempt_count = 0 while True: - request_kwargs = dict(kwargs) - response_headers: CaseInsensitiveDict = CaseInsensitiveDict() - request_kwargs['_internal_response_headers_capture'] = response_headers - - # Prepare sanitised options and headers for the PK-range fetch. + ranges: List[Dict[str, Any]] = [] + # Start the change-feed drain at the previous map's etag (if any). + # On subsequent drain pages we advance this with the etag returned + # for the previous page so the service returns "what's new since X" + # until it eventually responds with 304 / no new ranges, mirroring + # the .NET and Go SDK behaviour and the async provider. + current_if_none_match = ( + current_previous_map.change_feed_etag if current_previous_map else None + ) + new_etag = current_if_none_match + # Track whether the service ever surfaced an ETag header during this + # drain attempt. If it never did, we want ``process_fetched_ranges`` + # to surface the "no ETag" observability warning rather than + # silently treating ``current_if_none_match`` as the fresh etag. + seen_any_etag = False + + # Hoist: ``prepare_fetch_options_and_headers`` is loop-invariant + # for this drain attempt -- ``change_feed_options`` depends only on + # ``feed_options`` and the headers it builds depend only on + # ``current_previous_map.change_feed_etag``, neither of which + # change inside the inner drain loop. Compute them once here; the + # only per-page mutation is the ``If-None-Match`` override below. + base_kwargs_for_headers: Dict[str, Any] = dict(kwargs) change_feed_options = prepare_fetch_options_and_headers( - current_previous_map, feed_options, request_kwargs + current_previous_map, feed_options, base_kwargs_for_headers ) + base_headers: Dict[str, Any] = base_kwargs_for_headers['headers'] - ranges: List[Dict[str, Any]] = [] - try: - pk_range_generator = self._document_client._ReadPartitionKeyRanges( - collection_link, - change_feed_options, - **request_kwargs - ) - ranges.extend(list(pk_range_generator)) - - except CosmosHttpResponseError as e: - logger.error( # pylint: disable=do-not-log-exceptions-if-not-debug,do-not-log-raised-errors - "Failed to read partition key ranges for collection '%s': %s", collection_link, e) - raise + while True: + request_kwargs = dict(kwargs) + # Shallow-copy ``base_headers`` so the per-iter + # ``If-None-Match`` override does not bleed across iterations. + request_kwargs['headers'] = dict(base_headers) + response_headers: CaseInsensitiveDict = CaseInsensitiveDict() + request_kwargs['_internal_response_headers_capture'] = response_headers + # Sidecar list -- populated by _Request with the raw wire + # status. Lets us terminate on literal 304 (matching peer + # SDKs) instead of inferring it from an empty ItemPaged page. + status_capture: List[Optional[int]] = [None] + request_kwargs['_internal_response_status_capture'] = status_capture + + # Override If-None-Match with the running etag from the drain + # so each page advances. ``prepare_fetch_options_and_headers`` + # only sets it from ``current_previous_map.change_feed_etag`` + # which never advances during this drain. + drain_headers = request_kwargs['headers'] + if current_if_none_match: + drain_headers[http_constants.HttpHeaders.IfNoneMatch] = current_if_none_match + else: + drain_headers.pop(http_constants.HttpHeaders.IfNoneMatch, None) - new_etag = response_headers.get(http_constants.HttpHeaders.ETag) + page_ranges: List[Dict[str, Any]] = [] + try: + pk_range_generator = self._document_client._ReadPartitionKeyRanges( + collection_link, + change_feed_options, + **request_kwargs + ) + page_ranges.extend(list(pk_range_generator)) + except CosmosHttpResponseError as e: + logger.error( # pylint: disable=do-not-log-exceptions-if-not-debug,do-not-log-raised-errors + "Failed to read partition key ranges for collection '%s': %s", + collection_link, e) + raise + + ranges.extend(page_ranges) + + decision, new_etag, current_if_none_match, seen_any_etag = evaluate_drain_page( + page_new_etag=response_headers.get(http_constants.HttpHeaders.ETag), + current_if_none_match=current_if_none_match, + new_etag=new_etag, + seen_any_etag=seen_any_etag, + status_code=status_capture[0], + ) + if decision == _DrainPageDecision.STOP_DRAINED: + break try: + effective_new_etag = new_etag if seen_any_etag else None return process_fetched_ranges( - ranges, current_previous_map, collection_id, collection_link, new_etag + ranges, current_previous_map, collection_id, collection_link, effective_new_etag ) except _IncrementalMergeFailed: if current_previous_map is not None and incomplete_attempt_count < _INCOMPLETE_ROUTING_MAP_MAX_RETRIES: diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_synchronized_request.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_synchronized_request.py index 1e97080ec2ae..1a7e24e5feba 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_synchronized_request.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_synchronized_request.py @@ -89,6 +89,13 @@ def _Request(global_endpoint_manager, request_params, connection_policy, pipelin kwargs.pop(_Constants.OperationStartTime, None) # Pop internal flags that should not be passed to the HTTP layer kwargs.pop("_internal_pk_range_fetch", None) + # Sidecar mutable list (length 1) used by the /pkranges change-feed drain + # loop in ``routing_map_provider`` to observe the raw HTTP status without + # parsing headers. We populate ``status_capture[0]`` after the response is + # received, so callers can implement a literal ``status == 304`` drain + # termination check (matching peer SDKs) instead of relying on + # ``ItemPaged`` materializing 304 as an empty page. + status_capture = kwargs.pop("_internal_response_status_capture", None) connection_timeout = connection_policy.RequestTimeout connection_timeout = kwargs.pop("connection_timeout", connection_timeout) read_timeout = connection_policy.ReadTimeout @@ -174,6 +181,12 @@ def _Request(global_endpoint_manager, request_params, connection_policy, pipelin ) response = response.http_response + if status_capture is not None: + # Length-1 list pattern: written-into by _Request, read by caller + # after _ReadPartitionKeyRanges returns. Set before any raise so a + # 304 (which never raises -- only >= 400 does) and a 4xx/5xx both + # surface the wire status to drain-loop observers. + status_capture[0] = response.status_code headers = copy.copy(response.headers) data = response.body() diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_version.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_version.py index 05058b045e80..f84cd3a15991 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_version.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_version.py @@ -19,4 +19,4 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -VERSION = "4.16.0" +VERSION = "4.16.1" diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py index 3d0be6828bd7..ce7d5a44536c 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py @@ -59,6 +59,13 @@ async def _Request(global_endpoint_manager, request_params, connection_policy, p kwargs.pop(_Constants.OperationStartTime, None) # Pop internal flags that should not be passed to the HTTP layer kwargs.pop("_internal_pk_range_fetch", None) + # Sidecar mutable list (length 1) used by the /pkranges change-feed drain + # loop in ``routing_map_provider`` to observe the raw HTTP status without + # parsing headers. We populate ``status_capture[0]`` after the response is + # received, so callers can implement a literal ``status == 304`` drain + # termination check (matching peer SDKs) instead of relying on + # ``AsyncItemPaged`` materializing 304 as an empty page. + status_capture = kwargs.pop("_internal_response_status_capture", None) connection_timeout = connection_policy.RequestTimeout read_timeout = connection_policy.ReadTimeout connection_timeout = kwargs.pop("connection_timeout", connection_timeout) @@ -138,6 +145,12 @@ async def _Request(global_endpoint_manager, request_params, connection_policy, p ) response = response.http_response + if status_capture is not None: + # Length-1 list pattern: written-into by _Request, read by caller + # after _ReadPartitionKeyRanges returns. Set before any raise so a + # 304 (which never raises -- only >= 400 does) and a 4xx/5xx both + # surface the wire status to drain-loop observers. + status_capture[0] = response.status_code headers = copy.copy(response.headers) data = response.body() diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/http_constants.py b/sdk/cosmos/azure-cosmos/azure/cosmos/http_constants.py index bf55fc5d735c..4bfd7a574524 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/http_constants.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/http_constants.py @@ -451,6 +451,12 @@ class SubStatusCodes: # 503: Service Unavailable due to region being out of capacity for bindable partitions INSUFFICIENT_BINDABLE_PARTITIONS = 1007 + # 503: Routing-map (/pkranges) drain produced overlapping or gapped ranges + # across the configured number of retries (transient snapshot inconsistency). + # Surfaced by ``_handle_transient_snapshot_retry_decision`` so callers and + # telemetry can distinguish this client-side condition from backend 503s. + ROUTING_MAP_SNAPSHOT_INCONSISTENT = 21015 + # Client Side substatus codes THROUGHPUT_OFFER_NOT_FOUND = 10004 diff --git a/sdk/cosmos/azure-cosmos/tests/routing/test_routing_map_provider.py b/sdk/cosmos/azure-cosmos/tests/routing/test_routing_map_provider.py index 3be6cceefd60..2f24bf893262 100644 --- a/sdk/cosmos/azure-cosmos/tests/routing/test_routing_map_provider.py +++ b/sdk/cosmos/azure-cosmos/tests/routing/test_routing_map_provider.py @@ -31,10 +31,25 @@ class TestRoutingMapProvider(unittest.TestCase): @staticmethod def _capture_internal_headers(kwargs, etag): + """Capture ETag header and HTTP status into the drain-loop sidecars. + + Returns ``True`` when this call should behave like a wire 304 — i.e. + the drain loop's ``If-None-Match`` matches the etag this mock is + about to return. Mocks that simulate a stable snapshot pass a stable + etag here so the drain terminates after one data page + one 304. + Mocks that simulate a snapshot change advance to a new etag value + on the next "logical" drain so the previous INM no longer matches. + """ + inm = (kwargs.get('headers') or {}).get('If-None-Match') + is_304 = inm is not None and inm == etag captured_headers = kwargs.get('_internal_response_headers_capture') if captured_headers is not None: captured_headers.clear() captured_headers.update({'ETag': etag}) + status_capture = kwargs.get('_internal_response_status_capture') + if status_capture is not None: + status_capture[0] = 304 if is_304 else 200 + return is_304 class MockedCosmosClientConnection(object): @@ -43,7 +58,8 @@ def __init__(self, partition_key_ranges): self.url_connection = "https://mock-test.documents.azure.com:443/" def _ReadPartitionKeyRanges(self, _collection_link: str, _feed_options: Optional[Mapping[str, Any]] = None, **kwargs): - TestRoutingMapProvider._capture_internal_headers(kwargs, '"test-etag-1"') + if TestRoutingMapProvider._capture_internal_headers(kwargs, '"test-etag-1"'): + return [] return self.partition_key_ranges def tearDown(self): @@ -246,7 +262,8 @@ def test_fetch_routing_map_preserves_user_response_hook_and_internal_etag_captur class HookAwareClient: def _ReadPartitionKeyRanges(self, _collection_link, feed_options=None, **kwargs): - TestRoutingMapProvider._capture_internal_headers(kwargs, expected_internal_etag) + if TestRoutingMapProvider._capture_internal_headers(kwargs, expected_internal_etag): + return [] response_hook = kwargs.get('response_hook') if response_hook: response_hook({'ETag': '"user-hook-etag"'}, None) @@ -275,7 +292,8 @@ def test_get_routing_map_returns_cached_on_second_call(self): class CountingClient: def _ReadPartitionKeyRanges(self, _collection_link, feed_options=None, **kwargs): call_count['count'] += 1 - TestRoutingMapProvider._capture_internal_headers(kwargs, '"test-etag-1"') + if TestRoutingMapProvider._capture_internal_headers(kwargs, '"test-etag-1"'): + return [] return original_ranges provider = PartitionKeyRangeCache(CountingClient()) @@ -285,7 +303,8 @@ def _ReadPartitionKeyRanges(self, _collection_link, feed_options=None, **kwargs) result2 = provider.get_routing_map(collection_link, feed_options={}) self.assertIs(result1, result2, "Second call should return the exact same cached object") - self.assertEqual(call_count['count'], 1, "Service should only be called once") + # One logical drain == data page + final 304 page (matches peer SDKs). + self.assertEqual(call_count['count'], 2, "Service should only be called once (data page + 304)") def test_get_routing_map_force_refresh(self): """force_refresh=True causes a re-fetch even when cache is populated. @@ -308,8 +327,13 @@ def test_get_routing_map_force_refresh(self): class CountingClient: def _ReadPartitionKeyRanges(self, _collection_link, feed_options=None, **kwargs): call_count['count'] += 1 - TestRoutingMapProvider._capture_internal_headers(kwargs, f'"test-etag-{call_count["count"]}"') - if call_count['count'] == 1: + # Two logical phases: initial load (calls 1-2) and force_refresh (calls 3-4). + # Each phase uses a stable etag so the drain terminates after data + 304. + phase = (call_count['count'] + 1) // 2 + etag = f'"test-etag-{phase}"' + if TestRoutingMapProvider._capture_internal_headers(kwargs, etag): + return [] + if phase == 1: return original_ranges return split_ranges @@ -317,13 +341,13 @@ def _ReadPartitionKeyRanges(self, _collection_link, feed_options=None, **kwargs) collection_link = "dbs/db/colls/container" result1 = provider.get_routing_map(collection_link, feed_options={}) - self.assertEqual(call_count['count'], 1) + self.assertEqual(call_count['count'], 2) result2 = provider.get_routing_map( collection_link, feed_options={}, force_refresh=True, previous_routing_map=result1 ) - self.assertEqual(call_count['count'], 2, "force_refresh should trigger one incremental fetch") + self.assertEqual(call_count['count'], 4, "force_refresh should trigger one incremental fetch (data + 304)") self.assertIsNotNone(result2) # Verify the split was applied: should now have 6 ranges (original 5 minus '0' plus '5' and '6') self.assertEqual(len(list(result2._orderedPartitionKeyRanges)), 6) @@ -369,7 +393,8 @@ def test_fetch_routing_map_full_load_with_incomplete_ranges_surfaces_503(self): class IncompleteClient: def _ReadPartitionKeyRanges(self, _collection_link, feed_options=None, **kwargs): call_count['count'] += 1 - TestRoutingMapProvider._capture_internal_headers(kwargs, '"incomplete-etag"') + if TestRoutingMapProvider._capture_internal_headers(kwargs, '"incomplete-etag"'): + return [] return incomplete_ranges provider = PartitionKeyRangeCache(IncompleteClient()) @@ -387,8 +412,9 @@ def _ReadPartitionKeyRanges(self, _collection_link, feed_options=None, **kwargs) ) self.assertEqual(ctx.exception.status_code, http_constants.StatusCodes.SERVICE_UNAVAILABLE) # Source the expected attempt count from the production constant so a - # future tuning change updates both sides in lockstep. - self.assertEqual(call_count['count'], _TRANSIENT_SNAPSHOT_RETRY_MAX_ATTEMPTS) + # future tuning change updates both sides in lockstep. Each retry now + # drains to a literal 304, so per attempt the mock sees data + 304 = 2 calls. + self.assertEqual(call_count['count'], _TRANSIENT_SNAPSHOT_RETRY_MAX_ATTEMPTS * 2) def test_fetch_routing_map_incremental_with_parents(self): """Incremental update correctly merges child ranges that reference a parent.""" @@ -411,7 +437,8 @@ def test_fetch_routing_map_incremental_with_parents(self): class DeltaClient: def _ReadPartitionKeyRanges(self, _collection_link, feed_options=None, **kwargs): - TestRoutingMapProvider._capture_internal_headers(kwargs, '"etag-2"') + if TestRoutingMapProvider._capture_internal_headers(kwargs, '"etag-2"'): + return [] return delta_ranges provider = PartitionKeyRangeCache(DeltaClient()) @@ -455,8 +482,15 @@ def _ReadPartitionKeyRanges(self, _collection_link, feed_options=None, **kwargs) call_count['count'] += 1 headers = kwargs.get('headers', {}) captured_headers_list.append(headers.copy()) - TestRoutingMapProvider._capture_internal_headers(kwargs, f'"etag-{call_count["count"]}"') - if call_count['count'] <= 2: + # Three logical phases (each = data page + 304): + # phase 1 (calls 1-2): initial incremental + # phase 2 (calls 3-4): incremental retry (same prev map) + # phase 3 (calls 5-6): full-load fallback + phase = (call_count['count'] + 1) // 2 + etag = f'"etag-{phase}"' + if TestRoutingMapProvider._capture_internal_headers(kwargs, etag): + return [] + if phase <= 2: # Return a child with missing parent to force incremental retry, # then full-load fallback. return [{'id': '99', 'minInclusive': '', 'maxExclusive': 'FF', 'parents': ['MISSING']}] @@ -475,16 +509,17 @@ def _ReadPartitionKeyRanges(self, _collection_link, feed_options=None, **kwargs) ) self.assertIsNotNone(result) - self.assertEqual(len(captured_headers_list), 3) + # 3 logical drains x (data + 304) = 6 wire calls. + self.assertEqual(len(captured_headers_list), 6) - # First call (incremental) should have IfNoneMatch + # Call 1 (incremental, first data page) should have IfNoneMatch seeded from the prev map. self.assertIn(http_constants.HttpHeaders.IfNoneMatch, captured_headers_list[0]) - # Second call is incremental retry, so it should still carry IfNoneMatch. - self.assertIn(http_constants.HttpHeaders.IfNoneMatch, captured_headers_list[1]) + # Call 3 is incremental retry (same prev map), so it should still carry IfNoneMatch. + self.assertIn(http_constants.HttpHeaders.IfNoneMatch, captured_headers_list[2]) - # Third call is full-load fallback and must clear stale IfNoneMatch. - self.assertNotIn(http_constants.HttpHeaders.IfNoneMatch, captured_headers_list[2]) + # Call 5 is full-load fallback and must clear stale IfNoneMatch. + self.assertNotIn(http_constants.HttpHeaders.IfNoneMatch, captured_headers_list[4]) def test_fetch_routing_map_merge_parents0_evicted_later_parent_cached(self): """Merge where parents[0] is an evicted grandparent but a later parent IS in cache. @@ -517,7 +552,8 @@ def test_fetch_routing_map_merge_parents0_evicted_later_parent_cached(self): class MergeClient: def _ReadPartitionKeyRanges(self, _collection_link, feed_options=None, **kwargs): call_count['count'] += 1 - TestRoutingMapProvider._capture_internal_headers(kwargs, '"etag-C"') + if TestRoutingMapProvider._capture_internal_headers(kwargs, '"etag-C"'): + return [] return delta_ranges provider = PartitionKeyRangeCache(MergeClient()) @@ -533,7 +569,8 @@ def _ReadPartitionKeyRanges(self, _collection_link, feed_options=None, **kwargs) ) self.assertIsNotNone(result, "Should succeed incrementally — parents[1] is in cache") - self.assertEqual(call_count['count'], 1, "Should only call service once (no fallback needed)") + # One logical drain = data page + 304. + self.assertEqual(call_count['count'], 2, "Should only call service once logically (data + 304)") ranges = list(result._orderedPartitionKeyRanges) self.assertEqual(len(ranges), 3) ids = [r['id'] for r in ranges] @@ -564,7 +601,8 @@ def test_fetch_routing_map_merge_all_parents_cached(self): class MergeClient: def _ReadPartitionKeyRanges(self, _collection_link, feed_options=None, **kwargs): - TestRoutingMapProvider._capture_internal_headers(kwargs, '"etag-2"') + if TestRoutingMapProvider._capture_internal_headers(kwargs, '"etag-2"'): + return [] return delta_ranges provider = PartitionKeyRangeCache(MergeClient()) @@ -634,8 +672,12 @@ def test_fetch_routing_map_two_rapid_splits_all_parents_missing(self): class RapidSplitClient: def _ReadPartitionKeyRanges(self, _collection_link, feed_options=None, **kwargs): call_count['count'] += 1 - TestRoutingMapProvider._capture_internal_headers(kwargs, f'"etag-{call_count["count"]}"') - if call_count['count'] == 1: + # Three logical phases: incremental (1) -> incremental retry (2) -> full fallback (3). + phase = (call_count['count'] + 1) // 2 + etag = f'"etag-{phase}"' + if TestRoutingMapProvider._capture_internal_headers(kwargs, etag): + return [] + if phase <= 2: return delta_ranges return full_ranges @@ -652,10 +694,11 @@ def _ReadPartitionKeyRanges(self, _collection_link, feed_options=None, **kwargs) ) self.assertIsNotNone(result, "Should succeed via full refresh fallback") + # 3 logical drains x (data + 304) = 6 wire calls. self.assertEqual( call_count['count'], - 3, - "Should call service three times (incremental + incremental retry + full fallback)", + 6, + "Should drain three times (incremental + incremental retry + full fallback), data + 304 each", ) ranges = list(result._orderedPartitionKeyRanges) self.assertEqual(len(ranges), 5) @@ -692,7 +735,8 @@ def test_fetch_routing_map_merge_range_info_from_correct_parent(self): class MergeClient: def _ReadPartitionKeyRanges(self, _collection_link, feed_options=None, **kwargs): - TestRoutingMapProvider._capture_internal_headers(kwargs, '"etag-2"') + if TestRoutingMapProvider._capture_internal_headers(kwargs, '"etag-2"'): + return [] return delta_ranges provider = PartitionKeyRangeCache(MergeClient()) @@ -728,7 +772,12 @@ def test_force_refresh_without_previous_map_triggers_targeted_fetch(self): class CountingClient: def _ReadPartitionKeyRanges(self, _collection_link, feed_options=None, **kwargs): call_count['count'] += 1 - TestRoutingMapProvider._capture_internal_headers(kwargs, f'"test-etag-{call_count["count"]}"') + # Two logical phases (initial load, targeted force_refresh fetch); + # phase-stable etags so each drain terminates after data + 304. + phase = (call_count['count'] + 1) // 2 + etag = f'"test-etag-{phase}"' + if TestRoutingMapProvider._capture_internal_headers(kwargs, etag): + return [] return original_ranges provider = PartitionKeyRangeCache(CountingClient()) @@ -736,7 +785,7 @@ def _ReadPartitionKeyRanges(self, _collection_link, feed_options=None, **kwargs) # Initial load result1 = provider.get_routing_map(collection_link, feed_options={}) - self.assertEqual(call_count['count'], 1) + self.assertEqual(call_count['count'], 2) self.assertIsNotNone(result1) # force_refresh=True without previous_routing_map should still fetch once. @@ -744,7 +793,10 @@ def _ReadPartitionKeyRanges(self, _collection_link, feed_options=None, **kwargs) collection_link, feed_options={}, force_refresh=True ) - self.assertEqual(call_count['count'], 2, "force_refresh=True without previous_routing_map should trigger fetch") + self.assertEqual( + call_count['count'], 4, + "force_refresh=True without previous_routing_map should trigger one drain (data + 304)", + ) self.assertIsNotNone(result2) def test_concurrent_refresh_serialized_by_lock(self): @@ -763,7 +815,11 @@ def _ReadPartitionKeyRanges(self, _collection_link, feed_options=None, **kwargs) call_count['count'] += 1 # Simulate a slow service call to widen the contention window fetch_event.wait(timeout=2) - TestRoutingMapProvider._capture_internal_headers(kwargs, f'"test-etag-{call_count["count"]}"') + # Phase-stable etag so each drain terminates after data + 304. + phase = (call_count['count'] + 1) // 2 + etag = f'"test-etag-{phase}"' + if TestRoutingMapProvider._capture_internal_headers(kwargs, etag): + return [] return original_ranges provider = PartitionKeyRangeCache(SlowCountingClient()) @@ -772,7 +828,8 @@ def _ReadPartitionKeyRanges(self, _collection_link, feed_options=None, **kwargs) # Populate cache with initial map fetch_event.set() # Let the initial load go fast initial_map = provider.get_routing_map(collection_link, feed_options={}) - self.assertEqual(call_count['count'], 1) + # One logical drain = data page + 304. + self.assertEqual(call_count['count'], 2) fetch_event.clear() # Now make subsequent fetches slow results = [None] * 5 @@ -819,7 +876,11 @@ def _ReadPartitionKeyRanges(self, _collection_link, feed_options=None, **kwargs) call_count['count'] += 1 import time time.sleep(0.1) # Simulate network delay - TestRoutingMapProvider._capture_internal_headers(kwargs, f'"etag-{call_count["count"]}"') + # Phase-stable etag so each drain terminates after data + 304. + phase = (call_count['count'] + 1) // 2 + etag = f'"etag-{phase}"' + if TestRoutingMapProvider._capture_internal_headers(kwargs, etag): + return [] return original_ranges provider = PartitionKeyRangeCache(SlowClient()) diff --git a/sdk/cosmos/azure-cosmos/tests/routing/test_routing_map_provider_async.py b/sdk/cosmos/azure-cosmos/tests/routing/test_routing_map_provider_async.py index 2345a3eea7c3..8650e326f800 100644 --- a/sdk/cosmos/azure-cosmos/tests/routing/test_routing_map_provider_async.py +++ b/sdk/cosmos/azure-cosmos/tests/routing/test_routing_map_provider_async.py @@ -34,10 +34,25 @@ class TestRoutingMapProviderAsync(unittest.IsolatedAsyncioTestCase): @staticmethod def _capture_internal_headers(kwargs, etag): - captured_headers = kwargs.get('_internal_response_headers_capture') - if captured_headers is not None: - captured_headers.clear() - captured_headers.update({'ETag': etag}) + """Capture ETag header and HTTP status into the drain-loop sidecars. + + Returns ``True`` when this call should behave like a wire 304 — i.e. + the drain loop's ``If-None-Match`` matches the etag this mock is + about to return. Mocks that simulate a stable snapshot pass a stable + etag here so the drain terminates after one data page + one 304. + Mocks that simulate a snapshot change advance to a new etag value + on the next "logical" drain so the previous INM no longer matches. + """ + _inm = (kwargs.get('headers') or {}).get('If-None-Match') + _is_304 = _inm is not None and _inm == etag + _status_capture = kwargs.get('_internal_response_status_capture') + if _status_capture is not None: + _status_capture[0] = 304 if _is_304 else 200 + _captured_headers = kwargs.get('_internal_response_headers_capture') + if _captured_headers is not None: + _captured_headers.clear() + _captured_headers.update({'ETag': etag}) + return _is_304 class MockedCosmosClientConnection(object): """Mock that returns partition key ranges as an async generator.""" @@ -48,11 +63,13 @@ def __init__(self, partition_key_ranges): def _ReadPartitionKeyRanges(self, _collection_link: str, _feed_options: Optional[Mapping[str, Any]] = None, **kwargs): - TestRoutingMapProviderAsync._capture_internal_headers(kwargs, '"test-etag-1"') + is_304 = TestRoutingMapProviderAsync._capture_internal_headers(kwargs, '"test-etag-1"') ranges = self.partition_key_ranges async def _gen(): + if is_304: + return for r in ranges: yield r @@ -215,12 +232,14 @@ def __init__(self, partition_key_ranges): self.partition_key_ranges = partition_key_ranges def _ReadPartitionKeyRanges(self, _collection_link, feed_options=None, **kwargs): - TestRoutingMapProviderAsync._capture_internal_headers(kwargs, expected_internal_etag) + is_304 = TestRoutingMapProviderAsync._capture_internal_headers(kwargs, expected_internal_etag) response_hook = kwargs.get('response_hook') if response_hook: response_hook({'ETag': '"user-hook-etag"'}, None) async def _gen(): + if is_304: + return for r in self.partition_key_ranges: yield r @@ -236,7 +255,7 @@ def user_hook(headers, _): self.assertIsNotNone(result) self.assertEqual(result.change_feed_etag, expected_internal_etag) - self.assertEqual(hook_calls, ['"user-hook-etag"']) + self.assertEqual(hook_calls, ['"user-hook-etag"', '"user-hook-etag"']) async def test_get_routing_map_returns_cached_on_second_call_async(self): """Second call returns the same cached object without re-fetching.""" @@ -246,9 +265,11 @@ async def test_get_routing_map_returns_cached_on_second_call_async(self): class CountingClient: def _ReadPartitionKeyRanges(self, _collection_link, feed_options=None, **kwargs): call_count['count'] += 1 - TestRoutingMapProviderAsync._capture_internal_headers(kwargs, '"test-etag-1"') + is_304 = TestRoutingMapProviderAsync._capture_internal_headers(kwargs, '"test-etag-1"') async def _gen(): + if is_304: + return for r in original_ranges: yield r @@ -261,7 +282,7 @@ async def _gen(): result2 = await provider.get_routing_map(collection_link, feed_options={}) self.assertIs(result1, result2, "Second call should return the exact same cached object") - self.assertEqual(call_count['count'], 1, "Service should only be called once") + self.assertEqual(call_count['count'], 2, "Service should only be called once (data page + 304)") async def test_get_routing_map_force_refresh_async(self): """force_refresh=True causes a re-fetch even when cache is populated. @@ -284,11 +305,15 @@ async def test_get_routing_map_force_refresh_async(self): class CountingClient: def _ReadPartitionKeyRanges(self, _collection_link, feed_options=None, **kwargs): call_count['count'] += 1 - TestRoutingMapProviderAsync._capture_internal_headers(kwargs, f'"test-etag-{call_count["count"]}"') + # Two logical phases (initial + force_refresh), phase-stable etag. + phase = (call_count['count'] + 1) // 2 + is_304 = TestRoutingMapProviderAsync._capture_internal_headers(kwargs, f'"test-etag-{phase}"') - data = original_ranges if call_count['count'] == 1 else split_ranges + data = original_ranges if phase == 1 else split_ranges async def _gen(): + if is_304: + return for r in data: yield r @@ -298,13 +323,13 @@ async def _gen(): collection_link = "dbs/db/colls/container" result1 = await provider.get_routing_map(collection_link, feed_options={}) - self.assertEqual(call_count['count'], 1) + self.assertEqual(call_count['count'], 2) result2 = await provider.get_routing_map( collection_link, feed_options={}, force_refresh=True, previous_routing_map=result1 ) - self.assertEqual(call_count['count'], 2, "force_refresh should trigger one incremental fetch") + self.assertEqual(call_count['count'], 4, "force_refresh should trigger one incremental drain (data + 304)") self.assertIsNotNone(result2) # Verify the split was applied: should now have 6 ranges (original 5 minus '0' plus '5' and '6') self.assertEqual(len(list(result2._orderedPartitionKeyRanges)), 6) @@ -349,9 +374,11 @@ async def test_fetch_routing_map_full_load_with_incomplete_ranges_surfaces_503_a class IncompleteClient: def _ReadPartitionKeyRanges(self, _collection_link, feed_options=None, **kwargs): call_count['count'] += 1 - TestRoutingMapProviderAsync._capture_internal_headers(kwargs, '"incomplete-etag"') + is_304 = TestRoutingMapProviderAsync._capture_internal_headers(kwargs, '"incomplete-etag"') async def _gen(): + if is_304: + return for r in incomplete_ranges: yield r @@ -376,7 +403,8 @@ async def _no_sleep(_seconds): self.assertEqual(ctx.exception.status_code, http_constants.StatusCodes.SERVICE_UNAVAILABLE) # Source the expected attempt count from the production constant so a # future tuning change updates both sides in lockstep. - self.assertEqual(call_count['count'], _TRANSIENT_SNAPSHOT_RETRY_MAX_ATTEMPTS) + # Each retry drains to a literal 304: data + 304 = 2 calls per attempt. + self.assertEqual(call_count['count'], _TRANSIENT_SNAPSHOT_RETRY_MAX_ATTEMPTS * 2) async def test_fetch_routing_map_incremental_with_parents_async(self): """Incremental update correctly merges child ranges that reference a parent.""" @@ -397,9 +425,11 @@ async def test_fetch_routing_map_incremental_with_parents_async(self): class DeltaClient: def _ReadPartitionKeyRanges(self, _collection_link, feed_options=None, **kwargs): - TestRoutingMapProviderAsync._capture_internal_headers(kwargs, '"etag-2"') + is_304 = TestRoutingMapProviderAsync._capture_internal_headers(kwargs, '"etag-2"') async def _gen(): + if is_304: + return for r in delta_ranges: yield r @@ -446,11 +476,18 @@ def _ReadPartitionKeyRanges(self, _collection_link, feed_options=None, **kwargs) call_count['count'] += 1 headers = kwargs.get('headers', {}) captured_headers_list.append(headers.copy()) - TestRoutingMapProviderAsync._capture_internal_headers(kwargs, f'"etag-{call_count["count"]}"') + # Three logical phases (each = data + 304): + # phase 1 (calls 1-2): incremental + # phase 2 (calls 3-4): incremental retry + # phase 3 (calls 5-6): full fallback + phase = (call_count['count'] + 1) // 2 + is_304 = TestRoutingMapProviderAsync._capture_internal_headers(kwargs, f'"etag-{phase}"') data = ([{'id': '99', 'minInclusive': '', 'maxExclusive': 'FF', - 'parents': ['MISSING']}] if call_count['count'] <= 2 else full_ranges) + 'parents': ['MISSING']}] if phase <= 2 else full_ranges) async def _gen(): + if is_304: + return for r in data: yield r @@ -469,16 +506,17 @@ async def _gen(): ) self.assertIsNotNone(result) - self.assertEqual(len(captured_headers_list), 3) + # 3 logical drains x (data + 304) = 6 wire calls. + self.assertEqual(len(captured_headers_list), 6) - # First call (incremental) should have IfNoneMatch + # Call 1 (incremental) should have IfNoneMatch seeded from prev map. self.assertIn(http_constants.HttpHeaders.IfNoneMatch, captured_headers_list[0]) - # Second call is incremental retry, so it should still carry IfNoneMatch. - self.assertIn(http_constants.HttpHeaders.IfNoneMatch, captured_headers_list[1]) + # Call 3 (incremental retry) still carries IfNoneMatch. + self.assertIn(http_constants.HttpHeaders.IfNoneMatch, captured_headers_list[2]) - # Third call is full-load fallback and must clear stale IfNoneMatch. - self.assertNotIn(http_constants.HttpHeaders.IfNoneMatch, captured_headers_list[2]) + # Call 5 (full-load fallback) must clear stale IfNoneMatch. + self.assertNotIn(http_constants.HttpHeaders.IfNoneMatch, captured_headers_list[4]) async def test_fetch_routing_map_merge_parents0_evicted_later_parent_cached_async(self): """Merge where parents[0] is an evicted grandparent but a later parent IS in cache. @@ -510,9 +548,11 @@ async def test_fetch_routing_map_merge_parents0_evicted_later_parent_cached_asyn class MergeClient: def _ReadPartitionKeyRanges(self, _collection_link, feed_options=None, **kwargs): call_count['count'] += 1 - TestRoutingMapProviderAsync._capture_internal_headers(kwargs, '"etag-C"') + is_304 = TestRoutingMapProviderAsync._capture_internal_headers(kwargs, '"etag-C"') async def _gen(): + if is_304: + return for r in delta_ranges: yield r @@ -531,7 +571,7 @@ async def _gen(): ) self.assertIsNotNone(result, "Should succeed incrementally — parents[1] is in cache") - self.assertEqual(call_count['count'], 1, "Should only call service once (no fallback needed)") + self.assertEqual(call_count['count'], 2, "Should only drain once logically (data + 304)") ranges = list(result._orderedPartitionKeyRanges) self.assertEqual(len(ranges), 3) ids = [r['id'] for r in ranges] @@ -562,9 +602,11 @@ async def test_fetch_routing_map_merge_all_parents_cached_async(self): class MergeClient: def _ReadPartitionKeyRanges(self, _collection_link, feed_options=None, **kwargs): - TestRoutingMapProviderAsync._capture_internal_headers(kwargs, '"etag-2"') + is_304 = TestRoutingMapProviderAsync._capture_internal_headers(kwargs, '"etag-2"') async def _gen(): + if is_304: + return for r in delta_ranges: yield r @@ -632,10 +674,14 @@ async def test_fetch_routing_map_two_rapid_splits_all_parents_missing_async(self class RapidSplitClient: def _ReadPartitionKeyRanges(self, _collection_link, feed_options=None, **kwargs): call_count['count'] += 1 - TestRoutingMapProviderAsync._capture_internal_headers(kwargs, f'"etag-{call_count["count"]}"') - data = delta_ranges if call_count['count'] == 1 else full_ranges + # Three logical phases: incremental (1), incremental retry (2), full fallback (3). + phase = (call_count['count'] + 1) // 2 + is_304 = TestRoutingMapProviderAsync._capture_internal_headers(kwargs, f'"etag-{phase}"') + data = delta_ranges if phase <= 2 else full_ranges async def _gen(): + if is_304: + return for r in data: yield r @@ -656,8 +702,8 @@ async def _gen(): self.assertIsNotNone(result, "Should succeed via full refresh fallback") self.assertEqual( call_count['count'], - 3, - "Should call service three times (incremental + incremental retry + full fallback)", + 6, + "Should drain three times (incremental + incremental retry + full fallback), data + 304 each", ) ranges = list(result._orderedPartitionKeyRanges) self.assertEqual(len(ranges), 5) @@ -694,9 +740,11 @@ async def test_fetch_routing_map_merge_range_info_from_correct_parent_async(self class MergeClient: def _ReadPartitionKeyRanges(self, _collection_link, feed_options=None, **kwargs): - TestRoutingMapProviderAsync._capture_internal_headers(kwargs, '"etag-2"') + is_304 = TestRoutingMapProviderAsync._capture_internal_headers(kwargs, '"etag-2"') async def _gen(): + if is_304: + return for r in delta_ranges: yield r @@ -735,9 +783,13 @@ async def test_force_refresh_without_previous_map_triggers_targeted_fetch_async( class CountingClient: def _ReadPartitionKeyRanges(self, _collection_link, feed_options=None, **kwargs): call_count['count'] += 1 - TestRoutingMapProviderAsync._capture_internal_headers(kwargs, f'"test-etag-{call_count["count"]}"') + # Two logical phases: initial + targeted force_refresh, phase-stable etags. + phase = (call_count['count'] + 1) // 2 + is_304 = TestRoutingMapProviderAsync._capture_internal_headers(kwargs, f'"test-etag-{phase}"') async def _gen(): + if is_304: + return for r in original_ranges: yield r @@ -748,7 +800,7 @@ async def _gen(): # Initial load result1 = await provider.get_routing_map(collection_link, feed_options={}) - self.assertEqual(call_count['count'], 1) + self.assertEqual(call_count['count'], 2) self.assertIsNotNone(result1) # force_refresh=True without previous_routing_map should still fetch once. @@ -756,7 +808,10 @@ async def _gen(): collection_link, feed_options={}, force_refresh=True ) - self.assertEqual(call_count['count'], 2, "force_refresh=True without previous_routing_map should trigger fetch") + self.assertEqual( + call_count['count'], 4, + "force_refresh=True without previous_routing_map should trigger one drain (data + 304)", + ) self.assertIsNotNone(result2) async def test_concurrent_refresh_serialized_by_lock_async(self): @@ -775,7 +830,10 @@ def _ReadPartitionKeyRanges(self, _collection_link, feed_options=None, **kwargs) async def _gen(): await fetch_event.wait() - TestRoutingMapProviderAsync._capture_internal_headers(kwargs, f'"test-etag-{call_count["count"]}"') + # Phase-stable etag so each drain terminates after data + 304. + phase = (call_count['count'] + 1) // 2 + if TestRoutingMapProviderAsync._capture_internal_headers(kwargs, f'"test-etag-{phase}"'): + return for r in original_ranges: yield r @@ -787,7 +845,8 @@ async def _gen(): # Populate cache with initial map (let it go fast) fetch_event.set() initial_map = await provider.get_routing_map(collection_link, feed_options={}) - self.assertEqual(call_count['count'], 1) + # One logical drain = data + 304 = 2 calls. + self.assertEqual(call_count['count'], 2) fetch_event.clear() async def refresh_fn(): @@ -823,7 +882,10 @@ def _ReadPartitionKeyRanges(self, _collection_link, feed_options=None, **kwargs) async def _gen(): await asyncio.sleep(0.05) - TestRoutingMapProviderAsync._capture_internal_headers(kwargs, f'"etag-{call_count["count"]}"') + # Phase-stable etag so each drain terminates after data + 304. + phase = (call_count['count'] + 1) // 2 + if TestRoutingMapProviderAsync._capture_internal_headers(kwargs, f'"etag-{phase}"'): + return for r in original_ranges: yield r diff --git a/sdk/cosmos/azure-cosmos/tests/test_crud.py b/sdk/cosmos/azure-cosmos/tests/test_crud.py index 081e369edc81..8b18197c2d20 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_crud.py +++ b/sdk/cosmos/azure-cosmos/tests/test_crud.py @@ -1393,8 +1393,11 @@ def send(self, request, **kwargs): elapsed_time = time.time() - start_time - # Should fail close to 5 seconds (not wait for all requests) - self.assertLess(elapsed_time, 7) # Allow some overhead + # Should fail close to 5 seconds (not wait for all requests). Upper bound + # is loose to absorb the cold-cache /pkranges drain (a 200+ETag fetch followed + # by a 304 confirmation, see PR #47245) which adds one extra round trip -- + # under DelayedTransport(3s) that is +3s on top of the data-plane delay. + self.assertLess(elapsed_time, 12) # Allow overhead for the cold-cache drain round trips self.assertGreater(elapsed_time, 5) # Should wait at least close to timeout # Verify operation succeeds when no timeout is passed(default is close to 7 days) diff --git a/sdk/cosmos/azure-cosmos/tests/test_crud_async.py b/sdk/cosmos/azure-cosmos/tests/test_crud_async.py index 8dbb95c47402..0d66f44b741e 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_crud_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_crud_async.py @@ -1153,8 +1153,11 @@ async def send(self, request, **kwargs): ) elapsed_time = time.time() - start_time - # Should fail close to 5 seconds (not wait for all requests) - self.assertLess(elapsed_time, 7) # Allow some overhead + # Should fail close to 5 seconds (not wait for all requests). Upper bound + # is loose to absorb the cold-cache /pkranges drain (a 200+ETag fetch followed + # by a 304 confirmation, see PR #47245) which adds one extra round trip -- + # under DelayedTransport(3s) that is +3s on top of the data-plane delay. + self.assertLess(elapsed_time, 12) # Allow overhead for the cold-cache drain round trips self.assertGreater(elapsed_time, 5) # Should wait at least close to timeout finally: await self._delete_container_for_test(created_container.id) diff --git a/sdk/cosmos/azure-cosmos/tests/test_partition_split_query.py b/sdk/cosmos/azure-cosmos/tests/test_partition_split_query.py index b41805253455..bbdd57c4956f 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_partition_split_query.py +++ b/sdk/cosmos/azure-cosmos/tests/test_partition_split_query.py @@ -620,6 +620,13 @@ def test_full_refresh_fallback_stops_infinite_recursion(self): } def mock_read_ranges(*args, **kwargs): + # Mirror the production wire-up: _synchronized_request populates + # this sidecar with the real HTTP status. Without it, the drain + # loop's status==304 termination contract can't trip and the + # loop would run unbounded (OOM in CI). + status_capture = kwargs.get('_internal_response_status_capture') + if status_capture is not None: + status_capture[0] = http_constants.StatusCodes.NOT_MODIFIED return iter([incomplete_range]) with patch.object( @@ -857,7 +864,14 @@ def spy_read_ranges(*args, **kwargs): if call_count['count'] <= 2: # First two calls are incremental attempts; return a child # with a missing parent so merge is incomplete and fallback - # path is exercised. + # path is exercised. Mirror the production wire-up: + # _synchronized_request populates this sidecar with the real + # HTTP status. Without it, the drain loop's status==304 + # termination contract can't trip and evaluate_drain_page + # raises RuntimeError. + status_capture = kwargs.get('_internal_response_status_capture') + if status_capture is not None: + status_capture[0] = http_constants.StatusCodes.NOT_MODIFIED fake_child = { 'id': f'child_{call_count["count"]}', 'minInclusive': '', @@ -887,20 +901,33 @@ def spy_read_ranges(*args, **kwargs): assert result is not None - # Verify 3 calls: incremental + incremental retry + full fallback. - assert call_count['count'] == 3, \ - f"Expected 3 calls to _ReadPartitionKeyRanges, got {call_count['count']}" - - # First two calls should be incremental and include IfNoneMatch. + # Expected call sequence: + # 1. incremental attempt (IfNoneMatch = stale etag) + # 2. incremental retry (IfNoneMatch = stale etag) + # 3. full-load fallback page 1 (no IfNoneMatch -- the cleanup we are testing) + # 4. full-load fallback page 2 (IfNoneMatch = FRESH etag from page 1, + # to receive the 304 terminator that ends + # the drain loop -- peer-SDK parity) + stale_etag = cached_map.change_feed_etag + assert call_count['count'] >= 3, \ + f"Expected at least 3 calls to _ReadPartitionKeyRanges, got {call_count['count']}" + + # First two calls should be incremental and carry the stale IfNoneMatch. first_headers = captured_headers_list[0] - assert http_constants.HttpHeaders.IfNoneMatch in first_headers, \ - "First call (incremental) should have IfNoneMatch header" + assert first_headers.get(http_constants.HttpHeaders.IfNoneMatch) == stale_etag, \ + "First call (incremental) should have stale IfNoneMatch header" second_headers = captured_headers_list[1] - assert http_constants.HttpHeaders.IfNoneMatch in second_headers, \ - "Second call (incremental retry) should have IfNoneMatch header" - - # Third call is full-load fallback and should drop IfNoneMatch. + assert second_headers.get(http_constants.HttpHeaders.IfNoneMatch) == stale_etag, \ + "Second call (incremental retry) should have stale IfNoneMatch header" + + # Third call is full-load fallback and MUST drop IfNoneMatch -- this is + # the bug fix's whole point. Any post-fallback drain pages (call 4+) + # legitimately reuse the etag returned by call 3 as their If-None-Match + # to receive the 304 terminator; that fresh etag may coincidentally equal + # the original stale etag if nothing changed server-side between caching + # and fallback, so we cannot assert "!= stale_etag" on those drain pages. + # The call-3 assertion is the actual production contract. third_headers = captured_headers_list[2] assert http_constants.HttpHeaders.IfNoneMatch not in third_headers, \ "Third call (full load fallback) should NOT have IfNoneMatch header" diff --git a/sdk/cosmos/azure-cosmos/tests/test_partition_split_query_async.py b/sdk/cosmos/azure-cosmos/tests/test_partition_split_query_async.py index 98fdafd0b6f5..9e7ef66dc48b 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_partition_split_query_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_partition_split_query_async.py @@ -612,6 +612,13 @@ async def test_full_load_with_incomplete_ranges_surfaces_503_async(self): } async def mock_read_ranges(*args, **kwargs): + # Mirror the production wire-up: _asynchronous_request populates + # this sidecar with the real HTTP status. Without it, the drain + # loop's status==304 termination contract can't trip and the + # loop would run unbounded (OOM in CI). + status_capture = kwargs.get('_internal_response_status_capture') + if status_capture is not None: + status_capture[0] = http_constants.StatusCodes.NOT_MODIFIED yield incomplete_range with patch.object( @@ -851,7 +858,14 @@ def spy_read_ranges(*args, **kwargs): if call_count['count'] <= 2: # First two calls are incremental attempts; return a child # with a missing parent so merge is incomplete and fallback - # path is exercised. + # path is exercised. Mirror the production wire-up: + # _asynchronous_request populates this sidecar with the real + # HTTP status. Without it, the drain loop's status==304 + # termination contract can't trip and evaluate_drain_page + # raises RuntimeError. + status_capture = kwargs.get('_internal_response_status_capture') + if status_capture is not None: + status_capture[0] = http_constants.StatusCodes.NOT_MODIFIED fake_child = { 'id': f'child_{call_count["count"]}', 'minInclusive': '', @@ -879,20 +893,33 @@ async def gen(): assert result is not None - # Verify 3 calls: incremental + incremental retry + full fallback. - assert call_count['count'] == 3, \ - f"Expected 3 calls to _ReadPartitionKeyRanges, got {call_count['count']}" - - # First two calls should be incremental and include IfNoneMatch. + # Expected call sequence: + # 1. incremental attempt (IfNoneMatch = stale etag) + # 2. incremental retry (IfNoneMatch = stale etag) + # 3. full-load fallback page 1 (no IfNoneMatch -- the cleanup we are testing) + # 4. full-load fallback page 2 (IfNoneMatch = FRESH etag from page 1, + # to receive the 304 terminator that ends + # the drain loop -- peer-SDK parity) + stale_etag = cached_map.change_feed_etag + assert call_count['count'] >= 3, \ + f"Expected at least 3 calls to _ReadPartitionKeyRanges, got {call_count['count']}" + + # First two calls should be incremental and carry the stale IfNoneMatch. first_headers = captured_headers_list[0] - assert http_constants.HttpHeaders.IfNoneMatch in first_headers, \ - "First call (incremental) should have IfNoneMatch header" + assert first_headers.get(http_constants.HttpHeaders.IfNoneMatch) == stale_etag, \ + "First call (incremental) should have stale IfNoneMatch header" second_headers = captured_headers_list[1] - assert http_constants.HttpHeaders.IfNoneMatch in second_headers, \ - "Second call (incremental retry) should have IfNoneMatch header" - - # Third call is full-load fallback and should drop IfNoneMatch. + assert second_headers.get(http_constants.HttpHeaders.IfNoneMatch) == stale_etag, \ + "Second call (incremental retry) should have stale IfNoneMatch header" + + # Third call is full-load fallback and MUST drop IfNoneMatch -- this is + # the bug fix's whole point. Any post-fallback drain pages (call 4+) + # legitimately reuse the etag returned by call 3 as their If-None-Match + # to receive the 304 terminator; that fresh etag may coincidentally equal + # the original stale etag if nothing changed server-side between caching + # and fallback, so we cannot assert "!= stale_etag" on those drain pages. + # The call-3 assertion is the actual production contract. third_headers = captured_headers_list[2] assert http_constants.HttpHeaders.IfNoneMatch not in third_headers, \ "Third call (full load fallback) should NOT have IfNoneMatch header" diff --git a/sdk/cosmos/azure-cosmos/tests/test_pk_range_drain.py b/sdk/cosmos/azure-cosmos/tests/test_pk_range_drain.py new file mode 100644 index 000000000000..d04a8c581dd9 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_pk_range_drain.py @@ -0,0 +1,853 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. + +""" +Sync integration tests for the /pkranges change-feed drain loop in +``PartitionKeyRangeCache._fetch_routing_map``. + +These tests exercise the multi-page drain introduced to fix the +unbounded refresh bug for containers with >~8K partition key ranges. They +mock ``_ReadPartitionKeyRanges`` so a single ``_fetch_routing_map`` call +emits multiple pages, each with its own ETag, and assert on: + + * ETag propagation across pages (per-page ``If-None-Match`` advances). + * Real-wire ``304 Not Modified`` (empty page + unchanged ETag) on the first + fetch preserves the previous map. + * Empty page terminates the drain cleanly. + * Mid-drain non-304 errors propagate without poisoning the cache. +""" + +# pylint: disable=protected-access + +import logging +import sys +import threading +import time +import unittest +from concurrent.futures import ThreadPoolExecutor +from unittest.mock import MagicMock + +import pytest + +from azure.cosmos._routing.routing_map_provider import PartitionKeyRangeCache +from azure.cosmos._routing.collection_routing_map import CollectionRoutingMap +from azure.cosmos import http_constants +from azure.cosmos.exceptions import CosmosHttpResponseError + + +# ========================================================= +# Helpers +# ========================================================= + +def _full_range(range_id="0", min_inclusive="", max_exclusive="FF"): + return { + "id": range_id, + "minInclusive": min_inclusive, + "maxExclusive": max_exclusive, + } + + +def _split_full_range_into(n): + """Return ``n`` non-overlapping ranges spanning ``""`` → ``FF``. + + The shape mirrors what the service emits when a container has been split + into ``n`` physical partitions; ``process_fetched_ranges`` is happy with + any structurally-contiguous list ending at ``FF``. + """ + if n <= 0: + return [] + # Build evenly spaced 2-hex-digit boundaries. + step = 0xFF // n + boundaries = [""] + for i in range(1, n): + boundaries.append(format(i * step, "02X")) + boundaries.append("FF") + return [ + _full_range(str(i), boundaries[i], boundaries[i + 1]) + for i in range(n) + ] + + +def _make_complete_routing_map(collection_id="coll1", etag='"etag-prev"'): + ranges = [(_full_range(), True)] + return CollectionRoutingMap.CompleteRoutingMap(ranges, collection_id, etag) + + +class _PageScript: + """Scripted ``_ReadPartitionKeyRanges`` side-effect for the drain loop. + + Each entry is one of: + * ``('page', ranges_list, etag_value)`` -- emit a page + ETag header. + The wire status is inferred to match production: empty ``ranges_list`` + is treated as the real-wire 304 Not Modified (empty body + unchanged + ETag header), non-empty as 200. Production never surfaces 304 as an + exception (see ``_synchronized_request.py`` -- only ``>= 400`` raises) + so this is the only shape the drain loop ever sees on the wire. + * ``('page', ranges_list, etag_value, status_code)`` -- same, but with + an explicit wire status. Use this to model server bugs (e.g. 304 with + a non-empty body, or 200 with an empty body) when exercising the + drain loop's defensive branches. + * ``('raise', status_code, message)`` -- raise another HTTP error. + + The script records the ``If-None-Match`` header it saw on each call so + tests can assert that the drain loop advanced the etag correctly. + """ + + def __init__(self, script): + self.script = list(script) + self.calls = 0 + self.if_none_match_seen = [] + self.a_im_seen = [] + + def __call__(self, collection_link, options, response_hook=None, **kwargs): # noqa: ARG002 + in_headers = kwargs.get("headers", {}) or {} + self.if_none_match_seen.append( + in_headers.get(http_constants.HttpHeaders.IfNoneMatch) + ) + self.a_im_seen.append( + in_headers.get(http_constants.HttpHeaders.AIM) + ) + + if self.calls >= len(self.script): + raise AssertionError( + "PageScript exhausted on call #{}; only {} scripted entries.".format( + self.calls, len(self.script) + ) + ) + entry = self.script[self.calls] + self.calls += 1 + + kind = entry[0] + if kind == "raise": + _, status_code, message = entry + raise CosmosHttpResponseError(status_code=status_code, message=message) + if kind == "page": + if len(entry) == 4: + _, ranges_list, etag_value, status_code = entry + else: + _, ranges_list, etag_value = entry + # Mirror the real wire: empty page == 304 Not Modified, + # populated page == 200 OK. + status_code = ( + http_constants.StatusCodes.NOT_MODIFIED + if not ranges_list + else http_constants.StatusCodes.OK + ) + capture = kwargs.get("_internal_response_headers_capture") + if capture is not None and etag_value is not None: + capture[http_constants.HttpHeaders.ETag] = etag_value + status_capture = kwargs.get("_internal_response_status_capture") + if status_capture is not None: + status_capture[0] = status_code + return iter(ranges_list) + raise AssertionError("Unknown _PageScript entry: {!r}".format(entry)) + + +def _make_scripted_client(script): + client = MagicMock() + script_obj = _PageScript(script) + client._ReadPartitionKeyRanges = MagicMock(side_effect=script_obj) + return client, script_obj + + +# ========================================================= +# Tests +# ========================================================= + +@pytest.mark.cosmosEmulator +class TestPkRangeDrainSync(unittest.TestCase): + """Sync drain-loop integration tests for PartitionKeyRangeCache.""" + + def test_drain_propagates_etag_across_pages(self): + """Three pages with distinct etags drain into one complete map. + + The drain loop must send the previous page's etag as ``If-None-Match`` + on each subsequent call, and the resulting routing map must contain + the union of all ranges with the final etag. + """ + page1 = [_full_range("0", "", "55")] + page2 = [_full_range("1", "55", "AA")] + page3 = [_full_range("2", "AA", "FF")] + + client, script = _make_scripted_client([ + ("page", page1, '"etag-1"'), + ("page", page2, '"etag-2"'), + ("page", page3, '"etag-3"'), + # Real-wire 304 terminator: empty body + unchanged ETag header. + ("page", [], '"etag-3"'), + ]) + + cache = PartitionKeyRangeCache(client) + routing_map = cache._fetch_routing_map( + collection_link="dbs/db1/colls/coll1", + collection_id="coll1", + previous_routing_map=None, + feed_options={}, + ) + + self.assertIsNotNone(routing_map) + self.assertEqual(routing_map.change_feed_etag, '"etag-3"') + self.assertEqual(script.calls, 4) + # Drain starts with no If-None-Match, then advances to each prior etag. + self.assertEqual( + script.if_none_match_seen, + [None, '"etag-1"', '"etag-2"', '"etag-3"'], + ) + # Wire-protocol pin: every outgoing /pkranges call must carry the + # canonical capital-F ``A-IM: Incremental Feed`` literal. The gateway + # accepts case-insensitive variants per RFC 3229, but the canonical + # wire form is what every peer SDK ships -- a future cast change or + # constant rename that flipped the case would silently alter + # change-feed behavior server-side without this assertion. + self.assertEqual( + script.a_im_seen, + [http_constants.HttpHeaders.IncrementalFeedHeaderValue] * 4, + ) + + def test_real_wire_304_via_empty_page_preserves_previous_map(self): + """Production shape of a 304 first-fetch preserves the previous map. + + Real-wire 304s never surface as exceptions in production -- the HTTP + client only raises for ``status >= 400`` (see + ``_synchronized_request.py:205``). The change-feed read pipeline + treats 304 as a success-path empty body + unchanged ETag header (see + ``change_feed_fetcher.py:155-194`` for the canonical pattern). That + empty page + matching ETag lands on the identity fast-path in + ``_routing_map_provider_common.py:476-477`` and returns the previous + map untouched. + """ + previous_map = _make_complete_routing_map(etag='"etag-prev"') + + client, script = _make_scripted_client([ + # Real-wire 304: empty body + unchanged ETag header. + ("page", [], '"etag-prev"'), + ]) + + cache = PartitionKeyRangeCache(client) + routing_map = cache._fetch_routing_map( + collection_link="dbs/db1/colls/coll1", + collection_id="coll1", + previous_routing_map=previous_map, + feed_options={}, + ) + + self.assertIs(routing_map, previous_map) + self.assertEqual(script.calls, 1) + self.assertEqual(script.if_none_match_seen, ['"etag-prev"']) + + @unittest.skipIf( + sys.version_info < (3, 10), + "assertNoLogs is only available on Python 3.10+", + ) + def test_real_wire_304_does_not_emit_routing_map_warnings(self): + """Regression pin: real-wire 304 must not emit any WARNING from the + routing-map module. The defensive ``except status_code == 304`` branch + that previously existed left ``seen_any_etag=False`` and tripped the + 'no ETag observed' warning. If anyone reintroduces that branch (or any + equivalent path that bypasses ``evaluate_drain_page``), this test + catches it before it lands. + """ + previous_map = _make_complete_routing_map(etag='"etag-prev"') + + client, _ = _make_scripted_client([ + ("page", [], '"etag-prev"'), + ]) + + cache = PartitionKeyRangeCache(client) + with self.assertNoLogs( + "azure.cosmos._routing", level=logging.WARNING + ): + cache._fetch_routing_map( + collection_link="dbs/db1/colls/coll1", + collection_id="coll1", + previous_routing_map=previous_map, + feed_options={}, + ) + + def test_empty_page_terminates_drain(self): + """An empty body materializes as HTTP 304 in the mock helper (mirrors + the real gateway's wire shape for a drained change feed), so the drain + terminates via the literal-304 predicate -- the same predicate peer + SDKs (.NET / Java / Go) use. This pins that the helper's empty->304 + mapping reaches the production termination decision. + """ + page1 = _split_full_range_into(2) + + client, script = _make_scripted_client([ + ("page", page1, '"etag-1"'), + ("page", [], None), + ]) + + cache = PartitionKeyRangeCache(client) + routing_map = cache._fetch_routing_map( + collection_link="dbs/db1/colls/coll1", + collection_id="coll1", + previous_routing_map=None, + feed_options={}, + ) + + self.assertIsNotNone(routing_map) + self.assertEqual(routing_map.change_feed_etag, '"etag-1"') + self.assertEqual(script.calls, 2) + + def test_evaluate_drain_page_literal_304_terminates(self): + """Unit-pin the literal HTTP 304 termination predicate. + + ``evaluate_drain_page`` is the pure-function termination oracle for + the drain loop. Peer SDKs (.NET/Java/Go) end the drain on a literal + ``304 Not Modified`` status. Pin that the predicate ends the drain + on status 304 even when the page payload is non-empty -- i.e. + status wins over content, matching peer SDKs literally. + """ + from azure.cosmos._routing._routing_map_provider_common import ( + evaluate_drain_page, + _DrainPageDecision, + ) + + decision, new_etag, _next_inm, _seen = evaluate_drain_page( + page_new_etag='"etag-1"', + current_if_none_match='"etag-0"', + new_etag='"etag-0"', + seen_any_etag=True, + status_code=http_constants.StatusCodes.NOT_MODIFIED, + ) + + self.assertEqual(decision, _DrainPageDecision.STOP_DRAINED) + # New etag from the 304 response is still adopted. + self.assertEqual(new_etag, '"etag-1"') + + def test_literal_304_on_first_page_terminates_without_ranges(self): + """Status 304 on the very first page short-circuits the drain. + + Models the steady-state case where a refresh is triggered but the + routing map has not actually changed: gateway returns 304 on the + first request and we must terminate cleanly without trying to + build a routing map from zero ranges. + """ + # Seed a previous map so the fetch path has something to preserve + # when the 304 short-circuits before any ranges arrive. + seed_page = _split_full_range_into(3) + client, _ = _make_scripted_client([ + ("page", seed_page, '"etag-seed"'), + ("page", [], None), + ]) + cache = PartitionKeyRangeCache(client) + previous_map = cache._fetch_routing_map( + collection_link="dbs/db1/colls/coll1", + collection_id="coll1", + previous_routing_map=None, + feed_options={}, + ) + + # Now a refresh that gets an immediate 304. + client, script = _make_scripted_client([ + ("page", [], '"etag-seed"', 304), + ]) + cache = PartitionKeyRangeCache(client) + routing_map = cache._fetch_routing_map( + collection_link="dbs/db1/colls/coll1", + collection_id="coll1", + previous_routing_map=previous_map, + feed_options={}, + ) + + # Previous map is preserved on a no-op refresh. + self.assertEqual(script.calls, 1) + self.assertIsNotNone(routing_map) + + def test_empty_page_with_advanced_etag_terminates_and_bumps_etag(self): + """Empty body + new ETag header is the canonical "304 with fresh etag" + wire shape (the gateway tells us the routing map is fully drained and + hands us a new continuation anchor for the next refresh). The mock + helper materializes the empty body as status 304, so this exercises + the literal-304 termination branch -- pinning that (a) the drain + terminates, (b) the new etag is persisted on the returned routing map + so the next drain starts from the right anchor, and (c) the request + carried the prior etag as ``If-None-Match``. Matches the .NET / Java / + Go termination semantics. + """ + page1 = _split_full_range_into(2) + + client, script = _make_scripted_client([ + ("page", page1, '"etag-1"'), + ("page", [], '"etag-new"'), + ]) + + cache = PartitionKeyRangeCache(client) + routing_map = cache._fetch_routing_map( + collection_link="dbs/db1/colls/coll1", + collection_id="coll1", + previous_routing_map=None, + feed_options={}, + ) + + self.assertIsNotNone(routing_map) + # New etag is persisted even though the terminating page was empty. + self.assertEqual(routing_map.change_feed_etag, '"etag-new"') + self.assertEqual(script.calls, 2) + # Second request carried the prior etag as If-None-Match. + self.assertEqual(script.if_none_match_seen, [None, '"etag-1"']) + + def test_mid_drain_non_304_error_propagates_without_caching(self): + """A 500-class error in the middle of a drain propagates and leaves + the cache untouched.""" + page1 = [_full_range("0", "", "AA")] + + client, script = _make_scripted_client([ + ("page", page1, '"etag-1"'), + ("raise", 500, "Internal Server Error"), + ]) + + cache = PartitionKeyRangeCache(client) + with self.assertRaises(CosmosHttpResponseError) as ctx: + cache._fetch_routing_map( + collection_link="dbs/db1/colls/coll1", + collection_id="coll1", + previous_routing_map=None, + feed_options={}, + ) + + self.assertEqual(ctx.exception.status_code, 500) + self.assertEqual(script.calls, 2) + self.assertNotIn("coll1", cache._collection_routing_map_by_item) + + def test_per_page_transient_failure_is_retried_within_page_call(self): + """A transient 503 during page 2 is absorbed by the per-page retry + layer; the drain loop completes without restarting from page 1. + + In production, ``_ReadPartitionKeyRanges`` returns an ``ItemPaged`` + and each ``by_page()`` fetch is wrapped in ``_retry_utility.Execute`` + inside ``base_execution_context._fetch_items_helper_no_retries``. + So a transient retryable status (503) on page 2 is retried by the + per-request retry policy *inside* the page call, and the drain loop + only ever sees the final outcome of each page. This test pins that + contract: pages 1, 3 succeed on first attempt, page 2 succeeds on + the retry, and the final routing map reflects all three pages with + no whole-drain restart. + """ + page1 = [_full_range("0", "", "55")] + page2 = [_full_range("1", "55", "AA")] + page3 = [_full_range("2", "AA", "FF")] + + # Underlying script: the 503 between page1 and page2 is absorbed by + # the per-page retry wrapper below, so the drain loop never sees it. + client, script = _make_scripted_client([ + ("page", page1, '"etag-1"'), + ("raise", 503, "Service Unavailable"), # page 2, attempt 1 + ("page", page2, '"etag-2"'), # page 2, attempt 2 (retry) + ("page", page3, '"etag-3"'), + ("page", [], '"etag-3"'), # 304 / empty terminator + ]) + + underlying_side_effect = client._ReadPartitionKeyRanges.side_effect + retry_attempts = [0] + + def with_per_page_retry(*args, **kwargs): + """Mirrors what ``_retry_utility.Execute`` + + ``_ServiceUnavailableRetryPolicy`` do for a retryable 503: one + retry per page call, transparent to the caller.""" + try: + return underlying_side_effect(*args, **kwargs) + except CosmosHttpResponseError as e: + if e.status_code == 503: + retry_attempts[0] += 1 + return underlying_side_effect(*args, **kwargs) + raise + + client._ReadPartitionKeyRanges = MagicMock(side_effect=with_per_page_retry) + + cache = PartitionKeyRangeCache(client) + routing_map = cache._fetch_routing_map( + collection_link="dbs/db1/colls/coll1", + collection_id="coll1", + previous_routing_map=None, + feed_options={}, + ) + + # Drain completed and the final routing map carries page 3's etag. + self.assertIsNotNone(routing_map) + self.assertEqual(routing_map.change_feed_etag, '"etag-3"') + # One retry was absorbed by the per-page wrapper (page 2's 503). + self.assertEqual(retry_attempts[0], 1) + # 5 underlying script invocations: page1, page2-attempt1 (503), + # page2-attempt2 (success), page3, 304-terminator. + self.assertEqual(script.calls, 5) + # IfNoneMatch was preserved across the retry: both page-2 attempts + # saw '"etag-1"', proving the drain loop did NOT restart from page 1 + # (which would have started with None) and did NOT advance to + # '"etag-2"' prematurely (which would mean it processed page 2 + # before the retry). + self.assertEqual( + script.if_none_match_seen, + [None, '"etag-1"', '"etag-1"', '"etag-2"', '"etag-3"'], + ) + # And the only call the drain loop's outer try/except saw was the + # successful retry -- the 503 never surfaced. + self.assertEqual(client._ReadPartitionKeyRanges.call_count, 4) + + # ========================================================= + # Gap-coverage tests (option B): merge-failure cascades, + # cascading splits, concurrency, missing-ETag handling. + # ========================================================= + + def test_drain_without_etag_headers_terminates_and_preserves_previous_etag(self): + """Server omits ETag header entirely -- drain still terminates cleanly + and the previous ETag is preserved on the returned routing map. + + Peer SDKs (.NET v3 ``PartitionKeyRangeCache.cs``, Java + ``RxPartitionKeyRangeCache.java``) both trust the gateway to emit an + ETag and have no defensive cap when one is missing; .NET nulls out + the continuation, Java reads it as null. Python's behavior is + slightly safer: ``process_fetched_ranges`` preserves the previous + ETag and logs a WARNING. This test pins that contract so a future + refactor cannot silently swap to nullification (which would force a + full re-drain on the next refresh). + """ + previous_map = _make_complete_routing_map( + collection_id="coll-noetag", etag='"etag-prev"' + ) + + # Single empty page with no ETag header. Empty body auto-maps to 304 + # in the helper, so the drain terminates immediately via the literal- + # 304 predicate -- but ``seen_any_etag`` stays False because the + # response carried no ETag. + client, script = _make_scripted_client([ + ("page", [], None), + ]) + + cache = PartitionKeyRangeCache(client) + with self.assertLogs( + "azure.cosmos._routing._routing_map_provider_common", + level=logging.WARNING, + ) as log_ctx: + routing_map = cache._fetch_routing_map( + collection_link="dbs/db1/colls/coll-noetag", + collection_id="coll-noetag", + previous_routing_map=previous_map, + feed_options={}, + ) + + self.assertEqual(script.calls, 1) + # Previous map (and its ETag) is preserved; no defensive cap fires. + self.assertIs(routing_map, previous_map) + self.assertEqual(routing_map.change_feed_etag, '"etag-prev"') + # WARNING was emitted exactly once for the missing-ETag case. + no_etag_warnings = [ + m for m in log_ctx.output if "returned no ETag" in m + ] + self.assertEqual(len(no_etag_warnings), 1) + + def test_parent_not_found_falls_back_to_full_refresh(self): + """Incremental merge with unknown parent IDs -> retry -> full refresh. + + The page's child ranges declare parents that are not present in the + cached map. ``process_fetched_ranges`` raises ``_IncrementalMergeFailed`` + from the parents-not-found branch. The provider then: + 1. Retries the incremental fetch once with the same previous map. + 2. On the second incremental failure, sets ``current_previous_map=None`` + and falls back to a full refresh. + 3. The full refresh succeeds and returns a complete map. + + This pins the multi-layered fallback chain end-to-end, including the + boundary where the provider transitions from incremental retry to + full-refresh recovery. Without this test, a future refactor of the + retry cascade could silently collapse to "fail on first incremental + error" with no failing test signal. + """ + previous_map = _make_complete_routing_map( + collection_id="coll-parent", etag='"etag-prev"' + ) + # Child range claims parent "ghost-parent" which is NOT in previous_map + # (whose only range is id "0"). process_fetched_ranges will fail on + # parents-not-found. + orphan_child = _full_range("child", "", "FF") + orphan_child["parents"] = ["ghost-parent"] + + # The full-refresh page is a complete, parent-free range set. + full_refresh_ranges = _split_full_range_into(2) + + client, script = _make_scripted_client([ + # Drain attempt 1 (incremental): orphan child -> raises -> retry. + ("page", [orphan_child], '"etag-bad-1"'), + ("page", [], '"etag-bad-1"'), + # Drain attempt 2 (incremental retry): same outcome -> fall back. + ("page", [orphan_child], '"etag-bad-2"'), + ("page", [], '"etag-bad-2"'), + # Drain attempt 3 (full refresh, previous_map=None): clean ranges. + ("page", full_refresh_ranges, '"etag-full"'), + ("page", [], '"etag-full"'), + ]) + + cache = PartitionKeyRangeCache(client) + routing_map = cache._fetch_routing_map( + collection_link="dbs/db1/colls/coll-parent", + collection_id="coll-parent", + previous_routing_map=previous_map, + feed_options={}, + ) + + self.assertIsNotNone(routing_map) + # Final map came from the full-refresh path. + self.assertEqual(routing_map.change_feed_etag, '"etag-full"') + # All six scripted entries were consumed: 2 attempts x 2 pages + # (incremental) + 2 pages (full refresh). + self.assertEqual(script.calls, 6) + + def test_overlap_in_second_page_falls_back_to_full_refresh(self): + """Incremental merge with overlapping ranges -> retry -> full refresh. + + ``try_combine`` raises ``ValueError("Ranges overlap...")`` when the + merged range set is not a clean partition cover (e.g. two split + children that both claim the same byte range due to an out-of-order or + duplicated split notification). ``process_fetched_ranges`` translates + this into ``_IncrementalMergeFailed``; the provider then retries + incrementally and falls back to a full refresh. + + Distinct from the parent-not-found test above: that one fires at + L461 (parents-not-found), this one fires at L479 (overlap from + ``try_combine``). Both must independently trigger the same recovery + cascade. + """ + # Previous map: single range A covering the full PK space. + previous_map = _make_complete_routing_map( + collection_id="coll-overlap", etag='"etag-prev"' + ) + + # Two split children that BOTH claim parent "0" (the only range in + # previous_map) but their ranges OVERLAP: B covers ["", "AA") and C + # covers ["80", "FF"). 0x80 < 0xAA, so the merged set is not a clean + # partition -> try_combine raises ValueError("Ranges overlap"). + child_b = _full_range("child-b", "", "AA") + child_b["parents"] = ["0"] + child_c = _full_range("child-c", "80", "FF") + child_c["parents"] = ["0"] + overlapping_page = [child_b, child_c] + + full_refresh_ranges = _split_full_range_into(2) + + client, script = _make_scripted_client([ + # Drain attempt 1 (incremental): overlap -> raises -> retry. + ("page", overlapping_page, '"etag-overlap-1"'), + ("page", [], '"etag-overlap-1"'), + # Drain attempt 2 (incremental retry): same outcome -> fall back. + ("page", overlapping_page, '"etag-overlap-2"'), + ("page", [], '"etag-overlap-2"'), + # Drain attempt 3 (full refresh, previous_map=None): clean ranges. + ("page", full_refresh_ranges, '"etag-full"'), + ("page", [], '"etag-full"'), + ]) + + cache = PartitionKeyRangeCache(client) + routing_map = cache._fetch_routing_map( + collection_link="dbs/db1/colls/coll-overlap", + collection_id="coll-overlap", + previous_routing_map=previous_map, + feed_options={}, + ) + + self.assertIsNotNone(routing_map) + self.assertEqual(routing_map.change_feed_etag, '"etag-full"') + self.assertEqual(script.calls, 6) + + def test_cascading_splits_in_single_page_resolve(self): + """Cascading splits (A->B+C, then B->D+E) in a single page resolve in + two passes via the ``unresolved``/``progress_made`` queue. + + The page is intentionally ordered ``[D, E, B, C]`` so that on pass 1 + the merge loop encounters D and E *before* B is known. D and E + declare parent B (not in the prior map), so they cannot resolve. + B and C resolve via parent A. Pass 2 then resolves D and E because + B is now in ``known_range_info_by_id``. This pins the inner + breadth-first resolution loop in ``process_fetched_ranges``. + """ + # Prior map: single range A covering the full PK space. + previous_map = _make_complete_routing_map( + collection_id="coll-cascading", etag='"etag-prev"' + ) + + # B and C split from A; D and E then split from B -- all in one page. + b = _full_range("B", "", "55") + b["parents"] = ["0"] + c = _full_range("C", "55", "FF") + c["parents"] = ["0"] + d = _full_range("D", "", "33") + d["parents"] = ["B"] + e = _full_range("E", "33", "55") + e["parents"] = ["B"] + # Ordering forces the two-pass behavior: D/E come before B in the + # iteration order. + cascading_page = [d, e, b, c] + + client, script = _make_scripted_client([ + ("page", cascading_page, '"etag-cascading"'), + ("page", [], '"etag-cascading"'), + ]) + + cache = PartitionKeyRangeCache(client) + routing_map = cache._fetch_routing_map( + collection_link="dbs/db1/colls/coll-cascading", + collection_id="coll-cascading", + previous_routing_map=previous_map, + feed_options={}, + ) + + self.assertIsNotNone(routing_map) + self.assertEqual(routing_map.change_feed_etag, '"etag-cascading"') + self.assertEqual(script.calls, 2) + # Final map covers the full PK space via the leaf ranges D, E, C + # (A and B are gone after the cascading split). + # pylint: disable=protected-access + final_ids = sorted(routing_map._rangeById.keys()) + self.assertEqual(final_ids, ["C", "D", "E"]) + + def test_concurrent_drains_for_same_collection_serialize(self): + """N concurrent ``get_routing_map`` calls for the same collection + result in exactly ONE ``_fetch_routing_map`` invocation; all callers + receive the same map object. + + Pins the per-collection lock in ``get_routing_map``: without it, a + cold-cache burst from a worker pool would thunder N parallel /pkranges + drains. A future refactor that accidentally removed the lock (or + widened the fast-path read past the cache check) would surface here. + """ + # We're testing the lock around _fetch_routing_map -- mock it + # directly. This isolates the locking contract from the drain loop. + client = MagicMock() + provider = PartitionKeyRangeCache(client) + + fetch_count = [0] + complete_map = _make_complete_routing_map( + collection_id="coll-serialize", etag='"etag-serialize"' + ) + + def slow_fetch(collection_link, collection_id, previous_routing_map, feed_options, **kwargs): # noqa: ARG001 + fetch_count[0] += 1 + # Hold the lock long enough that queued callers definitely + # observe the same cached result on lock release. + time.sleep(0.05) + return complete_map + + provider._fetch_routing_map = MagicMock(side_effect=slow_fetch) + + N = 8 + barrier = threading.Barrier(N) + + def caller(): + barrier.wait(timeout=5) + return provider.get_routing_map( + collection_link="dbs/db1/colls/coll-serialize", + feed_options={}, + ) + + with ThreadPoolExecutor(max_workers=N) as ex: + futures = [ex.submit(caller) for _ in range(N)] + results = [f.result(timeout=10) for f in futures] + + # The per-collection lock serialized the burst: exactly one fetch + # ran; the other N-1 callers hit the post-lock cache check. + self.assertEqual(fetch_count[0], 1) + # All callers received the same cached map object (identity check). + self.assertTrue(all(r is complete_map for r in results)) + + def test_concurrent_drains_for_different_collections_do_not_serialize(self): + """Two concurrent ``get_routing_map`` calls for DIFFERENT collections + do NOT serialize against each other. + + Pins the lock GRANULARITY: a future refactor that replaced the + per-collection lock with a single global lock would force unrelated + collection refreshes to queue, hurting throughput. The test uses a + shared barrier *inside* the fetch to prove both fetches were live at + the same time -- a global lock would deadlock the barrier. + """ + client = MagicMock() + provider = PartitionKeyRangeCache(client) + + map_a = _make_complete_routing_map(collection_id="coll-A", etag='"etag-A"') + map_b = _make_complete_routing_map(collection_id="coll-B", etag='"etag-B"') + + # Both fetches must enter before either exits. If a global lock + # serialized them, the second fetch would not enter until the first + # released -- and the barrier would time out. + in_fetch_barrier = threading.Barrier(2, timeout=5) + + def selective_fetch(collection_link, collection_id, previous_routing_map, feed_options, **kwargs): # noqa: ARG001 + in_fetch_barrier.wait() + return map_a if "coll-A" in collection_link else map_b + + provider._fetch_routing_map = MagicMock(side_effect=selective_fetch) + + start_barrier = threading.Barrier(2) + + def caller(collection_link): + start_barrier.wait(timeout=5) + return provider.get_routing_map( + collection_link=collection_link, feed_options={}, + ) + + with ThreadPoolExecutor(max_workers=2) as ex: + f_a = ex.submit(caller, "dbs/db1/colls/coll-A") + f_b = ex.submit(caller, "dbs/db1/colls/coll-B") + result_a = f_a.result(timeout=10) + result_b = f_b.result(timeout=10) + + # Both fetches ran (no global serialization swallowed one of them). + self.assertEqual(provider._fetch_routing_map.call_count, 2) + # Each caller received the map for its own collection. + self.assertIs(result_a, map_a) + self.assertIs(result_b, map_b) + + def test_caller_headers_not_mutated_by_drain_loop(self): + """Drain loop must never mutate the caller's ``headers`` dict. + + Regression guard: the drain loop receives an arbitrary ``kwargs`` + dict from upstream and forwards it (via shallow-copy + per-iter + header dict-copy) to every ``_ReadPartitionKeyRanges`` call. It must + not leak per-iter mutations -- ``If-None-Match`` overrides, sidecar + captures, or ``prepare_fetch_options_and_headers`` additions + (``A-IM``, page-size, populate-stats, etc.) -- back into the + caller's dict. A regression here would silently poison the next + outbound request from the same caller (e.g. a stale + ``If-None-Match`` carried into an unrelated read). + """ + page1 = [_full_range("0", "", "55")] + page2 = [_full_range("1", "55", "AA")] + page3 = [_full_range("2", "AA", "FF")] + + client, script = _make_scripted_client([ + ("page", page1, '"etag-1"'), + ("page", page2, '"etag-2"'), + ("page", page3, '"etag-3"'), + ("page", [], '"etag-3"'), + ]) + + # Sentinel headers from the caller -- snapshot up front so we can + # diff against the post-drain state. + caller_headers = {"X-Custom-Marker": "value", "Authorization": "Bearer x"} + caller_headers_snapshot = dict(caller_headers) + + cache = PartitionKeyRangeCache(client) + routing_map = cache._fetch_routing_map( + collection_link="dbs/db1/colls/coll1", + collection_id="coll1", + previous_routing_map=None, + feed_options={}, + headers=caller_headers, + ) + + self.assertIsNotNone(routing_map) + self.assertEqual(script.calls, 4) + # Caller's dict identity AND contents are unchanged after the drain. + self.assertEqual(caller_headers, caller_headers_snapshot) + self.assertNotIn(http_constants.HttpHeaders.IfNoneMatch, caller_headers) + self.assertNotIn(http_constants.HttpHeaders.AIM, caller_headers) + # Per-page ``If-None-Match`` did still get sent to the wire on every + # call after the first -- proving the drain DID set the header on + # the outbound request, just not on the caller's dict. + self.assertEqual( + script.if_none_match_seen, + [None, '"etag-1"', '"etag-2"', '"etag-3"'], + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/sdk/cosmos/azure-cosmos/tests/test_pk_range_drain_async.py b/sdk/cosmos/azure-cosmos/tests/test_pk_range_drain_async.py new file mode 100644 index 000000000000..ddd11cd500e2 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_pk_range_drain_async.py @@ -0,0 +1,723 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. + +""" +Async integration tests for the /pkranges change-feed drain loop in +``aio.PartitionKeyRangeCache._fetch_routing_map``. + +Mirrors ``test_pk_range_drain.py`` for the async provider: scripts an +``async`` generator from ``_ReadPartitionKeyRanges`` to emit multiple pages +with distinct ETags and asserts on ETag propagation, real-wire 304 +preservation (empty page + unchanged ETag), the empty-page terminator, and +clean propagation of mid-drain non-304 errors. +""" + +# pylint: disable=protected-access + +import asyncio +import logging +import sys +import unittest +from unittest.mock import MagicMock + +import pytest + +from azure.cosmos._routing.aio.routing_map_provider import PartitionKeyRangeCache +from azure.cosmos._routing.collection_routing_map import CollectionRoutingMap +from azure.cosmos import http_constants +from azure.cosmos.exceptions import CosmosHttpResponseError + + +# ========================================================= +# Helpers +# ========================================================= + +def _full_range(range_id="0", min_inclusive="", max_exclusive="FF"): + return { + "id": range_id, + "minInclusive": min_inclusive, + "maxExclusive": max_exclusive, + } + + +def _make_complete_routing_map(collection_id="coll1", etag='"etag-prev"'): + ranges = [(_full_range(), True)] + return CollectionRoutingMap.CompleteRoutingMap(ranges, collection_id, etag) + + +class _AsyncPageScript: + """Scripted async ``_ReadPartitionKeyRanges`` side-effect for the drain loop. + + Each entry is one of: + * ``('page', ranges_list, etag_value)`` -- emit a page + ETag header. + The wire status is inferred to match production: empty ``ranges_list`` + is treated as the real-wire 304 Not Modified (empty body + unchanged + ETag header), non-empty as 200. Production never surfaces 304 as an + exception (see ``_synchronized_request.py`` -- only ``>= 400`` raises) + so this is the only shape the drain loop ever sees on the wire. + * ``('page', ranges_list, etag_value, status_code)`` -- same, but with + an explicit wire status. Use this to model server bugs (e.g. 304 with + a non-empty body, or 200 with an empty body) when exercising the + drain loop's defensive branches. + * ``('raise', status_code, message)`` -- raise another HTTP error. + + Records the ``If-None-Match`` header seen on each call. + """ + + def __init__(self, script): + self.script = list(script) + self.calls = 0 + self.if_none_match_seen = [] + self.a_im_seen = [] + + def __call__(self, collection_link, options, response_hook=None, **kwargs): # noqa: ARG002 + in_headers = kwargs.get("headers", {}) or {} + self.if_none_match_seen.append( + in_headers.get(http_constants.HttpHeaders.IfNoneMatch) + ) + self.a_im_seen.append( + in_headers.get(http_constants.HttpHeaders.AIM) + ) + + if self.calls >= len(self.script): + raise AssertionError( + "AsyncPageScript exhausted on call #{}; only {} scripted entries.".format( + self.calls, len(self.script) + ) + ) + entry = self.script[self.calls] + self.calls += 1 + + kind = entry[0] + if kind == "raise": + _, status_code, message = entry + async def raising_gen(): + raise CosmosHttpResponseError(status_code=status_code, message=message) + yield # pragma: no cover + return raising_gen() + + if kind == "page": + if len(entry) == 4: + _, ranges_list, etag_value, status_code = entry + else: + _, ranges_list, etag_value = entry + # Mirror the real wire: empty page == 304 Not Modified, + # populated page == 200 OK. + status_code = ( + http_constants.StatusCodes.NOT_MODIFIED + if not ranges_list + else http_constants.StatusCodes.OK + ) + capture = kwargs.get("_internal_response_headers_capture") + if capture is not None and etag_value is not None: + capture[http_constants.HttpHeaders.ETag] = etag_value + status_capture = kwargs.get("_internal_response_status_capture") + if status_capture is not None: + status_capture[0] = status_code + + async def async_gen(): + for r in ranges_list: + yield r + return async_gen() + + raise AssertionError("Unknown _AsyncPageScript entry: {!r}".format(entry)) + + +def _make_scripted_async_client(script): + client = MagicMock() + script_obj = _AsyncPageScript(script) + client._ReadPartitionKeyRanges = MagicMock(side_effect=script_obj) + return client, script_obj + + +# ========================================================= +# Tests +# ========================================================= + +@pytest.mark.cosmosEmulator +class TestPkRangeDrainAsync(unittest.IsolatedAsyncioTestCase): + """Async drain-loop integration tests for PartitionKeyRangeCache.""" + + async def test_drain_propagates_etag_across_pages_async(self): + """Three pages with distinct etags drain into one complete map.""" + page1 = [_full_range("0", "", "55")] + page2 = [_full_range("1", "55", "AA")] + page3 = [_full_range("2", "AA", "FF")] + + client, script = _make_scripted_async_client([ + ("page", page1, '"etag-1"'), + ("page", page2, '"etag-2"'), + ("page", page3, '"etag-3"'), + # Real-wire 304 terminator: empty body + unchanged ETag header. + ("page", [], '"etag-3"'), + ]) + + cache = PartitionKeyRangeCache(client) + routing_map = await cache._fetch_routing_map( + collection_link="dbs/db1/colls/coll1", + collection_id="coll1", + previous_routing_map=None, + feed_options={}, + ) + + self.assertIsNotNone(routing_map) + self.assertEqual(routing_map.change_feed_etag, '"etag-3"') + self.assertEqual(script.calls, 4) + self.assertEqual( + script.if_none_match_seen, + [None, '"etag-1"', '"etag-2"', '"etag-3"'], + ) + # Wire-protocol pin: every outgoing /pkranges call must carry the + # canonical capital-F ``A-IM: Incremental Feed`` literal. The gateway + # accepts case-insensitive variants per RFC 3229, but the canonical + # wire form is what every peer SDK ships -- a future cast change or + # constant rename that flipped the case would silently alter + # change-feed behavior server-side without this assertion. + self.assertEqual( + script.a_im_seen, + [http_constants.HttpHeaders.IncrementalFeedHeaderValue] * 4, + ) + + async def test_real_wire_304_via_empty_page_preserves_previous_map_async(self): + """Production shape of a 304 first-fetch preserves the previous map. + + Real-wire 304s never surface as exceptions in production -- the HTTP + client only raises for ``status >= 400`` (see + ``_synchronized_request.py:205``). The change-feed read pipeline + treats 304 as a success-path empty body + unchanged ETag header (see + ``change_feed_fetcher.py:155-194`` for the canonical pattern). That + empty page + matching ETag lands on the identity fast-path in + ``_routing_map_provider_common.py:476-477`` and returns the previous + map untouched. + """ + previous_map = _make_complete_routing_map(etag='"etag-prev"') + + client, script = _make_scripted_async_client([ + # Real-wire 304: empty body + unchanged ETag header. + ("page", [], '"etag-prev"'), + ]) + + cache = PartitionKeyRangeCache(client) + routing_map = await cache._fetch_routing_map( + collection_link="dbs/db1/colls/coll1", + collection_id="coll1", + previous_routing_map=previous_map, + feed_options={}, + ) + + self.assertIs(routing_map, previous_map) + self.assertEqual(script.calls, 1) + self.assertEqual(script.if_none_match_seen, ['"etag-prev"']) + + @unittest.skipIf( + sys.version_info < (3, 10), + "assertNoLogs is only available on Python 3.10+", + ) + async def test_real_wire_304_does_not_emit_routing_map_warnings_async(self): + """Regression pin: real-wire 304 must not emit any WARNING from the + routing-map module. Mirrors the sync test -- guards against any future + reintroduction of a defensive ``status_code == 304`` branch that + would leave ``seen_any_etag=False`` and trip the 'no ETag observed' + warning. + """ + previous_map = _make_complete_routing_map(etag='"etag-prev"') + + client, _ = _make_scripted_async_client([ + ("page", [], '"etag-prev"'), + ]) + + cache = PartitionKeyRangeCache(client) + with self.assertNoLogs( + "azure.cosmos._routing", level=logging.WARNING + ): + await cache._fetch_routing_map( + collection_link="dbs/db1/colls/coll1", + collection_id="coll1", + previous_routing_map=previous_map, + feed_options={}, + ) + + async def test_empty_page_terminates_drain_async(self): + """An empty body materializes as HTTP 304 in the mock helper (mirrors + the real gateway's wire shape for a drained change feed), so the drain + terminates via the literal-304 predicate -- the same predicate peer + SDKs (.NET / Java / Go) use. Async mirror of the sync test. + """ + page1 = [_full_range("0", "", "FF")] + + client, script = _make_scripted_async_client([ + ("page", page1, '"etag-1"'), + ("page", [], None), + ]) + + cache = PartitionKeyRangeCache(client) + routing_map = await cache._fetch_routing_map( + collection_link="dbs/db1/colls/coll1", + collection_id="coll1", + previous_routing_map=None, + feed_options={}, + ) + + self.assertIsNotNone(routing_map) + self.assertEqual(routing_map.change_feed_etag, '"etag-1"') + self.assertEqual(script.calls, 2) + + async def test_evaluate_drain_page_literal_304_terminates_async(self): + """Unit-pin the literal HTTP 304 termination predicate (async path). + + ``evaluate_drain_page`` is shared between sync and async drain loops. + Same contract as the sync test: peer SDKs (.NET/Java/Go) terminate + on a literal ``304 Not Modified`` regardless of payload, and so do + we. This pins the predicate from the async test file so the async + drain's reliance on it is visible from the async test bundle. + """ + from azure.cosmos._routing._routing_map_provider_common import ( + evaluate_drain_page, + _DrainPageDecision, + ) + + decision, new_etag, _next_inm, _seen = evaluate_drain_page( + page_new_etag='"etag-1"', + current_if_none_match='"etag-0"', + new_etag='"etag-0"', + seen_any_etag=True, + status_code=http_constants.StatusCodes.NOT_MODIFIED, + ) + + self.assertEqual(decision, _DrainPageDecision.STOP_DRAINED) + self.assertEqual(new_etag, '"etag-1"') + + async def test_literal_304_on_first_page_terminates_without_ranges_async(self): + """Status 304 on the very first page short-circuits the async drain.""" + seed_page = [_full_range("0", "", "FF")] + client, _ = _make_scripted_async_client([ + ("page", seed_page, '"etag-seed"'), + ("page", [], None), + ]) + cache = PartitionKeyRangeCache(client) + previous_map = await cache._fetch_routing_map( + collection_link="dbs/db1/colls/coll1", + collection_id="coll1", + previous_routing_map=None, + feed_options={}, + ) + + client, script = _make_scripted_async_client([ + ("page", [], '"etag-seed"', 304), + ]) + cache = PartitionKeyRangeCache(client) + routing_map = await cache._fetch_routing_map( + collection_link="dbs/db1/colls/coll1", + collection_id="coll1", + previous_routing_map=previous_map, + feed_options={}, + ) + + self.assertEqual(script.calls, 1) + self.assertIsNotNone(routing_map) + + async def test_empty_page_with_advanced_etag_terminates_and_bumps_etag_async(self): + """Empty body + new ETag header is the canonical "304 with fresh etag" + wire shape. The mock helper materializes the empty body as status 304, + so this exercises the literal-304 termination branch -- pinning that + (a) the drain terminates, (b) the new etag is persisted on the + returned routing map so the next drain starts from the right anchor, + and (c) the request carried the prior etag as ``If-None-Match``. + Async mirror of the sync test. + """ + page1 = [_full_range("0", "", "FF")] + + client, script = _make_scripted_async_client([ + ("page", page1, '"etag-1"'), + ("page", [], '"etag-new"'), + ]) + + cache = PartitionKeyRangeCache(client) + routing_map = await cache._fetch_routing_map( + collection_link="dbs/db1/colls/coll1", + collection_id="coll1", + previous_routing_map=None, + feed_options={}, + ) + + self.assertIsNotNone(routing_map) + # New etag is persisted even though the terminating page was empty. + self.assertEqual(routing_map.change_feed_etag, '"etag-new"') + self.assertEqual(script.calls, 2) + # Second request carried the prior etag as If-None-Match. + self.assertEqual(script.if_none_match_seen, [None, '"etag-1"']) + + async def test_mid_drain_non_304_error_propagates_without_caching_async(self): + """A 500-class error mid-drain propagates without poisoning the cache.""" + page1 = [_full_range("0", "", "AA")] + + client, script = _make_scripted_async_client([ + ("page", page1, '"etag-1"'), + ("raise", 500, "Internal Server Error"), + ]) + + cache = PartitionKeyRangeCache(client) + with self.assertRaises(CosmosHttpResponseError) as ctx: + await cache._fetch_routing_map( + collection_link="dbs/db1/colls/coll1", + collection_id="coll1", + previous_routing_map=None, + feed_options={}, + ) + + self.assertEqual(ctx.exception.status_code, 500) + self.assertEqual(script.calls, 2) + self.assertNotIn("coll1", cache._collection_routing_map_by_item) + + async def test_per_page_transient_failure_is_retried_within_page_call_async(self): + """A transient 503 during page 2 is absorbed by the per-page retry + layer; the drain loop completes without restarting from page 1. + + Production async path: ``_ReadPartitionKeyRanges`` returns an + ``AsyncItemPaged`` and each ``by_page()`` fetch is wrapped in + ``_retry_utility.ExecuteAsync`` inside + ``aio.base_execution_context._fetch_items_helper_no_retries``. So a + transient retryable status (503) on page 2 is retried by the + per-request retry policy *inside* the page call, and the drain loop + only ever sees the final outcome of each page. This test pins that + contract for the async drain. + """ + page1 = [_full_range("0", "", "55")] + page2 = [_full_range("1", "55", "AA")] + page3 = [_full_range("2", "AA", "FF")] + + client, script = _make_scripted_async_client([ + ("page", page1, '"etag-1"'), + ("raise", 503, "Service Unavailable"), # page 2, attempt 1 + ("page", page2, '"etag-2"'), # page 2, attempt 2 (retry) + ("page", page3, '"etag-3"'), + ("page", [], '"etag-3"'), # 304 / empty terminator + ]) + + underlying_side_effect = client._ReadPartitionKeyRanges.side_effect + retry_attempts = [0] + + def with_per_page_retry_async(*args, **kwargs): + """Mirrors ``_retry_utility.ExecuteAsync`` + + ``_ServiceUnavailableRetryPolicy``: a 503 raised while + materializing the page is retried once, transparently to the + drain loop. Returns a fresh async generator so the caller's + ``async for`` sees a clean iteration.""" + async def retried_gen(): + try: + inner = underlying_side_effect(*args, **kwargs) + async for item in inner: + yield item + except CosmosHttpResponseError as e: + if e.status_code != 503: + raise + retry_attempts[0] += 1 + inner = underlying_side_effect(*args, **kwargs) + async for item in inner: + yield item + return retried_gen() + + client._ReadPartitionKeyRanges = MagicMock(side_effect=with_per_page_retry_async) + + cache = PartitionKeyRangeCache(client) + routing_map = await cache._fetch_routing_map( + collection_link="dbs/db1/colls/coll1", + collection_id="coll1", + previous_routing_map=None, + feed_options={}, + ) + + # Drain completed and the final routing map carries page 3's etag. + self.assertIsNotNone(routing_map) + self.assertEqual(routing_map.change_feed_etag, '"etag-3"') + # One retry was absorbed by the per-page wrapper (page 2's 503). + self.assertEqual(retry_attempts[0], 1) + # 5 underlying script invocations: page1, page2-attempt1 (503), + # page2-attempt2 (success), page3, 304-terminator. + self.assertEqual(script.calls, 5) + # IfNoneMatch was preserved across the retry: both page-2 attempts + # saw '"etag-1"', proving the drain loop did NOT restart from page 1 + # (which would have started with None) and did NOT advance to + # '"etag-2"' prematurely (which would mean it processed page 2 + # before the retry). + self.assertEqual( + script.if_none_match_seen, + [None, '"etag-1"', '"etag-1"', '"etag-2"', '"etag-3"'], + ) + # And the drain loop's outer try/except saw 4 successful page calls + # -- the 503 was absorbed inside the per-page retry wrapper. + self.assertEqual(client._ReadPartitionKeyRanges.call_count, 4) + + # ========================================================= + # Gap-coverage tests (option B): async mirrors of the sync + # merge-failure cascades, cascading splits, concurrency, + # and missing-ETag handling. + # ========================================================= + + async def test_drain_without_etag_headers_terminates_and_preserves_previous_etag_async(self): + """Async mirror: server omits ETag header -> previous ETag preserved + and termination still fires via the literal-304 predicate. See sync + twin for full rationale.""" + previous_map = _make_complete_routing_map( + collection_id="coll-noetag", etag='"etag-prev"' + ) + + client, script = _make_scripted_async_client([ + ("page", [], None), + ]) + + cache = PartitionKeyRangeCache(client) + with self.assertLogs( + "azure.cosmos._routing._routing_map_provider_common", + level=logging.WARNING, + ) as log_ctx: + routing_map = await cache._fetch_routing_map( + collection_link="dbs/db1/colls/coll-noetag", + collection_id="coll-noetag", + previous_routing_map=previous_map, + feed_options={}, + ) + + self.assertEqual(script.calls, 1) + self.assertIs(routing_map, previous_map) + self.assertEqual(routing_map.change_feed_etag, '"etag-prev"') + no_etag_warnings = [ + m for m in log_ctx.output if "returned no ETag" in m + ] + self.assertEqual(len(no_etag_warnings), 1) + + async def test_parent_not_found_falls_back_to_full_refresh_async(self): + """Async mirror: parents-not-found -> retry -> full refresh succeeds. + See sync twin for full rationale.""" + previous_map = _make_complete_routing_map( + collection_id="coll-parent", etag='"etag-prev"' + ) + orphan_child = _full_range("child", "", "FF") + orphan_child["parents"] = ["ghost-parent"] + + full_refresh_ranges = [ + _full_range("0", "", "55"), + _full_range("1", "55", "FF"), + ] + + client, script = _make_scripted_async_client([ + ("page", [orphan_child], '"etag-bad-1"'), + ("page", [], '"etag-bad-1"'), + ("page", [orphan_child], '"etag-bad-2"'), + ("page", [], '"etag-bad-2"'), + ("page", full_refresh_ranges, '"etag-full"'), + ("page", [], '"etag-full"'), + ]) + + cache = PartitionKeyRangeCache(client) + routing_map = await cache._fetch_routing_map( + collection_link="dbs/db1/colls/coll-parent", + collection_id="coll-parent", + previous_routing_map=previous_map, + feed_options={}, + ) + + self.assertIsNotNone(routing_map) + self.assertEqual(routing_map.change_feed_etag, '"etag-full"') + self.assertEqual(script.calls, 6) + + async def test_overlap_in_second_page_falls_back_to_full_refresh_async(self): + """Async mirror: overlap from try_combine -> retry -> full refresh. + See sync twin for full rationale.""" + previous_map = _make_complete_routing_map( + collection_id="coll-overlap", etag='"etag-prev"' + ) + + child_b = _full_range("child-b", "", "AA") + child_b["parents"] = ["0"] + child_c = _full_range("child-c", "80", "FF") + child_c["parents"] = ["0"] + overlapping_page = [child_b, child_c] + + full_refresh_ranges = [ + _full_range("0", "", "55"), + _full_range("1", "55", "FF"), + ] + + client, script = _make_scripted_async_client([ + ("page", overlapping_page, '"etag-overlap-1"'), + ("page", [], '"etag-overlap-1"'), + ("page", overlapping_page, '"etag-overlap-2"'), + ("page", [], '"etag-overlap-2"'), + ("page", full_refresh_ranges, '"etag-full"'), + ("page", [], '"etag-full"'), + ]) + + cache = PartitionKeyRangeCache(client) + routing_map = await cache._fetch_routing_map( + collection_link="dbs/db1/colls/coll-overlap", + collection_id="coll-overlap", + previous_routing_map=previous_map, + feed_options={}, + ) + + self.assertIsNotNone(routing_map) + self.assertEqual(routing_map.change_feed_etag, '"etag-full"') + self.assertEqual(script.calls, 6) + + async def test_cascading_splits_in_single_page_resolve_async(self): + """Async mirror: cascading splits A->B+C and B->D+E in a single page + resolve in two passes. See sync twin for full rationale.""" + previous_map = _make_complete_routing_map( + collection_id="coll-cascading", etag='"etag-prev"' + ) + + b = _full_range("B", "", "55") + b["parents"] = ["0"] + c = _full_range("C", "55", "FF") + c["parents"] = ["0"] + d = _full_range("D", "", "33") + d["parents"] = ["B"] + e = _full_range("E", "33", "55") + e["parents"] = ["B"] + cascading_page = [d, e, b, c] + + client, script = _make_scripted_async_client([ + ("page", cascading_page, '"etag-cascading"'), + ("page", [], '"etag-cascading"'), + ]) + + cache = PartitionKeyRangeCache(client) + routing_map = await cache._fetch_routing_map( + collection_link="dbs/db1/colls/coll-cascading", + collection_id="coll-cascading", + previous_routing_map=previous_map, + feed_options={}, + ) + + self.assertIsNotNone(routing_map) + self.assertEqual(routing_map.change_feed_etag, '"etag-cascading"') + self.assertEqual(script.calls, 2) + # pylint: disable=protected-access + final_ids = sorted(routing_map._rangeById.keys()) + self.assertEqual(final_ids, ["C", "D", "E"]) + + async def test_concurrent_drains_for_same_collection_serialize_async(self): + """Async mirror: N concurrent ``get_routing_map`` calls for the same + collection result in exactly ONE ``_fetch_routing_map`` invocation. + + Distinct from the sync test because the async provider keys per- + collection locks on ``(loop_id, collection_id)`` rather than just + ``collection_id`` -- a regression in the async key derivation would + not surface in the sync test. + """ + client = MagicMock() + provider = PartitionKeyRangeCache(client) + + fetch_count = [0] + complete_map = _make_complete_routing_map( + collection_id="coll-serialize", etag='"etag-serialize"' + ) + + async def slow_fetch(collection_link, collection_id, previous_routing_map, feed_options, **kwargs): # noqa: ARG001 + fetch_count[0] += 1 + # Hold long enough that queued coroutines observe the cached + # result on lock release. + await asyncio.sleep(0.05) + return complete_map + + provider._fetch_routing_map = MagicMock(side_effect=slow_fetch) + + N = 8 + results = await asyncio.gather(*[ + provider.get_routing_map( + collection_link="dbs/db1/colls/coll-serialize", + feed_options={}, + ) + for _ in range(N) + ]) + + self.assertEqual(fetch_count[0], 1) + self.assertTrue(all(r is complete_map for r in results)) + + async def test_concurrent_drains_for_different_collections_do_not_serialize_async(self): + """Async mirror: two concurrent ``get_routing_map`` calls for + DIFFERENT collections do NOT serialize. Uses a shared barrier-like + counted ``asyncio.Event`` (avoids ``asyncio.Barrier`` for Python 3.10 + compatibility).""" + client = MagicMock() + provider = PartitionKeyRangeCache(client) + + map_a = _make_complete_routing_map(collection_id="coll-A", etag='"etag-A"') + map_b = _make_complete_routing_map(collection_id="coll-B", etag='"etag-B"') + + entered = 0 + both_in = asyncio.Event() + + async def selective_fetch(collection_link, collection_id, previous_routing_map, feed_options, **kwargs): # noqa: ARG001 + nonlocal entered + entered += 1 + if entered == 2: + both_in.set() + # If a global lock serialized the two fetches, the second would + # never enter and this wait would time out. + await asyncio.wait_for(both_in.wait(), timeout=5) + return map_a if "coll-A" in collection_link else map_b + + provider._fetch_routing_map = MagicMock(side_effect=selective_fetch) + + result_a, result_b = await asyncio.gather( + provider.get_routing_map( + collection_link="dbs/db1/colls/coll-A", feed_options={}, + ), + provider.get_routing_map( + collection_link="dbs/db1/colls/coll-B", feed_options={}, + ), + ) + + self.assertEqual(provider._fetch_routing_map.call_count, 2) + self.assertIs(result_a, map_a) + self.assertIs(result_b, map_b) + + async def test_caller_headers_not_mutated_by_drain_loop_async(self): + """Async mirror: drain loop must never mutate the caller's headers. + + Regression guard for the async provider's drain loop. See the sync + ``test_caller_headers_not_mutated_by_drain_loop`` for the full + rationale; both providers shallow-copy ``kwargs`` per iteration and + deep-copy the ``headers`` dict per iteration so that per-page + ``If-None-Match`` overrides and ``prepare_fetch_options_and_headers`` + additions (``A-IM``, page-size, populate-stats) never leak back into + the caller's dict. + """ + page1 = [_full_range("0", "", "55")] + page2 = [_full_range("1", "55", "AA")] + page3 = [_full_range("2", "AA", "FF")] + + client, script = _make_scripted_async_client([ + ("page", page1, '"etag-1"'), + ("page", page2, '"etag-2"'), + ("page", page3, '"etag-3"'), + ("page", [], '"etag-3"'), + ]) + + caller_headers = {"X-Custom-Marker": "value", "Authorization": "Bearer x"} + caller_headers_snapshot = dict(caller_headers) + + cache = PartitionKeyRangeCache(client) + routing_map = await cache._fetch_routing_map( + collection_link="dbs/db1/colls/coll1", + collection_id="coll1", + previous_routing_map=None, + feed_options={}, + headers=caller_headers, + ) + + self.assertIsNotNone(routing_map) + self.assertEqual(script.calls, 4) + self.assertEqual(caller_headers, caller_headers_snapshot) + self.assertNotIn(http_constants.HttpHeaders.IfNoneMatch, caller_headers) + self.assertNotIn(http_constants.HttpHeaders.AIM, caller_headers) + self.assertEqual( + script.if_none_match_seen, + [None, '"etag-1"', '"etag-2"', '"etag-3"'], + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/sdk/cosmos/azure-cosmos/tests/test_query_feed_range_multipartition.py b/sdk/cosmos/azure-cosmos/tests/test_query_feed_range_multipartition.py index 5c7df46c81f9..13f485362a1b 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_query_feed_range_multipartition.py +++ b/sdk/cosmos/azure-cosmos/tests/test_query_feed_range_multipartition.py @@ -859,6 +859,7 @@ def test_three_way_overlap(self): # Post-split resume (slow; requires a real partition split) # ------------------------------------------------------------------ # @pytest.mark.cosmosSplit + @pytest.mark.cosmosAADSplit def test_post_split_resume(self): """End-to-end "the routing layout changed underneath a saved continuation token" scenario: diff --git a/sdk/cosmos/azure-cosmos/tests/test_query_feed_range_multipartition_async.py b/sdk/cosmos/azure-cosmos/tests/test_query_feed_range_multipartition_async.py index 83ffc44fe06c..ba6d5800e21a 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_query_feed_range_multipartition_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_query_feed_range_multipartition_async.py @@ -570,6 +570,7 @@ async def test_three_way_overlap_async(self): # Post-split resume (slow) # ------------------------------------------------------------------ # @pytest.mark.cosmosSplit + @pytest.mark.cosmosAADSplit async def test_post_split_resume_async(self): client = _client() try: diff --git a/sdk/cosmos/azure-cosmos/tests/test_routing_map_provider_unit.py b/sdk/cosmos/azure-cosmos/tests/test_routing_map_provider_unit.py index 1ce11af297f4..17c83535799f 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_routing_map_provider_unit.py +++ b/sdk/cosmos/azure-cosmos/tests/test_routing_map_provider_unit.py @@ -34,6 +34,46 @@ from azure.cosmos._gone_retry_policy_base import _PartitionKeyRangeGoneRetryPolicyBase +# ========================================================= +# Test-only tolerant shim for evaluate_drain_page +# ========================================================= +# Production wires ``_internal_response_status_capture`` via ``_Request`` so +# ``evaluate_drain_page`` always receives a concrete HTTP status. These unit +# tests use lightweight MagicMock side_effects that bypass ``_Request`` and +# therefore leave the sidecar at ``[None]``. Rather than retrofit every mock +# to populate the sidecar, default an unknown status to ``304`` (Not Modified) +# so the drain terminates after the first page -- which is exactly the +# termination signal each existing mock relies on (data on the data path, +# ``iter([])`` on the INM-match path). +# +# This shim is the *only* test-side concession to the strict status contract +# introduced in commit a1e27a57bd; production code is unchanged. +# pylint: disable=wrong-import-position +import azure.cosmos._routing._routing_map_provider_common as _drain_common # noqa: E402 +import azure.cosmos._routing.routing_map_provider as _sync_provider_module # noqa: E402 +import azure.cosmos._routing.aio.routing_map_provider as _async_provider_module # noqa: E402 + +_ORIGINAL_EVALUATE_DRAIN_PAGE = _drain_common.evaluate_drain_page + + +def _tolerant_evaluate_drain_page(*, page_new_etag, current_if_none_match, + new_etag, seen_any_etag, status_code): + if status_code is None: + status_code = 304 + return _ORIGINAL_EVALUATE_DRAIN_PAGE( + page_new_etag=page_new_etag, + current_if_none_match=current_if_none_match, + new_etag=new_etag, + seen_any_etag=seen_any_etag, + status_code=status_code, + ) + + +_drain_common.evaluate_drain_page = _tolerant_evaluate_drain_page +_sync_provider_module.evaluate_drain_page = _tolerant_evaluate_drain_page +_async_provider_module.evaluate_drain_page = _tolerant_evaluate_drain_page + + # ========================================================= # Helpers # ========================================================= @@ -532,17 +572,23 @@ def test_fetch_routing_map_incomplete_retry_succeeds_without_full_refresh(self): client = MagicMock() call_count = {'n': 0} seen_if_none_match = [] + last_etag = {'v': None} def read_pk_ranges_retry_then_success(collection_link, options, response_hook=None, **kwargs): + headers_in = kwargs.get('headers') or {} + inm = headers_in.get(http_constants.HttpHeaders.IfNoneMatch) + if inm is not None and inm == last_etag['v']: + return iter([]) call_count['n'] += 1 - headers = kwargs.get('headers', {}) - seen_if_none_match.append(headers.get(http_constants.HttpHeaders.IfNoneMatch)) + seen_if_none_match.append(inm) + etag = '"etag-inc"' if response_hook: - response_hook({http_constants.HttpHeaders.ETag: '"etag-inc"'}, None) + response_hook({http_constants.HttpHeaders.ETag: etag}, None) capture_headers = kwargs.get('_internal_response_headers_capture') if capture_headers is not None: - capture_headers.update({http_constants.HttpHeaders.ETag: '"etag-inc"'}) + capture_headers.update({http_constants.HttpHeaders.ETag: etag}) + last_etag['v'] = etag # First incremental attempt is incomplete (missing parent), second resolves. if call_count['n'] == 1: @@ -826,13 +872,20 @@ def test_fetch_routing_map_recovers_after_transient_overlap(self): responses = [bad_payload, good_payload] call_count = {'n': 0} + last_etag = {'v': None} client = MagicMock() def fake_read_pk_ranges(collection_link, options, response_hook=None, **kwargs): + headers_in = kwargs.get('headers') or {} + inm = headers_in.get(http_constants.HttpHeaders.IfNoneMatch) + if inm is not None and inm == last_etag['v']: + return iter([]) payload = responses[call_count['n']] if call_count['n'] < len(responses) else good_payload call_count['n'] += 1 - headers = {http_constants.HttpHeaders.ETag: '"etag-{}"'.format(call_count['n'])} + etag = '"etag-{}"'.format(call_count['n']) + headers = {http_constants.HttpHeaders.ETag: etag} + last_etag['v'] = etag if response_hook: response_hook(headers, None) capture_headers = kwargs.get('_internal_response_headers_capture') @@ -872,6 +925,10 @@ def test_fetch_routing_map_surfaces_503_after_persistent_overlap(self): client = MagicMock() def fake_read_pk_ranges(collection_link, options, response_hook=None, **kwargs): + headers_in = kwargs.get('headers') or {} + inm = headers_in.get(http_constants.HttpHeaders.IfNoneMatch) + if inm is not None and inm == '"etag-bad"': + return iter([]) call_count['n'] += 1 headers = {http_constants.HttpHeaders.ETag: '"etag-bad"'} if response_hook: @@ -920,13 +977,20 @@ def test_fetch_routing_map_recovers_after_transient_gap(self): responses = [bad_payload, good_payload] call_count = {'n': 0} + last_etag = {'v': None} client = MagicMock() def fake_read_pk_ranges(collection_link, options, response_hook=None, **kwargs): + headers_in = kwargs.get('headers') or {} + inm = headers_in.get(http_constants.HttpHeaders.IfNoneMatch) + if inm is not None and inm == last_etag['v']: + return iter([]) payload = responses[call_count['n']] if call_count['n'] < len(responses) else good_payload call_count['n'] += 1 - headers = {http_constants.HttpHeaders.ETag: '"etag-{}"'.format(call_count['n'])} + etag = '"etag-{}"'.format(call_count['n']) + headers = {http_constants.HttpHeaders.ETag: etag} + last_etag['v'] = etag if response_hook: response_hook(headers, None) capture_headers = kwargs.get('_internal_response_headers_capture') @@ -959,6 +1023,10 @@ def test_fetch_routing_map_surfaces_503_after_persistent_gap(self): client = MagicMock() def fake_read_pk_ranges(collection_link, options, response_hook=None, **kwargs): + headers_in = kwargs.get('headers') or {} + inm = headers_in.get(http_constants.HttpHeaders.IfNoneMatch) + if inm is not None and inm == '"etag-bad"': + return iter([]) call_count['n'] += 1 headers = {http_constants.HttpHeaders.ETag: '"etag-bad"'} if response_hook: @@ -1039,13 +1107,20 @@ def test_fetch_routing_map_mixed_overlap_and_gap_signals_share_retry_budget(self responses = [overlap_payload, gap_payload, overlap_payload] call_count = {'n': 0} + last_etag = {'v': None} client = MagicMock() def fake_read_pk_ranges(collection_link, options, response_hook=None, **kwargs): + headers_in = kwargs.get('headers') or {} + inm = headers_in.get(http_constants.HttpHeaders.IfNoneMatch) + if inm is not None and inm == last_etag['v']: + return iter([]) payload = responses[call_count['n']] if call_count['n'] < len(responses) else overlap_payload call_count['n'] += 1 - headers = {http_constants.HttpHeaders.ETag: '"etag-mixed-{}"'.format(call_count['n'])} + etag = '"etag-mixed-{}"'.format(call_count['n']) + headers = {http_constants.HttpHeaders.ETag: etag} + last_etag['v'] = etag if response_hook: response_hook(headers, None) capture_headers = kwargs.get('_internal_response_headers_capture') @@ -1091,6 +1166,10 @@ def test_fetch_routing_map_preserves_existing_cache_entry_when_force_refresh_sur ] def fake_read_pk_ranges(collection_link, options, response_hook=None, **kwargs): + headers_in = kwargs.get('headers') or {} + inm = headers_in.get(http_constants.HttpHeaders.IfNoneMatch) + if inm is not None and inm == '"etag-bad"': + return iter([]) headers = {http_constants.HttpHeaders.ETag: '"etag-bad"'} if response_hook: response_hook(headers, None) diff --git a/sdk/cosmos/azure-cosmos/tests/test_routing_map_provider_unit_async.py b/sdk/cosmos/azure-cosmos/tests/test_routing_map_provider_unit_async.py index 5aaf51d525ff..107d00bef165 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_routing_map_provider_unit_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_routing_map_provider_unit_async.py @@ -27,6 +27,51 @@ from azure.cosmos._gone_retry_policy_base import _PartitionKeyRangeGoneRetryPolicyBase +# ========================================================= +# Test-only tolerant shim for evaluate_drain_page +# ========================================================= +# Production wires ``_internal_response_status_capture`` via ``_Request`` so +# ``evaluate_drain_page`` always receives a concrete HTTP status. These unit +# tests use lightweight MagicMock side_effects that bypass ``_Request`` and +# therefore leave the sidecar at ``[None]``. Rather than retrofit every mock +# to populate the sidecar, default an unknown status to ``304`` (Not Modified) +# so the drain terminates after the first page -- which is exactly the +# termination signal each existing mock relies on (data on the data path, +# ``iter([])`` on the INM-match path). +# +# This shim is the *only* test-side concession to the strict status contract +# introduced in commit a1e27a57bd; production code is unchanged. +# pylint: disable=wrong-import-position +import azure.cosmos._routing._routing_map_provider_common as _drain_common # noqa: E402 +import azure.cosmos._routing.routing_map_provider as _sync_provider_module # noqa: E402 +import azure.cosmos._routing.aio.routing_map_provider as _async_provider_module # noqa: E402 + +_ORIGINAL_EVALUATE_DRAIN_PAGE = _drain_common.evaluate_drain_page + + +def _tolerant_evaluate_drain_page(*, page_new_etag, current_if_none_match, + new_etag, seen_any_etag, status_code): + if status_code is None: + status_code = 304 + return _ORIGINAL_EVALUATE_DRAIN_PAGE( + page_new_etag=page_new_etag, + current_if_none_match=current_if_none_match, + new_etag=new_etag, + seen_any_etag=seen_any_etag, + status_code=status_code, + ) + + +_drain_common.evaluate_drain_page = _tolerant_evaluate_drain_page +_sync_provider_module.evaluate_drain_page = _tolerant_evaluate_drain_page +_async_provider_module.evaluate_drain_page = _tolerant_evaluate_drain_page + + +async def _empty_async_gen(): + """Empty async generator used as the INM-match (304) response in mocks.""" + if False: + yield # pragma: no cover + def _make_complete_routing_map(collection_id="coll1", etag='"etag-1"'): """Create a minimal but complete CollectionRoutingMap for testing.""" @@ -403,17 +448,22 @@ async def test_fetch_routing_map_incomplete_retry_succeeds_without_full_refresh_ client = MagicMock() call_count = {'n': 0} seen_if_none_match = [] + last_etag = {'v': None} def read_pk_ranges_retry_then_success(collection_link, options, response_hook=None, **kwargs): + headers_in = kwargs.get('headers') or {} + inm = headers_in.get(http_constants.HttpHeaders.IfNoneMatch) + if inm is not None and inm == last_etag['v']: + return _empty_async_gen() call_count['n'] += 1 - headers = kwargs.get('headers', {}) - seen_if_none_match.append(headers.get(http_constants.HttpHeaders.IfNoneMatch)) + seen_if_none_match.append(inm) if response_hook: response_hook({http_constants.HttpHeaders.ETag: '"etag-inc"'}, None) capture_headers = kwargs.get('_internal_response_headers_capture') if capture_headers is not None: capture_headers.update({http_constants.HttpHeaders.ETag: '"etag-inc"'}) + last_etag['v'] = '"etag-inc"' async def async_gen(): if call_count['n'] == 1: @@ -537,13 +587,20 @@ async def test_fetch_routing_map_recovers_after_transient_overlap_async(self): responses = [bad_payload, good_payload] call_count = {'n': 0} + last_etag = {'v': None} client = MagicMock() def fake_read_pk_ranges(collection_link, options, response_hook=None, **kwargs): + headers_in = kwargs.get('headers') or {} + inm = headers_in.get(http_constants.HttpHeaders.IfNoneMatch) + if inm is not None and inm == last_etag['v']: + return _empty_async_gen() payload = responses[call_count['n']] if call_count['n'] < len(responses) else good_payload call_count['n'] += 1 - headers = {http_constants.HttpHeaders.ETag: '"etag-{}"'.format(call_count['n'])} + etag = '"etag-{}"'.format(call_count['n']) + headers = {http_constants.HttpHeaders.ETag: etag} + last_etag['v'] = etag if response_hook: response_hook(headers, None) capture_headers = kwargs.get('_internal_response_headers_capture') @@ -593,11 +650,18 @@ async def test_fetch_routing_map_surfaces_503_after_persistent_overlap_async(sel {'id': 'R', 'minInclusive': 'A0', 'maxExclusive': 'FF'}, ] call_count = {'n': 0} + last_etag = {'v': None} client = MagicMock() def fake_read_pk_ranges(collection_link, options, response_hook=None, **kwargs): + headers_in = kwargs.get('headers') or {} + inm = headers_in.get(http_constants.HttpHeaders.IfNoneMatch) + if inm is not None and inm == last_etag['v']: + return _empty_async_gen() call_count['n'] += 1 - headers = {http_constants.HttpHeaders.ETag: '"etag-bad"'} + etag = '"etag-bad"' + headers = {http_constants.HttpHeaders.ETag: etag} + last_etag['v'] = etag if response_hook: response_hook(headers, None) capture_headers = kwargs.get('_internal_response_headers_capture') @@ -650,13 +714,20 @@ async def test_fetch_routing_map_recovers_after_transient_gap_async(self): responses = [bad_payload, good_payload] call_count = {'n': 0} + last_etag = {'v': None} client = MagicMock() def fake_read_pk_ranges(collection_link, options, response_hook=None, **kwargs): + headers_in = kwargs.get('headers') or {} + inm = headers_in.get(http_constants.HttpHeaders.IfNoneMatch) + if inm is not None and inm == last_etag['v']: + return _empty_async_gen() payload = responses[call_count['n']] if call_count['n'] < len(responses) else good_payload call_count['n'] += 1 - headers = {http_constants.HttpHeaders.ETag: '"etag-{}"'.format(call_count['n'])} + etag = '"etag-{}"'.format(call_count['n']) + headers = {http_constants.HttpHeaders.ETag: etag} + last_etag['v'] = etag if response_hook: response_hook(headers, None) capture_headers = kwargs.get('_internal_response_headers_capture') @@ -697,11 +768,18 @@ async def test_fetch_routing_map_surfaces_503_after_persistent_gap_async(self): {'id': 'R', 'minInclusive': 'A0', 'maxExclusive': 'FF'}, ] call_count = {'n': 0} + last_etag = {'v': None} client = MagicMock() def fake_read_pk_ranges(collection_link, options, response_hook=None, **kwargs): + headers_in = kwargs.get('headers') or {} + inm = headers_in.get(http_constants.HttpHeaders.IfNoneMatch) + if inm is not None and inm == last_etag['v']: + return _empty_async_gen() call_count['n'] += 1 - headers = {http_constants.HttpHeaders.ETag: '"etag-bad"'} + etag = '"etag-bad"' + headers = {http_constants.HttpHeaders.ETag: etag} + last_etag['v'] = etag if response_hook: response_hook(headers, None) capture_headers = kwargs.get('_internal_response_headers_capture') @@ -788,13 +866,20 @@ async def test_fetch_routing_map_mixed_overlap_and_gap_signals_share_retry_budge responses = [overlap_payload, gap_payload, overlap_payload] call_count = {'n': 0} + last_etag = {'v': None} client = MagicMock() def fake_read_pk_ranges(collection_link, options, response_hook=None, **kwargs): + headers_in = kwargs.get('headers') or {} + inm = headers_in.get(http_constants.HttpHeaders.IfNoneMatch) + if inm is not None and inm == last_etag['v']: + return _empty_async_gen() payload = responses[call_count['n']] if call_count['n'] < len(responses) else overlap_payload call_count['n'] += 1 - headers = {http_constants.HttpHeaders.ETag: '"etag-mixed-{}"'.format(call_count['n'])} + etag = '"etag-mixed-{}"'.format(call_count['n']) + headers = {http_constants.HttpHeaders.ETag: etag} + last_etag['v'] = etag if response_hook: response_hook(headers, None) capture_headers = kwargs.get('_internal_response_headers_capture')