diff --git a/paimon-python/pypaimon/table/row/blob.py b/paimon-python/pypaimon/table/row/blob.py index 43391775bd8d..056316d55fb7 100644 --- a/paimon-python/pypaimon/table/row/blob.py +++ b/paimon-python/pypaimon/table/row/blob.py @@ -18,7 +18,7 @@ import io import struct from abc import ABC, abstractmethod -from typing import BinaryIO, Optional, Union +from typing import BinaryIO, Callable, Optional, Union from urllib.parse import urlparse from pypaimon.common.uri_reader import UriReader, FileUriReader @@ -382,3 +382,6 @@ def __eq__(self, other) -> bool: def __hash__(self) -> int: return hash(self._descriptor) + + +BlobConsumer = Callable[[str, Optional[BlobDescriptor]], bool] diff --git a/paimon-python/pypaimon/tests/blob_table_test.py b/paimon-python/pypaimon/tests/blob_table_test.py index ca9161bd1424..295182d43e92 100755 --- a/paimon-python/pypaimon/tests/blob_table_test.py +++ b/paimon-python/pypaimon/tests/blob_table_test.py @@ -3568,5 +3568,241 @@ def test_get_blob_on_non_blob_column_with_magic_bytes_raises(self): mock_create.assert_not_called() +class BlobConsumerTest(unittest.TestCase): + """Tests for BlobConsumer callback functionality.""" + + @classmethod + def setUpClass(cls): + cls.temp_dir = tempfile.mkdtemp() + cls.warehouse = os.path.join(cls.temp_dir, 'warehouse') + cls.catalog = CatalogFactory.create({'warehouse': cls.warehouse}) + cls.catalog.create_database('test_db', False) + + @classmethod + def tearDownClass(cls): + shutil.rmtree(cls.temp_dir, ignore_errors=True) + + def test_blob_consumer_basic(self): + """Consumer receives one BlobDescriptor per blob written, None for nulls.""" + from pypaimon.table.row.blob import Blob, BlobDescriptor + from pypaimon.common.uri_reader import FileUriReader + + pa_schema = pa.schema([ + ('id', pa.int32()), + ('name', pa.string()), + ('blob_data', pa.large_binary()), + ]) + schema = Schema.from_pyarrow_schema(pa_schema, options={ + 'row-tracking.enabled': 'true', + 'data-evolution.enabled': 'true', + }) + self.catalog.create_table('test_db.blob_consumer_basic', schema, False) + table = self.catalog.get_table('test_db.blob_consumer_basic') + + blob_bytes = b'hello_blob_consumer' + received = [] + + def my_consumer(field_name, descriptor): + received.append((field_name, descriptor)) + return True + + write_builder = table.new_batch_write_builder() + writer = write_builder.new_write() + writer.with_blob_consumer(my_consumer) + + test_data = pa.Table.from_pydict({ + 'id': [1, 2, 3], + 'name': ['a', 'b', 'c'], + 'blob_data': [blob_bytes, blob_bytes, None], + }, schema=pa_schema) + writer.write_arrow(test_data) + commit_messages = writer.prepare_commit() + write_builder.new_commit().commit(commit_messages) + writer.close() + + self.assertEqual(len(received), 3) + + for field_name, desc in received[:2]: + self.assertEqual(field_name, 'blob_data') + self.assertIsInstance(desc, BlobDescriptor) + uri_reader = FileUriReader(table.file_io) + blob = Blob.from_descriptor(uri_reader, desc) + self.assertEqual(blob.to_data(), blob_bytes) + + self.assertEqual(received[2][0], 'blob_data') + self.assertIsNone(received[2][1]) + + def test_blob_consumer_flush_behavior(self): + """Consumer return value controls flush; verify flush count.""" + pa_schema = pa.schema([ + ('id', pa.int32()), + ('name', pa.string()), + ('blob_data', pa.large_binary()), + ]) + schema = Schema.from_pyarrow_schema(pa_schema, options={ + 'row-tracking.enabled': 'true', + 'data-evolution.enabled': 'true', + }) + self.catalog.create_table('test_db.blob_consumer_flush', schema, False) + table = self.catalog.get_table('test_db.blob_consumer_flush') + + blob_bytes = b'flush_test_blob' + descriptors = [] + flush_count = [0] + + def my_consumer(field_name, descriptor): + descriptors.append(descriptor) + should_flush = len(descriptors) % 2 == 0 + if should_flush: + flush_count[0] += 1 + return should_flush + + write_builder = table.new_batch_write_builder() + writer = write_builder.new_write() + writer.with_blob_consumer(my_consumer) + + test_data = pa.Table.from_pydict({ + 'id': list(range(5)), + 'name': [f'row{i}' for i in range(5)], + 'blob_data': [blob_bytes] * 5, + }, schema=pa_schema) + writer.write_arrow(test_data) + commit_messages = writer.prepare_commit() + write_builder.new_commit().commit(commit_messages) + writer.close() + + self.assertEqual(len(descriptors), 5) + self.assertEqual(flush_count[0], 2) + + from pypaimon.table.row.blob import Blob + from pypaimon.common.uri_reader import FileUriReader + uri_reader = FileUriReader(table.file_io) + for desc in descriptors: + self.assertIsNotNone(desc) + blob = Blob.from_descriptor(uri_reader, desc) + self.assertEqual(blob.to_data(), blob_bytes) + + def test_blob_consumer_no_consumer_set(self): + """Without consumer, writing still works normally.""" + pa_schema = pa.schema([ + ('id', pa.int32()), + ('blob_data', pa.large_binary()), + ]) + schema = Schema.from_pyarrow_schema(pa_schema, options={ + 'row-tracking.enabled': 'true', + 'data-evolution.enabled': 'true', + }) + self.catalog.create_table('test_db.blob_no_consumer', schema, False) + table = self.catalog.get_table('test_db.blob_no_consumer') + + write_builder = table.new_batch_write_builder() + writer = write_builder.new_write() + + test_data = pa.Table.from_pydict({ + 'id': [1, 2], + 'blob_data': [b'data1', b'data2'], + }, schema=pa_schema) + writer.write_arrow(test_data) + commit_messages = writer.prepare_commit() + write_builder.new_commit().commit(commit_messages) + writer.close() + + result = table.new_read_builder().new_read().to_arrow( + table.new_read_builder().new_scan().plan().splits()) + self.assertEqual(result.column('blob_data').to_pylist(), [b'data1', b'data2']) + + def test_blob_consumer_chain_call(self): + """with_blob_consumer returns self for chaining.""" + pa_schema = pa.schema([ + ('id', pa.int32()), + ('blob_data', pa.large_binary()), + ]) + schema = Schema.from_pyarrow_schema(pa_schema, options={ + 'row-tracking.enabled': 'true', + 'data-evolution.enabled': 'true', + }) + self.catalog.create_table('test_db.blob_consumer_chain', schema, False) + table = self.catalog.get_table('test_db.blob_consumer_chain') + + write_builder = table.new_batch_write_builder() + result = write_builder.new_write().with_blob_consumer(lambda f, d: False) + self.assertIsNotNone(result) + result.close() + + def test_blob_consumer_abort_preserves_files(self): + """Abort with consumer must not delete blob files that descriptors point to.""" + from pypaimon.table.row.blob import Blob + from pypaimon.common.uri_reader import FileUriReader + + pa_schema = pa.schema([ + ('id', pa.int32()), + ('blob_data', pa.large_binary()), + ]) + schema = Schema.from_pyarrow_schema(pa_schema, options={ + 'row-tracking.enabled': 'true', + 'data-evolution.enabled': 'true', + 'blob.target-file-size': '1KB', + }) + self.catalog.create_table('test_db.blob_consumer_abort', schema, False) + table = self.catalog.get_table('test_db.blob_consumer_abort') + + blob_bytes = b'X' * 2048 + received = [] + + def my_consumer(field_name, descriptor): + received.append(descriptor) + return False + + write_builder = table.new_batch_write_builder() + writer = write_builder.new_write() + writer.with_blob_consumer(my_consumer) + + test_data = pa.Table.from_pydict({ + 'id': list(range(5)), + 'blob_data': [blob_bytes] * 5, + }, schema=pa_schema) + writer.write_arrow(test_data) + + self.assertGreater(len(received), 0) + + # Capture data writers before close() clears them, then abort each one. + data_writers = list(writer.file_store_write.data_writers.values()) + self.assertGreater(len(data_writers), 0) + for dw in data_writers: + dw.abort() + + # Every descriptor returned to the consumer must still be readable. + uri_reader = FileUriReader(table.file_io) + for desc in received: + self.assertIsNotNone(desc) + data = Blob.from_descriptor(uri_reader, desc).to_data() + self.assertEqual(data, blob_bytes) + + def test_blob_consumer_after_write_raises(self): + """Setting consumer after data has been written must raise.""" + pa_schema = pa.schema([ + ('id', pa.int32()), + ('blob_data', pa.large_binary()), + ]) + schema = Schema.from_pyarrow_schema(pa_schema, options={ + 'row-tracking.enabled': 'true', + 'data-evolution.enabled': 'true', + }) + self.catalog.create_table('test_db.blob_consumer_late', schema, False) + table = self.catalog.get_table('test_db.blob_consumer_late') + + write_builder = table.new_batch_write_builder() + writer = write_builder.new_write() + + writer.write_arrow(pa.Table.from_pydict({ + 'id': [1], + 'blob_data': [b'data'], + }, schema=pa_schema)) + + with self.assertRaises(RuntimeError): + writer.with_blob_consumer(lambda f, d: False) + writer.close() + + if __name__ == '__main__': unittest.main() diff --git a/paimon-python/pypaimon/write/blob_format_writer.py b/paimon-python/pypaimon/write/blob_format_writer.py index 92257f4ca97a..6004425b7181 100644 --- a/paimon-python/pypaimon/write/blob_format_writer.py +++ b/paimon-python/pypaimon/write/blob_format_writer.py @@ -17,9 +17,9 @@ import struct import zlib -from typing import BinaryIO, List +from typing import BinaryIO, List, Optional -from pypaimon.table.row.blob import Blob, BlobData, BlobDescriptor +from pypaimon.table.row.blob import Blob, BlobData, BlobDescriptor, BlobConsumer from pypaimon.common.delta_varint_compressor import DeltaVarintCompressor @@ -31,8 +31,12 @@ class BlobFormatWriter: BUFFER_SIZE = 4096 METADATA_SIZE = 12 # 8-byte length + 4-byte CRC - def __init__(self, output_stream: BinaryIO): + def __init__(self, output_stream: BinaryIO, + blob_consumer: Optional[BlobConsumer] = None, + file_path: Optional[str] = None): self.output_stream = output_stream + self._blob_consumer = blob_consumer + self._file_path = file_path self.lengths: List[int] = [] self.position = 0 @@ -40,9 +44,12 @@ def add_element(self, row) -> None: if not hasattr(row, 'values') or len(row.values) != 1: raise ValueError("BlobFormatWriter only supports one field") + blob_field_name = row.fields[0].name blob_value = row.values[0] if blob_value is None: self.lengths.append(self.NULL_LENGTH) + if self._blob_consumer is not None: + self._blob_consumer(blob_field_name, None) return if not isinstance(blob_value, Blob): @@ -59,6 +66,8 @@ def add_element(self, row) -> None: magic_bytes = struct.pack(' None: finally: stream.close() + blob_length = self.position - blob_pos + # Calculate total length including magic + data + metadata (length + CRC) bin_length = self.position - previous_pos + self.METADATA_SIZE self.lengths.append(bin_length) @@ -88,6 +99,12 @@ def add_element(self, row) -> None: self.output_stream.write(crc_bytes) self.position += 4 + if self._blob_consumer is not None: + descriptor = BlobDescriptor(self._file_path, blob_pos, blob_length) + flush = self._blob_consumer(blob_field_name, descriptor) + if flush: + self.output_stream.flush() + def _write_with_crc(self, data: bytes, crc32: int) -> int: crc32 = zlib.crc32(data, crc32) self.output_stream.write(data) diff --git a/paimon-python/pypaimon/write/file_store_write.py b/paimon-python/pypaimon/write/file_store_write.py index 6a20708a14fd..f1374797e15b 100644 --- a/paimon-python/pypaimon/write/file_store_write.py +++ b/paimon-python/pypaimon/write/file_store_write.py @@ -40,6 +40,7 @@ def __init__(self, table, commit_user): self.data_writers: Dict[Tuple, DataWriter] = {} self.max_seq_numbers: dict = {} self.write_cols = None + self.blob_consumer = None self.commit_identifier = 0 self.options = CoreOptions.copy(table.options) if self.table.bucket_mode() == BucketMode.POSTPONE_MODE: @@ -72,6 +73,7 @@ def max_seq_number(): max_seq_number=0, options=options, write_cols=self.write_cols, + blob_consumer=self.blob_consumer, ) elif self._has_vector_columns() and options.with_vector_format(): return DataVectorWriter( diff --git a/paimon-python/pypaimon/write/table_write.py b/paimon-python/pypaimon/write/table_write.py index 1eeeb5e8467b..411ddd9ceb21 100644 --- a/paimon-python/pypaimon/write/table_write.py +++ b/paimon-python/pypaimon/write/table_write.py @@ -22,6 +22,7 @@ from pypaimon.schema.data_types import PyarrowFieldParser from pypaimon.snapshot.snapshot import BATCH_COMMIT_IDENTIFIER +from pypaimon.table.row.blob import BlobConsumer from pypaimon.write.commit_message import CommitMessage from pypaimon.write.file_store_write import FileStoreWrite @@ -72,6 +73,14 @@ def with_write_type(self, write_cols: List[str]): self.file_store_write.write_cols = write_cols return self + def with_blob_consumer(self, blob_consumer: BlobConsumer): + if self.file_store_write.data_writers: + raise RuntimeError( + "with_blob_consumer must be called before any write operation." + ) + self.file_store_write.blob_consumer = blob_consumer + return self + def write_ray( self, dataset: "Dataset", diff --git a/paimon-python/pypaimon/write/writer/blob_file_writer.py b/paimon-python/pypaimon/write/writer/blob_file_writer.py index fe911586be70..e4aa66a1953c 100644 --- a/paimon-python/pypaimon/write/writer/blob_file_writer.py +++ b/paimon-python/pypaimon/write/writer/blob_file_writer.py @@ -22,7 +22,7 @@ from pypaimon.write.blob_format_writer import BlobFormatWriter from pypaimon.table.row.generic_row import GenericRow, RowKind -from pypaimon.table.row.blob import Blob, BlobData, BlobDescriptor +from pypaimon.table.row.blob import Blob, BlobConsumer, BlobData, BlobDescriptor from pypaimon.schema.data_types import DataField, PyarrowFieldParser @@ -32,11 +32,16 @@ class BlobFileWriter: Writes rows one by one and tracks file size. """ - def __init__(self, file_io, file_path: Path): + def __init__(self, file_io, file_path: Path, blob_consumer: Optional[BlobConsumer] = None): self.file_io = file_io self.file_path = file_path + self._blob_consumer = blob_consumer self.output_stream = file_io.new_output_stream(file_path) - self.writer = BlobFormatWriter(self.output_stream) + self.writer = BlobFormatWriter( + self.output_stream, + blob_consumer=blob_consumer, + file_path=str(file_path), + ) self.row_count = 0 self.closed = False @@ -118,7 +123,7 @@ def close(self) -> int: return file_size def abort(self): - """Abort the writer and delete the file.""" + """Abort the writer and delete the file (unless a blob consumer holds references).""" if not self.closed: try: if hasattr(self.output_stream, 'close'): @@ -127,5 +132,5 @@ def abort(self): pass self.closed = True - # Delete the file - self.file_io.delete_quietly(self.file_path) + if self._blob_consumer is None: + self.file_io.delete_quietly(self.file_path) diff --git a/paimon-python/pypaimon/write/writer/blob_writer.py b/paimon-python/pypaimon/write/writer/blob_writer.py index adc86e848689..24f64e4dff78 100644 --- a/paimon-python/pypaimon/write/writer/blob_writer.py +++ b/paimon-python/pypaimon/write/writer/blob_writer.py @@ -22,6 +22,7 @@ from pypaimon.common.options.core_options import CoreOptions from pypaimon.data.timestamp import Timestamp +from pypaimon.table.row.blob import BlobConsumer from pypaimon.write.writer.append_only_data_writer import AppendOnlyDataWriter from pypaimon.write.writer.blob_file_writer import BlobFileWriter @@ -31,7 +32,7 @@ class BlobWriter(AppendOnlyDataWriter): def __init__(self, table, partition: Tuple, bucket: int, max_seq_number: int, blob_column: str, - options: Dict[str, str] = None): + options: Dict[str, str] = None, blob_consumer: Optional[BlobConsumer] = None): super().__init__(table, partition, bucket, max_seq_number, options, write_cols=[blob_column]) @@ -44,6 +45,7 @@ def __init__(self, table, partition: Tuple, bucket: int, max_seq_number: int, bl options = self.table.options self.blob_target_file_size = CoreOptions.blob_target_file_size(options) + self._blob_consumer = blob_consumer self.current_writer: Optional[BlobFileWriter] = None self.current_file_path: Optional[str] = None self.record_count = 0 @@ -98,7 +100,7 @@ def open_current_writer(self): self.file_count += 1 # Increment counter for next file file_path = self._generate_file_path(file_name) self.current_file_path = file_path - self.current_writer = BlobFileWriter(self.file_io, file_path) + self.current_writer = BlobFileWriter(self.file_io, file_path, blob_consumer=self._blob_consumer) def rolling_file(self) -> bool: if self.current_writer is None: @@ -236,7 +238,11 @@ def abort(self): logger.warning(f"Error aborting blob writer: {e}", exc_info=e) self.current_writer = None self.current_file_path = None - super().abort() + if self._blob_consumer is not None: + self.pending_data = None + self.committed_files.clear() + else: + super().abort() @staticmethod def _get_column_stats(data_or_batch, column_name: str): diff --git a/paimon-python/pypaimon/write/writer/dedicated_format_writer.py b/paimon-python/pypaimon/write/writer/dedicated_format_writer.py index 6e7ea5778148..28ff70b36017 100644 --- a/paimon-python/pypaimon/write/writer/dedicated_format_writer.py +++ b/paimon-python/pypaimon/write/writer/dedicated_format_writer.py @@ -26,6 +26,7 @@ from pypaimon.manifest.schema.data_file_meta import DataFileMeta from pypaimon.manifest.schema.simple_stats import SimpleStats from pypaimon.schema.data_types import VectorType +from pypaimon.table.row.blob import BlobConsumer from pypaimon.table.row.generic_row import GenericRow from pypaimon.write.writer.data_writer import DataWriter @@ -50,7 +51,7 @@ class DedicatedFormatWriter(DataWriter): CHECK_ROLLING_RECORD_CNT = 1000 def __init__(self, table, partition: Tuple, bucket: int, max_seq_number: int, options: CoreOptions = None, - write_cols: Optional[List[str]] = None): + write_cols: Optional[List[str]] = None, blob_consumer: Optional[BlobConsumer] = None): super().__init__(table, partition, bucket, max_seq_number, options, write_cols=write_cols) # Determine blob columns from table schema @@ -123,7 +124,8 @@ def __init__(self, table, partition: Tuple, bucket: int, max_seq_number: int, op bucket=self.bucket, max_seq_number=max_seq_number, blob_column=blob_column, - options=options + options=options, + blob_consumer=blob_consumer, ) # Initialize vector writer when vector.file.format is configured.