diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 64ad10050d..529fffad15 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -877,12 +877,35 @@ def upsert( # get list of rows that exist so we don't have to load the entire target table matched_predicate = upsert_util.create_match_filter(df, join_cols) + # Augment the row filter with [min, max] predicates on any + # partition source column present in ``df``. ``inclusive_projection`` + # projects the range through monotonic partition transforms when + # planning the scan, so ``DataScan.plan_files`` can prune the + # destination at the manifest + file level. + matched_predicate = upsert_util.augment_filter_with_partition_ranges( + matched_predicate, + df, + self.table_metadata.schema(), + self.table_metadata.spec(), + ) + + # When ``when_matched_update_all=False`` the consumer loop below + # only ever reads ``join_cols`` off each destination batch (to + # build the per-batch match filter via + # ``upsert_util.create_match_filter``). Project ``join_cols`` + # only so the parquet reader can prune wide non-key columns. + # + # ``when_matched_update_all=True`` falls back to the legacy + # ``("*",)`` projection. + selected_fields: tuple[str, ...] = ("*",) if when_matched_update_all else tuple(join_cols) + # We must use Transaction.table_metadata for the scan. This includes all uncommitted - but relevant - changes. matched_iceberg_record_batches_scan = DataScan( table_metadata=self.table_metadata, io=self._table.io, row_filter=matched_predicate, + selected_fields=selected_fields, case_sensitive=case_sensitive, ) @@ -2072,13 +2095,11 @@ def _build_residual_evaluator(self, spec_id: int) -> Callable[[DataFile], Residu # The lambda created here is run in multiple threads. # So we avoid creating _EvaluatorExpression methods bound to a single # shared instance across multiple threads. - return lambda datafile: ( - residual_evaluator_of( - spec=spec, - expr=self.row_filter, - case_sensitive=self.case_sensitive, - schema=self.table_metadata.schema(), - ) + return lambda datafile: residual_evaluator_of( + spec=spec, + expr=self.row_filter, + case_sensitive=self.case_sensitive, + schema=self.table_metadata.schema(), ) @staticmethod diff --git a/pyiceberg/table/upsert_util.py b/pyiceberg/table/upsert_util.py index 6f32826eb0..717f8a3a92 100644 --- a/pyiceberg/table/upsert_util.py +++ b/pyiceberg/table/upsert_util.py @@ -23,11 +23,16 @@ from pyiceberg.expressions import ( AlwaysFalse, + And, BooleanExpression, EqualTo, + GreaterThanOrEqual, In, + LessThanOrEqual, Or, ) +from pyiceberg.partitioning import PartitionSpec +from pyiceberg.schema import Schema def create_match_filter(df: pyarrow_table, join_cols: list[str]) -> BooleanExpression: @@ -53,6 +58,104 @@ def has_duplicate_rows(df: pyarrow_table, join_cols: list[str]) -> bool: return len(df.select(join_cols).group_by(join_cols).aggregate([([], "count_all")]).filter(pc.field("count_all") > 1)) > 0 +def augment_filter_with_partition_ranges( + matched_predicate: BooleanExpression, + df: pyarrow_table, + schema: Schema, + spec: PartitionSpec, +) -> BooleanExpression: + """Return *matched_predicate* AND'd with ``[min, max]`` predicates on partition source columns. + + Iceberg's ``inclusive_projection`` projects each range through the + partition transform (``hours``, ``days``, ``months``, ``years``, + ``identity``, ``truncate``) when planning the scan, so + ``DataScan.plan_files`` can prune manifests and data files that + don't overlap the source's value range. Without this augmentation, + tables whose partition spec sources from columns NOT in + ``join_cols`` (a common pattern for append-only event logs + partitioned by time but keyed by composite IDs) fall through to a + full table scan on every upsert because the row filter built from + ``join_cols`` alone projects to ``AlwaysTrue`` against the + partition spec. + + Bucket and other non-monotonic transforms return ``None`` from + their ``project`` method for inequalities, so the augmentation is + safe — it either prunes or contributes ``AlwaysTrue`` (no harm). + + A partition source column is skipped from augmentation when: + + - It isn't present on ``df`` (no source value to bound). + - It is entirely null in ``df`` (no meaningful min/max). + - It contains any null in ``df`` (preserving correctness: a + ``GreaterThanOrEqual(col, non_null_min)`` predicate would + exclude destination rows whose partition value is ``NULL``, + potentially missing a key match. Without partition pruning + those NULL-partition rows are scanned normally.) + + When ``min == max`` for a column, an ``EqualTo`` predicate is + emitted instead of the range pair — tighter, and lets exact + partition pruning fire. + + Args: + matched_predicate: The row filter built from ``join_cols``. + df: Source data frame whose values bound the augmentation. + schema: Iceberg schema, used to resolve partition source ids + to column names. + spec: Active partition spec. + + Returns: + The augmented predicate, or *matched_predicate* unchanged + when no partition source column qualifies. + """ + if spec.is_unpartitioned(): + return matched_predicate + + df_columns = set(df.column_names) + augmentations: list[BooleanExpression] = [] + + # Iterate distinct source columns rather than partition fields — + # multiple partition fields can share a source column (e.g. + # ``bucket(8, id), truncate(4, id)``) but we only need to add the + # source-column range once; ``inclusive_projection`` projects + # through each partition field independently. + seen_source_ids: set[int] = set() + for field in spec.fields: + if field.source_id in seen_source_ids: + continue + seen_source_ids.add(field.source_id) + + col_name = schema.find_field(field.source_id).name + if col_name not in df_columns: + continue + + col = df[col_name] + if col.null_count > 0: + # Mixing null with a bounded predicate would exclude + # destination rows whose partition value is null, + # potentially missing key matches. Skip pruning rather + # than risk a correctness regression. + continue + + col_min = pc.min(col).as_py() + col_max = pc.max(col).as_py() + if col_min is None or col_max is None: + # Defensive — ``null_count == 0`` should imply both bounds + # are non-null, but pyarrow's min/max can still return None + # on empty columns. + continue + + if col_min == col_max: + augmentations.append(EqualTo(col_name, col_min)) + else: + augmentations.append(GreaterThanOrEqual(col_name, col_min)) + augmentations.append(LessThanOrEqual(col_name, col_max)) + + if not augmentations: + return matched_predicate + + return functools.reduce(And, [matched_predicate, *augmentations]) + + def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, join_cols: list[str]) -> pa.Table: """ Return a table with rows that need to be updated in the target table based on the join columns. diff --git a/tests/table/test_upsert.py b/tests/table/test_upsert.py index 08f90c6600..0b8ca0155f 100644 --- a/tests/table/test_upsert.py +++ b/tests/table/test_upsert.py @@ -14,7 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import datetime from pathlib import PosixPath +from typing import Any import pyarrow as pa import pytest @@ -23,14 +25,16 @@ from pyiceberg.catalog import Catalog from pyiceberg.exceptions import NoSuchTableError -from pyiceberg.expressions import AlwaysTrue, And, EqualTo, Reference +from pyiceberg.expressions import AlwaysTrue, And, EqualTo, GreaterThanOrEqual, LessThanOrEqual, Reference from pyiceberg.expressions.literals import LongLiteral from pyiceberg.io.pyarrow import schema_to_pyarrow +from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC, PartitionField, PartitionSpec from pyiceberg.schema import Schema from pyiceberg.table import Table, UpsertResult from pyiceberg.table.snapshots import Operation -from pyiceberg.table.upsert_util import create_match_filter -from pyiceberg.types import IntegerType, NestedField, StringType, StructType +from pyiceberg.table.upsert_util import augment_filter_with_partition_ranges, create_match_filter +from pyiceberg.transforms import IdentityTransform +from pyiceberg.types import DateType, IntegerType, NestedField, StringType, StructType from tests.catalog.test_base import InMemoryCatalog @@ -888,3 +892,660 @@ def test_upsert_snapshot_properties(catalog: Catalog) -> None: for snapshot in snapshots[initial_snapshot_count:]: assert snapshot.summary is not None assert snapshot.summary.additional_properties.get("test_prop") == "test_value" + + +# --------------------------------------------------------------------------- +# Partition-range augmentation for upsert row filters. +# +# ``Transaction.upsert`` builds its scan ``row_filter`` from ``join_cols`` +# alone via ``create_match_filter``. When the partition spec sources from +# columns NOT in ``join_cols`` (a common pattern for append-only event logs +# partitioned by time but keyed by composite IDs), ``inclusive_projection`` +# collapses the entire predicate to ``AlwaysTrue`` against the partition +# spec and ``DataScan.plan_files`` falls through to a full table scan. +# +# ``augment_filter_with_partition_ranges`` derives ``[min, max]`` predicates +# from ``df`` for every partition source column present in the frame and +# ANDs them into the row filter. Iceberg's inclusive projection then +# projects each range through the partition transform when planning the +# scan, enabling manifest- and file-level pruning. +# +# See related issues #2138, #2159, #3129. +# --------------------------------------------------------------------------- + + +class TestAugmentFilterWithPartitionRanges: + """Pure-function tests for ``augment_filter_with_partition_ranges``. + + Asserts the structural shape of the augmented predicate. End-to-end + file-pruning behaviour is exercised by the upsert integration tests + below. + """ + + @staticmethod + def _orders_schema() -> Schema: + return Schema( + NestedField(1, "order_id", IntegerType(), required=True), + NestedField(2, "order_date", DateType(), required=True), + NestedField(3, "order_type", StringType(), required=True), + ) + + @staticmethod + def _orders_pa_schema() -> pa.Schema: + return pa.schema( + [ + pa.field("order_id", pa.int32(), nullable=False), + pa.field("order_date", pa.date32(), nullable=False), + pa.field("order_type", pa.string(), nullable=False), + ] + ) + + def _df(self, rows: list[dict[str, object]]) -> pa.Table: + return pa.Table.from_pylist(rows, schema=self._orders_pa_schema()) + + def test_unpartitioned_spec_returns_input_unchanged(self) -> None: + """Tables without a partition spec have nothing to project through. + The augmentation must short-circuit and hand back the exact + ``matched_predicate`` object — no allocation, no semantic change.""" + df = self._df([{"order_id": 1, "order_date": datetime.date(2026, 1, 1), "order_type": "A"}]) + matched = create_match_filter(df, ["order_id"]) + augmented = augment_filter_with_partition_ranges( + matched_predicate=matched, + df=df, + schema=self._orders_schema(), + spec=UNPARTITIONED_PARTITION_SPEC, + ) + assert augmented is matched + + def test_partition_source_column_not_in_df_skipped(self) -> None: + """Source frames that don't contain the partition source column + can't contribute a bound — the augmentation has to skip rather + than guess. Returns ``matched_predicate`` unchanged so the + existing scan behaviour applies.""" + df_no_date = pa.Table.from_pylist( + [{"order_id": 1, "order_type": "A"}], + schema=pa.schema( + [ + pa.field("order_id", pa.int32(), nullable=False), + pa.field("order_type", pa.string(), nullable=False), + ] + ), + ) + matched = create_match_filter(df_no_date, ["order_id"]) + spec = PartitionSpec(PartitionField(source_id=2, field_id=1000, transform=IdentityTransform(), name="order_date")) + augmented = augment_filter_with_partition_ranges( + matched_predicate=matched, + df=df_no_date, + schema=self._orders_schema(), + spec=spec, + ) + assert augmented == matched + + def test_partition_source_column_all_nulls_skipped(self) -> None: + """When every value of the partition source column in ``df`` is + null, there is no meaningful ``min`` / ``max`` to bound the + predicate. Skip rather than emit a vacuous augmentation.""" + df = pa.Table.from_pylist( + [{"order_id": 1, "order_date": None, "order_type": "A"}], + schema=pa.schema( + [ + pa.field("order_id", pa.int32(), nullable=False), + pa.field("order_date", pa.date32(), nullable=True), + pa.field("order_type", pa.string(), nullable=False), + ] + ), + ) + matched = create_match_filter(df, ["order_id"]) + spec = PartitionSpec(PartitionField(source_id=2, field_id=1000, transform=IdentityTransform(), name="order_date")) + augmented = augment_filter_with_partition_ranges( + matched_predicate=matched, + df=df, + schema=self._orders_schema(), + spec=spec, + ) + assert augmented == matched + + def test_partition_source_column_some_nulls_skipped(self) -> None: + """Correctness guard: a partial-null source column cannot use a + non-null ``GreaterThanOrEqual`` augmentation because destination + rows whose partition value is NULL would be excluded from the + match scan even though their ``(key)`` may collide with the + null-partition source rows. Skip pruning over emitting an unsafe + predicate.""" + df = pa.Table.from_pylist( + [ + {"order_id": 1, "order_date": datetime.date(2026, 1, 1), "order_type": "A"}, + {"order_id": 2, "order_date": None, "order_type": "B"}, + ], + schema=pa.schema( + [ + pa.field("order_id", pa.int32(), nullable=False), + pa.field("order_date", pa.date32(), nullable=True), + pa.field("order_type", pa.string(), nullable=False), + ] + ), + ) + matched = create_match_filter(df, ["order_id"]) + spec = PartitionSpec(PartitionField(source_id=2, field_id=1000, transform=IdentityTransform(), name="order_date")) + augmented = augment_filter_with_partition_ranges( + matched_predicate=matched, + df=df, + schema=self._orders_schema(), + spec=spec, + ) + assert augmented == matched + + def test_single_value_partition_column_emits_equal_to(self) -> None: + """``min == max`` collapses to a single ``EqualTo`` — tighter than + the range pair and lets exact partition pruning fire (e.g. when + every source row falls in the same hourly bucket).""" + df = self._df( + [ + {"order_id": 1, "order_date": datetime.date(2026, 1, 1), "order_type": "A"}, + {"order_id": 2, "order_date": datetime.date(2026, 1, 1), "order_type": "B"}, + ] + ) + matched = create_match_filter(df, ["order_id"]) + spec = PartitionSpec(PartitionField(source_id=2, field_id=1000, transform=IdentityTransform(), name="order_date")) + augmented = augment_filter_with_partition_ranges( + matched_predicate=matched, + df=df, + schema=self._orders_schema(), + spec=spec, + ) + assert augmented == And(matched, EqualTo("order_date", datetime.date(2026, 1, 1))) + + def test_range_emits_gteq_and_lteq(self) -> None: + """Multiple distinct values → ``GreaterThanOrEqual(min) AND + LessThanOrEqual(max)`` pair, AND'd onto the original matched + predicate. Inclusive_projection handles the partition-transform + projection at scan time.""" + df = self._df( + [ + {"order_id": 1, "order_date": datetime.date(2026, 1, 1), "order_type": "A"}, + {"order_id": 2, "order_date": datetime.date(2026, 1, 15), "order_type": "B"}, + {"order_id": 3, "order_date": datetime.date(2026, 2, 1), "order_type": "C"}, + ] + ) + matched = create_match_filter(df, ["order_id"]) + spec = PartitionSpec(PartitionField(source_id=2, field_id=1000, transform=IdentityTransform(), name="order_date")) + augmented = augment_filter_with_partition_ranges( + matched_predicate=matched, + df=df, + schema=self._orders_schema(), + spec=spec, + ) + assert augmented == And( + And(matched, GreaterThanOrEqual("order_date", datetime.date(2026, 1, 1))), + LessThanOrEqual("order_date", datetime.date(2026, 2, 1)), + ) + + def test_multiple_partition_fields_share_source_id_emitted_once(self) -> None: + """When two partition fields source from the same column (e.g. + ``bucket(8, id), truncate(4, id)``), only one source-column range + is emitted. ``inclusive_projection`` projects through each + partition field independently at scan time, so a single source- + range predicate suffices for both.""" + df = self._df( + [ + {"order_id": 1, "order_date": datetime.date(2026, 1, 1), "order_type": "A"}, + {"order_id": 10, "order_date": datetime.date(2026, 1, 2), "order_type": "B"}, + ] + ) + matched = create_match_filter(df, ["order_type"]) + + from pyiceberg.transforms import BucketTransform, TruncateTransform + + spec = PartitionSpec( + PartitionField(source_id=1, field_id=1000, transform=BucketTransform(8), name="order_id_bucket"), + PartitionField(source_id=1, field_id=1001, transform=TruncateTransform(4), name="order_id_trunc"), + ) + augmented = augment_filter_with_partition_ranges( + matched_predicate=matched, + df=df, + schema=self._orders_schema(), + spec=spec, + ) + # Exactly one ``GreaterThanOrEqual`` + ``LessThanOrEqual`` pair on + # ``order_id`` — not duplicated for each partition field. + assert augmented == And( + And(matched, GreaterThanOrEqual("order_id", 1)), + LessThanOrEqual("order_id", 10), + ) + + +class TestUpsertPartitionPruningIntegration: + """End-to-end upsert against partitioned tables. + + Verifies that the augmented row filter doesn't change upsert + semantics — ``rows_updated`` / ``rows_inserted`` match the original + behaviour — across the three structural cases: + + 1. Partition source not in ``join_cols`` (the case the augmentation + fires for; biggest perf gain). + 2. Partition source IS in ``join_cols`` (augmentation contributes + redundantly but doesn't change correctness). + 3. Unpartitioned (augmentation is a no-op). + """ + + def test_upsert_correct_when_partition_col_not_in_join_cols(self, catalog: Catalog) -> None: + """Source partitioned by ``order_date`` but keyed on ``order_id``. + Augmentation fires — semantics must be identical to the + unpartitioned baseline.""" + identifier = "default.test_upsert_partition_not_in_join_cols" + _drop_table(catalog, identifier) + + schema = Schema( + NestedField(1, "order_id", IntegerType(), required=True), + NestedField(2, "order_date", DateType(), required=True), + NestedField(3, "order_type", StringType(), required=True), + ) + spec = PartitionSpec(PartitionField(source_id=2, field_id=1000, transform=IdentityTransform(), name="order_date")) + table = catalog.create_table(identifier, schema=schema, partition_spec=spec) + + arrow_schema = pa.schema( + [ + pa.field("order_id", pa.int32(), nullable=False), + pa.field("order_date", pa.date32(), nullable=False), + pa.field("order_type", pa.string(), nullable=False), + ] + ) + # Initial load: ids 1-5 across two different partitions. + initial = pa.Table.from_pylist( + [ + {"order_id": 1, "order_date": datetime.date(2026, 1, 1), "order_type": "A"}, + {"order_id": 2, "order_date": datetime.date(2026, 1, 1), "order_type": "A"}, + {"order_id": 3, "order_date": datetime.date(2026, 1, 2), "order_type": "A"}, + {"order_id": 4, "order_date": datetime.date(2026, 1, 2), "order_type": "A"}, + {"order_id": 5, "order_date": datetime.date(2026, 1, 2), "order_type": "A"}, + ], + schema=arrow_schema, + ) + table.append(initial) + + # Upsert: ids 3-7 (3 update existing, 2 are new), all in the + # same partition the augmentation will prune to. + upsert_df = pa.Table.from_pylist( + [ + {"order_id": 3, "order_date": datetime.date(2026, 1, 2), "order_type": "B"}, + {"order_id": 4, "order_date": datetime.date(2026, 1, 2), "order_type": "B"}, + {"order_id": 5, "order_date": datetime.date(2026, 1, 2), "order_type": "B"}, + {"order_id": 6, "order_date": datetime.date(2026, 1, 2), "order_type": "B"}, + {"order_id": 7, "order_date": datetime.date(2026, 1, 2), "order_type": "B"}, + ], + schema=arrow_schema, + ) + + res = table.upsert(df=upsert_df, join_cols=["order_id"]) + assert res.rows_updated == 3 + assert res.rows_inserted == 2 + + # Sanity: the table now has 7 rows total (5 initial - 3 updated + 3 updated + 2 inserted == 7). + final = table.scan().to_arrow() + assert final.num_rows == 7 + assert set(final["order_id"].to_pylist()) == {1, 2, 3, 4, 5, 6, 7} + + def test_upsert_correct_when_partition_col_in_join_cols(self, catalog: Catalog) -> None: + """Partition column IS one of the ``join_cols``. The augmentation + adds a redundant ``order_date`` range to a predicate that already + constrains ``order_date`` via ``create_match_filter`` — no + semantic change should result.""" + identifier = "default.test_upsert_partition_in_join_cols" + _drop_table(catalog, identifier) + + schema = Schema( + NestedField(1, "order_id", IntegerType(), required=True), + NestedField(2, "order_date", DateType(), required=True), + NestedField(3, "order_type", StringType(), required=True), + ) + spec = PartitionSpec(PartitionField(source_id=2, field_id=1000, transform=IdentityTransform(), name="order_date")) + table = catalog.create_table(identifier, schema=schema, partition_spec=spec) + + arrow_schema = pa.schema( + [ + pa.field("order_id", pa.int32(), nullable=False), + pa.field("order_date", pa.date32(), nullable=False), + pa.field("order_type", pa.string(), nullable=False), + ] + ) + initial = pa.Table.from_pylist( + [ + {"order_id": 1, "order_date": datetime.date(2026, 1, 1), "order_type": "A"}, + {"order_id": 2, "order_date": datetime.date(2026, 1, 2), "order_type": "A"}, + ], + schema=arrow_schema, + ) + table.append(initial) + + upsert_df = pa.Table.from_pylist( + [ + {"order_id": 1, "order_date": datetime.date(2026, 1, 1), "order_type": "B"}, # update + {"order_id": 3, "order_date": datetime.date(2026, 1, 3), "order_type": "B"}, # insert + ], + schema=arrow_schema, + ) + + res = table.upsert(df=upsert_df, join_cols=["order_id", "order_date"]) + assert res.rows_updated == 1 + assert res.rows_inserted == 1 + + def test_augmented_predicate_prunes_destination_files(self, catalog: Catalog) -> None: + """Smoke test: ``DataScan.plan_files()`` returns strictly fewer + files when the augmented predicate is used. This is the actual + perf claim the optimization makes; the integration tests above + verify semantics, this one verifies pruning happens. + + Realistic worst-case shape: the key column (``order_id``) has + per-file lower/upper bounds that span the entire range of source + keys (modelled by writing the same set of ``order_id`` values + across every partition). This defeats ``_InclusiveMetricsEvaluator`` + which would otherwise prune via file-level parquet column stats — + and is exactly the situation real workloads hit when keys are + UUIDs or otherwise uniformly distributed across files + independent of partition. With key-level metrics unable to + prune, partition-spec projection on ``order_date`` is the only + lever left, and the test pins the assertion that our + augmentation activates it. + """ + from pyiceberg.table import DataScan + + identifier = "default.test_augmented_predicate_prunes_files" + _drop_table(catalog, identifier) + + schema = Schema( + NestedField(1, "order_id", IntegerType(), required=True), + NestedField(2, "order_date", DateType(), required=True), + NestedField(3, "order_type", StringType(), required=True), + ) + spec = PartitionSpec(PartitionField(source_id=2, field_id=1000, transform=IdentityTransform(), name="order_date")) + table = catalog.create_table(identifier, schema=schema, partition_spec=spec) + + arrow_schema = pa.schema( + [ + pa.field("order_id", pa.int32(), nullable=False), + pa.field("order_date", pa.date32(), nullable=False), + pa.field("order_type", pa.string(), nullable=False), + ] + ) + + # Five separate appends, one per partition. Each appended file + # contains the SAME set of ``order_id`` values {1..5} so per-file + # parquet stats on ``order_id`` are identical across files. This + # mirrors UUID-keyed workloads where per-file key bounds span + # the full key space and metrics-level pruning is useless; + # partition pruning is the only effective lever. + partitions = [ + datetime.date(2026, 1, 1), + datetime.date(2026, 1, 2), + datetime.date(2026, 1, 3), + datetime.date(2026, 1, 4), + datetime.date(2026, 1, 5), + ] + order_ids_per_partition = [1, 2, 3, 4, 5] + for d in partitions: + table.append( + pa.Table.from_pylist( + [{"order_id": oid, "order_date": d, "order_type": "A"} for oid in order_ids_per_partition], + schema=arrow_schema, + ) + ) + + # Setup invariant: one file per partition (5 total). + all_files_scan = DataScan(table_metadata=table.metadata, io=table.io, row_filter=AlwaysTrue()) + total_files = len(list(all_files_scan.plan_files())) + assert total_files == 5, f"setup invariant: expected 5 destination files, got {total_files}" + + # Source covers only 2 of the 5 partitions, both with order_id=2 + # (which appears in every destination file → file-level metrics + # pruning on order_id can't help). + src = pa.Table.from_pylist( + [ + {"order_id": 2, "order_date": datetime.date(2026, 1, 2), "order_type": "B"}, + {"order_id": 99, "order_date": datetime.date(2026, 1, 3), "order_type": "B"}, + ], + schema=arrow_schema, + ) + + # (a) Original behaviour: row_filter built from join_cols alone. + # join_cols = ['order_id'] doesn't reference 'order_date', so + # inclusive_projection collapses the partition projection to + # AlwaysTrue. order_id=2 falls within [1, 5] in every file → + # metrics evaluator can't prune either. All 5 files listed. + original_predicate = create_match_filter(src, ["order_id"]) + original_files = len( + list(DataScan(table_metadata=table.metadata, io=table.io, row_filter=original_predicate).plan_files()) + ) + + # (b) Augmented predicate adds [min, max] on the partition source. + # inclusive_projection projects through IdentityTransform; only + # the 2 hourly partitions overlapping [2026-01-02, 2026-01-03] + # are kept. + augmented_predicate = augment_filter_with_partition_ranges( + matched_predicate=original_predicate, + df=src, + schema=table.metadata.schema(), + spec=table.metadata.spec(), + ) + augmented_files = len( + list(DataScan(table_metadata=table.metadata, io=table.io, row_filter=augmented_predicate).plan_files()) + ) + + # The whole point of the optimization. + assert original_files == 5, ( + f"original behaviour invariant: with per-file order_id bounds spanning the full source key set, " + f"neither partition projection nor metrics evaluation can prune; expected 5 files, got {original_files}" + ) + assert augmented_files == 2, ( + f"augmented predicate must prune to overlapping partitions only; " + f"expected 2 files (2026-01-02, 2026-01-03), got {augmented_files}" + ) + assert augmented_files < original_files + + def test_upsert_unpartitioned_unchanged(self, catalog: Catalog) -> None: + """Sanity that the augmentation doesn't alter the unpartitioned + path. Mirrors ``test_merge_rows`` but smaller — purely a + regression guard against the augmentation accidentally tripping + unpartitioned tables.""" + identifier = "default.test_upsert_unpartitioned_augmentation_noop" + _drop_table(catalog, identifier) + + schema = Schema( + NestedField(1, "order_id", IntegerType(), required=True), + NestedField(2, "order_type", StringType(), required=True), + ) + table = catalog.create_table(identifier, schema=schema) + + arrow_schema = pa.schema( + [ + pa.field("order_id", pa.int32(), nullable=False), + pa.field("order_type", pa.string(), nullable=False), + ] + ) + table.append(pa.Table.from_pylist([{"order_id": 1, "order_type": "A"}], schema=arrow_schema)) + + upsert_df = pa.Table.from_pylist( + [ + {"order_id": 1, "order_type": "B"}, # update + {"order_id": 2, "order_type": "B"}, # insert + ], + schema=arrow_schema, + ) + res = table.upsert(df=upsert_df, join_cols=["order_id"]) + assert res.rows_updated == 1 + assert res.rows_inserted == 1 + + +class TestUpsertScanProjection: + """``Transaction.upsert`` narrows the destination scan's + ``selected_fields`` to ``join_cols`` when ``when_matched_update_all=False``. + + Rationale: the insert-on-no-match branch only reads ``join_cols`` + off each destination batch (to feed ``create_match_filter``); every + other column is unused. Projection at the scan boundary lets the + parquet reader prune wide non-key columns at the file level — + significant for tables whose payload column (e.g. a JSON ``log``) + dominates file bytes. ``_projected_field_ids`` auto-unions the + row-filter's column ids back in, so the augmented ``created_at`` + range and the original join-key predicates still see the columns + they need for filter evaluation. + + Falls back to ``("*",)`` when ``when_matched_update_all=True`` + because ``get_rows_to_update`` compares non-key columns to detect + actual value changes. + """ + + @staticmethod + def _build_partitioned_table(catalog: Catalog, identifier: str) -> Table: + _drop_table(catalog, identifier) + schema = Schema( + NestedField(1, "order_id", IntegerType(), required=True), + NestedField(2, "order_date", DateType(), required=True), + NestedField(3, "order_type", StringType(), required=True), + ) + spec = PartitionSpec(PartitionField(source_id=2, field_id=1000, transform=IdentityTransform(), name="order_date")) + return catalog.create_table(identifier, schema=schema, partition_spec=spec) + + @staticmethod + def _arrow_schema() -> pa.Schema: + return pa.schema( + [ + pa.field("order_id", pa.int32(), nullable=False), + pa.field("order_date", pa.date32(), nullable=False), + pa.field("order_type", pa.string(), nullable=False), + ] + ) + + def _seed(self, table: Table) -> None: + table.append( + pa.Table.from_pylist( + [ + {"order_id": 1, "order_date": datetime.date(2026, 1, 1), "order_type": "A"}, + {"order_id": 2, "order_date": datetime.date(2026, 1, 2), "order_type": "A"}, + ], + schema=self._arrow_schema(), + ) + ) + + @pytest.fixture + def captured_scans(self, monkeypatch: pytest.MonkeyPatch) -> list[dict[str, Any]]: + """Spy on ``DataScan.__init__`` to capture every kwargs dict. + + Lets the tests pin which ``selected_fields`` the upsert path + actually passes — assertions on the surfaced batch schema alone + would miss the case where the underlying projection contract + regresses but the test data happens to have only join_cols + anyway. + + The spy preserves ``__init__``'s signature via + :func:`functools.wraps` so ``DataScan.update()``'s reflective + ``inspect.signature(type(self).__init__).parameters`` lookup + (used by ``use_ref``) still resolves to the real parameter + names, not the spy's ``**kwargs``. + """ + import functools + + from pyiceberg.table import DataScan + + captured: list[dict[str, Any]] = [] + original_init = DataScan.__init__ + + @functools.wraps(original_init) + def _spy(self: DataScan, *args: Any, **kwargs: Any) -> None: + captured.append(dict(kwargs)) + original_init(self, *args, **kwargs) + + monkeypatch.setattr(DataScan, "__init__", _spy) + return captured + + def test_when_matched_false_projects_join_cols_only(self, catalog: Catalog, captured_scans: list[dict[str, Any]]) -> None: + """The insert-on-no-match branch never reads non-key destination + columns, so the scan must narrow the projection to ``join_cols`` + — saving the parquet reader from materialising wide payload + columns just to be discarded.""" + table = self._build_partitioned_table(catalog, "default.test_upsert_projection_insert_only") + self._seed(table) + upsert_df = pa.Table.from_pylist( + [ + {"order_id": 2, "order_date": datetime.date(2026, 1, 2), "order_type": "B"}, + {"order_id": 3, "order_date": datetime.date(2026, 1, 3), "order_type": "B"}, + ], + schema=self._arrow_schema(), + ) + + # Snapshot only the scans constructed during the upsert (the + # seed append above may have created its own). + before = len(captured_scans) + res = table.upsert(df=upsert_df, join_cols=["order_id"], when_matched_update_all=False) + upsert_scans = captured_scans[before:] + assert res.rows_inserted == 1 + assert res.rows_updated == 0 + + # The upsert constructs one DataScan for the destination match. + # ``use_ref`` may construct a second DataScan as an inherited + # copy (via ``self.update``), which carries the same + # ``selected_fields`` through. Pin both: at least one scan was + # constructed during the upsert, and every scan that ran + # carries the narrowed projection. ``_projected_field_ids`` + # auto-unions the row filter's column ids back in, so + # ``order_date`` (added by the partition-range augmentation) + # is still read for filter evaluation without us having to + # list it explicitly. + assert upsert_scans, "upsert path constructed no DataScan — projection contract regression" + selected = [s.get("selected_fields") for s in upsert_scans] + assert all(sf == ("order_id",) for sf in selected), ( + f"expected every DataScan during upsert to use selected_fields=('order_id',); got {selected}" + ) + + def test_when_matched_true_keeps_star_projection(self, catalog: Catalog, captured_scans: list[dict[str, Any]]) -> None: + """The update branch's ``get_rows_to_update`` compares non-key + columns to detect actual value changes — projecting only + ``join_cols`` would feed it data with no non-key columns to + compare and silently turn every match into a write-back. Must + keep ``("*",)``.""" + table = self._build_partitioned_table(catalog, "default.test_upsert_projection_update_mode") + self._seed(table) + upsert_df = pa.Table.from_pylist( + [ + {"order_id": 1, "order_date": datetime.date(2026, 1, 1), "order_type": "B"}, + {"order_id": 3, "order_date": datetime.date(2026, 1, 3), "order_type": "B"}, + ], + schema=self._arrow_schema(), + ) + + before = len(captured_scans) + res = table.upsert(df=upsert_df, join_cols=["order_id"], when_matched_update_all=True) + upsert_scans = captured_scans[before:] + assert res.rows_updated == 1 + assert res.rows_inserted == 1 + + assert upsert_scans, "upsert path constructed no DataScan — projection contract regression" + selected = [s.get("selected_fields") for s in upsert_scans] + assert all(sf == ("*",) for sf in selected), ( + f"expected every DataScan during upsert to keep selected_fields=('*',) for the update branch; got {selected}" + ) + + def test_update_mode_actually_updates_non_key_columns(self, catalog: Catalog) -> None: + """End-to-end correctness pin: with ``when_matched_update_all=True`` + the destination scan must read non-key columns so + ``get_rows_to_update`` can detect ``order_type`` changes. A + regression that narrows projection unconditionally would skip + the comparison and silently miss updates whose non-key columns + differ. + """ + identifier = "default.test_upsert_update_mode_correctness" + table = self._build_partitioned_table(catalog, identifier) + self._seed(table) + # Source has the same (order_id, order_date) as one destination + # row but a different ``order_type``. Update path must detect + # the non-key change and overwrite. + upsert_df = pa.Table.from_pylist( + [{"order_id": 2, "order_date": datetime.date(2026, 1, 2), "order_type": "CHANGED"}], + schema=self._arrow_schema(), + ) + res = table.upsert(df=upsert_df, join_cols=["order_id"], when_matched_update_all=True) + assert res.rows_updated == 1 + assert res.rows_inserted == 0 + + # Read back: the original 'A' must have been overwritten with 'CHANGED'. + rows = {r["order_id"]: r for r in table.scan().to_arrow().to_pylist()} + assert rows[2]["order_type"] == "CHANGED"