Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 28 additions & 7 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,12 +877,35 @@ def upsert(
# get list of rows that exist so we don't have to load the entire target table
matched_predicate = upsert_util.create_match_filter(df, join_cols)

# Augment the row filter with [min, max] predicates on any
# partition source column present in ``df``. ``inclusive_projection``
# projects the range through monotonic partition transforms when
# planning the scan, so ``DataScan.plan_files`` can prune the
# destination at the manifest + file level.
matched_predicate = upsert_util.augment_filter_with_partition_ranges(
matched_predicate,
df,
self.table_metadata.schema(),
self.table_metadata.spec(),
)

# When ``when_matched_update_all=False`` the consumer loop below
# only ever reads ``join_cols`` off each destination batch (to
# build the per-batch match filter via
# ``upsert_util.create_match_filter``). Project ``join_cols``
# only so the parquet reader can prune wide non-key columns.
#
# ``when_matched_update_all=True`` falls back to the legacy
# ``("*",)`` projection.
selected_fields: tuple[str, ...] = ("*",) if when_matched_update_all else tuple(join_cols)

# We must use Transaction.table_metadata for the scan. This includes all uncommitted - but relevant - changes.

matched_iceberg_record_batches_scan = DataScan(
table_metadata=self.table_metadata,
io=self._table.io,
row_filter=matched_predicate,
selected_fields=selected_fields,
case_sensitive=case_sensitive,
)

Expand Down Expand Up @@ -2072,13 +2095,11 @@ def _build_residual_evaluator(self, spec_id: int) -> Callable[[DataFile], Residu
# The lambda created here is run in multiple threads.
# So we avoid creating _EvaluatorExpression methods bound to a single
# shared instance across multiple threads.
return lambda datafile: (
residual_evaluator_of(
spec=spec,
expr=self.row_filter,
case_sensitive=self.case_sensitive,
schema=self.table_metadata.schema(),
)
return lambda datafile: residual_evaluator_of(
spec=spec,
expr=self.row_filter,
case_sensitive=self.case_sensitive,
schema=self.table_metadata.schema(),
)

@staticmethod
Expand Down
103 changes: 103 additions & 0 deletions pyiceberg/table/upsert_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,16 @@

from pyiceberg.expressions import (
AlwaysFalse,
And,
BooleanExpression,
EqualTo,
GreaterThanOrEqual,
In,
LessThanOrEqual,
Or,
)
from pyiceberg.partitioning import PartitionSpec
from pyiceberg.schema import Schema


def create_match_filter(df: pyarrow_table, join_cols: list[str]) -> BooleanExpression:
Expand All @@ -53,6 +58,104 @@ def has_duplicate_rows(df: pyarrow_table, join_cols: list[str]) -> bool:
return len(df.select(join_cols).group_by(join_cols).aggregate([([], "count_all")]).filter(pc.field("count_all") > 1)) > 0


def augment_filter_with_partition_ranges(
matched_predicate: BooleanExpression,
df: pyarrow_table,
schema: Schema,
spec: PartitionSpec,
) -> BooleanExpression:
"""Return *matched_predicate* AND'd with ``[min, max]`` predicates on partition source columns.

Iceberg's ``inclusive_projection`` projects each range through the
partition transform (``hours``, ``days``, ``months``, ``years``,
``identity``, ``truncate``) when planning the scan, so
``DataScan.plan_files`` can prune manifests and data files that
don't overlap the source's value range. Without this augmentation,
tables whose partition spec sources from columns NOT in
``join_cols`` (a common pattern for append-only event logs
partitioned by time but keyed by composite IDs) fall through to a
full table scan on every upsert because the row filter built from
``join_cols`` alone projects to ``AlwaysTrue`` against the
partition spec.

Bucket and other non-monotonic transforms return ``None`` from
their ``project`` method for inequalities, so the augmentation is
safe — it either prunes or contributes ``AlwaysTrue`` (no harm).

A partition source column is skipped from augmentation when:

- It isn't present on ``df`` (no source value to bound).
- It is entirely null in ``df`` (no meaningful min/max).
- It contains any null in ``df`` (preserving correctness: a
``GreaterThanOrEqual(col, non_null_min)`` predicate would
exclude destination rows whose partition value is ``NULL``,
potentially missing a key match. Without partition pruning
those NULL-partition rows are scanned normally.)

When ``min == max`` for a column, an ``EqualTo`` predicate is
emitted instead of the range pair — tighter, and lets exact
partition pruning fire.

Args:
matched_predicate: The row filter built from ``join_cols``.
df: Source data frame whose values bound the augmentation.
schema: Iceberg schema, used to resolve partition source ids
to column names.
spec: Active partition spec.

Returns:
The augmented predicate, or *matched_predicate* unchanged
when no partition source column qualifies.
"""
if spec.is_unpartitioned():
return matched_predicate

df_columns = set(df.column_names)
augmentations: list[BooleanExpression] = []

# Iterate distinct source columns rather than partition fields —
# multiple partition fields can share a source column (e.g.
# ``bucket(8, id), truncate(4, id)``) but we only need to add the
# source-column range once; ``inclusive_projection`` projects
# through each partition field independently.
seen_source_ids: set[int] = set()
for field in spec.fields:
if field.source_id in seen_source_ids:
continue
seen_source_ids.add(field.source_id)

col_name = schema.find_field(field.source_id).name
if col_name not in df_columns:
continue

col = df[col_name]
if col.null_count > 0:
# Mixing null with a bounded predicate would exclude
# destination rows whose partition value is null,
# potentially missing key matches. Skip pruning rather
# than risk a correctness regression.
continue

col_min = pc.min(col).as_py()
col_max = pc.max(col).as_py()
if col_min is None or col_max is None:
# Defensive — ``null_count == 0`` should imply both bounds
# are non-null, but pyarrow's min/max can still return None
# on empty columns.
continue

if col_min == col_max:
augmentations.append(EqualTo(col_name, col_min))
else:
augmentations.append(GreaterThanOrEqual(col_name, col_min))
augmentations.append(LessThanOrEqual(col_name, col_max))

if not augmentations:
return matched_predicate

return functools.reduce(And, [matched_predicate, *augmentations])


def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, join_cols: list[str]) -> pa.Table:
"""
Return a table with rows that need to be updated in the target table based on the join columns.
Expand Down
Loading
Loading