Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions datafusion/core/tests/sql/aggregates/basic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,57 @@ async fn count_aggregated_cube() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn duplicate_grouping_sets_are_preserved() -> Result<()> {
let ctx = SessionContext::new();
let schema = Arc::new(Schema::new(vec![
Field::new("deptno", DataType::Int32, false),
Field::new("job", DataType::Utf8, true),
Field::new("sal", DataType::Int32, true),
Field::new("comm", DataType::Int32, true),
]));
let batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(Int32Array::from(vec![10, 20])),
Arc::new(StringArray::from(vec![Some("CLERK"), Some("MANAGER")])),
Arc::new(Int32Array::from(vec![1300, 3000])),
Arc::new(Int32Array::from(vec![None, None])),
],
)?;
let provider = MemTable::try_new(Arc::clone(&schema), vec![vec![batch]])?;
ctx.register_table("dup_grouping_sets", Arc::new(provider))?;

let results = plan_and_collect(
&ctx,
"
SELECT deptno, job, sal, sum(comm) AS sum_comm,
grouping(deptno) AS deptno_flag,
grouping(job) AS job_flag,
grouping(sal) AS sal_flag
FROM dup_grouping_sets
GROUP BY GROUPING SETS ((deptno, job), (deptno, sal), (deptno, job))
ORDER BY deptno, job, sal, deptno_flag, job_flag, sal_flag
",
)
.await?;

assert_eq!(results.len(), 1);
assert_snapshot!(batches_to_string(&results), @r"
+--------+---------+------+----------+-------------+----------+----------+
| deptno | job | sal | sum_comm | deptno_flag | job_flag | sal_flag |
+--------+---------+------+----------+-------------+----------+----------+
| 10 | CLERK | | | 0 | 0 | 1 |
| 10 | CLERK | | | 0 | 0 | 1 |
| 10 | | 1300 | | 0 | 1 | 0 |
| 20 | MANAGER | | | 0 | 0 | 1 |
| 20 | MANAGER | | | 0 | 0 | 1 |
| 20 | | 3000 | | 0 | 1 | 0 |
+--------+---------+------+----------+-------------+----------+----------+
");
Ok(())
}

async fn run_count_distinct_integers_aggregated_scenario(
partitions: Vec<Vec<(&str, u64)>>,
) -> Result<Vec<RecordBatch>> {
Expand Down
118 changes: 113 additions & 5 deletions datafusion/physical-plan/src/aggregates/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ use crate::{
use datafusion_common::config::ConfigOptions;
use datafusion_physical_expr::utils::collect_columns;
use parking_lot::Mutex;
use std::collections::HashSet;
use std::collections::{HashMap, HashSet};

use arrow::array::{ArrayRef, UInt8Array, UInt16Array, UInt32Array, UInt64Array};
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
Expand Down Expand Up @@ -1937,15 +1937,53 @@ fn evaluate_optional(
.collect()
}

fn group_id_array(group: &[bool], batch: &RecordBatch) -> Result<ArrayRef> {
fn group_id_array(
group: &[bool],
group_ordinal: usize,
batch: &RecordBatch,
) -> Result<ArrayRef> {
if group.len() > 64 {
return not_impl_err!(
"Grouping sets with more than 64 columns are not supported"
);
}
let width_bits = if group.len() <= 8 {
8
} else if group.len() <= 16 {
16
} else if group.len() <= 32 {
32
} else {
64
};
let extra_bits = width_bits - group.len();
if extra_bits == 0 && group_ordinal > 0 {
return not_impl_err!(
"Duplicate grouping sets with more than {} grouping columns are not supported",
width_bits
);
}
if extra_bits < usize::BITS as usize {
let max_group_ordinal = 1usize << extra_bits;
if group_ordinal >= max_group_ordinal {
return not_impl_err!(
"Duplicate grouping sets exceed the supported grouping id capacity"
);
}
}
let group_id = group.iter().fold(0u64, |acc, &is_null| {
(acc << 1) | if is_null { 1 } else { 0 }
});
let group_id = if group.len() == 64 {
if group_ordinal > 0 {
return not_impl_err!(
"Duplicate grouping sets with 64 grouping columns are not supported"
);
}
group_id
} else {
((group_ordinal as u64) << group.len()) | group_id
};
let num_rows = batch.num_rows();
if group.len() <= 8 {
Ok(Arc::new(UInt8Array::from(vec![group_id as u8; num_rows])))
Expand All @@ -1972,6 +2010,7 @@ pub fn evaluate_group_by(
group_by: &PhysicalGroupBy,
batch: &RecordBatch,
) -> Result<Vec<Vec<ArrayRef>>> {
let mut group_ordinals: HashMap<Vec<bool>, usize> = HashMap::new();
let exprs = evaluate_expressions_to_arrays(
group_by.expr.iter().map(|(expr, _)| expr),
batch,
Expand All @@ -1985,6 +2024,10 @@ pub fn evaluate_group_by(
.groups
.iter()
.map(|group| {
let group_ordinal = group_ordinals.entry(group.clone()).or_insert(0);
let current_group_ordinal = *group_ordinal;
*group_ordinal += 1;

let mut group_values = Vec::with_capacity(group_by.num_group_exprs());
group_values.extend(group.iter().enumerate().map(|(idx, is_null)| {
if *is_null {
Expand All @@ -1994,7 +2037,7 @@ pub fn evaluate_group_by(
}
}));
if !group_by.is_single() {
group_values.push(group_id_array(group, batch)?);
group_values.push(group_id_array(group, current_group_ordinal, batch)?);
}
Ok(group_values)
})
Expand All @@ -2018,8 +2061,8 @@ mod tests {
use crate::test::exec::{BlockingExec, assert_strong_count_converges_to_zero};

use arrow::array::{
DictionaryArray, Float32Array, Float64Array, Int32Array, Int64Array, StructArray,
UInt32Array, UInt64Array,
DictionaryArray, Float32Array, Float64Array, Int32Array, Int64Array, StringArray,
StructArray, UInt32Array, UInt64Array,
};
use arrow::compute::{SortOptions, concat_batches};
use arrow::datatypes::{DataType, Int32Type};
Expand Down Expand Up @@ -3478,6 +3521,71 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn grouping_sets_preserve_duplicate_groups() -> Result<()> {
let schema = Arc::new(Schema::new(vec![
Field::new("deptno", DataType::Int32, false),
Field::new("job", DataType::Utf8, true),
Field::new("sal", DataType::Float64, true),
Field::new("comm", DataType::Float64, true),
]));

let input = TestMemoryExec::try_new_exec(
&[vec![RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(Int32Array::from(vec![10, 20])),
Arc::new(StringArray::from(vec![Some("CLERK"), Some("MANAGER")])),
Arc::new(Float64Array::from(vec![1300.0, 3000.0])),
Arc::new(Float64Array::from(vec![None, None])),
],
)?]],
Arc::clone(&schema),
None,
)?;

let group_by = PhysicalGroupBy::new(
vec![
(col("deptno", &schema)?, "deptno".to_string()),
(col("job", &schema)?, "job".to_string()),
(col("sal", &schema)?, "sal".to_string()),
],
vec![
(lit(ScalarValue::Int32(None)), "deptno".to_string()),
(lit(ScalarValue::Utf8(None)), "job".to_string()),
(lit(ScalarValue::Float64(None)), "sal".to_string()),
],
vec![
vec![false, false, true],
vec![false, true, false],
vec![false, false, true],
],
true,
);

let aggr_exprs: Vec<Arc<AggregateFunctionExpr>> = vec![Arc::new(
AggregateExprBuilder::new(count_udaf(), vec![lit(1)])
.schema(Arc::clone(&schema))
.alias("COUNT(1)")
.build()?,
)];

let aggregate_exec = Arc::new(AggregateExec::try_new(
AggregateMode::Single,
group_by,
aggr_exprs,
vec![None],
Arc::clone(&input) as Arc<dyn ExecutionPlan>,
Arc::clone(&schema),
)?);

let output =
collect(aggregate_exec.execute(0, Arc::new(TaskContext::default()))?).await?;
let batch = concat_batches(&output[0].schema(), &output)?;
assert_eq!(batch.num_rows(), 6);
Ok(())
}

// test for https://github.com/apache/datafusion/issues/13949
async fn run_test_with_spill_pool_if_necessary(
pool_size: usize,
Expand Down
22 changes: 22 additions & 0 deletions datafusion/sqllogictest/test_files/group_by.slt
Original file line number Diff line number Diff line change
Expand Up @@ -5203,6 +5203,28 @@ NULL NULL 1
statement ok
drop table t;

# regression: duplicate grouping sets must not be collapsed into one
statement ok
create table duplicate_grouping_sets(deptno int, job varchar, sal int, comm int) as values
(10, 'CLERK', 1300, null),
(20, 'MANAGER', 3000, null);

query ITIIIII
select deptno, job, sal, sum(comm), grouping(deptno), grouping(job), grouping(sal)
from duplicate_grouping_sets
group by grouping sets ((deptno, job), (deptno, sal), (deptno, job))
order by deptno, job, sal, grouping(deptno), grouping(job), grouping(sal);
----
10 CLERK NULL NULL 0 0 1
10 CLERK NULL NULL 0 0 1
10 NULL 1300 NULL 0 1 0
20 MANAGER NULL NULL 0 0 1
20 MANAGER NULL NULL 0 0 1
20 NULL 3000 NULL 0 1 0

statement ok
drop table duplicate_grouping_sets;

# test multi group by for binary type without nulls
statement ok
create table t(a int, b bytea) as values (1, 0xa), (1, 0xa), (2, 0xb), (3, 0xb), (3, 0xb);
Expand Down
Loading