diff --git a/.github/workflows/paimon-python-checks.yml b/.github/workflows/paimon-python-checks.yml
index 6a88767590e3..dc4cccaeea20 100755
--- a/.github/workflows/paimon-python-checks.yml
+++ b/.github/workflows/paimon-python-checks.yml
@@ -133,7 +133,7 @@ jobs:
else
python -m pip install --upgrade pip
pip install torch --index-url https://download.pytorch.org/whl/cpu
- python -m pip install pyroaring readerwriterlock==1.0.9 fsspec==2024.3.1 cachetools==5.3.3 ossfs==2023.12.0 ray==2.54.0 fastavro==1.11.1 pyarrow==16.0.0 zstandard==0.24.0 polars==1.32.0 duckdb==1.3.2 numpy==1.24.3 pandas==2.0.3 pylance==0.39.0 cramjam flake8==4.0.1 pytest~=7.0 py4j==0.10.9.9 requests parameterized==0.9.0 'daft>=0.7.6' pypaimon-rust==0.2.0
+ python -m pip install pyroaring readerwriterlock==1.0.9 fsspec==2024.3.1 cachetools==5.3.3 ossfs==2023.12.0 ray==2.54.0 fastavro==1.11.1 pyarrow==16.0.0 zstandard==0.24.0 polars==1.32.0 duckdb==1.3.2 numpy==1.24.3 pandas==2.0.3 pylance==0.39.0 cramjam flake8==4.0.1 pytest~=7.0 py4j==0.10.9.9 requests parameterized==0.9.0 'daft>=0.7.6' pypaimon-rust==0.2.0 'datafusion>=52'
python -m pip install 'lumina-data>=${{ env.LUMINA_DATA_VERSION }}' -i https://pypi.org/simple/
if python -c "import sys; sys.exit(0 if sys.version_info >= (3, 11) else 1)"; then
python -m pip install vortex-data==0.70.0
diff --git a/docs/docs/pypaimon/ray-data.md b/docs/docs/pypaimon/ray-data.md
index 658a1098ae04..a3c59c9f68b8 100644
--- a/docs/docs/pypaimon/ray-data.md
+++ b/docs/docs/pypaimon/ray-data.md
@@ -357,12 +357,29 @@ metrics = merge_into(
print(metrics) # {"num_matched": 3, "num_inserted": 2, "num_unchanged": 0}
```
+Conditional clauses filter which matched/unmatched rows are acted on:
+
+```python
+merge_into(
+ target="db.table",
+ source=source_ds,
+ catalog_options=catalog_options,
+ on=["id"],
+ when_matched=[WhenMatched(update="*", condition="s.age > t.age")],
+ when_not_matched=[WhenNotMatched(insert="*", condition="s.age > 18")],
+)
+```
+
+Conditions use SQL-style expressions with `s.` (source) and `t.` (target)
+column prefixes. `WhenNotMatched` conditions may only reference source
+columns (`s.*`). Requires the `datafusion` package: `pip install pypaimon[sql]`.
+
- `update` / `insert`: only `"*"` is supported in this PR. A future follow-up
will add mapping-based SET (e.g. `{"col": "s.col"}`) where values are
analyzable string expressions (`"s.
"`, `"t."`, or literals),
not Python callables.
-- `condition`: reserved for a future follow-up; passing a non-None value
- currently raises `NotImplementedError`.
+- `condition`: an optional SQL-style boolean expression. Use `s.` and
+ `t.` to reference source and target columns.
**Parameters:**
- `source`: a `ray.data.Dataset`, `pyarrow.Table`, `pandas.DataFrame`, or a
@@ -375,10 +392,9 @@ print(metrics) # {"num_matched": 3, "num_inserted": 2, "num_unchanged": 0}
tasks (update transform, group write, insert transform).
- `concurrency`: scheduling for the insert sink.
-**Returns:** `{"num_matched", "num_inserted", "num_unchanged"}`. In this PR
-every matched row is updated, so `num_matched` always equals `num_updated`
-and `num_unchanged` is always `0`; conditional clauses (added later) can
-make `num_unchanged > 0`.
+**Returns:** `{"num_matched", "num_inserted", "num_unchanged"}`. `num_matched`
+counts the rows actually updated (after condition filtering). `num_unchanged`
+is `0` in the current implementation.
**Notes:**
- Partition key columns cannot be updated by matched clauses. If the target
diff --git a/paimon-python/dev/requirements-dev.txt b/paimon-python/dev/requirements-dev.txt
index 9ef88817f726..c83a2e44b8b6 100644
--- a/paimon-python/dev/requirements-dev.txt
+++ b/paimon-python/dev/requirements-dev.txt
@@ -28,5 +28,7 @@ requests
parameterized
# Vortex 0.71.0 regresses native predicate pushdown on single-row files.
vortex-data==0.70.0; python_version >= "3.11"
+# merge_into condition expressions (optional, for condition tests)
+datafusion>=52; python_version >= "3.10"
# Lumina vector search (optional, for lumina index tests)
lumina-data>=0.1.0
diff --git a/paimon-python/pypaimon/ray/data_evolution_merge_into.py b/paimon-python/pypaimon/ray/data_evolution_merge_into.py
index b90abdc745c7..5be68f301e05 100644
--- a/paimon-python/pypaimon/ray/data_evolution_merge_into.py
+++ b/paimon-python/pypaimon/ray/data_evolution_merge_into.py
@@ -96,12 +96,6 @@ def _prepare(target, source, catalog_options, when_matched, when_not_matched, on
"WhenNotMatched clause; multi-clause fall-through will be added "
"in a follow-up PR."
)
- for clause in list(when_matched) + list(when_not_matched):
- if clause.condition is not None:
- raise NotImplementedError(
- "merge_into does not yet support condition expressions; "
- "this will be added in a follow-up PR."
- )
target_on_cols, source_on_cols = _normalize_on(on)
from pypaimon.catalog.catalog_factory import CatalogFactory
@@ -136,14 +130,41 @@ def _prepare(target, source, catalog_options, when_matched, when_not_matched, on
spec=_normalize_set_spec(
c.update, settable_field_names, on_map,
),
+ condition=c.condition,
)
for c in when_matched
]
+ has_condition = any(
+ c.condition is not None
+ for c in list(when_matched) + list(when_not_matched)
+ )
+ if has_condition:
+ from pypaimon.ray.merge_condition import (
+ _require_datafusion, extract_target_columns,
+ )
+ _require_datafusion()
+ for c in when_not_matched:
+ if c.condition is not None:
+ t_refs = extract_target_columns(c.condition)
+ if t_refs:
+ raise ValueError(
+ f"WhenNotMatched condition must not reference "
+ f"target columns (t.*), but found: {sorted(t_refs)}"
+ )
+ for c in list(when_matched) + list(when_not_matched):
+ if c.condition is not None:
+ blob_refs = extract_target_columns(c.condition) & blob_cols
+ if blob_refs:
+ raise ValueError(
+ f"condition must not reference blob columns, "
+ f"but found: {sorted(blob_refs)}"
+ )
not_matched_specs = [
_NormalizedClause(
spec=_normalize_set_spec(
c.insert, settable_field_names, on_map,
),
+ condition=c.condition,
)
for c in when_not_matched
]
@@ -154,6 +175,25 @@ def _prepare(target, source, catalog_options, when_matched, when_not_matched, on
source_ds, settable_field_names, on_map,
)
+ if has_condition:
+ from pypaimon.ray.merge_condition import extract_columns
+ source_names = set(_source_schema_or_raise(source_ds).names)
+ target_names = set(full_target_field_names)
+ for c in list(when_matched) + list(when_not_matched):
+ if c.condition is not None:
+ for ref in extract_columns(c.condition):
+ prefix, col = ref.split(".", 1)
+ if prefix == "s" and col not in source_names:
+ raise ValueError(
+ f"condition references unknown source "
+ f"column '{col}'"
+ )
+ if prefix == "t" and col not in target_names:
+ raise ValueError(
+ f"condition references unknown target "
+ f"column '{col}'"
+ )
+
from pypaimon.schema.data_types import PyarrowFieldParser
full_pa_schema = PyarrowFieldParser.from_paimon_schema(
table.table_schema.fields
@@ -272,8 +312,7 @@ def _execute_and_commit(
tc.commit(all_msgs)
tc.close()
- # MVP has no condition, so every matched row is updated; num_unchanged
- # is always 0. Kept in the dict for API stability when condition lands.
+ # num_matched = rows that passed the condition and were updated
return {
"num_matched": num_updated,
"num_inserted": num_inserted,
@@ -375,6 +414,12 @@ def _resolve_target_projection(
needed = set(_needed_target_cols(
clauses, target_on, update_cols, target_field_names,
))
+ if any(c.condition is not None for c in clauses):
+ from pypaimon.ray.merge_condition import extract_target_columns
+ target_set = set(target_field_names)
+ for clause in clauses:
+ if clause.condition is not None:
+ needed |= extract_target_columns(clause.condition) & target_set
return [c for c in target_field_names if c in needed]
diff --git a/paimon-python/pypaimon/ray/data_evolution_merge_join.py b/paimon-python/pypaimon/ray/data_evolution_merge_join.py
index 40e7994b2ddd..6d8b40e39028 100644
--- a/paimon-python/pypaimon/ray/data_evolution_merge_join.py
+++ b/paimon-python/pypaimon/ray/data_evolution_merge_join.py
@@ -94,12 +94,31 @@ def build_matched_update_ds(
f"build_matched_update_ds expected 1 clause, got {len(clauses)}"
)
spec = clauses[0].spec
+ condition = clauses[0].condition
captured_update_cols = list(update_cols)
captured_row_id_name = row_id_name
captured_on_pairs = list(zip(source_on, target_on))
captured_schema = update_schema
+ captured_apply = None
+ captured_rewritten = None
+ if condition is not None:
+ from pypaimon.ray.merge_condition import (
+ apply_condition, remap_source_on_keys, rewrite_condition,
+ )
+ on_map = dict(zip(source_on, target_on))
+ captured_rewritten = remap_source_on_keys(
+ rewrite_condition(condition), on_map,
+ )
+ captured_apply = apply_condition
+
def _transform(batch: pa.Table) -> pa.Table:
+ if captured_apply is not None:
+ batch = captured_apply(
+ batch, captured_rewritten, captured_schema,
+ )
+ if batch.num_rows == 0:
+ return batch
return vectorized_matched_transform(
batch, spec, captured_on_pairs,
captured_update_cols, captured_row_id_name,
@@ -308,8 +327,21 @@ def build_not_matched_insert_ds(
f"build_not_matched_insert_ds expected 1 clause, got {len(clauses)}"
)
spec = clauses[0].spec
+ condition = clauses[0].condition
+ captured_apply = None
+ captured_rewritten = None
+ if condition is not None:
+ from pypaimon.ray.merge_condition import apply_condition, rewrite_condition
+ captured_rewritten = rewrite_condition(condition)
+ captured_apply = apply_condition
def _transform(batch: pa.Table) -> pa.Table:
+ if captured_apply is not None:
+ batch = captured_apply(
+ batch, captured_rewritten, out_schema,
+ )
+ if batch.num_rows == 0:
+ return _coerce_large_string_types(batch)
return _coerce_large_string_types(
vectorized_insert_transform(
batch, spec, captured_field_names, out_schema
diff --git a/paimon-python/pypaimon/ray/data_evolution_merge_transform.py b/paimon-python/pypaimon/ray/data_evolution_merge_transform.py
index 0fc2d22f779d..ed786467f1e7 100644
--- a/paimon-python/pypaimon/ray/data_evolution_merge_transform.py
+++ b/paimon-python/pypaimon/ray/data_evolution_merge_transform.py
@@ -40,6 +40,7 @@ class WhenNotMatched:
@dataclass
class _NormalizedClause:
spec: Dict[str, Any]
+ condition: Optional[str] = None
def vectorized_matched_transform(
diff --git a/paimon-python/pypaimon/ray/merge_condition.py b/paimon-python/pypaimon/ray/merge_condition.py
new file mode 100644
index 000000000000..5497406c5cd2
--- /dev/null
+++ b/paimon-python/pypaimon/ray/merge_condition.py
@@ -0,0 +1,104 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+import re
+from typing import Mapping, Set
+
+import pyarrow as pa
+
+
+_COL_REF_PATTERN = re.compile(r'\b([st])\.(\w+)\b')
+
+
+def _require_datafusion():
+ try:
+ import datafusion
+ return datafusion
+ except ImportError:
+ raise ImportError(
+ "merge_into condition expressions require the 'datafusion' "
+ "package. Install it with: pip install pypaimon[sql]"
+ )
+
+
+_STRING_LITERAL = re.compile(r"'(?:[^']|'')*'")
+
+
+def _strip_string_literals(condition: str) -> str:
+ return _STRING_LITERAL.sub('', condition)
+
+
+def rewrite_condition(condition: str) -> str:
+ parts, last = [], 0
+ for m in _STRING_LITERAL.finditer(condition):
+ parts.append(_COL_REF_PATTERN.sub(r'"\1.\2"', condition[last:m.start()]))
+ parts.append(m.group())
+ last = m.end()
+ parts.append(_COL_REF_PATTERN.sub(r'"\1.\2"', condition[last:]))
+ return ''.join(parts)
+
+
+def remap_source_on_keys(
+ rewritten: str, on_map: Mapping[str, str],
+) -> str:
+ for s_col, t_col in on_map.items():
+ old, new = f'"s.{s_col}"', f'"t.{t_col}"'
+ parts, last = [], 0
+ for m in _STRING_LITERAL.finditer(rewritten):
+ parts.append(rewritten[last:m.start()].replace(old, new))
+ parts.append(m.group())
+ last = m.end()
+ parts.append(rewritten[last:].replace(old, new))
+ rewritten = ''.join(parts)
+ return rewritten
+
+
+def filter_batch(
+ batch: pa.Table, condition: str, _pre_rewritten: bool = False,
+) -> pa.Table:
+ if batch.num_rows == 0:
+ return batch
+ datafusion = _require_datafusion()
+ rewritten = condition if _pre_rewritten else rewrite_condition(condition)
+ ctx = datafusion.SessionContext()
+ ctx.register_record_batches("_batch", [batch.to_batches()])
+ result = ctx.sql(
+ f'SELECT * FROM _batch WHERE {rewritten}'
+ )
+ return result.to_arrow_table()
+
+
+def apply_condition(
+ batch: pa.Table, rewritten: str, empty_schema: pa.Schema,
+) -> pa.Table:
+ batch = filter_batch(batch, rewritten, _pre_rewritten=True)
+ if batch.num_rows == 0:
+ return empty_schema.empty_table()
+ return batch
+
+
+def extract_columns(condition: str) -> Set[str]:
+ stripped = _strip_string_literals(condition)
+ return {f"{m.group(1)}.{m.group(2)}"
+ for m in _COL_REF_PATTERN.finditer(stripped)}
+
+
+def extract_target_columns(condition: str) -> Set[str]:
+ stripped = _strip_string_literals(condition)
+ return {m.group(2) for m in _COL_REF_PATTERN.finditer(stripped)
+ if m.group(1) == "t"}
diff --git a/paimon-python/pypaimon/tests/ray_data_evolution_merge_into_test.py b/paimon-python/pypaimon/tests/ray_data_evolution_merge_into_test.py
index 727185a2e4d2..47981088f2d3 100644
--- a/paimon-python/pypaimon/tests/ray_data_evolution_merge_into_test.py
+++ b/paimon-python/pypaimon/tests/ray_data_evolution_merge_into_test.py
@@ -28,6 +28,15 @@
from pypaimon import CatalogFactory, Schema
from pypaimon.ray import WhenMatched, WhenNotMatched, merge_into
+try:
+ import datafusion # noqa: F401
+ _HAS_DATAFUSION = True
+except ImportError:
+ _HAS_DATAFUSION = False
+
+_SKIP_CONDITION = not _HAS_DATAFUSION
+_SKIP_REASON = "datafusion not installed"
+
_TEST_NUM_PARTITIONS = 2
@@ -153,6 +162,56 @@ def test_source_missing_on_col_raises(self):
)
self.assertIn("'id'", str(ctx.exception))
+ @unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON)
+ def test_not_matched_condition_rejects_target_refs(self):
+ target = self._create_table()
+ with self.assertRaises(ValueError) as ctx:
+ merge_into(
+ target=target,
+ source=self._source(),
+ catalog_options=self.catalog_options,
+ on=['id'],
+ when_not_matched=[
+ WhenNotMatched(insert='*', condition='t.age > 10')
+ ],
+ num_partitions=_TEST_NUM_PARTITIONS,
+ )
+ self.assertIn('t.', str(ctx.exception))
+
+ @unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON)
+ def test_condition_unknown_source_col_rejected(self):
+ target = self._create_table()
+ self._write(target, self._source())
+ with self.assertRaises(ValueError) as ctx:
+ merge_into(
+ target=target,
+ source=self._source(),
+ catalog_options=self.catalog_options,
+ on=['id'],
+ when_matched=[
+ WhenMatched(update='*', condition='s.nonexistent > 0')
+ ],
+ num_partitions=_TEST_NUM_PARTITIONS,
+ )
+ self.assertIn('nonexistent', str(ctx.exception))
+
+ @unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON)
+ def test_condition_unknown_target_col_rejected(self):
+ target = self._create_table()
+ self._write(target, self._source())
+ with self.assertRaises(ValueError) as ctx:
+ merge_into(
+ target=target,
+ source=self._source(),
+ catalog_options=self.catalog_options,
+ on=['id'],
+ when_matched=[
+ WhenMatched(update='*', condition='s.age > t.nonexistent')
+ ],
+ num_partitions=_TEST_NUM_PARTITIONS,
+ )
+ self.assertIn('nonexistent', str(ctx.exception))
+
def test_matched_update_star(self):
target = self._create_table()
self._write(
@@ -519,12 +578,249 @@ def test_partitioned_insert_allowed(self):
self.assertEqual(out['id'], [1, 2])
self.assertEqual(out['pt'], ['a', 'b'])
+ @unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON)
+ def test_matched_update_with_condition(self):
+ target = self._create_table()
+ self._write(
+ target,
+ pa.Table.from_pydict(
+ {
+ 'id': pa.array([1, 2, 3], type=pa.int32()),
+ 'name': ['a', 'b', 'c'],
+ 'age': pa.array([10, 20, 30], type=pa.int32()),
+ },
+ schema=self.pa_schema,
+ ),
+ )
+
+ source = pa.Table.from_pydict(
+ {
+ 'id': pa.array([1, 2, 3], type=pa.int32()),
+ 'name': ['a2', 'b2', 'c2'],
+ 'age': pa.array([15, 25, 45], type=pa.int32()),
+ },
+ schema=self.pa_schema,
+ )
+
+ merge_into(
+ target=target,
+ source=source,
+ catalog_options=self.catalog_options,
+ on=['id'],
+ when_matched=[WhenMatched(update='*', condition='s.age > t.age + 10')],
+ num_partitions=_TEST_NUM_PARTITIONS,
+ )
+
+ out = self._read_sorted(target)
+ self.assertEqual(out['id'], [1, 2, 3])
+ self.assertEqual(out['name'], ['a', 'b', 'c2'])
+ self.assertEqual(out['age'], [10, 20, 45])
+
+ @unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON)
+ def test_matched_condition_with_source_on_key(self):
+ target = self._create_table()
+ self._write(
+ target,
+ pa.Table.from_pydict(
+ {
+ 'id': pa.array([1, 2, 3], type=pa.int32()),
+ 'name': ['a', 'b', 'c'],
+ 'age': pa.array([10, 20, 30], type=pa.int32()),
+ },
+ schema=self.pa_schema,
+ ),
+ )
+
+ source = pa.Table.from_pydict(
+ {
+ 'id': pa.array([1, 2, 3], type=pa.int32()),
+ 'name': ['a2', 'b2', 'c2'],
+ 'age': pa.array([15, 25, 35], type=pa.int32()),
+ },
+ schema=self.pa_schema,
+ )
+
+ merge_into(
+ target=target,
+ source=source,
+ catalog_options=self.catalog_options,
+ on=['id'],
+ when_matched=[WhenMatched(update='*', condition='s.id >= 2')],
+ num_partitions=_TEST_NUM_PARTITIONS,
+ )
+
+ out = self._read_sorted(target)
+ self.assertEqual(out['id'], [1, 2, 3])
+ self.assertEqual(out['name'], ['a', 'b2', 'c2'])
+ self.assertEqual(out['age'], [10, 25, 35])
+
+ @unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON)
+ def test_not_matched_insert_with_condition(self):
+ target = self._create_table()
+ self._write(
+ target,
+ pa.Table.from_pydict(
+ {
+ 'id': pa.array([1], type=pa.int32()),
+ 'name': ['a'],
+ 'age': pa.array([10], type=pa.int32()),
+ },
+ schema=self.pa_schema,
+ ),
+ )
+
+ source = pa.Table.from_pydict(
+ {
+ 'id': pa.array([2, 3, 4], type=pa.int32()),
+ 'name': ['b', 'c', 'd'],
+ 'age': pa.array([15, 25, 5], type=pa.int32()),
+ },
+ schema=self.pa_schema,
+ )
+
+ merge_into(
+ target=target,
+ source=source,
+ catalog_options=self.catalog_options,
+ on=['id'],
+ when_not_matched=[
+ WhenNotMatched(insert='*', condition='s.age >= 10')
+ ],
+ num_partitions=_TEST_NUM_PARTITIONS,
+ )
+
+ out = self._read_sorted(target)
+ self.assertEqual(out['id'], [1, 2, 3])
+ self.assertEqual(out['name'], ['a', 'b', 'c'])
+ self.assertEqual(out['age'], [10, 15, 25])
+
+ @unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON)
+ def test_combined_with_conditions(self):
+ target = self._create_table()
+ self._write(
+ target,
+ pa.Table.from_pydict(
+ {
+ 'id': pa.array([1, 2], type=pa.int32()),
+ 'name': ['a', 'b'],
+ 'age': pa.array([10, 20], type=pa.int32()),
+ },
+ schema=self.pa_schema,
+ ),
+ )
+
+ source = pa.Table.from_pydict(
+ {
+ 'id': pa.array([1, 2, 3, 4], type=pa.int32()),
+ 'name': ['a2', 'b2', 'c', 'd'],
+ 'age': pa.array([50, 5, 30, 8], type=pa.int32()),
+ },
+ schema=self.pa_schema,
+ )
+
+ metrics = merge_into(
+ target=target,
+ source=source,
+ catalog_options=self.catalog_options,
+ on=['id'],
+ when_matched=[WhenMatched(update='*', condition='s.age > t.age')],
+ when_not_matched=[
+ WhenNotMatched(insert='*', condition='s.age > 10')
+ ],
+ num_partitions=_TEST_NUM_PARTITIONS,
+ )
+
+ out = self._read_sorted(target)
+ self.assertEqual(out['id'], [1, 2, 3])
+ self.assertEqual(out['name'], ['a2', 'b', 'c'])
+ self.assertEqual(out['age'], [50, 20, 30])
+ self.assertEqual(metrics['num_matched'], 1)
+ self.assertEqual(metrics['num_inserted'], 1)
+
+ @unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON)
+ def test_condition_no_rows_match_is_noop(self):
+ target = self._create_table()
+ self._write(
+ target,
+ pa.Table.from_pydict(
+ {
+ 'id': pa.array([1, 2], type=pa.int32()),
+ 'name': ['a', 'b'],
+ 'age': pa.array([10, 20], type=pa.int32()),
+ },
+ schema=self.pa_schema,
+ ),
+ )
+
+ source = pa.Table.from_pydict(
+ {
+ 'id': pa.array([1, 2], type=pa.int32()),
+ 'name': ['a2', 'b2'],
+ 'age': pa.array([5, 5], type=pa.int32()),
+ },
+ schema=self.pa_schema,
+ )
+
+ merge_into(
+ target=target,
+ source=source,
+ catalog_options=self.catalog_options,
+ on=['id'],
+ when_matched=[WhenMatched(update='*', condition='s.age > t.age')],
+ num_partitions=_TEST_NUM_PARTITIONS,
+ )
+
+ out = self._read_sorted(target)
+ self.assertEqual(out['id'], [1, 2])
+ self.assertEqual(out['name'], ['a', 'b'])
+ self.assertEqual(out['age'], [10, 20])
+
+ @unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON)
+ def test_duplicate_source_filtered_by_condition(self):
+ target = self._create_table()
+ self._write(
+ target,
+ pa.Table.from_pydict(
+ {
+ 'id': pa.array([1], type=pa.int32()),
+ 'name': ['a'],
+ 'age': pa.array([10], type=pa.int32()),
+ },
+ schema=self.pa_schema,
+ ),
+ )
+
+ source = pa.Table.from_pydict(
+ {
+ 'id': pa.array([1, 1], type=pa.int32()),
+ 'name': ['x', 'y'],
+ 'age': pa.array([5, 20], type=pa.int32()),
+ },
+ schema=self.pa_schema,
+ )
+
+ merge_into(
+ target=target,
+ source=source,
+ catalog_options=self.catalog_options,
+ on=['id'],
+ when_matched=[
+ WhenMatched(update='*', condition='s.age > t.age')
+ ],
+ num_partitions=_TEST_NUM_PARTITIONS,
+ )
+
+ out = self._read_sorted(target)
+ self.assertEqual(out['id'], [1])
+ self.assertEqual(out['name'], ['y'])
+ self.assertEqual(out['age'], [20])
+
class TargetProjectionTest(unittest.TestCase):
- def _clause(self, spec):
+ def _clause(self, spec, condition=None):
from pypaimon.ray import data_evolution_merge_into as m
- return m._NormalizedClause(spec=spec)
+ return m._NormalizedClause(spec=spec, condition=condition)
def test_unconditional_set_excludes_target_update_col(self):
from pypaimon.ray import data_evolution_merge_into as m
@@ -534,6 +830,91 @@ def test_unconditional_set_excludes_target_update_col(self):
)
self.assertEqual(['id'], cols)
+ def test_condition_adds_referenced_target_cols(self):
+ from pypaimon.ray import data_evolution_merge_into as m
+ cols = m._resolve_target_projection(
+ [self._clause({'feature': 's.feature'}, condition='s.age > t.age')],
+ ['id'], ['feature'], ['id', 'feature', 'age', 'image'],
+ )
+ self.assertIn('age', cols)
+ self.assertIn('id', cols)
+
+
+class MergeConditionUnitTest(unittest.TestCase):
+
+ def test_rewrite_condition(self):
+ from pypaimon.ray.merge_condition import rewrite_condition
+ self.assertEqual(
+ rewrite_condition('s.age > t.age + 10'),
+ '"s.age" > "t.age" + 10',
+ )
+
+ def test_rewrite_condition_preserves_string_literals(self):
+ from pypaimon.ray.merge_condition import rewrite_condition
+ self.assertEqual(
+ rewrite_condition("s.status = 't.active' AND s.age > t.age"),
+ '"s.status" = \'t.active\' AND "s.age" > "t.age"',
+ )
+
+ def test_remap_source_on_keys(self):
+ from pypaimon.ray.merge_condition import (
+ remap_source_on_keys, rewrite_condition,
+ )
+ rewritten = rewrite_condition('s.id > 1 AND s.age > t.age')
+ remapped = remap_source_on_keys(rewritten, {'id': 'id'})
+ self.assertEqual(remapped, '"t.id" > 1 AND "s.age" > "t.age"')
+
+ def test_remap_source_on_keys_renamed(self):
+ from pypaimon.ray.merge_condition import (
+ remap_source_on_keys, rewrite_condition,
+ )
+ rewritten = rewrite_condition('s.uid > 1')
+ remapped = remap_source_on_keys(rewritten, {'uid': 'id'})
+ self.assertEqual(remapped, '"t.id" > 1')
+
+ def test_remap_preserves_string_literals(self):
+ from pypaimon.ray.merge_condition import (
+ remap_source_on_keys, rewrite_condition,
+ )
+ rewritten = rewrite_condition("s.note = '\"s.id\"' AND s.id = 1")
+ remapped = remap_source_on_keys(rewritten, {'id': 'id'})
+ self.assertEqual(
+ remapped,
+ '"s.note" = \'\"s.id\"\' AND "t.id" = 1',
+ )
+
+ def test_extract_target_columns(self):
+ from pypaimon.ray.merge_condition import extract_target_columns
+ self.assertEqual(
+ extract_target_columns('s.name = t.name AND s.age > t.age'),
+ {'name', 'age'},
+ )
+
+ def test_extract_target_columns_ignores_string_literals(self):
+ from pypaimon.ray.merge_condition import extract_target_columns
+ self.assertEqual(
+ extract_target_columns("s.name = 't.fake' AND s.age > t.age"),
+ {'age'},
+ )
+
+ def test_extract_columns(self):
+ from pypaimon.ray.merge_condition import extract_columns
+ self.assertEqual(
+ extract_columns('s.id = t.id AND s.age > t.age'),
+ {'s.id', 't.id', 's.age', 't.age'},
+ )
+
+ @unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON)
+ def test_filter_batch(self):
+ from pypaimon.ray.merge_condition import filter_batch
+ batch = pa.table({
+ 's.id': pa.array([1, 2, 3], type=pa.int32()),
+ 's.age': pa.array([10, 25, 30], type=pa.int32()),
+ 't.age': pa.array([20, 20, 20], type=pa.int32()),
+ })
+ result = filter_batch(batch, 's.age > t.age')
+ self.assertEqual(result.column('s.id').to_pylist(), [2, 3])
+
if __name__ == '__main__':
unittest.main()