diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 9c80f33d39..326388253e 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -71,7 +71,8 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim { classOf[Flatten] -> CometFlatten, classOf[GetArrayItem] -> CometGetArrayItem, classOf[Size] -> CometSize, - classOf[ArraysZip] -> CometArraysZip) + classOf[ArraysZip] -> CometArraysZip, + classOf[ArrayExists] -> CometArrayExists) private val conditionalExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map(classOf[CaseWhen] -> CometCaseWhen, classOf[If] -> CometIf) diff --git a/spark/src/main/scala/org/apache/comet/serde/arrays.scala b/spark/src/main/scala/org/apache/comet/serde/arrays.scala index 5edc08840a..74fe02bfe2 100644 --- a/spark/src/main/scala/org/apache/comet/serde/arrays.scala +++ b/spark/src/main/scala/org/apache/comet/serde/arrays.scala @@ -22,7 +22,7 @@ package org.apache.comet.serde import scala.annotation.tailrec import scala.jdk.CollectionConverters._ -import org.apache.spark.sql.catalyst.expressions.{And, ArrayAppend, ArrayContains, ArrayExcept, ArrayFilter, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayMax, ArrayMin, ArrayPosition, ArrayRemove, ArrayRepeat, ArraysOverlap, ArraysZip, ArrayUnion, Attribute, CreateArray, ElementAt, EmptyRow, Expression, Flatten, GetArrayItem, IsNotNull, Literal, Reverse, Size, SortArray} +import org.apache.spark.sql.catalyst.expressions.{And, ArrayAppend, ArrayContains, ArrayExcept, ArrayExists, ArrayFilter, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayMax, ArrayMin, ArrayPosition, ArrayRemove, ArrayRepeat, ArraysOverlap, ArraysZip, ArrayUnion, Attribute, AttributeReference, CreateArray, ElementAt, EmptyRow, Expression, Flatten, GetArrayItem, IsNotNull, LambdaFunction, Literal, NamedLambdaVariable, Reverse, Size, SortArray} import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -812,3 +812,77 @@ trait ArraysBase { } } } + +object CometArrayExists extends CometExpressionSerde[ArrayExists] { + + private def isElementTypeSupported(dt: DataType): Boolean = dt match { + case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | + _: DecimalType | DateType | TimestampType | StringType => + true + case _ => false + } + + override def getSupportLevel(expr: ArrayExists): SupportLevel = { + val ArrayType(elementType, _) = expr.argument.dataType + if (!isElementTypeSupported(elementType)) { + return Unsupported(Some(s"Unsupported array element type: $elementType")) + } + // Only support lambdas that reference the lambda variable alone (no captured columns) + expr.function match { + case LambdaFunction(body, Seq(_: NamedLambdaVariable), _) => + val capturedRefs = body.collect { case a: AttributeReference => a } + if (capturedRefs.nonEmpty) { + Unsupported(Some("Lambda references columns outside the array element")) + } else { + Compatible() + } + case _ => + Unsupported(Some("Only single-argument lambda functions are supported")) + } + } + + override def convert( + expr: ArrayExists, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + val arrayProto = exprToProtoInternal(expr.argument, inputs, binding) + if (arrayProto.isEmpty) { + withInfo(expr, "Failed to serialize array argument") + return None + } + + // The lambda body is evaluated on the executor inside ArrayExistsUDF, so the entire + // ArrayExists Catalyst expression must travel with the plan. The driver-side + // CometLambdaRegistry approach broke under real distributed execution because the + // executor JVM doesn't share the driver's map. Serializing the expression as bytes + // and shipping it through the proto as a Literal arg keeps the lambda self-contained. + val baos = new java.io.ByteArrayOutputStream() + val oos = new java.io.ObjectOutputStream(baos) + try { + oos.writeObject(expr) + } finally { + oos.close() + } + val payloadProto = exprToProtoInternal(Literal(baos.toByteArray, BinaryType), inputs, binding) + if (payloadProto.isEmpty) { + withInfo(expr, "Failed to serialize lambda expression payload") + return None + } + + val returnType = serializeDataType(BooleanType).getOrElse(return None) + + val udfBuilder = ExprOuterClass.JvmScalarUdf + .newBuilder() + .setClassName("org.apache.comet.udf.ArrayExistsUDF") + .addArgs(arrayProto.get) + .addArgs(payloadProto.get) + .setReturnType(returnType) + .setReturnNullable(expr.nullable) + + Some( + ExprOuterClass.Expr + .newBuilder() + .setJvmScalarUdf(udfBuilder.build()) + .build()) + } +} diff --git a/spark/src/main/scala/org/apache/comet/udf/ArrayExistsUDF.scala b/spark/src/main/scala/org/apache/comet/udf/ArrayExistsUDF.scala new file mode 100644 index 0000000000..4470fa6249 --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/udf/ArrayExistsUDF.scala @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.udf + +import org.apache.arrow.vector._ +import org.apache.arrow.vector.complex.ListVector +import org.apache.spark.sql.catalyst.expressions.{ArrayExists, LambdaFunction, NamedLambdaVariable} +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +import org.apache.comet.CometArrowAllocator + +/** + * JVM UDF implementing Spark's `exists(array, x -> predicate(x))` higher-order function. + * + * Inputs: + * - inputs(0): ListVector (the array column) + * - inputs(1): VarBinaryVector length-1 scalar containing the Java-serialized [[ArrayExists]] + * Catalyst expression. Shipping the expression in the proto avoids the driver-vs-executor + * mismatch a process-local registry would suffer. + * + * Output: BitVector (nullable boolean), same length as the input array vector. + * + * Implements Spark's three-valued logic: + * - true if any element satisfies the predicate + * - null if no element satisfies but the predicate returned null for at least one element + * - false if all elements produce false + */ +class ArrayExistsUDF extends CometUDF { + + override def evaluate(inputs: Array[ValueVector], numRows: Int): ValueVector = { + require(inputs.length == 2, s"ArrayExistsUDF expects 2 inputs, got ${inputs.length}") + val listVec = inputs(0).asInstanceOf[ListVector] + val payloadVec = inputs(1).asInstanceOf[VarBinaryVector] + require( + payloadVec.getValueCount >= 1 && !payloadVec.isNull(0), + "ArrayExistsUDF requires a non-null scalar payload") + + val payloadBytes = payloadVec.get(0) + val bais = new java.io.ByteArrayInputStream(payloadBytes) + val ois = new java.io.ObjectInputStream(bais) + val arrayExistsExpr = + try ois.readObject().asInstanceOf[ArrayExists] + finally ois.close() + + val LambdaFunction(_, Seq(elementVar: NamedLambdaVariable), _) = arrayExistsExpr.function + val body = arrayExistsExpr.functionForEval + val followThreeValuedLogic = arrayExistsExpr.followThreeValuedLogic + val elementType = elementVar.dataType + + val dataVec = listVec.getDataVector + val n = listVec.getValueCount + val out = new BitVector("exists_result", CometArrowAllocator) + out.allocateNew(n) + + var i = 0 + while (i < n) { + if (listVec.isNull(i)) { + out.setNull(i) + } else { + val startIdx = listVec.getElementStartIndex(i) + val endIdx = listVec.getElementEndIndex(i) + var exists = false + var foundNull = false + var j = startIdx + while (j < endIdx && !exists) { + if (dataVec.isNull(j)) { + elementVar.value.set(null) + val ret = body.eval(null) + if (ret == null) foundNull = true + else if (ret.asInstanceOf[Boolean]) exists = true + } else { + val elem = getSparkValue(dataVec, j, elementType) + elementVar.value.set(elem) + val ret = body.eval(null) + if (ret == null) foundNull = true + else if (ret.asInstanceOf[Boolean]) exists = true + } + j += 1 + } + if (exists) { + out.set(i, 1) + } else if (followThreeValuedLogic && foundNull) { + out.setNull(i) + } else { + out.set(i, 0) + } + } + i += 1 + } + out.setValueCount(n) + out + } + + private def getSparkValue(vec: ValueVector, index: Int, sparkType: DataType): Any = { + sparkType match { + case BooleanType => + vec.asInstanceOf[BitVector].get(index) == 1 + case ByteType => + vec.asInstanceOf[TinyIntVector].get(index).toByte + case ShortType => + vec.asInstanceOf[SmallIntVector].get(index).toShort + case IntegerType => + vec.asInstanceOf[IntVector].get(index) + case LongType => + vec.asInstanceOf[BigIntVector].get(index) + case FloatType => + vec.asInstanceOf[Float4Vector].get(index) + case DoubleType => + vec.asInstanceOf[Float8Vector].get(index) + case StringType => + val bytes = vec.asInstanceOf[VarCharVector].get(index) + UTF8String.fromBytes(bytes) + case BinaryType => + vec.asInstanceOf[VarBinaryVector].get(index) + case _: DecimalType => + val dt = sparkType.asInstanceOf[DecimalType] + val decimal = vec.asInstanceOf[DecimalVector].getObject(index) + Decimal(decimal, dt.precision, dt.scale) + case DateType => + vec.asInstanceOf[DateDayVector].get(index) + case TimestampType => + vec.asInstanceOf[TimeStampMicroTZVector].get(index) + case _ => + throw new UnsupportedOperationException( + s"ArrayExistsUDF does not yet support element type: $sparkType") + } + } +} diff --git a/spark/src/main/scala/org/apache/comet/udf/CometLambdaRegistry.scala b/spark/src/main/scala/org/apache/comet/udf/CometLambdaRegistry.scala deleted file mode 100644 index 5e020ae74a..0000000000 --- a/spark/src/main/scala/org/apache/comet/udf/CometLambdaRegistry.scala +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.comet.udf - -import java.util.UUID -import java.util.concurrent.ConcurrentHashMap - -import org.apache.spark.sql.catalyst.expressions.Expression - -/** - * Thread-safe registry bridging plan-time Spark expressions to execution-time UDF lookup. At plan - * time the serde layer registers a lambda expression under a unique key; at execution time the - * UDF retrieves it by that key (passed as a scalar argument). - */ -object CometLambdaRegistry { - - private val registry = new ConcurrentHashMap[String, Expression]() - - def register(expression: Expression): String = { - val key = UUID.randomUUID().toString - registry.put(key, expression) - key - } - - def get(key: String): Expression = { - val expr = registry.get(key) - if (expr == null) { - throw new IllegalStateException( - s"Lambda expression not found in registry for key: $key. " + - "This indicates a lifecycle issue between plan creation and execution.") - } - expr - } - - def remove(key: String): Unit = { - registry.remove(key) - } - - // Visible for testing - def size(): Int = registry.size() -} diff --git a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala index 1b3f8c070c..482e5e74ca 100644 --- a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala @@ -1091,4 +1091,145 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp } } } + + test("array_exists - integer predicate") { + withTable("t") { + sql("CREATE TABLE t (arr ARRAY) USING parquet") + sql("INSERT INTO t VALUES (array(1, 2, 3)), (array(4, 5, 6)), (array(-1, -2)), (NULL)") + checkSparkAnswerAndOperator(sql("SELECT exists(arr, x -> x > 2) FROM t")) + } + } + + test("array_exists - string predicate") { + withTable("t") { + sql("CREATE TABLE t (arr ARRAY) USING parquet") + sql( + "INSERT INTO t VALUES (array('hello', 'world')), (array('foo')), (array(NULL, 'bar')), (NULL)") + checkSparkAnswerAndOperator(sql("SELECT exists(arr, x -> x = 'world') FROM t")) + } + } + + test("array_exists - null elements with three-valued logic") { + withTable("t") { + sql("CREATE TABLE t (arr ARRAY) USING parquet") + sql("INSERT INTO t VALUES (array(1, NULL, 3)), (array(NULL, NULL)), (array(4, 5))") + checkSparkAnswerAndOperator(sql("SELECT exists(arr, x -> x > 10) FROM t")) + } + } + + test("array_exists - all elements match") { + withTable("t") { + sql("CREATE TABLE t (arr ARRAY) USING parquet") + sql("INSERT INTO t VALUES (array(10, 20, 30)), (array(1, 2, 3))") + checkSparkAnswerAndOperator(sql("SELECT exists(arr, x -> x > 0) FROM t")) + } + } + + test("array_exists - empty array") { + withTable("t") { + sql("CREATE TABLE t (arr ARRAY) USING parquet") + sql("INSERT INTO t VALUES (array()), (array(1))") + checkSparkAnswerAndOperator(sql("SELECT exists(arr, x -> x > 0) FROM t")) + } + } + + test("array_exists - DataFrame API") { + withTable("t") { + sql("CREATE TABLE t (arr ARRAY) USING parquet") + sql("INSERT INTO t VALUES (array(1, 2, 3)), (array(1, 2)), (array()), (NULL)") + val df = spark.table("t") + checkSparkAnswerAndOperator(df.select(exists(col("arr"), x => x > 2))) + checkSparkAnswerAndOperator( + df.select( + exists(col("arr"), x => x > 0).as("any_positive"), + exists(col("arr"), x => x > 100).as("any_large"))) + } + } + + test("array_exists - decimal element type") { + withTable("t") { + sql("CREATE TABLE t (arr ARRAY) USING parquet") + sql("INSERT INTO t VALUES (array(1.50, 2.75, 3.25)), (array(0.10, 0.20))") + checkSparkAnswerAndOperator( + spark.table("t").select(exists(col("arr"), x => x > lit(BigDecimal("2.00"))))) + } + } + + test("array_exists - date element type") { + withTable("t") { + sql("CREATE TABLE t (arr ARRAY) USING parquet") + sql( + "INSERT INTO t VALUES (array(date'2024-01-01', date'2024-06-15')), (array(date'2023-01-01'))") + checkSparkAnswerAndOperator( + spark + .table("t") + .select(exists(col("arr"), x => x > lit("2024-03-01").cast("date")))) + } + } + + test("array_exists - timestamp element type") { + withTable("t") { + sql("CREATE TABLE t (arr ARRAY) USING parquet") + sql( + "INSERT INTO t VALUES " + + "(array(timestamp'2024-01-01 00:00:00', timestamp'2024-06-15 12:30:00')), " + + "(array(timestamp'2023-01-01 00:00:00'))") + checkSparkAnswerAndOperator( + spark + .table("t") + .select(exists(col("arr"), x => x > lit("2024-03-01 00:00:00").cast("timestamp")))) + } + } + + test("array_exists - literal lambda bodies") { + withTable("t") { + sql("CREATE TABLE t (arr ARRAY) USING parquet") + sql("INSERT INTO t VALUES (array(1, 2, 3)), (array()), (NULL)") + val df = spark.table("t") + checkSparkAnswerAndOperator(df.select(exists(col("arr"), _ => lit(false)))) + checkSparkAnswerAndOperator(df.select(exists(col("arr"), _ => lit(true)))) + checkSparkAnswerAndOperator(df.select(exists(col("arr"), _ => lit(null).cast("boolean")))) + } + } + + test("array_exists - CaseWhen/If in lambda") { + withTable("t") { + sql("CREATE TABLE t (arr ARRAY) USING parquet") + sql("INSERT INTO t VALUES (array(1, 2, 3)), (array(-1, 0, 1)), (NULL)") + val df = spark.table("t") + checkSparkAnswerAndOperator( + df.selectExpr("exists(arr, x -> CASE WHEN x > 0 THEN true ELSE false END)")) + checkSparkAnswerAndOperator(df.selectExpr("exists(arr, x -> IF(x > 0, true, false))")) + } + } + + test("array_exists - fallback for unsupported element type") { + withTable("t") { + sql("CREATE TABLE t (arr ARRAY) USING parquet") + sql("INSERT INTO t VALUES (array(X'01', X'02'))") + checkSparkAnswerAndFallbackReason( + spark.table("t").select(exists(col("arr"), x => x.isNotNull)), + "Unsupported array element type") + } + } + + test("array_exists - fallback for lambda capturing outer column") { + withTable("t") { + sql("CREATE TABLE t (arr ARRAY, threshold INT) USING parquet") + sql("INSERT INTO t VALUES (array(1, 2, 3), 2), (array(1, 2), 5)") + checkSparkAnswerAndFallbackReason( + spark.table("t").select(exists(col("arr"), x => x > col("threshold"))), + "Lambda references columns outside the array element") + } + } + + test("array_exists - fallback for nested lambda") { + withTable("t") { + sql("CREATE TABLE t (arr1 ARRAY, arr2 ARRAY) USING parquet") + sql("INSERT INTO t VALUES (array(1, 2, 3), array(4, 5, 6)), (array(10), array(1))") + checkSparkAnswerAndFallbackReason( + spark.table("t").select(exists(col("arr1"), x => exists(col("arr2"), y => y > x))), + "Lambda references columns outside the array element") + } + } } diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometArrayExpressionBenchmark.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometArrayExpressionBenchmark.scala index 624abdebe8..ef2b2a757c 100644 --- a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometArrayExpressionBenchmark.scala +++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometArrayExpressionBenchmark.scala @@ -146,6 +146,39 @@ object CometArrayExpressionBenchmark extends CometBenchmarkBase { } } + def arrayExistsBenchmark(values: Int): Unit = { + withTempPath { dir => + withTempTable("parquetV1Table") { + prepareTable( + dir, + spark.sql(s"""SELECT + | array( + | cast(value % 100 as int), + | cast((value + 1) % 100 as int), + | cast((value + 2) % 100 as int), + | cast((value + 3) % 100 as int), + | cast((value + 4) % 100 as int), + | cast((value + 5) % 100 as int), + | cast((value + 6) % 100 as int), + | cast((value + 7) % 100 as int), + | cast((value + 8) % 100 as int), + | cast((value + 9) % 100 as int) + | ) as int_arr + |FROM $tbl""".stripMargin)) + + runExpressionBenchmark( + "array_exists - int array (x -> x > 50)", + values, + "SELECT exists(int_arr, x -> x > 50) FROM parquetV1Table") + + runExpressionBenchmark( + "array_exists - int array (x -> x < 0)", + values, + "SELECT exists(int_arr, x -> x < 0) FROM parquetV1Table") + } + } + } + override def runCometBenchmark(mainArgs: Array[String]): Unit = { val values = 4 * 1024 * 1024 @@ -168,5 +201,9 @@ object CometArrayExpressionBenchmark extends CometBenchmarkBase { runBenchmarkWithTable("ArrayPosition", values) { v => arrayPositionBenchmark(v) } + + runBenchmarkWithTable("ArrayExists", values) { v => + arrayExistsBenchmark(v) + } } }