From 5d1438447aa26b1011886385232b40a2f3d65bf6 Mon Sep 17 00:00:00 2001 From: xiaohongbo Date: Wed, 3 Jun 2026 15:22:06 +0800 Subject: [PATCH 1/7] [ray] Support condition expressions in merge_into Add condition support for WhenMatched and WhenNotMatched clauses using DataFusion SQL engine for expression evaluation. - Condition filtering in both matched (update) and not-matched (insert) paths - Rewrite and remap respect SQL string literal spans - Validate: WhenNotMatched rejects t.* refs, blob column refs rejected - Fail-fast datafusion availability check - Source ON key remapped to target ON key in matched conditions - Add datafusion>=52 to CI dependencies - SessionContext cached per worker, empty batch handled safely --- .github/workflows/paimon-python-checks.yml | 2 +- docs/docs/pypaimon/ray-data.md | 28 +- .../pypaimon/ray/data_evolution_merge_into.py | 41 ++- .../pypaimon/ray/data_evolution_merge_join.py | 32 ++ .../ray/data_evolution_merge_transform.py | 1 + paimon-python/pypaimon/ray/merge_condition.py | 116 +++++++ .../ray_data_evolution_merge_into_test.py | 299 +++++++++++++++++- 7 files changed, 502 insertions(+), 17 deletions(-) create mode 100644 paimon-python/pypaimon/ray/merge_condition.py diff --git a/.github/workflows/paimon-python-checks.yml b/.github/workflows/paimon-python-checks.yml index 6a88767590e3..dc4cccaeea20 100755 --- a/.github/workflows/paimon-python-checks.yml +++ b/.github/workflows/paimon-python-checks.yml @@ -133,7 +133,7 @@ jobs: else python -m pip install --upgrade pip pip install torch --index-url https://download.pytorch.org/whl/cpu - python -m pip install pyroaring readerwriterlock==1.0.9 fsspec==2024.3.1 cachetools==5.3.3 ossfs==2023.12.0 ray==2.54.0 fastavro==1.11.1 pyarrow==16.0.0 zstandard==0.24.0 polars==1.32.0 duckdb==1.3.2 numpy==1.24.3 pandas==2.0.3 pylance==0.39.0 cramjam flake8==4.0.1 pytest~=7.0 py4j==0.10.9.9 requests parameterized==0.9.0 'daft>=0.7.6' pypaimon-rust==0.2.0 + python -m pip install pyroaring readerwriterlock==1.0.9 fsspec==2024.3.1 cachetools==5.3.3 ossfs==2023.12.0 ray==2.54.0 fastavro==1.11.1 pyarrow==16.0.0 zstandard==0.24.0 polars==1.32.0 duckdb==1.3.2 numpy==1.24.3 pandas==2.0.3 pylance==0.39.0 cramjam flake8==4.0.1 pytest~=7.0 py4j==0.10.9.9 requests parameterized==0.9.0 'daft>=0.7.6' pypaimon-rust==0.2.0 'datafusion>=52' python -m pip install 'lumina-data>=${{ env.LUMINA_DATA_VERSION }}' -i https://pypi.org/simple/ if python -c "import sys; sys.exit(0 if sys.version_info >= (3, 11) else 1)"; then python -m pip install vortex-data==0.70.0 diff --git a/docs/docs/pypaimon/ray-data.md b/docs/docs/pypaimon/ray-data.md index 658a1098ae04..a3c59c9f68b8 100644 --- a/docs/docs/pypaimon/ray-data.md +++ b/docs/docs/pypaimon/ray-data.md @@ -357,12 +357,29 @@ metrics = merge_into( print(metrics) # {"num_matched": 3, "num_inserted": 2, "num_unchanged": 0} ``` +Conditional clauses filter which matched/unmatched rows are acted on: + +```python +merge_into( + target="db.table", + source=source_ds, + catalog_options=catalog_options, + on=["id"], + when_matched=[WhenMatched(update="*", condition="s.age > t.age")], + when_not_matched=[WhenNotMatched(insert="*", condition="s.age > 18")], +) +``` + +Conditions use SQL-style expressions with `s.` (source) and `t.` (target) +column prefixes. `WhenNotMatched` conditions may only reference source +columns (`s.*`). Requires the `datafusion` package: `pip install pypaimon[sql]`. + - `update` / `insert`: only `"*"` is supported in this PR. A future follow-up will add mapping-based SET (e.g. `{"col": "s.col"}`) where values are analyzable string expressions (`"s."`, `"t."`, or literals), not Python callables. -- `condition`: reserved for a future follow-up; passing a non-None value - currently raises `NotImplementedError`. +- `condition`: an optional SQL-style boolean expression. Use `s.` and + `t.` to reference source and target columns. **Parameters:** - `source`: a `ray.data.Dataset`, `pyarrow.Table`, `pandas.DataFrame`, or a @@ -375,10 +392,9 @@ print(metrics) # {"num_matched": 3, "num_inserted": 2, "num_unchanged": 0} tasks (update transform, group write, insert transform). - `concurrency`: scheduling for the insert sink. -**Returns:** `{"num_matched", "num_inserted", "num_unchanged"}`. In this PR -every matched row is updated, so `num_matched` always equals `num_updated` -and `num_unchanged` is always `0`; conditional clauses (added later) can -make `num_unchanged > 0`. +**Returns:** `{"num_matched", "num_inserted", "num_unchanged"}`. `num_matched` +counts the rows actually updated (after condition filtering). `num_unchanged` +is `0` in the current implementation. **Notes:** - Partition key columns cannot be updated by matched clauses. If the target diff --git a/paimon-python/pypaimon/ray/data_evolution_merge_into.py b/paimon-python/pypaimon/ray/data_evolution_merge_into.py index b90abdc745c7..1a701af95c86 100644 --- a/paimon-python/pypaimon/ray/data_evolution_merge_into.py +++ b/paimon-python/pypaimon/ray/data_evolution_merge_into.py @@ -96,12 +96,6 @@ def _prepare(target, source, catalog_options, when_matched, when_not_matched, on "WhenNotMatched clause; multi-clause fall-through will be added " "in a follow-up PR." ) - for clause in list(when_matched) + list(when_not_matched): - if clause.condition is not None: - raise NotImplementedError( - "merge_into does not yet support condition expressions; " - "this will be added in a follow-up PR." - ) target_on_cols, source_on_cols = _normalize_on(on) from pypaimon.catalog.catalog_factory import CatalogFactory @@ -136,14 +130,42 @@ def _prepare(target, source, catalog_options, when_matched, when_not_matched, on spec=_normalize_set_spec( c.update, settable_field_names, on_map, ), + condition=c.condition, ) for c in when_matched ] + has_condition = any( + c.condition is not None + for c in list(when_matched) + list(when_not_matched) + ) + if has_condition: + from pypaimon.ray.merge_condition import ( + _require_datafusion, extract_target_columns, + ) + _require_datafusion() + for c in list(when_matched) + list(when_not_matched): + if c.condition is not None: + blob_refs = extract_target_columns(c.condition) & blob_cols + if blob_refs: + raise ValueError( + f"condition must not reference blob columns, " + f"but found: {sorted(blob_refs)}" + ) + for c in when_not_matched: + if c.condition is not None: + from pypaimon.ray.merge_condition import extract_target_columns + t_refs = extract_target_columns(c.condition) + if t_refs: + raise ValueError( + f"WhenNotMatched condition must not reference target " + f"columns (t.*), but found: {sorted(t_refs)}" + ) not_matched_specs = [ _NormalizedClause( spec=_normalize_set_spec( c.insert, settable_field_names, on_map, ), + condition=c.condition, ) for c in when_not_matched ] @@ -272,8 +294,6 @@ def _execute_and_commit( tc.commit(all_msgs) tc.close() - # MVP has no condition, so every matched row is updated; num_unchanged - # is always 0. Kept in the dict for API stability when condition lands. return { "num_matched": num_updated, "num_inserted": num_inserted, @@ -375,6 +395,11 @@ def _resolve_target_projection( needed = set(_needed_target_cols( clauses, target_on, update_cols, target_field_names, )) + from pypaimon.ray.merge_condition import extract_target_columns + target_set = set(target_field_names) + for clause in clauses: + if clause.condition is not None: + needed |= extract_target_columns(clause.condition) & target_set return [c for c in target_field_names if c in needed] diff --git a/paimon-python/pypaimon/ray/data_evolution_merge_join.py b/paimon-python/pypaimon/ray/data_evolution_merge_join.py index 40e7994b2ddd..6d8b40e39028 100644 --- a/paimon-python/pypaimon/ray/data_evolution_merge_join.py +++ b/paimon-python/pypaimon/ray/data_evolution_merge_join.py @@ -94,12 +94,31 @@ def build_matched_update_ds( f"build_matched_update_ds expected 1 clause, got {len(clauses)}" ) spec = clauses[0].spec + condition = clauses[0].condition captured_update_cols = list(update_cols) captured_row_id_name = row_id_name captured_on_pairs = list(zip(source_on, target_on)) captured_schema = update_schema + captured_apply = None + captured_rewritten = None + if condition is not None: + from pypaimon.ray.merge_condition import ( + apply_condition, remap_source_on_keys, rewrite_condition, + ) + on_map = dict(zip(source_on, target_on)) + captured_rewritten = remap_source_on_keys( + rewrite_condition(condition), on_map, + ) + captured_apply = apply_condition + def _transform(batch: pa.Table) -> pa.Table: + if captured_apply is not None: + batch = captured_apply( + batch, captured_rewritten, captured_schema, + ) + if batch.num_rows == 0: + return batch return vectorized_matched_transform( batch, spec, captured_on_pairs, captured_update_cols, captured_row_id_name, @@ -308,8 +327,21 @@ def build_not_matched_insert_ds( f"build_not_matched_insert_ds expected 1 clause, got {len(clauses)}" ) spec = clauses[0].spec + condition = clauses[0].condition + captured_apply = None + captured_rewritten = None + if condition is not None: + from pypaimon.ray.merge_condition import apply_condition, rewrite_condition + captured_rewritten = rewrite_condition(condition) + captured_apply = apply_condition def _transform(batch: pa.Table) -> pa.Table: + if captured_apply is not None: + batch = captured_apply( + batch, captured_rewritten, out_schema, + ) + if batch.num_rows == 0: + return _coerce_large_string_types(batch) return _coerce_large_string_types( vectorized_insert_transform( batch, spec, captured_field_names, out_schema diff --git a/paimon-python/pypaimon/ray/data_evolution_merge_transform.py b/paimon-python/pypaimon/ray/data_evolution_merge_transform.py index 0fc2d22f779d..ed786467f1e7 100644 --- a/paimon-python/pypaimon/ray/data_evolution_merge_transform.py +++ b/paimon-python/pypaimon/ray/data_evolution_merge_transform.py @@ -40,6 +40,7 @@ class WhenNotMatched: @dataclass class _NormalizedClause: spec: Dict[str, Any] + condition: Optional[str] = None def vectorized_matched_transform( diff --git a/paimon-python/pypaimon/ray/merge_condition.py b/paimon-python/pypaimon/ray/merge_condition.py new file mode 100644 index 000000000000..28147be4c568 --- /dev/null +++ b/paimon-python/pypaimon/ray/merge_condition.py @@ -0,0 +1,116 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +import re +from typing import Mapping, Set + +import pyarrow as pa + + +_COL_REF_PATTERN = re.compile(r'\b([st])\.(\w+)\b') + + +def _require_datafusion(): + try: + import datafusion + return datafusion + except ImportError: + raise ImportError( + "merge_into condition expressions require the 'datafusion' " + "package. Install it with: pip install pypaimon[sql]" + ) + + +_STRING_LITERAL = re.compile(r"'(?:[^']|'')*'") + + +def _strip_string_literals(condition: str) -> str: + return _STRING_LITERAL.sub('', condition) + + +def rewrite_condition(condition: str) -> str: + parts, last = [], 0 + for m in _STRING_LITERAL.finditer(condition): + parts.append(_COL_REF_PATTERN.sub(r'"\1.\2"', condition[last:m.start()])) + parts.append(m.group()) + last = m.end() + parts.append(_COL_REF_PATTERN.sub(r'"\1.\2"', condition[last:])) + return ''.join(parts) + + +def remap_source_on_keys( + rewritten: str, on_map: Mapping[str, str], +) -> str: + for s_col, t_col in on_map.items(): + old, new = f'"s.{s_col}"', f'"t.{t_col}"' + parts, last = [], 0 + for m in _STRING_LITERAL.finditer(rewritten): + parts.append(rewritten[last:m.start()].replace(old, new)) + parts.append(m.group()) + last = m.end() + parts.append(rewritten[last:].replace(old, new)) + rewritten = ''.join(parts) + return rewritten + + +_SESSION_CTX = None + + +def _get_session_context(): + global _SESSION_CTX + if _SESSION_CTX is None: + datafusion = _require_datafusion() + _SESSION_CTX = datafusion.SessionContext() + return _SESSION_CTX + + +def filter_batch( + batch: pa.Table, condition: str, _pre_rewritten: bool = False, +) -> pa.Table: + if batch.num_rows == 0: + return batch + rewritten = condition if _pre_rewritten else rewrite_condition(condition) + ctx = _get_session_context() + if ctx.table_exist("_merge_batch"): + ctx.deregister_table("_merge_batch") + ctx.register_record_batches("_merge_batch", [batch.to_batches()]) + result = ctx.sql( + f'SELECT * FROM _merge_batch WHERE {rewritten}' + ) + return result.to_arrow_table() + + +def apply_condition( + batch: pa.Table, rewritten: str, empty_schema: pa.Schema, +) -> pa.Table: + batch = filter_batch(batch, rewritten, _pre_rewritten=True) + if batch.num_rows == 0: + return empty_schema.empty_table() + return batch + + +def extract_columns(condition: str) -> Set[str]: + stripped = _strip_string_literals(condition) + return {f"{m.group(1)}.{m.group(2)}" + for m in _COL_REF_PATTERN.finditer(stripped)} + + +def extract_target_columns(condition: str) -> Set[str]: + stripped = _strip_string_literals(condition) + return {m.group(2) for m in _COL_REF_PATTERN.finditer(stripped) + if m.group(1) == "t"} diff --git a/paimon-python/pypaimon/tests/ray_data_evolution_merge_into_test.py b/paimon-python/pypaimon/tests/ray_data_evolution_merge_into_test.py index 727185a2e4d2..67b707d8a4af 100644 --- a/paimon-python/pypaimon/tests/ray_data_evolution_merge_into_test.py +++ b/paimon-python/pypaimon/tests/ray_data_evolution_merge_into_test.py @@ -153,6 +153,21 @@ def test_source_missing_on_col_raises(self): ) self.assertIn("'id'", str(ctx.exception)) + def test_not_matched_condition_rejects_target_refs(self): + target = self._create_table() + with self.assertRaises(ValueError) as ctx: + merge_into( + target=target, + source=self._source(), + catalog_options=self.catalog_options, + on=['id'], + when_not_matched=[ + WhenNotMatched(insert='*', condition='t.age > 10') + ], + num_partitions=_TEST_NUM_PARTITIONS, + ) + self.assertIn('t.', str(ctx.exception)) + def test_matched_update_star(self): target = self._create_table() self._write( @@ -519,12 +534,204 @@ def test_partitioned_insert_allowed(self): self.assertEqual(out['id'], [1, 2]) self.assertEqual(out['pt'], ['a', 'b']) + def test_matched_update_with_condition(self): + target = self._create_table() + self._write( + target, + pa.Table.from_pydict( + { + 'id': pa.array([1, 2, 3], type=pa.int32()), + 'name': ['a', 'b', 'c'], + 'age': pa.array([10, 20, 30], type=pa.int32()), + }, + schema=self.pa_schema, + ), + ) + + source = pa.Table.from_pydict( + { + 'id': pa.array([1, 2, 3], type=pa.int32()), + 'name': ['a2', 'b2', 'c2'], + 'age': pa.array([15, 25, 45], type=pa.int32()), + }, + schema=self.pa_schema, + ) + + merge_into( + target=target, + source=source, + catalog_options=self.catalog_options, + on=['id'], + when_matched=[WhenMatched(update='*', condition='s.age > t.age + 10')], + num_partitions=_TEST_NUM_PARTITIONS, + ) + + out = self._read_sorted(target) + self.assertEqual(out['id'], [1, 2, 3]) + self.assertEqual(out['name'], ['a', 'b', 'c2']) + self.assertEqual(out['age'], [10, 20, 45]) + + def test_matched_condition_with_source_on_key(self): + target = self._create_table() + self._write( + target, + pa.Table.from_pydict( + { + 'id': pa.array([1, 2, 3], type=pa.int32()), + 'name': ['a', 'b', 'c'], + 'age': pa.array([10, 20, 30], type=pa.int32()), + }, + schema=self.pa_schema, + ), + ) + + source = pa.Table.from_pydict( + { + 'id': pa.array([1, 2, 3], type=pa.int32()), + 'name': ['a2', 'b2', 'c2'], + 'age': pa.array([15, 25, 35], type=pa.int32()), + }, + schema=self.pa_schema, + ) + + merge_into( + target=target, + source=source, + catalog_options=self.catalog_options, + on=['id'], + when_matched=[WhenMatched(update='*', condition='s.id >= 2')], + num_partitions=_TEST_NUM_PARTITIONS, + ) + + out = self._read_sorted(target) + self.assertEqual(out['id'], [1, 2, 3]) + self.assertEqual(out['name'], ['a', 'b2', 'c2']) + self.assertEqual(out['age'], [10, 25, 35]) + + def test_not_matched_insert_with_condition(self): + target = self._create_table() + self._write( + target, + pa.Table.from_pydict( + { + 'id': pa.array([1], type=pa.int32()), + 'name': ['a'], + 'age': pa.array([10], type=pa.int32()), + }, + schema=self.pa_schema, + ), + ) + + source = pa.Table.from_pydict( + { + 'id': pa.array([2, 3, 4], type=pa.int32()), + 'name': ['b', 'c', 'd'], + 'age': pa.array([15, 25, 5], type=pa.int32()), + }, + schema=self.pa_schema, + ) + + merge_into( + target=target, + source=source, + catalog_options=self.catalog_options, + on=['id'], + when_not_matched=[ + WhenNotMatched(insert='*', condition='s.age >= 10') + ], + num_partitions=_TEST_NUM_PARTITIONS, + ) + + out = self._read_sorted(target) + self.assertEqual(out['id'], [1, 2, 3]) + self.assertEqual(out['name'], ['a', 'b', 'c']) + self.assertEqual(out['age'], [10, 15, 25]) + + def test_combined_with_conditions(self): + target = self._create_table() + self._write( + target, + pa.Table.from_pydict( + { + 'id': pa.array([1, 2], type=pa.int32()), + 'name': ['a', 'b'], + 'age': pa.array([10, 20], type=pa.int32()), + }, + schema=self.pa_schema, + ), + ) + + source = pa.Table.from_pydict( + { + 'id': pa.array([1, 2, 3, 4], type=pa.int32()), + 'name': ['a2', 'b2', 'c', 'd'], + 'age': pa.array([50, 5, 30, 8], type=pa.int32()), + }, + schema=self.pa_schema, + ) + + metrics = merge_into( + target=target, + source=source, + catalog_options=self.catalog_options, + on=['id'], + when_matched=[WhenMatched(update='*', condition='s.age > t.age')], + when_not_matched=[ + WhenNotMatched(insert='*', condition='s.age > 10') + ], + num_partitions=_TEST_NUM_PARTITIONS, + ) + + out = self._read_sorted(target) + self.assertEqual(out['id'], [1, 2, 3]) + self.assertEqual(out['name'], ['a2', 'b', 'c']) + self.assertEqual(out['age'], [50, 20, 30]) + self.assertEqual(metrics['num_matched'], 1) + self.assertEqual(metrics['num_inserted'], 1) + + def test_condition_no_rows_match_is_noop(self): + target = self._create_table() + self._write( + target, + pa.Table.from_pydict( + { + 'id': pa.array([1, 2], type=pa.int32()), + 'name': ['a', 'b'], + 'age': pa.array([10, 20], type=pa.int32()), + }, + schema=self.pa_schema, + ), + ) + + source = pa.Table.from_pydict( + { + 'id': pa.array([1, 2], type=pa.int32()), + 'name': ['a2', 'b2'], + 'age': pa.array([5, 5], type=pa.int32()), + }, + schema=self.pa_schema, + ) + + merge_into( + target=target, + source=source, + catalog_options=self.catalog_options, + on=['id'], + when_matched=[WhenMatched(update='*', condition='s.age > t.age')], + num_partitions=_TEST_NUM_PARTITIONS, + ) + + out = self._read_sorted(target) + self.assertEqual(out['id'], [1, 2]) + self.assertEqual(out['name'], ['a', 'b']) + self.assertEqual(out['age'], [10, 20]) + class TargetProjectionTest(unittest.TestCase): - def _clause(self, spec): + def _clause(self, spec, condition=None): from pypaimon.ray import data_evolution_merge_into as m - return m._NormalizedClause(spec=spec) + return m._NormalizedClause(spec=spec, condition=condition) def test_unconditional_set_excludes_target_update_col(self): from pypaimon.ray import data_evolution_merge_into as m @@ -534,6 +741,94 @@ def test_unconditional_set_excludes_target_update_col(self): ) self.assertEqual(['id'], cols) + def test_condition_adds_referenced_target_cols(self): + from pypaimon.ray import data_evolution_merge_into as m + cols = m._resolve_target_projection( + [self._clause({'feature': 's.feature'}, condition='s.age > t.age')], + ['id'], ['feature'], ['id', 'feature', 'age', 'image'], + ) + self.assertIn('age', cols) + self.assertIn('id', cols) + + +class MergeConditionUnitTest(unittest.TestCase): + + def test_rewrite_condition(self): + from pypaimon.ray.merge_condition import rewrite_condition + self.assertEqual( + rewrite_condition('s.age > t.age + 10'), + '"s.age" > "t.age" + 10', + ) + + def test_rewrite_condition_preserves_string_literals(self): + from pypaimon.ray.merge_condition import rewrite_condition + self.assertEqual( + rewrite_condition("s.status = 't.active' AND s.age > t.age"), + '"s.status" = \'t.active\' AND "s.age" > "t.age"', + ) + + def test_remap_source_on_keys(self): + from pypaimon.ray.merge_condition import ( + remap_source_on_keys, rewrite_condition, + ) + rewritten = rewrite_condition('s.id > 1 AND s.age > t.age') + remapped = remap_source_on_keys(rewritten, {'id': 'id'}) + self.assertEqual(remapped, '"t.id" > 1 AND "s.age" > "t.age"') + + def test_remap_source_on_keys_renamed(self): + from pypaimon.ray.merge_condition import ( + remap_source_on_keys, rewrite_condition, + ) + rewritten = rewrite_condition('s.uid > 1') + remapped = remap_source_on_keys(rewritten, {'uid': 'id'}) + self.assertEqual(remapped, '"t.id" > 1') + + def test_remap_preserves_string_literals(self): + from pypaimon.ray.merge_condition import ( + remap_source_on_keys, rewrite_condition, + ) + rewritten = rewrite_condition("s.note = '\"s.id\"' AND s.id = 1") + remapped = remap_source_on_keys(rewritten, {'id': 'id'}) + self.assertEqual( + remapped, + '"s.note" = \'\"s.id\"\' AND "t.id" = 1', + ) + + def test_extract_target_columns(self): + from pypaimon.ray.merge_condition import extract_target_columns + self.assertEqual( + extract_target_columns('s.name = t.name AND s.age > t.age'), + {'name', 'age'}, + ) + + def test_extract_target_columns_ignores_string_literals(self): + from pypaimon.ray.merge_condition import extract_target_columns + self.assertEqual( + extract_target_columns("s.name = 't.fake' AND s.age > t.age"), + {'age'}, + ) + + def test_extract_columns(self): + from pypaimon.ray.merge_condition import extract_columns + self.assertEqual( + extract_columns('s.id = t.id AND s.age > t.age'), + {'s.id', 't.id', 's.age', 't.age'}, + ) + + def test_filter_batch(self): + try: + import datafusion # noqa: F401 + except ImportError: + self.skipTest("datafusion not installed") + from pypaimon.ray.merge_condition import filter_batch + batch = pa.table({ + 's.id': pa.array([1, 2, 3], type=pa.int32()), + 's.age': pa.array([10, 25, 30], type=pa.int32()), + 't.age': pa.array([20, 20, 20], type=pa.int32()), + }) + result = filter_batch(batch, 's.age > t.age') + self.assertEqual(result.column('s.id').to_pylist(), [2, 3]) + if __name__ == '__main__': unittest.main() From 23ef7457593b17c022e42844411f3f436c2b98d2 Mon Sep 17 00:00:00 2001 From: xiaohongbo Date: Wed, 3 Jun 2026 15:57:39 +0800 Subject: [PATCH 2/7] [ray] Harden condition validation and simplify SessionContext - Create fresh SessionContext per filter_batch call (no global state) - Guard merge_condition import behind condition check - Check WhenNotMatched target-ref before blob-ref for clearer errors - Clarify num_matched semantics in comment --- .../pypaimon/ray/data_evolution_merge_into.py | 29 ++++++++++--------- paimon-python/pypaimon/ray/merge_condition.py | 20 +++---------- 2 files changed, 19 insertions(+), 30 deletions(-) diff --git a/paimon-python/pypaimon/ray/data_evolution_merge_into.py b/paimon-python/pypaimon/ray/data_evolution_merge_into.py index 1a701af95c86..a886d5c55c7e 100644 --- a/paimon-python/pypaimon/ray/data_evolution_merge_into.py +++ b/paimon-python/pypaimon/ray/data_evolution_merge_into.py @@ -143,6 +143,14 @@ def _prepare(target, source, catalog_options, when_matched, when_not_matched, on _require_datafusion, extract_target_columns, ) _require_datafusion() + for c in when_not_matched: + if c.condition is not None: + t_refs = extract_target_columns(c.condition) + if t_refs: + raise ValueError( + f"WhenNotMatched condition must not reference " + f"target columns (t.*), but found: {sorted(t_refs)}" + ) for c in list(when_matched) + list(when_not_matched): if c.condition is not None: blob_refs = extract_target_columns(c.condition) & blob_cols @@ -151,15 +159,6 @@ def _prepare(target, source, catalog_options, when_matched, when_not_matched, on f"condition must not reference blob columns, " f"but found: {sorted(blob_refs)}" ) - for c in when_not_matched: - if c.condition is not None: - from pypaimon.ray.merge_condition import extract_target_columns - t_refs = extract_target_columns(c.condition) - if t_refs: - raise ValueError( - f"WhenNotMatched condition must not reference target " - f"columns (t.*), but found: {sorted(t_refs)}" - ) not_matched_specs = [ _NormalizedClause( spec=_normalize_set_spec( @@ -294,6 +293,7 @@ def _execute_and_commit( tc.commit(all_msgs) tc.close() + # num_matched = rows that passed the condition and were updated return { "num_matched": num_updated, "num_inserted": num_inserted, @@ -395,11 +395,12 @@ def _resolve_target_projection( needed = set(_needed_target_cols( clauses, target_on, update_cols, target_field_names, )) - from pypaimon.ray.merge_condition import extract_target_columns - target_set = set(target_field_names) - for clause in clauses: - if clause.condition is not None: - needed |= extract_target_columns(clause.condition) & target_set + if any(c.condition is not None for c in clauses): + from pypaimon.ray.merge_condition import extract_target_columns + target_set = set(target_field_names) + for clause in clauses: + if clause.condition is not None: + needed |= extract_target_columns(clause.condition) & target_set return [c for c in target_field_names if c in needed] diff --git a/paimon-python/pypaimon/ray/merge_condition.py b/paimon-python/pypaimon/ray/merge_condition.py index 28147be4c568..5497406c5cd2 100644 --- a/paimon-python/pypaimon/ray/merge_condition.py +++ b/paimon-python/pypaimon/ray/merge_condition.py @@ -68,29 +68,17 @@ def remap_source_on_keys( return rewritten -_SESSION_CTX = None - - -def _get_session_context(): - global _SESSION_CTX - if _SESSION_CTX is None: - datafusion = _require_datafusion() - _SESSION_CTX = datafusion.SessionContext() - return _SESSION_CTX - - def filter_batch( batch: pa.Table, condition: str, _pre_rewritten: bool = False, ) -> pa.Table: if batch.num_rows == 0: return batch + datafusion = _require_datafusion() rewritten = condition if _pre_rewritten else rewrite_condition(condition) - ctx = _get_session_context() - if ctx.table_exist("_merge_batch"): - ctx.deregister_table("_merge_batch") - ctx.register_record_batches("_merge_batch", [batch.to_batches()]) + ctx = datafusion.SessionContext() + ctx.register_record_batches("_batch", [batch.to_batches()]) result = ctx.sql( - f'SELECT * FROM _merge_batch WHERE {rewritten}' + f'SELECT * FROM _batch WHERE {rewritten}' ) return result.to_arrow_table() From 19381cf6455eed48c285ae967e22c872a038c4e9 Mon Sep 17 00:00:00 2001 From: xiaohongbo Date: Wed, 3 Jun 2026 16:17:18 +0800 Subject: [PATCH 3/7] [ray] Add datafusion to dev requirements Local dev installs via requirements-dev.txt were missing datafusion, causing condition integration tests to fail outside CI. --- paimon-python/dev/requirements-dev.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/paimon-python/dev/requirements-dev.txt b/paimon-python/dev/requirements-dev.txt index 9ef88817f726..c83a2e44b8b6 100644 --- a/paimon-python/dev/requirements-dev.txt +++ b/paimon-python/dev/requirements-dev.txt @@ -28,5 +28,7 @@ requests parameterized # Vortex 0.71.0 regresses native predicate pushdown on single-row files. vortex-data==0.70.0; python_version >= "3.11" +# merge_into condition expressions (optional, for condition tests) +datafusion>=52; python_version >= "3.10" # Lumina vector search (optional, for lumina index tests) lumina-data>=0.1.0 From aa7f3970ee17156de0a0eb9ee4af16eb161b52d7 Mon Sep 17 00:00:00 2001 From: xiaohongbo Date: Wed, 3 Jun 2026 17:30:54 +0800 Subject: [PATCH 4/7] [ray] Add test for duplicate source rows filtered by condition Two source rows match the same target row (id=1). Without condition this would raise "multiple source rows". With condition s.age > t.age, only the row with age=20 passes (age=5 is filtered), so the update succeeds with exactly one matching row. --- .../ray_data_evolution_merge_into_test.py | 39 +++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/paimon-python/pypaimon/tests/ray_data_evolution_merge_into_test.py b/paimon-python/pypaimon/tests/ray_data_evolution_merge_into_test.py index 67b707d8a4af..5a11f23afa09 100644 --- a/paimon-python/pypaimon/tests/ray_data_evolution_merge_into_test.py +++ b/paimon-python/pypaimon/tests/ray_data_evolution_merge_into_test.py @@ -726,6 +726,45 @@ def test_condition_no_rows_match_is_noop(self): self.assertEqual(out['name'], ['a', 'b']) self.assertEqual(out['age'], [10, 20]) + def test_duplicate_source_filtered_by_condition(self): + target = self._create_table() + self._write( + target, + pa.Table.from_pydict( + { + 'id': pa.array([1], type=pa.int32()), + 'name': ['a'], + 'age': pa.array([10], type=pa.int32()), + }, + schema=self.pa_schema, + ), + ) + + source = pa.Table.from_pydict( + { + 'id': pa.array([1, 1], type=pa.int32()), + 'name': ['x', 'y'], + 'age': pa.array([5, 20], type=pa.int32()), + }, + schema=self.pa_schema, + ) + + merge_into( + target=target, + source=source, + catalog_options=self.catalog_options, + on=['id'], + when_matched=[ + WhenMatched(update='*', condition='s.age > t.age') + ], + num_partitions=_TEST_NUM_PARTITIONS, + ) + + out = self._read_sorted(target) + self.assertEqual(out['id'], [1]) + self.assertEqual(out['name'], ['y']) + self.assertEqual(out['age'], [20]) + class TargetProjectionTest(unittest.TestCase): From 927511a036e798d285f217fdbb56222a1621bed2 Mon Sep 17 00:00:00 2001 From: xiaohongbo Date: Wed, 3 Jun 2026 17:41:32 +0800 Subject: [PATCH 5/7] [ray] Validate condition column refs against source/target schema Check that s.* and t.* references in condition expressions exist in the source and target schemas at merge_into call time, instead of deferring to DataFusion runtime errors. --- .../pypaimon/ray/data_evolution_merge_into.py | 19 +++++++++++ .../ray_data_evolution_merge_into_test.py | 32 +++++++++++++++++++ 2 files changed, 51 insertions(+) diff --git a/paimon-python/pypaimon/ray/data_evolution_merge_into.py b/paimon-python/pypaimon/ray/data_evolution_merge_into.py index a886d5c55c7e..5be68f301e05 100644 --- a/paimon-python/pypaimon/ray/data_evolution_merge_into.py +++ b/paimon-python/pypaimon/ray/data_evolution_merge_into.py @@ -175,6 +175,25 @@ def _prepare(target, source, catalog_options, when_matched, when_not_matched, on source_ds, settable_field_names, on_map, ) + if has_condition: + from pypaimon.ray.merge_condition import extract_columns + source_names = set(_source_schema_or_raise(source_ds).names) + target_names = set(full_target_field_names) + for c in list(when_matched) + list(when_not_matched): + if c.condition is not None: + for ref in extract_columns(c.condition): + prefix, col = ref.split(".", 1) + if prefix == "s" and col not in source_names: + raise ValueError( + f"condition references unknown source " + f"column '{col}'" + ) + if prefix == "t" and col not in target_names: + raise ValueError( + f"condition references unknown target " + f"column '{col}'" + ) + from pypaimon.schema.data_types import PyarrowFieldParser full_pa_schema = PyarrowFieldParser.from_paimon_schema( table.table_schema.fields diff --git a/paimon-python/pypaimon/tests/ray_data_evolution_merge_into_test.py b/paimon-python/pypaimon/tests/ray_data_evolution_merge_into_test.py index 5a11f23afa09..50855421f02c 100644 --- a/paimon-python/pypaimon/tests/ray_data_evolution_merge_into_test.py +++ b/paimon-python/pypaimon/tests/ray_data_evolution_merge_into_test.py @@ -168,6 +168,38 @@ def test_not_matched_condition_rejects_target_refs(self): ) self.assertIn('t.', str(ctx.exception)) + def test_condition_unknown_source_col_rejected(self): + target = self._create_table() + self._write(target, self._source()) + with self.assertRaises(ValueError) as ctx: + merge_into( + target=target, + source=self._source(), + catalog_options=self.catalog_options, + on=['id'], + when_matched=[ + WhenMatched(update='*', condition='s.nonexistent > 0') + ], + num_partitions=_TEST_NUM_PARTITIONS, + ) + self.assertIn('nonexistent', str(ctx.exception)) + + def test_condition_unknown_target_col_rejected(self): + target = self._create_table() + self._write(target, self._source()) + with self.assertRaises(ValueError) as ctx: + merge_into( + target=target, + source=self._source(), + catalog_options=self.catalog_options, + on=['id'], + when_matched=[ + WhenMatched(update='*', condition='s.age > t.nonexistent') + ], + num_partitions=_TEST_NUM_PARTITIONS, + ) + self.assertIn('nonexistent', str(ctx.exception)) + def test_matched_update_star(self): target = self._create_table() self._write( From dab57ba09e464f7797a00ad8e24f17244bbec33f Mon Sep 17 00:00:00 2001 From: xiaohongbo Date: Wed, 3 Jun 2026 17:59:24 +0800 Subject: [PATCH 6/7] [ray] Skip condition tests when datafusion is not installed Add @unittest.skipIf decorator to all condition E2E tests so they gracefully skip on Python < 3.10 or environments without datafusion. --- .../ray_data_evolution_merge_into_test.py | 22 +++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/paimon-python/pypaimon/tests/ray_data_evolution_merge_into_test.py b/paimon-python/pypaimon/tests/ray_data_evolution_merge_into_test.py index 50855421f02c..0fcd3e219caf 100644 --- a/paimon-python/pypaimon/tests/ray_data_evolution_merge_into_test.py +++ b/paimon-python/pypaimon/tests/ray_data_evolution_merge_into_test.py @@ -28,6 +28,15 @@ from pypaimon import CatalogFactory, Schema from pypaimon.ray import WhenMatched, WhenNotMatched, merge_into +try: + import datafusion # noqa: F401 + _HAS_DATAFUSION = True +except ImportError: + _HAS_DATAFUSION = False + +_SKIP_CONDITION = not _HAS_DATAFUSION +_SKIP_REASON = "datafusion not installed" + _TEST_NUM_PARTITIONS = 2 @@ -168,6 +177,7 @@ def test_not_matched_condition_rejects_target_refs(self): ) self.assertIn('t.', str(ctx.exception)) + @unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON) def test_condition_unknown_source_col_rejected(self): target = self._create_table() self._write(target, self._source()) @@ -184,6 +194,7 @@ def test_condition_unknown_source_col_rejected(self): ) self.assertIn('nonexistent', str(ctx.exception)) + @unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON) def test_condition_unknown_target_col_rejected(self): target = self._create_table() self._write(target, self._source()) @@ -566,6 +577,7 @@ def test_partitioned_insert_allowed(self): self.assertEqual(out['id'], [1, 2]) self.assertEqual(out['pt'], ['a', 'b']) + @unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON) def test_matched_update_with_condition(self): target = self._create_table() self._write( @@ -603,6 +615,7 @@ def test_matched_update_with_condition(self): self.assertEqual(out['name'], ['a', 'b', 'c2']) self.assertEqual(out['age'], [10, 20, 45]) + @unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON) def test_matched_condition_with_source_on_key(self): target = self._create_table() self._write( @@ -640,6 +653,7 @@ def test_matched_condition_with_source_on_key(self): self.assertEqual(out['name'], ['a', 'b2', 'c2']) self.assertEqual(out['age'], [10, 25, 35]) + @unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON) def test_not_matched_insert_with_condition(self): target = self._create_table() self._write( @@ -679,6 +693,7 @@ def test_not_matched_insert_with_condition(self): self.assertEqual(out['name'], ['a', 'b', 'c']) self.assertEqual(out['age'], [10, 15, 25]) + @unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON) def test_combined_with_conditions(self): target = self._create_table() self._write( @@ -721,6 +736,7 @@ def test_combined_with_conditions(self): self.assertEqual(metrics['num_matched'], 1) self.assertEqual(metrics['num_inserted'], 1) + @unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON) def test_condition_no_rows_match_is_noop(self): target = self._create_table() self._write( @@ -758,6 +774,7 @@ def test_condition_no_rows_match_is_noop(self): self.assertEqual(out['name'], ['a', 'b']) self.assertEqual(out['age'], [10, 20]) + @unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON) def test_duplicate_source_filtered_by_condition(self): target = self._create_table() self._write( @@ -886,11 +903,8 @@ def test_extract_columns(self): {'s.id', 't.id', 's.age', 't.age'}, ) + @unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON) def test_filter_batch(self): - try: - import datafusion # noqa: F401 - except ImportError: - self.skipTest("datafusion not installed") from pypaimon.ray.merge_condition import filter_batch batch = pa.table({ 's.id': pa.array([1, 2, 3], type=pa.int32()), From 4849d577fa1a40de9722ab5ecdc9f4bf4750baeb Mon Sep 17 00:00:00 2001 From: xiaohongbo Date: Wed, 3 Jun 2026 18:17:20 +0800 Subject: [PATCH 7/7] [ray] Add missing skip decorator to target-ref rejection test test_not_matched_condition_rejects_target_refs also requires datafusion (_prepare calls _require_datafusion before the ValueError check). --- .../pypaimon/tests/ray_data_evolution_merge_into_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/paimon-python/pypaimon/tests/ray_data_evolution_merge_into_test.py b/paimon-python/pypaimon/tests/ray_data_evolution_merge_into_test.py index 0fcd3e219caf..47981088f2d3 100644 --- a/paimon-python/pypaimon/tests/ray_data_evolution_merge_into_test.py +++ b/paimon-python/pypaimon/tests/ray_data_evolution_merge_into_test.py @@ -162,6 +162,7 @@ def test_source_missing_on_col_raises(self): ) self.assertIn("'id'", str(ctx.exception)) + @unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON) def test_not_matched_condition_rejects_target_refs(self): target = self._create_table() with self.assertRaises(ValueError) as ctx: