Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
f7fa33c
fix: allow safe mixed Spark/Comet partial/final aggregate execution
andygrove Apr 21, 2026
f2a8207
fix: address review feedback on mixed partial/final aggregate guard
andygrove Apr 21, 2026
9826403
fix: skip partial aggregate tag when partial itself cannot be converted
andygrove Apr 21, 2026
753a9a5
fix: narrow partial aggregate tag lookup and regenerate TPC-DS golden…
andygrove Apr 21, 2026
6ae483d
fix: reject grouping on nested map types in hash aggregate conversion
andygrove Apr 21, 2026
53405f6
fix: remove COUNT from mixed-safe aggregates to fix AQE/count-bug reg…
andygrove Apr 22, 2026
9e2c25a
spotless
andygrove Apr 22, 2026
f53e3c1
test: ignore SPARK-33853 explain codegen subquery test under Comet
andygrove Apr 23, 2026
3285485
Merge remote-tracking branch 'apache/main' into fix/safe-mixed-partia…
andygrove Apr 25, 2026
671afa6
Merge remote-tracking branch 'apache/main' into fix/safe-mixed-partia…
andygrove May 6, 2026
4322852
test: regenerate Spark 4.2 TPC-DS golden files after merge from main
andygrove May 6, 2026
12018c3
Merge remote-tracking branch 'apache/main' into fix/safe-mixed-partia…
andygrove May 20, 2026
43e0c0b
fix: address review feedback on safe mixed aggregate guard
andygrove May 20, 2026
4bbfe74
fix: drop unused StructType import and regenerate TPC-DS golden files
andygrove May 20, 2026
56e5da6
chore: revert .gitignore change
andygrove May 20, 2026
08b3924
test: ignore SPARK-33853 explain codegen test on Spark 4.1.1
andygrove May 21, 2026
64575f2
test: use descriptive reason for SPARK-33853 IgnoreComet tag
andygrove May 21, 2026
8db42b0
fix: emit Spark-compatible BloomFilter intermediate buffer
andygrove May 21, 2026
36cf0e8
feat: enable BloomFilter for mixed Spark/Comet partial/final aggregate
andygrove May 21, 2026
9bab432
Merge remote-tracking branch 'apache/main' into feat/bloom-filter-int…
andygrove May 21, 2026
2406272
refactor: move SparkBitArray test-only methods into the test module
andygrove May 21, 2026
16299be
refactor: address review feedback on SparkBloomFilter::merge_filter
andygrove May 21, 2026
d51c9d6
Merge remote-tracking branch 'apache/main' into feat/bloom-filter-int…
andygrove May 21, 2026
264510c
fix: cap bloom_filter_agg numItems/numBits and skip null inputs
andygrove May 22, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions native/spark-expr/src/bloom_filter/bloom_filter_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use crate::bloom_filter::spark_bloom_filter::{SparkBloomFilter, SparkBloomFilter
use arrow::array::ArrayRef;
use arrow::array::BinaryArray;
use datafusion::common::{downcast_value, ScalarValue};
use datafusion::error::Result;
use datafusion::error::{DataFusionError, Result};
use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs};
use datafusion::logical_expr::{AggregateUDFImpl, Signature};
use datafusion::physical_expr::expressions::Literal;
Expand Down Expand Up @@ -141,8 +141,16 @@ impl Accumulator for SparkBloomFilter {
ScalarValue::Utf8(Some(value)) => {
self.put_binary(value.as_bytes());
}
_ => {
unreachable!()
// Spark's BloomFilterAggregate.update ignores null inputs.
ScalarValue::Int8(None)
| ScalarValue::Int16(None)
| ScalarValue::Int32(None)
| ScalarValue::Int64(None)
| ScalarValue::Utf8(None) => {}
other => {
return Err(DataFusionError::Internal(format!(
"bloom_filter_agg received an unsupported input type: {other:?}"
)));
}
}
Ok(())
Expand Down Expand Up @@ -173,7 +181,6 @@ impl Accumulator for SparkBloomFilter {
);
assert_eq!(states[0].len(), 1);
let state_sv = downcast_value!(states[0], BinaryArray);
self.merge_filter(state_sv.value_data());
Ok(())
self.merge_filter(state_sv.value_data())
}
}
72 changes: 43 additions & 29 deletions native/spark-expr/src/bloom_filter/spark_bit_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@
// specific language governing permissions and limitations
// under the License.

use arrow::datatypes::ToByteSlice;
use std::iter::zip;

/// A simple bit array implementation that simulates the behavior of Spark's BitArray which is
/// used in the BloomFilter implementation. Some methods are not implemented as they are not
/// required for the current use case.
Expand Down Expand Up @@ -61,41 +58,23 @@ impl SparkBitArray {
self.word_size() as u64 * 64
}

pub fn byte_size(&self) -> usize {
self.word_size() * 8
}

pub fn word_size(&self) -> usize {
self.data.len()
}

#[allow(dead_code)] // this is only called from tests
pub fn cardinality(&self) -> usize {
self.bit_count
}

pub fn to_bytes(&self) -> Vec<u8> {
Vec::from(self.data.to_byte_slice())
}

pub fn data(&self) -> Vec<u64> {
self.data.clone()
}

// Combines SparkBitArrays, however other is a &[u8] because we anticipate to come from an
// Arrow ScalarValue::Binary which is a byte vector underneath, rather than a word vector.
pub fn merge_bits(&mut self, other: &[u8]) {
assert_eq!(self.byte_size(), other.len());
/// OR-merge `incoming` (big-endian `u64` words, one per word in `self`) into
/// `self.data` in place and refresh `bit_count` in the same pass. The caller
/// is responsible for ensuring `incoming.len() == self.word_size() * 8`.
pub fn merge_be_words(&mut self, incoming: &[u8]) {
debug_assert_eq!(self.data.len() * 8, incoming.len());
let mut bit_count: usize = 0;
// For each word, merge the bits into self, and accumulate a new bit_count.
for i in zip(
self.data.iter_mut(),
other
.chunks(8)
.map(|chunk| u64::from_ne_bytes(chunk.try_into().unwrap())),
) {
*i.0 |= i.1;
bit_count += i.0.count_ones() as usize;
for (word, chunk) in self.data.iter_mut().zip(incoming.chunks_exact(8)) {
*word |= u64::from_be_bytes(chunk.try_into().unwrap());
bit_count += word.count_ones() as usize;
}
self.bit_count = bit_count;
}
Expand All @@ -108,6 +87,41 @@ pub fn num_words(num_bits: usize) -> usize {
#[cfg(test)]
mod test {
use super::*;
use arrow::datatypes::ToByteSlice;
use std::iter::zip;

impl SparkBitArray {
fn byte_size(&self) -> usize {
self.word_size() * 8
}

fn cardinality(&self) -> usize {
self.bit_count
}

fn to_bytes(&self) -> Vec<u8> {
Vec::from(self.data.to_byte_slice())
}

/// Combines SparkBitArrays, however other is a &[u8] because we anticipate to come from
/// an Arrow ScalarValue::Binary which is a byte vector underneath, rather than a word
/// vector.
fn merge_bits(&mut self, other: &[u8]) {
assert_eq!(self.byte_size(), other.len());
let mut bit_count: usize = 0;
// For each word, merge the bits into self, and accumulate a new bit_count.
for i in zip(
self.data.iter_mut(),
other
.chunks(8)
.map(|chunk| u64::from_ne_bytes(chunk.try_into().unwrap())),
) {
*i.0 |= i.1;
bit_count += i.0.count_ones() as usize;
}
self.bit_count = bit_count;
}
}

#[test]
fn test_spark_bit_array() {
Expand Down
158 changes: 150 additions & 8 deletions native/spark-expr/src/bloom_filter/spark_bloom_filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

use arrow::array::{ArrowNativeTypeOp, BooleanArray, Int64Array};
use arrow::datatypes::ToByteSlice;
use datafusion::common::{DataFusionError, Result as DFResult};
use std::cmp;

use crate::bloom_filter::spark_bit_array;
Expand Down Expand Up @@ -272,16 +273,64 @@ impl SparkBloomFilter {
}

pub fn state_as_bytes(&self) -> Vec<u8> {
self.bits.to_bytes()
self.spark_serialization()
}

pub fn merge_filter(&mut self, other: &[u8]) {
assert_eq!(
other.len(),
self.bits.byte_size(),
"Cannot merge SparkBloomFilters with different lengths."
);
self.bits.merge_bits(other);
pub fn merge_filter(&mut self, other: &[u8]) -> DFResult<()> {
let mut offset = 0;

let version_int = read_num_be_bytes!(i32, 4, other[offset..]);
offset += 4;
if version_int != self.version.to_int() {
return Err(DataFusionError::Internal(format!(
"BloomFilter merge: version mismatch (got {}, expected {})",
version_int,
self.version.to_int(),
)));
}

let num_hash = read_num_be_bytes!(i32, 4, other[offset..]) as u32;
offset += 4;
if num_hash != self.num_hash_functions {
return Err(DataFusionError::Internal(format!(
"BloomFilter merge: num_hash_functions mismatch (got {}, expected {})",
num_hash, self.num_hash_functions,
)));
}

if let SparkBloomFilterVersion::V2 = self.version {
let seed = read_num_be_bytes!(i32, 4, other[offset..]);
offset += 4;
if seed != self.seed {
return Err(DataFusionError::Internal(format!(
"BloomFilter merge: seed mismatch (got {}, expected {})",
seed, self.seed,
)));
}
}

let num_words = read_num_be_bytes!(i32, 4, other[offset..]) as usize;
offset += 4;
if num_words != self.bits.word_size() {
return Err(DataFusionError::Internal(format!(
"BloomFilter merge: num_words mismatch (got {}, expected {})",
num_words,
self.bits.word_size(),
)));
}

let expected_bytes = num_words * 8;
if other.len() - offset < expected_bytes {
return Err(DataFusionError::Internal(format!(
"BloomFilter merge: truncated bit array (got {} bytes, expected {})",
other.len() - offset,
expected_bytes,
)));
}

self.bits
.merge_be_words(&other[offset..offset + expected_bytes]);
Ok(())
}
}

Expand Down Expand Up @@ -396,4 +445,97 @@ mod tests {
buf.extend_from_slice(&[0u8; 32]); // 4 words * 8 bytes
let _ = SparkBloomFilter::from(buf.as_slice());
}

/// Two V1 filters with identical parameters. Populate the first, serialize via
/// state_as_bytes, merge into the empty second, and verify the second contains
/// everything the first did. Exercises the aggregator state → merge_batch path.
#[test]
fn state_round_trip_v1_merge() {
let num_bits = 1024;
let num_hash = optimal_num_hash_functions(100, num_bits);
let mut a = SparkBloomFilter::new(SparkBloomFilterVersion::V1, num_hash, num_bits, 0);
for v in [1_i64, 7, 42, 99, -3, i64::MAX] {
a.put_long(v);
}

let mut b = SparkBloomFilter::new(SparkBloomFilterVersion::V1, num_hash, num_bits, 0);
b.merge_filter(&a.state_as_bytes()).unwrap();

for v in [1_i64, 7, 42, 99, -3, i64::MAX] {
assert!(b.might_contain_long(v), "missing {v} after merge");
}
}

/// V2 default seed (0) round-trip through state_as_bytes → merge_filter.
#[test]
fn state_round_trip_v2_default_seed() {
let num_bits = 1024;
let num_hash = optimal_num_hash_functions(100, num_bits);
let mut a = SparkBloomFilter::new(SparkBloomFilterVersion::V2, num_hash, num_bits, 0);
for v in [11_i64, 222, 3333] {
a.put_long(v);
}

let mut b = SparkBloomFilter::new(SparkBloomFilterVersion::V2, num_hash, num_bits, 0);
b.merge_filter(&a.state_as_bytes()).unwrap();

for v in [11_i64, 222, 3333] {
assert!(b.might_contain_long(v));
}
}

/// V2 non-zero seed round-trip; verifies the seed field is parsed and that
/// both filters use the same seed-dependent hash scattering.
#[test]
fn state_round_trip_v2_nonzero_seed() {
let num_bits = 1024;
let num_hash = optimal_num_hash_functions(100, num_bits);
let seed = 0x5eed_5eed_u32 as i32;
let mut a = SparkBloomFilter::new(SparkBloomFilterVersion::V2, num_hash, num_bits, seed);
a.put_long(123);

let mut b = SparkBloomFilter::new(SparkBloomFilterVersion::V2, num_hash, num_bits, seed);
b.merge_filter(&a.state_as_bytes()).unwrap();

assert!(b.might_contain_long(123));
}

fn assert_merge_err_contains(filter: &mut SparkBloomFilter, buf: &[u8], needle: &str) {
let err = filter.merge_filter(buf).unwrap_err().to_string();
assert!(err.contains(needle), "expected `{needle}` in error: {err}");
}

#[test]
fn merge_rejects_version_mismatch() {
let num_bits = 1024;
let num_hash = optimal_num_hash_functions(100, num_bits);
let a = SparkBloomFilter::new(SparkBloomFilterVersion::V2, num_hash, num_bits, 0);
let mut b = SparkBloomFilter::new(SparkBloomFilterVersion::V1, num_hash, num_bits, 0);
assert_merge_err_contains(&mut b, &a.state_as_bytes(), "version mismatch");
}

#[test]
fn merge_rejects_num_hash_mismatch() {
let num_bits = 1024;
let a = SparkBloomFilter::new(SparkBloomFilterVersion::V1, 5, num_bits, 0);
let mut b = SparkBloomFilter::new(SparkBloomFilterVersion::V1, 7, num_bits, 0);
assert_merge_err_contains(&mut b, &a.state_as_bytes(), "num_hash_functions mismatch");
}

#[test]
fn merge_rejects_seed_mismatch_v2() {
let num_bits = 1024;
let num_hash = optimal_num_hash_functions(100, num_bits);
let a = SparkBloomFilter::new(SparkBloomFilterVersion::V2, num_hash, num_bits, 1);
let mut b = SparkBloomFilter::new(SparkBloomFilterVersion::V2, num_hash, num_bits, 2);
assert_merge_err_contains(&mut b, &a.state_as_bytes(), "seed mismatch");
}

#[test]
fn merge_rejects_num_words_mismatch() {
let num_hash = 5;
let a = SparkBloomFilter::new(SparkBloomFilterVersion::V1, num_hash, 512, 0);
let mut b = SparkBloomFilter::new(SparkBloomFilterVersion::V1, num_hash, 1024, 0);
assert_merge_err_contains(&mut b, &a.state_as_bytes(), "num_words mismatch");
}
}
20 changes: 17 additions & 3 deletions spark/src/main/scala/org/apache/comet/serde/aggregates.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ package org.apache.comet.serde

import scala.jdk.CollectionConverters._

import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, BitAndAgg, BitOrAgg, BitXorAgg, BloomFilterAggregate, CentralMomentAgg, CollectSet, Corr, Count, Covariance, CovPopulation, CovSample, First, Last, Max, Min, StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{ByteType, DataTypes, DecimalType, IntegerType, LongType, ShortType, StringType}
Expand Down Expand Up @@ -638,6 +638,8 @@ object CometCorr extends CometAggregateExpressionSerde[Corr] {

object CometBloomFilterAggregate extends CometAggregateExpressionSerde[BloomFilterAggregate] {

override def supportsMixedPartialFinal: Boolean = true

override def convert(
aggExpr: AggregateExpression,
bloomFilter: BloomFilterAggregate,
Expand All @@ -647,8 +649,20 @@ object CometBloomFilterAggregate extends CometAggregateExpressionSerde[BloomFilt
// We ignore mutableAggBufferOffset and inputAggBufferOffset because they are
// implementation details for Spark's ObjectHashAggregate.
val childExpr = exprToProto(bloomFilter.child, inputs, binding)
val numItemsExpr = exprToProto(bloomFilter.estimatedNumItemsExpression, inputs, binding)
val numBitsExpr = exprToProto(bloomFilter.numBitsExpression, inputs, binding)
// Spark's BloomFilterAggregate caps numItems / numBits at the configured maxima
// (its `estimatedNumItems` / `numBits` lazy vals). Comet's native aggregate stores
// these as i32, so an uncapped Long literal (e.g. the Long.MaxValue cases in
// BloomFilterAggregateQuerySuite) would wrap to a bogus negative size and abort the
// executor with a multi-exabyte allocation. Apply the same cap here so the native
// side always receives a sane, Spark-equivalent value.
val numItems = math.min(
bloomFilter.estimatedNumItemsExpression.eval().asInstanceOf[Number].longValue,
conf.getConf(SQLConf.RUNTIME_BLOOM_FILTER_MAX_NUM_ITEMS))
val numBits = math.min(
bloomFilter.numBitsExpression.eval().asInstanceOf[Number].longValue,
conf.getConf(SQLConf.RUNTIME_BLOOM_FILTER_MAX_NUM_BITS))
val numItemsExpr = exprToProto(Literal(numItems, LongType), inputs, binding)
val numBitsExpr = exprToProto(Literal(numBits, LongType), inputs, binding)
val dataType = serializeDataType(bloomFilter.dataType)

if (childExpr.isDefined &&
Expand Down
14 changes: 14 additions & 0 deletions spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1697,6 +1697,20 @@ object CometObjectHashAggregateExec
override def enabledConfig: Option[ConfigEntry[Boolean]] = Some(
CometConf.COMET_EXEC_AGGREGATE_ENABLED)

override def getSupportLevel(op: ObjectHashAggregateExec): SupportLevel = {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The body here is identical to CometHashAggregateExec.getSupportLevel at operators.scala:1658-1670, including the conf names. That is fine for the test-knob purpose called out in the comment, but COMET_ENABLE_PARTIAL_HASH_AGGREGATE and COMET_ENABLE_FINAL_HASH_AGGREGATE now gate both HashAggregateExec and ObjectHashAggregateExec. As a follow-up, consider renaming to COMET_ENABLE_PARTIAL_AGGREGATE / COMET_ENABLE_FINAL_AGGREGATE so the conf names match the scope.

// Mirror the same test-knobs as CometHashAggregateExec so that mixed-execution
// unit tests can selectively disable partial or final ObjectHashAggregateExec conversion.
if (!CometConf.COMET_ENABLE_PARTIAL_HASH_AGGREGATE.get(op.conf) &&
op.aggregateExpressions.exists(expr => expr.mode == Partial || expr.mode == PartialMerge)) {
return Unsupported(Some("Partial aggregates disabled via test config"))
}
if (!CometConf.COMET_ENABLE_FINAL_HASH_AGGREGATE.get(op.conf) &&
op.aggregateExpressions.exists(_.mode == Final)) {
return Unsupported(Some("Final aggregates disabled via test config"))
}
Compatible()
}

override def convert(
aggregate: ObjectHashAggregateExec,
builder: Operator.Builder,
Expand Down
Loading
Loading