Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions datafusion/common/src/hash_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1852,6 +1852,8 @@ mod tests {
assert_eq!(hashes1, hashes2);
}

// Tests actual values of hashes, which are different if forcing collisions
#[cfg(not(feature = "force_hash_collisions"))]
#[test]
#[cfg(not(feature = "force_hash_collisions"))]
fn test_create_hashes_with_quality_hash_state() {
Expand Down
97 changes: 97 additions & 0 deletions datafusion/core/tests/user_defined/user_defined_aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,55 @@ async fn test_udaf() {
assert!(!test_state.retract_batch());
}

#[tokio::test]
async fn test_zero_argument_udaf() {
let TestContext { mut ctx, .. } = TestContext::new();
NullaryAccumulator::register(&mut ctx, "window_start");

let actual = execute(&ctx, "SELECT window_start() from t").await.unwrap();

insta::assert_snapshot!(batches_to_string(&actual), @r"
+----------------+
| window_start() |
+----------------+
| 7 |
+----------------+
");

let actual = execute(
&ctx,
"SELECT window_start() FILTER (WHERE value < 0.0) FROM t",
)
.await
.unwrap();

insta::assert_snapshot!(batches_to_string(&actual), @r"
+----------------------------------------------------+
| window_start() FILTER (WHERE t.value < Float64(0)) |
+----------------------------------------------------+
| |
+----------------------------------------------------+
");

let actual = execute(
&ctx,
"SELECT value, window_start() FROM t GROUP BY value ORDER BY value",
)
.await
.unwrap();

insta::assert_snapshot!(batches_to_string(&actual), @r"
+-------+----------------+
| value | window_start() |
+-------+----------------+
| 1.0 | 7 |
| 2.0 | 7 |
| 3.0 | 7 |
| 5.0 | 7 |
+-------+----------------+
");
}

/// User defined aggregate used as a window function
#[tokio::test]
async fn test_udaf_as_window() {
Expand Down Expand Up @@ -559,6 +608,54 @@ impl TestState {
}
}

#[derive(Debug, Default)]
struct NullaryAccumulator {
value: Option<i64>,
}

impl NullaryAccumulator {
fn register(ctx: &mut SessionContext, name: &str) {
let accumulator: AccumulatorFactoryFunction =
Arc::new(|_| Ok(Box::<Self>::default()));

let udaf = AggregateUDF::from(SimpleAggregateUDF::new_with_signature(
name,
Signature::nullary(Volatility::Immutable),
DataType::Int64,
accumulator,
vec![Field::new("value", DataType::Int64, true).into()],
));

ctx.register_udaf(udaf)
}
}

impl Accumulator for NullaryAccumulator {
fn state(&mut self) -> Result<Vec<ScalarValue>> {
Ok(vec![self.evaluate()?])
}

fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
assert!(values.is_empty());
self.value = Some(7);
Ok(())
}

fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
assert_eq!(states.len(), 1);
self.value = Some(7);
Ok(())
}

fn evaluate(&mut self) -> Result<ScalarValue> {
Ok(ScalarValue::Int64(self.value))
}

fn size(&self) -> usize {
size_of_val(self)
}
}

/// Models a user defined aggregate function that computes the a sum
/// of timestamps (not a quantity that has much real world meaning)
#[derive(Debug)]
Expand Down
8 changes: 8 additions & 0 deletions datafusion/expr-common/src/type_coercion/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,14 @@ pub fn check_arg_count(
);
}
}
TypeSignature::Nullary => {
if !input_fields.is_empty() {
return plan_err!(
"The function {func_name} expects zero arguments, but {} were provided",
input_fields.len()
);
}
}
TypeSignature::OneOf(variants) => {
let ok = variants
.iter()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use std::mem::{size_of, size_of_val};

use arrow::array::new_empty_array;
use arrow::{
array::{ArrayRef, AsArray, BooleanArray, PrimitiveArray},
array::{Array, ArrayRef, AsArray, BooleanArray, PrimitiveArray},
compute,
compute::take_arrays,
datatypes::UInt32Type,
Expand Down Expand Up @@ -100,6 +100,12 @@ pub struct GroupsAccumulatorAdapter {
/// bottleneck in earlier implementations when there were many
/// distinct groups.
allocation_bytes: usize,

/// Whether this adapter can convert raw input rows directly to aggregate
/// state. Nullary aggregates do not carry a value array with the input row
/// cardinality, so this optimization cannot safely reconstruct one state
/// per input row for them.
supports_convert_to_state: bool,
}

struct AccumulatorState {
Expand Down Expand Up @@ -137,6 +143,24 @@ impl GroupsAccumulatorAdapter {
factory: Box::new(factory),
states: vec![],
allocation_bytes: 0,
supports_convert_to_state: true,
}
}

/// Create a new adapter with explicit control over whether row-to-state
/// conversion is supported.
pub fn new_with_convert_to_state<F>(
factory: F,
supports_convert_to_state: bool,
) -> Self
where
F: Fn() -> Result<Box<dyn Accumulator>> + Send + 'static,
{
Self {
factory: Box::new(factory),
states: vec![],
allocation_bytes: 0,
supports_convert_to_state,
}
}

Expand Down Expand Up @@ -196,13 +220,24 @@ impl GroupsAccumulatorAdapter {
{
self.make_accumulators_if_needed(total_num_groups)?;

assert_eq!(values[0].len(), group_indices.len());
if values.is_empty() {
for (idx, group_index) in group_indices.iter().enumerate() {
if opt_filter
.is_some_and(|filter| !filter.is_valid(idx) || !filter.value(idx))
{
continue;
}
self.states[*group_index].indices.push(idx as u32);
}
} else {
assert_eq!(values[0].len(), group_indices.len());

// figure out which input rows correspond to which groups.
// Note that self.state.indices starts empty for all groups
// (it is cleared out below)
for (idx, group_index) in group_indices.iter().enumerate() {
self.states[*group_index].indices.push(idx as u32);
// figure out which input rows correspond to which groups.
// Note that self.state.indices starts empty for all groups
// (it is cleared out below)
for (idx, group_index) in group_indices.iter().enumerate() {
self.states[*group_index].indices.push(idx as u32);
}
}

// groups_with_rows holds a list of group indexes that have
Expand Down Expand Up @@ -230,13 +265,18 @@ impl GroupsAccumulatorAdapter {
offset_so_far += indices.len();
offsets.push(offset_so_far);
}
let batch_indices = batch_indices.into();
let values_and_filter = if values.is_empty() {
None
} else {
let batch_indices = batch_indices.into();

// reorder the values and opt_filter by batch_indices so that
// all values for each group are contiguous, then invoke the
// accumulator once per group with values
let values = take_arrays(values, &batch_indices, None)?;
let opt_filter = get_filter_at_indices(opt_filter, &batch_indices)?;
// reorder the values and opt_filter by batch_indices so that
// all values for each group are contiguous, then invoke the
// accumulator once per group with values
let values = take_arrays(values, &batch_indices, None)?;
let opt_filter = get_filter_at_indices(opt_filter, &batch_indices)?;
Some((values, opt_filter))
};

// invoke each accumulator with the appropriate rows, first
// pulling the input arguments for this group into their own
Expand All @@ -249,11 +289,14 @@ impl GroupsAccumulatorAdapter {
let state = &mut self.states[group_idx];
sizes_pre += state.size();

let values_to_accumulate = slice_and_maybe_filter(
&values,
opt_filter.as_ref().map(|f| f.as_boolean()),
offsets,
)?;
let values_to_accumulate = match &values_and_filter {
Some((values, opt_filter)) => slice_and_maybe_filter(
values,
opt_filter.as_ref().map(|f| f.as_boolean()),
offsets,
)?,
None => vec![],
};
f(state.accumulator.as_mut(), &values_to_accumulate)?;

// clear out the state so they are empty for next
Expand Down Expand Up @@ -443,7 +486,7 @@ impl GroupsAccumulator for GroupsAccumulatorAdapter {
}

fn supports_convert_to_state(&self) -> bool {
true
self.supports_convert_to_state
}
}

Expand Down
10 changes: 7 additions & 3 deletions datafusion/physical-expr/src/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ use arrow::compute::SortOptions;
use arrow::datatypes::{DataType, FieldRef, Schema, SchemaRef};
use datafusion_common::metadata::FieldMetadata;
use datafusion_common::{
DFSchema, Result, ScalarValue, assert_or_internal_err, internal_err, not_impl_err,
DFSchema, Result, ScalarValue, internal_err, not_impl_err, plan_err,
};
use datafusion_expr::execution_props::ExecutionProps;
use datafusion_expr::expr::{
Expand Down Expand Up @@ -261,8 +261,6 @@ impl AggregateExprBuilder {
is_distinct,
is_reversed,
} = self;
assert_or_internal_err!(!args.is_empty(), "args should not be empty");

let ordering_types = order_bys
.iter()
.map(|e| e.expr.data_type(&schema))
Expand All @@ -275,6 +273,12 @@ impl AggregateExprBuilder {
.map(|arg| arg.return_field(&schema))
.collect::<Result<Vec<_>>>()?;

if args.is_empty() && fun.name() == "count" {
return plan_err!(
"Physical count aggregate requires an argument; use COUNT_STAR_EXPANSION for count(*)"
);
}

check_arg_count(
fun.name(),
&input_exprs_fields,
Expand Down
15 changes: 14 additions & 1 deletion datafusion/physical-plan/src/aggregates/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2433,7 +2433,7 @@ mod tests {
use arrow::compute::{SortOptions, concat_batches};
use arrow::datatypes::Int32Type;
use datafusion_common::test_util::{batches_to_sort_string, batches_to_string};
use datafusion_common::{DataFusionError, internal_err};
use datafusion_common::{DataFusionError, assert_contains, internal_err};
use datafusion_execution::config::SessionConfig;
use datafusion_execution::memory_pool::FairSpillPool;
use datafusion_execution::runtime_env::RuntimeEnvBuilder;
Expand Down Expand Up @@ -2469,6 +2469,19 @@ mod tests {
Ok(schema)
}

#[test]
fn count_requires_physical_argument() {
let err = AggregateExprBuilder::new(count_udaf(), vec![])
.alias("count")
.build()
.expect_err("empty-argument physical count should fail");

assert_contains!(
err.to_string(),
"Physical count aggregate requires an argument"
);
}

/// some mock data to aggregates
fn some_data() -> (Arc<Schema>, Vec<RecordBatch>) {
// define a schema.
Expand Down
7 changes: 7 additions & 0 deletions datafusion/physical-plan/src/aggregates/no_grouping.rs
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,13 @@ fn aggregate_batch(
// 1.3
let values = evaluate_expressions_to_arrays(expr, batch.as_ref())?;

if values.is_empty()
&& batch.num_rows() == 0
&& mode.input_mode() == AggregateInputMode::Raw
{
return Ok(());
}

// 1.4
let size_pre = accum.size();
let res = match mode.input_mode() {
Expand Down
8 changes: 7 additions & 1 deletion datafusion/physical-plan/src/aggregates/row_hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -621,8 +621,14 @@ pub(crate) fn create_group_accumulator(
agg_expr.name()
);
let agg_expr_captured = Arc::clone(agg_expr);
let supports_convert_to_state = !agg_expr.all_expressions().args.is_empty();
let factory = move || agg_expr_captured.create_accumulator();
Ok(Box::new(GroupsAccumulatorAdapter::new(factory)))
Ok(Box::new(
GroupsAccumulatorAdapter::new_with_convert_to_state(
factory,
supports_convert_to_state,
),
))
}
}

Expand Down
Loading