diff --git a/sdks/java/io/delta/build.gradle b/sdks/java/io/delta/build.gradle index 617965b3bc4e..c07aef6981b9 100644 --- a/sdks/java/io/delta/build.gradle +++ b/sdks/java/io/delta/build.gradle @@ -26,6 +26,10 @@ applyJavaNature( description = "Apache Beam :: SDKs :: Java :: IO :: Delta Lake" ext.summary = "Integration with Delta Lake." +// We need to override the GCS bigdataos connector version to prevent conflicts. +def bigdataoss_gcs_connector_version = "4.0.4" + +def parquet_version = "1.16.0" dependencies { implementation project(path: ":sdks:java:core", configuration: "shadow") @@ -35,5 +39,45 @@ dependencies { permitUnusedDeclared library.java.delta_kernel_api permitUnusedDeclared library.java.delta_kernel_defaults + implementation library.java.hadoop_common + implementation library.java.joda_time + // implementation library.java.slf4j_api + implementation "org.apache.parquet:parquet-column:$parquet_version" + implementation "org.apache.parquet:parquet-hadoop:$parquet_version" + + // We need to override the GCS connector version to prevent conflicts with + // latest Hadoop. + implementation "com.google.cloud.bigdataoss:gcs-connector:$bigdataoss_gcs_connector_version" + implementation "com.google.cloud.bigdataoss:util-hadoop:$bigdataoss_gcs_connector_version" + implementation "com.google.cloud.bigdataoss:gcsio:$bigdataoss_gcs_connector_version" + implementation "com.google.cloud.bigdataoss:util:$bigdataoss_gcs_connector_version" + permitUnusedDeclared "com.google.cloud.bigdataoss:gcs-connector:$bigdataoss_gcs_connector_version" + permitUnusedDeclared "com.google.cloud.bigdataoss:util-hadoop:$bigdataoss_gcs_connector_version" + permitUnusedDeclared "com.google.cloud.bigdataoss:gcsio:$bigdataoss_gcs_connector_version" + permitUnusedDeclared "com.google.cloud.bigdataoss:util:$bigdataoss_gcs_connector_version" + + // For Avro conversions + testImplementation project(":sdks:java:extensions:avro") + + testImplementation library.java.avro testImplementation library.java.junit + testImplementation library.java.hamcrest + testImplementation "org.apache.parquet:parquet-avro:$parquet_version" + testImplementation project(":sdks:java:io:parquet") + testImplementation project(":sdks:java:managed") + testRuntimeOnly "org.yaml:snakeyaml:2.0" + testImplementation project(path: ":runners:direct-java", configuration: "shadow") +} + +configurations.all { + // Exclude conflicting logging frameworks + exclude group: "org.apache.logging.log4j", module: "log4j-slf4j2-impl" + exclude group: "org.apache.logging.log4j", module: "log4j-slf4j-impl" + exclude group: "org.slf4j", module: "slf4j-reload4j" + + // Force overriding for all configurations + resolutionStrategy.force "com.google.cloud.bigdataoss:gcs-connector:$bigdataoss_gcs_connector_version" + resolutionStrategy.force "com.google.cloud.bigdataoss:util-hadoop:$bigdataoss_gcs_connector_version" + resolutionStrategy.force "com.google.cloud.bigdataoss:gcsio:$bigdataoss_gcs_connector_version" + resolutionStrategy.force "com.google.cloud.bigdataoss:util:$bigdataoss_gcs_connector_version" } diff --git a/sdks/java/io/delta/src/main/java/org/apache/beam/sdk/io/delta/BeamEngine.java b/sdks/java/io/delta/src/main/java/org/apache/beam/sdk/io/delta/BeamEngine.java new file mode 100644 index 000000000000..de82d8d01b81 --- /dev/null +++ b/sdks/java/io/delta/src/main/java/org/apache/beam/sdk/io/delta/BeamEngine.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.delta; + +import io.delta.kernel.engine.Engine; +import io.delta.kernel.engine.ExpressionHandler; +import io.delta.kernel.engine.FileSystemClient; +import io.delta.kernel.engine.JsonHandler; +import io.delta.kernel.engine.ParquetHandler; + +/** A Beam specific {@link Engine} wrapper that provides a custom {@link ParquetHandler}. */ +public class BeamEngine implements Engine { + private final Engine delegate; + private final ParquetHandler parquetHandler; + + public BeamEngine(Engine delegate, ParquetHandler parquetHandler) { + this.delegate = delegate; + this.parquetHandler = parquetHandler; + } + + @Override + public ExpressionHandler getExpressionHandler() { + return delegate.getExpressionHandler(); + } + + @Override + public JsonHandler getJsonHandler() { + return delegate.getJsonHandler(); + } + + @Override + public FileSystemClient getFileSystemClient() { + return delegate.getFileSystemClient(); + } + + @Override + public ParquetHandler getParquetHandler() { + return parquetHandler; + } +} diff --git a/sdks/java/io/delta/src/main/java/org/apache/beam/sdk/io/delta/BeamParquetHandler.java b/sdks/java/io/delta/src/main/java/org/apache/beam/sdk/io/delta/BeamParquetHandler.java new file mode 100644 index 000000000000..0ed1b582e982 --- /dev/null +++ b/sdks/java/io/delta/src/main/java/org/apache/beam/sdk/io/delta/BeamParquetHandler.java @@ -0,0 +1,371 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.delta; + +import io.delta.kernel.data.FilteredColumnarBatch; +import io.delta.kernel.defaults.internal.parquet.ParquetFileReader.BatchReadSupport; +import io.delta.kernel.engine.FileReadResult; +import io.delta.kernel.engine.ParquetHandler; +import io.delta.kernel.expressions.Column; +import io.delta.kernel.expressions.Predicate; +import io.delta.kernel.types.MetadataColumnSpec; +import io.delta.kernel.types.StructType; +import io.delta.kernel.utils.CloseableIterator; +import io.delta.kernel.utils.DataFileStatus; +import io.delta.kernel.utils.FileStatus; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.Optional; +import java.util.Set; +import org.apache.beam.sdk.io.range.OffsetRange; +import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.parquet.column.page.PageReadStore; +import org.apache.parquet.filter2.compat.FilterCompat; +import org.apache.parquet.format.converter.ParquetMetadataConverter; +import org.apache.parquet.hadoop.ParquetFileReader; +import org.apache.parquet.hadoop.api.InitContext; +import org.apache.parquet.hadoop.api.ReadSupport; +import org.apache.parquet.hadoop.metadata.FileMetaData; +import org.apache.parquet.hadoop.metadata.ParquetMetadata; +import org.apache.parquet.hadoop.util.HadoopInputFile; +import org.apache.parquet.io.ColumnIOFactory; +import org.apache.parquet.io.MessageColumnIO; +import org.apache.parquet.io.RecordReader; +import org.apache.parquet.io.api.RecordMaterializer; +import org.apache.parquet.schema.MessageType; + +/** + * A Beam specific {@link ParquetHandler} that delegates row group claiming to a {@link + * DeltaReadTaskTracker}. + */ +public class BeamParquetHandler implements ParquetHandler { + private final Configuration conf; + private final ParquetHandler delegate; + private final RestrictionTracker tracker; + private static final long DEFAULT_START_RG_INDEX = 0L; + + public BeamParquetHandler( + Configuration conf, ParquetHandler delegate, RestrictionTracker tracker) { + this.conf = conf; + this.delegate = delegate; + this.tracker = tracker; + } + + private boolean claimFailed = false; + + /** + * A method that is expected to be called after the first file processing is done. It returns + * whether the last file process resulted in a claim failure. This allows the caller to skip + * trying to read the remaining files of the task which would result in claim failures for each + * row group within them. + * + * @return true, if the last file process resulted in a claim failure. Returns false otherwise. + */ + public boolean hasClaimFailed() { + return claimFailed; + } + + @Override + public CloseableIterator readParquetFiles( + CloseableIterator fileIter, + StructType physicalSchema, + Optional predicate) + throws IOException { + return readParquetFiles(fileIter, physicalSchema, predicate, DEFAULT_START_RG_INDEX); + } + + /** + * Reads Parquet files starting from a given row group index. + * + *

This takes the {@code RestrictionTracker} referenced by the current {@code ParquetReader} + * into consideration when reading by performing the following. + * + *

* Skips blocks of the set of files till the given start row group index or the start point + * of the {@code RestrictionTracker}, whatever is higher. * Invokes {@code tryClaim} when reading + * a specific block stops reading if a {@code tryClaim} fails. * Stops reading if the end of the + * range of the {@code RestrictionTracker} is reached. + * + *

If {@code tryClaim} fails during reading, subsequent {@code hasClaimFailed} calls will + * return {@code true}, so the caller can skip reading subsequent files that are in the range + * being considered for reading. + */ + public CloseableIterator readParquetFiles( + CloseableIterator fileIter, + StructType physicalSchema, + Optional predicate, + long startRgIndex) + throws IOException { + + List> results = new ArrayList<>(); + boolean hasRowIndexCol = physicalSchema.contains(MetadataColumnSpec.ROW_INDEX); + + long currentRgIndex = startRgIndex; + + try { + while (fileIter.hasNext()) { + if (currentRgIndex >= tracker.currentRestriction().getTo()) { + // Skipping all blocks for the remaining files since they are located after the + // end index of the tracker. Since currentRgIndex is monotonically increasing, + // we can break the loop immediately to avoid extremely expensive network I/O. + break; + } + + FileStatus fileStatus = fileIter.next(); + Path hadoopPath = new Path(fileStatus.getPath()); + ParquetMetadata metadata = + ParquetFileReader.readFooter(conf, hadoopPath, ParquetMetadataConverter.NO_FILTER); + long fileBlocks = metadata.getBlocks().size(); + + if (currentRgIndex + fileBlocks <= tracker.currentRestriction().getFrom()) { + // Skipping all blocks for the current file since they are located before the + // start index of the tracker. + currentRgIndex += fileBlocks; + continue; + } + + results.add( + readParquetFileDirect( + fileStatus, + hadoopPath, + metadata, + physicalSchema, + hasRowIndexCol, + currentRgIndex, + fileBlocks)); + + currentRgIndex += fileBlocks; + } + } finally { + fileIter.close(); + } + + return combineResults(results); + } + + // Reads the correct set of blocks that belong to the given Parquet file that + // are within range for the current `RestrictionTracker`. If the current file + // has some blocks that are within the tracker's range and some that are + // outside, + // this will only read the blocks that are within the range. + private CloseableIterator readParquetFileDirect( + FileStatus fileStatus, + Path hadoopPath, + ParquetMetadata metadata, + StructType physicalSchema, + boolean hasRowIndexCol, + long startRgIndex, + long fileBlocks) { + + return new CloseableIterator() { + @javax.annotation.Nullable private ParquetFileReader reader = null; + @javax.annotation.Nullable private BatchReadSupport readSupport = null; + @javax.annotation.Nullable private RecordMaterializer recordConverter = null; + @javax.annotation.Nullable private MessageColumnIO columnIO = null; + + private long currentRgOffset = 0; + @javax.annotation.Nullable private RecordReader currentRecordReader = null; + private long currentRgTotalRows = 0; + private long currentRgRowOffset = 0; + private long currentRgStartingRowIndex = 0; + + @javax.annotation.Nullable private FileReadResult nextResult = null; + private boolean isDone = false; + + private void initReaderIfRequired() throws IOException { + if (reader != null) { + return; + } + HadoopInputFile inputFile = HadoopInputFile.fromPath(hadoopPath, conf); + ParquetFileReader localReader = ParquetFileReader.open(inputFile); + reader = localReader; + + FileMetaData fileMetaData = metadata.getFileMetaData(); + MessageType fileSchema = fileMetaData.getSchema(); + Map> keyValueMetadata = new HashMap<>(); + if (fileMetaData.getKeyValueMetaData() != null) { + for (Map.Entry entry : fileMetaData.getKeyValueMetaData().entrySet()) { + keyValueMetadata.put(entry.getKey(), Collections.singleton(entry.getValue())); + } + } + + BatchReadSupport localReadSupport = new BatchReadSupport(1024, physicalSchema); + readSupport = localReadSupport; + ReadSupport.ReadContext readContext = + localReadSupport.init(new InitContext(conf, keyValueMetadata, fileSchema)); + RecordMaterializer localRecordConverter = + localReadSupport.prepareForRead( + conf, fileMetaData.getKeyValueMetaData(), fileSchema, readContext); + recordConverter = localRecordConverter; + localReader.setRequestedSchema(readContext.getRequestedSchema()); + + ColumnIOFactory columnIOFactory = new ColumnIOFactory(fileMetaData.getCreatedBy()); + columnIO = columnIOFactory.getColumnIO(readContext.getRequestedSchema(), fileSchema, true); + } + + @Override + public boolean hasNext() { + if (isDone) { + return false; + } + if (nextResult != null) { + return true; + } + + try { + initReaderIfRequired(); + ParquetFileReader localReader = reader; + BatchReadSupport localReadSupport = readSupport; + MessageColumnIO localColumnIO = columnIO; + RecordMaterializer localRecordConverter = recordConverter; + if (localReader == null + || localReadSupport == null + || localColumnIO == null + || localRecordConverter == null) { + throw new IllegalStateException("Reader not initialized"); + } + + while (true) { + RecordReader localRecordReader = currentRecordReader; + if (localRecordReader != null && currentRgRowOffset < currentRgTotalRows) { + int batchSize = (int) Math.min(1024L, currentRgTotalRows - currentRgRowOffset); + for (int i = 0; i < batchSize; i++) { + localRecordReader.read(); + long rowIndex = + hasRowIndexCol ? (currentRgStartingRowIndex + currentRgRowOffset + i) : -1L; + localReadSupport.finalizeCurrentRow(rowIndex); + } + currentRgRowOffset += batchSize; + io.delta.kernel.data.ColumnarBatch batch = + localReadSupport.getDataAsColumnarBatch(batchSize); + nextResult = new FileReadResult(batch, fileStatus.getPath()); + return true; + } + + currentRecordReader = null; + if (currentRgOffset >= fileBlocks) { + isDone = true; + return false; + } + + // Checking the range for specific row groups. + long rgIndex = startRgIndex + currentRgOffset; + if (rgIndex < tracker.currentRestriction().getFrom() + || rgIndex >= tracker.currentRestriction().getTo()) { + localReader.skipNextRowGroup(); + currentRgOffset++; + continue; + } + + // We only read the row group if it's within the range for the + // RestrictionTracker. + if (tracker.tryClaim(rgIndex)) { + PageReadStore pages = localReader.readNextRowGroup(); + currentRecordReader = + localColumnIO.getRecordReader(pages, localRecordConverter, FilterCompat.NOOP); + currentRgTotalRows = pages.getRowCount(); + currentRgRowOffset = 0; + currentRgStartingRowIndex = pages.getRowIndexOffset().orElse(0L); + currentRgOffset++; + } else { + // Mark claim failed for the current row group, so we stop processing + // the remaining row groups in the source. + claimFailed = true; + isDone = true; + return false; + } + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public FileReadResult next() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + FileReadResult res = nextResult; + if (res == null) { + throw new NoSuchElementException(); + } + nextResult = null; + return res; + } + + @Override + public void close() throws IOException { + if (reader != null) { + reader.close(); + } + } + }; + } + + @Override + public void writeParquetFileAtomically( + String filePath, CloseableIterator data) throws IOException { + delegate.writeParquetFileAtomically(filePath, data); + } + + @Override + public CloseableIterator writeParquetFiles( + String filePath, CloseableIterator data, List statsColumns) + throws IOException { + return delegate.writeParquetFiles(filePath, data, statsColumns); + } + + private static CloseableIterator combineResults( + List> iterators) { + return new CloseableIterator() { + private int currentIdx = 0; + + @Override + public boolean hasNext() { + while (currentIdx < iterators.size()) { + if (iterators.get(currentIdx).hasNext()) { + return true; + } + currentIdx++; + } + return false; + } + + @Override + public FileReadResult next() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + return iterators.get(currentIdx).next(); + } + + @Override + public void close() throws IOException { + for (CloseableIterator it : iterators) { + it.close(); + } + } + }; + } +} diff --git a/sdks/java/io/delta/src/main/java/org/apache/beam/sdk/io/delta/CreateReadTasksDoFn.java b/sdks/java/io/delta/src/main/java/org/apache/beam/sdk/io/delta/CreateReadTasksDoFn.java new file mode 100644 index 000000000000..aa4eaa61bd46 --- /dev/null +++ b/sdks/java/io/delta/src/main/java/org/apache/beam/sdk/io/delta/CreateReadTasksDoFn.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.delta; + +import io.delta.kernel.Scan; +import io.delta.kernel.Snapshot; +import io.delta.kernel.Table; +import io.delta.kernel.data.FilteredColumnarBatch; +import io.delta.kernel.data.Row; +import io.delta.kernel.defaults.engine.DefaultEngine; +import io.delta.kernel.engine.Engine; +import io.delta.kernel.internal.InternalScanFileUtils; +import io.delta.kernel.utils.CloseableIterator; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.hadoop.conf.Configuration; +import org.checkerframework.checker.nullness.qual.Nullable; + +/** A DoFn that reads the Delta log and outputs a list of DeltaReadTask records to read. */ +class CreateReadTasksDoFn extends DoFn { + private static final long MAX_TASK_SIZE_BYTES = 1024L * 1024L * 1024L; // 1 GB + private final @Nullable Map hadoopConfig; + + public CreateReadTasksDoFn(@Nullable Map hadoopConfig) { + this.hadoopConfig = hadoopConfig; + } + + @ProcessElement + public void processElement(@Element String tablePath, OutputReceiver out) + throws Exception { + Configuration conf = new Configuration(); + if (hadoopConfig != null) { + for (Map.Entry entry : hadoopConfig.entrySet()) { + conf.set(entry.getKey(), entry.getValue()); + } + } + Engine engine = DefaultEngine.create(conf); + Table table = Table.forPath(engine, tablePath); + Snapshot snapshot = table.getLatestSnapshot(engine); + Scan scan = snapshot.getScanBuilder().build(); + Row scanState = scan.getScanState(engine); + SerializableRow serializableScanState = new SerializableRow(scanState); + + List currentGroup = new ArrayList<>(); + long currentGroupSize = 0L; + + try (CloseableIterator scanFiles = scan.getScanFiles(engine)) { + while (scanFiles.hasNext()) { + FilteredColumnarBatch batch = scanFiles.next(); + try (CloseableIterator rows = batch.getRows()) { + while (rows.hasNext()) { + Row scanFileRow = rows.next(); + SerializableRow fileRow = new SerializableRow(scanFileRow); + long fileSize = InternalScanFileUtils.getAddFileStatus(fileRow).getSize(); + + if (fileSize >= MAX_TASK_SIZE_BYTES) { + if (!currentGroup.isEmpty()) { + DeltaReadTask readTask = new DeltaReadTask(currentGroup, serializableScanState); + out.output(readTask); + currentGroup = new ArrayList<>(); + currentGroupSize = 0L; + } + + DeltaReadTask readTask = + new DeltaReadTask(Collections.singletonList(fileRow), serializableScanState); + out.output(readTask); + } else { + if (currentGroupSize + fileSize > MAX_TASK_SIZE_BYTES) { + DeltaReadTask readTask = new DeltaReadTask(currentGroup, serializableScanState); + out.output(readTask); + currentGroup = new ArrayList<>(); + currentGroup.add(fileRow); + currentGroupSize = fileSize; + } else { + currentGroup.add(fileRow); + currentGroupSize += fileSize; + } + } + } + } + } + } + + if (!currentGroup.isEmpty()) { + DeltaReadTask readTask = new DeltaReadTask(currentGroup, serializableScanState); + out.output(readTask); + } + } +} diff --git a/sdks/java/io/delta/src/main/java/org/apache/beam/sdk/io/delta/DeltaIO.java b/sdks/java/io/delta/src/main/java/org/apache/beam/sdk/io/delta/DeltaIO.java index 6c5df4728b4e..c511a7380dc1 100644 --- a/sdks/java/io/delta/src/main/java/org/apache/beam/sdk/io/delta/DeltaIO.java +++ b/sdks/java/io/delta/src/main/java/org/apache/beam/sdk/io/delta/DeltaIO.java @@ -18,12 +18,33 @@ package org.apache.beam.sdk.io.delta; import com.google.auto.value.AutoValue; +import io.delta.kernel.Table; +import io.delta.kernel.defaults.engine.DefaultEngine; +import io.delta.kernel.engine.Engine; +import io.delta.kernel.types.ArrayType; +import io.delta.kernel.types.BinaryType; +import io.delta.kernel.types.BooleanType; +import io.delta.kernel.types.DataType; +import io.delta.kernel.types.DateType; +import io.delta.kernel.types.DoubleType; +import io.delta.kernel.types.FloatType; +import io.delta.kernel.types.IntegerType; +import io.delta.kernel.types.LongType; +import io.delta.kernel.types.MapType; +import io.delta.kernel.types.StringType; +import io.delta.kernel.types.StructField; +import io.delta.kernel.types.StructType; +import io.delta.kernel.types.TimestampType; import java.util.Map; import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.Row; +import org.apache.hadoop.conf.Configuration; import org.checkerframework.checker.nullness.qual.Nullable; /** @@ -83,9 +104,74 @@ public ReadRows withConfig(Map config) { @Override public PCollection expand(PBegin input) { - // TODO(https://github.com/apache/beam/issues/38551): Implement expansion for - // Delta Lake ReadRows - throw new UnsupportedOperationException("Not implemented yet."); + String path = getTablePath(); + if (path == null) { + throw new IllegalArgumentException("Table path must be set."); + } + + Configuration conf = new Configuration(); + Map hadoopConfig = getHadoopConfig(); + if (hadoopConfig != null) { + for (Map.Entry entry : hadoopConfig.entrySet()) { + conf.set(entry.getKey(), entry.getValue()); + } + } + Engine engine = DefaultEngine.create(conf); + Table table = Table.forPath(engine, path); + io.delta.kernel.Snapshot snapshot = table.getLatestSnapshot(engine); + StructType deltaSchema = snapshot.getSchema(); + if (deltaSchema == null) { + throw new IllegalStateException("Table schema is null."); + } + Schema beamSchema = convertToBeamSchema(deltaSchema); + + return input + .apply("Create Path", Create.of(path)) + .apply("Plan Files", ParDo.of(new CreateReadTasksDoFn(hadoopConfig))) + .apply("Read Logical Data", ParDo.of(new DeltaSourceDoFn(hadoopConfig))) + .setRowSchema(beamSchema); + } + + static Schema convertToBeamSchema(StructType deltaSchema) { + Schema.Builder builder = Schema.builder(); + for (StructField field : deltaSchema.fields()) { + builder.addField(field.getName(), convertToBeamFieldType(field.getDataType())); + } + return builder.build(); + } + + static Schema.FieldType convertToBeamFieldType(DataType deltaType) { + if (deltaType instanceof StringType) { + return Schema.FieldType.STRING; + } else if (deltaType instanceof IntegerType) { + return Schema.FieldType.INT32; + } else if (deltaType instanceof LongType) { + return Schema.FieldType.INT64; + } else if (deltaType instanceof FloatType) { + return Schema.FieldType.FLOAT; + } else if (deltaType instanceof DoubleType) { + return Schema.FieldType.DOUBLE; + } else if (deltaType instanceof BooleanType) { + return Schema.FieldType.BOOLEAN; + } else if (deltaType instanceof BinaryType) { + return Schema.FieldType.BYTES; + } else if (deltaType instanceof TimestampType) { + return Schema.FieldType.DATETIME; + } else if (deltaType instanceof DateType) { + return Schema.FieldType.DATETIME; + } else if (deltaType instanceof ArrayType) { + DataType elementType = ((ArrayType) deltaType).getElementType(); + return Schema.FieldType.iterable(convertToBeamFieldType(elementType)); + } else if (deltaType instanceof MapType) { + DataType keyType = ((MapType) deltaType).getKeyType(); + DataType valueType = ((MapType) deltaType).getValueType(); + return Schema.FieldType.map( + convertToBeamFieldType(keyType), convertToBeamFieldType(valueType)); + } else if (deltaType instanceof StructType) { + return Schema.FieldType.row(convertToBeamSchema((StructType) deltaType)); + } else { + throw new UnsupportedOperationException("Unsupported Delta type: " + deltaType.getClass()); + } } } } diff --git a/sdks/java/io/delta/src/main/java/org/apache/beam/sdk/io/delta/DeltaReadTask.java b/sdks/java/io/delta/src/main/java/org/apache/beam/sdk/io/delta/DeltaReadTask.java new file mode 100644 index 000000000000..817feed9d891 --- /dev/null +++ b/sdks/java/io/delta/src/main/java/org/apache/beam/sdk/io/delta/DeltaReadTask.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.delta; + +import java.io.Serializable; +import java.util.List; +import java.util.Objects; +import org.checkerframework.checker.nullness.qual.Nullable; + +/** + * A serializable task containing the necessary metadata to read a group of files in a Delta table. + * Packs both the {@code scanFileRows} (representing the physical files and deletion vectors) and + * {@code scanStateRow} (containing snapshot-level read schemas, configuration, and options). + */ +public class DeltaReadTask implements Serializable { + private static final long serialVersionUID = 1L; + + private final List scanFileRows; + private final SerializableRow scanStateRow; + + public DeltaReadTask(List scanFileRows, SerializableRow scanStateRow) { + this.scanFileRows = scanFileRows; + this.scanStateRow = scanStateRow; + } + + public List getScanFileRows() { + return scanFileRows; + } + + public SerializableRow getScanStateRow() { + return scanStateRow; + } + + @Override + public boolean equals(@Nullable Object o) { + if (this == o) { + return true; + } + if (!(o instanceof DeltaReadTask)) { + return false; + } + DeltaReadTask that = (DeltaReadTask) o; + return Objects.equals(scanFileRows, that.scanFileRows) + && Objects.equals(scanStateRow, that.scanStateRow); + } + + @Override + public int hashCode() { + return Objects.hash(scanFileRows, scanStateRow); + } + + @Override + public String toString() { + return "DeltaReadTask{" + + "scanFileRows=" + + scanFileRows + + ", scanStateRow=" + + scanStateRow + + '}'; + } +} diff --git a/sdks/java/io/delta/src/main/java/org/apache/beam/sdk/io/delta/DeltaReadTaskTracker.java b/sdks/java/io/delta/src/main/java/org/apache/beam/sdk/io/delta/DeltaReadTaskTracker.java new file mode 100644 index 000000000000..c81a2a231713 --- /dev/null +++ b/sdks/java/io/delta/src/main/java/org/apache/beam/sdk/io/delta/DeltaReadTaskTracker.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.delta; + +import java.util.List; +import org.apache.beam.sdk.io.range.OffsetRange; +import org.apache.beam.sdk.transforms.splittabledofn.OffsetRangeTracker; + +/** + * A {@link org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker} for tracking progress + * across Parquet row groups represented by a {@link DeltaReadTask}. + */ +public class DeltaReadTaskTracker extends OffsetRangeTracker { + private final List rowGroupSizes; + + public DeltaReadTaskTracker(OffsetRange restriction, List rowGroupSizes) { + super(restriction); + this.rowGroupSizes = rowGroupSizes; + } + + @Override + public Progress getProgress() { + long workCompleted = 0L; + long workRemaining = 0L; + long from = range.getFrom(); + long to = range.getTo(); + long attempted = lastAttemptedOffset == null ? (from - 1) : lastAttemptedOffset; + + for (int i = (int) from; i < (int) to; i++) { + // Upper bound of the range is the number of row groups. + if (i < rowGroupSizes.size()) { + if (i <= attempted) { + workCompleted += rowGroupSizes.get(i); + } else { + workRemaining += rowGroupSizes.get(i); + } + } + } + return Progress.from((double) workCompleted, (double) workRemaining); + } +} diff --git a/sdks/java/io/delta/src/main/java/org/apache/beam/sdk/io/delta/DeltaSourceDoFn.java b/sdks/java/io/delta/src/main/java/org/apache/beam/sdk/io/delta/DeltaSourceDoFn.java new file mode 100644 index 000000000000..9d3195462172 --- /dev/null +++ b/sdks/java/io/delta/src/main/java/org/apache/beam/sdk/io/delta/DeltaSourceDoFn.java @@ -0,0 +1,498 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.delta; + +import io.delta.kernel.Scan; +import io.delta.kernel.data.ArrayValue; +import io.delta.kernel.data.ColumnVector; +import io.delta.kernel.data.ColumnarBatch; +import io.delta.kernel.data.FilteredColumnarBatch; +import io.delta.kernel.data.MapValue; +import io.delta.kernel.defaults.engine.DefaultEngine; +import io.delta.kernel.engine.Engine; +import io.delta.kernel.engine.FileReadResult; +import io.delta.kernel.internal.InternalScanFileUtils; +import io.delta.kernel.internal.data.ScanStateRow; +import io.delta.kernel.internal.util.Utils; +import io.delta.kernel.types.ArrayType; +import io.delta.kernel.types.BinaryType; +import io.delta.kernel.types.BooleanType; +import io.delta.kernel.types.ByteType; +import io.delta.kernel.types.DataType; +import io.delta.kernel.types.DateType; +import io.delta.kernel.types.DoubleType; +import io.delta.kernel.types.FloatType; +import io.delta.kernel.types.IntegerType; +import io.delta.kernel.types.LongType; +import io.delta.kernel.types.MapType; +import io.delta.kernel.types.ShortType; +import io.delta.kernel.types.StringType; +import io.delta.kernel.types.StructField; +import io.delta.kernel.types.StructType; +import io.delta.kernel.types.TimestampType; +import io.delta.kernel.utils.CloseableIterator; +import io.delta.kernel.utils.FileStatus; +import java.math.BigDecimal; +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import org.apache.beam.sdk.io.range.OffsetRange; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker; +import org.apache.beam.sdk.values.Row; +import org.apache.hadoop.conf.Configuration; +import org.checkerframework.checker.nullness.qual.Nullable; + +/** + * A Splittable DoFn that processes {@link DeltaReadTask} elements, performs logical reads, and + * supports dynamic work rebalancing. + */ +@DoFn.BoundedPerElement +class DeltaSourceDoFn extends DoFn { + @Nullable Map hadoopConfig; + private transient @Nullable Engine engine; + private transient @Nullable Configuration conf; + + private transient @Nullable DeltaReadTask cachedTask; + private transient @Nullable List cachedRowGroupSizes; + private transient @Nullable List cachedBlockCountsPerFile; + + public DeltaSourceDoFn(@Nullable Map hadoopConfig) { + this.hadoopConfig = hadoopConfig; + } + + private synchronized Configuration getConfiguration() { + Configuration localConf = conf; + if (localConf == null) { + localConf = new Configuration(); + if (hadoopConfig != null) { + for (Map.Entry entry : hadoopConfig.entrySet()) { + localConf.set(entry.getKey(), entry.getValue()); + } + } + conf = localConf; + } + return localConf; + } + + private synchronized @Nullable List getCachedBlockCounts(DeltaReadTask task) { + if (task.equals(cachedTask)) { + return cachedBlockCountsPerFile; + } + return null; + } + + // Returns the sizes of the row groups for a given DeltaReadTask. + private synchronized List getRowGroupSizes(DeltaReadTask task) { + if (task.equals(cachedTask) && cachedRowGroupSizes != null) { + return cachedRowGroupSizes; + } + + List sizes = new ArrayList<>(); + List blockCounts = new ArrayList<>(); + Configuration conf = getConfiguration(); + for (SerializableRow scanFileRow : task.getScanFileRows()) { + String pathStr = InternalScanFileUtils.getAddFileStatus(scanFileRow).getPath(); + try { + org.apache.hadoop.fs.Path hadoopPath = new org.apache.hadoop.fs.Path(pathStr); + org.apache.parquet.hadoop.metadata.ParquetMetadata metadata = + org.apache.parquet.hadoop.ParquetFileReader.readFooter( + conf, + hadoopPath, + org.apache.parquet.format.converter.ParquetMetadataConverter.NO_FILTER); + long blocksInFile = metadata.getBlocks().size(); + for (org.apache.parquet.hadoop.metadata.BlockMetaData block : metadata.getBlocks()) { + sizes.add(block.getTotalByteSize()); + } + blockCounts.add(blocksInFile); + } catch (java.io.IOException e) { + throw new RuntimeException("Failed to read Parquet footer for " + pathStr, e); + } + } + + cachedTask = task; + cachedRowGroupSizes = sizes; + cachedBlockCountsPerFile = blockCounts; + return sizes; + } + + @GetInitialRestriction + public OffsetRange getInitialRestriction(@Element DeltaReadTask task) { + List rowGroupSizes = getRowGroupSizes(task); + // Note that we use the number of row groups, `rowGroupSizes.size()`, here + // as the upper bound not the byte size of row groups. + return new OffsetRange(0L, rowGroupSizes.size()); + } + + @NewTracker + public DeltaReadTaskTracker newTracker( + @Restriction OffsetRange restriction, @Element DeltaReadTask task) { + return new DeltaReadTaskTracker(restriction, getRowGroupSizes(task)); + } + + @Setup + public void setUp() { + engine = DefaultEngine.create(getConfiguration()); + } + + @ProcessElement + public ProcessContinuation processElement( + @Element DeltaReadTask task, + RestrictionTracker tracker, + OutputReceiver out) + throws Exception { + + SerializableRow scanStateRow = task.getScanStateRow(); + StructType physicalSchema = ScanStateRow.getPhysicalDataReadSchema(scanStateRow); + StructType logicalSchema = ScanStateRow.getLogicalSchema(scanStateRow); + Schema beamSchema = DeltaIO.ReadRows.convertToBeamSchema(logicalSchema); + + Engine currentEngine = engine; + if (currentEngine == null) { + throw new IllegalArgumentException("Expected the engine to not be null"); + } + + // `BeamParquetHandler` takes a reference to the `RestrictionTracker` so that it + // can perform `getFrom`, `getTo`, `tryClaim` requests to return the correct set + // of row groups that map to the current restriction. + BeamParquetHandler parquetHandler = + new BeamParquetHandler(getConfiguration(), currentEngine.getParquetHandler(), tracker); + BeamEngine beamEngine = new BeamEngine(currentEngine, parquetHandler); + + long currentStartRgIndex = 0L; + + // We have to go through files in the `DeltaReadTask` in order so that the + // `RestrictionTracker` + // can correctly handle the range of the current split. + List cachedBlockCounts = getCachedBlockCounts(task); + List scanFileRows = task.getScanFileRows(); + for (int i = 0; i < scanFileRows.size(); i++) { + SerializableRow scanFileRow = scanFileRows.get(i); + FileStatus fileStatus = InternalScanFileUtils.getAddFileStatus(scanFileRow); + + long fileBlocks; + if (cachedBlockCounts != null && i < cachedBlockCounts.size()) { + fileBlocks = cachedBlockCounts.get(i); + } else { + org.apache.hadoop.fs.Path hadoopPath = new org.apache.hadoop.fs.Path(fileStatus.getPath()); + org.apache.parquet.hadoop.metadata.ParquetMetadata metadata = + org.apache.parquet.hadoop.ParquetFileReader.readFooter( + getConfiguration(), + hadoopPath, + org.apache.parquet.format.converter.ParquetMetadataConverter.NO_FILTER); + fileBlocks = metadata.getBlocks().size(); + } + + try (CloseableIterator fileReadResults = + parquetHandler.readParquetFiles( + Utils.singletonCloseableIterator(fileStatus), + physicalSchema, + Optional.empty(), + currentStartRgIndex)) { + + // Get the correct set of physical data for the current file that are within the + // range for the current `RestrictionTracker`. + CloseableIterator physicalData = + new CloseableIterator() { + @Override + public void close() throws java.io.IOException {} + + @Override + public boolean hasNext() { + return fileReadResults.hasNext(); + } + + @Override + public ColumnarBatch next() { + return fileReadResults.next().getData(); + } + }; + + // Convert physical data to logical data. + try (CloseableIterator logicalBatches = + Scan.transformPhysicalData(beamEngine, scanStateRow, scanFileRow, physicalData)) { + + while (logicalBatches.hasNext()) { + FilteredColumnarBatch batch = logicalBatches.next(); + try (CloseableIterator logicalRows = batch.getRows()) { + while (logicalRows.hasNext()) { + io.delta.kernel.data.Row deltaRow = logicalRows.next(); + Row beamRow = toBeamRow(deltaRow, beamSchema); + out.output(beamRow); + } + } + } + } + } + + // Advance the total number of blocked handled so far. + currentStartRgIndex += fileBlocks; + + // If the tryClaim failed during processing of the current file, there's no need + // to look at the rest of the files within the task. + if (parquetHandler.hasClaimFailed()) { + break; + } + } + return ProcessContinuation.stop(); + } + + // Convert Delta `Row` to Beam `Row`. + private static Row toBeamRow(io.delta.kernel.data.Row deltaRow, Schema beamSchema) { + Row.Builder builder = Row.withSchema(beamSchema); + StructType deltaSchema = deltaRow.getSchema(); + List fields = deltaSchema.fields(); + for (int i = 0; i < fields.size(); i++) { + StructField field = fields.get(i); + builder.addValue(getFieldValue(deltaRow, i, field.getDataType())); + } + return builder.build(); + } + + // Returns the value at a specific index in a given row. + private static @Nullable Object getFieldValue( + io.delta.kernel.data.Row row, int index, DataType type) { + if (row.isNullAt(index)) { + return null; + } + if (type instanceof BooleanType) { + return row.getBoolean(index); + } else if (type instanceof ByteType) { + return (int) row.getByte(index); + } else if (type instanceof ShortType) { + return (int) row.getShort(index); + } else if (type instanceof IntegerType) { + return row.getInt(index); + } else if (type instanceof LongType) { + return row.getLong(index); + } else if (type instanceof FloatType) { + return row.getFloat(index); + } else if (type instanceof DoubleType) { + return row.getDouble(index); + } else if (type instanceof StringType) { + return row.getString(index); + } else if (type instanceof BinaryType) { + return row.getBinary(index); + } else if (type instanceof TimestampType) { + long microSeconds = row.getLong(index); + return new org.joda.time.Instant(microSeconds / 1000L); + } else if (type instanceof DateType) { + int daysSinceEpoch = row.getInt(index); + return new org.joda.time.Instant(daysSinceEpoch * 86400000L); + } else if (type instanceof ArrayType) { + ArrayValue arrayValue = row.getArray(index); + int size = arrayValue.getSize(); + ColumnVector elements = arrayValue.getElements(); + DataType elementType = ((ArrayType) type).getElementType(); + List<@Nullable Object> list = new ArrayList<>(size); + for (int i = 0; i < size; i++) { + list.add(getVectorValue(elements, i, elementType)); + } + return list; + } else if (type instanceof MapType) { + MapValue mapValue = row.getMap(index); + int size = mapValue.getSize(); + ColumnVector keys = mapValue.getKeys(); + ColumnVector values = mapValue.getValues(); + DataType keyType = ((MapType) type).getKeyType(); + DataType valueType = ((MapType) type).getValueType(); + Map map = new LinkedHashMap<>(size); + for (int i = 0; i < size; i++) { + Object key = getVectorValue(keys, i, keyType); + if (key != null) { + map.put(key, getVectorValue(values, i, valueType)); + } + } + return map; + } else if (type instanceof StructType) { + io.delta.kernel.data.Row nestedRow = row.getStruct(index); + Schema nestedBeamSchema = DeltaIO.ReadRows.convertToBeamSchema((StructType) type); + return toBeamRow(nestedRow, nestedBeamSchema); + } + throw new UnsupportedOperationException("Unsupported type: " + type.getClass()); + } + + // Returns the value at a specific index in a given column vector. + private static @Nullable Object getVectorValue(ColumnVector vector, int index, DataType type) { + if (vector.isNullAt(index)) { + return null; + } + if (type instanceof BooleanType) { + return vector.getBoolean(index); + } else if (type instanceof ByteType) { + return (int) vector.getByte(index); + } else if (type instanceof ShortType) { + return (int) vector.getShort(index); + } else if (type instanceof IntegerType) { + return vector.getInt(index); + } else if (type instanceof LongType) { + return vector.getLong(index); + } else if (type instanceof FloatType) { + return vector.getFloat(index); + } else if (type instanceof DoubleType) { + return vector.getDouble(index); + } else if (type instanceof StringType) { + return vector.getString(index); + } else if (type instanceof BinaryType) { + return vector.getBinary(index); + } else if (type instanceof TimestampType) { + long microSeconds = vector.getLong(index); + return new org.joda.time.Instant(microSeconds / 1000L); + } else if (type instanceof DateType) { + // Convert days since epoch to milliseconds since epoch. + int daysSinceEpoch = vector.getInt(index); + return new org.joda.time.Instant(daysSinceEpoch * 24L * 60L * 60L * 1000L); + } else if (type instanceof ArrayType) { + ArrayValue arrayValue = vector.getArray(index); + int size = arrayValue.getSize(); + ColumnVector elements = arrayValue.getElements(); + DataType elementType = ((ArrayType) type).getElementType(); + List<@Nullable Object> list = new ArrayList<>(size); + for (int i = 0; i < size; i++) { + list.add(getVectorValue(elements, i, elementType)); + } + return list; + } else if (type instanceof MapType) { + MapValue mapValue = vector.getMap(index); + int size = mapValue.getSize(); + ColumnVector keys = mapValue.getKeys(); + ColumnVector values = mapValue.getValues(); + DataType keyType = ((MapType) type).getKeyType(); + DataType valueType = ((MapType) type).getValueType(); + Map map = new LinkedHashMap<>(size); + for (int i = 0; i < size; i++) { + Object key = getVectorValue(keys, i, keyType); + if (key != null) { + map.put(key, getVectorValue(values, i, valueType)); + } + } + return map; + } else if (type instanceof StructType) { + StructType structType = (StructType) type; + int numFields = structType.fields().size(); + ColumnVector[] childVectors = new ColumnVector[numFields]; + for (int i = 0; i < numFields; i++) { + childVectors[i] = vector.getChild(i); + } + io.delta.kernel.data.Row nestedRow = new VectorRow(structType, childVectors, index); + Schema nestedBeamSchema = DeltaIO.ReadRows.convertToBeamSchema(structType); + return toBeamRow(nestedRow, nestedBeamSchema); + } + throw new UnsupportedOperationException("Unsupported vector type: " + type.getClass()); + } + + // A new new Delta Row type to efficiently convert columnar data from Delta Lake + // to Beam Rows. We use this to store columnar vectors without creating + // additinal memory copies for all fields but return values of the specific + // row index. + private static class VectorRow implements io.delta.kernel.data.Row { + private final StructType schema; + private final ColumnVector[] fields; + private final int rowIndex; + + VectorRow(StructType schema, ColumnVector[] fields, int rowIndex) { + this.schema = schema; + this.fields = fields; + this.rowIndex = rowIndex; + } + + @Override + public StructType getSchema() { + return schema; + } + + @Override + public boolean isNullAt(int ord) { + return fields[ord].isNullAt(rowIndex); + } + + @Override + public boolean getBoolean(int ord) { + return fields[ord].getBoolean(rowIndex); + } + + @Override + public byte getByte(int ord) { + return fields[ord].getByte(rowIndex); + } + + @Override + public short getShort(int ord) { + return fields[ord].getShort(rowIndex); + } + + @Override + public int getInt(int ord) { + return fields[ord].getInt(rowIndex); + } + + @Override + public long getLong(int ord) { + return fields[ord].getLong(rowIndex); + } + + @Override + public float getFloat(int ord) { + return fields[ord].getFloat(rowIndex); + } + + @Override + public double getDouble(int ord) { + return fields[ord].getDouble(rowIndex); + } + + @Override + public String getString(int ord) { + return fields[ord].getString(rowIndex); + } + + @Override + public byte[] getBinary(int ord) { + return fields[ord].getBinary(rowIndex); + } + + @Override + public BigDecimal getDecimal(int ord) { + return fields[ord].getDecimal(rowIndex); + } + + @Override + public io.delta.kernel.data.Row getStruct(int ord) { + StructType childSchema = (StructType) schema.fields().get(ord).getDataType(); + int numFields = childSchema.fields().size(); + ColumnVector[] childFields = new ColumnVector[numFields]; + for (int j = 0; j < numFields; j++) { + childFields[j] = fields[ord].getChild(j); + } + return new VectorRow(childSchema, childFields, rowIndex); + } + + @Override + public ArrayValue getArray(int ord) { + return fields[ord].getArray(rowIndex); + } + + @Override + public MapValue getMap(int ord) { + return fields[ord].getMap(rowIndex); + } + } +} diff --git a/sdks/java/io/delta/src/main/java/org/apache/beam/sdk/io/delta/SerializableRow.java b/sdks/java/io/delta/src/main/java/org/apache/beam/sdk/io/delta/SerializableRow.java new file mode 100644 index 000000000000..3ce3f1f5ee0f --- /dev/null +++ b/sdks/java/io/delta/src/main/java/org/apache/beam/sdk/io/delta/SerializableRow.java @@ -0,0 +1,534 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.delta; + +import io.delta.kernel.data.ArrayValue; +import io.delta.kernel.data.ColumnVector; +import io.delta.kernel.data.MapValue; +import io.delta.kernel.data.Row; +import io.delta.kernel.types.ArrayType; +import io.delta.kernel.types.BinaryType; +import io.delta.kernel.types.BooleanType; +import io.delta.kernel.types.ByteType; +import io.delta.kernel.types.DataType; +import io.delta.kernel.types.DecimalType; +import io.delta.kernel.types.DoubleType; +import io.delta.kernel.types.FloatType; +import io.delta.kernel.types.IntegerType; +import io.delta.kernel.types.LongType; +import io.delta.kernel.types.MapType; +import io.delta.kernel.types.ShortType; +import io.delta.kernel.types.StringType; +import io.delta.kernel.types.StructType; +import java.io.Serializable; +import java.math.BigDecimal; +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import org.checkerframework.checker.nullness.qual.Nullable; + +/** + * A serializable wrapper for Delta {@link Row} that implements the {@link Row} interface itself, + * allowing worker nodes to access serialized Row objects using standard Delta Kernel APIs. + */ +public class SerializableRow implements Row, Serializable { + private static final long serialVersionUID = 1L; + + private final SerializableStructType schema; + private final @Nullable Object[] values; + + public SerializableRow(Row row) { + this.schema = new SerializableStructType(row.getSchema()); + StructType structType = row.getSchema(); + int numFields = structType.fields().size(); + this.values = new Object[numFields]; + for (int i = 0; i < numFields; i++) { + DataType type = structType.fields().get(i).getDataType(); + this.values[i] = getValue(row, i, type); + } + } + + @Override + public StructType getSchema() { + return schema.get(); + } + + @Override + public boolean isNullAt(int ord) { + return values == null || values[ord] == null; + } + + @Override + public boolean getBoolean(int ord) { + Object val = Objects.requireNonNull(values)[ord]; + return val != null ? (Boolean) val : false; + } + + @Override + public byte getByte(int ord) { + Object val = Objects.requireNonNull(values)[ord]; + return val != null ? (Byte) val : 0; + } + + @Override + public short getShort(int ord) { + Object val = Objects.requireNonNull(values)[ord]; + return val != null ? (Short) val : 0; + } + + @Override + public int getInt(int ord) { + Object val = Objects.requireNonNull(values)[ord]; + return val != null ? (Integer) val : 0; + } + + @Override + public long getLong(int ord) { + Object val = Objects.requireNonNull(values)[ord]; + return val != null ? (Long) val : 0L; + } + + @Override + public float getFloat(int ord) { + Object val = Objects.requireNonNull(values)[ord]; + return val != null ? (Float) val : 0.0f; + } + + @Override + public double getDouble(int ord) { + Object val = Objects.requireNonNull(values)[ord]; + return val != null ? (Double) val : 0.0d; + } + + @Override + @SuppressWarnings("nullness") + public String getString(int ord) { + return (String) Objects.requireNonNull(values)[ord]; + } + + @Override + @SuppressWarnings("nullness") + public byte[] getBinary(int ord) { + return (byte[]) Objects.requireNonNull(values)[ord]; + } + + @Override + @SuppressWarnings("nullness") + public BigDecimal getDecimal(int ord) { + return (BigDecimal) Objects.requireNonNull(values)[ord]; + } + + @Override + @SuppressWarnings("nullness") + public Row getStruct(int ord) { + return (Row) Objects.requireNonNull(values)[ord]; + } + + @Override + @SuppressWarnings({"unchecked", "nullness"}) + public ArrayValue getArray(int ord) { + Object val = Objects.requireNonNull(values)[ord]; + if (val == null) { + return null; + } + DataType elementType = + ((ArrayType) getSchema().fields().get(ord).getDataType()).getElementType(); + return new SerializableArrayValue((List<@Nullable Object>) val, elementType); + } + + @Override + @SuppressWarnings({"unchecked", "nullness"}) + public MapValue getMap(int ord) { + Object val = Objects.requireNonNull(values)[ord]; + if (val == null) { + return null; + } + MapType mapType = (MapType) getSchema().fields().get(ord).getDataType(); + return new SerializableMapValue( + (Map) val, mapType.getKeyType(), mapType.getValueType()); + } + + @Override + public boolean equals(@Nullable Object o) { + if (this == o) { + return true; + } + if (!(o instanceof SerializableRow)) { + return false; + } + SerializableRow that = (SerializableRow) o; + return Objects.equals(schema, that.schema) && java.util.Arrays.deepEquals(values, that.values); + } + + @Override + public int hashCode() { + return Objects.hash(schema, java.util.Arrays.deepHashCode(values)); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append("SerializableRow{schema=").append(schema).append(", values=["); + if (values != null) { + for (int i = 0; i < values.length; i++) { + if (i > 0) { + sb.append(", "); + } + sb.append(values[i]); + } + } + sb.append("]}"); + return sb.toString(); + } + + private static @Nullable Object getValue(Row row, int index, DataType type) { + if (row.isNullAt(index)) { + return null; + } + if (type instanceof BooleanType) { + return row.getBoolean(index); + } else if (type instanceof ByteType) { + return row.getByte(index); + } else if (type instanceof ShortType) { + return row.getShort(index); + } else if (type instanceof IntegerType) { + return row.getInt(index); + } else if (type instanceof LongType) { + return row.getLong(index); + } else if (type instanceof FloatType) { + return row.getFloat(index); + } else if (type instanceof DoubleType) { + return row.getDouble(index); + } else if (type instanceof StringType) { + return row.getString(index); + } else if (type instanceof BinaryType) { + return row.getBinary(index); + } else if (type instanceof DecimalType) { + return row.getDecimal(index); + } else if (type instanceof StructType) { + return new SerializableRow(row.getStruct(index)); + } else if (type instanceof ArrayType) { + ArrayValue arr = row.getArray(index); + return convertArray(arr, (ArrayType) type); + } else if (type instanceof MapType) { + MapValue map = row.getMap(index); + return convertMap(map, (MapType) type); + } + throw new IllegalArgumentException("Unsupported type: " + type); + } + + private static @Nullable Object getVectorValue(ColumnVector vector, int index, DataType type) { + if (vector.isNullAt(index)) { + return null; + } + if (type instanceof BooleanType) { + return vector.getBoolean(index); + } else if (type instanceof ByteType) { + return vector.getByte(index); + } else if (type instanceof ShortType) { + return vector.getShort(index); + } else if (type instanceof IntegerType) { + return vector.getInt(index); + } else if (type instanceof LongType) { + return vector.getLong(index); + } else if (type instanceof FloatType) { + return vector.getFloat(index); + } else if (type instanceof DoubleType) { + return vector.getDouble(index); + } else if (type instanceof StringType) { + return vector.getString(index); + } else if (type instanceof BinaryType) { + return vector.getBinary(index); + } else if (type instanceof DecimalType) { + return vector.getDecimal(index); + } else if (type instanceof StructType) { + StructType structType = (StructType) type; + int numFields = structType.fields().size(); + ColumnVector[] childFields = new ColumnVector[numFields]; + for (int j = 0; j < numFields; j++) { + childFields[j] = vector.getChild(j); + } + return new SerializableRow(new VectorRow(structType, childFields, index)); + } else if (type instanceof ArrayType) { + ArrayValue arr = vector.getArray(index); + return convertArray(arr, (ArrayType) type); + } else if (type instanceof MapType) { + MapValue map = vector.getMap(index); + return convertMap(map, (MapType) type); + } + throw new IllegalArgumentException("Unsupported vector type: " + type); + } + + private static List<@Nullable Object> convertArray(ArrayValue arr, ArrayType arrayType) { + int size = arr.getSize(); + ColumnVector elements = arr.getElements(); + DataType elementType = arrayType.getElementType(); + List<@Nullable Object> result = new ArrayList<>(size); + for (int i = 0; i < size; i++) { + result.add(getVectorValue(elements, i, elementType)); + } + return result; + } + + private static Map convertMap(MapValue map, MapType mapType) { + int size = map.getSize(); + ColumnVector keys = map.getKeys(); + ColumnVector values = map.getValues(); + DataType keyType = mapType.getKeyType(); + DataType valueType = mapType.getValueType(); + Map result = new LinkedHashMap<>(size); + for (int i = 0; i < size; i++) { + Object key = getVectorValue(keys, i, keyType); + if (key != null) { + result.put(key, getVectorValue(values, i, valueType)); + } + } + return result; + } + + private static class VectorRow implements Row { + private final StructType schema; + private final ColumnVector[] fields; + private final int rowIndex; + + VectorRow(StructType schema, ColumnVector[] fields, int rowIndex) { + this.schema = schema; + this.fields = fields; + this.rowIndex = rowIndex; + } + + @Override + public StructType getSchema() { + return schema; + } + + @Override + public boolean isNullAt(int ord) { + return fields[ord].isNullAt(rowIndex); + } + + @Override + public boolean getBoolean(int ord) { + return fields[ord].getBoolean(rowIndex); + } + + @Override + public byte getByte(int ord) { + return fields[ord].getByte(rowIndex); + } + + @Override + public short getShort(int ord) { + return fields[ord].getShort(rowIndex); + } + + @Override + public int getInt(int ord) { + return fields[ord].getInt(rowIndex); + } + + @Override + public long getLong(int ord) { + return fields[ord].getLong(rowIndex); + } + + @Override + public float getFloat(int ord) { + return fields[ord].getFloat(rowIndex); + } + + @Override + public double getDouble(int ord) { + return fields[ord].getDouble(rowIndex); + } + + @Override + public String getString(int ord) { + return fields[ord].getString(rowIndex); + } + + @Override + public byte[] getBinary(int ord) { + return fields[ord].getBinary(rowIndex); + } + + @Override + public BigDecimal getDecimal(int ord) { + return fields[ord].getDecimal(rowIndex); + } + + @Override + public Row getStruct(int ord) { + StructType childSchema = (StructType) schema.fields().get(ord).getDataType(); + int numFields = childSchema.fields().size(); + ColumnVector[] childFields = new ColumnVector[numFields]; + for (int j = 0; j < numFields; j++) { + childFields[j] = fields[ord].getChild(j); + } + return new VectorRow(childSchema, childFields, rowIndex); + } + + @Override + public ArrayValue getArray(int ord) { + return fields[ord].getArray(rowIndex); + } + + @Override + public MapValue getMap(int ord) { + return fields[ord].getMap(rowIndex); + } + } + + private static class ListColumnVector implements ColumnVector { + private final DataType dataType; + private final List<@Nullable Object> list; + + @SuppressWarnings("unchecked") + ListColumnVector(DataType dataType, List list) { + this.dataType = dataType; + this.list = (List<@Nullable Object>) list; + } + + @Override + public DataType getDataType() { + return dataType; + } + + @Override + public int getSize() { + return list.size(); + } + + @Override + public boolean isNullAt(int rowId) { + return list.get(rowId) == null; + } + + @Override + public boolean getBoolean(int rowId) { + Object val = list.get(rowId); + return val != null ? (Boolean) val : false; + } + + @Override + public byte getByte(int rowId) { + Object val = list.get(rowId); + return val != null ? (Byte) val : 0; + } + + @Override + public short getShort(int rowId) { + Object val = list.get(rowId); + return val != null ? (Short) val : 0; + } + + @Override + public int getInt(int rowId) { + Object val = list.get(rowId); + return val != null ? (Integer) val : 0; + } + + @Override + public long getLong(int rowId) { + Object val = list.get(rowId); + return val != null ? (Long) val : 0L; + } + + @Override + public float getFloat(int rowId) { + Object val = list.get(rowId); + return val != null ? (Float) val : 0.0f; + } + + @Override + public double getDouble(int rowId) { + Object val = list.get(rowId); + return val != null ? (Double) val : 0.0d; + } + + @Override + @SuppressWarnings("nullness") + public String getString(int rowId) { + return (String) list.get(rowId); + } + + @Override + @SuppressWarnings("nullness") + public byte[] getBinary(int rowId) { + return (byte[]) list.get(rowId); + } + + @Override + @SuppressWarnings("nullness") + public BigDecimal getDecimal(int rowId) { + return (BigDecimal) list.get(rowId); + } + + @Override + public void close() {} + } + + private static class SerializableArrayValue implements ArrayValue { + private final List list; + private final DataType elementType; + + SerializableArrayValue(List list, DataType elementType) { + this.list = list; + this.elementType = elementType; + } + + @Override + public int getSize() { + return list.size(); + } + + @Override + public ColumnVector getElements() { + return new ListColumnVector(elementType, list); + } + } + + private static class SerializableMapValue implements MapValue { + private final Map map; + private final DataType keyType; + private final DataType valueType; + + SerializableMapValue(Map map, DataType keyType, DataType valueType) { + this.map = map; + this.keyType = keyType; + this.valueType = valueType; + } + + @Override + public int getSize() { + return map.size(); + } + + @Override + public ColumnVector getKeys() { + return new ListColumnVector(keyType, new ArrayList<>(map.keySet())); + } + + @Override + public ColumnVector getValues() { + return new ListColumnVector(valueType, new ArrayList<>(map.values())); + } + } +} diff --git a/sdks/java/io/delta/src/main/java/org/apache/beam/sdk/io/delta/SerializableStructType.java b/sdks/java/io/delta/src/main/java/org/apache/beam/sdk/io/delta/SerializableStructType.java new file mode 100644 index 000000000000..08a27ac55303 --- /dev/null +++ b/sdks/java/io/delta/src/main/java/org/apache/beam/sdk/io/delta/SerializableStructType.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.delta; + +import io.delta.kernel.internal.types.DataTypeJsonSerDe; +import io.delta.kernel.types.StructType; +import java.io.Serializable; +import java.util.Objects; +import org.checkerframework.checker.nullness.qual.Nullable; + +/** + * A serializable wrapper for {@link StructType} using the Delta Kernel JSON + * serializer/deserializer. + */ +public class SerializableStructType implements Serializable { + private static final long serialVersionUID = 1L; + + private final String jsonSchema; + private transient StructType structType; + + public SerializableStructType(StructType structType) { + this.structType = Objects.requireNonNull(structType, "structType cannot be null"); + this.jsonSchema = DataTypeJsonSerDe.serializeStructType(structType); + } + + public StructType get() { + if (structType == null) { + structType = DataTypeJsonSerDe.deserializeStructType(jsonSchema); + } + return structType; + } + + @Override + public boolean equals(@Nullable Object o) { + if (this == o) { + return true; + } + if (!(o instanceof SerializableStructType)) { + return false; + } + SerializableStructType that = (SerializableStructType) o; + return Objects.equals(jsonSchema, that.jsonSchema); + } + + @Override + public int hashCode() { + return Objects.hash(jsonSchema); + } + + @Override + public String toString() { + return jsonSchema; + } +} diff --git a/sdks/java/io/delta/src/test/java/org/apache/beam/sdk/io/delta/DeltaIOTest.java b/sdks/java/io/delta/src/test/java/org/apache/beam/sdk/io/delta/DeltaIOTest.java index 4ab932fc5d83..c08a4c8ee17d 100644 --- a/sdks/java/io/delta/src/test/java/org/apache/beam/sdk/io/delta/DeltaIOTest.java +++ b/sdks/java/io/delta/src/test/java/org/apache/beam/sdk/io/delta/DeltaIOTest.java @@ -17,18 +17,57 @@ */ package org.apache.beam.sdk.io.delta; +import io.delta.kernel.types.ArrayType; +import io.delta.kernel.types.BinaryType; +import io.delta.kernel.types.BooleanType; +import io.delta.kernel.types.DateType; +import io.delta.kernel.types.DoubleType; +import io.delta.kernel.types.FloatType; +import io.delta.kernel.types.IntegerType; +import io.delta.kernel.types.LongType; +import io.delta.kernel.types.MapType; +import io.delta.kernel.types.StringType; +import io.delta.kernel.types.StructField; +import io.delta.kernel.types.StructType; +import io.delta.kernel.types.TimestampType; +import java.io.File; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; import java.util.HashMap; import java.util.Map; +import org.apache.avro.generic.GenericRecord; +import org.apache.beam.sdk.extensions.avro.coders.AvroCoder; +import org.apache.beam.sdk.extensions.avro.schemas.utils.AvroUtils; +import org.apache.beam.sdk.io.Compression; +import org.apache.beam.sdk.io.FileIO; import org.apache.beam.sdk.io.delta.DeltaIO.ReadRows; +import org.apache.beam.sdk.io.parquet.ParquetIO; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Count; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.PaneInfo; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.Row; import org.junit.Assert; +import org.junit.Rule; import org.junit.Test; +import org.junit.rules.TemporaryFolder; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; -/** Unit tests for the {@link DeltaIO}. */ +/** Unit and local integration tests for {@link DeltaIO}. */ @RunWith(JUnit4.class) public class DeltaIOTest { + @Rule public TestPipeline writePipeline = TestPipeline.create(); + @Rule public TestPipeline readPipeline = TestPipeline.create(); + @Rule public TestPipeline filteringPipeline = TestPipeline.create(); + @Rule public TemporaryFolder tempFolder = new TemporaryFolder(); + @Test public void testReadRowsBuilderAndGetters() { String tablePath = "/path/to/table"; @@ -59,4 +98,664 @@ public void testReadRowsNullDefaults() { Assert.assertNull(readRows.getTimestamp()); Assert.assertNull(readRows.getHadoopConfig()); } + + @Test + public void testPrintScanStateSchema() throws Exception { + File tableDir = tempFolder.newFolder("delta-table-schema"); + File logDir = new File(tableDir, "_delta_log"); + logDir.mkdirs(); + File commitFile = new File(logDir, "00000000000000000000.json"); + + String commitContent = + "{\"protocol\":{\"minReaderVersion\":1,\"minWriterVersion\":2}}\n" + + "{\"metaData\":{\"id\":\"test-id\",\"format\":{\"provider\":\"parquet\",\"options\":{}},\"schemaString\":\"{\\\"type\\\":\\\"struct\\\",\\\"fields\\\":[{\\\"name\\\":\\\"name\\\",\\\"type\\\":\\\"string\\\",\\\"nullable\\\":true,\\\"metadata\\\":{}}]}\",\"partitionColumns\":[],\"configuration\":{},\"createdAt\":123456789}}\n" + + "{\"add\":{\"path\":\"part-00000.parquet\",\"partitionValues\":{},\"size\":100,\"modificationTime\":123456789,\"dataChange\":true}}"; + + Files.write(commitFile.toPath(), commitContent.getBytes(StandardCharsets.UTF_8)); + + io.delta.kernel.defaults.engine.DefaultEngine engine = + io.delta.kernel.defaults.engine.DefaultEngine.create( + new org.apache.hadoop.conf.Configuration()); + io.delta.kernel.Table table = io.delta.kernel.Table.forPath(engine, tableDir.getAbsolutePath()); + io.delta.kernel.Snapshot snapshot = table.getLatestSnapshot(engine); + io.delta.kernel.Scan scan = snapshot.getScanBuilder().build(); + + io.delta.kernel.data.Row scanState = scan.getScanState(engine); + System.err.println("SCAN STATE SCHEMA: " + scanState.getSchema().toString()); + + try (io.delta.kernel.utils.CloseableIterator + scanFiles = scan.getScanFiles(engine)) { + while (scanFiles.hasNext()) { + io.delta.kernel.data.FilteredColumnarBatch batch = scanFiles.next(); + try (io.delta.kernel.utils.CloseableIterator rows = + batch.getRows()) { + while (rows.hasNext()) { + io.delta.kernel.data.Row row = rows.next(); + verifySerialization(row); + } + } + } + } + } + + private void verifySerialization(io.delta.kernel.data.Row row) throws Exception { + SerializableRow serializableRow = new SerializableRow(row); + + // Serialize using standard Java Serialization + java.io.ByteArrayOutputStream baos = new java.io.ByteArrayOutputStream(); + try (java.io.ObjectOutputStream oos = new java.io.ObjectOutputStream(baos)) { + oos.writeObject(serializableRow); + } + + byte[] bytes = baos.toByteArray(); + + // Deserialize + SerializableRow deserializedRow; + java.io.ByteArrayInputStream bais = new java.io.ByteArrayInputStream(bytes); + try (java.io.ObjectInputStream ois = new java.io.ObjectInputStream(bais)) { + deserializedRow = (SerializableRow) ois.readObject(); + } + + // Assert equals + org.junit.Assert.assertEquals(serializableRow, deserializedRow); + org.junit.Assert.assertEquals( + row.getSchema().toString(), deserializedRow.getSchema().toString()); + + // Deep verify fields + io.delta.kernel.types.StructType schema = row.getSchema(); + for (int i = 0; i < schema.fields().size(); i++) { + org.junit.Assert.assertEquals(row.isNullAt(i), deserializedRow.isNullAt(i)); + if (!row.isNullAt(i)) { + io.delta.kernel.types.DataType type = schema.fields().get(i).getDataType(); + if (type instanceof io.delta.kernel.types.StringType) { + org.junit.Assert.assertEquals(row.getString(i), deserializedRow.getString(i)); + } else if (type instanceof io.delta.kernel.types.LongType) { + org.junit.Assert.assertEquals(row.getLong(i), deserializedRow.getLong(i)); + } + } + } + } + + @Test + public void testCreateReadTasksDoFn() throws Exception { + File tableDir = tempFolder.newFolder("delta-table"); + File logDir = new File(tableDir, "_delta_log"); + logDir.mkdirs(); + File commitFile = new File(logDir, "00000000000000000000.json"); + + String commitContent = + "{\"protocol\":{\"minReaderVersion\":1,\"minWriterVersion\":2}}\n" + + "{\"metaData\":{\"id\":\"test-id\",\"format\":{\"provider\":\"parquet\",\"options\":{}},\"schemaString\":\"{\\\"type\\\":\\\"struct\\\",\\\"fields\\\":[{\\\"name\\\":\\\"name\\\",\\\"type\\\":\\\"string\\\",\\\"nullable\\\":true,\\\"metadata\\\":{}}]}\",\"partitionColumns\":[],\"configuration\":{},\"createdAt\":123456789}}\n" + + "{\"add\":{\"path\":\"part-00000.parquet\",\"partitionValues\":{},\"size\":100,\"modificationTime\":123456789,\"dataChange\":true}}"; + + Files.write(commitFile.toPath(), commitContent.getBytes(StandardCharsets.UTF_8)); + + PCollection output = + writePipeline + .apply(Create.of(tableDir.getAbsolutePath())) + .apply(ParDo.of(new CreateReadTasksDoFn(null))); + + PCollection paths = + output.apply( + org.apache.beam.sdk.transforms.MapElements.into( + org.apache.beam.sdk.values.TypeDescriptors.strings()) + .via( + task -> + io.delta.kernel.internal.InternalScanFileUtils.getAddFileStatus( + task.getScanFileRows().get(0)) + .getPath())); + + PAssert.that(paths) + .containsInAnyOrder("file:" + tableDir.getAbsolutePath() + "/part-00000.parquet"); + + writePipeline.run().waitUntilFinish(); + } + + @Test + public void testCreateReadTasksDoFnGrouping() throws Exception { + File tableDir = tempFolder.newFolder("delta-table-grouping"); + File logDir = new File(tableDir, "_delta_log"); + logDir.mkdirs(); + File commitFile = new File(logDir, "00000000000000000000.json"); + + String commitContent = + "{\"protocol\":{\"minReaderVersion\":1,\"minWriterVersion\":2}}\n" + + "{\"metaData\":{\"id\":\"test-id\",\"format\":{\"provider\":\"parquet\",\"options\":{}},\"schemaString\":\"{\\\"type\\\":\\\"struct\\\",\\\"fields\\\":[{\\\"name\\\":\\\"name\\\",\\\"type\\\":\\\"string\\\",\\\"nullable\\\":true,\\\"metadata\\\":{}}]}\",\"partitionColumns\":[],\"configuration\":{},\"createdAt\":123456789}}\n" + + "{\"add\":{\"path\":\"part-00001.parquet\",\"partitionValues\":{},\"size\":400000000,\"modificationTime\":123456789,\"dataChange\":true}}\n" + + "{\"add\":{\"path\":\"part-00002.parquet\",\"partitionValues\":{},\"size\":400000000,\"modificationTime\":123456789,\"dataChange\":true}}\n" + + "{\"add\":{\"path\":\"part-00003.parquet\",\"partitionValues\":{},\"size\":1200000000,\"modificationTime\":123456789,\"dataChange\":true}}\n" + + "{\"add\":{\"path\":\"part-00004.parquet\",\"partitionValues\":{},\"size\":100,\"modificationTime\":123456789,\"dataChange\":true}}"; + + Files.write(commitFile.toPath(), commitContent.getBytes(StandardCharsets.UTF_8)); + + PCollection output = + writePipeline + .apply("Create Grouping Input", Create.of(tableDir.getAbsolutePath())) + .apply("Plan Grouped Files", ParDo.of(new CreateReadTasksDoFn(null))); + + PCollection taskDescriptions = + output.apply( + org.apache.beam.sdk.transforms.MapElements.into( + org.apache.beam.sdk.values.TypeDescriptors.strings()) + .via( + task -> { + StringBuilder sb = new StringBuilder(); + for (SerializableRow row : task.getScanFileRows()) { + if (sb.length() > 0) { + sb.append(","); + } + String fullPath = + io.delta.kernel.internal.InternalScanFileUtils.getAddFileStatus(row) + .getPath(); + String filename = fullPath.substring(fullPath.lastIndexOf('/') + 1); + sb.append(filename); + } + return sb.toString(); + })); + + PAssert.that(taskDescriptions) + .containsInAnyOrder( + "part-00001.parquet,part-00002.parquet", "part-00003.parquet", "part-00004.parquet"); + + writePipeline.run().waitUntilFinish(); + } + + @Test + public void testFullPipelineRead() throws Exception { + File tableDir = tempFolder.newFolder("delta-table-full"); + + // 1. Write a Parquet file using Beam + Schema schema = Schema.builder().addField("name", Schema.FieldType.STRING).build(); + Row row = Row.withSchema(schema).addValues("test-name").build(); + + org.apache.avro.Schema avroSchema = AvroUtils.toAvroSchema(schema); + GenericRecord record = AvroUtils.toGenericRecord(row, avroSchema); + + writePipeline + .apply("Create Input", Create.of(record).withCoder(AvroCoder.of(avroSchema))) + .apply( + "Write Parquet", + FileIO.write() + .via(ParquetIO.sink(avroSchema)) + .to(tableDir.getAbsolutePath() + "/") + .withNaming( + (BoundedWindow window, + PaneInfo paneInfo, + int numShards, + int shardIndex, + Compression compression) -> "part-00000.parquet")); + + writePipeline.run().waitUntilFinish(); + + System.out.println("FILES IN TABLE DIR:"); + for (File f : tableDir.listFiles()) { + System.out.println( + " - " + f.getName() + " (size=" + f.length() + ", isDir=" + f.isDirectory() + ")"); + if (f.isDirectory()) { + for (File sub : f.listFiles()) { + System.out.println(" - " + sub.getName() + " (size=" + sub.length() + ")"); + } + } + } + + File parquetFile = new File(tableDir, "part-00000.parquet"); + byte[] fileBytes = Files.readAllBytes(parquetFile.toPath()); + System.out.println("PARQUET FILE LENGTH: " + fileBytes.length); + if (fileBytes.length >= 8) { + System.out.println( + "PARQUET FIRST 4 BYTES: " + + fileBytes[0] + + ", " + + fileBytes[1] + + ", " + + fileBytes[2] + + ", " + + fileBytes[3] + + " ('" + + (char) fileBytes[0] + + (char) fileBytes[1] + + (char) fileBytes[2] + + (char) fileBytes[3] + + "')"); + int len = fileBytes.length; + System.out.println( + "PARQUET LAST 4 BYTES: " + + fileBytes[len - 4] + + ", " + + fileBytes[len - 3] + + ", " + + fileBytes[len - 2] + + ", " + + fileBytes[len - 1] + + " ('" + + (char) fileBytes[len - 4] + + (char) fileBytes[len - 3] + + (char) fileBytes[len - 2] + + (char) fileBytes[len - 1] + + "')"); + } + + // 2. Create the Delta log + File logDir = new File(tableDir, "_delta_log"); + logDir.mkdirs(); + File commitFile = new File(logDir, "00000000000000000000.json"); + + String commitContent = + "{\"protocol\":{\"minReaderVersion\":1,\"minWriterVersion\":2}}\n" + + "{\"metaData\":{\"id\":\"test-id\",\"format\":{\"provider\":\"parquet\",\"options\":{}},\"schemaString\":\"{\\\"type\\\":\\\"struct\\\",\\\"fields\\\":[{\\\"name\\\":\\\"name\\\",\\\"type\\\":\\\"string\\\",\\\"nullable\\\":true,\\\"metadata\\\":{}}]}\",\"partitionColumns\":[],\"configuration\":{},\"createdAt\":123456789}}\n" + + "{\"add\":{\"path\":\"part-00000.parquet\",\"partitionValues\":{},\"size\":" + + fileBytes.length + + ",\"modificationTime\":123456789,\"dataChange\":true}}"; + + Files.write(commitFile.toPath(), commitContent.getBytes(StandardCharsets.UTF_8)); + + // 3. Read it using DeltaIO + PCollection output = + readPipeline.apply(DeltaIO.readRows().from(tableDir.getAbsolutePath())); + + PAssert.that(output).containsInAnyOrder(row); + + readPipeline.run().waitUntilFinish(); + } + + private byte[] writeParquetFile(File file, Row row) throws Exception { + org.apache.avro.Schema avroSchema = + org.apache.beam.sdk.extensions.avro.schemas.utils.AvroUtils.toAvroSchema(row.getSchema()); + org.apache.avro.generic.GenericRecord record = + org.apache.beam.sdk.extensions.avro.schemas.utils.AvroUtils.toGenericRecord( + row, avroSchema); + org.apache.hadoop.fs.Path path = new org.apache.hadoop.fs.Path(file.getAbsolutePath()); + try (org.apache.parquet.hadoop.ParquetWriter writer = + org.apache.parquet.avro.AvroParquetWriter.builder( + path) + .withSchema(avroSchema) + .withConf(new org.apache.hadoop.conf.Configuration()) + .build()) { + writer.write(record); + } + return java.nio.file.Files.readAllBytes(file.toPath()); + } + + @Test + @org.junit.Ignore("Manual integration test with external local table") + public void testReadingLocalTable() throws Exception { + PCollection output = + readPipeline.apply( + DeltaIO.readRows() + .from("/Users/chamikara/testing/delta_lake/test_repo/test_table_1_gb")); + PCollection counted = output.apply(Count.globally()); + + counted + .apply( + "Convert to String", + org.apache.beam.sdk.transforms.MapElements.into( + org.apache.beam.sdk.values.TypeDescriptors.strings()) + .via(String::valueOf)) + .apply( + "Write to File", + org.apache.beam.sdk.io.TextIO.write() + .to("/Users/chamikara/testing/delta_lake/test_repo_pipeline_output/output") + .withSuffix(".txt") + .withoutSharding()); + + readPipeline.run().waitUntilFinish(); + } + + @Test + public void testConvertToBeamSchema() { + StructType deltaSchema = + new StructType( + java.util.Arrays.asList( + new StructField("string", StringType.STRING, false), + new StructField("integer", IntegerType.INTEGER, false), + new StructField("long", LongType.LONG, false), + new StructField("float", FloatType.FLOAT, false), + new StructField("double", DoubleType.DOUBLE, false), + new StructField("boolean", BooleanType.BOOLEAN, false), + new StructField("binary", BinaryType.BINARY, false), + new StructField("timestamp", TimestampType.TIMESTAMP, false), + new StructField("date", DateType.DATE, false), + new StructField("array", new ArrayType(StringType.STRING, true), false), + new StructField( + "map", new MapType(StringType.STRING, IntegerType.INTEGER, true), false), + new StructField( + "struct", + new StructType( + java.util.Arrays.asList( + new StructField("nested_string", StringType.STRING, false))), + false))); + + Schema nestedSchema = + Schema.builder().addField("nested_string", Schema.FieldType.STRING).build(); + + Schema expectedSchema = + Schema.builder() + .addField("string", Schema.FieldType.STRING) + .addField("integer", Schema.FieldType.INT32) + .addField("long", Schema.FieldType.INT64) + .addField("float", Schema.FieldType.FLOAT) + .addField("double", Schema.FieldType.DOUBLE) + .addField("boolean", Schema.FieldType.BOOLEAN) + .addField("binary", Schema.FieldType.BYTES) + .addField("timestamp", Schema.FieldType.DATETIME) + .addField("date", Schema.FieldType.DATETIME) + .addField("array", Schema.FieldType.iterable(Schema.FieldType.STRING)) + .addField("map", Schema.FieldType.map(Schema.FieldType.STRING, Schema.FieldType.INT32)) + .addField("struct", Schema.FieldType.row(nestedSchema)) + .build(); + + Schema actualSchema = DeltaIO.ReadRows.convertToBeamSchema(deltaSchema); + org.junit.Assert.assertEquals(expectedSchema, actualSchema); + } + + @Test + public void testDeltaReadTaskTracker() { + java.util.List sizes = java.util.Arrays.asList(100L, 200L, 300L); + org.apache.beam.sdk.io.range.OffsetRange range = + new org.apache.beam.sdk.io.range.OffsetRange(0L, 3L); + DeltaReadTaskTracker tracker = new DeltaReadTaskTracker(range, sizes); + + org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker.Progress progress = + tracker.getProgress(); + org.junit.Assert.assertEquals(0.0, progress.getWorkCompleted(), 0.001); + org.junit.Assert.assertEquals(600.0, progress.getWorkRemaining(), 0.001); + + org.junit.Assert.assertTrue(tracker.tryClaim(0L)); + progress = tracker.getProgress(); + org.junit.Assert.assertEquals(100.0, progress.getWorkCompleted(), 0.001); + org.junit.Assert.assertEquals(500.0, progress.getWorkRemaining(), 0.001); + + org.junit.Assert.assertTrue(tracker.tryClaim(1L)); + progress = tracker.getProgress(); + org.junit.Assert.assertEquals(300.0, progress.getWorkCompleted(), 0.001); + org.junit.Assert.assertEquals(300.0, progress.getWorkRemaining(), 0.001); + + org.junit.Assert.assertTrue(tracker.tryClaim(2L)); + progress = tracker.getProgress(); + org.junit.Assert.assertEquals(600.0, progress.getWorkCompleted(), 0.001); + org.junit.Assert.assertEquals(0.0, progress.getWorkRemaining(), 0.001); + + tracker.checkDone(); + } + + @Test + public void testBeamParquetHandler() { + java.util.List sizes = java.util.Arrays.asList(100L, 200L); + org.apache.beam.sdk.io.range.OffsetRange range = + new org.apache.beam.sdk.io.range.OffsetRange(0L, 2L); + DeltaReadTaskTracker tracker = new DeltaReadTaskTracker(range, sizes); + + org.apache.hadoop.conf.Configuration conf = new org.apache.hadoop.conf.Configuration(); + io.delta.kernel.engine.ParquetHandler dummyDelegate = + new io.delta.kernel.engine.ParquetHandler() { + @Override + public io.delta.kernel.utils.CloseableIterator + readParquetFiles( + io.delta.kernel.utils.CloseableIterator + fileIter, + io.delta.kernel.types.StructType physicalSchema, + java.util.Optional predicate) + throws java.io.IOException { + return new io.delta.kernel.utils.CloseableIterator< + io.delta.kernel.engine.FileReadResult>() { + @Override + public boolean hasNext() { + return false; + } + + @Override + public io.delta.kernel.engine.FileReadResult next() { + throw new java.util.NoSuchElementException(); + } + + @Override + public void close() {} + }; + } + + @Override + public void writeParquetFileAtomically( + String filePath, + io.delta.kernel.utils.CloseableIterator + data) + throws java.io.IOException {} + + @Override + public io.delta.kernel.utils.CloseableIterator + writeParquetFiles( + String filePath, + io.delta.kernel.utils.CloseableIterator< + io.delta.kernel.data.FilteredColumnarBatch> + data, + java.util.List statsColumns) + throws java.io.IOException { + return new io.delta.kernel.utils.CloseableIterator< + io.delta.kernel.utils.DataFileStatus>() { + @Override + public boolean hasNext() { + return false; + } + + @Override + public io.delta.kernel.utils.DataFileStatus next() { + throw new java.util.NoSuchElementException(); + } + + @Override + public void close() {} + }; + } + }; + + BeamParquetHandler handler = new BeamParquetHandler(conf, dummyDelegate, tracker); + org.junit.Assert.assertNotNull(handler); + + BeamEngine beamEngine = + new BeamEngine(io.delta.kernel.defaults.engine.DefaultEngine.create(conf), handler); + org.junit.Assert.assertEquals(handler, beamEngine.getParquetHandler()); + } + + @Test + public void testBeamParquetHandlerWriteDelegation() throws Exception { + java.util.List sizes = java.util.Arrays.asList(100L); + org.apache.beam.sdk.io.range.OffsetRange range = + new org.apache.beam.sdk.io.range.OffsetRange(0L, 1L); + DeltaReadTaskTracker tracker = new DeltaReadTaskTracker(range, sizes); + org.apache.hadoop.conf.Configuration conf = new org.apache.hadoop.conf.Configuration(); + + boolean[] flags = new boolean[2]; + io.delta.kernel.engine.ParquetHandler delegate = + new io.delta.kernel.engine.ParquetHandler() { + @Override + public io.delta.kernel.utils.CloseableIterator + readParquetFiles( + io.delta.kernel.utils.CloseableIterator + fileIter, + io.delta.kernel.types.StructType physicalSchema, + java.util.Optional predicate) { + return null; + } + + @Override + public void writeParquetFileAtomically( + String filePath, + io.delta.kernel.utils.CloseableIterator + data) { + flags[0] = true; + } + + @Override + public io.delta.kernel.utils.CloseableIterator + writeParquetFiles( + String filePath, + io.delta.kernel.utils.CloseableIterator< + io.delta.kernel.data.FilteredColumnarBatch> + data, + java.util.List statsColumns) { + flags[1] = true; + return new io.delta.kernel.utils.CloseableIterator< + io.delta.kernel.utils.DataFileStatus>() { + @Override + public boolean hasNext() { + return false; + } + + @Override + public io.delta.kernel.utils.DataFileStatus next() { + throw new java.util.NoSuchElementException(); + } + + @Override + public void close() {} + }; + } + }; + + BeamParquetHandler handler = new BeamParquetHandler(conf, delegate, tracker); + handler.writeParquetFileAtomically("path", null); + org.junit.Assert.assertTrue(flags[0]); + + handler.writeParquetFiles("path", null, java.util.Collections.emptyList()); + org.junit.Assert.assertTrue(flags[1]); + } + + @Test + public void testBeamParquetHandlerReadFiltering() throws Exception { + File tableDir = tempFolder.newFolder("parquet-filtering-test"); + + Schema schema = Schema.builder().addField("name", Schema.FieldType.STRING).build(); + Row row = Row.withSchema(schema).addValues("test-name").build(); + org.apache.avro.Schema avroSchema = AvroUtils.toAvroSchema(schema); + GenericRecord record = AvroUtils.toGenericRecord(row, avroSchema); + + filteringPipeline + .apply("Create Input", Create.of(record).withCoder(AvroCoder.of(avroSchema))) + .apply( + "Write Parquet", + FileIO.write() + .via(ParquetIO.sink(avroSchema)) + .to(tableDir.getAbsolutePath() + "/") + .withNaming((w, p, n, s, c) -> "part-00000.parquet")); + + filteringPipeline.run().waitUntilFinish(); + + File parquetFile = new File(tableDir, "part-00000.parquet"); + io.delta.kernel.utils.FileStatus fileStatus = + io.delta.kernel.utils.FileStatus.of( + parquetFile.getAbsolutePath(), parquetFile.length(), 123456789L); + + org.apache.hadoop.conf.Configuration conf = new org.apache.hadoop.conf.Configuration(); + io.delta.kernel.types.StructType physicalSchema = + new io.delta.kernel.types.StructType( + java.util.Arrays.asList( + new io.delta.kernel.types.StructField( + "name", io.delta.kernel.types.StringType.STRING, true))); + + io.delta.kernel.engine.ParquetHandler dummyDelegate = + new io.delta.kernel.engine.ParquetHandler() { + @Override + public io.delta.kernel.utils.CloseableIterator + readParquetFiles( + io.delta.kernel.utils.CloseableIterator + fileIter, + io.delta.kernel.types.StructType physicalSchema, + java.util.Optional predicate) { + return null; + } + + @Override + public void writeParquetFileAtomically( + String filePath, + io.delta.kernel.utils.CloseableIterator + data) {} + + @Override + public io.delta.kernel.utils.CloseableIterator + writeParquetFiles( + String filePath, + io.delta.kernel.utils.CloseableIterator< + io.delta.kernel.data.FilteredColumnarBatch> + data, + java.util.List statsColumns) { + return null; + } + }; + + // Case A: Out of bounds before (tracker range [10, 20)) + DeltaReadTaskTracker trackerA = + new DeltaReadTaskTracker( + new org.apache.beam.sdk.io.range.OffsetRange(10L, 20L), + java.util.Collections.singletonList(parquetFile.length())); + BeamParquetHandler handlerA = new BeamParquetHandler(conf, dummyDelegate, trackerA); + try (io.delta.kernel.utils.CloseableIterator iter = + handlerA.readParquetFiles( + io.delta.kernel.internal.util.Utils.singletonCloseableIterator(fileStatus), + physicalSchema, + java.util.Optional.empty())) { + org.junit.Assert.assertFalse(iter.hasNext()); + try { + iter.next(); + org.junit.Assert.fail("Expected NoSuchElementException"); + } catch (java.util.NoSuchElementException e) { + // expected + } + } + + // Case B: Out of bounds after (tracker range [0, 0)) + DeltaReadTaskTracker trackerB = + new DeltaReadTaskTracker( + new org.apache.beam.sdk.io.range.OffsetRange(0L, 0L), + java.util.Collections.singletonList(parquetFile.length())); + BeamParquetHandler handlerB = new BeamParquetHandler(conf, dummyDelegate, trackerB); + try (io.delta.kernel.utils.CloseableIterator iter = + handlerB.readParquetFiles( + io.delta.kernel.internal.util.Utils.singletonCloseableIterator(fileStatus), + physicalSchema, + java.util.Optional.empty())) { + org.junit.Assert.assertFalse(iter.hasNext()); + } + + // Case C: Claim fails + DeltaReadTaskTracker trackerC = + new DeltaReadTaskTracker( + new org.apache.beam.sdk.io.range.OffsetRange(0L, 1L), + java.util.Collections.singletonList(parquetFile.length())) { + @Override + public boolean tryClaim(Long i) { + return false; // Simulate failure to claim + } + }; + BeamParquetHandler handlerC = new BeamParquetHandler(conf, dummyDelegate, trackerC); + try (io.delta.kernel.utils.CloseableIterator iter = + handlerC.readParquetFiles( + io.delta.kernel.internal.util.Utils.singletonCloseableIterator(fileStatus), + physicalSchema, + java.util.Optional.empty())) { + org.junit.Assert.assertFalse(iter.hasNext()); + } + + // Case D: Successful claim and read + DeltaReadTaskTracker trackerD = + new DeltaReadTaskTracker( + new org.apache.beam.sdk.io.range.OffsetRange(0L, 1L), + java.util.Collections.singletonList(parquetFile.length())); + BeamParquetHandler handlerD = new BeamParquetHandler(conf, dummyDelegate, trackerD); + try (io.delta.kernel.utils.CloseableIterator iter = + handlerD.readParquetFiles( + io.delta.kernel.internal.util.Utils.singletonCloseableIterator(fileStatus), + physicalSchema, + java.util.Optional.empty())) { + org.junit.Assert.assertTrue(iter.hasNext()); + io.delta.kernel.engine.FileReadResult res = iter.next(); + org.junit.Assert.assertNotNull(res); + org.junit.Assert.assertNotNull(res.getData()); + org.junit.Assert.assertFalse(iter.hasNext()); + try { + iter.next(); + org.junit.Assert.fail("Expected NoSuchElementException"); + } catch (java.util.NoSuchElementException e) { + // expected + } + } + } } diff --git a/sdks/java/io/delta/src/test/java/org/apache/beam/sdk/io/delta/SerializableRowTest.java b/sdks/java/io/delta/src/test/java/org/apache/beam/sdk/io/delta/SerializableRowTest.java new file mode 100644 index 000000000000..743eb60cd784 --- /dev/null +++ b/sdks/java/io/delta/src/test/java/org/apache/beam/sdk/io/delta/SerializableRowTest.java @@ -0,0 +1,486 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.delta; + +import io.delta.kernel.data.ArrayValue; +import io.delta.kernel.data.ColumnVector; +import io.delta.kernel.data.MapValue; +import io.delta.kernel.data.Row; +import io.delta.kernel.types.ArrayType; +import io.delta.kernel.types.BinaryType; +import io.delta.kernel.types.BooleanType; +import io.delta.kernel.types.ByteType; +import io.delta.kernel.types.DataType; +import io.delta.kernel.types.DecimalType; +import io.delta.kernel.types.DoubleType; +import io.delta.kernel.types.FloatType; +import io.delta.kernel.types.IntegerType; +import io.delta.kernel.types.LongType; +import io.delta.kernel.types.MapType; +import io.delta.kernel.types.ShortType; +import io.delta.kernel.types.StringType; +import io.delta.kernel.types.StructField; +import io.delta.kernel.types.StructType; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.math.BigDecimal; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class SerializableRowTest { + + @Test + public void testAllTypesReadWriteSerialization() throws Exception { + StructType structSchema = + new StructType(Arrays.asList(new StructField("nested_int", IntegerType.INTEGER, false))); + + StructType schema = + new StructType( + Arrays.asList( + new StructField("boolean", BooleanType.BOOLEAN, true), + new StructField("byte", ByteType.BYTE, true), + new StructField("short", ShortType.SHORT, true), + new StructField("integer", IntegerType.INTEGER, true), + new StructField("long", LongType.LONG, true), + new StructField("float", FloatType.FLOAT, true), + new StructField("double", DoubleType.DOUBLE, true), + new StructField("string", StringType.STRING, true), + new StructField("binary", BinaryType.BINARY, true), + new StructField("decimal", new DecimalType(10, 2), true), + new StructField("struct", structSchema, true), + new StructField("array", new ArrayType(StringType.STRING, true), true), + new StructField( + "map", new MapType(StringType.STRING, IntegerType.INTEGER, true), true))); + + Map values = new LinkedHashMap<>(); + values.put("boolean", true); + values.put("byte", (byte) 1); + values.put("short", (short) 2); + values.put("integer", 3); + values.put("long", 4L); + values.put("float", 5.0f); + values.put("double", 6.0d); + values.put("string", "hello"); + values.put("binary", new byte[] {7, 8}); + values.put("decimal", new BigDecimal("9.20")); + + Map nestedValues = new LinkedHashMap<>(); + nestedValues.put("nested_int", 42); + values.put("struct", new FakeRow(structSchema, nestedValues)); + + values.put("array", Arrays.asList("a", "b", null)); + + Map mapValues = new LinkedHashMap<>(); + mapValues.put("key1", 100); + mapValues.put("key2", null); + values.put("map", mapValues); + + FakeRow originalRow = new FakeRow(schema, values); + SerializableRow serializableRow = new SerializableRow(originalRow); + + verifyRowContents(serializableRow, schema, values); + + // Test equals and hashCode + SerializableRow serializableRow2 = new SerializableRow(new FakeRow(schema, values)); + Assert.assertEquals(serializableRow, serializableRow2); + Assert.assertEquals(serializableRow.hashCode(), serializableRow2.hashCode()); + Assert.assertNotNull(serializableRow.toString()); + + // Serialize and Deserialize + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + try (ObjectOutputStream oos = new ObjectOutputStream(baos)) { + oos.writeObject(serializableRow); + } + + byte[] bytes = baos.toByteArray(); + SerializableRow deserializedRow; + try (ObjectInputStream ois = new ObjectInputStream(new ByteArrayInputStream(bytes))) { + deserializedRow = (SerializableRow) ois.readObject(); + } + + verifyRowContents(deserializedRow, schema, values); + Assert.assertEquals(serializableRow, deserializedRow); + } + + @Test + public void testNullValuesReadWriteSerialization() throws Exception { + StructType structSchema = + new StructType(Arrays.asList(new StructField("nested_int", IntegerType.INTEGER, true))); + + StructType schema = + new StructType( + Arrays.asList( + new StructField("boolean", BooleanType.BOOLEAN, true), + new StructField("byte", ByteType.BYTE, true), + new StructField("short", ShortType.SHORT, true), + new StructField("integer", IntegerType.INTEGER, true), + new StructField("long", LongType.LONG, true), + new StructField("float", FloatType.FLOAT, true), + new StructField("double", DoubleType.DOUBLE, true), + new StructField("string", StringType.STRING, true), + new StructField("binary", BinaryType.BINARY, true), + new StructField("decimal", new DecimalType(10, 2), true), + new StructField("struct", structSchema, true), + new StructField("array", new ArrayType(StringType.STRING, true), true), + new StructField( + "map", new MapType(StringType.STRING, IntegerType.INTEGER, true), true))); + + Map nullValues = new LinkedHashMap<>(); + for (StructField field : schema.fields()) { + nullValues.put(field.getName(), null); + } + + FakeRow originalRow = new FakeRow(schema, nullValues); + SerializableRow serializableRow = new SerializableRow(originalRow); + + verifyRowContents(serializableRow, schema, nullValues); + + // Serialize and Deserialize + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + try (ObjectOutputStream oos = new ObjectOutputStream(baos)) { + oos.writeObject(serializableRow); + } + + byte[] bytes = baos.toByteArray(); + SerializableRow deserializedRow; + try (ObjectInputStream ois = new ObjectInputStream(new ByteArrayInputStream(bytes))) { + deserializedRow = (SerializableRow) ois.readObject(); + } + + verifyRowContents(deserializedRow, schema, nullValues); + Assert.assertEquals(serializableRow, deserializedRow); + } + + private void verifyRowContents(Row row, StructType schema, Map expectedValues) { + Assert.assertEquals(schema.toString(), row.getSchema().toString()); + int i = 0; + for (StructField field : schema.fields()) { + Object expected = expectedValues.get(field.getName()); + Assert.assertEquals(expected == null, row.isNullAt(i)); + if (expected != null) { + DataType type = field.getDataType(); + if (type instanceof BooleanType) { + Assert.assertEquals(expected, row.getBoolean(i)); + } else if (type instanceof ByteType) { + Assert.assertEquals(expected, row.getByte(i)); + } else if (type instanceof ShortType) { + Assert.assertEquals(expected, row.getShort(i)); + } else if (type instanceof IntegerType) { + Assert.assertEquals(expected, row.getInt(i)); + } else if (type instanceof LongType) { + Assert.assertEquals(expected, row.getLong(i)); + } else if (type instanceof FloatType) { + Assert.assertEquals(expected, row.getFloat(i)); + } else if (type instanceof DoubleType) { + Assert.assertEquals(expected, row.getDouble(i)); + } else if (type instanceof StringType) { + Assert.assertEquals(expected, row.getString(i)); + } else if (type instanceof BinaryType) { + Assert.assertArrayEquals((byte[]) expected, row.getBinary(i)); + } else if (type instanceof DecimalType) { + Assert.assertEquals(expected, row.getDecimal(i)); + } else if (type instanceof StructType) { + Row actualStruct = row.getStruct(i); + Assert.assertNotNull(actualStruct); + Row expectedStruct = (Row) expected; + Assert.assertEquals( + expectedStruct.getSchema().toString(), actualStruct.getSchema().toString()); + for (int j = 0; j < expectedStruct.getSchema().fields().size(); j++) { + Assert.assertEquals(expectedStruct.isNullAt(j), actualStruct.isNullAt(j)); + if (!expectedStruct.isNullAt(j)) { + Assert.assertEquals(expectedStruct.getInt(j), actualStruct.getInt(j)); + } + } + } else if (type instanceof ArrayType) { + ArrayValue actualArray = row.getArray(i); + Assert.assertNotNull(actualArray); + List expectedList = (List) expected; + Assert.assertEquals(expectedList.size(), actualArray.getSize()); + ColumnVector vector = actualArray.getElements(); + for (int j = 0; j < expectedList.size(); j++) { + Assert.assertEquals(expectedList.get(j) == null, vector.isNullAt(j)); + if (expectedList.get(j) != null) { + Assert.assertEquals(expectedList.get(j), vector.getString(j)); + } + } + } else if (type instanceof MapType) { + MapValue actualMap = row.getMap(i); + Assert.assertNotNull(actualMap); + Map expectedMap = (Map) expected; + Assert.assertEquals(expectedMap.size(), actualMap.getSize()); + ColumnVector keys = actualMap.getKeys(); + ColumnVector valuesVector = actualMap.getValues(); + int j = 0; + for (Map.Entry entry : expectedMap.entrySet()) { + Assert.assertEquals(entry.getKey(), keys.getString(j)); + Assert.assertEquals(entry.getValue() == null, valuesVector.isNullAt(j)); + if (entry.getValue() != null) { + Assert.assertEquals(entry.getValue(), valuesVector.getInt(j)); + } + j++; + } + } + } + i++; + } + } + + private static class FakeRow implements Row { + private final StructType schema; + private final Map values; + + FakeRow(StructType schema, Map values) { + this.schema = schema; + this.values = values; + } + + @Override + public StructType getSchema() { + return schema; + } + + @Override + public boolean isNullAt(int ord) { + String fieldName = schema.fields().get(ord).getName(); + return values.get(fieldName) == null; + } + + @Override + public boolean getBoolean(int ord) { + String fieldName = schema.fields().get(ord).getName(); + return (Boolean) values.get(fieldName); + } + + @Override + public byte getByte(int ord) { + String fieldName = schema.fields().get(ord).getName(); + return (Byte) values.get(fieldName); + } + + @Override + public short getShort(int ord) { + String fieldName = schema.fields().get(ord).getName(); + return (Short) values.get(fieldName); + } + + @Override + public int getInt(int ord) { + String fieldName = schema.fields().get(ord).getName(); + return (Integer) values.get(fieldName); + } + + @Override + public long getLong(int ord) { + String fieldName = schema.fields().get(ord).getName(); + return (Long) values.get(fieldName); + } + + @Override + public float getFloat(int ord) { + String fieldName = schema.fields().get(ord).getName(); + return (Float) values.get(fieldName); + } + + @Override + public double getDouble(int ord) { + String fieldName = schema.fields().get(ord).getName(); + return (Double) values.get(fieldName); + } + + @Override + public String getString(int ord) { + String fieldName = schema.fields().get(ord).getName(); + return (String) values.get(fieldName); + } + + @Override + public byte[] getBinary(int ord) { + String fieldName = schema.fields().get(ord).getName(); + return (byte[]) values.get(fieldName); + } + + @Override + public BigDecimal getDecimal(int ord) { + String fieldName = schema.fields().get(ord).getName(); + return (BigDecimal) values.get(fieldName); + } + + @Override + public Row getStruct(int ord) { + String fieldName = schema.fields().get(ord).getName(); + return (Row) values.get(fieldName); + } + + @Override + public ArrayValue getArray(int ord) { + String fieldName = schema.fields().get(ord).getName(); + List list = (List) values.get(fieldName); + if (list == null) { + return null; + } + DataType elementType = ((ArrayType) schema.fields().get(ord).getDataType()).getElementType(); + return new FakeArrayValue(list, elementType); + } + + @Override + public MapValue getMap(int ord) { + String fieldName = schema.fields().get(ord).getName(); + Map map = (Map) values.get(fieldName); + if (map == null) { + return null; + } + MapType mapType = (MapType) schema.fields().get(ord).getDataType(); + return new FakeMapValue(map, mapType.getKeyType(), mapType.getValueType()); + } + } + + private static class FakeArrayValue implements ArrayValue { + private final List list; + private final DataType elementType; + + FakeArrayValue(List list, DataType elementType) { + this.list = list; + this.elementType = elementType; + } + + @Override + public int getSize() { + return list.size(); + } + + @Override + public ColumnVector getElements() { + return new FakeColumnVector(elementType, list); + } + } + + private static class FakeMapValue implements MapValue { + private final Map map; + private final DataType keyType; + private final DataType valueType; + + FakeMapValue(Map map, DataType keyType, DataType valueType) { + this.map = map; + this.keyType = keyType; + this.valueType = valueType; + } + + @Override + public int getSize() { + return map.size(); + } + + @Override + public ColumnVector getKeys() { + return new FakeColumnVector(keyType, new ArrayList<>(map.keySet())); + } + + @Override + public ColumnVector getValues() { + return new FakeColumnVector(valueType, new ArrayList<>(map.values())); + } + } + + private static class FakeColumnVector implements ColumnVector { + private final DataType dataType; + private final List list; + + FakeColumnVector(DataType dataType, List list) { + this.dataType = dataType; + this.list = new ArrayList<>(list); + } + + @Override + public DataType getDataType() { + return dataType; + } + + @Override + public int getSize() { + return list.size(); + } + + @Override + public boolean isNullAt(int rowId) { + return list.get(rowId) == null; + } + + @Override + public boolean getBoolean(int rowId) { + return (Boolean) list.get(rowId); + } + + @Override + public byte getByte(int rowId) { + return (Byte) list.get(rowId); + } + + @Override + public short getShort(int rowId) { + return (Short) list.get(rowId); + } + + @Override + public int getInt(int rowId) { + return (Integer) list.get(rowId); + } + + @Override + public long getLong(int rowId) { + return (Long) list.get(rowId); + } + + @Override + public float getFloat(int rowId) { + return (Float) list.get(rowId); + } + + @Override + public double getDouble(int rowId) { + return (Double) list.get(rowId); + } + + @Override + public String getString(int rowId) { + return (String) list.get(rowId); + } + + @Override + public byte[] getBinary(int rowId) { + return (byte[]) list.get(rowId); + } + + @Override + public BigDecimal getDecimal(int rowId) { + return (BigDecimal) list.get(rowId); + } + + @Override + public void close() {} + } +}