From 169833ea770a76f2be30c0881924918dfa4869a2 Mon Sep 17 00:00:00 2001 From: 0lai0 Date: Wed, 11 Mar 2026 23:16:39 +0800 Subject: [PATCH 1/2] Native engine crashes on literal DateTrunc and TimestampTrunc --- .../src/datetime_funcs/date_trunc.rs | 24 ++++++++++++----- .../src/datetime_funcs/timestamp_trunc.rs | 26 +++++++++++++------ .../comet/CometTemporalExpressionSuite.scala | 14 ++++++++++ 3 files changed, 49 insertions(+), 15 deletions(-) diff --git a/native/spark-expr/src/datetime_funcs/date_trunc.rs b/native/spark-expr/src/datetime_funcs/date_trunc.rs index 6d36b0975c..2454f31cc9 100644 --- a/native/spark-expr/src/datetime_funcs/date_trunc.rs +++ b/native/spark-expr/src/datetime_funcs/date_trunc.rs @@ -67,18 +67,28 @@ impl ScalarUDFImpl for SparkDateTrunc { fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { let [date, format] = take_function_args(self.name(), args.args)?; - match (date, format) { - (ColumnarValue::Array(date), ColumnarValue::Scalar(Utf8(Some(format)))) => { - let result = date_trunc_dyn(&date, format)?; + let num_rows = [&date, &format] + .iter() + .find_map(|arg| match arg { + ColumnarValue::Array(array) => Some(array.len()), + ColumnarValue::Scalar(_) => None, + }) + .unwrap_or(1); + + let date_arr = date.into_array(num_rows)?; + + match format { + ColumnarValue::Scalar(Utf8(Some(format))) => { + let result = date_trunc_dyn(&date_arr, format)?; Ok(ColumnarValue::Array(result)) } - (ColumnarValue::Array(date), ColumnarValue::Array(formats)) => { - let result = date_trunc_array_fmt_dyn(&date, &formats)?; + ColumnarValue::Array(formats) => { + let result = date_trunc_array_fmt_dyn(&date_arr, &formats)?; Ok(ColumnarValue::Array(result)) } _ => Err(DataFusionError::Execution( - "Invalid input to function DateTrunc. Expected (PrimitiveArray, Scalar) or \ - (PrimitiveArray, StringArray)".to_string(), + "Invalid format input to function DateTrunc. Expected Scalar or StringArray" + .to_string(), )), } } diff --git a/native/spark-expr/src/datetime_funcs/timestamp_trunc.rs b/native/spark-expr/src/datetime_funcs/timestamp_trunc.rs index 1a35f02e07..c748c4e2e1 100644 --- a/native/spark-expr/src/datetime_funcs/timestamp_trunc.rs +++ b/native/spark-expr/src/datetime_funcs/timestamp_trunc.rs @@ -111,19 +111,30 @@ impl PhysicalExpr for TimestampTruncExpr { let timestamp = self.child.evaluate(batch)?; let format = self.format.evaluate(batch)?; let tz = self.timezone.clone(); - match (timestamp, format) { - (ColumnarValue::Array(ts), ColumnarValue::Scalar(Utf8(Some(format)))) => { + + let num_rows = match ×tamp { + ColumnarValue::Array(array) => array.len(), + ColumnarValue::Scalar(_) => match &format { + ColumnarValue::Array(array) => array.len(), + ColumnarValue::Scalar(_) => 1, + }, + }; + + let ts_arr = timestamp.into_array(num_rows)?; + + match format { + ColumnarValue::Scalar(Utf8(Some(format))) => { let ts = array_with_timezone( - ts, + ts_arr, tz.clone(), Some(&DataType::Timestamp(Microsecond, Some(tz.into()))), )?; let result = timestamp_trunc_dyn(&ts, format)?; Ok(ColumnarValue::Array(result)) } - (ColumnarValue::Array(ts), ColumnarValue::Array(formats)) => { + ColumnarValue::Array(formats) => { let ts = array_with_timezone( - ts, + ts_arr, tz.clone(), Some(&DataType::Timestamp(Microsecond, Some(tz.into()))), )?; @@ -131,9 +142,8 @@ impl PhysicalExpr for TimestampTruncExpr { Ok(ColumnarValue::Array(result)) } _ => Err(DataFusionError::Execution( - "Invalid input to function TimestampTrunc. \ - Expected (PrimitiveArray, Scalar, String) or \ - (PrimitiveArray, StringArray, String)" + "Invalid format input to function TimestampTrunc. \ + Expected Scalar or StringArray" .to_string(), )), } diff --git a/spark/src/test/scala/org/apache/comet/CometTemporalExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometTemporalExpressionSuite.scala index 1ae6926e05..2f14fa1d6c 100644 --- a/spark/src/test/scala/org/apache/comet/CometTemporalExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometTemporalExpressionSuite.scala @@ -58,6 +58,13 @@ class CometTemporalExpressionSuite extends CometTestBase with AdaptiveSparkPlanH checkSparkAnswerAndFallbackReason( "SELECT c0, trunc(c0, c1) from tbl order by c0, c1", "Invalid format strings will throw an exception instead of returning NULL") + + // Disable constant folding to ensure literal expressions are executed by Comet + withSQLConf( + SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> + "org.apache.spark.sql.catalyst.optimizer.ConstantFolding") { + checkSparkAnswerAndOperator("SELECT trunc(DATE('2024-06-15'), 'year')") + } } test("date_trunc (TruncTimestamp) - reading from DataFrame") { @@ -82,6 +89,13 @@ class CometTemporalExpressionSuite extends CometTestBase with AdaptiveSparkPlanH checkSparkAnswerAndFallbackReason( "SELECT c0, date_trunc(fmt, c0) from tbl order by c0, fmt", "Invalid format strings will throw an exception instead of returning NULL") + + // Disable constant folding to ensure literal expressions are executed by Comet + withSQLConf( + SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> + "org.apache.spark.sql.catalyst.optimizer.ConstantFolding") { + checkSparkAnswerAndOperator("SELECT date_trunc('year', TIMESTAMP '2024-06-15 10:30:45')") + } } } From 7209dee98caff951926d9ea1af455e574c5c378f Mon Sep 17 00:00:00 2001 From: 0lai0 Date: Thu, 12 Mar 2026 10:40:05 +0800 Subject: [PATCH 2/2] address comment --- .../src/datetime_funcs/date_trunc.rs | 33 ++++++++-------- .../src/datetime_funcs/timestamp_trunc.rs | 38 +++++++++---------- 2 files changed, 34 insertions(+), 37 deletions(-) diff --git a/native/spark-expr/src/datetime_funcs/date_trunc.rs b/native/spark-expr/src/datetime_funcs/date_trunc.rs index 2454f31cc9..aeae18e36f 100644 --- a/native/spark-expr/src/datetime_funcs/date_trunc.rs +++ b/native/spark-expr/src/datetime_funcs/date_trunc.rs @@ -16,7 +16,9 @@ // under the License. use arrow::datatypes::DataType; -use datafusion::common::{utils::take_function_args, DataFusionError, Result, ScalarValue::Utf8}; +use datafusion::common::{ + utils::take_function_args, DataFusionError, Result, ScalarValue, ScalarValue::Utf8, +}; use datafusion::logical_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, }; @@ -67,28 +69,23 @@ impl ScalarUDFImpl for SparkDateTrunc { fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { let [date, format] = take_function_args(self.name(), args.args)?; - let num_rows = [&date, &format] - .iter() - .find_map(|arg| match arg { - ColumnarValue::Array(array) => Some(array.len()), - ColumnarValue::Scalar(_) => None, - }) - .unwrap_or(1); - - let date_arr = date.into_array(num_rows)?; - - match format { - ColumnarValue::Scalar(Utf8(Some(format))) => { - let result = date_trunc_dyn(&date_arr, format)?; + match (date, format) { + (ColumnarValue::Array(date), ColumnarValue::Scalar(Utf8(Some(format)))) => { + let result = date_trunc_dyn(&date, format)?; Ok(ColumnarValue::Array(result)) } - ColumnarValue::Array(formats) => { - let result = date_trunc_array_fmt_dyn(&date_arr, &formats)?; + (ColumnarValue::Array(date), ColumnarValue::Array(formats)) => { + let result = date_trunc_array_fmt_dyn(&date, &formats)?; Ok(ColumnarValue::Array(result)) } + (ColumnarValue::Scalar(date_scalar), ColumnarValue::Scalar(Utf8(Some(format)))) => { + let date_arr = date_scalar.to_array()?; + let result = date_trunc_dyn(&date_arr, format)?; + let scalar = ScalarValue::try_from_array(&result, 0)?; + Ok(ColumnarValue::Scalar(scalar)) + } _ => Err(DataFusionError::Execution( - "Invalid format input to function DateTrunc. Expected Scalar or StringArray" - .to_string(), + "Invalid input to function DateTrunc. Expected (Date32, Utf8)".to_string(), )), } } diff --git a/native/spark-expr/src/datetime_funcs/timestamp_trunc.rs b/native/spark-expr/src/datetime_funcs/timestamp_trunc.rs index c748c4e2e1..2d7a571b76 100644 --- a/native/spark-expr/src/datetime_funcs/timestamp_trunc.rs +++ b/native/spark-expr/src/datetime_funcs/timestamp_trunc.rs @@ -18,7 +18,7 @@ use crate::utils::array_with_timezone; use arrow::datatypes::{DataType, Schema, TimeUnit::Microsecond}; use arrow::record_batch::RecordBatch; -use datafusion::common::{DataFusionError, ScalarValue::Utf8}; +use datafusion::common::{DataFusionError, ScalarValue, ScalarValue::Utf8}; use datafusion::logical_expr::ColumnarValue; use datafusion::physical_expr::PhysicalExpr; use std::hash::Hash; @@ -111,39 +111,39 @@ impl PhysicalExpr for TimestampTruncExpr { let timestamp = self.child.evaluate(batch)?; let format = self.format.evaluate(batch)?; let tz = self.timezone.clone(); - - let num_rows = match ×tamp { - ColumnarValue::Array(array) => array.len(), - ColumnarValue::Scalar(_) => match &format { - ColumnarValue::Array(array) => array.len(), - ColumnarValue::Scalar(_) => 1, - }, - }; - - let ts_arr = timestamp.into_array(num_rows)?; - - match format { - ColumnarValue::Scalar(Utf8(Some(format))) => { + match (timestamp, format) { + (ColumnarValue::Array(ts), ColumnarValue::Scalar(Utf8(Some(format)))) => { let ts = array_with_timezone( - ts_arr, + ts, tz.clone(), Some(&DataType::Timestamp(Microsecond, Some(tz.into()))), )?; let result = timestamp_trunc_dyn(&ts, format)?; Ok(ColumnarValue::Array(result)) } - ColumnarValue::Array(formats) => { + (ColumnarValue::Array(ts), ColumnarValue::Array(formats)) => { let ts = array_with_timezone( - ts_arr, + ts, tz.clone(), Some(&DataType::Timestamp(Microsecond, Some(tz.into()))), )?; let result = timestamp_trunc_array_fmt_dyn(&ts, &formats)?; Ok(ColumnarValue::Array(result)) } + (ColumnarValue::Scalar(ts_scalar), ColumnarValue::Scalar(Utf8(Some(format)))) => { + let ts_arr = ts_scalar.to_array()?; + let ts = array_with_timezone( + ts_arr, + tz.clone(), + Some(&DataType::Timestamp(Microsecond, Some(tz.into()))), + )?; + let result = timestamp_trunc_dyn(&ts, format)?; + let scalar = ScalarValue::try_from_array(&result, 0)?; + Ok(ColumnarValue::Scalar(scalar)) + } _ => Err(DataFusionError::Execution( - "Invalid format input to function TimestampTrunc. \ - Expected Scalar or StringArray" + "Invalid input to function TimestampTrunc. \ + Expected (Timestamp, Utf8)" .to_string(), )), }