diff --git a/native/spark-expr/Cargo.toml b/native/spark-expr/Cargo.toml index e7c238f7eb..b014c49a2d 100644 --- a/native/spark-expr/Cargo.toml +++ b/native/spark-expr/Cargo.toml @@ -103,3 +103,7 @@ path = "tests/spark_expr_reg.rs" [[bench]] name = "cast_from_boolean" harness = false + +[[bench]] +name = "cast_non_int_numeric_timestamp" +harness = false diff --git a/native/spark-expr/benches/cast_non_int_numeric_timestamp.rs b/native/spark-expr/benches/cast_non_int_numeric_timestamp.rs new file mode 100644 index 0000000000..ea1a85e407 --- /dev/null +++ b/native/spark-expr/benches/cast_non_int_numeric_timestamp.rs @@ -0,0 +1,143 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::builder::{BooleanBuilder, Decimal128Builder, Float32Builder, Float64Builder}; +use arrow::array::RecordBatch; +use arrow::datatypes::{DataType, Field, Schema, TimeUnit}; +use criterion::{criterion_group, criterion_main, Criterion}; +use datafusion::physical_expr::{expressions::Column, PhysicalExpr}; +use datafusion_comet_spark_expr::{Cast, EvalMode, SparkCastOptions}; +use std::sync::Arc; + +const BATCH_SIZE: usize = 8192; + +fn criterion_benchmark(c: &mut Criterion) { + let spark_cast_options = SparkCastOptions::new(EvalMode::Legacy, "UTC", false); + let timestamp_type = DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into())); + + let mut group = c.benchmark_group("cast_non_int_numeric_to_timestamp"); + + // Float32 -> Timestamp + let batch_f32 = create_float32_batch(); + let expr_f32 = Arc::new(Column::new("a", 0)); + let cast_f32_to_ts = Cast::new(expr_f32, timestamp_type.clone(), spark_cast_options.clone()); + group.bench_function("cast_f32_to_timestamp", |b| { + b.iter(|| cast_f32_to_ts.evaluate(&batch_f32).unwrap()); + }); + + // Float64 -> Timestamp + let batch_f64 = create_float64_batch(); + let expr_f64 = Arc::new(Column::new("a", 0)); + let cast_f64_to_ts = Cast::new(expr_f64, timestamp_type.clone(), spark_cast_options.clone()); + group.bench_function("cast_f64_to_timestamp", |b| { + b.iter(|| cast_f64_to_ts.evaluate(&batch_f64).unwrap()); + }); + + // Boolean -> Timestamp + let batch_bool = create_boolean_batch(); + let expr_bool = Arc::new(Column::new("a", 0)); + let cast_bool_to_ts = Cast::new( + expr_bool, + timestamp_type.clone(), + spark_cast_options.clone(), + ); + group.bench_function("cast_bool_to_timestamp", |b| { + b.iter(|| cast_bool_to_ts.evaluate(&batch_bool).unwrap()); + }); + + // Decimal128 -> Timestamp + let batch_decimal = create_decimal128_batch(); + let expr_decimal = Arc::new(Column::new("a", 0)); + let cast_decimal_to_ts = Cast::new( + expr_decimal, + timestamp_type.clone(), + spark_cast_options.clone(), + ); + group.bench_function("cast_decimal_to_timestamp", |b| { + b.iter(|| cast_decimal_to_ts.evaluate(&batch_decimal).unwrap()); + }); + + group.finish(); +} + +fn create_float32_batch() -> RecordBatch { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)])); + let mut b = Float32Builder::with_capacity(BATCH_SIZE); + for i in 0..BATCH_SIZE { + if i % 10 == 0 { + b.append_null(); + } else { + b.append_value(rand::random::()); + } + } + RecordBatch::try_new(schema, vec![Arc::new(b.finish())]).unwrap() +} + +fn create_float64_batch() -> RecordBatch { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Float64, true)])); + let mut b = Float64Builder::with_capacity(BATCH_SIZE); + for i in 0..BATCH_SIZE { + if i % 10 == 0 { + b.append_null(); + } else { + b.append_value(rand::random::()); + } + } + RecordBatch::try_new(schema, vec![Arc::new(b.finish())]).unwrap() +} + +fn create_boolean_batch() -> RecordBatch { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Boolean, true)])); + let mut b = BooleanBuilder::with_capacity(BATCH_SIZE); + for i in 0..BATCH_SIZE { + if i % 10 == 0 { + b.append_null(); + } else { + b.append_value(rand::random::()); + } + } + RecordBatch::try_new(schema, vec![Arc::new(b.finish())]).unwrap() +} + +fn create_decimal128_batch() -> RecordBatch { + let schema = Arc::new(Schema::new(vec![Field::new( + "a", + DataType::Decimal128(18, 6), + true, + )])); + let mut b = Decimal128Builder::with_capacity(BATCH_SIZE); + for i in 0..BATCH_SIZE { + if i % 10 == 0 { + b.append_null(); + } else { + b.append_value(i as i128 * 1_000_000); + } + } + let array = b.finish().with_precision_and_scale(18, 6).unwrap(); + RecordBatch::try_new(schema, vec![Arc::new(array)]).unwrap() +} + +fn config() -> Criterion { + Criterion::default() +} + +criterion_group! { + name = benches; + config = config(); + targets = criterion_benchmark +} +criterion_main!(benches); diff --git a/native/spark-expr/src/conversion_funcs/boolean.rs b/native/spark-expr/src/conversion_funcs/boolean.rs index db288fa32a..49855790b1 100644 --- a/native/spark-expr/src/conversion_funcs/boolean.rs +++ b/native/spark-expr/src/conversion_funcs/boolean.rs @@ -16,7 +16,7 @@ // under the License. use crate::SparkResult; -use arrow::array::{ArrayRef, AsArray, Decimal128Array}; +use arrow::array::{Array, ArrayRef, AsArray, Decimal128Array, TimestampMicrosecondBuilder}; use arrow::datatypes::DataType; use std::sync::Arc; @@ -28,7 +28,6 @@ pub fn is_df_cast_from_bool_spark_compatible(to_type: &DataType) -> bool { ) } -// only DF incompatible boolean cast pub fn cast_boolean_to_decimal( array: &ArrayRef, precision: u8, @@ -43,6 +42,25 @@ pub fn cast_boolean_to_decimal( Ok(Arc::new(result.with_precision_and_scale(precision, scale)?)) } +pub(crate) fn cast_boolean_to_timestamp( + array_ref: &ArrayRef, + target_tz: &Option>, +) -> SparkResult { + let bool_array = array_ref.as_boolean(); + let mut builder = TimestampMicrosecondBuilder::with_capacity(bool_array.len()); + + for i in 0..bool_array.len() { + if bool_array.is_null(i) { + builder.append_null(); + } else { + let micros = if bool_array.value(i) { 1 } else { 0 }; + builder.append_value(micros); + } + } + + Ok(Arc::new(builder.finish().with_timezone_opt(target_tz.clone())) as ArrayRef) +} + #[cfg(test)] mod tests { use super::*; @@ -53,6 +71,7 @@ mod tests { Int64Array, Int8Array, StringArray, }; use arrow::datatypes::DataType::Decimal128; + use arrow::datatypes::TimestampMicrosecondType; use std::sync::Arc; fn test_input_bool_array() -> ArrayRef { @@ -193,4 +212,26 @@ mod tests { assert_eq!(arr.value(1), expected_arr.value(1)); assert!(arr.is_null(2)); } + + #[test] + fn test_cast_boolean_to_timestamp() { + let timezones: [Option>; 3] = [ + Some(Arc::from("UTC")), + Some(Arc::from("America/Los_Angeles")), + None, + ]; + + for tz in &timezones { + let bool_array: ArrayRef = + Arc::new(BooleanArray::from(vec![Some(true), Some(false), None])); + + let result = cast_boolean_to_timestamp(&bool_array, tz).unwrap(); + let ts_array = result.as_primitive::(); + + assert_eq!(ts_array.value(0), 1); // true -> 1 microsecond + assert_eq!(ts_array.value(1), 0); // false -> 0 (epoch) + assert!(ts_array.is_null(2)); + assert_eq!(ts_array.timezone(), tz.as_ref().map(|s| s.as_ref())); + } + } } diff --git a/native/spark-expr/src/conversion_funcs/cast.rs b/native/spark-expr/src/conversion_funcs/cast.rs index ff09dbe06e..a9e6888145 100644 --- a/native/spark-expr/src/conversion_funcs/cast.rs +++ b/native/spark-expr/src/conversion_funcs/cast.rs @@ -16,14 +16,15 @@ // under the License. use crate::conversion_funcs::boolean::{ - cast_boolean_to_decimal, is_df_cast_from_bool_spark_compatible, + cast_boolean_to_decimal, cast_boolean_to_timestamp, is_df_cast_from_bool_spark_compatible, }; use crate::conversion_funcs::numeric::{ - cast_float32_to_decimal128, cast_float64_to_decimal128, cast_int_to_decimal128, - cast_int_to_timestamp, is_df_cast_from_decimal_spark_compatible, - is_df_cast_from_float_spark_compatible, is_df_cast_from_int_spark_compatible, - spark_cast_decimal_to_boolean, spark_cast_float32_to_utf8, spark_cast_float64_to_utf8, - spark_cast_int_to_int, spark_cast_nonintegral_numeric_to_integral, + cast_decimal_to_timestamp, cast_float32_to_decimal128, cast_float64_to_decimal128, + cast_float_to_timestamp, cast_int_to_decimal128, cast_int_to_timestamp, + is_df_cast_from_decimal_spark_compatible, is_df_cast_from_float_spark_compatible, + is_df_cast_from_int_spark_compatible, spark_cast_decimal_to_boolean, + spark_cast_float32_to_utf8, spark_cast_float64_to_utf8, spark_cast_int_to_int, + spark_cast_nonintegral_numeric_to_integral, }; use crate::conversion_funcs::string::{ cast_string_to_date, cast_string_to_decimal, cast_string_to_float, cast_string_to_int, @@ -384,6 +385,9 @@ pub(crate) fn cast_array( cast_boolean_to_decimal(&array, *precision, *scale) } (Int8 | Int16 | Int32 | Int64, Timestamp(_, tz)) => cast_int_to_timestamp(&array, tz), + (Float32 | Float64, Timestamp(_, tz)) => cast_float_to_timestamp(&array, tz, eval_mode), + (Boolean, Timestamp(_, tz)) => cast_boolean_to_timestamp(&array, tz), + (Decimal128(_, scale), Timestamp(_, tz)) => cast_decimal_to_timestamp(&array, tz, *scale), _ if cast_options.is_adapting_schema || is_datafusion_spark_compatible(&from_type, to_type) => { diff --git a/native/spark-expr/src/conversion_funcs/numeric.rs b/native/spark-expr/src/conversion_funcs/numeric.rs index d204e2871c..59a65fb49f 100644 --- a/native/spark-expr/src/conversion_funcs/numeric.rs +++ b/native/spark-expr/src/conversion_funcs/numeric.rs @@ -24,7 +24,7 @@ use arrow::array::{ OffsetSizeTrait, PrimitiveArray, TimestampMicrosecondBuilder, }; use arrow::datatypes::{ - is_validate_decimal_precision, ArrowPrimitiveType, DataType, Decimal128Type, Float32Type, + i256, is_validate_decimal_precision, ArrowPrimitiveType, DataType, Decimal128Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, }; use num::{cast::AsPrimitive, ToPrimitive, Zero}; @@ -75,6 +75,56 @@ pub(crate) fn is_df_cast_from_decimal_spark_compatible(to_type: &DataType) -> bo ) } +macro_rules! cast_float_to_timestamp_impl { + ($array:expr, $builder:expr, $primitive_type:ty, $eval_mode:expr) => {{ + let arr = $array.as_primitive::<$primitive_type>(); + for i in 0..arr.len() { + if arr.is_null(i) { + $builder.append_null(); + } else { + let val = arr.value(i) as f64; + // Path 1: NaN/Infinity check - error says TIMESTAMP + if val.is_nan() || val.is_infinite() { + if $eval_mode == EvalMode::Ansi { + return Err(SparkError::CastInvalidValue { + value: val.to_string(), + from_type: "DOUBLE".to_string(), + to_type: "TIMESTAMP".to_string(), + }); + } + $builder.append_null(); + } else { + // Path 2: Multiply then check overflow - error says BIGINT + let micros = val * MICROS_PER_SECOND as f64; + if micros.floor() <= i64::MAX as f64 && micros.ceil() >= i64::MIN as f64 { + $builder.append_value(micros as i64); + } else { + if $eval_mode == EvalMode::Ansi { + let value_str = if micros.is_infinite() { + if micros.is_sign_positive() { + "Infinity".to_string() + } else { + "-Infinity".to_string() + } + } else if micros.is_nan() { + "NaN".to_string() + } else { + format!("{:e}", micros).to_uppercase() + "D" + }; + return Err(SparkError::CastOverFlow { + value: value_str, + from_type: "DOUBLE".to_string(), + to_type: "BIGINT".to_string(), + }); + } + $builder.append_null(); + } + } + } + } + }}; +} + macro_rules! cast_float_to_string { ($from:expr, $eval_mode:expr, $type:ty, $output_type:ty, $offset_type:ty) => {{ @@ -913,6 +963,57 @@ pub(crate) fn cast_int_to_timestamp( Ok(Arc::new(builder.finish().with_timezone_opt(target_tz.clone())) as ArrayRef) } +pub(crate) fn cast_decimal_to_timestamp( + array_ref: &ArrayRef, + target_tz: &Option>, + scale: i8, +) -> SparkResult { + let arr = array_ref.as_primitive::(); + let scale_factor = 10_i128.pow(scale as u32); + let mut builder = TimestampMicrosecondBuilder::with_capacity(arr.len()); + + for i in 0..arr.len() { + if arr.is_null(i) { + builder.append_null(); + } else { + let value = arr.value(i); + // Note: spark's big decimal truncates to + // long value and does not throw error (in all leval modes) + let value_256 = i256::from_i128(value); + let micros_256 = value_256 * i256::from_i128(MICROS_PER_SECOND as i128); + let ts = micros_256 / i256::from_i128(scale_factor); + builder.append_value(ts.as_i128() as i64); + } + } + + Ok(Arc::new(builder.finish().with_timezone_opt(target_tz.clone())) as ArrayRef) +} + +pub(crate) fn cast_float_to_timestamp( + array_ref: &ArrayRef, + target_tz: &Option>, + eval_mode: EvalMode, +) -> SparkResult { + let mut builder = TimestampMicrosecondBuilder::with_capacity(array_ref.len()); + + match array_ref.data_type() { + DataType::Float32 => { + cast_float_to_timestamp_impl!(array_ref, builder, Float32Type, eval_mode) + } + DataType::Float64 => { + cast_float_to_timestamp_impl!(array_ref, builder, Float64Type, eval_mode) + } + dt => { + return Err(SparkError::Internal(format!( + "Unsupported type for cast_float_to_timestamp: {:?}", + dt + ))) + } + } + + Ok(Arc::new(builder.finish().with_timezone_opt(target_tz.clone())) as ArrayRef) +} + #[cfg(test)] mod tests { use super::*; @@ -1100,4 +1201,113 @@ mod tests { assert!(casted.is_null(8)); assert!(casted.is_null(9)); } + + #[test] + fn test_cast_decimal_to_timestamp() { + let timezones: [Option>; 3] = [ + Some(Arc::from("UTC")), + Some(Arc::from("America/Los_Angeles")), + None, + ]; + + for tz in &timezones { + // Decimal128 with scale 6 + let decimal_array: ArrayRef = Arc::new( + Decimal128Array::from(vec![ + Some(0_i128), + Some(1_000_000_i128), + Some(-1_000_000_i128), + Some(1_500_000_i128), + Some(123_456_789_i128), + None, + ]) + .with_precision_and_scale(18, 6) + .unwrap(), + ); + + let result = cast_decimal_to_timestamp(&decimal_array, tz, 6).unwrap(); + let ts_array = result.as_primitive::(); + + assert_eq!(ts_array.value(0), 0); + assert_eq!(ts_array.value(1), 1_000_000); + assert_eq!(ts_array.value(2), -1_000_000); + assert_eq!(ts_array.value(3), 1_500_000); + assert_eq!(ts_array.value(4), 123_456_789); + assert!(ts_array.is_null(5)); + assert_eq!(ts_array.timezone(), tz.as_ref().map(|s| s.as_ref())); + + // Test with scale 2 + let decimal_array: ArrayRef = Arc::new( + Decimal128Array::from(vec![Some(100_i128), Some(150_i128), Some(-250_i128)]) + .with_precision_and_scale(10, 2) + .unwrap(), + ); + + let result = cast_decimal_to_timestamp(&decimal_array, tz, 2).unwrap(); + let ts_array = result.as_primitive::(); + + assert_eq!(ts_array.value(0), 1_000_000); + assert_eq!(ts_array.value(1), 1_500_000); + assert_eq!(ts_array.value(2), -2_500_000); + } + } + + #[test] + fn test_cast_float_to_timestamp() { + let timezones: [Option>; 3] = [ + Some(Arc::from("UTC")), + Some(Arc::from("America/Los_Angeles")), + None, + ]; + let eval_modes = [EvalMode::Legacy, EvalMode::Ansi, EvalMode::Try]; + + for tz in &timezones { + for eval_mode in &eval_modes { + // Float64 tests + let f64_array: ArrayRef = Arc::new(Float64Array::from(vec![ + Some(0.0), + Some(1.0), + Some(-1.0), + Some(1.5), + Some(0.000001), + None, + ])); + + let result = cast_float_to_timestamp(&f64_array, tz, *eval_mode).unwrap(); + let ts_array = result.as_primitive::(); + + assert_eq!(ts_array.value(0), 0); + assert_eq!(ts_array.value(1), 1_000_000); + assert_eq!(ts_array.value(2), -1_000_000); + assert_eq!(ts_array.value(3), 1_500_000); + assert_eq!(ts_array.value(4), 1); + assert!(ts_array.is_null(5)); + assert_eq!(ts_array.timezone(), tz.as_ref().map(|s| s.as_ref())); + + // Float32 tests + let f32_array: ArrayRef = Arc::new(Float32Array::from(vec![ + Some(0.0_f32), + Some(1.0_f32), + Some(-1.0_f32), + None, + ])); + + let result = cast_float_to_timestamp(&f32_array, tz, *eval_mode).unwrap(); + let ts_array = result.as_primitive::(); + + assert_eq!(ts_array.value(0), 0); + assert_eq!(ts_array.value(1), 1_000_000); + assert_eq!(ts_array.value(2), -1_000_000); + assert!(ts_array.is_null(3)); + } + } + + // ANSI mode errors on NaN/Infinity + let tz = &Some(Arc::from("UTC")); + let f64_nan: ArrayRef = Arc::new(Float64Array::from(vec![Some(f64::NAN)])); + assert!(cast_float_to_timestamp(&f64_nan, tz, EvalMode::Ansi).is_err()); + + let f64_inf: ArrayRef = Arc::new(Float64Array::from(vec![Some(f64::INFINITY)])); + assert!(cast_float_to_timestamp(&f64_inf, tz, EvalMode::Ansi).is_err()); + } } diff --git a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala index 15dfcb2d7c..95d5366907 100644 --- a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala +++ b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala @@ -21,7 +21,7 @@ package org.apache.comet.expressions import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Expression, Literal} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{ArrayType, DataType, DataTypes, DecimalType, NullType, StructType} +import org.apache.spark.sql.types.{ArrayType, DataType, DataTypes, DecimalType, NullType, StructType, TimestampType} import org.apache.comet.CometConf import org.apache.comet.CometSparkSessionExtensions.withInfo @@ -63,16 +63,17 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim { cast: Cast, inputs: Seq[Attribute], binding: Boolean): Option[ExprOuterClass.Expr] = { + val cometEvalMode = evalMode(cast) cast.child match { case _: Literal => exprToProtoInternal(Literal.create(cast.eval(), cast.dataType), inputs, binding) case _ => - if (isAlwaysCastToNull(cast.child.dataType, cast.dataType, evalMode(cast))) { + if (isAlwaysCastToNull(cast.child.dataType, cast.dataType, cometEvalMode)) { exprToProtoInternal(Literal.create(null, cast.dataType), inputs, binding) } else { val childExpr = exprToProtoInternal(cast.child, inputs, binding) if (childExpr.isDefined) { - castToProto(cast, cast.timeZoneId, cast.dataType, childExpr.get, evalMode(cast)) + castToProto(cast, cast.timeZoneId, cast.dataType, childExpr.get, cometEvalMode) } else { withInfo(cast, cast.child) None @@ -165,7 +166,7 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim { case (_: DecimalType, _) => canCastFromDecimal(toType) case (DataTypes.BooleanType, _) => - canCastFromBoolean(toType) + canCastFromBoolean(toType, evalMode) case (DataTypes.ByteType, _) => canCastFromByte(toType, evalMode) case (DataTypes.ShortType, _) => @@ -282,12 +283,15 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim { } } - private def canCastFromBoolean(toType: DataType): SupportLevel = toType match { - case DataTypes.ByteType | DataTypes.ShortType | DataTypes.IntegerType | DataTypes.LongType | - DataTypes.FloatType | DataTypes.DoubleType | _: DecimalType => - Compatible() - case _ => unsupported(DataTypes.BooleanType, toType) - } + private def canCastFromBoolean(toType: DataType, evalMode: CometEvalMode.Value): SupportLevel = + toType match { + case DataTypes.ByteType | DataTypes.ShortType | DataTypes.IntegerType | DataTypes.LongType | + DataTypes.FloatType | DataTypes.DoubleType | _: DecimalType => + Compatible() + case _: TimestampType if evalMode == CometEvalMode.LEGACY => + Compatible() + case _ => unsupported(DataTypes.BooleanType, toType) + } private def canCastFromByte(toType: DataType, evalMode: CometEvalMode.Value): SupportLevel = toType match { @@ -357,7 +361,7 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim { private def canCastFromFloat(toType: DataType): SupportLevel = toType match { case DataTypes.BooleanType | DataTypes.DoubleType | DataTypes.ByteType | DataTypes.ShortType | - DataTypes.IntegerType | DataTypes.LongType => + DataTypes.IntegerType | DataTypes.LongType | DataTypes.TimestampType => Compatible() case _: DecimalType => // https://github.com/apache/datafusion-comet/issues/1371 @@ -368,7 +372,7 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim { private def canCastFromDouble(toType: DataType): SupportLevel = toType match { case DataTypes.BooleanType | DataTypes.FloatType | DataTypes.ByteType | DataTypes.ShortType | - DataTypes.IntegerType | DataTypes.LongType => + DataTypes.IntegerType | DataTypes.LongType | DataTypes.TimestampType => Compatible() case _: DecimalType => // https://github.com/apache/datafusion-comet/issues/1371 @@ -378,7 +382,8 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim { private def canCastFromDecimal(toType: DataType): SupportLevel = toType match { case DataTypes.FloatType | DataTypes.DoubleType | DataTypes.ByteType | DataTypes.ShortType | - DataTypes.IntegerType | DataTypes.LongType | DataTypes.BooleanType => + DataTypes.IntegerType | DataTypes.LongType | DataTypes.BooleanType | + DataTypes.TimestampType => Compatible() case _ => Unsupported(Some(s"Cast from DecimalType to $toType is not supported")) } diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala index 72c2390d71..48242a978c 100644 --- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala @@ -167,9 +167,9 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { castTest(generateBools(), DataTypes.StringType) } - ignore("cast BooleanType to TimestampType") { - // Arrow error: Cast error: Casting from Boolean to Timestamp(Microsecond, Some("UTC")) not supported - castTest(generateBools(), DataTypes.TimestampType) + test("cast BooleanType to TimestampType") { + // Spark does not support ANSI or Try mode for Boolean to Timestamp casts + castTest(generateBools(), DataTypes.TimestampType, testAnsi = false, testTry = false) } // CAST from ByteType @@ -504,9 +504,13 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { castTest(withNulls(values).toDF("a"), DataTypes.StringType) } - ignore("cast FloatType to TimestampType") { - // java.lang.ArithmeticException: long overflow - castTest(generateFloats(), DataTypes.TimestampType) + test("cast FloatType to TimestampType") { + compatibleTimezones.foreach { tz => + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> tz) { + // Use useDFDiff to avoid collect() which fails on extreme timestamp values + castTest(generateFloats(), DataTypes.TimestampType, useDataFrameDiff = true) + } + } } // CAST from DoubleType @@ -560,9 +564,13 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { castTest(withNulls(values).toDF("a"), DataTypes.StringType) } - ignore("cast DoubleType to TimestampType") { - // java.lang.ArithmeticException: long overflow - castTest(generateDoubles(), DataTypes.TimestampType) + test("cast DoubleType to TimestampType") { + compatibleTimezones.foreach { tz => + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> tz) { + // Use useDFDiff to avoid collect() which fails on extreme timestamp values + castTest(generateDoubles(), DataTypes.TimestampType, useDataFrameDiff = true) + } + } } // CAST from DecimalType(10,2) @@ -627,11 +635,14 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { castTest(generateDecimalsPrecision10Scale2(), DataTypes.StringType) } - ignore("cast DecimalType(10,2) to TimestampType") { - // input: -123456.789000000000000000, expected: 1969-12-30 05:42:23.211, actual: 1969-12-31 15:59:59.876544 + test("cast DecimalType(10,2) to TimestampType") { castTest(generateDecimalsPrecision10Scale2(), DataTypes.TimestampType) } + test("cast DecimalType(38,10) to TimestampType") { + castTest(generateDecimalsPrecision38Scale18(), DataTypes.TimestampType) + } + // CAST from StringType test("cast StringType to BooleanType") { @@ -1466,7 +1477,8 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { toType: DataType, hasIncompatibleType: Boolean = false, testAnsi: Boolean = true, - testTry: Boolean = true): Unit = { + testTry: Boolean = true, + useDataFrameDiff: Boolean = false): Unit = { withTempPath { dir => val data = roundtripParquet(input, dir).coalesce(1) @@ -1474,22 +1486,29 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { withSQLConf((SQLConf.ANSI_ENABLED.key, "false")) { // cast() should return null for invalid inputs when ansi mode is disabled val df = data.select(col("a"), col("a").cast(toType)).orderBy(col("a")) - if (hasIncompatibleType) { - checkSparkAnswer(df) + if (useDataFrameDiff) { + assertDataFrameEqualsWithExceptions(df, assertCometNative = !hasIncompatibleType) } else { - checkSparkAnswerAndOperator(df) + if (hasIncompatibleType) { + checkSparkAnswer(df) + } else { + checkSparkAnswerAndOperator(df) + } } if (testTry) { data.createOrReplaceTempView("t") -// try_cast() should always return null for invalid inputs -// not using spark DSL since it `try_cast` is only available from Spark 4x - val df2 = - spark.sql(s"select a, try_cast(a as ${toType.sql}) from t order by a") + // try_cast() should always return null for invalid inputs + // not using spark DSL since it `try_cast` is only available from Spark 4x + val df2 = spark.sql(s"select a, try_cast(a as ${toType.sql}) from t order by a") if (hasIncompatibleType) { checkSparkAnswer(df2) } else { - checkSparkAnswerAndOperator(df2) + if (useDataFrameDiff) { + assertDataFrameEqualsWithExceptions(df2, assertCometNative = !hasIncompatibleType) + } else { + checkSparkAnswerAndOperator(df2) + } } } } @@ -1502,7 +1521,12 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { // cast() should throw exception on invalid inputs when ansi mode is enabled val df = data.withColumn("converted", col("a").cast(toType)) - checkSparkAnswerMaybeThrows(df) match { + val res = if (useDataFrameDiff) { + assertDataFrameEqualsWithExceptions(df, assertCometNative = !hasIncompatibleType) + } else { + checkSparkAnswerMaybeThrows(df) + } + res match { case (None, None) => // neither system threw an exception case (None, Some(e)) => @@ -1546,12 +1570,15 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } } + } - // try_cast() should always return null for invalid inputs - if (testTry) { - data.createOrReplaceTempView("t") - val df2 = - spark.sql(s"select a, try_cast(a as ${toType.sql}) from t order by a") + // try_cast() should always return null for invalid inputs + if (testTry) { + data.createOrReplaceTempView("t") + val df2 = spark.sql(s"select a, try_cast(a as ${toType.sql}) from t order by a") + if (useDataFrameDiff) { + assertDataFrameEqualsWithExceptions(df2, assertCometNative = !hasIncompatibleType) + } else { if (hasIncompatibleType) { checkSparkAnswer(df2) } else { diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala index 41080ed9e9..f831d53bfe 100644 --- a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala @@ -1276,4 +1276,53 @@ abstract class CometTestBase !usingLegacyNativeCometScan(conf) && CometConf.COMET_PARQUET_UNSIGNED_SMALL_INT_CHECK.get(conf) } + + /** + * Compares Spark and Comet results using foreach() and exceptAll() to avoid collect() + */ + protected def assertDataFrameEqualsWithExceptions( + df: => DataFrame, + assertCometNative: Boolean = true): (Option[Throwable], Option[Throwable]) = { + + var expected: Try[Unit] = null + withSQLConf(CometConf.COMET_ENABLED.key -> "false") { + expected = Try(datasetOfRows(spark, df.logicalPlan).foreach(_ => ())) + } + val actual = Try(datasetOfRows(spark, df.logicalPlan).foreach(_ => ())) + + (expected, actual) match { + case (Success(_), Success(_)) => + // compare results and confirm that they match + var dfSpark: DataFrame = null + withSQLConf(CometConf.COMET_ENABLED.key -> "false") { + dfSpark = datasetOfRows(spark, df.logicalPlan) + } + val dfComet = datasetOfRows(spark, df.logicalPlan) + + // Compare schemas + assert( + dfSpark.schema == dfComet.schema, + s"Schema mismatch:\nSpark: ${dfSpark.schema}\nComet: ${dfComet.schema}") + + val sparkMinusComet = dfSpark.exceptAll(dfComet) + val cometMinusSpark = dfComet.exceptAll(dfSpark) + val diffCount1 = sparkMinusComet.count() + val diffCount2 = cometMinusSpark.count() + + if (diffCount1 > 0 || diffCount2 > 0) { + fail( + "Results do not match. " + + s"Rows in Spark but not Comet: $diffCount1. " + + s"Rows in Comet but not Spark: $diffCount2.") + } + + if (assertCometNative) { + checkCometOperators(stripAQEPlan(dfComet.queryExecution.executedPlan)) + } + + (None, None) + case _ => + (expected.failed.toOption, actual.failed.toOption) + } + } }