Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/paimon-python-checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ jobs:
else
python -m pip install --upgrade pip
pip install torch --index-url https://download.pytorch.org/whl/cpu
python -m pip install pyroaring readerwriterlock==1.0.9 fsspec==2024.3.1 cachetools==5.3.3 ossfs==2023.12.0 ray==2.54.0 fastavro==1.11.1 pyarrow==16.0.0 zstandard==0.24.0 polars==1.32.0 duckdb==1.3.2 numpy==1.24.3 pandas==2.0.3 pylance==0.39.0 cramjam flake8==4.0.1 pytest~=7.0 py4j==0.10.9.9 requests parameterized==0.9.0 'daft>=0.7.6' pypaimon-rust==0.2.0
python -m pip install pyroaring readerwriterlock==1.0.9 fsspec==2024.3.1 cachetools==5.3.3 ossfs==2023.12.0 ray==2.54.0 fastavro==1.11.1 pyarrow==16.0.0 zstandard==0.24.0 polars==1.32.0 duckdb==1.3.2 numpy==1.24.3 pandas==2.0.3 pylance==0.39.0 cramjam flake8==4.0.1 pytest~=7.0 py4j==0.10.9.9 requests parameterized==0.9.0 'daft>=0.7.6' pypaimon-rust==0.2.0 'datafusion>=52'
python -m pip install 'lumina-data>=${{ env.LUMINA_DATA_VERSION }}' -i https://pypi.org/simple/
if python -c "import sys; sys.exit(0 if sys.version_info >= (3, 11) else 1)"; then
python -m pip install vortex-data==0.70.0
Expand Down
28 changes: 22 additions & 6 deletions docs/docs/pypaimon/ray-data.md
Original file line number Diff line number Diff line change
Expand Up @@ -357,12 +357,29 @@ metrics = merge_into(
print(metrics) # {"num_matched": 3, "num_inserted": 2, "num_unchanged": 0}
```

Conditional clauses filter which matched/unmatched rows are acted on:

```python
merge_into(
target="db.table",
source=source_ds,
catalog_options=catalog_options,
on=["id"],
when_matched=[WhenMatched(update="*", condition="s.age > t.age")],
when_not_matched=[WhenNotMatched(insert="*", condition="s.age > 18")],
)
```

Conditions use SQL-style expressions with `s.` (source) and `t.` (target)
column prefixes. `WhenNotMatched` conditions may only reference source
columns (`s.*`). Requires the `datafusion` package: `pip install pypaimon[sql]`.

- `update` / `insert`: only `"*"` is supported in this PR. A future follow-up
will add mapping-based SET (e.g. `{"col": "s.col"}`) where values are
analyzable string expressions (`"s.<col>"`, `"t.<col>"`, or literals),
not Python callables.
- `condition`: reserved for a future follow-up; passing a non-None value
currently raises `NotImplementedError`.
- `condition`: an optional SQL-style boolean expression. Use `s.<col>` and
`t.<col>` to reference source and target columns.

**Parameters:**
- `source`: a `ray.data.Dataset`, `pyarrow.Table`, `pandas.DataFrame`, or a
Expand All @@ -375,10 +392,9 @@ print(metrics) # {"num_matched": 3, "num_inserted": 2, "num_unchanged": 0}
tasks (update transform, group write, insert transform).
- `concurrency`: scheduling for the insert sink.

**Returns:** `{"num_matched", "num_inserted", "num_unchanged"}`. In this PR
every matched row is updated, so `num_matched` always equals `num_updated`
and `num_unchanged` is always `0`; conditional clauses (added later) can
make `num_unchanged > 0`.
**Returns:** `{"num_matched", "num_inserted", "num_unchanged"}`. `num_matched`
counts the rows actually updated (after condition filtering). `num_unchanged`
is `0` in the current implementation.

**Notes:**
- Partition key columns cannot be updated by matched clauses. If the target
Expand Down
2 changes: 2 additions & 0 deletions paimon-python/dev/requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,7 @@ requests
parameterized
# Vortex 0.71.0 regresses native predicate pushdown on single-row files.
vortex-data==0.70.0; python_version >= "3.11"
# merge_into condition expressions (optional, for condition tests)
datafusion>=52; python_version >= "3.10"
# Lumina vector search (optional, for lumina index tests)
lumina-data>=0.1.0
61 changes: 53 additions & 8 deletions paimon-python/pypaimon/ray/data_evolution_merge_into.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,6 @@ def _prepare(target, source, catalog_options, when_matched, when_not_matched, on
"WhenNotMatched clause; multi-clause fall-through will be added "
"in a follow-up PR."
)
for clause in list(when_matched) + list(when_not_matched):
if clause.condition is not None:
raise NotImplementedError(
"merge_into does not yet support condition expressions; "
"this will be added in a follow-up PR."
)
target_on_cols, source_on_cols = _normalize_on(on)

from pypaimon.catalog.catalog_factory import CatalogFactory
Expand Down Expand Up @@ -136,14 +130,41 @@ def _prepare(target, source, catalog_options, when_matched, when_not_matched, on
spec=_normalize_set_spec(
c.update, settable_field_names, on_map,
),
condition=c.condition,
)
for c in when_matched
]
has_condition = any(
c.condition is not None
for c in list(when_matched) + list(when_not_matched)
)
if has_condition:
from pypaimon.ray.merge_condition import (
_require_datafusion, extract_target_columns,
)
_require_datafusion()
for c in when_not_matched:
if c.condition is not None:
t_refs = extract_target_columns(c.condition)
if t_refs:
raise ValueError(
f"WhenNotMatched condition must not reference "
f"target columns (t.*), but found: {sorted(t_refs)}"
)
for c in list(when_matched) + list(when_not_matched):
if c.condition is not None:
blob_refs = extract_target_columns(c.condition) & blob_cols
if blob_refs:
raise ValueError(
f"condition must not reference blob columns, "
f"but found: {sorted(blob_refs)}"
)
not_matched_specs = [
_NormalizedClause(
spec=_normalize_set_spec(
c.insert, settable_field_names, on_map,
),
condition=c.condition,
)
for c in when_not_matched
]
Expand All @@ -154,6 +175,25 @@ def _prepare(target, source, catalog_options, when_matched, when_not_matched, on
source_ds, settable_field_names, on_map,
)

if has_condition:
from pypaimon.ray.merge_condition import extract_columns
source_names = set(_source_schema_or_raise(source_ds).names)
target_names = set(full_target_field_names)
for c in list(when_matched) + list(when_not_matched):
if c.condition is not None:
for ref in extract_columns(c.condition):
prefix, col = ref.split(".", 1)
if prefix == "s" and col not in source_names:
raise ValueError(
f"condition references unknown source "
f"column '{col}'"
)
if prefix == "t" and col not in target_names:
raise ValueError(
f"condition references unknown target "
f"column '{col}'"
)

from pypaimon.schema.data_types import PyarrowFieldParser
full_pa_schema = PyarrowFieldParser.from_paimon_schema(
table.table_schema.fields
Expand Down Expand Up @@ -272,8 +312,7 @@ def _execute_and_commit(
tc.commit(all_msgs)
tc.close()

# MVP has no condition, so every matched row is updated; num_unchanged
# is always 0. Kept in the dict for API stability when condition lands.
# num_matched = rows that passed the condition and were updated
return {
"num_matched": num_updated,
"num_inserted": num_inserted,
Expand Down Expand Up @@ -375,6 +414,12 @@ def _resolve_target_projection(
needed = set(_needed_target_cols(
clauses, target_on, update_cols, target_field_names,
))
if any(c.condition is not None for c in clauses):
from pypaimon.ray.merge_condition import extract_target_columns
target_set = set(target_field_names)
for clause in clauses:
if clause.condition is not None:
needed |= extract_target_columns(clause.condition) & target_set
return [c for c in target_field_names if c in needed]


Expand Down
32 changes: 32 additions & 0 deletions paimon-python/pypaimon/ray/data_evolution_merge_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,31 @@ def build_matched_update_ds(
f"build_matched_update_ds expected 1 clause, got {len(clauses)}"
)
spec = clauses[0].spec
condition = clauses[0].condition
captured_update_cols = list(update_cols)
captured_row_id_name = row_id_name
captured_on_pairs = list(zip(source_on, target_on))
captured_schema = update_schema

captured_apply = None
captured_rewritten = None
if condition is not None:
from pypaimon.ray.merge_condition import (
apply_condition, remap_source_on_keys, rewrite_condition,
)
on_map = dict(zip(source_on, target_on))
captured_rewritten = remap_source_on_keys(
rewrite_condition(condition), on_map,
)
captured_apply = apply_condition

def _transform(batch: pa.Table) -> pa.Table:
if captured_apply is not None:
batch = captured_apply(
batch, captured_rewritten, captured_schema,
)
if batch.num_rows == 0:
return batch
return vectorized_matched_transform(
batch, spec, captured_on_pairs,
captured_update_cols, captured_row_id_name,
Expand Down Expand Up @@ -308,8 +327,21 @@ def build_not_matched_insert_ds(
f"build_not_matched_insert_ds expected 1 clause, got {len(clauses)}"
)
spec = clauses[0].spec
condition = clauses[0].condition
captured_apply = None
captured_rewritten = None
if condition is not None:
from pypaimon.ray.merge_condition import apply_condition, rewrite_condition
captured_rewritten = rewrite_condition(condition)
captured_apply = apply_condition

def _transform(batch: pa.Table) -> pa.Table:
if captured_apply is not None:
batch = captured_apply(
batch, captured_rewritten, out_schema,
)
if batch.num_rows == 0:
return _coerce_large_string_types(batch)
return _coerce_large_string_types(
vectorized_insert_transform(
batch, spec, captured_field_names, out_schema
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class WhenNotMatched:
@dataclass
class _NormalizedClause:
spec: Dict[str, Any]
condition: Optional[str] = None


def vectorized_matched_transform(
Expand Down
104 changes: 104 additions & 0 deletions paimon-python/pypaimon/ray/merge_condition.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################

import re
from typing import Mapping, Set

import pyarrow as pa


_COL_REF_PATTERN = re.compile(r'\b([st])\.(\w+)\b')


def _require_datafusion():
try:
import datafusion
return datafusion
except ImportError:
raise ImportError(
"merge_into condition expressions require the 'datafusion' "
"package. Install it with: pip install pypaimon[sql]"
)


_STRING_LITERAL = re.compile(r"'(?:[^']|'')*'")


def _strip_string_literals(condition: str) -> str:
return _STRING_LITERAL.sub('', condition)


def rewrite_condition(condition: str) -> str:
parts, last = [], 0
for m in _STRING_LITERAL.finditer(condition):
parts.append(_COL_REF_PATTERN.sub(r'"\1.\2"', condition[last:m.start()]))
parts.append(m.group())
last = m.end()
parts.append(_COL_REF_PATTERN.sub(r'"\1.\2"', condition[last:]))
return ''.join(parts)


def remap_source_on_keys(
rewritten: str, on_map: Mapping[str, str],
) -> str:
for s_col, t_col in on_map.items():
old, new = f'"s.{s_col}"', f'"t.{t_col}"'
parts, last = [], 0
for m in _STRING_LITERAL.finditer(rewritten):
parts.append(rewritten[last:m.start()].replace(old, new))
parts.append(m.group())
last = m.end()
parts.append(rewritten[last:].replace(old, new))
rewritten = ''.join(parts)
return rewritten


def filter_batch(
batch: pa.Table, condition: str, _pre_rewritten: bool = False,
) -> pa.Table:
if batch.num_rows == 0:
return batch
datafusion = _require_datafusion()
rewritten = condition if _pre_rewritten else rewrite_condition(condition)
ctx = datafusion.SessionContext()
ctx.register_record_batches("_batch", [batch.to_batches()])
result = ctx.sql(
f'SELECT * FROM _batch WHERE {rewritten}'
)
return result.to_arrow_table()


def apply_condition(
batch: pa.Table, rewritten: str, empty_schema: pa.Schema,
) -> pa.Table:
batch = filter_batch(batch, rewritten, _pre_rewritten=True)
if batch.num_rows == 0:
return empty_schema.empty_table()
return batch


def extract_columns(condition: str) -> Set[str]:
stripped = _strip_string_literals(condition)
return {f"{m.group(1)}.{m.group(2)}"
for m in _COL_REF_PATTERN.finditer(stripped)}


def extract_target_columns(condition: str) -> Set[str]:
stripped = _strip_string_literals(condition)
return {m.group(2) for m in _COL_REF_PATTERN.finditer(stripped)
if m.group(1) == "t"}
Loading
Loading