diff --git a/docs/docs/pypaimon/ray-data.md b/docs/docs/pypaimon/ray-data.md index a3c59c9f68b8..d160b4302f65 100644 --- a/docs/docs/pypaimon/ray-data.md +++ b/docs/docs/pypaimon/ray-data.md @@ -374,10 +374,16 @@ Conditions use SQL-style expressions with `s.` (source) and `t.` (target) column prefixes. `WhenNotMatched` conditions may only reference source columns (`s.*`). Requires the `datafusion` package: `pip install pypaimon[sql]`. -- `update` / `insert`: only `"*"` is supported in this PR. A future follow-up - will add mapping-based SET (e.g. `{"col": "s.col"}`) where values are - analyzable string expressions (`"s."`, `"t."`, or literals), - not Python callables. +- `update` / `insert`: `"*"` updates/inserts all non-blob columns from source. + A mapping selects specific columns: + ```python + from pypaimon.ray import source_col, target_col, lit + + WhenMatched(update={"age": source_col("age"), "name": target_col("name")}) + WhenNotMatched(insert={"id": source_col("id"), "status": lit("new")}) + ``` + `"s."` / `"t."` shorthands also work (`t.*` only in update). + Use `lit()` for literals starting with `s.` or `t.`. - `condition`: an optional SQL-style boolean expression. Use `s.` and `t.` to reference source and target columns. diff --git a/paimon-python/pypaimon/ray/__init__.py b/paimon-python/pypaimon/ray/__init__.py index 9161f3cbb3b7..4280187956e3 100644 --- a/paimon-python/pypaimon/ray/__init__.py +++ b/paimon-python/pypaimon/ray/__init__.py @@ -21,6 +21,11 @@ WhenNotMatched, merge_into, ) +from pypaimon.ray.data_evolution_merge_transform import ( + source_col, + target_col, + lit, +) __all__ = [ "read_paimon", @@ -28,4 +33,7 @@ "merge_into", "WhenMatched", "WhenNotMatched", + "source_col", + "target_col", + "lit", ] diff --git a/paimon-python/pypaimon/ray/data_evolution_merge_into.py b/paimon-python/pypaimon/ray/data_evolution_merge_into.py index 5be68f301e05..35451bb744e9 100644 --- a/paimon-python/pypaimon/ray/data_evolution_merge_into.py +++ b/paimon-python/pypaimon/ray/data_evolution_merge_into.py @@ -30,8 +30,11 @@ distributed_write_collect_msgs, ) from pypaimon.ray.data_evolution_merge_transform import ( + LiteralValue, OnSpec, SetSpec, + SourceColumnRef, + TargetColumnRef, WhenMatched, WhenNotMatched, _NormalizedClause, @@ -159,20 +162,23 @@ def _prepare(target, source, catalog_options, when_matched, when_not_matched, on f"condition must not reference blob columns, " f"but found: {sorted(blob_refs)}" ) - not_matched_specs = [ - _NormalizedClause( - spec=_normalize_set_spec( - c.insert, settable_field_names, on_map, - ), - condition=c.condition, + not_matched_specs = [] + for c in when_not_matched: + spec = _normalize_set_spec( + c.insert, settable_field_names, on_map, + allow_target_refs=False, + ) + for tk, sk in on_map.items(): + if tk in settable_field_names and tk not in spec: + spec[tk] = SourceColumnRef(sk) + not_matched_specs.append( + _NormalizedClause(spec=spec, condition=c.condition) ) - for c in when_not_matched - ] source_ds = _normalize_source(source, catalog_options) _validate_source_on_cols(source_ds, source_on_cols) _validate_source_has_target_cols( - source_ds, settable_field_names, on_map, + source_ds, matched_specs + not_matched_specs, ) if has_condition: @@ -398,8 +404,8 @@ def _needed_target_cols( set_by_all = set(update_cols) for clause in clauses: for value in clause.spec.values(): - if isinstance(value, str) and value.startswith("t."): - needed.add(value[2:]) + if isinstance(value, TargetColumnRef): + needed.add(value.column) set_by_all &= set(clause.spec.keys()) needed |= set(update_cols) - set_by_all return [c for c in all_target_cols if c in needed] @@ -427,15 +433,66 @@ def _normalize_set_spec( spec: SetSpec, target_field_names: Sequence[str], on_map: Optional[Mapping[str, str]] = None, + allow_target_refs: bool = True, ) -> Dict[str, Any]: on_map = on_map or {} - if spec != "*": - raise NotImplementedError( - "merge_into currently only supports '*' for update/insert; " - "partial SET will be added in a follow-up PR." + if spec == "*": + return { + col: SourceColumnRef(on_map.get(col, col)) + for col in target_field_names + } + if not isinstance(spec, Mapping): + raise TypeError( + f"SET spec must be '*' or a mapping, got {type(spec).__name__}" ) - # A renamed ON key resolves via the source's ON column, not its own name. - return {col: f"s.{on_map.get(col, col)}" for col in target_field_names} + if not spec: + raise ValueError("SET spec must not be empty") + target_set = set(target_field_names) + for key in spec: + if key not in target_set: + raise ValueError( + f"SET spec references unknown target column '{key}'" + ) + result: Dict[str, Any] = {} + for key, val in spec.items(): + if callable(val) and not isinstance(val, type): + raise TypeError( + "SET values must be source_col(), target_col(), " + "lit(), or literals, not callables" + ) + if isinstance(val, SourceColumnRef): + result[key] = val + elif isinstance(val, TargetColumnRef): + if not allow_target_refs: + raise ValueError( + "INSERT spec must not reference target columns " + f"(t.*), but found: 't.{val.column}'" + ) + if val.column not in target_set: + raise ValueError( + f"SET spec references unknown target column " + f"'{val.column}'" + ) + result[key] = val + elif isinstance(val, LiteralValue): + result[key] = val + elif isinstance(val, str) and val.startswith("s."): + result[key] = SourceColumnRef(val[2:]) + elif isinstance(val, str) and val.startswith("t."): + if not allow_target_refs: + raise ValueError( + "INSERT spec must not reference target columns " + f"(t.*), but found: '{val}'" + ) + ref = val[2:] + if ref not in target_set: + raise ValueError( + f"SET spec references unknown target column '{ref}'" + ) + result[key] = TargetColumnRef(ref) + else: + result[key] = LiteralValue(val) + return result def _normalize_source(source: Any, catalog_options: Dict[str, str]): @@ -483,17 +540,16 @@ def _validate_source_on_cols(source_ds, on: Sequence[str]) -> None: def _validate_source_has_target_cols( source_ds, - target_field_names: Sequence[str], - on_map: Mapping[str, str], + specs: List[_NormalizedClause], ) -> None: - """For update='*'/insert='*', source must carry every (non-blob) target - column; otherwise the SET spec resolves to null and silently overwrites.""" names = set(_source_schema_or_raise(source_ds).names) - expected = {on_map.get(c, c) for c in target_field_names} - missing = sorted(expected - names) + needed = set() + for clause in specs: + for val in clause.spec.values(): + if isinstance(val, SourceColumnRef): + needed.add(val.column) + missing = sorted(needed - names) if missing: raise ValueError( - f"source is missing target columns {missing}; " - f"update='*'/insert='*' requires the source to carry every " - f"(non-blob) target column." + f"source is missing columns {missing} referenced by SET spec" ) diff --git a/paimon-python/pypaimon/ray/data_evolution_merge_transform.py b/paimon-python/pypaimon/ray/data_evolution_merge_transform.py index ed786467f1e7..003977f3e7f2 100644 --- a/paimon-python/pypaimon/ray/data_evolution_merge_transform.py +++ b/paimon-python/pypaimon/ray/data_evolution_merge_transform.py @@ -25,6 +25,33 @@ OnSpec = Union[Sequence[str], Mapping[str, str]] +@dataclass(frozen=True) +class SourceColumnRef: + column: str + + +@dataclass(frozen=True) +class TargetColumnRef: + column: str + + +@dataclass(frozen=True) +class LiteralValue: + value: Any + + +def source_col(name: str) -> SourceColumnRef: + return SourceColumnRef(name) + + +def target_col(name: str) -> TargetColumnRef: + return TargetColumnRef(name) + + +def lit(value: Any) -> LiteralValue: + return LiteralValue(value) + + @dataclass class WhenMatched: update: SetSpec @@ -105,18 +132,19 @@ def _resolve_spec_array( on_pairs: Sequence[Tuple[str, str]], out_type: pa.DataType, ): - if isinstance(val, str) and val.startswith("s."): - ref = val[2:] + if isinstance(val, LiteralValue): + return pa.array([val.value] * batch.num_rows, type=out_type) + if isinstance(val, SourceColumnRef): + ref = val.column if f"s.{ref}" in available: return batch.column(f"s.{ref}") for sk, tk in on_pairs: if sk == ref and f"t.{tk}" in available: return batch.column(f"t.{tk}") return pa.nulls(batch.num_rows, type=out_type) - if isinstance(val, str) and val.startswith("t."): - ref = val[2:] - col_name = f"t.{ref}" + if isinstance(val, TargetColumnRef): + col_name = f"t.{val.column}" return batch.column(col_name) if col_name in available else pa.nulls( batch.num_rows, type=out_type ) - return pa.array([val] * batch.num_rows, type=out_type) + raise TypeError(f"unexpected spec value type: {type(val).__name__}") diff --git a/paimon-python/pypaimon/tests/ray_data_evolution_merge_into_test.py b/paimon-python/pypaimon/tests/ray_data_evolution_merge_into_test.py index 47981088f2d3..00814057c23e 100644 --- a/paimon-python/pypaimon/tests/ray_data_evolution_merge_into_test.py +++ b/paimon-python/pypaimon/tests/ray_data_evolution_merge_into_test.py @@ -26,7 +26,10 @@ import ray from pypaimon import CatalogFactory, Schema -from pypaimon.ray import WhenMatched, WhenNotMatched, merge_into +from pypaimon.ray import ( + WhenMatched, WhenNotMatched, merge_into, + source_col, target_col, lit, +) try: import datafusion # noqa: F401 @@ -815,6 +818,472 @@ def test_duplicate_source_filtered_by_condition(self): self.assertEqual(out['name'], ['y']) self.assertEqual(out['age'], [20]) + def test_matched_partial_update(self): + target = self._create_table() + self._write( + target, + pa.Table.from_pydict( + { + 'id': pa.array([1, 2], type=pa.int32()), + 'name': ['a', 'b'], + 'age': pa.array([10, 20], type=pa.int32()), + }, + schema=self.pa_schema, + ), + ) + + source = pa.Table.from_pydict( + { + 'id': pa.array([1, 2], type=pa.int32()), + 'name': ['a2', 'b2'], + 'age': pa.array([99, 88], type=pa.int32()), + }, + schema=self.pa_schema, + ) + + merge_into( + target=target, + source=source, + catalog_options=self.catalog_options, + on=['id'], + when_matched=[WhenMatched(update={'age': 's.age'})], + num_partitions=_TEST_NUM_PARTITIONS, + ) + + out = self._read_sorted(target) + self.assertEqual(out['id'], [1, 2]) + self.assertEqual(out['name'], ['a', 'b']) + self.assertEqual(out['age'], [99, 88]) + + def test_insert_partial_mapping(self): + target = self._create_table() + + source = pa.Table.from_pydict( + { + 'id': pa.array([1, 2], type=pa.int32()), + 'name': ['a', 'b'], + 'age': pa.array([10, 20], type=pa.int32()), + }, + schema=self.pa_schema, + ) + + merge_into( + target=target, + source=source, + catalog_options=self.catalog_options, + on=['id'], + when_not_matched=[ + WhenNotMatched(insert={'id': 's.id', 'name': 's.name'}) + ], + num_partitions=_TEST_NUM_PARTITIONS, + ) + + out = self._read_sorted(target) + self.assertEqual(out['id'], [1, 2]) + self.assertEqual(out['name'], ['a', 'b']) + self.assertEqual(out['age'], [None, None]) + + def test_update_with_literal(self): + target = self._create_table() + self._write( + target, + pa.Table.from_pydict( + { + 'id': pa.array([1], type=pa.int32()), + 'name': ['old'], + 'age': pa.array([10], type=pa.int32()), + }, + schema=self.pa_schema, + ), + ) + + source = pa.Table.from_pydict( + { + 'id': pa.array([1], type=pa.int32()), + 'name': ['ignored'], + 'age': pa.array([99], type=pa.int32()), + }, + schema=self.pa_schema, + ) + + merge_into( + target=target, + source=source, + catalog_options=self.catalog_options, + on=['id'], + when_matched=[WhenMatched(update={'name': 'updated'})], + num_partitions=_TEST_NUM_PARTITIONS, + ) + + out = self._read_sorted(target) + self.assertEqual(out['name'], ['updated']) + self.assertEqual(out['age'], [10]) + + def test_invalid_target_column_rejected(self): + target = self._create_table() + with self.assertRaises(ValueError) as ctx: + merge_into( + target=target, + source=self._source(), + catalog_options=self.catalog_options, + on=['id'], + when_matched=[WhenMatched(update={'nonexistent': 's.id'})], + num_partitions=_TEST_NUM_PARTITIONS, + ) + self.assertIn('nonexistent', str(ctx.exception)) + + def test_invalid_target_ref_rejected(self): + target = self._create_table() + with self.assertRaises(ValueError) as ctx: + merge_into( + target=target, + source=self._source(), + catalog_options=self.catalog_options, + on=['id'], + when_matched=[WhenMatched(update={'name': 't.nme'})], + num_partitions=_TEST_NUM_PARTITIONS, + ) + self.assertIn('nme', str(ctx.exception)) + + def test_empty_mapping_rejected(self): + target = self._create_table() + with self.assertRaises(ValueError): + merge_into( + target=target, + source=self._source(), + catalog_options=self.catalog_options, + on=['id'], + when_matched=[WhenMatched(update={})], + num_partitions=_TEST_NUM_PARTITIONS, + ) + + def test_insert_target_ref_rejected(self): + target = self._create_table() + with self.assertRaises(ValueError) as ctx: + merge_into( + target=target, + source=self._source(), + catalog_options=self.catalog_options, + on=['id'], + when_not_matched=[ + WhenNotMatched(insert={'name': 't.name'}) + ], + num_partitions=_TEST_NUM_PARTITIONS, + ) + self.assertIn('t.', str(ctx.exception)) + + def test_matched_update_with_target_ref(self): + target = self._create_table() + self._write( + target, + pa.Table.from_pydict( + { + 'id': pa.array([1], type=pa.int32()), + 'name': ['old'], + 'age': pa.array([10], type=pa.int32()), + }, + schema=self.pa_schema, + ), + ) + + source = pa.Table.from_pydict( + { + 'id': pa.array([1], type=pa.int32()), + 'name': ['ignored'], + 'age': pa.array([99], type=pa.int32()), + }, + schema=self.pa_schema, + ) + + merge_into( + target=target, + source=source, + catalog_options=self.catalog_options, + on=['id'], + when_matched=[WhenMatched(update={'age': 's.age', 'name': 't.name'})], + num_partitions=_TEST_NUM_PARTITIONS, + ) + + out = self._read_sorted(target) + self.assertEqual(out['name'], ['old']) + self.assertEqual(out['age'], [99]) + + def test_callable_value_rejected(self): + target = self._create_table() + with self.assertRaises(TypeError): + merge_into( + target=target, + source=self._source(), + catalog_options=self.catalog_options, + on=['id'], + when_matched=[WhenMatched(update={'name': lambda r: r})], + num_partitions=_TEST_NUM_PARTITIONS, + ) + + def test_source_missing_referenced_col(self): + target = self._create_table() + source = pa.Table.from_pydict( + {'id': pa.array([1], type=pa.int32())}, + schema=pa.schema([('id', pa.int32())]), + ) + with self.assertRaises(ValueError) as ctx: + merge_into( + target=target, + source=source, + catalog_options=self.catalog_options, + on=['id'], + when_matched=[WhenMatched(update={'name': 's.name'})], + num_partitions=_TEST_NUM_PARTITIONS, + ) + self.assertIn('name', str(ctx.exception)) + + def test_partial_insert_auto_fills_on_key(self): + target = self._create_table() + + source = pa.Table.from_pydict( + { + 'id': pa.array([1, 2], type=pa.int32()), + 'name': ['a', 'b'], + 'age': pa.array([10, 20], type=pa.int32()), + }, + schema=self.pa_schema, + ) + + merge_into( + target=target, + source=source, + catalog_options=self.catalog_options, + on=['id'], + when_not_matched=[ + WhenNotMatched(insert={'name': 's.name'}) + ], + num_partitions=_TEST_NUM_PARTITIONS, + ) + + out = self._read_sorted(target) + self.assertEqual(out['id'], [1, 2]) + self.assertEqual(out['name'], ['a', 'b']) + + def test_partial_insert_renamed_on_key_auto_filled(self): + target = self._create_table() + + source_schema = pa.schema([ + ('uid', pa.int32()), + ('name', pa.string()), + ('age', pa.int32()), + ]) + source = pa.Table.from_pydict( + { + 'uid': pa.array([1, 2], type=pa.int32()), + 'name': ['a', 'b'], + 'age': pa.array([10, 20], type=pa.int32()), + }, + schema=source_schema, + ) + + merge_into( + target=target, + source=source, + catalog_options=self.catalog_options, + on={'id': 'uid'}, + when_not_matched=[ + WhenNotMatched(insert={'name': 's.name'}) + ], + num_partitions=_TEST_NUM_PARTITIONS, + ) + + out = self._read_sorted(target) + self.assertEqual(out['id'], [1, 2]) + self.assertEqual(out['name'], ['a', 'b']) + + def test_explicit_source_ref_not_remapped_by_on_key(self): + target = self._create_table() + self._write( + target, + pa.Table.from_pydict( + { + 'id': pa.array([1], type=pa.int32()), + 'name': ['old'], + 'age': pa.array([10], type=pa.int32()), + }, + schema=self.pa_schema, + ), + ) + + source_schema = pa.schema([ + ('uid', pa.int32()), + ('id', pa.int32()), + ('name', pa.string()), + ('age', pa.int32()), + ]) + source = pa.Table.from_pydict( + { + 'uid': pa.array([1], type=pa.int32()), + 'id': pa.array([42], type=pa.int32()), + 'name': ['new'], + 'age': pa.array([99], type=pa.int32()), + }, + schema=source_schema, + ) + + merge_into( + target=target, + source=source, + catalog_options=self.catalog_options, + on={'id': 'uid'}, + when_matched=[WhenMatched(update={ + 'age': source_col('id'), + })], + num_partitions=_TEST_NUM_PARTITIONS, + ) + + out = self._read_sorted(target) + self.assertEqual(out['age'], [42]) + self.assertEqual(out['name'], ['old']) + + def test_renamed_on_key_missing_source_col_rejected(self): + target = self._create_table() + source_schema = pa.schema([ + ('uid', pa.int32()), + ('name', pa.string()), + ('age', pa.int32()), + ]) + source = pa.Table.from_pydict( + { + 'uid': pa.array([1], type=pa.int32()), + 'name': ['a'], + 'age': pa.array([10], type=pa.int32()), + }, + schema=source_schema, + ) + with self.assertRaises(ValueError) as ctx: + merge_into( + target=target, + source=source, + catalog_options=self.catalog_options, + on={'id': 'uid'}, + when_matched=[WhenMatched(update={ + 'id': source_col('id'), + })], + num_partitions=_TEST_NUM_PARTITIONS, + ) + self.assertIn('id', str(ctx.exception)) + + def test_lit_prevents_column_ref_interpretation(self): + target = self._create_table() + self._write( + target, + pa.Table.from_pydict( + { + 'id': pa.array([1], type=pa.int32()), + 'name': ['old'], + 'age': pa.array([10], type=pa.int32()), + }, + schema=self.pa_schema, + ), + ) + + source = pa.Table.from_pydict( + { + 'id': pa.array([1], type=pa.int32()), + 'name': ['ignored'], + 'age': pa.array([99], type=pa.int32()), + }, + schema=self.pa_schema, + ) + + merge_into( + target=target, + source=source, + catalog_options=self.catalog_options, + on=['id'], + when_matched=[WhenMatched(update={ + 'name': lit('s.active'), + })], + num_partitions=_TEST_NUM_PARTITIONS, + ) + + out = self._read_sorted(target) + self.assertEqual(out['name'], ['s.active']) + self.assertEqual(out['age'], [10]) + + def test_source_col_helper(self): + target = self._create_table() + self._write( + target, + pa.Table.from_pydict( + { + 'id': pa.array([1], type=pa.int32()), + 'name': ['old'], + 'age': pa.array([10], type=pa.int32()), + }, + schema=self.pa_schema, + ), + ) + + source = pa.Table.from_pydict( + { + 'id': pa.array([1], type=pa.int32()), + 'name': ['new'], + 'age': pa.array([99], type=pa.int32()), + }, + schema=self.pa_schema, + ) + + merge_into( + target=target, + source=source, + catalog_options=self.catalog_options, + on=['id'], + when_matched=[WhenMatched(update={ + 'age': source_col('age'), + })], + num_partitions=_TEST_NUM_PARTITIONS, + ) + + out = self._read_sorted(target) + self.assertEqual(out['name'], ['old']) + self.assertEqual(out['age'], [99]) + + def test_target_col_helper(self): + target = self._create_table() + self._write( + target, + pa.Table.from_pydict( + { + 'id': pa.array([1], type=pa.int32()), + 'name': ['keep'], + 'age': pa.array([10], type=pa.int32()), + }, + schema=self.pa_schema, + ), + ) + + source = pa.Table.from_pydict( + { + 'id': pa.array([1], type=pa.int32()), + 'name': ['ignored'], + 'age': pa.array([99], type=pa.int32()), + }, + schema=self.pa_schema, + ) + + merge_into( + target=target, + source=source, + catalog_options=self.catalog_options, + on=['id'], + when_matched=[WhenMatched(update={ + 'age': source_col('age'), + 'name': target_col('name'), + })], + num_partitions=_TEST_NUM_PARTITIONS, + ) + + out = self._read_sorted(target) + self.assertEqual(out['name'], ['keep']) + self.assertEqual(out['age'], [99]) + class TargetProjectionTest(unittest.TestCase):