diff --git a/google/cloud/firestore_v1/base_query.py b/google/cloud/firestore_v1/base_query.py index b1b74fcf1..c5a78a0a9 100644 --- a/google/cloud/firestore_v1/base_query.py +++ b/google/cloud/firestore_v1/base_query.py @@ -1366,6 +1366,69 @@ def _cursor_pb(cursor_pair: Optional[Tuple[list, bool]]) -> Optional[Cursor]: return None +def _where_conditions_from_cursor( + cursor: Tuple[List, bool], + orderings: List[pipeline_expressions.Ordering], + is_start_cursor: bool, +) -> pipeline_expressions.BooleanExpression: + """ + Converts a cursor into a filter condition for the pipeline. + + Args: + cursor: The cursor values and the 'before' flag. + orderings: The list of ordering expressions used in the query. + is_start_cursor: True if this is a start_at/start_after cursor, False if it is an end_at/end_before cursor. + Returns: + A BooleanExpression representing the cursor condition. + """ + cursor_values, before = cursor + size = len(cursor_values) + + if is_start_cursor: + filter_func = pipeline_expressions.Expression.greater_than + else: + filter_func = pipeline_expressions.Expression.less_than + + field = orderings[size - 1].expr + value = pipeline_expressions.Constant(cursor_values[size - 1]) + + # Add condition for last bound + condition = filter_func(field, value) + + if (is_start_cursor and before) or (not is_start_cursor and not before): + # When the cursor bound is inclusive, then the last bound + # can be equal to the value, otherwise it's not equal + condition = pipeline_expressions.Or(condition, field.equal(value)) + + # Iterate backwards over the remaining bounds, adding a condition for each one + for i in range(size - 2, -1, -1): + field = orderings[i].expr + value = pipeline_expressions.Constant(cursor_values[i]) + + # For each field in the orderings, the condition is either + # a) lessThan|greaterThan the cursor value, + # b) or equal the cursor value and lessThan|greaterThan the cursor values for other fields + condition = pipeline_expressions.Or( + filter_func(field, value), + pipeline_expressions.And(field.equal(value), condition), + ) + + return condition + + +def _reverse_orderings( + orderings: List[pipeline_expressions.Ordering], +) -> List[pipeline_expressions.Ordering]: + reversed_orderings = [] + for o in orderings: + if o.order_dir == pipeline_expressions.Ordering.Direction.ASCENDING: + new_dir = "descending" + else: + new_dir = "ascending" + reversed_orderings.append(pipeline_expressions.Ordering(o.expr, new_dir)) + return reversed_orderings + + def _query_response_to_snapshot( response_pb: RunQueryResponse, collection, expected_prefix: str ) -> Optional[document.DocumentSnapshot]: diff --git a/google/cloud/firestore_v1/services/firestore/async_client.py b/google/cloud/firestore_v1/services/firestore/async_client.py index 3557eb94c..96421f879 100644 --- a/google/cloud/firestore_v1/services/firestore/async_client.py +++ b/google/cloud/firestore_v1/services/firestore/async_client.py @@ -238,6 +238,9 @@ def __init__( If a Callable is given, it will be called with the same set of initialization arguments as used in the FirestoreTransport constructor. If set to None, a transport is chosen automatically. + NOTE: "rest" transport functionality is currently in a + beta state (preview). We welcome your feedback via an + issue in this library's source repository. client_options (Optional[Union[google.api_core.client_options.ClientOptions, dict]]): Custom options for the client. diff --git a/google/cloud/firestore_v1/services/firestore/client.py b/google/cloud/firestore_v1/services/firestore/client.py index ac86aaa9e..e362896af 100644 --- a/google/cloud/firestore_v1/services/firestore/client.py +++ b/google/cloud/firestore_v1/services/firestore/client.py @@ -571,6 +571,9 @@ def __init__( If a Callable is given, it will be called with the same set of initialization arguments as used in the FirestoreTransport constructor. If set to None, a transport is chosen automatically. + NOTE: "rest" transport functionality is currently in a + beta state (preview). We welcome your feedback via an + issue in this library's source repository. client_options (Optional[Union[google.api_core.client_options.ClientOptions, dict]]): Custom options for the client. diff --git a/google/cloud/firestore_v1/services/firestore/transports/rest.py b/google/cloud/firestore_v1/services/firestore/transports/rest.py index 845569d97..31db60310 100644 --- a/google/cloud/firestore_v1/services/firestore/transports/rest.py +++ b/google/cloud/firestore_v1/services/firestore/transports/rest.py @@ -994,14 +994,9 @@ def __init__( ) -> None: """Instantiate the transport. - Args: - host (Optional[str]): - The hostname to connect to (default: 'firestore.googleapis.com'). - credentials (Optional[google.auth.credentials.Credentials]): The - authorization credentials to attach to requests. These - credentials identify the application to the service; if none - are specified, the client will attempt to ascertain the - credentials from the environment. + NOTE: This REST transport functionality is currently in a beta + state (preview). We welcome your feedback via a GitHub issue in + this library's repository. Thank you! credentials_file (Optional[str]): Deprecated. A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. diff --git a/google/cloud/firestore_v1/services/firestore/transports/rest_base.py b/google/cloud/firestore_v1/services/firestore/transports/rest_base.py index 80ce35e49..66cffc43c 100644 --- a/google/cloud/firestore_v1/services/firestore/transports/rest_base.py +++ b/google/cloud/firestore_v1/services/firestore/transports/rest_base.py @@ -130,7 +130,7 @@ def _get_request_body_json(transcoded_request): # Jsonify the request body body = json_format.MessageToJson( - transcoded_request["body"], use_integers_for_enums=True + transcoded_request["body"], use_integers_for_enums=False ) return body @@ -139,7 +139,7 @@ def _get_query_params_json(transcoded_request): query_params = json.loads( json_format.MessageToJson( transcoded_request["query_params"], - use_integers_for_enums=True, + use_integers_for_enums=False, ) ) query_params.update( @@ -148,7 +148,6 @@ def _get_query_params_json(transcoded_request): ) ) - query_params["$alt"] = "json;enum-encoding=int" return query_params class _BaseBatchWrite: @@ -187,7 +186,7 @@ def _get_request_body_json(transcoded_request): # Jsonify the request body body = json_format.MessageToJson( - transcoded_request["body"], use_integers_for_enums=True + transcoded_request["body"], use_integers_for_enums=False ) return body @@ -196,7 +195,7 @@ def _get_query_params_json(transcoded_request): query_params = json.loads( json_format.MessageToJson( transcoded_request["query_params"], - use_integers_for_enums=True, + use_integers_for_enums=False, ) ) query_params.update( @@ -205,7 +204,6 @@ def _get_query_params_json(transcoded_request): ) ) - query_params["$alt"] = "json;enum-encoding=int" return query_params class _BaseBeginTransaction: @@ -244,7 +242,7 @@ def _get_request_body_json(transcoded_request): # Jsonify the request body body = json_format.MessageToJson( - transcoded_request["body"], use_integers_for_enums=True + transcoded_request["body"], use_integers_for_enums=False ) return body @@ -253,7 +251,7 @@ def _get_query_params_json(transcoded_request): query_params = json.loads( json_format.MessageToJson( transcoded_request["query_params"], - use_integers_for_enums=True, + use_integers_for_enums=False, ) ) query_params.update( @@ -262,7 +260,6 @@ def _get_query_params_json(transcoded_request): ) ) - query_params["$alt"] = "json;enum-encoding=int" return query_params class _BaseCommit: @@ -301,7 +298,7 @@ def _get_request_body_json(transcoded_request): # Jsonify the request body body = json_format.MessageToJson( - transcoded_request["body"], use_integers_for_enums=True + transcoded_request["body"], use_integers_for_enums=False ) return body @@ -310,7 +307,7 @@ def _get_query_params_json(transcoded_request): query_params = json.loads( json_format.MessageToJson( transcoded_request["query_params"], - use_integers_for_enums=True, + use_integers_for_enums=False, ) ) query_params.update( @@ -319,7 +316,6 @@ def _get_query_params_json(transcoded_request): ) ) - query_params["$alt"] = "json;enum-encoding=int" return query_params class _BaseCreateDocument: @@ -358,7 +354,7 @@ def _get_request_body_json(transcoded_request): # Jsonify the request body body = json_format.MessageToJson( - transcoded_request["body"], use_integers_for_enums=True + transcoded_request["body"], use_integers_for_enums=False ) return body @@ -367,7 +363,7 @@ def _get_query_params_json(transcoded_request): query_params = json.loads( json_format.MessageToJson( transcoded_request["query_params"], - use_integers_for_enums=True, + use_integers_for_enums=False, ) ) query_params.update( @@ -376,7 +372,6 @@ def _get_query_params_json(transcoded_request): ) ) - query_params["$alt"] = "json;enum-encoding=int" return query_params class _BaseDeleteDocument: @@ -414,7 +409,7 @@ def _get_query_params_json(transcoded_request): query_params = json.loads( json_format.MessageToJson( transcoded_request["query_params"], - use_integers_for_enums=True, + use_integers_for_enums=False, ) ) query_params.update( @@ -423,6 +418,62 @@ def _get_query_params_json(transcoded_request): ) ) + return query_params + + class _BaseExecutePipeline: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/v1/{database=projects/*/databases/*}/documents:executePipeline", + "body": "*", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = firestore.ExecutePipelineRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_request_body_json(transcoded_request): + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], use_integers_for_enums=False + ) + return body + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=False, + ) + ) + query_params.update( + _BaseFirestoreRestTransport._BaseExecutePipeline._get_unset_required_fields( + query_params + ) + ) + query_params["$alt"] = "json;enum-encoding=int" return query_params @@ -518,7 +569,7 @@ def _get_query_params_json(transcoded_request): query_params = json.loads( json_format.MessageToJson( transcoded_request["query_params"], - use_integers_for_enums=True, + use_integers_for_enums=False, ) ) query_params.update( @@ -527,7 +578,6 @@ def _get_query_params_json(transcoded_request): ) ) - query_params["$alt"] = "json;enum-encoding=int" return query_params class _BaseListCollectionIds: @@ -571,7 +621,7 @@ def _get_request_body_json(transcoded_request): # Jsonify the request body body = json_format.MessageToJson( - transcoded_request["body"], use_integers_for_enums=True + transcoded_request["body"], use_integers_for_enums=False ) return body @@ -580,7 +630,7 @@ def _get_query_params_json(transcoded_request): query_params = json.loads( json_format.MessageToJson( transcoded_request["query_params"], - use_integers_for_enums=True, + use_integers_for_enums=False, ) ) query_params.update( @@ -589,7 +639,6 @@ def _get_query_params_json(transcoded_request): ) ) - query_params["$alt"] = "json;enum-encoding=int" return query_params class _BaseListDocuments: @@ -631,7 +680,7 @@ def _get_query_params_json(transcoded_request): query_params = json.loads( json_format.MessageToJson( transcoded_request["query_params"], - use_integers_for_enums=True, + use_integers_for_enums=False, ) ) query_params.update( @@ -640,7 +689,6 @@ def _get_query_params_json(transcoded_request): ) ) - query_params["$alt"] = "json;enum-encoding=int" return query_params class _BaseListen: @@ -688,7 +736,7 @@ def _get_request_body_json(transcoded_request): # Jsonify the request body body = json_format.MessageToJson( - transcoded_request["body"], use_integers_for_enums=True + transcoded_request["body"], use_integers_for_enums=False ) return body @@ -697,7 +745,7 @@ def _get_query_params_json(transcoded_request): query_params = json.loads( json_format.MessageToJson( transcoded_request["query_params"], - use_integers_for_enums=True, + use_integers_for_enums=False, ) ) query_params.update( @@ -706,7 +754,6 @@ def _get_query_params_json(transcoded_request): ) ) - query_params["$alt"] = "json;enum-encoding=int" return query_params class _BaseRollback: @@ -745,7 +792,7 @@ def _get_request_body_json(transcoded_request): # Jsonify the request body body = json_format.MessageToJson( - transcoded_request["body"], use_integers_for_enums=True + transcoded_request["body"], use_integers_for_enums=False ) return body @@ -754,7 +801,7 @@ def _get_query_params_json(transcoded_request): query_params = json.loads( json_format.MessageToJson( transcoded_request["query_params"], - use_integers_for_enums=True, + use_integers_for_enums=False, ) ) query_params.update( @@ -763,7 +810,6 @@ def _get_query_params_json(transcoded_request): ) ) - query_params["$alt"] = "json;enum-encoding=int" return query_params class _BaseRunAggregationQuery: @@ -807,7 +853,7 @@ def _get_request_body_json(transcoded_request): # Jsonify the request body body = json_format.MessageToJson( - transcoded_request["body"], use_integers_for_enums=True + transcoded_request["body"], use_integers_for_enums=False ) return body @@ -816,7 +862,7 @@ def _get_query_params_json(transcoded_request): query_params = json.loads( json_format.MessageToJson( transcoded_request["query_params"], - use_integers_for_enums=True, + use_integers_for_enums=False, ) ) query_params.update( @@ -825,7 +871,6 @@ def _get_query_params_json(transcoded_request): ) ) - query_params["$alt"] = "json;enum-encoding=int" return query_params class _BaseRunQuery: @@ -869,7 +914,7 @@ def _get_request_body_json(transcoded_request): # Jsonify the request body body = json_format.MessageToJson( - transcoded_request["body"], use_integers_for_enums=True + transcoded_request["body"], use_integers_for_enums=False ) return body @@ -878,7 +923,7 @@ def _get_query_params_json(transcoded_request): query_params = json.loads( json_format.MessageToJson( transcoded_request["query_params"], - use_integers_for_enums=True, + use_integers_for_enums=False, ) ) query_params.update( @@ -887,7 +932,6 @@ def _get_query_params_json(transcoded_request): ) ) - query_params["$alt"] = "json;enum-encoding=int" return query_params class _BaseUpdateDocument: @@ -926,7 +970,7 @@ def _get_request_body_json(transcoded_request): # Jsonify the request body body = json_format.MessageToJson( - transcoded_request["body"], use_integers_for_enums=True + transcoded_request["body"], use_integers_for_enums=False ) return body @@ -935,7 +979,7 @@ def _get_query_params_json(transcoded_request): query_params = json.loads( json_format.MessageToJson( transcoded_request["query_params"], - use_integers_for_enums=True, + use_integers_for_enums=False, ) ) query_params.update( @@ -944,7 +988,6 @@ def _get_query_params_json(transcoded_request): ) ) - query_params["$alt"] = "json;enum-encoding=int" return query_params class _BaseWrite: diff --git a/tests/system/test_system.py b/tests/system/test_system.py index 0c86c69a3..d45753b64 100644 --- a/tests/system/test_system.py +++ b/tests/system/test_system.py @@ -66,6 +66,12 @@ def _get_credentials_and_project(): @pytest.fixture(scope="session") def database(request): + from test__helpers import FIRESTORE_ENTERPRISE_DB + + # enterprise mode currently does not support RunQuery calls in prod on kokoro test project + # TODO: remove skip when kokoro test project supports full enterprise mode + if request.param == FIRESTORE_ENTERPRISE_DB and IS_KOKORO_TEST: + pytest.skip("enterprise mode does not support RunQuery on kokoro") return request.param @@ -1450,6 +1456,7 @@ def test_query_stream_w_start_end_cursor(query_docs, database): for key, value in values: assert stored[key] == value assert value["a"] == num_vals - 2 + verify_pipeline(query) @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) @@ -1883,6 +1890,9 @@ def test_query_with_order_dot_key(client, cleanup, database): ) cursor_with_key_data = list(query4.stream()) assert found_data == [snap.to_dict() for snap in cursor_with_key_data] + verify_pipeline(query) + verify_pipeline(query2) + verify_pipeline(query3) @pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) @@ -2007,6 +2017,7 @@ def test_collection_group_queries_startat_endat(client, cleanup, database): snapshots = list(query.stream()) found = set(snapshot.id for snapshot in snapshots) assert found == set(["cg-doc2", "cg-doc3", "cg-doc4"]) + verify_pipeline(query) query = ( client.collection_group(collection_group) @@ -2017,6 +2028,7 @@ def test_collection_group_queries_startat_endat(client, cleanup, database): snapshots = list(query.stream()) found = set(snapshot.id for snapshot in snapshots) assert found == set(["cg-doc2"]) + verify_pipeline(query) @pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) @@ -3013,6 +3025,7 @@ def test_count_query_with_start_at(query, database): for result in count_query.stream(): for aggregation_result in result: assert aggregation_result.value == expected_count + verify_pipeline(count_query) @pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) diff --git a/tests/system/test_system_async.py b/tests/system/test_system_async.py index 1442e7932..c74463962 100644 --- a/tests/system/test_system_async.py +++ b/tests/system/test_system_async.py @@ -144,6 +144,12 @@ def _verify_explain_metrics_analyze_false(explain_metrics): @pytest.fixture(scope="session") def database(request): + from test__helpers import FIRESTORE_ENTERPRISE_DB + + # enterprise mode currently does not support RunQuery calls in prod on kokoro test project + # TODO: remove skip when kokoro test project supports full enterprise mode + if request.param == FIRESTORE_ENTERPRISE_DB and IS_KOKORO_TEST: + pytest.skip("enterprise mode does not support RunQuery on kokoro") return request.param diff --git a/tests/unit/gapic/firestore_v1/test_firestore.py b/tests/unit/gapic/firestore_v1/test_firestore.py index e3821e772..af45e4326 100644 --- a/tests/unit/gapic/firestore_v1/test_firestore.py +++ b/tests/unit/gapic/firestore_v1/test_firestore.py @@ -6385,7 +6385,7 @@ def test_get_document_rest_required_fields(request_type=firestore.GetDocumentReq response = client.get_document(request) - expected_params = [("$alt", "json;enum-encoding=int")] + expected_params = [] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params @@ -6526,7 +6526,7 @@ def test_list_documents_rest_required_fields( response = client.list_documents(request) - expected_params = [("$alt", "json;enum-encoding=int")] + expected_params = [] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params @@ -6727,7 +6727,7 @@ def test_update_document_rest_required_fields( response = client.update_document(request) - expected_params = [("$alt", "json;enum-encoding=int")] + expected_params = [] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params @@ -6919,7 +6919,7 @@ def test_delete_document_rest_required_fields( response = client.delete_document(request) - expected_params = [("$alt", "json;enum-encoding=int")] + expected_params = [] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params @@ -7105,7 +7105,7 @@ def test_batch_get_documents_rest_required_fields( iter_content.return_value = iter(json_return_value) response = client.batch_get_documents(request) - expected_params = [("$alt", "json;enum-encoding=int")] + expected_params = [] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params @@ -7228,7 +7228,7 @@ def test_begin_transaction_rest_required_fields( response = client.begin_transaction(request) - expected_params = [("$alt", "json;enum-encoding=int")] + expected_params = [] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params @@ -7405,7 +7405,7 @@ def test_commit_rest_required_fields(request_type=firestore.CommitRequest): response = client.commit(request) - expected_params = [("$alt", "json;enum-encoding=int")] + expected_params = [] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params @@ -7585,7 +7585,7 @@ def test_rollback_rest_required_fields(request_type=firestore.RollbackRequest): response = client.rollback(request) - expected_params = [("$alt", "json;enum-encoding=int")] + expected_params = [] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params @@ -7773,7 +7773,7 @@ def test_run_query_rest_required_fields(request_type=firestore.RunQueryRequest): iter_content.return_value = iter(json_return_value) response = client.run_query(request) - expected_params = [("$alt", "json;enum-encoding=int")] + expected_params = [] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params @@ -8028,7 +8028,7 @@ def test_run_aggregation_query_rest_required_fields( iter_content.return_value = iter(json_return_value) response = client.run_aggregation_query(request) - expected_params = [("$alt", "json;enum-encoding=int")] + expected_params = [] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params @@ -8149,7 +8149,7 @@ def test_partition_query_rest_required_fields( response = client.partition_query(request) - expected_params = [("$alt", "json;enum-encoding=int")] + expected_params = [] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params @@ -8357,7 +8357,7 @@ def test_list_collection_ids_rest_required_fields( response = client.list_collection_ids(request) - expected_params = [("$alt", "json;enum-encoding=int")] + expected_params = [] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params @@ -8597,7 +8597,7 @@ def test_batch_write_rest_required_fields(request_type=firestore.BatchWriteReque response = client.batch_write(request) - expected_params = [("$alt", "json;enum-encoding=int")] + expected_params = [] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params @@ -8729,7 +8729,7 @@ def test_create_document_rest_required_fields( response = client.create_document(request) - expected_params = [("$alt", "json;enum-encoding=int")] + expected_params = [] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params diff --git a/tests/unit/v1/test_base_query.py b/tests/unit/v1/test_base_query.py index 4a4dac727..98461922f 100644 --- a/tests/unit/v1/test_base_query.py +++ b/tests/unit/v1/test_base_query.py @@ -2298,3 +2298,99 @@ def _make_snapshot(docref, values): from google.cloud.firestore_v1 import document return document.DocumentSnapshot(docref, values, True, None, None, None) + + +def test__build_pipeline_limit_to_last_ordering(): + from google.cloud.firestore_v1 import pipeline_expressions as expr + + # Verify that for limit_to_last=True: + # 1. Sort (reversed) + # 2. Where (cursor condition) + + client = make_client() + # Query: Order by 'a' ASC, StartAt(10), LimitToLast(5) + query = ( + client.collection("my_col").order_by("a").start_at({"a": 10}).limit_to_last(5) + ) + + pipeline = query._build_pipeline(client.pipeline()) + + # Expected stages: + # 0. Collection + # 1. Exists (for 'a') + # 2. Sort (DESCENDING) -> This must come BEFORE the cursor filter + # 3. Where (a > 10 condition or similar) + # 4. Limit (5) + # 5. Sort (ASCENDING) + + assert len(pipeline.stages) >= 4 + + # Find indices + sort_reversed_idx = -1 + cursor_where_idx = -1 + + for i, stage in enumerate(pipeline.stages): + if isinstance(stage, stages.Sort): + # Check if it is the reversed sort (DESCENDING) + if ( + len(stage.orders) > 0 + and stage.orders[0].order_dir == expr.Ordering.Direction.DESCENDING + ): + if sort_reversed_idx == -1: + sort_reversed_idx = i + + if isinstance(stage, stages.Where): + # Check if this is the cursor condition. + # Cursor condition for start_at({"a": 10}) should be related to 'a' and 10. + # usually an OR or Comparison. + # The Exists filter is also a Where, but it's usually `exists(a)`. + + # Simple check: The condition is not just an 'exists' function call. + cond = stage.condition + if not (hasattr(cond, "name") and cond.name == "exists"): + # Assume this is the cursor filter + cursor_where_idx = i + + assert sort_reversed_idx != -1, "Reversed sort stage not found" + assert cursor_where_idx != -1, "Cursor filter stage not found" + + # Reversed Sort must happen BEFORE Cursor Filter + assert sort_reversed_idx < cursor_where_idx + + +def test__build_pipeline_normal_ordering(): + from google.cloud.firestore_v1 import pipeline_expressions as expr + + # Verify that for limit_to_last=False (Normal): + # 1. Where (cursor condition) + # 2. Sort + + client = make_client() + # Query: Order by 'a' ASC, StartAt(10) + query = client.collection("my_col").order_by("a").start_at({"a": 10}) + + pipeline = query._build_pipeline(client.pipeline()) + + # Expected stages: + # 0. Collection + # 1. Exists (for 'a') + # 2. Where (cursor condition) + # 3. Sort (ASCENDING) + + sort_idx = -1 + cursor_where_idx = -1 + + for i, stage in enumerate(pipeline.stages): + if isinstance(stage, stages.Sort): + sort_idx = i + + if isinstance(stage, stages.Where): + cond = stage.condition + if not (hasattr(cond, "name") and cond.name == "exists"): + cursor_where_idx = i + + assert sort_idx != -1, "Sort stage not found" + assert cursor_where_idx != -1, "Cursor filter stage not found" + + # Cursor Filter must happen BEFORE Sort + assert cursor_where_idx < sort_idx