diff --git a/paimon-python/pypaimon/read/reader/limited_record_reader.py b/paimon-python/pypaimon/read/reader/limited_record_reader.py index 74f2612ebdc0..f78221d3f87b 100644 --- a/paimon-python/pypaimon/read/reader/limited_record_reader.py +++ b/paimon-python/pypaimon/read/reader/limited_record_reader.py @@ -26,6 +26,9 @@ from typing import Optional +from pyarrow import RecordBatch + +from pypaimon.read.reader.iface.record_batch_reader import RecordBatchReader from pypaimon.read.reader.iface.record_iterator import RecordIterator from pypaimon.read.reader.iface.record_reader import RecordReader @@ -68,3 +71,38 @@ def next(self): return None self._limiter.count += 1 return row + + +class LimitedRecordBatchReader(RecordBatchReader): + """Stop emitting rows once ``limit`` rows have been delivered. + + Unlike ``LimitedRecordReader`` (which inherits ``RecordReader``), + this class inherits ``RecordBatchReader`` so that the + ``isinstance(..., RecordBatchReader)`` gate in TableRead picks the + arrow-batch code path. + """ + + def __init__(self, inner: RecordBatchReader, limit: int): + if limit < 0: + raise ValueError("limit must be non-negative, got %d" % limit) + self._inner = inner + self._limit = limit + self.count = 0 + self.file_io = inner.file_io + self.blob_field_indices = inner.blob_field_indices + self.vector_field_indices = inner.vector_field_indices + + def read_arrow_batch(self) -> Optional[RecordBatch]: + if self.count >= self._limit: + return None + batch = self._inner.read_arrow_batch() + if batch is None: + return None + remaining = self._limit - self.count + if batch.num_rows > remaining: + batch = batch.slice(0, remaining) + self.count += batch.num_rows + return batch + + def close(self) -> None: + self._inner.close() diff --git a/paimon-python/pypaimon/read/split_read.py b/paimon-python/pypaimon/read/split_read.py index ddb349e0cab3..4e9b7644cd20 100644 --- a/paimon-python/pypaimon/read/split_read.py +++ b/paimon-python/pypaimon/read/split_read.py @@ -42,6 +42,7 @@ from pypaimon.read.reader.format_avro_reader import FormatAvroReader from pypaimon.read.reader.blob_descriptor_convert_reader import BlobDescriptorConvertReader from pypaimon.read.reader.filter_record_batch_reader import FilterRecordBatchReader +from pypaimon.read.reader.limited_record_reader import LimitedRecordBatchReader, LimitedRecordReader from pypaimon.read.reader.row_range_filter_record_reader import RowIdFilterRecordBatchReader from pypaimon.read.reader.format_blob_reader import FormatBlobReader from pypaimon.read.reader.format_lance_reader import FormatLanceReader @@ -109,7 +110,8 @@ def __init__( read_type: List[DataField], split: Split, row_tracking_enabled: bool, - nested_name_paths: Optional[List[List[str]]] = None): + nested_name_paths: Optional[List[List[str]]] = None, + limit: Optional[int] = None): from pypaimon.table.file_store_table import FileStoreTable self.table: FileStoreTable = table @@ -119,6 +121,7 @@ def __init__( self.row_tracking_enabled = row_tracking_enabled self.value_arity = len(read_type) self.nested_name_paths = nested_name_paths + self.limit = limit # Snapshot the raw value-side schema before _create_key_value_fields # wraps it, so MergeFileSplitRead can hand per-value-field nullable # flags to merge functions that mirror Java's NOT-NULL check. @@ -144,8 +147,8 @@ def __init__( # the space FilterRecordReader actually evaluates against. read_type_names = {f.name for f in read_type} if ( - self.predicate is not None - and _get_all_fields(self.predicate).issubset(read_type_names) + self.predicate is not None + and _get_all_fields(self.predicate).issubset(read_type_names) ): self.predicate_for_reader = rewrite_predicate_indices( self.predicate, read_type @@ -620,9 +623,14 @@ def create_reader(self) -> RecordReader: vector_field_indices=_vector_field_indices(self.read_fields)) # if the table is appendonly table, we don't need extra filter, all predicates has pushed down if self.table.is_primary_key_table and self.predicate_for_reader: - return FilterRecordReader(concat_reader, self.predicate_for_reader) + reader = FilterRecordReader(concat_reader, self.predicate_for_reader) + if self.limit is not None: + reader = LimitedRecordReader(reader, self.limit) else: - return concat_reader + reader = concat_reader + if self.limit is not None: + reader = LimitedRecordBatchReader(reader, self.limit) + return reader def _get_all_data_fields(self): if self.row_tracking_enabled: @@ -650,9 +658,9 @@ def __init__( split=split, row_tracking_enabled=row_tracking_enabled, nested_name_paths=None, + limit=limit, ) self.outer_extract_name_paths = outer_extract_name_paths - self.limit = limit # Built once per split-read (value_fields and options are constant # for the object's life), not per section. ``None`` when # ``sequence.field`` is unset, in which case the heap falls back to @@ -757,8 +765,6 @@ def create_reader(self) -> RecordReader: blob_field_indices=_blob_field_indices(inner_value_fields), vector_field_indices=_vector_field_indices(inner_value_fields)) if self.limit is not None: - from pypaimon.read.reader.limited_record_reader import \ - LimitedRecordReader reader = LimitedRecordReader(reader, self.limit) return reader @@ -775,7 +781,8 @@ def __init__( read_type: List[DataField], split: Split, row_tracking_enabled: bool, - nested_name_paths: Optional[List[List[str]]] = None): + nested_name_paths: Optional[List[List[str]]] = None, + limit: Optional[int] = None): self.row_ranges = None actual_split = split if isinstance(split, IndexedSplit): @@ -784,6 +791,7 @@ def __init__( super().__init__( table, predicate, read_type, actual_split, row_tracking_enabled, nested_name_paths=nested_name_paths, + limit=limit, ) def _push_down_predicate(self) -> Optional[Predicate]: @@ -827,6 +835,9 @@ def create_reader(self) -> RecordReader: and CoreOptions.blob_descriptor_fields(self.table.options)): reader = BlobDescriptorConvertReader(reader, self.table) + if self.limit is not None: + reader = LimitedRecordBatchReader(reader, self.limit) + return reader def _split_by_row_id(self, files: List[DataFileMeta]) -> List[List[DataFileMeta]]: diff --git a/paimon-python/pypaimon/read/table_read.py b/paimon-python/pypaimon/read/table_read.py index 7c717df24d93..826b2b4024a1 100644 --- a/paimon-python/pypaimon/read/table_read.py +++ b/paimon-python/pypaimon/read/table_read.py @@ -580,6 +580,7 @@ def _create_split_read(self, split: Split) -> SplitRead: split=split, row_tracking_enabled=True, nested_name_paths=self.nested_name_paths, + limit=self.limit, ) else: return RawFileSplitRead( @@ -589,6 +590,7 @@ def _create_split_read(self, split: Split) -> SplitRead: split=split, row_tracking_enabled=self.table.options.row_tracking_enabled(), nested_name_paths=self.nested_name_paths, + limit=self.limit, ) def _widen_to_top_level_for_merge(self) -> List[DataField]: diff --git a/paimon-python/pypaimon/tests/blob_table_test.py b/paimon-python/pypaimon/tests/blob_table_test.py index ca9161bd1424..fbacdb892033 100755 --- a/paimon-python/pypaimon/tests/blob_table_test.py +++ b/paimon-python/pypaimon/tests/blob_table_test.py @@ -3369,6 +3369,22 @@ def test_get_blob_access(self): self.assertEqual(results[1], (2, b'img_data_2')) self.assertEqual(results[2], (3, b'img_data_3')) + def test_get_blob_access_with_limit(self): + read_builder = self.table.new_read_builder().with_limit(2) + splits = read_builder.new_scan().plan().splits() + read = read_builder.new_read() + + results = [] + for row in read.to_iterator(splits): + blob = row.get_blob(2) + self.assertIsNotNone(blob) + results.append((row.get_field(0), blob.to_data())) + + self.assertEqual(len(results), 2) + for row_id, data in results: + self.assertIn(row_id, (1, 2, 3)) + self.assertIn(data, (b'img_data_1', b'img_data_2', b'img_data_3')) + def test_get_blob_streaming(self): read_builder = self.table.new_read_builder() splits = read_builder.new_scan().plan().splits() diff --git a/paimon-python/pypaimon/tests/test_limit_pushdown.py b/paimon-python/pypaimon/tests/test_limit_pushdown.py index 2e717c28c7e1..a4b51ee3c00c 100644 --- a/paimon-python/pypaimon/tests/test_limit_pushdown.py +++ b/paimon-python/pypaimon/tests/test_limit_pushdown.py @@ -110,8 +110,8 @@ def test_append_only_limit_stops_within_first_split(self): exactly 3 rows — even though each partition split has 5 rows.""" table = self._create_ao_table('limit_ao_within_split') self._write_ao_partitions(table, [ - ('p1', list(range(5))), # 5 rows - ('p2', list(range(5, 10))), # 5 rows + ('p1', list(range(5))), # 5 rows + ('p2', list(range(5, 10))), # 5 rows ]) rb = table.new_read_builder().with_limit(3) result = rb.new_read().to_arrow(rb.new_scan().plan().splits()) @@ -207,6 +207,54 @@ def test_to_iterator_limit_short_circuits(self): rows = list(it) self.assertEqual(len(rows), 7) + # ---- SplitRead-level limit pushdown verification --------------------- + + def test_append_only_split_read_creates_limited_batch_reader(self): + """Verify that RawFileSplitRead.create_reader() returns a + LimitedRecordBatchReader (inherits RecordBatchReader) when limit + is set, so the arrow-batch read path is taken.""" + from pypaimon.read.reader.iface.record_batch_reader import RecordBatchReader + from pypaimon.read.reader.limited_record_reader import LimitedRecordBatchReader + + table = self._create_ao_table('limit_ao_split_read') + self._write_ao_partitions(table, [('p1', list(range(10)))]) + rb = table.new_read_builder().with_limit(3) + table_read = rb.new_read() + splits = rb.new_scan().plan().splits() + self.assertGreater(len(splits), 0) + for split in splits: + split_read = table_read._create_split_read(split) + self.assertEqual(split_read.limit, 3) + reader = split_read.create_reader() + self.assertIsInstance(reader, LimitedRecordBatchReader, + "RawFileSplitRead.create_reader() should wrap with LimitedRecordBatchReader") + self.assertIsInstance(reader, RecordBatchReader, + "LimitedRecordBatchReader should be a RecordBatchReader") + reader.close() + + def test_append_only_split_read_limit_truncates_within_split(self): + """Directly read from a single split's reader with limit and verify + the reader itself stops at the limit boundary, not relying on + TableRead-level truncation.""" + table = self._create_ao_table('limit_ao_split_truncate') + self._write_ao_partitions(table, [('p1', list(range(20)))]) + rb = table.new_read_builder().with_limit(5) + table_read = rb.new_read() + splits = rb.new_scan().plan().splits() + self.assertEqual(len(splits), 1) + split_read = table_read._create_split_read(splits[0]) + reader = split_read.create_reader() + # Drain the reader directly, bypassing TableRead-level control + total_rows = 0 + while True: + batch = reader.read_arrow_batch() + if batch is None: + break + total_rows += batch.num_rows + reader.close() + self.assertEqual(total_rows, 5, + "SplitRead-level reader should stop at limit=5, got %d" % total_rows) + if __name__ == '__main__': unittest.main()