From 7f41e246bb3a1f1fa4b58649554c8c3976b91ee8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=96=86=E5=AE=87?= Date: Mon, 1 Jun 2026 18:25:11 +0800 Subject: [PATCH 1/5] [python] support chunk shuffle for append table --- .../scanner/chunk_shuffle_split_generator.py | 373 ++++++++ .../pypaimon/read/scanner/file_scanner.py | 66 +- paimon-python/pypaimon/read/table_scan.py | 4 + .../chunk_shuffle_split_generator_test.py | 800 ++++++++++++++++++ 4 files changed, 1242 insertions(+), 1 deletion(-) create mode 100644 paimon-python/pypaimon/read/scanner/chunk_shuffle_split_generator.py create mode 100644 paimon-python/pypaimon/tests/scanner/chunk_shuffle_split_generator_test.py diff --git a/paimon-python/pypaimon/read/scanner/chunk_shuffle_split_generator.py b/paimon-python/pypaimon/read/scanner/chunk_shuffle_split_generator.py new file mode 100644 index 000000000000..c64b620e398b --- /dev/null +++ b/paimon-python/pypaimon/read/scanner/chunk_shuffle_split_generator.py @@ -0,0 +1,373 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import random +from abc import abstractmethod +from collections import defaultdict +from dataclasses import dataclass +from typing import Any, List, Optional, Tuple + +from pypaimon.globalindex.indexed_split import IndexedSplit +from pypaimon.manifest.schema.data_file_meta import DataFileMeta +from pypaimon.manifest.schema.manifest_entry import ManifestEntry +from pypaimon.read.scanner.split_generator import AbstractSplitGenerator +from pypaimon.read.sliced_split import SlicedSplit +from pypaimon.read.split import DataSplit, Split +from pypaimon.table.row.generic_row import GenericRow +from pypaimon.utils.range import Range + + +def _null_safe_partition_key(partition_values) -> tuple: + """Wrap each partition value with a None-aware tag so tuples that mix + null and non-null partition values can be ordered without raising + ``TypeError: '<' not supported between instances of 'NoneType' and 'str'``. + Paimon supports null partition values; Python 3 refuses to compare + None against str/int directly. + """ + return tuple((v is None, v) for v in partition_values) + + +@dataclass +class _Chunk: + """A unit of work for one DataLoader read. ``segments`` carries + subclass-specific payload (file segments for append, aligned-group + segments for data evolution). + """ + partition: GenericRow + bucket: int + segments: List[Any] + + +class ChunkShuffleSplitGeneratorBase(AbstractSplitGenerator): + """Common scaffolding for deterministic chunk-shuffled split generation. + + Pipeline (template method, in :meth:`create_splits`): + 1. Stable-sort entries (key from :meth:`_sort_key`) so manifest-read + parallelism cannot bleed into the output. + 2. Group by (partition, bucket); iterate groups in sorted-key order. + 3. Per group, call :meth:`_slice_group_into_chunks` to produce a list + of segment lists — one segment list per chunk. + 4. Wrap each chunk with its (partition, bucket) into ``_Chunk``, + concatenate across groups. + 5. ``random.Random(seed).shuffle`` all chunks. + 6. If sharded, take this worker's slice via balanced ``_compute_shard_range``. + 7. Map each chunk through :meth:`_chunk_to_split`. + + Subclasses implement the three abstract hooks. Reader paths + (``RawFileSplitRead`` for append, ``DataEvolutionSplitRead`` for DE) + are unchanged because chunks ride on existing wrappers + (``SlicedSplit`` / ``IndexedSplit``). + """ + + def __init__( + self, + table, + target_split_size: int, + open_file_cost: int, + deletion_files_map=None, + seed: int = 0, + chunk_size: int = 0, + ): + super().__init__(table, target_split_size, open_file_cost, deletion_files_map) + self.seed = seed + self.chunk_size = chunk_size + + def create_splits(self, file_entries: List[ManifestEntry]) -> List[Split]: + if not file_entries: + return [] + + sorted_entries = sorted(file_entries, key=self._sort_key) + + partitioned: "defaultdict[Tuple[tuple, int], List[ManifestEntry]]" = defaultdict(list) + for entry in sorted_entries: + partitioned[(tuple(entry.partition.values), entry.bucket)].append(entry) + + all_chunks: List[_Chunk] = [] + for key in sorted( + partitioned.keys(), + key=lambda k: (_null_safe_partition_key(k[0]), k[1]), + ): + entries_in_group = partitioned[key] + partition_row = entries_in_group[0].partition + bucket = entries_in_group[0].bucket + # Materialize file_path once per unique file in this group. + seen_paths: set = set() + for entry in entries_in_group: + f = entry.file + if f.file_name in seen_paths: + continue + seen_paths.add(f.file_name) + f.set_file_path( + self.table.table_path, + partition_row, + bucket, + self.default_part_value, + ) + for segments in self._slice_group_into_chunks(entries_in_group): + all_chunks.append(_Chunk(partition_row, bucket, segments)) + + rng = random.Random(self.seed) + rng.shuffle(all_chunks) + + if self.idx_of_this_subtask is not None: + start, end = self._compute_shard_range(len(all_chunks)) + all_chunks = all_chunks[start:end] + + return [self._chunk_to_split(c) for c in all_chunks] + + @abstractmethod + def _sort_key(self, entry: ManifestEntry): + """Return a comparable, deterministic key for stable sort.""" + + @abstractmethod + def _slice_group_into_chunks(self, entries: List[ManifestEntry]) -> List[List[Any]]: + """Cut one (partition, bucket) group into chunks of segments. + + Each returned inner list represents one chunk; segment shape is + subclass-defined. + """ + + @abstractmethod + def _chunk_to_split(self, chunk: _Chunk) -> Split: + """Wrap a chunk into a Split that the existing readers consume.""" + + +# --------------------------------------------------------------------------- +# Append (non-DE, non-DV) implementation +# --------------------------------------------------------------------------- + + +@dataclass +class _FileSegment: + """A contiguous slice of a data file inside one chunk. + + start/end are half-open row offsets within the file when the chunk + boundary falls inside the file; both are None when the chunk owns + the full file (so SlicedSplit's shard_file_idx_map can skip it and + treat the file as full — see sliced_split.py:73-78). + """ + file: DataFileMeta + start: Optional[int] + end: Optional[int] + + +class AppendChunkShuffleSplitGenerator(ChunkShuffleSplitGeneratorBase): + """Chunk-shuffled splits for plain append tables (non-PK, non-DV, non-DE).""" + + def _sort_key(self, entry: ManifestEntry): + return ( + _null_safe_partition_key(entry.partition.values), + entry.bucket, + entry.file.file_name, + ) + + def _slice_group_into_chunks( + self, entries: List[ManifestEntry] + ) -> List[List[_FileSegment]]: + """Cut a (partition, bucket) group into chunks of at most + ``self.chunk_size`` rows. ``chunk_size`` is a hard upper bound: + the last chunk may be smaller, but no chunk exceeds it. + """ + chunks: List[List[_FileSegment]] = [] + current: List[_FileSegment] = [] + current_rows = 0 + + for entry in entries: + file = entry.file + offset = 0 + remaining = file.row_count + while remaining > 0: + avail = self.chunk_size - current_rows + if avail <= 0: + chunks.append(current) + current = [] + current_rows = 0 + avail = self.chunk_size + + take = min(remaining, avail) + + if take == file.row_count and offset == 0: + current.append(_FileSegment(file, None, None)) + else: + current.append(_FileSegment(file, offset, offset + take)) + + current_rows += take + offset += take + remaining -= take + + if current: + chunks.append(current) + + return chunks + + def _chunk_to_split(self, chunk: _Chunk) -> Split: + files: List[DataFileMeta] = [] + shard_file_idx_map = {} + for seg in chunk.segments: + files.append(seg.file) + if seg.start is not None and seg.end is not None: + shard_file_idx_map[seg.file.file_name] = (seg.start, seg.end) + + # set_file_path is already done once per unique file in + # ChunkShuffleSplitGeneratorBase.create_splits. + + data_split = DataSplit( + files=files, + partition=chunk.partition, + bucket=chunk.bucket, + raw_convertible=True, + data_deletion_files=None, + ) + + if shard_file_idx_map: + return SlicedSplit(data_split, shard_file_idx_map) + return data_split + + +# --------------------------------------------------------------------------- +# Data Evolution implementation +# --------------------------------------------------------------------------- + + +@dataclass +class _AlignedGroupSegment: + """A row_id sub-range over one row-id-aligned file group. + + ``files`` is the entire group (may include blob/vector siblings), + so the reader sees every column file even when only a slice of the + group's row_id range lands in this chunk. ``row_range`` is the + inclusive global row_id range this segment owns. + """ + files: List[DataFileMeta] + row_range: Range + + +class DataEvolutionChunkShuffleSplitGenerator(ChunkShuffleSplitGeneratorBase): + """Chunk-shuffled splits for data-evolution append tables. + + The minimum cuttable unit is a row_id-aligned file group: cutting + inside one group would orphan column files relative to the row_id + range, so we keep groups intact and only slice along their row_id + axis. Each chunk maps to an :class:`IndexedSplit` whose ``row_ranges`` + bound the readable slice for that chunk. + """ + + def _sort_key(self, entry: ManifestEntry): + first_row_id = ( + entry.file.first_row_id + if entry.file.first_row_id is not None + else float('-inf') + ) + is_special = 1 if ( + DataFileMeta.is_blob_file(entry.file.file_name) + or DataFileMeta.is_vector_file(entry.file.file_name) + ) else 0 + return ( + _null_safe_partition_key(entry.partition.values), + entry.bucket, + first_row_id, + is_special, + entry.file.file_name, + ) + + def _slice_group_into_chunks( + self, entries: List[ManifestEntry] + ) -> List[List[_AlignedGroupSegment]]: + files = [e.file for e in entries] + # (Range, [files]) pairs sorted by row_id — see helper docstring. + aligned_groups = self._split_by_row_id_with_range(files) + + chunks: List[List[_AlignedGroupSegment]] = [] + current: List[_AlignedGroupSegment] = [] + current_rows = 0 + + for group_range, group_files in aligned_groups: + offset = 0 + group_rows = group_range.count() + while offset < group_rows: + avail = self.chunk_size - current_rows + if avail <= 0: + chunks.append(current) + current = [] + current_rows = 0 + avail = self.chunk_size + + take = min(group_rows - offset, avail) + seg_range = Range( + group_range.from_ + offset, + group_range.from_ + offset + take - 1, + ) + current.append(_AlignedGroupSegment(group_files, seg_range)) + current_rows += take + offset += take + + if current: + chunks.append(current) + + return chunks + + def _chunk_to_split(self, chunk: _Chunk) -> Split: + segments = chunk.segments + if len(segments) == 1: + all_files = segments[0].files + row_ranges = [segments[0].row_range] + else: + all_files = [] + row_ranges = [] + for seg in segments: + all_files.extend(seg.files) + row_ranges.append(seg.row_range) + row_ranges.sort(key=lambda r: r.from_) + + data_split = DataSplit( + files=all_files, + partition=chunk.partition, + bucket=chunk.bucket, + raw_convertible=False, + data_deletion_files=None, + ) + return IndexedSplit(data_split, row_ranges, scores=None) + + @staticmethod + def _split_by_row_id_with_range( + files: List[DataFileMeta], + ) -> List[Tuple[Range, List[DataFileMeta]]]: + """Group files by overlapping row_id range, returning (range, files) + pairs sorted by ``range.from_``. + + Mirrors :meth:`DataEvolutionSplitGenerator._split_by_row_id` but + also returns the merged row_id range per group, which the chunk + slicer needs to drive row-count accumulation. Files without + ``first_row_id`` are skipped (DE invariant guarantees presence; + defensive in case stray entries sneak in). + """ + list_ranges = [f.row_id_range() for f in files if f.row_id_range() is not None] + if not list_ranges: + return [] + sorted_ranges = Range.sort_and_merge_overlap(list_ranges, True, False) + + range_to_files: "dict[Range, List[DataFileMeta]]" = {} + for f in files: + file_range = f.row_id_range() + if file_range is None: + continue + for r in sorted_ranges: + if r.overlaps(file_range): + range_to_files.setdefault(r, []).append(f) + break + + return sorted(range_to_files.items(), key=lambda kv: kv[0].from_) diff --git a/paimon-python/pypaimon/read/scanner/file_scanner.py b/paimon-python/pypaimon/read/scanner/file_scanner.py index 5424d2d51910..650bbe8ca402 100755 --- a/paimon-python/pypaimon/read/scanner/file_scanner.py +++ b/paimon-python/pypaimon/read/scanner/file_scanner.py @@ -40,6 +40,10 @@ AppendTableSplitGenerator from pypaimon.read.scanner.bucket_select_converter import \ create_bucket_selector +from pypaimon.read.scanner.chunk_shuffle_split_generator import ( + AppendChunkShuffleSplitGenerator, + DataEvolutionChunkShuffleSplitGenerator, +) from pypaimon.read.scanner.data_evolution_split_generator import \ DataEvolutionSplitGenerator from pypaimon.read.scanner.primary_key_table_split_generator import \ @@ -204,6 +208,7 @@ def __init__( self.number_of_para_subtasks = None self.start_pos_of_this_subtask = None self.end_pos_of_this_subtask = None + self.chunk_shuffle: Optional[Tuple[int, int]] = None self.only_read_real_buckets = options.bucket() == BucketMode.POSTPONE_BUCKET.value self.data_evolution = options.data_evolution_enabled() @@ -243,7 +248,34 @@ def _deletion_files_map(self, entries: List[ManifestEntry]) -> Dict[tuple, Dict[ def scan(self) -> Plan: start_ms = time.time() * 1000 # Create appropriate split generator based on table type - if self.table.is_primary_key_table: + if self.chunk_shuffle is not None: + self._validate_chunk_shuffle_compat() + seed, chunk_size = self.chunk_shuffle + # Both append and DE paths use plan_files() directly: the + # predicate is partition-only (enforced by + # _validate_chunk_shuffle_compat), so manifest_entry-level + # partition pruning in plan_files() is the only filter we + # want — no row_id range pushdown, no global index lookup. + entries = self.plan_files() + if self.data_evolution: + split_generator = DataEvolutionChunkShuffleSplitGenerator( + self.table, + self.target_split_size, + self.open_file_cost, + self._deletion_files_map(entries), + seed=seed, + chunk_size=chunk_size, + ) + else: + split_generator = AppendChunkShuffleSplitGenerator( + self.table, + self.target_split_size, + self.open_file_cost, + self._deletion_files_map(entries), + seed=seed, + chunk_size=chunk_size, + ) + elif self.table.is_primary_key_table: entries = self.plan_files() split_generator = PrimaryKeyTableSplitGenerator( self.table, @@ -441,6 +473,38 @@ def scan_with_stats(self) -> Tuple[Plan, ScanStats]: plan = self.scan() return plan, self.scan_stats + def with_chunk_shuffle(self, seed: int, chunk_size: int) -> 'FileScanner': + if not isinstance(seed, int): + raise ValueError("chunk_shuffle seed must be an int") + if not isinstance(chunk_size, int) or chunk_size <= 0: + raise ValueError("chunk_shuffle chunk_size must be a positive int") + self.chunk_shuffle = (seed, chunk_size) + return self + + def _validate_chunk_shuffle_compat(self) -> None: + if self.table.is_primary_key_table: + raise ValueError("chunk_shuffle only supports append tables") + if self.deletion_vectors_enabled: + raise ValueError("chunk_shuffle not supported with deletion vectors") + if self.start_pos_of_this_subtask is not None: + raise ValueError("chunk_shuffle cannot combine with with_slice") + if self.limit is not None: + raise ValueError("chunk_shuffle cannot combine with limit") + if self._global_index_result is not None: + raise ValueError("chunk_shuffle cannot combine with global index") + # Only partition predicates are allowed: row-level / column-level + # predicates would silently shrink each chunk's effective row count, + # breaking the chunk_size contract DataLoader callers expect. + if self.predicate is not None: + partition_keys = set(self.table.partition_keys or []) + non_partition_fields = _get_all_fields(self.predicate) - partition_keys + if non_partition_fields: + raise ValueError( + "chunk_shuffle predicate must reference only partition keys; " + "got non-partition fields: " + f"{sorted(non_partition_fields)}" + ) + def _apply_push_down_limit(self, splits: List[DataSplit]) -> List[DataSplit]: """Mirror Java ``DataTableBatchScan.applyPushDownLimit``: sum the DV-aware ``merged_row_count`` (== Java ``Split.mergedRowCount()``) diff --git a/paimon-python/pypaimon/read/table_scan.py b/paimon-python/pypaimon/read/table_scan.py index 623261803503..06eafc38d2d9 100755 --- a/paimon-python/pypaimon/read/table_scan.py +++ b/paimon-python/pypaimon/read/table_scan.py @@ -158,3 +158,7 @@ def with_slice(self, start_pos, end_pos) -> 'TableScan': def with_global_index_result(self, result) -> 'TableScan': self.file_scanner.with_global_index_result(result) return self + + def with_chunk_shuffle(self, seed: int, chunk_size: int) -> 'TableScan': + self.file_scanner.with_chunk_shuffle(seed, chunk_size) + return self diff --git a/paimon-python/pypaimon/tests/scanner/chunk_shuffle_split_generator_test.py b/paimon-python/pypaimon/tests/scanner/chunk_shuffle_split_generator_test.py new file mode 100644 index 000000000000..be50cfa637d1 --- /dev/null +++ b/paimon-python/pypaimon/tests/scanner/chunk_shuffle_split_generator_test.py @@ -0,0 +1,800 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Tests for ChunkShuffleSplitGenerator and TableScan.with_chunk_shuffle. + +Algorithmic tests use Mock entries so they don't touch disk; the +end-to-end test writes a real append table and validates that all +workers together cover the data exactly once. +""" + +import os +import shutil +import tempfile +import unittest +from unittest.mock import Mock + +import pyarrow as pa + +from pypaimon import CatalogFactory, Schema +from pypaimon.globalindex.indexed_split import IndexedSplit +from pypaimon.manifest.schema.data_file_meta import DataFileMeta +from pypaimon.read.scanner.chunk_shuffle_split_generator import ( + AppendChunkShuffleSplitGenerator, + DataEvolutionChunkShuffleSplitGenerator, +) +from pypaimon.read.sliced_split import SlicedSplit +from pypaimon.read.split import DataSplit +from pypaimon.utils.range import Range + + +def _mock_table(table_path='/tmp/_chunk_shuffle_test_path'): + table = Mock() + table.table_path = table_path + table.options = Mock() + return table + + +def _mock_entry(partition_values, bucket, file_name, row_count, file_size=1024): + entry = Mock() + entry.partition = Mock() + entry.partition.values = partition_values + entry.bucket = bucket + entry.file = Mock() + entry.file.file_name = file_name + entry.file.file_size = file_size + entry.file.row_count = row_count + # Swallow set_file_path so we don't need to mock partition path encoding. + entry.file.set_file_path = Mock() + return entry + + +def _make_generator(seed, chunk_size, table=None): + if table is None: + table = _mock_table() + return AppendChunkShuffleSplitGenerator( + table, + target_split_size=128 * 1024 * 1024, + open_file_cost=4 * 1024 * 1024, + deletion_files_map=None, + seed=seed, + chunk_size=chunk_size, + ) + + +def _make_de_generator(seed, chunk_size, table=None): + if table is None: + table = _mock_table() + return DataEvolutionChunkShuffleSplitGenerator( + table, + target_split_size=128 * 1024 * 1024, + open_file_cost=4 * 1024 * 1024, + deletion_files_map=None, + seed=seed, + chunk_size=chunk_size, + ) + + +def _mock_de_entry(partition_values, bucket, file_name, first_row_id, row_count, file_size=1024): + """A DE-flavoured mock entry: file carries first_row_id and a real + Range so :meth:`row_id_range` and ``Range.overlaps`` work.""" + entry = Mock() + entry.partition = Mock() + entry.partition.values = partition_values + entry.bucket = bucket + file = Mock(spec=DataFileMeta) + file.file_name = file_name + file.file_size = file_size + file.row_count = row_count + file.first_row_id = first_row_id + file.row_id_range = lambda f=first_row_id, c=row_count: Range(f, f + c - 1) + file.set_file_path = Mock() + entry.file = file + return entry + + +def _split_signature(split): + """A stable, comparable identity for a split — what the worker would actually read.""" + if isinstance(split, SlicedSplit): + underlying = split.data_split() + files = tuple(f.file_name for f in underlying.files) + idx_map = tuple(sorted(split.shard_file_idx_map().items())) + return (tuple(underlying.partition.values), underlying.bucket, files, idx_map) + if isinstance(split, IndexedSplit): + underlying = split.data_split() + files = tuple(sorted(f.file_name for f in underlying.files)) + ranges = tuple((r.from_, r.to) for r in split.row_ranges()) + return (tuple(underlying.partition.values), underlying.bucket, files, ranges) + if isinstance(split, DataSplit): + files = tuple(f.file_name for f in split.files) + return (tuple(split.partition.values), split.bucket, files, ()) + raise AssertionError("unexpected split type: %r" % type(split)) + + +def _split_rows(split): + """Effective row count this split actually exposes.""" + return split.row_count + + +class ChunkShuffleSplitGeneratorAlgoTest(unittest.TestCase): + + def test_no_entries_returns_empty(self): + gen = _make_generator(seed=1, chunk_size=100) + self.assertEqual(gen.create_splits([]), []) + + def test_full_files_no_truncation(self): + entries = [ + _mock_entry([], 0, 'f1', 100), + _mock_entry([], 0, 'f2', 100), + _mock_entry([], 0, 'f3', 100), + ] + gen = _make_generator(seed=1, chunk_size=100) + splits = gen.create_splits(entries) + # 3 chunks, each holding exactly one whole file → all DataSplit, no SlicedSplit + self.assertEqual(len(splits), 3) + for s in splits: + self.assertIsInstance(s, DataSplit) + self.assertEqual(s.row_count, 100) + + def test_chunk_truncates_inside_file(self): + # one file of 250 rows, chunk_size 100 → 3 chunks: 100, 100, 50 + entries = [_mock_entry([], 0, 'f1', 250)] + gen = _make_generator(seed=1, chunk_size=100) + splits = gen.create_splits(entries) + self.assertEqual(len(splits), 3) + # All three chunks slice the same file → all SlicedSplit + for s in splits: + self.assertIsInstance(s, SlicedSplit) + # union of (start, end) intervals must cover [0, 250) + intervals = sorted(s.shard_file_idx_map()['f1'] for s in splits) + self.assertEqual(intervals, [(0, 100), (100, 200), (200, 250)]) + total = sum(end - start for start, end in intervals) + self.assertEqual(total, 250) + + def test_chunk_spans_multiple_files(self): + # f1=30, f2=30, f3=30, chunk_size=50 → chunks: [f1(30)+f2(0,20)], [f2(20,30)+f3(0,40 cap 30=30)] ... + entries = [ + _mock_entry([], 0, 'f1', 30), + _mock_entry([], 0, 'f2', 30), + _mock_entry([], 0, 'f3', 30), + ] + gen = _make_generator(seed=1, chunk_size=50) + splits = gen.create_splits(entries) + # total 90 rows, chunk_size 50 → 2 chunks (50 + 40) + self.assertEqual(len(splits), 2) + total_rows = sum(_split_rows(s) for s in splits) + self.assertEqual(total_rows, 90) + + def test_chunk_size_larger_than_total(self): + entries = [ + _mock_entry([], 0, 'f1', 30), + _mock_entry([], 0, 'f2', 30), + ] + gen = _make_generator(seed=1, chunk_size=1000) + splits = gen.create_splits(entries) + self.assertEqual(len(splits), 1) + # No truncation — full files inside one chunk → DataSplit not SlicedSplit + self.assertIsInstance(splits[0], DataSplit) + self.assertEqual(_split_rows(splits[0]), 60) + + def test_deterministic_same_seed_same_order(self): + entries = [_mock_entry([], 0, f'f{i}', 100) for i in range(20)] + gen1 = _make_generator(seed=42, chunk_size=50) + gen2 = _make_generator(seed=42, chunk_size=50) + splits1 = gen1.create_splits(entries) + splits2 = gen2.create_splits(entries) + self.assertEqual( + [_split_signature(s) for s in splits1], + [_split_signature(s) for s in splits2], + ) + + def test_different_seed_different_order(self): + entries = [_mock_entry([], 0, f'f{i}', 100) for i in range(50)] + gen1 = _make_generator(seed=1, chunk_size=100) + gen2 = _make_generator(seed=2, chunk_size=100) + sigs1 = [_split_signature(s) for s in gen1.create_splits(entries)] + sigs2 = [_split_signature(s) for s in gen2.create_splits(entries)] + # Same set of chunks, different order — high probability they differ on 50 items + self.assertEqual(sorted(sigs1), sorted(sigs2)) + self.assertNotEqual(sigs1, sigs2) + + def test_shuffle_actually_reorders(self): + # 20 files in scan order f0..f19. After shuffle the file order should not be sorted. + entries = [_mock_entry([], 0, f'f{i:02d}', 100) for i in range(20)] + gen = _make_generator(seed=42, chunk_size=100) + splits = gen.create_splits(entries) + file_names = [s.files[0].file_name for s in splits] + self.assertNotEqual(file_names, sorted(file_names)) + + def test_shard_round_trip_no_overlap_no_loss(self): + # 13 files × 100 rows = 1300 rows. 4 workers. + entries = [_mock_entry([], 0, f'f{i:02d}', 100) for i in range(13)] + num_workers = 4 + all_sigs = [] + total_rows = 0 + for worker in range(num_workers): + gen = _make_generator(seed=7, chunk_size=100) + gen.with_shard(worker, num_workers) + splits = gen.create_splits(list(entries)) # copy: shuffle is in-place on chunks list + for s in splits: + all_sigs.append(_split_signature(s)) + total_rows += _split_rows(s) + self.assertEqual(total_rows, 13 * 100) + # No duplicate chunks across workers + self.assertEqual(len(all_sigs), len(set(all_sigs))) + # All chunks together equal an unsharded run + unsharded = _make_generator(seed=7, chunk_size=100).create_splits(list(entries)) + self.assertEqual( + sorted(all_sigs), + sorted(_split_signature(s) for s in unsharded), + ) + + def test_shard_balanced_distribution(self): + # 10 chunks across 3 workers → 4, 3, 3 (front-loaded by _compute_shard_range) + entries = [_mock_entry([], 0, f'f{i:02d}', 100) for i in range(10)] + counts = [] + for worker in range(3): + gen = _make_generator(seed=0, chunk_size=100) + gen.with_shard(worker, 3) + counts.append(len(gen.create_splits(list(entries)))) + self.assertEqual(sorted(counts, reverse=True), [4, 3, 3]) + + def test_chunks_fewer_than_workers(self): + # 2 chunks, 5 workers → 3 workers get nothing + entries = [_mock_entry([], 0, f'f{i}', 100) for i in range(2)] + empties = 0 + non_empties = 0 + for worker in range(5): + gen = _make_generator(seed=0, chunk_size=100) + gen.with_shard(worker, 5) + n = len(gen.create_splits(list(entries))) + if n == 0: + empties += 1 + else: + non_empties += 1 + self.assertEqual(n, 1) + self.assertEqual(empties, 3) + self.assertEqual(non_empties, 2) + + def test_multi_partition_no_chunk_crosses_partition(self): + entries = [ + _mock_entry(['p1'], 0, 'f1', 100), + _mock_entry(['p1'], 0, 'f2', 100), + _mock_entry(['p2'], 0, 'f3', 100), + _mock_entry(['p2'], 0, 'f4', 100), + ] + gen = _make_generator(seed=0, chunk_size=100) + splits = gen.create_splits(entries) + # Each split's underlying files come from one partition only + for s in splits: + partitions_in_files = set() + data_split = s.data_split() if isinstance(s, SlicedSplit) else s + partitions_in_files.add(tuple(data_split.partition.values)) + self.assertEqual(len(partitions_in_files), 1) + + def test_null_and_non_null_partitions_sort_safely(self): + # Mixing null and non-null partition values used to raise + # ``TypeError: '<' not supported between instances of 'NoneType' and 'str'`` + # before _null_safe_partition_key. Validate planning succeeds and + # both partitions produce splits. + entries = [ + _mock_entry(['p1'], 0, 'f1', 100), + _mock_entry([None], 0, 'f2', 100), + _mock_entry(['p2'], 0, 'f3', 100), + ] + gen = _make_generator(seed=1, chunk_size=100) + splits = gen.create_splits(entries) + self.assertEqual(len(splits), 3) + partitions = {tuple(_split_signature(s)[0]) for s in splits} + self.assertEqual(partitions, {('p1',), ('p2',), (None,)}) + + def test_input_order_does_not_affect_output_when_same_files(self): + """Manifest read parallelism shouldn't bleed through — sorting is internal.""" + a = _mock_entry([], 0, 'f1', 100) + b = _mock_entry([], 0, 'f2', 100) + c = _mock_entry([], 0, 'f3', 100) + gen1 = _make_generator(seed=99, chunk_size=100) + gen2 = _make_generator(seed=99, chunk_size=100) + sigs1 = [_split_signature(s) for s in gen1.create_splits([a, b, c])] + sigs2 = [_split_signature(s) for s in gen2.create_splits([c, a, b])] + self.assertEqual(sigs1, sigs2) + + +class ChunkShuffleEndToEndTest(unittest.TestCase): + """Real append table → with_chunk_shuffle → multiple workers → union == original.""" + + @classmethod + def setUpClass(cls): + cls.tempdir = tempfile.mkdtemp() + cls.warehouse = os.path.join(cls.tempdir, 'warehouse') + cls.catalog = CatalogFactory.create({'warehouse': cls.warehouse}) + cls.catalog.create_database('default', True) + + @classmethod + def tearDownClass(cls): + shutil.rmtree(cls.tempdir, ignore_errors=True) + + def _create_append_table(self, name, partition_keys=None): + pa_schema = pa.schema([ + ('id', pa.int64()), + ('value', pa.string()), + ('part', pa.string()), + ]) + schema = Schema.from_pyarrow_schema( + pa_schema, partition_keys=partition_keys or []) + identifier = f'default.{name}' + self.catalog.create_table(identifier, schema, False) + return self.catalog.get_table(identifier), pa_schema + + def _write_n_batches(self, table, pa_schema, batches): + wb = table.new_batch_write_builder() + for batch in batches: + tw = wb.new_write() + tc = wb.new_commit() + tw.write_arrow(pa.Table.from_pydict(batch, schema=pa_schema)) + tc.commit(tw.prepare_commit()) + tw.close() + tc.close() + + def test_workers_union_equals_full_table(self): + table, pa_schema = self._create_append_table('cs_union') + # 4 commits × 50 rows = 200 rows across several files + batches = [] + for c in range(4): + base = c * 50 + batches.append({ + 'id': list(range(base, base + 50)), + 'value': [f'v{i}' for i in range(base, base + 50)], + 'part': ['p1'] * 50, + }) + self._write_n_batches(table, pa_schema, batches) + + read_builder = table.new_read_builder() + table_read = read_builder.new_read() + + num_workers = 3 + worker_tables = [] + for w in range(num_workers): + scan = read_builder.new_scan() \ + .with_chunk_shuffle(seed=123, chunk_size=37) \ + .with_shard(w, num_workers) + splits = scan.plan().splits() + if splits: + worker_tables.append(table_read.to_arrow(splits)) + + actual = pa.concat_tables(worker_tables).sort_by('id') if worker_tables else None + self.assertIsNotNone(actual) + self.assertEqual(actual.num_rows, 200) + self.assertEqual(actual.column('id').to_pylist(), list(range(200))) + + def test_deterministic_plan_across_calls(self): + table, pa_schema = self._create_append_table('cs_determinism') + self._write_n_batches(table, pa_schema, [{ + 'id': list(range(100)), + 'value': [f'v{i}' for i in range(100)], + 'part': ['p'] * 100, + }]) + + def plan_files(worker): + scan = table.new_read_builder().new_scan() \ + .with_chunk_shuffle(seed=42, chunk_size=20) \ + .with_shard(worker, 3) + return [_split_signature(s) for s in scan.plan().splits()] + + for worker in range(3): + self.assertEqual(plan_files(worker), plan_files(worker)) + + +class ChunkShuffleCompatibilityTest(unittest.TestCase): + """Validates the reject-on-incompatible-combination matrix.""" + + @classmethod + def setUpClass(cls): + cls.tempdir = tempfile.mkdtemp() + cls.warehouse = os.path.join(cls.tempdir, 'warehouse') + cls.catalog = CatalogFactory.create({'warehouse': cls.warehouse}) + cls.catalog.create_database('default', True) + + @classmethod + def tearDownClass(cls): + shutil.rmtree(cls.tempdir, ignore_errors=True) + + def _append_table(self, name, options=None, partition_keys=None): + if partition_keys: + pa_schema = pa.schema([ + ('id', pa.int64()), + ('value', pa.string()), + ('part', pa.string()), + ]) + else: + pa_schema = pa.schema([('id', pa.int64()), ('value', pa.string())]) + schema = Schema.from_pyarrow_schema( + pa_schema, partition_keys=partition_keys, options=options or {}) + self.catalog.create_table(f'default.{name}', schema, False) + return self.catalog.get_table(f'default.{name}') + + def _pk_table(self, name): + pa_schema = pa.schema([ + pa.field('id', pa.int64(), nullable=False), + ('value', pa.string()), + ]) + schema = Schema.from_pyarrow_schema( + pa_schema, primary_keys=['id'], options={'bucket': '1'}) + self.catalog.create_table(f'default.{name}', schema, False) + return self.catalog.get_table(f'default.{name}') + + def test_pk_table_rejected(self): + table = self._pk_table('cs_pk') + scan = table.new_read_builder().new_scan() + scan.with_chunk_shuffle(seed=1, chunk_size=100) + with self.assertRaisesRegex(ValueError, "only supports append tables"): + scan.plan() + + def test_dv_table_rejected(self): + table = self._append_table('cs_dv', options={'deletion-vectors.enabled': 'true'}) + scan = table.new_read_builder().new_scan() + scan.with_chunk_shuffle(seed=1, chunk_size=100) + with self.assertRaisesRegex(ValueError, "deletion vectors"): + scan.plan() + + def test_with_slice_then_chunk_shuffle_rejected(self): + table = self._append_table('cs_slice') + scan = table.new_read_builder().new_scan() + scan.with_slice(0, 100).with_chunk_shuffle(seed=1, chunk_size=100) + with self.assertRaisesRegex(ValueError, "with_slice"): + scan.plan() + + def test_limit_with_chunk_shuffle_rejected(self): + table = self._append_table('cs_limit') + scan = table.new_read_builder().with_limit(50).new_scan() + scan.with_chunk_shuffle(seed=1, chunk_size=100) + with self.assertRaisesRegex(ValueError, "limit"): + scan.plan() + + def test_invalid_chunk_size(self): + table = self._append_table('cs_invalid') + scan = table.new_read_builder().new_scan() + with self.assertRaisesRegex(ValueError, "chunk_size"): + scan.with_chunk_shuffle(seed=1, chunk_size=0) + with self.assertRaisesRegex(ValueError, "chunk_size"): + scan.with_chunk_shuffle(seed=1, chunk_size=-5) + + def test_column_predicate_rejected(self): + # Non-partition predicate would silently shrink effective chunk + # row counts inside the reader → not allowed. + table = self._append_table('cs_col_pred', partition_keys=['part']) + rb = table.new_read_builder() + col_pred = rb.new_predicate_builder().equal('id', 5) + rb = rb.with_filter(col_pred) + scan = rb.new_scan().with_chunk_shuffle(seed=1, chunk_size=10) + with self.assertRaisesRegex(ValueError, "partition keys"): + scan.plan() + + def test_partition_predicate_allowed(self): + # Filter is partition-only → must succeed and read only the + # matching partition. + table, pa_schema = self._partitioned_table_with_data('cs_part_pred') + + rb = table.new_read_builder() + pred = rb.new_predicate_builder().equal('part', 'p1') + scan = rb.with_filter(pred).new_scan() \ + .with_chunk_shuffle(seed=1, chunk_size=10) + plan = scan.plan() + # All splits should be from partition 'p1' + for split in plan.splits(): + partition_values = split.partition.values + self.assertEqual(tuple(partition_values), ('p1',)) + + def _partitioned_table_with_data(self, name): + pa_schema = pa.schema([ + ('id', pa.int64()), + ('value', pa.string()), + ('part', pa.string()), + ]) + schema = Schema.from_pyarrow_schema(pa_schema, partition_keys=['part']) + identifier = f'default.{name}' + self.catalog.create_table(identifier, schema, False) + table = self.catalog.get_table(identifier) + wb = table.new_batch_write_builder() + for part, ids in [('p1', range(50)), ('p2', range(50, 100))]: + tw = wb.new_write() + tc = wb.new_commit() + tw.write_arrow(pa.Table.from_pydict( + {'id': list(ids), + 'value': [f'v{i}' for i in ids], + 'part': [part] * 50}, + schema=pa_schema, + )) + tc.commit(tw.prepare_commit()) + tw.close() + tc.close() + return table, pa_schema + + +class DataEvolutionChunkShuffleAlgoTest(unittest.TestCase): + """Mock-based tests for the DE chunk slicer.""" + + def test_no_entries_returns_empty(self): + gen = _make_de_generator(seed=1, chunk_size=100) + self.assertEqual(gen.create_splits([]), []) + + def test_full_aligned_groups_one_per_chunk(self): + # Three commits of 100 rows each → three aligned groups. + # chunk_size = 100 → 3 chunks, each holding one group whole. + entries = [ + _mock_de_entry([], 0, 'g0.parquet', 0, 100), + _mock_de_entry([], 0, 'g1.parquet', 100, 100), + _mock_de_entry([], 0, 'g2.parquet', 200, 100), + ] + gen = _make_de_generator(seed=1, chunk_size=100) + splits = gen.create_splits(entries) + self.assertEqual(len(splits), 3) + for s in splits: + self.assertIsInstance(s, IndexedSplit) + self.assertEqual(s.row_count, 100) + self.assertEqual(len(s.row_ranges()), 1) + + def test_aligned_group_split_across_chunks(self): + # One 250-row group, chunk_size=100 → 3 chunks (100, 100, 50). + # All three chunks reference the SAME aligned group's files but + # each carries a distinct row_range slice. + entries = [_mock_de_entry([], 0, 'g0.parquet', 1000, 250)] + gen = _make_de_generator(seed=1, chunk_size=100) + splits = gen.create_splits(entries) + self.assertEqual(len(splits), 3) + + # Union of the three chunks' row_ranges must cover the whole group [1000, 1249]. + ranges = [] + for s in splits: + self.assertIsInstance(s, IndexedSplit) + ranges.extend((r.from_, r.to) for r in s.row_ranges()) + ranges.sort() + self.assertEqual(ranges, [(1000, 1099), (1100, 1199), (1200, 1249)]) + total = sum(r[1] - r[0] + 1 for r in ranges) + self.assertEqual(total, 250) + + def test_chunk_pulls_in_blob_siblings(self): + # One aligned group with a main parquet and a blob sibling sharing the + # row_id range. A single chunk must include BOTH files so the reader + # can union the columns. + entries = [ + _mock_de_entry([], 0, 'g0.parquet', 0, 100), + _mock_de_entry([], 0, 'g0.blob', 0, 100), # .blob ext → is_blob_file + ] + gen = _make_de_generator(seed=1, chunk_size=100) + splits = gen.create_splits(entries) + self.assertEqual(len(splits), 1) + files = sorted(f.file_name for f in splits[0].files) + self.assertEqual(files, ['g0.blob', 'g0.parquet']) + + def test_blob_propagates_when_group_split(self): + # Same scenario but chunk_size halves the group → the blob sibling + # must appear in BOTH chunk splits. + entries = [ + _mock_de_entry([], 0, 'g0.parquet', 0, 100), + _mock_de_entry([], 0, 'g0.blob', 0, 100), + ] + gen = _make_de_generator(seed=1, chunk_size=50) + splits = gen.create_splits(entries) + self.assertEqual(len(splits), 2) + for s in splits: + files = sorted(f.file_name for f in s.files) + self.assertEqual(files, ['g0.blob', 'g0.parquet']) + + def test_deterministic_same_seed(self): + entries = [_mock_de_entry([], 0, f'g{i:02d}.parquet', i * 100, 100) for i in range(20)] + gen1 = _make_de_generator(seed=42, chunk_size=100) + gen2 = _make_de_generator(seed=42, chunk_size=100) + sigs1 = [_split_signature(s) for s in gen1.create_splits(entries)] + sigs2 = [_split_signature(s) for s in gen2.create_splits(entries)] + self.assertEqual(sigs1, sigs2) + + def test_different_seed_reorders(self): + entries = [_mock_de_entry([], 0, f'g{i:02d}.parquet', i * 100, 100) for i in range(50)] + gen1 = _make_de_generator(seed=1, chunk_size=100) + gen2 = _make_de_generator(seed=2, chunk_size=100) + sigs1 = [_split_signature(s) for s in gen1.create_splits(entries)] + sigs2 = [_split_signature(s) for s in gen2.create_splits(entries)] + self.assertEqual(sorted(sigs1), sorted(sigs2)) + self.assertNotEqual(sigs1, sigs2) + + def test_input_order_does_not_affect_output(self): + a = _mock_de_entry([], 0, 'g0.parquet', 0, 100) + b = _mock_de_entry([], 0, 'g1.parquet', 100, 100) + c = _mock_de_entry([], 0, 'g2.parquet', 200, 100) + gen1 = _make_de_generator(seed=99, chunk_size=100) + gen2 = _make_de_generator(seed=99, chunk_size=100) + sigs1 = [_split_signature(s) for s in gen1.create_splits([a, b, c])] + sigs2 = [_split_signature(s) for s in gen2.create_splits([c, a, b])] + self.assertEqual(sigs1, sigs2) + + def test_shard_round_trip_no_overlap_no_loss(self): + # 13 aligned groups × 100 rows = 1300 rows. Shard across 4 workers. + entries = [_mock_de_entry([], 0, f'g{i:02d}.parquet', i * 100, 100) for i in range(13)] + num_workers = 4 + + unsharded = _make_de_generator(seed=7, chunk_size=100).create_splits(list(entries)) + unsharded_sigs = sorted(_split_signature(s) for s in unsharded) + + sharded_sigs = [] + total_rows = 0 + for w in range(num_workers): + gen = _make_de_generator(seed=7, chunk_size=100) + gen.with_shard(w, num_workers) + for s in gen.create_splits(list(entries)): + sharded_sigs.append(_split_signature(s)) + total_rows += s.row_count + self.assertEqual(total_rows, 13 * 100) + # No duplicate splits across workers + self.assertEqual(len(sharded_sigs), len(set(sharded_sigs))) + self.assertEqual(sorted(sharded_sigs), unsharded_sigs) + + def test_multi_partition_no_chunk_crosses_partition(self): + entries = [ + _mock_de_entry(['p1'], 0, 'g0.parquet', 0, 100), + _mock_de_entry(['p1'], 0, 'g1.parquet', 100, 100), + _mock_de_entry(['p2'], 0, 'g2.parquet', 200, 100), + _mock_de_entry(['p2'], 0, 'g3.parquet', 300, 100), + ] + gen = _make_de_generator(seed=0, chunk_size=100) + splits = gen.create_splits(entries) + for s in splits: + data_split = s.data_split() if isinstance(s, IndexedSplit) else s + self.assertEqual(len({tuple(data_split.partition.values)}), 1) + + def test_null_and_non_null_partitions_sort_safely(self): + # Same null-vs-non-null sort guard, exercised on the DE path. + entries = [ + _mock_de_entry(['p1'], 0, 'g0.parquet', 0, 100), + _mock_de_entry([None], 0, 'g1.parquet', 100, 100), + _mock_de_entry(['p2'], 0, 'g2.parquet', 200, 100), + ] + gen = _make_de_generator(seed=1, chunk_size=100) + splits = gen.create_splits(entries) + self.assertEqual(len(splits), 3) + partitions = {_split_signature(s)[0] for s in splits} + self.assertEqual(partitions, {('p1',), ('p2',), (None,)}) + + +class DataEvolutionChunkShuffleEndToEndTest(unittest.TestCase): + """Real DE table → with_chunk_shuffle → multi-worker → union == full table.""" + + @classmethod + def setUpClass(cls): + cls.tempdir = tempfile.mkdtemp() + cls.warehouse = os.path.join(cls.tempdir, 'warehouse') + cls.catalog = CatalogFactory.create({'warehouse': cls.warehouse}) + cls.catalog.create_database('default', True) + + @classmethod + def tearDownClass(cls): + shutil.rmtree(cls.tempdir, ignore_errors=True) + + def _create_de_table(self, name): + pa_schema = pa.schema([ + ('id', pa.int32()), + ('value', pa.string()), + ('payload', pa.large_binary()), + ]) + schema = Schema.from_pyarrow_schema( + pa_schema, + options={ + 'row-tracking.enabled': 'true', + 'data-evolution.enabled': 'true', + 'blob.target-file-size': '1 b', + }, + ) + identifier = f'default.{name}' + self.catalog.create_table(identifier, schema, False) + return self.catalog.get_table(identifier), pa_schema + + @staticmethod + def _payloads(ids): + return [f'payload-{i:03d}'.encode('utf-8') for i in ids] + + def _commit_full_rows(self, table, pa_schema, ids): + wb = table.new_batch_write_builder() + tw = wb.new_write() + tc = wb.new_commit() + tw.write_arrow(pa.Table.from_pydict( + { + 'id': ids, + 'value': [f'v{i}' for i in ids], + 'payload': self._payloads(ids), + }, + schema=pa_schema)) + commit_messages = tw.prepare_commit() + tc.commit(commit_messages) + tw.close() + tc.close() + return commit_messages + + def _assert_commit_has_main_and_multiple_blob_files(self, commit_messages): + all_files = [f for msg in commit_messages for f in msg.new_files] + main_files = [f for f in all_files if not DataFileMeta.is_blob_file(f.file_name)] + blob_files = [f for f in all_files if DataFileMeta.is_blob_file(f.file_name)] + self.assertGreaterEqual(len(main_files), 1) + self.assertGreater( + len(blob_files), 1, + "DE chunk-shuffle tests should exercise one row-id group with multiple blob files", + ) + + def _assert_splits_include_blob_files(self, splits): + self.assertGreater(len(splits), 0) + for split in splits: + data_split = split.data_split() if isinstance(split, IndexedSplit) else split + blob_files = [ + f for f in data_split.files + if DataFileMeta.is_blob_file(f.file_name) + ] + self.assertGreater( + len(blob_files), 0, + "Each DE chunk should keep blob sidecar files with its aligned row-id group", + ) + + def test_workers_union_equals_full_table(self): + table, pa_schema = self._create_de_table('cs_de_union') + # 4 commits → 4 aligned groups. Each group has one normal file and + # multiple blob sidecar files because blob.target-file-size is 1 byte. + for c in range(4): + base = c * 50 + commit_messages = self._commit_full_rows( + table, pa_schema, list(range(base, base + 50))) + self._assert_commit_has_main_and_multiple_blob_files(commit_messages) + + read_builder = table.new_read_builder() + table_read = read_builder.new_read() + + num_workers = 3 + worker_tables = [] + for w in range(num_workers): + scan = read_builder.new_scan() \ + .with_chunk_shuffle(seed=123, chunk_size=37) \ + .with_shard(w, num_workers) + splits = scan.plan().splits() + if splits: + self._assert_splits_include_blob_files(splits) + worker_tables.append(table_read.to_arrow(splits)) + + actual = pa.concat_tables(worker_tables).sort_by('id') + self.assertEqual(actual.num_rows, 200) + self.assertEqual(actual.column('id').to_pylist(), list(range(200))) + self.assertEqual(actual.column('payload').to_pylist(), self._payloads(range(200))) + + def test_deterministic_plan_across_calls(self): + table, pa_schema = self._create_de_table('cs_de_determinism') + for c in range(3): + base = c * 40 + commit_messages = self._commit_full_rows( + table, pa_schema, list(range(base, base + 40))) + self._assert_commit_has_main_and_multiple_blob_files(commit_messages) + + def plan_sigs(worker): + scan = table.new_read_builder().new_scan() \ + .with_chunk_shuffle(seed=42, chunk_size=15) \ + .with_shard(worker, 4) + splits = scan.plan().splits() + if splits: + self._assert_splits_include_blob_files(splits) + return [_split_signature(s) for s in splits] + + for worker in range(4): + self.assertEqual(plan_sigs(worker), plan_sigs(worker)) + + +if __name__ == '__main__': + unittest.main() From 75fee93b3c5e13bb7101974d493bb35731f6ef9b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=96=86=E5=AE=87?= Date: Tue, 2 Jun 2026 18:25:24 +0800 Subject: [PATCH 2/5] pytorch dataset supports --- docs/docs/pypaimon/pytorch.md | 95 ++++++++ .../pypaimon/read/datasource/torch_dataset.py | 205 ++++++++++++++---- paimon-python/pypaimon/read/table_read.py | 20 ++ .../pypaimon/tests/torch_read_test.py | 193 +++++++++++++++++ 4 files changed, 471 insertions(+), 42 deletions(-) diff --git a/docs/docs/pypaimon/pytorch.md b/docs/docs/pypaimon/pytorch.md index 9e98b487eeed..61918447d215 100644 --- a/docs/docs/pypaimon/pytorch.md +++ b/docs/docs/pypaimon/pytorch.md @@ -58,3 +58,98 @@ When the `streaming` parameter is true, it will iteratively read; when it is false, it will read the full amount of data into memory. **`prefetch_concurrency`** (default: 1): When streaming is true, number of threads used for parallel prefetch within each DataLoader worker. Set to a value greater than 1 to partition splits across threads and increase read throughput. Has no effect when streaming is false. + +## Shuffle + +PyPaimon supports streaming shuffle for PyTorch `IterableDataset`. The shuffle +pipeline can be composed of three layers: + +1. **Chunk shuffle**: split files into row chunks during scan planning and + shuffle the generated chunk splits. This is enabled by + `TableScan.with_chunk_shuffle(seed, chunk_size)`. +2. **Split interleave**: read from multiple splits in round-robin order inside + each DataLoader worker. +3. **Buffer shuffle**: apply a reservoir-style row shuffle buffer before rows + are yielded to PyTorch. + +Chunk shuffle is a scan planning feature for append tables, including +Data Evolution append tables. For Data Evolution tables, chunk shuffle keeps +row-id-aligned data files and sidecar files together while slicing by row-id +range. Primary-key tables and deletion-vector scans are not supported by +`with_chunk_shuffle`. + +The second and third layers are Dataset features. They work on the splits you +pass to `to_torch`, so they can be used with either normal splits or +chunk-shuffled splits. + +### Use Dataset Shuffle Only + +Use this when normal scan splits are enough and you only want split interleave +plus row buffer shuffle: + +```python +from torch.utils.data import DataLoader + +table_scan = read_builder.new_scan() +table_read = read_builder.new_read() +splits = table_scan.plan().splits() + +dataset = table_read.to_torch( + splits, + streaming=True, + shuffle=True, + seed=42, + buffer_size=1000, + max_buffer_input_splits=10, +) + +dataloader = DataLoader( + dataset, + batch_size=32, + num_workers=2, + shuffle=False, +) +``` + +`buffer_size` controls the row shuffle buffer. Larger values produce a better +approximation of global shuffle, at the cost of more memory. If +`max_buffer_input_splits` is `1`, split interleave is skipped and only buffer +shuffle is applied. `shuffle=True` requires `streaming=True` and does not +support `prefetch_concurrency > 1`. + +### Use All Three Layers + +For append tables, enable chunk shuffle during scan planning, then enable +Dataset shuffle when converting to PyTorch: + +```python +from torch.utils.data import DataLoader + +seed = 42 + +table_scan = read_builder.new_scan().with_chunk_shuffle( + seed=seed, + chunk_size=1000, +) +table_read = read_builder.new_read() +splits = table_scan.plan().splits() + +dataset = table_read.to_torch( + splits, + streaming=True, + shuffle=True, + seed=seed, + buffer_size=1000, + max_buffer_input_splits=10, +) + +dataloader = DataLoader( + dataset, + batch_size=32, + num_workers=2, + shuffle=False, +) +``` + +Call `dataset.set_epoch(epoch)` before creating or iterating a DataLoader for a +new training epoch if you want a different buffer-shuffle order for each epoch. diff --git a/paimon-python/pypaimon/read/datasource/torch_dataset.py b/paimon-python/pypaimon/read/datasource/torch_dataset.py index 6012d3dc680e..4958cf497264 100644 --- a/paimon-python/pypaimon/read/datasource/torch_dataset.py +++ b/paimon-python/pypaimon/read/datasource/torch_dataset.py @@ -19,8 +19,9 @@ Module to read a Paimon table into PyTorch Dataset. """ import queue +import random import threading -from typing import List +from typing import Iterator, List import torch from torch.utils.data import Dataset, IterableDataset @@ -76,7 +77,44 @@ def __getitem__(self, index: int): return self._data[index] -class TorchIterDataset(IterableDataset): +class _BaseTorchIterDataset(IterableDataset): + """ + Shared helpers for streaming PyTorch datasets backed by Paimon splits. + """ + + def __init__(self, table_read: TableRead, splits: List[Split]): + self.table_read = table_read + self.splits = splits + self.field_names = [field.name for field in table_read.read_type] + + def _row_to_dict(self, offset_row) -> dict: + row_dict = {} + for i, field_name in enumerate(self.field_names): + value = offset_row.get_field(i) + row_dict[field_name] = value + return row_dict + + def _worker_splits(self, worker_info) -> List[Split]: + if worker_info is None: + return self.splits + + worker_id = worker_info.id + num_workers = worker_info.num_workers + total_splits = len(self.splits) + splits_per_worker = total_splits // num_workers + remainder = total_splits % num_workers + + if worker_id < remainder: + start_idx = worker_id * (splits_per_worker + 1) + end_idx = start_idx + splits_per_worker + 1 + else: + start_idx = worker_id * splits_per_worker + remainder + end_idx = start_idx + splits_per_worker + + return self.splits[start_idx:end_idx] + + +class TorchIterDataset(_BaseTorchIterDataset): """ PyTorch IterableDataset implementation for reading Paimon table data. @@ -104,18 +142,8 @@ def __init__(self, table_read: TableRead, splits: List[Split], prefetch_concurre this worker (default 1). When > 1, splits are partitioned across threads to increase read throughput. """ - self.table_read = table_read - self.splits = splits + super().__init__(table_read, splits) self.prefetch_concurrency = max(1, int(prefetch_concurrency)) - # Get field names from read_type - self.field_names = [field.name for field in table_read.read_type] - - def _row_to_dict(self, offset_row) -> dict: - row_dict = {} - for i, field_name in enumerate(self.field_names): - value = offset_row.get_field(i) - row_dict[field_name] = value - return row_dict def __iter__(self): """ @@ -128,30 +156,7 @@ def __iter__(self): row data of dict type, where keys are column names """ worker_info = torch.utils.data.get_worker_info() - - if worker_info is None: - # Single-process data loading, iterate over all splits - splits_to_process = self.splits - else: - # Multi-process data loading, partition splits across workers - worker_id = worker_info.id - num_workers = worker_info.num_workers - - # Calculate start and end indices for this worker - # Distribute splits evenly by slicing - total_splits = len(self.splits) - splits_per_worker = total_splits // num_workers - remainder = total_splits % num_workers - - # Workers with id < remainder get one extra split - if worker_id < remainder: - start_idx = worker_id * (splits_per_worker + 1) - end_idx = start_idx + splits_per_worker + 1 - else: - start_idx = worker_id * splits_per_worker + remainder - end_idx = start_idx + splits_per_worker - - splits_to_process = self.splits[start_idx:end_idx] + splits_to_process = self._worker_splits(worker_info) if self.prefetch_concurrency > 1: for row in self._iter_rows(splits_to_process): @@ -161,11 +166,7 @@ def __iter__(self): worker_iterator = self.table_read.to_iterator(splits_to_process) for offset_row in worker_iterator: - row_dict = {} - for i, field_name in enumerate(self.field_names): - value = offset_row.get_field(i) - row_dict[field_name] = value - yield row_dict + yield self._row_to_dict(offset_row) def _iter_rows(self, splits: List[Split]): n = min(self.prefetch_concurrency, len(splits)) @@ -221,3 +222,123 @@ def producer(split_group: List): stop.set() for t in threads: t.join(timeout=self._PREFETCH_JOIN_TIMEOUT_SEC) + + +class TorchShuffledIterDataset(_BaseTorchIterDataset): + """ + PyTorch IterableDataset with Paimon-controlled streaming shuffle. + + This dataset consumes pre-planned splits, then mixes rows with split + interleaving and a shuffle buffer. Chunk-level shuffle, when needed, + stays in TableScan.with_chunk_shuffle(). + """ + + def __init__( + self, + table_read: TableRead, + splits: List[Split], + seed: int = 0, + buffer_size: int = 1000, + max_buffer_input_splits: int = 10, + ): + super().__init__(table_read, splits) + self.seed = self._require_int(seed, "seed") + self.buffer_size = self._require_positive_int(buffer_size, "buffer_size") + self.max_buffer_input_splits = self._require_positive_int( + max_buffer_input_splits, "max_buffer_input_splits") + self.epoch = 0 + + @staticmethod + def _require_int(value: int, name: str) -> int: + if not isinstance(value, int): + raise ValueError("%s must be an int" % name) + return value + + @staticmethod + def _require_positive_int(value: int, name: str) -> int: + if not isinstance(value, int) or value <= 0: + raise ValueError("%s must be a positive int" % name) + return value + + def set_epoch(self, epoch: int) -> "TorchShuffledIterDataset": + self.epoch = self._require_int(epoch, "epoch") + return self + + def __iter__(self): + worker_info = torch.utils.data.get_worker_info() + worker_id = worker_info.id if worker_info is not None else 0 + splits_to_process = self._worker_splits(worker_info) + + if self.max_buffer_input_splits == 1: + rows = self._iter_ordered_rows(splits_to_process) + else: + rows = self._iter_interleaved_rows(splits_to_process) + for row in self._iter_buffer_shuffled_rows(rows, worker_id): + yield row + + def _iter_ordered_rows(self, splits: List[Split]) -> Iterator[dict]: + for offset_row in self.table_read.to_iterator(splits): + yield self._row_to_dict(offset_row) + + def _iter_interleaved_rows(self, splits: List[Split]) -> Iterator[dict]: + if not splits: + return + + split_iter = iter(splits) + active: List[Iterator] = [] + + def add_next_split() -> bool: + try: + split = next(split_iter) + except StopIteration: + return False + active.append(self.table_read.to_iterator([split])) + return True + + for _ in range(min(self.max_buffer_input_splits, len(splits))): + add_next_split() + + idx = 0 + try: + while active: + if idx >= len(active): + idx = 0 + row_iter = active[idx] + try: + offset_row = next(row_iter) + except StopIteration: + self._close_iterator(row_iter) + del active[idx] + add_next_split() + continue + + yield self._row_to_dict(offset_row) + idx += 1 + finally: + for row_iter in active: + self._close_iterator(row_iter) + + @staticmethod + def _close_iterator(row_iter) -> None: + close = getattr(row_iter, "close", None) + if close is not None: + close() + + def _iter_buffer_shuffled_rows( + self, + rows: Iterator[dict], + worker_id: int, + ) -> Iterator[dict]: + rng = random.Random(self.seed + self.epoch * 1000003 + worker_id) + buffer = [] + for row in rows: + if len(buffer) < self.buffer_size: + buffer.append(row) + continue + idx = rng.randint(0, self.buffer_size - 1) + yield buffer[idx] + buffer[idx] = row + + rng.shuffle(buffer) + for row in buffer: + yield row diff --git a/paimon-python/pypaimon/read/table_read.py b/paimon-python/pypaimon/read/table_read.py index 52a4eaaa7f1a..32c6036fddbf 100644 --- a/paimon-python/pypaimon/read/table_read.py +++ b/paimon-python/pypaimon/read/table_read.py @@ -511,8 +511,28 @@ def to_torch( splits: List[Split], streaming: bool = False, prefetch_concurrency: int = 1, + *, + shuffle: bool = False, + seed: int = 0, + buffer_size: int = 1000, + max_buffer_input_splits: int = 10, ) -> "torch.utils.data.Dataset": """Wrap Paimon table data to PyTorch Dataset.""" + if shuffle: + if not streaming: + raise ValueError("shuffle=True only supports streaming=True") + if prefetch_concurrency > 1: + raise ValueError("shuffle=True does not support prefetch_concurrency > 1") + from pypaimon.read.datasource.torch_dataset import TorchShuffledIterDataset + dataset = TorchShuffledIterDataset( + self, + splits, + seed=seed, + buffer_size=buffer_size, + max_buffer_input_splits=max_buffer_input_splits, + ) + return dataset + if streaming: from pypaimon.read.datasource.torch_dataset import TorchIterDataset dataset = TorchIterDataset(self, splits, prefetch_concurrency) diff --git a/paimon-python/pypaimon/tests/torch_read_test.py b/paimon-python/pypaimon/tests/torch_read_test.py index ac6088ece308..176fe74b4d51 100644 --- a/paimon-python/pypaimon/tests/torch_read_test.py +++ b/paimon-python/pypaimon/tests/torch_read_test.py @@ -645,6 +645,199 @@ def test_torch_read_with_predicate(self): print("✓ All predicate test cases passed!") print(f"{'=' * 60}\n") + def test_torch_streaming_shuffle_single_worker(self): + table = self._create_shuffle_append_table('default.test_torch_shuffle_single') + read_builder = table.new_read_builder().with_projection(['user_id']) + table_read = read_builder.new_read() + splits = read_builder.new_scan().plan().splits() + + expected = list(range(80)) + for max_buffer_input_splits in [1, 3]: + with self.subTest(max_buffer_input_splits=max_buffer_input_splits): + dataset = table_read.to_torch( + splits, + streaming=True, + shuffle=True, + seed=17, + buffer_size=7, + max_buffer_input_splits=max_buffer_input_splits, + ) + ids = self._collect_torch_user_ids(dataset, num_workers=0) + self.assertEqual(sorted(ids), expected) + self.assertNotEqual(ids, expected) + + def test_torch_streaming_shuffle_seed_and_epoch(self): + table = self._create_shuffle_append_table('default.test_torch_shuffle_epoch') + read_builder = table.new_read_builder().with_projection(['user_id']) + table_read = read_builder.new_read() + splits = read_builder.new_scan().plan().splits() + + dataset = table_read.to_torch( + splits, + streaming=True, + shuffle=True, + seed=23, + buffer_size=11, + max_buffer_input_splits=4, + ) + epoch0 = self._collect_torch_user_ids(dataset, num_workers=0) + epoch0_again = self._collect_torch_user_ids(dataset, num_workers=0) + self.assertEqual(epoch0, epoch0_again) + + dataset.set_epoch(1) + epoch1 = self._collect_torch_user_ids(dataset, num_workers=0) + self.assertEqual(sorted(epoch1), list(range(80))) + self.assertNotEqual(epoch0, epoch1) + + dataset.set_epoch(0) + self.assertEqual(epoch0, self._collect_torch_user_ids(dataset, num_workers=0)) + + other_seed_dataset = table_read.to_torch( + splits, + streaming=True, + shuffle=True, + seed=24, + buffer_size=11, + max_buffer_input_splits=4, + ) + self.assertNotEqual( + epoch0, + self._collect_torch_user_ids(other_seed_dataset, num_workers=0), + ) + + def test_torch_streaming_shuffle_multi_worker(self): + table = self._create_shuffle_append_table('default.test_torch_shuffle_multi') + read_builder = table.new_read_builder().with_projection(['user_id']) + table_read = read_builder.new_read() + splits = read_builder.new_scan() \ + .with_chunk_shuffle(seed=31, chunk_size=5) \ + .plan() \ + .splits() + + dataset = table_read.to_torch( + splits, + streaming=True, + shuffle=True, + seed=31, + buffer_size=13, + max_buffer_input_splits=4, + ) + ids = self._collect_torch_user_ids(dataset, num_workers=2) + + expected = list(range(80)) + self.assertEqual(len(ids), len(expected)) + self.assertEqual(sorted(ids), expected) + + def test_torch_streaming_shuffle_rejects_non_streaming(self): + table = self._create_shuffle_append_table('default.test_torch_shuffle_non_streaming') + read_builder = table.new_read_builder() + table_read = read_builder.new_read() + splits = read_builder.new_scan().plan().splits() + + with self.assertRaisesRegex(ValueError, "streaming=True"): + table_read.to_torch(splits, streaming=False, shuffle=True) + + def test_torch_streaming_shuffle_accepts_pk_table_splits(self): + pa_schema = pa.schema([ + pa.field('user_id', pa.int32(), nullable=False), + ('item_id', pa.int64()), + ('behavior', pa.string()), + ('dt', pa.string()) + ]) + schema = Schema.from_pyarrow_schema( + pa_schema, + primary_keys=['user_id'], + options={'bucket': '1'}, + ) + self.catalog.create_table('default.test_torch_shuffle_pk', schema, False) + table = self.catalog.get_table('default.test_torch_shuffle_pk') + self._write_test_table(table) + + read_builder = table.new_read_builder().with_projection(['user_id']) + splits = read_builder.new_scan().plan().splits() + dataset = read_builder.new_read().to_torch( + splits, + streaming=True, + shuffle=True, + seed=7, + buffer_size=3, + ) + ids = self._collect_torch_user_ids(dataset, num_workers=0) + + self.assertEqual(sorted(ids), [1, 2, 3, 4, 5, 6, 7, 8]) + + def test_torch_streaming_shuffle_rejects_invalid_dataset_options(self): + table = self._create_shuffle_append_table('default.test_torch_shuffle_invalid_options') + read_builder = table.new_read_builder().with_projection(['user_id']) + table_read = read_builder.new_read() + splits = read_builder.new_scan().plan().splits() + + with self.assertRaisesRegex(ValueError, "prefetch_concurrency"): + table_read.to_torch( + splits, + streaming=True, + shuffle=True, + prefetch_concurrency=2, + ) + with self.assertRaisesRegex(ValueError, "buffer_size"): + table_read.to_torch( + splits, + streaming=True, + shuffle=True, + buffer_size=0, + ) + with self.assertRaisesRegex(ValueError, "max_buffer_input_splits"): + table_read.to_torch( + splits, + streaming=True, + shuffle=True, + max_buffer_input_splits=0, + ) + + def _create_shuffle_append_table( + self, + identifier, + total_rows=80, + rows_per_commit=10, + partition_keys=None, + ): + schema = Schema.from_pyarrow_schema( + self.pa_schema, + partition_keys=partition_keys or [], + ) + self.catalog.create_table(identifier, schema, False) + table = self.catalog.get_table(identifier) + + write_builder = table.new_batch_write_builder() + for start in range(0, total_rows, rows_per_commit): + end = min(start + rows_per_commit, total_rows) + table_write = write_builder.new_write() + table_commit = write_builder.new_commit() + pa_table = pa.Table.from_pydict({ + 'user_id': list(range(start, end)), + 'item_id': [1000 + i for i in range(start, end)], + 'behavior': [chr(ord('a') + (i % 26)) for i in range(start, end)], + 'dt': [f'p{i % 4}' for i in range(start, end)], + }, schema=self.pa_schema) + table_write.write_arrow(pa_table) + table_commit.commit(table_write.prepare_commit()) + table_write.close() + table_commit.close() + return table + + @staticmethod + def _collect_torch_user_ids(dataset, num_workers=0): + dataloader = DataLoader( + dataset, + batch_size=8, + num_workers=num_workers, + shuffle=False, + ) + all_user_ids = [] + for batch_data in dataloader: + all_user_ids.extend(batch_data['user_id'].tolist()) + return all_user_ids + def _write_test_table(self, table): write_builder = table.new_batch_write_builder() table_pa_schema = self.pk_pa_schema if table.primary_keys else self.pa_schema From 54051e5a36cdc41487f5bf1ba0126e1abc62ef37 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=96=86=E5=AE=87?= Date: Wed, 3 Jun 2026 12:12:13 +0800 Subject: [PATCH 3/5] fix comments --- .../pypaimon/read/datasource/torch_dataset.py | 23 +++++++++++- .../pypaimon/tests/torch_read_test.py | 37 +++++++++++++++++++ 2 files changed, 58 insertions(+), 2 deletions(-) diff --git a/paimon-python/pypaimon/read/datasource/torch_dataset.py b/paimon-python/pypaimon/read/datasource/torch_dataset.py index 4958cf497264..5eb3485dddd1 100644 --- a/paimon-python/pypaimon/read/datasource/torch_dataset.py +++ b/paimon-python/pypaimon/read/datasource/torch_dataset.py @@ -30,6 +30,12 @@ from pypaimon.read.table_read import TableRead +def _share_epoch_with_torch_workers(value): + if isinstance(value, torch.Tensor): + return value.share_memory_() + return torch.tensor(value, dtype=torch.long).share_memory_() + + class TorchDataset(Dataset): """ PyTorch Dataset implementation for reading Paimon table data. @@ -246,7 +252,20 @@ def __init__( self.buffer_size = self._require_positive_int(buffer_size, "buffer_size") self.max_buffer_input_splits = self._require_positive_int( max_buffer_input_splits, "max_buffer_input_splits") - self.epoch = 0 + self._epoch = _share_epoch_with_torch_workers(0) + + def __setstate__(self, state): + self.__dict__ = state + self._epoch = _share_epoch_with_torch_workers(self._epoch) + + @property + def epoch(self) -> int: + return int(self._epoch) + + @epoch.setter + def epoch(self, epoch: int) -> None: + epoch = self._require_int(epoch, "epoch") + self._epoch += epoch - self._epoch @staticmethod def _require_int(value: int, name: str) -> int: @@ -261,7 +280,7 @@ def _require_positive_int(value: int, name: str) -> int: return value def set_epoch(self, epoch: int) -> "TorchShuffledIterDataset": - self.epoch = self._require_int(epoch, "epoch") + self.epoch = epoch return self def __iter__(self): diff --git a/paimon-python/pypaimon/tests/torch_read_test.py b/paimon-python/pypaimon/tests/torch_read_test.py index 176fe74b4d51..5f55cb2bc892 100644 --- a/paimon-python/pypaimon/tests/torch_read_test.py +++ b/paimon-python/pypaimon/tests/torch_read_test.py @@ -705,6 +705,36 @@ def test_torch_streaming_shuffle_seed_and_epoch(self): self._collect_torch_user_ids(other_seed_dataset, num_workers=0), ) + def test_torch_streaming_shuffle_epoch_with_persistent_workers(self): + table = self._create_shuffle_append_table('default.test_torch_shuffle_persistent_epoch') + read_builder = table.new_read_builder().with_projection(['user_id']) + table_read = read_builder.new_read() + splits = read_builder.new_scan().plan().splits() + + dataset = table_read.to_torch( + splits, + streaming=True, + shuffle=True, + seed=23, + buffer_size=11, + max_buffer_input_splits=4, + ) + dataloader = DataLoader( + dataset, + batch_size=8, + num_workers=2, + persistent_workers=True, + shuffle=False, + ) + + epoch0 = self._collect_torch_user_ids_from_dataloader(dataloader) + self.assertEqual(epoch0, self._collect_torch_user_ids_from_dataloader(dataloader)) + + dataset.set_epoch(1) + epoch1 = self._collect_torch_user_ids_from_dataloader(dataloader) + self.assertEqual(sorted(epoch1), list(range(80))) + self.assertNotEqual(epoch0, epoch1) + def test_torch_streaming_shuffle_multi_worker(self): table = self._create_shuffle_append_table('default.test_torch_shuffle_multi') read_builder = table.new_read_builder().with_projection(['user_id']) @@ -838,6 +868,13 @@ def _collect_torch_user_ids(dataset, num_workers=0): all_user_ids.extend(batch_data['user_id'].tolist()) return all_user_ids + @staticmethod + def _collect_torch_user_ids_from_dataloader(dataloader): + all_user_ids = [] + for batch_data in dataloader: + all_user_ids.extend(batch_data['user_id'].tolist()) + return all_user_ids + def _write_test_table(self, table): write_builder = table.new_batch_write_builder() table_pa_schema = self.pk_pa_schema if table.primary_keys else self.pa_schema From b887962e3a9f14642173c4079e8315b080480c45 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=96=86=E5=AE=87?= Date: Wed, 3 Jun 2026 12:25:51 +0800 Subject: [PATCH 4/5] emphasise chunk shuffle is for random access formats --- docs/docs/pypaimon/pytorch.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/docs/pypaimon/pytorch.md b/docs/docs/pypaimon/pytorch.md index 61918447d215..0f7f7bbdef6d 100644 --- a/docs/docs/pypaimon/pytorch.md +++ b/docs/docs/pypaimon/pytorch.md @@ -75,7 +75,9 @@ pipeline can be composed of three layers: Chunk shuffle is a scan planning feature for append tables, including Data Evolution append tables. For Data Evolution tables, chunk shuffle keeps row-id-aligned data files and sidecar files together while slicing by row-id -range. Primary-key tables and deletion-vector scans are not supported by +range. Chunk shuffle should be used with file formats that **support random +access**. Currently, the random-access file formats are Lance, Vortex, Row, and +Blob. Primary-key tables and deletion-vector scans are not supported by `with_chunk_shuffle`. The second and third layers are Dataset features. They work on the splits you From 5f704b1b6a1c4a43ab9536ce689f3f2fafd7d564 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=96=86=E5=AE=87?= Date: Wed, 3 Jun 2026 15:57:11 +0800 Subject: [PATCH 5/5] add check for first_row_id --- .../scanner/chunk_shuffle_split_generator.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/paimon-python/pypaimon/read/scanner/chunk_shuffle_split_generator.py b/paimon-python/pypaimon/read/scanner/chunk_shuffle_split_generator.py index c64b620e398b..504236e50071 100644 --- a/paimon-python/pypaimon/read/scanner/chunk_shuffle_split_generator.py +++ b/paimon-python/pypaimon/read/scanner/chunk_shuffle_split_generator.py @@ -351,11 +351,17 @@ def _split_by_row_id_with_range( Mirrors :meth:`DataEvolutionSplitGenerator._split_by_row_id` but also returns the merged row_id range per group, which the chunk - slicer needs to drive row-count accumulation. Files without - ``first_row_id`` are skipped (DE invariant guarantees presence; - defensive in case stray entries sneak in). + slicer needs to drive row-count accumulation. """ - list_ranges = [f.row_id_range() for f in files if f.row_id_range() is not None] + list_ranges = [] + for f in files: + file_range = f.row_id_range() + if file_range is None: + raise ValueError( + "chunk_shuffle for data evolution tables requires row tracking; " + f"file {f.file_name} is missing first_row_id" + ) + list_ranges.append(file_range) if not list_ranges: return [] sorted_ranges = Range.sort_and_merge_overlap(list_ranges, True, False) @@ -363,8 +369,6 @@ def _split_by_row_id_with_range( range_to_files: "dict[Range, List[DataFileMeta]]" = {} for f in files: file_range = f.row_id_range() - if file_range is None: - continue for r in sorted_ranges: if r.overlaps(file_range): range_to_files.setdefault(r, []).append(f)