Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim {
classOf[Flatten] -> CometFlatten,
classOf[GetArrayItem] -> CometGetArrayItem,
classOf[Size] -> CometSize,
classOf[ArraysZip] -> CometArraysZip)
classOf[ArraysZip] -> CometArraysZip,
classOf[ArrayExists] -> CometArrayExists)

private val conditionalExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] =
Map(classOf[CaseWhen] -> CometCaseWhen, classOf[If] -> CometIf)
Expand Down
76 changes: 75 additions & 1 deletion spark/src/main/scala/org/apache/comet/serde/arrays.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ package org.apache.comet.serde
import scala.annotation.tailrec
import scala.jdk.CollectionConverters._

import org.apache.spark.sql.catalyst.expressions.{And, ArrayAppend, ArrayContains, ArrayExcept, ArrayFilter, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayMax, ArrayMin, ArrayPosition, ArrayRemove, ArrayRepeat, ArraysOverlap, ArraysZip, ArrayUnion, Attribute, CreateArray, ElementAt, EmptyRow, Expression, Flatten, GetArrayItem, IsNotNull, Literal, Reverse, Size, SortArray}
import org.apache.spark.sql.catalyst.expressions.{And, ArrayAppend, ArrayContains, ArrayExcept, ArrayExists, ArrayFilter, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayMax, ArrayMin, ArrayPosition, ArrayRemove, ArrayRepeat, ArraysOverlap, ArraysZip, ArrayUnion, Attribute, AttributeReference, CreateArray, ElementAt, EmptyRow, Expression, Flatten, GetArrayItem, IsNotNull, LambdaFunction, Literal, NamedLambdaVariable, Reverse, Size, SortArray}
import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -812,3 +812,77 @@ trait ArraysBase {
}
}
}

object CometArrayExists extends CometExpressionSerde[ArrayExists] {

private def isElementTypeSupported(dt: DataType): Boolean = dt match {
case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType |
_: DecimalType | DateType | TimestampType | StringType =>
true
case _ => false
}

override def getSupportLevel(expr: ArrayExists): SupportLevel = {
val ArrayType(elementType, _) = expr.argument.dataType
if (!isElementTypeSupported(elementType)) {
return Unsupported(Some(s"Unsupported array element type: $elementType"))
}
// Only support lambdas that reference the lambda variable alone (no captured columns)
expr.function match {
case LambdaFunction(body, Seq(_: NamedLambdaVariable), _) =>
val capturedRefs = body.collect { case a: AttributeReference => a }
if (capturedRefs.nonEmpty) {
Unsupported(Some("Lambda references columns outside the array element"))
} else {
Compatible()
}
case _ =>
Unsupported(Some("Only single-argument lambda functions are supported"))
}
}

override def convert(
expr: ArrayExists,
inputs: Seq[Attribute],
binding: Boolean): Option[ExprOuterClass.Expr] = {
val arrayProto = exprToProtoInternal(expr.argument, inputs, binding)
if (arrayProto.isEmpty) {
withInfo(expr, "Failed to serialize array argument")
return None
}

// The lambda body is evaluated on the executor inside ArrayExistsUDF, so the entire
// ArrayExists Catalyst expression must travel with the plan. The driver-side
// CometLambdaRegistry approach broke under real distributed execution because the
// executor JVM doesn't share the driver's map. Serializing the expression as bytes
// and shipping it through the proto as a Literal arg keeps the lambda self-contained.
val baos = new java.io.ByteArrayOutputStream()
val oos = new java.io.ObjectOutputStream(baos)
try {
oos.writeObject(expr)
} finally {
oos.close()
}
val payloadProto = exprToProtoInternal(Literal(baos.toByteArray, BinaryType), inputs, binding)
if (payloadProto.isEmpty) {
withInfo(expr, "Failed to serialize lambda expression payload")
return None
}

val returnType = serializeDataType(BooleanType).getOrElse(return None)

val udfBuilder = ExprOuterClass.JvmScalarUdf
.newBuilder()
.setClassName("org.apache.comet.udf.ArrayExistsUDF")
.addArgs(arrayProto.get)
.addArgs(payloadProto.get)
.setReturnType(returnType)
.setReturnNullable(expr.nullable)

Some(
ExprOuterClass.Expr
.newBuilder()
.setJvmScalarUdf(udfBuilder.build())
.build())
}
}
146 changes: 146 additions & 0 deletions spark/src/main/scala/org/apache/comet/udf/ArrayExistsUDF.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.comet.udf

import org.apache.arrow.vector._
import org.apache.arrow.vector.complex.ListVector
import org.apache.spark.sql.catalyst.expressions.{ArrayExists, LambdaFunction, NamedLambdaVariable}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

import org.apache.comet.CometArrowAllocator

/**
* JVM UDF implementing Spark's `exists(array, x -> predicate(x))` higher-order function.
*
* Inputs:
* - inputs(0): ListVector (the array column)
* - inputs(1): VarBinaryVector length-1 scalar containing the Java-serialized [[ArrayExists]]
* Catalyst expression. Shipping the expression in the proto avoids the driver-vs-executor
* mismatch a process-local registry would suffer.
*
* Output: BitVector (nullable boolean), same length as the input array vector.
*
* Implements Spark's three-valued logic:
* - true if any element satisfies the predicate
* - null if no element satisfies but the predicate returned null for at least one element
* - false if all elements produce false
*/
class ArrayExistsUDF extends CometUDF {

override def evaluate(inputs: Array[ValueVector], numRows: Int): ValueVector = {
require(inputs.length == 2, s"ArrayExistsUDF expects 2 inputs, got ${inputs.length}")
val listVec = inputs(0).asInstanceOf[ListVector]
val payloadVec = inputs(1).asInstanceOf[VarBinaryVector]
require(
payloadVec.getValueCount >= 1 && !payloadVec.isNull(0),
"ArrayExistsUDF requires a non-null scalar payload")

val payloadBytes = payloadVec.get(0)
val bais = new java.io.ByteArrayInputStream(payloadBytes)
val ois = new java.io.ObjectInputStream(bais)
val arrayExistsExpr =
try ois.readObject().asInstanceOf[ArrayExists]
finally ois.close()

val LambdaFunction(_, Seq(elementVar: NamedLambdaVariable), _) = arrayExistsExpr.function
val body = arrayExistsExpr.functionForEval
val followThreeValuedLogic = arrayExistsExpr.followThreeValuedLogic
val elementType = elementVar.dataType

val dataVec = listVec.getDataVector
val n = listVec.getValueCount
val out = new BitVector("exists_result", CometArrowAllocator)
out.allocateNew(n)

var i = 0
while (i < n) {
if (listVec.isNull(i)) {
out.setNull(i)
} else {
val startIdx = listVec.getElementStartIndex(i)
val endIdx = listVec.getElementEndIndex(i)
var exists = false
var foundNull = false
var j = startIdx
while (j < endIdx && !exists) {
if (dataVec.isNull(j)) {
elementVar.value.set(null)
val ret = body.eval(null)
if (ret == null) foundNull = true
else if (ret.asInstanceOf[Boolean]) exists = true
} else {
val elem = getSparkValue(dataVec, j, elementType)
elementVar.value.set(elem)
val ret = body.eval(null)
if (ret == null) foundNull = true
else if (ret.asInstanceOf[Boolean]) exists = true
}
j += 1
}
if (exists) {
out.set(i, 1)
} else if (followThreeValuedLogic && foundNull) {
out.setNull(i)
} else {
out.set(i, 0)
}
}
i += 1
}
out.setValueCount(n)
out
}

private def getSparkValue(vec: ValueVector, index: Int, sparkType: DataType): Any = {
sparkType match {
case BooleanType =>
vec.asInstanceOf[BitVector].get(index) == 1
case ByteType =>
vec.asInstanceOf[TinyIntVector].get(index).toByte
case ShortType =>
vec.asInstanceOf[SmallIntVector].get(index).toShort
case IntegerType =>
vec.asInstanceOf[IntVector].get(index)
case LongType =>
vec.asInstanceOf[BigIntVector].get(index)
case FloatType =>
vec.asInstanceOf[Float4Vector].get(index)
case DoubleType =>
vec.asInstanceOf[Float8Vector].get(index)
case StringType =>
val bytes = vec.asInstanceOf[VarCharVector].get(index)
UTF8String.fromBytes(bytes)
case BinaryType =>
vec.asInstanceOf[VarBinaryVector].get(index)
case _: DecimalType =>
val dt = sparkType.asInstanceOf[DecimalType]
val decimal = vec.asInstanceOf[DecimalVector].getObject(index)
Decimal(decimal, dt.precision, dt.scale)
case DateType =>
vec.asInstanceOf[DateDayVector].get(index)
case TimestampType =>
vec.asInstanceOf[TimeStampMicroTZVector].get(index)
case _ =>
throw new UnsupportedOperationException(
s"ArrayExistsUDF does not yet support element type: $sparkType")
}
}
}

This file was deleted.

Loading
Loading