Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions paimon-python/pypaimon/read/reader/limited_record_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@

from typing import Optional

from pyarrow import RecordBatch

from pypaimon.read.reader.iface.record_batch_reader import RecordBatchReader
from pypaimon.read.reader.iface.record_iterator import RecordIterator
from pypaimon.read.reader.iface.record_reader import RecordReader

Expand Down Expand Up @@ -68,3 +71,38 @@ def next(self):
return None
self._limiter.count += 1
return row


class LimitedRecordBatchReader(RecordBatchReader):
"""Stop emitting rows once ``limit`` rows have been delivered.

Unlike ``LimitedRecordReader`` (which inherits ``RecordReader``),
this class inherits ``RecordBatchReader`` so that the
``isinstance(..., RecordBatchReader)`` gate in TableRead picks the
arrow-batch code path.
"""

def __init__(self, inner: RecordBatchReader, limit: int):
if limit < 0:
raise ValueError("limit must be non-negative, got %d" % limit)
self._inner = inner
self._limit = limit
self.count = 0
self.file_io = inner.file_io
self.blob_field_indices = inner.blob_field_indices
self.vector_field_indices = inner.vector_field_indices

def read_arrow_batch(self) -> Optional[RecordBatch]:
if self.count >= self._limit:
return None
batch = self._inner.read_arrow_batch()
if batch is None:
return None
remaining = self._limit - self.count
if batch.num_rows > remaining:
batch = batch.slice(0, remaining)
self.count += batch.num_rows
return batch

def close(self) -> None:
self._inner.close()
29 changes: 20 additions & 9 deletions paimon-python/pypaimon/read/split_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from pypaimon.read.reader.format_avro_reader import FormatAvroReader
from pypaimon.read.reader.blob_descriptor_convert_reader import BlobDescriptorConvertReader
from pypaimon.read.reader.filter_record_batch_reader import FilterRecordBatchReader
from pypaimon.read.reader.limited_record_reader import LimitedRecordBatchReader, LimitedRecordReader
from pypaimon.read.reader.row_range_filter_record_reader import RowIdFilterRecordBatchReader
from pypaimon.read.reader.format_blob_reader import FormatBlobReader
from pypaimon.read.reader.format_lance_reader import FormatLanceReader
Expand Down Expand Up @@ -109,7 +110,8 @@ def __init__(
read_type: List[DataField],
split: Split,
row_tracking_enabled: bool,
nested_name_paths: Optional[List[List[str]]] = None):
nested_name_paths: Optional[List[List[str]]] = None,
limit: Optional[int] = None):
from pypaimon.table.file_store_table import FileStoreTable

self.table: FileStoreTable = table
Expand All @@ -119,6 +121,7 @@ def __init__(
self.row_tracking_enabled = row_tracking_enabled
self.value_arity = len(read_type)
self.nested_name_paths = nested_name_paths
self.limit = limit
# Snapshot the raw value-side schema before _create_key_value_fields
# wraps it, so MergeFileSplitRead can hand per-value-field nullable
# flags to merge functions that mirror Java's NOT-NULL check.
Expand All @@ -144,8 +147,8 @@ def __init__(
# the space FilterRecordReader actually evaluates against.
read_type_names = {f.name for f in read_type}
if (
self.predicate is not None
and _get_all_fields(self.predicate).issubset(read_type_names)
self.predicate is not None
and _get_all_fields(self.predicate).issubset(read_type_names)
):
self.predicate_for_reader = rewrite_predicate_indices(
self.predicate, read_type
Expand Down Expand Up @@ -620,9 +623,14 @@ def create_reader(self) -> RecordReader:
vector_field_indices=_vector_field_indices(self.read_fields))
# if the table is appendonly table, we don't need extra filter, all predicates has pushed down
if self.table.is_primary_key_table and self.predicate_for_reader:
return FilterRecordReader(concat_reader, self.predicate_for_reader)
reader = FilterRecordReader(concat_reader, self.predicate_for_reader)
if self.limit is not None:
reader = LimitedRecordReader(reader, self.limit)
else:
return concat_reader
reader = concat_reader
if self.limit is not None:
reader = LimitedRecordBatchReader(reader, self.limit)
return reader

def _get_all_data_fields(self):
if self.row_tracking_enabled:
Expand Down Expand Up @@ -650,9 +658,9 @@ def __init__(
split=split,
row_tracking_enabled=row_tracking_enabled,
nested_name_paths=None,
limit=limit,
)
self.outer_extract_name_paths = outer_extract_name_paths
self.limit = limit
# Built once per split-read (value_fields and options are constant
# for the object's life), not per section. ``None`` when
# ``sequence.field`` is unset, in which case the heap falls back to
Expand Down Expand Up @@ -757,8 +765,6 @@ def create_reader(self) -> RecordReader:
blob_field_indices=_blob_field_indices(inner_value_fields),
vector_field_indices=_vector_field_indices(inner_value_fields))
if self.limit is not None:
from pypaimon.read.reader.limited_record_reader import \
LimitedRecordReader
reader = LimitedRecordReader(reader, self.limit)
return reader

Expand All @@ -775,7 +781,8 @@ def __init__(
read_type: List[DataField],
split: Split,
row_tracking_enabled: bool,
nested_name_paths: Optional[List[List[str]]] = None):
nested_name_paths: Optional[List[List[str]]] = None,
limit: Optional[int] = None):
self.row_ranges = None
actual_split = split
if isinstance(split, IndexedSplit):
Expand All @@ -784,6 +791,7 @@ def __init__(
super().__init__(
table, predicate, read_type, actual_split, row_tracking_enabled,
nested_name_paths=nested_name_paths,
limit=limit,
)

def _push_down_predicate(self) -> Optional[Predicate]:
Expand Down Expand Up @@ -827,6 +835,9 @@ def create_reader(self) -> RecordReader:
and CoreOptions.blob_descriptor_fields(self.table.options)):
reader = BlobDescriptorConvertReader(reader, self.table)

if self.limit is not None:
reader = LimitedRecordBatchReader(reader, self.limit)

return reader

def _split_by_row_id(self, files: List[DataFileMeta]) -> List[List[DataFileMeta]]:
Expand Down
2 changes: 2 additions & 0 deletions paimon-python/pypaimon/read/table_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,7 @@ def _create_split_read(self, split: Split) -> SplitRead:
split=split,
row_tracking_enabled=True,
nested_name_paths=self.nested_name_paths,
limit=self.limit,
)
else:
return RawFileSplitRead(
Expand All @@ -589,6 +590,7 @@ def _create_split_read(self, split: Split) -> SplitRead:
split=split,
row_tracking_enabled=self.table.options.row_tracking_enabled(),
nested_name_paths=self.nested_name_paths,
limit=self.limit,
)

def _widen_to_top_level_for_merge(self) -> List[DataField]:
Expand Down
16 changes: 16 additions & 0 deletions paimon-python/pypaimon/tests/blob_table_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3369,6 +3369,22 @@ def test_get_blob_access(self):
self.assertEqual(results[1], (2, b'img_data_2'))
self.assertEqual(results[2], (3, b'img_data_3'))

def test_get_blob_access_with_limit(self):
read_builder = self.table.new_read_builder().with_limit(2)
splits = read_builder.new_scan().plan().splits()
read = read_builder.new_read()

results = []
for row in read.to_iterator(splits):
blob = row.get_blob(2)
self.assertIsNotNone(blob)
results.append((row.get_field(0), blob.to_data()))

self.assertEqual(len(results), 2)
for row_id, data in results:
self.assertIn(row_id, (1, 2, 3))
self.assertIn(data, (b'img_data_1', b'img_data_2', b'img_data_3'))

def test_get_blob_streaming(self):
read_builder = self.table.new_read_builder()
splits = read_builder.new_scan().plan().splits()
Expand Down
52 changes: 50 additions & 2 deletions paimon-python/pypaimon/tests/test_limit_pushdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,8 @@ def test_append_only_limit_stops_within_first_split(self):
exactly 3 rows — even though each partition split has 5 rows."""
table = self._create_ao_table('limit_ao_within_split')
self._write_ao_partitions(table, [
('p1', list(range(5))), # 5 rows
('p2', list(range(5, 10))), # 5 rows
('p1', list(range(5))), # 5 rows
('p2', list(range(5, 10))), # 5 rows
])
rb = table.new_read_builder().with_limit(3)
result = rb.new_read().to_arrow(rb.new_scan().plan().splits())
Expand Down Expand Up @@ -207,6 +207,54 @@ def test_to_iterator_limit_short_circuits(self):
rows = list(it)
self.assertEqual(len(rows), 7)

# ---- SplitRead-level limit pushdown verification ---------------------

def test_append_only_split_read_creates_limited_batch_reader(self):
"""Verify that RawFileSplitRead.create_reader() returns a
LimitedRecordBatchReader (inherits RecordBatchReader) when limit
is set, so the arrow-batch read path is taken."""
from pypaimon.read.reader.iface.record_batch_reader import RecordBatchReader
from pypaimon.read.reader.limited_record_reader import LimitedRecordBatchReader

table = self._create_ao_table('limit_ao_split_read')
self._write_ao_partitions(table, [('p1', list(range(10)))])
rb = table.new_read_builder().with_limit(3)
table_read = rb.new_read()
splits = rb.new_scan().plan().splits()
self.assertGreater(len(splits), 0)
for split in splits:
split_read = table_read._create_split_read(split)
self.assertEqual(split_read.limit, 3)
reader = split_read.create_reader()
self.assertIsInstance(reader, LimitedRecordBatchReader,
"RawFileSplitRead.create_reader() should wrap with LimitedRecordBatchReader")
self.assertIsInstance(reader, RecordBatchReader,
"LimitedRecordBatchReader should be a RecordBatchReader")
reader.close()

def test_append_only_split_read_limit_truncates_within_split(self):
"""Directly read from a single split's reader with limit and verify
the reader itself stops at the limit boundary, not relying on
TableRead-level truncation."""
table = self._create_ao_table('limit_ao_split_truncate')
self._write_ao_partitions(table, [('p1', list(range(20)))])
rb = table.new_read_builder().with_limit(5)
table_read = rb.new_read()
splits = rb.new_scan().plan().splits()
self.assertEqual(len(splits), 1)
split_read = table_read._create_split_read(splits[0])
reader = split_read.create_reader()
# Drain the reader directly, bypassing TableRead-level control
total_rows = 0
while True:
batch = reader.read_arrow_batch()
if batch is None:
break
total_rows += batch.num_rows
reader.close()
self.assertEqual(total_rows, 5,
"SplitRead-level reader should stop at limit=5, got %d" % total_rows)


if __name__ == '__main__':
unittest.main()
Loading