diff --git a/native/spark-expr/src/datetime_funcs/date_trunc.rs b/native/spark-expr/src/datetime_funcs/date_trunc.rs index 6d36b0975c..aeae18e36f 100644 --- a/native/spark-expr/src/datetime_funcs/date_trunc.rs +++ b/native/spark-expr/src/datetime_funcs/date_trunc.rs @@ -16,7 +16,9 @@ // under the License. use arrow::datatypes::DataType; -use datafusion::common::{utils::take_function_args, DataFusionError, Result, ScalarValue::Utf8}; +use datafusion::common::{ + utils::take_function_args, DataFusionError, Result, ScalarValue, ScalarValue::Utf8, +}; use datafusion::logical_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, }; @@ -76,9 +78,14 @@ impl ScalarUDFImpl for SparkDateTrunc { let result = date_trunc_array_fmt_dyn(&date, &formats)?; Ok(ColumnarValue::Array(result)) } + (ColumnarValue::Scalar(date_scalar), ColumnarValue::Scalar(Utf8(Some(format)))) => { + let date_arr = date_scalar.to_array()?; + let result = date_trunc_dyn(&date_arr, format)?; + let scalar = ScalarValue::try_from_array(&result, 0)?; + Ok(ColumnarValue::Scalar(scalar)) + } _ => Err(DataFusionError::Execution( - "Invalid input to function DateTrunc. Expected (PrimitiveArray, Scalar) or \ - (PrimitiveArray, StringArray)".to_string(), + "Invalid input to function DateTrunc. Expected (Date32, Utf8)".to_string(), )), } } diff --git a/native/spark-expr/src/datetime_funcs/timestamp_trunc.rs b/native/spark-expr/src/datetime_funcs/timestamp_trunc.rs index 1a35f02e07..2d7a571b76 100644 --- a/native/spark-expr/src/datetime_funcs/timestamp_trunc.rs +++ b/native/spark-expr/src/datetime_funcs/timestamp_trunc.rs @@ -18,7 +18,7 @@ use crate::utils::array_with_timezone; use arrow::datatypes::{DataType, Schema, TimeUnit::Microsecond}; use arrow::record_batch::RecordBatch; -use datafusion::common::{DataFusionError, ScalarValue::Utf8}; +use datafusion::common::{DataFusionError, ScalarValue, ScalarValue::Utf8}; use datafusion::logical_expr::ColumnarValue; use datafusion::physical_expr::PhysicalExpr; use std::hash::Hash; @@ -130,10 +130,20 @@ impl PhysicalExpr for TimestampTruncExpr { let result = timestamp_trunc_array_fmt_dyn(&ts, &formats)?; Ok(ColumnarValue::Array(result)) } + (ColumnarValue::Scalar(ts_scalar), ColumnarValue::Scalar(Utf8(Some(format)))) => { + let ts_arr = ts_scalar.to_array()?; + let ts = array_with_timezone( + ts_arr, + tz.clone(), + Some(&DataType::Timestamp(Microsecond, Some(tz.into()))), + )?; + let result = timestamp_trunc_dyn(&ts, format)?; + let scalar = ScalarValue::try_from_array(&result, 0)?; + Ok(ColumnarValue::Scalar(scalar)) + } _ => Err(DataFusionError::Execution( "Invalid input to function TimestampTrunc. \ - Expected (PrimitiveArray, Scalar, String) or \ - (PrimitiveArray, StringArray, String)" + Expected (Timestamp, Utf8)" .to_string(), )), } diff --git a/spark/src/test/scala/org/apache/comet/CometTemporalExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometTemporalExpressionSuite.scala index 1ae6926e05..2f14fa1d6c 100644 --- a/spark/src/test/scala/org/apache/comet/CometTemporalExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometTemporalExpressionSuite.scala @@ -58,6 +58,13 @@ class CometTemporalExpressionSuite extends CometTestBase with AdaptiveSparkPlanH checkSparkAnswerAndFallbackReason( "SELECT c0, trunc(c0, c1) from tbl order by c0, c1", "Invalid format strings will throw an exception instead of returning NULL") + + // Disable constant folding to ensure literal expressions are executed by Comet + withSQLConf( + SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> + "org.apache.spark.sql.catalyst.optimizer.ConstantFolding") { + checkSparkAnswerAndOperator("SELECT trunc(DATE('2024-06-15'), 'year')") + } } test("date_trunc (TruncTimestamp) - reading from DataFrame") { @@ -82,6 +89,13 @@ class CometTemporalExpressionSuite extends CometTestBase with AdaptiveSparkPlanH checkSparkAnswerAndFallbackReason( "SELECT c0, date_trunc(fmt, c0) from tbl order by c0, fmt", "Invalid format strings will throw an exception instead of returning NULL") + + // Disable constant folding to ensure literal expressions are executed by Comet + withSQLConf( + SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> + "org.apache.spark.sql.catalyst.optimizer.ConstantFolding") { + checkSparkAnswerAndOperator("SELECT date_trunc('year', TIMESTAMP '2024-06-15 10:30:45')") + } } }