diff --git a/native/spark-expr/src/bloom_filter/bloom_filter_agg.rs b/native/spark-expr/src/bloom_filter/bloom_filter_agg.rs index 4ac236e6fc..2a7a4b5506 100644 --- a/native/spark-expr/src/bloom_filter/bloom_filter_agg.rs +++ b/native/spark-expr/src/bloom_filter/bloom_filter_agg.rs @@ -25,7 +25,7 @@ use crate::bloom_filter::spark_bloom_filter::{SparkBloomFilter, SparkBloomFilter use arrow::array::ArrayRef; use arrow::array::BinaryArray; use datafusion::common::{downcast_value, ScalarValue}; -use datafusion::error::Result; +use datafusion::error::{DataFusionError, Result}; use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion::logical_expr::{AggregateUDFImpl, Signature}; use datafusion::physical_expr::expressions::Literal; @@ -141,8 +141,16 @@ impl Accumulator for SparkBloomFilter { ScalarValue::Utf8(Some(value)) => { self.put_binary(value.as_bytes()); } - _ => { - unreachable!() + // Spark's BloomFilterAggregate.update ignores null inputs. + ScalarValue::Int8(None) + | ScalarValue::Int16(None) + | ScalarValue::Int32(None) + | ScalarValue::Int64(None) + | ScalarValue::Utf8(None) => {} + other => { + return Err(DataFusionError::Internal(format!( + "bloom_filter_agg received an unsupported input type: {other:?}" + ))); } } Ok(()) @@ -150,6 +158,13 @@ impl Accumulator for SparkBloomFilter { } fn evaluate(&mut self) -> Result { + // Spark's BloomFilterAggregate.eval returns NULL when no bit is set, + // i.e. the aggregate saw no non-null input. Mirror that here so an + // empty-input bloom_filter_agg yields NULL rather than a serialized + // empty filter. + if self.cardinality() == 0 { + return Ok(ScalarValue::Binary(None)); + } Ok(ScalarValue::Binary(Some(self.spark_serialization()))) } @@ -173,7 +188,34 @@ impl Accumulator for SparkBloomFilter { ); assert_eq!(states[0].len(), 1); let state_sv = downcast_value!(states[0], BinaryArray); - self.merge_filter(state_sv.value_data()); - Ok(()) + self.merge_filter(state_sv.value_data()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + /// Spark's BloomFilterAggregate.eval returns NULL when the filter saw no + /// non-null input (cardinality 0); an untouched accumulator must match. + #[test] + fn evaluate_empty_filter_yields_null() { + let num_bits = 1024; + let num_hash = spark_bloom_filter::optimal_num_hash_functions(100, num_bits); + let mut acc = SparkBloomFilter::new(SparkBloomFilterVersion::V1, num_hash, num_bits, 0); + assert_eq!(acc.evaluate().unwrap(), ScalarValue::Binary(None)); + } + + /// A filter with at least one set bit serializes to a non-null binary. + #[test] + fn evaluate_non_empty_filter_yields_binary() { + let num_bits = 1024; + let num_hash = spark_bloom_filter::optimal_num_hash_functions(100, num_bits); + let mut acc = SparkBloomFilter::new(SparkBloomFilterVersion::V1, num_hash, num_bits, 0); + acc.put_long(42); + assert!(matches!( + acc.evaluate().unwrap(), + ScalarValue::Binary(Some(_)) + )); } } diff --git a/native/spark-expr/src/bloom_filter/spark_bit_array.rs b/native/spark-expr/src/bloom_filter/spark_bit_array.rs index 954e983a0f..6d43bdb942 100644 --- a/native/spark-expr/src/bloom_filter/spark_bit_array.rs +++ b/native/spark-expr/src/bloom_filter/spark_bit_array.rs @@ -15,9 +15,6 @@ // specific language governing permissions and limitations // under the License. -use arrow::datatypes::ToByteSlice; -use std::iter::zip; - /// A simple bit array implementation that simulates the behavior of Spark's BitArray which is /// used in the BloomFilter implementation. Some methods are not implemented as they are not /// required for the current use case. @@ -61,41 +58,28 @@ impl SparkBitArray { self.word_size() as u64 * 64 } - pub fn byte_size(&self) -> usize { - self.word_size() * 8 - } - pub fn word_size(&self) -> usize { self.data.len() } - #[allow(dead_code)] // this is only called from tests - pub fn cardinality(&self) -> usize { - self.bit_count - } - - pub fn to_bytes(&self) -> Vec { - Vec::from(self.data.to_byte_slice()) - } - pub fn data(&self) -> Vec { self.data.clone() } - // Combines SparkBitArrays, however other is a &[u8] because we anticipate to come from an - // Arrow ScalarValue::Binary which is a byte vector underneath, rather than a word vector. - pub fn merge_bits(&mut self, other: &[u8]) { - assert_eq!(self.byte_size(), other.len()); + /// Number of set bits in the array. Mirrors Spark's `BitArray.cardinality()`. + pub fn cardinality(&self) -> usize { + self.bit_count + } + + /// OR-merge `incoming` (big-endian `u64` words, one per word in `self`) into + /// `self.data` in place and refresh `bit_count` in the same pass. The caller + /// is responsible for ensuring `incoming.len() == self.word_size() * 8`. + pub fn merge_be_words(&mut self, incoming: &[u8]) { + debug_assert_eq!(self.data.len() * 8, incoming.len()); let mut bit_count: usize = 0; - // For each word, merge the bits into self, and accumulate a new bit_count. - for i in zip( - self.data.iter_mut(), - other - .chunks(8) - .map(|chunk| u64::from_ne_bytes(chunk.try_into().unwrap())), - ) { - *i.0 |= i.1; - bit_count += i.0.count_ones() as usize; + for (word, chunk) in self.data.iter_mut().zip(incoming.chunks_exact(8)) { + *word |= u64::from_be_bytes(chunk.try_into().unwrap()); + bit_count += word.count_ones() as usize; } self.bit_count = bit_count; } @@ -108,6 +92,37 @@ pub fn num_words(num_bits: usize) -> usize { #[cfg(test)] mod test { use super::*; + use arrow::datatypes::ToByteSlice; + use std::iter::zip; + + impl SparkBitArray { + fn byte_size(&self) -> usize { + self.word_size() * 8 + } + + fn to_bytes(&self) -> Vec { + Vec::from(self.data.to_byte_slice()) + } + + /// Combines SparkBitArrays, however other is a &[u8] because we anticipate to come from + /// an Arrow ScalarValue::Binary which is a byte vector underneath, rather than a word + /// vector. + fn merge_bits(&mut self, other: &[u8]) { + assert_eq!(self.byte_size(), other.len()); + let mut bit_count: usize = 0; + // For each word, merge the bits into self, and accumulate a new bit_count. + for i in zip( + self.data.iter_mut(), + other + .chunks(8) + .map(|chunk| u64::from_ne_bytes(chunk.try_into().unwrap())), + ) { + *i.0 |= i.1; + bit_count += i.0.count_ones() as usize; + } + self.bit_count = bit_count; + } + } #[test] fn test_spark_bit_array() { diff --git a/native/spark-expr/src/bloom_filter/spark_bloom_filter.rs b/native/spark-expr/src/bloom_filter/spark_bloom_filter.rs index 2f52941210..0c5dbea6c8 100644 --- a/native/spark-expr/src/bloom_filter/spark_bloom_filter.rs +++ b/native/spark-expr/src/bloom_filter/spark_bloom_filter.rs @@ -17,6 +17,7 @@ use arrow::array::{ArrowNativeTypeOp, BooleanArray, Int64Array}; use arrow::datatypes::ToByteSlice; +use datafusion::common::{DataFusionError, Result as DFResult}; use std::cmp; use crate::bloom_filter::spark_bit_array; @@ -271,17 +272,72 @@ impl SparkBloomFilter { .collect() } + /// Number of set bits in the underlying bit array. Mirrors Spark's + /// `BloomFilter.cardinality()`: a filter that has seen no non-null input + /// has cardinality 0. + pub fn cardinality(&self) -> usize { + self.bits.cardinality() + } + pub fn state_as_bytes(&self) -> Vec { - self.bits.to_bytes() + self.spark_serialization() } - pub fn merge_filter(&mut self, other: &[u8]) { - assert_eq!( - other.len(), - self.bits.byte_size(), - "Cannot merge SparkBloomFilters with different lengths." - ); - self.bits.merge_bits(other); + pub fn merge_filter(&mut self, other: &[u8]) -> DFResult<()> { + let mut offset = 0; + + let version_int = read_num_be_bytes!(i32, 4, other[offset..]); + offset += 4; + if version_int != self.version.to_int() { + return Err(DataFusionError::Internal(format!( + "BloomFilter merge: version mismatch (got {}, expected {})", + version_int, + self.version.to_int(), + ))); + } + + let num_hash = read_num_be_bytes!(i32, 4, other[offset..]) as u32; + offset += 4; + if num_hash != self.num_hash_functions { + return Err(DataFusionError::Internal(format!( + "BloomFilter merge: num_hash_functions mismatch (got {}, expected {})", + num_hash, self.num_hash_functions, + ))); + } + + if let SparkBloomFilterVersion::V2 = self.version { + let seed = read_num_be_bytes!(i32, 4, other[offset..]); + offset += 4; + if seed != self.seed { + return Err(DataFusionError::Internal(format!( + "BloomFilter merge: seed mismatch (got {}, expected {})", + seed, self.seed, + ))); + } + } + + let num_words = read_num_be_bytes!(i32, 4, other[offset..]) as usize; + offset += 4; + if num_words != self.bits.word_size() { + return Err(DataFusionError::Internal(format!( + "BloomFilter merge: num_words mismatch (got {}, expected {})", + num_words, + self.bits.word_size(), + ))); + } + + let expected_bytes = num_words * 8; + if other.len() - offset < expected_bytes { + return Err(DataFusionError::Internal(format!( + "BloomFilter merge: truncated bit array (got {} bytes, expected {})", + other.len() - offset, + expected_bytes, + ))); + } + + self.bits + .merge_be_words(&other[offset..offset + expected_bytes]); + Ok(()) } } @@ -396,4 +452,97 @@ mod tests { buf.extend_from_slice(&[0u8; 32]); // 4 words * 8 bytes let _ = SparkBloomFilter::from(buf.as_slice()); } + + /// Two V1 filters with identical parameters. Populate the first, serialize via + /// state_as_bytes, merge into the empty second, and verify the second contains + /// everything the first did. Exercises the aggregator state → merge_batch path. + #[test] + fn state_round_trip_v1_merge() { + let num_bits = 1024; + let num_hash = optimal_num_hash_functions(100, num_bits); + let mut a = SparkBloomFilter::new(SparkBloomFilterVersion::V1, num_hash, num_bits, 0); + for v in [1_i64, 7, 42, 99, -3, i64::MAX] { + a.put_long(v); + } + + let mut b = SparkBloomFilter::new(SparkBloomFilterVersion::V1, num_hash, num_bits, 0); + b.merge_filter(&a.state_as_bytes()).unwrap(); + + for v in [1_i64, 7, 42, 99, -3, i64::MAX] { + assert!(b.might_contain_long(v), "missing {v} after merge"); + } + } + + /// V2 default seed (0) round-trip through state_as_bytes → merge_filter. + #[test] + fn state_round_trip_v2_default_seed() { + let num_bits = 1024; + let num_hash = optimal_num_hash_functions(100, num_bits); + let mut a = SparkBloomFilter::new(SparkBloomFilterVersion::V2, num_hash, num_bits, 0); + for v in [11_i64, 222, 3333] { + a.put_long(v); + } + + let mut b = SparkBloomFilter::new(SparkBloomFilterVersion::V2, num_hash, num_bits, 0); + b.merge_filter(&a.state_as_bytes()).unwrap(); + + for v in [11_i64, 222, 3333] { + assert!(b.might_contain_long(v)); + } + } + + /// V2 non-zero seed round-trip; verifies the seed field is parsed and that + /// both filters use the same seed-dependent hash scattering. + #[test] + fn state_round_trip_v2_nonzero_seed() { + let num_bits = 1024; + let num_hash = optimal_num_hash_functions(100, num_bits); + let seed = 0x5eed_5eed_u32 as i32; + let mut a = SparkBloomFilter::new(SparkBloomFilterVersion::V2, num_hash, num_bits, seed); + a.put_long(123); + + let mut b = SparkBloomFilter::new(SparkBloomFilterVersion::V2, num_hash, num_bits, seed); + b.merge_filter(&a.state_as_bytes()).unwrap(); + + assert!(b.might_contain_long(123)); + } + + fn assert_merge_err_contains(filter: &mut SparkBloomFilter, buf: &[u8], needle: &str) { + let err = filter.merge_filter(buf).unwrap_err().to_string(); + assert!(err.contains(needle), "expected `{needle}` in error: {err}"); + } + + #[test] + fn merge_rejects_version_mismatch() { + let num_bits = 1024; + let num_hash = optimal_num_hash_functions(100, num_bits); + let a = SparkBloomFilter::new(SparkBloomFilterVersion::V2, num_hash, num_bits, 0); + let mut b = SparkBloomFilter::new(SparkBloomFilterVersion::V1, num_hash, num_bits, 0); + assert_merge_err_contains(&mut b, &a.state_as_bytes(), "version mismatch"); + } + + #[test] + fn merge_rejects_num_hash_mismatch() { + let num_bits = 1024; + let a = SparkBloomFilter::new(SparkBloomFilterVersion::V1, 5, num_bits, 0); + let mut b = SparkBloomFilter::new(SparkBloomFilterVersion::V1, 7, num_bits, 0); + assert_merge_err_contains(&mut b, &a.state_as_bytes(), "num_hash_functions mismatch"); + } + + #[test] + fn merge_rejects_seed_mismatch_v2() { + let num_bits = 1024; + let num_hash = optimal_num_hash_functions(100, num_bits); + let a = SparkBloomFilter::new(SparkBloomFilterVersion::V2, num_hash, num_bits, 1); + let mut b = SparkBloomFilter::new(SparkBloomFilterVersion::V2, num_hash, num_bits, 2); + assert_merge_err_contains(&mut b, &a.state_as_bytes(), "seed mismatch"); + } + + #[test] + fn merge_rejects_num_words_mismatch() { + let num_hash = 5; + let a = SparkBloomFilter::new(SparkBloomFilterVersion::V1, num_hash, 512, 0); + let mut b = SparkBloomFilter::new(SparkBloomFilterVersion::V1, num_hash, 1024, 0); + assert_merge_err_contains(&mut b, &a.state_as_bytes(), "num_words mismatch"); + } } diff --git a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala index e7227ccca9..2714a7e466 100644 --- a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala +++ b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala @@ -21,7 +21,7 @@ package org.apache.comet.serde import scala.jdk.CollectionConverters._ -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal} import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, BitAndAgg, BitOrAgg, BitXorAgg, BloomFilterAggregate, CentralMomentAgg, CollectSet, Corr, Count, Covariance, CovPopulation, CovSample, First, Last, Max, Min, StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{ByteType, DataTypes, DecimalType, IntegerType, LongType, ShortType, StringType} @@ -638,6 +638,8 @@ object CometCorr extends CometAggregateExpressionSerde[Corr] { object CometBloomFilterAggregate extends CometAggregateExpressionSerde[BloomFilterAggregate] { + override def supportsMixedPartialFinal: Boolean = true + override def convert( aggExpr: AggregateExpression, bloomFilter: BloomFilterAggregate, @@ -647,8 +649,20 @@ object CometBloomFilterAggregate extends CometAggregateExpressionSerde[BloomFilt // We ignore mutableAggBufferOffset and inputAggBufferOffset because they are // implementation details for Spark's ObjectHashAggregate. val childExpr = exprToProto(bloomFilter.child, inputs, binding) - val numItemsExpr = exprToProto(bloomFilter.estimatedNumItemsExpression, inputs, binding) - val numBitsExpr = exprToProto(bloomFilter.numBitsExpression, inputs, binding) + // Spark's BloomFilterAggregate caps numItems / numBits at the configured maxima + // (its `estimatedNumItems` / `numBits` lazy vals). Comet's native aggregate stores + // these as i32, so an uncapped Long literal (e.g. the Long.MaxValue cases in + // BloomFilterAggregateQuerySuite) would wrap to a bogus negative size and abort the + // executor with a multi-exabyte allocation. Apply the same cap here so the native + // side always receives a sane, Spark-equivalent value. + val numItems = math.min( + bloomFilter.estimatedNumItemsExpression.eval().asInstanceOf[Number].longValue, + conf.getConf(SQLConf.RUNTIME_BLOOM_FILTER_MAX_NUM_ITEMS)) + val numBits = math.min( + bloomFilter.numBitsExpression.eval().asInstanceOf[Number].longValue, + conf.getConf(SQLConf.RUNTIME_BLOOM_FILTER_MAX_NUM_BITS)) + val numItemsExpr = exprToProto(Literal(numItems, LongType), inputs, binding) + val numBitsExpr = exprToProto(Literal(numBits, LongType), inputs, binding) val dataType = serializeDataType(bloomFilter.dataType) if (childExpr.isDefined && diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index b24b96abd1..f7109df795 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -1697,6 +1697,20 @@ object CometObjectHashAggregateExec override def enabledConfig: Option[ConfigEntry[Boolean]] = Some( CometConf.COMET_EXEC_AGGREGATE_ENABLED) + override def getSupportLevel(op: ObjectHashAggregateExec): SupportLevel = { + // Mirror the same test-knobs as CometHashAggregateExec so that mixed-execution + // unit tests can selectively disable partial or final ObjectHashAggregateExec conversion. + if (!CometConf.COMET_ENABLE_PARTIAL_HASH_AGGREGATE.get(op.conf) && + op.aggregateExpressions.exists(expr => expr.mode == Partial || expr.mode == PartialMerge)) { + return Unsupported(Some("Partial aggregates disabled via test config")) + } + if (!CometConf.COMET_ENABLE_FINAL_HASH_AGGREGATE.get(op.conf) && + op.aggregateExpressions.exists(_.mode == Final)) { + return Unsupported(Some("Final aggregates disabled via test config")) + } + Compatible() + } + override def convert( aggregate: ObjectHashAggregateExec, builder: Operator.Builder, diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExec3_4PlusSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExec3_4PlusSuite.scala index b187c30b44..3a406f9a72 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExec3_4PlusSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExec3_4PlusSuite.scala @@ -29,6 +29,7 @@ import org.scalatest.Tag import org.apache.spark.sql.CometTestBase import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.expressions.{BloomFilterMightContain, Expression, ExpressionInfo} +import org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate import org.apache.spark.sql.functions.{col, lit} import org.apache.spark.util.sketch.BloomFilter @@ -42,6 +43,7 @@ class CometExec3_4PlusSuite extends CometTestBase { import testImplicits._ val func_might_contain = new FunctionIdentifier("might_contain") + val func_bloom_filter_agg = new FunctionIdentifier("bloom_filter_agg") override def beforeAll(): Unit = { super.beforeAll() @@ -51,12 +53,23 @@ class CometExec3_4PlusSuite extends CometTestBase { func_might_contain, new ExpressionInfo(classOf[BloomFilterMightContain].getName, "might_contain"), (children: Seq[Expression]) => BloomFilterMightContain(children.head, children(1))) + // Register 'bloom_filter_agg' to builtin. + spark.sessionState.functionRegistry.registerFunction( + func_bloom_filter_agg, + new ExpressionInfo(classOf[BloomFilterAggregate].getName, "bloom_filter_agg"), + (children: Seq[Expression]) => + children.size match { + case 1 => new BloomFilterAggregate(children.head) + case 2 => new BloomFilterAggregate(children.head, children(1)) + case 3 => new BloomFilterAggregate(children.head, children(1), children(2)) + }) } } override def afterAll(): Unit = { if (!isSpark42Plus) { spark.sessionState.functionRegistry.dropFunction(func_might_contain) + spark.sessionState.functionRegistry.dropFunction(func_bloom_filter_agg) } super.afterAll() } @@ -185,6 +198,24 @@ class CometExec3_4PlusSuite extends CometTestBase { } } + test("bloom_filter_agg caps oversized numItems / numBits like Spark") { + assume(!isSpark42Plus, "https://github.com/apache/datafusion-comet/issues/4142") + val table = "test" + withTable(table) { + sql(s"create table $table(col1 long) using parquet") + sql(s"insert into $table values (1), (2), (3), (201), (null)") + // numItems / numBits exceed the Int range. Spark's BloomFilterAggregate caps + // them at maxNumItems / maxNumBits; Comet must apply the same cap, otherwise the + // oversized values truncate to a negative i32 and abort the executor with a + // multi-exabyte allocation. + checkSparkAnswerAndOperator(s""" + |SELECT bloom_filter_agg(col1, + | cast(9223372036854775807 as long), + | cast(9223372036854775807 as long)) FROM $table + |""".stripMargin) + } + } + private def bloomFilterFromRandomInput( expectedItems: Long, expectedBits: Long): (Seq[Long], Array[Byte]) = { diff --git a/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala b/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala index 7f353c36e2..7fa06a26cc 100644 --- a/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala +++ b/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala @@ -22,16 +22,19 @@ package org.apache.comet.rules import scala.util.Random import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.FunctionIdentifier +import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo} +import org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate import org.apache.spark.sql.comet._ import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.QueryStageExec -import org.apache.spark.sql.execution.aggregate.HashAggregateExec +import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec} import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.types.{DataTypes, StructField, StructType} import org.apache.comet.CometConf -import org.apache.comet.CometSparkSessionExtensions.isSpark40Plus +import org.apache.comet.CometSparkSessionExtensions.{isSpark40Plus, isSpark42Plus} import org.apache.comet.testing.{DataGenOptions, FuzzDataGenerator} /** @@ -228,6 +231,83 @@ class CometExecRuleSuite extends CometTestBase { } } + test("CometExecRule should allow BloomFilter mixed Comet partial and Spark final") { + assume(!isSpark42Plus, "https://github.com/apache/datafusion-comet/issues/4142") + val funcId = new FunctionIdentifier("bloom_filter_agg") + spark.sessionState.functionRegistry.registerFunction( + funcId, + new ExpressionInfo(classOf[BloomFilterAggregate].getName, "bloom_filter_agg"), + (children: Seq[Expression]) => + children.size match { + case 1 => new BloomFilterAggregate(children.head) + case 2 => new BloomFilterAggregate(children.head, children(1)) + case 3 => new BloomFilterAggregate(children.head, children(1), children(2)) + }) + try { + withTempView("test_data") { + createTestDataFrame.createOrReplaceTempView("test_data") + + // Cast to bigint: Spark 3.4's bloom_filter_agg only accepts a long-typed first + // argument; later versions widened it to any integral type. + val sparkPlan = + createSparkPlan(spark, "SELECT bloom_filter_agg(CAST(id AS BIGINT)) FROM test_data") + + val originalObjectAggCount = countOperators(sparkPlan, classOf[ObjectHashAggregateExec]) + assert(originalObjectAggCount == 2) + + withSQLConf( + CometConf.COMET_ENABLE_FINAL_HASH_AGGREGATE.key -> "false", + CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "true") { + val transformedPlan = applyCometExecRule(sparkPlan) + + // BloomFilter is mixed-safe: partial converts to Comet, final stays Spark. + assert(countOperators(transformedPlan, classOf[ObjectHashAggregateExec]) == 1) + assert(countOperators(transformedPlan, classOf[CometHashAggregateExec]) == 1) + } + } + } finally { + spark.sessionState.functionRegistry.dropFunction(funcId) + } + } + + test("CometExecRule should allow BloomFilter mixed Spark partial and Comet final") { + assume(!isSpark42Plus, "https://github.com/apache/datafusion-comet/issues/4142") + val funcId = new FunctionIdentifier("bloom_filter_agg") + spark.sessionState.functionRegistry.registerFunction( + funcId, + new ExpressionInfo(classOf[BloomFilterAggregate].getName, "bloom_filter_agg"), + (children: Seq[Expression]) => + children.size match { + case 1 => new BloomFilterAggregate(children.head) + case 2 => new BloomFilterAggregate(children.head, children(1)) + case 3 => new BloomFilterAggregate(children.head, children(1), children(2)) + }) + try { + withTempView("test_data") { + createTestDataFrame.createOrReplaceTempView("test_data") + + // Cast to bigint: Spark 3.4's bloom_filter_agg only accepts a long-typed first + // argument; later versions widened it to any integral type. + val sparkPlan = + createSparkPlan(spark, "SELECT bloom_filter_agg(CAST(id AS BIGINT)) FROM test_data") + + val originalObjectAggCount = countOperators(sparkPlan, classOf[ObjectHashAggregateExec]) + assert(originalObjectAggCount == 2) + + withSQLConf( + CometConf.COMET_ENABLE_PARTIAL_HASH_AGGREGATE.key -> "false", + CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "true") { + val transformedPlan = applyCometExecRule(sparkPlan) + + assert(countOperators(transformedPlan, classOf[ObjectHashAggregateExec]) == 1) + assert(countOperators(transformedPlan, classOf[CometHashAggregateExec]) == 1) + } + } + } finally { + spark.sessionState.functionRegistry.dropFunction(funcId) + } + } + test("CometExecRule should not convert hash aggregate when grouping key contains map type") { // Spark 3.4/3.5 reject `array>` as a grouping key in the analyzer (not orderable), // so the plan never reaches CometExecRule on those versions. The guard we're exercising