Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 6 additions & 8 deletions native/core/src/execution/columnar_to_row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1052,10 +1052,10 @@ impl ColumnarToRowContext {
})
}
(DataType::Int32, DataType::Decimal128(precision, scale)) => {
// Parquet stores small-precision decimals as Int32 for efficiency.
// When COMET_USE_DECIMAL_128 is false, BatchReader produces these types.
// The Int32 value is already scaled (e.g., -1 means -0.01 for scale 2).
// We need to reinterpret (not cast) to Decimal128 preserving the value.
// Parquet stores small-precision decimals as Int32 for efficiency, and the
// reader may surface them as the physical Int32 type. The value is already
// scaled (e.g., -1 means -0.01 for scale 2). Reinterpret (not cast) to
// Decimal128 preserving the value.
let int_array = array.as_any().downcast_ref::<Int32Array>().ok_or_else(|| {
CometError::Internal("Failed to downcast to Int32Array".to_string())
})?;
Expand Down Expand Up @@ -2581,8 +2581,7 @@ mod tests {
#[test]
fn test_convert_int32_to_decimal128() {
// Test that Int32 arrays are correctly cast to Decimal128 when schema expects Decimal128.
// This can happen when COMET_USE_DECIMAL_128 is false and the parquet reader produces
// Int32 for small-precision decimals.
// This can happen when the parquet reader surfaces small-precision decimals as Int32.

// Create an Int32 array representing decimals: [-1, -2, -3] which at scale 2 means
// [-0.01, -0.02, -0.03]
Expand Down Expand Up @@ -2619,8 +2618,7 @@ mod tests {
#[test]
fn test_convert_int64_to_decimal128() {
// Test that Int64 arrays are correctly cast to Decimal128 when schema expects Decimal128.
// This can happen when COMET_USE_DECIMAL_128 is false and the parquet reader produces
// Int64 for medium-precision decimals.
// This can happen when the parquet reader surfaces medium-precision decimals as Int64.

// Create an Int64 array representing decimals
let int_array: ArrayRef = Arc::new(Int64Array::from(vec![-100i64, -200, -300]));
Expand Down
4 changes: 0 additions & 4 deletions native/core/src/parquet/parquet_support.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,6 @@ pub struct SparkParquetOptions {
pub allow_incompat: bool,
/// Support casting unsigned ints to signed ints (used by Parquet SchemaAdapter)
pub allow_cast_unsigned_ints: bool,
/// Whether to always represent decimals using 128 bits. If false, the native reader may represent decimals using 32 or 64 bits, depending on the precision.
pub use_decimal_128: bool,
/// Whether to read dates/timestamps that were written in the legacy hybrid Julian + Gregorian calendar as it is. If false, throw exceptions instead. If the spark type is TimestampNTZ, this should be true.
pub use_legacy_date_timestamp_or_ntz: bool,
// Whether schema field names are case sensitive
Expand Down Expand Up @@ -105,7 +103,6 @@ impl SparkParquetOptions {
timezone: timezone.to_string(),
allow_incompat,
allow_cast_unsigned_ints: false,
use_decimal_128: false,
use_legacy_date_timestamp_or_ntz: false,
case_sensitive: false,
return_null_struct_if_all_fields_missing: true,
Expand All @@ -121,7 +118,6 @@ impl SparkParquetOptions {
timezone: "".to_string(),
allow_incompat,
allow_cast_unsigned_ints: false,
use_decimal_128: false,
use_legacy_date_timestamp_or_ntz: false,
case_sensitive: false,
return_null_struct_if_all_fields_missing: true,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,12 @@ public abstract class CometDecodedVector extends CometVector {
private byte validityByteCache;
protected boolean isUuid;

protected CometDecodedVector(ValueVector vector, Field valueField, boolean useDecimal128) {
this(vector, valueField, useDecimal128, false);
protected CometDecodedVector(ValueVector vector, Field valueField) {
this(vector, valueField, false);
}

protected CometDecodedVector(
ValueVector vector, Field valueField, boolean useDecimal128, boolean isUuid) {
super(Utils.fromArrowField(valueField), useDecimal128);
protected CometDecodedVector(ValueVector vector, Field valueField, boolean isUuid) {
super(Utils.fromArrowField(valueField));
this.valueVector = vector;
this.numNulls = valueVector.getNullCount();
this.numValues = valueVector.getValueCount();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,11 @@ public class CometDelegateVector extends CometVector {
protected CometVector delegate;

public CometDelegateVector(DataType dataType) {
this(dataType, null, false);
this(dataType, null);
}

public CometDelegateVector(DataType dataType, boolean useDecimal128) {
this(dataType, null, useDecimal128);
}

public CometDelegateVector(DataType dataType, CometVector delegate, boolean useDecimal128) {
super(dataType, useDecimal128);
public CometDelegateVector(DataType dataType, CometVector delegate) {
super(dataType);
if (delegate instanceof CometDelegateVector) {
throw new IllegalArgumentException("cannot have nested delegation");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,21 +35,17 @@ public class CometDictionaryVector extends CometDecodedVector {
private final boolean isAlias;

public CometDictionaryVector(
CometPlainVector indices,
CometDictionary values,
DictionaryProvider provider,
boolean useDecimal128) {
this(indices, values, provider, useDecimal128, false, false);
CometPlainVector indices, CometDictionary values, DictionaryProvider provider) {
this(indices, values, provider, false, false);
}

public CometDictionaryVector(
CometPlainVector indices,
CometDictionary values,
DictionaryProvider provider,
boolean useDecimal128,
boolean isAlias,
boolean isUuid) {
super(indices.valueVector, values.getValueVector().getField(), useDecimal128, isUuid);
super(indices.valueVector, values.getValueVector().getField(), isUuid);
Preconditions.checkArgument(
indices.valueVector instanceof IntVector, "'indices' should be a IntVector");
this.values = values;
Expand Down Expand Up @@ -131,11 +127,11 @@ byte[] getBinaryDecimal(int i) {
public CometVector slice(int offset, int length) {
TransferPair tp = indices.valueVector.getTransferPair(indices.valueVector.getAllocator());
tp.splitAndTransfer(offset, length);
CometPlainVector sliced = new CometPlainVector(tp.getTo(), useDecimal128);
CometPlainVector sliced = new CometPlainVector(tp.getTo());

// Set the alias flag to true so that the sliced vector will not close the dictionary vector.
// Otherwise, if the dictionary is closed, the sliced vector will not be able to access the
// dictionary.
return new CometDictionaryVector(sliced, values, provider, useDecimal128, true, isUuid);
return new CometDictionaryVector(sliced, values, provider, true, isUuid);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,13 @@ public class CometListVector extends CometDecodedVector {
final ColumnVector dataColumnVector;
final DictionaryProvider dictionaryProvider;

public CometListVector(
ValueVector vector, boolean useDecimal128, DictionaryProvider dictionaryProvider) {
super(vector, vector.getField(), useDecimal128);
public CometListVector(ValueVector vector, DictionaryProvider dictionaryProvider) {
super(vector, vector.getField());

this.listVector = ((ListVector) vector);
this.dataVector = listVector.getDataVector();
this.dictionaryProvider = dictionaryProvider;
this.dataColumnVector = getVector(dataVector, useDecimal128, dictionaryProvider);
this.dataColumnVector = getVector(dataVector, dictionaryProvider);
}

@Override
Expand All @@ -57,6 +56,6 @@ public CometVector slice(int offset, int length) {
TransferPair tp = this.valueVector.getTransferPair(this.valueVector.getAllocator());
tp.splitAndTransfer(offset, length);

return new CometListVector(tp.getTo(), useDecimal128, dictionaryProvider);
return new CometListVector(tp.getTo(), dictionaryProvider);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,15 @@ public class CometMapVector extends CometDecodedVector {
final ColumnVector keys;
final ColumnVector values;

public CometMapVector(
ValueVector vector, boolean useDecimal128, DictionaryProvider dictionaryProvider) {
super(vector, vector.getField(), useDecimal128);
public CometMapVector(ValueVector vector, DictionaryProvider dictionaryProvider) {
super(vector, vector.getField());

this.mapVector = ((MapVector) vector);
this.dataVector = mapVector.getDataVector();
this.dictionaryProvider = dictionaryProvider;

if (dataVector instanceof StructVector) {
this.dataColumnVector = new CometStructVector(dataVector, useDecimal128, dictionaryProvider);
this.dataColumnVector = new CometStructVector(dataVector, dictionaryProvider);

if (dataColumnVector.children.size() != 2) {
throw new RuntimeException(
Expand Down Expand Up @@ -77,6 +76,6 @@ public CometVector slice(int offset, int length) {
TransferPair tp = this.valueVector.getTransferPair(this.valueVector.getAllocator());
tp.splitAndTransfer(offset, length);

return new CometMapVector(tp.getTo(), useDecimal128, dictionaryProvider);
return new CometMapVector(tp.getTo(), dictionaryProvider);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,16 @@ public class CometPlainVector extends CometDecodedVector {

private boolean isReused;

public CometPlainVector(ValueVector vector, boolean useDecimal128) {
this(vector, useDecimal128, false);
public CometPlainVector(ValueVector vector) {
this(vector, false);
}

public CometPlainVector(ValueVector vector, boolean useDecimal128, boolean isUuid) {
this(vector, useDecimal128, isUuid, false);
public CometPlainVector(ValueVector vector, boolean isUuid) {
this(vector, isUuid, false);
}

public CometPlainVector(
ValueVector vector, boolean useDecimal128, boolean isUuid, boolean isReused) {
super(vector, vector.getField(), useDecimal128, isUuid);
public CometPlainVector(ValueVector vector, boolean isUuid, boolean isReused) {
super(vector, vector.getField(), isUuid);
// NullType doesn't have data buffer.
if (vector instanceof NullVector) {
this.valueBufferAddress = -1;
Expand Down Expand Up @@ -184,7 +183,7 @@ public CometVector slice(int offset, int length) {
TransferPair tp = this.valueVector.getTransferPair(this.valueVector.getAllocator());
tp.splitAndTransfer(offset, length);

return new CometPlainVector(tp.getTo(), useDecimal128);
return new CometPlainVector(tp.getTo());
}

private static UUID convertToUuid(byte[] buf) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,7 @@ public class CometSelectionVector extends CometVector {
* @throws IllegalArgumentException if any index is out of bounds
*/
public CometSelectionVector(CometVector values, int[] indices, int numValues) {
// Use the values vector's datatype, useDecimal128, and dictionary provider
super(values.dataType(), values.useDecimal128);
super(values.dataType());

this.values = values;
this.selectionIndices = indices;
Expand All @@ -97,8 +96,7 @@ public CometSelectionVector(CometVector values, int[] indices, int numValues) {
}
indicesVector.setValueCount(numValues);

this.indices =
CometVector.getVector(indicesVector, values.useDecimal128, values.getDictionaryProvider());
this.indices = CometVector.getVector(indicesVector, values.getDictionaryProvider());
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,8 @@ public class CometStructVector extends CometDecodedVector {
final List<ColumnVector> children;
final DictionaryProvider dictionaryProvider;

public CometStructVector(
ValueVector vector, boolean useDecimal128, DictionaryProvider dictionaryProvider) {
super(vector, vector.getField(), useDecimal128);
public CometStructVector(ValueVector vector, DictionaryProvider dictionaryProvider) {
super(vector, vector.getField());

StructVector structVector = ((StructVector) vector);

Expand All @@ -44,7 +43,7 @@ public CometStructVector(

for (int i = 0; i < size; ++i) {
ValueVector value = structVector.getVectorById(i);
children.add(getVector(value, useDecimal128, dictionaryProvider));
children.add(getVector(value, dictionaryProvider));
}
this.children = children;
this.dictionaryProvider = dictionaryProvider;
Expand All @@ -60,6 +59,6 @@ public CometVector slice(int offset, int length) {
TransferPair tp = this.valueVector.getTransferPair(this.valueVector.getAllocator());
tp.splitAndTransfer(offset, length);

return new CometStructVector(tp.getTo(), useDecimal128, dictionaryProvider);
return new CometStructVector(tp.getTo(), dictionaryProvider);
}
}
33 changes: 12 additions & 21 deletions spark/src/main/java/org/apache/comet/vector/CometVector.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
import org.apache.arrow.vector.types.pojo.DictionaryEncoding;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.Decimal;
import org.apache.spark.sql.types.IntegerType;
import org.apache.spark.sql.vectorized.ColumnVector;
import org.apache.spark.sql.vectorized.ColumnarArray;
import org.apache.spark.sql.vectorized.ColumnarMap;
Expand All @@ -43,7 +42,6 @@
public abstract class CometVector extends ColumnVector {
private static final int DECIMAL_BYTE_WIDTH = 16;
private final byte[] DECIMAL_BYTES = new byte[DECIMAL_BYTE_WIDTH];
protected final boolean useDecimal128;

private static final long decimalValOffset;

Expand All @@ -58,9 +56,8 @@ public abstract class CometVector extends ColumnVector {
}
}

public CometVector(DataType type, boolean useDecimal128) {
public CometVector(DataType type) {
super(type);
this.useDecimal128 = useDecimal128;
}

/**
Expand All @@ -86,10 +83,8 @@ public boolean isFixedLength() {
@Override
public Decimal getDecimal(int i, int precision, int scale) {
if (isNullAt(i)) return null;
if (!useDecimal128 && precision <= Decimal.MAX_INT_DIGITS() && type instanceof IntegerType) {
return createDecimal(getInt(i), precision, scale);
} else if (precision <= Decimal.MAX_LONG_DIGITS()) {
return createDecimal(useDecimal128 ? getLongDecimal(i) : getLong(i), precision, scale);
if (precision <= Decimal.MAX_LONG_DIGITS()) {
return createDecimal(getLongDecimal(i), precision, scale);
} else {
byte[] bytes = getBinaryDecimal(i);
BigInteger bigInteger = new BigInteger(bytes);
Expand Down Expand Up @@ -230,37 +225,33 @@ public DictionaryProvider getDictionaryProvider() {
* Returns a corresponding `CometVector` implementation based on the given Arrow `ValueVector`.
*
* @param vector Arrow `ValueVector`
* @param useDecimal128 Whether to use Decimal128 for decimal column
* @return `CometVector` implementation
*/
public static CometVector getVector(
ValueVector vector, boolean useDecimal128, DictionaryProvider dictionaryProvider) {
public static CometVector getVector(ValueVector vector, DictionaryProvider dictionaryProvider) {
if (vector instanceof StructVector) {
return new CometStructVector(vector, useDecimal128, dictionaryProvider);
return new CometStructVector(vector, dictionaryProvider);
} else if (vector instanceof MapVector) {
return new CometMapVector(vector, useDecimal128, dictionaryProvider);
return new CometMapVector(vector, dictionaryProvider);
} else if (vector instanceof ListVector) {
return new CometListVector(vector, useDecimal128, dictionaryProvider);
return new CometListVector(vector, dictionaryProvider);
} else {
DictionaryEncoding dictionaryEncoding = vector.getField().getDictionary();
CometPlainVector cometVector = new CometPlainVector(vector, useDecimal128);
CometPlainVector cometVector = new CometPlainVector(vector);

if (dictionaryEncoding == null) {
return cometVector;
} else {
Dictionary dictionary = dictionaryProvider.lookup(dictionaryEncoding.getId());
CometPlainVector dictionaryVector =
new CometPlainVector(dictionary.getVector(), useDecimal128);
CometPlainVector dictionaryVector = new CometPlainVector(dictionary.getVector());
CometDictionary cometDictionary = new CometDictionary(dictionaryVector);

return new CometDictionaryVector(
cometVector, cometDictionary, dictionaryProvider, useDecimal128);
return new CometDictionaryVector(cometVector, cometDictionary, dictionaryProvider);
}
}
}

protected static CometVector getVector(ValueVector vector, boolean useDecimal128) {
return getVector(vector, useDecimal128, null);
protected static CometVector getVector(ValueVector vector) {
return getVector(vector, null);
}

private UnsupportedOperationException notImplementedException() {
Expand Down
10 changes: 0 additions & 10 deletions spark/src/main/scala/org/apache/comet/CometConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -699,16 +699,6 @@ object CometConf extends ShimCometConf {
.booleanConf
.createWithDefault(false)

val COMET_USE_DECIMAL_128: ConfigEntry[Boolean] = conf("spark.comet.use.decimal128")
.internal()
.category(CATEGORY_EXEC)
.doc("If true, Comet will always use 128 bits to represent a decimal value, regardless of " +
"its precision. If false, Comet will use 32, 64 and 128 bits respectively depending on " +
"the precision. N.B. this is NOT a user-facing config but should be inferred and set by " +
"Comet itself.")
.booleanConf
.createWithDefault(false)

val COMET_USE_LAZY_MATERIALIZATION: ConfigEntry[Boolean] = conf(
"spark.comet.use.lazyMaterialization")
.internal()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -452,8 +452,7 @@ private[codegen] object CometBatchKernelCodegenInput {
out: mutable.ArrayBuffer[String]): Unit = spec match {
case sc: ScalarColumnSpec =>
if (wrapsInCometPlainVector(sc.vectorClass)) {
// `useDecimal128 = true` matches Spark's 128-bit decimal storage.
out += s"this.$path = new $cometPlainVectorName($source, true);"
out += s"this.$path = new $cometPlainVectorName($source);"
} else {
out += s"this.$path = (${sc.vectorClass.getName}) $source;"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,6 @@ case class CometExecRule(session: SparkSession)
case s: ShuffleExchangeExec =>
CometShuffleExchangeExec.shuffleSupported(s) match {
case Some(CometNativeShuffle) =>
// Switch to use Decimal128 regardless of precision, since Arrow native execution
// doesn't support Decimal32 and Decimal64 yet.
conf.setConfString(CometConf.COMET_USE_DECIMAL_128.key, "true")
CometShuffleExchangeExec(s, shuffleType = CometNativeShuffle)
case Some(CometColumnarShuffle) =>
CometShuffleExchangeExec(s, shuffleType = CometColumnarShuffle)
Expand Down
Loading
Loading