From 0342e5c29066c858d76b6a1a504dee9e4a7ae790 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Fri, 22 Aug 2025 10:23:31 -0400 Subject: [PATCH 1/3] Add in test and resolve errors related to return types --- datafusion/core/tests/dataframe/mod.rs | 46 +- datafusion/functions-nested/src/lib.rs | 1 + datafusion/functions-nested/src/transform.rs | 438 +++++++++++++++++++ 3 files changed, 480 insertions(+), 5 deletions(-) create mode 100644 datafusion/functions-nested/src/transform.rs diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 27afbd7246601..9ff901fc85483 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -20,12 +20,12 @@ mod dataframe_functions; mod describe; use arrow::array::{ - record_batch, Array, ArrayRef, BooleanArray, DictionaryArray, FixedSizeListArray, - FixedSizeListBuilder, Float32Array, Float64Array, Int32Array, Int32Builder, - Int8Array, LargeListArray, ListArray, ListBuilder, RecordBatch, StringArray, - StringBuilder, StructBuilder, UInt32Array, UInt32Builder, UnionArray, + as_list_array, record_batch, Array, ArrayRef, BooleanArray, DictionaryArray, + FixedSizeListArray, FixedSizeListBuilder, Float32Array, Float64Array, Int32Array, + Int32Builder, Int8Array, LargeListArray, ListArray, ListBuilder, RecordBatch, + StringArray, StringBuilder, StructBuilder, UInt32Array, UInt32Builder, UnionArray, }; -use arrow::buffer::ScalarBuffer; +use arrow::buffer::{NullBuffer, OffsetBuffer, ScalarBuffer}; use arrow::datatypes::{ DataType, Field, Float32Type, Int32Type, Schema, SchemaRef, UInt64Type, UnionFields, UnionMode, @@ -79,6 +79,7 @@ use datafusion_expr::{ LogicalPlanBuilder, ScalarFunctionImplementation, SortExpr, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; +use datafusion_functions_nested::transform::array_transform_udf; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::Partitioning; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; @@ -6245,3 +6246,38 @@ async fn test_copy_to_preserves_order() -> Result<()> { ); Ok(()) } + +#[tokio::test] +async fn test_function_array_transform() -> Result<()> { + let values = Arc::new(Int32Array::from(vec![1, -2, 3, -4, 5, -6, 7, -8])); + let field = Arc::new(Field::new("a", DataType::Int32, false)); + let offsets: OffsetBuffer = OffsetBuffer::from_lengths([3, 2, 3]); + + let outer = Arc::new(ListArray::try_new(field, offsets, values, None)?); + + let df = DataFrame::from_columns(vec![("a", outer)])?; + + let udf = array_transform_udf(datafusion_functions::math::abs(), 0); + + let df = df.select([col("a"), udf.call(vec![col("a")]).alias("abs(a[])")])?; + + let results = df.collect().await?; + let result_column = as_list_array(results[0].column(1)); + assert_eq!(result_column.len(), 3); + + let expected_values = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8])); + let expected_field = Arc::new(Field::new("abs", DataType::Int32, true)); + let expected_offsets: OffsetBuffer = OffsetBuffer::from_lengths([3, 2, 3]); + let expected_nulls = NullBuffer::new_valid(3); + + let expected = Arc::new(ListArray::try_new( + expected_field, + expected_offsets, + expected_values, + Some(expected_nulls), + )?) as ArrayRef; + + assert_eq!(results[0].column(1), &expected); + + Ok(()) +} diff --git a/datafusion/functions-nested/src/lib.rs b/datafusion/functions-nested/src/lib.rs index 1d3f11b50c613..5a1bb6bac66d0 100644 --- a/datafusion/functions-nested/src/lib.rs +++ b/datafusion/functions-nested/src/lib.rs @@ -66,6 +66,7 @@ pub mod reverse; pub mod set_ops; pub mod sort; pub mod string; +pub mod transform; pub mod utils; use datafusion_common::Result; diff --git a/datafusion/functions-nested/src/transform.rs b/datafusion/functions-nested/src/transform.rs new file mode 100644 index 0000000000000..7f28ae3c91dc1 --- /dev/null +++ b/datafusion/functions-nested/src/transform.rs @@ -0,0 +1,438 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`ScalarUDFImpl`] definition for array_transform function. + +use arrow::array::{Array, ArrayRef, GenericListArray, NullArray}; +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion_common::cast::as_large_list_array; +use datafusion_common::cast::as_list_array; +use datafusion_common::{exec_datafusion_err, exec_err, plan_err, Result}; +use datafusion_expr::type_coercion::functions::data_types_with_scalar_udf; +use datafusion_expr::{ + ColumnarValue, Documentation, ReturnFieldArgs, ScalarUDFImpl, Signature, + TypeSignature, +}; +use datafusion_expr::{Expr, ScalarUDF}; +use datafusion_macros::user_doc; +use std::any::Any; +use std::sync::Arc; + +#[doc = "ScalarFunction that returns a [`ScalarUDF`](datafusion_expr::ScalarUDF) for "] +#[doc = "ArrayTransform"] +pub fn array_transform_udf( + inner: Arc, + argument_index: usize, +) -> Arc { + Arc::new(ScalarUDF::new_from_impl(::new( + inner, + argument_index, + ))) +} + +#[user_doc( + doc_section(label = "Array Transform"), + description = "Transform every element of an array according to a scalar function.", + syntax_example = "array_transform(inner_function(arg1, arg2, arg3), arg_index)", + sql_example = r#"```sql +> select array_transform(abs([-3,1,-4,2]), 0); ++-----------------------------------------+ +| array_transform(abs([-3,1,-4,2]), 0) | ++-----------------------------------------+ +| [3,1,4,2] | ++-----------------------------------------+ +```"#, + argument( + name = "inner_function", + description = "Scalar function with arguments." + ), + argument( + name = "arg_index", + description = "0 based index that specifies which argument to the scalar function represents the array to transform." + ) +)] +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct ArrayTransform { + signature: Signature, + aliases: Vec, + function: Arc, + argument_index: usize, +} + +impl ArrayTransform { + pub fn new(function: Arc, argument_index: usize) -> Self { + let signature = Signature { + type_signature: TypeSignature::UserDefined, + volatility: function.signature().volatility, + }; + + Self { + signature, + aliases: vec![String::from("list_transform")], + function, + argument_index, + } + } +} + +impl ScalarUDFImpl for ArrayTransform { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "array_transform" + } + + fn display_name(&self, args: &[Expr]) -> Result { + let mut arg_names = args.iter().map(ToString::to_string).collect::>(); + arg_names.insert(0, "[]".to_string()); + + Ok(format!( + "{}({})]", + self.function.name(), + arg_names.join(", ") + )) + } + + fn schema_name(&self, args: &[Expr]) -> Result { + let mut arg_names = args.iter().map(ToString::to_string).collect::>(); + arg_names.insert(0, "[]".to_string()); + + Ok(format!( + "{}({})]", + self.function.name(), + arg_names.join(", ") + )) + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + unimplemented!() + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let replacement_field = + args.arg_fields + .get(self.argument_index) + .ok_or(exec_datafusion_err!( + "Invalid argument index {} for number of arguments provided {}", + self.argument_index, + args.arg_fields.len() + ))?; + + let replacement_field = match replacement_field.data_type() { + DataType::Null => { + return Ok(Arc::new(Field::new("null", DataType::Null, true))) + } + DataType::List(field) + | DataType::LargeList(field) + | DataType::ListView(field) + | DataType::LargeListView(field) => Ok(Arc::clone(field)), + arg_type => plan_err!("{} does not support type {arg_type}", self.name()), + }?; + + let mut inner_arg_fields = args.arg_fields.to_vec(); + inner_arg_fields[self.argument_index] = replacement_field; + + let inner_args = ReturnFieldArgs { + arg_fields: &inner_arg_fields, + scalar_arguments: args.scalar_arguments, + }; + + let inner_return = self.function.return_field_from_args(inner_args)?; + let name = inner_return.name().to_owned(); + + match args.arg_fields[self.argument_index].data_type() { + DataType::List(_) => Ok(Arc::new(Field::new( + name, + DataType::List(inner_return), + true, + ))), + DataType::ListView(_) => Ok(Arc::new(Field::new( + name, + DataType::ListView(inner_return), + true, + ))), + DataType::LargeList(_) => Ok(Arc::new(Field::new( + name, + DataType::LargeList(inner_return), + true, + ))), + DataType::LargeListView(_) => Ok(Arc::new(Field::new( + name, + DataType::LargeListView(inner_return), + true, + ))), + _ => unreachable!(), + } + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + if self.argument_index >= arg_types.len() { + return exec_err!( + "Invalid argument index {} for array_transform with {} arguments.", + self.argument_index, + arg_types.len() + ); + } + + let mut replacement_types = arg_types.to_vec(); + let replacement = match &arg_types[self.argument_index] { + DataType::List(field) + | DataType::LargeList(field) + | DataType::ListView(field) + | DataType::LargeListView(field) => field.data_type().clone(), + _ => { + return exec_err!( + "Expected list type for the argument index {} in array_transform", + self.argument_index + ) + } + }; + replacement_types[self.argument_index] = replacement; + + let mut return_types = + data_types_with_scalar_udf(&replacement_types, self.function.as_ref())?; + + let replacement_type = return_types[self.argument_index].clone(); + return_types[self.argument_index] = match &arg_types[self.argument_index] { + DataType::List(field) => { + DataType::List(Arc::new(Field::new(field.name(), replacement_type, true))) + } + DataType::LargeList(field) => DataType::LargeList(Arc::new(Field::new( + field.name(), + replacement_type, + true, + ))), + DataType::ListView(field) => DataType::ListView(Arc::new(Field::new( + field.name(), + replacement_type, + true, + ))), + DataType::LargeListView(field) => DataType::LargeListView(Arc::new( + Field::new(field.name(), replacement_type, true), + )), + _ => unreachable!(), + }; + + Ok(return_types) + } + + fn invoke_with_args( + &self, + mut args: datafusion_expr::ScalarFunctionArgs, + ) -> Result { + let return_field = match args.return_field.data_type() { + DataType::List(field) + | DataType::LargeList(field) + | DataType::ListView(field) + | DataType::LargeListView(field) => Arc::clone(field), + _ => { + return exec_err!( + "Unexpected return field for array_transform. Expected list data type." + ) + } + }; + + let replacement_array = + args.args + .get(self.argument_index) + .ok_or(exec_datafusion_err!( + "Invalid number of arguments. Expected at least {} but received {}", + self.argument_index + 1, + args.args.len() + ))?; + + let ColumnarValue::Array(replacement_array) = replacement_array else { + return exec_err!("Unexpected scalar value in array_transform"); + }; + + let result = match &replacement_array.data_type() { + DataType::Null => { + Ok(Arc::new(NullArray::new(replacement_array.len())) as ArrayRef) + } + DataType::List(_) => { + let array = as_list_array(&replacement_array)?; + let offsets = array.offsets().clone(); + let nulls = array.nulls().cloned(); + + let values = array.values(); + + args.args[self.argument_index] = ColumnarValue::Array(Arc::clone(values)); + + let results = self.function.invoke_with_args(args)?; + + let ColumnarValue::Array(result_array) = results else { + return Ok(results); + }; + + Ok(Arc::new(GenericListArray::try_new( + return_field, + offsets, + result_array, + nulls, + )?) as ArrayRef) + } + DataType::LargeList(_) => { + let array = as_large_list_array(&replacement_array)?; + let offsets = array.offsets().clone(); + let nulls = array.nulls().cloned(); + + let values = array.values(); + + args.args[self.argument_index] = ColumnarValue::Array(Arc::clone(values)); + + let results = self.function.invoke_with_args(args)?; + + let ColumnarValue::Array(result_array) = results else { + return Ok(results); + }; + + Ok(Arc::new(GenericListArray::try_new( + return_field, + offsets, + result_array, + nulls, + )?) as ArrayRef) + } + arg_type => { + exec_err!("array_transform does not support type {arg_type}") + } + }?; + + Ok(ColumnarValue::Array(result)) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} + +#[cfg(test)] +mod tests { + use super::array_transform_udf; + use arrow::array::{create_array, ArrayRef, GenericListArray}; + use arrow::datatypes::{DataType, Field}; + use datafusion_common::utils::SingleRowListArrayBuilder; + use datafusion_common::{exec_err, DataFusionError, ScalarValue}; + use datafusion_expr::{ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs}; + use datafusion_functions::math::{abs, round}; + use std::sync::Arc; + + #[test] + fn test_array_transform_apply_single_valued_function() -> Result<(), DataFusionError> + { + let udf = array_transform_udf(abs(), 0); + + let data = SingleRowListArrayBuilder::new(create_array!( + Int32, + [Some(1), Some(-2), None] + )) + .build_list_array(); + let data = Arc::new(data) as ArrayRef; + let input_field = Arc::new(Field::new( + "a", + DataType::List(Field::new("b", DataType::Int32, true).into()), + true, + )); + let return_field = udf.return_field_from_args(ReturnFieldArgs { + arg_fields: &[Arc::clone(&input_field)], + scalar_arguments: &[None], + })?; + + let args = ScalarFunctionArgs { + args: vec![ColumnarValue::Array(data)], + arg_fields: vec![input_field], + number_rows: 3, + return_field, + config_options: Arc::new(Default::default()), + }; + + let ColumnarValue::Array(result) = udf.invoke_with_args(args)? else { + return exec_err!("Invalid return type"); + }; + let list_array = result + .as_any() + .downcast_ref::>() + .unwrap(); + + let expected = create_array!(Int32, [Some(1), Some(2), None]) as ArrayRef; + + assert_eq!(&list_array.value(0), &expected); + + Ok(()) + } + + #[test] + fn test_array_transform_test_argument_index() -> Result<(), DataFusionError> { + let udf = array_transform_udf(round(), 1); + + let data = SingleRowListArrayBuilder::new(create_array!( + Int32, + [Some(1), Some(2), Some(3), None] + )) + .build_list_array(); + let data = Arc::new(data) as ArrayRef; + let input_fields = vec![ + Arc::new(Field::new("b", DataType::Float64, true)), + Arc::new(Field::new( + "a", + DataType::List(Field::new("b", DataType::Int64, true).into()), + true, + )), + ]; + + let original_value = ScalarValue::Float64(Some(0.123456)); + let return_field = udf.return_field_from_args(ReturnFieldArgs { + arg_fields: &input_fields, + scalar_arguments: &[Some(&original_value), None], + })?; + + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Scalar(original_value), + ColumnarValue::Array(data), + ], + arg_fields: input_fields, + number_rows: 4, + return_field, + config_options: Arc::new(Default::default()), + }; + + let ColumnarValue::Array(result) = udf.invoke_with_args(args)? else { + return exec_err!("Invalid return type"); + }; + let list_array = result + .as_any() + .downcast_ref::>() + .unwrap(); + + let expected = create_array!(Float64, [Some(0.1), Some(0.12), Some(0.123), None]) + as ArrayRef; + + assert_eq!(&list_array.value(0), &expected); + + Ok(()) + } +} From e15211f194036a65c7a5bb51ab5b56d9d9fd18b8 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sat, 30 Aug 2025 15:35:31 -0400 Subject: [PATCH 2/3] Add unit tests --- datafusion/common/src/cast.rs | 13 +- datafusion/functions-nested/src/transform.rs | 348 ++++++++++++++----- 2 files changed, 267 insertions(+), 94 deletions(-) diff --git a/datafusion/common/src/cast.rs b/datafusion/common/src/cast.rs index 68b753a6678a4..4872b933c02a1 100644 --- a/datafusion/common/src/cast.rs +++ b/datafusion/common/src/cast.rs @@ -24,7 +24,8 @@ use crate::{downcast_value, Result}; use arrow::array::{ BinaryViewArray, DurationMicrosecondArray, DurationMillisecondArray, DurationNanosecondArray, DurationSecondArray, Float16Array, Int16Array, Int8Array, - LargeBinaryArray, LargeStringArray, StringViewArray, UInt16Array, + LargeBinaryArray, LargeListViewArray, LargeStringArray, ListViewArray, + StringViewArray, UInt16Array, }; use arrow::{ array::{ @@ -147,6 +148,16 @@ pub fn as_list_array(array: &dyn Array) -> Result<&ListArray> { Ok(downcast_value!(array, ListArray)) } +// Downcast Array to ListViewArray +pub fn as_list_view_array(array: &dyn Array) -> Result<&ListViewArray> { + Ok(downcast_value!(array, ListViewArray)) +} + +// Downcast Array to LargeListViewArray +pub fn as_large_list_view_array(array: &dyn Array) -> Result<&LargeListViewArray> { + Ok(downcast_value!(array, LargeListViewArray)) +} + // Downcast Array to DictionaryArray pub fn as_dictionary_array( array: &dyn Array, diff --git a/datafusion/functions-nested/src/transform.rs b/datafusion/functions-nested/src/transform.rs index 7f28ae3c91dc1..6bbbef991268d 100644 --- a/datafusion/functions-nested/src/transform.rs +++ b/datafusion/functions-nested/src/transform.rs @@ -17,23 +17,23 @@ //! [`ScalarUDFImpl`] definition for array_transform function. -use arrow::array::{Array, ArrayRef, GenericListArray, NullArray}; +use arrow::array::{Array, ArrayRef, GenericListArray, GenericListViewArray}; use arrow::datatypes::{DataType, Field, FieldRef}; -use datafusion_common::cast::as_large_list_array; -use datafusion_common::cast::as_list_array; +use datafusion_common::cast::{ + as_large_list_array, as_large_list_view_array, as_list_array, as_list_view_array, +}; use datafusion_common::{exec_datafusion_err, exec_err, plan_err, Result}; use datafusion_expr::type_coercion::functions::data_types_with_scalar_udf; use datafusion_expr::{ - ColumnarValue, Documentation, ReturnFieldArgs, ScalarUDFImpl, Signature, - TypeSignature, + ColumnarValue, Documentation, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, + Signature, TypeSignature, }; use datafusion_expr::{Expr, ScalarUDF}; use datafusion_macros::user_doc; use std::any::Any; use std::sync::Arc; -#[doc = "ScalarFunction that returns a [`ScalarUDF`](datafusion_expr::ScalarUDF) for "] -#[doc = "ArrayTransform"] +#[doc = "ScalarFunction that returns a [`ScalarUDF`] for ArrayTransform"] pub fn array_transform_udf( inner: Arc, argument_index: usize, @@ -46,7 +46,7 @@ pub fn array_transform_udf( #[user_doc( doc_section(label = "Array Transform"), - description = "Transform every element of an array according to a scalar function.", + description = "Transform every element of an array according to a scalar function. This work is under development and currently only supports passing a single ListArray as input to the inner function. Other inputs must be scalar values.", syntax_example = "array_transform(inner_function(arg1, arg2, arg3), arg_index)", sql_example = r#"```sql > select array_transform(abs([-3,1,-4,2]), 0); @@ -89,6 +89,90 @@ impl ArrayTransform { } } +macro_rules! invoke_by_list_type { + ($fn_name:ident, $downcast_fn:ident, $return_type:ty) => { + fn $fn_name( + &self, + replacement_array: ArrayRef, + mut args: ScalarFunctionArgs, + return_field: FieldRef, + ) -> Result { + let array = $downcast_fn(&replacement_array)?; + let offsets = array.offsets().clone(); + let nulls = array.nulls().cloned(); + + let values = array.values(); + + args.args[self.argument_index] = ColumnarValue::Array(Arc::clone(values)); + + let results = self.function.invoke_with_args(args)?; + + let ColumnarValue::Array(result_array) = results else { + return Ok(results); + }; + + Ok(ColumnarValue::Array(Arc::new(<$return_type>::try_new( + return_field, + offsets, + result_array, + nulls, + )?) as ArrayRef)) + } + }; +} +macro_rules! invoke_by_list_view_type { + ($fn_name:ident, $downcast_fn:ident, $return_type:ty) => { + fn $fn_name( + &self, + replacement_array: ArrayRef, + mut args: ScalarFunctionArgs, + return_field: FieldRef, + ) -> Result { + let array = $downcast_fn(&replacement_array)?; + let offsets = array.offsets().clone(); + let nulls = array.nulls().cloned(); + let sizes = array.sizes().clone(); + + let values = array.values(); + + args.args[self.argument_index] = ColumnarValue::Array(Arc::clone(values)); + + let results = self.function.invoke_with_args(args)?; + + let ColumnarValue::Array(result_array) = results else { + return Ok(results); + }; + + Ok(ColumnarValue::Array(Arc::new(<$return_type>::try_new( + return_field, + offsets, + sizes, + result_array, + nulls, + )?) as ArrayRef)) + } + }; +} + +impl ArrayTransform { + invoke_by_list_type!(invoke_list, as_list_array, GenericListArray); + invoke_by_list_type!( + invoke_large_list, + as_large_list_array, + GenericListArray + ); + invoke_by_list_view_type!( + invoke_list_view, + as_list_view_array, + GenericListViewArray + ); + invoke_by_list_view_type!( + invoke_large_list_view, + as_large_list_view_array, + GenericListViewArray + ); +} + impl ScalarUDFImpl for ArrayTransform { fn as_any(&self) -> &dyn Any { self @@ -235,10 +319,7 @@ impl ScalarUDFImpl for ArrayTransform { Ok(return_types) } - fn invoke_with_args( - &self, - mut args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { let return_field = match args.return_field.data_type() { DataType::List(field) | DataType::LargeList(field) @@ -263,61 +344,25 @@ impl ScalarUDFImpl for ArrayTransform { let ColumnarValue::Array(replacement_array) = replacement_array else { return exec_err!("Unexpected scalar value in array_transform"); }; + let replacement_array = Arc::clone(replacement_array); let result = match &replacement_array.data_type() { - DataType::Null => { - Ok(Arc::new(NullArray::new(replacement_array.len())) as ArrayRef) - } - DataType::List(_) => { - let array = as_list_array(&replacement_array)?; - let offsets = array.offsets().clone(); - let nulls = array.nulls().cloned(); - - let values = array.values(); - - args.args[self.argument_index] = ColumnarValue::Array(Arc::clone(values)); - - let results = self.function.invoke_with_args(args)?; - - let ColumnarValue::Array(result_array) = results else { - return Ok(results); - }; - - Ok(Arc::new(GenericListArray::try_new( - return_field, - offsets, - result_array, - nulls, - )?) as ArrayRef) + DataType::List(_) => self.invoke_list(replacement_array, args, return_field), + DataType::ListView(_) => { + self.invoke_list_view(replacement_array, args, return_field) } DataType::LargeList(_) => { - let array = as_large_list_array(&replacement_array)?; - let offsets = array.offsets().clone(); - let nulls = array.nulls().cloned(); - - let values = array.values(); - - args.args[self.argument_index] = ColumnarValue::Array(Arc::clone(values)); - - let results = self.function.invoke_with_args(args)?; - - let ColumnarValue::Array(result_array) = results else { - return Ok(results); - }; - - Ok(Arc::new(GenericListArray::try_new( - return_field, - offsets, - result_array, - nulls, - )?) as ArrayRef) + self.invoke_large_list(replacement_array, args, return_field) + } + DataType::LargeListView(_) => { + self.invoke_large_list_view(replacement_array, args, return_field) } arg_type => { exec_err!("array_transform does not support type {arg_type}") } }?; - Ok(ColumnarValue::Array(result)) + Ok(result) } fn aliases(&self) -> &[String] { @@ -332,7 +377,10 @@ impl ScalarUDFImpl for ArrayTransform { #[cfg(test)] mod tests { use super::array_transform_udf; - use arrow::array::{create_array, ArrayRef, GenericListArray}; + use arrow::array::{ + create_array, Array, ArrayRef, GenericListArray, GenericListViewArray, Int32Array, + }; + use arrow::buffer::{NullBuffer, OffsetBuffer, ScalarBuffer}; use arrow::datatypes::{DataType, Field}; use datafusion_common::utils::SingleRowListArrayBuilder; use datafusion_common::{exec_err, DataFusionError, ScalarValue}; @@ -340,50 +388,164 @@ mod tests { use datafusion_functions::math::{abs, round}; use std::sync::Arc; - #[test] - fn test_array_transform_apply_single_valued_function() -> Result<(), DataFusionError> - { - let udf = array_transform_udf(abs(), 0); + macro_rules! test_array_transform_generic_list_test { + ($test_name:ident, $array_type:ty, $data_type:ident, $offset_type:ty) => { + #[test] + fn $test_name() -> Result<(), DataFusionError> { + let udf = array_transform_udf(abs(), 0); + + let field = Arc::new(Field::new_list_field(DataType::Int32, true)); + let offsets: OffsetBuffer<$offset_type> = + OffsetBuffer::from_lengths(vec![3, 3, 2]); + let values = Int32Array::from(vec![ + Some(0), + Some(-1), + Some(-2), + None, + Some(4), + Some(-5), + Some(-6), + Some(7), + ]); + let nulls = NullBuffer::from(vec![true, true, false]); + let data = + <$array_type>::new(field, offsets, Arc::new(values), Some(nulls)); + + let data = Arc::new(data) as ArrayRef; + let input_field = Arc::new(Field::new( + "a", + DataType::$data_type(Field::new("b", DataType::Int32, true).into()), + true, + )); + let return_field = udf.return_field_from_args(ReturnFieldArgs { + arg_fields: &[Arc::clone(&input_field)], + scalar_arguments: &[None], + })?; + + let args = ScalarFunctionArgs { + args: vec![ColumnarValue::Array(data)], + arg_fields: vec![input_field], + number_rows: 3, + return_field, + config_options: Arc::new(Default::default()), + }; - let data = SingleRowListArrayBuilder::new(create_array!( - Int32, - [Some(1), Some(-2), None] - )) - .build_list_array(); - let data = Arc::new(data) as ArrayRef; - let input_field = Arc::new(Field::new( - "a", - DataType::List(Field::new("b", DataType::Int32, true).into()), - true, - )); - let return_field = udf.return_field_from_args(ReturnFieldArgs { - arg_fields: &[Arc::clone(&input_field)], - scalar_arguments: &[None], - })?; + let ColumnarValue::Array(result) = udf.invoke_with_args(args)? else { + return exec_err!("Invalid return type"); + }; + let list_array = result.as_any().downcast_ref::<$array_type>().unwrap(); - let args = ScalarFunctionArgs { - args: vec![ColumnarValue::Array(data)], - arg_fields: vec![input_field], - number_rows: 3, - return_field, - config_options: Arc::new(Default::default()), - }; + let expected = + create_array!(Int32, [Some(0), Some(1), Some(2)]) as ArrayRef; + assert_eq!(&list_array.value(0), &expected); - let ColumnarValue::Array(result) = udf.invoke_with_args(args)? else { - return exec_err!("Invalid return type"); + // assert!(list_array.is_null(1)); + let expected = create_array!(Int32, [None, Some(4), Some(5)]) as ArrayRef; + assert_eq!(&list_array.value(1), &expected); + + assert!(list_array.is_null(2)); + + Ok(()) + } }; - let list_array = result - .as_any() - .downcast_ref::>() - .unwrap(); + } - let expected = create_array!(Int32, [Some(1), Some(2), None]) as ArrayRef; + macro_rules! test_array_transform_generic_view_test { + ($test_name:ident, $array_type:ty, $data_type:ident, $offset_type:ty) => { + #[test] + fn $test_name() -> Result<(), DataFusionError> { + let udf = array_transform_udf(abs(), 0); + + let field = Arc::new(Field::new_list_field(DataType::Int32, true)); + let sizes: ScalarBuffer<$offset_type> = ScalarBuffer::from(vec![3, 3, 2]); + let offsets: ScalarBuffer<$offset_type> = + ScalarBuffer::from(vec![0, 3, 6]); + let values = Int32Array::from(vec![ + Some(0), + Some(-1), + Some(-2), + None, + Some(4), + Some(-5), + Some(-6), + Some(7), + ]); + let nulls = NullBuffer::from(vec![true, true, false]); + let data = <$array_type>::new( + field, + offsets, + sizes, + Arc::new(values), + Some(nulls), + ); + + let data = Arc::new(data) as ArrayRef; + let input_field = Arc::new(Field::new( + "a", + DataType::$data_type(Field::new("b", DataType::Int32, true).into()), + true, + )); + let return_field = udf.return_field_from_args(ReturnFieldArgs { + arg_fields: &[Arc::clone(&input_field)], + scalar_arguments: &[None], + })?; + + let args = ScalarFunctionArgs { + args: vec![ColumnarValue::Array(data)], + arg_fields: vec![input_field], + number_rows: 3, + return_field, + config_options: Arc::new(Default::default()), + }; - assert_eq!(&list_array.value(0), &expected); + let ColumnarValue::Array(result) = udf.invoke_with_args(args)? else { + return exec_err!("Invalid return type"); + }; + let list_array = result.as_any().downcast_ref::<$array_type>().unwrap(); - Ok(()) + let expected = + create_array!(Int32, [Some(0), Some(1), Some(2)]) as ArrayRef; + assert_eq!(&list_array.value(0), &expected); + + // assert!(list_array.is_null(1)); + let expected = create_array!(Int32, [None, Some(4), Some(5)]) as ArrayRef; + assert_eq!(&list_array.value(1), &expected); + + assert!(list_array.is_null(2)); + + Ok(()) + } + }; } + test_array_transform_generic_list_test!( + test_array_transform_list_array_test, + GenericListArray, + List, + i32 + ); + + test_array_transform_generic_list_test!( + test_array_transform_large_list_array_test, + GenericListArray, + LargeList, + i64 + ); + + test_array_transform_generic_view_test!( + test_array_transform_list_view_array_test, + GenericListViewArray, + ListView, + i32 + ); + + test_array_transform_generic_view_test!( + test_array_transform_large_list_view_array_test, + GenericListViewArray, + LargeListView, + i64 + ); + #[test] fn test_array_transform_test_argument_index() -> Result<(), DataFusionError> { let udf = array_transform_udf(round(), 1); From 492cf818be75878b4ff1fff621e3c4901697f85c Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sat, 30 Aug 2025 15:58:57 -0400 Subject: [PATCH 3/3] Support fixed size list array --- datafusion/functions-nested/src/transform.rs | 118 ++++++++++++++++++- 1 file changed, 113 insertions(+), 5 deletions(-) diff --git a/datafusion/functions-nested/src/transform.rs b/datafusion/functions-nested/src/transform.rs index 6bbbef991268d..fce66ea8ceecd 100644 --- a/datafusion/functions-nested/src/transform.rs +++ b/datafusion/functions-nested/src/transform.rs @@ -17,10 +17,13 @@ //! [`ScalarUDFImpl`] definition for array_transform function. -use arrow::array::{Array, ArrayRef, GenericListArray, GenericListViewArray}; +use arrow::array::{ + Array, ArrayRef, FixedSizeListArray, GenericListArray, GenericListViewArray, +}; use arrow::datatypes::{DataType, Field, FieldRef}; use datafusion_common::cast::{ - as_large_list_array, as_large_list_view_array, as_list_array, as_list_view_array, + as_fixed_size_list_array, as_large_list_array, as_large_list_view_array, + as_list_array, as_list_view_array, }; use datafusion_common::{exec_datafusion_err, exec_err, plan_err, Result}; use datafusion_expr::type_coercion::functions::data_types_with_scalar_udf; @@ -120,6 +123,37 @@ macro_rules! invoke_by_list_type { } }; } +macro_rules! invoke_by_fixed_size_type { + ($fn_name:ident, $downcast_fn:ident, $return_type:ty) => { + fn $fn_name( + &self, + replacement_array: ArrayRef, + mut args: ScalarFunctionArgs, + return_field: FieldRef, + ) -> Result { + let array = $downcast_fn(&replacement_array)?; + let size = array.value_length(); + let nulls = array.nulls().cloned(); + + let values = array.values(); + + args.args[self.argument_index] = ColumnarValue::Array(Arc::clone(values)); + + let results = self.function.invoke_with_args(args)?; + + let ColumnarValue::Array(result_array) = results else { + return Ok(results); + }; + + Ok(ColumnarValue::Array(Arc::new(<$return_type>::try_new( + return_field, + size, + result_array, + nulls, + )?) as ArrayRef)) + } + }; +} macro_rules! invoke_by_list_view_type { ($fn_name:ident, $downcast_fn:ident, $return_type:ty) => { fn $fn_name( @@ -171,6 +205,11 @@ impl ArrayTransform { as_large_list_view_array, GenericListViewArray ); + invoke_by_fixed_size_type!( + invoke_fixed_size_list, + as_fixed_size_list_array, + FixedSizeListArray + ); } impl ScalarUDFImpl for ArrayTransform { @@ -228,7 +267,8 @@ impl ScalarUDFImpl for ArrayTransform { DataType::List(field) | DataType::LargeList(field) | DataType::ListView(field) - | DataType::LargeListView(field) => Ok(Arc::clone(field)), + | DataType::LargeListView(field) + | DataType::FixedSizeList(field, _) => Ok(Arc::clone(field)), arg_type => plan_err!("{} does not support type {arg_type}", self.name()), }?; @@ -264,6 +304,11 @@ impl ScalarUDFImpl for ArrayTransform { DataType::LargeListView(inner_return), true, ))), + DataType::FixedSizeList(_, size) => Ok(Arc::new(Field::new( + name, + DataType::FixedSizeList(inner_return, *size), + true, + ))), _ => unreachable!(), } } @@ -324,7 +369,8 @@ impl ScalarUDFImpl for ArrayTransform { DataType::List(field) | DataType::LargeList(field) | DataType::ListView(field) - | DataType::LargeListView(field) => Arc::clone(field), + | DataType::LargeListView(field) + | DataType::FixedSizeList(field, _) => Arc::clone(field), _ => { return exec_err!( "Unexpected return field for array_transform. Expected list data type." @@ -357,6 +403,9 @@ impl ScalarUDFImpl for ArrayTransform { DataType::LargeListView(_) => { self.invoke_large_list_view(replacement_array, args, return_field) } + DataType::FixedSizeList(_, _) => { + self.invoke_fixed_size_list(replacement_array, args, return_field) + } arg_type => { exec_err!("array_transform does not support type {arg_type}") } @@ -378,7 +427,8 @@ impl ScalarUDFImpl for ArrayTransform { mod tests { use super::array_transform_udf; use arrow::array::{ - create_array, Array, ArrayRef, GenericListArray, GenericListViewArray, Int32Array, + create_array, Array, ArrayRef, FixedSizeListArray, GenericListArray, + GenericListViewArray, Int32Array, }; use arrow::buffer::{NullBuffer, OffsetBuffer, ScalarBuffer}; use arrow::datatypes::{DataType, Field}; @@ -546,6 +596,64 @@ mod tests { i64 ); + #[test] + fn test_array_transform_fixed_size_list_array_test() -> Result<(), DataFusionError> { + let udf = array_transform_udf(abs(), 0); + + let field = Arc::new(Field::new_list_field(DataType::Int32, true)); + let values = Int32Array::from(vec![ + Some(0), + Some(-1), + Some(-2), + None, + Some(4), + Some(-5), + Some(-6), + Some(7), + None, + ]); + let nulls = NullBuffer::from(vec![true, true, false]); + let data = FixedSizeListArray::new(field, 3, Arc::new(values), Some(nulls)); + + let data = Arc::new(data) as ArrayRef; + let input_field = Arc::new(Field::new( + "a", + DataType::FixedSizeList(Field::new("b", DataType::Int32, true).into(), 3), + true, + )); + let return_field = udf.return_field_from_args(ReturnFieldArgs { + arg_fields: &[Arc::clone(&input_field)], + scalar_arguments: &[None], + })?; + + let args = ScalarFunctionArgs { + args: vec![ColumnarValue::Array(data)], + arg_fields: vec![input_field], + number_rows: 3, + return_field, + config_options: Arc::new(Default::default()), + }; + + let ColumnarValue::Array(result) = udf.invoke_with_args(args)? else { + return exec_err!("Invalid return type"); + }; + let list_array = result + .as_any() + .downcast_ref::() + .unwrap(); + + let expected = create_array!(Int32, [Some(0), Some(1), Some(2)]) as ArrayRef; + assert_eq!(&list_array.value(0), &expected); + + // assert!(list_array.is_null(1)); + let expected = create_array!(Int32, [None, Some(4), Some(5)]) as ArrayRef; + assert_eq!(&list_array.value(1), &expected); + + assert!(list_array.is_null(2)); + + Ok(()) + } + #[test] fn test_array_transform_test_argument_index() -> Result<(), DataFusionError> { let udf = array_transform_udf(round(), 1);