diff --git a/spark/src/main/scala/org/apache/comet/serde/datetime.scala b/spark/src/main/scala/org/apache/comet/serde/datetime.scala index 3192697b26..bb0a57632f 100644 --- a/spark/src/main/scala/org/apache/comet/serde/datetime.scala +++ b/spark/src/main/scala/org/apache/comet/serde/datetime.scala @@ -26,11 +26,13 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DateType, DoubleType, FloatType, IntegerType, LongType, StringType, TimestampNTZType, TimestampType} import org.apache.spark.unsafe.types.UTF8String +import org.apache.comet.CometConf import org.apache.comet.CometSparkSessionExtensions.withInfo import org.apache.comet.expressions.{CometCast, CometEvalMode} import org.apache.comet.serde.CometGetDateField.CometGetDateField import org.apache.comet.serde.ExprOuterClass.Expr import org.apache.comet.serde.QueryPlanSerde._ +import org.apache.comet.udf.DateFormatUDF private object CometGetDateField extends Enumeration { type CometGetDateField = Value @@ -572,17 +574,21 @@ object CometTruncTimestamp extends CometExpressionSerde[TruncTimestamp] { } /** - * Converts Spark DateFormatClass expression to DataFusion's to_char function. + * Converts Spark DateFormatClass expression to DataFusion's to_char function when the format and + * timezone are mappable; otherwise emits a JvmScalarUdf that delegates to Spark's own + * `DateFormatClass` so that any format / timezone combination remains supported. * - * Spark uses Java SimpleDateFormat patterns while DataFusion uses strftime patterns. This - * implementation supports a whitelist of common format strings that can be reliably mapped - * between the two systems. + * Routing: + * - format is a literal in `supportedFormats` AND timezone is UTC -> native to_char + * - format is a literal in `supportedFormats` AND timezone is non-UTC, with the per-expression + * allowIncompatible flag set -> native to_char (results may differ from Spark) + * - all other cases -> JVM UDF (`org.apache.comet.udf.DateFormatUDF`) */ object CometDateFormat extends CometExpressionSerde[DateFormatClass] { /** * Mapping from Spark SimpleDateFormat patterns to strftime patterns. Only formats in this map - * are supported. + * are supported by the native path. */ val supportedFormats: Map[String, String] = Map( // Full date formats @@ -616,67 +622,70 @@ object CometDateFormat extends CometExpressionSerde[DateFormatClass] { // ISO formats "yyyy-MM-dd'T'HH:mm:ss" -> "%Y-%m-%dT%H:%M:%S") - override def getIncompatibleReasons(): Seq[String] = Seq( - "Non-UTC timezones may produce different results than Spark") - - override def getUnsupportedReasons(): Seq[String] = Seq( - "Only the following formats are supported:" + - supportedFormats.keys.toSeq.sorted - .map(k => s"`$k`") - .mkString("\n - ", "\n - ", "")) + // The JVM UDF covers every case that the native path cannot, so the expression is always + // emittable. Compatibility decisions happen inside `convert`. + override def getSupportLevel(expr: DateFormatClass): SupportLevel = Compatible() - override def getSupportLevel(expr: DateFormatClass): SupportLevel = { - // Check timezone - only UTC is fully compatible + override def convert( + expr: DateFormatClass, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { val timezone = expr.timeZoneId.getOrElse("UTC") val isUtc = timezone == "UTC" || timezone == "Etc/UTC" - expr.right match { - case Literal(fmt: UTF8String, _) => - val format = fmt.toString - if (supportedFormats.contains(format)) { - if (isUtc) { - Compatible() - } else { - Incompatible(Some(s"Non-UTC timezone '$timezone' may produce different results")) - } - } else { - Unsupported( - Some( - s"Format '$format' is not supported. Supported formats: " + - supportedFormats.keys.mkString(", "))) - } - case _ => - Unsupported(Some("Only literal format strings are supported")) + val nativeFormat: Option[String] = expr.right match { + case Literal(fmt: UTF8String, _) => supportedFormats.get(fmt.toString) + case _ => None + } + + val canUseNative = nativeFormat.isDefined && { + isUtc || CometConf.isExprAllowIncompat(getExprConfigName(expr)) + } + + if (canUseNative) { + val childExpr = exprToProtoInternal(expr.left, inputs, binding) + val formatExpr = exprToProtoInternal(Literal(nativeFormat.get), inputs, binding) + val optExpr = scalarFunctionExprToProtoWithReturnType( + "to_char", + StringType, + false, + childExpr, + formatExpr) + optExprWithInfo(optExpr, expr, expr.left, expr.right) + } else { + convertViaJvmUdf(expr, timezone, inputs, binding) } } - override def convert( + private def convertViaJvmUdf( expr: DateFormatClass, + timezone: String, inputs: Seq[Attribute], binding: Boolean): Option[ExprOuterClass.Expr] = { - // Get the format string - must be a literal for us to map it - val strftimeFormat = expr.right match { - case Literal(fmt: UTF8String, _) => - supportedFormats.get(fmt.toString) - case _ => None + val tsProto = exprToProtoInternal(expr.left, inputs, binding) + val fmtProto = exprToProtoInternal(expr.right, inputs, binding) + val tzProto = exprToProtoInternal(Literal(timezone), inputs, binding) + if (tsProto.isEmpty || fmtProto.isEmpty || tzProto.isEmpty) { + withInfo(expr, expr.left, expr.right) + return None } - - strftimeFormat match { - case Some(format) => - val childExpr = exprToProtoInternal(expr.left, inputs, binding) - val formatExpr = exprToProtoInternal(Literal(format), inputs, binding) - - val optExpr = scalarFunctionExprToProtoWithReturnType( - "to_char", - StringType, - false, - childExpr, - formatExpr) - optExprWithInfo(optExpr, expr, expr.left, expr.right) - case None => - withInfo(expr, expr.left, expr.right) - None + val returnType = serializeDataType(StringType).getOrElse { + withInfo(expr, "Failed to serialize StringType return type") + return None } + val udfBuilder = ExprOuterClass.JvmScalarUdf + .newBuilder() + .setClassName(classOf[DateFormatUDF].getName) + .addArgs(tsProto.get) + .addArgs(fmtProto.get) + .addArgs(tzProto.get) + .setReturnType(returnType) + .setReturnNullable(expr.nullable) + Some( + ExprOuterClass.Expr + .newBuilder() + .setJvmScalarUdf(udfBuilder.build()) + .build()) } } diff --git a/spark/src/main/scala/org/apache/comet/udf/DateFormatUDF.scala b/spark/src/main/scala/org/apache/comet/udf/DateFormatUDF.scala new file mode 100644 index 0000000000..966301e90e --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/udf/DateFormatUDF.scala @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.udf + +import java.nio.charset.StandardCharsets.UTF_8 +import java.util.concurrent.ConcurrentHashMap + +import org.apache.arrow.vector.{TimeStampMicroTZVector, TimeStampMicroVector, ValueVector, VarCharVector} +import org.apache.spark.sql.catalyst.expressions.{BoundReference, DateFormatClass, Literal} +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow +import org.apache.spark.sql.types.{StringType, TimestampType} +import org.apache.spark.unsafe.types.UTF8String + +import org.apache.comet.CometArrowAllocator + +/** + * `date_format(timestamp, format)` implemented by delegating to Spark's `DateFormatClass`. + * + * Used as the JVM fallback when CometDateFormat cannot push to native (non-UTC timezone, format + * outside the strftime-mappable whitelist, non-literal format string). + * + * Inputs: + * - inputs(0): TimeStampMicro[TZ]Vector - timestamp column (microseconds since epoch) + * - inputs(1): VarCharVector - format string; length-1 if literal, else per-row + * - inputs(2): VarCharVector - session timezone id (length-1 scalar) + * + * Output: VarCharVector of length `numRows`. + */ +class DateFormatUDF extends CometUDF { + + // Cache one DateFormatClass per (format, timezone). Constructing it with a Literal format makes + // its `formatterOption` lazy-val resolve to Some(formatter), so subsequent eval calls reuse the + // formatter instead of rebuilding it per row. + private val cache = new ConcurrentHashMap[(String, String), DateFormatClass]() + + private def lookup(formatStr: String, tzId: String): DateFormatClass = + cache.computeIfAbsent( + (formatStr, tzId), + { case (f, tz) => + DateFormatClass( + BoundReference(0, TimestampType, nullable = true), + Literal(UTF8String.fromString(f), StringType), + Some(tz)) + }) + + override def evaluate(inputs: Array[ValueVector], numRows: Int): ValueVector = { + require( + inputs.length == 3, + s"DateFormatUDF expects 3 inputs (timestamp, format, timezone), got ${inputs.length}") + val tsVec = inputs(0) + val fmtVec = inputs(1).asInstanceOf[VarCharVector] + val tzVec = inputs(2).asInstanceOf[VarCharVector] + require( + tzVec.getValueCount >= 1 && !tzVec.isNull(0), + "DateFormatUDF requires a non-null scalar timezone") + + val tzId = new String(tzVec.get(0), UTF_8) + val fmtScalar = fmtVec.getValueCount == 1 + // For scalar format the format never varies, so resolve the DateFormatClass once and reuse it + // across every row instead of doing a Tuple2 allocation + HashMap lookup per row. + val scalarDf: DateFormatClass = + if (fmtScalar && !fmtVec.isNull(0)) lookup(new String(fmtVec.get(0), UTF_8), tzId) + else null + + val getMicros: Int => Long = tsVec match { + case t: TimeStampMicroTZVector => i => t.get(i) + case t: TimeStampMicroVector => i => t.get(i) + case other => + throw new RuntimeException( + s"DateFormatUDF: unsupported timestamp vector ${other.getClass.getName}") + } + + val out = new VarCharVector("date_format_result", CometArrowAllocator) + out.allocateNew(numRows) + + val row = new GenericInternalRow(1) + + var i = 0 + while (i < numRows) { + val fmtIdx = if (fmtScalar) 0 else i + val result: AnyRef = + if (tsVec.isNull(i) || fmtVec.isNull(fmtIdx)) null + else { + val df = + if (scalarDf != null) scalarDf + else lookup(new String(fmtVec.get(i), UTF_8), tzId) + row.update(0, getMicros(i)) + df.eval(row).asInstanceOf[AnyRef] + } + if (result == null) out.setNull(i) + else out.setSafe(i, result.asInstanceOf[UTF8String].getBytes) + i += 1 + } + out.setValueCount(numRows) + out + } +} diff --git a/spark/src/test/resources/sql-tests/expressions/datetime/date_format.sql b/spark/src/test/resources/sql-tests/expressions/datetime/date_format.sql index 09333f44d3..0af00d477d 100644 --- a/spark/src/test/resources/sql-tests/expressions/datetime/date_format.sql +++ b/spark/src/test/resources/sql-tests/expressions/datetime/date_format.sql @@ -21,15 +21,15 @@ CREATE TABLE test_date_format(ts timestamp) USING parquet statement INSERT INTO test_date_format VALUES (timestamp('2024-06-15 10:30:45')), (timestamp('1970-01-01 00:00:00')), (NULL) -query expect_fallback(Non-UTC timezone) +query SELECT date_format(ts, 'yyyy-MM-dd') FROM test_date_format -query expect_fallback(Non-UTC timezone) +query SELECT date_format(ts, 'HH:mm:ss') FROM test_date_format -query expect_fallback(Non-UTC timezone) +query SELECT date_format(ts, 'yyyy-MM-dd HH:mm:ss') FROM test_date_format -- literal arguments -query expect_fallback(Non-UTC timezone) +query SELECT date_format(timestamp('2024-06-15 10:30:45'), 'yyyy-MM-dd') diff --git a/spark/src/test/scala/org/apache/comet/CometTemporalExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometTemporalExpressionSuite.scala index a8147089d9..1346d3c2df 100644 --- a/spark/src/test/scala/org/apache/comet/CometTemporalExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometTemporalExpressionSuite.scala @@ -214,10 +214,8 @@ class CometTemporalExpressionSuite extends CometTestBase with AdaptiveSparkPlanH } test("date_format - timestamp_ntz input") { - // TimestampNTZ is timezone-independent, so date_format should produce the same - // formatted string regardless of session timezone. Comet currently only runs this - // natively for UTC; for non-UTC it falls back to Spark. We verify correctness - // (matching Spark's output) in all cases. + // TimestampNTZ is timezone-independent, so date_format must produce the same string + // regardless of session timezone. val r = new Random(42) val ntzSchema = StructType(Seq(StructField("ts_ntz", DataTypes.TimestampNTZType, true))) val ntzDF = FuzzDataGenerator.generateDataFrame(r, spark, ntzSchema, 100, DataGenOptions()) @@ -227,14 +225,8 @@ class CometTemporalExpressionSuite extends CometTestBase with AdaptiveSparkPlanH for (tz <- crossTimezones) { withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> tz) { for (format <- supportedFormats) { - if (tz == "UTC") { - checkSparkAnswerAndOperator( - s"SELECT ts_ntz, date_format(ts_ntz, '$format') from ntz_tbl order by ts_ntz") - } else { - // Non-UTC falls back to Spark but should still produce correct results - checkSparkAnswer( - s"SELECT ts_ntz, date_format(ts_ntz, '$format') from ntz_tbl order by ts_ntz") - } + checkSparkAnswerAndOperator( + s"SELECT ts_ntz, date_format(ts_ntz, '$format') from ntz_tbl order by ts_ntz") } } } @@ -476,18 +468,16 @@ class CometTemporalExpressionSuite extends CometTestBase with AdaptiveSparkPlanH } } - test("date_format unsupported format falls back to Spark") { + test("date_format unsupported format runs via JVM UDF inside Comet") { createTimestampTestData.createOrReplaceTempView("tbl") withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC") { - // Unsupported format string - checkSparkAnswerAndFallbackReason( - "SELECT c0, date_format(c0, 'yyyy-MM-dd EEEE') from tbl order by c0", - "Format 'yyyy-MM-dd EEEE' is not supported") + checkSparkAnswerAndOperator( + "SELECT c0, date_format(c0, 'yyyy-MM-dd EEEE') from tbl order by c0") } } - test("date_format with non-UTC timezone falls back to Spark") { + test("date_format with non-UTC timezone runs via JVM UDF inside Comet") { createTimestampTestData.createOrReplaceTempView("tbl") val nonUtcTimezones = @@ -495,15 +485,13 @@ class CometTemporalExpressionSuite extends CometTestBase with AdaptiveSparkPlanH for (tz <- nonUtcTimezones) { withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> tz) { - // Non-UTC timezones should fall back to Spark as Incompatible - checkSparkAnswerAndFallbackReason( - "SELECT c0, date_format(c0, 'yyyy-MM-dd HH:mm:ss') from tbl order by c0", - s"Non-UTC timezone '$tz' may produce different results") + checkSparkAnswerAndOperator( + "SELECT c0, date_format(c0, 'yyyy-MM-dd HH:mm:ss') from tbl order by c0") } } } - test("date_format with non-UTC timezone works when allowIncompatible is enabled") { + test("date_format with non-UTC timezone takes native path when allowIncompatible is enabled") { createTimestampTestData.createOrReplaceTempView("tbl") val nonUtcTimezones = Seq("America/New_York", "Europe/London", "Asia/Tokyo") @@ -511,10 +499,13 @@ class CometTemporalExpressionSuite extends CometTestBase with AdaptiveSparkPlanH for (tz <- nonUtcTimezones) { withSQLConf( SQLConf.SESSION_LOCAL_TIMEZONE.key -> tz, - "spark.comet.expr.DateFormatClass.allowIncompatible" -> "true") { - // With allowIncompatible enabled, Comet will execute the expression - // Results may differ from Spark but should not throw errors - checkSparkAnswer("SELECT c0, date_format(c0, 'yyyy-MM-dd') from tbl order by c0") + "spark.comet.expression.DateFormatClass.allowIncompatible" -> "true") { + // Native to_char results may diverge from Spark for non-UTC timezones (the reason the + // JVM UDF is the default), so we only check that execution stays inside Comet. ORDER BY + // is omitted to keep the plan free of AQEShuffleRead. + val df = sql("SELECT c0, date_format(c0, 'yyyy-MM-dd') from tbl") + df.collect() + checkCometOperators(stripAQEPlan(df.queryExecution.executedPlan)) } } }