diff --git a/datafusion/functions-aggregate/Cargo.toml b/datafusion/functions-aggregate/Cargo.toml index c1b992a6d89b0..9561471a8f4d0 100644 --- a/datafusion/functions-aggregate/Cargo.toml +++ b/datafusion/functions-aggregate/Cargo.toml @@ -95,5 +95,9 @@ harness = false name = "percentile_cont" harness = false +[[bench]] +name = "any_value" +harness = false + [features] force_hash_collisions = ["datafusion-common/force_hash_collisions"] diff --git a/datafusion/functions-aggregate/benches/any_value.rs b/datafusion/functions-aggregate/benches/any_value.rs new file mode 100644 index 0000000000000..6f7096c1f9745 --- /dev/null +++ b/datafusion/functions-aggregate/benches/any_value.rs @@ -0,0 +1,130 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::hint::black_box; +use std::sync::Arc; + +use arrow::array::{ArrayRef, Int64Array, StringArray}; +use arrow::datatypes::{DataType, Field, Schema}; +use criterion::{BatchSize, Criterion, criterion_group, criterion_main}; +use datafusion_expr::function::AccumulatorArgs; +use datafusion_expr::{AggregateUDFImpl, EmitTo, GroupsAccumulator}; +use datafusion_functions_aggregate::any_value::AnyValue; +use datafusion_physical_expr::GroupsAccumulatorAdapter; +use datafusion_physical_expr::expressions::col; + +const BATCH_SIZE: usize = 8192; +const NUM_GROUPS: usize = 4096; + +fn with_accumulator_args( + data_type: DataType, + f: impl FnOnce(AccumulatorArgs<'_>) -> T, +) -> T { + let schema = Schema::new(vec![Field::new("value", data_type.clone(), true)]); + let expr = col("value", &schema).unwrap(); + let expr_fields = vec![expr.return_field(&schema).unwrap()]; + let exprs = vec![expr]; + let return_field = Field::new("any_value", data_type, true).into(); + + f(AccumulatorArgs { + return_field, + schema: &schema, + expr_fields: &expr_fields, + ignore_nulls: false, + order_bys: &[], + is_reversed: false, + name: "any_value(value)", + is_distinct: false, + exprs: &exprs, + }) +} + +fn native_accumulator(data_type: DataType) -> Box { + with_accumulator_args(data_type, |args| { + AnyValue::new().create_groups_accumulator(args).unwrap() + }) +} + +fn adapter_accumulator(data_type: DataType) -> Box { + Box::new(GroupsAccumulatorAdapter::new(move || { + with_accumulator_args(data_type.clone(), |args| AnyValue::new().accumulator(args)) + })) +} + +fn run_grouped( + accumulator: &mut dyn GroupsAccumulator, + values: &ArrayRef, + group_indices: &[usize], +) { + accumulator + .update_batch( + std::slice::from_ref(values), + group_indices, + None, + NUM_GROUPS, + ) + .unwrap(); + black_box(accumulator.evaluate(EmitTo::All).unwrap()); +} + +fn benchmark_type(c: &mut Criterion, name: &str, values: &ArrayRef) { + let group_indices = (0..BATCH_SIZE) + .map(|row| row % NUM_GROUPS) + .collect::>(); + let data_type = values.data_type().clone(); + + let mut group = c.benchmark_group(format!("any_value grouped {name}")); + group.bench_function("native", |b| { + b.iter_batched( + || native_accumulator(data_type.clone()), + |mut accumulator| { + run_grouped(accumulator.as_mut(), values, &group_indices); + }, + BatchSize::SmallInput, + ) + }); + group.bench_function("adapter", |b| { + b.iter_batched( + || adapter_accumulator(data_type.clone()), + |mut accumulator| { + run_grouped(accumulator.as_mut(), values, &group_indices); + }, + BatchSize::SmallInput, + ) + }); + group.finish(); +} + +fn criterion_benchmark(c: &mut Criterion) { + let int_values = Arc::new( + (0..BATCH_SIZE) + .map(|row| (row % 17 != 0).then_some(row as i64)) + .collect::(), + ) as ArrayRef; + benchmark_type(c, "int64", &int_values); + + let strings = (0..BATCH_SIZE) + .map(|row| (row % 17 != 0).then(|| format!("value-{row}"))) + .collect::>(); + let string_values = + Arc::new(StringArray::from_iter(strings.iter().map(Option::as_deref))) + as ArrayRef; + benchmark_type(c, "utf8", &string_values); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions-aggregate/src/any_value.rs b/datafusion/functions-aggregate/src/any_value.rs new file mode 100644 index 0000000000000..3fe85c38af8ee --- /dev/null +++ b/datafusion/functions-aggregate/src/any_value.rs @@ -0,0 +1,356 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines the ANY_VALUE aggregation. + +use std::fmt::Debug; +use std::hash::Hash; +use std::mem::size_of_val; +use std::sync::Arc; + +use arrow::array::{Array, ArrayRef, BooleanArray, BooleanBufferBuilder}; +use arrow::buffer::BooleanBuffer; +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion_common::cast::as_boolean_array; +use datafusion_common::{Result, ScalarValue, not_impl_err}; +use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion_expr::utils::{AggregateOrderSensitivity, format_state_name}; +use datafusion_expr::{ + Accumulator, AggregateUDFImpl, Documentation, EmitTo, GroupsAccumulator, Signature, + Volatility, +}; +use datafusion_macros::user_doc; + +use crate::first_last::TrivialFirstValueAccumulator; + +make_udaf_expr_and_func!( + AnyValue, + any_value, + expression, + "Returns an arbitrary non-null value", + any_value_udaf +); + +#[user_doc( + doc_section(label = "General Functions"), + description = "Returns an arbitrary non-null value from a group, or NULL if the group contains only NULL values.", + syntax_example = "any_value(expression)", + sql_example = r#"```sql +> SELECT any_value(column_name) FROM table_name; ++------------------------+ +| any_value(column_name) | ++------------------------+ +| arbitrary_value | ++------------------------+ +```"#, + standard_argument(name = "expression",) +)] +#[derive(PartialEq, Eq, Hash, Debug)] +pub struct AnyValue { + signature: Signature, +} + +impl Default for AnyValue { + fn default() -> Self { + Self::new() + } +} + +impl AnyValue { + pub fn new() -> Self { + Self { + signature: Signature::any(1, Volatility::Immutable), + } + } +} + +impl AggregateUDFImpl for AnyValue { + fn name(&self) -> &str { + "any_value" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + not_impl_err!("Not called because return_field is implemented") + } + + fn return_field(&self, arg_fields: &[FieldRef]) -> Result { + Ok(Arc::new( + Field::new(self.name(), arg_fields[0].data_type().clone(), true) + .with_metadata(arg_fields[0].metadata().clone()), + )) + } + + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + TrivialFirstValueAccumulator::try_new(acc_args.return_field.data_type(), true) + .map(|acc| Box::new(acc) as _) + } + + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + Ok(vec![ + Field::new( + format_state_name(args.name, "any_value"), + args.return_type().clone(), + true, + ) + .into(), + Field::new( + format_state_name(args.name, "any_value_is_set"), + DataType::Boolean, + true, + ) + .into(), + ]) + } + + fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { + true + } + + fn create_groups_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result> { + Ok(Box::new(AnyValueGroupsAccumulator::try_new( + args.return_field.data_type(), + )?)) + } + + fn order_sensitivity(&self) -> AggregateOrderSensitivity { + AggregateOrderSensitivity::Insensitive + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} + +#[derive(Debug)] +struct AnyValueGroupsAccumulator { + values: Vec, + is_set: BooleanBufferBuilder, + null_value: ScalarValue, +} + +impl AnyValueGroupsAccumulator { + fn try_new(data_type: &DataType) -> Result { + Ok(Self { + values: vec![], + is_set: BooleanBufferBuilder::new(0), + null_value: ScalarValue::try_from(data_type)?, + }) + } + + fn ensure_groups(&mut self, total_num_groups: usize) { + if self.values.len() < total_num_groups { + self.values + .resize(total_num_groups, self.null_value.clone()); + self.is_set.resize(total_num_groups); + } + } + + fn take_state(&mut self, emit_to: EmitTo) -> (Vec, BooleanBuffer) { + let values = emit_to.take_needed(&mut self.values); + let is_set = self.is_set.finish(); + match emit_to { + EmitTo::All => (values, is_set), + EmitTo::First(n) => { + let emitted = is_set.slice(0, n); + self.is_set + .append_buffer(&is_set.slice(n, is_set.len() - n)); + (values, emitted) + } + } + } + + fn update_row( + &mut self, + values: &ArrayRef, + group_index: usize, + row: usize, + ) -> Result<()> { + if !self.is_set.get_bit(group_index) && values.is_valid(row) { + self.values[group_index] = ScalarValue::try_from_array(values, row)?; + self.is_set.set_bit(group_index, true); + } + Ok(()) + } +} + +impl GroupsAccumulator for AnyValueGroupsAccumulator { + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + assert_eq!(values.len(), 1, "any_value expects one argument"); + let values = &values[0]; + self.ensure_groups(total_num_groups); + + for (row, &group_index) in group_indices.iter().enumerate() { + if opt_filter.is_none_or(|filter| filter.is_valid(row) && filter.value(row)) { + self.update_row(values, group_index, row)?; + } + } + Ok(()) + } + + fn evaluate(&mut self, emit_to: EmitTo) -> Result { + let (values, _) = self.take_state(emit_to); + ScalarValue::iter_to_array(values) + } + + fn state(&mut self, emit_to: EmitTo) -> Result> { + let (values, is_set) = self.take_state(emit_to); + Ok(vec![ + ScalarValue::iter_to_array(values)?, + Arc::new(BooleanArray::new(is_set, None)), + ]) + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + total_num_groups: usize, + ) -> Result<()> { + assert_eq!(values.len(), 2, "any_value expects value and is_set state"); + let is_set = as_boolean_array(&values[1])?; + self.ensure_groups(total_num_groups); + + for (row, &group_index) in group_indices.iter().enumerate() { + if is_set.is_valid(row) + && is_set.value(row) + && !self.is_set.get_bit(group_index) + { + self.values[group_index] = ScalarValue::try_from_array(&values[0], row)?; + self.is_set.set_bit(group_index, true); + } + } + Ok(()) + } + + fn supports_convert_to_state(&self) -> bool { + true + } + + fn convert_to_state( + &self, + values: &[ArrayRef], + opt_filter: Option<&BooleanArray>, + ) -> Result> { + assert_eq!(values.len(), 1, "any_value expects one argument"); + let values = &values[0]; + let is_set = BooleanArray::from_iter((0..values.len()).map(|row| { + values.is_valid(row) + && opt_filter + .is_none_or(|filter| filter.is_valid(row) && filter.value(row)) + })); + Ok(vec![Arc::clone(values), Arc::new(is_set)]) + } + + fn size(&self) -> usize { + size_of_val(self) + + ScalarValue::size_of_vec(&self.values) + + self.is_set.capacity() / 8 + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{Int64Array, StringArray}; + + #[test] + fn groups_accumulator_uses_first_non_null_value() -> Result<()> { + let mut acc = AnyValueGroupsAccumulator::try_new(&DataType::Int64)?; + let values = Arc::new(Int64Array::from(vec![ + None, + Some(10), + Some(11), + Some(20), + None, + Some(30), + ])) as ArrayRef; + let filter = BooleanArray::from(vec![true, true, true, false, true, true]); + + acc.update_batch(&[values], &[0, 0, 0, 1, 1, 2], Some(&filter), 4)?; + let result = acc.evaluate(EmitTo::All)?; + let expected = + Arc::new(Int64Array::from(vec![Some(10), None, Some(30), None])) as ArrayRef; + assert_eq!(&result, &expected); + Ok(()) + } + + #[test] + fn groups_accumulator_merges_partial_state() -> Result<()> { + let mut acc = AnyValueGroupsAccumulator::try_new(&DataType::Utf8)?; + let values = Arc::new(StringArray::from(vec![Some("a"), Some("b"), Some("c")])) + as ArrayRef; + let is_set = Arc::new(BooleanArray::from(vec![false, true, true])) as ArrayRef; + + acc.merge_batch(&[values, is_set], &[0, 0, 1], 3)?; + let state = acc.state(EmitTo::All)?; + let expected_values = + Arc::new(StringArray::from(vec![Some("b"), Some("c"), None])) as ArrayRef; + let expected_is_set = + Arc::new(BooleanArray::from(vec![true, true, false])) as ArrayRef; + assert_eq!(&state[0], &expected_values); + assert_eq!(&state[1], &expected_is_set); + Ok(()) + } + + #[test] + fn groups_accumulator_convert_to_state_applies_filter_and_nulls() -> Result<()> { + let acc = AnyValueGroupsAccumulator::try_new(&DataType::Int64)?; + let values = Arc::new(Int64Array::from(vec![Some(1), None, Some(3)])) as ArrayRef; + let filter = BooleanArray::from(vec![true, true, false]); + + let state = acc.convert_to_state(&[Arc::clone(&values)], Some(&filter))?; + let expected_is_set = + Arc::new(BooleanArray::from(vec![true, false, false])) as ArrayRef; + assert_eq!(&state[0], &values); + assert_eq!(&state[1], &expected_is_set); + Ok(()) + } + + #[test] + fn groups_accumulator_emit_first_retains_remaining_groups() -> Result<()> { + let mut acc = AnyValueGroupsAccumulator::try_new(&DataType::Int64)?; + let values = + Arc::new(Int64Array::from(vec![Some(10), Some(20), Some(30)])) as ArrayRef; + acc.update_batch(&[values], &[0, 1, 2], None, 3)?; + + let first = acc.evaluate(EmitTo::First(2))?; + let expected_first = + Arc::new(Int64Array::from(vec![Some(10), Some(20)])) as ArrayRef; + assert_eq!(&first, &expected_first); + + let values = Arc::new(Int64Array::from(vec![Some(31), Some(40)])) as ArrayRef; + acc.update_batch(&[values], &[0, 1], None, 2)?; + let remaining = acc.evaluate(EmitTo::All)?; + let expected_remaining = + Arc::new(Int64Array::from(vec![Some(30), Some(40)])) as ArrayRef; + assert_eq!(&remaining, &expected_remaining); + Ok(()) + } +} diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs index 1b9996220d882..e3f2714abbf25 100644 --- a/datafusion/functions-aggregate/src/lib.rs +++ b/datafusion/functions-aggregate/src/lib.rs @@ -65,6 +65,7 @@ #[macro_use] pub mod macros; +pub mod any_value; pub mod approx_distinct; pub mod approx_median; pub mod approx_percentile_cont; @@ -102,6 +103,7 @@ use std::sync::Arc; /// Fluent-style API for creating `Expr`s pub mod expr_fn { + pub use super::any_value::any_value; pub use super::approx_distinct::approx_distinct; pub use super::approx_median::approx_median; pub use super::approx_percentile_cont::approx_percentile_cont; @@ -147,6 +149,7 @@ pub mod expr_fn { /// Returns all default aggregate functions pub fn all_default_aggregate_functions() -> Vec> { vec![ + any_value::any_value_udaf(), array_agg::array_agg_udaf(), first_last::first_value_udaf(), first_last::last_value_udaf(), diff --git a/datafusion/sqllogictest/test_files/aggregate_any_value.slt b/datafusion/sqllogictest/test_files/aggregate_any_value.slt new file mode 100644 index 0000000000000..3fe6f787d346d --- /dev/null +++ b/datafusion/sqllogictest/test_files/aggregate_any_value.slt @@ -0,0 +1,57 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +statement ok +CREATE TABLE any_value_test AS VALUES + (1, NULL, NULL), + (1, 10, 'first'), + (1, 20, 'second'), + (2, NULL, NULL), + (2, NULL, NULL), + (3, 30, 'third'); + +query B +SELECT any_value(column2) IN (10, 20) FROM any_value_test; +---- +true + +query IBB rowsort +SELECT + column1, + any_value(column2) IN (10, 20, 30), + any_value(column3) IN ('first', 'second', 'third') +FROM any_value_test +GROUP BY column1; +---- +1 true true +2 NULL NULL +3 true true + +query T +SELECT arrow_typeof(any_value(column3)) FROM any_value_test; +---- +Utf8 + +query I +SELECT any_value(column2) FROM any_value_test WHERE false; +---- +NULL + +query I +SELECT any_value(column2) FROM any_value_test WHERE column1 = 2; +---- +NULL