diff --git a/docs/source/user-guide/latest/custom_comet_udfs.md b/docs/source/user-guide/latest/custom_comet_udfs.md new file mode 100644 index 0000000000..5d8f2dafb7 --- /dev/null +++ b/docs/source/user-guide/latest/custom_comet_udfs.md @@ -0,0 +1,121 @@ + + +# Custom Comet UDFs + +Comet lets you register a user-supplied vectorized implementation of a Spark UDF. When the +registered name matches a `ScalaUDF` in your query, Comet routes the call to your vectorized +implementation on the native path instead of running Spark's row-at-a-time function. Other Comet +operators in the same plan stay native. + +This is a more direct alternative to the [Janino codegen path](scala_java_udfs.md): you supply +the columnar implementation yourself, work with Arrow vectors directly, and Comet ships the data +to your code through the existing native UDF bridge. + +This feature is experimental. The `CometUDF` trait and `CometUDFRegistry` API carry the +`@org.apache.spark.annotation.Unstable` annotation and may change. + +## API + +```scala +package com.example + +import org.apache.arrow.memory.RootAllocator +import org.apache.arrow.vector.{IntVector, ValueVector} +import org.apache.comet.udf.CometUDF + +class PlusOneUdf extends CometUDF { + private val allocator = new RootAllocator() + + override def evaluate(inputs: Array[ValueVector], numRows: Int): ValueVector = { + val in = inputs(0).asInstanceOf[IntVector] + val out = new IntVector("out", allocator) + out.allocateNew(numRows) + var i = 0 + while (i < numRows) { + if (in.isNull(i)) out.setNull(i) else out.set(i, in.get(i) + 1) + i += 1 + } + out.setValueCount(numRows) + out + } +} +``` + +Register both the Spark UDF (for type binding and the row-based fallback) and the Comet UDF: + +```scala +import org.apache.comet.udf.CometUDFRegistry + +spark.udf.register("plus_one", (x: Int) => x + 1) +CometUDFRegistry.register("plus_one", classOf[com.example.PlusOneUdf]) +``` + +Use it from SQL or DataFrame as you would any Spark UDF; Comet picks up the registered class at +plan time: + +```sql +SELECT plus_one(value) FROM t +``` + +## Contract + +The `CometUDF` trait: + +```scala +trait CometUDF { + def evaluate(inputs: Array[ValueVector], numRows: Int): ValueVector +} +``` + +- Vector arguments arrive at the row count of the current batch. +- Scalar (literal-folded) arguments arrive as length-1 vectors and must be read at index 0. +- The returned vector's length must match `numRows`. +- Implementations must have a public no-arg constructor. +- A fresh instance is created per Spark task attempt per class and reused for every batch within + that task; instance fields may hold per-task state. Instances are dropped at task completion. + Do not hold state that must persist across tasks. +- At most one thread calls `evaluate` on a given instance at a time, so per-task state does not + require synchronization. + +## Routing precedence + +For each `ScalaUDF` Comet encounters: + +1. If the UDF name has a `CometUDF` registered via `CometUDFRegistry`, Comet routes the call to + the registered class. +2. Otherwise, if `spark.comet.exec.scalaUDF.codegen.enabled=true`, Comet uses the Janino codegen + dispatcher. +3. Otherwise, the enclosing operator falls back to Spark. + +## Cluster deployment + +The class is loaded on each executor via the task's context classloader, so the jar containing +your `CometUDF` must be on the executor classpath (e.g. `spark.jars`, `--jars`, or shaded into +the application). Registration calls themselves are driver-side; executors receive the class +name through the serialized plan. + +## Limitations + +- Aggregate, table, Python, Pandas, and Hive UDFs are out of scope. +- The matching Spark UDF must be registered separately; without it, the function name will not + bind during analysis. +- Return type and nullability come from the registered Spark UDF, not from the `CometUDF` class. + Make sure your vectorized implementation produces a vector compatible with the declared + Spark return type. diff --git a/docs/source/user-guide/latest/index.rst b/docs/source/user-guide/latest/index.rst index 9587b2ee03..be00eb243c 100644 --- a/docs/source/user-guide/latest/index.rst +++ b/docs/source/user-guide/latest/index.rst @@ -44,6 +44,7 @@ to read more. Supported Operators Supported Expressions ScalaUDF and Java UDF Support + Custom Comet UDFs Configuration Settings Compatibility Guide Understanding Comet Plans diff --git a/spark/src/main/scala/org/apache/comet/serde/CometScalaUDF.scala b/spark/src/main/scala/org/apache/comet/serde/CometScalaUDF.scala index bf636f7221..99111d8f23 100644 --- a/spark/src/main/scala/org/apache/comet/serde/CometScalaUDF.scala +++ b/spark/src/main/scala/org/apache/comet/serde/CometScalaUDF.scala @@ -20,7 +20,7 @@ package org.apache.comet.serde import org.apache.spark.SparkEnv -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSeq, BindReferences, Literal, ScalaUDF} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSeq, BindReferences, KnownNotNull, Literal, ScalaUDF} import org.apache.spark.sql.types.BinaryType import org.apache.comet.CometConf @@ -28,27 +28,39 @@ import org.apache.comet.CometSparkSessionExtensions.withInfo import org.apache.comet.codegen.CometBatchKernelCodegen import org.apache.comet.serde.ExprOuterClass.Expr import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, serializeDataType} +import org.apache.comet.udf.CometUDFRegistry import org.apache.comet.udf.codegen.CometScalaUDFCodegen /** - * Routes scalar `ScalaUDF` (Scala and Java UDFs) through the codegen dispatcher. - * `ScalaUDF.doGenCode` emits compilable Java that invokes the user function via - * `ctx.addReferenceObj`; the dispatcher serializes the bound tree, the closure serializer carries - * the function reference across the wire, and the Janino-compiled kernel invokes it in a tight - * batch loop. + * Routes scalar `ScalaUDF` (Scala and Java UDFs) through one of two native paths. * - * Not covered: + * If the UDF name has a user-supplied [[org.apache.comet.udf.CometUDF]] registered in + * [[CometUDFRegistry]], emit a `JvmScalarUdf` that targets the registered class directly. The + * native side passes each argument as an Arrow vector and the user implementation produces the + * output vector. + * + * Otherwise, route through the Janino codegen dispatcher. `ScalaUDF.doGenCode` emits compilable + * Java that invokes the user function via `ctx.addReferenceObj`; the dispatcher serializes the + * bound tree, the closure serializer carries the function reference across the wire, and the + * Janino-compiled kernel invokes it in a tight batch loop. Gated by + * [[CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED]]; when disabled, unregistered ScalaUDFs fall back + * to Spark. + * + * Not covered by either path: * - Aggregate UDFs (`ScalaAggregator`, `TypedImperativeAggregate`, legacy UDAF). * - Table UDFs and generators. * - Python / Pandas UDFs. * - Hive `GenericUDF` / `SimpleUDF`. - * - * Gated by [[CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED]]. When disabled, plans containing a - * `ScalaUDF` fall back to Spark for the enclosing operator. */ object CometScalaUDF extends CometExpressionSerde[ScalaUDF] { override def convert(expr: ScalaUDF, inputs: Seq[Attribute], binding: Boolean): Option[Expr] = { + expr.udfName.flatMap(CometUDFRegistry.get) match { + case Some(udfClass) => + return convertRegistered(expr, inputs, binding, udfClass.getName) + case None => + } + if (!CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED.get()) { withInfo( expr, @@ -99,4 +111,38 @@ object CometScalaUDF extends CometExpressionSerde[ScalaUDF] { .setJvmScalarUdf(udfBuilder.build()) .build()) } + + /** + * Build a `JvmScalarUdf` proto that targets a user-registered [[org.apache.comet.udf.CometUDF]] + * class directly. Each ScalaUDF child becomes an `args` entry; the native side evaluates them + * to Arrow vectors before invoking `evaluate(inputs, numRows)` on a per-task instance. + * + * Spark wraps UDF arguments in `KnownNotNull` when the UDF declares its inputs non-nullable; + * strip those so the underlying expression has a Comet serde mapping. + */ + private def convertRegistered( + expr: ScalaUDF, + inputs: Seq[Attribute], + binding: Boolean, + className: String): Option[Expr] = { + val argProtos = expr.children.map { + case KnownNotNull(child) => exprToProtoInternal(child, inputs, binding) + case other => exprToProtoInternal(other, inputs, binding) + } + if (argProtos.exists(_.isEmpty)) { + return None + } + val returnTypeProto = serializeDataType(expr.dataType).getOrElse(return None) + val udfBuilder = ExprOuterClass.JvmScalarUdf + .newBuilder() + .setClassName(className) + .setReturnType(returnTypeProto) + .setReturnNullable(expr.nullable) + argProtos.foreach(p => udfBuilder.addArgs(p.get)) + Some( + ExprOuterClass.Expr + .newBuilder() + .setJvmScalarUdf(udfBuilder.build()) + .build()) + } } diff --git a/spark/src/main/scala/org/apache/comet/udf/CometUDF.scala b/spark/src/main/scala/org/apache/comet/udf/CometUDF.scala index 6b435c4064..00fd8e0725 100644 --- a/spark/src/main/scala/org/apache/comet/udf/CometUDF.scala +++ b/spark/src/main/scala/org/apache/comet/udf/CometUDF.scala @@ -20,6 +20,7 @@ package org.apache.comet.udf import org.apache.arrow.vector.ValueVector +import org.apache.spark.annotation.Unstable /** * Scalar UDF invoked from native execution via JNI. Receives Arrow vectors as input and returns @@ -43,6 +44,7 @@ import org.apache.arrow.vector.ValueVector * per partition and Tokio polls one future per worker, so the per-task instance is never touched * concurrently even if the task's future migrates between Tokio workers across batches. */ +@Unstable trait CometUDF { def evaluate(inputs: Array[ValueVector], numRows: Int): ValueVector } diff --git a/spark/src/main/scala/org/apache/comet/udf/CometUDFRegistry.scala b/spark/src/main/scala/org/apache/comet/udf/CometUDFRegistry.scala new file mode 100644 index 0000000000..4b376d15bb --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/udf/CometUDFRegistry.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.udf + +import java.util.concurrent.ConcurrentHashMap + +import org.apache.spark.annotation.Unstable + +/** + * Driver-side registry mapping a Spark UDF name to a user-supplied [[CometUDF]] implementation + * class. When the [[org.apache.comet.serde.CometScalaUDF]] serde encounters a `ScalaUDF` whose + * name is registered, it routes the call through the registered class instead of the Janino + * codegen dispatcher. + * + * Typical usage: + * {{{ + * spark.udf.register("plus_one", (x: Int) => x + 1) // row-based fallback / type info for Spark + * CometUDFRegistry.register("plus_one", classOf[PlusOneUdf]) + * }}} + * + * The matching Spark UDF must be registered separately (e.g. via `spark.udf.register`) so Spark + * can bind the function during analysis and supply the return type. If Comet is disabled or the + * enclosing operator falls back, Spark evaluates the row-based UDF. + * + * Registration is driver-side state; executors look up the class name from the serialized plan + * and load the class via the executor's context classloader. + */ +@Unstable +object CometUDFRegistry { + + private val registry = new ConcurrentHashMap[String, Class[_ <: CometUDF]]() + + /** Register a [[CometUDF]] implementation against a Spark UDF name. */ + def register(udfName: String, udfClass: Class[_ <: CometUDF]): Unit = { + registry.put(udfName, udfClass) + } + + /** Remove a previously registered UDF. No-op if not registered. */ + def unregister(udfName: String): Unit = { + registry.remove(udfName) + } + + /** Whether a UDF name has a registered [[CometUDF]] implementation. */ + def isRegistered(udfName: String): Boolean = registry.containsKey(udfName) + + private[comet] def get(udfName: String): Option[Class[_ <: CometUDF]] = + Option(registry.get(udfName)) + + // Visible for testing. + private[comet] def clear(): Unit = registry.clear() +} diff --git a/spark/src/test/scala/org/apache/comet/CometRegisteredUdfSuite.scala b/spark/src/test/scala/org/apache/comet/CometRegisteredUdfSuite.scala new file mode 100644 index 0000000000..7a909454c2 --- /dev/null +++ b/spark/src/test/scala/org/apache/comet/CometRegisteredUdfSuite.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet + +import org.apache.arrow.memory.RootAllocator +import org.apache.arrow.vector.{IntVector, ValueVector} +import org.apache.spark.sql.CometTestBase +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper + +import org.apache.comet.udf.{CometUDF, CometUDFRegistry} + +/** + * End-to-end test for [[CometUDFRegistry]]. Registering a [[CometUDF]] against a Spark UDF name + * causes the [[org.apache.comet.serde.CometScalaUDF]] serde to route through the registered class + * on the native path. + */ +class CometRegisteredUdfSuite extends CometTestBase with AdaptiveSparkPlanHelper { + + override protected def afterEach(): Unit = { + CometUDFRegistry.clear() + super.afterEach() + } + + test("registered CometUDF runs on the native path") { + spark.udf.register("plus_one", (x: Int) => x + 1) + CometUDFRegistry.register("plus_one", classOf[PlusOneCometUDF]) + + withParquetTable((0 until 8).map(Tuple1(_)), "t") { + checkSparkAnswerAndOperator(sql("SELECT plus_one(_1) FROM t")) + } + } + + test("unregistered ScalaUDF still falls back when codegen is disabled") { + spark.udf.register("plus_two", (x: Int) => x + 2) + assert(!CometUDFRegistry.isRegistered("plus_two")) + + withParquetTable((0 until 4).map(Tuple1(_)), "t") { + checkSparkAnswer(sql("SELECT plus_two(_1) FROM t")) + } + } + + test("register / isRegistered / unregister round-trip") { + assert(!CometUDFRegistry.isRegistered("plus_one")) + CometUDFRegistry.register("plus_one", classOf[PlusOneCometUDF]) + assert(CometUDFRegistry.isRegistered("plus_one")) + CometUDFRegistry.unregister("plus_one") + assert(!CometUDFRegistry.isRegistered("plus_one")) + } +} + +/** + * Test [[CometUDF]] that returns `input + 1` over an int vector. Top-level with a no-arg + * constructor so `CometUdfBridge` can instantiate it reflectively per task. + */ +class PlusOneCometUDF extends CometUDF { + // A RootAllocator owned by the UDF instance keeps the test self-contained. Production UDFs + // would typically reuse a TaskContext-scoped allocator, but the per-task instance lifecycle + // makes either choice safe. + private val allocator = new RootAllocator() + + override def evaluate(inputs: Array[ValueVector], numRows: Int): ValueVector = { + val input = inputs(0).asInstanceOf[IntVector] + val out = new IntVector("plus_one_out", allocator) + out.allocateNew(numRows) + var i = 0 + while (i < numRows) { + if (input.isNull(i)) out.setNull(i) + else out.set(i, input.get(i) + 1) + i += 1 + } + out.setValueCount(numRows) + out + } +}