From e0cc868f0ad161eaa05d0330630bf81cd2997e49 Mon Sep 17 00:00:00 2001 From: xiedeyantu Date: Fri, 20 Mar 2026 00:28:07 +0800 Subject: [PATCH 1/3] fix: preserve duplicate GROUPING SETS rows --- datafusion/core/tests/sql/aggregates/basic.rs | 51 ++++++++ .../physical-plan/src/aggregates/mod.rs | 118 +++++++++++++++++- .../sqllogictest/test_files/group_by.slt | 22 ++++ 3 files changed, 186 insertions(+), 5 deletions(-) diff --git a/datafusion/core/tests/sql/aggregates/basic.rs b/datafusion/core/tests/sql/aggregates/basic.rs index d1b376b735ab..572b19a20875 100644 --- a/datafusion/core/tests/sql/aggregates/basic.rs +++ b/datafusion/core/tests/sql/aggregates/basic.rs @@ -175,6 +175,57 @@ async fn count_aggregated_cube() -> Result<()> { Ok(()) } +#[tokio::test] +async fn duplicate_grouping_sets_are_preserved() -> Result<()> { + let ctx = SessionContext::new(); + let schema = Arc::new(Schema::new(vec![ + Field::new("deptno", DataType::Int32, false), + Field::new("job", DataType::Utf8, true), + Field::new("sal", DataType::Int32, true), + Field::new("comm", DataType::Int32, true), + ])); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![10, 20])), + Arc::new(StringArray::from(vec![Some("CLERK"), Some("MANAGER")])), + Arc::new(Int32Array::from(vec![1300, 3000])), + Arc::new(Int32Array::from(vec![None, None])), + ], + )?; + let provider = MemTable::try_new(Arc::clone(&schema), vec![vec![batch]])?; + ctx.register_table("dup_grouping_sets", Arc::new(provider))?; + + let results = plan_and_collect( + &ctx, + " + SELECT deptno, job, sal, sum(comm) AS sum_comm, + grouping(deptno) AS deptno_flag, + grouping(job) AS job_flag, + grouping(sal) AS sal_flag + FROM dup_grouping_sets + GROUP BY GROUPING SETS ((deptno, job), (deptno, sal), (deptno, job)) + ORDER BY deptno, job, sal, deptno_flag, job_flag, sal_flag + ", + ) + .await?; + + assert_eq!(results.len(), 1); + assert_snapshot!(batches_to_string(&results), @r" + +--------+---------+------+----------+-------------+----------+----------+ + | deptno | job | sal | sum_comm | deptno_flag | job_flag | sal_flag | + +--------+---------+------+----------+-------------+----------+----------+ + | 10 | CLERK | | | 0 | 0 | 1 | + | 10 | CLERK | | | 0 | 0 | 1 | + | 10 | | 1300 | | 0 | 1 | 0 | + | 20 | MANAGER | | | 0 | 0 | 1 | + | 20 | MANAGER | | | 0 | 0 | 1 | + | 20 | | 3000 | | 0 | 1 | 0 | + +--------+---------+------+----------+-------------+----------+----------+ + "); + Ok(()) +} + async fn run_count_distinct_integers_aggregated_scenario( partitions: Vec>, ) -> Result> { diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 42df1a8b07cd..d728707df13c 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -38,7 +38,7 @@ use crate::{ use datafusion_common::config::ConfigOptions; use datafusion_physical_expr::utils::collect_columns; use parking_lot::Mutex; -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; use arrow::array::{ArrayRef, UInt8Array, UInt16Array, UInt32Array, UInt64Array}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; @@ -1937,15 +1937,53 @@ fn evaluate_optional( .collect() } -fn group_id_array(group: &[bool], batch: &RecordBatch) -> Result { +fn group_id_array( + group: &[bool], + group_ordinal: usize, + batch: &RecordBatch, +) -> Result { if group.len() > 64 { return not_impl_err!( "Grouping sets with more than 64 columns are not supported" ); } + let width_bits = if group.len() <= 8 { + 8 + } else if group.len() <= 16 { + 16 + } else if group.len() <= 32 { + 32 + } else { + 64 + }; + let extra_bits = width_bits - group.len(); + if extra_bits == 0 && group_ordinal > 0 { + return not_impl_err!( + "Duplicate grouping sets with more than {} grouping columns are not supported", + width_bits + ); + } + if extra_bits < usize::BITS as usize { + let max_group_ordinal = 1usize << extra_bits; + if group_ordinal >= max_group_ordinal { + return not_impl_err!( + "Duplicate grouping sets exceed the supported grouping id capacity" + ); + } + } let group_id = group.iter().fold(0u64, |acc, &is_null| { (acc << 1) | if is_null { 1 } else { 0 } }); + let group_id = if group.len() == 64 { + if group_ordinal > 0 { + return not_impl_err!( + "Duplicate grouping sets with 64 grouping columns are not supported" + ); + } + group_id + } else { + ((group_ordinal as u64) << group.len()) | group_id + }; let num_rows = batch.num_rows(); if group.len() <= 8 { Ok(Arc::new(UInt8Array::from(vec![group_id as u8; num_rows]))) @@ -1972,6 +2010,7 @@ pub fn evaluate_group_by( group_by: &PhysicalGroupBy, batch: &RecordBatch, ) -> Result>> { + let mut group_ordinals: HashMap, usize> = HashMap::new(); let exprs = evaluate_expressions_to_arrays( group_by.expr.iter().map(|(expr, _)| expr), batch, @@ -1985,6 +2024,10 @@ pub fn evaluate_group_by( .groups .iter() .map(|group| { + let group_ordinal = group_ordinals.entry(group.clone()).or_insert(0); + let current_group_ordinal = *group_ordinal; + *group_ordinal += 1; + let mut group_values = Vec::with_capacity(group_by.num_group_exprs()); group_values.extend(group.iter().enumerate().map(|(idx, is_null)| { if *is_null { @@ -1994,7 +2037,7 @@ pub fn evaluate_group_by( } })); if !group_by.is_single() { - group_values.push(group_id_array(group, batch)?); + group_values.push(group_id_array(group, current_group_ordinal, batch)?); } Ok(group_values) }) @@ -2018,8 +2061,8 @@ mod tests { use crate::test::exec::{BlockingExec, assert_strong_count_converges_to_zero}; use arrow::array::{ - DictionaryArray, Float32Array, Float64Array, Int32Array, Int64Array, StructArray, - UInt32Array, UInt64Array, + DictionaryArray, Float32Array, Float64Array, Int32Array, Int64Array, StringArray, + StructArray, UInt32Array, UInt64Array, }; use arrow::compute::{SortOptions, concat_batches}; use arrow::datatypes::{DataType, Int32Type}; @@ -3478,6 +3521,71 @@ mod tests { Ok(()) } + #[tokio::test] + async fn grouping_sets_preserve_duplicate_groups() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("deptno", DataType::Int32, false), + Field::new("job", DataType::Utf8, true), + Field::new("sal", DataType::Float64, true), + Field::new("comm", DataType::Float64, true), + ])); + + let input = TestMemoryExec::try_new_exec( + &[vec![RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![10, 20])), + Arc::new(StringArray::from(vec![Some("CLERK"), Some("MANAGER")])), + Arc::new(Float64Array::from(vec![1300.0, 3000.0])), + Arc::new(Float64Array::from(vec![None, None])), + ], + )?]], + Arc::clone(&schema), + None, + )?; + + let group_by = PhysicalGroupBy::new( + vec![ + (col("deptno", &schema)?, "deptno".to_string()), + (col("job", &schema)?, "job".to_string()), + (col("sal", &schema)?, "sal".to_string()), + ], + vec![ + (lit(ScalarValue::Int32(None)), "deptno".to_string()), + (lit(ScalarValue::Utf8(None)), "job".to_string()), + (lit(ScalarValue::Float64(None)), "sal".to_string()), + ], + vec![ + vec![false, false, true], + vec![false, true, false], + vec![false, false, true], + ], + true, + ); + + let aggr_exprs: Vec> = vec![Arc::new( + AggregateExprBuilder::new(count_udaf(), vec![lit(1)]) + .schema(Arc::clone(&schema)) + .alias("COUNT(1)") + .build()?, + )]; + + let aggregate_exec = Arc::new(AggregateExec::try_new( + AggregateMode::Single, + group_by, + aggr_exprs, + vec![None], + Arc::clone(&input) as Arc, + Arc::clone(&schema), + )?); + + let output = + collect(aggregate_exec.execute(0, Arc::new(TaskContext::default()))?).await?; + let batch = concat_batches(&output[0].schema(), &output)?; + assert_eq!(batch.num_rows(), 6); + Ok(()) + } + // test for https://github.com/apache/datafusion/issues/13949 async fn run_test_with_spill_pool_if_necessary( pool_size: usize, diff --git a/datafusion/sqllogictest/test_files/group_by.slt b/datafusion/sqllogictest/test_files/group_by.slt index 294841552a66..de3ff9f20171 100644 --- a/datafusion/sqllogictest/test_files/group_by.slt +++ b/datafusion/sqllogictest/test_files/group_by.slt @@ -5203,6 +5203,28 @@ NULL NULL 1 statement ok drop table t; +# regression: duplicate grouping sets must not be collapsed into one +statement ok +create table duplicate_grouping_sets(deptno int, job varchar, sal int, comm int) as values +(10, 'CLERK', 1300, null), +(20, 'MANAGER', 3000, null); + +query IT?I?I?III +select deptno, job, sal, sum(comm), grouping(deptno), grouping(job), grouping(sal) +from duplicate_grouping_sets +group by grouping sets ((deptno, job), (deptno, sal), (deptno, job)) +order by deptno, job, sal, grouping(deptno), grouping(job), grouping(sal); +---- +10 CLERK NULL NULL 0 0 1 +10 CLERK NULL NULL 0 0 1 +10 NULL 1300 NULL 0 1 0 +20 MANAGER NULL NULL 0 0 1 +20 MANAGER NULL NULL 0 0 1 +20 NULL 3000 NULL 0 1 0 + +statement ok +drop table duplicate_grouping_sets; + # test multi group by for binary type without nulls statement ok create table t(a int, b bytea) as values (1, 0xa), (1, 0xa), (2, 0xb), (3, 0xb), (3, 0xb); From 09c92c84db9acce8db4fb54f2f13af522037f4c1 Mon Sep 17 00:00:00 2001 From: Jensen Date: Fri, 20 Mar 2026 00:40:19 +0800 Subject: [PATCH 2/3] Fix query syntax in group_by.slt test file --- datafusion/sqllogictest/test_files/group_by.slt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/sqllogictest/test_files/group_by.slt b/datafusion/sqllogictest/test_files/group_by.slt index de3ff9f20171..b855902ca67a 100644 --- a/datafusion/sqllogictest/test_files/group_by.slt +++ b/datafusion/sqllogictest/test_files/group_by.slt @@ -5209,7 +5209,7 @@ create table duplicate_grouping_sets(deptno int, job varchar, sal int, comm int) (10, 'CLERK', 1300, null), (20, 'MANAGER', 3000, null); -query IT?I?I?III +query ITIIIII select deptno, job, sal, sum(comm), grouping(deptno), grouping(job), grouping(sal) from duplicate_grouping_sets group by grouping sets ((deptno, job), (deptno, sal), (deptno, job)) From 1f584626e21f11046bca0360a95c5caf99bb6ffb Mon Sep 17 00:00:00 2001 From: Jensen Date: Fri, 20 Mar 2026 06:35:12 +0800 Subject: [PATCH 3/3] Fix formatting in group_by.slt test file --- datafusion/sqllogictest/test_files/group_by.slt | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/datafusion/sqllogictest/test_files/group_by.slt b/datafusion/sqllogictest/test_files/group_by.slt index b855902ca67a..6934ba3dba1a 100644 --- a/datafusion/sqllogictest/test_files/group_by.slt +++ b/datafusion/sqllogictest/test_files/group_by.slt @@ -5215,12 +5215,12 @@ from duplicate_grouping_sets group by grouping sets ((deptno, job), (deptno, sal), (deptno, job)) order by deptno, job, sal, grouping(deptno), grouping(job), grouping(sal); ---- -10 CLERK NULL NULL 0 0 1 -10 CLERK NULL NULL 0 0 1 -10 NULL 1300 NULL 0 1 0 +10 CLERK NULL NULL 0 0 1 +10 CLERK NULL NULL 0 0 1 +10 NULL 1300 NULL 0 1 0 20 MANAGER NULL NULL 0 0 1 20 MANAGER NULL NULL 0 0 1 -20 NULL 3000 NULL 0 1 0 +20 NULL 3000 NULL 0 1 0 statement ok drop table duplicate_grouping_sets;