diff --git a/native/shuffle/src/spark_unsafe/list.rs b/native/shuffle/src/spark_unsafe/list.rs index 3fea3fadeb..1a7da0b32f 100644 --- a/native/shuffle/src/spark_unsafe/list.rs +++ b/native/shuffle/src/spark_unsafe/list.rs @@ -24,7 +24,8 @@ use arrow::array::{ builder::{ ArrayBuilder, BinaryBuilder, BooleanBuilder, Date32Builder, Decimal128Builder, Float32Builder, Float64Builder, Int16Builder, Int32Builder, Int64Builder, Int8Builder, - ListBuilder, StringBuilder, StructBuilder, TimestampMicrosecondBuilder, + ListBuilder, StringBuilder, StructBuilder, Time64NanosecondBuilder, + TimestampMicrosecondBuilder, }, MapBuilder, }; @@ -179,6 +180,7 @@ impl SparkUnsafeArray { impl_append_to_builder!(append_ints_to_builder, Int32Builder, i32); impl_append_to_builder!(append_longs_to_builder, Int64Builder, i64); + impl_append_to_builder!(append_time64s_to_builder, Time64NanosecondBuilder, i64); impl_append_to_builder!(append_shorts_to_builder, Int16Builder, i16); impl_append_to_builder!(append_bytes_to_builder, Int8Builder, i8); impl_append_to_builder!(append_floats_to_builder, Float32Builder, f32); diff --git a/native/shuffle/src/spark_unsafe/row.rs b/native/shuffle/src/spark_unsafe/row.rs index ec0903bc56..25ff7c91b4 100644 --- a/native/shuffle/src/spark_unsafe/row.rs +++ b/native/shuffle/src/spark_unsafe/row.rs @@ -29,7 +29,7 @@ use arrow::array::{ ArrayBuilder, BinaryBuilder, BinaryDictionaryBuilder, BooleanBuilder, Date32Builder, Decimal128Builder, Float32Builder, Float64Builder, Int16Builder, Int32Builder, Int64Builder, Int8Builder, ListBuilder, MapBuilder, StringBuilder, StringDictionaryBuilder, - StructBuilder, TimestampMicrosecondBuilder, + StructBuilder, Time64NanosecondBuilder, TimestampMicrosecondBuilder, }, types::Int32Type, Array, ArrayRef, RecordBatch, RecordBatchOptions, @@ -274,6 +274,12 @@ pub(super) fn append_field( .append_value(row.get_timestamp(idx)) ); } + DataType::Time64(TimeUnit::Nanosecond) => { + append_field_to_builder!( + Time64NanosecondBuilder, + |builder: &mut Time64NanosecondBuilder| builder.append_value(row.get_long(idx)) + ); + } DataType::Binary => { append_field_to_builder!(BinaryBuilder, |builder: &mut BinaryBuilder| builder .append_value(row.get_binary(idx))); @@ -435,6 +441,13 @@ fn append_nested_struct_fields_field_major( |row: &SparkUnsafeRow, idx| row.get_timestamp(idx) ); } + DataType::Time64(TimeUnit::Nanosecond) => { + process_field!( + Time64NanosecondBuilder, + field_idx, + |row: &SparkUnsafeRow, idx| row.get_long(idx) + ); + } DataType::Binary => { let field_builder = get_field_builder!(struct_builder, BinaryBuilder, field_idx); @@ -652,6 +665,9 @@ fn append_list_column_batch( DataType::Timestamp(TimeUnit::Microsecond, _) => { process_primitive_lists!(TimestampMicrosecondBuilder, append_timestamps_to_builder); } + DataType::Time64(TimeUnit::Nanosecond) => { + process_primitive_lists!(Time64NanosecondBuilder, append_time64s_to_builder); + } // For complex element types, fall back to per-row dispatch _ => { for i in row_start..row_end { @@ -876,6 +892,13 @@ fn append_struct_fields_field_major( |row: &SparkUnsafeRow, idx| row.get_timestamp(idx) ); } + DataType::Time64(TimeUnit::Nanosecond) => { + process_field!( + Time64NanosecondBuilder, + field_idx, + |row: &SparkUnsafeRow, idx| row.get_long(idx) + ); + } DataType::Binary => { let field_builder = get_field_builder!(struct_builder, BinaryBuilder, field_idx); @@ -1155,6 +1178,13 @@ fn append_columns( .append_value(row.get_timestamp(idx)) ); } + DataType::Time64(TimeUnit::Nanosecond) => { + append_column_to_builder!( + Time64NanosecondBuilder, + |builder: &mut Time64NanosecondBuilder, row: &SparkUnsafeRow, idx| builder + .append_value(row.get_long(idx)) + ); + } DataType::Map(field, _) => { let map_builder = downcast_builder_ref!( MapBuilder, Box>, @@ -1255,6 +1285,9 @@ fn make_builders( DataType::Timestamp(TimeUnit::Microsecond, _) => { Box::new(TimestampMicrosecondBuilder::with_capacity(row_num).with_data_type(dt.clone())) } + DataType::Time64(TimeUnit::Nanosecond) => { + Box::new(Time64NanosecondBuilder::with_capacity(row_num)) + } DataType::Map(field, _) => { let (key_field, value_field, map_field_names) = get_map_key_value_fields(field)?; let key_dt = key_field.data_type(); diff --git a/native/spark-expr/src/hash_funcs/utils.rs b/native/spark-expr/src/hash_funcs/utils.rs index d634b27bf8..979f787b59 100644 --- a/native/spark-expr/src/hash_funcs/utils.rs +++ b/native/spark-expr/src/hash_funcs/utils.rs @@ -380,6 +380,10 @@ macro_rules! hash_list_with_primitive_elements { let elem_array = $values.as_any().downcast_ref::().unwrap(); $crate::hash_list_primitive!(offsets: $offsets, $list_array, elem_array, $hashes_buffer, $hash_method, |v: i64| v.to_le_bytes()); } + DataType::Time64(TimeUnit::Nanosecond) => { + let elem_array = $values.as_any().downcast_ref::().unwrap(); + $crate::hash_list_primitive!(offsets: $offsets, $list_array, elem_array, $hashes_buffer, $hash_method, |v: i64| v.to_le_bytes()); + } _ => { // Fall back to recursive approach for complex element types $crate::hash_list_array!($list_array_type, $fallback_offset_type, $col, $hashes_buffer, $recursive_hash_method); @@ -471,6 +475,10 @@ macro_rules! hash_list_with_primitive_elements { let elem_array = $values.as_any().downcast_ref::().unwrap(); $crate::hash_list_primitive!(fixed_size: $list_size, $list_array, elem_array, $hashes_buffer, $hash_method, |v: i64| v.to_le_bytes()); } + DataType::Time64(TimeUnit::Nanosecond) => { + let elem_array = $values.as_any().downcast_ref::().unwrap(); + $crate::hash_list_primitive!(fixed_size: $list_size, $list_array, elem_array, $hashes_buffer, $hash_method, |v: i64| v.to_le_bytes()); + } _ => { // Fall back to recursive approach for complex element types if $list_array.null_count() == 0 { @@ -689,6 +697,15 @@ macro_rules! create_hashes_internal { $hash_method ); } + DataType::Time64(TimeUnit::Nanosecond) => { + $crate::hash_array_primitive!( + Time64NanosecondArray, + col, + i64, + $hashes_buffer, + $hash_method + ); + } DataType::Utf8 => { $crate::hash_array!(StringArray, col, $hashes_buffer, $hash_method); } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala index d4ee4e4ccf..2ceac51139 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala @@ -369,6 +369,8 @@ object CometShuffleExchangeExec _: FloatType | _: DoubleType | _: StringType | _: BinaryType | _: TimestampType | _: TimestampNTZType | _: DecimalType | _: DateType => true + case dt if isTimeType(dt) => + true case StructType(fields) => fields.nonEmpty && fields.forall(f => supportedSerializableDataType(f.dataType)) case ArrayType(elementType, _) => @@ -492,6 +494,8 @@ object CometShuffleExchangeExec _: FloatType | _: DoubleType | _: StringType | _: BinaryType | _: TimestampType | _: TimestampNTZType | _: DecimalType | _: DateType => true + case dt if isTimeType(dt) => + true case StructType(fields) => fields.nonEmpty && fields.forall(f => supportedSerializableDataType(f.dataType)) && // Java Arrow stream reader cannot work on duplicate field name diff --git a/spark/src/test/resources/sql-tests/expressions/datetime/shuffle_time.sql b/spark/src/test/resources/sql-tests/expressions/datetime/shuffle_time.sql new file mode 100644 index 0000000000..318f5c1738 --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/datetime/shuffle_time.sql @@ -0,0 +1,40 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- MinSparkVersion: 4.1 +-- Config: spark.sql.timeType.enabled=true +-- Config: spark.comet.native.shuffle.partitioning.roundrobin.enabled=true +-- ConfigMatrix: spark.comet.exec.shuffle.mode=native,jvm + +statement +CREATE TABLE test_time_shuffle(hours int, minutes int, secs decimal(16,6)) USING parquet + +statement +INSERT INTO test_time_shuffle VALUES (12, 30, 45.123456), (0, 0, 0.0), (23, 59, 59.999999), (1, 2, 3.5), + (NULL, 0, 0.0), (NULL, NULL, NULL) + +query +SELECT /*+ REPARTITION(3) */ make_time(hours, minutes, secs) AS t FROM test_time_shuffle + +query +SELECT /*+ REPARTITION(3) */ hours, make_time(hours, minutes, secs) AS t, secs FROM test_time_shuffle + +query +SELECT /*+ REPARTITION(3) */ named_struct('t', make_time(hours, minutes, secs)) AS s FROM test_time_shuffle + +query +SELECT /*+ REPARTITION(3) */ array(make_time(hours, minutes, secs)) AS arr FROM test_time_shuffle