Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion paimon-python/pypaimon/table/row/blob.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import io
import struct
from abc import ABC, abstractmethod
from typing import BinaryIO, Optional, Union
from typing import BinaryIO, Callable, Optional, Union
from urllib.parse import urlparse

from pypaimon.common.uri_reader import UriReader, FileUriReader
Expand Down Expand Up @@ -382,3 +382,6 @@ def __eq__(self, other) -> bool:

def __hash__(self) -> int:
return hash(self._descriptor)


BlobConsumer = Callable[[str, Optional[BlobDescriptor]], bool]
236 changes: 236 additions & 0 deletions paimon-python/pypaimon/tests/blob_table_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3568,5 +3568,241 @@ def test_get_blob_on_non_blob_column_with_magic_bytes_raises(self):
mock_create.assert_not_called()


class BlobConsumerTest(unittest.TestCase):
Comment thread
JingsongLi marked this conversation as resolved.
"""Tests for BlobConsumer callback functionality."""

@classmethod
def setUpClass(cls):
cls.temp_dir = tempfile.mkdtemp()
cls.warehouse = os.path.join(cls.temp_dir, 'warehouse')
cls.catalog = CatalogFactory.create({'warehouse': cls.warehouse})
cls.catalog.create_database('test_db', False)

@classmethod
def tearDownClass(cls):
shutil.rmtree(cls.temp_dir, ignore_errors=True)

def test_blob_consumer_basic(self):
"""Consumer receives one BlobDescriptor per blob written, None for nulls."""
from pypaimon.table.row.blob import Blob, BlobDescriptor
from pypaimon.common.uri_reader import FileUriReader

pa_schema = pa.schema([
('id', pa.int32()),
('name', pa.string()),
('blob_data', pa.large_binary()),
])
schema = Schema.from_pyarrow_schema(pa_schema, options={
'row-tracking.enabled': 'true',
'data-evolution.enabled': 'true',
})
self.catalog.create_table('test_db.blob_consumer_basic', schema, False)
table = self.catalog.get_table('test_db.blob_consumer_basic')

blob_bytes = b'hello_blob_consumer'
received = []

def my_consumer(field_name, descriptor):
received.append((field_name, descriptor))
return True

write_builder = table.new_batch_write_builder()
writer = write_builder.new_write()
writer.with_blob_consumer(my_consumer)

test_data = pa.Table.from_pydict({
'id': [1, 2, 3],
'name': ['a', 'b', 'c'],
'blob_data': [blob_bytes, blob_bytes, None],
}, schema=pa_schema)
writer.write_arrow(test_data)
commit_messages = writer.prepare_commit()
write_builder.new_commit().commit(commit_messages)
writer.close()

self.assertEqual(len(received), 3)

for field_name, desc in received[:2]:
self.assertEqual(field_name, 'blob_data')
self.assertIsInstance(desc, BlobDescriptor)
uri_reader = FileUriReader(table.file_io)
blob = Blob.from_descriptor(uri_reader, desc)
self.assertEqual(blob.to_data(), blob_bytes)

self.assertEqual(received[2][0], 'blob_data')
self.assertIsNone(received[2][1])

def test_blob_consumer_flush_behavior(self):
"""Consumer return value controls flush; verify flush count."""
pa_schema = pa.schema([
('id', pa.int32()),
('name', pa.string()),
('blob_data', pa.large_binary()),
])
schema = Schema.from_pyarrow_schema(pa_schema, options={
'row-tracking.enabled': 'true',
'data-evolution.enabled': 'true',
})
self.catalog.create_table('test_db.blob_consumer_flush', schema, False)
table = self.catalog.get_table('test_db.blob_consumer_flush')

blob_bytes = b'flush_test_blob'
descriptors = []
flush_count = [0]

def my_consumer(field_name, descriptor):
descriptors.append(descriptor)
should_flush = len(descriptors) % 2 == 0
if should_flush:
flush_count[0] += 1
return should_flush

write_builder = table.new_batch_write_builder()
writer = write_builder.new_write()
writer.with_blob_consumer(my_consumer)

test_data = pa.Table.from_pydict({
'id': list(range(5)),
'name': [f'row{i}' for i in range(5)],
'blob_data': [blob_bytes] * 5,
}, schema=pa_schema)
writer.write_arrow(test_data)
commit_messages = writer.prepare_commit()
write_builder.new_commit().commit(commit_messages)
writer.close()

self.assertEqual(len(descriptors), 5)
self.assertEqual(flush_count[0], 2)

from pypaimon.table.row.blob import Blob
from pypaimon.common.uri_reader import FileUriReader
uri_reader = FileUriReader(table.file_io)
for desc in descriptors:
self.assertIsNotNone(desc)
blob = Blob.from_descriptor(uri_reader, desc)
self.assertEqual(blob.to_data(), blob_bytes)

def test_blob_consumer_no_consumer_set(self):
"""Without consumer, writing still works normally."""
pa_schema = pa.schema([
('id', pa.int32()),
('blob_data', pa.large_binary()),
])
schema = Schema.from_pyarrow_schema(pa_schema, options={
'row-tracking.enabled': 'true',
'data-evolution.enabled': 'true',
})
self.catalog.create_table('test_db.blob_no_consumer', schema, False)
table = self.catalog.get_table('test_db.blob_no_consumer')

write_builder = table.new_batch_write_builder()
writer = write_builder.new_write()

test_data = pa.Table.from_pydict({
'id': [1, 2],
'blob_data': [b'data1', b'data2'],
}, schema=pa_schema)
writer.write_arrow(test_data)
commit_messages = writer.prepare_commit()
write_builder.new_commit().commit(commit_messages)
writer.close()

result = table.new_read_builder().new_read().to_arrow(
table.new_read_builder().new_scan().plan().splits())
self.assertEqual(result.column('blob_data').to_pylist(), [b'data1', b'data2'])

def test_blob_consumer_chain_call(self):
"""with_blob_consumer returns self for chaining."""
pa_schema = pa.schema([
('id', pa.int32()),
('blob_data', pa.large_binary()),
])
schema = Schema.from_pyarrow_schema(pa_schema, options={
'row-tracking.enabled': 'true',
'data-evolution.enabled': 'true',
})
self.catalog.create_table('test_db.blob_consumer_chain', schema, False)
table = self.catalog.get_table('test_db.blob_consumer_chain')

write_builder = table.new_batch_write_builder()
result = write_builder.new_write().with_blob_consumer(lambda f, d: False)
self.assertIsNotNone(result)
result.close()

def test_blob_consumer_abort_preserves_files(self):
"""Abort with consumer must not delete blob files that descriptors point to."""
from pypaimon.table.row.blob import Blob
from pypaimon.common.uri_reader import FileUriReader

pa_schema = pa.schema([
('id', pa.int32()),
('blob_data', pa.large_binary()),
])
schema = Schema.from_pyarrow_schema(pa_schema, options={
'row-tracking.enabled': 'true',
'data-evolution.enabled': 'true',
'blob.target-file-size': '1KB',
})
self.catalog.create_table('test_db.blob_consumer_abort', schema, False)
table = self.catalog.get_table('test_db.blob_consumer_abort')

blob_bytes = b'X' * 2048
received = []

def my_consumer(field_name, descriptor):
received.append(descriptor)
return False

write_builder = table.new_batch_write_builder()
writer = write_builder.new_write()
writer.with_blob_consumer(my_consumer)

test_data = pa.Table.from_pydict({
'id': list(range(5)),
'blob_data': [blob_bytes] * 5,
}, schema=pa_schema)
writer.write_arrow(test_data)

self.assertGreater(len(received), 0)

# Capture data writers before close() clears them, then abort each one.
data_writers = list(writer.file_store_write.data_writers.values())
self.assertGreater(len(data_writers), 0)
for dw in data_writers:
dw.abort()

# Every descriptor returned to the consumer must still be readable.
uri_reader = FileUriReader(table.file_io)
for desc in received:
self.assertIsNotNone(desc)
data = Blob.from_descriptor(uri_reader, desc).to_data()
self.assertEqual(data, blob_bytes)

def test_blob_consumer_after_write_raises(self):
"""Setting consumer after data has been written must raise."""
pa_schema = pa.schema([
('id', pa.int32()),
('blob_data', pa.large_binary()),
])
schema = Schema.from_pyarrow_schema(pa_schema, options={
'row-tracking.enabled': 'true',
'data-evolution.enabled': 'true',
})
self.catalog.create_table('test_db.blob_consumer_late', schema, False)
table = self.catalog.get_table('test_db.blob_consumer_late')

write_builder = table.new_batch_write_builder()
writer = write_builder.new_write()

writer.write_arrow(pa.Table.from_pydict({
'id': [1],
'blob_data': [b'data'],
}, schema=pa_schema))

with self.assertRaises(RuntimeError):
writer.with_blob_consumer(lambda f, d: False)
writer.close()


if __name__ == '__main__':
unittest.main()
23 changes: 20 additions & 3 deletions paimon-python/pypaimon/write/blob_format_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@

import struct
import zlib
from typing import BinaryIO, List
from typing import BinaryIO, List, Optional

from pypaimon.table.row.blob import Blob, BlobData, BlobDescriptor
from pypaimon.table.row.blob import Blob, BlobData, BlobDescriptor, BlobConsumer
from pypaimon.common.delta_varint_compressor import DeltaVarintCompressor


Expand All @@ -31,18 +31,25 @@ class BlobFormatWriter:
BUFFER_SIZE = 4096
METADATA_SIZE = 12 # 8-byte length + 4-byte CRC

def __init__(self, output_stream: BinaryIO):
def __init__(self, output_stream: BinaryIO,
blob_consumer: Optional[BlobConsumer] = None,
file_path: Optional[str] = None):
self.output_stream = output_stream
self._blob_consumer = blob_consumer
self._file_path = file_path
self.lengths: List[int] = []
self.position = 0

def add_element(self, row) -> None:
if not hasattr(row, 'values') or len(row.values) != 1:
raise ValueError("BlobFormatWriter only supports one field")

blob_field_name = row.fields[0].name
blob_value = row.values[0]
if blob_value is None:
self.lengths.append(self.NULL_LENGTH)
if self._blob_consumer is not None:
self._blob_consumer(blob_field_name, None)
return

if not isinstance(blob_value, Blob):
Expand All @@ -59,6 +66,8 @@ def add_element(self, row) -> None:
magic_bytes = struct.pack('<I', self.MAGIC_NUMBER) # Little endian
crc32 = self._write_with_crc(magic_bytes, crc32)

blob_pos = self.position

# Write blob data
if isinstance(blob_value, BlobData):
data = blob_value.to_data()
Expand All @@ -74,6 +83,8 @@ def add_element(self, row) -> None:
finally:
stream.close()

blob_length = self.position - blob_pos

# Calculate total length including magic + data + metadata (length + CRC)
bin_length = self.position - previous_pos + self.METADATA_SIZE
self.lengths.append(bin_length)
Expand All @@ -88,6 +99,12 @@ def add_element(self, row) -> None:
self.output_stream.write(crc_bytes)
self.position += 4

if self._blob_consumer is not None:
descriptor = BlobDescriptor(self._file_path, blob_pos, blob_length)
flush = self._blob_consumer(blob_field_name, descriptor)
if flush:
self.output_stream.flush()

def _write_with_crc(self, data: bytes, crc32: int) -> int:
crc32 = zlib.crc32(data, crc32)
self.output_stream.write(data)
Expand Down
2 changes: 2 additions & 0 deletions paimon-python/pypaimon/write/file_store_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(self, table, commit_user):
self.data_writers: Dict[Tuple, DataWriter] = {}
self.max_seq_numbers: dict = {}
self.write_cols = None
self.blob_consumer = None
self.commit_identifier = 0
self.options = CoreOptions.copy(table.options)
if self.table.bucket_mode() == BucketMode.POSTPONE_MODE:
Expand Down Expand Up @@ -72,6 +73,7 @@ def max_seq_number():
max_seq_number=0,
options=options,
write_cols=self.write_cols,
blob_consumer=self.blob_consumer,
)
elif self._has_vector_columns() and options.with_vector_format():
return DataVectorWriter(
Expand Down
9 changes: 9 additions & 0 deletions paimon-python/pypaimon/write/table_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from pypaimon.schema.data_types import PyarrowFieldParser
from pypaimon.snapshot.snapshot import BATCH_COMMIT_IDENTIFIER
from pypaimon.table.row.blob import BlobConsumer
from pypaimon.write.commit_message import CommitMessage
from pypaimon.write.file_store_write import FileStoreWrite

Expand Down Expand Up @@ -72,6 +73,14 @@ def with_write_type(self, write_cols: List[str]):
self.file_store_write.write_cols = write_cols
return self

def with_blob_consumer(self, blob_consumer: BlobConsumer):
if self.file_store_write.data_writers:
raise RuntimeError(
"with_blob_consumer must be called before any write operation."
)
self.file_store_write.blob_consumer = blob_consumer
return self

def write_ray(
self,
dataset: "Dataset",
Expand Down
Loading
Loading