Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 121 additions & 0 deletions docs/source/user-guide/latest/custom_comet_udfs.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
<!---
Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
-->

# Custom Comet UDFs

Comet lets you register a user-supplied vectorized implementation of a Spark UDF. When the
registered name matches a `ScalaUDF` in your query, Comet routes the call to your vectorized
implementation on the native path instead of running Spark's row-at-a-time function. Other Comet
operators in the same plan stay native.

This is a more direct alternative to the [Janino codegen path](scala_java_udfs.md): you supply
the columnar implementation yourself, work with Arrow vectors directly, and Comet ships the data
to your code through the existing native UDF bridge.

This feature is experimental. The `CometUDF` trait and `CometUDFRegistry` API carry the
`@org.apache.spark.annotation.Unstable` annotation and may change.

## API

```scala
package com.example

import org.apache.arrow.memory.RootAllocator
import org.apache.arrow.vector.{IntVector, ValueVector}
import org.apache.comet.udf.CometUDF

class PlusOneUdf extends CometUDF {
private val allocator = new RootAllocator()

override def evaluate(inputs: Array[ValueVector], numRows: Int): ValueVector = {
val in = inputs(0).asInstanceOf[IntVector]
val out = new IntVector("out", allocator)
out.allocateNew(numRows)
var i = 0
while (i < numRows) {
if (in.isNull(i)) out.setNull(i) else out.set(i, in.get(i) + 1)
i += 1
}
out.setValueCount(numRows)
out
}
}
```

Register both the Spark UDF (for type binding and the row-based fallback) and the Comet UDF:

```scala
import org.apache.comet.udf.CometUDFRegistry

spark.udf.register("plus_one", (x: Int) => x + 1)
CometUDFRegistry.register("plus_one", classOf[com.example.PlusOneUdf])
```

Use it from SQL or DataFrame as you would any Spark UDF; Comet picks up the registered class at
plan time:

```sql
SELECT plus_one(value) FROM t
```

## Contract

The `CometUDF` trait:

```scala
trait CometUDF {
def evaluate(inputs: Array[ValueVector], numRows: Int): ValueVector
}
```

- Vector arguments arrive at the row count of the current batch.
- Scalar (literal-folded) arguments arrive as length-1 vectors and must be read at index 0.
- The returned vector's length must match `numRows`.
- Implementations must have a public no-arg constructor.
- A fresh instance is created per Spark task attempt per class and reused for every batch within
that task; instance fields may hold per-task state. Instances are dropped at task completion.
Do not hold state that must persist across tasks.
- At most one thread calls `evaluate` on a given instance at a time, so per-task state does not
require synchronization.

## Routing precedence

For each `ScalaUDF` Comet encounters:

1. If the UDF name has a `CometUDF` registered via `CometUDFRegistry`, Comet routes the call to
the registered class.
2. Otherwise, if `spark.comet.exec.scalaUDF.codegen.enabled=true`, Comet uses the Janino codegen
dispatcher.
3. Otherwise, the enclosing operator falls back to Spark.

## Cluster deployment

The class is loaded on each executor via the task's context classloader, so the jar containing
your `CometUDF` must be on the executor classpath (e.g. `spark.jars`, `--jars`, or shaded into
the application). Registration calls themselves are driver-side; executors receive the class
name through the serialized plan.

## Limitations

- Aggregate, table, Python, Pandas, and Hive UDFs are out of scope.
- The matching Spark UDF must be registered separately; without it, the function name will not
bind during analysis.
- Return type and nullability come from the registered Spark UDF, not from the `CometUDF` class.
Make sure your vectorized implementation produces a vector compatible with the declared
Spark return type.
1 change: 1 addition & 0 deletions docs/source/user-guide/latest/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ to read more.
Supported Operators <operators>
Supported Expressions <expressions>
ScalaUDF and Java UDF Support <scala_java_udfs>
Custom Comet UDFs <custom_comet_udfs>
Configuration Settings <configs>
Compatibility Guide <compatibility/index>
Understanding Comet Plans <understanding-comet-plans>
Expand Down
66 changes: 56 additions & 10 deletions spark/src/main/scala/org/apache/comet/serde/CometScalaUDF.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,35 +20,47 @@
package org.apache.comet.serde

import org.apache.spark.SparkEnv
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSeq, BindReferences, Literal, ScalaUDF}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSeq, BindReferences, KnownNotNull, Literal, ScalaUDF}
import org.apache.spark.sql.types.BinaryType

import org.apache.comet.CometConf
import org.apache.comet.CometSparkSessionExtensions.withInfo
import org.apache.comet.codegen.CometBatchKernelCodegen
import org.apache.comet.serde.ExprOuterClass.Expr
import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, serializeDataType}
import org.apache.comet.udf.CometUDFRegistry
import org.apache.comet.udf.codegen.CometScalaUDFCodegen

/**
* Routes scalar `ScalaUDF` (Scala and Java UDFs) through the codegen dispatcher.
* `ScalaUDF.doGenCode` emits compilable Java that invokes the user function via
* `ctx.addReferenceObj`; the dispatcher serializes the bound tree, the closure serializer carries
* the function reference across the wire, and the Janino-compiled kernel invokes it in a tight
* batch loop.
* Routes scalar `ScalaUDF` (Scala and Java UDFs) through one of two native paths.
*
* Not covered:
* If the UDF name has a user-supplied [[org.apache.comet.udf.CometUDF]] registered in
* [[CometUDFRegistry]], emit a `JvmScalarUdf` that targets the registered class directly. The
* native side passes each argument as an Arrow vector and the user implementation produces the
* output vector.
*
* Otherwise, route through the Janino codegen dispatcher. `ScalaUDF.doGenCode` emits compilable
* Java that invokes the user function via `ctx.addReferenceObj`; the dispatcher serializes the
* bound tree, the closure serializer carries the function reference across the wire, and the
* Janino-compiled kernel invokes it in a tight batch loop. Gated by
* [[CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED]]; when disabled, unregistered ScalaUDFs fall back
* to Spark.
*
* Not covered by either path:
* - Aggregate UDFs (`ScalaAggregator`, `TypedImperativeAggregate`, legacy UDAF).
* - Table UDFs and generators.
* - Python / Pandas UDFs.
* - Hive `GenericUDF` / `SimpleUDF`.
*
* Gated by [[CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED]]. When disabled, plans containing a
* `ScalaUDF` fall back to Spark for the enclosing operator.
*/
object CometScalaUDF extends CometExpressionSerde[ScalaUDF] {

override def convert(expr: ScalaUDF, inputs: Seq[Attribute], binding: Boolean): Option[Expr] = {
expr.udfName.flatMap(CometUDFRegistry.get) match {
case Some(udfClass) =>
return convertRegistered(expr, inputs, binding, udfClass.getName)
case None =>
}

if (!CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED.get()) {
withInfo(
expr,
Expand Down Expand Up @@ -99,4 +111,38 @@ object CometScalaUDF extends CometExpressionSerde[ScalaUDF] {
.setJvmScalarUdf(udfBuilder.build())
.build())
}

/**
* Build a `JvmScalarUdf` proto that targets a user-registered [[org.apache.comet.udf.CometUDF]]
* class directly. Each ScalaUDF child becomes an `args` entry; the native side evaluates them
* to Arrow vectors before invoking `evaluate(inputs, numRows)` on a per-task instance.
*
* Spark wraps UDF arguments in `KnownNotNull` when the UDF declares its inputs non-nullable;
* strip those so the underlying expression has a Comet serde mapping.
*/
private def convertRegistered(
expr: ScalaUDF,
inputs: Seq[Attribute],
binding: Boolean,
className: String): Option[Expr] = {
val argProtos = expr.children.map {
case KnownNotNull(child) => exprToProtoInternal(child, inputs, binding)
case other => exprToProtoInternal(other, inputs, binding)
}
if (argProtos.exists(_.isEmpty)) {
return None
}
val returnTypeProto = serializeDataType(expr.dataType).getOrElse(return None)
val udfBuilder = ExprOuterClass.JvmScalarUdf
.newBuilder()
.setClassName(className)
.setReturnType(returnTypeProto)
.setReturnNullable(expr.nullable)
argProtos.foreach(p => udfBuilder.addArgs(p.get))
Some(
ExprOuterClass.Expr
.newBuilder()
.setJvmScalarUdf(udfBuilder.build())
.build())
}
}
2 changes: 2 additions & 0 deletions spark/src/main/scala/org/apache/comet/udf/CometUDF.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
package org.apache.comet.udf

import org.apache.arrow.vector.ValueVector
import org.apache.spark.annotation.Unstable

/**
* Scalar UDF invoked from native execution via JNI. Receives Arrow vectors as input and returns
Expand All @@ -43,6 +44,7 @@ import org.apache.arrow.vector.ValueVector
* per partition and Tokio polls one future per worker, so the per-task instance is never touched
* concurrently even if the task's future migrates between Tokio workers across batches.
*/
@Unstable
trait CometUDF {
def evaluate(inputs: Array[ValueVector], numRows: Int): ValueVector
}
68 changes: 68 additions & 0 deletions spark/src/main/scala/org/apache/comet/udf/CometUDFRegistry.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.comet.udf

import java.util.concurrent.ConcurrentHashMap

import org.apache.spark.annotation.Unstable

/**
* Driver-side registry mapping a Spark UDF name to a user-supplied [[CometUDF]] implementation
* class. When the [[org.apache.comet.serde.CometScalaUDF]] serde encounters a `ScalaUDF` whose
* name is registered, it routes the call through the registered class instead of the Janino
* codegen dispatcher.
*
* Typical usage:
* {{{
* spark.udf.register("plus_one", (x: Int) => x + 1) // row-based fallback / type info for Spark
* CometUDFRegistry.register("plus_one", classOf[PlusOneUdf])
* }}}
*
* The matching Spark UDF must be registered separately (e.g. via `spark.udf.register`) so Spark
* can bind the function during analysis and supply the return type. If Comet is disabled or the
* enclosing operator falls back, Spark evaluates the row-based UDF.
*
* Registration is driver-side state; executors look up the class name from the serialized plan
* and load the class via the executor's context classloader.
*/
@Unstable
object CometUDFRegistry {

private val registry = new ConcurrentHashMap[String, Class[_ <: CometUDF]]()

/** Register a [[CometUDF]] implementation against a Spark UDF name. */
def register(udfName: String, udfClass: Class[_ <: CometUDF]): Unit = {
registry.put(udfName, udfClass)
}

/** Remove a previously registered UDF. No-op if not registered. */
def unregister(udfName: String): Unit = {
registry.remove(udfName)
}

/** Whether a UDF name has a registered [[CometUDF]] implementation. */
def isRegistered(udfName: String): Boolean = registry.containsKey(udfName)

private[comet] def get(udfName: String): Option[Class[_ <: CometUDF]] =
Option(registry.get(udfName))

// Visible for testing.
private[comet] def clear(): Unit = registry.clear()
}
Loading