Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 63 additions & 54 deletions spark/src/main/scala/org/apache/comet/serde/datetime.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,13 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DateType, DoubleType, FloatType, IntegerType, LongType, StringType, TimestampNTZType, TimestampType}
import org.apache.spark.unsafe.types.UTF8String

import org.apache.comet.CometConf
import org.apache.comet.CometSparkSessionExtensions.withInfo
import org.apache.comet.expressions.{CometCast, CometEvalMode}
import org.apache.comet.serde.CometGetDateField.CometGetDateField
import org.apache.comet.serde.ExprOuterClass.Expr
import org.apache.comet.serde.QueryPlanSerde._
import org.apache.comet.udf.DateFormatUDF

private object CometGetDateField extends Enumeration {
type CometGetDateField = Value
Expand Down Expand Up @@ -572,17 +574,21 @@ object CometTruncTimestamp extends CometExpressionSerde[TruncTimestamp] {
}

/**
* Converts Spark DateFormatClass expression to DataFusion's to_char function.
* Converts Spark DateFormatClass expression to DataFusion's to_char function when the format and
* timezone are mappable; otherwise emits a JvmScalarUdf that delegates to Spark's own
* `DateFormatClass` so that any format / timezone combination remains supported.
*
* Spark uses Java SimpleDateFormat patterns while DataFusion uses strftime patterns. This
* implementation supports a whitelist of common format strings that can be reliably mapped
* between the two systems.
* Routing:
* - format is a literal in `supportedFormats` AND timezone is UTC -> native to_char
* - format is a literal in `supportedFormats` AND timezone is non-UTC, with the per-expression
* allowIncompatible flag set -> native to_char (results may differ from Spark)
* - all other cases -> JVM UDF (`org.apache.comet.udf.DateFormatUDF`)
*/
object CometDateFormat extends CometExpressionSerde[DateFormatClass] {

/**
* Mapping from Spark SimpleDateFormat patterns to strftime patterns. Only formats in this map
* are supported.
* are supported by the native path.
*/
val supportedFormats: Map[String, String] = Map(
// Full date formats
Expand Down Expand Up @@ -616,67 +622,70 @@ object CometDateFormat extends CometExpressionSerde[DateFormatClass] {
// ISO formats
"yyyy-MM-dd'T'HH:mm:ss" -> "%Y-%m-%dT%H:%M:%S")

override def getIncompatibleReasons(): Seq[String] = Seq(
"Non-UTC timezones may produce different results than Spark")

override def getUnsupportedReasons(): Seq[String] = Seq(
"Only the following formats are supported:" +
supportedFormats.keys.toSeq.sorted
.map(k => s"`$k`")
.mkString("\n - ", "\n - ", ""))
// The JVM UDF covers every case that the native path cannot, so the expression is always
// emittable. Compatibility decisions happen inside `convert`.
override def getSupportLevel(expr: DateFormatClass): SupportLevel = Compatible()

override def getSupportLevel(expr: DateFormatClass): SupportLevel = {
// Check timezone - only UTC is fully compatible
override def convert(
expr: DateFormatClass,
inputs: Seq[Attribute],
binding: Boolean): Option[ExprOuterClass.Expr] = {
val timezone = expr.timeZoneId.getOrElse("UTC")
val isUtc = timezone == "UTC" || timezone == "Etc/UTC"

expr.right match {
case Literal(fmt: UTF8String, _) =>
val format = fmt.toString
if (supportedFormats.contains(format)) {
if (isUtc) {
Compatible()
} else {
Incompatible(Some(s"Non-UTC timezone '$timezone' may produce different results"))
}
} else {
Unsupported(
Some(
s"Format '$format' is not supported. Supported formats: " +
supportedFormats.keys.mkString(", ")))
}
case _ =>
Unsupported(Some("Only literal format strings are supported"))
val nativeFormat: Option[String] = expr.right match {
case Literal(fmt: UTF8String, _) => supportedFormats.get(fmt.toString)
case _ => None
}

val canUseNative = nativeFormat.isDefined && {
isUtc || CometConf.isExprAllowIncompat(getExprConfigName(expr))
}

if (canUseNative) {
val childExpr = exprToProtoInternal(expr.left, inputs, binding)
val formatExpr = exprToProtoInternal(Literal(nativeFormat.get), inputs, binding)
val optExpr = scalarFunctionExprToProtoWithReturnType(
"to_char",
StringType,
false,
childExpr,
formatExpr)
optExprWithInfo(optExpr, expr, expr.left, expr.right)
} else {
convertViaJvmUdf(expr, timezone, inputs, binding)
}
}

override def convert(
private def convertViaJvmUdf(
expr: DateFormatClass,
timezone: String,
inputs: Seq[Attribute],
binding: Boolean): Option[ExprOuterClass.Expr] = {
// Get the format string - must be a literal for us to map it
val strftimeFormat = expr.right match {
case Literal(fmt: UTF8String, _) =>
supportedFormats.get(fmt.toString)
case _ => None
val tsProto = exprToProtoInternal(expr.left, inputs, binding)
val fmtProto = exprToProtoInternal(expr.right, inputs, binding)
val tzProto = exprToProtoInternal(Literal(timezone), inputs, binding)
if (tsProto.isEmpty || fmtProto.isEmpty || tzProto.isEmpty) {
withInfo(expr, expr.left, expr.right)
return None
}

strftimeFormat match {
case Some(format) =>
val childExpr = exprToProtoInternal(expr.left, inputs, binding)
val formatExpr = exprToProtoInternal(Literal(format), inputs, binding)

val optExpr = scalarFunctionExprToProtoWithReturnType(
"to_char",
StringType,
false,
childExpr,
formatExpr)
optExprWithInfo(optExpr, expr, expr.left, expr.right)
case None =>
withInfo(expr, expr.left, expr.right)
None
val returnType = serializeDataType(StringType).getOrElse {
withInfo(expr, "Failed to serialize StringType return type")
return None
}
val udfBuilder = ExprOuterClass.JvmScalarUdf
.newBuilder()
.setClassName(classOf[DateFormatUDF].getName)
.addArgs(tsProto.get)
.addArgs(fmtProto.get)
.addArgs(tzProto.get)
.setReturnType(returnType)
.setReturnNullable(expr.nullable)
Some(
ExprOuterClass.Expr
.newBuilder()
.setJvmScalarUdf(udfBuilder.build())
.build())
}
}

Expand Down
114 changes: 114 additions & 0 deletions spark/src/main/scala/org/apache/comet/udf/DateFormatUDF.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.comet.udf

import java.nio.charset.StandardCharsets.UTF_8
import java.util.concurrent.ConcurrentHashMap

import org.apache.arrow.vector.{TimeStampMicroTZVector, TimeStampMicroVector, ValueVector, VarCharVector}
import org.apache.spark.sql.catalyst.expressions.{BoundReference, DateFormatClass, Literal}
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
import org.apache.spark.sql.types.{StringType, TimestampType}
import org.apache.spark.unsafe.types.UTF8String

import org.apache.comet.CometArrowAllocator

/**
* `date_format(timestamp, format)` implemented by delegating to Spark's `DateFormatClass`.
*
* Used as the JVM fallback when CometDateFormat cannot push to native (non-UTC timezone, format
* outside the strftime-mappable whitelist, non-literal format string).
*
* Inputs:
* - inputs(0): TimeStampMicro[TZ]Vector - timestamp column (microseconds since epoch)
* - inputs(1): VarCharVector - format string; length-1 if literal, else per-row
* - inputs(2): VarCharVector - session timezone id (length-1 scalar)
*
* Output: VarCharVector of length `numRows`.
*/
class DateFormatUDF extends CometUDF {

// Cache one DateFormatClass per (format, timezone). Constructing it with a Literal format makes
// its `formatterOption` lazy-val resolve to Some(formatter), so subsequent eval calls reuse the
// formatter instead of rebuilding it per row.
private val cache = new ConcurrentHashMap[(String, String), DateFormatClass]()

private def lookup(formatStr: String, tzId: String): DateFormatClass =
cache.computeIfAbsent(
(formatStr, tzId),
{ case (f, tz) =>
DateFormatClass(
BoundReference(0, TimestampType, nullable = true),
Literal(UTF8String.fromString(f), StringType),
Some(tz))
})

override def evaluate(inputs: Array[ValueVector], numRows: Int): ValueVector = {
require(
inputs.length == 3,
s"DateFormatUDF expects 3 inputs (timestamp, format, timezone), got ${inputs.length}")
val tsVec = inputs(0)
val fmtVec = inputs(1).asInstanceOf[VarCharVector]
val tzVec = inputs(2).asInstanceOf[VarCharVector]
require(
tzVec.getValueCount >= 1 && !tzVec.isNull(0),
"DateFormatUDF requires a non-null scalar timezone")

val tzId = new String(tzVec.get(0), UTF_8)
val fmtScalar = fmtVec.getValueCount == 1
// For scalar format the format never varies, so resolve the DateFormatClass once and reuse it
// across every row instead of doing a Tuple2 allocation + HashMap lookup per row.
val scalarDf: DateFormatClass =
if (fmtScalar && !fmtVec.isNull(0)) lookup(new String(fmtVec.get(0), UTF_8), tzId)
else null

val getMicros: Int => Long = tsVec match {
case t: TimeStampMicroTZVector => i => t.get(i)
case t: TimeStampMicroVector => i => t.get(i)
case other =>
throw new RuntimeException(
s"DateFormatUDF: unsupported timestamp vector ${other.getClass.getName}")
}

val out = new VarCharVector("date_format_result", CometArrowAllocator)
out.allocateNew(numRows)

val row = new GenericInternalRow(1)

var i = 0
while (i < numRows) {
val fmtIdx = if (fmtScalar) 0 else i
val result: AnyRef =
if (tsVec.isNull(i) || fmtVec.isNull(fmtIdx)) null
else {
val df =
if (scalarDf != null) scalarDf
else lookup(new String(fmtVec.get(i), UTF_8), tzId)
row.update(0, getMicros(i))
df.eval(row).asInstanceOf[AnyRef]
}
if (result == null) out.setNull(i)
else out.setSafe(i, result.asInstanceOf[UTF8String].getBytes)
i += 1
}
out.setValueCount(numRows)
out
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@ CREATE TABLE test_date_format(ts timestamp) USING parquet
statement
INSERT INTO test_date_format VALUES (timestamp('2024-06-15 10:30:45')), (timestamp('1970-01-01 00:00:00')), (NULL)

query expect_fallback(Non-UTC timezone)
query
SELECT date_format(ts, 'yyyy-MM-dd') FROM test_date_format

query expect_fallback(Non-UTC timezone)
query
SELECT date_format(ts, 'HH:mm:ss') FROM test_date_format

query expect_fallback(Non-UTC timezone)
query
SELECT date_format(ts, 'yyyy-MM-dd HH:mm:ss') FROM test_date_format

-- literal arguments
query expect_fallback(Non-UTC timezone)
query
SELECT date_format(timestamp('2024-06-15 10:30:45'), 'yyyy-MM-dd')
Original file line number Diff line number Diff line change
Expand Up @@ -214,10 +214,8 @@ class CometTemporalExpressionSuite extends CometTestBase with AdaptiveSparkPlanH
}

test("date_format - timestamp_ntz input") {
// TimestampNTZ is timezone-independent, so date_format should produce the same
// formatted string regardless of session timezone. Comet currently only runs this
// natively for UTC; for non-UTC it falls back to Spark. We verify correctness
// (matching Spark's output) in all cases.
// TimestampNTZ is timezone-independent, so date_format must produce the same string
// regardless of session timezone.
val r = new Random(42)
val ntzSchema = StructType(Seq(StructField("ts_ntz", DataTypes.TimestampNTZType, true)))
val ntzDF = FuzzDataGenerator.generateDataFrame(r, spark, ntzSchema, 100, DataGenOptions())
Expand All @@ -227,14 +225,8 @@ class CometTemporalExpressionSuite extends CometTestBase with AdaptiveSparkPlanH
for (tz <- crossTimezones) {
withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> tz) {
for (format <- supportedFormats) {
if (tz == "UTC") {
checkSparkAnswerAndOperator(
s"SELECT ts_ntz, date_format(ts_ntz, '$format') from ntz_tbl order by ts_ntz")
} else {
// Non-UTC falls back to Spark but should still produce correct results
checkSparkAnswer(
s"SELECT ts_ntz, date_format(ts_ntz, '$format') from ntz_tbl order by ts_ntz")
}
checkSparkAnswerAndOperator(
s"SELECT ts_ntz, date_format(ts_ntz, '$format') from ntz_tbl order by ts_ntz")
}
}
}
Expand Down Expand Up @@ -476,45 +468,44 @@ class CometTemporalExpressionSuite extends CometTestBase with AdaptiveSparkPlanH
}
}

test("date_format unsupported format falls back to Spark") {
test("date_format unsupported format runs via JVM UDF inside Comet") {
createTimestampTestData.createOrReplaceTempView("tbl")

withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC") {
// Unsupported format string
checkSparkAnswerAndFallbackReason(
"SELECT c0, date_format(c0, 'yyyy-MM-dd EEEE') from tbl order by c0",
"Format 'yyyy-MM-dd EEEE' is not supported")
checkSparkAnswerAndOperator(
"SELECT c0, date_format(c0, 'yyyy-MM-dd EEEE') from tbl order by c0")
}
}

test("date_format with non-UTC timezone falls back to Spark") {
test("date_format with non-UTC timezone runs via JVM UDF inside Comet") {
createTimestampTestData.createOrReplaceTempView("tbl")

val nonUtcTimezones =
Seq("America/New_York", "America/Los_Angeles", "Europe/London", "Asia/Tokyo")

for (tz <- nonUtcTimezones) {
withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> tz) {
// Non-UTC timezones should fall back to Spark as Incompatible
checkSparkAnswerAndFallbackReason(
"SELECT c0, date_format(c0, 'yyyy-MM-dd HH:mm:ss') from tbl order by c0",
s"Non-UTC timezone '$tz' may produce different results")
checkSparkAnswerAndOperator(
"SELECT c0, date_format(c0, 'yyyy-MM-dd HH:mm:ss') from tbl order by c0")
}
}
}

test("date_format with non-UTC timezone works when allowIncompatible is enabled") {
test("date_format with non-UTC timezone takes native path when allowIncompatible is enabled") {
createTimestampTestData.createOrReplaceTempView("tbl")

val nonUtcTimezones = Seq("America/New_York", "Europe/London", "Asia/Tokyo")

for (tz <- nonUtcTimezones) {
withSQLConf(
SQLConf.SESSION_LOCAL_TIMEZONE.key -> tz,
"spark.comet.expr.DateFormatClass.allowIncompatible" -> "true") {
// With allowIncompatible enabled, Comet will execute the expression
// Results may differ from Spark but should not throw errors
checkSparkAnswer("SELECT c0, date_format(c0, 'yyyy-MM-dd') from tbl order by c0")
"spark.comet.expression.DateFormatClass.allowIncompatible" -> "true") {
// Native to_char results may diverge from Spark for non-UTC timezones (the reason the
// JVM UDF is the default), so we only check that execution stays inside Comet. ORDER BY
// is omitted to keep the plan free of AQEShuffleRead.
val df = sql("SELECT c0, date_format(c0, 'yyyy-MM-dd') from tbl")
df.collect()
checkCometOperators(stripAQEPlan(df.queryExecution.executedPlan))
}
}
}
Expand Down
Loading