From dbf2aa5ad9b61bec17c6f6010359383f8707b5ba Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Sat, 22 Nov 2025 15:29:54 -0300 Subject: [PATCH 01/12] add lambda support --- Cargo.lock | 1 + .../examples/custom_file_casts.rs | 11 +- .../examples/default_column_values.rs | 14 +- datafusion-examples/examples/expr_api.rs | 8 +- .../examples/json_shredding.rs | 14 +- datafusion/catalog-listing/src/helpers.rs | 9 +- datafusion/common/src/column.rs | 6 + datafusion/common/src/cse.rs | 22 +- datafusion/common/src/dfschema.rs | 16 +- datafusion/common/src/lib.rs | 2 + datafusion/common/src/utils/mod.rs | 127 +++- .../core/src/execution/session_state.rs | 5 +- datafusion/core/tests/parquet/mod.rs | 2 +- .../core/tests/parquet/schema_adapter.rs | 8 +- .../datasource-parquet/src/row_filter.rs | 16 +- datafusion/expr/src/expr.rs | 69 +- datafusion/expr/src/expr_rewriter/mod.rs | 49 +- datafusion/expr/src/expr_rewriter/order_by.rs | 4 + datafusion/expr/src/expr_schema.rs | 60 +- datafusion/expr/src/lib.rs | 7 +- datafusion/expr/src/tree_node.rs | 702 +++++++++++++++++- datafusion/expr/src/udf.rs | 564 +++++++++++++- datafusion/expr/src/utils.rs | 41 +- datafusion/ffi/src/udf/mod.rs | 8 +- datafusion/ffi/src/udf/return_type_args.rs | 9 +- .../functions-nested/src/array_transform.rs | 266 +++++++ .../src/analyzer/function_rewrite.rs | 21 +- .../optimizer/src/analyzer/type_coercion.rs | 98 +-- .../optimizer/src/common_subexpr_eliminate.rs | 23 +- datafusion/optimizer/src/decorrelate.rs | 20 +- .../optimizer/src/optimize_projections/mod.rs | 4 +- datafusion/optimizer/src/push_down_filter.rs | 37 +- .../optimizer/src/scalar_subquery_to_join.rs | 67 +- .../simplify_expressions/expr_simplifier.rs | 105 ++- datafusion/optimizer/src/utils.rs | 4 +- .../src/schema_rewriter.rs | 29 +- datafusion/physical-expr/Cargo.toml | 4 + .../src/async_scalar_function.rs | 2 + .../physical-expr/src/expressions/column.rs | 21 +- .../physical-expr/src/expressions/lambda.rs | 139 ++++ .../physical-expr/src/expressions/mod.rs | 2 + datafusion/physical-expr/src/lib.rs | 2 + datafusion/physical-expr/src/physical_expr.rs | 10 +- datafusion/physical-expr/src/planner.rs | 29 +- datafusion/physical-expr/src/projection.rs | 53 +- .../physical-expr/src/scalar_function.rs | 701 ++++++++++++++++- .../physical-expr/src/simplifier/mod.rs | 20 +- .../src/simplifier/unwrap_cast.rs | 12 +- datafusion/physical-expr/src/utils/mod.rs | 21 +- .../src/enforce_sorting/sort_pushdown.rs | 60 +- .../src/projection_pushdown.rs | 55 +- datafusion/physical-plan/src/async_func.rs | 6 +- .../src/joins/stream_join_utils.rs | 29 +- datafusion/physical-plan/src/projection.rs | 61 +- datafusion/proto/src/logical_plan/to_proto.rs | 5 + datafusion/pruning/src/pruning_predicate.rs | 11 +- datafusion/sql/src/expr/function.rs | 28 +- datafusion/sql/src/expr/identifier.rs | 13 + datafusion/sql/src/planner.rs | 30 +- datafusion/sql/src/select.rs | 6 +- datafusion/sql/src/unparser/expr.rs | 17 +- datafusion/sql/src/unparser/plan.rs | 6 +- datafusion/sql/src/unparser/rewrite.rs | 14 +- datafusion/sql/src/unparser/utils.rs | 44 +- datafusion/sql/src/utils.rs | 32 +- datafusion/sqllogictest/test_files/array.slt | 8 +- datafusion/sqllogictest/test_files/lambda.slt | 166 +++++ .../src/logical_plan/producer/expr/mod.rs | 1 + 68 files changed, 3573 insertions(+), 483 deletions(-) create mode 100644 datafusion/functions-nested/src/array_transform.rs create mode 100644 datafusion/physical-expr/src/expressions/lambda.rs create mode 100644 datafusion/sqllogictest/test_files/lambda.slt diff --git a/Cargo.lock b/Cargo.lock index f500265108ff5..4a315ff38f2aa 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2489,6 +2489,7 @@ dependencies = [ "paste", "petgraph 0.8.3", "rand 0.9.2", + "recursive", "rstest", ] diff --git a/datafusion-examples/examples/custom_file_casts.rs b/datafusion-examples/examples/custom_file_casts.rs index 4d97ecd91dc64..d8db97d1e0440 100644 --- a/datafusion-examples/examples/custom_file_casts.rs +++ b/datafusion-examples/examples/custom_file_casts.rs @@ -22,7 +22,7 @@ use arrow::datatypes::{DataType, Field, FieldRef, Schema, SchemaRef}; use datafusion::assert_batches_eq; use datafusion::common::not_impl_err; -use datafusion::common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion::common::tree_node::{Transformed, TransformedResult}; use datafusion::common::{Result, ScalarValue}; use datafusion::datasource::listing::{ ListingTable, ListingTableConfig, ListingTableConfigExt, ListingTableUrl, @@ -31,7 +31,7 @@ use datafusion::execution::context::SessionContext; use datafusion::execution::object_store::ObjectStoreUrl; use datafusion::parquet::arrow::ArrowWriter; use datafusion::physical_expr::expressions::CastExpr; -use datafusion::physical_expr::PhysicalExpr; +use datafusion::physical_expr::{PhysicalExpr, PhysicalExprExt}; use datafusion::prelude::SessionConfig; use datafusion_physical_expr_adapter::{ DefaultPhysicalExprAdapterFactory, PhysicalExprAdapter, PhysicalExprAdapterFactory, @@ -181,11 +181,10 @@ impl PhysicalExprAdapter for CustomCastsPhysicalExprAdapter { expr = self.inner.rewrite(expr)?; // Now we can apply custom casting rules or even swap out all CastExprs for a custom cast kernel / expression // For example, [DataFusion Comet](https://github.com/apache/datafusion-comet) has a [custom cast kernel](https://github.com/apache/datafusion-comet/blob/b4ac876ab420ed403ac7fc8e1b29f42f1f442566/native/spark-expr/src/conversion_funcs/cast.rs#L133-L138). - expr.transform(|expr| { + expr.transform_with_schema(&self.physical_file_schema, |expr, schema| { if let Some(cast) = expr.as_any().downcast_ref::() { - let input_data_type = - cast.expr().data_type(&self.physical_file_schema)?; - let output_data_type = cast.data_type(&self.physical_file_schema)?; + let input_data_type = cast.expr().data_type(schema)?; + let output_data_type = cast.data_type(schema)?; if !cast.is_bigger_cast(&input_data_type) { return not_impl_err!( "Unsupported CAST from {input_data_type} to {output_data_type}" diff --git a/datafusion-examples/examples/default_column_values.rs b/datafusion-examples/examples/default_column_values.rs index d3a7d2ec67f3c..0d00d2c3af827 100644 --- a/datafusion-examples/examples/default_column_values.rs +++ b/datafusion-examples/examples/default_column_values.rs @@ -26,8 +26,8 @@ use async_trait::async_trait; use datafusion::assert_batches_eq; use datafusion::catalog::memory::DataSourceExec; use datafusion::catalog::{Session, TableProvider}; -use datafusion::common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion::common::DFSchema; +use datafusion::common::tree_node::{Transformed, TransformedResult}; +use datafusion::common::{DFSchema, HashSet}; use datafusion::common::{Result, ScalarValue}; use datafusion::datasource::listing::PartitionedFile; use datafusion::datasource::physical_plan::{FileScanConfigBuilder, ParquetSource}; @@ -38,7 +38,7 @@ use datafusion::logical_expr::{Expr, TableProviderFilterPushDown, TableType}; use datafusion::parquet::arrow::ArrowWriter; use datafusion::parquet::file::properties::WriterProperties; use datafusion::physical_expr::expressions::{CastExpr, Column, Literal}; -use datafusion::physical_expr::PhysicalExpr; +use datafusion::physical_expr::{PhysicalExpr, PhysicalExprExt}; use datafusion::physical_plan::ExecutionPlan; use datafusion::prelude::{lit, SessionConfig}; use datafusion_physical_expr_adapter::{ @@ -308,11 +308,12 @@ impl PhysicalExprAdapter for DefaultValuePhysicalExprAdapter { fn rewrite(&self, expr: Arc) -> Result> { // First try our custom default value injection for missing columns let rewritten = expr - .transform(|expr| { + .transform_with_lambdas_params(|expr, lambdas_params| { self.inject_default_values( expr, &self.logical_file_schema, &self.physical_file_schema, + lambdas_params, ) }) .data()?; @@ -348,12 +349,15 @@ impl DefaultValuePhysicalExprAdapter { expr: Arc, logical_file_schema: &Schema, physical_file_schema: &Schema, + lambdas_params: &HashSet, ) -> Result>> { if let Some(column) = expr.as_any().downcast_ref::() { let column_name = column.name(); // Check if this column exists in the physical schema - if physical_file_schema.index_of(column_name).is_err() { + if !lambdas_params.contains(column_name) + && physical_file_schema.index_of(column_name).is_err() + { // Column is missing from physical schema, check if logical schema has a default if let Ok(logical_field) = logical_file_schema.field_with_name(column_name) diff --git a/datafusion-examples/examples/expr_api.rs b/datafusion-examples/examples/expr_api.rs index 56f960870e58a..29f074e2b400c 100644 --- a/datafusion-examples/examples/expr_api.rs +++ b/datafusion-examples/examples/expr_api.rs @@ -23,7 +23,7 @@ use arrow::record_batch::RecordBatch; use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit}; use datafusion::common::stats::Precision; -use datafusion::common::tree_node::{Transformed, TreeNode}; +use datafusion::common::tree_node::Transformed; use datafusion::common::{ColumnStatistics, DFSchema}; use datafusion::common::{ScalarValue, ToDFSchema}; use datafusion::error::Result; @@ -556,7 +556,7 @@ fn type_coercion_demo() -> Result<()> { // 3. Type coercion with `TypeCoercionRewriter`. let coerced_expr = expr .clone() - .rewrite(&mut TypeCoercionRewriter::new(&df_schema))? + .rewrite_with_schema(&df_schema, &mut TypeCoercionRewriter::new(&df_schema))? .data; let physical_expr = datafusion::physical_expr::create_physical_expr( &coerced_expr, @@ -567,7 +567,7 @@ fn type_coercion_demo() -> Result<()> { // 4. Apply explicit type coercion by manually rewriting the expression let coerced_expr = expr - .transform(|e| { + .transform_with_schema(&df_schema, |e, df_schema| { // Only type coerces binary expressions. let Expr::BinaryExpr(e) = e else { return Ok(Transformed::no(e)); @@ -575,7 +575,7 @@ fn type_coercion_demo() -> Result<()> { if let Expr::Column(ref col_expr) = *e.left { let field = df_schema.field_with_name(None, col_expr.name())?; let cast_to_type = field.data_type(); - let coerced_right = e.right.cast_to(cast_to_type, &df_schema)?; + let coerced_right = e.right.cast_to(cast_to_type, df_schema)?; Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr::new( e.left, e.op, diff --git a/datafusion-examples/examples/json_shredding.rs b/datafusion-examples/examples/json_shredding.rs index 5ef8b59b64200..e97f27b818d8d 100644 --- a/datafusion-examples/examples/json_shredding.rs +++ b/datafusion-examples/examples/json_shredding.rs @@ -22,10 +22,8 @@ use arrow::array::{RecordBatch, StringArray}; use arrow::datatypes::{DataType, Field, FieldRef, Schema, SchemaRef}; use datafusion::assert_batches_eq; -use datafusion::common::tree_node::{ - Transformed, TransformedResult, TreeNode, TreeNodeRecursion, -}; -use datafusion::common::{assert_contains, exec_datafusion_err, Result}; +use datafusion::common::tree_node::{Transformed, TransformedResult, TreeNodeRecursion}; +use datafusion::common::{assert_contains, exec_datafusion_err, HashSet, Result}; use datafusion::datasource::listing::{ ListingTable, ListingTableConfig, ListingTableConfigExt, ListingTableUrl, }; @@ -36,8 +34,8 @@ use datafusion::logical_expr::{ }; use datafusion::parquet::arrow::ArrowWriter; use datafusion::parquet::file::properties::WriterProperties; -use datafusion::physical_expr::PhysicalExpr; use datafusion::physical_expr::{expressions, ScalarFunctionExpr}; +use datafusion::physical_expr::{PhysicalExpr, PhysicalExprExt}; use datafusion::prelude::SessionConfig; use datafusion::scalar::ScalarValue; use datafusion_physical_expr_adapter::{ @@ -302,7 +300,9 @@ impl PhysicalExprAdapter for ShreddedJsonRewriter { fn rewrite(&self, expr: Arc) -> Result> { // First try our custom JSON shredding rewrite let rewritten = expr - .transform(|expr| self.rewrite_impl(expr, &self.physical_file_schema)) + .transform_with_lambdas_params(|expr, lambdas_params| { + self.rewrite_impl(expr, &self.physical_file_schema, lambdas_params) + }) .data()?; // Then apply the default adapter as a fallback to handle standard schema differences @@ -335,6 +335,7 @@ impl ShreddedJsonRewriter { &self, expr: Arc, physical_file_schema: &Schema, + lambdas_params: &HashSet, ) -> Result>> { if let Some(func) = expr.as_any().downcast_ref::() { if func.name() == "json_get_str" && func.args().len() == 2 { @@ -348,6 +349,7 @@ impl ShreddedJsonRewriter { if let Some(column) = func.args()[1] .as_any() .downcast_ref::() + .filter(|col| !lambdas_params.contains(col.name())) { let column_name = column.name(); // Check if there's a flat column with underscore prefix diff --git a/datafusion/catalog-listing/src/helpers.rs b/datafusion/catalog-listing/src/helpers.rs index 82cc36867939e..444f505f4280b 100644 --- a/datafusion/catalog-listing/src/helpers.rs +++ b/datafusion/catalog-listing/src/helpers.rs @@ -52,9 +52,9 @@ use object_store::{ObjectMeta, ObjectStore}; /// was performed pub fn expr_applicable_for_cols(col_names: &[&str], expr: &Expr) -> bool { let mut is_applicable = true; - expr.apply(|expr| match expr { - Expr::Column(Column { ref name, .. }) => { - is_applicable &= col_names.contains(&name.as_str()); + expr.apply_with_lambdas_params(|expr, lambdas_params| match expr { + Expr::Column(col) => { + is_applicable &= col_names.contains(&col.name()) || col.is_lambda_parameter(lambdas_params); if is_applicable { Ok(TreeNodeRecursion::Jump) } else { @@ -86,7 +86,8 @@ pub fn expr_applicable_for_cols(col_names: &[&str], expr: &Expr) -> bool { | Expr::InSubquery(_) | Expr::ScalarSubquery(_) | Expr::GroupingSet(_) - | Expr::Case(_) => Ok(TreeNodeRecursion::Continue), + | Expr::Case(_) + | Expr::Lambda(_) => Ok(TreeNodeRecursion::Continue), Expr::ScalarFunction(scalar_function) => { match scalar_function.func.signature().volatility { diff --git a/datafusion/common/src/column.rs b/datafusion/common/src/column.rs index c7f0b5a4f4881..dd9b985e6485c 100644 --- a/datafusion/common/src/column.rs +++ b/datafusion/common/src/column.rs @@ -22,6 +22,7 @@ use crate::utils::parse_identifiers_normalized; use crate::utils::quote_identifier; use crate::{DFSchema, Diagnostic, Result, SchemaError, Spans, TableReference}; use arrow::datatypes::{Field, FieldRef}; +use std::borrow::Borrow; use std::collections::HashSet; use std::fmt; @@ -325,6 +326,11 @@ impl Column { ..self.clone() } } + + pub fn is_lambda_parameter(&self, lambdas_params: &crate::HashSet + Eq + std::hash::Hash>) -> bool { + // currently, references to lambda parameters are always unqualified + self.relation.is_none() && lambdas_params.contains(self.name()) + } } impl From<&str> for Column { diff --git a/datafusion/common/src/cse.rs b/datafusion/common/src/cse.rs index 674d3386171f8..a7ffde52c93b2 100644 --- a/datafusion/common/src/cse.rs +++ b/datafusion/common/src/cse.rs @@ -178,6 +178,14 @@ pub trait CSEController { /// if all are always evaluated. fn conditional_children(node: &Self::Node) -> Option>; + // A helper method called on each node before is_ignored, during top-down traversal during the first, + // visiting traversal of CSE. + fn visit_f_down(&mut self, _node: &Self::Node) {} + + // A helper method called on each node after is_ignored, during bottom-up traversal during the first, + // visiting traversal of CSE. + fn visit_f_up(&mut self, _node: &Self::Node) {} + // Returns true if a node is valid. If a node is invalid then it can't be eliminated. // Validity is propagated up which means no subtree can be eliminated that contains // an invalid node. @@ -274,7 +282,7 @@ where /// thus can not be extracted as a common [`TreeNode`]. conditional: bool, - controller: &'a C, + controller: &'a mut C, } /// Record item that used when traversing a [`TreeNode`] tree. @@ -352,6 +360,7 @@ where self.visit_stack .push(VisitRecord::EnterMark(self.down_index)); self.down_index += 1; + self.controller.visit_f_down(node); // If a node can short-circuit then some of its children might not be executed so // count the occurrence either normal or conditional. @@ -414,6 +423,7 @@ where self.visit_stack .push(VisitRecord::NodeItem(node_id, is_valid)); self.up_index += 1; + self.controller.visit_f_up(node); Ok(TreeNodeRecursion::Continue) } @@ -532,7 +542,7 @@ where /// Add an identifier to `id_array` for every [`TreeNode`] in this tree. fn node_to_id_array<'n>( - &self, + &mut self, node: &'n N, node_stats: &mut NodeStats<'n, N>, id_array: &mut IdArray<'n, N>, @@ -546,7 +556,7 @@ where random_state: &self.random_state, found_common: false, conditional: false, - controller: &self.controller, + controller: &mut self.controller, }; node.visit(&mut visitor)?; @@ -561,7 +571,7 @@ where /// Each element is itself the result of [`CSE::node_to_id_array`] for that node /// (e.g. the identifiers for each node in the tree) fn to_arrays<'n>( - &self, + &mut self, nodes: &'n [N], node_stats: &mut NodeStats<'n, N>, ) -> Result<(bool, Vec>)> { @@ -761,7 +771,7 @@ mod test { #[test] fn id_array_visitor() -> Result<()> { let alias_generator = AliasGenerator::new(); - let eliminator = CSE::new(TestTreeNodeCSEController::new( + let mut eliminator = CSE::new(TestTreeNodeCSEController::new( &alias_generator, TestTreeNodeMask::Normal, )); @@ -853,7 +863,7 @@ mod test { assert_eq!(expected, id_array); // include aggregates - let eliminator = CSE::new(TestTreeNodeCSEController::new( + let mut eliminator = CSE::new(TestTreeNodeCSEController::new( &alias_generator, TestTreeNodeMask::NormalAndAggregates, )); diff --git a/datafusion/common/src/dfschema.rs b/datafusion/common/src/dfschema.rs index 24d152a7dba8c..8a09d61292b27 100644 --- a/datafusion/common/src/dfschema.rs +++ b/datafusion/common/src/dfschema.rs @@ -314,8 +314,10 @@ impl DFSchema { return; } - let self_fields: HashSet<(Option<&TableReference>, &FieldRef)> = - self.iter().collect(); + let self_fields: HashSet<(Option<&TableReference>, &str)> = self + .iter() + .map(|(qualifier, field)| (qualifier, field.name().as_str())) + .collect(); let self_unqualified_names: HashSet<&str> = self .inner .fields @@ -328,7 +330,10 @@ impl DFSchema { for (qualifier, field) in other_schema.iter() { // skip duplicate columns let duplicated_field = match qualifier { - Some(q) => self_fields.contains(&(Some(q), field)), + Some(q) => { + self_fields.contains(&(Some(q), field.name().as_str())) + || self_fields.contains(&(None, field.name().as_str())) + } // for unqualified columns, check as unqualified name None => self_unqualified_names.contains(field.name().as_str()), }; @@ -867,6 +872,11 @@ impl DFSchema { &self.functional_dependencies } + /// Get functional dependencies + pub fn field_qualifiers(&self) -> &[Option] { + &self.field_qualifiers + } + /// Iterate over the qualifiers and fields in the DFSchema pub fn iter(&self) -> impl Iterator, &FieldRef)> { self.field_qualifiers diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs index 76c7b46e32737..8923df683f899 100644 --- a/datafusion/common/src/lib.rs +++ b/datafusion/common/src/lib.rs @@ -117,6 +117,8 @@ pub mod hash_set { pub use hashbrown::hash_set::Entry; } +pub use hashbrown; + /// Downcast an Arrow Array to a concrete type, return an `DataFusionError::Internal` if the cast is /// not possible. In normal usage of DataFusion the downcast should always succeed. /// diff --git a/datafusion/common/src/utils/mod.rs b/datafusion/common/src/utils/mod.rs index 7b145ac3ae21d..ec2dad505a561 100644 --- a/datafusion/common/src/utils/mod.rs +++ b/datafusion/common/src/utils/mod.rs @@ -22,15 +22,20 @@ pub mod memory; pub mod proxy; pub mod string_utils; -use crate::error::{_exec_datafusion_err, _internal_datafusion_err, _internal_err}; +use crate::error::{ + _exec_datafusion_err, _exec_err, _internal_datafusion_err, _internal_err, +}; use crate::{Result, ScalarValue}; use arrow::array::{ cast::AsArray, Array, ArrayRef, FixedSizeListArray, LargeListArray, ListArray, OffsetSizeTrait, }; +use arrow::array::{ArrowPrimitiveType, PrimitiveArray}; use arrow::buffer::OffsetBuffer; use arrow::compute::{partition, SortColumn, SortOptions}; -use arrow::datatypes::{DataType, Field, SchemaRef}; +use arrow::datatypes::{ + ArrowNativeType, DataType, Field, Int32Type, Int64Type, SchemaRef, +}; #[cfg(feature = "sql")] use sqlparser::{ast::Ident, dialect::GenericDialect, parser::Parser}; use std::borrow::{Borrow, Cow}; @@ -939,6 +944,124 @@ pub fn take_function_args( }) } +/// [0, 2, 2, 5, 6] -> [0, 0, 2, 2, 2, 3] +pub fn make_list_array_indices( + offsets: &OffsetBuffer, +) -> PrimitiveArray { + let mut indices = Vec::with_capacity( + offsets.last().unwrap().as_usize() - offsets.first().unwrap().as_usize(), + ); + + for (i, (&start, &end)) in std::iter::zip(&offsets[..], &offsets[1..]).enumerate() { + indices.extend(std::iter::repeat_n( + T::Native::usize_as(i), + end.as_usize() - start.as_usize(), + )); + } + + PrimitiveArray::new(indices.into(), None) +} + +/// [0, 2, 2, 5, 6] -> [0, 1, 0, 1, 2, 0] +pub fn make_list_element_indices( + offsets: &OffsetBuffer, +) -> PrimitiveArray { + let mut indices = vec![ + T::default_value(); + offsets.last().unwrap().as_usize() + - offsets.first().unwrap().as_usize() + ]; + + for (&start, &end) in std::iter::zip(&offsets[..], &offsets[1..]) { + for i in 0..end.as_usize() - start.as_usize() { + indices[start.as_usize() + i] = T::Native::usize_as(i); + } + } + + PrimitiveArray::new(indices.into(), None) +} + +/// (3, 2) -> [0, 0, 1, 1, 2, 2] +pub fn make_fsl_array_indices( + list_size: i32, + array_len: usize, +) -> PrimitiveArray { + let mut indices = vec![0; list_size as usize * array_len]; + + for i in 0..array_len { + for j in 0..list_size as usize { + indices[i + j] = i as i32; + } + } + + PrimitiveArray::new(indices.into(), None) +} + +/// (3, 2) -> [0, 1, 0, 1, 0, 1] +pub fn make_fsl_element_indices( + list_size: i32, + array_len: usize, +) -> PrimitiveArray { + let mut indices = vec![0; list_size as usize * array_len]; + + for i in 0..array_len { + for j in 0..list_size as usize { + indices[i + j] = j as i32; + } + } + + PrimitiveArray::new(indices.into(), None) +} + +pub fn list_values(array: &dyn Array) -> Result<&ArrayRef> { + match array.data_type() { + DataType::List(_) => Ok(array.as_list::().values()), + DataType::LargeList(_) => Ok(array.as_list::().values()), + DataType::FixedSizeList(_, _) => Ok(array.as_fixed_size_list().values()), + other => _exec_err!("expected list, got {other}"), + } +} + +pub fn list_indices(array: &dyn Array) -> Result { + match array.data_type() { + DataType::List(_) => Ok(Arc::new(make_list_array_indices::( + array.as_list().offsets(), + ))), + DataType::LargeList(_) => Ok(Arc::new(make_list_array_indices::( + array.as_list().offsets(), + ))), + DataType::FixedSizeList(_, _) => { + let fixed_size_list = array.as_fixed_size_list(); + + Ok(Arc::new(make_fsl_array_indices( + fixed_size_list.value_length(), + fixed_size_list.len(), + ))) + } + other => _exec_err!("expected list, got {other}"), + } +} + +pub fn elements_indices(array: &dyn Array) -> Result { + match array.data_type() { + DataType::List(_) => Ok(Arc::new(make_list_element_indices::( + array.as_list::().offsets(), + ))), + DataType::LargeList(_) => Ok(Arc::new(make_list_element_indices::( + array.as_list::().offsets(), + ))), + DataType::FixedSizeList(_, _) => { + let fixed_size_list = array.as_fixed_size_list(); + + Ok(Arc::new(make_fsl_element_indices( + fixed_size_list.value_length(), + fixed_size_list.len(), + ))) + } + other => _exec_err!("expected list, got {other}"), + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index c15b7eae08432..ad4ffb487ee1d 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -41,7 +41,6 @@ use datafusion_common::alias::AliasGenerator; use datafusion_common::config::Dialect; use datafusion_common::config::{ConfigExtension, ConfigOptions, TableOptions}; use datafusion_common::display::{PlanType, StringifiedPlan, ToStringifiedPlan}; -use datafusion_common::tree_node::TreeNode; use datafusion_common::{ config_err, exec_err, plan_datafusion_err, DFSchema, DataFusionError, ResolvedTableReference, TableReference, @@ -701,7 +700,9 @@ impl SessionState { let config_options = self.config_options(); for rewrite in self.analyzer.function_rewrites() { expr = expr - .transform_up(|expr| rewrite.rewrite(expr, df_schema, config_options))? + .transform_up_with_schema(df_schema, |expr, df_schema| { + rewrite.rewrite(expr, df_schema, config_options) + })? .data; } create_physical_expr(&expr, df_schema, self.execution_props()) diff --git a/datafusion/core/tests/parquet/mod.rs b/datafusion/core/tests/parquet/mod.rs index 097600e45eadd..eea6085c02b9f 100644 --- a/datafusion/core/tests/parquet/mod.rs +++ b/datafusion/core/tests/parquet/mod.rs @@ -516,7 +516,7 @@ fn make_uint_batches(start: u8, end: u8) -> RecordBatch { Field::new("u64", DataType::UInt64, true), ])); let v8: Vec = (start..end).collect(); - let v16: Vec = (start as _..end as _).collect(); + let v16: Vec = (start as u16..end as _).collect(); let v32: Vec = (start as _..end as _).collect(); let v64: Vec = (start as _..end as _).collect(); RecordBatch::try_new( diff --git a/datafusion/core/tests/parquet/schema_adapter.rs b/datafusion/core/tests/parquet/schema_adapter.rs index 40fc6176e212b..dfa4c91ba5dd8 100644 --- a/datafusion/core/tests/parquet/schema_adapter.rs +++ b/datafusion/core/tests/parquet/schema_adapter.rs @@ -27,7 +27,7 @@ use datafusion::datasource::listing::{ ListingTable, ListingTableConfig, ListingTableConfigExt, }; use datafusion::prelude::{SessionConfig, SessionContext}; -use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_common::tree_node::{Transformed, TransformedResult}; use datafusion_common::DataFusionError; use datafusion_common::{ColumnStatistics, ScalarValue}; use datafusion_datasource::file::FileSource; @@ -39,7 +39,7 @@ use datafusion_datasource::ListingTableUrl; use datafusion_datasource_parquet::source::ParquetSource; use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_physical_expr::expressions::{self, Column}; -use datafusion_physical_expr::PhysicalExpr; +use datafusion_physical_expr::{PhysicalExpr, PhysicalExprExt}; use datafusion_physical_expr_adapter::{ DefaultPhysicalExprAdapter, DefaultPhysicalExprAdapterFactory, PhysicalExprAdapter, PhysicalExprAdapterFactory, @@ -181,10 +181,10 @@ struct CustomPhysicalExprAdapter { impl PhysicalExprAdapter for CustomPhysicalExprAdapter { fn rewrite(&self, mut expr: Arc) -> Result> { expr = expr - .transform(|expr| { + .transform_with_lambdas_params(|expr, lambdas_params| { if let Some(column) = expr.as_any().downcast_ref::() { let field_name = column.name(); - if self + if !lambdas_params.contains(field_name) && self .physical_file_schema .field_with_name(field_name) .ok() diff --git a/datafusion/datasource-parquet/src/row_filter.rs b/datafusion/datasource-parquet/src/row_filter.rs index 660b32f486120..45441ad71086c 100644 --- a/datafusion/datasource-parquet/src/row_filter.rs +++ b/datafusion/datasource-parquet/src/row_filter.rs @@ -77,7 +77,7 @@ use datafusion_common::Result; use datafusion_datasource::schema_adapter::{SchemaAdapterFactory, SchemaMapper}; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::utils::reassign_expr_columns; -use datafusion_physical_expr::{split_conjunction, PhysicalExpr}; +use datafusion_physical_expr::{split_conjunction, PhysicalExpr, PhysicalExprExt}; use datafusion_physical_plan::metrics; @@ -336,6 +336,20 @@ impl<'schema> PushdownChecker<'schema> { fn prevents_pushdown(&self) -> bool { self.non_primitive_columns || self.projected_columns } + + fn check(&mut self, node: Arc) -> Result { + node.apply_with_lambdas_params(|node, lamdas_params| { + if let Some(column) = node.as_any().downcast_ref::() { + if !lamdas_params.contains(column.name()) { + if let Some(recursion) = self.check_single_column(column.name()) { + return Ok(recursion); + } + } + } + + Ok(TreeNodeRecursion::Continue) + }) + } } impl TreeNodeVisitor<'_> for PushdownChecker<'_> { diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 13160d573ab4d..e2845ea5a7de8 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -398,6 +398,10 @@ pub enum Expr { OuterReferenceColumn(FieldRef, Column), /// Unnest expression Unnest(Unnest), + /// Lambda expression, valid only as a scalar function argument + /// Note that it has it's own scoped schema, different from the plan schema, + /// that can be constructed with ScalarUDF::arguments_schemas and variants + Lambda(Lambda), } impl Default for Expr { @@ -1211,6 +1215,23 @@ impl GroupingSet { } } +/// Lambda expression. +#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] +pub struct Lambda { + pub params: Vec, + pub body: Box, +} + +impl Lambda { + /// Create a new lambda expression + pub fn new(params: Vec, body: Expr) -> Self { + Self { + params, + body: Box::new(body), + } + } +} + #[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] #[cfg(not(feature = "sql"))] pub struct IlikeSelectItem { @@ -1525,6 +1546,7 @@ impl Expr { #[expect(deprecated)] Expr::Wildcard { .. } => "Wildcard", Expr::Unnest { .. } => "Unnest", + Expr::Lambda { .. } => "Lambda", } } @@ -1908,9 +1930,11 @@ impl Expr { /// /// See [`Self::column_refs`] for details pub fn add_column_refs<'a>(&'a self, set: &mut HashSet<&'a Column>) { - self.apply(|expr| { + self.apply_with_lambdas_params(|expr, lambdas_params| { if let Expr::Column(col) = expr { - set.insert(col); + if col.relation.is_some() || !lambdas_params.contains(col.name()) { + set.insert(col); + } } Ok(TreeNodeRecursion::Continue) }) @@ -1943,9 +1967,11 @@ impl Expr { /// /// See [`Self::column_refs_counts`] for details pub fn add_column_ref_counts<'a>(&'a self, map: &mut HashMap<&'a Column, usize>) { - self.apply(|expr| { + self.apply_with_lambdas_params(|expr, lambdas_params| { if let Expr::Column(col) = expr { - *map.entry(col).or_default() += 1; + if !col.is_lambda_parameter(lambdas_params) { + *map.entry(col).or_default() += 1; + } } Ok(TreeNodeRecursion::Continue) }) @@ -1954,8 +1980,10 @@ impl Expr { /// Returns true if there are any column references in this Expr pub fn any_column_refs(&self) -> bool { - self.exists(|expr| Ok(matches!(expr, Expr::Column(_)))) - .expect("exists closure is infallible") + self.exists_with_lambdas_params(|expr, lambdas_params| { + Ok(matches!(expr, Expr::Column(c) if !c.is_lambda_parameter(lambdas_params))) + }) + .expect("exists closure is infallible") } /// Return true if the expression contains out reference(correlated) expressions. @@ -1995,7 +2023,7 @@ impl Expr { /// at least one placeholder. pub fn infer_placeholder_types(self, schema: &DFSchema) -> Result<(Expr, bool)> { let mut has_placeholder = false; - self.transform(|mut expr| { + self.transform_with_schema(schema, |mut expr, schema| { match &mut expr { // Default to assuming the arguments are the same type Expr::BinaryExpr(BinaryExpr { left, op: _, right }) => { @@ -2078,7 +2106,8 @@ impl Expr { | Expr::Wildcard { .. } | Expr::WindowFunction(..) | Expr::Literal(..) - | Expr::Placeholder(..) => false, + | Expr::Placeholder(..) + | Expr::Lambda { .. } => false, } } @@ -2674,6 +2703,12 @@ impl HashNode for Expr { column.hash(state); } Expr::Unnest(Unnest { expr: _expr }) => {} + Expr::Lambda(Lambda { + params, + body: _, + }) => { + params.hash(state); + } }; } } @@ -2987,6 +3022,12 @@ impl Display for SchemaDisplay<'_> { } } } + Expr::Lambda(Lambda { + params, + body, + }) => { + write!(f, "({}) -> {body}", display_comma_separated(params)) + } } } } @@ -3167,6 +3208,12 @@ impl Display for SqlDisplay<'_> { } } } + Expr::Lambda(Lambda { + params, + body, + }) => { + write!(f, "({}) -> {}", params.join(", "), SchemaDisplay(body)) + } _ => write!(f, "{}", self.0), } } @@ -3474,6 +3521,12 @@ impl Display for Expr { Expr::Unnest(Unnest { expr }) => { write!(f, "{UNNEST_COLUMN_PREFIX}({expr})") } + Expr::Lambda(Lambda { + params, + body, + }) => { + write!(f, "({}) -> {body}", params.join(", ")) + } } } } diff --git a/datafusion/expr/src/expr_rewriter/mod.rs b/datafusion/expr/src/expr_rewriter/mod.rs index 9c3c5df7007ff..81ec6e7acbe38 100644 --- a/datafusion/expr/src/expr_rewriter/mod.rs +++ b/datafusion/expr/src/expr_rewriter/mod.rs @@ -62,11 +62,15 @@ pub trait FunctionRewrite: Debug { /// Recursively call `LogicalPlanBuilder::normalize` on all [`Column`] expressions /// in the `expr` expression tree. pub fn normalize_col(expr: Expr, plan: &LogicalPlan) -> Result { - expr.transform(|expr| { + expr.transform_up_with_lambdas_params(|expr, lambdas_params| { Ok({ if let Expr::Column(c) = expr { - let col = LogicalPlanBuilder::normalize(plan, c)?; - Transformed::yes(Expr::Column(col)) + if c.relation.is_some() || !lambdas_params.contains(c.name()) { + let col = LogicalPlanBuilder::normalize(plan, c)?; + Transformed::yes(Expr::Column(col)) + } else { + Transformed::no(Expr::Column(c)) + } } else { Transformed::no(expr) } @@ -91,14 +95,21 @@ pub fn normalize_col_with_schemas_and_ambiguity_check( return Ok(Expr::Unnest(Unnest { expr: Box::new(e) })); } - expr.transform(|expr| { + expr.transform_up_with_lambdas_params(|expr, lambdas_params| { Ok({ - if let Expr::Column(c) = expr { - let col = - c.normalize_with_schemas_and_ambiguity_check(schemas, using_columns)?; - Transformed::yes(Expr::Column(col)) - } else { - Transformed::no(expr) + match expr { + Expr::Column(c) => { + if c.relation.is_none() && lambdas_params.contains(c.name()) { + Transformed::no(Expr::Column(c)) + } else { + let col = c.normalize_with_schemas_and_ambiguity_check( + schemas, + using_columns, + )?; + Transformed::yes(Expr::Column(col)) + } + } + _ => Transformed::no(expr), } }) }) @@ -133,15 +144,18 @@ pub fn normalize_sorts( /// Recursively replace all [`Column`] expressions in a given expression tree with /// `Column` expressions provided by the hash map argument. pub fn replace_col(expr: Expr, replace_map: &HashMap<&Column, &Column>) -> Result { - expr.transform(|expr| { + expr.transform_up_with_lambdas_params(|expr, lambdas_params| { Ok({ - if let Expr::Column(c) = &expr { - match replace_map.get(c) { - Some(new_c) => Transformed::yes(Expr::Column((*new_c).to_owned())), - None => Transformed::no(expr), + match &expr { + Expr::Column(c) if !c.is_lambda_parameter(lambdas_params) => { + match replace_map.get(c) { + Some(new_c) => { + Transformed::yes(Expr::Column((*new_c).to_owned())) + } + None => Transformed::no(expr), + } } - } else { - Transformed::no(expr) + _ => Transformed::no(expr), } }) }) @@ -201,6 +215,7 @@ pub fn strip_outer_reference(expr: Expr) -> Expr { expr.transform(|expr| { Ok({ if let Expr::OuterReferenceColumn(_, col) = expr { + //todo: what if this col collides with a lambda parameter? Transformed::yes(Expr::Column(col)) } else { Transformed::no(expr) diff --git a/datafusion/expr/src/expr_rewriter/order_by.rs b/datafusion/expr/src/expr_rewriter/order_by.rs index 6db95555502da..b94c632ce74b3 100644 --- a/datafusion/expr/src/expr_rewriter/order_by.rs +++ b/datafusion/expr/src/expr_rewriter/order_by.rs @@ -77,6 +77,10 @@ fn rewrite_in_terms_of_projection( // assumption is that each item in exprs, such as "b + c" is // available as an output column named "b + c" expr.transform(|expr| { + if matches!(expr, Expr::Lambda(_)) { + return Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump)) + } + // search for unnormalized names first such as "c1" (such as aliases) if let Some(found) = proj_exprs.iter().find(|a| (**a) == expr) { let (qualifier, field_name) = found.qualified_name(); diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 9e8d6080b82c8..4a1efadccd0ec 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -16,19 +16,24 @@ // under the License. use super::{Between, Expr, Like}; +use crate::expr::FieldMetadata; use crate::expr::{ AggregateFunction, AggregateFunctionParams, Alias, BinaryExpr, Cast, InList, - InSubquery, Placeholder, ScalarFunction, TryCast, Unnest, WindowFunction, + InSubquery, Lambda, Placeholder, ScalarFunction, TryCast, Unnest, WindowFunction, WindowFunctionParams, }; use crate::type_coercion::functions::{ - data_types_with_scalar_udf, fields_with_aggregate_udf, fields_with_window_udf, + fields_with_aggregate_udf, fields_with_window_udf, +}; +use crate::{ + type_coercion::functions::data_types_with_scalar_udf, udf::ReturnFieldArgs, utils, + LogicalPlan, Projection, Subquery, WindowFunctionDefinition, +}; +use arrow::datatypes::FieldRef; +use arrow::{ + compute::can_cast_types, + datatypes::{DataType, Field}, }; -use crate::udf::ReturnFieldArgs; -use crate::{utils, LogicalPlan, Projection, Subquery, WindowFunctionDefinition}; -use arrow::compute::can_cast_types; -use arrow::datatypes::{DataType, Field, FieldRef}; -use datafusion_common::metadata::FieldMetadata; use datafusion_common::{ not_impl_err, plan_datafusion_err, plan_err, Column, DataFusionError, ExprSchema, Result, Spans, TableReference, @@ -229,6 +234,7 @@ impl ExprSchemable for Expr { // Grouping sets do not really have a type and do not appear in projections Ok(DataType::Null) } + Expr::Lambda { .. } => Ok(DataType::Null), } } @@ -347,6 +353,7 @@ impl ExprSchemable for Expr { // in projections Ok(true) } + Expr::Lambda { .. } => Ok(false), } } @@ -535,14 +542,31 @@ impl ExprSchemable for Expr { func.return_field(&new_fields) } + // Expr::Lambda(Lambda { params, body}) => body.to_field(schema), Expr::ScalarFunction(ScalarFunction { func, args }) => { - let (arg_types, fields): (Vec, Vec>) = args + let fields = if args.iter().any(|arg| matches!(arg, Expr::Lambda(_))) { + let lambdas_schemas = func.arguments_expr_schema(args, schema)?; + + std::iter::zip(args, lambdas_schemas) + // .map(|(e, schema)| e.to_field(schema).map(|(_, f)| f)) + .map(|(e, schema)| match e { + Expr::Lambda(Lambda { params: _, body }) => { + body.to_field(&schema).map(|(_, f)| f) + } + _ => e.to_field(&schema).map(|(_, f)| f), + }) + .collect::>>()? + } else { + args.iter() + .map(|e| e.to_field(schema).map(|(_, f)| f)) + .collect::>>()? + }; + + let arg_types = fields .iter() - .map(|e| e.to_field(schema).map(|(_, f)| f)) - .collect::>>()? - .into_iter() - .map(|f| (f.data_type().clone(), f)) - .unzip(); + .map(|f| f.data_type().clone()) + .collect::>(); + // Verify that function is invoked with correct number and type of arguments as defined in `TypeSignature` let new_data_types = data_types_with_scalar_udf(&arg_types, func) .map_err(|err| { @@ -573,9 +597,16 @@ impl ExprSchemable for Expr { _ => None, }) .collect::>(); + + let lambdas = args + .iter() + .map(|e| matches!(e, Expr::Lambda { .. })) + .collect::>(); + let args = ReturnFieldArgs { arg_fields: &new_fields, scalar_arguments: &arguments, + lambdas: &lambdas, }; func.return_field_from_args(args) @@ -600,7 +631,8 @@ impl ExprSchemable for Expr { | Expr::Wildcard { .. } | Expr::GroupingSet(_) | Expr::Placeholder(_) - | Expr::Unnest(_) => Ok(Arc::new(Field::new( + | Expr::Unnest(_) + | Expr::Lambda(_) => Ok(Arc::new(Field::new( &schema_name, self.get_type(schema)?, self.nullable(schema)?, diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 2b7cc9d46ad34..46c7422814ace 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -117,7 +117,12 @@ pub use udaf::{ udaf_default_window_function_schema_name, AggregateUDF, AggregateUDFImpl, ReversedUDAF, SetMonotonicity, StatisticsArgs, }; -pub use udf::{ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl}; +pub use udf::{ + merge_captures_with_args, merge_captures_with_boxed_lazy_args, + merge_captures_with_lazy_args, ReturnFieldArgs, ScalarFunctionArgs, + ScalarFunctionLambdaArg, ScalarUDF, ScalarUDFImpl, ValueOrLambda, ValueOrLambdaField, + ValueOrLambdaParameter, +}; pub use udwf::{LimitEffect, ReversedUDWF, WindowUDF, WindowUDFImpl}; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs index 81846b4f80608..63c535b43ee8b 100644 --- a/datafusion/expr/src/tree_node.rs +++ b/datafusion/expr/src/tree_node.rs @@ -17,17 +17,20 @@ //! Tree node implementation for Logical Expressions -use crate::expr::{ - AggregateFunction, AggregateFunctionParams, Alias, Between, BinaryExpr, Case, Cast, - GroupingSet, InList, InSubquery, Like, Placeholder, ScalarFunction, TryCast, Unnest, - WindowFunction, WindowFunctionParams, +use crate::{ + expr::{ + AggregateFunction, AggregateFunctionParams, Alias, Between, BinaryExpr, Case, + Cast, GroupingSet, InList, InSubquery, Lambda, Like, Placeholder, ScalarFunction, + TryCast, Unnest, WindowFunction, WindowFunctionParams, + }, + Expr, }; -use crate::Expr; - -use datafusion_common::tree_node::{ - Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion, TreeNodeRefContainer, +use datafusion_common::{ + tree_node::{ + Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion, TreeNodeRefContainer, + }, + DFSchema, HashSet, Result, }; -use datafusion_common::Result; /// Implementation of the [`TreeNode`] trait /// @@ -106,6 +109,7 @@ impl TreeNode for Expr { Expr::InList(InList { expr, list, .. }) => { (expr, list).apply_ref_elements(f) } + Expr::Lambda (Lambda{ params: _, body}) => body.apply_elements(f) } } @@ -311,6 +315,686 @@ impl TreeNode for Expr { .update_data(|(new_expr, new_list)| { Expr::InList(InList::new(new_expr, new_list, negated)) }), + Expr::Lambda(Lambda { params, body }) => body + .map_elements(f)? + .update_data(|body| Expr::Lambda(Lambda { params, body })), + }) + } +} + +impl Expr { + /// Similarly to [`Self::rewrite`], rewrites this expr and its inputs using `f`, + /// including lambdas that may appear in expressions such as `array_transform([1, 2], v -> v*2)`. + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] + pub fn rewrite_with_schema< + R: for<'a> TreeNodeRewriterWithPayload = &'a DFSchema>, + >( + self, + schema: &DFSchema, + rewriter: &mut R, + ) -> Result> { + rewriter + .f_down(self, schema)? + .transform_children(|n| match &n { + Expr::ScalarFunction(ScalarFunction { func, args }) + if args.iter().any(|arg| matches!(arg, Expr::Lambda(_))) => + { + let mut lambdas_schemas = func + .arguments_schema_from_logical_args(args, schema)? + .into_iter(); + + n.map_children(|n| { + n.rewrite_with_schema(&lambdas_schemas.next().unwrap(), rewriter) + }) + } + _ => n.map_children(|n| n.rewrite_with_schema(schema, rewriter)), + })? + .transform_parent(|n| rewriter.f_up(n, schema)) + } + + /// Similarly to [`Self::rewrite`], rewrites this expr and its inputs using `f`, + /// including lambdas that may appear in expressions such as `array_transform([1, 2], v -> v*2)`. + pub fn rewrite_with_lambdas_params< + R: for<'a> TreeNodeRewriterWithPayload< + Node = Expr, + Payload<'a> = &'a HashSet, + >, + >( + self, + rewriter: &mut R, + ) -> Result> { + self.rewrite_with_lambdas_params_impl(&HashSet::new(), rewriter) + } + + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] + fn rewrite_with_lambdas_params_impl< + R: for<'a> TreeNodeRewriterWithPayload< + Node = Expr, + Payload<'a> = &'a HashSet, + >, + >( + self, + args: &HashSet, + rewriter: &mut R, + ) -> Result> { + rewriter + .f_down(self, args)? + .transform_children(|n| match n { + Expr::Lambda(Lambda { + ref params, + body: _, + }) => { + let mut args = args.clone(); + + args.extend(params.iter().cloned()); + + n.map_children(|n| { + n.rewrite_with_lambdas_params_impl(&args, rewriter) + }) + } + _ => { + n.map_children(|n| n.rewrite_with_lambdas_params_impl(args, rewriter)) + } + })? + .transform_parent(|n| rewriter.f_up(n, args)) + } + + /// Similarly to [`Self::map_children`], rewrites all lambdas that may + /// appear in expressions such as `array_transform([1, 2], v -> v*2)`. + /// + /// Returns the current node. + pub fn map_children_with_lambdas_params< + F: FnMut(Self, &HashSet) -> Result>, + >( + self, + args: &HashSet, + mut f: F, + ) -> Result> { + match &self { + Expr::Lambda(Lambda { params, body: _ }) => { + let mut args = args.clone(); + + args.extend(params.iter().cloned()); + + self.map_children(|expr| f(expr, &args)) + } + _ => self.map_children(|expr| f(expr, args)), + } + } + + /// Similarly to [`Self::transform_up`], rewrites this expr and its inputs using `f`, + /// including lambdas that may appear in expressions such as `array_transform([1, 2], v -> v*2)`. + pub fn transform_up_with_lambdas_params< + F: FnMut(Self, &HashSet) -> Result>, + >( + self, + mut f: F, + ) -> Result> { + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] + fn transform_up_with_lambdas_params_impl< + F: FnMut(Expr, &HashSet) -> Result>, + >( + node: Expr, + args: &HashSet, + f: &mut F, + ) -> Result> { + node.map_children_with_lambdas_params(args, |node, args| { + transform_up_with_lambdas_params_impl(node, args, f) + })? + .transform_parent(|node| f(node, args)) + /*match &node { + Expr::Lambda(Lambda { params, body: _ }) => { + let mut args = args.clone(); + + args.extend(params.iter().cloned()); + + node.map_children(|n| { + transform_up_with_lambdas_params_impl(n, &args, f) + })? + .transform_parent(|n| f(n, &args)) + } + _ => node + .map_children(|n| transform_up_with_lambdas_params_impl(n, args, f))? + .transform_parent(|n| f(n, args)), + }*/ + } + + transform_up_with_lambdas_params_impl(self, &HashSet::new(), &mut f) + } + + /// Similarly to [`Self::transform_down`], rewrites this expr and its inputs using `f`, + /// including lambdas that may appear in expressions such as `array_transform([1, 2], v -> v*2)`. + pub fn transform_down_with_lambdas_params< + F: FnMut(Self, &HashSet) -> Result>, + >( + self, + mut f: F, + ) -> Result> { + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] + fn transform_down_with_lambdas_params_impl< + F: FnMut(Expr, &HashSet) -> Result>, + >( + node: Expr, + args: &HashSet, + f: &mut F, + ) -> Result> { + f(node, args)?.transform_children(|node| { + node.map_children_with_lambdas_params(args, |node, args| { + transform_down_with_lambdas_params_impl(node, args, f) + }) + }) + } + + transform_down_with_lambdas_params_impl(self, &HashSet::new(), &mut f) + } + + pub fn apply_with_lambdas_params< + 'n, + F: FnMut(&'n Self, &HashSet<&'n str>) -> Result, + >( + &'n self, + mut f: F, + ) -> Result { + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] + fn apply_with_lambdas_params_impl< + 'n, + F: FnMut(&'n Expr, &HashSet<&'n str>) -> Result, + >( + node: &'n Expr, + args: &HashSet<&'n str>, + f: &mut F, + ) -> Result { + match node { + Expr::Lambda(Lambda { params, body: _ }) => { + let mut args = args.clone(); + + args.extend(params.iter().map(|v| v.as_str())); + + f(node, &args)?.visit_children(|| { + node.apply_children(|c| { + apply_with_lambdas_params_impl(c, &args, f) + }) + }) + } + _ => f(node, args)?.visit_children(|| { + node.apply_children(|c| apply_with_lambdas_params_impl(c, args, f)) + }), + } + } + + apply_with_lambdas_params_impl(self, &HashSet::new(), &mut f) + } + + /// Similarly to [`Self::transform`], rewrites this expr and its inputs using `f`, + /// including lambdas that may appear in expressions such as `array_transform([1, 2], v -> v*2)`. + pub fn transform_with_schema< + F: FnMut(Self, &DFSchema) -> Result>, + >( + self, + schema: &DFSchema, + f: F, + ) -> Result> { + self.transform_up_with_schema(schema, f) + } + + /// Similarly to [`Self::transform_up`], rewrites this expr and its inputs using `f`, + /// including lambdas that may appear in expressions such as `array_transform([1, 2], v -> v*2)`. + pub fn transform_up_with_schema< + F: FnMut(Self, &DFSchema) -> Result>, + >( + self, + schema: &DFSchema, + mut f: F, + ) -> Result> { + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] + fn transform_up_with_schema_impl< + F: FnMut(Expr, &DFSchema) -> Result>, + >( + node: Expr, + schema: &DFSchema, + f: &mut F, + ) -> Result> { + node.map_children_with_schema(schema, |n, schema| { + transform_up_with_schema_impl(n, schema, f) + })? + .transform_parent(|n| f(n, schema)) + } + + transform_up_with_schema_impl(self, schema, &mut f) + } + + pub fn map_children_with_schema< + F: FnMut(Self, &DFSchema) -> Result>, + >( + self, + schema: &DFSchema, + mut f: F, + ) -> Result> { + match self { + Expr::ScalarFunction(ref fun) + if fun.args.iter().any(|arg| matches!(arg, Expr::Lambda(_))) => + { + let mut args_schemas = fun + .func + .arguments_schema_from_logical_args(&fun.args, schema)? + .into_iter(); + + self.map_children(|expr| f(expr, &args_schemas.next().unwrap())) + } + _ => self.map_children(|expr| f(expr, schema)), + } + } + + pub fn exists_with_lambdas_params) -> Result>( + &self, + mut f: F, + ) -> Result { + let mut found = false; + + self.apply_with_lambdas_params(|n, lambdas_params| { + if f(n, lambdas_params)? { + found = true; + Ok(TreeNodeRecursion::Stop) + } else { + Ok(TreeNodeRecursion::Continue) + } + })?; + + Ok(found) + } +} + +pub trait ExprWithLambdasRewriter2: Sized { + /// Invoked while traversing down the tree before any children are rewritten. + /// Default implementation returns the node as is and continues recursion. + fn f_down(&mut self, node: Expr, _schema: &DFSchema) -> Result> { + Ok(Transformed::no(node)) + } + + /// Invoked while traversing up the tree after all children have been rewritten. + /// Default implementation returns the node as is and continues recursion. + fn f_up(&mut self, node: Expr, _schema: &DFSchema) -> Result> { + Ok(Transformed::no(node)) + } +} +pub trait TreeNodeRewriterWithPayload: Sized { + type Node; + type Payload<'a>; + + /// Invoked while traversing down the tree before any children are rewritten. + /// Default implementation returns the node as is and continues recursion. + fn f_down<'a>( + &mut self, + node: Self::Node, + _payload: Self::Payload<'a>, + ) -> Result> { + Ok(Transformed::no(node)) + } + + /// Invoked while traversing up the tree after all children have been rewritten. + /// Default implementation returns the node as is and continues recursion. + fn f_up<'a>( + &mut self, + node: Self::Node, + _payload: Self::Payload<'a>, + ) -> Result> { + Ok(Transformed::no(node)) + } +} + +/* +struct LambdaColumnNormalizer<'a> { + existing_qualifiers: HashSet<&'a str>, + alias_generator: AliasGenerator, + lambdas_columns: HashMap>, +} + +impl<'a> LambdaColumnNormalizer<'a> { + fn new(dfschema: &'a DFSchema, expr: &'a Expr) -> Self { + let mut existing_qualifiers: HashSet<&'a str> = dfschema + .field_qualifiers() + .iter() + .flatten() + .map(|tbl| tbl.table()) + .filter(|table| table.starts_with("lambda_")) + .collect(); + + expr.apply(|node| { + if let Expr::Lambda(lambda) = node { + if let Some(qualifier) = &lambda.qualifier { + existing_qualifiers.insert(qualifier); + } + } + + Ok(TreeNodeRecursion::Continue) }) + .unwrap(); + + Self { + existing_qualifiers, + alias_generator: AliasGenerator::new(), + lambdas_columns: HashMap::new(), + } + } +} + +impl TreeNodeRewriter for LambdaColumnNormalizer<'_> { + type Node = Expr; + + fn f_down(&mut self, node: Self::Node) -> Result> { + match node { + Expr::Lambda(mut lambda) => { + let tbl = lambda.qualifier.as_ref().map_or_else( + || loop { + let table = self.alias_generator.next("lambda"); + + if !self.existing_qualifiers.contains(table.as_str()) { + break TableReference::bare(table); + } + }, + |qualifier| TableReference::bare(qualifier.as_str()), + ); + + for param in &lambda.params { + self.lambdas_columns + .entry_ref(param) + .or_default() + .push(tbl.clone()); + } + + if lambda.qualifier.is_none() { + lambda.qualifier = Some(tbl.table().to_owned()); + + Ok(Transformed::yes(Expr::Lambda(lambda))) + } else { + Ok(Transformed::no(Expr::Lambda(lambda))) + } + } + Expr::Column(c) if c.relation.is_none() => { + if let Some(lambda_qualifier) = self.lambdas_columns.get(c.name()) { + Ok(Transformed::yes(Expr::Column( + c.with_relation(lambda_qualifier.last().unwrap().clone()), + ))) + } else { + Ok(Transformed::no(Expr::Column(c))) + } + } + _ => Ok(Transformed::no(node)) + } + } + + fn f_up(&mut self, node: Self::Node) -> Result> { + if let Expr::Lambda(lambda) = &node { + for param in &lambda.params { + match self.lambdas_columns.entry_ref(param) { + EntryRef::Occupied(mut entry) => { + let chain = entry.get_mut(); + + chain.pop(); + + if chain.is_empty() { + entry.remove(); + } + } + EntryRef::Vacant(_) => unreachable!(), + } + } + } + + Ok(Transformed::no(node)) + } +} +*/ + +// helpers used in udf.rs +#[cfg(test)] +pub(crate) mod tests { + use super::TreeNodeRewriterWithPayload; + use crate::{ + col, expr::Lambda, Expr, ScalarUDF, ScalarUDFImpl, ValueOrLambdaParameter, + }; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_common::{ + tree_node::{Transformed, TreeNodeRecursion}, + DFSchema, HashSet, Result, + }; + use datafusion_expr_common::signature::{Signature, Volatility}; + + pub(crate) fn list_list_int() -> DFSchema { + DFSchema::try_from(Schema::new(vec![Field::new( + "v", + DataType::new_list(DataType::new_list(DataType::Int32, false), false), + false, + )])) + .unwrap() + } + + pub(crate) fn list_int() -> DFSchema { + DFSchema::try_from(Schema::new(vec![Field::new( + "v", + DataType::new_list(DataType::Int32, false), + false, + )])) + .unwrap() + } + + fn int() -> DFSchema { + DFSchema::try_from(Schema::new(vec![Field::new("v", DataType::Int32, false)])) + .unwrap() + } + + pub(crate) fn array_transform_udf() -> ScalarUDF { + ScalarUDF::new_from_impl(ArrayTransformFunc::new()) + } + + pub(crate) fn args() -> Vec { + vec![ + col("v"), + Expr::Lambda(Lambda::new( + vec!["v".into()], + array_transform_udf().call(vec![ + col("v"), + Expr::Lambda(Lambda::new(vec!["v".into()], -col("v"))), + ]), + )), + ] + } + + // array_transform(v, |v| -> array_transform(v, |v| -> -v)) + fn array_transform() -> Expr { + array_transform_udf().call(args()) + } + + #[derive(Debug, PartialEq, Eq, Hash)] + pub(crate) struct ArrayTransformFunc { + signature: Signature, + } + + impl ArrayTransformFunc { + pub fn new() -> Self { + Self { + signature: Signature::any(2, Volatility::Immutable), + } + } + } + + impl ScalarUDFImpl for ArrayTransformFunc { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "array_transform" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(arg_types[0].clone()) + } + + fn lambdas_parameters( + &self, + args: &[ValueOrLambdaParameter], + ) -> Result>>> { + let ValueOrLambdaParameter::Value(value_field) = &args[0] else { + unreachable!() + }; + + let DataType::List(field) = value_field.data_type() else { + unreachable!() + }; + + Ok(vec![ + None, + Some(vec![Field::new( + "", + field.data_type().clone(), + field.is_nullable(), + )]), + ]) + } + + fn invoke_with_args( + &self, + _args: crate::ScalarFunctionArgs, + ) -> Result { + unimplemented!() + } + } + + #[test] + fn test_rewrite_with_schema() { + let schema = list_list_int(); + let array_transform = array_transform(); + + let mut rewriter = OkRewriter::default(); + + array_transform + .rewrite_with_schema(&schema, &mut rewriter) + .unwrap(); + + let expected = [ + ( + "f_down array_transform(v, (v) -> array_transform(v, (v) -> (- v)))", + list_list_int(), + ), + ("f_down v", list_list_int()), + ("f_up v", list_list_int()), + ("f_down (v) -> array_transform(v, (v) -> (- v))", list_int()), + ("f_down array_transform(v, (v) -> (- v))", list_int()), + ("f_down v", list_int()), + ("f_up v", list_int()), + ("f_down (v) -> (- v)", int()), + ("f_down (- v)", int()), + ("f_down v", int()), + ("f_up v", int()), + ("f_up (- v)", int()), + ("f_up (v) -> (- v)", int()), + ("f_up array_transform(v, (v) -> (- v))", list_int()), + ("f_up (v) -> array_transform(v, (v) -> (- v))", list_int()), + ( + "f_up array_transform(v, (v) -> array_transform(v, (v) -> (- v)))", + list_list_int(), + ), + ] + .map(|(a, b)| (String::from(a), b)); + + assert_eq!(rewriter.steps, expected) + } + + #[derive(Default)] + struct OkRewriter { + steps: Vec<(String, DFSchema)>, + } + + impl TreeNodeRewriterWithPayload for OkRewriter { + type Node = Expr; + type Payload<'a> = &'a DFSchema; + + fn f_down( + &mut self, + node: Expr, + schema: &DFSchema, + ) -> Result> { + self.steps.push((format!("f_down {node}"), schema.clone())); + + Ok(Transformed::no(node)) + } + + fn f_up( + &mut self, + node: Expr, + schema: &DFSchema, + ) -> Result> { + self.steps.push((format!("f_up {node}"), schema.clone())); + + Ok(Transformed::no(node)) + } + } + + #[test] + fn test_transform_up_with_lambdas_params() { + let mut steps = vec![]; + + array_transform() + .transform_up_with_lambdas_params(|node, params| { + steps.push((node.to_string(), params.clone())); + + Ok(Transformed::no(node)) + }) + .unwrap(); + + let lambdas_params = &HashSet::from([String::from("v")]); + + let expected = [ + ("v", lambdas_params), + ("v", lambdas_params), + ("v", lambdas_params), + ("(- v)", lambdas_params), + ("(v) -> (- v)", lambdas_params), + ("array_transform(v, (v) -> (- v))", lambdas_params), + ("(v) -> array_transform(v, (v) -> (- v))", lambdas_params), + ( + "array_transform(v, (v) -> array_transform(v, (v) -> (- v)))", + lambdas_params, + ), + ] + .map(|(a, b)| (String::from(a), b.clone())); + + assert_eq!(steps, expected); + } + + #[test] + fn test_apply_with_lambdas_params() { + let array_transform = array_transform(); + let mut steps = vec![]; + + array_transform + .apply_with_lambdas_params(|node, params| { + steps.push((node.to_string(), params.clone())); + + Ok(TreeNodeRecursion::Continue) + }) + .unwrap(); + + let expected = [ + ("v", HashSet::from(["v"])), + ("v", HashSet::from(["v"])), + ("v", HashSet::from(["v"])), + ("(- v)", HashSet::from(["v"])), + ("(v) -> (- v)", HashSet::from(["v"])), + ("array_transform(v, (v) -> (- v))", HashSet::from(["v"])), + ("(v) -> array_transform(v, (v) -> (- v))", HashSet::from(["v"])), + ( + "array_transform(v, (v) -> array_transform(v, (v) -> (- v)))", + HashSet::from(["v"]), + ), + ] + .map(|(a, b)| (String::from(a), b)); + + assert_eq!(steps, expected); } } diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index fd54bb13a62f3..74ac1b456ff04 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -18,21 +18,30 @@ //! [`ScalarUDF`]: Scalar User Defined Functions use crate::async_udf::AsyncScalarUDF; -use crate::expr::schema_name_from_exprs_comma_separated_without_space; +use crate::expr::{schema_name_from_exprs_comma_separated_without_space, Lambda}; use crate::simplify::{ExprSimplifyResult, SimplifyInfo}; use crate::sort_properties::{ExprProperties, SortProperties}; use crate::udf_eq::UdfEq; -use crate::{ColumnarValue, Documentation, Expr, Signature}; -use arrow::datatypes::{DataType, Field, FieldRef}; +use crate::{ColumnarValue, Documentation, Expr, ExprSchemable, Signature}; +use arrow::array::{ArrayRef, RecordBatch}; +use arrow::datatypes::{DataType, Field, FieldRef, Fields, Schema}; +use datafusion_common::alias::AliasGenerator; use datafusion_common::config::ConfigOptions; -use datafusion_common::{not_impl_err, ExprSchema, Result, ScalarValue}; +use datafusion_common::tree_node::TreeNodeRecursion; +use datafusion_common::{ + exec_err, not_impl_err, DFSchema, ExprSchema, Result, ScalarValue, +}; use datafusion_expr_common::dyn_eq::{DynEq, DynHash}; use datafusion_expr_common::interval_arithmetic::Interval; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use indexmap::IndexMap; use std::any::Any; +use std::borrow::Cow; use std::cmp::Ordering; +use std::collections::HashMap; use std::fmt::Debug; use std::hash::{Hash, Hasher}; -use std::sync::Arc; +use std::sync::{Arc, LazyLock}; /// Logical representation of a Scalar User Defined Function. /// @@ -343,6 +352,272 @@ impl ScalarUDF { pub fn as_async(&self) -> Option<&AsyncScalarUDF> { self.inner().as_any().downcast_ref::() } + + /// Variation of arguments_from_logical_args that works with arrow Schema's and ScalarFunctionArgMetadata instead + pub(crate) fn arguments_expr_schema<'a>( + &self, + args: &[Expr], + schema: &'a dyn ExprSchema, + ) -> Result> { + self.arguments_scope_with( + &lambda_parameters(args, schema)?, + ExtendableExprSchema::new(schema), + ) + } + + /// Variation of arguments_from_logical_args that works with arrow Schema's and ScalarFunctionArgMetadata instead, + pub fn arguments_arrow_schema<'a>( + &self, + args: &[ValueOrLambdaParameter], + schema: &'a Schema, + ) -> Result>> { + self.arguments_scope_with(args, Cow::Borrowed(schema)) + } + + pub fn arguments_schema_from_logical_args<'a>( + &self, + args: &[Expr], + schema: &'a DFSchema, + ) -> Result>> { + self.arguments_scope_with( + &lambda_parameters(args, schema)?, + Cow::Borrowed(schema), + ) + } + + /// Scalar function supports lambdas as arguments, which will be evaluated with + /// a different schema that of the function itself. This functions returns a vec + /// with the correspoding schema that each argument will run + /// + /// Return a vec with a value for each argument in args that, if it's a value, it's a clone of base_scope, + /// if it's a lambda, it's the return of merge called with the index and the fields from lambdas_parameters + /// updated with names from metadata + fn arguments_scope_with( + &self, + args: &[ValueOrLambdaParameter], + schema: T, + ) -> Result> { + let parameters = self.inner().lambdas_parameters(args)?; + + if parameters.len() != args.len() { + return exec_err!( + "lambdas_schemas: {} lambdas_parameters returned {} values instead of {}", + self.name(), + args.len(), + parameters.len() + ); + } + + std::iter::zip(args, parameters) + .enumerate() + .map(|(i, (arg, parameters))| match (arg, parameters) { + (ValueOrLambdaParameter::Value(_), None) => Ok(schema.clone()), + (ValueOrLambdaParameter::Value(_), Some(_)) => exec_err!("lambdas_schemas: {} argument {} (0-indexed) is a value but lambdas_parameters result treat it as a lambda", self.name(), i), + (ValueOrLambdaParameter::Lambda(_, _), None) => exec_err!("lambdas_schemas: {} argument {} (0-indexed) is a lambda but lambdas_parameters result treat it as a value", self.name(), i), + (ValueOrLambdaParameter::Lambda(names, captures), Some(args)) => { + if names.len() > args.len() { + return exec_err!("lambdas_schemas: {} argument {} (0-indexed), a lambda, supports up to {} arguments, but got {}", self.name(), i, args.len(), names.len()) + } + + let fields = std::iter::zip(*names, args) + .map(|(name, arg)| arg.with_name(name)) + .collect::(); + + if *captures { + schema.extend(fields) + } else { + T::from_fields(fields) + } + } + }) + .collect() + } +} + +pub trait ExtendSchema: Sized { + fn from_fields(params: Fields) -> Result; + fn extend(&self, params: Fields) -> Result; +} + +impl ExtendSchema for DFSchema { + fn from_fields(params: Fields) -> Result { + DFSchema::from_unqualified_fields(params, Default::default()) + } + + fn extend(&self, params: Fields) -> Result { + let qualified_fields = self + .iter() + .map(|(qualifier, field)| { + if params.find(field.name().as_str()).is_none() { + return (qualifier.cloned(), Arc::clone(field)); + } + + let alias_gen = AliasGenerator::new(); + + loop { + let alias = alias_gen.next(field.name().as_str()); + + if params.find(&alias).is_none() + && !self.has_column_with_unqualified_name(&alias) + { + return ( + qualifier.cloned(), + Arc::new(Field::new( + alias, + field.data_type().clone(), + field.is_nullable(), + )), + ); + } + } + }) + .collect(); + + let mut schema = DFSchema::new_with_metadata(qualified_fields, HashMap::new())?; + let fields_schema = DFSchema::from_unqualified_fields(params, HashMap::new())?; + + schema.merge(&fields_schema); + + assert_eq!( + schema.fields().len(), + self.fields().len() + fields_schema.fields().len() + ); + + Ok(schema) + } +} + +impl ExtendSchema for Schema { + fn from_fields(params: Fields) -> Result { + Ok(Schema::new(params)) + } + + fn extend(&self, params: Fields) -> Result { + let mut params2 = params.iter() + .map(|f| (f.name().as_str(), Some(Arc::clone(f)))) + .collect::>(); + + let mut fields = self.fields() + .iter() + .map(|field| { + match params2.get_mut(field.name().as_str()).and_then(|p| p.take()) { + Some(param) => param, + None => Arc::clone(field), + } + }) + .collect::>(); + + fields.extend(params2.into_values().flatten()); + + let fields = self + .fields() + .iter() + .map(|field| { + if params.find(field.name().as_str()).is_none() { + return Arc::clone(field); + } + + let alias_gen = AliasGenerator::new(); + + loop { + let alias = alias_gen.next(field.name().as_str()); + + if params.find(&alias).is_none() + && self.column_with_name(&alias).is_none() + { + return Arc::new(Field::new( + alias, + field.data_type().clone(), + field.is_nullable(), + )); + } + } + }) + .chain(params.iter().cloned()) + .collect::(); + + assert_eq!(fields.len(), self.fields().len() + params.len()); + + Ok(Schema::new_with_metadata(fields, self.metadata.clone())) + } +} + +impl ExtendSchema for Cow<'_, T> { + fn from_fields(params: Fields) -> Result { + Ok(Cow::Owned(T::from_fields(params)?)) + } + + fn extend(&self, params: Fields) -> Result { + Ok(Cow::Owned(self.as_ref().extend(params)?)) + } +} + +impl ExtendSchema for Arc { + fn from_fields(params: Fields) -> Result { + Ok(Arc::new(T::from_fields(params)?)) + } + + fn extend(&self, params: Fields) -> Result { + Ok(Arc::new(self.as_ref().extend(params)?)) + } +} + +impl ExtendSchema for ExtendableExprSchema<'_> { + fn from_fields(params: Fields) -> Result { + static EMPTY_DFSCHEMA: LazyLock = LazyLock::new(DFSchema::empty); + + Ok(ExtendableExprSchema { + fields_chain: vec![params], + outer_schema: &*EMPTY_DFSCHEMA, + }) + } + + fn extend(&self, params: Fields) -> Result { + Ok(ExtendableExprSchema { + fields_chain: std::iter::once(params) + .chain(self.fields_chain.iter().cloned()) + .collect(), + outer_schema: self.outer_schema, + }) + } +} + +/// A `&dyn ExprSchema` wrapper that supports adding the parameters of a lambda +#[derive(Clone, Debug)] +struct ExtendableExprSchema<'a> { + fields_chain: Vec, + outer_schema: &'a dyn ExprSchema, +} + +impl<'a> ExtendableExprSchema<'a> { + fn new(schema: &'a dyn ExprSchema) -> Self { + Self { + fields_chain: vec![], + outer_schema: schema, + } + } +} + +impl ExprSchema for ExtendableExprSchema<'_> { + fn field_from_column(&self, col: &datafusion_common::Column) -> Result<&Field> { + if col.relation.is_none() { + for fields in &self.fields_chain { + if let Some((_index, lambda_param)) = fields.find(&col.name) { + return Ok(lambda_param); + } + } + } + + self.outer_schema.field_from_column(col) + } +} + +#[derive(Clone, Debug)] +pub enum ValueOrLambdaParameter<'a> { + /// A columnar value with the given field + Value(FieldRef), + /// A lambda with the given parameters names and a flag indicating wheter it captures any columns + Lambda(&'a [String], bool), } impl From for ScalarUDF @@ -359,6 +634,7 @@ where #[derive(Debug, Clone)] pub struct ScalarFunctionArgs { /// The evaluated arguments to the function + /// If it's a lambda, will be `ColumnarValue::Scalar(ScalarValue::Null)` pub args: Vec, /// Field associated with each arg, if it exists pub arg_fields: Vec, @@ -370,6 +646,30 @@ pub struct ScalarFunctionArgs { pub return_field: FieldRef, /// The config options at execution time pub config_options: Arc, + /// The lambdas passed to the function + /// If it's not a lambda it will be `None` + pub lambdas: Option>>, +} + +/// A lambda argument to a ScalarFunction +#[derive(Clone, Debug)] +pub struct ScalarFunctionLambdaArg { + /// The parameters defined in this lambda + /// + /// For example, for `array_transform([2], v -> -v)`, + /// this will be `vec![Field::new("v", DataType::Int32, true)]` + pub params: Vec, + /// The body of the lambda + /// + /// For example, for `array_transform([2], v -> -v)`, + /// this will be the physical expression of `-v` + pub body: Arc, + /// A RecordBatch containing at least the captured columns inside this lambda body, if any + /// Note that it may contain additional, non-specified columns, but that's implementation detail + /// + /// For example, for `array_transform([2], v -> v + a + b)`, + /// this will be a `RecordBatch` with two columns, `a` and `b` + pub captures: Option, } impl ScalarFunctionArgs { @@ -378,6 +678,25 @@ impl ScalarFunctionArgs { pub fn return_type(&self) -> &DataType { self.return_field.data_type() } + + pub fn to_lambda_args(&self) -> Vec> { + match &self.lambdas { + Some(lambdas) => std::iter::zip(&self.args, lambdas) + .map(|(arg, lambda)| match lambda { + Some(lambda) => ValueOrLambda::Lambda(lambda), + None => ValueOrLambda::Value(arg), + }) + .collect(), + None => self.args.iter().map(ValueOrLambda::Value).collect(), + } + } +} + +// An argument to a ScalarUDF that supports lambdas +#[derive(Debug)] +pub enum ValueOrLambda<'a> { + Value(&'a ColumnarValue), + Lambda(&'a ScalarFunctionLambdaArg), } /// Information about arguments passed to the function @@ -390,6 +709,12 @@ impl ScalarFunctionArgs { #[derive(Debug)] pub struct ReturnFieldArgs<'a> { /// The data types of the arguments to the function + /// + /// If argument `i` to the function is a lambda, it will be the field returned by the + /// lambda when executed with the arguments returned from `ScalarUDFImpl::lambdas_parameters` + /// + /// For example, with `array_transform([1], v -> v == 5)` + /// this field will be `[Field::new("", DataType::List(DataType::Int32), false), Field::new("", DataType::Boolean, false)]` pub arg_fields: &'a [FieldRef], /// Is argument `i` to the function a scalar (constant)? /// @@ -398,6 +723,36 @@ pub struct ReturnFieldArgs<'a> { /// For example, if a function is called like `my_function(column_a, 5)` /// this field will be `[None, Some(ScalarValue::Int32(Some(5)))]` pub scalar_arguments: &'a [Option<&'a ScalarValue>], + /// Is argument `i` to the function a lambda? + /// + /// For example, with `array_transform([1], v -> v == 5)` + /// this field will be `[false, true]` + pub lambdas: &'a [bool], +} + +/// A tagged Field indicating whether it correspond to a value or a lambda argument +#[derive(Debug)] +pub enum ValueOrLambdaField<'a> { + /// The Field of a ColumnarValue argument + Value(&'a FieldRef), + /// The Field of the return of the lambda body when evaluated with the parameters from ScalarUDF::lambda_parameters + Lambda(&'a FieldRef), +} + +impl<'a> ReturnFieldArgs<'a> { + /// Based on self.lambdas, encodes self.arg_fields to tagged enums + /// indicating whether it correspond to a value or a lambda argument + pub fn to_lambda_args(&self) -> Vec> { + std::iter::zip(self.arg_fields, self.lambdas) + .map(|(field, is_lambda)| { + if *is_lambda { + ValueOrLambdaField::Lambda(field) + } else { + ValueOrLambdaField::Value(field) + } + }) + .collect() + } } /// Trait for implementing user defined scalar functions. @@ -841,6 +1196,14 @@ pub trait ScalarUDFImpl: Debug + DynEq + DynHash + Send + Sync { fn documentation(&self) -> Option<&Documentation> { None } + + /// Returns the parameters that any lambda supports + fn lambdas_parameters( + &self, + args: &[ValueOrLambdaParameter], + ) -> Result>>> { + Ok(vec![None; args.len()]) + } } /// ScalarUDF that adds an alias to the underlying function. It is better to @@ -959,6 +1322,118 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl { fn documentation(&self) -> Option<&Documentation> { self.inner.documentation() } + + fn lambdas_parameters( + &self, + args: &[ValueOrLambdaParameter], + ) -> Result>>> { + self.inner.lambdas_parameters(args) + } +} + +fn lambda_parameters<'a>( + args: &'a [Expr], + schema: &dyn ExprSchema, +) -> Result>> { + args.iter() + .map(|e| match e { + Expr::Lambda(Lambda { params, body: _ }) => { + let mut captures = false; + + e.apply_with_lambdas_params(|expr, lambdas_params| match expr { + Expr::Column(c) if !c.is_lambda_parameter(lambdas_params) => { + captures = true; + + Ok(TreeNodeRecursion::Stop) + } + _ => Ok(TreeNodeRecursion::Continue), + }) + .unwrap(); + + Ok(ValueOrLambdaParameter::Lambda(params.as_slice(), captures)) + } + _ => Ok(ValueOrLambdaParameter::Value(e.to_field(schema)?.1)), + }) + .collect() +} + +/// Merge the lambda body captured columns with it's arguments +/// Datafusion relies on an unspecified field ordering implemented in this function +/// As such, this is the only correct way to merge the captured values with the arguments +/// The number of args should not be lower than the number of params +/// +/// See also merge_captures_with_lazy_args and merge_captures_with_boxed_lazy_args that lazily +/// computes only the necessary arguments to match the number of params +pub fn merge_captures_with_args( + captures: Option<&RecordBatch>, + params: &[FieldRef], + args: &[ArrayRef], +) -> Result { + if args.len() < params.len() { + return exec_err!( + "merge_captures_with_args called with {} params but with {} args", + params.len(), + args.len() + ); + } + + // the order of the merged batch must be kept in sync with ScalarFunction::lambdas_schemas variants + let (fields, columns) = match captures { + Some(captures) => { + let fields = captures + .schema() + .fields() + .iter() + .chain(params) + .cloned() + .collect::>(); + + let columns = [captures.columns(), args].concat(); + + (fields, columns) + } + None => (params.to_vec(), args.to_vec()), + }; + + Ok(RecordBatch::try_new( + Arc::new(Schema::new(fields)), + columns, + )?) +} + +/// Lazy version of merge_captures_with_args that receives closures to compute the arguments, +/// and calls only the necessary to match the number of params +pub fn merge_captures_with_lazy_args( + captures: Option<&RecordBatch>, + params: &[FieldRef], + args: &[&dyn Fn() -> Result], +) -> Result { + merge_captures_with_args( + captures, + params, + &args + .iter() + .take(params.len()) + .map(|arg| arg()) + .collect::>>()?, + ) +} + +/// Variation of merge_captures_with_lazy_args that take boxed closures +pub fn merge_captures_with_boxed_lazy_args( + captures: Option<&RecordBatch>, + params: &[FieldRef], + args: &[Box Result>], +) -> Result { + merge_captures_with_args( + captures, + params, + &args + .iter() + .take(params.len()) + .map(|arg| arg()) + .collect::>>()?, + ) } #[cfg(test)] @@ -1039,4 +1514,83 @@ mod tests { value.hash(hasher); hasher.finish() } + + use std::borrow::Cow; + + use arrow::datatypes::Fields; + + use crate::{ + tree_node::tests::{args, list_int, list_list_int, array_transform_udf}, + udf::{lambda_parameters, ExtendableExprSchema}, + }; + + #[test] + fn test_arguments_expr_schema() { + let args = args(); + let schema = list_list_int(); + + let schemas = array_transform_udf() + .arguments_expr_schema(&args, &schema) + .unwrap() + .into_iter() + .map(|s| format!("{s:?}")) + .collect::>(); + + let mut lambdas_parameters = array_transform_udf() + .inner() + .lambdas_parameters(&lambda_parameters(&args, &schema).unwrap()) + .unwrap(); + + assert_eq!( + schemas, + &[ + format!("{}", &list_list_int()), + format!( + "{:?}", + ExtendableExprSchema { + fields_chain: vec![Fields::from( + lambdas_parameters[0].take().unwrap() + )], + outer_schema: &list_list_int() + } + ), + ] + ) + } + + #[test] + fn test_arguments_arrow_schema() { + let list_int = list_int(); + let list_list_int = list_list_int(); + + let schemas = array_transform_udf() + .arguments_arrow_schema( + &lambda_parameters(&args(), &list_list_int).unwrap(), + //&[HashSet::new(), HashSet::from([0])], + list_list_int.as_arrow(), + ) + .unwrap(); + + assert_eq!( + schemas, + &[ + Cow::Borrowed(list_list_int.as_arrow()), + Cow::Owned(list_int.as_arrow().clone()) + ] + ) + } + + #[test] + fn test_arguments_schema_from_logical_args() { + let list_list_int = list_list_int(); + + let schemas = array_transform_udf() + .arguments_schema_from_logical_args(&args(), &list_list_int) + .unwrap(); + + assert_eq!( + schemas, + &[Cow::Borrowed(&list_list_int), Cow::Owned(list_int())] + ) + } } diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index cd733e0a130a9..93fcfaef882ff 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -266,10 +266,12 @@ pub fn grouping_set_to_exprlist(group_expr: &[Expr]) -> Result> { /// Recursively walk an expression tree, collecting the unique set of columns /// referenced in the expression pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet) -> Result<()> { - expr.apply(|expr| { + expr.apply_with_lambdas_params(|expr, lambdas_params| { match expr { Expr::Column(qc) => { - accum.insert(qc.clone()); + if qc.relation.is_some() || !lambdas_params.contains(qc.name()) { + accum.insert(qc.clone()); + } } // Use explicit pattern match instead of a default // implementation, so that in the future if someone adds @@ -307,7 +309,8 @@ pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet) -> Result<()> { | Expr::ScalarSubquery(_) | Expr::Wildcard { .. } | Expr::Placeholder(_) - | Expr::OuterReferenceColumn { .. } => {} + | Expr::OuterReferenceColumn { .. } + | Expr::Lambda { .. } => {} } Ok(TreeNodeRecursion::Continue) }) @@ -650,6 +653,7 @@ where /// Search an `Expr`, and all of its nested `Expr`'s, for any that pass the /// provided test. The returned `Expr`'s are deduplicated and returned in order /// of appearance (depth first). +/// todo: document about that columns may refer to a lambda parameter? fn find_exprs_in_expr(expr: &Expr, test_fn: &F) -> Vec where F: Fn(&Expr) -> bool, @@ -672,6 +676,7 @@ where } /// Recursively inspect an [`Expr`] and all its children. +/// todo: document about that columns may refer to a lambda parameter? pub fn inspect_expr_pre(expr: &Expr, mut f: F) -> Result<(), E> where F: FnMut(&Expr) -> Result<(), E>, @@ -743,13 +748,19 @@ pub fn columnize_expr(e: Expr, input: &LogicalPlan) -> Result { _ => return Ok(e), }; let exprs_map: HashMap<&Expr, Column> = output_exprs.into_iter().collect(); - e.transform_down(|node: Expr| match exprs_map.get(&node) { - Some(column) => Ok(Transformed::new( - Expr::Column(column.clone()), - true, - TreeNodeRecursion::Jump, - )), - None => Ok(Transformed::no(node)), + e.transform_down_with_lambdas_params(|node: Expr, lambdas_params| { + if matches!(&node, Expr::Column(c) if c.is_lambda_parameter(lambdas_params)) { + return Ok(Transformed::no(node)); + } + + match exprs_map.get(&node) { + Some(column) => Ok(Transformed::new( + Expr::Column(column.clone()), + true, + TreeNodeRecursion::Jump, + )), + None => Ok(Transformed::no(node)), + } }) .data() } @@ -766,9 +777,11 @@ pub fn find_column_exprs(exprs: &[Expr]) -> Vec { pub(crate) fn find_columns_referenced_by_expr(e: &Expr) -> Vec { let mut exprs = vec![]; - e.apply(|expr| { + e.apply_with_lambdas_params(|expr, lambdas_params| { if let Expr::Column(c) = expr { - exprs.push(c.clone()) + if !c.is_lambda_parameter(lambdas_params) { + exprs.push(c.clone()) + } } Ok(TreeNodeRecursion::Continue) }) @@ -797,9 +810,9 @@ pub(crate) fn find_column_indexes_referenced_by_expr( schema: &DFSchemaRef, ) -> Vec { let mut indexes = vec![]; - e.apply(|expr| { + e.apply_with_lambdas_params(|expr, lambdas_params| { match expr { - Expr::Column(qc) => { + Expr::Column(qc) if !qc.is_lambda_parameter(lambdas_params) => { if let Ok(idx) = schema.index_of_column(qc) { indexes.push(idx); } diff --git a/datafusion/ffi/src/udf/mod.rs b/datafusion/ffi/src/udf/mod.rs index 5e59cfc5ecb07..400ad44696047 100644 --- a/datafusion/ffi/src/udf/mod.rs +++ b/datafusion/ffi/src/udf/mod.rs @@ -33,7 +33,7 @@ use arrow::{ }; use arrow_schema::FieldRef; use datafusion::config::ConfigOptions; -use datafusion::logical_expr::ReturnFieldArgs; +use datafusion::{common::exec_err, logical_expr::ReturnFieldArgs}; use datafusion::{ error::DataFusionError, logical_expr::type_coercion::functions::data_types_with_scalar_udf, @@ -210,6 +210,7 @@ unsafe extern "C" fn invoke_with_args_fn_wrapper( return_field, // TODO: pass config options: https://github.com/apache/datafusion/issues/17035 config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = rresult_return!(udf @@ -382,10 +383,15 @@ impl ScalarUDFImpl for ForeignScalarUDF { arg_fields, number_rows, return_field, + lambdas, // TODO: pass config options: https://github.com/apache/datafusion/issues/17035 config_options: _config_options, } = invoke_args; + if lambdas.is_some_and(|lambdas| lambdas.iter().any(|l| l.is_some())) { + return exec_err!("ForeignScalarUDF doesn't support lambdas"); + } + let args = args .into_iter() .map(|v| v.to_array(number_rows)) diff --git a/datafusion/ffi/src/udf/return_type_args.rs b/datafusion/ffi/src/udf/return_type_args.rs index c437c9537be6f..d5cbfff1d3a4b 100644 --- a/datafusion/ffi/src/udf/return_type_args.rs +++ b/datafusion/ffi/src/udf/return_type_args.rs @@ -21,7 +21,7 @@ use abi_stable::{ }; use arrow_schema::FieldRef; use datafusion::{ - common::exec_datafusion_err, error::DataFusionError, logical_expr::ReturnFieldArgs, + common::{exec_datafusion_err, exec_err}, error::DataFusionError, logical_expr::ReturnFieldArgs, scalar::ScalarValue, }; @@ -42,6 +42,10 @@ impl TryFrom> for FFI_ReturnFieldArgs { type Error = DataFusionError; fn try_from(value: ReturnFieldArgs) -> Result { + if value.lambdas.iter().any(|l| *l) { + return exec_err!("FFI_ReturnFieldArgs doesn't support lambdas") + } + let arg_fields = vec_fieldref_to_rvec_wrapped(value.arg_fields)?; let scalar_arguments: Result, Self::Error> = value .scalar_arguments @@ -77,6 +81,7 @@ pub struct ForeignReturnFieldArgsOwned { pub struct ForeignReturnFieldArgs<'a> { arg_fields: &'a [FieldRef], scalar_arguments: Vec>, + lambdas: Vec, // currently always false, used to return a reference in From<&Self> for ReturnFieldArgs } impl TryFrom<&FFI_ReturnFieldArgs> for ForeignReturnFieldArgsOwned { @@ -116,6 +121,7 @@ impl<'a> From<&'a ForeignReturnFieldArgsOwned> for ForeignReturnFieldArgs<'a> { .iter() .map(|opt| opt.as_ref()) .collect(), + lambdas: vec![false; value.arg_fields.len()] } } } @@ -125,6 +131,7 @@ impl<'a> From<&'a ForeignReturnFieldArgs<'a>> for ReturnFieldArgs<'a> { ReturnFieldArgs { arg_fields: value.arg_fields, scalar_arguments: &value.scalar_arguments, + lambdas: &value.lambdas, } } } diff --git a/datafusion/functions-nested/src/array_transform.rs b/datafusion/functions-nested/src/array_transform.rs new file mode 100644 index 0000000000000..700fed477b4cb --- /dev/null +++ b/datafusion/functions-nested/src/array_transform.rs @@ -0,0 +1,266 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`ScalarUDFImpl`] definitions for array_transform function. + +use arrow::{ + array::{Array, ArrayRef, AsArray, FixedSizeListArray, LargeListArray, ListArray}, + compute::take_record_batch, + datatypes::{DataType, Field}, +}; +use datafusion_common::{ + HashSet, Result, exec_err, internal_err, tree_node::{Transformed, TreeNode}, utils::{elements_indices, list_indices, list_values, take_function_args} +}; +use datafusion_expr::{ + ColumnarValue, Documentation, Expr, ScalarFunctionArgs, ScalarUDFImpl, Signature, ValueOrLambda, ValueOrLambdaField, ValueOrLambdaParameter, Volatility, expr::Lambda, merge_captures_with_lazy_args +}; +use datafusion_macros::user_doc; +use std::{any::Any, sync::Arc}; + +make_udf_expr_and_func!( + ArrayTransform, + array_transform, + array lambda, + "transforms the values of a array", + array_transform_udf +); + +#[user_doc( + doc_section(label = "Array Functions"), + description = "transforms the values of a array", + syntax_example = "array_transform(array, x -> x*2)", + sql_example = r#"```sql +> select array_transform([1, 2, 3, 4, 5], x -> x*2); ++-------------------------------------------+ +| array_transform([1, 2, 3, 4, 5], x -> x*2) | ++-------------------------------------------+ +| [2, 4, 6, 8, 10] | ++-------------------------------------------+ +```"#, + argument( + name = "array", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ), + argument(name = "lambda", description = "Lambda") +)] +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct ArrayTransform { + signature: Signature, + aliases: Vec, +} + +impl Default for ArrayTransform { + fn default() -> Self { + Self::new() + } +} + +impl ArrayTransform { + pub fn new() -> Self { + Self { + signature: Signature::any(2, Volatility::Immutable), + aliases: vec![String::from("list_transform")], + } + } +} + +impl ScalarUDFImpl for ArrayTransform { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "array_transform" + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_type called instead of return_field_from_args") + } + + fn return_field_from_args( + &self, + args: datafusion_expr::ReturnFieldArgs, + ) -> Result> { + let args = args.to_lambda_args(); + + let [ValueOrLambdaField::Value(list), ValueOrLambdaField::Lambda(lambda)] = + take_function_args(self.name(), &args)? + else { + return exec_err!( + "{} expects a value follewed by a lambda, got {:?}", + self.name(), + args + ); + }; + + //TODO: should metadata be passed? If so, with the same keys or prefixed/suffixed? + + // lambda is the resulting field of executing the lambda body + // with the parameters returned in lambdas_parameters + let field = Arc::new(Field::new( + Field::LIST_FIELD_DEFAULT_NAME, + lambda.data_type().clone(), + lambda.is_nullable(), + )); + + let return_type = match list.data_type() { + DataType::List(_) => DataType::List(field), + DataType::LargeList(_) => DataType::LargeList(field), + DataType::FixedSizeList(_, size) => DataType::FixedSizeList(field, *size), + _ => unreachable!(), + }; + + Ok(Arc::new(Field::new("", return_type, list.is_nullable()))) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + // args.lambda_args allows the convenient match below, instead of inspecting both args.args and args.lambdas + let lambda_args = args.to_lambda_args(); + let [list_value, lambda] = take_function_args(self.name(), &lambda_args)?; + + let (ValueOrLambda::Value(list_value), ValueOrLambda::Lambda(lambda)) = + (list_value, lambda) + else { + return exec_err!( + "{} expects a value followed by a lambda, got {:?}", + self.name(), + &lambda_args + ); + }; + + let list_array = list_value.to_array(args.number_rows)?; + + // if any column got captured, we need to adjust it to the values arrays, + // duplicating values of list with mulitple values and removing values of empty lists + // list_indices is not cheap so is important to avoid it when no column is captured + let adjusted_captures = lambda + .captures + .as_ref() + .map(|captures| take_record_batch(captures, &list_indices(&list_array)?)) + .transpose()?; + + // use closures and merge_captures_with_lazy_args so that it calls only the needed ones based on the number of arguments + // avoiding unnecessary computations + let values_param = || Ok(Arc::clone(list_values(&list_array)?)); + let indices_param = || elements_indices(&list_array); + + // the order of the merged schema is an unspecified implementation detail that may change in the future, + // using this function is the correct way to merge as it return the correct ordering and will change in sync + // the implementation without the need for fixes. It also computes only the parameters requested + let lambda_batch = merge_captures_with_lazy_args( + adjusted_captures.as_ref(), + &lambda.params, // ScalarUDF already merged the fields returned in lambdas_parameters with the parameters names definied in the lambda, so we don't need to + &[&values_param, &indices_param], + )?; + + // call the transforming expression with the record batch composed of the list values merged with captured columns + let transformed_values = lambda + .body + .evaluate(&lambda_batch)? + .into_array(lambda_batch.num_rows())?; + + let field = match args.return_field.data_type() { + DataType::List(field) + | DataType::LargeList(field) + | DataType::FixedSizeList(field, _) => Arc::clone(field), + _ => { + return exec_err!( + "{} expected ScalarFunctionArgs.return_field to be a list, got {}", + self.name(), + args.return_field + ) + } + }; + + let transformed_list = match list_array.data_type() { + DataType::List(_) => { + let list = list_array.as_list(); + + Arc::new(ListArray::new( + field, + list.offsets().clone(), + transformed_values, + list.nulls().cloned(), + )) as ArrayRef + } + DataType::LargeList(_) => { + let large_list = list_array.as_list(); + + Arc::new(LargeListArray::new( + field, + large_list.offsets().clone(), + transformed_values, + large_list.nulls().cloned(), + )) + } + DataType::FixedSizeList(_, value_length) => { + Arc::new(FixedSizeListArray::new( + field, + *value_length, + transformed_values, + list_array.as_fixed_size_list().nulls().cloned(), + )) + } + other => exec_err!("expected list, got {other}")?, + }; + + Ok(ColumnarValue::Array(transformed_list)) + } + + fn lambdas_parameters( + &self, + args: &[ValueOrLambdaParameter], + ) -> Result>>> { + let [ValueOrLambdaParameter::Value(list), ValueOrLambdaParameter::Lambda(_, _)] = + args + else { + return exec_err!( + "{} expects a value follewed by a lambda, got {:?}", + self.name(), + args + ); + }; + + let (field, index_type) = match list.data_type() { + DataType::List(field) => (field, DataType::Int32), + DataType::LargeList(field) => (field, DataType::Int64), + DataType::FixedSizeList(field, _) => (field, DataType::Int32), + _ => return exec_err!("expected list, got {list}"), + }; + + // we don't need to omit the index in the case the lambda don't specify, e.g. array_transform([], v -> v*2), + // nor check whether the lambda contains more than two parameters, e.g. array_transform([], (v, i, j) -> v+i+j), + // as datafusion will do that for us + let value = Field::new("value", field.data_type().clone(), field.is_nullable()) + .with_metadata(field.metadata().clone()); + let index = Field::new("index", index_type, false); + + Ok(vec![None, Some(vec![value, index])]) + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} diff --git a/datafusion/optimizer/src/analyzer/function_rewrite.rs b/datafusion/optimizer/src/analyzer/function_rewrite.rs index c6bf14ebce2e3..0e5e602f8238e 100644 --- a/datafusion/optimizer/src/analyzer/function_rewrite.rs +++ b/datafusion/optimizer/src/analyzer/function_rewrite.rs @@ -19,7 +19,7 @@ use super::AnalyzerRule; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::tree_node::Transformed; use datafusion_common::{DFSchema, Result}; use crate::utils::NamePreserver; @@ -64,15 +64,16 @@ impl ApplyFunctionRewrites { let original_name = name_preserver.save(&expr); // recursively transform the expression, applying the rewrites at each step - let transformed_expr = expr.transform_up(|expr| { - let mut result = Transformed::no(expr); - for rewriter in self.function_rewrites.iter() { - result = result.transform_data(|expr| { - rewriter.rewrite(expr, &schema, options) - })?; - } - Ok(result) - })?; + let transformed_expr = + expr.transform_up_with_schema(&schema, |expr, schema| { + let mut result = Transformed::no(expr); + for rewriter in self.function_rewrites.iter() { + result = result.transform_data(|expr| { + rewriter.rewrite(expr, schema, options) + })?; + } + Ok(result) + })?; Ok(transformed_expr.update_data(|expr| original_name.restore(expr))) }) diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 4fb0f8553b4ba..1b82182e8600f 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -20,6 +20,7 @@ use std::sync::Arc; use datafusion_expr::binary::BinaryTypeCoercer; +use datafusion_expr::tree_node::TreeNodeRewriterWithPayload; use itertools::{izip, Itertools as _}; use arrow::datatypes::{DataType, Field, IntervalUnit, Schema}; @@ -27,7 +28,7 @@ use arrow::datatypes::{DataType, Field, IntervalUnit, Schema}; use crate::analyzer::AnalyzerRule; use crate::utils::NamePreserver; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter}; +use datafusion_common::tree_node::Transformed; use datafusion_common::{ exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_datafusion_err, plan_err, Column, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, @@ -140,7 +141,7 @@ fn analyze_internal( // apply coercion rewrite all expressions in the plan individually plan.map_expressions(|expr| { let original_name = name_preserver.save(&expr); - expr.rewrite(&mut expr_rewrite) + expr.rewrite_with_schema(&schema, &mut expr_rewrite) .map(|transformed| transformed.update_data(|e| original_name.restore(e))) })? // some plans need extra coercion after their expressions are coerced @@ -304,10 +305,11 @@ impl<'a> TypeCoercionRewriter<'a> { } } -impl TreeNodeRewriter for TypeCoercionRewriter<'_> { +impl TreeNodeRewriterWithPayload for TypeCoercionRewriter<'_> { type Node = Expr; + type Payload<'a> = &'a DFSchema; - fn f_up(&mut self, expr: Expr) -> Result> { + fn f_up(&mut self, expr: Expr, schema: &DFSchema) -> Result> { match expr { Expr::Unnest(_) => not_impl_err!( "Unnest should be rewritten to LogicalPlan::Unnest before type coercion" @@ -318,7 +320,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { spans, }) => { let new_plan = - analyze_internal(self.schema, Arc::unwrap_or_clone(subquery))?.data; + analyze_internal(schema, Arc::unwrap_or_clone(subquery))?.data; Ok(Transformed::yes(Expr::ScalarSubquery(Subquery { subquery: Arc::new(new_plan), outer_ref_columns, @@ -327,7 +329,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { } Expr::Exists(Exists { subquery, negated }) => { let new_plan = analyze_internal( - self.schema, + schema, Arc::unwrap_or_clone(subquery.subquery), )? .data; @@ -346,11 +348,11 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { negated, }) => { let new_plan = analyze_internal( - self.schema, + schema, Arc::unwrap_or_clone(subquery.subquery), )? .data; - let expr_type = expr.get_type(self.schema)?; + let expr_type = expr.get_type(schema)?; let subquery_type = new_plan.schema().field(0).data_type(); let common_type = comparison_coercion(&expr_type, subquery_type).ok_or( plan_datafusion_err!( @@ -363,32 +365,32 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { spans: subquery.spans, }; Ok(Transformed::yes(Expr::InSubquery(InSubquery::new( - Box::new(expr.cast_to(&common_type, self.schema)?), + Box::new(expr.cast_to(&common_type, schema)?), cast_subquery(new_subquery, &common_type)?, negated, )))) } Expr::Not(expr) => Ok(Transformed::yes(not(get_casted_expr_for_bool_op( *expr, - self.schema, + schema, )?))), Expr::IsTrue(expr) => Ok(Transformed::yes(is_true( - get_casted_expr_for_bool_op(*expr, self.schema)?, + get_casted_expr_for_bool_op(*expr, schema)?, ))), Expr::IsNotTrue(expr) => Ok(Transformed::yes(is_not_true( - get_casted_expr_for_bool_op(*expr, self.schema)?, + get_casted_expr_for_bool_op(*expr, schema)?, ))), Expr::IsFalse(expr) => Ok(Transformed::yes(is_false( - get_casted_expr_for_bool_op(*expr, self.schema)?, + get_casted_expr_for_bool_op(*expr, schema)?, ))), Expr::IsNotFalse(expr) => Ok(Transformed::yes(is_not_false( - get_casted_expr_for_bool_op(*expr, self.schema)?, + get_casted_expr_for_bool_op(*expr, schema)?, ))), Expr::IsUnknown(expr) => Ok(Transformed::yes(is_unknown( - get_casted_expr_for_bool_op(*expr, self.schema)?, + get_casted_expr_for_bool_op(*expr, schema)?, ))), Expr::IsNotUnknown(expr) => Ok(Transformed::yes(is_not_unknown( - get_casted_expr_for_bool_op(*expr, self.schema)?, + get_casted_expr_for_bool_op(*expr, schema)?, ))), Expr::Like(Like { negated, @@ -397,8 +399,8 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { escape_char, case_insensitive, }) => { - let left_type = expr.get_type(self.schema)?; - let right_type = pattern.get_type(self.schema)?; + let left_type = expr.get_type(schema)?; + let right_type = pattern.get_type(schema)?; let coerced_type = like_coercion(&left_type, &right_type).ok_or_else(|| { let op_name = if case_insensitive { "ILIKE" @@ -411,9 +413,9 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { })?; let expr = match left_type { DataType::Dictionary(_, inner) if *inner == DataType::Utf8 => expr, - _ => Box::new(expr.cast_to(&coerced_type, self.schema)?), + _ => Box::new(expr.cast_to(&coerced_type, schema)?), }; - let pattern = Box::new(pattern.cast_to(&coerced_type, self.schema)?); + let pattern = Box::new(pattern.cast_to(&coerced_type, schema)?); Ok(Transformed::yes(Expr::Like(Like::new( negated, expr, @@ -424,7 +426,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { } Expr::BinaryExpr(BinaryExpr { left, op, right }) => { let (left, right) = - self.coerce_binary_op(*left, self.schema, op, *right, self.schema)?; + self.coerce_binary_op(*left, schema, op, *right, schema)?; Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr::new( Box::new(left), op, @@ -437,15 +439,15 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { low, high, }) => { - let expr_type = expr.get_type(self.schema)?; - let low_type = low.get_type(self.schema)?; + let expr_type = expr.get_type(schema)?; + let low_type = low.get_type(schema)?; let low_coerced_type = comparison_coercion(&expr_type, &low_type) .ok_or_else(|| { internal_datafusion_err!( "Failed to coerce types {expr_type} and {low_type} in BETWEEN expression" ) })?; - let high_type = high.get_type(self.schema)?; + let high_type = high.get_type(schema)?; let high_coerced_type = comparison_coercion(&expr_type, &high_type) .ok_or_else(|| { internal_datafusion_err!( @@ -460,10 +462,10 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { ) })?; Ok(Transformed::yes(Expr::Between(Between::new( - Box::new(expr.cast_to(&coercion_type, self.schema)?), + Box::new(expr.cast_to(&coercion_type, schema)?), negated, - Box::new(low.cast_to(&coercion_type, self.schema)?), - Box::new(high.cast_to(&coercion_type, self.schema)?), + Box::new(low.cast_to(&coercion_type, schema)?), + Box::new(high.cast_to(&coercion_type, schema)?), )))) } Expr::InList(InList { @@ -471,10 +473,10 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { list, negated, }) => { - let expr_data_type = expr.get_type(self.schema)?; + let expr_data_type = expr.get_type(schema)?; let list_data_types = list .iter() - .map(|list_expr| list_expr.get_type(self.schema)) + .map(|list_expr| list_expr.get_type(schema)) .collect::>>()?; let result_type = get_coerce_type_for_list(&expr_data_type, &list_data_types); @@ -484,11 +486,11 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { ), Some(coerced_type) => { // find the coerced type - let cast_expr = expr.cast_to(&coerced_type, self.schema)?; + let cast_expr = expr.cast_to(&coerced_type, schema)?; let cast_list_expr = list .into_iter() .map(|list_expr| { - list_expr.cast_to(&coerced_type, self.schema) + list_expr.cast_to(&coerced_type, schema) }) .collect::>>()?; Ok(Transformed::yes(Expr::InList(InList ::new( @@ -500,13 +502,13 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { } } Expr::Case(case) => { - let case = coerce_case_expression(case, self.schema)?; + let case = coerce_case_expression(case, schema)?; Ok(Transformed::yes(Expr::Case(case))) } Expr::ScalarFunction(ScalarFunction { func, args }) => { let new_expr = coerce_arguments_for_signature_with_scalar_udf( args, - self.schema, + schema, &func, )?; Ok(Transformed::yes(Expr::ScalarFunction( @@ -526,7 +528,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { }) => { let new_expr = coerce_arguments_for_signature_with_aggregate_udf( args, - self.schema, + schema, &func, )?; Ok(Transformed::yes(Expr::AggregateFunction( @@ -555,13 +557,13 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { }, } = *window_fun; let window_frame = - coerce_window_frame(window_frame, self.schema, &order_by)?; + coerce_window_frame(window_frame, schema, &order_by)?; let args = match &fun { expr::WindowFunctionDefinition::AggregateUDF(udf) => { coerce_arguments_for_signature_with_aggregate_udf( args, - self.schema, + schema, udf, )? } @@ -597,7 +599,8 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { | Expr::Wildcard { .. } | Expr::GroupingSet(_) | Expr::Placeholder(_) - | Expr::OuterReferenceColumn(_, _) => Ok(Transformed::no(expr)), + | Expr::OuterReferenceColumn(_, _) + | Expr::Lambda { .. } => Ok(Transformed::no(expr)), } } } @@ -793,9 +796,11 @@ fn coerce_arguments_for_signature_with_scalar_udf( return Ok(expressions); } - let current_types = expressions - .iter() - .map(|e| e.get_type(schema)) + let current_types = expressions.iter() + .map(|e| match e { + Expr::Lambda { .. } => Ok(DataType::Null), + _ => e.get_type(schema), + }) .collect::>>()?; let new_types = data_types_with_scalar_udf(¤t_types, func)?; @@ -803,7 +808,10 @@ fn coerce_arguments_for_signature_with_scalar_udf( expressions .into_iter() .enumerate() - .map(|(i, expr)| expr.cast_to(&new_types[i], schema)) + .map(|(i, expr)| match expr { + lambda @ Expr::Lambda { .. } => Ok(lambda), + _ => expr.cast_to(&new_types[i], schema), + }) .collect() } @@ -1125,7 +1133,7 @@ mod test { use crate::analyzer::Analyzer; use crate::assert_analyzed_plan_with_config_eq_snapshot; use datafusion_common::config::ConfigOptions; - use datafusion_common::tree_node::{TransformedResult, TreeNode}; + use datafusion_common::tree_node::{TransformedResult}; use datafusion_common::{DFSchema, DFSchemaRef, Result, ScalarValue, Spans}; use datafusion_expr::expr::{self, InSubquery, Like, ScalarFunction}; use datafusion_expr::logical_plan::{EmptyRelation, Projection, Sort}; @@ -2076,7 +2084,7 @@ mod test { let mut rewriter = TypeCoercionRewriter { schema: &schema }; let expr = is_true(lit(12i32).gt(lit(13i64))); let expected = is_true(cast(lit(12i32), DataType::Int64).gt(lit(13i64))); - let result = expr.rewrite(&mut rewriter).data()?; + let result = expr.rewrite_with_schema(&schema, &mut rewriter).data()?; assert_eq!(expected, result); // eq @@ -2087,7 +2095,7 @@ mod test { let mut rewriter = TypeCoercionRewriter { schema: &schema }; let expr = is_true(lit(12i32).eq(lit(13i64))); let expected = is_true(cast(lit(12i32), DataType::Int64).eq(lit(13i64))); - let result = expr.rewrite(&mut rewriter).data()?; + let result = expr.rewrite_with_schema(&schema, &mut rewriter).data()?; assert_eq!(expected, result); // lt @@ -2098,7 +2106,7 @@ mod test { let mut rewriter = TypeCoercionRewriter { schema: &schema }; let expr = is_true(lit(12i32).lt(lit(13i64))); let expected = is_true(cast(lit(12i32), DataType::Int64).lt(lit(13i64))); - let result = expr.rewrite(&mut rewriter).data()?; + let result = expr.rewrite_with_schema(&schema, &mut rewriter).data()?; assert_eq!(expected, result); Ok(()) diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 2510068494591..e06ed6e547eb5 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -29,7 +29,7 @@ use datafusion_common::alias::AliasGenerator; use datafusion_common::cse::{CSEController, FoundCommonNodes, CSE}; use datafusion_common::tree_node::{Transformed, TreeNode}; -use datafusion_common::{qualified_name, Column, DFSchema, DFSchemaRef, Result}; +use datafusion_common::{qualified_name, Column, DFSchema, DFSchemaRef, HashSet, Result}; use datafusion_expr::expr::{Alias, ScalarFunction}; use datafusion_expr::logical_plan::{ Aggregate, Filter, LogicalPlan, Projection, Sort, Window, @@ -632,6 +632,7 @@ struct ExprCSEController<'a> { // how many aliases have we seen so far alias_counter: usize, + lambdas_params: HashSet, } impl<'a> ExprCSEController<'a> { @@ -640,6 +641,7 @@ impl<'a> ExprCSEController<'a> { alias_generator, mask, alias_counter: 0, + lambdas_params: HashSet::new(), } } } @@ -693,11 +695,30 @@ impl CSEController for ExprCSEController<'_> { } } + fn visit_f_down(&mut self, node: &Expr) { + if let Expr::Lambda(lambda) = node { + self.lambdas_params + .extend(lambda.params.iter().cloned()); + } + } + + fn visit_f_up(&mut self, node: &Expr) { + if let Expr::Lambda(lambda) = node { + for param in &lambda.params { + self.lambdas_params.remove(param); + } + } + } + fn is_valid(node: &Expr) -> bool { !node.is_volatile_node() } fn is_ignored(&self, node: &Expr) -> bool { + if matches!(node, Expr::Column(c) if c.is_lambda_parameter(&self.lambdas_params)) { + return true + } + // TODO: remove the next line after `Expr::Wildcard` is removed #[expect(deprecated)] let is_normal_minus_aggregates = matches!( diff --git a/datafusion/optimizer/src/decorrelate.rs b/datafusion/optimizer/src/decorrelate.rs index 63236787743a4..0f43741834009 100644 --- a/datafusion/optimizer/src/decorrelate.rs +++ b/datafusion/optimizer/src/decorrelate.rs @@ -527,18 +527,17 @@ fn proj_exprs_evaluation_result_on_empty_batch( for expr in proj_expr.iter() { let result_expr = expr .clone() - .transform_up(|expr| { - if let Expr::Column(Column { name, .. }) = &expr { + .transform_up_with_lambdas_params(|expr, lambdas_params| match &expr { + Expr::Column(col) if !col.is_lambda_parameter(lambdas_params) => { if let Some(result_expr) = - input_expr_result_map_for_count_bug.get(name) + input_expr_result_map_for_count_bug.get(col.name()) { Ok(Transformed::yes(result_expr.clone())) } else { Ok(Transformed::no(expr)) } - } else { - Ok(Transformed::no(expr)) } + _ => Ok(Transformed::no(expr)), }) .data()?; @@ -570,16 +569,17 @@ fn filter_exprs_evaluation_result_on_empty_batch( ) -> Result> { let result_expr = filter_expr .clone() - .transform_up(|expr| { - if let Expr::Column(Column { name, .. }) = &expr { - if let Some(result_expr) = input_expr_result_map_for_count_bug.get(name) { + .transform_up_with_lambdas_params(|expr, lambdas_params| match &expr { + Expr::Column(col) if !col.is_lambda_parameter(lambdas_params) => { + if let Some(result_expr) = + input_expr_result_map_for_count_bug.get(col.name()) + { Ok(Transformed::yes(result_expr.clone())) } else { Ok(Transformed::no(expr)) } - } else { - Ok(Transformed::no(expr)) } + _ => Ok(Transformed::no(expr)), }) .data()?; diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index 5db71417bc8fd..f0187b618ccc0 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -639,7 +639,7 @@ fn is_expr_trivial(expr: &Expr) -> bool { /// --Source(a, b) /// ``` fn rewrite_expr(expr: Expr, input: &Projection) -> Result> { - expr.transform_up(|expr| { + expr.transform_up_with_lambdas_params(|expr, lambdas_params| { match expr { // remove any intermediate aliases if they do not carry metadata Expr::Alias(alias) => { @@ -653,7 +653,7 @@ fn rewrite_expr(expr: Expr, input: &Projection) -> Result> { false => Ok(Transformed::no(Expr::Alias(alias))), } } - Expr::Column(col) => { + Expr::Column(col) if !col.is_lambda_parameter(lambdas_params) => { // Find index of column: let idx = input.schema.index_of_column(&col)?; // get the corresponding unaliased input expression diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 1c0790b3e3acd..54cb026543270 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -293,7 +293,8 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { Expr::AggregateFunction(_) | Expr::WindowFunction(_) | Expr::Wildcard { .. } - | Expr::GroupingSet(_) => internal_err!("Unsupported predicate type"), + | Expr::GroupingSet(_) + | Expr::Lambda { .. } => internal_err!("Unsupported predicate type"), })?; Ok(is_evaluate) } @@ -1389,14 +1390,15 @@ pub fn replace_cols_by_name( e: Expr, replace_map: &HashMap, ) -> Result { - e.transform_up(|expr| { - Ok(if let Expr::Column(c) = &expr { - match replace_map.get(&c.flat_name()) { - Some(new_c) => Transformed::yes(new_c.clone()), - None => Transformed::no(expr), + e.transform_up_with_lambdas_params(|expr, lambdas_params| { + Ok(match &expr { + Expr::Column(c) if !c.is_lambda_parameter(lambdas_params) => { + match replace_map.get(&c.flat_name()) { + Some(new_c) => Transformed::yes(new_c.clone()), + None => Transformed::no(expr), + } } - } else { - Transformed::no(expr) + _ => Transformed::no(expr), }) }) .data() @@ -1405,17 +1407,18 @@ pub fn replace_cols_by_name( /// check whether the expression uses the columns in `check_map`. fn contain(e: &Expr, check_map: &HashMap) -> bool { let mut is_contain = false; - e.apply(|expr| { - Ok(if let Expr::Column(c) = &expr { - match check_map.get(&c.flat_name()) { - Some(_) => { - is_contain = true; - TreeNodeRecursion::Stop + e.apply_with_lambdas_params(|expr, lambdas_params| { + Ok(match &expr { + Expr::Column(c) if !c.is_lambda_parameter(lambdas_params) => { + match check_map.get(&c.flat_name()) { + Some(_) => { + is_contain = true; + TreeNodeRecursion::Stop + } + None => TreeNodeRecursion::Continue, } - None => TreeNodeRecursion::Continue, } - } else { - TreeNodeRecursion::Continue + _ => TreeNodeRecursion::Continue, }) }) .unwrap(); diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index 48d1182527013..f1e619750f9c8 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -106,17 +106,22 @@ impl OptimizerRule for ScalarSubqueryToJoin { { if !expr_check_map.is_empty() { rewrite_expr = rewrite_expr - .transform_up(|expr| { - // replace column references with entry in map, if it exists - if let Some(map_expr) = expr - .try_as_col() - .and_then(|col| expr_check_map.get(&col.name)) - { - Ok(Transformed::yes(map_expr.clone())) - } else { - Ok(Transformed::no(expr)) - } - }) + .transform_up_with_lambdas_params( + |expr, lambdas_params| { + // replace column references with entry in map, if it exists + if let Some(map_expr) = expr + .try_as_col() + .filter(|c| { + !c.is_lambda_parameter(lambdas_params) + }) + .and_then(|col| expr_check_map.get(&col.name)) + { + Ok(Transformed::yes(map_expr.clone())) + } else { + Ok(Transformed::no(expr)) + } + }, + ) .data()?; } cur_input = optimized_subquery; @@ -171,18 +176,26 @@ impl OptimizerRule for ScalarSubqueryToJoin { { let new_expr = rewrite_expr .clone() - .transform_up(|expr| { - // replace column references with entry in map, if it exists - if let Some(map_expr) = - expr.try_as_col().and_then(|col| { - expr_check_map.get(&col.name) - }) - { - Ok(Transformed::yes(map_expr.clone())) - } else { - Ok(Transformed::no(expr)) - } - }) + .transform_up_with_lambdas_params( + |expr, lambdas_params| { + // replace column references with entry in map, if it exists + if let Some(map_expr) = expr + .try_as_col() + .filter(|c| { + !c.is_lambda_parameter( + lambdas_params, + ) + }) + .and_then(|col| { + expr_check_map.get(&col.name) + }) + { + Ok(Transformed::yes(map_expr.clone())) + } else { + Ok(Transformed::no(expr)) + } + }, + ) .data()?; expr_to_rewrite_expr_map.insert(expr, new_expr); } @@ -396,8 +409,12 @@ fn build_join( let mut expr_rewrite = TypeCoercionRewriter { schema: new_plan.schema(), }; - computation_project_expr - .insert(name, computer_expr.rewrite(&mut expr_rewrite).data()?); + computation_project_expr.insert( + name, + computer_expr + .rewrite_with_schema(new_plan.schema(), &mut expr_rewrite) + .data()?, + ); } } diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 05b8c28fadd6c..a824f6b7be49f 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -17,27 +17,30 @@ //! Expression simplification API +use std::collections::HashSet; +use std::ops::Not; +use std::{borrow::Cow, sync::Arc}; + use arrow::{ array::{new_null_array, AsArray}, datatypes::{DataType, Field, Schema}, record_batch::RecordBatch, }; -use std::borrow::Cow; -use std::collections::HashSet; -use std::ops::Not; -use std::sync::Arc; +use datafusion_common::tree_node::TreeNodeContainer; use datafusion_common::{ cast::{as_large_list_array, as_list_array}, metadata::FieldMetadata, - tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter}, + tree_node::{ + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter, + }, }; use datafusion_common::{ exec_datafusion_err, internal_err, DFSchema, DataFusionError, Result, ScalarValue, }; use datafusion_expr::{ - and, binary::BinaryTypeCoercer, lit, or, BinaryExpr, Case, ColumnarValue, Expr, Like, - Operator, Volatility, + and, binary::BinaryTypeCoercer, lit, or, simplify::SimplifyContext, BinaryExpr, Case, + ColumnarValue, Expr, Like, Operator, Volatility, }; use datafusion_expr::{expr::ScalarFunction, interval_arithmetic::NullableInterval}; use datafusion_expr::{ @@ -267,7 +270,7 @@ impl ExprSimplifier { /// documentation for more details on type coercion pub fn coerce(&self, expr: Expr, schema: &DFSchema) -> Result { let mut expr_rewrite = TypeCoercionRewriter { schema }; - expr.rewrite(&mut expr_rewrite).data() + expr.rewrite_with_schema(schema, &mut expr_rewrite).data() } /// Input guarantees about the values of columns. @@ -649,7 +652,8 @@ impl<'a> ConstEvaluator<'a> { | Expr::WindowFunction { .. } | Expr::GroupingSet(_) | Expr::Wildcard { .. } - | Expr::Placeholder(_) => false, + | Expr::Placeholder(_) + | Expr::Lambda { .. } => false, Expr::ScalarFunction(ScalarFunction { func, .. }) => { Self::volatility_ok(func.signature().volatility) } @@ -754,6 +758,89 @@ impl<'a, S> Simplifier<'a, S> { impl TreeNodeRewriter for Simplifier<'_, S> { type Node = Expr; + fn f_down(&mut self, expr: Self::Node) -> Result> { + match expr { + Expr::ScalarFunction(ScalarFunction { func, args }) + if args.iter().any(|arg| matches!(arg, Expr::Lambda(_))) => + { + // there's currently no way to adapt a generic SimplifyInfo with lambda parameters, + // so, if the scalar function has any lambda, we materialize a DFSchema using all the + // columns references in every arguments. Than we can call lambdas_schemas_from_args, + // and for each argument, we create a new SimplifyContext with the scoped schema, and + // simplify the argument using this 'sub-context'. Finally, we set Transformed.tnr to + // Jump so the parent context doesn't try to simplify the argument again, without the + // parameters info + + // get all columns references + let mut columns_refs = HashSet::new(); + + for arg in &args { + arg.add_column_refs(&mut columns_refs); + } + + // materialize columns references into qualified fields + let qualified_fields = columns_refs + .into_iter() + .map(|captured_column| { + let expr = Expr::Column(captured_column.clone()); + + Ok(( + captured_column.relation.clone(), + Arc::new(Field::new( + captured_column.name(), + self.info.get_data_type(&expr)?, + self.info.nullable(&expr)?, + )), + )) + }) + .collect::>()?; + + // create a schema using the materialized fields + let dfschema = + DFSchema::new_with_metadata(qualified_fields, Default::default())?; + + let mut scoped_schemas = func + .arguments_schema_from_logical_args(&args, &dfschema)? + .into_iter(); + + let transformed_args = args + .map_elements(|arg| { + let scoped_schema = scoped_schemas.next().unwrap(); + + // create a sub-context, using the scoped schema, that includes information about the lambda parameters + let simplify_context = + SimplifyContext::new(self.info.execution_props()) + .with_schema(Arc::new(scoped_schema.into_owned())); + + let mut simplifier = Simplifier::new(&simplify_context); + + // simplify the argument using it's context + arg.rewrite(&mut simplifier) + })? + .update_data(|args| { + Expr::ScalarFunction(ScalarFunction { func, args }) + }); + + Ok(Transformed::new( + transformed_args.data, + transformed_args.transformed, + // return at least Jump so the parent contex doesn't try again to simplify the arguments + // (and fail because it doesn't contain info about lambdas paramters) + match transformed_args.tnr { + TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => { + TreeNodeRecursion::Jump + } + TreeNodeRecursion::Stop => TreeNodeRecursion::Stop, + }, + )) + + // Ok(transformed_args.update_data(|args| Expr::ScalarFunction(ScalarFunction { func, args}))) + } + // Expr::Lambda(_) => Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump)), + _ => Ok(Transformed::no(expr)), + } + } + /// rewrite the expression simplifying any constant expressions fn f_up(&mut self, expr: Expr) -> Result> { use datafusion_expr::Operator::{ diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 81763fa0552fb..d0ae4932628f3 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -23,7 +23,7 @@ use crate::analyzer::type_coercion::TypeCoercionRewriter; use arrow::array::{new_null_array, Array, RecordBatch}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::cast::as_boolean_array; -use datafusion_common::tree_node::{TransformedResult, TreeNode}; +use datafusion_common::tree_node::TransformedResult; use datafusion_common::{Column, DFSchema, Result, ScalarValue}; use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::expr_rewriter::replace_col; @@ -148,7 +148,7 @@ fn evaluate_expr_with_null_column<'a>( fn coerce(expr: Expr, schema: &DFSchema) -> Result { let mut expr_rewrite = TypeCoercionRewriter { schema }; - expr.rewrite(&mut expr_rewrite).data() + expr.rewrite_with_schema(schema, &mut expr_rewrite).data() } #[cfg(test)] diff --git a/datafusion/physical-expr-adapter/src/schema_rewriter.rs b/datafusion/physical-expr-adapter/src/schema_rewriter.rs index 61cc97dae300e..4a81a5c99ac75 100644 --- a/datafusion/physical-expr-adapter/src/schema_rewriter.rs +++ b/datafusion/physical-expr-adapter/src/schema_rewriter.rs @@ -21,12 +21,14 @@ use std::sync::Arc; use arrow::compute::can_cast_types; use arrow::datatypes::{DataType, FieldRef, Schema, SchemaRef}; +use datafusion_common::HashSet; use datafusion_common::{ exec_err, - tree_node::{Transformed, TransformedResult, TreeNode}, + tree_node::{Transformed, TransformedResult}, Result, ScalarValue, }; use datafusion_functions::core::getfield::GetFieldFunc; +use datafusion_physical_expr::PhysicalExprExt; use datafusion_physical_expr::{ expressions::{self, CastExpr, Column}, ScalarFunctionExpr, @@ -217,8 +219,10 @@ impl PhysicalExprAdapter for DefaultPhysicalExprAdapter { physical_file_schema: &self.physical_file_schema, partition_fields: &self.partition_values, }; - expr.transform(|expr| rewriter.rewrite_expr(Arc::clone(&expr))) - .data() + expr.transform_with_lambdas_params(|expr, lambdas_params| { + rewriter.rewrite_expr(Arc::clone(&expr), lambdas_params) + }) + .data() } fn with_partition_values( @@ -242,13 +246,18 @@ impl<'a> DefaultPhysicalExprAdapterRewriter<'a> { fn rewrite_expr( &self, expr: Arc, + lambdas_params: &HashSet, ) -> Result>> { - if let Some(transformed) = self.try_rewrite_struct_field_access(&expr)? { + if let Some(transformed) = + self.try_rewrite_struct_field_access(&expr, lambdas_params)? + { return Ok(Transformed::yes(transformed)); } if let Some(column) = expr.as_any().downcast_ref::() { - return self.rewrite_column(Arc::clone(&expr), column); + if !lambdas_params.contains(column.name()) { + return self.rewrite_column(Arc::clone(&expr), column); + } } Ok(Transformed::no(expr)) @@ -260,6 +269,7 @@ impl<'a> DefaultPhysicalExprAdapterRewriter<'a> { fn try_rewrite_struct_field_access( &self, expr: &Arc, + lambdas_params: &HashSet, ) -> Result>> { let get_field_expr = match ScalarFunctionExpr::try_downcast_func::(expr.as_ref()) { @@ -291,8 +301,8 @@ impl<'a> DefaultPhysicalExprAdapterRewriter<'a> { }; let column = match source_expr.as_any().downcast_ref::() { - Some(column) => column, - None => return Ok(None), + Some(column) if !lambdas_params.contains(column.name()) => column, + _ => return Ok(None), }; let physical_field = @@ -446,6 +456,7 @@ mod tests { use super::*; use arrow::array::{RecordBatch, RecordBatchOptions}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; + use datafusion_common::hashbrown::HashSet; use datafusion_common::{assert_contains, record_batch, Result, ScalarValue}; use datafusion_expr::Operator; use datafusion_physical_expr::expressions::{col, lit, CastExpr, Column, Literal}; @@ -852,7 +863,9 @@ mod tests { // Test that when a field exists in physical schema, it returns None let column = Arc::new(Column::new("struct_col", 0)) as Arc; - let result = rewriter.try_rewrite_struct_field_access(&column).unwrap(); + let result = rewriter + .try_rewrite_struct_field_access(&column, &HashSet::new()) + .unwrap(); assert!(result.is_none()); // The actual test for the get_field expression would require creating a proper ScalarFunctionExpr diff --git a/datafusion/physical-expr/Cargo.toml b/datafusion/physical-expr/Cargo.toml index b7654a0f6f603..d4c0e1cbe6eb7 100644 --- a/datafusion/physical-expr/Cargo.toml +++ b/datafusion/physical-expr/Cargo.toml @@ -37,6 +37,9 @@ workspace = true [lib] name = "datafusion_physical_expr" +[features] +recursive_protection = ["dep:recursive"] + [dependencies] ahash = { workspace = true } arrow = { workspace = true } @@ -52,6 +55,7 @@ itertools = { workspace = true, features = ["use_std"] } parking_lot = { workspace = true } paste = "^1.0" petgraph = "0.8.3" +recursive = { workspace = true, optional = true } [dev-dependencies] arrow = { workspace = true, features = ["test_utils"] } diff --git a/datafusion/physical-expr/src/async_scalar_function.rs b/datafusion/physical-expr/src/async_scalar_function.rs index b434694a20cc8..a34d3cda47682 100644 --- a/datafusion/physical-expr/src/async_scalar_function.rs +++ b/datafusion/physical-expr/src/async_scalar_function.rs @@ -168,6 +168,7 @@ impl AsyncFuncExpr { number_rows: current_batch.num_rows(), return_field: Arc::clone(&self.return_field), config_options: Arc::clone(&config_options), + lambdas: None, }) .await?, ); @@ -187,6 +188,7 @@ impl AsyncFuncExpr { number_rows: batch.num_rows(), return_field: Arc::clone(&self.return_field), config_options: Arc::clone(&config_options), + lambdas: None, }) .await?, ); diff --git a/datafusion/physical-expr/src/expressions/column.rs b/datafusion/physical-expr/src/expressions/column.rs index 9ca464b304306..c55f42ae333bc 100644 --- a/datafusion/physical-expr/src/expressions/column.rs +++ b/datafusion/physical-expr/src/expressions/column.rs @@ -22,12 +22,13 @@ use std::hash::Hash; use std::sync::Arc; use crate::physical_expr::PhysicalExpr; +use crate::PhysicalExprExt; use arrow::datatypes::FieldRef; use arrow::{ datatypes::{DataType, Schema, SchemaRef}, record_batch::RecordBatch, }; -use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::tree_node::{Transformed, TransformedResult}; use datafusion_common::{internal_err, plan_err, Result}; use datafusion_expr::ColumnarValue; @@ -67,7 +68,8 @@ use datafusion_expr::ColumnarValue; pub struct Column { /// The name of the column (used for debugging and display purposes) name: String, - /// The index of the column in its schema + /// The index of the column in its schema. + /// Within a lambda body, this refer to the lambda scoped schema, not the plan schema. index: usize, } @@ -178,9 +180,9 @@ pub fn with_new_schema( expr: Arc, schema: &SchemaRef, ) -> Result> { - Ok(expr - .transform_up(|expr| { - if let Some(col) = expr.as_any().downcast_ref::() { + expr.transform_up_with_lambdas_params(|expr, lambdas_params| { + match expr.as_any().downcast_ref::() { + Some(col) if !lambdas_params.contains(col.name()) => { let idx = col.index(); let Some(field) = schema.fields().get(idx) else { return plan_err!( @@ -188,12 +190,13 @@ pub fn with_new_schema( ); }; let new_col = Column::new(field.name(), idx); + Ok(Transformed::yes(Arc::new(new_col) as _)) - } else { - Ok(Transformed::no(expr)) } - })? - .data) + _ => Ok(Transformed::no(expr)), + } + }) + .data() } #[cfg(test)] diff --git a/datafusion/physical-expr/src/expressions/lambda.rs b/datafusion/physical-expr/src/expressions/lambda.rs new file mode 100644 index 0000000000000..55110fdf5bf6b --- /dev/null +++ b/datafusion/physical-expr/src/expressions/lambda.rs @@ -0,0 +1,139 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Physical column reference: [`Column`] + +use std::hash::Hash; +use std::sync::Arc; +use std::{any::Any, sync::OnceLock}; + +use crate::expressions::Column; +use crate::physical_expr::PhysicalExpr; +use crate::PhysicalExprExt; +use arrow::{ + datatypes::{DataType, Schema}, + record_batch::RecordBatch, +}; +use datafusion_common::tree_node::TreeNodeRecursion; +use datafusion_common::{internal_err, HashSet, Result}; +use datafusion_expr::ColumnarValue; + +/// Represents a lambda with the given parameters name and body +#[derive(Debug, Eq, Clone)] +pub struct LambdaExpr { + params: Vec, + body: Arc, + captures: OnceLock>, +} + +// Manually derive PartialEq and Hash to work around https://github.com/rust-lang/rust/issues/78808 [https://github.com/apache/datafusion/issues/13196] +impl PartialEq for LambdaExpr { + fn eq(&self, other: &Self) -> bool { + self.params.eq(&other.params) && self.body.eq(&other.body) + } +} + +impl Hash for LambdaExpr { + fn hash(&self, state: &mut H) { + self.params.hash(state); + self.body.hash(state); + } +} + +impl LambdaExpr { + /// Create a new lambda expression with the given parameters and body + pub fn new(params: Vec, body: Arc) -> Self { + Self { + params, + body, + captures: OnceLock::new(), + } + } + + /// Get the lambda's params names + pub fn params(&self) -> &[String] { + &self.params + } + + /// Get the lambda's body + pub fn body(&self) -> &Arc { + &self.body + } + + pub fn captures(&self) -> &HashSet { + self.captures.get_or_init(|| { + let mut indices = HashSet::new(); + + self.body + .apply_with_lambdas_params(|expr, lambdas_params| { + if let Some(column) = expr.as_any().downcast_ref::() { + if !lambdas_params.contains(column.name()) { + indices.insert(column.index()); + } + } + + Ok(TreeNodeRecursion::Continue) + }) + .unwrap(); + + indices + }) + } +} + +impl std::fmt::Display for LambdaExpr { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "({}) -> {}", self.params.join(", "), self.body) + } +} + +impl PhysicalExpr for LambdaExpr { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, _input_schema: &Schema) -> Result { + Ok(DataType::Null) + } + + fn nullable(&self, _input_schema: &Schema) -> Result { + Ok(true) + } + + fn evaluate(&self, _batch: &RecordBatch) -> Result { + internal_err!("Lambda::evaluate() should not be called") + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.body] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + Ok(Arc::new(Self { + params: self.params.clone(), + body: Arc::clone(&children[0]), + captures: OnceLock::new(), + })) + } + + fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "({}) -> {}", self.params.join(", "), self.body) + } +} diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index 59d675753d985..e87941da5ef4c 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -27,6 +27,7 @@ mod dynamic_filters; mod in_list; mod is_not_null; mod is_null; +mod lambda; mod like; mod literal; mod negative; @@ -49,6 +50,7 @@ pub use dynamic_filters::DynamicFilterPhysicalExpr; pub use in_list::{in_list, InListExpr}; pub use is_not_null::{is_not_null, IsNotNullExpr}; pub use is_null::{is_null, IsNullExpr}; +pub use lambda::LambdaExpr; pub use like::{like, LikeExpr}; pub use literal::{lit, Literal}; pub use negative::{negative, NegativeExpr}; diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index aa8c9e50fd71e..873205f28bef4 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -70,6 +70,8 @@ pub use datafusion_physical_expr_common::sort_expr::{ pub use planner::{create_physical_expr, create_physical_exprs}; pub use scalar_function::ScalarFunctionExpr; +pub use scalar_function::PhysicalExprExt; + pub use simplifier::PhysicalExprSimplifier; pub use utils::{conjunction, conjunction_opt, split_conjunction}; diff --git a/datafusion/physical-expr/src/physical_expr.rs b/datafusion/physical-expr/src/physical_expr.rs index c658a8eddc233..2584fc22885c2 100644 --- a/datafusion/physical-expr/src/physical_expr.rs +++ b/datafusion/physical-expr/src/physical_expr.rs @@ -18,11 +18,11 @@ use std::sync::Arc; use crate::expressions::{self, Column}; -use crate::{create_physical_expr, LexOrdering, PhysicalSortExpr}; +use crate::{create_physical_expr, LexOrdering, PhysicalExprExt, PhysicalSortExpr}; use arrow::compute::SortOptions; use arrow::datatypes::{Schema, SchemaRef}; -use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_common::tree_node::{Transformed, TransformedResult}; use datafusion_common::{plan_err, Result}; use datafusion_common::{DFSchema, HashMap}; use datafusion_expr::execution_props::ExecutionProps; @@ -38,14 +38,14 @@ pub fn add_offset_to_expr( expr: Arc, offset: isize, ) -> Result> { - expr.transform_down(|e| match e.as_any().downcast_ref::() { - Some(col) => { + expr.transform_down_with_lambdas_params(|e, lambdas_params| match e.as_any().downcast_ref::() { + Some(col) if !lambdas_params.contains(col.name()) => { let Some(idx) = col.index().checked_add_signed(offset) else { return plan_err!("Column index overflow"); }; Ok(Transformed::yes(Arc::new(Column::new(col.name(), idx)))) } - None => Ok(Transformed::no(e)), + _ => Ok(Transformed::no(e)), }) .data() } diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index 7790380dffd56..0119c81b8ed94 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -17,6 +17,7 @@ use std::sync::Arc; +use crate::expressions::LambdaExpr; use crate::ScalarFunctionExpr; use crate::{ expressions::{self, binary, like, similar_to, Column, Literal}, @@ -30,7 +31,7 @@ use datafusion_common::{ exec_err, not_impl_err, plan_err, DFSchema, Result, ScalarValue, ToDFSchema, }; use datafusion_expr::execution_props::ExecutionProps; -use datafusion_expr::expr::{Alias, Cast, InList, Placeholder, ScalarFunction}; +use datafusion_expr::expr::{Alias, Cast, InList, Lambda, Placeholder, ScalarFunction}; use datafusion_expr::var_provider::is_system_variables; use datafusion_expr::var_provider::VarType; use datafusion_expr::{ @@ -104,7 +105,8 @@ use datafusion_expr::{ /// /// * `e` - The logical expression /// * `input_dfschema` - The DataFusion schema for the input, used to resolve `Column` references -/// to qualified or unqualified fields by name. +/// to qualified or unqualified fields by name. Note that for creating a lambda, this must be +/// scoped lambda schema, and not the outer schema pub fn create_physical_expr( e: &Expr, input_dfschema: &DFSchema, @@ -314,9 +316,28 @@ pub fn create_physical_expr( input_dfschema, execution_props, )?), + Expr::Lambda { .. } => { + exec_err!("Expr::Lambda should be handled by Expr::ScalarFunction, as it can only exist within it") + } Expr::ScalarFunction(ScalarFunction { func, args }) => { - let physical_args = - create_physical_exprs(args, input_dfschema, execution_props)?; + let lambdas_schemas = + func.arguments_schema_from_logical_args(args, input_dfschema)?; + + let physical_args = std::iter::zip(args, lambdas_schemas) + .map(|(expr, schema)| match expr { + Expr::Lambda(Lambda { params, body }) => { + Ok(Arc::new(LambdaExpr::new( + params.clone(), + create_physical_expr(body, &schema, execution_props)?, + )) as Arc) + } + expr => create_physical_expr(expr, &schema, execution_props), + }) + .collect::>>()?; + + //let physical_args = + // create_physical_exprs(args, input_dfschema, execution_props)?; + let config_options = match execution_props.config_options.as_ref() { Some(config_options) => Arc::clone(config_options), None => Arc::new(ConfigOptions::default()), diff --git a/datafusion/physical-expr/src/projection.rs b/datafusion/physical-expr/src/projection.rs index a120ab427e1de..70be717a8436c 100644 --- a/datafusion/physical-expr/src/projection.rs +++ b/datafusion/physical-expr/src/projection.rs @@ -20,11 +20,11 @@ use std::sync::Arc; use crate::expressions::Column; use crate::utils::collect_columns; -use crate::PhysicalExpr; +use crate::{PhysicalExpr, PhysicalExprExt}; use arrow::datatypes::{Field, Schema, SchemaRef}; use datafusion_common::stats::{ColumnStatistics, Precision}; -use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_common::tree_node::{Transformed, TransformedResult}; use datafusion_common::{internal_datafusion_err, internal_err, plan_err, Result}; use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; @@ -499,13 +499,16 @@ pub fn update_expr( let mut state = RewriteState::Unchanged; let new_expr = Arc::clone(expr) - .transform_up(|expr| { + .transform_up_with_lambdas_params(|expr, lambdas_params| { if state == RewriteState::RewrittenInvalid { return Ok(Transformed::no(expr)); } - let Some(column) = expr.as_any().downcast_ref::() else { - return Ok(Transformed::no(expr)); + let column = match expr.as_any().downcast_ref::() { + Some(column) if !lambdas_params.contains(column.name()) => column, + _ => { + return Ok(Transformed::no(expr)); + } }; if sync_with_child { state = RewriteState::RewrittenValid; @@ -616,14 +619,14 @@ impl ProjectionMapping { let mut map = IndexMap::<_, ProjectionTargets>::new(); for (expr_idx, (expr, name)) in expr.into_iter().enumerate() { let target_expr = Arc::new(Column::new(&name, expr_idx)) as _; - let source_expr = expr.transform_down(|e| match e.as_any().downcast_ref::() { + let source_expr = expr.transform_down_with_schema(input_schema, |e, schema| match e.as_any().downcast_ref::() { Some(col) => { - // Sometimes, an expression and its name in the input_schema + // Sometimes, an expression and its name in the schema // doesn't match. This can cause problems, so we make sure - // that the expression name matches with the name in `input_schema`. + // that the expression name matches with the name in `schema`. // Conceptually, `source_expr` and `expression` should be the same. let idx = col.index(); - let matching_field = input_schema.field(idx); + let matching_field = schema.field(idx); let matching_name = matching_field.name(); if col.name() != matching_name { return internal_err!( @@ -737,21 +740,25 @@ pub fn project_ordering( ) -> Option { let mut projected_exprs = vec![]; for PhysicalSortExpr { expr, options } in ordering.iter() { - let transformed = Arc::clone(expr).transform_up(|expr| { - let Some(col) = expr.as_any().downcast_ref::() else { - return Ok(Transformed::no(expr)); - }; + let transformed = + Arc::clone(expr).transform_up_with_lambdas_params(|expr, lambdas_params| { + let col = match expr.as_any().downcast_ref::() { + Some(col) if !lambdas_params.contains(col.name()) => col, + _ => { + return Ok(Transformed::no(expr)); + } + }; - let name = col.name(); - if let Some((idx, _)) = schema.column_with_name(name) { - // Compute the new column expression (with correct index) after projection: - Ok(Transformed::yes(Arc::new(Column::new(name, idx)))) - } else { - // Cannot find expression in the projected_schema, - // signal this using an Err result - plan_err!("") - } - }); + let name = col.name(); + if let Some((idx, _)) = schema.column_with_name(name) { + // Compute the new column expression (with correct index) after projection: + Ok(Transformed::yes(Arc::new(Column::new(name, idx)))) + } else { + // Cannot find expression in the projected_schema, + // signal this using an Err result + plan_err!("") + } + }); match transformed { Ok(transformed) => { diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index 743d5b99cde95..22fa300f05df4 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -30,23 +30,25 @@ //! to a function that supports f64, it is coerced to f64. use std::any::Any; +use std::borrow::Cow; use std::fmt::{self, Debug, Formatter}; use std::hash::{Hash, Hasher}; use std::sync::Arc; -use crate::expressions::Literal; +use crate::expressions::{Column, LambdaExpr, Literal}; use crate::PhysicalExpr; -use arrow::array::{Array, RecordBatch}; -use arrow::datatypes::{DataType, FieldRef, Schema}; +use arrow::array::{Array, NullArray, RecordBatch}; +use arrow::datatypes::{DataType, Field, FieldRef, Schema}; use datafusion_common::config::{ConfigEntry, ConfigOptions}; -use datafusion_common::{internal_err, Result, ScalarValue}; +use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; +use datafusion_common::{internal_err, HashSet, Result, ScalarValue}; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::sort_properties::ExprProperties; use datafusion_expr::type_coercion::functions::data_types_with_scalar_udf; use datafusion_expr::{ - expr_vec_fmt, ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, - Volatility, + expr_vec_fmt, ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, + ScalarFunctionLambdaArg, ScalarUDF, ValueOrLambdaParameter, Volatility, }; /// Physical expression of a scalar function @@ -94,10 +96,16 @@ impl ScalarFunctionExpr { schema: &Schema, config_options: Arc, ) -> Result { - let name = fun.name().to_string(); - let arg_fields = args - .iter() - .map(|e| e.return_field(schema)) + let lambdas_schemas = lambdas_schemas_from_args(&fun, &args, schema)?; + + let arg_fields = std::iter::zip(&args, lambdas_schemas) + .map(|(e, schema)| { + if let Some(lambda) = e.as_any().downcast_ref::() { + lambda.body().return_field(&schema) + } else { + e.return_field(&schema) + } + }) .collect::>>()?; // verify that input data types is consistent with function's `TypeSignature` @@ -105,6 +113,7 @@ impl ScalarFunctionExpr { .iter() .map(|f| f.data_type().clone()) .collect::>(); + data_types_with_scalar_udf(&arg_types, &fun)?; let arguments = args @@ -115,11 +124,21 @@ impl ScalarFunctionExpr { .map(|literal| literal.value()) }) .collect::>(); + + let lambdas = args + .iter() + .map(|e| e.as_any().is::()) + .collect::>(); + let ret_args = ReturnFieldArgs { arg_fields: &arg_fields, scalar_arguments: &arguments, + lambdas: &lambdas, }; + let return_field = fun.return_field_from_args(ret_args)?; + let name = fun.name().to_string(); + Ok(Self { fun, name, @@ -260,7 +279,10 @@ impl PhysicalExpr for ScalarFunctionExpr { let args = self .args .iter() - .map(|e| e.evaluate(batch)) + .map(|e| match e.as_any().downcast_ref::() { + Some(_) => Ok(ColumnarValue::Scalar(ScalarValue::Null)), + None => Ok(e.evaluate(batch)?), + }) .collect::>>()?; let arg_fields = self @@ -274,6 +296,111 @@ impl PhysicalExpr for ScalarFunctionExpr { .iter() .all(|arg| matches!(arg, ColumnarValue::Scalar(_))); + let lambdas = if self.args.iter().any(|arg| arg.as_any().is::()) { + let args_metadata = std::iter::zip(&self.args, &arg_fields) + .map( + |(expr, field)| match expr.as_any().downcast_ref::() { + Some(lambda) => { + let mut captures = false; + + expr.apply_with_lambdas_params(|expr, lambdas_params| { + match expr.as_any().downcast_ref::() { + Some(col) if !lambdas_params.contains(col.name()) => { + captures = true; + + Ok(TreeNodeRecursion::Stop) + } + _ => Ok(TreeNodeRecursion::Continue), + } + }) + .unwrap(); + + ValueOrLambdaParameter::Lambda(lambda.params(), captures) + } + None => ValueOrLambdaParameter::Value(Arc::clone(field)), + }, + ) + .collect::>(); + + let params = self.fun().inner().lambdas_parameters(&args_metadata)?; + + let lambdas = std::iter::zip(&self.args, params) + .map(|(arg, lambda_params)| { + arg.as_any() + .downcast_ref::() + .map(|lambda| { + let mut indices = HashSet::new(); + + arg.apply_with_lambdas_params(|expr, lambdas_params| { + if let Some(column) = + expr.as_any().downcast_ref::() + { + if !lambdas_params.contains(column.name()) { + indices.insert( + column.index(), //batch + // .schema_ref() + // .index_of(column.name())?, + ); + } + } + + Ok(TreeNodeRecursion::Continue) + })?; + + //let mut indices = indices.into_iter().collect::>(); + + //indices.sort_unstable(); + + let params = + std::iter::zip(lambda.params(), lambda_params.unwrap()) + .map(|(name, param)| Arc::new(param.with_name(name))) + .collect(); + + let captures = if !indices.is_empty() { + let (fields, columns): (Vec<_>, _) = std::iter::zip( + batch.schema_ref().fields(), + batch.columns(), + ) + .enumerate() + .map(|(column_index, (field, column))| { + if indices.contains(&column_index) { + (Arc::clone(field), Arc::clone(column)) + } else { + ( + Arc::new(Field::new( + field.name(), + DataType::Null, + false, + )), + Arc::new(NullArray::new(column.len())) as _, + ) + } + }) + .unzip(); + + let schema = Arc::new(Schema::new(fields)); + + Some(RecordBatch::try_new(schema, columns)?) + //Some(batch.project(&indices)?) + } else { + None + }; + + Ok(ScalarFunctionLambdaArg { + params, + body: Arc::clone(lambda.body()), + captures, + }) + }) + .transpose() + }) + .collect::>>()?; + + Some(lambdas) + } else { + None + }; + // evaluate the function let output = self.fun.invoke_with_args(ScalarFunctionArgs { args, @@ -281,6 +408,7 @@ impl PhysicalExpr for ScalarFunctionExpr { number_rows: batch.num_rows(), return_field: Arc::clone(&self.return_field), config_options: Arc::clone(&self.config_options), + lambdas, })?; if let ColumnarValue::Array(array) = &output { @@ -365,14 +493,377 @@ impl PhysicalExpr for ScalarFunctionExpr { } } +pub fn lambdas_schemas_from_args<'a>( + fun: &ScalarUDF, + args: &[Arc], + schema: &'a Schema, +) -> Result>> { + let args_metadata = args + .iter() + .map(|e| match e.as_any().downcast_ref::() { + Some(lambda) => { + let mut captures = false; + + e.apply_with_lambdas_params(|expr, lambdas_params| { + match expr.as_any().downcast_ref::() { + Some(col) if !lambdas_params.contains(col.name()) => { + captures = true; + + Ok(TreeNodeRecursion::Stop) + } + _ => Ok(TreeNodeRecursion::Continue), + } + }) + .unwrap(); + + Ok(ValueOrLambdaParameter::Lambda(lambda.params(), captures)) + } + None => Ok(ValueOrLambdaParameter::Value(e.return_field(schema)?)), + }) + .collect::>>()?; + + /*let captures = args + .iter() + .map(|arg| { + if arg.as_any().is::() { + let mut columns = HashSet::new(); + + arg.apply_with_lambdas_params(|n, lambdas_params| { + if let Some(column) = n.as_any().downcast_ref::() { + if !lambdas_params.contains(column.name()) { + columns.insert(schema.index_of(column.name())?); + } + // columns.insert(column.index()); + } + + Ok(TreeNodeRecursion::Continue) + })?; + + Ok(columns) + } else { + Ok(HashSet::new()) + } + }) + .collect::>>()?; */ + + fun.arguments_arrow_schema(&args_metadata, schema) +} + +pub trait PhysicalExprExt: Sized { + fn apply_with_lambdas_params< + 'n, + F: FnMut(&'n Self, &HashSet<&'n str>) -> Result, + >( + &'n self, + f: F, + ) -> Result; + + fn apply_with_schema<'n, F: FnMut(&'n Self, &Schema) -> Result>( + &'n self, + schema: &Schema, + f: F, + ) -> Result; + + fn apply_children_with_schema< + 'n, + F: FnMut(&'n Self, &Schema) -> Result, + >( + &'n self, + schema: &Schema, + f: F, + ) -> Result; + + fn transform_down_with_schema Result>>( + self, + schema: &Schema, + f: F, + ) -> Result>; + + fn transform_up_with_schema Result>>( + self, + schema: &Schema, + f: F, + ) -> Result>; + + fn transform_with_schema Result>>( + self, + schema: &Schema, + f: F, + ) -> Result> { + self.transform_up_with_schema(schema, f) + } + + fn transform_down_with_lambdas_params( + self, + f: impl FnMut(Self, &HashSet) -> Result>, + ) -> Result>; + + fn transform_up_with_lambdas_params( + self, + f: impl FnMut(Self, &HashSet) -> Result>, + ) -> Result>; + + fn transform_with_lambdas_params( + self, + f: impl FnMut(Self, &HashSet) -> Result>, + ) -> Result> { + self.transform_up_with_lambdas_params(f) + } +} + +impl PhysicalExprExt for Arc { + fn apply_with_lambdas_params< + 'n, + F: FnMut(&'n Self, &HashSet<&'n str>) -> Result, + >( + &'n self, + mut f: F, + ) -> Result { + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] + fn apply_with_lambdas_params_impl< + 'n, + F: FnMut( + &'n Arc, + &HashSet<&'n str>, + ) -> Result, + >( + node: &'n Arc, + args: &HashSet<&'n str>, + f: &mut F, + ) -> Result { + match node.as_any().downcast_ref::() { + Some(lambda) => { + let mut args = args.clone(); + + args.extend(lambda.params().iter().map(|v| v.as_str())); + + f(node, &args)?.visit_children(|| { + node.apply_children(|c| { + apply_with_lambdas_params_impl(c, &args, f) + }) + }) + } + _ => f(node, args)?.visit_children(|| { + node.apply_children(|c| apply_with_lambdas_params_impl(c, args, f)) + }), + } + } + + apply_with_lambdas_params_impl(self, &HashSet::new(), &mut f) + } + + fn apply_with_schema<'n, F: FnMut(&'n Self, &Schema) -> Result>( + &'n self, + schema: &Schema, + mut f: F, + ) -> Result { + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] + fn apply_with_lambdas_impl< + 'n, + F: FnMut(&'n Arc, &Schema) -> Result, + >( + node: &'n Arc, + schema: &Schema, + f: &mut F, + ) -> Result { + f(node, schema)?.visit_children(|| { + node.apply_children_with_schema(schema, |c, schema| { + apply_with_lambdas_impl(c, schema, f) + }) + }) + } + + apply_with_lambdas_impl(self, schema, &mut f) + } + + fn apply_children_with_schema< + 'n, + F: FnMut(&'n Self, &Schema) -> Result, + >( + &'n self, + schema: &Schema, + mut f: F, + ) -> Result { + match self.as_any().downcast_ref::() { + Some(scalar_function) + if scalar_function + .args() + .iter() + .any(|arg| arg.as_any().is::()) => + { + let mut lambdas_schemas = lambdas_schemas_from_args( + scalar_function.fun(), + scalar_function.args(), + schema, + )? + .into_iter(); + + self.apply_children(|expr| f(expr, &lambdas_schemas.next().unwrap())) + } + _ => self.apply_children(|e| f(e, schema)), + } + } + + fn transform_down_with_schema< + F: FnMut(Self, &Schema) -> Result>, + >( + self, + schema: &Schema, + mut f: F, + ) -> Result> { + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] + fn transform_down_with_schema_impl< + F: FnMut( + Arc, + &Schema, + ) -> Result>>, + >( + node: Arc, + schema: &Schema, + f: &mut F, + ) -> Result>> { + f(node, schema)?.transform_children(|node| { + map_children_with_schema(node, schema, |n, schema| { + transform_down_with_schema_impl(n, schema, f) + }) + }) + } + + transform_down_with_schema_impl(self, schema, &mut f) + } + + fn transform_up_with_schema Result>>( + self, + schema: &Schema, + mut f: F, + ) -> Result> { + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] + fn transform_up_with_schema_impl< + F: FnMut( + Arc, + &Schema, + ) -> Result>>, + >( + node: Arc, + schema: &Schema, + f: &mut F, + ) -> Result>> { + map_children_with_schema(node, schema, |n, schema| { + transform_up_with_schema_impl(n, schema, f) + })? + .transform_parent(|n| f(n, schema)) + } + + transform_up_with_schema_impl(self, schema, &mut f) + } + + fn transform_up_with_lambdas_params( + self, + mut f: impl FnMut(Self, &HashSet) -> Result>, + ) -> Result> { + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] + fn transform_up_with_lambdas_params_impl< + F: FnMut( + Arc, + &HashSet, + ) -> Result>>, + >( + node: Arc, + params: &HashSet, + f: &mut F, + ) -> Result>> { + map_children_with_lambdas_params(node, params, |n, params| { + transform_up_with_lambdas_params_impl(n, params, f) + })? + .transform_parent(|n| f(n, params)) + } + + transform_up_with_lambdas_params_impl(self, &HashSet::new(), &mut f) + } + + fn transform_down_with_lambdas_params( + self, + mut f: impl FnMut(Self, &HashSet) -> Result>, + ) -> Result> { + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] + fn transform_down_with_lambdas_params_impl< + F: FnMut( + Arc, + &HashSet, + ) -> Result>>, + >( + node: Arc, + params: &HashSet, + f: &mut F, + ) -> Result>> { + f(node, params)?.transform_children(|node| { + map_children_with_lambdas_params(node, params, |node, args| { + transform_down_with_lambdas_params_impl(node, args, f) + }) + }) + } + + transform_down_with_lambdas_params_impl(self, &HashSet::new(), &mut f) + } +} + +fn map_children_with_schema( + node: Arc, + schema: &Schema, + mut f: impl FnMut( + Arc, + &Schema, + ) -> Result>>, +) -> Result>> { + match node.as_any().downcast_ref::() { + Some(fun) if fun.args().iter().any(|arg| arg.as_any().is::()) => { + let mut args_schemas = + lambdas_schemas_from_args(fun.fun(), fun.args(), schema)?.into_iter(); + + node.map_children(|node| f(node, &args_schemas.next().unwrap())) + } + _ => node.map_children(|node| f(node, schema)), + } +} + +fn map_children_with_lambdas_params( + node: Arc, + params: &HashSet, + mut f: impl FnMut( + Arc, + &HashSet, + ) -> Result>>, +) -> Result>> { + match node.as_any().downcast_ref::() { + Some(lambda) => { + let mut params = params.clone(); + + params.extend(lambda.params().iter().cloned()); + + node.map_children(|node| f(node, ¶ms)) + } + None => node.map_children(|node| f(node, params)), + } +} + #[cfg(test)] mod tests { + use std::any::Any; + use std::{borrow::Cow, sync::Arc}; + use super::*; + use super::{lambdas_schemas_from_args, PhysicalExprExt}; use crate::expressions::Column; + use crate::{create_physical_expr, ScalarFunctionExpr}; use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_common::{tree_node::TreeNodeRecursion, DFSchema, HashSet, Result}; + use datafusion_expr::{ + col, expr::Lambda, Expr, ScalarFunctionArgs, ValueOrLambdaParameter, Volatility, + }; use datafusion_expr::{ScalarUDF, ScalarUDFImpl, Signature}; + use datafusion_expr_common::columnar_value::ColumnarValue; use datafusion_physical_expr_common::physical_expr::is_volatile; - use std::any::Any; + use datafusion_physical_expr_common::physical_expr::PhysicalExpr; /// Test helper to create a mock UDF with a specific volatility #[derive(Debug, PartialEq, Eq, Hash)] @@ -444,4 +935,190 @@ mod tests { let stable_arc: Arc = Arc::new(stable_expr); assert!(!is_volatile(&stable_arc)); } + + fn list_list_int() -> Schema { + Schema::new(vec![Field::new( + "v", + DataType::new_list(DataType::new_list(DataType::Int32, false), false), + false, + )]) + } + + fn list_int() -> Schema { + Schema::new(vec![Field::new( + "v", + DataType::new_list(DataType::Int32, false), + false, + )]) + } + + fn int() -> Schema { + Schema::new(vec![Field::new("v", DataType::Int32, false)]) + } + + fn array_transform_udf() -> ScalarUDF { + ScalarUDF::new_from_impl(ArrayTransformFunc::new()) + } + + fn args() -> Vec { + vec![ + col("v"), + Expr::Lambda(Lambda::new( + vec!["v".into()], + array_transform_udf().call(vec![ + col("v"), + Expr::Lambda(Lambda::new(vec!["v".into()], -col("v"))), + ]), + )), + ] + } + + // array_transform(v, |v| -> array_transform(v, |v| -> -v)) + fn array_transform() -> Arc { + let e = array_transform_udf().call(args()); + + create_physical_expr( + &e, + &DFSchema::try_from(list_list_int()).unwrap(), + &Default::default(), + ) + .unwrap() + } + + #[derive(Debug, PartialEq, Eq, Hash)] + struct ArrayTransformFunc { + signature: Signature, + } + + impl ArrayTransformFunc { + pub fn new() -> Self { + Self { + signature: Signature::any(2, Volatility::Immutable), + } + } + } + + impl ScalarUDFImpl for ArrayTransformFunc { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "array_transform" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(arg_types[0].clone()) + } + + fn lambdas_parameters( + &self, + args: &[ValueOrLambdaParameter], + ) -> Result>>> { + let ValueOrLambdaParameter::Value(value_field) = &args[0] else { + unimplemented!() + }; + let DataType::List(field) = value_field.data_type() else { + unimplemented!() + }; + + Ok(vec![ + None, + Some(vec![Field::new( + "", + field.data_type().clone(), + field.is_nullable(), + )]), + ]) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + unimplemented!() + } + } + + #[test] + fn test_lambdas_schemas_from_args() { + let schema = list_list_int(); + let expr = array_transform(); + + let args = expr + .as_any() + .downcast_ref::() + .unwrap() + .args(); + + let schemas = + lambdas_schemas_from_args(&array_transform_udf(), args, &schema).unwrap(); + + assert_eq!(schemas, &[Cow::Borrowed(&schema), Cow::Owned(list_int())]); + } + + #[test] + fn test_apply_with_schema() { + let mut steps = vec![]; + + array_transform() + .apply_with_schema(&list_list_int(), |node, schema| { + steps.push((node.to_string(), schema.clone())); + + Ok(TreeNodeRecursion::Continue) + }) + .unwrap(); + + let expected = [ + ( + "array_transform(v@0, (v) -> array_transform(v@0, (v) -> (- v@0)))", + list_list_int(), + ), + ("(v) -> array_transform(v@0, (v) -> (- v@0))", list_int()), + ("array_transform(v@0, (v) -> (- v@0))", list_int()), + ("(v) -> (- v@0)", int()), + ("(- v@0)", int()), + ("v@0", int()), + ("v@0", int()), + ("v@0", int()), + ] + .map(|(a, b)| (String::from(a), b)); + + assert_eq!(steps, expected); + } + + #[test] + fn test_apply_with_lambdas_params() { + let array_transform = array_transform(); + let mut steps = vec![]; + + array_transform + .apply_with_lambdas_params(|node, params| { + steps.push((node.to_string(), params.clone())); + + Ok(TreeNodeRecursion::Continue) + }) + .unwrap(); + + let expected = [ + ( + "array_transform(v@0, (v) -> array_transform(v@0, (v) -> (- v@0)))", + HashSet::from(["v"]), + ), + ( + "(v) -> array_transform(v@0, (v) -> (- v@0))", + HashSet::from(["v"]), + ), + ("array_transform(v@0, (v) -> (- v@0))", HashSet::from(["v"])), + ("(v) -> (- v@0)", HashSet::from(["v"])), + ("(- v@0)", HashSet::from(["v"])), + ("v@0", HashSet::from(["v"])), + ("v@0", HashSet::from(["v"])), + ("v@0", HashSet::from(["v"])), + ] + .map(|(a, b)| (String::from(a), b)); + + assert_eq!(steps, expected); + } } diff --git a/datafusion/physical-expr/src/simplifier/mod.rs b/datafusion/physical-expr/src/simplifier/mod.rs index 80d6ee0a7b914..dd7e6e314672f 100644 --- a/datafusion/physical-expr/src/simplifier/mod.rs +++ b/datafusion/physical-expr/src/simplifier/mod.rs @@ -19,12 +19,12 @@ use arrow::datatypes::Schema; use datafusion_common::{ - tree_node::{Transformed, TreeNode, TreeNodeRewriter}, + tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter}, Result, }; use std::sync::Arc; -use crate::PhysicalExpr; +use crate::{PhysicalExpr, PhysicalExprExt}; pub mod unwrap_cast; @@ -48,6 +48,22 @@ impl<'a> PhysicalExprSimplifier<'a> { &mut self, expr: Arc, ) -> Result> { + return expr + .transform_up_with_schema(self.schema, |node, schema| { + // Apply unwrap cast optimization + #[cfg(test)] + let original_type = node.data_type(schema).unwrap(); + let unwrapped = unwrap_cast::unwrap_cast_in_comparison(node, schema)?; + #[cfg(test)] + assert_eq!( + unwrapped.data.data_type(schema).unwrap(), + original_type, + "Simplified expression should have the same data type as the original" + ); + Ok(unwrapped) + }) + .data(); + Ok(expr.rewrite(self)?.data) } } diff --git a/datafusion/physical-expr/src/simplifier/unwrap_cast.rs b/datafusion/physical-expr/src/simplifier/unwrap_cast.rs index d409ce9cb5bf2..1ccfc1cfe84d8 100644 --- a/datafusion/physical-expr/src/simplifier/unwrap_cast.rs +++ b/datafusion/physical-expr/src/simplifier/unwrap_cast.rs @@ -34,22 +34,22 @@ use std::sync::Arc; use arrow::datatypes::{DataType, Schema}; -use datafusion_common::{ - tree_node::{Transformed, TreeNode}, - Result, ScalarValue, -}; +use datafusion_common::{tree_node::Transformed, Result, ScalarValue}; use datafusion_expr::Operator; use datafusion_expr_common::casts::try_cast_literal_to_type; -use crate::expressions::{lit, BinaryExpr, CastExpr, Literal, TryCastExpr}; use crate::PhysicalExpr; +use crate::{ + expressions::{lit, BinaryExpr, CastExpr, Literal, TryCastExpr}, + PhysicalExprExt, +}; /// Attempts to unwrap casts in comparison expressions. pub(crate) fn unwrap_cast_in_comparison( expr: Arc, schema: &Schema, ) -> Result>> { - expr.transform_down(|e| { + expr.transform_down_with_schema(schema, |e, schema| { if let Some(binary) = e.as_any().downcast_ref::() { if let Some(unwrapped) = try_unwrap_cast_binary(binary, schema)? { return Ok(Transformed::yes(unwrapped)); diff --git a/datafusion/physical-expr/src/utils/mod.rs b/datafusion/physical-expr/src/utils/mod.rs index 745ae855efee2..92ecbb7176dc9 100644 --- a/datafusion/physical-expr/src/utils/mod.rs +++ b/datafusion/physical-expr/src/utils/mod.rs @@ -22,6 +22,7 @@ use std::borrow::Borrow; use std::sync::Arc; use crate::expressions::{BinaryExpr, Column}; +use crate::scalar_function::PhysicalExprExt; use crate::tree_node::ExprContext; use crate::PhysicalExpr; use crate::PhysicalSortExpr; @@ -227,9 +228,11 @@ where /// Recursively extract referenced [`Column`]s within a [`PhysicalExpr`]. pub fn collect_columns(expr: &Arc) -> HashSet { let mut columns = HashSet::::new(); - expr.apply(|expr| { + expr.apply_with_lambdas_params(|expr, lambdas_params| { if let Some(column) = expr.as_any().downcast_ref::() { - columns.get_or_insert_owned(column); + if !lambdas_params.contains(column.name()) { + columns.get_or_insert_owned(column); + } } Ok(TreeNodeRecursion::Continue) }) @@ -251,14 +254,16 @@ pub fn reassign_expr_columns( expr: Arc, schema: &Schema, ) -> Result> { - expr.transform_down(|expr| { + expr.transform_down_with_lambdas_params(|expr, lambdas_params| { if let Some(column) = expr.as_any().downcast_ref::() { - let index = schema.index_of(column.name())?; + if !lambdas_params.contains(column.name()) { + let index = schema.index_of(column.name())?; - return Ok(Transformed::yes(Arc::new(Column::new( - column.name(), - index, - )))); + return Ok(Transformed::yes(Arc::new(Column::new( + column.name(), + index, + )))); + } } Ok(Transformed::no(expr)) }) diff --git a/datafusion/physical-optimizer/src/enforce_sorting/sort_pushdown.rs b/datafusion/physical-optimizer/src/enforce_sorting/sort_pushdown.rs index 6e4e784866129..d87e001946414 100644 --- a/datafusion/physical-optimizer/src/enforce_sorting/sort_pushdown.rs +++ b/datafusion/physical-optimizer/src/enforce_sorting/sort_pushdown.rs @@ -29,7 +29,7 @@ use datafusion_expr::JoinType; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::utils::collect_columns; use datafusion_physical_expr::{ - add_offset_to_physical_sort_exprs, EquivalenceProperties, + add_offset_to_physical_sort_exprs, EquivalenceProperties, PhysicalExprExt, }; use datafusion_physical_expr_common::sort_expr::{ LexOrdering, LexRequirement, OrderingRequirements, PhysicalSortExpr, @@ -661,20 +661,21 @@ fn handle_custom_pushdown( .into_iter() .map(|req| { let child_schema = plan.children()[maintained_child_idx].schema(); - let updated_columns = req - .expr - .transform_up(|expr| { - if let Some(col) = expr.as_any().downcast_ref::() { - let new_index = col.index() - sub_offset; - Ok(Transformed::yes(Arc::new(Column::new( - child_schema.field(new_index).name(), - new_index, - )))) - } else { - Ok(Transformed::no(expr)) - } - })? - .data; + let updated_columns = + req.expr + .transform_up_with_lambdas_params(|expr, lambdas_params| { + match expr.as_any().downcast_ref::() { + Some(col) if !lambdas_params.contains(col.name()) => { + let new_index = col.index() - sub_offset; + Ok(Transformed::yes(Arc::new(Column::new( + child_schema.field(new_index).name(), + new_index, + )))) + } + _ => Ok(Transformed::no(expr)), + } + })? + .data; Ok(PhysicalSortRequirement::new(updated_columns, req.options)) }) .collect::>>()?; @@ -742,20 +743,21 @@ fn handle_hash_join( .into_iter() .map(|req| { let child_schema = plan.children()[1].schema(); - let updated_columns = req - .expr - .transform_up(|expr| { - if let Some(col) = expr.as_any().downcast_ref::() { - let index = projected_indices[col.index()].index; - Ok(Transformed::yes(Arc::new(Column::new( - child_schema.field(index).name(), - index, - )))) - } else { - Ok(Transformed::no(expr)) - } - })? - .data; + let updated_columns = + req.expr + .transform_up_with_lambdas_params(|expr, lambdas_params| { + match expr.as_any().downcast_ref::() { + Some(col) if !lambdas_params.contains(col.name()) => { + let index = projected_indices[col.index()].index; + Ok(Transformed::yes(Arc::new(Column::new( + child_schema.field(index).name(), + index, + )))) + } + _ => Ok(Transformed::no(expr)), + } + })? + .data; Ok(PhysicalSortRequirement::new(updated_columns, req.options)) }) .collect::>>()?; diff --git a/datafusion/physical-optimizer/src/projection_pushdown.rs b/datafusion/physical-optimizer/src/projection_pushdown.rs index 987e3cb6f713e..8ed81d3874d64 100644 --- a/datafusion/physical-optimizer/src/projection_pushdown.rs +++ b/datafusion/physical-optimizer/src/projection_pushdown.rs @@ -23,6 +23,7 @@ use crate::PhysicalOptimizerRule; use arrow::datatypes::{Fields, Schema, SchemaRef}; use datafusion_common::alias::AliasGenerator; +use datafusion_physical_expr::PhysicalExprExt; use std::collections::HashSet; use std::sync::Arc; @@ -243,9 +244,11 @@ fn minimize_join_filter( rhs_schema: &Schema, ) -> JoinFilter { let mut used_columns = HashSet::new(); - expr.apply(|expr| { + expr.apply_with_lambdas_params(|expr, lambdas_params| { if let Some(col) = expr.as_any().downcast_ref::() { - used_columns.insert(col.index()); + if !lambdas_params.contains(col.name()) { + used_columns.insert(col.index()); + } } Ok(TreeNodeRecursion::Continue) }) @@ -267,17 +270,19 @@ fn minimize_join_filter( .collect::(); let final_expr = expr - .transform_up(|expr| match expr.as_any().downcast_ref::() { - None => Ok(Transformed::no(expr)), - Some(column) => { - let new_idx = used_columns - .iter() - .filter(|idx| **idx < column.index()) - .count(); - let new_column = Column::new(column.name(), new_idx); - Ok(Transformed::yes( - Arc::new(new_column) as Arc - )) + .transform_up_with_lambdas_params(|expr, lambdas_params| { + match expr.as_any().downcast_ref::() { + Some(column) if !lambdas_params.contains(column.name()) => { + let new_idx = used_columns + .iter() + .filter(|idx| **idx < column.index()) + .count(); + let new_column = Column::new(column.name(), new_idx); + Ok(Transformed::yes( + Arc::new(new_column) as Arc + )) + } + _ => Ok(Transformed::no(expr)), } }) .expect("Closure cannot fail"); @@ -380,10 +385,9 @@ impl<'a> JoinFilterRewriter<'a> { // First, add a new projection. The expression must be rewritten, as it is no longer // executed against the filter schema. let new_idx = self.join_side_projections.len(); - let rewritten_expr = expr.transform_up(|expr| { + let rewritten_expr = expr.transform_up_with_lambdas_params(|expr, lambdas_params| { Ok(match expr.as_any().downcast_ref::() { - None => Transformed::no(expr), - Some(column) => { + Some(column) if !lambdas_params.contains(column.name()) => { let intermediate_column = &self.intermediate_column_indices[column.index()]; assert_eq!(intermediate_column.side, self.join_side); @@ -393,6 +397,7 @@ impl<'a> JoinFilterRewriter<'a> { let new_column = Column::new(field.name(), join_side_index); Transformed::yes(Arc::new(new_column) as Arc) } + _ => Transformed::no(expr), }) })?; self.join_side_projections.push((rewritten_expr.data, name)); @@ -415,15 +420,17 @@ impl<'a> JoinFilterRewriter<'a> { join_side: JoinSide, ) -> Result { let mut result = false; - expr.apply(|expr| match expr.as_any().downcast_ref::() { - None => Ok(TreeNodeRecursion::Continue), - Some(c) => { - let column_index = &self.intermediate_column_indices[c.index()]; - if column_index.side == join_side { - result = true; - return Ok(TreeNodeRecursion::Stop); + expr.apply_with_lambdas_params(|expr, lambdas_params| { + match expr.as_any().downcast_ref::() { + Some(c) if !lambdas_params.contains(c.name()) => { + let column_index = &self.intermediate_column_indices[c.index()]; + if column_index.side == join_side { + result = true; + return Ok(TreeNodeRecursion::Stop); + } + Ok(TreeNodeRecursion::Continue) } - Ok(TreeNodeRecursion::Continue) + _ => Ok(TreeNodeRecursion::Continue), } })?; diff --git a/datafusion/physical-plan/src/async_func.rs b/datafusion/physical-plan/src/async_func.rs index 54a76e0ebb971..be72a6af2b509 100644 --- a/datafusion/physical-plan/src/async_func.rs +++ b/datafusion/physical-plan/src/async_func.rs @@ -22,13 +22,13 @@ use crate::{ }; use arrow::array::RecordBatch; use arrow_schema::{Fields, Schema, SchemaRef}; -use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; +use datafusion_common::tree_node::{Transformed, TreeNodeRecursion}; use datafusion_common::{internal_err, Result}; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_physical_expr::async_scalar_function::AsyncFuncExpr; use datafusion_physical_expr::equivalence::ProjectionMapping; use datafusion_physical_expr::expressions::Column; -use datafusion_physical_expr::ScalarFunctionExpr; +use datafusion_physical_expr::{PhysicalExprExt, ScalarFunctionExpr}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use futures::stream::StreamExt; use log::trace; @@ -249,7 +249,7 @@ impl AsyncMapper { schema: &Schema, ) -> Result<()> { // recursively look for references to async functions - physical_expr.apply(|expr| { + physical_expr.apply_with_schema(schema, |expr, schema| { if let Some(scalar_func_expr) = expr.as_any().downcast_ref::() { diff --git a/datafusion/physical-plan/src/joins/stream_join_utils.rs b/datafusion/physical-plan/src/joins/stream_join_utils.rs index 80221a77992ce..b70a8f60508a5 100644 --- a/datafusion/physical-plan/src/joins/stream_join_utils.rs +++ b/datafusion/physical-plan/src/joins/stream_join_utils.rs @@ -35,7 +35,7 @@ use arrow::array::{ }; use arrow::compute::concat_batches; use arrow::datatypes::{ArrowNativeType, Schema, SchemaRef}; -use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_common::tree_node::{Transformed, TransformedResult}; use datafusion_common::utils::memory::estimate_memory_size; use datafusion_common::{ arrow_datafusion_err, DataFusionError, HashSet, JoinSide, Result, ScalarValue, @@ -44,7 +44,7 @@ use datafusion_expr::interval_arithmetic::Interval; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::intervals::cp_solver::ExprIntervalGraph; use datafusion_physical_expr::utils::collect_columns; -use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; +use datafusion_physical_expr::{PhysicalExpr, PhysicalExprExt, PhysicalSortExpr}; use datafusion_physical_expr_common::sort_expr::LexOrdering; use hashbrown::HashTable; @@ -312,13 +312,13 @@ pub fn convert_sort_expr_with_filter_schema( // Since we are sure that one to one column mapping includes all columns, we convert // the sort expression into a filter expression. let converted_filter_expr = expr - .transform_up(|p| { - convert_filter_columns(p.as_ref(), &column_map).map(|transformed| { - match transformed { + .transform_up_with_lambdas_params(|p, lambdas_params| { + convert_filter_columns(p.as_ref(), &column_map, lambdas_params).map( + |transformed| match transformed { Some(transformed) => Transformed::yes(transformed), None => Transformed::no(p), - } - }) + }, + ) }) .data()?; // Search the converted `PhysicalExpr` in filter expression; if an exact @@ -361,14 +361,17 @@ pub fn build_filter_input_order( fn convert_filter_columns( input: &dyn PhysicalExpr, column_map: &HashMap, + lambdas_params: &HashSet, ) -> Result>> { // Attempt to downcast the input expression to a Column type. - Ok(if let Some(col) = input.as_any().downcast_ref::() { - // If the downcast is successful, retrieve the corresponding filter column. - column_map.get(col).map(|c| Arc::new(c.clone()) as _) - } else { - // If the downcast fails, return the input expression as is. - None + Ok(match input.as_any().downcast_ref::() { + Some(col) if !lambdas_params.contains(col.name()) => { + column_map.get(col).map(|c| Arc::new(c.clone()) as _) + } + _ => { + // If the downcast fails, return the input expression as is. + None + } }) } diff --git a/datafusion/physical-plan/src/projection.rs b/datafusion/physical-plan/src/projection.rs index ead2196860cde..ab654e4eee1df 100644 --- a/datafusion/physical-plan/src/projection.rs +++ b/datafusion/physical-plan/src/projection.rs @@ -42,14 +42,13 @@ use std::task::{Context, Poll}; use arrow::datatypes::SchemaRef; use arrow::record_batch::{RecordBatch, RecordBatchOptions}; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{ - Transformed, TransformedResult, TreeNode, TreeNodeRecursion, -}; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNodeRecursion}; use datafusion_common::{internal_err, JoinSide, Result}; use datafusion_execution::TaskContext; use datafusion_physical_expr::equivalence::ProjectionMapping; use datafusion_physical_expr::utils::collect_columns; -use datafusion_physical_expr_common::physical_expr::{fmt_sql, PhysicalExprRef}; +use datafusion_physical_expr::{PhysicalExprExt, PhysicalExprRef}; +use datafusion_physical_expr_common::physical_expr::fmt_sql; use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexRequirement}; // Re-exported from datafusion-physical-expr for backwards compatibility // We recommend updating your imports to use datafusion-physical-expr directly @@ -866,10 +865,12 @@ fn try_unifying_projections( projection.expr().iter().for_each(|proj_expr| { proj_expr .expr - .apply(|expr| { + .apply_with_lambdas_params(|expr, lambdas_params| { Ok({ if let Some(column) = expr.as_any().downcast_ref::() { - *column_ref_map.entry(column.clone()).or_default() += 1; + if !lambdas_params.contains(column.name()) { + *column_ref_map.entry(column.clone()).or_default() += 1; + } } TreeNodeRecursion::Continue }) @@ -957,31 +958,31 @@ fn new_columns_for_join_on( .filter_map(|on| { // Rewrite all columns in `on` Arc::clone(*on) - .transform(|expr| { - if let Some(column) = expr.as_any().downcast_ref::() { - // Find the column in the projection expressions - let new_column = projection_exprs - .iter() - .enumerate() - .find(|(_, (proj_column, _))| { - column.name() == proj_column.name() - && column.index() + column_index_offset - == proj_column.index() - }) - .map(|(index, (_, alias))| Column::new(alias, index)); - if let Some(new_column) = new_column { - Ok(Transformed::yes(Arc::new(new_column))) - } else { - // If the column is not found in the projection expressions, - // it means that the column is not projected. In this case, - // we cannot push the projection down. - internal_err!( - "Column {:?} not found in projection expressions", - column - ) + .transform_with_lambdas_params(|expr, lambdas_params| { + match expr.as_any().downcast_ref::() { + Some(column) if !lambdas_params.contains(column.name()) => { + let new_column = projection_exprs + .iter() + .enumerate() + .find(|(_, (proj_column, _))| { + column.name() == proj_column.name() + && column.index() + column_index_offset + == proj_column.index() + }) + .map(|(index, (_, alias))| Column::new(alias, index)); + if let Some(new_column) = new_column { + Ok(Transformed::yes(Arc::new(new_column))) + } else { + // If the column is not found in the projection expressions, + // it means that the column is not projected. In this case, + // we cannot push the projection down. + internal_err!( + "Column {:?} not found in projection expressions", + column + ) + } } - } else { - Ok(Transformed::no(expr)) + _ => Ok(Transformed::no(expr)), } }) .data() diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 2774b5b6ba7c3..b87a50b3f5281 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -622,6 +622,11 @@ pub fn serialize_expr( .unwrap_or(HashMap::new()), })), }, + Expr::Lambda { .. } => { + return Err(Error::General( + "Proto serialization error: Lambda not supported".to_string(), + )) + } }; Ok(expr_node) diff --git a/datafusion/pruning/src/pruning_predicate.rs b/datafusion/pruning/src/pruning_predicate.rs index 380ada10df6e1..c9df93f8b693c 100644 --- a/datafusion/pruning/src/pruning_predicate.rs +++ b/datafusion/pruning/src/pruning_predicate.rs @@ -38,13 +38,14 @@ use datafusion_common::error::Result; use datafusion_common::tree_node::TransformedResult; use datafusion_common::{ internal_datafusion_err, internal_err, plan_datafusion_err, plan_err, - tree_node::{Transformed, TreeNode}, - ScalarValue, + tree_node::Transformed, ScalarValue, }; use datafusion_common::{Column, DFSchema}; use datafusion_expr_common::operator::Operator; use datafusion_physical_expr::utils::{collect_columns, Guarantee, LiteralGuarantee}; -use datafusion_physical_expr::{expressions as phys_expr, PhysicalExprRef}; +use datafusion_physical_expr::{ + expressions as phys_expr, PhysicalExprExt, PhysicalExprRef, +}; use datafusion_physical_expr_common::physical_expr::snapshot_physical_expr; use datafusion_physical_plan::{ColumnarValue, PhysicalExpr}; @@ -1204,9 +1205,9 @@ fn rewrite_column_expr( column_old: &phys_expr::Column, column_new: &phys_expr::Column, ) -> Result> { - e.transform(|expr| { + e.transform_with_lambdas_params(|expr, lambdas_params| { if let Some(column) = expr.as_any().downcast_ref::() { - if column == column_old { + if !lambdas_params.contains(column.name()) && column == column_old { return Ok(Transformed::yes(Arc::new(column_new.clone()))); } } diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 50e479af36204..c13fd33104eb0 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -22,10 +22,10 @@ use datafusion_common::{ internal_datafusion_err, internal_err, not_impl_err, plan_datafusion_err, plan_err, DFSchema, Dependency, Diagnostic, Result, Span, }; -use datafusion_expr::expr::{ - NullTreatment, ScalarFunction, Unnest, WildcardOptions, WindowFunction, -}; -use datafusion_expr::planner::{PlannerResult, RawAggregateExpr, RawWindowExpr}; +use datafusion_expr::expr::{Lambda, ScalarFunction, Unnest}; +use datafusion_expr::expr::{NullTreatment, WildcardOptions, WindowFunction}; +use datafusion_expr::planner::PlannerResult; +use datafusion_expr::planner::{RawAggregateExpr, RawWindowExpr}; use datafusion_expr::{expr, Expr, ExprSchemable, WindowFrame, WindowFunctionDefinition}; use sqlparser::ast::{ DuplicateTreatment, Expr as SQLExpr, Function as SQLFunction, FunctionArg, @@ -724,6 +724,26 @@ impl SqlToRel<'_, S> { let arg_name = crate::utils::normalize_ident(name); Ok((expr, Some(arg_name))) } + FunctionArg::Unnamed(FunctionArgExpr::Expr(SQLExpr::Lambda( + sqlparser::ast::LambdaFunction { params, body }, + ))) => { + let params = params + .into_iter() + .map(|v| v.to_string()) + .collect::>(); + + Ok(( + Expr::Lambda(Lambda { + params: params.clone(), + body: Box::new(self.sql_expr_to_logical_expr( + *body, + schema, + &mut planner_context.clone().with_lambda_parameters(params), + )?), + }), + None, + )) + } FunctionArg::Unnamed(FunctionArgExpr::Expr(arg)) => { let expr = self.sql_expr_to_logical_expr(arg, schema, planner_context)?; Ok((expr, None)) diff --git a/datafusion/sql/src/expr/identifier.rs b/datafusion/sql/src/expr/identifier.rs index 3c57d195ade67..dc39cb4de055d 100644 --- a/datafusion/sql/src/expr/identifier.rs +++ b/datafusion/sql/src/expr/identifier.rs @@ -53,6 +53,19 @@ impl SqlToRel<'_, S> { // identifier. (e.g. it is "foo.bar" not foo.bar) let normalize_ident = self.ident_normalizer.normalize(id); + if planner_context + .lambdas_parameters() + .contains(&normalize_ident) + { + let mut column = Column::new_unqualified(normalize_ident); + if self.options.collect_spans { + if let Some(span) = Span::try_from_sqlparser_span(id_span) { + column.spans_mut().add_span(span); + } + } + return Ok(Expr::Column(column)); + } + // Check for qualified field with unqualified name if let Ok((qualifier, _)) = schema.qualified_field_with_unqualified_name(normalize_ident.as_str()) diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 7bac0337672dc..2992378fd1d6c 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -28,13 +28,11 @@ use datafusion_common::datatype::{DataTypeExt, FieldExt}; use datafusion_common::error::add_possible_columns_to_diag; use datafusion_common::TableReference; use datafusion_common::{ - field_not_found, internal_err, plan_datafusion_err, DFSchemaRef, Diagnostic, - SchemaError, + field_not_found, plan_datafusion_err, DFSchemaRef, Diagnostic, HashSet, SchemaError, }; use datafusion_common::{not_impl_err, plan_err, DFSchema, DataFusionError, Result}; use datafusion_expr::logical_plan::{LogicalPlan, LogicalPlanBuilder}; pub use datafusion_expr::planner::ContextProvider; -use datafusion_expr::utils::find_column_exprs; use datafusion_expr::{col, Expr}; use sqlparser::ast::{ArrayElemTypeDef, ExactNumberInfo, TimezoneInfo}; use sqlparser::ast::{ColumnDef as SQLColumnDef, ColumnOption}; @@ -267,6 +265,8 @@ pub struct PlannerContext { outer_from_schema: Option, /// The query schema defined by the table create_table_schema: Option, + /// The lambda introduced columns names + lambdas_parameters: HashSet, } impl Default for PlannerContext { @@ -284,6 +284,7 @@ impl PlannerContext { outer_query_schema: None, outer_from_schema: None, create_table_schema: None, + lambdas_parameters: HashSet::new(), } } @@ -370,6 +371,19 @@ impl PlannerContext { self.ctes.get(cte_name).map(|cte| cte.as_ref()) } + pub fn lambdas_parameters(&self) -> &HashSet { + &self.lambdas_parameters + } + + pub fn with_lambda_parameters( + mut self, + arguments: impl IntoIterator, + ) -> Self { + self.lambdas_parameters.extend(arguments); + + self + } + /// Remove the plan of CTE / Subquery for the specified name pub(super) fn remove_cte(&mut self, cte_name: &str) { self.ctes.remove(cte_name); @@ -531,10 +545,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { schema: &DFSchema, exprs: &[Expr], ) -> Result<()> { - find_column_exprs(exprs) + exprs .iter() - .try_for_each(|col| match col { - Expr::Column(col) => match &col.relation { + .flat_map(|expr| expr.column_refs()) + .try_for_each(|col| { + match &col.relation { Some(r) => schema.field_with_qualified_name(r, &col.name).map(|_| ()), None => { if !schema.fields_with_unqualified_name(&col.name).is_empty() { @@ -584,8 +599,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { err.with_diagnostic(diagnostic) } _ => err, - }), - _ => internal_err!("Not a column"), + }) }) } diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs index 42013a76a8657..0e7490d2c780b 100644 --- a/datafusion/sql/src/select.rs +++ b/datafusion/sql/src/select.rs @@ -540,9 +540,11 @@ impl SqlToRel<'_, S> { None => { let mut columns = HashSet::new(); for expr in &aggr_expr { - expr.apply(|expr| { + expr.apply_with_lambdas_params(|expr, lambdas_params| { if let Expr::Column(c) = expr { - columns.insert(Expr::Column(c.clone())); + if !c.is_lambda_parameter(lambdas_params) { + columns.insert(Expr::Column(c.clone())); + } } Ok(TreeNodeRecursion::Continue) }) diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 97f2b58bf8402..67ca92bb1c1f1 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -15,13 +15,14 @@ // specific language governing permissions and limitations // under the License. -use datafusion_expr::expr::{AggregateFunctionParams, Unnest, WindowFunctionParams}; +use datafusion_expr::expr::{AggregateFunctionParams, WindowFunctionParams}; +use datafusion_expr::expr::{Lambda, Unnest}; use sqlparser::ast::Value::SingleQuotedString; use sqlparser::ast::{ - self, Array, BinaryOperator, CaseWhen, DuplicateTreatment, Expr as AstExpr, Function, - Ident, Interval, ObjectName, OrderByOptions, Subscript, TimezoneInfo, UnaryOperator, - ValueWithSpan, + self, Array, BinaryOperator, Expr as AstExpr, Function, Ident, Interval, + LambdaFunction, ObjectName, Subscript, TimezoneInfo, UnaryOperator, }; +use sqlparser::ast::{CaseWhen, DuplicateTreatment, OrderByOptions, ValueWithSpan}; use std::sync::Arc; use std::vec; @@ -527,6 +528,14 @@ impl Unparser<'_> { } Expr::OuterReferenceColumn(_, col) => self.col_to_sql(col), Expr::Unnest(unnest) => self.unnest_to_sql(unnest), + Expr::Lambda(Lambda { params, body }) => { + Ok(ast::Expr::Lambda(LambdaFunction { + params: ast::OneOrManyWithParens::Many( + params.iter().map(|param| param.as_str().into()).collect(), + ), + body: Box::new(self.expr_to_sql_inner(body)?), + })) + } } } diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index e7535338b7677..c218ce547b312 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -40,7 +40,7 @@ use crate::unparser::{ast::UnnestRelationBuilder, rewrite::rewrite_qualify}; use crate::utils::UNNEST_PLACEHOLDER; use datafusion_common::{ internal_err, not_impl_err, - tree_node::{TransformedResult, TreeNode}, + tree_node::TransformedResult, Column, DataFusionError, Result, ScalarValue, TableReference, }; use datafusion_expr::expr::OUTER_REFERENCE_COLUMN_PREFIX; @@ -1131,7 +1131,7 @@ impl Unparser<'_> { .cloned() .map(|expr| { if let Some(ref mut rewriter) = filter_alias_rewriter { - expr.rewrite(rewriter).data() + expr.rewrite_with_lambdas_params(rewriter).data() } else { Ok(expr) } @@ -1197,7 +1197,7 @@ impl Unparser<'_> { .cloned() .map(|expr| { if let Some(ref mut rewriter) = alias_rewriter { - expr.rewrite(rewriter).data() + expr.rewrite_with_lambdas_params(rewriter).data() } else { Ok(expr) } diff --git a/datafusion/sql/src/unparser/rewrite.rs b/datafusion/sql/src/unparser/rewrite.rs index c961f1d6f1f0c..58f4435095517 100644 --- a/datafusion/sql/src/unparser/rewrite.rs +++ b/datafusion/sql/src/unparser/rewrite.rs @@ -20,10 +20,11 @@ use std::{collections::HashSet, sync::Arc}; use arrow::datatypes::Schema; use datafusion_common::tree_node::TreeNodeContainer; use datafusion_common::{ - tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter}, + tree_node::{Transformed, TransformedResult, TreeNode}, Column, HashMap, Result, TableReference, }; use datafusion_expr::expr::{Alias, UNNEST_COLUMN_PREFIX}; +use datafusion_expr::tree_node::TreeNodeRewriterWithPayload; use datafusion_expr::{Expr, LogicalPlan, Projection, Sort, SortExpr}; use sqlparser::ast::Ident; @@ -466,12 +467,17 @@ pub struct TableAliasRewriter<'a> { pub alias_name: TableReference, } -impl TreeNodeRewriter for TableAliasRewriter<'_> { +impl TreeNodeRewriterWithPayload for TableAliasRewriter<'_> { type Node = Expr; + type Payload<'a> = &'a datafusion_common::HashSet; - fn f_down(&mut self, expr: Expr) -> Result> { + fn f_down( + &mut self, + expr: Expr, + lambdas_params: &datafusion_common::HashSet, + ) -> Result> { match expr { - Expr::Column(column) => { + Expr::Column(column) if !column.is_lambda_parameter(lambdas_params) => { if let Ok(field) = self.table_schema.field_with_name(&column.name) { let new_column = Column::new(Some(self.alias_name.clone()), field.name().clone()); diff --git a/datafusion/sql/src/unparser/utils.rs b/datafusion/sql/src/unparser/utils.rs index 8b3791017a8af..f785f640dbcee 100644 --- a/datafusion/sql/src/unparser/utils.rs +++ b/datafusion/sql/src/unparser/utils.rs @@ -161,11 +161,11 @@ pub(crate) fn find_window_nodes_within_select<'a>( /// For example, if expr contains the column expr "__unnest_placeholder(make_array(Int64(1),Int64(2),Int64(2),Int64(5),NULL),depth=1)" /// it will be transformed into an actual unnest expression UNNEST([1, 2, 2, 5, NULL]) pub(crate) fn unproject_unnest_expr(expr: Expr, unnest: &Unnest) -> Result { - expr.transform(|sub_expr| { + expr.transform_up_with_lambdas_params(|sub_expr, lambdas_params| { if let Expr::Column(col_ref) = &sub_expr { // Check if the column is among the columns to run unnest on. // Currently, only List/Array columns (defined in `list_type_columns`) are supported for unnesting. - if unnest.list_type_columns.iter().any(|e| e.1.output_column.name == col_ref.name) { + if !col_ref.is_lambda_parameter(lambdas_params) && unnest.list_type_columns.iter().any(|e| e.1.output_column.name == col_ref.name) { if let Ok(idx) = unnest.schema.index_of_column(col_ref) { if let LogicalPlan::Projection(Projection { expr, .. }) = unnest.input.as_ref() { if let Some(unprojected_expr) = expr.get(idx) { @@ -195,22 +195,21 @@ pub(crate) fn unproject_agg_exprs( agg: &Aggregate, windows: Option<&[&Window]>, ) -> Result { - expr.transform(|sub_expr| { - if let Expr::Column(c) = sub_expr { - if let Some(unprojected_expr) = find_agg_expr(agg, &c)? { - Ok(Transformed::yes(unprojected_expr.clone())) - } else if let Some(unprojected_expr) = - windows.and_then(|w| find_window_expr(w, &c.name).cloned()) - { - // Window function can contain an aggregation columns, e.g., 'avg(sum(ss_sales_price)) over ...' that needs to be unprojected - Ok(Transformed::yes(unproject_agg_exprs(unprojected_expr, agg, None)?)) - } else { - internal_err!( - "Tried to unproject agg expr for column '{}' that was not found in the provided Aggregate!", &c.name - ) - } - } else { - Ok(Transformed::no(sub_expr)) + expr.transform_up_with_lambdas_params(|sub_expr, lambdas_params| { + match sub_expr { + Expr::Column(c) if !c.is_lambda_parameter(lambdas_params) => if let Some(unprojected_expr) = find_agg_expr(agg, &c)? { + Ok(Transformed::yes(unprojected_expr.clone())) + } else if let Some(unprojected_expr) = + windows.and_then(|w| find_window_expr(w, &c.name).cloned()) + { + // Window function can contain an aggregation columns, e.g., 'avg(sum(ss_sales_price)) over ...' that needs to be unprojected + Ok(Transformed::yes(unproject_agg_exprs(unprojected_expr, agg, None)?)) + } else { + internal_err!( + "Tried to unproject agg expr for column '{}' that was not found in the provided Aggregate!", &c.name + ) + }, + _ => Ok(Transformed::no(sub_expr)), } }) .map(|e| e.data) @@ -222,16 +221,15 @@ pub(crate) fn unproject_agg_exprs( /// For example, if expr contains the column expr "COUNT(*) PARTITION BY id" it will be transformed /// into an actual window expression as identified in the window node. pub(crate) fn unproject_window_exprs(expr: Expr, windows: &[&Window]) -> Result { - expr.transform(|sub_expr| { - if let Expr::Column(c) = sub_expr { + expr.transform_up_with_lambdas_params(|sub_expr, lambdas_params| match sub_expr { + Expr::Column(c) if !c.is_lambda_parameter(lambdas_params) => { if let Some(unproj) = find_window_expr(windows, &c.name) { Ok(Transformed::yes(unproj.clone())) } else { Ok(Transformed::no(Expr::Column(c))) } - } else { - Ok(Transformed::no(sub_expr)) } + _ => Ok(Transformed::no(sub_expr)), }) .map(|e| e.data) } @@ -376,7 +374,7 @@ pub(crate) fn try_transform_to_simple_table_scan_with_filters( .cloned() .map(|expr| { if let Some(ref mut rewriter) = filter_alias_rewriter { - expr.rewrite(rewriter).data() + expr.rewrite_with_lambdas_params(rewriter).data() } else { Ok(expr) } diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs index 3c86d2d04905f..6380412e3b5ee 100644 --- a/datafusion/sql/src/utils.rs +++ b/datafusion/sql/src/utils.rs @@ -23,16 +23,16 @@ use arrow::datatypes::{ DataType, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, DECIMAL_DEFAULT_SCALE, }; use datafusion_common::tree_node::{ - Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter, + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, }; use datafusion_common::{ - exec_datafusion_err, exec_err, internal_err, plan_err, Column, DFSchemaRef, - Diagnostic, HashMap, Result, ScalarValue, + exec_datafusion_err, exec_err, internal_err, plan_err, Column, DFSchema, Diagnostic, HashMap, Result, ScalarValue }; use datafusion_expr::builder::get_struct_unnested_columns; use datafusion_expr::expr::{ Alias, GroupingSet, Unnest, WindowFunction, WindowFunctionParams, }; +use datafusion_expr::tree_node::TreeNodeRewriterWithPayload; use datafusion_expr::utils::{expr_as_column_expr, find_column_exprs}; use datafusion_expr::{ col, expr_vec_fmt, ColumnUnnestList, Expr, ExprSchemable, LogicalPlan, @@ -44,9 +44,9 @@ use sqlparser::ast::{Ident, Value}; /// Make a best-effort attempt at resolving all columns in the expression tree pub(crate) fn resolve_columns(expr: &Expr, plan: &LogicalPlan) -> Result { expr.clone() - .transform_up(|nested_expr| { + .transform_up_with_lambdas_params(|nested_expr, lambdas_params| { match nested_expr { - Expr::Column(col) => { + Expr::Column(col) if !col.is_lambda_parameter(lambdas_params) => { let (qualifier, field) = plan.schema().qualified_field_from_column(&col)?; Ok(Transformed::yes(Expr::Column(Column::from(( @@ -81,6 +81,7 @@ pub(crate) fn rebase_expr( base_exprs: &[Expr], plan: &LogicalPlan, ) -> Result { + //todo user transform_down_with_lambdas_params expr.clone() .transform_down(|nested_expr| { if base_exprs.contains(&nested_expr) { @@ -231,8 +232,8 @@ pub(crate) fn resolve_aliases_to_exprs( expr: Expr, aliases: &HashMap, ) -> Result { - expr.transform_up(|nested_expr| match nested_expr { - Expr::Column(c) if c.relation.is_none() => { + expr.transform_up_with_lambdas_params(|nested_expr, lambdas_params| match nested_expr { + Expr::Column(c) if c.relation.is_none() && !c.is_lambda_parameter(lambdas_params) => { if let Some(aliased_expr) = aliases.get(&c.name) { Ok(Transformed::yes(aliased_expr.clone())) } else { @@ -371,7 +372,6 @@ This is only usedful when used with transform down up A full example of how the transformation works: */ struct RecursiveUnnestRewriter<'a> { - input_schema: &'a DFSchemaRef, root_expr: &'a Expr, // Useful to detect which child expr is a part of/ not a part of unnest operation top_most_unnest: Option, @@ -405,6 +405,7 @@ impl RecursiveUnnestRewriter<'_> { alias_name: String, expr_in_unnest: &Expr, struct_allowed: bool, + input_schema: &DFSchema, ) -> Result> { let inner_expr_name = expr_in_unnest.schema_name().to_string(); @@ -418,7 +419,7 @@ impl RecursiveUnnestRewriter<'_> { // column name as is, to comply with group by and order by let placeholder_column = Column::from_name(placeholder_name.clone()); - let (data_type, _) = expr_in_unnest.data_type_and_nullable(self.input_schema)?; + let (data_type, _) = expr_in_unnest.data_type_and_nullable(input_schema)?; match data_type { DataType::Struct(inner_fields) => { @@ -468,17 +469,18 @@ impl RecursiveUnnestRewriter<'_> { } } -impl TreeNodeRewriter for RecursiveUnnestRewriter<'_> { +impl TreeNodeRewriterWithPayload for RecursiveUnnestRewriter<'_> { type Node = Expr; + type Payload<'a> = &'a DFSchema; /// This downward traversal needs to keep track of: /// - Whether or not some unnest expr has been visited from the top util the current node /// - If some unnest expr has been visited, maintain a stack of such information, this /// is used to detect if some recursive unnest expr exists (e.g **unnest(unnest(unnest(3d column))))** - fn f_down(&mut self, expr: Expr) -> Result> { + fn f_down(&mut self, expr: Expr, input_schema: &DFSchema) -> Result> { if let Expr::Unnest(ref unnest_expr) = expr { let (data_type, _) = - unnest_expr.expr.data_type_and_nullable(self.input_schema)?; + unnest_expr.expr.data_type_and_nullable(input_schema)?; self.consecutive_unnest.push(Some(unnest_expr.clone())); // if expr inside unnest is a struct, do not consider // the next unnest as consecutive unnest (if any) @@ -532,7 +534,7 @@ impl TreeNodeRewriter for RecursiveUnnestRewriter<'_> { /// column2 /// ``` /// - fn f_up(&mut self, expr: Expr) -> Result> { + fn f_up(&mut self, expr: Expr, input_schema: &DFSchema) -> Result> { if let Expr::Unnest(ref traversing_unnest) = expr { if traversing_unnest == self.top_most_unnest.as_ref().unwrap() { self.top_most_unnest = None; @@ -568,6 +570,7 @@ impl TreeNodeRewriter for RecursiveUnnestRewriter<'_> { expr.schema_name().to_string(), inner_expr, struct_allowed, + input_schema, )?; if struct_allowed { self.transformed_root_exprs = Some(transformed_exprs.clone()); @@ -619,7 +622,6 @@ pub(crate) fn rewrite_recursive_unnest_bottom_up( original_expr: &Expr, ) -> Result> { let mut rewriter = RecursiveUnnestRewriter { - input_schema: input.schema(), root_expr: original_expr, top_most_unnest: None, consecutive_unnest: vec![], @@ -641,7 +643,7 @@ pub(crate) fn rewrite_recursive_unnest_bottom_up( data: transformed_expr, transformed, tnr: _, - } = original_expr.clone().rewrite(&mut rewriter)?; + } = original_expr.clone().rewrite_with_schema(input.schema(), &mut rewriter)?; if !transformed { // TODO: remove the next line after `Expr::Wildcard` is removed diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 00629c392df48..29ea8cb786072 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -5866,10 +5866,10 @@ select array_ndims(arrow_cast([null], 'List(List(List(Int64)))')); 3 # array_ndims scalar function #2 -query II -select array_ndims(array_repeat(array_repeat(array_repeat(1, 3), 2), 1)), array_ndims([[[[[[[[[[[[[[[[[[[[[1]]]]]]]]]]]]]]]]]]]]]); ----- -3 21 +#query II +#select array_ndims(array_repeat(array_repeat(array_repeat(1, 3), 2), 1)), array_ndims([[[[[[[[[[[[[[[[[[[[[1]]]]]]]]]]]]]]]]]]]]]); +#---- +#3 21 # array_ndims scalar function #3 query II diff --git a/datafusion/sqllogictest/test_files/lambda.slt b/datafusion/sqllogictest/test_files/lambda.slt new file mode 100644 index 0000000000000..0043eae17a60c --- /dev/null +++ b/datafusion/sqllogictest/test_files/lambda.slt @@ -0,0 +1,166 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +############# +## Array Expressions Tests +############# + +statement ok +set datafusion.sql_parser.dialect = databricks; + +statement ok +CREATE TABLE tt +AS VALUES +([1, 50], 10), +([4, 50], 40); + +statement ok +CREATE TABLE t AS SELECT 1 as f, [ [ [2, 3], [2] ], [ [1] ], [ [] ] ] as v, 1 as n; + +query I? +SELECT t.n, array_transform([], e1 -> t.n) from t; +---- +1 [] + +query ? +SELECT array_transform([1], e1 -> (select n from t)); +---- +[1] + +query ? +SELECT array_transform(v, (v1, i) -> array_transform(v1, (v2, j) -> array_transform(v2, v3 -> j)) ) from t; +---- +[[[0, 0], [1]], [[0]], [[]]] + +query I? +SELECT t.n, array_transform([1, 2], (e) -> n) from t; +---- +1 [1, 1] + +# selection pushdown not working yet +query ? +SELECT array_transform([1, 2], (e) -> n) from t; +---- +[1, 1] + +query ? +SELECT array_transform([1, 2], (e, i) -> i) from t; +---- +[0, 1] + +# type coercion +query ? +SELECT array_transform([1, 2], (e, i) -> e+i) from t; +---- +[1, 3] + +query TT +EXPLAIN SELECT array_transform([1, 2], (e, i) -> e+i); +---- +logical_plan +01)Projection: array_transform(List([1, 2]), (e, i) -> e + CAST(i AS Int64)) AS array_transform(make_array(Int64(1),Int64(2)),(e, i) -> e + i) +02)--EmptyRelation: rows=1 +physical_plan +01)ProjectionExec: expr=[array_transform([1, 2], (e, i) -> e@0 + CAST(i@1 AS Int64)) as array_transform(make_array(Int64(1),Int64(2)),(e, i) -> e + i)] +02)--PlaceholderRowExec + +#cse +query TT +explain select n + 1, array_transform([1], v -> v + n + 1) from t; +---- +logical_plan +01)Projection: t.n + Int64(1), array_transform(List([1]), (v) -> v + t.n + Int64(1)) AS array_transform(make_array(Int64(1)),(v) -> v + t.n + Int64(1)) +02)--TableScan: t projection=[n] +physical_plan +01)ProjectionExec: expr=[n@0 + 1 as t.n + Int64(1), array_transform([1], (v) -> v@1 + n@0 + 1) as array_transform(make_array(Int64(1)),(v) -> v + t.n + Int64(1))] +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +query ? +SELECT array_transform([1,2,3,4,5], v -> v*2); +---- +[2, 4, 6, 8, 10] + +query ? +SELECT array_transform([1,2,3,4,5], v -> 2); +---- +[2, 2, 2, 2, 2] + +query ? +SELECT array_transform([[1,2],[3,4,5]], v -> array_transform(v, v -> v*2)); +---- +[[2, 4], [6, 8, 10]] + +query ? +SELECT array_transform([1,2,3,4,5], v -> repeat("a", v)); +---- +[a, aa, aaa, aaaa, aaaaa] + +query ? +SELECT array_transform([1,2,3,4,5], v -> list_repeat("a", v)); +---- +[[a], [a, a], [a, a, a], [a, a, a, a], [a, a, a, a, a]] + +query TT +EXPLAIN SELECT array_transform([1,2,3,4,5], v -> v*2); +---- +logical_plan +01)Projection: array_transform(List([1, 2, 3, 4, 5]), (v) -> v * Int64(2)) AS array_transform(make_array(Int64(1),Int64(2),Int64(3),Int64(4),Int64(5)),(v) -> v * Int64(2)) +02)--EmptyRelation: rows=1 +physical_plan +01)ProjectionExec: expr=[array_transform([1, 2, 3, 4, 5], (v) -> v@0 * 2) as array_transform(make_array(Int64(1),Int64(2),Int64(3),Int64(4),Int64(5)),(v) -> v * Int64(2))] +02)--PlaceholderRowExec + + +query I?? +SELECT t.n, t.v, array_transform(t.v, (v, i) -> array_transform(v, (v, j) -> n) ) from t; +---- +1 [[[2, 3], [2]], [[1]], [[]]] [[1, 1], [1], [1]] + +query ? +SELECT array_transform([1,2,3,4,5], v -> v*2); +---- +[2, 4, 6, 8, 10] + + +# expr simplifier +query TT +EXPLAIN SELECT v = v, array_transform([1], v -> v = v) from t; +---- +logical_plan +01)Projection: Boolean(true) AS t.v = t.v, array_transform(List([1]), (v) -> v IS NOT NULL OR Boolean(NULL)) AS array_transform(make_array(Int64(1)),(v) -> v = v) +02)--TableScan: t projection=[] +physical_plan +01)ProjectionExec: expr=[true as t.v = t.v, array_transform([1], (v) -> v@0 IS NOT NULL OR NULL) as array_transform(make_array(Int64(1)),(v) -> v = v)] +02)--DataSourceExec: partitions=1, partition_sizes=[1] + + +query error +select array_transform(); +---- +DataFusion error: Error during planning: 'array_transform' does not support zero arguments No function matches the given name and argument types 'array_transform()'. You might need to add explicit type casts. + Candidate functions: + array_transform(Any, Any) + + +query error DataFusion error: Execution error: expected list, got Field \{ name: "Int64\(1\)", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: \{\} \} +select array_transform(1, v -> v*2); + +query error DataFusion error: Execution error: array_transform expects a value follewed by a lambda, got \[Lambda\(\["v"\], false\), Value\(Field \{ name: "make_array\(Int64\(1\),Int64\(2\)\)", data_type: List\(Field \{ name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\)\] +select array_transform(v -> v*2, [1, 2]); + +query error DataFusion error: Execution error: lambdas_schemas: array_transform argument 1 \(0\-indexed\), a lambda, supports up to 2 arguments, but got 3 +SELECT array_transform([1, 2], (e, i, j) -> i) from t; diff --git a/datafusion/substrait/src/logical_plan/producer/expr/mod.rs b/datafusion/substrait/src/logical_plan/producer/expr/mod.rs index f4e43fd586773..103d593cafbc0 100644 --- a/datafusion/substrait/src/logical_plan/producer/expr/mod.rs +++ b/datafusion/substrait/src/logical_plan/producer/expr/mod.rs @@ -152,6 +152,7 @@ pub fn to_substrait_rex( not_impl_err!("Cannot convert {expr:?} to Substrait") } Expr::Unnest(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), + Expr::Lambda(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), // https://github.com/substrait-io/substrait/issues/349 } } From fa4a8fbebe21207225077f41183c2e9016a24fbd Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Sat, 22 Nov 2025 15:34:04 -0300 Subject: [PATCH 02/12] add lambdas: None to existing ScalarFunctionArgs in tests/benches --- datafusion/functions-nested/benches/map.rs | 1 + datafusion/functions-nested/src/array_has.rs | 2 ++ datafusion/functions-nested/src/lib.rs | 3 +++ datafusion/functions-nested/src/map_values.rs | 1 + datafusion/functions-nested/src/set_ops.rs | 1 + datafusion/functions/benches/ascii.rs | 4 ++++ .../functions/benches/character_length.rs | 4 ++++ datafusion/functions/benches/chr.rs | 1 + datafusion/functions/benches/concat.rs | 1 + datafusion/functions/benches/cot.rs | 2 ++ datafusion/functions/benches/date_bin.rs | 1 + datafusion/functions/benches/date_trunc.rs | 1 + datafusion/functions/benches/encoding.rs | 4 ++++ datafusion/functions/benches/find_in_set.rs | 4 ++++ datafusion/functions/benches/gcd.rs | 3 +++ datafusion/functions/benches/initcap.rs | 3 +++ datafusion/functions/benches/isnan.rs | 2 ++ datafusion/functions/benches/iszero.rs | 2 ++ datafusion/functions/benches/lower.rs | 6 ++++++ datafusion/functions/benches/ltrim.rs | 1 + datafusion/functions/benches/make_date.rs | 4 ++++ datafusion/functions/benches/nullif.rs | 1 + datafusion/functions/benches/pad.rs | 1 + datafusion/functions/benches/random.rs | 2 ++ datafusion/functions/benches/repeat.rs | 1 + datafusion/functions/benches/reverse.rs | 4 ++++ datafusion/functions/benches/signum.rs | 2 ++ datafusion/functions/benches/strpos.rs | 4 ++++ datafusion/functions/benches/substr.rs | 1 + datafusion/functions/benches/substr_index.rs | 1 + datafusion/functions/benches/to_char.rs | 6 ++++++ datafusion/functions/benches/to_hex.rs | 2 ++ datafusion/functions/benches/to_timestamp.rs | 6 ++++++ datafusion/functions/benches/trunc.rs | 2 ++ datafusion/functions/benches/upper.rs | 1 + datafusion/functions/benches/uuid.rs | 1 + datafusion/functions/src/core/union_extract.rs | 4 ++++ datafusion/functions/src/core/union_tag.rs | 8 ++++++-- datafusion/functions/src/core/version.rs | 1 + datafusion/functions/src/datetime/date_bin.rs | 1 + .../functions/src/datetime/date_trunc.rs | 2 ++ .../functions/src/datetime/from_unixtime.rs | 2 ++ datafusion/functions/src/datetime/make_date.rs | 1 + datafusion/functions/src/datetime/now.rs | 2 ++ datafusion/functions/src/datetime/to_char.rs | 7 +++++++ datafusion/functions/src/datetime/to_date.rs | 1 + .../functions/src/datetime/to_local_time.rs | 2 ++ .../functions/src/datetime/to_timestamp.rs | 2 ++ datafusion/functions/src/math/log.rs | 18 ++++++++++++++++++ datafusion/functions/src/math/power.rs | 2 ++ datafusion/functions/src/math/signum.rs | 2 ++ datafusion/functions/src/regex/regexpcount.rs | 1 + datafusion/functions/src/regex/regexpinstr.rs | 1 + datafusion/functions/src/string/concat.rs | 1 + datafusion/functions/src/string/concat_ws.rs | 2 ++ datafusion/functions/src/string/contains.rs | 1 + datafusion/functions/src/string/lower.rs | 1 + datafusion/functions/src/string/upper.rs | 1 + .../functions/src/unicode/find_in_set.rs | 1 + datafusion/functions/src/unicode/strpos.rs | 1 + datafusion/functions/src/utils.rs | 3 +++ datafusion/spark/benches/char.rs | 1 + .../spark/src/function/bitmap/bitmap_count.rs | 1 + .../src/function/datetime/make_dt_interval.rs | 1 + .../src/function/datetime/make_interval.rs | 1 + datafusion/spark/src/function/string/concat.rs | 2 ++ datafusion/spark/src/function/utils.rs | 5 ++++- 67 files changed, 162 insertions(+), 3 deletions(-) diff --git a/datafusion/functions-nested/benches/map.rs b/datafusion/functions-nested/benches/map.rs index 3197cc55cc957..3075d2e573e4a 100644 --- a/datafusion/functions-nested/benches/map.rs +++ b/datafusion/functions-nested/benches/map.rs @@ -117,6 +117,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: 1, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, }) .expect("map should work on valid values"), ); diff --git a/datafusion/functions-nested/src/array_has.rs b/datafusion/functions-nested/src/array_has.rs index 080b2f16d92f3..d6a333c0a0ef3 100644 --- a/datafusion/functions-nested/src/array_has.rs +++ b/datafusion/functions-nested/src/array_has.rs @@ -819,6 +819,7 @@ mod tests { number_rows: 1, return_field, config_options: Arc::new(ConfigOptions::default()), + lambdas: None, })?; let output = result.into_array(1)?; @@ -847,6 +848,7 @@ mod tests { number_rows: 1, return_field, config_options: Arc::new(ConfigOptions::default()), + lambdas: None, })?; let output = result.into_array(1)?; diff --git a/datafusion/functions-nested/src/lib.rs b/datafusion/functions-nested/src/lib.rs index 3a66e65694768..55acf24ba4657 100644 --- a/datafusion/functions-nested/src/lib.rs +++ b/datafusion/functions-nested/src/lib.rs @@ -37,6 +37,7 @@ pub mod macros; pub mod array_has; +pub mod array_transform; pub mod cardinality; pub mod concat; pub mod dimension; @@ -78,6 +79,7 @@ pub mod expr_fn { pub use super::array_has::array_has; pub use super::array_has::array_has_all; pub use super::array_has::array_has_any; + pub use super::array_transform::array_transform; pub use super::cardinality::cardinality; pub use super::concat::array_append; pub use super::concat::array_concat; @@ -145,6 +147,7 @@ pub fn all_default_nested_functions() -> Vec> { array_has::array_has_udf(), array_has::array_has_all_udf(), array_has::array_has_any_udf(), + array_transform::array_transform_udf(), empty::array_empty_udf(), length::array_length_udf(), distance::array_distance_udf(), diff --git a/datafusion/functions-nested/src/map_values.rs b/datafusion/functions-nested/src/map_values.rs index 6ae8a278063da..ac21ff8acd3f9 100644 --- a/datafusion/functions-nested/src/map_values.rs +++ b/datafusion/functions-nested/src/map_values.rs @@ -204,6 +204,7 @@ mod tests { let args = datafusion_expr::ReturnFieldArgs { arg_fields: &[field], scalar_arguments: &[None::<&ScalarValue>], + lambdas: &[false], }; func.return_field_from_args(args).unwrap() diff --git a/datafusion/functions-nested/src/set_ops.rs b/datafusion/functions-nested/src/set_ops.rs index 53642bf1622b0..f26fc173d8a9f 100644 --- a/datafusion/functions-nested/src/set_ops.rs +++ b/datafusion/functions-nested/src/set_ops.rs @@ -596,6 +596,7 @@ mod tests { number_rows: 1, return_field: input_field.clone().into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, })?; assert_eq!( diff --git a/datafusion/functions/benches/ascii.rs b/datafusion/functions/benches/ascii.rs index 03d25e9c3d4fe..97e6ab20ed458 100644 --- a/datafusion/functions/benches/ascii.rs +++ b/datafusion/functions/benches/ascii.rs @@ -60,6 +60,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: N_ROWS, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }, @@ -81,6 +82,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: N_ROWS, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }, @@ -108,6 +110,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: N_ROWS, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }, @@ -129,6 +132,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: N_ROWS, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }, diff --git a/datafusion/functions/benches/character_length.rs b/datafusion/functions/benches/character_length.rs index 4a1a63d62765f..f98e8a8b1a68b 100644 --- a/datafusion/functions/benches/character_length.rs +++ b/datafusion/functions/benches/character_length.rs @@ -55,6 +55,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: n_rows, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }, @@ -79,6 +80,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: n_rows, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }, @@ -103,6 +105,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: n_rows, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }, @@ -127,6 +130,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: n_rows, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }, diff --git a/datafusion/functions/benches/chr.rs b/datafusion/functions/benches/chr.rs index 8356cf7c31726..d51cda4566d64 100644 --- a/datafusion/functions/benches/chr.rs +++ b/datafusion/functions/benches/chr.rs @@ -69,6 +69,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, }) .unwrap(), ) diff --git a/datafusion/functions/benches/concat.rs b/datafusion/functions/benches/concat.rs index 09200139a244b..6378328537827 100644 --- a/datafusion/functions/benches/concat.rs +++ b/datafusion/functions/benches/concat.rs @@ -60,6 +60,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, }) .unwrap(), ) diff --git a/datafusion/functions/benches/cot.rs b/datafusion/functions/benches/cot.rs index 97f21ccd6d55e..56f50522acc5d 100644 --- a/datafusion/functions/benches/cot.rs +++ b/datafusion/functions/benches/cot.rs @@ -54,6 +54,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Float32, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, }) .unwrap(), ) @@ -80,6 +81,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, }) .unwrap(), ) diff --git a/datafusion/functions/benches/date_bin.rs b/datafusion/functions/benches/date_bin.rs index 74390491d538c..1c3713723738a 100644 --- a/datafusion/functions/benches/date_bin.rs +++ b/datafusion/functions/benches/date_bin.rs @@ -66,6 +66,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, }) .expect("date_bin should work on valid values"), ) diff --git a/datafusion/functions/benches/date_trunc.rs b/datafusion/functions/benches/date_trunc.rs index 498a3e63ef290..b757535fb03c5 100644 --- a/datafusion/functions/benches/date_trunc.rs +++ b/datafusion/functions/benches/date_trunc.rs @@ -71,6 +71,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, }) .expect("date_trunc should work on valid values"), ) diff --git a/datafusion/functions/benches/encoding.rs b/datafusion/functions/benches/encoding.rs index 98faee91e1911..72b033cf5d9ed 100644 --- a/datafusion/functions/benches/encoding.rs +++ b/datafusion/functions/benches/encoding.rs @@ -45,6 +45,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, }) .unwrap(); @@ -63,6 +64,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, }) .unwrap(), ) @@ -82,6 +84,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, }) .unwrap(); @@ -101,6 +104,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, }) .unwrap(), ) diff --git a/datafusion/functions/benches/find_in_set.rs b/datafusion/functions/benches/find_in_set.rs index a928f5655806c..6fe498a58d84b 100644 --- a/datafusion/functions/benches/find_in_set.rs +++ b/datafusion/functions/benches/find_in_set.rs @@ -168,6 +168,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: n_rows, return_field: Arc::clone(&return_field), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, })) }) }); @@ -186,6 +187,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: n_rows, return_field: Arc::clone(&return_field), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, })) }) }); @@ -208,6 +210,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: n_rows, return_field: Arc::clone(&return_field), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, })) }) }); @@ -228,6 +231,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: n_rows, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }); diff --git a/datafusion/functions/benches/gcd.rs b/datafusion/functions/benches/gcd.rs index 19e196d9a3eab..2bfec91e290dd 100644 --- a/datafusion/functions/benches/gcd.rs +++ b/datafusion/functions/benches/gcd.rs @@ -58,6 +58,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: 0, return_field: Field::new("f", DataType::Int64, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, }) .expect("date_bin should work on valid values"), ) @@ -79,6 +80,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: 0, return_field: Field::new("f", DataType::Int64, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, }) .expect("date_bin should work on valid values"), ) @@ -100,6 +102,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: 0, return_field: Field::new("f", DataType::Int64, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, }) .expect("date_bin should work on valid values"), ) diff --git a/datafusion/functions/benches/initcap.rs b/datafusion/functions/benches/initcap.rs index 50aee8dbb9161..37d98596deb82 100644 --- a/datafusion/functions/benches/initcap.rs +++ b/datafusion/functions/benches/initcap.rs @@ -70,6 +70,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8View, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }, @@ -86,6 +87,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8View, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }, @@ -100,6 +102,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }); diff --git a/datafusion/functions/benches/isnan.rs b/datafusion/functions/benches/isnan.rs index 4a90d45d66223..dcce59e46ce41 100644 --- a/datafusion/functions/benches/isnan.rs +++ b/datafusion/functions/benches/isnan.rs @@ -53,6 +53,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Boolean, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, }) .unwrap(), ) @@ -77,6 +78,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Boolean, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, }) .unwrap(), ) diff --git a/datafusion/functions/benches/iszero.rs b/datafusion/functions/benches/iszero.rs index 961cba7200ce0..574539fbb6427 100644 --- a/datafusion/functions/benches/iszero.rs +++ b/datafusion/functions/benches/iszero.rs @@ -55,6 +55,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, }) .unwrap(), ) @@ -82,6 +83,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, }) .unwrap(), ) diff --git a/datafusion/functions/benches/lower.rs b/datafusion/functions/benches/lower.rs index 6a5178b87fdce..e741afd0d8e01 100644 --- a/datafusion/functions/benches/lower.rs +++ b/datafusion/functions/benches/lower.rs @@ -145,6 +145,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }); @@ -167,6 +168,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }); @@ -191,6 +193,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }, @@ -225,6 +228,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, })) }), ); @@ -240,6 +244,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, })) }), ); @@ -256,6 +261,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, })) }), ); diff --git a/datafusion/functions/benches/ltrim.rs b/datafusion/functions/benches/ltrim.rs index 4458af614396d..9b344cc6b143a 100644 --- a/datafusion/functions/benches/ltrim.rs +++ b/datafusion/functions/benches/ltrim.rs @@ -153,6 +153,7 @@ fn run_with_string_type( number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }, diff --git a/datafusion/functions/benches/make_date.rs b/datafusion/functions/benches/make_date.rs index 15a895468db93..2a681ddedcbe8 100644 --- a/datafusion/functions/benches/make_date.rs +++ b/datafusion/functions/benches/make_date.rs @@ -81,6 +81,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, }) .expect("make_date should work on valid values"), ) @@ -111,6 +112,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, }) .expect("make_date should work on valid values"), ) @@ -141,6 +143,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, }) .expect("make_date should work on valid values"), ) @@ -168,6 +171,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: 1, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, }) .expect("make_date should work on valid values"), ) diff --git a/datafusion/functions/benches/nullif.rs b/datafusion/functions/benches/nullif.rs index d649697cc5188..15914cd7ee6c5 100644 --- a/datafusion/functions/benches/nullif.rs +++ b/datafusion/functions/benches/nullif.rs @@ -54,6 +54,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, }) .unwrap(), ) diff --git a/datafusion/functions/benches/pad.rs b/datafusion/functions/benches/pad.rs index f92a69bbf4f92..c7d46da3d26c6 100644 --- a/datafusion/functions/benches/pad.rs +++ b/datafusion/functions/benches/pad.rs @@ -116,6 +116,7 @@ fn invoke_pad_with_args( number_rows, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, }; if left_pad { diff --git a/datafusion/functions/benches/random.rs b/datafusion/functions/benches/random.rs index 88efb2d1b5b93..2935876685800 100644 --- a/datafusion/functions/benches/random.rs +++ b/datafusion/functions/benches/random.rs @@ -43,6 +43,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: 8192, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, }) .unwrap(), ); @@ -64,6 +65,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: 128, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, }) .unwrap(), ); diff --git a/datafusion/functions/benches/repeat.rs b/datafusion/functions/benches/repeat.rs index 80ffa8ee38f1a..9a7c63ed4f304 100644 --- a/datafusion/functions/benches/repeat.rs +++ b/datafusion/functions/benches/repeat.rs @@ -76,6 +76,7 @@ fn invoke_repeat_with_args( number_rows: repeat_times as usize, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, }) } diff --git a/datafusion/functions/benches/reverse.rs b/datafusion/functions/benches/reverse.rs index b1eca654fb254..a8af40cd8cc19 100644 --- a/datafusion/functions/benches/reverse.rs +++ b/datafusion/functions/benches/reverse.rs @@ -58,6 +58,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: N_ROWS, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }, @@ -80,6 +81,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: N_ROWS, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }, @@ -107,6 +109,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: N_ROWS, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }, @@ -131,6 +134,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: N_ROWS, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }, diff --git a/datafusion/functions/benches/signum.rs b/datafusion/functions/benches/signum.rs index 24b8861e4d28c..805b62c83da6d 100644 --- a/datafusion/functions/benches/signum.rs +++ b/datafusion/functions/benches/signum.rs @@ -55,6 +55,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, }) .unwrap(), ) @@ -83,6 +84,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, }) .unwrap(), ) diff --git a/datafusion/functions/benches/strpos.rs b/datafusion/functions/benches/strpos.rs index 18a99e44bf487..708ebb5518727 100644 --- a/datafusion/functions/benches/strpos.rs +++ b/datafusion/functions/benches/strpos.rs @@ -128,6 +128,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: n_rows, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }, @@ -146,6 +147,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: n_rows, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }); @@ -165,6 +167,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: n_rows, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }, @@ -185,6 +188,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: n_rows, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }, diff --git a/datafusion/functions/benches/substr.rs b/datafusion/functions/benches/substr.rs index 771413458c1fb..58fda73defd25 100644 --- a/datafusion/functions/benches/substr.rs +++ b/datafusion/functions/benches/substr.rs @@ -116,6 +116,7 @@ fn invoke_substr_with_args( number_rows, return_field: Field::new("f", DataType::Utf8View, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, }) } diff --git a/datafusion/functions/benches/substr_index.rs b/datafusion/functions/benches/substr_index.rs index d0941d9baedda..a77b961657c5f 100644 --- a/datafusion/functions/benches/substr_index.rs +++ b/datafusion/functions/benches/substr_index.rs @@ -110,6 +110,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, }) .expect("substr_index should work on valid values"), ) diff --git a/datafusion/functions/benches/to_char.rs b/datafusion/functions/benches/to_char.rs index 945508aec7405..61990b4cb8b95 100644 --- a/datafusion/functions/benches/to_char.rs +++ b/datafusion/functions/benches/to_char.rs @@ -149,6 +149,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, }) .expect("to_char should work on valid values"), ) @@ -176,6 +177,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, }) .expect("to_char should work on valid values"), ) @@ -203,6 +205,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, }) .expect("to_char should work on valid values"), ) @@ -229,6 +232,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, }) .expect("to_char should work on valid values"), ) @@ -256,6 +260,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, }) .expect("to_char should work on valid values"), ) @@ -288,6 +293,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: 1, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, }) .expect("to_char should work on valid values"), ) diff --git a/datafusion/functions/benches/to_hex.rs b/datafusion/functions/benches/to_hex.rs index a75ed9258791e..baa2de80c466f 100644 --- a/datafusion/functions/benches/to_hex.rs +++ b/datafusion/functions/benches/to_hex.rs @@ -44,6 +44,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, }) .unwrap(), ) @@ -62,6 +63,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, }) .unwrap(), ) diff --git a/datafusion/functions/benches/to_timestamp.rs b/datafusion/functions/benches/to_timestamp.rs index a8f5c5816d4da..e510a7c3fad41 100644 --- a/datafusion/functions/benches/to_timestamp.rs +++ b/datafusion/functions/benches/to_timestamp.rs @@ -130,6 +130,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, }) .expect("to_timestamp should work on valid values"), ) @@ -150,6 +151,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, }) .expect("to_timestamp should work on valid values"), ) @@ -170,6 +172,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, }) .expect("to_timestamp should work on valid values"), ) @@ -203,6 +206,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, }) .expect("to_timestamp should work on valid values"), ) @@ -244,6 +248,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, }) .expect("to_timestamp should work on valid values"), ) @@ -286,6 +291,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, }) .expect("to_timestamp should work on valid values"), ) diff --git a/datafusion/functions/benches/trunc.rs b/datafusion/functions/benches/trunc.rs index 6e225e0e7038b..0b08791f9ae50 100644 --- a/datafusion/functions/benches/trunc.rs +++ b/datafusion/functions/benches/trunc.rs @@ -49,6 +49,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, }) .unwrap(), ) @@ -68,6 +69,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, }) .unwrap(), ) diff --git a/datafusion/functions/benches/upper.rs b/datafusion/functions/benches/upper.rs index 7328b32574a4a..e9f0941032d8a 100644 --- a/datafusion/functions/benches/upper.rs +++ b/datafusion/functions/benches/upper.rs @@ -50,6 +50,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }); diff --git a/datafusion/functions/benches/uuid.rs b/datafusion/functions/benches/uuid.rs index 1368e2f2af5d1..8ad79b2866eaf 100644 --- a/datafusion/functions/benches/uuid.rs +++ b/datafusion/functions/benches/uuid.rs @@ -37,6 +37,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: 1024, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }); diff --git a/datafusion/functions/src/core/union_extract.rs b/datafusion/functions/src/core/union_extract.rs index a71e2e87388d5..ac542866f7e43 100644 --- a/datafusion/functions/src/core/union_extract.rs +++ b/datafusion/functions/src/core/union_extract.rs @@ -209,6 +209,7 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, })?; assert_scalar(result, ScalarValue::Utf8(None)); @@ -232,6 +233,7 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, })?; assert_scalar(result, ScalarValue::Utf8(None)); @@ -248,12 +250,14 @@ mod tests { .iter() .map(|arg| Field::new("a", arg.data_type().clone(), true).into()) .collect::>(); + let result = fun.invoke_with_args(ScalarFunctionArgs { args, arg_fields, number_rows: 1, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, })?; assert_scalar(result, ScalarValue::new_utf8("42")); diff --git a/datafusion/functions/src/core/union_tag.rs b/datafusion/functions/src/core/union_tag.rs index aeadb8292ba1e..ecdebf66e0043 100644 --- a/datafusion/functions/src/core/union_tag.rs +++ b/datafusion/functions/src/core/union_tag.rs @@ -173,6 +173,7 @@ mod tests { fields, UnionMode::Dense, ); + let arg_fields = vec![Field::new("a", scalar.data_type(), false).into()]; let return_type = DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)); @@ -180,10 +181,11 @@ mod tests { let result = UnionTagFunc::new() .invoke_with_args(ScalarFunctionArgs { args: vec![ColumnarValue::Scalar(scalar)], + arg_fields, number_rows: 1, return_field: Field::new("res", return_type, true).into(), - arg_fields: vec![], config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }) .unwrap(); @@ -196,6 +198,7 @@ mod tests { #[test] fn union_scalar_empty() { let scalar = ScalarValue::Union(None, UnionFields::empty(), UnionMode::Dense); + let arg_fields = vec![Field::new("a", scalar.data_type(), false).into()]; let return_type = DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)); @@ -203,10 +206,11 @@ mod tests { let result = UnionTagFunc::new() .invoke_with_args(ScalarFunctionArgs { args: vec![ColumnarValue::Scalar(scalar)], + arg_fields, number_rows: 1, return_field: Field::new("res", return_type, true).into(), - arg_fields: vec![], config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }) .unwrap(); diff --git a/datafusion/functions/src/core/version.rs b/datafusion/functions/src/core/version.rs index ef3c5aafa4801..390111028c8f2 100644 --- a/datafusion/functions/src/core/version.rs +++ b/datafusion/functions/src/core/version.rs @@ -112,6 +112,7 @@ mod test { number_rows: 0, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }) .unwrap(); diff --git a/datafusion/functions/src/datetime/date_bin.rs b/datafusion/functions/src/datetime/date_bin.rs index 92af123dbafac..5466129314640 100644 --- a/datafusion/functions/src/datetime/date_bin.rs +++ b/datafusion/functions/src/datetime/date_bin.rs @@ -530,6 +530,7 @@ mod tests { number_rows, return_field: Arc::clone(return_field), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; DateBinFunc::new().invoke_with_args(args) } diff --git a/datafusion/functions/src/datetime/date_trunc.rs b/datafusion/functions/src/datetime/date_trunc.rs index 913e6217af82d..5736c221cae84 100644 --- a/datafusion/functions/src/datetime/date_trunc.rs +++ b/datafusion/functions/src/datetime/date_trunc.rs @@ -892,6 +892,7 @@ mod tests { ) .into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = DateTruncFunc::new().invoke_with_args(args).unwrap(); if let ColumnarValue::Array(result) = result { @@ -1080,6 +1081,7 @@ mod tests { ) .into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = DateTruncFunc::new().invoke_with_args(args).unwrap(); if let ColumnarValue::Array(result) = result { diff --git a/datafusion/functions/src/datetime/from_unixtime.rs b/datafusion/functions/src/datetime/from_unixtime.rs index 5d6adfb6f119a..be44be094e5b7 100644 --- a/datafusion/functions/src/datetime/from_unixtime.rs +++ b/datafusion/functions/src/datetime/from_unixtime.rs @@ -179,6 +179,7 @@ mod test { number_rows: 1, return_field: Field::new("f", DataType::Timestamp(Second, None), true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = FromUnixtimeFunc::new().invoke_with_args(args).unwrap(); @@ -212,6 +213,7 @@ mod test { ) .into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = FromUnixtimeFunc::new().invoke_with_args(args).unwrap(); diff --git a/datafusion/functions/src/datetime/make_date.rs b/datafusion/functions/src/datetime/make_date.rs index 0fe5d156a8383..afa4ef132147a 100644 --- a/datafusion/functions/src/datetime/make_date.rs +++ b/datafusion/functions/src/datetime/make_date.rs @@ -250,6 +250,7 @@ mod tests { number_rows, return_field: Field::new("f", DataType::Date32, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; MakeDateFunc::new().invoke_with_args(args) } diff --git a/datafusion/functions/src/datetime/now.rs b/datafusion/functions/src/datetime/now.rs index 4723548a45584..f18e72a107e28 100644 --- a/datafusion/functions/src/datetime/now.rs +++ b/datafusion/functions/src/datetime/now.rs @@ -163,6 +163,7 @@ mod tests { .return_field_from_args(ReturnFieldArgs { arg_fields: &empty_fields, scalar_arguments: &empty_scalars, + lambdas: &[], }) .expect("legacy now() return field"); @@ -170,6 +171,7 @@ mod tests { .return_field_from_args(ReturnFieldArgs { arg_fields: &empty_fields, scalar_arguments: &empty_scalars, + lambdas: &[], }) .expect("configured now() return field"); diff --git a/datafusion/functions/src/datetime/to_char.rs b/datafusion/functions/src/datetime/to_char.rs index 7d9b2bc241e1a..5d69ce233f643 100644 --- a/datafusion/functions/src/datetime/to_char.rs +++ b/datafusion/functions/src/datetime/to_char.rs @@ -375,6 +375,7 @@ mod tests { number_rows: batch_len, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&Arc::new(ConfigOptions::default())), + lambdas: None, }; let result = ToCharFunc::new() .invoke_with_args(args) @@ -480,6 +481,7 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = ToCharFunc::new() .invoke_with_args(args) @@ -574,6 +576,7 @@ mod tests { number_rows: batch_len, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = ToCharFunc::new() .invoke_with_args(args) @@ -738,6 +741,7 @@ mod tests { number_rows: batch_len, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = ToCharFunc::new() .invoke_with_args(args) @@ -766,6 +770,7 @@ mod tests { number_rows: batch_len, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = ToCharFunc::new() .invoke_with_args(args) @@ -791,6 +796,7 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = ToCharFunc::new().invoke_with_args(args); assert_eq!( @@ -812,6 +818,7 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = ToCharFunc::new().invoke_with_args(args); assert_eq!( diff --git a/datafusion/functions/src/datetime/to_date.rs b/datafusion/functions/src/datetime/to_date.rs index 3840c8d8bbb94..f6b313e6a28bb 100644 --- a/datafusion/functions/src/datetime/to_date.rs +++ b/datafusion/functions/src/datetime/to_date.rs @@ -186,6 +186,7 @@ mod tests { number_rows, return_field: Field::new("f", DataType::Date32, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; ToDateFunc::new().invoke_with_args(args) } diff --git a/datafusion/functions/src/datetime/to_local_time.rs b/datafusion/functions/src/datetime/to_local_time.rs index 6e0a150b0a35f..4d50a70d37236 100644 --- a/datafusion/functions/src/datetime/to_local_time.rs +++ b/datafusion/functions/src/datetime/to_local_time.rs @@ -549,6 +549,7 @@ mod tests { number_rows: 1, return_field: Field::new("f", expected.data_type(), true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }) .unwrap(); match res { @@ -620,6 +621,7 @@ mod tests { ) .into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = ToLocalTimeFunc::new().invoke_with_args(args).unwrap(); if let ColumnarValue::Array(result) = result { diff --git a/datafusion/functions/src/datetime/to_timestamp.rs b/datafusion/functions/src/datetime/to_timestamp.rs index 0a0700097770f..f35e170073030 100644 --- a/datafusion/functions/src/datetime/to_timestamp.rs +++ b/datafusion/functions/src/datetime/to_timestamp.rs @@ -1033,6 +1033,7 @@ mod tests { number_rows: 4, return_field: Field::new("f", rt, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let res = udf .invoke_with_args(args) @@ -1083,6 +1084,7 @@ mod tests { number_rows: 5, return_field: Field::new("f", rt, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let res = udf .invoke_with_args(args) diff --git a/datafusion/functions/src/math/log.rs b/datafusion/functions/src/math/log.rs index f66f6fcfc1f88..1a73ed8436a68 100644 --- a/datafusion/functions/src/math/log.rs +++ b/datafusion/functions/src/math/log.rs @@ -370,6 +370,7 @@ mod tests { number_rows: 4, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = LogFunc::new().invoke_with_args(args); assert!(result.is_err()); @@ -390,6 +391,7 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = LogFunc::new().invoke_with_args(args); @@ -407,6 +409,7 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Float32, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = LogFunc::new() .invoke_with_args(args) @@ -437,6 +440,7 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = LogFunc::new() .invoke_with_args(args) @@ -471,6 +475,7 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Float32, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = LogFunc::new() .invoke_with_args(args) @@ -505,6 +510,7 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = LogFunc::new() .invoke_with_args(args) @@ -537,6 +543,7 @@ mod tests { number_rows: 4, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = LogFunc::new() .invoke_with_args(args) @@ -572,6 +579,7 @@ mod tests { number_rows: 4, return_field: Field::new("f", DataType::Float32, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = LogFunc::new() .invoke_with_args(args) @@ -613,6 +621,7 @@ mod tests { number_rows: 5, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = LogFunc::new() .invoke_with_args(args) @@ -655,6 +664,7 @@ mod tests { number_rows: 4, return_field: Field::new("f", DataType::Float32, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = LogFunc::new() .invoke_with_args(args) @@ -836,6 +846,7 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Decimal128(38, 0), true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = LogFunc::new() .invoke_with_args(args) @@ -869,6 +880,7 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = LogFunc::new() .invoke_with_args(args) @@ -903,6 +915,7 @@ mod tests { number_rows: 6, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = LogFunc::new() .invoke_with_args(args) @@ -947,6 +960,7 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = LogFunc::new() .invoke_with_args(args) @@ -987,6 +1001,7 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = LogFunc::new() .invoke_with_args(args) @@ -1037,6 +1052,7 @@ mod tests { number_rows: 7, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = LogFunc::new() .invoke_with_args(args) @@ -1078,6 +1094,7 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = LogFunc::new().invoke_with_args(args); assert!(result.is_err()); @@ -1101,6 +1118,7 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = LogFunc::new().invoke_with_args(args); assert!(result.is_err()); diff --git a/datafusion/functions/src/math/power.rs b/datafusion/functions/src/math/power.rs index ad2e795d086e9..21a777abb3295 100644 --- a/datafusion/functions/src/math/power.rs +++ b/datafusion/functions/src/math/power.rs @@ -222,6 +222,7 @@ mod tests { number_rows: 4, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = PowerFunc::new() .invoke_with_args(args) @@ -258,6 +259,7 @@ mod tests { number_rows: 4, return_field: Field::new("f", DataType::Int64, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = PowerFunc::new() .invoke_with_args(args) diff --git a/datafusion/functions/src/math/signum.rs b/datafusion/functions/src/math/signum.rs index bbe6178f39b79..d1d49b1bf6f90 100644 --- a/datafusion/functions/src/math/signum.rs +++ b/datafusion/functions/src/math/signum.rs @@ -173,6 +173,7 @@ mod test { number_rows: array.len(), return_field: Field::new("f", DataType::Float32, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = SignumFunc::new() .invoke_with_args(args) @@ -220,6 +221,7 @@ mod test { number_rows: array.len(), return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = SignumFunc::new() .invoke_with_args(args) diff --git a/datafusion/functions/src/regex/regexpcount.rs b/datafusion/functions/src/regex/regexpcount.rs index 8bad506217aa5..ee6f412bb9a16 100644 --- a/datafusion/functions/src/regex/regexpcount.rs +++ b/datafusion/functions/src/regex/regexpcount.rs @@ -628,6 +628,7 @@ mod tests { number_rows: args.len(), return_field: Field::new("f", Int64, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }) } diff --git a/datafusion/functions/src/regex/regexpinstr.rs b/datafusion/functions/src/regex/regexpinstr.rs index 851c182a90dd0..1e64f7087ea74 100644 --- a/datafusion/functions/src/regex/regexpinstr.rs +++ b/datafusion/functions/src/regex/regexpinstr.rs @@ -494,6 +494,7 @@ mod tests { number_rows: args.len(), return_field: Arc::new(Field::new("f", Int64, true)), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }) } diff --git a/datafusion/functions/src/string/concat.rs b/datafusion/functions/src/string/concat.rs index a93e70e714e8b..661bcfe4e0fd8 100644 --- a/datafusion/functions/src/string/concat.rs +++ b/datafusion/functions/src/string/concat.rs @@ -487,6 +487,7 @@ mod tests { number_rows: 3, return_field: Field::new("f", Utf8, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = ConcatFunc::new().invoke_with_args(args)?; diff --git a/datafusion/functions/src/string/concat_ws.rs b/datafusion/functions/src/string/concat_ws.rs index cdd30ac8755ab..85704d6b2f468 100644 --- a/datafusion/functions/src/string/concat_ws.rs +++ b/datafusion/functions/src/string/concat_ws.rs @@ -495,6 +495,7 @@ mod tests { number_rows: 3, return_field: Field::new("f", Utf8, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = ConcatWsFunc::new().invoke_with_args(args)?; @@ -532,6 +533,7 @@ mod tests { number_rows: 3, return_field: Field::new("f", Utf8, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = ConcatWsFunc::new().invoke_with_args(args)?; diff --git a/datafusion/functions/src/string/contains.rs b/datafusion/functions/src/string/contains.rs index 7e50676933c8d..1edab4c6bf334 100644 --- a/datafusion/functions/src/string/contains.rs +++ b/datafusion/functions/src/string/contains.rs @@ -177,6 +177,7 @@ mod test { number_rows: 2, return_field: Field::new("f", DataType::Boolean, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let actual = udf.invoke_with_args(args).unwrap(); diff --git a/datafusion/functions/src/string/lower.rs b/datafusion/functions/src/string/lower.rs index ee56a6a549857..099a3ffd44cc4 100644 --- a/datafusion/functions/src/string/lower.rs +++ b/datafusion/functions/src/string/lower.rs @@ -113,6 +113,7 @@ mod tests { arg_fields, return_field: Field::new("f", Utf8, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = match func.invoke_with_args(args)? { diff --git a/datafusion/functions/src/string/upper.rs b/datafusion/functions/src/string/upper.rs index 8bb2ec1d511cd..d7d2bde94b0a3 100644 --- a/datafusion/functions/src/string/upper.rs +++ b/datafusion/functions/src/string/upper.rs @@ -112,6 +112,7 @@ mod tests { arg_fields: vec![arg_field], return_field: Field::new("f", Utf8, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = match func.invoke_with_args(args)? { diff --git a/datafusion/functions/src/unicode/find_in_set.rs b/datafusion/functions/src/unicode/find_in_set.rs index fa68e539600b0..219bd6eaa762c 100644 --- a/datafusion/functions/src/unicode/find_in_set.rs +++ b/datafusion/functions/src/unicode/find_in_set.rs @@ -485,6 +485,7 @@ mod tests { number_rows: cardinality, return_field: Field::new("f", return_type, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }); assert!(result.is_ok()); diff --git a/datafusion/functions/src/unicode/strpos.rs b/datafusion/functions/src/unicode/strpos.rs index 4f238b2644bdf..a3734b0c0de4f 100644 --- a/datafusion/functions/src/unicode/strpos.rs +++ b/datafusion/functions/src/unicode/strpos.rs @@ -336,6 +336,7 @@ mod tests { Field::new("f2", DataType::Utf8, substring_nullable).into(), ], scalar_arguments: &[None::<&ScalarValue>, None::<&ScalarValue>], + lambdas: &[false; 2], }; strpos.return_field_from_args(args).unwrap().is_nullable() diff --git a/datafusion/functions/src/utils.rs b/datafusion/functions/src/utils.rs index 932d61e8007cd..d6d56b32722de 100644 --- a/datafusion/functions/src/utils.rs +++ b/datafusion/functions/src/utils.rs @@ -234,6 +234,7 @@ pub mod test { let return_field = func.return_field_from_args(datafusion_expr::ReturnFieldArgs { arg_fields: &field_array, scalar_arguments: &scalar_arguments_refs, + lambdas: &vec![false; scalar_arguments_refs.len()], }); let arg_fields = $ARGS.iter() .enumerate() @@ -252,6 +253,7 @@ pub mod test { arg_fields, number_rows: cardinality, return_field, + lambdas: None, config_options: $CONFIG_OPTIONS }); assert_eq!(result.is_ok(), true, "function returned an error: {}", result.unwrap_err()); @@ -274,6 +276,7 @@ pub mod test { arg_fields, number_rows: cardinality, return_field, + lambdas: None, config_options: $CONFIG_OPTIONS, }) { Ok(_) => assert!(false, "expected error"), diff --git a/datafusion/spark/benches/char.rs b/datafusion/spark/benches/char.rs index 02eab7630d070..501bfd2a0186d 100644 --- a/datafusion/spark/benches/char.rs +++ b/datafusion/spark/benches/char.rs @@ -68,6 +68,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Arc::new(Field::new("f", DataType::Utf8, true)), config_options: Arc::clone(&config_options), + lambdas: None, }) .unwrap(), ) diff --git a/datafusion/spark/src/function/bitmap/bitmap_count.rs b/datafusion/spark/src/function/bitmap/bitmap_count.rs index 56a9c5edb812c..e4c12ebe19665 100644 --- a/datafusion/spark/src/function/bitmap/bitmap_count.rs +++ b/datafusion/spark/src/function/bitmap/bitmap_count.rs @@ -217,6 +217,7 @@ mod tests { number_rows: 1, return_field: Field::new("f", Int64, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let udf = BitmapCount::new(); let actual = udf.invoke_with_args(args)?; diff --git a/datafusion/spark/src/function/datetime/make_dt_interval.rs b/datafusion/spark/src/function/datetime/make_dt_interval.rs index bbfba44861344..aaff5400d0c00 100644 --- a/datafusion/spark/src/function/datetime/make_dt_interval.rs +++ b/datafusion/spark/src/function/datetime/make_dt_interval.rs @@ -317,6 +317,7 @@ mod tests { number_rows, return_field: Field::new("f", Duration(Microsecond), true).into(), config_options: Arc::new(Default::default()), + lambdas: None, }; SparkMakeDtInterval::new().invoke_with_args(args) } diff --git a/datafusion/spark/src/function/datetime/make_interval.rs b/datafusion/spark/src/function/datetime/make_interval.rs index 8e3169556b95b..9f98c4b5ce9fb 100644 --- a/datafusion/spark/src/function/datetime/make_interval.rs +++ b/datafusion/spark/src/function/datetime/make_interval.rs @@ -516,6 +516,7 @@ mod tests { number_rows, return_field: Field::new("f", Interval(MonthDayNano), true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; SparkMakeInterval::new().invoke_with_args(args) } diff --git a/datafusion/spark/src/function/string/concat.rs b/datafusion/spark/src/function/string/concat.rs index 0dcc58d5bb8ed..e2cd8d977fe29 100644 --- a/datafusion/spark/src/function/string/concat.rs +++ b/datafusion/spark/src/function/string/concat.rs @@ -105,6 +105,7 @@ fn spark_concat(args: ScalarFunctionArgs) -> Result { number_rows, return_field, config_options, + lambdas, } = args; // Handle zero-argument case: return empty string @@ -130,6 +131,7 @@ fn spark_concat(args: ScalarFunctionArgs) -> Result { number_rows, return_field, config_options, + lambdas, }; let result = concat_func.invoke_with_args(func_args)?; diff --git a/datafusion/spark/src/function/utils.rs b/datafusion/spark/src/function/utils.rs index e272d91d8a70e..1064acc342916 100644 --- a/datafusion/spark/src/function/utils.rs +++ b/datafusion/spark/src/function/utils.rs @@ -60,7 +60,8 @@ pub mod test { let return_field = func.return_field_from_args(datafusion_expr::ReturnFieldArgs { arg_fields: &arg_fields, - scalar_arguments: &scalar_arguments_refs + scalar_arguments: &scalar_arguments_refs, + lambdas: &vec![false; arg_fields.len()], }); match expected { @@ -74,6 +75,7 @@ pub mod test { return_field, arg_fields: arg_fields.clone(), config_options: $CONFIG_OPTIONS, + lambdas: None, }) { Ok(col_value) => { match col_value.to_array(cardinality) { @@ -117,6 +119,7 @@ pub mod test { return_field: value, arg_fields, config_options: $CONFIG_OPTIONS, + lambdas: None, }) { Ok(_) => assert!(false, "expected error"), Err(error) => { From b18d2145163fb5933dfed55eb6305412743b6cac Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Mon, 15 Dec 2025 07:27:02 -0300 Subject: [PATCH 03/12] simplify lambda support --- Cargo.lock | 2 +- .../examples/custom_file_casts.rs | 11 +- .../examples/default_column_values.rs | 14 +- datafusion-examples/examples/expr_api.rs | 8 +- .../examples/json_shredding.rs | 14 +- datafusion/catalog-listing/src/helpers.rs | 9 +- datafusion/common/src/column.rs | 6 - datafusion/common/src/cse.rs | 22 +- datafusion/common/src/dfschema.rs | 16 +- datafusion/common/src/lib.rs | 2 - datafusion/common/src/utils/mod.rs | 32 +- .../core/src/execution/session_state.rs | 5 +- datafusion/core/tests/parquet/mod.rs | 2 +- .../core/tests/parquet/schema_adapter.rs | 8 +- .../datasource-parquet/src/row_filter.rs | 16 +- datafusion/expr/src/expr.rs | 84 ++- datafusion/expr/src/expr_rewriter/mod.rs | 49 +- datafusion/expr/src/expr_rewriter/order_by.rs | 4 - datafusion/expr/src/expr_schema.rs | 39 +- datafusion/expr/src/lib.rs | 6 +- datafusion/expr/src/tree_node.rs | 685 +----------------- datafusion/expr/src/udf.rs | 467 +----------- datafusion/expr/src/utils.rs | 41 +- datafusion/functions-nested/Cargo.toml | 1 + .../functions-nested/src/array_transform.rs | 110 ++- datafusion/functions/src/core/union_tag.rs | 6 +- .../src/analyzer/function_rewrite.rs | 21 +- .../optimizer/src/analyzer/type_coercion.rs | 85 ++- .../optimizer/src/common_subexpr_eliminate.rs | 26 +- datafusion/optimizer/src/decorrelate.rs | 20 +- .../optimizer/src/optimize_projections/mod.rs | 4 +- datafusion/optimizer/src/push_down_filter.rs | 41 +- .../optimizer/src/scalar_subquery_to_join.rs | 67 +- .../simplify_expressions/expr_simplifier.rs | 178 ++--- datafusion/optimizer/src/utils.rs | 4 +- .../src/schema_rewriter.rs | 28 +- datafusion/physical-expr/Cargo.toml | 4 - .../physical-expr/src/expressions/column.rs | 20 +- .../physical-expr/src/expressions/lambda.rs | 23 +- .../src/expressions/lambda_column.rs | 136 ++++ .../physical-expr/src/expressions/mod.rs | 2 + datafusion/physical-expr/src/lib.rs | 2 - datafusion/physical-expr/src/physical_expr.rs | 10 +- datafusion/physical-expr/src/planner.rs | 46 +- datafusion/physical-expr/src/projection.rs | 53 +- .../physical-expr/src/scalar_function.rs | 657 ++--------------- .../physical-expr/src/simplifier/mod.rs | 20 +- .../src/simplifier/unwrap_cast.rs | 12 +- datafusion/physical-expr/src/utils/mod.rs | 21 +- .../src/enforce_sorting/sort_pushdown.rs | 60 +- .../src/projection_pushdown.rs | 55 +- datafusion/physical-plan/src/async_func.rs | 6 +- .../src/joins/stream_join_utils.rs | 29 +- datafusion/physical-plan/src/projection.rs | 61 +- datafusion/proto/src/logical_plan/to_proto.rs | 4 +- datafusion/pruning/src/pruning_predicate.rs | 11 +- datafusion/sql/src/expr/function.rs | 117 ++- datafusion/sql/src/expr/identifier.rs | 11 +- datafusion/sql/src/planner.rs | 28 +- datafusion/sql/src/select.rs | 6 +- datafusion/sql/src/unparser/expr.rs | 3 + datafusion/sql/src/unparser/plan.rs | 6 +- datafusion/sql/src/unparser/rewrite.rs | 14 +- datafusion/sql/src/unparser/utils.rs | 44 +- datafusion/sql/src/utils.rs | 32 +- datafusion/sqllogictest/test_files/array.slt | 8 +- datafusion/sqllogictest/test_files/lambda.slt | 38 +- .../src/logical_plan/producer/expr/mod.rs | 1 + 68 files changed, 1007 insertions(+), 2666 deletions(-) create mode 100644 datafusion/physical-expr/src/expressions/lambda_column.rs diff --git a/Cargo.lock b/Cargo.lock index 4a315ff38f2aa..8377a263cd0cf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2386,6 +2386,7 @@ dependencies = [ "datafusion-functions-aggregate", "datafusion-functions-aggregate-common", "datafusion-macros", + "datafusion-physical-expr", "datafusion-physical-expr-common", "itertools 0.14.0", "log", @@ -2489,7 +2490,6 @@ dependencies = [ "paste", "petgraph 0.8.3", "rand 0.9.2", - "recursive", "rstest", ] diff --git a/datafusion-examples/examples/custom_file_casts.rs b/datafusion-examples/examples/custom_file_casts.rs index d8db97d1e0440..4d97ecd91dc64 100644 --- a/datafusion-examples/examples/custom_file_casts.rs +++ b/datafusion-examples/examples/custom_file_casts.rs @@ -22,7 +22,7 @@ use arrow::datatypes::{DataType, Field, FieldRef, Schema, SchemaRef}; use datafusion::assert_batches_eq; use datafusion::common::not_impl_err; -use datafusion::common::tree_node::{Transformed, TransformedResult}; +use datafusion::common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion::common::{Result, ScalarValue}; use datafusion::datasource::listing::{ ListingTable, ListingTableConfig, ListingTableConfigExt, ListingTableUrl, @@ -31,7 +31,7 @@ use datafusion::execution::context::SessionContext; use datafusion::execution::object_store::ObjectStoreUrl; use datafusion::parquet::arrow::ArrowWriter; use datafusion::physical_expr::expressions::CastExpr; -use datafusion::physical_expr::{PhysicalExpr, PhysicalExprExt}; +use datafusion::physical_expr::PhysicalExpr; use datafusion::prelude::SessionConfig; use datafusion_physical_expr_adapter::{ DefaultPhysicalExprAdapterFactory, PhysicalExprAdapter, PhysicalExprAdapterFactory, @@ -181,10 +181,11 @@ impl PhysicalExprAdapter for CustomCastsPhysicalExprAdapter { expr = self.inner.rewrite(expr)?; // Now we can apply custom casting rules or even swap out all CastExprs for a custom cast kernel / expression // For example, [DataFusion Comet](https://github.com/apache/datafusion-comet) has a [custom cast kernel](https://github.com/apache/datafusion-comet/blob/b4ac876ab420ed403ac7fc8e1b29f42f1f442566/native/spark-expr/src/conversion_funcs/cast.rs#L133-L138). - expr.transform_with_schema(&self.physical_file_schema, |expr, schema| { + expr.transform(|expr| { if let Some(cast) = expr.as_any().downcast_ref::() { - let input_data_type = cast.expr().data_type(schema)?; - let output_data_type = cast.data_type(schema)?; + let input_data_type = + cast.expr().data_type(&self.physical_file_schema)?; + let output_data_type = cast.data_type(&self.physical_file_schema)?; if !cast.is_bigger_cast(&input_data_type) { return not_impl_err!( "Unsupported CAST from {input_data_type} to {output_data_type}" diff --git a/datafusion-examples/examples/default_column_values.rs b/datafusion-examples/examples/default_column_values.rs index 0d00d2c3af827..d3a7d2ec67f3c 100644 --- a/datafusion-examples/examples/default_column_values.rs +++ b/datafusion-examples/examples/default_column_values.rs @@ -26,8 +26,8 @@ use async_trait::async_trait; use datafusion::assert_batches_eq; use datafusion::catalog::memory::DataSourceExec; use datafusion::catalog::{Session, TableProvider}; -use datafusion::common::tree_node::{Transformed, TransformedResult}; -use datafusion::common::{DFSchema, HashSet}; +use datafusion::common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion::common::DFSchema; use datafusion::common::{Result, ScalarValue}; use datafusion::datasource::listing::PartitionedFile; use datafusion::datasource::physical_plan::{FileScanConfigBuilder, ParquetSource}; @@ -38,7 +38,7 @@ use datafusion::logical_expr::{Expr, TableProviderFilterPushDown, TableType}; use datafusion::parquet::arrow::ArrowWriter; use datafusion::parquet::file::properties::WriterProperties; use datafusion::physical_expr::expressions::{CastExpr, Column, Literal}; -use datafusion::physical_expr::{PhysicalExpr, PhysicalExprExt}; +use datafusion::physical_expr::PhysicalExpr; use datafusion::physical_plan::ExecutionPlan; use datafusion::prelude::{lit, SessionConfig}; use datafusion_physical_expr_adapter::{ @@ -308,12 +308,11 @@ impl PhysicalExprAdapter for DefaultValuePhysicalExprAdapter { fn rewrite(&self, expr: Arc) -> Result> { // First try our custom default value injection for missing columns let rewritten = expr - .transform_with_lambdas_params(|expr, lambdas_params| { + .transform(|expr| { self.inject_default_values( expr, &self.logical_file_schema, &self.physical_file_schema, - lambdas_params, ) }) .data()?; @@ -349,15 +348,12 @@ impl DefaultValuePhysicalExprAdapter { expr: Arc, logical_file_schema: &Schema, physical_file_schema: &Schema, - lambdas_params: &HashSet, ) -> Result>> { if let Some(column) = expr.as_any().downcast_ref::() { let column_name = column.name(); // Check if this column exists in the physical schema - if !lambdas_params.contains(column_name) - && physical_file_schema.index_of(column_name).is_err() - { + if physical_file_schema.index_of(column_name).is_err() { // Column is missing from physical schema, check if logical schema has a default if let Ok(logical_field) = logical_file_schema.field_with_name(column_name) diff --git a/datafusion-examples/examples/expr_api.rs b/datafusion-examples/examples/expr_api.rs index 29f074e2b400c..56f960870e58a 100644 --- a/datafusion-examples/examples/expr_api.rs +++ b/datafusion-examples/examples/expr_api.rs @@ -23,7 +23,7 @@ use arrow::record_batch::RecordBatch; use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit}; use datafusion::common::stats::Precision; -use datafusion::common::tree_node::Transformed; +use datafusion::common::tree_node::{Transformed, TreeNode}; use datafusion::common::{ColumnStatistics, DFSchema}; use datafusion::common::{ScalarValue, ToDFSchema}; use datafusion::error::Result; @@ -556,7 +556,7 @@ fn type_coercion_demo() -> Result<()> { // 3. Type coercion with `TypeCoercionRewriter`. let coerced_expr = expr .clone() - .rewrite_with_schema(&df_schema, &mut TypeCoercionRewriter::new(&df_schema))? + .rewrite(&mut TypeCoercionRewriter::new(&df_schema))? .data; let physical_expr = datafusion::physical_expr::create_physical_expr( &coerced_expr, @@ -567,7 +567,7 @@ fn type_coercion_demo() -> Result<()> { // 4. Apply explicit type coercion by manually rewriting the expression let coerced_expr = expr - .transform_with_schema(&df_schema, |e, df_schema| { + .transform(|e| { // Only type coerces binary expressions. let Expr::BinaryExpr(e) = e else { return Ok(Transformed::no(e)); @@ -575,7 +575,7 @@ fn type_coercion_demo() -> Result<()> { if let Expr::Column(ref col_expr) = *e.left { let field = df_schema.field_with_name(None, col_expr.name())?; let cast_to_type = field.data_type(); - let coerced_right = e.right.cast_to(cast_to_type, df_schema)?; + let coerced_right = e.right.cast_to(cast_to_type, &df_schema)?; Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr::new( e.left, e.op, diff --git a/datafusion-examples/examples/json_shredding.rs b/datafusion-examples/examples/json_shredding.rs index e97f27b818d8d..5ef8b59b64200 100644 --- a/datafusion-examples/examples/json_shredding.rs +++ b/datafusion-examples/examples/json_shredding.rs @@ -22,8 +22,10 @@ use arrow::array::{RecordBatch, StringArray}; use arrow::datatypes::{DataType, Field, FieldRef, Schema, SchemaRef}; use datafusion::assert_batches_eq; -use datafusion::common::tree_node::{Transformed, TransformedResult, TreeNodeRecursion}; -use datafusion::common::{assert_contains, exec_datafusion_err, HashSet, Result}; +use datafusion::common::tree_node::{ + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, +}; +use datafusion::common::{assert_contains, exec_datafusion_err, Result}; use datafusion::datasource::listing::{ ListingTable, ListingTableConfig, ListingTableConfigExt, ListingTableUrl, }; @@ -34,8 +36,8 @@ use datafusion::logical_expr::{ }; use datafusion::parquet::arrow::ArrowWriter; use datafusion::parquet::file::properties::WriterProperties; +use datafusion::physical_expr::PhysicalExpr; use datafusion::physical_expr::{expressions, ScalarFunctionExpr}; -use datafusion::physical_expr::{PhysicalExpr, PhysicalExprExt}; use datafusion::prelude::SessionConfig; use datafusion::scalar::ScalarValue; use datafusion_physical_expr_adapter::{ @@ -300,9 +302,7 @@ impl PhysicalExprAdapter for ShreddedJsonRewriter { fn rewrite(&self, expr: Arc) -> Result> { // First try our custom JSON shredding rewrite let rewritten = expr - .transform_with_lambdas_params(|expr, lambdas_params| { - self.rewrite_impl(expr, &self.physical_file_schema, lambdas_params) - }) + .transform(|expr| self.rewrite_impl(expr, &self.physical_file_schema)) .data()?; // Then apply the default adapter as a fallback to handle standard schema differences @@ -335,7 +335,6 @@ impl ShreddedJsonRewriter { &self, expr: Arc, physical_file_schema: &Schema, - lambdas_params: &HashSet, ) -> Result>> { if let Some(func) = expr.as_any().downcast_ref::() { if func.name() == "json_get_str" && func.args().len() == 2 { @@ -349,7 +348,6 @@ impl ShreddedJsonRewriter { if let Some(column) = func.args()[1] .as_any() .downcast_ref::() - .filter(|col| !lambdas_params.contains(col.name())) { let column_name = column.name(); // Check if there's a flat column with underscore prefix diff --git a/datafusion/catalog-listing/src/helpers.rs b/datafusion/catalog-listing/src/helpers.rs index 444f505f4280b..78b46171006a7 100644 --- a/datafusion/catalog-listing/src/helpers.rs +++ b/datafusion/catalog-listing/src/helpers.rs @@ -52,9 +52,9 @@ use object_store::{ObjectMeta, ObjectStore}; /// was performed pub fn expr_applicable_for_cols(col_names: &[&str], expr: &Expr) -> bool { let mut is_applicable = true; - expr.apply_with_lambdas_params(|expr, lambdas_params| match expr { - Expr::Column(col) => { - is_applicable &= col_names.contains(&col.name()) || col.is_lambda_parameter(lambdas_params); + expr.apply(|expr| match expr { + Expr::Column(Column { ref name, .. }) => { + is_applicable &= col_names.contains(&name.as_str()); if is_applicable { Ok(TreeNodeRecursion::Jump) } else { @@ -87,7 +87,8 @@ pub fn expr_applicable_for_cols(col_names: &[&str], expr: &Expr) -> bool { | Expr::ScalarSubquery(_) | Expr::GroupingSet(_) | Expr::Case(_) - | Expr::Lambda(_) => Ok(TreeNodeRecursion::Continue), + | Expr::Lambda(_) + | Expr::LambdaColumn(_) => Ok(TreeNodeRecursion::Continue), Expr::ScalarFunction(scalar_function) => { match scalar_function.func.signature().volatility { diff --git a/datafusion/common/src/column.rs b/datafusion/common/src/column.rs index dd9b985e6485c..c7f0b5a4f4881 100644 --- a/datafusion/common/src/column.rs +++ b/datafusion/common/src/column.rs @@ -22,7 +22,6 @@ use crate::utils::parse_identifiers_normalized; use crate::utils::quote_identifier; use crate::{DFSchema, Diagnostic, Result, SchemaError, Spans, TableReference}; use arrow::datatypes::{Field, FieldRef}; -use std::borrow::Borrow; use std::collections::HashSet; use std::fmt; @@ -326,11 +325,6 @@ impl Column { ..self.clone() } } - - pub fn is_lambda_parameter(&self, lambdas_params: &crate::HashSet + Eq + std::hash::Hash>) -> bool { - // currently, references to lambda parameters are always unqualified - self.relation.is_none() && lambdas_params.contains(self.name()) - } } impl From<&str> for Column { diff --git a/datafusion/common/src/cse.rs b/datafusion/common/src/cse.rs index a7ffde52c93b2..674d3386171f8 100644 --- a/datafusion/common/src/cse.rs +++ b/datafusion/common/src/cse.rs @@ -178,14 +178,6 @@ pub trait CSEController { /// if all are always evaluated. fn conditional_children(node: &Self::Node) -> Option>; - // A helper method called on each node before is_ignored, during top-down traversal during the first, - // visiting traversal of CSE. - fn visit_f_down(&mut self, _node: &Self::Node) {} - - // A helper method called on each node after is_ignored, during bottom-up traversal during the first, - // visiting traversal of CSE. - fn visit_f_up(&mut self, _node: &Self::Node) {} - // Returns true if a node is valid. If a node is invalid then it can't be eliminated. // Validity is propagated up which means no subtree can be eliminated that contains // an invalid node. @@ -282,7 +274,7 @@ where /// thus can not be extracted as a common [`TreeNode`]. conditional: bool, - controller: &'a mut C, + controller: &'a C, } /// Record item that used when traversing a [`TreeNode`] tree. @@ -360,7 +352,6 @@ where self.visit_stack .push(VisitRecord::EnterMark(self.down_index)); self.down_index += 1; - self.controller.visit_f_down(node); // If a node can short-circuit then some of its children might not be executed so // count the occurrence either normal or conditional. @@ -423,7 +414,6 @@ where self.visit_stack .push(VisitRecord::NodeItem(node_id, is_valid)); self.up_index += 1; - self.controller.visit_f_up(node); Ok(TreeNodeRecursion::Continue) } @@ -542,7 +532,7 @@ where /// Add an identifier to `id_array` for every [`TreeNode`] in this tree. fn node_to_id_array<'n>( - &mut self, + &self, node: &'n N, node_stats: &mut NodeStats<'n, N>, id_array: &mut IdArray<'n, N>, @@ -556,7 +546,7 @@ where random_state: &self.random_state, found_common: false, conditional: false, - controller: &mut self.controller, + controller: &self.controller, }; node.visit(&mut visitor)?; @@ -571,7 +561,7 @@ where /// Each element is itself the result of [`CSE::node_to_id_array`] for that node /// (e.g. the identifiers for each node in the tree) fn to_arrays<'n>( - &mut self, + &self, nodes: &'n [N], node_stats: &mut NodeStats<'n, N>, ) -> Result<(bool, Vec>)> { @@ -771,7 +761,7 @@ mod test { #[test] fn id_array_visitor() -> Result<()> { let alias_generator = AliasGenerator::new(); - let mut eliminator = CSE::new(TestTreeNodeCSEController::new( + let eliminator = CSE::new(TestTreeNodeCSEController::new( &alias_generator, TestTreeNodeMask::Normal, )); @@ -863,7 +853,7 @@ mod test { assert_eq!(expected, id_array); // include aggregates - let mut eliminator = CSE::new(TestTreeNodeCSEController::new( + let eliminator = CSE::new(TestTreeNodeCSEController::new( &alias_generator, TestTreeNodeMask::NormalAndAggregates, )); diff --git a/datafusion/common/src/dfschema.rs b/datafusion/common/src/dfschema.rs index 8a09d61292b27..24d152a7dba8c 100644 --- a/datafusion/common/src/dfschema.rs +++ b/datafusion/common/src/dfschema.rs @@ -314,10 +314,8 @@ impl DFSchema { return; } - let self_fields: HashSet<(Option<&TableReference>, &str)> = self - .iter() - .map(|(qualifier, field)| (qualifier, field.name().as_str())) - .collect(); + let self_fields: HashSet<(Option<&TableReference>, &FieldRef)> = + self.iter().collect(); let self_unqualified_names: HashSet<&str> = self .inner .fields @@ -330,10 +328,7 @@ impl DFSchema { for (qualifier, field) in other_schema.iter() { // skip duplicate columns let duplicated_field = match qualifier { - Some(q) => { - self_fields.contains(&(Some(q), field.name().as_str())) - || self_fields.contains(&(None, field.name().as_str())) - } + Some(q) => self_fields.contains(&(Some(q), field)), // for unqualified columns, check as unqualified name None => self_unqualified_names.contains(field.name().as_str()), }; @@ -872,11 +867,6 @@ impl DFSchema { &self.functional_dependencies } - /// Get functional dependencies - pub fn field_qualifiers(&self) -> &[Option] { - &self.field_qualifiers - } - /// Iterate over the qualifiers and fields in the DFSchema pub fn iter(&self) -> impl Iterator, &FieldRef)> { self.field_qualifiers diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs index 8923df683f899..76c7b46e32737 100644 --- a/datafusion/common/src/lib.rs +++ b/datafusion/common/src/lib.rs @@ -117,8 +117,6 @@ pub mod hash_set { pub use hashbrown::hash_set::Entry; } -pub use hashbrown; - /// Downcast an Arrow Array to a concrete type, return an `DataFusionError::Internal` if the cast is /// not possible. In normal usage of DataFusion the downcast should always succeed. /// diff --git a/datafusion/common/src/utils/mod.rs b/datafusion/common/src/utils/mod.rs index ec2dad505a561..3fd0683659caf 100644 --- a/datafusion/common/src/utils/mod.rs +++ b/datafusion/common/src/utils/mod.rs @@ -41,6 +41,7 @@ use sqlparser::{ast::Ident, dialect::GenericDialect, parser::Parser}; use std::borrow::{Borrow, Cow}; use std::cmp::{min, Ordering}; use std::collections::HashSet; +use std::iter::repeat_n; use std::num::NonZero; use std::ops::Range; use std::sync::Arc; @@ -953,7 +954,7 @@ pub fn make_list_array_indices( ); for (i, (&start, &end)) in std::iter::zip(&offsets[..], &offsets[1..]).enumerate() { - indices.extend(std::iter::repeat_n( + indices.extend(repeat_n( T::Native::usize_as(i), end.as_usize() - start.as_usize(), )); @@ -966,16 +967,13 @@ pub fn make_list_array_indices( pub fn make_list_element_indices( offsets: &OffsetBuffer, ) -> PrimitiveArray { - let mut indices = vec![ - T::default_value(); - offsets.last().unwrap().as_usize() - - offsets.first().unwrap().as_usize() - ]; + let mut indices = + Vec::with_capacity(offsets.last().unwrap().as_usize() - offsets[0].as_usize()); for (&start, &end) in std::iter::zip(&offsets[..], &offsets[1..]) { - for i in 0..end.as_usize() - start.as_usize() { - indices[start.as_usize() + i] = T::Native::usize_as(i); - } + indices.extend( + (0..end.as_usize() - start.as_usize()).map(|i| T::Native::usize_as(i)), + ); } PrimitiveArray::new(indices.into(), None) @@ -986,12 +984,10 @@ pub fn make_fsl_array_indices( list_size: i32, array_len: usize, ) -> PrimitiveArray { - let mut indices = vec![0; list_size as usize * array_len]; + let mut indices = Vec::with_capacity(list_size as usize * array_len); for i in 0..array_len { - for j in 0..list_size as usize { - indices[i + j] = i as i32; - } + indices.extend(repeat_n(i as i32, list_size as usize)); } PrimitiveArray::new(indices.into(), None) @@ -1002,11 +998,13 @@ pub fn make_fsl_element_indices( list_size: i32, array_len: usize, ) -> PrimitiveArray { - let mut indices = vec![0; list_size as usize * array_len]; + let mut indices = Vec::with_capacity(list_size as usize * array_len); - for i in 0..array_len { - for j in 0..list_size as usize { - indices[i + j] = j as i32; + if array_len > 0 { + indices.extend((0..list_size as usize).map(|j| j as i32)); + + for _ in 1..array_len { + indices.extend_from_within(0..list_size as usize); } } diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index ad4ffb487ee1d..c15b7eae08432 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -41,6 +41,7 @@ use datafusion_common::alias::AliasGenerator; use datafusion_common::config::Dialect; use datafusion_common::config::{ConfigExtension, ConfigOptions, TableOptions}; use datafusion_common::display::{PlanType, StringifiedPlan, ToStringifiedPlan}; +use datafusion_common::tree_node::TreeNode; use datafusion_common::{ config_err, exec_err, plan_datafusion_err, DFSchema, DataFusionError, ResolvedTableReference, TableReference, @@ -700,9 +701,7 @@ impl SessionState { let config_options = self.config_options(); for rewrite in self.analyzer.function_rewrites() { expr = expr - .transform_up_with_schema(df_schema, |expr, df_schema| { - rewrite.rewrite(expr, df_schema, config_options) - })? + .transform_up(|expr| rewrite.rewrite(expr, df_schema, config_options))? .data; } create_physical_expr(&expr, df_schema, self.execution_props()) diff --git a/datafusion/core/tests/parquet/mod.rs b/datafusion/core/tests/parquet/mod.rs index eea6085c02b9f..097600e45eadd 100644 --- a/datafusion/core/tests/parquet/mod.rs +++ b/datafusion/core/tests/parquet/mod.rs @@ -516,7 +516,7 @@ fn make_uint_batches(start: u8, end: u8) -> RecordBatch { Field::new("u64", DataType::UInt64, true), ])); let v8: Vec = (start..end).collect(); - let v16: Vec = (start as u16..end as _).collect(); + let v16: Vec = (start as _..end as _).collect(); let v32: Vec = (start as _..end as _).collect(); let v64: Vec = (start as _..end as _).collect(); RecordBatch::try_new( diff --git a/datafusion/core/tests/parquet/schema_adapter.rs b/datafusion/core/tests/parquet/schema_adapter.rs index dfa4c91ba5dd8..40fc6176e212b 100644 --- a/datafusion/core/tests/parquet/schema_adapter.rs +++ b/datafusion/core/tests/parquet/schema_adapter.rs @@ -27,7 +27,7 @@ use datafusion::datasource::listing::{ ListingTable, ListingTableConfig, ListingTableConfigExt, }; use datafusion::prelude::{SessionConfig, SessionContext}; -use datafusion_common::tree_node::{Transformed, TransformedResult}; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::DataFusionError; use datafusion_common::{ColumnStatistics, ScalarValue}; use datafusion_datasource::file::FileSource; @@ -39,7 +39,7 @@ use datafusion_datasource::ListingTableUrl; use datafusion_datasource_parquet::source::ParquetSource; use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_physical_expr::expressions::{self, Column}; -use datafusion_physical_expr::{PhysicalExpr, PhysicalExprExt}; +use datafusion_physical_expr::PhysicalExpr; use datafusion_physical_expr_adapter::{ DefaultPhysicalExprAdapter, DefaultPhysicalExprAdapterFactory, PhysicalExprAdapter, PhysicalExprAdapterFactory, @@ -181,10 +181,10 @@ struct CustomPhysicalExprAdapter { impl PhysicalExprAdapter for CustomPhysicalExprAdapter { fn rewrite(&self, mut expr: Arc) -> Result> { expr = expr - .transform_with_lambdas_params(|expr, lambdas_params| { + .transform(|expr| { if let Some(column) = expr.as_any().downcast_ref::() { let field_name = column.name(); - if !lambdas_params.contains(field_name) && self + if self .physical_file_schema .field_with_name(field_name) .ok() diff --git a/datafusion/datasource-parquet/src/row_filter.rs b/datafusion/datasource-parquet/src/row_filter.rs index 45441ad71086c..660b32f486120 100644 --- a/datafusion/datasource-parquet/src/row_filter.rs +++ b/datafusion/datasource-parquet/src/row_filter.rs @@ -77,7 +77,7 @@ use datafusion_common::Result; use datafusion_datasource::schema_adapter::{SchemaAdapterFactory, SchemaMapper}; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::utils::reassign_expr_columns; -use datafusion_physical_expr::{split_conjunction, PhysicalExpr, PhysicalExprExt}; +use datafusion_physical_expr::{split_conjunction, PhysicalExpr}; use datafusion_physical_plan::metrics; @@ -336,20 +336,6 @@ impl<'schema> PushdownChecker<'schema> { fn prevents_pushdown(&self) -> bool { self.non_primitive_columns || self.projected_columns } - - fn check(&mut self, node: Arc) -> Result { - node.apply_with_lambdas_params(|node, lamdas_params| { - if let Some(column) = node.as_any().downcast_ref::() { - if !lamdas_params.contains(column.name()) { - if let Some(recursion) = self.check_single_column(column.name()) { - return Ok(recursion); - } - } - } - - Ok(TreeNodeRecursion::Continue) - }) - } } impl TreeNodeVisitor<'_> for PushdownChecker<'_> { diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index e2845ea5a7de8..6387fc4a44f38 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -398,10 +398,30 @@ pub enum Expr { OuterReferenceColumn(FieldRef, Column), /// Unnest expression Unnest(Unnest), - /// Lambda expression, valid only as a scalar function argument - /// Note that it has it's own scoped schema, different from the plan schema, - /// that can be constructed with ScalarUDF::arguments_schemas and variants + /// Lambda expression Lambda(Lambda), + LambdaColumn(LambdaColumn), +} + +#[derive(Clone, PartialEq, PartialOrd, Eq, Debug, Hash)] +pub struct LambdaColumn { + pub name: String, + pub field: FieldRef, + pub spans: Spans, +} + +impl LambdaColumn { + pub fn new(name: String, field: FieldRef) -> Self { + Self { + name, + field, + spans: Spans::new(), + } + } + + pub fn spans_mut(&mut self) -> &mut Spans { + &mut self.spans + } } impl Default for Expr { @@ -1547,6 +1567,7 @@ impl Expr { Expr::Wildcard { .. } => "Wildcard", Expr::Unnest { .. } => "Unnest", Expr::Lambda { .. } => "Lambda", + Expr::LambdaColumn { .. } => "LambdaColumn", } } @@ -1930,11 +1951,9 @@ impl Expr { /// /// See [`Self::column_refs`] for details pub fn add_column_refs<'a>(&'a self, set: &mut HashSet<&'a Column>) { - self.apply_with_lambdas_params(|expr, lambdas_params| { + self.apply(|expr| { if let Expr::Column(col) = expr { - if col.relation.is_some() || !lambdas_params.contains(col.name()) { - set.insert(col); - } + set.insert(col); } Ok(TreeNodeRecursion::Continue) }) @@ -1967,11 +1986,9 @@ impl Expr { /// /// See [`Self::column_refs_counts`] for details pub fn add_column_ref_counts<'a>(&'a self, map: &mut HashMap<&'a Column, usize>) { - self.apply_with_lambdas_params(|expr, lambdas_params| { + self.apply(|expr| { if let Expr::Column(col) = expr { - if !col.is_lambda_parameter(lambdas_params) { - *map.entry(col).or_default() += 1; - } + *map.entry(col).or_default() += 1; } Ok(TreeNodeRecursion::Continue) }) @@ -1980,10 +1997,8 @@ impl Expr { /// Returns true if there are any column references in this Expr pub fn any_column_refs(&self) -> bool { - self.exists_with_lambdas_params(|expr, lambdas_params| { - Ok(matches!(expr, Expr::Column(c) if !c.is_lambda_parameter(lambdas_params))) - }) - .expect("exists closure is infallible") + self.exists(|expr| Ok(matches!(expr, Expr::Column(_)))) + .expect("exists closure is infallible") } /// Return true if the expression contains out reference(correlated) expressions. @@ -2023,7 +2038,7 @@ impl Expr { /// at least one placeholder. pub fn infer_placeholder_types(self, schema: &DFSchema) -> Result<(Expr, bool)> { let mut has_placeholder = false; - self.transform_with_schema(schema, |mut expr, schema| { + self.transform(|mut expr| { match &mut expr { // Default to assuming the arguments are the same type Expr::BinaryExpr(BinaryExpr { left, op: _, right }) => { @@ -2107,7 +2122,8 @@ impl Expr { | Expr::WindowFunction(..) | Expr::Literal(..) | Expr::Placeholder(..) - | Expr::Lambda { .. } => false, + | Expr::Lambda(..) + | Expr::LambdaColumn(..) => false, } } @@ -2703,12 +2719,17 @@ impl HashNode for Expr { column.hash(state); } Expr::Unnest(Unnest { expr: _expr }) => {} - Expr::Lambda(Lambda { - params, - body: _, - }) => { + Expr::Lambda(Lambda { params, body: _ }) => { params.hash(state); } + Expr::LambdaColumn(LambdaColumn { + name, + field, + spans: _, + }) => { + name.hash(state); + field.hash(state); + } }; } } @@ -3022,12 +3043,12 @@ impl Display for SchemaDisplay<'_> { } } } - Expr::Lambda(Lambda { - params, - body, - }) => { + Expr::Lambda(Lambda { params, body }) => { write!(f, "({}) -> {body}", display_comma_separated(params)) } + Expr::LambdaColumn(c) => { + write!(f, "{}", c.name) + } } } } @@ -3208,10 +3229,7 @@ impl Display for SqlDisplay<'_> { } } } - Expr::Lambda(Lambda { - params, - body, - }) => { + Expr::Lambda(Lambda { params, body }) => { write!(f, "({}) -> {}", params.join(", "), SchemaDisplay(body)) } _ => write!(f, "{}", self.0), @@ -3521,12 +3539,12 @@ impl Display for Expr { Expr::Unnest(Unnest { expr }) => { write!(f, "{UNNEST_COLUMN_PREFIX}({expr})") } - Expr::Lambda(Lambda { - params, - body, - }) => { + Expr::Lambda(Lambda { params, body }) => { write!(f, "({}) -> {body}", params.join(", ")) } + Expr::LambdaColumn(c) => { + write!(f, "{}", c.name) + } } } } diff --git a/datafusion/expr/src/expr_rewriter/mod.rs b/datafusion/expr/src/expr_rewriter/mod.rs index 81ec6e7acbe38..9c3c5df7007ff 100644 --- a/datafusion/expr/src/expr_rewriter/mod.rs +++ b/datafusion/expr/src/expr_rewriter/mod.rs @@ -62,15 +62,11 @@ pub trait FunctionRewrite: Debug { /// Recursively call `LogicalPlanBuilder::normalize` on all [`Column`] expressions /// in the `expr` expression tree. pub fn normalize_col(expr: Expr, plan: &LogicalPlan) -> Result { - expr.transform_up_with_lambdas_params(|expr, lambdas_params| { + expr.transform(|expr| { Ok({ if let Expr::Column(c) = expr { - if c.relation.is_some() || !lambdas_params.contains(c.name()) { - let col = LogicalPlanBuilder::normalize(plan, c)?; - Transformed::yes(Expr::Column(col)) - } else { - Transformed::no(Expr::Column(c)) - } + let col = LogicalPlanBuilder::normalize(plan, c)?; + Transformed::yes(Expr::Column(col)) } else { Transformed::no(expr) } @@ -95,21 +91,14 @@ pub fn normalize_col_with_schemas_and_ambiguity_check( return Ok(Expr::Unnest(Unnest { expr: Box::new(e) })); } - expr.transform_up_with_lambdas_params(|expr, lambdas_params| { + expr.transform(|expr| { Ok({ - match expr { - Expr::Column(c) => { - if c.relation.is_none() && lambdas_params.contains(c.name()) { - Transformed::no(Expr::Column(c)) - } else { - let col = c.normalize_with_schemas_and_ambiguity_check( - schemas, - using_columns, - )?; - Transformed::yes(Expr::Column(col)) - } - } - _ => Transformed::no(expr), + if let Expr::Column(c) = expr { + let col = + c.normalize_with_schemas_and_ambiguity_check(schemas, using_columns)?; + Transformed::yes(Expr::Column(col)) + } else { + Transformed::no(expr) } }) }) @@ -144,18 +133,15 @@ pub fn normalize_sorts( /// Recursively replace all [`Column`] expressions in a given expression tree with /// `Column` expressions provided by the hash map argument. pub fn replace_col(expr: Expr, replace_map: &HashMap<&Column, &Column>) -> Result { - expr.transform_up_with_lambdas_params(|expr, lambdas_params| { + expr.transform(|expr| { Ok({ - match &expr { - Expr::Column(c) if !c.is_lambda_parameter(lambdas_params) => { - match replace_map.get(c) { - Some(new_c) => { - Transformed::yes(Expr::Column((*new_c).to_owned())) - } - None => Transformed::no(expr), - } + if let Expr::Column(c) = &expr { + match replace_map.get(c) { + Some(new_c) => Transformed::yes(Expr::Column((*new_c).to_owned())), + None => Transformed::no(expr), } - _ => Transformed::no(expr), + } else { + Transformed::no(expr) } }) }) @@ -215,7 +201,6 @@ pub fn strip_outer_reference(expr: Expr) -> Expr { expr.transform(|expr| { Ok({ if let Expr::OuterReferenceColumn(_, col) = expr { - //todo: what if this col collides with a lambda parameter? Transformed::yes(Expr::Column(col)) } else { Transformed::no(expr) diff --git a/datafusion/expr/src/expr_rewriter/order_by.rs b/datafusion/expr/src/expr_rewriter/order_by.rs index b94c632ce74b3..6db95555502da 100644 --- a/datafusion/expr/src/expr_rewriter/order_by.rs +++ b/datafusion/expr/src/expr_rewriter/order_by.rs @@ -77,10 +77,6 @@ fn rewrite_in_terms_of_projection( // assumption is that each item in exprs, such as "b + c" is // available as an output column named "b + c" expr.transform(|expr| { - if matches!(expr, Expr::Lambda(_)) { - return Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump)) - } - // search for unnormalized names first such as "c1" (such as aliases) if let Some(found) = proj_exprs.iter().find(|a| (**a) == expr) { let (qualifier, field_name) = found.qualified_name(); diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 4a1efadccd0ec..1b7f5f0212c6b 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -16,12 +16,12 @@ // under the License. use super::{Between, Expr, Like}; -use crate::expr::FieldMetadata; use crate::expr::{ AggregateFunction, AggregateFunctionParams, Alias, BinaryExpr, Cast, InList, InSubquery, Lambda, Placeholder, ScalarFunction, TryCast, Unnest, WindowFunction, WindowFunctionParams, }; +use crate::expr::{FieldMetadata, LambdaColumn}; use crate::type_coercion::functions::{ fields_with_aggregate_udf, fields_with_window_udf, }; @@ -234,7 +234,10 @@ impl ExprSchemable for Expr { // Grouping sets do not really have a type and do not appear in projections Ok(DataType::Null) } - Expr::Lambda { .. } => Ok(DataType::Null), + Expr::Lambda(Lambda { params: _, body }) => body.get_type(schema), + Expr::LambdaColumn(LambdaColumn { name: _, field, .. }) => { + Ok(field.data_type().clone()) + } } } @@ -353,7 +356,8 @@ impl ExprSchemable for Expr { // in projections Ok(true) } - Expr::Lambda { .. } => Ok(false), + Expr::Lambda(l) => l.body.nullable(input_schema), + Expr::LambdaColumn(c) => Ok(c.field.is_nullable()), } } @@ -542,30 +546,14 @@ impl ExprSchemable for Expr { func.return_field(&new_fields) } - // Expr::Lambda(Lambda { params, body}) => body.to_field(schema), Expr::ScalarFunction(ScalarFunction { func, args }) => { - let fields = if args.iter().any(|arg| matches!(arg, Expr::Lambda(_))) { - let lambdas_schemas = func.arguments_expr_schema(args, schema)?; - - std::iter::zip(args, lambdas_schemas) - // .map(|(e, schema)| e.to_field(schema).map(|(_, f)| f)) - .map(|(e, schema)| match e { - Expr::Lambda(Lambda { params: _, body }) => { - body.to_field(&schema).map(|(_, f)| f) - } - _ => e.to_field(&schema).map(|(_, f)| f), - }) - .collect::>>()? - } else { - args.iter() - .map(|e| e.to_field(schema).map(|(_, f)| f)) - .collect::>>()? - }; - - let arg_types = fields + let (arg_types, fields): (Vec, Vec>) = args .iter() - .map(|f| f.data_type().clone()) - .collect::>(); + .map(|e| e.to_field(schema).map(|(_, f)| f)) + .collect::>>()? + .into_iter() + .map(|f| (f.data_type().clone(), f)) + .unzip(); // Verify that function is invoked with correct number and type of arguments as defined in `TypeSignature` let new_data_types = data_types_with_scalar_udf(&arg_types, func) @@ -637,6 +625,7 @@ impl ExprSchemable for Expr { self.get_type(schema)?, self.nullable(schema)?, ))), + Expr::LambdaColumn(c) => Ok(Arc::clone(&c.field)), }?; Ok(( diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 46c7422814ace..0f26218e74779 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -118,10 +118,8 @@ pub use udaf::{ ReversedUDAF, SetMonotonicity, StatisticsArgs, }; pub use udf::{ - merge_captures_with_args, merge_captures_with_boxed_lazy_args, - merge_captures_with_lazy_args, ReturnFieldArgs, ScalarFunctionArgs, - ScalarFunctionLambdaArg, ScalarUDF, ScalarUDFImpl, ValueOrLambda, ValueOrLambdaField, - ValueOrLambdaParameter, + ReturnFieldArgs, ScalarFunctionArgs, ScalarFunctionLambdaArg, ScalarUDF, + ScalarUDFImpl, ValueOrLambda, ValueOrLambdaField, ValueOrLambdaParameter, }; pub use udwf::{LimitEffect, ReversedUDWF, WindowUDF, WindowUDFImpl}; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs index 63c535b43ee8b..df98f720a0f08 100644 --- a/datafusion/expr/src/tree_node.rs +++ b/datafusion/expr/src/tree_node.rs @@ -29,7 +29,7 @@ use datafusion_common::{ tree_node::{ Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion, TreeNodeRefContainer, }, - DFSchema, HashSet, Result, + Result, }; /// Implementation of the [`TreeNode`] trait @@ -80,7 +80,8 @@ impl TreeNode for Expr { | Expr::Exists { .. } | Expr::ScalarSubquery(_) | Expr::Wildcard { .. } - | Expr::Placeholder(_) => Ok(TreeNodeRecursion::Continue), + | Expr::Placeholder(_) + | Expr::LambdaColumn(_) => Ok(TreeNodeRecursion::Continue), Expr::BinaryExpr(BinaryExpr { left, right, .. }) => { (left, right).apply_ref_elements(f) } @@ -131,7 +132,8 @@ impl TreeNode for Expr { | Expr::Exists { .. } | Expr::ScalarSubquery(_) | Expr::ScalarVariable(_, _) - | Expr::Literal(_, _) => Transformed::no(self), + | Expr::Literal(_, _) + | Expr::LambdaColumn(_) => Transformed::no(self), Expr::Unnest(Unnest { expr, .. }) => expr .map_elements(f)? .update_data(|expr| Expr::Unnest(Unnest { expr })), @@ -321,680 +323,3 @@ impl TreeNode for Expr { }) } } - -impl Expr { - /// Similarly to [`Self::rewrite`], rewrites this expr and its inputs using `f`, - /// including lambdas that may appear in expressions such as `array_transform([1, 2], v -> v*2)`. - #[cfg_attr(feature = "recursive_protection", recursive::recursive)] - pub fn rewrite_with_schema< - R: for<'a> TreeNodeRewriterWithPayload = &'a DFSchema>, - >( - self, - schema: &DFSchema, - rewriter: &mut R, - ) -> Result> { - rewriter - .f_down(self, schema)? - .transform_children(|n| match &n { - Expr::ScalarFunction(ScalarFunction { func, args }) - if args.iter().any(|arg| matches!(arg, Expr::Lambda(_))) => - { - let mut lambdas_schemas = func - .arguments_schema_from_logical_args(args, schema)? - .into_iter(); - - n.map_children(|n| { - n.rewrite_with_schema(&lambdas_schemas.next().unwrap(), rewriter) - }) - } - _ => n.map_children(|n| n.rewrite_with_schema(schema, rewriter)), - })? - .transform_parent(|n| rewriter.f_up(n, schema)) - } - - /// Similarly to [`Self::rewrite`], rewrites this expr and its inputs using `f`, - /// including lambdas that may appear in expressions such as `array_transform([1, 2], v -> v*2)`. - pub fn rewrite_with_lambdas_params< - R: for<'a> TreeNodeRewriterWithPayload< - Node = Expr, - Payload<'a> = &'a HashSet, - >, - >( - self, - rewriter: &mut R, - ) -> Result> { - self.rewrite_with_lambdas_params_impl(&HashSet::new(), rewriter) - } - - #[cfg_attr(feature = "recursive_protection", recursive::recursive)] - fn rewrite_with_lambdas_params_impl< - R: for<'a> TreeNodeRewriterWithPayload< - Node = Expr, - Payload<'a> = &'a HashSet, - >, - >( - self, - args: &HashSet, - rewriter: &mut R, - ) -> Result> { - rewriter - .f_down(self, args)? - .transform_children(|n| match n { - Expr::Lambda(Lambda { - ref params, - body: _, - }) => { - let mut args = args.clone(); - - args.extend(params.iter().cloned()); - - n.map_children(|n| { - n.rewrite_with_lambdas_params_impl(&args, rewriter) - }) - } - _ => { - n.map_children(|n| n.rewrite_with_lambdas_params_impl(args, rewriter)) - } - })? - .transform_parent(|n| rewriter.f_up(n, args)) - } - - /// Similarly to [`Self::map_children`], rewrites all lambdas that may - /// appear in expressions such as `array_transform([1, 2], v -> v*2)`. - /// - /// Returns the current node. - pub fn map_children_with_lambdas_params< - F: FnMut(Self, &HashSet) -> Result>, - >( - self, - args: &HashSet, - mut f: F, - ) -> Result> { - match &self { - Expr::Lambda(Lambda { params, body: _ }) => { - let mut args = args.clone(); - - args.extend(params.iter().cloned()); - - self.map_children(|expr| f(expr, &args)) - } - _ => self.map_children(|expr| f(expr, args)), - } - } - - /// Similarly to [`Self::transform_up`], rewrites this expr and its inputs using `f`, - /// including lambdas that may appear in expressions such as `array_transform([1, 2], v -> v*2)`. - pub fn transform_up_with_lambdas_params< - F: FnMut(Self, &HashSet) -> Result>, - >( - self, - mut f: F, - ) -> Result> { - #[cfg_attr(feature = "recursive_protection", recursive::recursive)] - fn transform_up_with_lambdas_params_impl< - F: FnMut(Expr, &HashSet) -> Result>, - >( - node: Expr, - args: &HashSet, - f: &mut F, - ) -> Result> { - node.map_children_with_lambdas_params(args, |node, args| { - transform_up_with_lambdas_params_impl(node, args, f) - })? - .transform_parent(|node| f(node, args)) - /*match &node { - Expr::Lambda(Lambda { params, body: _ }) => { - let mut args = args.clone(); - - args.extend(params.iter().cloned()); - - node.map_children(|n| { - transform_up_with_lambdas_params_impl(n, &args, f) - })? - .transform_parent(|n| f(n, &args)) - } - _ => node - .map_children(|n| transform_up_with_lambdas_params_impl(n, args, f))? - .transform_parent(|n| f(n, args)), - }*/ - } - - transform_up_with_lambdas_params_impl(self, &HashSet::new(), &mut f) - } - - /// Similarly to [`Self::transform_down`], rewrites this expr and its inputs using `f`, - /// including lambdas that may appear in expressions such as `array_transform([1, 2], v -> v*2)`. - pub fn transform_down_with_lambdas_params< - F: FnMut(Self, &HashSet) -> Result>, - >( - self, - mut f: F, - ) -> Result> { - #[cfg_attr(feature = "recursive_protection", recursive::recursive)] - fn transform_down_with_lambdas_params_impl< - F: FnMut(Expr, &HashSet) -> Result>, - >( - node: Expr, - args: &HashSet, - f: &mut F, - ) -> Result> { - f(node, args)?.transform_children(|node| { - node.map_children_with_lambdas_params(args, |node, args| { - transform_down_with_lambdas_params_impl(node, args, f) - }) - }) - } - - transform_down_with_lambdas_params_impl(self, &HashSet::new(), &mut f) - } - - pub fn apply_with_lambdas_params< - 'n, - F: FnMut(&'n Self, &HashSet<&'n str>) -> Result, - >( - &'n self, - mut f: F, - ) -> Result { - #[cfg_attr(feature = "recursive_protection", recursive::recursive)] - fn apply_with_lambdas_params_impl< - 'n, - F: FnMut(&'n Expr, &HashSet<&'n str>) -> Result, - >( - node: &'n Expr, - args: &HashSet<&'n str>, - f: &mut F, - ) -> Result { - match node { - Expr::Lambda(Lambda { params, body: _ }) => { - let mut args = args.clone(); - - args.extend(params.iter().map(|v| v.as_str())); - - f(node, &args)?.visit_children(|| { - node.apply_children(|c| { - apply_with_lambdas_params_impl(c, &args, f) - }) - }) - } - _ => f(node, args)?.visit_children(|| { - node.apply_children(|c| apply_with_lambdas_params_impl(c, args, f)) - }), - } - } - - apply_with_lambdas_params_impl(self, &HashSet::new(), &mut f) - } - - /// Similarly to [`Self::transform`], rewrites this expr and its inputs using `f`, - /// including lambdas that may appear in expressions such as `array_transform([1, 2], v -> v*2)`. - pub fn transform_with_schema< - F: FnMut(Self, &DFSchema) -> Result>, - >( - self, - schema: &DFSchema, - f: F, - ) -> Result> { - self.transform_up_with_schema(schema, f) - } - - /// Similarly to [`Self::transform_up`], rewrites this expr and its inputs using `f`, - /// including lambdas that may appear in expressions such as `array_transform([1, 2], v -> v*2)`. - pub fn transform_up_with_schema< - F: FnMut(Self, &DFSchema) -> Result>, - >( - self, - schema: &DFSchema, - mut f: F, - ) -> Result> { - #[cfg_attr(feature = "recursive_protection", recursive::recursive)] - fn transform_up_with_schema_impl< - F: FnMut(Expr, &DFSchema) -> Result>, - >( - node: Expr, - schema: &DFSchema, - f: &mut F, - ) -> Result> { - node.map_children_with_schema(schema, |n, schema| { - transform_up_with_schema_impl(n, schema, f) - })? - .transform_parent(|n| f(n, schema)) - } - - transform_up_with_schema_impl(self, schema, &mut f) - } - - pub fn map_children_with_schema< - F: FnMut(Self, &DFSchema) -> Result>, - >( - self, - schema: &DFSchema, - mut f: F, - ) -> Result> { - match self { - Expr::ScalarFunction(ref fun) - if fun.args.iter().any(|arg| matches!(arg, Expr::Lambda(_))) => - { - let mut args_schemas = fun - .func - .arguments_schema_from_logical_args(&fun.args, schema)? - .into_iter(); - - self.map_children(|expr| f(expr, &args_schemas.next().unwrap())) - } - _ => self.map_children(|expr| f(expr, schema)), - } - } - - pub fn exists_with_lambdas_params) -> Result>( - &self, - mut f: F, - ) -> Result { - let mut found = false; - - self.apply_with_lambdas_params(|n, lambdas_params| { - if f(n, lambdas_params)? { - found = true; - Ok(TreeNodeRecursion::Stop) - } else { - Ok(TreeNodeRecursion::Continue) - } - })?; - - Ok(found) - } -} - -pub trait ExprWithLambdasRewriter2: Sized { - /// Invoked while traversing down the tree before any children are rewritten. - /// Default implementation returns the node as is and continues recursion. - fn f_down(&mut self, node: Expr, _schema: &DFSchema) -> Result> { - Ok(Transformed::no(node)) - } - - /// Invoked while traversing up the tree after all children have been rewritten. - /// Default implementation returns the node as is and continues recursion. - fn f_up(&mut self, node: Expr, _schema: &DFSchema) -> Result> { - Ok(Transformed::no(node)) - } -} -pub trait TreeNodeRewriterWithPayload: Sized { - type Node; - type Payload<'a>; - - /// Invoked while traversing down the tree before any children are rewritten. - /// Default implementation returns the node as is and continues recursion. - fn f_down<'a>( - &mut self, - node: Self::Node, - _payload: Self::Payload<'a>, - ) -> Result> { - Ok(Transformed::no(node)) - } - - /// Invoked while traversing up the tree after all children have been rewritten. - /// Default implementation returns the node as is and continues recursion. - fn f_up<'a>( - &mut self, - node: Self::Node, - _payload: Self::Payload<'a>, - ) -> Result> { - Ok(Transformed::no(node)) - } -} - -/* -struct LambdaColumnNormalizer<'a> { - existing_qualifiers: HashSet<&'a str>, - alias_generator: AliasGenerator, - lambdas_columns: HashMap>, -} - -impl<'a> LambdaColumnNormalizer<'a> { - fn new(dfschema: &'a DFSchema, expr: &'a Expr) -> Self { - let mut existing_qualifiers: HashSet<&'a str> = dfschema - .field_qualifiers() - .iter() - .flatten() - .map(|tbl| tbl.table()) - .filter(|table| table.starts_with("lambda_")) - .collect(); - - expr.apply(|node| { - if let Expr::Lambda(lambda) = node { - if let Some(qualifier) = &lambda.qualifier { - existing_qualifiers.insert(qualifier); - } - } - - Ok(TreeNodeRecursion::Continue) - }) - .unwrap(); - - Self { - existing_qualifiers, - alias_generator: AliasGenerator::new(), - lambdas_columns: HashMap::new(), - } - } -} - -impl TreeNodeRewriter for LambdaColumnNormalizer<'_> { - type Node = Expr; - - fn f_down(&mut self, node: Self::Node) -> Result> { - match node { - Expr::Lambda(mut lambda) => { - let tbl = lambda.qualifier.as_ref().map_or_else( - || loop { - let table = self.alias_generator.next("lambda"); - - if !self.existing_qualifiers.contains(table.as_str()) { - break TableReference::bare(table); - } - }, - |qualifier| TableReference::bare(qualifier.as_str()), - ); - - for param in &lambda.params { - self.lambdas_columns - .entry_ref(param) - .or_default() - .push(tbl.clone()); - } - - if lambda.qualifier.is_none() { - lambda.qualifier = Some(tbl.table().to_owned()); - - Ok(Transformed::yes(Expr::Lambda(lambda))) - } else { - Ok(Transformed::no(Expr::Lambda(lambda))) - } - } - Expr::Column(c) if c.relation.is_none() => { - if let Some(lambda_qualifier) = self.lambdas_columns.get(c.name()) { - Ok(Transformed::yes(Expr::Column( - c.with_relation(lambda_qualifier.last().unwrap().clone()), - ))) - } else { - Ok(Transformed::no(Expr::Column(c))) - } - } - _ => Ok(Transformed::no(node)) - } - } - - fn f_up(&mut self, node: Self::Node) -> Result> { - if let Expr::Lambda(lambda) = &node { - for param in &lambda.params { - match self.lambdas_columns.entry_ref(param) { - EntryRef::Occupied(mut entry) => { - let chain = entry.get_mut(); - - chain.pop(); - - if chain.is_empty() { - entry.remove(); - } - } - EntryRef::Vacant(_) => unreachable!(), - } - } - } - - Ok(Transformed::no(node)) - } -} -*/ - -// helpers used in udf.rs -#[cfg(test)] -pub(crate) mod tests { - use super::TreeNodeRewriterWithPayload; - use crate::{ - col, expr::Lambda, Expr, ScalarUDF, ScalarUDFImpl, ValueOrLambdaParameter, - }; - use arrow::datatypes::{DataType, Field, Schema}; - use datafusion_common::{ - tree_node::{Transformed, TreeNodeRecursion}, - DFSchema, HashSet, Result, - }; - use datafusion_expr_common::signature::{Signature, Volatility}; - - pub(crate) fn list_list_int() -> DFSchema { - DFSchema::try_from(Schema::new(vec![Field::new( - "v", - DataType::new_list(DataType::new_list(DataType::Int32, false), false), - false, - )])) - .unwrap() - } - - pub(crate) fn list_int() -> DFSchema { - DFSchema::try_from(Schema::new(vec![Field::new( - "v", - DataType::new_list(DataType::Int32, false), - false, - )])) - .unwrap() - } - - fn int() -> DFSchema { - DFSchema::try_from(Schema::new(vec![Field::new("v", DataType::Int32, false)])) - .unwrap() - } - - pub(crate) fn array_transform_udf() -> ScalarUDF { - ScalarUDF::new_from_impl(ArrayTransformFunc::new()) - } - - pub(crate) fn args() -> Vec { - vec![ - col("v"), - Expr::Lambda(Lambda::new( - vec!["v".into()], - array_transform_udf().call(vec![ - col("v"), - Expr::Lambda(Lambda::new(vec!["v".into()], -col("v"))), - ]), - )), - ] - } - - // array_transform(v, |v| -> array_transform(v, |v| -> -v)) - fn array_transform() -> Expr { - array_transform_udf().call(args()) - } - - #[derive(Debug, PartialEq, Eq, Hash)] - pub(crate) struct ArrayTransformFunc { - signature: Signature, - } - - impl ArrayTransformFunc { - pub fn new() -> Self { - Self { - signature: Signature::any(2, Volatility::Immutable), - } - } - } - - impl ScalarUDFImpl for ArrayTransformFunc { - fn as_any(&self) -> &dyn std::any::Any { - self - } - - fn name(&self) -> &str { - "array_transform" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(arg_types[0].clone()) - } - - fn lambdas_parameters( - &self, - args: &[ValueOrLambdaParameter], - ) -> Result>>> { - let ValueOrLambdaParameter::Value(value_field) = &args[0] else { - unreachable!() - }; - - let DataType::List(field) = value_field.data_type() else { - unreachable!() - }; - - Ok(vec![ - None, - Some(vec![Field::new( - "", - field.data_type().clone(), - field.is_nullable(), - )]), - ]) - } - - fn invoke_with_args( - &self, - _args: crate::ScalarFunctionArgs, - ) -> Result { - unimplemented!() - } - } - - #[test] - fn test_rewrite_with_schema() { - let schema = list_list_int(); - let array_transform = array_transform(); - - let mut rewriter = OkRewriter::default(); - - array_transform - .rewrite_with_schema(&schema, &mut rewriter) - .unwrap(); - - let expected = [ - ( - "f_down array_transform(v, (v) -> array_transform(v, (v) -> (- v)))", - list_list_int(), - ), - ("f_down v", list_list_int()), - ("f_up v", list_list_int()), - ("f_down (v) -> array_transform(v, (v) -> (- v))", list_int()), - ("f_down array_transform(v, (v) -> (- v))", list_int()), - ("f_down v", list_int()), - ("f_up v", list_int()), - ("f_down (v) -> (- v)", int()), - ("f_down (- v)", int()), - ("f_down v", int()), - ("f_up v", int()), - ("f_up (- v)", int()), - ("f_up (v) -> (- v)", int()), - ("f_up array_transform(v, (v) -> (- v))", list_int()), - ("f_up (v) -> array_transform(v, (v) -> (- v))", list_int()), - ( - "f_up array_transform(v, (v) -> array_transform(v, (v) -> (- v)))", - list_list_int(), - ), - ] - .map(|(a, b)| (String::from(a), b)); - - assert_eq!(rewriter.steps, expected) - } - - #[derive(Default)] - struct OkRewriter { - steps: Vec<(String, DFSchema)>, - } - - impl TreeNodeRewriterWithPayload for OkRewriter { - type Node = Expr; - type Payload<'a> = &'a DFSchema; - - fn f_down( - &mut self, - node: Expr, - schema: &DFSchema, - ) -> Result> { - self.steps.push((format!("f_down {node}"), schema.clone())); - - Ok(Transformed::no(node)) - } - - fn f_up( - &mut self, - node: Expr, - schema: &DFSchema, - ) -> Result> { - self.steps.push((format!("f_up {node}"), schema.clone())); - - Ok(Transformed::no(node)) - } - } - - #[test] - fn test_transform_up_with_lambdas_params() { - let mut steps = vec![]; - - array_transform() - .transform_up_with_lambdas_params(|node, params| { - steps.push((node.to_string(), params.clone())); - - Ok(Transformed::no(node)) - }) - .unwrap(); - - let lambdas_params = &HashSet::from([String::from("v")]); - - let expected = [ - ("v", lambdas_params), - ("v", lambdas_params), - ("v", lambdas_params), - ("(- v)", lambdas_params), - ("(v) -> (- v)", lambdas_params), - ("array_transform(v, (v) -> (- v))", lambdas_params), - ("(v) -> array_transform(v, (v) -> (- v))", lambdas_params), - ( - "array_transform(v, (v) -> array_transform(v, (v) -> (- v)))", - lambdas_params, - ), - ] - .map(|(a, b)| (String::from(a), b.clone())); - - assert_eq!(steps, expected); - } - - #[test] - fn test_apply_with_lambdas_params() { - let array_transform = array_transform(); - let mut steps = vec![]; - - array_transform - .apply_with_lambdas_params(|node, params| { - steps.push((node.to_string(), params.clone())); - - Ok(TreeNodeRecursion::Continue) - }) - .unwrap(); - - let expected = [ - ("v", HashSet::from(["v"])), - ("v", HashSet::from(["v"])), - ("v", HashSet::from(["v"])), - ("(- v)", HashSet::from(["v"])), - ("(v) -> (- v)", HashSet::from(["v"])), - ("array_transform(v, (v) -> (- v))", HashSet::from(["v"])), - ("(v) -> array_transform(v, (v) -> (- v))", HashSet::from(["v"])), - ( - "array_transform(v, (v) -> array_transform(v, (v) -> (- v)))", - HashSet::from(["v"]), - ), - ] - .map(|(a, b)| (String::from(a), b)); - - assert_eq!(steps, expected); - } -} diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 74ac1b456ff04..911fc890e2bc5 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -18,30 +18,23 @@ //! [`ScalarUDF`]: Scalar User Defined Functions use crate::async_udf::AsyncScalarUDF; -use crate::expr::{schema_name_from_exprs_comma_separated_without_space, Lambda}; +use crate::expr::schema_name_from_exprs_comma_separated_without_space; use crate::simplify::{ExprSimplifyResult, SimplifyInfo}; use crate::sort_properties::{ExprProperties, SortProperties}; use crate::udf_eq::UdfEq; -use crate::{ColumnarValue, Documentation, Expr, ExprSchemable, Signature}; -use arrow::array::{ArrayRef, RecordBatch}; -use arrow::datatypes::{DataType, Field, FieldRef, Fields, Schema}; -use datafusion_common::alias::AliasGenerator; +use crate::{ColumnarValue, Documentation, Expr, Signature}; +use arrow::array::RecordBatch; +use arrow::datatypes::{DataType, Field, FieldRef}; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::TreeNodeRecursion; -use datafusion_common::{ - exec_err, not_impl_err, DFSchema, ExprSchema, Result, ScalarValue, -}; +use datafusion_common::{not_impl_err, ExprSchema, Result, ScalarValue}; use datafusion_expr_common::dyn_eq::{DynEq, DynHash}; use datafusion_expr_common::interval_arithmetic::Interval; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; -use indexmap::IndexMap; use std::any::Any; -use std::borrow::Cow; use std::cmp::Ordering; -use std::collections::HashMap; use std::fmt::Debug; use std::hash::{Hash, Hasher}; -use std::sync::{Arc, LazyLock}; +use std::sync::Arc; /// Logical representation of a Scalar User Defined Function. /// @@ -352,272 +345,14 @@ impl ScalarUDF { pub fn as_async(&self) -> Option<&AsyncScalarUDF> { self.inner().as_any().downcast_ref::() } - - /// Variation of arguments_from_logical_args that works with arrow Schema's and ScalarFunctionArgMetadata instead - pub(crate) fn arguments_expr_schema<'a>( - &self, - args: &[Expr], - schema: &'a dyn ExprSchema, - ) -> Result> { - self.arguments_scope_with( - &lambda_parameters(args, schema)?, - ExtendableExprSchema::new(schema), - ) - } - - /// Variation of arguments_from_logical_args that works with arrow Schema's and ScalarFunctionArgMetadata instead, - pub fn arguments_arrow_schema<'a>( - &self, - args: &[ValueOrLambdaParameter], - schema: &'a Schema, - ) -> Result>> { - self.arguments_scope_with(args, Cow::Borrowed(schema)) - } - - pub fn arguments_schema_from_logical_args<'a>( - &self, - args: &[Expr], - schema: &'a DFSchema, - ) -> Result>> { - self.arguments_scope_with( - &lambda_parameters(args, schema)?, - Cow::Borrowed(schema), - ) - } - - /// Scalar function supports lambdas as arguments, which will be evaluated with - /// a different schema that of the function itself. This functions returns a vec - /// with the correspoding schema that each argument will run - /// - /// Return a vec with a value for each argument in args that, if it's a value, it's a clone of base_scope, - /// if it's a lambda, it's the return of merge called with the index and the fields from lambdas_parameters - /// updated with names from metadata - fn arguments_scope_with( - &self, - args: &[ValueOrLambdaParameter], - schema: T, - ) -> Result> { - let parameters = self.inner().lambdas_parameters(args)?; - - if parameters.len() != args.len() { - return exec_err!( - "lambdas_schemas: {} lambdas_parameters returned {} values instead of {}", - self.name(), - args.len(), - parameters.len() - ); - } - - std::iter::zip(args, parameters) - .enumerate() - .map(|(i, (arg, parameters))| match (arg, parameters) { - (ValueOrLambdaParameter::Value(_), None) => Ok(schema.clone()), - (ValueOrLambdaParameter::Value(_), Some(_)) => exec_err!("lambdas_schemas: {} argument {} (0-indexed) is a value but lambdas_parameters result treat it as a lambda", self.name(), i), - (ValueOrLambdaParameter::Lambda(_, _), None) => exec_err!("lambdas_schemas: {} argument {} (0-indexed) is a lambda but lambdas_parameters result treat it as a value", self.name(), i), - (ValueOrLambdaParameter::Lambda(names, captures), Some(args)) => { - if names.len() > args.len() { - return exec_err!("lambdas_schemas: {} argument {} (0-indexed), a lambda, supports up to {} arguments, but got {}", self.name(), i, args.len(), names.len()) - } - - let fields = std::iter::zip(*names, args) - .map(|(name, arg)| arg.with_name(name)) - .collect::(); - - if *captures { - schema.extend(fields) - } else { - T::from_fields(fields) - } - } - }) - .collect() - } -} - -pub trait ExtendSchema: Sized { - fn from_fields(params: Fields) -> Result; - fn extend(&self, params: Fields) -> Result; -} - -impl ExtendSchema for DFSchema { - fn from_fields(params: Fields) -> Result { - DFSchema::from_unqualified_fields(params, Default::default()) - } - - fn extend(&self, params: Fields) -> Result { - let qualified_fields = self - .iter() - .map(|(qualifier, field)| { - if params.find(field.name().as_str()).is_none() { - return (qualifier.cloned(), Arc::clone(field)); - } - - let alias_gen = AliasGenerator::new(); - - loop { - let alias = alias_gen.next(field.name().as_str()); - - if params.find(&alias).is_none() - && !self.has_column_with_unqualified_name(&alias) - { - return ( - qualifier.cloned(), - Arc::new(Field::new( - alias, - field.data_type().clone(), - field.is_nullable(), - )), - ); - } - } - }) - .collect(); - - let mut schema = DFSchema::new_with_metadata(qualified_fields, HashMap::new())?; - let fields_schema = DFSchema::from_unqualified_fields(params, HashMap::new())?; - - schema.merge(&fields_schema); - - assert_eq!( - schema.fields().len(), - self.fields().len() + fields_schema.fields().len() - ); - - Ok(schema) - } -} - -impl ExtendSchema for Schema { - fn from_fields(params: Fields) -> Result { - Ok(Schema::new(params)) - } - - fn extend(&self, params: Fields) -> Result { - let mut params2 = params.iter() - .map(|f| (f.name().as_str(), Some(Arc::clone(f)))) - .collect::>(); - - let mut fields = self.fields() - .iter() - .map(|field| { - match params2.get_mut(field.name().as_str()).and_then(|p| p.take()) { - Some(param) => param, - None => Arc::clone(field), - } - }) - .collect::>(); - - fields.extend(params2.into_values().flatten()); - - let fields = self - .fields() - .iter() - .map(|field| { - if params.find(field.name().as_str()).is_none() { - return Arc::clone(field); - } - - let alias_gen = AliasGenerator::new(); - - loop { - let alias = alias_gen.next(field.name().as_str()); - - if params.find(&alias).is_none() - && self.column_with_name(&alias).is_none() - { - return Arc::new(Field::new( - alias, - field.data_type().clone(), - field.is_nullable(), - )); - } - } - }) - .chain(params.iter().cloned()) - .collect::(); - - assert_eq!(fields.len(), self.fields().len() + params.len()); - - Ok(Schema::new_with_metadata(fields, self.metadata.clone())) - } -} - -impl ExtendSchema for Cow<'_, T> { - fn from_fields(params: Fields) -> Result { - Ok(Cow::Owned(T::from_fields(params)?)) - } - - fn extend(&self, params: Fields) -> Result { - Ok(Cow::Owned(self.as_ref().extend(params)?)) - } -} - -impl ExtendSchema for Arc { - fn from_fields(params: Fields) -> Result { - Ok(Arc::new(T::from_fields(params)?)) - } - - fn extend(&self, params: Fields) -> Result { - Ok(Arc::new(self.as_ref().extend(params)?)) - } -} - -impl ExtendSchema for ExtendableExprSchema<'_> { - fn from_fields(params: Fields) -> Result { - static EMPTY_DFSCHEMA: LazyLock = LazyLock::new(DFSchema::empty); - - Ok(ExtendableExprSchema { - fields_chain: vec![params], - outer_schema: &*EMPTY_DFSCHEMA, - }) - } - - fn extend(&self, params: Fields) -> Result { - Ok(ExtendableExprSchema { - fields_chain: std::iter::once(params) - .chain(self.fields_chain.iter().cloned()) - .collect(), - outer_schema: self.outer_schema, - }) - } -} - -/// A `&dyn ExprSchema` wrapper that supports adding the parameters of a lambda -#[derive(Clone, Debug)] -struct ExtendableExprSchema<'a> { - fields_chain: Vec, - outer_schema: &'a dyn ExprSchema, -} - -impl<'a> ExtendableExprSchema<'a> { - fn new(schema: &'a dyn ExprSchema) -> Self { - Self { - fields_chain: vec![], - outer_schema: schema, - } - } -} - -impl ExprSchema for ExtendableExprSchema<'_> { - fn field_from_column(&self, col: &datafusion_common::Column) -> Result<&Field> { - if col.relation.is_none() { - for fields in &self.fields_chain { - if let Some((_index, lambda_param)) = fields.find(&col.name) { - return Ok(lambda_param); - } - } - } - - self.outer_schema.field_from_column(col) - } } #[derive(Clone, Debug)] -pub enum ValueOrLambdaParameter<'a> { +pub enum ValueOrLambdaParameter { /// A columnar value with the given field Value(FieldRef), - /// A lambda with the given parameters names and a flag indicating wheter it captures any columns - Lambda(&'a [String], bool), + /// A lambda + Lambda, } impl From for ScalarUDF @@ -1331,111 +1066,6 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl { } } -fn lambda_parameters<'a>( - args: &'a [Expr], - schema: &dyn ExprSchema, -) -> Result>> { - args.iter() - .map(|e| match e { - Expr::Lambda(Lambda { params, body: _ }) => { - let mut captures = false; - - e.apply_with_lambdas_params(|expr, lambdas_params| match expr { - Expr::Column(c) if !c.is_lambda_parameter(lambdas_params) => { - captures = true; - - Ok(TreeNodeRecursion::Stop) - } - _ => Ok(TreeNodeRecursion::Continue), - }) - .unwrap(); - - Ok(ValueOrLambdaParameter::Lambda(params.as_slice(), captures)) - } - _ => Ok(ValueOrLambdaParameter::Value(e.to_field(schema)?.1)), - }) - .collect() -} - -/// Merge the lambda body captured columns with it's arguments -/// Datafusion relies on an unspecified field ordering implemented in this function -/// As such, this is the only correct way to merge the captured values with the arguments -/// The number of args should not be lower than the number of params -/// -/// See also merge_captures_with_lazy_args and merge_captures_with_boxed_lazy_args that lazily -/// computes only the necessary arguments to match the number of params -pub fn merge_captures_with_args( - captures: Option<&RecordBatch>, - params: &[FieldRef], - args: &[ArrayRef], -) -> Result { - if args.len() < params.len() { - return exec_err!( - "merge_captures_with_args called with {} params but with {} args", - params.len(), - args.len() - ); - } - - // the order of the merged batch must be kept in sync with ScalarFunction::lambdas_schemas variants - let (fields, columns) = match captures { - Some(captures) => { - let fields = captures - .schema() - .fields() - .iter() - .chain(params) - .cloned() - .collect::>(); - - let columns = [captures.columns(), args].concat(); - - (fields, columns) - } - None => (params.to_vec(), args.to_vec()), - }; - - Ok(RecordBatch::try_new( - Arc::new(Schema::new(fields)), - columns, - )?) -} - -/// Lazy version of merge_captures_with_args that receives closures to compute the arguments, -/// and calls only the necessary to match the number of params -pub fn merge_captures_with_lazy_args( - captures: Option<&RecordBatch>, - params: &[FieldRef], - args: &[&dyn Fn() -> Result], -) -> Result { - merge_captures_with_args( - captures, - params, - &args - .iter() - .take(params.len()) - .map(|arg| arg()) - .collect::>>()?, - ) -} - -/// Variation of merge_captures_with_lazy_args that take boxed closures -pub fn merge_captures_with_boxed_lazy_args( - captures: Option<&RecordBatch>, - params: &[FieldRef], - args: &[Box Result>], -) -> Result { - merge_captures_with_args( - captures, - params, - &args - .iter() - .take(params.len()) - .map(|arg| arg()) - .collect::>>()?, - ) -} - #[cfg(test)] mod tests { use super::*; @@ -1514,83 +1144,4 @@ mod tests { value.hash(hasher); hasher.finish() } - - use std::borrow::Cow; - - use arrow::datatypes::Fields; - - use crate::{ - tree_node::tests::{args, list_int, list_list_int, array_transform_udf}, - udf::{lambda_parameters, ExtendableExprSchema}, - }; - - #[test] - fn test_arguments_expr_schema() { - let args = args(); - let schema = list_list_int(); - - let schemas = array_transform_udf() - .arguments_expr_schema(&args, &schema) - .unwrap() - .into_iter() - .map(|s| format!("{s:?}")) - .collect::>(); - - let mut lambdas_parameters = array_transform_udf() - .inner() - .lambdas_parameters(&lambda_parameters(&args, &schema).unwrap()) - .unwrap(); - - assert_eq!( - schemas, - &[ - format!("{}", &list_list_int()), - format!( - "{:?}", - ExtendableExprSchema { - fields_chain: vec![Fields::from( - lambdas_parameters[0].take().unwrap() - )], - outer_schema: &list_list_int() - } - ), - ] - ) - } - - #[test] - fn test_arguments_arrow_schema() { - let list_int = list_int(); - let list_list_int = list_list_int(); - - let schemas = array_transform_udf() - .arguments_arrow_schema( - &lambda_parameters(&args(), &list_list_int).unwrap(), - //&[HashSet::new(), HashSet::from([0])], - list_list_int.as_arrow(), - ) - .unwrap(); - - assert_eq!( - schemas, - &[ - Cow::Borrowed(list_list_int.as_arrow()), - Cow::Owned(list_int.as_arrow().clone()) - ] - ) - } - - #[test] - fn test_arguments_schema_from_logical_args() { - let list_list_int = list_list_int(); - - let schemas = array_transform_udf() - .arguments_schema_from_logical_args(&args(), &list_list_int) - .unwrap(); - - assert_eq!( - schemas, - &[Cow::Borrowed(&list_list_int), Cow::Owned(list_int())] - ) - } } diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 93fcfaef882ff..1e5e4c9f5b99a 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -266,12 +266,10 @@ pub fn grouping_set_to_exprlist(group_expr: &[Expr]) -> Result> { /// Recursively walk an expression tree, collecting the unique set of columns /// referenced in the expression pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet) -> Result<()> { - expr.apply_with_lambdas_params(|expr, lambdas_params| { + expr.apply(|expr| { match expr { Expr::Column(qc) => { - if qc.relation.is_some() || !lambdas_params.contains(qc.name()) { - accum.insert(qc.clone()); - } + accum.insert(qc.clone()); } // Use explicit pattern match instead of a default // implementation, so that in the future if someone adds @@ -310,7 +308,8 @@ pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet) -> Result<()> { | Expr::Wildcard { .. } | Expr::Placeholder(_) | Expr::OuterReferenceColumn { .. } - | Expr::Lambda { .. } => {} + | Expr::Lambda(_) + | Expr::LambdaColumn(_) => {} } Ok(TreeNodeRecursion::Continue) }) @@ -653,7 +652,6 @@ where /// Search an `Expr`, and all of its nested `Expr`'s, for any that pass the /// provided test. The returned `Expr`'s are deduplicated and returned in order /// of appearance (depth first). -/// todo: document about that columns may refer to a lambda parameter? fn find_exprs_in_expr(expr: &Expr, test_fn: &F) -> Vec where F: Fn(&Expr) -> bool, @@ -676,7 +674,6 @@ where } /// Recursively inspect an [`Expr`] and all its children. -/// todo: document about that columns may refer to a lambda parameter? pub fn inspect_expr_pre(expr: &Expr, mut f: F) -> Result<(), E> where F: FnMut(&Expr) -> Result<(), E>, @@ -748,19 +745,13 @@ pub fn columnize_expr(e: Expr, input: &LogicalPlan) -> Result { _ => return Ok(e), }; let exprs_map: HashMap<&Expr, Column> = output_exprs.into_iter().collect(); - e.transform_down_with_lambdas_params(|node: Expr, lambdas_params| { - if matches!(&node, Expr::Column(c) if c.is_lambda_parameter(lambdas_params)) { - return Ok(Transformed::no(node)); - } - - match exprs_map.get(&node) { - Some(column) => Ok(Transformed::new( - Expr::Column(column.clone()), - true, - TreeNodeRecursion::Jump, - )), - None => Ok(Transformed::no(node)), - } + e.transform_down(|node: Expr| match exprs_map.get(&node) { + Some(column) => Ok(Transformed::new( + Expr::Column(column.clone()), + true, + TreeNodeRecursion::Jump, + )), + None => Ok(Transformed::no(node)), }) .data() } @@ -777,11 +768,9 @@ pub fn find_column_exprs(exprs: &[Expr]) -> Vec { pub(crate) fn find_columns_referenced_by_expr(e: &Expr) -> Vec { let mut exprs = vec![]; - e.apply_with_lambdas_params(|expr, lambdas_params| { + e.apply(|expr| { if let Expr::Column(c) = expr { - if !c.is_lambda_parameter(lambdas_params) { - exprs.push(c.clone()) - } + exprs.push(c.clone()) } Ok(TreeNodeRecursion::Continue) }) @@ -810,9 +799,9 @@ pub(crate) fn find_column_indexes_referenced_by_expr( schema: &DFSchemaRef, ) -> Vec { let mut indexes = vec![]; - e.apply_with_lambdas_params(|expr, lambdas_params| { + e.apply(|expr| { match expr { - Expr::Column(qc) if !qc.is_lambda_parameter(lambdas_params) => { + Expr::Column(qc) => { if let Ok(idx) = schema.index_of_column(qc) { indexes.push(idx); } diff --git a/datafusion/functions-nested/Cargo.toml b/datafusion/functions-nested/Cargo.toml index 6e0d1048f9697..0299aebdcac47 100644 --- a/datafusion/functions-nested/Cargo.toml +++ b/datafusion/functions-nested/Cargo.toml @@ -53,6 +53,7 @@ datafusion-functions = { workspace = true } datafusion-functions-aggregate = { workspace = true } datafusion-functions-aggregate-common = { workspace = true } datafusion-macros = { workspace = true } +datafusion-physical-expr = { workspace = true } datafusion-physical-expr-common = { workspace = true } itertools = { workspace = true, features = ["use_std"] } log = { workspace = true } diff --git a/datafusion/functions-nested/src/array_transform.rs b/datafusion/functions-nested/src/array_transform.rs index 700fed477b4cb..dbc08473c246b 100644 --- a/datafusion/functions-nested/src/array_transform.rs +++ b/datafusion/functions-nested/src/array_transform.rs @@ -18,17 +18,23 @@ //! [`ScalarUDFImpl`] definitions for array_transform function. use arrow::{ - array::{Array, ArrayRef, AsArray, FixedSizeListArray, LargeListArray, ListArray}, + array::{ + Array, ArrayRef, AsArray, FixedSizeListArray, LargeListArray, ListArray, + RecordBatch, RecordBatchOptions, + }, compute::take_record_batch, - datatypes::{DataType, Field}, + datatypes::{DataType, Field, FieldRef, Schema}, }; use datafusion_common::{ - HashSet, Result, exec_err, internal_err, tree_node::{Transformed, TreeNode}, utils::{elements_indices, list_indices, list_values, take_function_args} + HashMap, Result, exec_err, internal_err, tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter}, utils::{elements_indices, list_indices, list_values, take_function_args} }; use datafusion_expr::{ - ColumnarValue, Documentation, Expr, ScalarFunctionArgs, ScalarUDFImpl, Signature, ValueOrLambda, ValueOrLambdaField, ValueOrLambdaParameter, Volatility, expr::Lambda, merge_captures_with_lazy_args + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + ValueOrLambda, ValueOrLambdaField, ValueOrLambdaParameter, Volatility, }; use datafusion_macros::user_doc; +use datafusion_physical_expr::expressions::{LambdaColumn, LambdaExpr}; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use std::{any::Any, sync::Arc}; make_udf_expr_and_func!( @@ -115,7 +121,7 @@ impl ScalarUDFImpl for ArrayTransform { ); }; - //TODO: should metadata be passed? If so, with the same keys or prefixed/suffixed? + //TODO: should metadata be copied into the transformed array? // lambda is the resulting field of executing the lambda body // with the parameters returned in lambdas_parameters @@ -151,6 +157,7 @@ impl ScalarUDFImpl for ArrayTransform { }; let list_array = list_value.to_array(args.number_rows)?; + let list_values = list_values(&list_array)?; // if any column got captured, we need to adjust it to the values arrays, // duplicating values of list with mulitple values and removing values of empty lists @@ -163,23 +170,26 @@ impl ScalarUDFImpl for ArrayTransform { // use closures and merge_captures_with_lazy_args so that it calls only the needed ones based on the number of arguments // avoiding unnecessary computations - let values_param = || Ok(Arc::clone(list_values(&list_array)?)); + let values_param = || Ok(Arc::clone(list_values)); let indices_param = || elements_indices(&list_array); - // the order of the merged schema is an unspecified implementation detail that may change in the future, - // using this function is the correct way to merge as it return the correct ordering and will change in sync - // the implementation without the need for fixes. It also computes only the parameters requested - let lambda_batch = merge_captures_with_lazy_args( - adjusted_captures.as_ref(), - &lambda.params, // ScalarUDF already merged the fields returned in lambdas_parameters with the parameters names definied in the lambda, so we don't need to + let binded_body = bind_lambda_columns( + Arc::clone(&lambda.body), + &lambda.params, &[&values_param, &indices_param], )?; // call the transforming expression with the record batch composed of the list values merged with captured columns - let transformed_values = lambda - .body - .evaluate(&lambda_batch)? - .into_array(lambda_batch.num_rows())?; + let transformed_values = binded_body + .evaluate(&adjusted_captures.unwrap_or_else(|| { + RecordBatch::try_new_with_options( + Arc::new(Schema::empty()), + vec![], + &RecordBatchOptions::new().with_row_count(Some(list_values.len())), + ) + .unwrap() + }))? + .into_array(list_values.len())?; let field = match args.return_field.data_type() { DataType::List(field) @@ -233,7 +243,7 @@ impl ScalarUDFImpl for ArrayTransform { &self, args: &[ValueOrLambdaParameter], ) -> Result>>> { - let [ValueOrLambdaParameter::Value(list), ValueOrLambdaParameter::Lambda(_, _)] = + let [ValueOrLambdaParameter::Value(list), ValueOrLambdaParameter::Lambda] = args else { return exec_err!( @@ -253,9 +263,9 @@ impl ScalarUDFImpl for ArrayTransform { // we don't need to omit the index in the case the lambda don't specify, e.g. array_transform([], v -> v*2), // nor check whether the lambda contains more than two parameters, e.g. array_transform([], (v, i, j) -> v+i+j), // as datafusion will do that for us - let value = Field::new("value", field.data_type().clone(), field.is_nullable()) + let value = Field::new("", field.data_type().clone(), field.is_nullable()) .with_metadata(field.metadata().clone()); - let index = Field::new("index", index_type, false); + let index = Field::new("", index_type, false); Ok(vec![None, Some(vec![value, index])]) } @@ -264,3 +274,65 @@ impl ScalarUDFImpl for ArrayTransform { self.doc() } } + +fn bind_lambda_columns( + expr: Arc, + params: &[FieldRef], + args: &[&dyn Fn() -> Result], +) -> Result> { + let columns = std::iter::zip(params, args) + .map(|(param, arg)| Ok((param.name().as_str(), (arg()?, 0)))) + .collect::>>()?; + + expr.rewrite(&mut BindLambdaColumn::new(columns)).data() +} + +struct BindLambdaColumn<'a> { + columns: HashMap<&'a str, (ArrayRef, usize)>, +} + +impl<'a> BindLambdaColumn<'a> { + fn new(columns: HashMap<&'a str, (ArrayRef, usize)>) -> Self { + Self { columns } + } +} + +impl TreeNodeRewriter for BindLambdaColumn<'_> { + type Node = Arc; + + fn f_down(&mut self, node: Self::Node) -> Result> { + if let Some(lambda_column) = node.as_any().downcast_ref::() { + if let Some((value, shadows)) = self.columns.get(lambda_column.name()) { + if *shadows == 0 { + return Ok(Transformed::yes(Arc::new( + lambda_column.clone().with_value(value.clone()), + ))); + } + } + } else if let Some(inner_lambda) = node.as_any().downcast_ref::() { + for param in inner_lambda.params() { + if let Some((_value, shadows)) = self.columns.get_mut(param.as_str()) { + *shadows += 1; + } + } + + if self.columns.values().all(|(_value, shadows)| *shadows > 0) { + return Ok(Transformed::new(node, false, TreeNodeRecursion::Jump)) + } + } + + Ok(Transformed::no(node)) + } + + fn f_up(&mut self, node: Self::Node) -> Result> { + if let Some(inner_lambda) = node.as_any().downcast_ref::() { + for param in inner_lambda.params() { + if let Some((_value, shadows)) = self.columns.get_mut(param.as_str()) { + *shadows -= 1; + } + } + } + + Ok(Transformed::no(node)) + } +} diff --git a/datafusion/functions/src/core/union_tag.rs b/datafusion/functions/src/core/union_tag.rs index ecdebf66e0043..4832c368872bf 100644 --- a/datafusion/functions/src/core/union_tag.rs +++ b/datafusion/functions/src/core/union_tag.rs @@ -173,7 +173,6 @@ mod tests { fields, UnionMode::Dense, ); - let arg_fields = vec![Field::new("a", scalar.data_type(), false).into()]; let return_type = DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)); @@ -181,9 +180,9 @@ mod tests { let result = UnionTagFunc::new() .invoke_with_args(ScalarFunctionArgs { args: vec![ColumnarValue::Scalar(scalar)], - arg_fields, number_rows: 1, return_field: Field::new("res", return_type, true).into(), + arg_fields: vec![], config_options: Arc::new(ConfigOptions::default()), lambdas: None, }) @@ -198,7 +197,6 @@ mod tests { #[test] fn union_scalar_empty() { let scalar = ScalarValue::Union(None, UnionFields::empty(), UnionMode::Dense); - let arg_fields = vec![Field::new("a", scalar.data_type(), false).into()]; let return_type = DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)); @@ -206,9 +204,9 @@ mod tests { let result = UnionTagFunc::new() .invoke_with_args(ScalarFunctionArgs { args: vec![ColumnarValue::Scalar(scalar)], - arg_fields, number_rows: 1, return_field: Field::new("res", return_type, true).into(), + arg_fields: vec![], config_options: Arc::new(ConfigOptions::default()), lambdas: None, }) diff --git a/datafusion/optimizer/src/analyzer/function_rewrite.rs b/datafusion/optimizer/src/analyzer/function_rewrite.rs index 0e5e602f8238e..c6bf14ebce2e3 100644 --- a/datafusion/optimizer/src/analyzer/function_rewrite.rs +++ b/datafusion/optimizer/src/analyzer/function_rewrite.rs @@ -19,7 +19,7 @@ use super::AnalyzerRule; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::Transformed; +use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{DFSchema, Result}; use crate::utils::NamePreserver; @@ -64,16 +64,15 @@ impl ApplyFunctionRewrites { let original_name = name_preserver.save(&expr); // recursively transform the expression, applying the rewrites at each step - let transformed_expr = - expr.transform_up_with_schema(&schema, |expr, schema| { - let mut result = Transformed::no(expr); - for rewriter in self.function_rewrites.iter() { - result = result.transform_data(|expr| { - rewriter.rewrite(expr, schema, options) - })?; - } - Ok(result) - })?; + let transformed_expr = expr.transform_up(|expr| { + let mut result = Transformed::no(expr); + for rewriter in self.function_rewrites.iter() { + result = result.transform_data(|expr| { + rewriter.rewrite(expr, &schema, options) + })?; + } + Ok(result) + })?; Ok(transformed_expr.update_data(|expr| original_name.restore(expr))) }) diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 1b82182e8600f..763f693f2f607 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -20,7 +20,6 @@ use std::sync::Arc; use datafusion_expr::binary::BinaryTypeCoercer; -use datafusion_expr::tree_node::TreeNodeRewriterWithPayload; use itertools::{izip, Itertools as _}; use arrow::datatypes::{DataType, Field, IntervalUnit, Schema}; @@ -28,7 +27,7 @@ use arrow::datatypes::{DataType, Field, IntervalUnit, Schema}; use crate::analyzer::AnalyzerRule; use crate::utils::NamePreserver; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::Transformed; +use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter}; use datafusion_common::{ exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_datafusion_err, plan_err, Column, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, @@ -141,7 +140,7 @@ fn analyze_internal( // apply coercion rewrite all expressions in the plan individually plan.map_expressions(|expr| { let original_name = name_preserver.save(&expr); - expr.rewrite_with_schema(&schema, &mut expr_rewrite) + expr.rewrite(&mut expr_rewrite) .map(|transformed| transformed.update_data(|e| original_name.restore(e))) })? // some plans need extra coercion after their expressions are coerced @@ -305,11 +304,10 @@ impl<'a> TypeCoercionRewriter<'a> { } } -impl TreeNodeRewriterWithPayload for TypeCoercionRewriter<'_> { +impl TreeNodeRewriter for TypeCoercionRewriter<'_> { type Node = Expr; - type Payload<'a> = &'a DFSchema; - fn f_up(&mut self, expr: Expr, schema: &DFSchema) -> Result> { + fn f_up(&mut self, expr: Expr) -> Result> { match expr { Expr::Unnest(_) => not_impl_err!( "Unnest should be rewritten to LogicalPlan::Unnest before type coercion" @@ -320,7 +318,7 @@ impl TreeNodeRewriterWithPayload for TypeCoercionRewriter<'_> { spans, }) => { let new_plan = - analyze_internal(schema, Arc::unwrap_or_clone(subquery))?.data; + analyze_internal(self.schema, Arc::unwrap_or_clone(subquery))?.data; Ok(Transformed::yes(Expr::ScalarSubquery(Subquery { subquery: Arc::new(new_plan), outer_ref_columns, @@ -329,7 +327,7 @@ impl TreeNodeRewriterWithPayload for TypeCoercionRewriter<'_> { } Expr::Exists(Exists { subquery, negated }) => { let new_plan = analyze_internal( - schema, + self.schema, Arc::unwrap_or_clone(subquery.subquery), )? .data; @@ -348,11 +346,11 @@ impl TreeNodeRewriterWithPayload for TypeCoercionRewriter<'_> { negated, }) => { let new_plan = analyze_internal( - schema, + self.schema, Arc::unwrap_or_clone(subquery.subquery), )? .data; - let expr_type = expr.get_type(schema)?; + let expr_type = expr.get_type(self.schema)?; let subquery_type = new_plan.schema().field(0).data_type(); let common_type = comparison_coercion(&expr_type, subquery_type).ok_or( plan_datafusion_err!( @@ -365,32 +363,32 @@ impl TreeNodeRewriterWithPayload for TypeCoercionRewriter<'_> { spans: subquery.spans, }; Ok(Transformed::yes(Expr::InSubquery(InSubquery::new( - Box::new(expr.cast_to(&common_type, schema)?), + Box::new(expr.cast_to(&common_type, self.schema)?), cast_subquery(new_subquery, &common_type)?, negated, )))) } Expr::Not(expr) => Ok(Transformed::yes(not(get_casted_expr_for_bool_op( *expr, - schema, + self.schema, )?))), Expr::IsTrue(expr) => Ok(Transformed::yes(is_true( - get_casted_expr_for_bool_op(*expr, schema)?, + get_casted_expr_for_bool_op(*expr, self.schema)?, ))), Expr::IsNotTrue(expr) => Ok(Transformed::yes(is_not_true( - get_casted_expr_for_bool_op(*expr, schema)?, + get_casted_expr_for_bool_op(*expr, self.schema)?, ))), Expr::IsFalse(expr) => Ok(Transformed::yes(is_false( - get_casted_expr_for_bool_op(*expr, schema)?, + get_casted_expr_for_bool_op(*expr, self.schema)?, ))), Expr::IsNotFalse(expr) => Ok(Transformed::yes(is_not_false( - get_casted_expr_for_bool_op(*expr, schema)?, + get_casted_expr_for_bool_op(*expr, self.schema)?, ))), Expr::IsUnknown(expr) => Ok(Transformed::yes(is_unknown( - get_casted_expr_for_bool_op(*expr, schema)?, + get_casted_expr_for_bool_op(*expr, self.schema)?, ))), Expr::IsNotUnknown(expr) => Ok(Transformed::yes(is_not_unknown( - get_casted_expr_for_bool_op(*expr, schema)?, + get_casted_expr_for_bool_op(*expr, self.schema)?, ))), Expr::Like(Like { negated, @@ -399,8 +397,8 @@ impl TreeNodeRewriterWithPayload for TypeCoercionRewriter<'_> { escape_char, case_insensitive, }) => { - let left_type = expr.get_type(schema)?; - let right_type = pattern.get_type(schema)?; + let left_type = expr.get_type(self.schema)?; + let right_type = pattern.get_type(self.schema)?; let coerced_type = like_coercion(&left_type, &right_type).ok_or_else(|| { let op_name = if case_insensitive { "ILIKE" @@ -413,9 +411,9 @@ impl TreeNodeRewriterWithPayload for TypeCoercionRewriter<'_> { })?; let expr = match left_type { DataType::Dictionary(_, inner) if *inner == DataType::Utf8 => expr, - _ => Box::new(expr.cast_to(&coerced_type, schema)?), + _ => Box::new(expr.cast_to(&coerced_type, self.schema)?), }; - let pattern = Box::new(pattern.cast_to(&coerced_type, schema)?); + let pattern = Box::new(pattern.cast_to(&coerced_type, self.schema)?); Ok(Transformed::yes(Expr::Like(Like::new( negated, expr, @@ -426,7 +424,7 @@ impl TreeNodeRewriterWithPayload for TypeCoercionRewriter<'_> { } Expr::BinaryExpr(BinaryExpr { left, op, right }) => { let (left, right) = - self.coerce_binary_op(*left, schema, op, *right, schema)?; + self.coerce_binary_op(*left, self.schema, op, *right, self.schema)?; Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr::new( Box::new(left), op, @@ -439,15 +437,15 @@ impl TreeNodeRewriterWithPayload for TypeCoercionRewriter<'_> { low, high, }) => { - let expr_type = expr.get_type(schema)?; - let low_type = low.get_type(schema)?; + let expr_type = expr.get_type(self.schema)?; + let low_type = low.get_type(self.schema)?; let low_coerced_type = comparison_coercion(&expr_type, &low_type) .ok_or_else(|| { internal_datafusion_err!( "Failed to coerce types {expr_type} and {low_type} in BETWEEN expression" ) })?; - let high_type = high.get_type(schema)?; + let high_type = high.get_type(self.schema)?; let high_coerced_type = comparison_coercion(&expr_type, &high_type) .ok_or_else(|| { internal_datafusion_err!( @@ -462,10 +460,10 @@ impl TreeNodeRewriterWithPayload for TypeCoercionRewriter<'_> { ) })?; Ok(Transformed::yes(Expr::Between(Between::new( - Box::new(expr.cast_to(&coercion_type, schema)?), + Box::new(expr.cast_to(&coercion_type, self.schema)?), negated, - Box::new(low.cast_to(&coercion_type, schema)?), - Box::new(high.cast_to(&coercion_type, schema)?), + Box::new(low.cast_to(&coercion_type, self.schema)?), + Box::new(high.cast_to(&coercion_type, self.schema)?), )))) } Expr::InList(InList { @@ -473,10 +471,10 @@ impl TreeNodeRewriterWithPayload for TypeCoercionRewriter<'_> { list, negated, }) => { - let expr_data_type = expr.get_type(schema)?; + let expr_data_type = expr.get_type(self.schema)?; let list_data_types = list .iter() - .map(|list_expr| list_expr.get_type(schema)) + .map(|list_expr| list_expr.get_type(self.schema)) .collect::>>()?; let result_type = get_coerce_type_for_list(&expr_data_type, &list_data_types); @@ -486,11 +484,11 @@ impl TreeNodeRewriterWithPayload for TypeCoercionRewriter<'_> { ), Some(coerced_type) => { // find the coerced type - let cast_expr = expr.cast_to(&coerced_type, schema)?; + let cast_expr = expr.cast_to(&coerced_type, self.schema)?; let cast_list_expr = list .into_iter() .map(|list_expr| { - list_expr.cast_to(&coerced_type, schema) + list_expr.cast_to(&coerced_type, self.schema) }) .collect::>>()?; Ok(Transformed::yes(Expr::InList(InList ::new( @@ -502,13 +500,13 @@ impl TreeNodeRewriterWithPayload for TypeCoercionRewriter<'_> { } } Expr::Case(case) => { - let case = coerce_case_expression(case, schema)?; + let case = coerce_case_expression(case, self.schema)?; Ok(Transformed::yes(Expr::Case(case))) } Expr::ScalarFunction(ScalarFunction { func, args }) => { let new_expr = coerce_arguments_for_signature_with_scalar_udf( args, - schema, + self.schema, &func, )?; Ok(Transformed::yes(Expr::ScalarFunction( @@ -528,7 +526,7 @@ impl TreeNodeRewriterWithPayload for TypeCoercionRewriter<'_> { }) => { let new_expr = coerce_arguments_for_signature_with_aggregate_udf( args, - schema, + self.schema, &func, )?; Ok(Transformed::yes(Expr::AggregateFunction( @@ -557,13 +555,13 @@ impl TreeNodeRewriterWithPayload for TypeCoercionRewriter<'_> { }, } = *window_fun; let window_frame = - coerce_window_frame(window_frame, schema, &order_by)?; + coerce_window_frame(window_frame, self.schema, &order_by)?; let args = match &fun { expr::WindowFunctionDefinition::AggregateUDF(udf) => { coerce_arguments_for_signature_with_aggregate_udf( args, - schema, + self.schema, udf, )? } @@ -600,7 +598,8 @@ impl TreeNodeRewriterWithPayload for TypeCoercionRewriter<'_> { | Expr::GroupingSet(_) | Expr::Placeholder(_) | Expr::OuterReferenceColumn(_, _) - | Expr::Lambda { .. } => Ok(Transformed::no(expr)), + | Expr::Lambda(_) + | Expr::LambdaColumn(_) => Ok(Transformed::no(expr)), } } } @@ -1133,7 +1132,7 @@ mod test { use crate::analyzer::Analyzer; use crate::assert_analyzed_plan_with_config_eq_snapshot; use datafusion_common::config::ConfigOptions; - use datafusion_common::tree_node::{TransformedResult}; + use datafusion_common::tree_node::{TransformedResult, TreeNode}; use datafusion_common::{DFSchema, DFSchemaRef, Result, ScalarValue, Spans}; use datafusion_expr::expr::{self, InSubquery, Like, ScalarFunction}; use datafusion_expr::logical_plan::{EmptyRelation, Projection, Sort}; @@ -2084,7 +2083,7 @@ mod test { let mut rewriter = TypeCoercionRewriter { schema: &schema }; let expr = is_true(lit(12i32).gt(lit(13i64))); let expected = is_true(cast(lit(12i32), DataType::Int64).gt(lit(13i64))); - let result = expr.rewrite_with_schema(&schema, &mut rewriter).data()?; + let result = expr.rewrite(&mut rewriter).data()?; assert_eq!(expected, result); // eq @@ -2095,7 +2094,7 @@ mod test { let mut rewriter = TypeCoercionRewriter { schema: &schema }; let expr = is_true(lit(12i32).eq(lit(13i64))); let expected = is_true(cast(lit(12i32), DataType::Int64).eq(lit(13i64))); - let result = expr.rewrite_with_schema(&schema, &mut rewriter).data()?; + let result = expr.rewrite(&mut rewriter).data()?; assert_eq!(expected, result); // lt @@ -2106,7 +2105,7 @@ mod test { let mut rewriter = TypeCoercionRewriter { schema: &schema }; let expr = is_true(lit(12i32).lt(lit(13i64))); let expected = is_true(cast(lit(12i32), DataType::Int64).lt(lit(13i64))); - let result = expr.rewrite_with_schema(&schema, &mut rewriter).data()?; + let result = expr.rewrite(&mut rewriter).data()?; assert_eq!(expected, result); Ok(()) diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index e06ed6e547eb5..f95a0f908b813 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -29,7 +29,7 @@ use datafusion_common::alias::AliasGenerator; use datafusion_common::cse::{CSEController, FoundCommonNodes, CSE}; use datafusion_common::tree_node::{Transformed, TreeNode}; -use datafusion_common::{qualified_name, Column, DFSchema, DFSchemaRef, HashSet, Result}; +use datafusion_common::{qualified_name, Column, DFSchema, DFSchemaRef, Result}; use datafusion_expr::expr::{Alias, ScalarFunction}; use datafusion_expr::logical_plan::{ Aggregate, Filter, LogicalPlan, Projection, Sort, Window, @@ -632,7 +632,6 @@ struct ExprCSEController<'a> { // how many aliases have we seen so far alias_counter: usize, - lambdas_params: HashSet, } impl<'a> ExprCSEController<'a> { @@ -641,7 +640,6 @@ impl<'a> ExprCSEController<'a> { alias_generator, mask, alias_counter: 0, - lambdas_params: HashSet::new(), } } } @@ -695,30 +693,11 @@ impl CSEController for ExprCSEController<'_> { } } - fn visit_f_down(&mut self, node: &Expr) { - if let Expr::Lambda(lambda) = node { - self.lambdas_params - .extend(lambda.params.iter().cloned()); - } - } - - fn visit_f_up(&mut self, node: &Expr) { - if let Expr::Lambda(lambda) = node { - for param in &lambda.params { - self.lambdas_params.remove(param); - } - } - } - fn is_valid(node: &Expr) -> bool { - !node.is_volatile_node() + !node.is_volatile_node() && !matches!(node, Expr::LambdaColumn(_)) } fn is_ignored(&self, node: &Expr) -> bool { - if matches!(node, Expr::Column(c) if c.is_lambda_parameter(&self.lambdas_params)) { - return true - } - // TODO: remove the next line after `Expr::Wildcard` is removed #[expect(deprecated)] let is_normal_minus_aggregates = matches!( @@ -728,6 +707,7 @@ impl CSEController for ExprCSEController<'_> { | Expr::ScalarVariable(..) | Expr::Alias(..) | Expr::Wildcard { .. } + | Expr::LambdaColumn(_) ); let is_aggr = matches!(node, Expr::AggregateFunction(..)); diff --git a/datafusion/optimizer/src/decorrelate.rs b/datafusion/optimizer/src/decorrelate.rs index 0f43741834009..63236787743a4 100644 --- a/datafusion/optimizer/src/decorrelate.rs +++ b/datafusion/optimizer/src/decorrelate.rs @@ -527,17 +527,18 @@ fn proj_exprs_evaluation_result_on_empty_batch( for expr in proj_expr.iter() { let result_expr = expr .clone() - .transform_up_with_lambdas_params(|expr, lambdas_params| match &expr { - Expr::Column(col) if !col.is_lambda_parameter(lambdas_params) => { + .transform_up(|expr| { + if let Expr::Column(Column { name, .. }) = &expr { if let Some(result_expr) = - input_expr_result_map_for_count_bug.get(col.name()) + input_expr_result_map_for_count_bug.get(name) { Ok(Transformed::yes(result_expr.clone())) } else { Ok(Transformed::no(expr)) } + } else { + Ok(Transformed::no(expr)) } - _ => Ok(Transformed::no(expr)), }) .data()?; @@ -569,17 +570,16 @@ fn filter_exprs_evaluation_result_on_empty_batch( ) -> Result> { let result_expr = filter_expr .clone() - .transform_up_with_lambdas_params(|expr, lambdas_params| match &expr { - Expr::Column(col) if !col.is_lambda_parameter(lambdas_params) => { - if let Some(result_expr) = - input_expr_result_map_for_count_bug.get(col.name()) - { + .transform_up(|expr| { + if let Expr::Column(Column { name, .. }) = &expr { + if let Some(result_expr) = input_expr_result_map_for_count_bug.get(name) { Ok(Transformed::yes(result_expr.clone())) } else { Ok(Transformed::no(expr)) } + } else { + Ok(Transformed::no(expr)) } - _ => Ok(Transformed::no(expr)), }) .data()?; diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index f0187b618ccc0..5db71417bc8fd 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -639,7 +639,7 @@ fn is_expr_trivial(expr: &Expr) -> bool { /// --Source(a, b) /// ``` fn rewrite_expr(expr: Expr, input: &Projection) -> Result> { - expr.transform_up_with_lambdas_params(|expr, lambdas_params| { + expr.transform_up(|expr| { match expr { // remove any intermediate aliases if they do not carry metadata Expr::Alias(alias) => { @@ -653,7 +653,7 @@ fn rewrite_expr(expr: Expr, input: &Projection) -> Result> { false => Ok(Transformed::no(Expr::Alias(alias))), } } - Expr::Column(col) if !col.is_lambda_parameter(lambdas_params) => { + Expr::Column(col) => { // Find index of column: let idx = input.schema.index_of_column(&col)?; // get the corresponding unaliased input expression diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 54cb026543270..77c533ce2f01e 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -287,14 +287,15 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { | Expr::Cast(_) | Expr::TryCast(_) | Expr::InList { .. } - | Expr::ScalarFunction(_) => Ok(TreeNodeRecursion::Continue), + | Expr::ScalarFunction(_) + | Expr::Lambda(_) + | Expr::LambdaColumn(_) => Ok(TreeNodeRecursion::Continue), // TODO: remove the next line after `Expr::Wildcard` is removed #[expect(deprecated)] Expr::AggregateFunction(_) | Expr::WindowFunction(_) | Expr::Wildcard { .. } - | Expr::GroupingSet(_) - | Expr::Lambda { .. } => internal_err!("Unsupported predicate type"), + | Expr::GroupingSet(_) => internal_err!("Unsupported predicate type"), })?; Ok(is_evaluate) } @@ -1390,15 +1391,14 @@ pub fn replace_cols_by_name( e: Expr, replace_map: &HashMap, ) -> Result { - e.transform_up_with_lambdas_params(|expr, lambdas_params| { - Ok(match &expr { - Expr::Column(c) if !c.is_lambda_parameter(lambdas_params) => { - match replace_map.get(&c.flat_name()) { - Some(new_c) => Transformed::yes(new_c.clone()), - None => Transformed::no(expr), - } + e.transform_up(|expr| { + Ok(if let Expr::Column(c) = &expr { + match replace_map.get(&c.flat_name()) { + Some(new_c) => Transformed::yes(new_c.clone()), + None => Transformed::no(expr), } - _ => Transformed::no(expr), + } else { + Transformed::no(expr) }) }) .data() @@ -1407,18 +1407,17 @@ pub fn replace_cols_by_name( /// check whether the expression uses the columns in `check_map`. fn contain(e: &Expr, check_map: &HashMap) -> bool { let mut is_contain = false; - e.apply_with_lambdas_params(|expr, lambdas_params| { - Ok(match &expr { - Expr::Column(c) if !c.is_lambda_parameter(lambdas_params) => { - match check_map.get(&c.flat_name()) { - Some(_) => { - is_contain = true; - TreeNodeRecursion::Stop - } - None => TreeNodeRecursion::Continue, + e.apply(|expr| { + Ok(if let Expr::Column(c) = &expr { + match check_map.get(&c.flat_name()) { + Some(_) => { + is_contain = true; + TreeNodeRecursion::Stop } + None => TreeNodeRecursion::Continue, } - _ => TreeNodeRecursion::Continue, + } else { + TreeNodeRecursion::Continue }) }) .unwrap(); diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index f1e619750f9c8..48d1182527013 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -106,22 +106,17 @@ impl OptimizerRule for ScalarSubqueryToJoin { { if !expr_check_map.is_empty() { rewrite_expr = rewrite_expr - .transform_up_with_lambdas_params( - |expr, lambdas_params| { - // replace column references with entry in map, if it exists - if let Some(map_expr) = expr - .try_as_col() - .filter(|c| { - !c.is_lambda_parameter(lambdas_params) - }) - .and_then(|col| expr_check_map.get(&col.name)) - { - Ok(Transformed::yes(map_expr.clone())) - } else { - Ok(Transformed::no(expr)) - } - }, - ) + .transform_up(|expr| { + // replace column references with entry in map, if it exists + if let Some(map_expr) = expr + .try_as_col() + .and_then(|col| expr_check_map.get(&col.name)) + { + Ok(Transformed::yes(map_expr.clone())) + } else { + Ok(Transformed::no(expr)) + } + }) .data()?; } cur_input = optimized_subquery; @@ -176,26 +171,18 @@ impl OptimizerRule for ScalarSubqueryToJoin { { let new_expr = rewrite_expr .clone() - .transform_up_with_lambdas_params( - |expr, lambdas_params| { - // replace column references with entry in map, if it exists - if let Some(map_expr) = expr - .try_as_col() - .filter(|c| { - !c.is_lambda_parameter( - lambdas_params, - ) - }) - .and_then(|col| { - expr_check_map.get(&col.name) - }) - { - Ok(Transformed::yes(map_expr.clone())) - } else { - Ok(Transformed::no(expr)) - } - }, - ) + .transform_up(|expr| { + // replace column references with entry in map, if it exists + if let Some(map_expr) = + expr.try_as_col().and_then(|col| { + expr_check_map.get(&col.name) + }) + { + Ok(Transformed::yes(map_expr.clone())) + } else { + Ok(Transformed::no(expr)) + } + }) .data()?; expr_to_rewrite_expr_map.insert(expr, new_expr); } @@ -409,12 +396,8 @@ fn build_join( let mut expr_rewrite = TypeCoercionRewriter { schema: new_plan.schema(), }; - computation_project_expr.insert( - name, - computer_expr - .rewrite_with_schema(new_plan.schema(), &mut expr_rewrite) - .data()?, - ); + computation_project_expr + .insert(name, computer_expr.rewrite(&mut expr_rewrite).data()?); } } diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index a824f6b7be49f..779c0acea9963 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -27,20 +27,18 @@ use arrow::{ record_batch::RecordBatch, }; -use datafusion_common::tree_node::TreeNodeContainer; use datafusion_common::{ cast::{as_large_list_array, as_list_array}, metadata::FieldMetadata, - tree_node::{ - Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter, - }, + tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter}, }; use datafusion_common::{ - exec_datafusion_err, internal_err, DFSchema, DataFusionError, Result, ScalarValue, + exec_datafusion_err, internal_err, DFSchema, DataFusionError, Result, + ScalarValue, }; use datafusion_expr::{ - and, binary::BinaryTypeCoercer, lit, or, simplify::SimplifyContext, BinaryExpr, Case, - ColumnarValue, Expr, Like, Operator, Volatility, + and, binary::BinaryTypeCoercer, lit, or, BinaryExpr, Case, ColumnarValue, Expr, Like, + Operator, Volatility, }; use datafusion_expr::{expr::ScalarFunction, interval_arithmetic::NullableInterval}; use datafusion_expr::{ @@ -270,7 +268,7 @@ impl ExprSimplifier { /// documentation for more details on type coercion pub fn coerce(&self, expr: Expr, schema: &DFSchema) -> Result { let mut expr_rewrite = TypeCoercionRewriter { schema }; - expr.rewrite_with_schema(schema, &mut expr_rewrite).data() + expr.rewrite(&mut expr_rewrite).data() } /// Input guarantees about the values of columns. @@ -469,9 +467,11 @@ impl TreeNodeRewriter for Canonicalizer { }; match (left.as_ref(), right.as_ref(), op.swap()) { // - (Expr::Column(left_col), Expr::Column(right_col), Some(swapped_op)) - if right_col > left_col => - { + ( + left_col @ (Expr::Column(_) | Expr::LambdaColumn(_)), + right_col @ (Expr::Column(_) | Expr::LambdaColumn(_)), + Some(swapped_op), + ) if right_col > left_col => { Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr { left: right, op: swapped_op, @@ -479,13 +479,15 @@ impl TreeNodeRewriter for Canonicalizer { }))) } // - (Expr::Literal(_a, _), Expr::Column(_b), Some(swapped_op)) => { - Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr { - left: right, - op: swapped_op, - right: left, - }))) - } + ( + Expr::Literal(_, _), + Expr::Column(_) | Expr::LambdaColumn(_), + Some(swapped_op), + ) => Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr { + left: right, + op: swapped_op, + right: left, + }))), _ => Ok(Transformed::no(Expr::BinaryExpr(BinaryExpr { left, op, @@ -653,7 +655,7 @@ impl<'a> ConstEvaluator<'a> { | Expr::GroupingSet(_) | Expr::Wildcard { .. } | Expr::Placeholder(_) - | Expr::Lambda { .. } => false, + | Expr::LambdaColumn(_) => false, Expr::ScalarFunction(ScalarFunction { func, .. }) => { Self::volatility_ok(func.signature().volatility) } @@ -677,7 +679,8 @@ impl<'a> ConstEvaluator<'a> { | Expr::Case(_) | Expr::Cast { .. } | Expr::TryCast { .. } - | Expr::InList { .. } => true, + | Expr::InList { .. } + | Expr::Lambda(_) => true, } } @@ -758,89 +761,6 @@ impl<'a, S> Simplifier<'a, S> { impl TreeNodeRewriter for Simplifier<'_, S> { type Node = Expr; - fn f_down(&mut self, expr: Self::Node) -> Result> { - match expr { - Expr::ScalarFunction(ScalarFunction { func, args }) - if args.iter().any(|arg| matches!(arg, Expr::Lambda(_))) => - { - // there's currently no way to adapt a generic SimplifyInfo with lambda parameters, - // so, if the scalar function has any lambda, we materialize a DFSchema using all the - // columns references in every arguments. Than we can call lambdas_schemas_from_args, - // and for each argument, we create a new SimplifyContext with the scoped schema, and - // simplify the argument using this 'sub-context'. Finally, we set Transformed.tnr to - // Jump so the parent context doesn't try to simplify the argument again, without the - // parameters info - - // get all columns references - let mut columns_refs = HashSet::new(); - - for arg in &args { - arg.add_column_refs(&mut columns_refs); - } - - // materialize columns references into qualified fields - let qualified_fields = columns_refs - .into_iter() - .map(|captured_column| { - let expr = Expr::Column(captured_column.clone()); - - Ok(( - captured_column.relation.clone(), - Arc::new(Field::new( - captured_column.name(), - self.info.get_data_type(&expr)?, - self.info.nullable(&expr)?, - )), - )) - }) - .collect::>()?; - - // create a schema using the materialized fields - let dfschema = - DFSchema::new_with_metadata(qualified_fields, Default::default())?; - - let mut scoped_schemas = func - .arguments_schema_from_logical_args(&args, &dfschema)? - .into_iter(); - - let transformed_args = args - .map_elements(|arg| { - let scoped_schema = scoped_schemas.next().unwrap(); - - // create a sub-context, using the scoped schema, that includes information about the lambda parameters - let simplify_context = - SimplifyContext::new(self.info.execution_props()) - .with_schema(Arc::new(scoped_schema.into_owned())); - - let mut simplifier = Simplifier::new(&simplify_context); - - // simplify the argument using it's context - arg.rewrite(&mut simplifier) - })? - .update_data(|args| { - Expr::ScalarFunction(ScalarFunction { func, args }) - }); - - Ok(Transformed::new( - transformed_args.data, - transformed_args.transformed, - // return at least Jump so the parent contex doesn't try again to simplify the arguments - // (and fail because it doesn't contain info about lambdas paramters) - match transformed_args.tnr { - TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => { - TreeNodeRecursion::Jump - } - TreeNodeRecursion::Stop => TreeNodeRecursion::Stop, - }, - )) - - // Ok(transformed_args.update_data(|args| Expr::ScalarFunction(ScalarFunction { func, args}))) - } - // Expr::Lambda(_) => Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump)), - _ => Ok(Transformed::no(expr)), - } - } - /// rewrite the expression simplifying any constant expressions fn f_up(&mut self, expr: Expr) -> Result> { use datafusion_expr::Operator::{ @@ -2092,8 +2012,8 @@ fn are_inlist_and_eq(left: &Expr, right: &Expr) -> bool { let left = as_inlist(left); let right = as_inlist(right); if let (Some(lhs), Some(rhs)) = (left, right) { - matches!(lhs.expr.as_ref(), Expr::Column(_)) - && matches!(rhs.expr.as_ref(), Expr::Column(_)) + matches!(lhs.expr.as_ref(), Expr::Column(_) | Expr::LambdaColumn(_)) + && matches!(rhs.expr.as_ref(), Expr::Column(_) | Expr::LambdaColumn(_)) && lhs.expr == rhs.expr && !lhs.negated && !rhs.negated @@ -2108,16 +2028,20 @@ fn as_inlist(expr: &'_ Expr) -> Option> { Expr::InList(inlist) => Some(Cow::Borrowed(inlist)), Expr::BinaryExpr(BinaryExpr { left, op, right }) if *op == Operator::Eq => { match (left.as_ref(), right.as_ref()) { - (Expr::Column(_), Expr::Literal(_, _)) => Some(Cow::Owned(InList { - expr: left.clone(), - list: vec![*right.clone()], - negated: false, - })), - (Expr::Literal(_, _), Expr::Column(_)) => Some(Cow::Owned(InList { - expr: right.clone(), - list: vec![*left.clone()], - negated: false, - })), + (Expr::Column(_) | Expr::LambdaColumn(_), Expr::Literal(_, _)) => { + Some(Cow::Owned(InList { + expr: left.clone(), + list: vec![*right.clone()], + negated: false, + })) + } + (Expr::Literal(_, _), Expr::Column(_) | Expr::LambdaColumn(_)) => { + Some(Cow::Owned(InList { + expr: right.clone(), + list: vec![*left.clone()], + negated: false, + })) + } _ => None, } } @@ -2133,16 +2057,20 @@ fn to_inlist(expr: Expr) -> Option { op: Operator::Eq, right, }) => match (left.as_ref(), right.as_ref()) { - (Expr::Column(_), Expr::Literal(_, _)) => Some(InList { - expr: left, - list: vec![*right], - negated: false, - }), - (Expr::Literal(_, _), Expr::Column(_)) => Some(InList { - expr: right, - list: vec![*left], - negated: false, - }), + (Expr::Column(_) | Expr::LambdaColumn(_), Expr::Literal(_, _)) => { + Some(InList { + expr: left, + list: vec![*right], + negated: false, + }) + } + (Expr::Literal(_, _), Expr::Column(_) | Expr::LambdaColumn(_)) => { + Some(InList { + expr: right, + list: vec![*left], + negated: false, + }) + } _ => None, }, _ => None, diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index d0ae4932628f3..81763fa0552fb 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -23,7 +23,7 @@ use crate::analyzer::type_coercion::TypeCoercionRewriter; use arrow::array::{new_null_array, Array, RecordBatch}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::cast::as_boolean_array; -use datafusion_common::tree_node::TransformedResult; +use datafusion_common::tree_node::{TransformedResult, TreeNode}; use datafusion_common::{Column, DFSchema, Result, ScalarValue}; use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::expr_rewriter::replace_col; @@ -148,7 +148,7 @@ fn evaluate_expr_with_null_column<'a>( fn coerce(expr: Expr, schema: &DFSchema) -> Result { let mut expr_rewrite = TypeCoercionRewriter { schema }; - expr.rewrite_with_schema(schema, &mut expr_rewrite).data() + expr.rewrite(&mut expr_rewrite).data() } #[cfg(test)] diff --git a/datafusion/physical-expr-adapter/src/schema_rewriter.rs b/datafusion/physical-expr-adapter/src/schema_rewriter.rs index 4a81a5c99ac75..d1957ae1892ea 100644 --- a/datafusion/physical-expr-adapter/src/schema_rewriter.rs +++ b/datafusion/physical-expr-adapter/src/schema_rewriter.rs @@ -21,14 +21,13 @@ use std::sync::Arc; use arrow::compute::can_cast_types; use arrow::datatypes::{DataType, FieldRef, Schema, SchemaRef}; -use datafusion_common::HashSet; +use datafusion_common::tree_node::TreeNode; use datafusion_common::{ exec_err, tree_node::{Transformed, TransformedResult}, Result, ScalarValue, }; use datafusion_functions::core::getfield::GetFieldFunc; -use datafusion_physical_expr::PhysicalExprExt; use datafusion_physical_expr::{ expressions::{self, CastExpr, Column}, ScalarFunctionExpr, @@ -219,10 +218,8 @@ impl PhysicalExprAdapter for DefaultPhysicalExprAdapter { physical_file_schema: &self.physical_file_schema, partition_fields: &self.partition_values, }; - expr.transform_with_lambdas_params(|expr, lambdas_params| { - rewriter.rewrite_expr(Arc::clone(&expr), lambdas_params) - }) - .data() + expr.transform(|expr| rewriter.rewrite_expr(Arc::clone(&expr))) + .data() } fn with_partition_values( @@ -246,18 +243,13 @@ impl<'a> DefaultPhysicalExprAdapterRewriter<'a> { fn rewrite_expr( &self, expr: Arc, - lambdas_params: &HashSet, ) -> Result>> { - if let Some(transformed) = - self.try_rewrite_struct_field_access(&expr, lambdas_params)? - { + if let Some(transformed) = self.try_rewrite_struct_field_access(&expr)? { return Ok(Transformed::yes(transformed)); } if let Some(column) = expr.as_any().downcast_ref::() { - if !lambdas_params.contains(column.name()) { - return self.rewrite_column(Arc::clone(&expr), column); - } + return self.rewrite_column(Arc::clone(&expr), column); } Ok(Transformed::no(expr)) @@ -269,7 +261,6 @@ impl<'a> DefaultPhysicalExprAdapterRewriter<'a> { fn try_rewrite_struct_field_access( &self, expr: &Arc, - lambdas_params: &HashSet, ) -> Result>> { let get_field_expr = match ScalarFunctionExpr::try_downcast_func::(expr.as_ref()) { @@ -301,8 +292,8 @@ impl<'a> DefaultPhysicalExprAdapterRewriter<'a> { }; let column = match source_expr.as_any().downcast_ref::() { - Some(column) if !lambdas_params.contains(column.name()) => column, - _ => return Ok(None), + Some(column) => column, + None => return Ok(None), }; let physical_field = @@ -456,7 +447,6 @@ mod tests { use super::*; use arrow::array::{RecordBatch, RecordBatchOptions}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; - use datafusion_common::hashbrown::HashSet; use datafusion_common::{assert_contains, record_batch, Result, ScalarValue}; use datafusion_expr::Operator; use datafusion_physical_expr::expressions::{col, lit, CastExpr, Column, Literal}; @@ -863,9 +853,7 @@ mod tests { // Test that when a field exists in physical schema, it returns None let column = Arc::new(Column::new("struct_col", 0)) as Arc; - let result = rewriter - .try_rewrite_struct_field_access(&column, &HashSet::new()) - .unwrap(); + let result = rewriter.try_rewrite_struct_field_access(&column).unwrap(); assert!(result.is_none()); // The actual test for the get_field expression would require creating a proper ScalarFunctionExpr diff --git a/datafusion/physical-expr/Cargo.toml b/datafusion/physical-expr/Cargo.toml index d4c0e1cbe6eb7..b7654a0f6f603 100644 --- a/datafusion/physical-expr/Cargo.toml +++ b/datafusion/physical-expr/Cargo.toml @@ -37,9 +37,6 @@ workspace = true [lib] name = "datafusion_physical_expr" -[features] -recursive_protection = ["dep:recursive"] - [dependencies] ahash = { workspace = true } arrow = { workspace = true } @@ -55,7 +52,6 @@ itertools = { workspace = true, features = ["use_std"] } parking_lot = { workspace = true } paste = "^1.0" petgraph = "0.8.3" -recursive = { workspace = true, optional = true } [dev-dependencies] arrow = { workspace = true, features = ["test_utils"] } diff --git a/datafusion/physical-expr/src/expressions/column.rs b/datafusion/physical-expr/src/expressions/column.rs index c55f42ae333bc..7040fa2bfc9b4 100644 --- a/datafusion/physical-expr/src/expressions/column.rs +++ b/datafusion/physical-expr/src/expressions/column.rs @@ -22,13 +22,12 @@ use std::hash::Hash; use std::sync::Arc; use crate::physical_expr::PhysicalExpr; -use crate::PhysicalExprExt; use arrow::datatypes::FieldRef; use arrow::{ datatypes::{DataType, Schema, SchemaRef}, record_batch::RecordBatch, }; -use datafusion_common::tree_node::{Transformed, TransformedResult}; +use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{internal_err, plan_err, Result}; use datafusion_expr::ColumnarValue; @@ -68,8 +67,7 @@ use datafusion_expr::ColumnarValue; pub struct Column { /// The name of the column (used for debugging and display purposes) name: String, - /// The index of the column in its schema. - /// Within a lambda body, this refer to the lambda scoped schema, not the plan schema. + /// The index of the column in its schema index: usize, } @@ -180,9 +178,9 @@ pub fn with_new_schema( expr: Arc, schema: &SchemaRef, ) -> Result> { - expr.transform_up_with_lambdas_params(|expr, lambdas_params| { - match expr.as_any().downcast_ref::() { - Some(col) if !lambdas_params.contains(col.name()) => { + Ok(expr + .transform_up(|expr| { + if let Some(col) = expr.as_any().downcast_ref::() { let idx = col.index(); let Some(field) = schema.fields().get(idx) else { return plan_err!( @@ -192,11 +190,11 @@ pub fn with_new_schema( let new_col = Column::new(field.name(), idx); Ok(Transformed::yes(Arc::new(new_col) as _)) + } else { + Ok(Transformed::no(expr)) } - _ => Ok(Transformed::no(expr)), - } - }) - .data() + })? + .data) } #[cfg(test)] diff --git a/datafusion/physical-expr/src/expressions/lambda.rs b/datafusion/physical-expr/src/expressions/lambda.rs index 55110fdf5bf6b..38b64e3c7f3e1 100644 --- a/datafusion/physical-expr/src/expressions/lambda.rs +++ b/datafusion/physical-expr/src/expressions/lambda.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! Physical column reference: [`Column`] +//! Physical lambda expression: [`LambdaExpr`] use std::hash::Hash; use std::sync::Arc; @@ -23,16 +23,15 @@ use std::{any::Any, sync::OnceLock}; use crate::expressions::Column; use crate::physical_expr::PhysicalExpr; -use crate::PhysicalExprExt; use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, }; -use datafusion_common::tree_node::TreeNodeRecursion; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion_common::{internal_err, HashSet, Result}; use datafusion_expr::ColumnarValue; -/// Represents a lambda with the given parameters name and body +/// Represents a lambda with the given parameters names and body #[derive(Debug, Eq, Clone)] pub struct LambdaExpr { params: Vec, @@ -79,16 +78,14 @@ impl LambdaExpr { let mut indices = HashSet::new(); self.body - .apply_with_lambdas_params(|expr, lambdas_params| { + .apply(|expr| { if let Some(column) = expr.as_any().downcast_ref::() { - if !lambdas_params.contains(column.name()) { - indices.insert(column.index()); - } + indices.insert(column.index()); } Ok(TreeNodeRecursion::Continue) }) - .unwrap(); + .expect("closure should be infallibe"); indices }) @@ -106,12 +103,12 @@ impl PhysicalExpr for LambdaExpr { self } - fn data_type(&self, _input_schema: &Schema) -> Result { - Ok(DataType::Null) + fn data_type(&self, input_schema: &Schema) -> Result { + self.body.data_type(input_schema) } - fn nullable(&self, _input_schema: &Schema) -> Result { - Ok(true) + fn nullable(&self, input_schema: &Schema) -> Result { + self.body.nullable(input_schema) } fn evaluate(&self, _batch: &RecordBatch) -> Result { diff --git a/datafusion/physical-expr/src/expressions/lambda_column.rs b/datafusion/physical-expr/src/expressions/lambda_column.rs new file mode 100644 index 0000000000000..4aed16186ba6f --- /dev/null +++ b/datafusion/physical-expr/src/expressions/lambda_column.rs @@ -0,0 +1,136 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Physical lambda column reference: [`LambdaColumn`] + +use std::any::Any; +use std::hash::Hash; +use std::sync::Arc; + +use crate::physical_expr::PhysicalExpr; +use arrow::array::ArrayRef; +use arrow::datatypes::FieldRef; +use arrow::{ + datatypes::{DataType, Schema}, + record_batch::RecordBatch, +}; +use datafusion_common::{Result, exec_datafusion_err}; +use datafusion_expr::ColumnarValue; + +/// Represents the lambda column with a given name and field +#[derive(Debug, Clone)] +pub struct LambdaColumn { + name: String, + field: FieldRef, + value: Option, +} + +impl Eq for LambdaColumn {} + +impl PartialEq for LambdaColumn { + fn eq(&self, other: &Self) -> bool { + self.name == other.name && self.field == other.field + } +} + +impl Hash for LambdaColumn { + fn hash(&self, state: &mut H) { + self.name.hash(state); + self.field.hash(state); + } +} + +impl LambdaColumn { + /// Create a new lambda column expression + pub fn new(name: &str, field: FieldRef) -> Self { + Self { + name: name.to_owned(), + field, + value: None, + } + } + + /// Get the column's name + pub fn name(&self) -> &str { + &self.name + } + + /// Get the column's field + pub fn field(&self) -> &FieldRef { + &self.field + } + + pub fn with_value(self, value: ArrayRef) -> Self { + Self { + name: self.name, + field: self.field, + value: Some(ColumnarValue::Array(value)), + } + } +} + +impl std::fmt::Display for LambdaColumn { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{}@-1", self.name) + } +} + +impl PhysicalExpr for LambdaColumn { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + /// Get the data type of this expression, given the schema of the input + fn data_type(&self, _input_schema: &Schema) -> Result { + Ok(self.field.data_type().clone()) + } + + /// Decide whether this expression is nullable, given the schema of the input + fn nullable(&self, _input_schema: &Schema) -> Result { + Ok(self.field.is_nullable()) + } + + /// Evaluate the expression + fn evaluate(&self, _batch: &RecordBatch) -> Result { + self.value.clone().ok_or_else(|| exec_datafusion_err!("Physical LambdaColumn {} missing value", self.name)) + } + + fn return_field(&self, _input_schema: &Schema) -> Result { + Ok(Arc::clone(&self.field)) + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + _children: Vec>, + ) -> Result> { + Ok(self) + } + + fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.name) + } +} + +/// Create a column expression +pub fn lambda_col(name: &str, field: FieldRef) -> Result> { + Ok(Arc::new(LambdaColumn::new(name, field))) +} diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index e87941da5ef4c..5d044ab848550 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -23,6 +23,7 @@ mod case; mod cast; mod cast_column; mod column; +mod lambda_column; mod dynamic_filters; mod in_list; mod is_not_null; @@ -45,6 +46,7 @@ pub use case::{case, CaseExpr}; pub use cast::{cast, CastExpr}; pub use cast_column::CastColumnExpr; pub use column::{col, with_new_schema, Column}; +pub use lambda_column::{lambda_col, LambdaColumn}; pub use datafusion_expr::utils::format_state_name; pub use dynamic_filters::DynamicFilterPhysicalExpr; pub use in_list::{in_list, InListExpr}; diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index 873205f28bef4..aa8c9e50fd71e 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -70,8 +70,6 @@ pub use datafusion_physical_expr_common::sort_expr::{ pub use planner::{create_physical_expr, create_physical_exprs}; pub use scalar_function::ScalarFunctionExpr; -pub use scalar_function::PhysicalExprExt; - pub use simplifier::PhysicalExprSimplifier; pub use utils::{conjunction, conjunction_opt, split_conjunction}; diff --git a/datafusion/physical-expr/src/physical_expr.rs b/datafusion/physical-expr/src/physical_expr.rs index 2584fc22885c2..c658a8eddc233 100644 --- a/datafusion/physical-expr/src/physical_expr.rs +++ b/datafusion/physical-expr/src/physical_expr.rs @@ -18,11 +18,11 @@ use std::sync::Arc; use crate::expressions::{self, Column}; -use crate::{create_physical_expr, LexOrdering, PhysicalExprExt, PhysicalSortExpr}; +use crate::{create_physical_expr, LexOrdering, PhysicalSortExpr}; use arrow::compute::SortOptions; use arrow::datatypes::{Schema, SchemaRef}; -use datafusion_common::tree_node::{Transformed, TransformedResult}; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{plan_err, Result}; use datafusion_common::{DFSchema, HashMap}; use datafusion_expr::execution_props::ExecutionProps; @@ -38,14 +38,14 @@ pub fn add_offset_to_expr( expr: Arc, offset: isize, ) -> Result> { - expr.transform_down_with_lambdas_params(|e, lambdas_params| match e.as_any().downcast_ref::() { - Some(col) if !lambdas_params.contains(col.name()) => { + expr.transform_down(|e| match e.as_any().downcast_ref::() { + Some(col) => { let Some(idx) = col.index().checked_add_signed(offset) else { return plan_err!("Column index overflow"); }; Ok(Transformed::yes(Arc::new(Column::new(col.name(), idx)))) } - _ => Ok(Transformed::no(e)), + None => Ok(Transformed::no(e)), }) .data() } diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index 0119c81b8ed94..4c3d1352cce0f 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -17,7 +17,7 @@ use std::sync::Arc; -use crate::expressions::LambdaExpr; +use crate::expressions::{lambda_col, LambdaExpr}; use crate::ScalarFunctionExpr; use crate::{ expressions::{self, binary, like, similar_to, Column, Literal}, @@ -28,10 +28,13 @@ use arrow::datatypes::Schema; use datafusion_common::config::ConfigOptions; use datafusion_common::metadata::FieldMetadata; use datafusion_common::{ - exec_err, not_impl_err, plan_err, DFSchema, Result, ScalarValue, ToDFSchema, + exec_err, not_impl_err, plan_err, DFSchema, Result, ScalarValue, + ToDFSchema, }; use datafusion_expr::execution_props::ExecutionProps; -use datafusion_expr::expr::{Alias, Cast, InList, Lambda, Placeholder, ScalarFunction}; +use datafusion_expr::expr::{ + Alias, Cast, InList, Lambda, LambdaColumn, Placeholder, ScalarFunction, +}; use datafusion_expr::var_provider::is_system_variables; use datafusion_expr::var_provider::VarType; use datafusion_expr::{ @@ -105,8 +108,7 @@ use datafusion_expr::{ /// /// * `e` - The logical expression /// * `input_dfschema` - The DataFusion schema for the input, used to resolve `Column` references -/// to qualified or unqualified fields by name. Note that for creating a lambda, this must be -/// scoped lambda schema, and not the outer schema +/// to qualified or unqualified fields by name. pub fn create_physical_expr( e: &Expr, input_dfschema: &DFSchema, @@ -316,27 +318,13 @@ pub fn create_physical_expr( input_dfschema, execution_props, )?), - Expr::Lambda { .. } => { - exec_err!("Expr::Lambda should be handled by Expr::ScalarFunction, as it can only exist within it") - } + Expr::Lambda(Lambda { params, body }) => Ok(Arc::new(LambdaExpr::new( + params.clone(), + create_physical_expr(body, input_dfschema, execution_props)?, + ))), Expr::ScalarFunction(ScalarFunction { func, args }) => { - let lambdas_schemas = - func.arguments_schema_from_logical_args(args, input_dfschema)?; - - let physical_args = std::iter::zip(args, lambdas_schemas) - .map(|(expr, schema)| match expr { - Expr::Lambda(Lambda { params, body }) => { - Ok(Arc::new(LambdaExpr::new( - params.clone(), - create_physical_expr(body, &schema, execution_props)?, - )) as Arc) - } - expr => create_physical_expr(expr, &schema, execution_props), - }) - .collect::>>()?; - - //let physical_args = - // create_physical_exprs(args, input_dfschema, execution_props)?; + let physical_args = + create_physical_exprs(args, input_dfschema, execution_props)?; let config_options = match execution_props.config_options.as_ref() { Some(config_options) => Arc::clone(config_options), @@ -404,6 +392,14 @@ pub fn create_physical_expr( Expr::Placeholder(Placeholder { id, .. }) => { exec_err!("Placeholder '{id}' was not provided a value for execution.") } + Expr::LambdaColumn(LambdaColumn { + name, + field, + spans: _, + }) => lambda_col( + name, + Arc::clone(field), + ), other => { not_impl_err!("Physical plan does not support logical expression {other:?}") } diff --git a/datafusion/physical-expr/src/projection.rs b/datafusion/physical-expr/src/projection.rs index 70be717a8436c..a120ab427e1de 100644 --- a/datafusion/physical-expr/src/projection.rs +++ b/datafusion/physical-expr/src/projection.rs @@ -20,11 +20,11 @@ use std::sync::Arc; use crate::expressions::Column; use crate::utils::collect_columns; -use crate::{PhysicalExpr, PhysicalExprExt}; +use crate::PhysicalExpr; use arrow::datatypes::{Field, Schema, SchemaRef}; use datafusion_common::stats::{ColumnStatistics, Precision}; -use datafusion_common::tree_node::{Transformed, TransformedResult}; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{internal_datafusion_err, internal_err, plan_err, Result}; use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; @@ -499,16 +499,13 @@ pub fn update_expr( let mut state = RewriteState::Unchanged; let new_expr = Arc::clone(expr) - .transform_up_with_lambdas_params(|expr, lambdas_params| { + .transform_up(|expr| { if state == RewriteState::RewrittenInvalid { return Ok(Transformed::no(expr)); } - let column = match expr.as_any().downcast_ref::() { - Some(column) if !lambdas_params.contains(column.name()) => column, - _ => { - return Ok(Transformed::no(expr)); - } + let Some(column) = expr.as_any().downcast_ref::() else { + return Ok(Transformed::no(expr)); }; if sync_with_child { state = RewriteState::RewrittenValid; @@ -619,14 +616,14 @@ impl ProjectionMapping { let mut map = IndexMap::<_, ProjectionTargets>::new(); for (expr_idx, (expr, name)) in expr.into_iter().enumerate() { let target_expr = Arc::new(Column::new(&name, expr_idx)) as _; - let source_expr = expr.transform_down_with_schema(input_schema, |e, schema| match e.as_any().downcast_ref::() { + let source_expr = expr.transform_down(|e| match e.as_any().downcast_ref::() { Some(col) => { - // Sometimes, an expression and its name in the schema + // Sometimes, an expression and its name in the input_schema // doesn't match. This can cause problems, so we make sure - // that the expression name matches with the name in `schema`. + // that the expression name matches with the name in `input_schema`. // Conceptually, `source_expr` and `expression` should be the same. let idx = col.index(); - let matching_field = schema.field(idx); + let matching_field = input_schema.field(idx); let matching_name = matching_field.name(); if col.name() != matching_name { return internal_err!( @@ -740,25 +737,21 @@ pub fn project_ordering( ) -> Option { let mut projected_exprs = vec![]; for PhysicalSortExpr { expr, options } in ordering.iter() { - let transformed = - Arc::clone(expr).transform_up_with_lambdas_params(|expr, lambdas_params| { - let col = match expr.as_any().downcast_ref::() { - Some(col) if !lambdas_params.contains(col.name()) => col, - _ => { - return Ok(Transformed::no(expr)); - } - }; + let transformed = Arc::clone(expr).transform_up(|expr| { + let Some(col) = expr.as_any().downcast_ref::() else { + return Ok(Transformed::no(expr)); + }; - let name = col.name(); - if let Some((idx, _)) = schema.column_with_name(name) { - // Compute the new column expression (with correct index) after projection: - Ok(Transformed::yes(Arc::new(Column::new(name, idx)))) - } else { - // Cannot find expression in the projected_schema, - // signal this using an Err result - plan_err!("") - } - }); + let name = col.name(); + if let Some((idx, _)) = schema.column_with_name(name) { + // Compute the new column expression (with correct index) after projection: + Ok(Transformed::yes(Arc::new(Column::new(name, idx)))) + } else { + // Cannot find expression in the projected_schema, + // signal this using an Err result + plan_err!("") + } + }); match transformed { Ok(transformed) => { diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index 22fa300f05df4..2527e84241fe3 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -30,19 +30,17 @@ //! to a function that supports f64, it is coerced to f64. use std::any::Any; -use std::borrow::Cow; use std::fmt::{self, Debug, Formatter}; use std::hash::{Hash, Hasher}; use std::sync::Arc; -use crate::expressions::{Column, LambdaExpr, Literal}; +use crate::expressions::{LambdaExpr, Literal}; use crate::PhysicalExpr; use arrow::array::{Array, NullArray, RecordBatch}; use arrow::datatypes::{DataType, Field, FieldRef, Schema}; use datafusion_common::config::{ConfigEntry, ConfigOptions}; -use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; -use datafusion_common::{internal_err, HashSet, Result, ScalarValue}; +use datafusion_common::{exec_err, internal_err, Result, ScalarValue}; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::sort_properties::ExprProperties; use datafusion_expr::type_coercion::functions::data_types_with_scalar_udf; @@ -96,16 +94,10 @@ impl ScalarFunctionExpr { schema: &Schema, config_options: Arc, ) -> Result { - let lambdas_schemas = lambdas_schemas_from_args(&fun, &args, schema)?; - - let arg_fields = std::iter::zip(&args, lambdas_schemas) - .map(|(e, schema)| { - if let Some(lambda) = e.as_any().downcast_ref::() { - lambda.body().return_field(&schema) - } else { - e.return_field(&schema) - } - }) + let name = fun.name().to_string(); + let arg_fields = args + .iter() + .map(|e| e.return_field(schema)) .collect::>>()?; // verify that input data types is consistent with function's `TypeSignature` @@ -137,7 +129,6 @@ impl ScalarFunctionExpr { }; let return_field = fun.return_field_from_args(ret_args)?; - let name = fun.name().to_string(); Ok(Self { fun, @@ -300,23 +291,7 @@ impl PhysicalExpr for ScalarFunctionExpr { let args_metadata = std::iter::zip(&self.args, &arg_fields) .map( |(expr, field)| match expr.as_any().downcast_ref::() { - Some(lambda) => { - let mut captures = false; - - expr.apply_with_lambdas_params(|expr, lambdas_params| { - match expr.as_any().downcast_ref::() { - Some(col) if !lambdas_params.contains(col.name()) => { - captures = true; - - Ok(TreeNodeRecursion::Stop) - } - _ => Ok(TreeNodeRecursion::Continue), - } - }) - .unwrap(); - - ValueOrLambdaParameter::Lambda(lambda.params(), captures) - } + Some(_lambda) => ValueOrLambdaParameter::Lambda, None => ValueOrLambdaParameter::Value(Arc::clone(field)), }, ) @@ -326,44 +301,30 @@ impl PhysicalExpr for ScalarFunctionExpr { let lambdas = std::iter::zip(&self.args, params) .map(|(arg, lambda_params)| { - arg.as_any() - .downcast_ref::() - .map(|lambda| { - let mut indices = HashSet::new(); - - arg.apply_with_lambdas_params(|expr, lambdas_params| { - if let Some(column) = - expr.as_any().downcast_ref::() - { - if !lambdas_params.contains(column.name()) { - indices.insert( - column.index(), //batch - // .schema_ref() - // .index_of(column.name())?, - ); - } - } - - Ok(TreeNodeRecursion::Continue) - })?; - - //let mut indices = indices.into_iter().collect::>(); - - //indices.sort_unstable(); - - let params = - std::iter::zip(lambda.params(), lambda_params.unwrap()) - .map(|(name, param)| Arc::new(param.with_name(name))) - .collect(); - - let captures = if !indices.is_empty() { + match (arg.as_any().downcast_ref::(), lambda_params) { + (Some(lambda), Some(lambda_params)) => { + if lambda.params().len() > lambda_params.len() { + return exec_err!( + "lambda defined {} params but UDF support only {}", + lambda.params().len(), + lambda_params.len() + ); + } + + let captures = lambda.captures(); + + let params = std::iter::zip(lambda.params(), lambda_params) + .map(|(name, param)| Arc::new(param.with_name(name))) + .collect(); + + let captures = if !captures.is_empty() { let (fields, columns): (Vec<_>, _) = std::iter::zip( batch.schema_ref().fields(), batch.columns(), ) .enumerate() .map(|(column_index, (field, column))| { - if indices.contains(&column_index) { + if captures.contains(&column_index) { (Arc::clone(field), Arc::clone(column)) } else { ( @@ -381,18 +342,26 @@ impl PhysicalExpr for ScalarFunctionExpr { let schema = Arc::new(Schema::new(fields)); Some(RecordBatch::try_new(schema, columns)?) - //Some(batch.project(&indices)?) } else { None }; - Ok(ScalarFunctionLambdaArg { + Ok(Some(ScalarFunctionLambdaArg { params, body: Arc::clone(lambda.body()), captures, - }) - }) - .transpose() + })) + } + (Some(_lambda), None) => exec_err!( + "{} don't reported the parameters of one of it's lambdas", + self.fun.name() + ), + (None, Some(_lambda_params)) => exec_err!( + "{} reported parameters for an argument that is not a lambda", + self.fun.name() + ), + _ => Ok(None), + } }) .collect::>>()?; @@ -493,373 +462,17 @@ impl PhysicalExpr for ScalarFunctionExpr { } } -pub fn lambdas_schemas_from_args<'a>( - fun: &ScalarUDF, - args: &[Arc], - schema: &'a Schema, -) -> Result>> { - let args_metadata = args - .iter() - .map(|e| match e.as_any().downcast_ref::() { - Some(lambda) => { - let mut captures = false; - - e.apply_with_lambdas_params(|expr, lambdas_params| { - match expr.as_any().downcast_ref::() { - Some(col) if !lambdas_params.contains(col.name()) => { - captures = true; - - Ok(TreeNodeRecursion::Stop) - } - _ => Ok(TreeNodeRecursion::Continue), - } - }) - .unwrap(); - - Ok(ValueOrLambdaParameter::Lambda(lambda.params(), captures)) - } - None => Ok(ValueOrLambdaParameter::Value(e.return_field(schema)?)), - }) - .collect::>>()?; - - /*let captures = args - .iter() - .map(|arg| { - if arg.as_any().is::() { - let mut columns = HashSet::new(); - - arg.apply_with_lambdas_params(|n, lambdas_params| { - if let Some(column) = n.as_any().downcast_ref::() { - if !lambdas_params.contains(column.name()) { - columns.insert(schema.index_of(column.name())?); - } - // columns.insert(column.index()); - } - - Ok(TreeNodeRecursion::Continue) - })?; - - Ok(columns) - } else { - Ok(HashSet::new()) - } - }) - .collect::>>()?; */ - - fun.arguments_arrow_schema(&args_metadata, schema) -} - -pub trait PhysicalExprExt: Sized { - fn apply_with_lambdas_params< - 'n, - F: FnMut(&'n Self, &HashSet<&'n str>) -> Result, - >( - &'n self, - f: F, - ) -> Result; - - fn apply_with_schema<'n, F: FnMut(&'n Self, &Schema) -> Result>( - &'n self, - schema: &Schema, - f: F, - ) -> Result; - - fn apply_children_with_schema< - 'n, - F: FnMut(&'n Self, &Schema) -> Result, - >( - &'n self, - schema: &Schema, - f: F, - ) -> Result; - - fn transform_down_with_schema Result>>( - self, - schema: &Schema, - f: F, - ) -> Result>; - - fn transform_up_with_schema Result>>( - self, - schema: &Schema, - f: F, - ) -> Result>; - - fn transform_with_schema Result>>( - self, - schema: &Schema, - f: F, - ) -> Result> { - self.transform_up_with_schema(schema, f) - } - - fn transform_down_with_lambdas_params( - self, - f: impl FnMut(Self, &HashSet) -> Result>, - ) -> Result>; - - fn transform_up_with_lambdas_params( - self, - f: impl FnMut(Self, &HashSet) -> Result>, - ) -> Result>; - - fn transform_with_lambdas_params( - self, - f: impl FnMut(Self, &HashSet) -> Result>, - ) -> Result> { - self.transform_up_with_lambdas_params(f) - } -} - -impl PhysicalExprExt for Arc { - fn apply_with_lambdas_params< - 'n, - F: FnMut(&'n Self, &HashSet<&'n str>) -> Result, - >( - &'n self, - mut f: F, - ) -> Result { - #[cfg_attr(feature = "recursive_protection", recursive::recursive)] - fn apply_with_lambdas_params_impl< - 'n, - F: FnMut( - &'n Arc, - &HashSet<&'n str>, - ) -> Result, - >( - node: &'n Arc, - args: &HashSet<&'n str>, - f: &mut F, - ) -> Result { - match node.as_any().downcast_ref::() { - Some(lambda) => { - let mut args = args.clone(); - - args.extend(lambda.params().iter().map(|v| v.as_str())); - - f(node, &args)?.visit_children(|| { - node.apply_children(|c| { - apply_with_lambdas_params_impl(c, &args, f) - }) - }) - } - _ => f(node, args)?.visit_children(|| { - node.apply_children(|c| apply_with_lambdas_params_impl(c, args, f)) - }), - } - } - - apply_with_lambdas_params_impl(self, &HashSet::new(), &mut f) - } - - fn apply_with_schema<'n, F: FnMut(&'n Self, &Schema) -> Result>( - &'n self, - schema: &Schema, - mut f: F, - ) -> Result { - #[cfg_attr(feature = "recursive_protection", recursive::recursive)] - fn apply_with_lambdas_impl< - 'n, - F: FnMut(&'n Arc, &Schema) -> Result, - >( - node: &'n Arc, - schema: &Schema, - f: &mut F, - ) -> Result { - f(node, schema)?.visit_children(|| { - node.apply_children_with_schema(schema, |c, schema| { - apply_with_lambdas_impl(c, schema, f) - }) - }) - } - - apply_with_lambdas_impl(self, schema, &mut f) - } - - fn apply_children_with_schema< - 'n, - F: FnMut(&'n Self, &Schema) -> Result, - >( - &'n self, - schema: &Schema, - mut f: F, - ) -> Result { - match self.as_any().downcast_ref::() { - Some(scalar_function) - if scalar_function - .args() - .iter() - .any(|arg| arg.as_any().is::()) => - { - let mut lambdas_schemas = lambdas_schemas_from_args( - scalar_function.fun(), - scalar_function.args(), - schema, - )? - .into_iter(); - - self.apply_children(|expr| f(expr, &lambdas_schemas.next().unwrap())) - } - _ => self.apply_children(|e| f(e, schema)), - } - } - - fn transform_down_with_schema< - F: FnMut(Self, &Schema) -> Result>, - >( - self, - schema: &Schema, - mut f: F, - ) -> Result> { - #[cfg_attr(feature = "recursive_protection", recursive::recursive)] - fn transform_down_with_schema_impl< - F: FnMut( - Arc, - &Schema, - ) -> Result>>, - >( - node: Arc, - schema: &Schema, - f: &mut F, - ) -> Result>> { - f(node, schema)?.transform_children(|node| { - map_children_with_schema(node, schema, |n, schema| { - transform_down_with_schema_impl(n, schema, f) - }) - }) - } - - transform_down_with_schema_impl(self, schema, &mut f) - } - - fn transform_up_with_schema Result>>( - self, - schema: &Schema, - mut f: F, - ) -> Result> { - #[cfg_attr(feature = "recursive_protection", recursive::recursive)] - fn transform_up_with_schema_impl< - F: FnMut( - Arc, - &Schema, - ) -> Result>>, - >( - node: Arc, - schema: &Schema, - f: &mut F, - ) -> Result>> { - map_children_with_schema(node, schema, |n, schema| { - transform_up_with_schema_impl(n, schema, f) - })? - .transform_parent(|n| f(n, schema)) - } - - transform_up_with_schema_impl(self, schema, &mut f) - } - - fn transform_up_with_lambdas_params( - self, - mut f: impl FnMut(Self, &HashSet) -> Result>, - ) -> Result> { - #[cfg_attr(feature = "recursive_protection", recursive::recursive)] - fn transform_up_with_lambdas_params_impl< - F: FnMut( - Arc, - &HashSet, - ) -> Result>>, - >( - node: Arc, - params: &HashSet, - f: &mut F, - ) -> Result>> { - map_children_with_lambdas_params(node, params, |n, params| { - transform_up_with_lambdas_params_impl(n, params, f) - })? - .transform_parent(|n| f(n, params)) - } - - transform_up_with_lambdas_params_impl(self, &HashSet::new(), &mut f) - } - - fn transform_down_with_lambdas_params( - self, - mut f: impl FnMut(Self, &HashSet) -> Result>, - ) -> Result> { - #[cfg_attr(feature = "recursive_protection", recursive::recursive)] - fn transform_down_with_lambdas_params_impl< - F: FnMut( - Arc, - &HashSet, - ) -> Result>>, - >( - node: Arc, - params: &HashSet, - f: &mut F, - ) -> Result>> { - f(node, params)?.transform_children(|node| { - map_children_with_lambdas_params(node, params, |node, args| { - transform_down_with_lambdas_params_impl(node, args, f) - }) - }) - } - - transform_down_with_lambdas_params_impl(self, &HashSet::new(), &mut f) - } -} - -fn map_children_with_schema( - node: Arc, - schema: &Schema, - mut f: impl FnMut( - Arc, - &Schema, - ) -> Result>>, -) -> Result>> { - match node.as_any().downcast_ref::() { - Some(fun) if fun.args().iter().any(|arg| arg.as_any().is::()) => { - let mut args_schemas = - lambdas_schemas_from_args(fun.fun(), fun.args(), schema)?.into_iter(); - - node.map_children(|node| f(node, &args_schemas.next().unwrap())) - } - _ => node.map_children(|node| f(node, schema)), - } -} - -fn map_children_with_lambdas_params( - node: Arc, - params: &HashSet, - mut f: impl FnMut( - Arc, - &HashSet, - ) -> Result>>, -) -> Result>> { - match node.as_any().downcast_ref::() { - Some(lambda) => { - let mut params = params.clone(); - - params.extend(lambda.params().iter().cloned()); - - node.map_children(|node| f(node, ¶ms)) - } - None => node.map_children(|node| f(node, params)), - } -} - #[cfg(test)] mod tests { use std::any::Any; - use std::{borrow::Cow, sync::Arc}; + use std::sync::Arc; use super::*; - use super::{lambdas_schemas_from_args, PhysicalExprExt}; use crate::expressions::Column; - use crate::{create_physical_expr, ScalarFunctionExpr}; + use crate::ScalarFunctionExpr; use arrow::datatypes::{DataType, Field, Schema}; - use datafusion_common::{tree_node::TreeNodeRecursion, DFSchema, HashSet, Result}; - use datafusion_expr::{ - col, expr::Lambda, Expr, ScalarFunctionArgs, ValueOrLambdaParameter, Volatility, - }; + use datafusion_common::Result; + use datafusion_expr::{ScalarFunctionArgs, Volatility}; use datafusion_expr::{ScalarUDF, ScalarUDFImpl, Signature}; use datafusion_expr_common::columnar_value::ColumnarValue; use datafusion_physical_expr_common::physical_expr::is_volatile; @@ -935,190 +548,4 @@ mod tests { let stable_arc: Arc = Arc::new(stable_expr); assert!(!is_volatile(&stable_arc)); } - - fn list_list_int() -> Schema { - Schema::new(vec![Field::new( - "v", - DataType::new_list(DataType::new_list(DataType::Int32, false), false), - false, - )]) - } - - fn list_int() -> Schema { - Schema::new(vec![Field::new( - "v", - DataType::new_list(DataType::Int32, false), - false, - )]) - } - - fn int() -> Schema { - Schema::new(vec![Field::new("v", DataType::Int32, false)]) - } - - fn array_transform_udf() -> ScalarUDF { - ScalarUDF::new_from_impl(ArrayTransformFunc::new()) - } - - fn args() -> Vec { - vec![ - col("v"), - Expr::Lambda(Lambda::new( - vec!["v".into()], - array_transform_udf().call(vec![ - col("v"), - Expr::Lambda(Lambda::new(vec!["v".into()], -col("v"))), - ]), - )), - ] - } - - // array_transform(v, |v| -> array_transform(v, |v| -> -v)) - fn array_transform() -> Arc { - let e = array_transform_udf().call(args()); - - create_physical_expr( - &e, - &DFSchema::try_from(list_list_int()).unwrap(), - &Default::default(), - ) - .unwrap() - } - - #[derive(Debug, PartialEq, Eq, Hash)] - struct ArrayTransformFunc { - signature: Signature, - } - - impl ArrayTransformFunc { - pub fn new() -> Self { - Self { - signature: Signature::any(2, Volatility::Immutable), - } - } - } - - impl ScalarUDFImpl for ArrayTransformFunc { - fn as_any(&self) -> &dyn std::any::Any { - self - } - - fn name(&self) -> &str { - "array_transform" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(arg_types[0].clone()) - } - - fn lambdas_parameters( - &self, - args: &[ValueOrLambdaParameter], - ) -> Result>>> { - let ValueOrLambdaParameter::Value(value_field) = &args[0] else { - unimplemented!() - }; - let DataType::List(field) = value_field.data_type() else { - unimplemented!() - }; - - Ok(vec![ - None, - Some(vec![Field::new( - "", - field.data_type().clone(), - field.is_nullable(), - )]), - ]) - } - - fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { - unimplemented!() - } - } - - #[test] - fn test_lambdas_schemas_from_args() { - let schema = list_list_int(); - let expr = array_transform(); - - let args = expr - .as_any() - .downcast_ref::() - .unwrap() - .args(); - - let schemas = - lambdas_schemas_from_args(&array_transform_udf(), args, &schema).unwrap(); - - assert_eq!(schemas, &[Cow::Borrowed(&schema), Cow::Owned(list_int())]); - } - - #[test] - fn test_apply_with_schema() { - let mut steps = vec![]; - - array_transform() - .apply_with_schema(&list_list_int(), |node, schema| { - steps.push((node.to_string(), schema.clone())); - - Ok(TreeNodeRecursion::Continue) - }) - .unwrap(); - - let expected = [ - ( - "array_transform(v@0, (v) -> array_transform(v@0, (v) -> (- v@0)))", - list_list_int(), - ), - ("(v) -> array_transform(v@0, (v) -> (- v@0))", list_int()), - ("array_transform(v@0, (v) -> (- v@0))", list_int()), - ("(v) -> (- v@0)", int()), - ("(- v@0)", int()), - ("v@0", int()), - ("v@0", int()), - ("v@0", int()), - ] - .map(|(a, b)| (String::from(a), b)); - - assert_eq!(steps, expected); - } - - #[test] - fn test_apply_with_lambdas_params() { - let array_transform = array_transform(); - let mut steps = vec![]; - - array_transform - .apply_with_lambdas_params(|node, params| { - steps.push((node.to_string(), params.clone())); - - Ok(TreeNodeRecursion::Continue) - }) - .unwrap(); - - let expected = [ - ( - "array_transform(v@0, (v) -> array_transform(v@0, (v) -> (- v@0)))", - HashSet::from(["v"]), - ), - ( - "(v) -> array_transform(v@0, (v) -> (- v@0))", - HashSet::from(["v"]), - ), - ("array_transform(v@0, (v) -> (- v@0))", HashSet::from(["v"])), - ("(v) -> (- v@0)", HashSet::from(["v"])), - ("(- v@0)", HashSet::from(["v"])), - ("v@0", HashSet::from(["v"])), - ("v@0", HashSet::from(["v"])), - ("v@0", HashSet::from(["v"])), - ] - .map(|(a, b)| (String::from(a), b)); - - assert_eq!(steps, expected); - } } diff --git a/datafusion/physical-expr/src/simplifier/mod.rs b/datafusion/physical-expr/src/simplifier/mod.rs index dd7e6e314672f..80d6ee0a7b914 100644 --- a/datafusion/physical-expr/src/simplifier/mod.rs +++ b/datafusion/physical-expr/src/simplifier/mod.rs @@ -19,12 +19,12 @@ use arrow::datatypes::Schema; use datafusion_common::{ - tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter}, + tree_node::{Transformed, TreeNode, TreeNodeRewriter}, Result, }; use std::sync::Arc; -use crate::{PhysicalExpr, PhysicalExprExt}; +use crate::PhysicalExpr; pub mod unwrap_cast; @@ -48,22 +48,6 @@ impl<'a> PhysicalExprSimplifier<'a> { &mut self, expr: Arc, ) -> Result> { - return expr - .transform_up_with_schema(self.schema, |node, schema| { - // Apply unwrap cast optimization - #[cfg(test)] - let original_type = node.data_type(schema).unwrap(); - let unwrapped = unwrap_cast::unwrap_cast_in_comparison(node, schema)?; - #[cfg(test)] - assert_eq!( - unwrapped.data.data_type(schema).unwrap(), - original_type, - "Simplified expression should have the same data type as the original" - ); - Ok(unwrapped) - }) - .data(); - Ok(expr.rewrite(self)?.data) } } diff --git a/datafusion/physical-expr/src/simplifier/unwrap_cast.rs b/datafusion/physical-expr/src/simplifier/unwrap_cast.rs index 1ccfc1cfe84d8..d409ce9cb5bf2 100644 --- a/datafusion/physical-expr/src/simplifier/unwrap_cast.rs +++ b/datafusion/physical-expr/src/simplifier/unwrap_cast.rs @@ -34,22 +34,22 @@ use std::sync::Arc; use arrow::datatypes::{DataType, Schema}; -use datafusion_common::{tree_node::Transformed, Result, ScalarValue}; +use datafusion_common::{ + tree_node::{Transformed, TreeNode}, + Result, ScalarValue, +}; use datafusion_expr::Operator; use datafusion_expr_common::casts::try_cast_literal_to_type; +use crate::expressions::{lit, BinaryExpr, CastExpr, Literal, TryCastExpr}; use crate::PhysicalExpr; -use crate::{ - expressions::{lit, BinaryExpr, CastExpr, Literal, TryCastExpr}, - PhysicalExprExt, -}; /// Attempts to unwrap casts in comparison expressions. pub(crate) fn unwrap_cast_in_comparison( expr: Arc, schema: &Schema, ) -> Result>> { - expr.transform_down_with_schema(schema, |e, schema| { + expr.transform_down(|e| { if let Some(binary) = e.as_any().downcast_ref::() { if let Some(unwrapped) = try_unwrap_cast_binary(binary, schema)? { return Ok(Transformed::yes(unwrapped)); diff --git a/datafusion/physical-expr/src/utils/mod.rs b/datafusion/physical-expr/src/utils/mod.rs index 92ecbb7176dc9..745ae855efee2 100644 --- a/datafusion/physical-expr/src/utils/mod.rs +++ b/datafusion/physical-expr/src/utils/mod.rs @@ -22,7 +22,6 @@ use std::borrow::Borrow; use std::sync::Arc; use crate::expressions::{BinaryExpr, Column}; -use crate::scalar_function::PhysicalExprExt; use crate::tree_node::ExprContext; use crate::PhysicalExpr; use crate::PhysicalSortExpr; @@ -228,11 +227,9 @@ where /// Recursively extract referenced [`Column`]s within a [`PhysicalExpr`]. pub fn collect_columns(expr: &Arc) -> HashSet { let mut columns = HashSet::::new(); - expr.apply_with_lambdas_params(|expr, lambdas_params| { + expr.apply(|expr| { if let Some(column) = expr.as_any().downcast_ref::() { - if !lambdas_params.contains(column.name()) { - columns.get_or_insert_owned(column); - } + columns.get_or_insert_owned(column); } Ok(TreeNodeRecursion::Continue) }) @@ -254,16 +251,14 @@ pub fn reassign_expr_columns( expr: Arc, schema: &Schema, ) -> Result> { - expr.transform_down_with_lambdas_params(|expr, lambdas_params| { + expr.transform_down(|expr| { if let Some(column) = expr.as_any().downcast_ref::() { - if !lambdas_params.contains(column.name()) { - let index = schema.index_of(column.name())?; + let index = schema.index_of(column.name())?; - return Ok(Transformed::yes(Arc::new(Column::new( - column.name(), - index, - )))); - } + return Ok(Transformed::yes(Arc::new(Column::new( + column.name(), + index, + )))); } Ok(Transformed::no(expr)) }) diff --git a/datafusion/physical-optimizer/src/enforce_sorting/sort_pushdown.rs b/datafusion/physical-optimizer/src/enforce_sorting/sort_pushdown.rs index d87e001946414..6e4e784866129 100644 --- a/datafusion/physical-optimizer/src/enforce_sorting/sort_pushdown.rs +++ b/datafusion/physical-optimizer/src/enforce_sorting/sort_pushdown.rs @@ -29,7 +29,7 @@ use datafusion_expr::JoinType; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::utils::collect_columns; use datafusion_physical_expr::{ - add_offset_to_physical_sort_exprs, EquivalenceProperties, PhysicalExprExt, + add_offset_to_physical_sort_exprs, EquivalenceProperties, }; use datafusion_physical_expr_common::sort_expr::{ LexOrdering, LexRequirement, OrderingRequirements, PhysicalSortExpr, @@ -661,21 +661,20 @@ fn handle_custom_pushdown( .into_iter() .map(|req| { let child_schema = plan.children()[maintained_child_idx].schema(); - let updated_columns = - req.expr - .transform_up_with_lambdas_params(|expr, lambdas_params| { - match expr.as_any().downcast_ref::() { - Some(col) if !lambdas_params.contains(col.name()) => { - let new_index = col.index() - sub_offset; - Ok(Transformed::yes(Arc::new(Column::new( - child_schema.field(new_index).name(), - new_index, - )))) - } - _ => Ok(Transformed::no(expr)), - } - })? - .data; + let updated_columns = req + .expr + .transform_up(|expr| { + if let Some(col) = expr.as_any().downcast_ref::() { + let new_index = col.index() - sub_offset; + Ok(Transformed::yes(Arc::new(Column::new( + child_schema.field(new_index).name(), + new_index, + )))) + } else { + Ok(Transformed::no(expr)) + } + })? + .data; Ok(PhysicalSortRequirement::new(updated_columns, req.options)) }) .collect::>>()?; @@ -743,21 +742,20 @@ fn handle_hash_join( .into_iter() .map(|req| { let child_schema = plan.children()[1].schema(); - let updated_columns = - req.expr - .transform_up_with_lambdas_params(|expr, lambdas_params| { - match expr.as_any().downcast_ref::() { - Some(col) if !lambdas_params.contains(col.name()) => { - let index = projected_indices[col.index()].index; - Ok(Transformed::yes(Arc::new(Column::new( - child_schema.field(index).name(), - index, - )))) - } - _ => Ok(Transformed::no(expr)), - } - })? - .data; + let updated_columns = req + .expr + .transform_up(|expr| { + if let Some(col) = expr.as_any().downcast_ref::() { + let index = projected_indices[col.index()].index; + Ok(Transformed::yes(Arc::new(Column::new( + child_schema.field(index).name(), + index, + )))) + } else { + Ok(Transformed::no(expr)) + } + })? + .data; Ok(PhysicalSortRequirement::new(updated_columns, req.options)) }) .collect::>>()?; diff --git a/datafusion/physical-optimizer/src/projection_pushdown.rs b/datafusion/physical-optimizer/src/projection_pushdown.rs index 8ed81d3874d64..987e3cb6f713e 100644 --- a/datafusion/physical-optimizer/src/projection_pushdown.rs +++ b/datafusion/physical-optimizer/src/projection_pushdown.rs @@ -23,7 +23,6 @@ use crate::PhysicalOptimizerRule; use arrow::datatypes::{Fields, Schema, SchemaRef}; use datafusion_common::alias::AliasGenerator; -use datafusion_physical_expr::PhysicalExprExt; use std::collections::HashSet; use std::sync::Arc; @@ -244,11 +243,9 @@ fn minimize_join_filter( rhs_schema: &Schema, ) -> JoinFilter { let mut used_columns = HashSet::new(); - expr.apply_with_lambdas_params(|expr, lambdas_params| { + expr.apply(|expr| { if let Some(col) = expr.as_any().downcast_ref::() { - if !lambdas_params.contains(col.name()) { - used_columns.insert(col.index()); - } + used_columns.insert(col.index()); } Ok(TreeNodeRecursion::Continue) }) @@ -270,19 +267,17 @@ fn minimize_join_filter( .collect::(); let final_expr = expr - .transform_up_with_lambdas_params(|expr, lambdas_params| { - match expr.as_any().downcast_ref::() { - Some(column) if !lambdas_params.contains(column.name()) => { - let new_idx = used_columns - .iter() - .filter(|idx| **idx < column.index()) - .count(); - let new_column = Column::new(column.name(), new_idx); - Ok(Transformed::yes( - Arc::new(new_column) as Arc - )) - } - _ => Ok(Transformed::no(expr)), + .transform_up(|expr| match expr.as_any().downcast_ref::() { + None => Ok(Transformed::no(expr)), + Some(column) => { + let new_idx = used_columns + .iter() + .filter(|idx| **idx < column.index()) + .count(); + let new_column = Column::new(column.name(), new_idx); + Ok(Transformed::yes( + Arc::new(new_column) as Arc + )) } }) .expect("Closure cannot fail"); @@ -385,9 +380,10 @@ impl<'a> JoinFilterRewriter<'a> { // First, add a new projection. The expression must be rewritten, as it is no longer // executed against the filter schema. let new_idx = self.join_side_projections.len(); - let rewritten_expr = expr.transform_up_with_lambdas_params(|expr, lambdas_params| { + let rewritten_expr = expr.transform_up(|expr| { Ok(match expr.as_any().downcast_ref::() { - Some(column) if !lambdas_params.contains(column.name()) => { + None => Transformed::no(expr), + Some(column) => { let intermediate_column = &self.intermediate_column_indices[column.index()]; assert_eq!(intermediate_column.side, self.join_side); @@ -397,7 +393,6 @@ impl<'a> JoinFilterRewriter<'a> { let new_column = Column::new(field.name(), join_side_index); Transformed::yes(Arc::new(new_column) as Arc) } - _ => Transformed::no(expr), }) })?; self.join_side_projections.push((rewritten_expr.data, name)); @@ -420,17 +415,15 @@ impl<'a> JoinFilterRewriter<'a> { join_side: JoinSide, ) -> Result { let mut result = false; - expr.apply_with_lambdas_params(|expr, lambdas_params| { - match expr.as_any().downcast_ref::() { - Some(c) if !lambdas_params.contains(c.name()) => { - let column_index = &self.intermediate_column_indices[c.index()]; - if column_index.side == join_side { - result = true; - return Ok(TreeNodeRecursion::Stop); - } - Ok(TreeNodeRecursion::Continue) + expr.apply(|expr| match expr.as_any().downcast_ref::() { + None => Ok(TreeNodeRecursion::Continue), + Some(c) => { + let column_index = &self.intermediate_column_indices[c.index()]; + if column_index.side == join_side { + result = true; + return Ok(TreeNodeRecursion::Stop); } - _ => Ok(TreeNodeRecursion::Continue), + Ok(TreeNodeRecursion::Continue) } })?; diff --git a/datafusion/physical-plan/src/async_func.rs b/datafusion/physical-plan/src/async_func.rs index be72a6af2b509..54a76e0ebb971 100644 --- a/datafusion/physical-plan/src/async_func.rs +++ b/datafusion/physical-plan/src/async_func.rs @@ -22,13 +22,13 @@ use crate::{ }; use arrow::array::RecordBatch; use arrow_schema::{Fields, Schema, SchemaRef}; -use datafusion_common::tree_node::{Transformed, TreeNodeRecursion}; +use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; use datafusion_common::{internal_err, Result}; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_physical_expr::async_scalar_function::AsyncFuncExpr; use datafusion_physical_expr::equivalence::ProjectionMapping; use datafusion_physical_expr::expressions::Column; -use datafusion_physical_expr::{PhysicalExprExt, ScalarFunctionExpr}; +use datafusion_physical_expr::ScalarFunctionExpr; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use futures::stream::StreamExt; use log::trace; @@ -249,7 +249,7 @@ impl AsyncMapper { schema: &Schema, ) -> Result<()> { // recursively look for references to async functions - physical_expr.apply_with_schema(schema, |expr, schema| { + physical_expr.apply(|expr| { if let Some(scalar_func_expr) = expr.as_any().downcast_ref::() { diff --git a/datafusion/physical-plan/src/joins/stream_join_utils.rs b/datafusion/physical-plan/src/joins/stream_join_utils.rs index b70a8f60508a5..80221a77992ce 100644 --- a/datafusion/physical-plan/src/joins/stream_join_utils.rs +++ b/datafusion/physical-plan/src/joins/stream_join_utils.rs @@ -35,7 +35,7 @@ use arrow::array::{ }; use arrow::compute::concat_batches; use arrow::datatypes::{ArrowNativeType, Schema, SchemaRef}; -use datafusion_common::tree_node::{Transformed, TransformedResult}; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::utils::memory::estimate_memory_size; use datafusion_common::{ arrow_datafusion_err, DataFusionError, HashSet, JoinSide, Result, ScalarValue, @@ -44,7 +44,7 @@ use datafusion_expr::interval_arithmetic::Interval; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::intervals::cp_solver::ExprIntervalGraph; use datafusion_physical_expr::utils::collect_columns; -use datafusion_physical_expr::{PhysicalExpr, PhysicalExprExt, PhysicalSortExpr}; +use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; use datafusion_physical_expr_common::sort_expr::LexOrdering; use hashbrown::HashTable; @@ -312,13 +312,13 @@ pub fn convert_sort_expr_with_filter_schema( // Since we are sure that one to one column mapping includes all columns, we convert // the sort expression into a filter expression. let converted_filter_expr = expr - .transform_up_with_lambdas_params(|p, lambdas_params| { - convert_filter_columns(p.as_ref(), &column_map, lambdas_params).map( - |transformed| match transformed { + .transform_up(|p| { + convert_filter_columns(p.as_ref(), &column_map).map(|transformed| { + match transformed { Some(transformed) => Transformed::yes(transformed), None => Transformed::no(p), - }, - ) + } + }) }) .data()?; // Search the converted `PhysicalExpr` in filter expression; if an exact @@ -361,17 +361,14 @@ pub fn build_filter_input_order( fn convert_filter_columns( input: &dyn PhysicalExpr, column_map: &HashMap, - lambdas_params: &HashSet, ) -> Result>> { // Attempt to downcast the input expression to a Column type. - Ok(match input.as_any().downcast_ref::() { - Some(col) if !lambdas_params.contains(col.name()) => { - column_map.get(col).map(|c| Arc::new(c.clone()) as _) - } - _ => { - // If the downcast fails, return the input expression as is. - None - } + Ok(if let Some(col) = input.as_any().downcast_ref::() { + // If the downcast is successful, retrieve the corresponding filter column. + column_map.get(col).map(|c| Arc::new(c.clone()) as _) + } else { + // If the downcast fails, return the input expression as is. + None }) } diff --git a/datafusion/physical-plan/src/projection.rs b/datafusion/physical-plan/src/projection.rs index ab654e4eee1df..ead2196860cde 100644 --- a/datafusion/physical-plan/src/projection.rs +++ b/datafusion/physical-plan/src/projection.rs @@ -42,13 +42,14 @@ use std::task::{Context, Poll}; use arrow::datatypes::SchemaRef; use arrow::record_batch::{RecordBatch, RecordBatchOptions}; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNodeRecursion}; +use datafusion_common::tree_node::{ + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, +}; use datafusion_common::{internal_err, JoinSide, Result}; use datafusion_execution::TaskContext; use datafusion_physical_expr::equivalence::ProjectionMapping; use datafusion_physical_expr::utils::collect_columns; -use datafusion_physical_expr::{PhysicalExprExt, PhysicalExprRef}; -use datafusion_physical_expr_common::physical_expr::fmt_sql; +use datafusion_physical_expr_common::physical_expr::{fmt_sql, PhysicalExprRef}; use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexRequirement}; // Re-exported from datafusion-physical-expr for backwards compatibility // We recommend updating your imports to use datafusion-physical-expr directly @@ -865,12 +866,10 @@ fn try_unifying_projections( projection.expr().iter().for_each(|proj_expr| { proj_expr .expr - .apply_with_lambdas_params(|expr, lambdas_params| { + .apply(|expr| { Ok({ if let Some(column) = expr.as_any().downcast_ref::() { - if !lambdas_params.contains(column.name()) { - *column_ref_map.entry(column.clone()).or_default() += 1; - } + *column_ref_map.entry(column.clone()).or_default() += 1; } TreeNodeRecursion::Continue }) @@ -958,31 +957,31 @@ fn new_columns_for_join_on( .filter_map(|on| { // Rewrite all columns in `on` Arc::clone(*on) - .transform_with_lambdas_params(|expr, lambdas_params| { - match expr.as_any().downcast_ref::() { - Some(column) if !lambdas_params.contains(column.name()) => { - let new_column = projection_exprs - .iter() - .enumerate() - .find(|(_, (proj_column, _))| { - column.name() == proj_column.name() - && column.index() + column_index_offset - == proj_column.index() - }) - .map(|(index, (_, alias))| Column::new(alias, index)); - if let Some(new_column) = new_column { - Ok(Transformed::yes(Arc::new(new_column))) - } else { - // If the column is not found in the projection expressions, - // it means that the column is not projected. In this case, - // we cannot push the projection down. - internal_err!( - "Column {:?} not found in projection expressions", - column - ) - } + .transform(|expr| { + if let Some(column) = expr.as_any().downcast_ref::() { + // Find the column in the projection expressions + let new_column = projection_exprs + .iter() + .enumerate() + .find(|(_, (proj_column, _))| { + column.name() == proj_column.name() + && column.index() + column_index_offset + == proj_column.index() + }) + .map(|(index, (_, alias))| Column::new(alias, index)); + if let Some(new_column) = new_column { + Ok(Transformed::yes(Arc::new(new_column))) + } else { + // If the column is not found in the projection expressions, + // it means that the column is not projected. In this case, + // we cannot push the projection down. + internal_err!( + "Column {:?} not found in projection expressions", + column + ) } - _ => Ok(Transformed::no(expr)), + } else { + Ok(Transformed::no(expr)) } }) .data() diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index b87a50b3f5281..c6e33159aa89f 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -622,9 +622,9 @@ pub fn serialize_expr( .unwrap_or(HashMap::new()), })), }, - Expr::Lambda { .. } => { + Expr::Lambda(_) | Expr::LambdaColumn(_) => { return Err(Error::General( - "Proto serialization error: Lambda not supported".to_string(), + "Proto serialization error: Lambda not implemented".to_string(), )) } }; diff --git a/datafusion/pruning/src/pruning_predicate.rs b/datafusion/pruning/src/pruning_predicate.rs index c9df93f8b693c..380ada10df6e1 100644 --- a/datafusion/pruning/src/pruning_predicate.rs +++ b/datafusion/pruning/src/pruning_predicate.rs @@ -38,14 +38,13 @@ use datafusion_common::error::Result; use datafusion_common::tree_node::TransformedResult; use datafusion_common::{ internal_datafusion_err, internal_err, plan_datafusion_err, plan_err, - tree_node::Transformed, ScalarValue, + tree_node::{Transformed, TreeNode}, + ScalarValue, }; use datafusion_common::{Column, DFSchema}; use datafusion_expr_common::operator::Operator; use datafusion_physical_expr::utils::{collect_columns, Guarantee, LiteralGuarantee}; -use datafusion_physical_expr::{ - expressions as phys_expr, PhysicalExprExt, PhysicalExprRef, -}; +use datafusion_physical_expr::{expressions as phys_expr, PhysicalExprRef}; use datafusion_physical_expr_common::physical_expr::snapshot_physical_expr; use datafusion_physical_plan::{ColumnarValue, PhysicalExpr}; @@ -1205,9 +1204,9 @@ fn rewrite_column_expr( column_old: &phys_expr::Column, column_new: &phys_expr::Column, ) -> Result> { - e.transform_with_lambdas_params(|expr, lambdas_params| { + e.transform(|expr| { if let Some(column) = expr.as_any().downcast_ref::() { - if !lambdas_params.contains(column.name()) && column == column_old { + if column == column_old { return Ok(Transformed::yes(Arc::new(column_new.clone()))); } } diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index c13fd33104eb0..439e65e8f7e47 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +use std::sync::Arc; + use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use arrow::datatypes::DataType; @@ -26,7 +28,10 @@ use datafusion_expr::expr::{Lambda, ScalarFunction, Unnest}; use datafusion_expr::expr::{NullTreatment, WildcardOptions, WindowFunction}; use datafusion_expr::planner::PlannerResult; use datafusion_expr::planner::{RawAggregateExpr, RawWindowExpr}; -use datafusion_expr::{expr, Expr, ExprSchemable, WindowFrame, WindowFunctionDefinition}; +use datafusion_expr::{ + expr, Expr, ExprSchemable, ValueOrLambdaParameter, WindowFrame, + WindowFunctionDefinition, +}; use sqlparser::ast::{ DuplicateTreatment, Expr as SQLExpr, Function as SQLFunction, FunctionArg, FunctionArgExpr, FunctionArgumentClause, FunctionArgumentList, FunctionArguments, @@ -274,8 +279,94 @@ impl SqlToRel<'_, S> { } // User-defined function (UDF) should have precedence if let Some(fm) = self.context_provider.get_function_meta(&name) { - let (args, arg_names) = - self.function_args_to_expr_with_names(args, schema, planner_context)?; + enum ExprOrLambda { + ExprWithName((Expr, Option)), + Lambda(sqlparser::ast::LambdaFunction), + } + + let pairs = args + .into_iter() + .map(|a| match a { + FunctionArg::Unnamed(FunctionArgExpr::Expr(SQLExpr::Lambda( + lambda, + ))) => Ok(ExprOrLambda::Lambda(lambda)), + _ => Ok(ExprOrLambda::ExprWithName( + self.sql_fn_arg_to_logical_expr_with_name( + a, + schema, + planner_context, + )?, + )), + }) + .collect::>>()?; + + let metadata = pairs + .iter() + .map(|e| match e { + ExprOrLambda::ExprWithName((expr, _name)) => { + Ok(ValueOrLambdaParameter::Value(expr.to_field(schema)?.1)) + } + ExprOrLambda::Lambda(_lambda_function) => { + Ok(ValueOrLambdaParameter::Lambda) + } + }) + .collect::>>()?; + + let lambdas_parameters = fm.inner().lambdas_parameters(&metadata)?; + + let pairs = pairs + .into_iter() + .zip(lambdas_parameters) + .map(|(e, lambda_parameters)| match (e, lambda_parameters) { + (ExprOrLambda::ExprWithName(expr_with_name), None) => { + Ok(expr_with_name) + } + (ExprOrLambda::Lambda(lambda), Some(lambda_params)) => { + if lambda.params.len() > lambda_params.len() { + return plan_err!( + "lambda defined {} params but UDF support only {}", + lambda.params.len(), + lambda_params.len() + ); + } + + let params = + lambda.params.iter().map(|p| p.value.clone()).collect(); + + let lambda_parameters = lambda_params + .into_iter() + .zip(¶ms) + .map(|(f, n)| Arc::new(f.with_name(n))); + + let mut planner_context = planner_context + .clone() + .with_lambda_parameters(lambda_parameters); + + Ok(( + Expr::Lambda(Lambda { + params, + body: Box::new(self.sql_expr_to_logical_expr( + *lambda.body, + schema, + &mut planner_context, + )?), + }), + None, + )) + } + (ExprOrLambda::ExprWithName(_), Some(_)) => plan_err!( + "{} reported parameters for an argument that is not a lambda", + fm.name() + ), + (ExprOrLambda::Lambda(_), None) => plan_err!( + "{} don't reported the parameters of one of it's lambdas", + fm.name() + ), + }) + .collect::>>()?; + + let (args, arg_names): (Vec, Vec>) = + pairs.into_iter().unzip(); let resolved_args = if arg_names.iter().any(|name| name.is_some()) { if let Some(param_names) = &fm.signature().parameter_names { @@ -724,26 +815,6 @@ impl SqlToRel<'_, S> { let arg_name = crate::utils::normalize_ident(name); Ok((expr, Some(arg_name))) } - FunctionArg::Unnamed(FunctionArgExpr::Expr(SQLExpr::Lambda( - sqlparser::ast::LambdaFunction { params, body }, - ))) => { - let params = params - .into_iter() - .map(|v| v.to_string()) - .collect::>(); - - Ok(( - Expr::Lambda(Lambda { - params: params.clone(), - body: Box::new(self.sql_expr_to_logical_expr( - *body, - schema, - &mut planner_context.clone().with_lambda_parameters(params), - )?), - }), - None, - )) - } FunctionArg::Unnamed(FunctionArgExpr::Expr(arg)) => { let expr = self.sql_expr_to_logical_expr(arg, schema, planner_context)?; Ok((expr, None)) diff --git a/datafusion/sql/src/expr/identifier.rs b/datafusion/sql/src/expr/identifier.rs index dc39cb4de055d..73be980d686d0 100644 --- a/datafusion/sql/src/expr/identifier.rs +++ b/datafusion/sql/src/expr/identifier.rs @@ -20,6 +20,7 @@ use datafusion_common::{ exec_datafusion_err, internal_err, not_impl_err, plan_datafusion_err, plan_err, Column, DFSchema, Result, Span, TableReference, }; +use datafusion_expr::expr::LambdaColumn; use datafusion_expr::planner::PlannerResult; use datafusion_expr::{Case, Expr}; use sqlparser::ast::{CaseWhen, Expr as SQLExpr, Ident}; @@ -53,17 +54,17 @@ impl SqlToRel<'_, S> { // identifier. (e.g. it is "foo.bar" not foo.bar) let normalize_ident = self.ident_normalizer.normalize(id); - if planner_context + if let Some(field) = planner_context .lambdas_parameters() - .contains(&normalize_ident) + .get(&normalize_ident) { - let mut column = Column::new_unqualified(normalize_ident); + let mut lambda_column = LambdaColumn::new(normalize_ident, Arc::clone(field)); if self.options.collect_spans { if let Some(span) = Span::try_from_sqlparser_span(id_span) { - column.spans_mut().add_span(span); + lambda_column.spans_mut().add_span(span); } } - return Ok(Expr::Column(column)); + return Ok(Expr::LambdaColumn(lambda_column)); } // Check for qualified field with unqualified name diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 2992378fd1d6c..8cc7747ffe16b 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -26,13 +26,14 @@ use arrow::datatypes::*; use datafusion_common::config::SqlParserOptions; use datafusion_common::datatype::{DataTypeExt, FieldExt}; use datafusion_common::error::add_possible_columns_to_diag; -use datafusion_common::TableReference; use datafusion_common::{ - field_not_found, plan_datafusion_err, DFSchemaRef, Diagnostic, HashSet, SchemaError, + field_not_found, plan_datafusion_err, DFSchemaRef, Diagnostic, SchemaError, }; +use datafusion_common::{internal_err, TableReference}; use datafusion_common::{not_impl_err, plan_err, DFSchema, DataFusionError, Result}; use datafusion_expr::logical_plan::{LogicalPlan, LogicalPlanBuilder}; pub use datafusion_expr::planner::ContextProvider; +use datafusion_expr::utils::find_column_exprs; use datafusion_expr::{col, Expr}; use sqlparser::ast::{ArrayElemTypeDef, ExactNumberInfo, TimezoneInfo}; use sqlparser::ast::{ColumnDef as SQLColumnDef, ColumnOption}; @@ -265,8 +266,8 @@ pub struct PlannerContext { outer_from_schema: Option, /// The query schema defined by the table create_table_schema: Option, - /// The lambda introduced columns names - lambdas_parameters: HashSet, + /// The parameters of all lambdas seen so far + lambdas_parameters: HashMap, } impl Default for PlannerContext { @@ -284,7 +285,7 @@ impl PlannerContext { outer_query_schema: None, outer_from_schema: None, create_table_schema: None, - lambdas_parameters: HashSet::new(), + lambdas_parameters: HashMap::new(), } } @@ -371,15 +372,16 @@ impl PlannerContext { self.ctes.get(cte_name).map(|cte| cte.as_ref()) } - pub fn lambdas_parameters(&self) -> &HashSet { + pub fn lambdas_parameters(&self) -> &HashMap { &self.lambdas_parameters } pub fn with_lambda_parameters( mut self, - arguments: impl IntoIterator, + arguments: impl IntoIterator, ) -> Self { - self.lambdas_parameters.extend(arguments); + self.lambdas_parameters + .extend(arguments.into_iter().map(|f| (f.name().clone(), f))); self } @@ -545,11 +547,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { schema: &DFSchema, exprs: &[Expr], ) -> Result<()> { - exprs + find_column_exprs(exprs) .iter() - .flat_map(|expr| expr.column_refs()) - .try_for_each(|col| { - match &col.relation { + .try_for_each(|col| match col { + Expr::Column(col) => match &col.relation { Some(r) => schema.field_with_qualified_name(r, &col.name).map(|_| ()), None => { if !schema.fields_with_unqualified_name(&col.name).is_empty() { @@ -599,7 +600,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { err.with_diagnostic(diagnostic) } _ => err, - }) + }), + _ => internal_err!("Not a column"), }) } diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs index 0e7490d2c780b..42013a76a8657 100644 --- a/datafusion/sql/src/select.rs +++ b/datafusion/sql/src/select.rs @@ -540,11 +540,9 @@ impl SqlToRel<'_, S> { None => { let mut columns = HashSet::new(); for expr in &aggr_expr { - expr.apply_with_lambdas_params(|expr, lambdas_params| { + expr.apply(|expr| { if let Expr::Column(c) = expr { - if !c.is_lambda_parameter(lambdas_params) { - columns.insert(Expr::Column(c.clone())); - } + columns.insert(Expr::Column(c.clone())); } Ok(TreeNodeRecursion::Continue) }) diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 67ca92bb1c1f1..71a1a342a9c5e 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -536,6 +536,9 @@ impl Unparser<'_> { body: Box::new(self.expr_to_sql_inner(body)?), })) } + Expr::LambdaColumn(l) => Ok(ast::Expr::Identifier( + self.new_ident_quoted_if_needs(l.name.clone()), + )), } } diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index c218ce547b312..e7535338b7677 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -40,7 +40,7 @@ use crate::unparser::{ast::UnnestRelationBuilder, rewrite::rewrite_qualify}; use crate::utils::UNNEST_PLACEHOLDER; use datafusion_common::{ internal_err, not_impl_err, - tree_node::TransformedResult, + tree_node::{TransformedResult, TreeNode}, Column, DataFusionError, Result, ScalarValue, TableReference, }; use datafusion_expr::expr::OUTER_REFERENCE_COLUMN_PREFIX; @@ -1131,7 +1131,7 @@ impl Unparser<'_> { .cloned() .map(|expr| { if let Some(ref mut rewriter) = filter_alias_rewriter { - expr.rewrite_with_lambdas_params(rewriter).data() + expr.rewrite(rewriter).data() } else { Ok(expr) } @@ -1197,7 +1197,7 @@ impl Unparser<'_> { .cloned() .map(|expr| { if let Some(ref mut rewriter) = alias_rewriter { - expr.rewrite_with_lambdas_params(rewriter).data() + expr.rewrite(rewriter).data() } else { Ok(expr) } diff --git a/datafusion/sql/src/unparser/rewrite.rs b/datafusion/sql/src/unparser/rewrite.rs index 58f4435095517..c961f1d6f1f0c 100644 --- a/datafusion/sql/src/unparser/rewrite.rs +++ b/datafusion/sql/src/unparser/rewrite.rs @@ -20,11 +20,10 @@ use std::{collections::HashSet, sync::Arc}; use arrow::datatypes::Schema; use datafusion_common::tree_node::TreeNodeContainer; use datafusion_common::{ - tree_node::{Transformed, TransformedResult, TreeNode}, + tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter}, Column, HashMap, Result, TableReference, }; use datafusion_expr::expr::{Alias, UNNEST_COLUMN_PREFIX}; -use datafusion_expr::tree_node::TreeNodeRewriterWithPayload; use datafusion_expr::{Expr, LogicalPlan, Projection, Sort, SortExpr}; use sqlparser::ast::Ident; @@ -467,17 +466,12 @@ pub struct TableAliasRewriter<'a> { pub alias_name: TableReference, } -impl TreeNodeRewriterWithPayload for TableAliasRewriter<'_> { +impl TreeNodeRewriter for TableAliasRewriter<'_> { type Node = Expr; - type Payload<'a> = &'a datafusion_common::HashSet; - fn f_down( - &mut self, - expr: Expr, - lambdas_params: &datafusion_common::HashSet, - ) -> Result> { + fn f_down(&mut self, expr: Expr) -> Result> { match expr { - Expr::Column(column) if !column.is_lambda_parameter(lambdas_params) => { + Expr::Column(column) => { if let Ok(field) = self.table_schema.field_with_name(&column.name) { let new_column = Column::new(Some(self.alias_name.clone()), field.name().clone()); diff --git a/datafusion/sql/src/unparser/utils.rs b/datafusion/sql/src/unparser/utils.rs index f785f640dbcee..8b3791017a8af 100644 --- a/datafusion/sql/src/unparser/utils.rs +++ b/datafusion/sql/src/unparser/utils.rs @@ -161,11 +161,11 @@ pub(crate) fn find_window_nodes_within_select<'a>( /// For example, if expr contains the column expr "__unnest_placeholder(make_array(Int64(1),Int64(2),Int64(2),Int64(5),NULL),depth=1)" /// it will be transformed into an actual unnest expression UNNEST([1, 2, 2, 5, NULL]) pub(crate) fn unproject_unnest_expr(expr: Expr, unnest: &Unnest) -> Result { - expr.transform_up_with_lambdas_params(|sub_expr, lambdas_params| { + expr.transform(|sub_expr| { if let Expr::Column(col_ref) = &sub_expr { // Check if the column is among the columns to run unnest on. // Currently, only List/Array columns (defined in `list_type_columns`) are supported for unnesting. - if !col_ref.is_lambda_parameter(lambdas_params) && unnest.list_type_columns.iter().any(|e| e.1.output_column.name == col_ref.name) { + if unnest.list_type_columns.iter().any(|e| e.1.output_column.name == col_ref.name) { if let Ok(idx) = unnest.schema.index_of_column(col_ref) { if let LogicalPlan::Projection(Projection { expr, .. }) = unnest.input.as_ref() { if let Some(unprojected_expr) = expr.get(idx) { @@ -195,21 +195,22 @@ pub(crate) fn unproject_agg_exprs( agg: &Aggregate, windows: Option<&[&Window]>, ) -> Result { - expr.transform_up_with_lambdas_params(|sub_expr, lambdas_params| { - match sub_expr { - Expr::Column(c) if !c.is_lambda_parameter(lambdas_params) => if let Some(unprojected_expr) = find_agg_expr(agg, &c)? { - Ok(Transformed::yes(unprojected_expr.clone())) - } else if let Some(unprojected_expr) = - windows.and_then(|w| find_window_expr(w, &c.name).cloned()) - { - // Window function can contain an aggregation columns, e.g., 'avg(sum(ss_sales_price)) over ...' that needs to be unprojected - Ok(Transformed::yes(unproject_agg_exprs(unprojected_expr, agg, None)?)) - } else { - internal_err!( - "Tried to unproject agg expr for column '{}' that was not found in the provided Aggregate!", &c.name - ) - }, - _ => Ok(Transformed::no(sub_expr)), + expr.transform(|sub_expr| { + if let Expr::Column(c) = sub_expr { + if let Some(unprojected_expr) = find_agg_expr(agg, &c)? { + Ok(Transformed::yes(unprojected_expr.clone())) + } else if let Some(unprojected_expr) = + windows.and_then(|w| find_window_expr(w, &c.name).cloned()) + { + // Window function can contain an aggregation columns, e.g., 'avg(sum(ss_sales_price)) over ...' that needs to be unprojected + Ok(Transformed::yes(unproject_agg_exprs(unprojected_expr, agg, None)?)) + } else { + internal_err!( + "Tried to unproject agg expr for column '{}' that was not found in the provided Aggregate!", &c.name + ) + } + } else { + Ok(Transformed::no(sub_expr)) } }) .map(|e| e.data) @@ -221,15 +222,16 @@ pub(crate) fn unproject_agg_exprs( /// For example, if expr contains the column expr "COUNT(*) PARTITION BY id" it will be transformed /// into an actual window expression as identified in the window node. pub(crate) fn unproject_window_exprs(expr: Expr, windows: &[&Window]) -> Result { - expr.transform_up_with_lambdas_params(|sub_expr, lambdas_params| match sub_expr { - Expr::Column(c) if !c.is_lambda_parameter(lambdas_params) => { + expr.transform(|sub_expr| { + if let Expr::Column(c) = sub_expr { if let Some(unproj) = find_window_expr(windows, &c.name) { Ok(Transformed::yes(unproj.clone())) } else { Ok(Transformed::no(Expr::Column(c))) } + } else { + Ok(Transformed::no(sub_expr)) } - _ => Ok(Transformed::no(sub_expr)), }) .map(|e| e.data) } @@ -374,7 +376,7 @@ pub(crate) fn try_transform_to_simple_table_scan_with_filters( .cloned() .map(|expr| { if let Some(ref mut rewriter) = filter_alias_rewriter { - expr.rewrite_with_lambdas_params(rewriter).data() + expr.rewrite(rewriter).data() } else { Ok(expr) } diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs index 6380412e3b5ee..3c86d2d04905f 100644 --- a/datafusion/sql/src/utils.rs +++ b/datafusion/sql/src/utils.rs @@ -23,16 +23,16 @@ use arrow::datatypes::{ DataType, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, DECIMAL_DEFAULT_SCALE, }; use datafusion_common::tree_node::{ - Transformed, TransformedResult, TreeNode, TreeNodeRecursion, + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter, }; use datafusion_common::{ - exec_datafusion_err, exec_err, internal_err, plan_err, Column, DFSchema, Diagnostic, HashMap, Result, ScalarValue + exec_datafusion_err, exec_err, internal_err, plan_err, Column, DFSchemaRef, + Diagnostic, HashMap, Result, ScalarValue, }; use datafusion_expr::builder::get_struct_unnested_columns; use datafusion_expr::expr::{ Alias, GroupingSet, Unnest, WindowFunction, WindowFunctionParams, }; -use datafusion_expr::tree_node::TreeNodeRewriterWithPayload; use datafusion_expr::utils::{expr_as_column_expr, find_column_exprs}; use datafusion_expr::{ col, expr_vec_fmt, ColumnUnnestList, Expr, ExprSchemable, LogicalPlan, @@ -44,9 +44,9 @@ use sqlparser::ast::{Ident, Value}; /// Make a best-effort attempt at resolving all columns in the expression tree pub(crate) fn resolve_columns(expr: &Expr, plan: &LogicalPlan) -> Result { expr.clone() - .transform_up_with_lambdas_params(|nested_expr, lambdas_params| { + .transform_up(|nested_expr| { match nested_expr { - Expr::Column(col) if !col.is_lambda_parameter(lambdas_params) => { + Expr::Column(col) => { let (qualifier, field) = plan.schema().qualified_field_from_column(&col)?; Ok(Transformed::yes(Expr::Column(Column::from(( @@ -81,7 +81,6 @@ pub(crate) fn rebase_expr( base_exprs: &[Expr], plan: &LogicalPlan, ) -> Result { - //todo user transform_down_with_lambdas_params expr.clone() .transform_down(|nested_expr| { if base_exprs.contains(&nested_expr) { @@ -232,8 +231,8 @@ pub(crate) fn resolve_aliases_to_exprs( expr: Expr, aliases: &HashMap, ) -> Result { - expr.transform_up_with_lambdas_params(|nested_expr, lambdas_params| match nested_expr { - Expr::Column(c) if c.relation.is_none() && !c.is_lambda_parameter(lambdas_params) => { + expr.transform_up(|nested_expr| match nested_expr { + Expr::Column(c) if c.relation.is_none() => { if let Some(aliased_expr) = aliases.get(&c.name) { Ok(Transformed::yes(aliased_expr.clone())) } else { @@ -372,6 +371,7 @@ This is only usedful when used with transform down up A full example of how the transformation works: */ struct RecursiveUnnestRewriter<'a> { + input_schema: &'a DFSchemaRef, root_expr: &'a Expr, // Useful to detect which child expr is a part of/ not a part of unnest operation top_most_unnest: Option, @@ -405,7 +405,6 @@ impl RecursiveUnnestRewriter<'_> { alias_name: String, expr_in_unnest: &Expr, struct_allowed: bool, - input_schema: &DFSchema, ) -> Result> { let inner_expr_name = expr_in_unnest.schema_name().to_string(); @@ -419,7 +418,7 @@ impl RecursiveUnnestRewriter<'_> { // column name as is, to comply with group by and order by let placeholder_column = Column::from_name(placeholder_name.clone()); - let (data_type, _) = expr_in_unnest.data_type_and_nullable(input_schema)?; + let (data_type, _) = expr_in_unnest.data_type_and_nullable(self.input_schema)?; match data_type { DataType::Struct(inner_fields) => { @@ -469,18 +468,17 @@ impl RecursiveUnnestRewriter<'_> { } } -impl TreeNodeRewriterWithPayload for RecursiveUnnestRewriter<'_> { +impl TreeNodeRewriter for RecursiveUnnestRewriter<'_> { type Node = Expr; - type Payload<'a> = &'a DFSchema; /// This downward traversal needs to keep track of: /// - Whether or not some unnest expr has been visited from the top util the current node /// - If some unnest expr has been visited, maintain a stack of such information, this /// is used to detect if some recursive unnest expr exists (e.g **unnest(unnest(unnest(3d column))))** - fn f_down(&mut self, expr: Expr, input_schema: &DFSchema) -> Result> { + fn f_down(&mut self, expr: Expr) -> Result> { if let Expr::Unnest(ref unnest_expr) = expr { let (data_type, _) = - unnest_expr.expr.data_type_and_nullable(input_schema)?; + unnest_expr.expr.data_type_and_nullable(self.input_schema)?; self.consecutive_unnest.push(Some(unnest_expr.clone())); // if expr inside unnest is a struct, do not consider // the next unnest as consecutive unnest (if any) @@ -534,7 +532,7 @@ impl TreeNodeRewriterWithPayload for RecursiveUnnestRewriter<'_> { /// column2 /// ``` /// - fn f_up(&mut self, expr: Expr, input_schema: &DFSchema) -> Result> { + fn f_up(&mut self, expr: Expr) -> Result> { if let Expr::Unnest(ref traversing_unnest) = expr { if traversing_unnest == self.top_most_unnest.as_ref().unwrap() { self.top_most_unnest = None; @@ -570,7 +568,6 @@ impl TreeNodeRewriterWithPayload for RecursiveUnnestRewriter<'_> { expr.schema_name().to_string(), inner_expr, struct_allowed, - input_schema, )?; if struct_allowed { self.transformed_root_exprs = Some(transformed_exprs.clone()); @@ -622,6 +619,7 @@ pub(crate) fn rewrite_recursive_unnest_bottom_up( original_expr: &Expr, ) -> Result> { let mut rewriter = RecursiveUnnestRewriter { + input_schema: input.schema(), root_expr: original_expr, top_most_unnest: None, consecutive_unnest: vec![], @@ -643,7 +641,7 @@ pub(crate) fn rewrite_recursive_unnest_bottom_up( data: transformed_expr, transformed, tnr: _, - } = original_expr.clone().rewrite_with_schema(input.schema(), &mut rewriter)?; + } = original_expr.clone().rewrite(&mut rewriter)?; if !transformed { // TODO: remove the next line after `Expr::Wildcard` is removed diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 29ea8cb786072..00629c392df48 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -5866,10 +5866,10 @@ select array_ndims(arrow_cast([null], 'List(List(List(Int64)))')); 3 # array_ndims scalar function #2 -#query II -#select array_ndims(array_repeat(array_repeat(array_repeat(1, 3), 2), 1)), array_ndims([[[[[[[[[[[[[[[[[[[[[1]]]]]]]]]]]]]]]]]]]]]); -#---- -#3 21 +query II +select array_ndims(array_repeat(array_repeat(array_repeat(1, 3), 2), 1)), array_ndims([[[[[[[[[[[[[[[[[[[[[1]]]]]]]]]]]]]]]]]]]]]); +---- +3 21 # array_ndims scalar function #3 query II diff --git a/datafusion/sqllogictest/test_files/lambda.slt b/datafusion/sqllogictest/test_files/lambda.slt index 0043eae17a60c..af5334a644421 100644 --- a/datafusion/sqllogictest/test_files/lambda.slt +++ b/datafusion/sqllogictest/test_files/lambda.slt @@ -42,9 +42,9 @@ SELECT array_transform([1], e1 -> (select n from t)); [1] query ? -SELECT array_transform(v, (v1, i) -> array_transform(v1, (v2, j) -> array_transform(v2, v3 -> j)) ) from t; +SELECT array_transform(t.v, (v1, i) -> array_transform(v1, (v2, j) -> array_transform(v2, v3 -> j)) ) from t; ---- -[[[0, 0], [1]], [[0]], [[]]] +[[[0, 1], [0]], [[0]], [[]]] query I? SELECT t.n, array_transform([1, 2], (e) -> n) from t; @@ -75,7 +75,7 @@ logical_plan 01)Projection: array_transform(List([1, 2]), (e, i) -> e + CAST(i AS Int64)) AS array_transform(make_array(Int64(1),Int64(2)),(e, i) -> e + i) 02)--EmptyRelation: rows=1 physical_plan -01)ProjectionExec: expr=[array_transform([1, 2], (e, i) -> e@0 + CAST(i@1 AS Int64)) as array_transform(make_array(Int64(1),Int64(2)),(e, i) -> e + i)] +01)ProjectionExec: expr=[array_transform([1, 2], (e, i) -> e@-1 + CAST(i@-1 AS Int64)) as array_transform(make_array(Int64(1),Int64(2)),(e, i) -> e + i)] 02)--PlaceholderRowExec #cse @@ -86,7 +86,7 @@ logical_plan 01)Projection: t.n + Int64(1), array_transform(List([1]), (v) -> v + t.n + Int64(1)) AS array_transform(make_array(Int64(1)),(v) -> v + t.n + Int64(1)) 02)--TableScan: t projection=[n] physical_plan -01)ProjectionExec: expr=[n@0 + 1 as t.n + Int64(1), array_transform([1], (v) -> v@1 + n@0 + 1) as array_transform(make_array(Int64(1)),(v) -> v + t.n + Int64(1))] +01)ProjectionExec: expr=[n@0 + 1 as t.n + Int64(1), array_transform([1], (v) -> v@-1 + n@0 + 1) as array_transform(make_array(Int64(1)),(v) -> v + t.n + Int64(1))] 02)--DataSourceExec: partitions=1, partition_sizes=[1] query ? @@ -121,9 +121,19 @@ logical_plan 01)Projection: array_transform(List([1, 2, 3, 4, 5]), (v) -> v * Int64(2)) AS array_transform(make_array(Int64(1),Int64(2),Int64(3),Int64(4),Int64(5)),(v) -> v * Int64(2)) 02)--EmptyRelation: rows=1 physical_plan -01)ProjectionExec: expr=[array_transform([1, 2, 3, 4, 5], (v) -> v@0 * 2) as array_transform(make_array(Int64(1),Int64(2),Int64(3),Int64(4),Int64(5)),(v) -> v * Int64(2))] +01)ProjectionExec: expr=[array_transform([1, 2, 3, 4, 5], (v) -> v@-1 * 2) as array_transform(make_array(Int64(1),Int64(2),Int64(3),Int64(4),Int64(5)),(v) -> v * Int64(2))] 02)--PlaceholderRowExec +query ? +SELECT array_transform( + [[1]], + v -> array_concat( + array_transform(v, v -> v), + array_transform(v, v1 -> v1 + v[0]) + ) +); +---- +[[1, NULL]] query I?? SELECT t.n, t.v, array_transform(t.v, (v, i) -> array_transform(v, (v, j) -> n) ) from t; @@ -144,23 +154,27 @@ logical_plan 01)Projection: Boolean(true) AS t.v = t.v, array_transform(List([1]), (v) -> v IS NOT NULL OR Boolean(NULL)) AS array_transform(make_array(Int64(1)),(v) -> v = v) 02)--TableScan: t projection=[] physical_plan -01)ProjectionExec: expr=[true as t.v = t.v, array_transform([1], (v) -> v@0 IS NOT NULL OR NULL) as array_transform(make_array(Int64(1)),(v) -> v = v)] +01)ProjectionExec: expr=[true as t.v = t.v, array_transform([1], (v) -> v@-1 IS NOT NULL OR NULL) as array_transform(make_array(Int64(1)),(v) -> v = v)] 02)--DataSourceExec: partitions=1, partition_sizes=[1] query error select array_transform(); ---- -DataFusion error: Error during planning: 'array_transform' does not support zero arguments No function matches the given name and argument types 'array_transform()'. You might need to add explicit type casts. - Candidate functions: - array_transform(Any, Any) +DataFusion error: Execution error: array_transform expects a value follewed by a lambda, got [] -query error DataFusion error: Execution error: expected list, got Field \{ name: "Int64\(1\)", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: \{\} \} +query error DataFusion error: Execution error: expected list, got Field \{ "Int64\(1\)": Int64 \} select array_transform(1, v -> v*2); -query error DataFusion error: Execution error: array_transform expects a value follewed by a lambda, got \[Lambda\(\["v"\], false\), Value\(Field \{ name: "make_array\(Int64\(1\),Int64\(2\)\)", data_type: List\(Field \{ name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\)\] +query error DataFusion error: Execution error: array_transform expects a value follewed by a lambda, got \[Lambda, Value\(Field \{ name: "make_array\(Int64\(1\),Int64\(2\)\)", data_type: List\(Field \{ data_type: Int64, nullable: true \}\), nullable: true \}\)\] select array_transform(v -> v*2, [1, 2]); -query error DataFusion error: Execution error: lambdas_schemas: array_transform argument 1 \(0\-indexed\), a lambda, supports up to 2 arguments, but got 3 +query error DataFusion error: Error during planning: lambda defined 3 params but UDF support only 2 SELECT array_transform([1, 2], (e, i, j) -> i) from t; + +#todo: this should error due to duplicate names +query ? +SELECT array_transform([1], (v, v) -> v*2); +---- +[0] diff --git a/datafusion/substrait/src/logical_plan/producer/expr/mod.rs b/datafusion/substrait/src/logical_plan/producer/expr/mod.rs index 103d593cafbc0..b16fd8032877f 100644 --- a/datafusion/substrait/src/logical_plan/producer/expr/mod.rs +++ b/datafusion/substrait/src/logical_plan/producer/expr/mod.rs @@ -153,6 +153,7 @@ pub fn to_substrait_rex( } Expr::Unnest(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), Expr::Lambda(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), // https://github.com/substrait-io/substrait/issues/349 + Expr::LambdaColumn(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), // https://github.com/substrait-io/substrait/issues/349 } } From d844b2d81d9ab5c4bc95671e9134ea6475af79bb Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Fri, 19 Dec 2025 05:07:33 -0300 Subject: [PATCH 04/12] rename LambdaColumn to LambdaVariable --- datafusion/catalog-listing/src/helpers.rs | 2 +- datafusion/expr/src/expr.rs | 16 ++++++------- datafusion/expr/src/expr_schema.rs | 8 +++---- datafusion/expr/src/tree_node.rs | 4 ++-- datafusion/expr/src/utils.rs | 2 +- .../functions-nested/src/array_transform.rs | 20 ++++++++-------- .../optimizer/src/analyzer/type_coercion.rs | 2 +- .../optimizer/src/common_subexpr_eliminate.rs | 4 ++-- datafusion/optimizer/src/push_down_filter.rs | 2 +- .../simplify_expressions/expr_simplifier.rs | 20 ++++++++-------- .../{lambda_column.rs => lambda_variable.rs} | 24 +++++++++---------- .../physical-expr/src/expressions/mod.rs | 4 ++-- datafusion/physical-expr/src/planner.rs | 8 +++---- datafusion/proto/src/logical_plan/to_proto.rs | 2 +- datafusion/sql/src/expr/identifier.rs | 8 +++---- datafusion/sql/src/unparser/expr.rs | 2 +- .../src/logical_plan/producer/expr/mod.rs | 2 +- 17 files changed, 65 insertions(+), 65 deletions(-) rename datafusion/physical-expr/src/expressions/{lambda_column.rs => lambda_variable.rs} (86%) diff --git a/datafusion/catalog-listing/src/helpers.rs b/datafusion/catalog-listing/src/helpers.rs index 78b46171006a7..eca681e3c604c 100644 --- a/datafusion/catalog-listing/src/helpers.rs +++ b/datafusion/catalog-listing/src/helpers.rs @@ -88,7 +88,7 @@ pub fn expr_applicable_for_cols(col_names: &[&str], expr: &Expr) -> bool { | Expr::GroupingSet(_) | Expr::Case(_) | Expr::Lambda(_) - | Expr::LambdaColumn(_) => Ok(TreeNodeRecursion::Continue), + | Expr::LambdaVariable(_) => Ok(TreeNodeRecursion::Continue), Expr::ScalarFunction(scalar_function) => { match scalar_function.func.signature().volatility { diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 6387fc4a44f38..07f1bc129c597 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -400,17 +400,17 @@ pub enum Expr { Unnest(Unnest), /// Lambda expression Lambda(Lambda), - LambdaColumn(LambdaColumn), + LambdaVariable(LambdaVariable), } #[derive(Clone, PartialEq, PartialOrd, Eq, Debug, Hash)] -pub struct LambdaColumn { +pub struct LambdaVariable { pub name: String, pub field: FieldRef, pub spans: Spans, } -impl LambdaColumn { +impl LambdaVariable { pub fn new(name: String, field: FieldRef) -> Self { Self { name, @@ -1567,7 +1567,7 @@ impl Expr { Expr::Wildcard { .. } => "Wildcard", Expr::Unnest { .. } => "Unnest", Expr::Lambda { .. } => "Lambda", - Expr::LambdaColumn { .. } => "LambdaColumn", + Expr::LambdaVariable { .. } => "LambdaVariable", } } @@ -2123,7 +2123,7 @@ impl Expr { | Expr::Literal(..) | Expr::Placeholder(..) | Expr::Lambda(..) - | Expr::LambdaColumn(..) => false, + | Expr::LambdaVariable(..) => false, } } @@ -2722,7 +2722,7 @@ impl HashNode for Expr { Expr::Lambda(Lambda { params, body: _ }) => { params.hash(state); } - Expr::LambdaColumn(LambdaColumn { + Expr::LambdaVariable(LambdaVariable { name, field, spans: _, @@ -3046,7 +3046,7 @@ impl Display for SchemaDisplay<'_> { Expr::Lambda(Lambda { params, body }) => { write!(f, "({}) -> {body}", display_comma_separated(params)) } - Expr::LambdaColumn(c) => { + Expr::LambdaVariable(c) => { write!(f, "{}", c.name) } } @@ -3542,7 +3542,7 @@ impl Display for Expr { Expr::Lambda(Lambda { params, body }) => { write!(f, "({}) -> {body}", params.join(", ")) } - Expr::LambdaColumn(c) => { + Expr::LambdaVariable(c) => { write!(f, "{}", c.name) } } diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 1b7f5f0212c6b..3e3ff7dacb9d8 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -21,7 +21,7 @@ use crate::expr::{ InSubquery, Lambda, Placeholder, ScalarFunction, TryCast, Unnest, WindowFunction, WindowFunctionParams, }; -use crate::expr::{FieldMetadata, LambdaColumn}; +use crate::expr::{FieldMetadata, LambdaVariable}; use crate::type_coercion::functions::{ fields_with_aggregate_udf, fields_with_window_udf, }; @@ -235,7 +235,7 @@ impl ExprSchemable for Expr { Ok(DataType::Null) } Expr::Lambda(Lambda { params: _, body }) => body.get_type(schema), - Expr::LambdaColumn(LambdaColumn { name: _, field, .. }) => { + Expr::LambdaVariable(LambdaVariable { name: _, field, .. }) => { Ok(field.data_type().clone()) } } @@ -357,7 +357,7 @@ impl ExprSchemable for Expr { Ok(true) } Expr::Lambda(l) => l.body.nullable(input_schema), - Expr::LambdaColumn(c) => Ok(c.field.is_nullable()), + Expr::LambdaVariable(c) => Ok(c.field.is_nullable()), } } @@ -625,7 +625,7 @@ impl ExprSchemable for Expr { self.get_type(schema)?, self.nullable(schema)?, ))), - Expr::LambdaColumn(c) => Ok(Arc::clone(&c.field)), + Expr::LambdaVariable(c) => Ok(Arc::clone(&c.field)), }?; Ok(( diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs index df98f720a0f08..a818c32948d09 100644 --- a/datafusion/expr/src/tree_node.rs +++ b/datafusion/expr/src/tree_node.rs @@ -81,7 +81,7 @@ impl TreeNode for Expr { | Expr::ScalarSubquery(_) | Expr::Wildcard { .. } | Expr::Placeholder(_) - | Expr::LambdaColumn(_) => Ok(TreeNodeRecursion::Continue), + | Expr::LambdaVariable(_) => Ok(TreeNodeRecursion::Continue), Expr::BinaryExpr(BinaryExpr { left, right, .. }) => { (left, right).apply_ref_elements(f) } @@ -133,7 +133,7 @@ impl TreeNode for Expr { | Expr::ScalarSubquery(_) | Expr::ScalarVariable(_, _) | Expr::Literal(_, _) - | Expr::LambdaColumn(_) => Transformed::no(self), + | Expr::LambdaVariable(_) => Transformed::no(self), Expr::Unnest(Unnest { expr, .. }) => expr .map_elements(f)? .update_data(|expr| Expr::Unnest(Unnest { expr })), diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 1e5e4c9f5b99a..ab58b1c3f835f 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -309,7 +309,7 @@ pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet) -> Result<()> { | Expr::Placeholder(_) | Expr::OuterReferenceColumn { .. } | Expr::Lambda(_) - | Expr::LambdaColumn(_) => {} + | Expr::LambdaVariable(_) => {} } Ok(TreeNodeRecursion::Continue) }) diff --git a/datafusion/functions-nested/src/array_transform.rs b/datafusion/functions-nested/src/array_transform.rs index dbc08473c246b..123df27b339be 100644 --- a/datafusion/functions-nested/src/array_transform.rs +++ b/datafusion/functions-nested/src/array_transform.rs @@ -33,7 +33,7 @@ use datafusion_expr::{ ValueOrLambda, ValueOrLambdaField, ValueOrLambdaParameter, Volatility, }; use datafusion_macros::user_doc; -use datafusion_physical_expr::expressions::{LambdaColumn, LambdaExpr}; +use datafusion_physical_expr::expressions::{LambdaVariable, LambdaExpr}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use std::{any::Any, sync::Arc}; @@ -173,7 +173,7 @@ impl ScalarUDFImpl for ArrayTransform { let values_param = || Ok(Arc::clone(list_values)); let indices_param = || elements_indices(&list_array); - let binded_body = bind_lambda_columns( + let binded_body = bind_lambda_variables( Arc::clone(&lambda.body), &lambda.params, &[&values_param, &indices_param], @@ -275,7 +275,7 @@ impl ScalarUDFImpl for ArrayTransform { } } -fn bind_lambda_columns( +fn bind_lambda_variables( expr: Arc, params: &[FieldRef], args: &[&dyn Fn() -> Result], @@ -284,28 +284,28 @@ fn bind_lambda_columns( .map(|(param, arg)| Ok((param.name().as_str(), (arg()?, 0)))) .collect::>>()?; - expr.rewrite(&mut BindLambdaColumn::new(columns)).data() + expr.rewrite(&mut BindLambdaVariable::new(columns)).data() } -struct BindLambdaColumn<'a> { +struct BindLambdaVariable<'a> { columns: HashMap<&'a str, (ArrayRef, usize)>, } -impl<'a> BindLambdaColumn<'a> { +impl<'a> BindLambdaVariable<'a> { fn new(columns: HashMap<&'a str, (ArrayRef, usize)>) -> Self { Self { columns } } } -impl TreeNodeRewriter for BindLambdaColumn<'_> { +impl TreeNodeRewriter for BindLambdaVariable<'_> { type Node = Arc; fn f_down(&mut self, node: Self::Node) -> Result> { - if let Some(lambda_column) = node.as_any().downcast_ref::() { - if let Some((value, shadows)) = self.columns.get(lambda_column.name()) { + if let Some(lambda_variable) = node.as_any().downcast_ref::() { + if let Some((value, shadows)) = self.columns.get(lambda_variable.name()) { if *shadows == 0 { return Ok(Transformed::yes(Arc::new( - lambda_column.clone().with_value(value.clone()), + lambda_variable.clone().with_value(value.clone()), ))); } } diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 763f693f2f607..626c2ba550594 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -599,7 +599,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { | Expr::Placeholder(_) | Expr::OuterReferenceColumn(_, _) | Expr::Lambda(_) - | Expr::LambdaColumn(_) => Ok(Transformed::no(expr)), + | Expr::LambdaVariable(_) => Ok(Transformed::no(expr)), } } } diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index f95a0f908b813..74e77011b71ed 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -694,7 +694,7 @@ impl CSEController for ExprCSEController<'_> { } fn is_valid(node: &Expr) -> bool { - !node.is_volatile_node() && !matches!(node, Expr::LambdaColumn(_)) + !node.is_volatile_node() && !matches!(node, Expr::LambdaVariable(_)) } fn is_ignored(&self, node: &Expr) -> bool { @@ -707,7 +707,7 @@ impl CSEController for ExprCSEController<'_> { | Expr::ScalarVariable(..) | Expr::Alias(..) | Expr::Wildcard { .. } - | Expr::LambdaColumn(_) + | Expr::LambdaVariable(_) ); let is_aggr = matches!(node, Expr::AggregateFunction(..)); diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 77c533ce2f01e..b7e8626aa7bd5 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -289,7 +289,7 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { | Expr::InList { .. } | Expr::ScalarFunction(_) | Expr::Lambda(_) - | Expr::LambdaColumn(_) => Ok(TreeNodeRecursion::Continue), + | Expr::LambdaVariable(_) => Ok(TreeNodeRecursion::Continue), // TODO: remove the next line after `Expr::Wildcard` is removed #[expect(deprecated)] Expr::AggregateFunction(_) diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 779c0acea9963..74794115755ba 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -468,8 +468,8 @@ impl TreeNodeRewriter for Canonicalizer { match (left.as_ref(), right.as_ref(), op.swap()) { // ( - left_col @ (Expr::Column(_) | Expr::LambdaColumn(_)), - right_col @ (Expr::Column(_) | Expr::LambdaColumn(_)), + left_col @ (Expr::Column(_) | Expr::LambdaVariable(_)), + right_col @ (Expr::Column(_) | Expr::LambdaVariable(_)), Some(swapped_op), ) if right_col > left_col => { Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr { @@ -481,7 +481,7 @@ impl TreeNodeRewriter for Canonicalizer { // ( Expr::Literal(_, _), - Expr::Column(_) | Expr::LambdaColumn(_), + Expr::Column(_) | Expr::LambdaVariable(_), Some(swapped_op), ) => Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr { left: right, @@ -655,7 +655,7 @@ impl<'a> ConstEvaluator<'a> { | Expr::GroupingSet(_) | Expr::Wildcard { .. } | Expr::Placeholder(_) - | Expr::LambdaColumn(_) => false, + | Expr::LambdaVariable(_) => false, Expr::ScalarFunction(ScalarFunction { func, .. }) => { Self::volatility_ok(func.signature().volatility) } @@ -2012,8 +2012,8 @@ fn are_inlist_and_eq(left: &Expr, right: &Expr) -> bool { let left = as_inlist(left); let right = as_inlist(right); if let (Some(lhs), Some(rhs)) = (left, right) { - matches!(lhs.expr.as_ref(), Expr::Column(_) | Expr::LambdaColumn(_)) - && matches!(rhs.expr.as_ref(), Expr::Column(_) | Expr::LambdaColumn(_)) + matches!(lhs.expr.as_ref(), Expr::Column(_) | Expr::LambdaVariable(_)) + && matches!(rhs.expr.as_ref(), Expr::Column(_) | Expr::LambdaVariable(_)) && lhs.expr == rhs.expr && !lhs.negated && !rhs.negated @@ -2028,14 +2028,14 @@ fn as_inlist(expr: &'_ Expr) -> Option> { Expr::InList(inlist) => Some(Cow::Borrowed(inlist)), Expr::BinaryExpr(BinaryExpr { left, op, right }) if *op == Operator::Eq => { match (left.as_ref(), right.as_ref()) { - (Expr::Column(_) | Expr::LambdaColumn(_), Expr::Literal(_, _)) => { + (Expr::Column(_) | Expr::LambdaVariable(_), Expr::Literal(_, _)) => { Some(Cow::Owned(InList { expr: left.clone(), list: vec![*right.clone()], negated: false, })) } - (Expr::Literal(_, _), Expr::Column(_) | Expr::LambdaColumn(_)) => { + (Expr::Literal(_, _), Expr::Column(_) | Expr::LambdaVariable(_)) => { Some(Cow::Owned(InList { expr: right.clone(), list: vec![*left.clone()], @@ -2057,14 +2057,14 @@ fn to_inlist(expr: Expr) -> Option { op: Operator::Eq, right, }) => match (left.as_ref(), right.as_ref()) { - (Expr::Column(_) | Expr::LambdaColumn(_), Expr::Literal(_, _)) => { + (Expr::Column(_) | Expr::LambdaVariable(_), Expr::Literal(_, _)) => { Some(InList { expr: left, list: vec![*right], negated: false, }) } - (Expr::Literal(_, _), Expr::Column(_) | Expr::LambdaColumn(_)) => { + (Expr::Literal(_, _), Expr::Column(_) | Expr::LambdaVariable(_)) => { Some(InList { expr: right, list: vec![*left], diff --git a/datafusion/physical-expr/src/expressions/lambda_column.rs b/datafusion/physical-expr/src/expressions/lambda_variable.rs similarity index 86% rename from datafusion/physical-expr/src/expressions/lambda_column.rs rename to datafusion/physical-expr/src/expressions/lambda_variable.rs index 4aed16186ba6f..305774c3c02da 100644 --- a/datafusion/physical-expr/src/expressions/lambda_column.rs +++ b/datafusion/physical-expr/src/expressions/lambda_variable.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! Physical lambda column reference: [`LambdaColumn`] +//! Physical lambda column reference: [`LambdaVariable`] use std::any::Any; use std::hash::Hash; @@ -33,28 +33,28 @@ use datafusion_expr::ColumnarValue; /// Represents the lambda column with a given name and field #[derive(Debug, Clone)] -pub struct LambdaColumn { +pub struct LambdaVariable { name: String, field: FieldRef, value: Option, } -impl Eq for LambdaColumn {} +impl Eq for LambdaVariable {} -impl PartialEq for LambdaColumn { +impl PartialEq for LambdaVariable { fn eq(&self, other: &Self) -> bool { self.name == other.name && self.field == other.field } } -impl Hash for LambdaColumn { +impl Hash for LambdaVariable { fn hash(&self, state: &mut H) { self.name.hash(state); self.field.hash(state); } } -impl LambdaColumn { +impl LambdaVariable { /// Create a new lambda column expression pub fn new(name: &str, field: FieldRef) -> Self { Self { @@ -83,13 +83,13 @@ impl LambdaColumn { } } -impl std::fmt::Display for LambdaColumn { +impl std::fmt::Display for LambdaVariable { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { write!(f, "{}@-1", self.name) } } -impl PhysicalExpr for LambdaColumn { +impl PhysicalExpr for LambdaVariable { /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { self @@ -107,7 +107,7 @@ impl PhysicalExpr for LambdaColumn { /// Evaluate the expression fn evaluate(&self, _batch: &RecordBatch) -> Result { - self.value.clone().ok_or_else(|| exec_datafusion_err!("Physical LambdaColumn {} missing value", self.name)) + self.value.clone().ok_or_else(|| exec_datafusion_err!("Physical LambdaVariable {} missing value", self.name)) } fn return_field(&self, _input_schema: &Schema) -> Result { @@ -130,7 +130,7 @@ impl PhysicalExpr for LambdaColumn { } } -/// Create a column expression -pub fn lambda_col(name: &str, field: FieldRef) -> Result> { - Ok(Arc::new(LambdaColumn::new(name, field))) +/// Create a lambda variable expression +pub fn lambda_variable(name: &str, field: FieldRef) -> Result> { + Ok(Arc::new(LambdaVariable::new(name, field))) } diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index 5d044ab848550..990e53fa23b2c 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -23,7 +23,7 @@ mod case; mod cast; mod cast_column; mod column; -mod lambda_column; +mod lambda_variable; mod dynamic_filters; mod in_list; mod is_not_null; @@ -46,7 +46,7 @@ pub use case::{case, CaseExpr}; pub use cast::{cast, CastExpr}; pub use cast_column::CastColumnExpr; pub use column::{col, with_new_schema, Column}; -pub use lambda_column::{lambda_col, LambdaColumn}; +pub use lambda_variable::{lambda_variable, LambdaVariable}; pub use datafusion_expr::utils::format_state_name; pub use dynamic_filters::DynamicFilterPhysicalExpr; pub use in_list::{in_list, InListExpr}; diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index 4c3d1352cce0f..8a53aa81da8fa 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -17,7 +17,7 @@ use std::sync::Arc; -use crate::expressions::{lambda_col, LambdaExpr}; +use crate::expressions::{lambda_variable, LambdaExpr}; use crate::ScalarFunctionExpr; use crate::{ expressions::{self, binary, like, similar_to, Column, Literal}, @@ -33,7 +33,7 @@ use datafusion_common::{ }; use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::expr::{ - Alias, Cast, InList, Lambda, LambdaColumn, Placeholder, ScalarFunction, + Alias, Cast, InList, Lambda, LambdaVariable, Placeholder, ScalarFunction, }; use datafusion_expr::var_provider::is_system_variables; use datafusion_expr::var_provider::VarType; @@ -392,11 +392,11 @@ pub fn create_physical_expr( Expr::Placeholder(Placeholder { id, .. }) => { exec_err!("Placeholder '{id}' was not provided a value for execution.") } - Expr::LambdaColumn(LambdaColumn { + Expr::LambdaVariable(LambdaVariable { name, field, spans: _, - }) => lambda_col( + }) => lambda_variable( name, Arc::clone(field), ), diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index c6e33159aa89f..41a9172fff276 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -622,7 +622,7 @@ pub fn serialize_expr( .unwrap_or(HashMap::new()), })), }, - Expr::Lambda(_) | Expr::LambdaColumn(_) => { + Expr::Lambda(_) | Expr::LambdaVariable(_) => { return Err(Error::General( "Proto serialization error: Lambda not implemented".to_string(), )) diff --git a/datafusion/sql/src/expr/identifier.rs b/datafusion/sql/src/expr/identifier.rs index 73be980d686d0..14433e9cf7eba 100644 --- a/datafusion/sql/src/expr/identifier.rs +++ b/datafusion/sql/src/expr/identifier.rs @@ -20,7 +20,7 @@ use datafusion_common::{ exec_datafusion_err, internal_err, not_impl_err, plan_datafusion_err, plan_err, Column, DFSchema, Result, Span, TableReference, }; -use datafusion_expr::expr::LambdaColumn; +use datafusion_expr::expr::LambdaVariable; use datafusion_expr::planner::PlannerResult; use datafusion_expr::{Case, Expr}; use sqlparser::ast::{CaseWhen, Expr as SQLExpr, Ident}; @@ -58,13 +58,13 @@ impl SqlToRel<'_, S> { .lambdas_parameters() .get(&normalize_ident) { - let mut lambda_column = LambdaColumn::new(normalize_ident, Arc::clone(field)); + let mut lambda_var = LambdaVariable::new(normalize_ident, Arc::clone(field)); if self.options.collect_spans { if let Some(span) = Span::try_from_sqlparser_span(id_span) { - lambda_column.spans_mut().add_span(span); + lambda_var.spans_mut().add_span(span); } } - return Ok(Expr::LambdaColumn(lambda_column)); + return Ok(Expr::LambdaVariable(lambda_var)); } // Check for qualified field with unqualified name diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 71a1a342a9c5e..013a6f1128957 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -536,7 +536,7 @@ impl Unparser<'_> { body: Box::new(self.expr_to_sql_inner(body)?), })) } - Expr::LambdaColumn(l) => Ok(ast::Expr::Identifier( + Expr::LambdaVariable(l) => Ok(ast::Expr::Identifier( self.new_ident_quoted_if_needs(l.name.clone()), )), } diff --git a/datafusion/substrait/src/logical_plan/producer/expr/mod.rs b/datafusion/substrait/src/logical_plan/producer/expr/mod.rs index b16fd8032877f..b3ca88a690291 100644 --- a/datafusion/substrait/src/logical_plan/producer/expr/mod.rs +++ b/datafusion/substrait/src/logical_plan/producer/expr/mod.rs @@ -153,7 +153,7 @@ pub fn to_substrait_rex( } Expr::Unnest(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), Expr::Lambda(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), // https://github.com/substrait-io/substrait/issues/349 - Expr::LambdaColumn(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), // https://github.com/substrait-io/substrait/issues/349 + Expr::LambdaVariable(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), // https://github.com/substrait-io/substrait/issues/349 } } From e1921eb377a56766b2dd4729311def1c3515c8c4 Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Mon, 23 Feb 2026 14:18:01 -0300 Subject: [PATCH 05/12] feat: add LambdaUDF --- datafusion-examples/examples/sql_frontend.rs | 8 +- datafusion/catalog-listing/src/helpers.rs | 18 +- .../core/src/bin/print_functions_docs.rs | 15 +- datafusion/core/src/execution/context/mod.rs | 16 + .../core/src/execution/session_state.rs | 87 ++- .../src/execution/session_state_defaults.rs | 8 +- datafusion/core/tests/optimizer/mod.rs | 7 +- .../datasource-arrow/src/file_format.rs | 6 +- datafusion/datasource/src/url.rs | 6 +- datafusion/execution/src/task.rs | 41 +- datafusion/expr/src/expr.rs | 49 ++ datafusion/expr/src/expr_schema.rs | 39 ++ datafusion/expr/src/lib.rs | 8 +- datafusion/expr/src/planner.rs | 6 +- datafusion/expr/src/registry.rs | 48 ++ datafusion/expr/src/tree_node.rs | 8 +- datafusion/expr/src/udf_eq.rs | 14 +- datafusion/expr/src/udlf.rs | 649 ++++++++++++++++++ datafusion/expr/src/utils.rs | 1 + .../functions-nested/src/array_transform.rs | 52 +- datafusion/functions-nested/src/lib.rs | 4 +- .../optimizer/src/analyzer/type_coercion.rs | 1 + datafusion/optimizer/src/push_down_filter.rs | 1 + .../simplify_expressions/expr_simplifier.rs | 4 + .../optimizer/tests/optimizer_integration.rs | 4 + .../physical-expr/src/lambda_function.rs | 530 ++++++++++++++ datafusion/physical-expr/src/lib.rs | 2 + datafusion/physical-expr/src/planner.rs | 28 +- datafusion/proto/src/bytes/mod.rs | 60 +- datafusion/proto/src/bytes/registry.rs | 10 + datafusion/proto/src/logical_plan/mod.rs | 10 +- datafusion/proto/src/logical_plan/to_proto.rs | 11 + datafusion/session/src/session.rs | 6 +- datafusion/sql/examples/sql.rs | 6 +- datafusion/sql/src/expr/function.rs | 57 +- datafusion/sql/src/expr/mod.rs | 6 +- datafusion/sql/src/unparser/expr.rs | 18 +- datafusion/sql/tests/common/mod.rs | 7 +- .../src/logical_plan/producer/expr/mod.rs | 5 +- .../producer/expr/scalar_function.rs | 23 +- .../producer/substrait_producer.rs | 15 +- 41 files changed, 1812 insertions(+), 82 deletions(-) create mode 100644 datafusion/expr/src/udlf.rs create mode 100644 datafusion/physical-expr/src/lambda_function.rs diff --git a/datafusion-examples/examples/sql_frontend.rs b/datafusion-examples/examples/sql_frontend.rs index 1fc9ce24ecbb5..e6341080a9e11 100644 --- a/datafusion-examples/examples/sql_frontend.rs +++ b/datafusion-examples/examples/sql_frontend.rs @@ -20,8 +20,8 @@ use datafusion::common::{plan_err, TableReference}; use datafusion::config::ConfigOptions; use datafusion::error::Result; use datafusion::logical_expr::{ - AggregateUDF, Expr, LogicalPlan, ScalarUDF, TableProviderFilterPushDown, TableSource, - WindowUDF, + AggregateUDF, Expr, LambdaUDF, LogicalPlan, ScalarUDF, TableProviderFilterPushDown, + TableSource, WindowUDF, }; use datafusion::optimizer::{ Analyzer, AnalyzerRule, Optimizer, OptimizerConfig, OptimizerContext, OptimizerRule, @@ -153,6 +153,10 @@ impl ContextProvider for MyContextProvider { None } + fn get_lambda_meta(&self, _name: &str) -> Option> { + None + } + fn get_aggregate_meta(&self, _name: &str) -> Option> { None } diff --git a/datafusion/catalog-listing/src/helpers.rs b/datafusion/catalog-listing/src/helpers.rs index eca681e3c604c..34fee4eb6bd41 100644 --- a/datafusion/catalog-listing/src/helpers.rs +++ b/datafusion/catalog-listing/src/helpers.rs @@ -100,6 +100,16 @@ pub fn expr_applicable_for_cols(col_names: &[&str], expr: &Expr) -> bool { } } } + Expr::LambdaFunction(lambda_function) => { + match lambda_function.func.signature().volatility { + Volatility::Immutable => Ok(TreeNodeRecursion::Continue), + // TODO: Stable functions could be `applicable`, but that would require access to the context + Volatility::Stable | Volatility::Volatile => { + is_applicable = false; + Ok(TreeNodeRecursion::Stop) + } + } + } // TODO other expressions are not handled yet: // - AGGREGATE and WINDOW should not end up in filter conditions, except maybe in some edge cases @@ -555,7 +565,7 @@ mod tests { use super::*; use datafusion_expr::{ - case, col, lit, AggregateUDF, Expr, LogicalPlan, ScalarUDF, WindowUDF, + case, col, lit, AggregateUDF, Expr, LambdaUDF, LogicalPlan, ScalarUDF, WindowUDF, }; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_plan::ExecutionPlan; @@ -1066,6 +1076,12 @@ mod tests { unimplemented!() } + fn lambda_functions( + &self, + ) -> &std::collections::HashMap> { + unimplemented!() + } + fn aggregate_functions( &self, ) -> &std::collections::HashMap> { diff --git a/datafusion/core/src/bin/print_functions_docs.rs b/datafusion/core/src/bin/print_functions_docs.rs index 63387c023b11a..97282edf49381 100644 --- a/datafusion/core/src/bin/print_functions_docs.rs +++ b/datafusion/core/src/bin/print_functions_docs.rs @@ -18,8 +18,7 @@ use datafusion::execution::SessionStateDefaults; use datafusion_common::{not_impl_err, HashSet, Result}; use datafusion_expr::{ - aggregate_doc_sections, scalar_doc_sections, window_doc_sections, AggregateUDF, - DocSection, Documentation, ScalarUDF, WindowUDF, + AggregateUDF, DocSection, Documentation, LambdaUDF, ScalarUDF, WindowUDF, aggregate_doc_sections, scalar_doc_sections, window_doc_sections }; use itertools::Itertools; use std::env::args; @@ -303,6 +302,18 @@ impl DocProvider for WindowUDF { } } +impl DocProvider for dyn LambdaUDF { + fn get_name(&self) -> String { + self.name().to_string() + } + fn get_aliases(&self) -> Vec { + self.aliases().iter().map(|a| a.to_string()).collect() + } + fn get_documentation(&self) -> Option<&Documentation> { + self.documentation() + } +} + #[allow(clippy::borrowed_box)] #[allow(clippy::ptr_arg)] fn get_names_and_aliases(functions: &Vec<&Box>) -> Vec { diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 687779787ab50..083ecdaf575af 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -75,6 +75,7 @@ use datafusion_common::{ pub use datafusion_execution::config::SessionConfig; use datafusion_execution::registry::SerializerRegistry; pub use datafusion_execution::TaskContext; +use datafusion_expr::LambdaUDF; pub use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::{ expr_rewriter::FunctionRewrite, @@ -1786,6 +1787,21 @@ impl FunctionRegistry for SessionContext { fn udwfs(&self) -> HashSet { self.state.read().udwfs() } + + fn udlfs(&self) -> HashSet { + self.state.read().udlfs() + } + + fn udlf(&self, name: &str) -> Result> { + self.state.read().udlf(name) + } + + fn register_udlf( + &mut self, + udlf: Arc, + ) -> Result>> { + self.state.write().register_udlf(udlf) + } } /// Create a new task context instance from SessionContext diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index c15b7eae08432..f33f3d3412f4d 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -59,7 +59,7 @@ use datafusion_expr::simplify::SimplifyInfo; #[cfg(feature = "sql")] use datafusion_expr::TableSource; use datafusion_expr::{ - AggregateUDF, Explain, Expr, ExprSchemable, LogicalPlan, ScalarUDF, WindowUDF, + AggregateUDF, Explain, Expr, ExprSchemable, LambdaUDF, LogicalPlan, ScalarUDF, WindowUDF }; use datafusion_optimizer::simplify_expressions::ExprSimplifier; use datafusion_optimizer::{ @@ -154,6 +154,8 @@ pub struct SessionState { table_functions: HashMap>, /// Scalar functions that are registered with the context scalar_functions: HashMap>, + /// Lambda functions that are registered with the context + lambda_functions: HashMap>, /// Aggregate functions registered in the context aggregate_functions: HashMap>, /// Window functions registered in the context @@ -252,6 +254,10 @@ impl Session for SessionState { fn scalar_functions(&self) -> &HashMap> { &self.scalar_functions } + + fn lambda_functions(&self) -> &HashMap> { + &self.lambda_functions + } fn aggregate_functions(&self) -> &HashMap> { &self.aggregate_functions @@ -921,6 +927,7 @@ pub struct SessionStateBuilder { catalog_list: Option>, table_functions: Option>>, scalar_functions: Option>>, + lambda_functions: Option>>, aggregate_functions: Option>>, window_functions: Option>>, serializer_registry: Option>, @@ -958,6 +965,7 @@ impl SessionStateBuilder { catalog_list: None, table_functions: None, scalar_functions: None, + lambda_functions: None, aggregate_functions: None, window_functions: None, serializer_registry: None, @@ -1008,6 +1016,7 @@ impl SessionStateBuilder { catalog_list: Some(existing.catalog_list), table_functions: Some(existing.table_functions), scalar_functions: Some(existing.scalar_functions.into_values().collect_vec()), + lambda_functions: Some(existing.lambda_functions.into_values().collect_vec()), aggregate_functions: Some( existing.aggregate_functions.into_values().collect_vec(), ), @@ -1048,6 +1057,10 @@ impl SessionStateBuilder { self.scalar_functions .get_or_insert_with(Vec::new) .extend(SessionStateDefaults::default_scalar_functions()); + + self.lambda_functions + .get_or_insert_with(Vec::new) + .extend(SessionStateDefaults::default_lambda_functions()); self.aggregate_functions .get_or_insert_with(Vec::new) @@ -1362,6 +1375,7 @@ impl SessionStateBuilder { catalog_list, table_functions, scalar_functions, + lambda_functions, aggregate_functions, window_functions, serializer_registry, @@ -1395,6 +1409,7 @@ impl SessionStateBuilder { }), table_functions: table_functions.unwrap_or_default(), scalar_functions: HashMap::new(), + lambda_functions: HashMap::new(), aggregate_functions: HashMap::new(), window_functions: HashMap::new(), serializer_registry: serializer_registry @@ -1446,6 +1461,34 @@ impl SessionStateBuilder { } } } + + if let Some(lambda_functions) = lambda_functions { + for udlf in lambda_functions { + let config_options = state.config().options(); + match udlf.with_updated_config(config_options) { + Some(new_udf) => { + if let Err(err) = state.register_udlf(new_udf) { + debug!( + "Failed to re-register updated UDLF '{}': {}", + udlf.name(), + err + ); + } + } + None => match state.register_udlf(Arc::clone(&udlf)) { + Ok(Some(existing)) => { + debug!("Overwrote existing UDLF '{}'", existing.name()); + } + Ok(None) => { + debug!("Registered UDLF '{}'", udlf.name()); + } + Err(err) => { + debug!("Failed to register UDLF '{}': {}", udlf.name(), err); + } + }, + } + } + } if let Some(aggregate_functions) = aggregate_functions { aggregate_functions.into_iter().for_each(|udaf| { @@ -1661,6 +1704,7 @@ impl Debug for SessionStateBuilder { .field("physical_optimizers", &self.physical_optimizers) .field("table_functions", &self.table_functions) .field("scalar_functions", &self.scalar_functions) + .field("lambda_functions", &self.lambda_functions) .field("aggregate_functions", &self.aggregate_functions) .field("window_functions", &self.window_functions) .finish() @@ -1755,6 +1799,10 @@ impl ContextProvider for SessionContextProvider<'_> { self.state.scalar_functions().get(name).cloned() } + fn get_lambda_meta(&self, name: &str) -> Option> { + self.state.lambda_functions().get(name).cloned() + } + fn get_aggregate_meta(&self, name: &str) -> Option> { self.state.aggregate_functions().get(name).cloned() } @@ -1918,6 +1966,37 @@ impl FunctionRegistry for SessionState { Ok(udwf) } + fn udlfs(&self) -> HashSet { + self.lambda_functions.keys().cloned().collect() + } + + fn udlf(&self, name: &str) -> datafusion_common::Result> { + self.lambda_functions + .get(name) + .cloned() + .ok_or_else(|| plan_datafusion_err!("Lambda Function {name} not found")) + } + + fn register_udlf( + &mut self, + udlf: Arc, + ) -> datafusion_common::Result>> { + Ok(self.lambda_functions.insert(udlf.name().into(), udlf)) + } + + fn deregister_udlf( + &mut self, + name: &str, + ) -> datafusion_common::Result>> { + let udlf = self.lambda_functions.remove(name); + if let Some(udlf) = &udlf { + for alias in udlf.aliases() { + self.lambda_functions.remove(alias); + } + } + Ok(udlf) + } + fn register_function_rewrite( &mut self, rewrite: Arc, @@ -1974,6 +2053,7 @@ impl From<&SessionState> for TaskContext { state.session_id.clone(), state.config.clone(), state.scalar_functions.clone(), + state.lambda_functions.clone(), state.aggregate_functions.clone(), state.window_functions.clone(), Arc::clone(&state.runtime_env), @@ -2062,6 +2142,7 @@ mod tests { use datafusion_optimizer::optimizer::OptimizerRule; use datafusion_optimizer::Optimizer; use datafusion_physical_plan::display::DisplayableExecutionPlan; + use datafusion_session::Session; use datafusion_sql::planner::{PlannerContext, SqlToRel}; use std::collections::HashMap; use std::sync::Arc; @@ -2338,6 +2419,10 @@ mod tests { self.state.scalar_functions().get(name).cloned() } + fn get_lambda_meta(&self, name: &str) -> Option> { + self.state.lambda_functions().get(name).cloned() + } + fn get_aggregate_meta(&self, name: &str) -> Option> { self.state.aggregate_functions().get(name).cloned() } diff --git a/datafusion/core/src/execution/session_state_defaults.rs b/datafusion/core/src/execution/session_state_defaults.rs index 62a575541a5d8..54037c0a96f9c 100644 --- a/datafusion/core/src/execution/session_state_defaults.rs +++ b/datafusion/core/src/execution/session_state_defaults.rs @@ -36,7 +36,8 @@ use datafusion_execution::config::SessionConfig; use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_expr::planner::ExprPlanner; -use datafusion_expr::{AggregateUDF, ScalarUDF, WindowUDF}; +use datafusion_expr::{AggregateUDF, LambdaUDF, ScalarUDF, WindowUDF}; +use datafusion_functions_nested::array_transform::ArrayTransform; use std::collections::HashMap; use std::sync::Arc; use url::Url; @@ -112,6 +113,11 @@ impl SessionStateDefaults { functions } + /// returns the list of default [`LambdaUDF`]s + pub fn default_lambda_functions() -> Vec> { + vec![Arc::new(ArrayTransform::new())] + } + /// returns the list of default [`AggregateUDF`]s pub fn default_aggregate_functions() -> Vec> { functions_aggregate::all_default_aggregate_functions() diff --git a/datafusion/core/tests/optimizer/mod.rs b/datafusion/core/tests/optimizer/mod.rs index 9b2a5596827d0..44e40143fe171 100644 --- a/datafusion/core/tests/optimizer/mod.rs +++ b/datafusion/core/tests/optimizer/mod.rs @@ -31,8 +31,7 @@ use datafusion_common::tree_node::{TransformedResult, TreeNode}; use datafusion_common::{plan_err, DFSchema, Result, ScalarValue, TableReference}; use datafusion_expr::interval_arithmetic::{Interval, NullableInterval}; use datafusion_expr::{ - col, lit, AggregateUDF, BinaryExpr, Expr, ExprSchemable, LogicalPlan, Operator, - ScalarUDF, TableSource, WindowUDF, + AggregateUDF, BinaryExpr, Expr, ExprSchemable, LambdaUDF, LogicalPlan, Operator, ScalarUDF, TableSource, WindowUDF, col, lit }; use datafusion_functions::core::expr_ext::FieldAccessor; use datafusion_optimizer::analyzer::Analyzer; @@ -217,6 +216,10 @@ impl ContextProvider for MyContextProvider { self.udfs.get(name).cloned() } + fn get_lambda_meta(&self, _name: &str) -> Option> { + None + } + fn get_aggregate_meta(&self, _name: &str) -> Option> { None } diff --git a/datafusion/datasource-arrow/src/file_format.rs b/datafusion/datasource-arrow/src/file_format.rs index 3b85640804219..31a880e688aeb 100644 --- a/datafusion/datasource-arrow/src/file_format.rs +++ b/datafusion/datasource-arrow/src/file_format.rs @@ -442,7 +442,7 @@ mod tests { use datafusion_execution::config::SessionConfig; use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_expr::execution_props::ExecutionProps; - use datafusion_expr::{AggregateUDF, Expr, LogicalPlan, ScalarUDF, WindowUDF}; + use datafusion_expr::{AggregateUDF, Expr, LambdaUDF, LogicalPlan, ScalarUDF, WindowUDF}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use object_store::{chunked::ChunkedStore, memory::InMemory, path::Path}; @@ -488,6 +488,10 @@ mod tests { fn scalar_functions(&self) -> &HashMap> { unimplemented!() } + + fn lambda_functions(&self) -> &HashMap> { + unimplemented!() + } fn aggregate_functions(&self) -> &HashMap> { unimplemented!() diff --git a/datafusion/datasource/src/url.rs b/datafusion/datasource/src/url.rs index 08e5b6a5df83a..b0b84bd3cc943 100644 --- a/datafusion/datasource/src/url.rs +++ b/datafusion/datasource/src/url.rs @@ -415,7 +415,7 @@ mod tests { use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_execution::TaskContext; use datafusion_expr::execution_props::ExecutionProps; - use datafusion_expr::{AggregateUDF, Expr, LogicalPlan, ScalarUDF, WindowUDF}; + use datafusion_expr::{AggregateUDF, Expr, LambdaUDF, LogicalPlan, ScalarUDF, WindowUDF}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_plan::ExecutionPlan; use object_store::{ @@ -874,6 +874,10 @@ mod tests { unimplemented!() } + fn lambda_functions(&self) -> &HashMap> { + unimplemented!() + } + fn aggregate_functions(&self) -> &HashMap> { unimplemented!() } diff --git a/datafusion/execution/src/task.rs b/datafusion/execution/src/task.rs index c2a6cfe2c833f..70c59b6375943 100644 --- a/datafusion/execution/src/task.rs +++ b/datafusion/execution/src/task.rs @@ -21,7 +21,7 @@ use crate::{ }; use datafusion_common::{internal_datafusion_err, plan_datafusion_err, Result}; use datafusion_expr::planner::ExprPlanner; -use datafusion_expr::{AggregateUDF, ScalarUDF, WindowUDF}; +use datafusion_expr::{AggregateUDF, ScalarUDF, LambdaUDF, WindowUDF}; use std::collections::HashSet; use std::{collections::HashMap, sync::Arc}; @@ -42,6 +42,8 @@ pub struct TaskContext { session_config: SessionConfig, /// Scalar functions associated with this task context scalar_functions: HashMap>, + /// Lambda functions associated with this task context + lambda_functions: HashMap>, /// Aggregate functions associated with this task context aggregate_functions: HashMap>, /// Window functions associated with this task context @@ -60,6 +62,7 @@ impl Default for TaskContext { task_id: None, session_config: SessionConfig::new(), scalar_functions: HashMap::new(), + lambda_functions: HashMap::new(), aggregate_functions: HashMap::new(), window_functions: HashMap::new(), runtime, @@ -73,11 +76,13 @@ impl TaskContext { /// Most users will use [`SessionContext::task_ctx`] to create [`TaskContext`]s /// /// [`SessionContext::task_ctx`]: https://docs.rs/datafusion/latest/datafusion/execution/context/struct.SessionContext.html#method.task_ctx + #[allow(clippy::too_many_arguments)] pub fn new( task_id: Option, session_id: String, session_config: SessionConfig, scalar_functions: HashMap>, + lambda_functions: HashMap>, aggregate_functions: HashMap>, window_functions: HashMap>, runtime: Arc, @@ -87,6 +92,7 @@ impl TaskContext { session_id, session_config, scalar_functions, + lambda_functions, aggregate_functions, window_functions, runtime, @@ -198,6 +204,37 @@ impl FunctionRegistry for TaskContext { Ok(self.scalar_functions.insert(udf.name().into(), udf)) } + fn udlfs(&self) -> HashSet { + self.lambda_functions.keys().cloned().collect() + } + + fn udlf(&self, name: &str) -> Result> { + self.lambda_functions + .get(name) + .cloned() + .ok_or_else(|| plan_datafusion_err!("Lambda Function {name} not found")) + } + + fn register_udlf( + &mut self, + udlf: Arc, + ) -> Result>> { + Ok(self.lambda_functions.insert(udlf.name().into(), udlf)) + } + + fn deregister_udlf( + &mut self, + name: &str, + ) -> Result>> { + let udlf = self.lambda_functions.remove(name); + if let Some(udlf) = &udlf { + for alias in udlf.aliases() { + self.lambda_functions.remove(alias); + } + } + Ok(udlf) + } + fn expr_planners(&self) -> Vec> { vec![] } @@ -248,6 +285,7 @@ mod tests { HashMap::default(), HashMap::default(), HashMap::default(), + HashMap::default(), runtime, ); @@ -280,6 +318,7 @@ mod tests { HashMap::default(), HashMap::default(), HashMap::default(), + HashMap::default(), runtime, ); diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 07f1bc129c597..ef4ba4ef5cdfd 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -27,6 +27,7 @@ use std::sync::Arc; use crate::expr_fn::binary_expr; use crate::function::WindowFunctionSimplification; use crate::logical_plan::Subquery; +use crate::udlf::LambdaUDF; use crate::{AggregateUDF, Volatility}; use crate::{ExprSchemable, Operator, Signature, WindowFrame, WindowUDF}; @@ -398,11 +399,41 @@ pub enum Expr { OuterReferenceColumn(FieldRef, Column), /// Unnest expression Unnest(Unnest), + LambdaFunction(LambdaFunction), /// Lambda expression Lambda(Lambda), LambdaVariable(LambdaVariable), } +#[derive(Clone, Eq, PartialOrd, Debug)] +pub struct LambdaFunction { + pub func: Arc, + pub args: Vec, +} + +impl LambdaFunction { + pub fn new(func: Arc, args: Vec) -> Self { + Self { func, args } + } + + pub fn name(&self) -> &str { + self.func.name() + } +} + +impl Hash for LambdaFunction { + fn hash(&self, state: &mut H) { + self.func.hash(state); + self.args.hash(state); + } +} + +impl PartialEq for LambdaFunction { + fn eq(&self, other: &Self) -> bool { + self.func.as_ref() == other.func.as_ref() && self.args == other.args + } +} + #[derive(Clone, PartialEq, PartialOrd, Eq, Debug, Hash)] pub struct LambdaVariable { pub name: String, @@ -1566,6 +1597,7 @@ impl Expr { #[expect(deprecated)] Expr::Wildcard { .. } => "Wildcard", Expr::Unnest { .. } => "Unnest", + Expr::LambdaFunction { .. } => "LambdaFunction", Expr::Lambda { .. } => "Lambda", Expr::LambdaVariable { .. } => "LambdaVariable", } @@ -2083,6 +2115,7 @@ impl Expr { pub fn short_circuits(&self) -> bool { match self { Expr::ScalarFunction(ScalarFunction { func, .. }) => func.short_circuits(), + Expr::LambdaFunction(LambdaFunction { func, .. }) => func.short_circuits(), Expr::BinaryExpr(BinaryExpr { op, .. }) => { matches!(op, Operator::And | Operator::Or) } @@ -2719,6 +2752,9 @@ impl HashNode for Expr { column.hash(state); } Expr::Unnest(Unnest { expr: _expr }) => {} + Expr::LambdaFunction(LambdaFunction { func, args: _args }) => { + func.hash(state); + } Expr::Lambda(Lambda { params, body: _ }) => { params.hash(state); } @@ -3043,6 +3079,16 @@ impl Display for SchemaDisplay<'_> { } } } + Expr::LambdaFunction(LambdaFunction { func, args }) => { + match func.schema_name(args) { + Ok(name) => { + write!(f, "{name}") + } + Err(e) => { + write!(f, "got error from schema_name {e}") + } + } + } Expr::Lambda(Lambda { params, body }) => { write!(f, "({}) -> {body}", display_comma_separated(params)) } @@ -3539,6 +3585,9 @@ impl Display for Expr { Expr::Unnest(Unnest { expr }) => { write!(f, "{UNNEST_COLUMN_PREFIX}({expr})") } + Expr::LambdaFunction(fun) => { + fmt_function(f, fun.name(), false, &fun.args, true) + } Expr::Lambda(Lambda { params, body }) => { write!(f, "({}) -> {body}", params.join(", ")) } diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 3e3ff7dacb9d8..6a3a7cbc85e76 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -25,6 +25,7 @@ use crate::expr::{FieldMetadata, LambdaVariable}; use crate::type_coercion::functions::{ fields_with_aggregate_udf, fields_with_window_udf, }; +use crate::udlf::{LambdaReturnFieldArgs, ValueOrLambdaField}; use crate::{ type_coercion::functions::data_types_with_scalar_udf, udf::ReturnFieldArgs, utils, LogicalPlan, Projection, Subquery, WindowFunctionDefinition, @@ -234,6 +235,10 @@ impl ExprSchemable for Expr { // Grouping sets do not really have a type and do not appear in projections Ok(DataType::Null) } + Expr::LambdaFunction(_func) => { + let (return_type, _) = self.data_type_and_nullable(schema)?; + Ok(return_type) + } Expr::Lambda(Lambda { params: _, body }) => body.get_type(schema), Expr::LambdaVariable(LambdaVariable { name: _, field, .. }) => { Ok(field.data_type().clone()) @@ -356,6 +361,10 @@ impl ExprSchemable for Expr { // in projections Ok(true) } + Expr::LambdaFunction(_func) => { + let (_, nullable) = self.data_type_and_nullable(input_schema)?; + Ok(nullable) + } Expr::Lambda(l) => l.body.nullable(input_schema), Expr::LambdaVariable(c) => Ok(c.field.is_nullable()), } @@ -625,6 +634,36 @@ impl ExprSchemable for Expr { self.get_type(schema)?, self.nullable(schema)?, ))), + Expr::LambdaFunction(func) => { + let arg_fields = func + .args + .iter() + .map(|arg| { + let field = arg.to_field(schema)?.1; + match arg { + Expr::Lambda(_lambda) => { + Ok(ValueOrLambdaField::Lambda(field)) + } + _ => Ok(ValueOrLambdaField::Value(field)), + } + }) + .collect::>>()?; + + let arguments = func.args + .iter() + .map(|e| match e { + Expr::Literal(sv, _) => Some(sv), + _ => None, + }) + .collect::>(); + + let args = LambdaReturnFieldArgs { + arg_fields: &arg_fields, + scalar_arguments: &arguments, + }; + + func.func.return_field_from_args(args) + } Expr::LambdaVariable(c) => Ok(Arc::clone(&c.field)), }?; diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 0f26218e74779..e6dfd9fc1483b 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -42,6 +42,7 @@ mod partition_evaluator; mod table_source; mod udaf; mod udf; +mod udlf; mod udwf; pub mod arguments; @@ -117,9 +118,10 @@ pub use udaf::{ udaf_default_window_function_schema_name, AggregateUDF, AggregateUDFImpl, ReversedUDAF, SetMonotonicity, StatisticsArgs, }; -pub use udf::{ - ReturnFieldArgs, ScalarFunctionArgs, ScalarFunctionLambdaArg, ScalarUDF, - ScalarUDFImpl, ValueOrLambda, ValueOrLambdaField, ValueOrLambdaParameter, +pub use udf::{ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl}; +pub use udlf::{ + LambdaFunctionArgs, LambdaFunctionLambdaArg, LambdaReturnFieldArgs, LambdaUDF, + ValueOrLambda, ValueOrLambdaField, ValueOrLambdaParameter, }; pub use udwf::{LimitEffect, ReversedUDWF, WindowUDF, WindowUDFImpl}; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; diff --git a/datafusion/expr/src/planner.rs b/datafusion/expr/src/planner.rs index 25a0f83947eee..7696faca0922a 100644 --- a/datafusion/expr/src/planner.rs +++ b/datafusion/expr/src/planner.rs @@ -22,8 +22,7 @@ use std::sync::Arc; use crate::expr::NullTreatment; use crate::{ - AggregateUDF, Expr, GetFieldAccess, ScalarUDF, SortExpr, TableSource, WindowFrame, - WindowFunctionDefinition, WindowUDF, + AggregateUDF, Expr, GetFieldAccess, LambdaUDF, ScalarUDF, SortExpr, TableSource, WindowFrame, WindowFunctionDefinition, WindowUDF }; use arrow::datatypes::{DataType, Field, SchemaRef}; use datafusion_common::{ @@ -91,6 +90,9 @@ pub trait ContextProvider { /// Return the scalar function with a given name, if any fn get_function_meta(&self, name: &str) -> Option>; + + /// Return the lambda function with a given name, if any + fn get_lambda_meta(&self, name: &str) -> Option>; /// Return the aggregate function with a given name, if any fn get_aggregate_meta(&self, name: &str) -> Option>; diff --git a/datafusion/expr/src/registry.rs b/datafusion/expr/src/registry.rs index 9554dd68e1758..92aa39d64c98d 100644 --- a/datafusion/expr/src/registry.rs +++ b/datafusion/expr/src/registry.rs @@ -19,6 +19,7 @@ use crate::expr_rewriter::FunctionRewrite; use crate::planner::ExprPlanner; +use crate::udlf::LambdaUDF; use crate::{AggregateUDF, ScalarUDF, UserDefinedLogicalNode, WindowUDF}; use datafusion_common::{not_impl_err, plan_datafusion_err, HashMap, Result}; use std::collections::HashSet; @@ -30,6 +31,9 @@ pub trait FunctionRegistry { /// Returns names of all available scalar user defined functions. fn udfs(&self) -> HashSet; + /// Returns names of all available lambda user defined functions. + fn udlfs(&self) -> HashSet; + /// Returns names of all available aggregate user defined functions. fn udafs(&self) -> HashSet; @@ -40,6 +44,10 @@ pub trait FunctionRegistry { /// `name`. fn udf(&self, name: &str) -> Result>; + /// Returns a reference to the user defined lambda function (udf) named + /// `name`. + fn udlf(&self, name: &str) -> Result>; + /// Returns a reference to the user defined aggregate function (udaf) named /// `name`. fn udaf(&self, name: &str) -> Result>; @@ -56,6 +64,17 @@ pub trait FunctionRegistry { fn register_udf(&mut self, _udf: Arc) -> Result>> { not_impl_err!("Registering ScalarUDF") } + /// Registers a new [`LambdaUDF`], returning any previously registered + /// implementation. + /// + /// Returns an error (the default) if the function can not be registered, + /// for example if the registry is read only. + fn register_udlf( + &mut self, + _udlf: Arc, + ) -> Result>> { + not_impl_err!("Registering LambdaUDF") + } /// Registers a new [`AggregateUDF`], returning any previously registered /// implementation. /// @@ -85,6 +104,15 @@ pub trait FunctionRegistry { not_impl_err!("Deregistering ScalarUDF") } + /// Deregisters a [`LambdaUDF`], returning the implementation that was + /// deregistered. + /// + /// Returns an error (the default) if the function can not be deregistered, + /// for example if the registry is read only. + fn deregister_udlf(&mut self, _name: &str) -> Result>> { + not_impl_err!("Deregistering LambdaUDF") + } + /// Deregisters a [`AggregateUDF`], returning the implementation that was /// deregistered. /// @@ -152,6 +180,8 @@ pub trait SerializerRegistry: Debug + Send + Sync { pub struct MemoryFunctionRegistry { /// Scalar Functions udfs: HashMap>, + /// Lambda Functions + udlfs: HashMap>, /// Aggregate Functions udafs: HashMap>, /// Window Functions @@ -214,4 +244,22 @@ impl FunctionRegistry for MemoryFunctionRegistry { fn udwfs(&self) -> HashSet { self.udwfs.keys().cloned().collect() } + + fn udlfs(&self) -> HashSet { + self.udlfs.keys().cloned().collect() + } + + fn udlf(&self, name: &str) -> Result> { + self.udlfs + .get(name) + .cloned() + .ok_or_else(|| plan_datafusion_err!("Lambda Function {name} not found")) + } + + fn register_udlf( + &mut self, + udlf: Arc, + ) -> Result>> { + Ok(self.udlfs.insert(udlf.name().into(), udlf)) + } } diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs index a818c32948d09..82179f095937b 100644 --- a/datafusion/expr/src/tree_node.rs +++ b/datafusion/expr/src/tree_node.rs @@ -20,8 +20,8 @@ use crate::{ expr::{ AggregateFunction, AggregateFunctionParams, Alias, Between, BinaryExpr, Case, - Cast, GroupingSet, InList, InSubquery, Lambda, Like, Placeholder, ScalarFunction, - TryCast, Unnest, WindowFunction, WindowFunctionParams, + Cast, GroupingSet, InList, InSubquery, Lambda, LambdaFunction, Like, Placeholder, + ScalarFunction, TryCast, Unnest, WindowFunction, WindowFunctionParams, }, Expr, }; @@ -110,6 +110,7 @@ impl TreeNode for Expr { Expr::InList(InList { expr, list, .. }) => { (expr, list).apply_ref_elements(f) } + Expr::LambdaFunction(LambdaFunction { func: _, args}) => args.apply_elements(f), Expr::Lambda (Lambda{ params: _, body}) => body.apply_elements(f) } } @@ -317,6 +318,9 @@ impl TreeNode for Expr { .update_data(|(new_expr, new_list)| { Expr::InList(InList::new(new_expr, new_list, negated)) }), + Expr::LambdaFunction(LambdaFunction { func, args }) => args + .map_elements(f)? + .update_data(|args| Expr::LambdaFunction(LambdaFunction { func, args })), Expr::Lambda(Lambda { params, body }) => body .map_elements(f)? .update_data(|body| Expr::Lambda(Lambda { params, body })), diff --git a/datafusion/expr/src/udf_eq.rs b/datafusion/expr/src/udf_eq.rs index 6664495267129..a003f05c15b5e 100644 --- a/datafusion/expr/src/udf_eq.rs +++ b/datafusion/expr/src/udf_eq.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::{AggregateUDFImpl, ScalarUDFImpl, WindowUDFImpl}; +use crate::{AggregateUDFImpl, LambdaUDF, ScalarUDFImpl, WindowUDFImpl}; use std::fmt::Debug; use std::hash::{DefaultHasher, Hash, Hasher}; use std::ops::Deref; @@ -93,6 +93,18 @@ impl UdfPointer for Arc { } } +impl UdfPointer for Arc { + fn equals(&self, other: &Self::Target) -> bool { + self.as_ref().dyn_eq(other.as_any()) + } + + fn hash_value(&self) -> u64 { + let hasher = &mut DefaultHasher::new(); + self.as_ref().dyn_hash(hasher); + hasher.finish() + } +} + impl UdfPointer for Arc { fn equals(&self, other: &(dyn AggregateUDFImpl + '_)) -> bool { self.as_ref().dyn_eq(other.as_any()) diff --git a/datafusion/expr/src/udlf.rs b/datafusion/expr/src/udlf.rs new file mode 100644 index 0000000000000..84f9494d2edd8 --- /dev/null +++ b/datafusion/expr/src/udlf.rs @@ -0,0 +1,649 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`LambdaUDF`]: Lambda User Defined Functions + +use crate::expr::schema_name_from_exprs_comma_separated_without_space; +use crate::simplify::{ExprSimplifyResult, SimplifyInfo}; +use crate::sort_properties::{ExprProperties, SortProperties}; +use crate::{ColumnarValue, Documentation, Expr}; +use arrow::array::RecordBatch; +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion_common::config::ConfigOptions; +use datafusion_common::{Result, ScalarValue, not_impl_err}; +use datafusion_expr_common::dyn_eq::{DynEq, DynHash}; +use datafusion_expr_common::interval_arithmetic::Interval; +use datafusion_expr_common::signature::Signature; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use std::any::Any; +use std::cmp::Ordering; +use std::fmt::Debug; +use std::hash::{Hash, Hasher}; +use std::sync::Arc; + +impl PartialEq for dyn LambdaUDF { + fn eq(&self, other: &Self) -> bool { + self.dyn_eq(other.as_any()) + } +} + +impl PartialOrd for dyn LambdaUDF { + fn partial_cmp(&self, other: &Self) -> Option { + let mut cmp = self.name().cmp(other.name()); + if cmp == Ordering::Equal { + cmp = self.signature().partial_cmp(other.signature())?; + } + if cmp == Ordering::Equal { + cmp = self.aliases().partial_cmp(other.aliases())?; + } + // Contract for PartialOrd and PartialEq consistency requires that + // a == b if and only if partial_cmp(a, b) == Some(Equal). + if cmp == Ordering::Equal && self != other { + // Functions may have other properties besides name and signature + // that differentiate two instances (e.g. type, or arbitrary parameters). + // We cannot return Some(Equal) in such case. + return None; + } + debug_assert!( + cmp == Ordering::Equal || self != other, + "Detected incorrect implementation of PartialEq when comparing functions: '{}' and '{}'. \ + The functions compare as equal, but they are not equal based on general properties that \ + the PartialOrd implementation observes,", + self.name(), other.name() + ); + Some(cmp) + } +} + +impl Eq for dyn LambdaUDF {} + +impl Hash for dyn LambdaUDF { + fn hash(&self, state: &mut H) { + self.dyn_hash(state) + } +} + +#[derive(Clone, Debug)] +pub enum ValueOrLambdaParameter { + /// A columnar value with the given field + Value(FieldRef), + /// A lambda + Lambda, +} + +/// Arguments passed to [`LambdaUDF::invoke_with_args`] when invoking a +/// lambda function. +#[derive(Debug, Clone)] +pub struct LambdaFunctionArgs { + /// The evaluated arguments to the function + pub args: Vec, + /// Field associated with each arg, if it exists + pub arg_fields: Vec, + /// The number of rows in record batch being evaluated + pub number_rows: usize, + /// The return field of the lambda function returned (from `return_type` + /// or `return_field_from_args`) when creating the physical expression + /// from the logical expression + pub return_field: FieldRef, + /// The config options at execution time + pub config_options: Arc, +} + +/// A lambda argument to a LambdaFunction +#[derive(Clone, Debug)] +pub struct LambdaFunctionLambdaArg { + /// The parameters defined in this lambda + /// + /// For example, for `array_transform([2], v -> -v)`, + /// this will be `vec![Field::new("v", DataType::Int32, true)]` + pub params: Vec, + /// The body of the lambda + /// + /// For example, for `array_transform([2], v -> -v)`, + /// this will be the physical expression of `-v` + pub body: Arc, + /// A RecordBatch containing at least the captured columns inside this lambda body, if any + /// Note that it may contain additional, non-specified columns, but that's implementation detail + /// + /// For example, for `array_transform([2], v -> v + a + b)`, + /// this will be a `RecordBatch` with two columns, `a` and `b` + pub captures: Option, +} + +impl LambdaFunctionArgs { + /// The return type of the function. See [`Self::return_field`] for more + /// details. + pub fn return_type(&self) -> &DataType { + self.return_field.data_type() + } +} + +// An argument to a LambdaUDF that supports lambdas +#[derive(Clone, Debug)] +pub enum ValueOrLambda { + Value(ColumnarValue), + Lambda(LambdaFunctionLambdaArg), +} + +/// Information about arguments passed to the function +/// +/// This structure contains metadata about how the function was called +/// such as the type of the arguments, any scalar arguments and if the +/// arguments can (ever) be null +/// +/// See [`LambdaUDF::return_field_from_args`] for more information +#[derive(Clone, Debug)] +pub struct LambdaReturnFieldArgs<'a> { + /// The data types of the arguments to the function + /// + /// If argument `i` to the function is a lambda, it will be the field returned by the + /// lambda when executed with the arguments returned from `LambdaUDF::lambdas_parameters` + /// + /// For example, with `array_transform([1], v -> v == 5)` + /// this field will be `[Field::new("", DataType::List(DataType::Int32), false), Field::new("", DataType::Boolean, false)]` + pub arg_fields: &'a [ValueOrLambdaField], + /// Is argument `i` to the function a scalar (constant)? + /// + /// If the argument `i` is not a scalar, it will be None + /// + /// For example, if a function is called like `my_function(column_a, 5)` + /// this field will be `[None, Some(ScalarValue::Int32(Some(5)))]` + pub scalar_arguments: &'a [Option<&'a ScalarValue>], +} + +/// A tagged Field indicating whether it correspond to a value or a lambda argument +#[derive(Clone, Debug)] +pub enum ValueOrLambdaField { + /// The Field of a ColumnarValue argument + Value(FieldRef), + /// The Field of the return of the lambda body when evaluated with the parameters from LambdaUDF::lambda_parameters + Lambda(FieldRef), +} + +/// Trait for implementing user defined lambda functions. +/// +/// This trait exposes the full API for implementing user defined functions and +/// can be used to implement any function. +/// +/// See [`advanced_udf.rs`] for a full example with complete implementation and +/// [`LambdaUDF`] for other available options. +/// +/// [`advanced_udf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/advanced_udf.rs +/// +/// # Basic Example +/// ``` +/// # use std::any::Any; +/// # use std::sync::LazyLock; +/// # use arrow::datatypes::DataType; +/// # use datafusion_common::{DataFusionError, plan_err, Result}; +/// # use datafusion_expr::{col, ColumnarValue, Documentation, LambdaFunctionArgs, Signature, Volatility}; +/// # use datafusion_expr::LambdaUDF; +/// # use datafusion_expr::lambda_doc_sections::DOC_SECTION_MATH; +/// /// This struct for a simple UDF that adds one to an int32 +/// #[derive(Debug, PartialEq, Eq, Hash)] +/// struct AddOne { +/// signature: Signature, +/// } +/// +/// impl AddOne { +/// fn new() -> Self { +/// Self { +/// signature: Signature::uniform(1, vec![DataType::Int32], Volatility::Immutable), +/// } +/// } +/// } +/// +/// static DOCUMENTATION: LazyLock = LazyLock::new(|| { +/// Documentation::builder(DOC_SECTION_MATH, "Add one to an int32", "add_one(2)") +/// .with_argument("arg1", "The int32 number to add one to") +/// .build() +/// }); +/// +/// fn get_doc() -> &'static Documentation { +/// &DOCUMENTATION +/// } +/// +/// /// Implement the LambdaUDF trait for AddOne +/// impl LambdaUDF for AddOne { +/// fn as_any(&self) -> &dyn Any { self } +/// fn name(&self) -> &str { "add_one" } +/// fn signature(&self) -> &Signature { &self.signature } +/// fn return_type(&self, args: &[DataType]) -> Result { +/// if !matches!(args.get(0), Some(&DataType::Int32)) { +/// return plan_err!("add_one only accepts Int32 arguments"); +/// } +/// Ok(DataType::Int32) +/// } +/// // The actual implementation would add one to the argument +/// fn invoke_with_args(&self, args: LambdaFunctionArgs) -> Result { +/// unimplemented!() +/// } +/// fn documentation(&self) -> Option<&Documentation> { +/// Some(get_doc()) +/// } +/// } +/// +/// // Create a new LambdaUDF from the implementation +/// let add_one = LambdaUDF::from(AddOne::new()); +/// +/// // Call the function `add_one(col)` +/// let expr = add_one.call(vec![col("a")]); +/// ``` +pub trait LambdaUDF: Debug + DynEq + DynHash + Send + Sync { + /// Returns this object as an [`Any`] trait object + fn as_any(&self) -> &dyn Any; + + /// Returns this function's name + fn name(&self) -> &str; + + /// Returns any aliases (alternate names) for this function. + /// + /// Aliases can be used to invoke the same function using different names. + /// For example in some databases `now()` and `current_timestamp()` are + /// aliases for the same function. This behavior can be obtained by + /// returning `current_timestamp` as an alias for the `now` function. + /// + /// Note: `aliases` should only include names other than [`Self::name`]. + /// Defaults to `[]` (no aliases) + fn aliases(&self) -> &[String] { + &[] + } + + /// Returns the name of the column this expression would create + /// + /// See [`Expr::schema_name`] for details + fn schema_name(&self, args: &[Expr]) -> Result { + Ok(format!( + "{}({})", + self.name(), + schema_name_from_exprs_comma_separated_without_space(args)? + )) + } + + /// Returns a [`Signature`] describing the argument types for which this + /// function has an implementation, and the function's [`Volatility`]. + /// + /// See [`Signature`] for more details on argument type handling + /// and [`Self::return_type`] for computing the return type. + /// + /// [`Volatility`]: datafusion_expr_common::signature::Volatility + fn signature(&self) -> &Signature; + + /// Create a new instance of this function with updated configuration. + /// + /// This method is called when configuration options change at runtime + /// (e.g., via `SET` statements) to allow functions that depend on + /// configuration to update themselves accordingly. + /// + /// Note the current [`ConfigOptions`] are also passed to [`Self::invoke_with_args`] so + /// this API is not needed for functions where the values may + /// depend on the current options. + /// + /// This API is useful for functions where the return + /// **type** depends on the configuration options, such as the `now()` function + /// which depends on the current timezone. + /// + /// # Arguments + /// + /// * `config` - The updated configuration options + /// + /// # Returns + /// + /// * `Some(LambdaUDF)` - A new instance of this function configured with the new settings + /// * `None` - If this function does not change with new configuration settings (the default) + fn with_updated_config(&self, _config: &ConfigOptions) -> Option> { + None + } + + /// What type will be returned by this function, given the arguments? + /// + /// By default, this function calls [`Self::return_type`] with the + /// types of each argument. + /// + /// # Notes + /// + /// For the majority of UDFs, implementing [`Self::return_type`] is sufficient, + /// as the result type is typically a deterministic function of the input types + /// (e.g., `sqrt(f32)` consistently yields `f32`). Implementing this method directly + /// is generally unnecessary unless the return type depends on runtime values. + /// + /// This function can be used for more advanced cases such as: + /// + /// 1. specifying nullability + /// 2. return types based on the **values** of the arguments (rather than + /// their **types**. + /// + /// # Example creating `Field` + /// + /// Note the name of the [`Field`] is ignored, except for structured types such as + /// `DataType::Struct`. + /// + /// ```rust + /// # use std::sync::Arc; + /// # use arrow::datatypes::{DataType, Field, FieldRef}; + /// # use datafusion_common::Result; + /// # use datafusion_expr::LambdaReturnFieldArgs; + /// # struct Example{} + /// # impl Example { + /// fn return_field_from_args(&self, args: LambdaReturnFieldArgs) -> Result { + /// // report output is only nullable if any one of the arguments are nullable + /// let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); + /// let field = Arc::new(Field::new("ignored_name", DataType::Int32, true)); + /// Ok(field) + /// } + /// # } + /// ``` + /// + /// # Output Type based on Values + /// + /// For example, the following two function calls get the same argument + /// types (something and a `Utf8` string) but return different types based + /// on the value of the second argument: + /// + /// * `arrow_cast(x, 'Int16')` --> `Int16` + /// * `arrow_cast(x, 'Float32')` --> `Float32` + /// + /// # Requirements + /// + /// This function **must** consistently return the same type for the same + /// logical input even if the input is simplified (e.g. it must return the same + /// value for `('foo' | 'bar')` as it does for ('foobar'). + fn return_field_from_args(&self, args: LambdaReturnFieldArgs) -> Result; + + /// Invoke the function returning the appropriate result. + /// + /// # Performance + /// + /// For the best performance, the implementations should handle the common case + /// when one or more of their arguments are constant values (aka + /// [`ColumnarValue::Scalar`]). + /// + /// [`ColumnarValue::values_to_arrays`] can be used to convert the arguments + /// to arrays, which will likely be simpler code, but be slower. + fn invoke_with_args(&self, args: LambdaFunctionArgs) -> Result; + + /// Optionally apply per-UDF simplification / rewrite rules. + /// + /// This can be used to apply function specific simplification rules during + /// optimization (e.g. `arrow_cast` --> `Expr::Cast`). The default + /// implementation does nothing. + /// + /// Note that DataFusion handles simplifying arguments and "constant + /// folding" (replacing a function call with constant arguments such as + /// `my_add(1,2) --> 3` ). Thus, there is no need to implement such + /// optimizations manually for specific UDFs. + /// + /// # Arguments + /// * `args`: The arguments of the function + /// * `info`: The necessary information for simplification + /// + /// # Returns + /// [`ExprSimplifyResult`] indicating the result of the simplification NOTE + /// if the function cannot be simplified, the arguments *MUST* be returned + /// unmodified + /// + /// # Notes + /// + /// The returned expression must have the same schema as the original + /// expression, including both the data type and nullability. For example, + /// if the original expression is nullable, the returned expression must + /// also be nullable, otherwise it may lead to schema verification errors + /// later in query planning. + fn simplify( + &self, + args: Vec, + _info: &dyn SimplifyInfo, + ) -> Result { + Ok(ExprSimplifyResult::Original(args)) + } + + /// Returns true if some of this `exprs` subexpressions may not be evaluated + /// and thus any side effects (like divide by zero) may not be encountered. + /// + /// Setting this to true prevents certain optimizations such as common + /// subexpression elimination + /// + /// When overriding this function to return `true`, [LambdaUDF::conditional_arguments] can also be + /// overridden to report more accurately which arguments are eagerly evaluated and which ones + /// lazily. + fn short_circuits(&self) -> bool { + false + } + + /// Determines which of the arguments passed to this function are evaluated eagerly + /// and which may be evaluated lazily. + /// + /// If this function returns `None`, all arguments are eagerly evaluated. + /// Returning `None` is a micro optimization that saves a needless `Vec` + /// allocation. + /// + /// If the function returns `Some`, returns (`eager`, `lazy`) where `eager` + /// are the arguments that are always evaluated, and `lazy` are the + /// arguments that may be evaluated lazily (i.e. may not be evaluated at all + /// in some cases). + /// + /// Implementations must ensure that the two returned `Vec`s are disjunct, + /// and that each argument from `args` is present in one the two `Vec`s. + /// + /// When overriding this function, [LambdaUDF::short_circuits] must + /// be overridden to return `true`. + fn conditional_arguments<'a>( + &self, + args: &'a [Expr], + ) -> Option<(Vec<&'a Expr>, Vec<&'a Expr>)> { + if self.short_circuits() { + Some((vec![], args.iter().collect())) + } else { + None + } + } + + /// Computes the output [`Interval`] for a [`LambdaUDF`], given the input + /// intervals. + /// + /// # Parameters + /// + /// * `children` are the intervals for the children (inputs) of this function. + /// + /// # Example + /// + /// If the function is `ABS(a)`, and the input interval is `a: [-3, 2]`, + /// then the output interval would be `[0, 3]`. + fn evaluate_bounds(&self, _input: &[&Interval]) -> Result { + // We cannot assume the input datatype is the same of output type. + Interval::make_unbounded(&DataType::Null) + } + + /// Updates bounds for child expressions, given a known [`Interval`]s for this + /// function. + /// + /// This function is used to propagate constraints down through an + /// expression tree. + /// + /// # Parameters + /// + /// * `interval` is the currently known interval for this function. + /// * `inputs` are the current intervals for the inputs (children) of this function. + /// + /// # Returns + /// + /// A `Vec` of new intervals for the children, in order. + /// + /// If constraint propagation reveals an infeasibility for any child, returns + /// [`None`]. If none of the children intervals change as a result of + /// propagation, may return an empty vector instead of cloning `children`. + /// This is the default (and conservative) return value. + /// + /// # Example + /// + /// If the function is `ABS(a)`, the current `interval` is `[4, 5]` and the + /// input `a` is given as `[-7, 3]`, then propagation would return `[-5, 3]`. + fn propagate_constraints( + &self, + _interval: &Interval, + _inputs: &[&Interval], + ) -> Result>> { + Ok(Some(vec![])) + } + + /// Calculates the [`SortProperties`] of this function based on its children's properties. + fn output_ordering(&self, inputs: &[ExprProperties]) -> Result { + if !self.preserves_lex_ordering(inputs)? { + return Ok(SortProperties::Unordered); + } + + let Some(first_order) = inputs.first().map(|p| &p.sort_properties) else { + return Ok(SortProperties::Singleton); + }; + + if inputs + .iter() + .skip(1) + .all(|input| &input.sort_properties == first_order) + { + Ok(*first_order) + } else { + Ok(SortProperties::Unordered) + } + } + + /// Returns true if the function preserves lexicographical ordering based on + /// the input ordering. + /// + /// For example, `concat(a || b)` preserves lexicographical ordering, but `abs(a)` does not. + fn preserves_lex_ordering(&self, _inputs: &[ExprProperties]) -> Result { + Ok(false) + } + + /// Coerce arguments of a function call to types that the function can evaluate. + /// + /// See the [type coercion module](crate::type_coercion) + /// documentation for more details on type coercion + /// + /// [`TypeSignature`]: crate::TypeSignature + /// + /// For example, if your function requires a floating point arguments, but the user calls + /// it like `my_func(1::int)` (i.e. with `1` as an integer), coerce_types can return `[DataType::Float64]` + /// to ensure the argument is converted to `1::double` + /// + /// # Parameters + /// * `arg_types`: The argument types of the arguments this function with + /// + /// # Return value + /// A Vec the same length as `arg_types`. DataFusion will `CAST` the function call + /// arguments to these specific types. + fn coerce_types(&self, _arg_types: &[DataType]) -> Result> { + not_impl_err!("Function {} does not implement coerce_types", self.name()) + } + + /// Returns the documentation for this Lambda UDF. + /// + /// Documentation can be accessed programmatically as well as generating + /// publicly facing documentation. + fn documentation(&self) -> Option<&Documentation> { + None + } + + /// Returns the parameters that any lambda supports + fn lambdas_parameters( + &self, + args: &[ValueOrLambdaParameter], + ) -> Result>>> { + Ok(vec![None; args.len()]) + } +} + +#[cfg(test)] +mod tests { + use datafusion_expr_common::signature::Volatility; + + use super::*; + use std::hash::DefaultHasher; + + #[derive(Debug, PartialEq, Eq, Hash)] + struct TestLambdaUDF { + name: &'static str, + field: &'static str, + signature: Signature, + } + impl LambdaUDF for TestLambdaUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_field_from_args(&self, _args: LambdaReturnFieldArgs) -> Result { + unimplemented!() + } + + fn invoke_with_args(&self, _args: LambdaFunctionArgs) -> Result { + unimplemented!() + } + } + + // PartialEq and Hash must be consistent, and also PartialEq and PartialOrd + // must be consistent, so they are tested together. + #[test] + fn test_partial_eq_hash_and_partial_ord() { + // A parameterized function + let f = test_func("foo", "a"); + + // Same like `f`, different instance + let f2 = test_func("foo", "a"); + assert_eq!(&f, &f2); + assert_eq!(hash(&f), hash(&f2)); + assert_eq!(f.partial_cmp(&f2), Some(Ordering::Equal)); + + // Different parameter + let b = test_func("foo", "b"); + assert_ne!(&f, &b); + assert_ne!(hash(&f), hash(&b)); // hash can collide for different values but does not collide in this test + assert_eq!(f.partial_cmp(&b), None); + + // Different name + let o = test_func("other", "a"); + assert_ne!(&f, &o); + assert_ne!(hash(&f), hash(&o)); // hash can collide for different values but does not collide in this test + assert_eq!(f.partial_cmp(&o), Some(Ordering::Less)); + + // Different name and parameter + assert_ne!(&b, &o); + assert_ne!(hash(&b), hash(&o)); // hash can collide for different values but does not collide in this test + assert_eq!(b.partial_cmp(&o), Some(Ordering::Less)); + } + + fn test_func(name: &'static str, parameter: &'static str) -> Arc { + Arc::new(TestLambdaUDF { + name, + field: parameter, + signature: Signature::any(1, Volatility::Immutable), + }) + } + + fn hash(value: &T) -> u64 { + let hasher = &mut DefaultHasher::new(); + value.hash(hasher); + hasher.finish() + } +} diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index ab58b1c3f835f..e7beba8c4b090 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -308,6 +308,7 @@ pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet) -> Result<()> { | Expr::Wildcard { .. } | Expr::Placeholder(_) | Expr::OuterReferenceColumn { .. } + | Expr::LambdaFunction(_) | Expr::Lambda(_) | Expr::LambdaVariable(_) => {} } diff --git a/datafusion/functions-nested/src/array_transform.rs b/datafusion/functions-nested/src/array_transform.rs index 123df27b339be..746cfe6b62be5 100644 --- a/datafusion/functions-nested/src/array_transform.rs +++ b/datafusion/functions-nested/src/array_transform.rs @@ -26,24 +26,29 @@ use arrow::{ datatypes::{DataType, Field, FieldRef, Schema}, }; use datafusion_common::{ - HashMap, Result, exec_err, internal_err, tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter}, utils::{elements_indices, list_indices, list_values, take_function_args} + exec_err, + tree_node::{ + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter, + }, + utils::{elements_indices, list_indices, list_values, take_function_args}, + HashMap, Result, }; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + ColumnarValue, Documentation, LambdaFunctionArgs, LambdaUDF, Signature, ValueOrLambda, ValueOrLambdaField, ValueOrLambdaParameter, Volatility, }; use datafusion_macros::user_doc; -use datafusion_physical_expr::expressions::{LambdaVariable, LambdaExpr}; +use datafusion_physical_expr::expressions::{LambdaExpr, LambdaVariable}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use std::{any::Any, sync::Arc}; -make_udf_expr_and_func!( - ArrayTransform, - array_transform, - array lambda, - "transforms the values of a array", - array_transform_udf -); +//make_udf_expr_and_func!( +// ArrayTransform, +// array_transform, +// array lambda, +// "transforms the values of a array", +// array_transform_udf +//); #[user_doc( doc_section(label = "Array Functions"), @@ -84,7 +89,7 @@ impl ArrayTransform { } } -impl ScalarUDFImpl for ArrayTransform { +impl LambdaUDF for ArrayTransform { fn as_any(&self) -> &dyn Any { self } @@ -101,18 +106,12 @@ impl ScalarUDFImpl for ArrayTransform { &self.signature } - fn return_type(&self, _arg_types: &[DataType]) -> Result { - internal_err!("return_type called instead of return_field_from_args") - } - fn return_field_from_args( &self, - args: datafusion_expr::ReturnFieldArgs, + args: datafusion_expr::LambdaReturnFieldArgs, ) -> Result> { - let args = args.to_lambda_args(); - let [ValueOrLambdaField::Value(list), ValueOrLambdaField::Lambda(lambda)] = - take_function_args(self.name(), &args)? + take_function_args(self.name(), args.arg_fields)? else { return exec_err!( "{} expects a value follewed by a lambda, got {:?}", @@ -141,10 +140,8 @@ impl ScalarUDFImpl for ArrayTransform { Ok(Arc::new(Field::new("", return_type, list.is_nullable()))) } - fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - // args.lambda_args allows the convenient match below, instead of inspecting both args.args and args.lambdas - let lambda_args = args.to_lambda_args(); - let [list_value, lambda] = take_function_args(self.name(), &lambda_args)?; + fn invoke_with_args(&self, args: LambdaFunctionArgs) -> Result { + let [list_value, lambda] = take_function_args(self.name(), &args.args)?; let (ValueOrLambda::Value(list_value), ValueOrLambda::Lambda(lambda)) = (list_value, lambda) @@ -152,7 +149,7 @@ impl ScalarUDFImpl for ArrayTransform { return exec_err!( "{} expects a value followed by a lambda, got {:?}", self.name(), - &lambda_args + &args.args ); }; @@ -243,8 +240,7 @@ impl ScalarUDFImpl for ArrayTransform { &self, args: &[ValueOrLambdaParameter], ) -> Result>>> { - let [ValueOrLambdaParameter::Value(list), ValueOrLambdaParameter::Lambda] = - args + let [ValueOrLambdaParameter::Value(list), ValueOrLambdaParameter::Lambda] = args else { return exec_err!( "{} expects a value follewed by a lambda, got {:?}", @@ -305,7 +301,7 @@ impl TreeNodeRewriter for BindLambdaVariable<'_> { if let Some((value, shadows)) = self.columns.get(lambda_variable.name()) { if *shadows == 0 { return Ok(Transformed::yes(Arc::new( - lambda_variable.clone().with_value(value.clone()), + lambda_variable.clone().with_value(Arc::clone(value)), ))); } } @@ -317,7 +313,7 @@ impl TreeNodeRewriter for BindLambdaVariable<'_> { } if self.columns.values().all(|(_value, shadows)| *shadows > 0) { - return Ok(Transformed::new(node, false, TreeNodeRecursion::Jump)) + return Ok(Transformed::new(node, false, TreeNodeRecursion::Jump)); } } diff --git a/datafusion/functions-nested/src/lib.rs b/datafusion/functions-nested/src/lib.rs index 55acf24ba4657..c93a55cce1a4f 100644 --- a/datafusion/functions-nested/src/lib.rs +++ b/datafusion/functions-nested/src/lib.rs @@ -79,7 +79,7 @@ pub mod expr_fn { pub use super::array_has::array_has; pub use super::array_has::array_has_all; pub use super::array_has::array_has_any; - pub use super::array_transform::array_transform; + //pub use super::array_transform::array_transform; pub use super::cardinality::cardinality; pub use super::concat::array_append; pub use super::concat::array_concat; @@ -147,7 +147,7 @@ pub fn all_default_nested_functions() -> Vec> { array_has::array_has_udf(), array_has::array_has_all_udf(), array_has::array_has_any_udf(), - array_transform::array_transform_udf(), + //array_transform::array_transform_udf(), empty::array_empty_udf(), length::array_length_udf(), distance::array_distance_udf(), diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 626c2ba550594..e0b1d9096b415 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -598,6 +598,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { | Expr::GroupingSet(_) | Expr::Placeholder(_) | Expr::OuterReferenceColumn(_, _) + | Expr::LambdaFunction(_) | Expr::Lambda(_) | Expr::LambdaVariable(_) => Ok(Transformed::no(expr)), } diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index b7e8626aa7bd5..63314b48facfd 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -288,6 +288,7 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { | Expr::TryCast(_) | Expr::InList { .. } | Expr::ScalarFunction(_) + | Expr::LambdaFunction(_) | Expr::Lambda(_) | Expr::LambdaVariable(_) => Ok(TreeNodeRecursion::Continue), // TODO: remove the next line after `Expr::Wildcard` is removed diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 74794115755ba..8e09fbbae48d8 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -36,6 +36,7 @@ use datafusion_common::{ exec_datafusion_err, internal_err, DFSchema, DataFusionError, Result, ScalarValue, }; +use datafusion_expr::expr::LambdaFunction; use datafusion_expr::{ and, binary::BinaryTypeCoercer, lit, or, BinaryExpr, Case, ColumnarValue, Expr, Like, Operator, Volatility, @@ -659,6 +660,9 @@ impl<'a> ConstEvaluator<'a> { Expr::ScalarFunction(ScalarFunction { func, .. }) => { Self::volatility_ok(func.signature().volatility) } + Expr::LambdaFunction(LambdaFunction { func, .. }) => { + Self::volatility_ok(func.signature().volatility) + } Expr::Literal(_, _) | Expr::Alias(..) | Expr::Unnest(_) diff --git a/datafusion/optimizer/tests/optimizer_integration.rs b/datafusion/optimizer/tests/optimizer_integration.rs index c0f48b8ebfc40..11a656f2abb4c 100644 --- a/datafusion/optimizer/tests/optimizer_integration.rs +++ b/datafusion/optimizer/tests/optimizer_integration.rs @@ -728,6 +728,10 @@ impl ContextProvider for MyContextProvider { None } + fn get_lambda_meta(&self, _name: &str) -> Option> { + None + } + fn get_aggregate_meta(&self, name: &str) -> Option> { self.udafs.get(name).cloned() } diff --git a/datafusion/physical-expr/src/lambda_function.rs b/datafusion/physical-expr/src/lambda_function.rs new file mode 100644 index 0000000000000..0b0c33cdd6be9 --- /dev/null +++ b/datafusion/physical-expr/src/lambda_function.rs @@ -0,0 +1,530 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Declaration of built-in (lambda) functions. +//! This module contains built-in functions' enumeration and metadata. +//! +//! Generally, a function has: +//! * a signature +//! * a return type, that is a function of the incoming argument's types +//! * the computation, that must accept each valid signature +//! +//! * Signature: see `Signature` +//! * Return type: a function `(arg_types) -> return_type`. E.g. for sqrt, ([f32]) -> f32, ([f64]) -> f64. +//! +//! This module also has a set of coercion rules to improve user experience: if an argument i32 is passed +//! to a function that supports f64, it is coerced to f64. + +use std::any::Any; +use std::fmt::{self, Debug, Formatter}; +use std::hash::{Hash, Hasher}; +use std::sync::Arc; + +use crate::expressions::{LambdaExpr, Literal}; +use crate::PhysicalExpr; + +use arrow::array::{Array, NullArray, RecordBatch}; +use arrow::datatypes::{DataType, Field, FieldRef, Schema}; +use datafusion_common::config::{ConfigEntry, ConfigOptions}; +use datafusion_common::{exec_err, internal_err, Result, ScalarValue}; +use datafusion_expr::interval_arithmetic::Interval; +use datafusion_expr::sort_properties::ExprProperties; +use datafusion_expr::{ + expr_vec_fmt, ColumnarValue, LambdaFunctionArgs, LambdaFunctionLambdaArg, + LambdaReturnFieldArgs, LambdaUDF, ValueOrLambda, ValueOrLambdaField, + ValueOrLambdaParameter, Volatility, +}; + +/// Physical expression of a lambda function +pub struct LambdaFunctionExpr { + fun: Arc, + name: String, + args: Vec>, + return_field: FieldRef, + config_options: Arc, +} + +impl Debug for LambdaFunctionExpr { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + f.debug_struct("LambdaFunctionExpr") + .field("fun", &"") + .field("name", &self.name) + .field("args", &self.args) + .field("return_field", &self.return_field) + .finish() + } +} + +impl LambdaFunctionExpr { + /// Create a new Lambda function + pub fn new( + name: &str, + fun: Arc, + args: Vec>, + return_field: FieldRef, + config_options: Arc, + ) -> Self { + Self { + fun, + name: name.to_owned(), + args, + return_field, + config_options, + } + } + + /// Create a new Lambda function + pub fn try_new( + fun: Arc, + args: Vec>, + schema: &Schema, + config_options: Arc, + ) -> Result { + let name = fun.name().to_string(); + let arg_fields = args + .iter() + .map(|e| { + let field = e.return_field(schema)?; + match e.as_any().downcast_ref::() { + Some(_lambda) => Ok(ValueOrLambdaField::Lambda(field)), + None => Ok(ValueOrLambdaField::Value(field)), + } + }) + .collect::>>()?; + + // TODO: verify that input data types is consistent with function's `TypeSignature` + + let arguments = args + .iter() + .map(|e| { + e.as_any() + .downcast_ref::() + .map(|literal| literal.value()) + }) + .collect::>(); + + let ret_args = LambdaReturnFieldArgs { + arg_fields: &arg_fields, + scalar_arguments: &arguments, + }; + + let return_field = fun.return_field_from_args(ret_args)?; + + Ok(Self { + fun, + name, + args, + return_field, + config_options, + }) + } + + /// Get the lambda function implementation + pub fn fun(&self) -> &dyn LambdaUDF { + self.fun.as_ref() + } + + /// The name for this expression + pub fn name(&self) -> &str { + &self.name + } + + /// Input arguments + pub fn args(&self) -> &[Arc] { + &self.args + } + + /// Data type produced by this expression + pub fn return_type(&self) -> &DataType { + self.return_field.data_type() + } + + pub fn with_nullable(mut self, nullable: bool) -> Self { + self.return_field = self + .return_field + .as_ref() + .clone() + .with_nullable(nullable) + .into(); + self + } + + pub fn nullable(&self) -> bool { + self.return_field.is_nullable() + } + + pub fn config_options(&self) -> &ConfigOptions { + &self.config_options + } + + /// Given an arbitrary PhysicalExpr attempt to downcast it to a LambdaFunctionExpr + /// and verify that its inner function is of type T. + /// If the downcast fails, or the function is not of type T, returns `None`. + /// Otherwise returns `Some(LambdaFunctionExpr)`. + pub fn try_downcast_func(expr: &dyn PhysicalExpr) -> Option<&LambdaFunctionExpr> + where + T: 'static, + { + match expr.as_any().downcast_ref::() { + Some(lambda_expr) + if lambda_expr.fun().as_any().downcast_ref::().is_some() => + { + Some(lambda_expr) + } + _ => None, + } + } +} + +impl fmt::Display for LambdaFunctionExpr { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + write!(f, "{}({})", self.name, expr_vec_fmt!(self.args)) + } +} + +impl PartialEq for LambdaFunctionExpr { + fn eq(&self, o: &Self) -> bool { + if std::ptr::eq(self, o) { + // The equality implementation is somewhat expensive, so let's short-circuit when possible. + return true; + } + let Self { + fun, + name, + args, + return_field, + config_options, + } = self; + fun.eq(&o.fun) + && name.eq(&o.name) + && args.eq(&o.args) + && return_field.eq(&o.return_field) + && (Arc::ptr_eq(config_options, &o.config_options) + || sorted_config_entries(config_options) + == sorted_config_entries(&o.config_options)) + } +} +impl Eq for LambdaFunctionExpr {} +impl Hash for LambdaFunctionExpr { + fn hash(&self, state: &mut H) { + let Self { + fun, + name, + args, + return_field, + config_options: _, // expensive to hash, and often equal + } = self; + fun.hash(state); + name.hash(state); + args.hash(state); + return_field.hash(state); + } +} + +fn sorted_config_entries(config_options: &ConfigOptions) -> Vec { + let mut entries = config_options.entries(); + entries.sort_by(|l, r| l.key.cmp(&r.key)); + entries +} + +impl PhysicalExpr for LambdaFunctionExpr { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, _input_schema: &Schema) -> Result { + Ok(self.return_field.data_type().clone()) + } + + fn nullable(&self, _input_schema: &Schema) -> Result { + Ok(self.return_field.is_nullable()) + } + + fn evaluate(&self, batch: &RecordBatch) -> Result { + let arg_fields = self + .args + .iter() + .map(|e| { + let field = e.return_field(batch.schema_ref())?; + match e.as_any().downcast_ref::() { + Some(_lambda) => Ok(ValueOrLambdaField::Lambda(field)), + None => Ok(ValueOrLambdaField::Value(field)), + } + }) + .collect::>>()?; + + let args_metadata = arg_fields.iter() + .map(|field| match field { + ValueOrLambdaField::Value(field) => ValueOrLambdaParameter::Value(Arc::clone(field)), + ValueOrLambdaField::Lambda(_field) => ValueOrLambdaParameter::Lambda, + }) + .collect::>(); + + let params = self.fun().lambdas_parameters(&args_metadata)?; + + let args = std::iter::zip(&self.args, params) + .map(|(arg, lambda_params)| { + match (arg.as_any().downcast_ref::(), lambda_params) { + (Some(lambda), Some(lambda_params)) => { + if lambda.params().len() > lambda_params.len() { + return exec_err!( + "lambda defined {} params but UDF support only {}", + lambda.params().len(), + lambda_params.len() + ); + } + + let captures = lambda.captures(); + + let params = std::iter::zip(lambda.params(), lambda_params) + .map(|(name, param)| Arc::new(param.with_name(name))) + .collect(); + + let captures = if !captures.is_empty() { + let (fields, columns): (Vec<_>, _) = std::iter::zip( + batch.schema_ref().fields(), + batch.columns(), + ) + .enumerate() + .map(|(column_index, (field, column))| { + if captures.contains(&column_index) { + (Arc::clone(field), Arc::clone(column)) + } else { + ( + Arc::new(Field::new( + field.name(), + DataType::Null, + false, + )), + Arc::new(NullArray::new(column.len())) as _, + ) + } + }) + .unzip(); + + let schema = Arc::new(Schema::new(fields)); + + Some(RecordBatch::try_new(schema, columns)?) + } else { + None + }; + + Ok(ValueOrLambda::Lambda(LambdaFunctionLambdaArg { + params, + body: Arc::clone(lambda.body()), + captures, + })) + } + (Some(_lambda), None) => exec_err!( + "{} don't reported the parameters of one of it's lambdas", + self.fun.name() + ), + (None, Some(_lambda_params)) => exec_err!( + "{} reported parameters for an argument that is not a lambda", + self.fun.name() + ), + (None, None) => Ok(ValueOrLambda::Value(arg.evaluate(batch)?)), + } + }) + .collect::>>()?; + + let input_empty = args.is_empty(); + let input_all_scalar = args + .iter() + .all(|arg| matches!(arg, ValueOrLambda::Value(ColumnarValue::Scalar(_)))); + + // evaluate the function + let output = self.fun.invoke_with_args(LambdaFunctionArgs { + args, + arg_fields, + number_rows: batch.num_rows(), + return_field: Arc::clone(&self.return_field), + config_options: Arc::clone(&self.config_options), + })?; + + if let ColumnarValue::Array(array) = &output { + if array.len() != batch.num_rows() { + // If the arguments are a non-empty slice of scalar values, we can assume that + // returning a one-element array is equivalent to returning a scalar. + let preserve_scalar = + array.len() == 1 && !input_empty && input_all_scalar; + return if preserve_scalar { + ScalarValue::try_from_array(array, 0).map(ColumnarValue::Scalar) + } else { + internal_err!("UDF {} returned a different number of rows than expected. Expected: {}, Got: {}", + self.name, batch.num_rows(), array.len()) + }; + } + } + Ok(output) + } + + fn return_field(&self, _input_schema: &Schema) -> Result { + Ok(Arc::clone(&self.return_field)) + } + + fn children(&self) -> Vec<&Arc> { + self.args.iter().collect() + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + Ok(Arc::new(LambdaFunctionExpr::new( + &self.name, + Arc::clone(&self.fun), + children, + Arc::clone(&self.return_field), + Arc::clone(&self.config_options), + ))) + } + + fn evaluate_bounds(&self, children: &[&Interval]) -> Result { + self.fun.evaluate_bounds(children) + } + + fn propagate_constraints( + &self, + interval: &Interval, + children: &[&Interval], + ) -> Result>> { + self.fun.propagate_constraints(interval, children) + } + + fn get_properties(&self, children: &[ExprProperties]) -> Result { + let sort_properties = self.fun.output_ordering(children)?; + let preserves_lex_ordering = self.fun.preserves_lex_ordering(children)?; + let children_range = children + .iter() + .map(|props| &props.range) + .collect::>(); + let range = self.fun().evaluate_bounds(&children_range)?; + + Ok(ExprProperties { + sort_properties, + range, + preserves_lex_ordering, + }) + } + + fn fmt_sql(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "{}(", self.name)?; + for (i, expr) in self.args.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + expr.fmt_sql(f)?; + } + write!(f, ")") + } + + fn is_volatile_node(&self) -> bool { + self.fun.signature().volatility == Volatility::Volatile + } +} + +#[cfg(test)] +mod tests { + use std::any::Any; + use std::sync::Arc; + + use super::*; + use crate::expressions::Column; + use crate::LambdaFunctionExpr; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_common::Result; + use datafusion_expr::{LambdaFunctionArgs, LambdaUDF, Signature}; + use datafusion_expr_common::columnar_value::ColumnarValue; + use datafusion_physical_expr_common::physical_expr::is_volatile; + use datafusion_physical_expr_common::physical_expr::PhysicalExpr; + + /// Test helper to create a mock UDF with a specific volatility + #[derive(Debug, PartialEq, Eq, Hash)] + struct MockLambdaUDF { + signature: Signature, + } + + impl LambdaUDF for MockLambdaUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "mock_function" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_field_from_args( + &self, + _args: LambdaReturnFieldArgs, + ) -> Result { + Ok(Arc::new(Field::new("", DataType::Int32, false))) + } + + fn invoke_with_args(&self, _args: LambdaFunctionArgs) -> Result { + Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(42)))) + } + } + + #[test] + fn test_lambda_function_volatile_node() { + // Create a volatile UDF + let volatile_udf = Arc::new(MockLambdaUDF { + signature: Signature::uniform( + 1, + vec![DataType::Float32], + Volatility::Volatile, + ), + }); + + // Create a non-volatile UDF + let stable_udf = Arc::new(MockLambdaUDF { + signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable), + }); + + let schema = Schema::new(vec![Field::new("a", DataType::Float32, false)]); + let args = vec![Arc::new(Column::new("a", 0)) as Arc]; + let config_options = Arc::new(ConfigOptions::new()); + + // Test volatile function + let volatile_expr = LambdaFunctionExpr::try_new( + volatile_udf, + args.clone(), + &schema, + Arc::clone(&config_options), + ) + .unwrap(); + + assert!(volatile_expr.is_volatile_node()); + let volatile_arc: Arc = Arc::new(volatile_expr); + assert!(is_volatile(&volatile_arc)); + + // Test non-volatile function + let stable_expr = + LambdaFunctionExpr::try_new(stable_udf, args, &schema, config_options) + .unwrap(); + + assert!(!stable_expr.is_volatile_node()); + let stable_arc: Arc = Arc::new(stable_expr); + assert!(!is_volatile(&stable_arc)); + } +} diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index aa8c9e50fd71e..a05d24d2ba2c5 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -31,6 +31,7 @@ pub mod binary_map { pub use datafusion_physical_expr_common::binary_map::{ArrowBytesSet, OutputType}; } pub mod async_scalar_function; +pub mod lambda_function; pub mod equivalence; pub mod expressions; pub mod intervals; @@ -70,6 +71,7 @@ pub use datafusion_physical_expr_common::sort_expr::{ pub use planner::{create_physical_expr, create_physical_exprs}; pub use scalar_function::ScalarFunctionExpr; +pub use lambda_function::LambdaFunctionExpr; pub use simplifier::PhysicalExprSimplifier; pub use utils::{conjunction, conjunction_opt, split_conjunction}; diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index 8a53aa81da8fa..f7be4aedf555e 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -18,7 +18,7 @@ use std::sync::Arc; use crate::expressions::{lambda_variable, LambdaExpr}; -use crate::ScalarFunctionExpr; +use crate::{LambdaFunctionExpr, ScalarFunctionExpr}; use crate::{ expressions::{self, binary, like, similar_to, Column, Literal}, PhysicalExpr, @@ -33,7 +33,7 @@ use datafusion_common::{ }; use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::expr::{ - Alias, Cast, InList, Lambda, LambdaVariable, Placeholder, ScalarFunction, + Alias, Cast, InList, Lambda, LambdaFunction, LambdaVariable, Placeholder, ScalarFunction }; use datafusion_expr::var_provider::is_system_variables; use datafusion_expr::var_provider::VarType; @@ -318,10 +318,6 @@ pub fn create_physical_expr( input_dfschema, execution_props, )?), - Expr::Lambda(Lambda { params, body }) => Ok(Arc::new(LambdaExpr::new( - params.clone(), - create_physical_expr(body, input_dfschema, execution_props)?, - ))), Expr::ScalarFunction(ScalarFunction { func, args }) => { let physical_args = create_physical_exprs(args, input_dfschema, execution_props)?; @@ -392,6 +388,26 @@ pub fn create_physical_expr( Expr::Placeholder(Placeholder { id, .. }) => { exec_err!("Placeholder '{id}' was not provided a value for execution.") } + Expr::LambdaFunction(LambdaFunction { func, args}) => { + let physical_args = + create_physical_exprs(args, input_dfschema, execution_props)?; + + let config_options = match execution_props.config_options.as_ref() { + Some(config_options) => Arc::clone(config_options), + None => Arc::new(ConfigOptions::default()), + }; + + Ok(Arc::new(LambdaFunctionExpr::try_new( + Arc::clone(func), + physical_args, + input_schema, + config_options, + )?)) + } + Expr::Lambda(Lambda { params, body }) => Ok(Arc::new(LambdaExpr::new( + params.clone(), + create_physical_expr(body, input_dfschema, execution_props)?, + ))), Expr::LambdaVariable(LambdaVariable { name, field, diff --git a/datafusion/proto/src/bytes/mod.rs b/datafusion/proto/src/bytes/mod.rs index 6eab2239015a7..e421ea11c4b2d 100644 --- a/datafusion/proto/src/bytes/mod.rs +++ b/datafusion/proto/src/bytes/mod.rs @@ -24,11 +24,11 @@ use crate::physical_plan::{ AsExecutionPlan, DefaultPhysicalExtensionCodec, PhysicalExtensionCodec, }; use crate::protobuf; -use datafusion_common::{plan_datafusion_err, Result}; +use datafusion_common::{not_impl_err, plan_datafusion_err, Result}; use datafusion_execution::TaskContext; use datafusion_expr::{ - create_udaf, create_udf, create_udwf, AggregateUDF, Expr, LogicalPlan, Volatility, - WindowUDF, + create_udaf, create_udf, create_udwf, AggregateUDF, Expr, LambdaUDF, LogicalPlan, + Signature, Volatility, WindowUDF, }; use prost::{ bytes::{Bytes, BytesMut}, @@ -167,6 +167,15 @@ impl Serializeable for Expr { ) } + fn register_udlf( + &mut self, + _udlf: Arc, + ) -> Result>> { + datafusion_common::internal_err!( + "register_udlf called in Placeholder Registry!" + ) + } + fn expr_planners(&self) -> Vec> { vec![] } @@ -178,6 +187,51 @@ impl Serializeable for Expr { fn udwfs(&self) -> std::collections::HashSet { std::collections::HashSet::default() } + + fn udlfs(&self) -> std::collections::HashSet { + std::collections::HashSet::default() + } + + fn udlf(&self, name: &str) -> Result> { + #[derive(Debug, PartialEq, Eq, Hash)] + struct MockLambdaUDF { + name: String, + signature: Signature, + } + + impl LambdaUDF for MockLambdaUDF { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_field_from_args( + &self, + _args: datafusion_expr::LambdaReturnFieldArgs, + ) -> Result { + not_impl_err!("mock LambdaUDF") + } + + fn invoke_with_args( + &self, + _args: datafusion_expr::LambdaFunctionArgs, + ) -> Result { + not_impl_err!("mock LambdaUDF") + } + } + + Ok(Arc::new(MockLambdaUDF { + name: name.to_string(), + signature: Signature::variadic_any(Volatility::Immutable), + })) + } } Expr::from_bytes_with_registry(&bytes, &PlaceHolderRegistry)?; diff --git a/datafusion/proto/src/bytes/registry.rs b/datafusion/proto/src/bytes/registry.rs index 087e073db21af..98f4928457679 100644 --- a/datafusion/proto/src/bytes/registry.rs +++ b/datafusion/proto/src/bytes/registry.rs @@ -67,4 +67,14 @@ impl FunctionRegistry for NoRegistry { fn udwfs(&self) -> HashSet { HashSet::new() } + + fn udlfs(&self) -> HashSet { + HashSet::new() + } + + fn udlf(&self, name: &str) -> Result> { + plan_err!("No function registry provided to deserialize, so can not deserialize User Defined Lambda Function '{name}'") + } + + } diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index 9644c9f69feae..1122952771c79 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -63,7 +63,7 @@ use datafusion_expr::{ Statement, WindowUDF, }; use datafusion_expr::{ - AggregateUDF, DmlStatement, FetchType, RecursiveQuery, SkipType, TableSource, Unnest, + AggregateUDF, DmlStatement, FetchType, LambdaUDF, RecursiveQuery, SkipType, TableSource, Unnest }; use self::to_proto::{serialize_expr, serialize_exprs}; @@ -153,6 +153,14 @@ pub trait LogicalExtensionCodec: Debug + Send + Sync { fn try_encode_udf(&self, _node: &ScalarUDF, _buf: &mut Vec) -> Result<()> { Ok(()) } + + fn try_decode_udlf(&self, name: &str, _buf: &[u8]) -> Result> { + not_impl_err!("LogicalExtensionCodec is not provided for lambda function {name}") + } + + fn try_encode_udlf(&self, _node: &dyn LambdaUDF, _buf: &mut Vec) -> Result<()> { + Ok(()) + } fn try_decode_udaf(&self, name: &str, _buf: &[u8]) -> Result> { not_impl_err!( diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 41a9172fff276..e080411b49e95 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -622,6 +622,17 @@ pub fn serialize_expr( .unwrap_or(HashMap::new()), })), }, + Expr::LambdaFunction(func) => { + let mut buf = Vec::new(); + let _ = codec.try_encode_udlf(func.func.as_ref(), &mut buf); + protobuf::LogicalExprNode { + expr_type: Some(ExprType::ScalarUdfExpr(protobuf::ScalarUdfExprNode { + fun_name: func.name().to_string(), + fun_definition: (!buf.is_empty()).then_some(buf), + args: serialize_exprs(&func.args, codec)?, + })), + } + } Expr::Lambda(_) | Expr::LambdaVariable(_) => { return Err(Error::General( "Proto serialization error: Lambda not implemented".to_string(), diff --git a/datafusion/session/src/session.rs b/datafusion/session/src/session.rs index fd033172f224f..625b3fb77a4d8 100644 --- a/datafusion/session/src/session.rs +++ b/datafusion/session/src/session.rs @@ -22,7 +22,7 @@ use datafusion_execution::config::SessionConfig; use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_execution::TaskContext; use datafusion_expr::execution_props::ExecutionProps; -use datafusion_expr::{AggregateUDF, Expr, LogicalPlan, ScalarUDF, WindowUDF}; +use datafusion_expr::{AggregateUDF, Expr, LambdaUDF, LogicalPlan, ScalarUDF, WindowUDF}; use datafusion_physical_plan::{ExecutionPlan, PhysicalExpr}; use parking_lot::{Mutex, RwLock}; use std::any::Any; @@ -109,6 +109,9 @@ pub trait Session: Send + Sync { /// Return reference to scalar_functions fn scalar_functions(&self) -> &HashMap>; + + /// Return reference to lambda_functions + fn lambda_functions(&self) -> &HashMap>; /// Return reference to aggregate_functions fn aggregate_functions(&self) -> &HashMap>; @@ -149,6 +152,7 @@ impl From<&dyn Session> for TaskContext { state.session_id().to_string(), state.config().clone(), state.scalar_functions().clone(), + state.lambda_functions().clone(), state.aggregate_functions().clone(), state.window_functions().clone(), Arc::clone(state.runtime_env()), diff --git a/datafusion/sql/examples/sql.rs b/datafusion/sql/examples/sql.rs index 2c0bb86cd8087..3d2ff0528081c 100644 --- a/datafusion/sql/examples/sql.rs +++ b/datafusion/sql/examples/sql.rs @@ -22,7 +22,7 @@ use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::config::ConfigOptions; use datafusion_common::{plan_err, Result, TableReference}; use datafusion_expr::planner::ExprPlanner; -use datafusion_expr::WindowUDF; +use datafusion_expr::{LambdaUDF, WindowUDF}; use datafusion_expr::{ logical_plan::builder::LogicalTableSource, AggregateUDF, ScalarUDF, TableSource, }; @@ -138,6 +138,10 @@ impl ContextProvider for MyContextProvider { None } + fn get_lambda_meta(&self, _name: &str) -> Option> { + None + } + fn get_aggregate_meta(&self, name: &str) -> Option> { self.udafs.get(name).cloned() } diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 439e65e8f7e47..47f132d065980 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -24,7 +24,7 @@ use datafusion_common::{ internal_datafusion_err, internal_err, not_impl_err, plan_datafusion_err, plan_err, DFSchema, Dependency, Diagnostic, Result, Span, }; -use datafusion_expr::expr::{Lambda, ScalarFunction, Unnest}; +use datafusion_expr::expr::{Lambda, LambdaFunction, ScalarFunction, Unnest}; use datafusion_expr::expr::{NullTreatment, WildcardOptions, WindowFunction}; use datafusion_expr::planner::PlannerResult; use datafusion_expr::planner::{RawAggregateExpr, RawWindowExpr}; @@ -277,8 +277,8 @@ impl SqlToRel<'_, S> { } } } - // User-defined function (UDF) should have precedence - if let Some(fm) = self.context_provider.get_function_meta(&name) { + + if let Some(fm) = self.context_provider.get_lambda_meta(&name) { enum ExprOrLambda { ExprWithName((Expr, Option)), Lambda(sqlparser::ast::LambdaFunction), @@ -312,7 +312,7 @@ impl SqlToRel<'_, S> { }) .collect::>>()?; - let lambdas_parameters = fm.inner().lambdas_parameters(&metadata)?; + let lambdas_parameters = fm.lambdas_parameters(&metadata)?; let pairs = pairs .into_iter() @@ -385,6 +385,55 @@ impl SqlToRel<'_, S> { args }; + // After resolution, all arguments are positional + let inner = LambdaFunction::new(fm, resolved_args); + + if name.eq_ignore_ascii_case(inner.name()) { + return Ok(Expr::LambdaFunction(inner)); + } else { + // If the function is called by an alias, a verbose string representation is created + // (e.g., "my_alias(arg1, arg2)") and the expression is wrapped in an `Alias` + // to ensure the output column name matches the user's query. + let arg_names = inner + .args + .iter() + .map(|arg| arg.to_string()) + .collect::>() + .join(","); + let verbose_alias = format!("{name}({arg_names})"); + + return Ok(Expr::LambdaFunction(inner).alias(verbose_alias)); + } + } + + // User-defined function (UDF) should have precedence + if let Some(fm) = self.context_provider.get_function_meta(&name) { + let (args, arg_names): (Vec, Vec>) = args + .into_iter() + .map(|a| { + self.sql_fn_arg_to_logical_expr_with_name(a, schema, planner_context) + }) + .collect::>>()? + .into_iter() + .unzip(); + + let resolved_args = if arg_names.iter().any(|name| name.is_some()) { + if let Some(param_names) = &fm.signature().parameter_names { + datafusion_expr::arguments::resolve_function_arguments( + param_names, + args, + arg_names, + )? + } else { + return plan_err!( + "Function '{}' does not support named arguments", + fm.name() + ); + } + } else { + args + }; + // After resolution, all arguments are positional let inner = ScalarFunction::new_udf(fm, resolved_args); diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index 715a02db8b027..e51f1c04cf157 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -1207,7 +1207,7 @@ mod tests { use datafusion_common::config::ConfigOptions; use datafusion_common::TableReference; use datafusion_expr::logical_plan::builder::LogicalTableSource; - use datafusion_expr::{AggregateUDF, ScalarUDF, TableSource, WindowUDF}; + use datafusion_expr::{AggregateUDF, LambdaUDF, ScalarUDF, TableSource, WindowUDF}; use super::*; @@ -1247,6 +1247,10 @@ mod tests { None } + fn get_lambda_meta(&self, _name: &str) -> Option> { + None + } + fn get_aggregate_meta(&self, name: &str) -> Option> { match name { "sum" => Some(datafusion_functions_aggregate::sum::sum_udaf()), diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 013a6f1128957..3bd669dbec071 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -15,12 +15,12 @@ // specific language governing permissions and limitations // under the License. -use datafusion_expr::expr::{AggregateFunctionParams, WindowFunctionParams}; +use datafusion_expr::expr::{AggregateFunctionParams, LambdaFunction, WindowFunctionParams}; use datafusion_expr::expr::{Lambda, Unnest}; use sqlparser::ast::Value::SingleQuotedString; use sqlparser::ast::{ self, Array, BinaryOperator, Expr as AstExpr, Function, Ident, Interval, - LambdaFunction, ObjectName, Subscript, TimezoneInfo, UnaryOperator, + ObjectName, Subscript, TimezoneInfo, UnaryOperator, }; use sqlparser::ast::{CaseWhen, DuplicateTreatment, OrderByOptions, ValueWithSpan}; use std::sync::Arc; @@ -528,8 +528,20 @@ impl Unparser<'_> { } Expr::OuterReferenceColumn(_, col) => self.col_to_sql(col), Expr::Unnest(unnest) => self.unnest_to_sql(unnest), + Expr::LambdaFunction(LambdaFunction { func, args }) => { + let func_name = func.name(); + + if let Some(expr) = self + .dialect + .scalar_function_to_sql_overrides(self, func_name, args)? + { + return Ok(expr); + } + + self.scalar_function_to_sql(func_name, args) + } Expr::Lambda(Lambda { params, body }) => { - Ok(ast::Expr::Lambda(LambdaFunction { + Ok(ast::Expr::Lambda(ast::LambdaFunction { params: ast::OneOrManyWithParens::Many( params.iter().map(|param| param.as_str().into()).collect(), ), diff --git a/datafusion/sql/tests/common/mod.rs b/datafusion/sql/tests/common/mod.rs index 5d9fd9f2c3740..6c9ac4bf70046 100644 --- a/datafusion/sql/tests/common/mod.rs +++ b/datafusion/sql/tests/common/mod.rs @@ -26,7 +26,7 @@ use datafusion_common::config::ConfigOptions; use datafusion_common::file_options::file_type::FileType; use datafusion_common::{plan_err, DFSchema, GetExt, Result, TableReference}; use datafusion_expr::planner::{ExprPlanner, PlannerResult, TypePlanner}; -use datafusion_expr::{AggregateUDF, Expr, ScalarUDF, TableSource, WindowUDF}; +use datafusion_expr::{AggregateUDF, Expr, LambdaUDF, ScalarUDF, TableSource, WindowUDF}; use datafusion_functions_nested::expr_fn::make_array; use datafusion_sql::planner::ContextProvider; @@ -53,6 +53,7 @@ impl Display for MockCsvType { #[derive(Default)] pub(crate) struct MockSessionState { scalar_functions: HashMap>, + lambda_functions: HashMap>, aggregate_functions: HashMap>, expr_planners: Vec>, type_planner: Option>, @@ -240,6 +241,10 @@ impl ContextProvider for MockContextProvider { self.state.scalar_functions.get(name).cloned() } + fn get_lambda_meta(&self, name: &str) -> Option> { + self.state.lambda_functions.get(name).cloned() + } + fn get_aggregate_meta(&self, name: &str) -> Option> { self.state.aggregate_functions.get(name).cloned() } diff --git a/datafusion/substrait/src/logical_plan/producer/expr/mod.rs b/datafusion/substrait/src/logical_plan/producer/expr/mod.rs index b3ca88a690291..d1112b99536d9 100644 --- a/datafusion/substrait/src/logical_plan/producer/expr/mod.rs +++ b/datafusion/substrait/src/logical_plan/producer/expr/mod.rs @@ -152,8 +152,9 @@ pub fn to_substrait_rex( not_impl_err!("Cannot convert {expr:?} to Substrait") } Expr::Unnest(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), - Expr::Lambda(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), // https://github.com/substrait-io/substrait/issues/349 - Expr::LambdaVariable(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), // https://github.com/substrait-io/substrait/issues/349 + Expr::LambdaFunction(expr) => producer.handle_lambda_function(expr, schema), + Expr::Lambda(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), // not yet implemented in substrait-rs + Expr::LambdaVariable(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), // not yet implemented in substrait-rs } } diff --git a/datafusion/substrait/src/logical_plan/producer/expr/scalar_function.rs b/datafusion/substrait/src/logical_plan/producer/expr/scalar_function.rs index abb26f6f66822..b2057a9d914f8 100644 --- a/datafusion/substrait/src/logical_plan/producer/expr/scalar_function.rs +++ b/datafusion/substrait/src/logical_plan/producer/expr/scalar_function.rs @@ -26,17 +26,34 @@ pub fn from_scalar_function( producer: &mut impl SubstraitProducer, fun: &expr::ScalarFunction, schema: &DFSchemaRef, +) -> datafusion::common::Result { + from_function(producer, fun.name(), &fun.args, schema) +} + +pub fn from_lambda_function( + producer: &mut impl SubstraitProducer, + fun: &expr::LambdaFunction, + schema: &DFSchemaRef, +) -> datafusion::common::Result { + from_function(producer, fun.name(), &fun.args, schema) +} + +fn from_function( + producer: &mut impl SubstraitProducer, + name: &str, + args: &[Expr], + schema: &DFSchemaRef, ) -> datafusion::common::Result { let mut arguments: Vec = vec![]; - for arg in &fun.args { + for arg in args { arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(producer.handle_expr(arg, schema)?)), }); } - let arguments = custom_argument_handler(fun.name(), arguments); + let arguments = custom_argument_handler(name, arguments); - let function_anchor = producer.register_function(fun.name().to_string()); + let function_anchor = producer.register_function(name.to_string()); #[allow(deprecated)] Ok(Expression { rex_type: Some(RexType::ScalarFunction(ScalarFunction { diff --git a/datafusion/substrait/src/logical_plan/producer/substrait_producer.rs b/datafusion/substrait/src/logical_plan/producer/substrait_producer.rs index db08e0f7bfd0c..d065bcf41586a 100644 --- a/datafusion/substrait/src/logical_plan/producer/substrait_producer.rs +++ b/datafusion/substrait/src/logical_plan/producer/substrait_producer.rs @@ -17,12 +17,7 @@ use crate::extensions::Extensions; use crate::logical_plan::producer::{ - from_aggregate, from_aggregate_function, from_alias, from_between, from_binary_expr, - from_case, from_cast, from_column, from_distinct, from_empty_relation, from_filter, - from_in_list, from_in_subquery, from_join, from_like, from_limit, from_literal, - from_projection, from_repartition, from_scalar_function, from_sort, - from_subquery_alias, from_table_scan, from_try_cast, from_unary_expr, from_union, - from_values, from_window, from_window_function, to_substrait_rel, to_substrait_rex, + from_aggregate, from_aggregate_function, from_alias, from_between, from_binary_expr, from_case, from_cast, from_column, from_distinct, from_empty_relation, from_filter, from_in_list, from_in_subquery, from_join, from_lambda_function, from_like, from_limit, from_literal, from_projection, from_repartition, from_scalar_function, from_sort, from_subquery_alias, from_table_scan, from_try_cast, from_unary_expr, from_union, from_values, from_window, from_window_function, to_substrait_rel, to_substrait_rex }; use datafusion::common::{substrait_err, Column, DFSchemaRef, ScalarValue}; use datafusion::execution::registry::SerializerRegistry; @@ -327,6 +322,14 @@ pub trait SubstraitProducer: Send + Sync + Sized { ) -> datafusion::common::Result { from_scalar_function(self, scalar_fn, schema) } + + fn handle_lambda_function( + &mut self, + scalar_fn: &expr::LambdaFunction, + schema: &DFSchemaRef, + ) -> datafusion::common::Result { + from_lambda_function(self, scalar_fn, schema) + } fn handle_aggregate_function( &mut self, From 1f19c6406b97e05ffbd5088bfeaf5c0e7320a622 Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Mon, 23 Feb 2026 14:40:03 -0300 Subject: [PATCH 06/12] feat: remove lambda support for ScalarUDF --- datafusion/expr/src/expr_schema.rs | 6 - datafusion/expr/src/udf.rs | 105 ----------------- datafusion/ffi/src/udf/mod.rs | 8 +- datafusion/ffi/src/udf/return_type_args.rs | 9 +- datafusion/functions-nested/benches/map.rs | 1 - datafusion/functions-nested/src/array_has.rs | 2 - datafusion/functions-nested/src/map_values.rs | 1 - datafusion/functions-nested/src/set_ops.rs | 1 - datafusion/functions/benches/ascii.rs | 4 - .../functions/benches/character_length.rs | 4 - datafusion/functions/benches/chr.rs | 1 - datafusion/functions/benches/concat.rs | 1 - datafusion/functions/benches/cot.rs | 2 - datafusion/functions/benches/date_bin.rs | 1 - datafusion/functions/benches/date_trunc.rs | 1 - datafusion/functions/benches/encoding.rs | 4 - datafusion/functions/benches/find_in_set.rs | 4 - datafusion/functions/benches/gcd.rs | 3 - datafusion/functions/benches/initcap.rs | 3 - datafusion/functions/benches/isnan.rs | 2 - datafusion/functions/benches/iszero.rs | 2 - datafusion/functions/benches/lower.rs | 6 - datafusion/functions/benches/ltrim.rs | 1 - datafusion/functions/benches/make_date.rs | 4 - datafusion/functions/benches/nullif.rs | 1 - datafusion/functions/benches/pad.rs | 1 - datafusion/functions/benches/random.rs | 2 - datafusion/functions/benches/repeat.rs | 1 - datafusion/functions/benches/reverse.rs | 4 - datafusion/functions/benches/signum.rs | 2 - datafusion/functions/benches/strpos.rs | 4 - datafusion/functions/benches/substr.rs | 1 - datafusion/functions/benches/substr_index.rs | 1 - datafusion/functions/benches/to_char.rs | 6 - datafusion/functions/benches/to_hex.rs | 2 - datafusion/functions/benches/to_timestamp.rs | 6 - datafusion/functions/benches/trunc.rs | 2 - datafusion/functions/benches/upper.rs | 1 - datafusion/functions/benches/uuid.rs | 1 - .../functions/src/core/union_extract.rs | 3 - datafusion/functions/src/core/union_tag.rs | 2 - datafusion/functions/src/core/version.rs | 1 - datafusion/functions/src/datetime/date_bin.rs | 1 - .../functions/src/datetime/date_trunc.rs | 2 - .../functions/src/datetime/from_unixtime.rs | 2 - .../functions/src/datetime/make_date.rs | 1 - datafusion/functions/src/datetime/now.rs | 2 - datafusion/functions/src/datetime/to_char.rs | 7 -- datafusion/functions/src/datetime/to_date.rs | 1 - .../functions/src/datetime/to_local_time.rs | 2 - .../functions/src/datetime/to_timestamp.rs | 2 - datafusion/functions/src/math/log.rs | 18 --- datafusion/functions/src/math/power.rs | 2 - datafusion/functions/src/math/signum.rs | 2 - datafusion/functions/src/regex/regexpcount.rs | 1 - datafusion/functions/src/regex/regexpinstr.rs | 1 - datafusion/functions/src/string/concat.rs | 1 - datafusion/functions/src/string/concat_ws.rs | 2 - datafusion/functions/src/string/contains.rs | 1 - datafusion/functions/src/string/lower.rs | 1 - datafusion/functions/src/string/upper.rs | 1 - .../functions/src/unicode/find_in_set.rs | 1 - datafusion/functions/src/unicode/strpos.rs | 1 - datafusion/functions/src/utils.rs | 3 - .../src/async_scalar_function.rs | 2 - .../physical-expr/src/scalar_function.rs | 108 ++---------------- datafusion/spark/benches/char.rs | 1 - .../spark/src/function/bitmap/bitmap_count.rs | 1 - .../src/function/datetime/make_dt_interval.rs | 1 - .../src/function/datetime/make_interval.rs | 1 - .../spark/src/function/string/concat.rs | 2 - datafusion/spark/src/function/utils.rs | 3 - 72 files changed, 10 insertions(+), 381 deletions(-) diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 6a3a7cbc85e76..f3789ca9fd115 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -595,15 +595,9 @@ impl ExprSchemable for Expr { }) .collect::>(); - let lambdas = args - .iter() - .map(|e| matches!(e, Expr::Lambda { .. })) - .collect::>(); - let args = ReturnFieldArgs { arg_fields: &new_fields, scalar_arguments: &arguments, - lambdas: &lambdas, }; func.return_field_from_args(args) diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 911fc890e2bc5..fd54bb13a62f3 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -23,13 +23,11 @@ use crate::simplify::{ExprSimplifyResult, SimplifyInfo}; use crate::sort_properties::{ExprProperties, SortProperties}; use crate::udf_eq::UdfEq; use crate::{ColumnarValue, Documentation, Expr, Signature}; -use arrow::array::RecordBatch; use arrow::datatypes::{DataType, Field, FieldRef}; use datafusion_common::config::ConfigOptions; use datafusion_common::{not_impl_err, ExprSchema, Result, ScalarValue}; use datafusion_expr_common::dyn_eq::{DynEq, DynHash}; use datafusion_expr_common::interval_arithmetic::Interval; -use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use std::any::Any; use std::cmp::Ordering; use std::fmt::Debug; @@ -347,14 +345,6 @@ impl ScalarUDF { } } -#[derive(Clone, Debug)] -pub enum ValueOrLambdaParameter { - /// A columnar value with the given field - Value(FieldRef), - /// A lambda - Lambda, -} - impl From for ScalarUDF where F: ScalarUDFImpl + 'static, @@ -369,7 +359,6 @@ where #[derive(Debug, Clone)] pub struct ScalarFunctionArgs { /// The evaluated arguments to the function - /// If it's a lambda, will be `ColumnarValue::Scalar(ScalarValue::Null)` pub args: Vec, /// Field associated with each arg, if it exists pub arg_fields: Vec, @@ -381,30 +370,6 @@ pub struct ScalarFunctionArgs { pub return_field: FieldRef, /// The config options at execution time pub config_options: Arc, - /// The lambdas passed to the function - /// If it's not a lambda it will be `None` - pub lambdas: Option>>, -} - -/// A lambda argument to a ScalarFunction -#[derive(Clone, Debug)] -pub struct ScalarFunctionLambdaArg { - /// The parameters defined in this lambda - /// - /// For example, for `array_transform([2], v -> -v)`, - /// this will be `vec![Field::new("v", DataType::Int32, true)]` - pub params: Vec, - /// The body of the lambda - /// - /// For example, for `array_transform([2], v -> -v)`, - /// this will be the physical expression of `-v` - pub body: Arc, - /// A RecordBatch containing at least the captured columns inside this lambda body, if any - /// Note that it may contain additional, non-specified columns, but that's implementation detail - /// - /// For example, for `array_transform([2], v -> v + a + b)`, - /// this will be a `RecordBatch` with two columns, `a` and `b` - pub captures: Option, } impl ScalarFunctionArgs { @@ -413,25 +378,6 @@ impl ScalarFunctionArgs { pub fn return_type(&self) -> &DataType { self.return_field.data_type() } - - pub fn to_lambda_args(&self) -> Vec> { - match &self.lambdas { - Some(lambdas) => std::iter::zip(&self.args, lambdas) - .map(|(arg, lambda)| match lambda { - Some(lambda) => ValueOrLambda::Lambda(lambda), - None => ValueOrLambda::Value(arg), - }) - .collect(), - None => self.args.iter().map(ValueOrLambda::Value).collect(), - } - } -} - -// An argument to a ScalarUDF that supports lambdas -#[derive(Debug)] -pub enum ValueOrLambda<'a> { - Value(&'a ColumnarValue), - Lambda(&'a ScalarFunctionLambdaArg), } /// Information about arguments passed to the function @@ -444,12 +390,6 @@ pub enum ValueOrLambda<'a> { #[derive(Debug)] pub struct ReturnFieldArgs<'a> { /// The data types of the arguments to the function - /// - /// If argument `i` to the function is a lambda, it will be the field returned by the - /// lambda when executed with the arguments returned from `ScalarUDFImpl::lambdas_parameters` - /// - /// For example, with `array_transform([1], v -> v == 5)` - /// this field will be `[Field::new("", DataType::List(DataType::Int32), false), Field::new("", DataType::Boolean, false)]` pub arg_fields: &'a [FieldRef], /// Is argument `i` to the function a scalar (constant)? /// @@ -458,36 +398,6 @@ pub struct ReturnFieldArgs<'a> { /// For example, if a function is called like `my_function(column_a, 5)` /// this field will be `[None, Some(ScalarValue::Int32(Some(5)))]` pub scalar_arguments: &'a [Option<&'a ScalarValue>], - /// Is argument `i` to the function a lambda? - /// - /// For example, with `array_transform([1], v -> v == 5)` - /// this field will be `[false, true]` - pub lambdas: &'a [bool], -} - -/// A tagged Field indicating whether it correspond to a value or a lambda argument -#[derive(Debug)] -pub enum ValueOrLambdaField<'a> { - /// The Field of a ColumnarValue argument - Value(&'a FieldRef), - /// The Field of the return of the lambda body when evaluated with the parameters from ScalarUDF::lambda_parameters - Lambda(&'a FieldRef), -} - -impl<'a> ReturnFieldArgs<'a> { - /// Based on self.lambdas, encodes self.arg_fields to tagged enums - /// indicating whether it correspond to a value or a lambda argument - pub fn to_lambda_args(&self) -> Vec> { - std::iter::zip(self.arg_fields, self.lambdas) - .map(|(field, is_lambda)| { - if *is_lambda { - ValueOrLambdaField::Lambda(field) - } else { - ValueOrLambdaField::Value(field) - } - }) - .collect() - } } /// Trait for implementing user defined scalar functions. @@ -931,14 +841,6 @@ pub trait ScalarUDFImpl: Debug + DynEq + DynHash + Send + Sync { fn documentation(&self) -> Option<&Documentation> { None } - - /// Returns the parameters that any lambda supports - fn lambdas_parameters( - &self, - args: &[ValueOrLambdaParameter], - ) -> Result>>> { - Ok(vec![None; args.len()]) - } } /// ScalarUDF that adds an alias to the underlying function. It is better to @@ -1057,13 +959,6 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl { fn documentation(&self) -> Option<&Documentation> { self.inner.documentation() } - - fn lambdas_parameters( - &self, - args: &[ValueOrLambdaParameter], - ) -> Result>>> { - self.inner.lambdas_parameters(args) - } } #[cfg(test)] diff --git a/datafusion/ffi/src/udf/mod.rs b/datafusion/ffi/src/udf/mod.rs index 400ad44696047..5e59cfc5ecb07 100644 --- a/datafusion/ffi/src/udf/mod.rs +++ b/datafusion/ffi/src/udf/mod.rs @@ -33,7 +33,7 @@ use arrow::{ }; use arrow_schema::FieldRef; use datafusion::config::ConfigOptions; -use datafusion::{common::exec_err, logical_expr::ReturnFieldArgs}; +use datafusion::logical_expr::ReturnFieldArgs; use datafusion::{ error::DataFusionError, logical_expr::type_coercion::functions::data_types_with_scalar_udf, @@ -210,7 +210,6 @@ unsafe extern "C" fn invoke_with_args_fn_wrapper( return_field, // TODO: pass config options: https://github.com/apache/datafusion/issues/17035 config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = rresult_return!(udf @@ -383,15 +382,10 @@ impl ScalarUDFImpl for ForeignScalarUDF { arg_fields, number_rows, return_field, - lambdas, // TODO: pass config options: https://github.com/apache/datafusion/issues/17035 config_options: _config_options, } = invoke_args; - if lambdas.is_some_and(|lambdas| lambdas.iter().any(|l| l.is_some())) { - return exec_err!("ForeignScalarUDF doesn't support lambdas"); - } - let args = args .into_iter() .map(|v| v.to_array(number_rows)) diff --git a/datafusion/ffi/src/udf/return_type_args.rs b/datafusion/ffi/src/udf/return_type_args.rs index d5cbfff1d3a4b..c437c9537be6f 100644 --- a/datafusion/ffi/src/udf/return_type_args.rs +++ b/datafusion/ffi/src/udf/return_type_args.rs @@ -21,7 +21,7 @@ use abi_stable::{ }; use arrow_schema::FieldRef; use datafusion::{ - common::{exec_datafusion_err, exec_err}, error::DataFusionError, logical_expr::ReturnFieldArgs, + common::exec_datafusion_err, error::DataFusionError, logical_expr::ReturnFieldArgs, scalar::ScalarValue, }; @@ -42,10 +42,6 @@ impl TryFrom> for FFI_ReturnFieldArgs { type Error = DataFusionError; fn try_from(value: ReturnFieldArgs) -> Result { - if value.lambdas.iter().any(|l| *l) { - return exec_err!("FFI_ReturnFieldArgs doesn't support lambdas") - } - let arg_fields = vec_fieldref_to_rvec_wrapped(value.arg_fields)?; let scalar_arguments: Result, Self::Error> = value .scalar_arguments @@ -81,7 +77,6 @@ pub struct ForeignReturnFieldArgsOwned { pub struct ForeignReturnFieldArgs<'a> { arg_fields: &'a [FieldRef], scalar_arguments: Vec>, - lambdas: Vec, // currently always false, used to return a reference in From<&Self> for ReturnFieldArgs } impl TryFrom<&FFI_ReturnFieldArgs> for ForeignReturnFieldArgsOwned { @@ -121,7 +116,6 @@ impl<'a> From<&'a ForeignReturnFieldArgsOwned> for ForeignReturnFieldArgs<'a> { .iter() .map(|opt| opt.as_ref()) .collect(), - lambdas: vec![false; value.arg_fields.len()] } } } @@ -131,7 +125,6 @@ impl<'a> From<&'a ForeignReturnFieldArgs<'a>> for ReturnFieldArgs<'a> { ReturnFieldArgs { arg_fields: value.arg_fields, scalar_arguments: &value.scalar_arguments, - lambdas: &value.lambdas, } } } diff --git a/datafusion/functions-nested/benches/map.rs b/datafusion/functions-nested/benches/map.rs index 3075d2e573e4a..3197cc55cc957 100644 --- a/datafusion/functions-nested/benches/map.rs +++ b/datafusion/functions-nested/benches/map.rs @@ -117,7 +117,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: 1, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), - lambdas: None, }) .expect("map should work on valid values"), ); diff --git a/datafusion/functions-nested/src/array_has.rs b/datafusion/functions-nested/src/array_has.rs index d6a333c0a0ef3..080b2f16d92f3 100644 --- a/datafusion/functions-nested/src/array_has.rs +++ b/datafusion/functions-nested/src/array_has.rs @@ -819,7 +819,6 @@ mod tests { number_rows: 1, return_field, config_options: Arc::new(ConfigOptions::default()), - lambdas: None, })?; let output = result.into_array(1)?; @@ -848,7 +847,6 @@ mod tests { number_rows: 1, return_field, config_options: Arc::new(ConfigOptions::default()), - lambdas: None, })?; let output = result.into_array(1)?; diff --git a/datafusion/functions-nested/src/map_values.rs b/datafusion/functions-nested/src/map_values.rs index ac21ff8acd3f9..6ae8a278063da 100644 --- a/datafusion/functions-nested/src/map_values.rs +++ b/datafusion/functions-nested/src/map_values.rs @@ -204,7 +204,6 @@ mod tests { let args = datafusion_expr::ReturnFieldArgs { arg_fields: &[field], scalar_arguments: &[None::<&ScalarValue>], - lambdas: &[false], }; func.return_field_from_args(args).unwrap() diff --git a/datafusion/functions-nested/src/set_ops.rs b/datafusion/functions-nested/src/set_ops.rs index f26fc173d8a9f..53642bf1622b0 100644 --- a/datafusion/functions-nested/src/set_ops.rs +++ b/datafusion/functions-nested/src/set_ops.rs @@ -596,7 +596,6 @@ mod tests { number_rows: 1, return_field: input_field.clone().into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, })?; assert_eq!( diff --git a/datafusion/functions/benches/ascii.rs b/datafusion/functions/benches/ascii.rs index 97e6ab20ed458..03d25e9c3d4fe 100644 --- a/datafusion/functions/benches/ascii.rs +++ b/datafusion/functions/benches/ascii.rs @@ -60,7 +60,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: N_ROWS, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), - lambdas: None, })) }) }, @@ -82,7 +81,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: N_ROWS, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), - lambdas: None, })) }) }, @@ -110,7 +108,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: N_ROWS, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), - lambdas: None, })) }) }, @@ -132,7 +129,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: N_ROWS, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), - lambdas: None, })) }) }, diff --git a/datafusion/functions/benches/character_length.rs b/datafusion/functions/benches/character_length.rs index f98e8a8b1a68b..4a1a63d62765f 100644 --- a/datafusion/functions/benches/character_length.rs +++ b/datafusion/functions/benches/character_length.rs @@ -55,7 +55,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: n_rows, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), - lambdas: None, })) }) }, @@ -80,7 +79,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: n_rows, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), - lambdas: None, })) }) }, @@ -105,7 +103,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: n_rows, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), - lambdas: None, })) }) }, @@ -130,7 +127,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: n_rows, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), - lambdas: None, })) }) }, diff --git a/datafusion/functions/benches/chr.rs b/datafusion/functions/benches/chr.rs index d51cda4566d64..8356cf7c31726 100644 --- a/datafusion/functions/benches/chr.rs +++ b/datafusion/functions/benches/chr.rs @@ -69,7 +69,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, }) .unwrap(), ) diff --git a/datafusion/functions/benches/concat.rs b/datafusion/functions/benches/concat.rs index 6378328537827..09200139a244b 100644 --- a/datafusion/functions/benches/concat.rs +++ b/datafusion/functions/benches/concat.rs @@ -60,7 +60,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, }) .unwrap(), ) diff --git a/datafusion/functions/benches/cot.rs b/datafusion/functions/benches/cot.rs index 56f50522acc5d..97f21ccd6d55e 100644 --- a/datafusion/functions/benches/cot.rs +++ b/datafusion/functions/benches/cot.rs @@ -54,7 +54,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Float32, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, }) .unwrap(), ) @@ -81,7 +80,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), - lambdas: None, }) .unwrap(), ) diff --git a/datafusion/functions/benches/date_bin.rs b/datafusion/functions/benches/date_bin.rs index 1c3713723738a..74390491d538c 100644 --- a/datafusion/functions/benches/date_bin.rs +++ b/datafusion/functions/benches/date_bin.rs @@ -66,7 +66,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), - lambdas: None, }) .expect("date_bin should work on valid values"), ) diff --git a/datafusion/functions/benches/date_trunc.rs b/datafusion/functions/benches/date_trunc.rs index b757535fb03c5..498a3e63ef290 100644 --- a/datafusion/functions/benches/date_trunc.rs +++ b/datafusion/functions/benches/date_trunc.rs @@ -71,7 +71,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), - lambdas: None, }) .expect("date_trunc should work on valid values"), ) diff --git a/datafusion/functions/benches/encoding.rs b/datafusion/functions/benches/encoding.rs index 72b033cf5d9ed..98faee91e1911 100644 --- a/datafusion/functions/benches/encoding.rs +++ b/datafusion/functions/benches/encoding.rs @@ -45,7 +45,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, }) .unwrap(); @@ -64,7 +63,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, }) .unwrap(), ) @@ -84,7 +82,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, }) .unwrap(); @@ -104,7 +101,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), - lambdas: None, }) .unwrap(), ) diff --git a/datafusion/functions/benches/find_in_set.rs b/datafusion/functions/benches/find_in_set.rs index 6fe498a58d84b..a928f5655806c 100644 --- a/datafusion/functions/benches/find_in_set.rs +++ b/datafusion/functions/benches/find_in_set.rs @@ -168,7 +168,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: n_rows, return_field: Arc::clone(&return_field), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, })) }) }); @@ -187,7 +186,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: n_rows, return_field: Arc::clone(&return_field), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, })) }) }); @@ -210,7 +208,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: n_rows, return_field: Arc::clone(&return_field), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, })) }) }); @@ -231,7 +228,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: n_rows, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), - lambdas: None, })) }) }); diff --git a/datafusion/functions/benches/gcd.rs b/datafusion/functions/benches/gcd.rs index 2bfec91e290dd..19e196d9a3eab 100644 --- a/datafusion/functions/benches/gcd.rs +++ b/datafusion/functions/benches/gcd.rs @@ -58,7 +58,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: 0, return_field: Field::new("f", DataType::Int64, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, }) .expect("date_bin should work on valid values"), ) @@ -80,7 +79,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: 0, return_field: Field::new("f", DataType::Int64, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, }) .expect("date_bin should work on valid values"), ) @@ -102,7 +100,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: 0, return_field: Field::new("f", DataType::Int64, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, }) .expect("date_bin should work on valid values"), ) diff --git a/datafusion/functions/benches/initcap.rs b/datafusion/functions/benches/initcap.rs index 37d98596deb82..50aee8dbb9161 100644 --- a/datafusion/functions/benches/initcap.rs +++ b/datafusion/functions/benches/initcap.rs @@ -70,7 +70,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8View, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, })) }) }, @@ -87,7 +86,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8View, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, })) }) }, @@ -102,7 +100,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, })) }) }); diff --git a/datafusion/functions/benches/isnan.rs b/datafusion/functions/benches/isnan.rs index dcce59e46ce41..4a90d45d66223 100644 --- a/datafusion/functions/benches/isnan.rs +++ b/datafusion/functions/benches/isnan.rs @@ -53,7 +53,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Boolean, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, }) .unwrap(), ) @@ -78,7 +77,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Boolean, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, }) .unwrap(), ) diff --git a/datafusion/functions/benches/iszero.rs b/datafusion/functions/benches/iszero.rs index 574539fbb6427..961cba7200ce0 100644 --- a/datafusion/functions/benches/iszero.rs +++ b/datafusion/functions/benches/iszero.rs @@ -55,7 +55,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), - lambdas: None, }) .unwrap(), ) @@ -83,7 +82,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), - lambdas: None, }) .unwrap(), ) diff --git a/datafusion/functions/benches/lower.rs b/datafusion/functions/benches/lower.rs index e741afd0d8e01..6a5178b87fdce 100644 --- a/datafusion/functions/benches/lower.rs +++ b/datafusion/functions/benches/lower.rs @@ -145,7 +145,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, })) }) }); @@ -168,7 +167,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, })) }) }); @@ -193,7 +191,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, })) }) }, @@ -228,7 +225,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, })) }), ); @@ -244,7 +240,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, })) }), ); @@ -261,7 +256,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, })) }), ); diff --git a/datafusion/functions/benches/ltrim.rs b/datafusion/functions/benches/ltrim.rs index 9b344cc6b143a..4458af614396d 100644 --- a/datafusion/functions/benches/ltrim.rs +++ b/datafusion/functions/benches/ltrim.rs @@ -153,7 +153,6 @@ fn run_with_string_type( number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, })) }) }, diff --git a/datafusion/functions/benches/make_date.rs b/datafusion/functions/benches/make_date.rs index 2a681ddedcbe8..15a895468db93 100644 --- a/datafusion/functions/benches/make_date.rs +++ b/datafusion/functions/benches/make_date.rs @@ -81,7 +81,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), - lambdas: None, }) .expect("make_date should work on valid values"), ) @@ -112,7 +111,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), - lambdas: None, }) .expect("make_date should work on valid values"), ) @@ -143,7 +141,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), - lambdas: None, }) .expect("make_date should work on valid values"), ) @@ -171,7 +168,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: 1, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), - lambdas: None, }) .expect("make_date should work on valid values"), ) diff --git a/datafusion/functions/benches/nullif.rs b/datafusion/functions/benches/nullif.rs index 15914cd7ee6c5..d649697cc5188 100644 --- a/datafusion/functions/benches/nullif.rs +++ b/datafusion/functions/benches/nullif.rs @@ -54,7 +54,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, }) .unwrap(), ) diff --git a/datafusion/functions/benches/pad.rs b/datafusion/functions/benches/pad.rs index c7d46da3d26c6..f92a69bbf4f92 100644 --- a/datafusion/functions/benches/pad.rs +++ b/datafusion/functions/benches/pad.rs @@ -116,7 +116,6 @@ fn invoke_pad_with_args( number_rows, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, }; if left_pad { diff --git a/datafusion/functions/benches/random.rs b/datafusion/functions/benches/random.rs index 2935876685800..88efb2d1b5b93 100644 --- a/datafusion/functions/benches/random.rs +++ b/datafusion/functions/benches/random.rs @@ -43,7 +43,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: 8192, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), - lambdas: None, }) .unwrap(), ); @@ -65,7 +64,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: 128, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), - lambdas: None, }) .unwrap(), ); diff --git a/datafusion/functions/benches/repeat.rs b/datafusion/functions/benches/repeat.rs index 9a7c63ed4f304..80ffa8ee38f1a 100644 --- a/datafusion/functions/benches/repeat.rs +++ b/datafusion/functions/benches/repeat.rs @@ -76,7 +76,6 @@ fn invoke_repeat_with_args( number_rows: repeat_times as usize, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, }) } diff --git a/datafusion/functions/benches/reverse.rs b/datafusion/functions/benches/reverse.rs index a8af40cd8cc19..b1eca654fb254 100644 --- a/datafusion/functions/benches/reverse.rs +++ b/datafusion/functions/benches/reverse.rs @@ -58,7 +58,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: N_ROWS, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, })) }) }, @@ -81,7 +80,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: N_ROWS, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, })) }) }, @@ -109,7 +107,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: N_ROWS, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, })) }) }, @@ -134,7 +131,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: N_ROWS, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, })) }) }, diff --git a/datafusion/functions/benches/signum.rs b/datafusion/functions/benches/signum.rs index 805b62c83da6d..24b8861e4d28c 100644 --- a/datafusion/functions/benches/signum.rs +++ b/datafusion/functions/benches/signum.rs @@ -55,7 +55,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), - lambdas: None, }) .unwrap(), ) @@ -84,7 +83,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), - lambdas: None, }) .unwrap(), ) diff --git a/datafusion/functions/benches/strpos.rs b/datafusion/functions/benches/strpos.rs index 708ebb5518727..18a99e44bf487 100644 --- a/datafusion/functions/benches/strpos.rs +++ b/datafusion/functions/benches/strpos.rs @@ -128,7 +128,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: n_rows, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), - lambdas: None, })) }) }, @@ -147,7 +146,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: n_rows, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), - lambdas: None, })) }) }); @@ -167,7 +165,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: n_rows, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), - lambdas: None, })) }) }, @@ -188,7 +185,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: n_rows, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), - lambdas: None, })) }) }, diff --git a/datafusion/functions/benches/substr.rs b/datafusion/functions/benches/substr.rs index 58fda73defd25..771413458c1fb 100644 --- a/datafusion/functions/benches/substr.rs +++ b/datafusion/functions/benches/substr.rs @@ -116,7 +116,6 @@ fn invoke_substr_with_args( number_rows, return_field: Field::new("f", DataType::Utf8View, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, }) } diff --git a/datafusion/functions/benches/substr_index.rs b/datafusion/functions/benches/substr_index.rs index a77b961657c5f..d0941d9baedda 100644 --- a/datafusion/functions/benches/substr_index.rs +++ b/datafusion/functions/benches/substr_index.rs @@ -110,7 +110,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, }) .expect("substr_index should work on valid values"), ) diff --git a/datafusion/functions/benches/to_char.rs b/datafusion/functions/benches/to_char.rs index 61990b4cb8b95..945508aec7405 100644 --- a/datafusion/functions/benches/to_char.rs +++ b/datafusion/functions/benches/to_char.rs @@ -149,7 +149,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, }) .expect("to_char should work on valid values"), ) @@ -177,7 +176,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, }) .expect("to_char should work on valid values"), ) @@ -205,7 +203,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, }) .expect("to_char should work on valid values"), ) @@ -232,7 +229,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, }) .expect("to_char should work on valid values"), ) @@ -260,7 +256,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, }) .expect("to_char should work on valid values"), ) @@ -293,7 +288,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: 1, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, }) .expect("to_char should work on valid values"), ) diff --git a/datafusion/functions/benches/to_hex.rs b/datafusion/functions/benches/to_hex.rs index baa2de80c466f..a75ed9258791e 100644 --- a/datafusion/functions/benches/to_hex.rs +++ b/datafusion/functions/benches/to_hex.rs @@ -44,7 +44,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, }) .unwrap(), ) @@ -63,7 +62,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, }) .unwrap(), ) diff --git a/datafusion/functions/benches/to_timestamp.rs b/datafusion/functions/benches/to_timestamp.rs index e510a7c3fad41..a8f5c5816d4da 100644 --- a/datafusion/functions/benches/to_timestamp.rs +++ b/datafusion/functions/benches/to_timestamp.rs @@ -130,7 +130,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), - lambdas: None, }) .expect("to_timestamp should work on valid values"), ) @@ -151,7 +150,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), - lambdas: None, }) .expect("to_timestamp should work on valid values"), ) @@ -172,7 +170,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), - lambdas: None, }) .expect("to_timestamp should work on valid values"), ) @@ -206,7 +203,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), - lambdas: None, }) .expect("to_timestamp should work on valid values"), ) @@ -248,7 +244,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), - lambdas: None, }) .expect("to_timestamp should work on valid values"), ) @@ -291,7 +286,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), - lambdas: None, }) .expect("to_timestamp should work on valid values"), ) diff --git a/datafusion/functions/benches/trunc.rs b/datafusion/functions/benches/trunc.rs index 0b08791f9ae50..6e225e0e7038b 100644 --- a/datafusion/functions/benches/trunc.rs +++ b/datafusion/functions/benches/trunc.rs @@ -49,7 +49,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), - lambdas: None, }) .unwrap(), ) @@ -69,7 +68,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), - lambdas: None, }) .unwrap(), ) diff --git a/datafusion/functions/benches/upper.rs b/datafusion/functions/benches/upper.rs index e9f0941032d8a..7328b32574a4a 100644 --- a/datafusion/functions/benches/upper.rs +++ b/datafusion/functions/benches/upper.rs @@ -50,7 +50,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, })) }) }); diff --git a/datafusion/functions/benches/uuid.rs b/datafusion/functions/benches/uuid.rs index 8ad79b2866eaf..1368e2f2af5d1 100644 --- a/datafusion/functions/benches/uuid.rs +++ b/datafusion/functions/benches/uuid.rs @@ -37,7 +37,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: 1024, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), - lambdas: None, })) }) }); diff --git a/datafusion/functions/src/core/union_extract.rs b/datafusion/functions/src/core/union_extract.rs index ac542866f7e43..7f93500b9cfb9 100644 --- a/datafusion/functions/src/core/union_extract.rs +++ b/datafusion/functions/src/core/union_extract.rs @@ -209,7 +209,6 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, })?; assert_scalar(result, ScalarValue::Utf8(None)); @@ -233,7 +232,6 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, })?; assert_scalar(result, ScalarValue::Utf8(None)); @@ -257,7 +255,6 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, })?; assert_scalar(result, ScalarValue::new_utf8("42")); diff --git a/datafusion/functions/src/core/union_tag.rs b/datafusion/functions/src/core/union_tag.rs index 4832c368872bf..aeadb8292ba1e 100644 --- a/datafusion/functions/src/core/union_tag.rs +++ b/datafusion/functions/src/core/union_tag.rs @@ -184,7 +184,6 @@ mod tests { return_field: Field::new("res", return_type, true).into(), arg_fields: vec![], config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }) .unwrap(); @@ -208,7 +207,6 @@ mod tests { return_field: Field::new("res", return_type, true).into(), arg_fields: vec![], config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }) .unwrap(); diff --git a/datafusion/functions/src/core/version.rs b/datafusion/functions/src/core/version.rs index 390111028c8f2..ef3c5aafa4801 100644 --- a/datafusion/functions/src/core/version.rs +++ b/datafusion/functions/src/core/version.rs @@ -112,7 +112,6 @@ mod test { number_rows: 0, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }) .unwrap(); diff --git a/datafusion/functions/src/datetime/date_bin.rs b/datafusion/functions/src/datetime/date_bin.rs index 5466129314640..92af123dbafac 100644 --- a/datafusion/functions/src/datetime/date_bin.rs +++ b/datafusion/functions/src/datetime/date_bin.rs @@ -530,7 +530,6 @@ mod tests { number_rows, return_field: Arc::clone(return_field), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; DateBinFunc::new().invoke_with_args(args) } diff --git a/datafusion/functions/src/datetime/date_trunc.rs b/datafusion/functions/src/datetime/date_trunc.rs index 5736c221cae84..913e6217af82d 100644 --- a/datafusion/functions/src/datetime/date_trunc.rs +++ b/datafusion/functions/src/datetime/date_trunc.rs @@ -892,7 +892,6 @@ mod tests { ) .into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = DateTruncFunc::new().invoke_with_args(args).unwrap(); if let ColumnarValue::Array(result) = result { @@ -1081,7 +1080,6 @@ mod tests { ) .into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = DateTruncFunc::new().invoke_with_args(args).unwrap(); if let ColumnarValue::Array(result) = result { diff --git a/datafusion/functions/src/datetime/from_unixtime.rs b/datafusion/functions/src/datetime/from_unixtime.rs index be44be094e5b7..5d6adfb6f119a 100644 --- a/datafusion/functions/src/datetime/from_unixtime.rs +++ b/datafusion/functions/src/datetime/from_unixtime.rs @@ -179,7 +179,6 @@ mod test { number_rows: 1, return_field: Field::new("f", DataType::Timestamp(Second, None), true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = FromUnixtimeFunc::new().invoke_with_args(args).unwrap(); @@ -213,7 +212,6 @@ mod test { ) .into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = FromUnixtimeFunc::new().invoke_with_args(args).unwrap(); diff --git a/datafusion/functions/src/datetime/make_date.rs b/datafusion/functions/src/datetime/make_date.rs index afa4ef132147a..0fe5d156a8383 100644 --- a/datafusion/functions/src/datetime/make_date.rs +++ b/datafusion/functions/src/datetime/make_date.rs @@ -250,7 +250,6 @@ mod tests { number_rows, return_field: Field::new("f", DataType::Date32, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; MakeDateFunc::new().invoke_with_args(args) } diff --git a/datafusion/functions/src/datetime/now.rs b/datafusion/functions/src/datetime/now.rs index f18e72a107e28..4723548a45584 100644 --- a/datafusion/functions/src/datetime/now.rs +++ b/datafusion/functions/src/datetime/now.rs @@ -163,7 +163,6 @@ mod tests { .return_field_from_args(ReturnFieldArgs { arg_fields: &empty_fields, scalar_arguments: &empty_scalars, - lambdas: &[], }) .expect("legacy now() return field"); @@ -171,7 +170,6 @@ mod tests { .return_field_from_args(ReturnFieldArgs { arg_fields: &empty_fields, scalar_arguments: &empty_scalars, - lambdas: &[], }) .expect("configured now() return field"); diff --git a/datafusion/functions/src/datetime/to_char.rs b/datafusion/functions/src/datetime/to_char.rs index 5d69ce233f643..7d9b2bc241e1a 100644 --- a/datafusion/functions/src/datetime/to_char.rs +++ b/datafusion/functions/src/datetime/to_char.rs @@ -375,7 +375,6 @@ mod tests { number_rows: batch_len, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&Arc::new(ConfigOptions::default())), - lambdas: None, }; let result = ToCharFunc::new() .invoke_with_args(args) @@ -481,7 +480,6 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = ToCharFunc::new() .invoke_with_args(args) @@ -576,7 +574,6 @@ mod tests { number_rows: batch_len, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = ToCharFunc::new() .invoke_with_args(args) @@ -741,7 +738,6 @@ mod tests { number_rows: batch_len, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = ToCharFunc::new() .invoke_with_args(args) @@ -770,7 +766,6 @@ mod tests { number_rows: batch_len, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = ToCharFunc::new() .invoke_with_args(args) @@ -796,7 +791,6 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = ToCharFunc::new().invoke_with_args(args); assert_eq!( @@ -818,7 +812,6 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = ToCharFunc::new().invoke_with_args(args); assert_eq!( diff --git a/datafusion/functions/src/datetime/to_date.rs b/datafusion/functions/src/datetime/to_date.rs index f6b313e6a28bb..3840c8d8bbb94 100644 --- a/datafusion/functions/src/datetime/to_date.rs +++ b/datafusion/functions/src/datetime/to_date.rs @@ -186,7 +186,6 @@ mod tests { number_rows, return_field: Field::new("f", DataType::Date32, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; ToDateFunc::new().invoke_with_args(args) } diff --git a/datafusion/functions/src/datetime/to_local_time.rs b/datafusion/functions/src/datetime/to_local_time.rs index 4d50a70d37236..6e0a150b0a35f 100644 --- a/datafusion/functions/src/datetime/to_local_time.rs +++ b/datafusion/functions/src/datetime/to_local_time.rs @@ -549,7 +549,6 @@ mod tests { number_rows: 1, return_field: Field::new("f", expected.data_type(), true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }) .unwrap(); match res { @@ -621,7 +620,6 @@ mod tests { ) .into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = ToLocalTimeFunc::new().invoke_with_args(args).unwrap(); if let ColumnarValue::Array(result) = result { diff --git a/datafusion/functions/src/datetime/to_timestamp.rs b/datafusion/functions/src/datetime/to_timestamp.rs index f35e170073030..0a0700097770f 100644 --- a/datafusion/functions/src/datetime/to_timestamp.rs +++ b/datafusion/functions/src/datetime/to_timestamp.rs @@ -1033,7 +1033,6 @@ mod tests { number_rows: 4, return_field: Field::new("f", rt, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let res = udf .invoke_with_args(args) @@ -1084,7 +1083,6 @@ mod tests { number_rows: 5, return_field: Field::new("f", rt, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let res = udf .invoke_with_args(args) diff --git a/datafusion/functions/src/math/log.rs b/datafusion/functions/src/math/log.rs index 1a73ed8436a68..f66f6fcfc1f88 100644 --- a/datafusion/functions/src/math/log.rs +++ b/datafusion/functions/src/math/log.rs @@ -370,7 +370,6 @@ mod tests { number_rows: 4, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = LogFunc::new().invoke_with_args(args); assert!(result.is_err()); @@ -391,7 +390,6 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = LogFunc::new().invoke_with_args(args); @@ -409,7 +407,6 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Float32, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = LogFunc::new() .invoke_with_args(args) @@ -440,7 +437,6 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = LogFunc::new() .invoke_with_args(args) @@ -475,7 +471,6 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Float32, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = LogFunc::new() .invoke_with_args(args) @@ -510,7 +505,6 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = LogFunc::new() .invoke_with_args(args) @@ -543,7 +537,6 @@ mod tests { number_rows: 4, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = LogFunc::new() .invoke_with_args(args) @@ -579,7 +572,6 @@ mod tests { number_rows: 4, return_field: Field::new("f", DataType::Float32, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = LogFunc::new() .invoke_with_args(args) @@ -621,7 +613,6 @@ mod tests { number_rows: 5, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = LogFunc::new() .invoke_with_args(args) @@ -664,7 +655,6 @@ mod tests { number_rows: 4, return_field: Field::new("f", DataType::Float32, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = LogFunc::new() .invoke_with_args(args) @@ -846,7 +836,6 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Decimal128(38, 0), true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = LogFunc::new() .invoke_with_args(args) @@ -880,7 +869,6 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = LogFunc::new() .invoke_with_args(args) @@ -915,7 +903,6 @@ mod tests { number_rows: 6, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = LogFunc::new() .invoke_with_args(args) @@ -960,7 +947,6 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = LogFunc::new() .invoke_with_args(args) @@ -1001,7 +987,6 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = LogFunc::new() .invoke_with_args(args) @@ -1052,7 +1037,6 @@ mod tests { number_rows: 7, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = LogFunc::new() .invoke_with_args(args) @@ -1094,7 +1078,6 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = LogFunc::new().invoke_with_args(args); assert!(result.is_err()); @@ -1118,7 +1101,6 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = LogFunc::new().invoke_with_args(args); assert!(result.is_err()); diff --git a/datafusion/functions/src/math/power.rs b/datafusion/functions/src/math/power.rs index 21a777abb3295..ad2e795d086e9 100644 --- a/datafusion/functions/src/math/power.rs +++ b/datafusion/functions/src/math/power.rs @@ -222,7 +222,6 @@ mod tests { number_rows: 4, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = PowerFunc::new() .invoke_with_args(args) @@ -259,7 +258,6 @@ mod tests { number_rows: 4, return_field: Field::new("f", DataType::Int64, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = PowerFunc::new() .invoke_with_args(args) diff --git a/datafusion/functions/src/math/signum.rs b/datafusion/functions/src/math/signum.rs index d1d49b1bf6f90..bbe6178f39b79 100644 --- a/datafusion/functions/src/math/signum.rs +++ b/datafusion/functions/src/math/signum.rs @@ -173,7 +173,6 @@ mod test { number_rows: array.len(), return_field: Field::new("f", DataType::Float32, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = SignumFunc::new() .invoke_with_args(args) @@ -221,7 +220,6 @@ mod test { number_rows: array.len(), return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = SignumFunc::new() .invoke_with_args(args) diff --git a/datafusion/functions/src/regex/regexpcount.rs b/datafusion/functions/src/regex/regexpcount.rs index ee6f412bb9a16..8bad506217aa5 100644 --- a/datafusion/functions/src/regex/regexpcount.rs +++ b/datafusion/functions/src/regex/regexpcount.rs @@ -628,7 +628,6 @@ mod tests { number_rows: args.len(), return_field: Field::new("f", Int64, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }) } diff --git a/datafusion/functions/src/regex/regexpinstr.rs b/datafusion/functions/src/regex/regexpinstr.rs index 1e64f7087ea74..851c182a90dd0 100644 --- a/datafusion/functions/src/regex/regexpinstr.rs +++ b/datafusion/functions/src/regex/regexpinstr.rs @@ -494,7 +494,6 @@ mod tests { number_rows: args.len(), return_field: Arc::new(Field::new("f", Int64, true)), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }) } diff --git a/datafusion/functions/src/string/concat.rs b/datafusion/functions/src/string/concat.rs index 661bcfe4e0fd8..a93e70e714e8b 100644 --- a/datafusion/functions/src/string/concat.rs +++ b/datafusion/functions/src/string/concat.rs @@ -487,7 +487,6 @@ mod tests { number_rows: 3, return_field: Field::new("f", Utf8, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = ConcatFunc::new().invoke_with_args(args)?; diff --git a/datafusion/functions/src/string/concat_ws.rs b/datafusion/functions/src/string/concat_ws.rs index 85704d6b2f468..cdd30ac8755ab 100644 --- a/datafusion/functions/src/string/concat_ws.rs +++ b/datafusion/functions/src/string/concat_ws.rs @@ -495,7 +495,6 @@ mod tests { number_rows: 3, return_field: Field::new("f", Utf8, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = ConcatWsFunc::new().invoke_with_args(args)?; @@ -533,7 +532,6 @@ mod tests { number_rows: 3, return_field: Field::new("f", Utf8, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = ConcatWsFunc::new().invoke_with_args(args)?; diff --git a/datafusion/functions/src/string/contains.rs b/datafusion/functions/src/string/contains.rs index 1edab4c6bf334..7e50676933c8d 100644 --- a/datafusion/functions/src/string/contains.rs +++ b/datafusion/functions/src/string/contains.rs @@ -177,7 +177,6 @@ mod test { number_rows: 2, return_field: Field::new("f", DataType::Boolean, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let actual = udf.invoke_with_args(args).unwrap(); diff --git a/datafusion/functions/src/string/lower.rs b/datafusion/functions/src/string/lower.rs index 099a3ffd44cc4..ee56a6a549857 100644 --- a/datafusion/functions/src/string/lower.rs +++ b/datafusion/functions/src/string/lower.rs @@ -113,7 +113,6 @@ mod tests { arg_fields, return_field: Field::new("f", Utf8, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = match func.invoke_with_args(args)? { diff --git a/datafusion/functions/src/string/upper.rs b/datafusion/functions/src/string/upper.rs index d7d2bde94b0a3..8bb2ec1d511cd 100644 --- a/datafusion/functions/src/string/upper.rs +++ b/datafusion/functions/src/string/upper.rs @@ -112,7 +112,6 @@ mod tests { arg_fields: vec![arg_field], return_field: Field::new("f", Utf8, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let result = match func.invoke_with_args(args)? { diff --git a/datafusion/functions/src/unicode/find_in_set.rs b/datafusion/functions/src/unicode/find_in_set.rs index 219bd6eaa762c..fa68e539600b0 100644 --- a/datafusion/functions/src/unicode/find_in_set.rs +++ b/datafusion/functions/src/unicode/find_in_set.rs @@ -485,7 +485,6 @@ mod tests { number_rows: cardinality, return_field: Field::new("f", return_type, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }); assert!(result.is_ok()); diff --git a/datafusion/functions/src/unicode/strpos.rs b/datafusion/functions/src/unicode/strpos.rs index a3734b0c0de4f..4f238b2644bdf 100644 --- a/datafusion/functions/src/unicode/strpos.rs +++ b/datafusion/functions/src/unicode/strpos.rs @@ -336,7 +336,6 @@ mod tests { Field::new("f2", DataType::Utf8, substring_nullable).into(), ], scalar_arguments: &[None::<&ScalarValue>, None::<&ScalarValue>], - lambdas: &[false; 2], }; strpos.return_field_from_args(args).unwrap().is_nullable() diff --git a/datafusion/functions/src/utils.rs b/datafusion/functions/src/utils.rs index d6d56b32722de..932d61e8007cd 100644 --- a/datafusion/functions/src/utils.rs +++ b/datafusion/functions/src/utils.rs @@ -234,7 +234,6 @@ pub mod test { let return_field = func.return_field_from_args(datafusion_expr::ReturnFieldArgs { arg_fields: &field_array, scalar_arguments: &scalar_arguments_refs, - lambdas: &vec![false; scalar_arguments_refs.len()], }); let arg_fields = $ARGS.iter() .enumerate() @@ -253,7 +252,6 @@ pub mod test { arg_fields, number_rows: cardinality, return_field, - lambdas: None, config_options: $CONFIG_OPTIONS }); assert_eq!(result.is_ok(), true, "function returned an error: {}", result.unwrap_err()); @@ -276,7 +274,6 @@ pub mod test { arg_fields, number_rows: cardinality, return_field, - lambdas: None, config_options: $CONFIG_OPTIONS, }) { Ok(_) => assert!(false, "expected error"), diff --git a/datafusion/physical-expr/src/async_scalar_function.rs b/datafusion/physical-expr/src/async_scalar_function.rs index a34d3cda47682..b434694a20cc8 100644 --- a/datafusion/physical-expr/src/async_scalar_function.rs +++ b/datafusion/physical-expr/src/async_scalar_function.rs @@ -168,7 +168,6 @@ impl AsyncFuncExpr { number_rows: current_batch.num_rows(), return_field: Arc::clone(&self.return_field), config_options: Arc::clone(&config_options), - lambdas: None, }) .await?, ); @@ -188,7 +187,6 @@ impl AsyncFuncExpr { number_rows: batch.num_rows(), return_field: Arc::clone(&self.return_field), config_options: Arc::clone(&config_options), - lambdas: None, }) .await?, ); diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index 2527e84241fe3..6ad22671ba847 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -34,19 +34,19 @@ use std::fmt::{self, Debug, Formatter}; use std::hash::{Hash, Hasher}; use std::sync::Arc; -use crate::expressions::{LambdaExpr, Literal}; +use crate::expressions::Literal; use crate::PhysicalExpr; -use arrow::array::{Array, NullArray, RecordBatch}; -use arrow::datatypes::{DataType, Field, FieldRef, Schema}; +use arrow::array::{Array, RecordBatch}; +use arrow::datatypes::{DataType, FieldRef, Schema}; use datafusion_common::config::{ConfigEntry, ConfigOptions}; -use datafusion_common::{exec_err, internal_err, Result, ScalarValue}; +use datafusion_common::{internal_err, Result, ScalarValue}; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::sort_properties::ExprProperties; use datafusion_expr::type_coercion::functions::data_types_with_scalar_udf; use datafusion_expr::{ - expr_vec_fmt, ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, - ScalarFunctionLambdaArg, ScalarUDF, ValueOrLambdaParameter, Volatility, + expr_vec_fmt, ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, + Volatility, }; /// Physical expression of a scalar function @@ -117,15 +117,9 @@ impl ScalarFunctionExpr { }) .collect::>(); - let lambdas = args - .iter() - .map(|e| e.as_any().is::()) - .collect::>(); - let ret_args = ReturnFieldArgs { arg_fields: &arg_fields, scalar_arguments: &arguments, - lambdas: &lambdas, }; let return_field = fun.return_field_from_args(ret_args)?; @@ -270,10 +264,7 @@ impl PhysicalExpr for ScalarFunctionExpr { let args = self .args .iter() - .map(|e| match e.as_any().downcast_ref::() { - Some(_) => Ok(ColumnarValue::Scalar(ScalarValue::Null)), - None => Ok(e.evaluate(batch)?), - }) + .map(|e| e.evaluate(batch)) .collect::>>()?; let arg_fields = self @@ -287,89 +278,6 @@ impl PhysicalExpr for ScalarFunctionExpr { .iter() .all(|arg| matches!(arg, ColumnarValue::Scalar(_))); - let lambdas = if self.args.iter().any(|arg| arg.as_any().is::()) { - let args_metadata = std::iter::zip(&self.args, &arg_fields) - .map( - |(expr, field)| match expr.as_any().downcast_ref::() { - Some(_lambda) => ValueOrLambdaParameter::Lambda, - None => ValueOrLambdaParameter::Value(Arc::clone(field)), - }, - ) - .collect::>(); - - let params = self.fun().inner().lambdas_parameters(&args_metadata)?; - - let lambdas = std::iter::zip(&self.args, params) - .map(|(arg, lambda_params)| { - match (arg.as_any().downcast_ref::(), lambda_params) { - (Some(lambda), Some(lambda_params)) => { - if lambda.params().len() > lambda_params.len() { - return exec_err!( - "lambda defined {} params but UDF support only {}", - lambda.params().len(), - lambda_params.len() - ); - } - - let captures = lambda.captures(); - - let params = std::iter::zip(lambda.params(), lambda_params) - .map(|(name, param)| Arc::new(param.with_name(name))) - .collect(); - - let captures = if !captures.is_empty() { - let (fields, columns): (Vec<_>, _) = std::iter::zip( - batch.schema_ref().fields(), - batch.columns(), - ) - .enumerate() - .map(|(column_index, (field, column))| { - if captures.contains(&column_index) { - (Arc::clone(field), Arc::clone(column)) - } else { - ( - Arc::new(Field::new( - field.name(), - DataType::Null, - false, - )), - Arc::new(NullArray::new(column.len())) as _, - ) - } - }) - .unzip(); - - let schema = Arc::new(Schema::new(fields)); - - Some(RecordBatch::try_new(schema, columns)?) - } else { - None - }; - - Ok(Some(ScalarFunctionLambdaArg { - params, - body: Arc::clone(lambda.body()), - captures, - })) - } - (Some(_lambda), None) => exec_err!( - "{} don't reported the parameters of one of it's lambdas", - self.fun.name() - ), - (None, Some(_lambda_params)) => exec_err!( - "{} reported parameters for an argument that is not a lambda", - self.fun.name() - ), - _ => Ok(None), - } - }) - .collect::>>()?; - - Some(lambdas) - } else { - None - }; - // evaluate the function let output = self.fun.invoke_with_args(ScalarFunctionArgs { args, @@ -377,9 +285,9 @@ impl PhysicalExpr for ScalarFunctionExpr { number_rows: batch.num_rows(), return_field: Arc::clone(&self.return_field), config_options: Arc::clone(&self.config_options), - lambdas, })?; + if let ColumnarValue::Array(array) = &output { if array.len() != batch.num_rows() { // If the arguments are a non-empty slice of scalar values, we can assume that diff --git a/datafusion/spark/benches/char.rs b/datafusion/spark/benches/char.rs index 501bfd2a0186d..02eab7630d070 100644 --- a/datafusion/spark/benches/char.rs +++ b/datafusion/spark/benches/char.rs @@ -68,7 +68,6 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Arc::new(Field::new("f", DataType::Utf8, true)), config_options: Arc::clone(&config_options), - lambdas: None, }) .unwrap(), ) diff --git a/datafusion/spark/src/function/bitmap/bitmap_count.rs b/datafusion/spark/src/function/bitmap/bitmap_count.rs index e4c12ebe19665..56a9c5edb812c 100644 --- a/datafusion/spark/src/function/bitmap/bitmap_count.rs +++ b/datafusion/spark/src/function/bitmap/bitmap_count.rs @@ -217,7 +217,6 @@ mod tests { number_rows: 1, return_field: Field::new("f", Int64, true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; let udf = BitmapCount::new(); let actual = udf.invoke_with_args(args)?; diff --git a/datafusion/spark/src/function/datetime/make_dt_interval.rs b/datafusion/spark/src/function/datetime/make_dt_interval.rs index aaff5400d0c00..bbfba44861344 100644 --- a/datafusion/spark/src/function/datetime/make_dt_interval.rs +++ b/datafusion/spark/src/function/datetime/make_dt_interval.rs @@ -317,7 +317,6 @@ mod tests { number_rows, return_field: Field::new("f", Duration(Microsecond), true).into(), config_options: Arc::new(Default::default()), - lambdas: None, }; SparkMakeDtInterval::new().invoke_with_args(args) } diff --git a/datafusion/spark/src/function/datetime/make_interval.rs b/datafusion/spark/src/function/datetime/make_interval.rs index 9f98c4b5ce9fb..8e3169556b95b 100644 --- a/datafusion/spark/src/function/datetime/make_interval.rs +++ b/datafusion/spark/src/function/datetime/make_interval.rs @@ -516,7 +516,6 @@ mod tests { number_rows, return_field: Field::new("f", Interval(MonthDayNano), true).into(), config_options: Arc::new(ConfigOptions::default()), - lambdas: None, }; SparkMakeInterval::new().invoke_with_args(args) } diff --git a/datafusion/spark/src/function/string/concat.rs b/datafusion/spark/src/function/string/concat.rs index e2cd8d977fe29..0dcc58d5bb8ed 100644 --- a/datafusion/spark/src/function/string/concat.rs +++ b/datafusion/spark/src/function/string/concat.rs @@ -105,7 +105,6 @@ fn spark_concat(args: ScalarFunctionArgs) -> Result { number_rows, return_field, config_options, - lambdas, } = args; // Handle zero-argument case: return empty string @@ -131,7 +130,6 @@ fn spark_concat(args: ScalarFunctionArgs) -> Result { number_rows, return_field, config_options, - lambdas, }; let result = concat_func.invoke_with_args(func_args)?; diff --git a/datafusion/spark/src/function/utils.rs b/datafusion/spark/src/function/utils.rs index 1064acc342916..b939dabda388d 100644 --- a/datafusion/spark/src/function/utils.rs +++ b/datafusion/spark/src/function/utils.rs @@ -61,7 +61,6 @@ pub mod test { let return_field = func.return_field_from_args(datafusion_expr::ReturnFieldArgs { arg_fields: &arg_fields, scalar_arguments: &scalar_arguments_refs, - lambdas: &vec![false; arg_fields.len()], }); match expected { @@ -75,7 +74,6 @@ pub mod test { return_field, arg_fields: arg_fields.clone(), config_options: $CONFIG_OPTIONS, - lambdas: None, }) { Ok(col_value) => { match col_value.to_array(cardinality) { @@ -119,7 +117,6 @@ pub mod test { return_field: value, arg_fields, config_options: $CONFIG_OPTIONS, - lambdas: None, }) { Ok(_) => assert!(false, "expected error"), Err(error) => { From 570cc53367e9d4a4697c0dd4d462641f87670c4e Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Sun, 1 Mar 2026 02:13:40 -0300 Subject: [PATCH 07/12] temporarily add pr description as DOC.md --- DOC.md | 1166 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 1166 insertions(+) create mode 100644 DOC.md diff --git a/DOC.md b/DOC.md new file mode 100644 index 0000000000000..10a0ab4e19407 --- /dev/null +++ b/DOC.md @@ -0,0 +1,1166 @@ +This PR adds support for lambdas with column capture and the `array_transform` function used to test the lambda implementation. Example usage: + +```sql +CREATE TABLE t as SELECT 2 as n; + +SELECT array_transform([2, 3], v -> v != t.n) from t; + +[false, true] + +-- arbitrally nested lambdas are also supported +SELECT array_transform([[[2, 3]]], m -> array_transform(m, l -> array_transform(l, v -> v*2))); + +[[[4, 6]]] +``` + +Some comments on code snippets of this doc show what value each struct, variant or field would hold after planning the first example above. Some literals are simplified pseudo code + +3 new `Expr` variants are added, `LambdaFunction`, owing a new trait `LambdaUDF`, which is like a `ScalarFunction`/`ScalarUDFImpl` with support for lambdas, `Lambda`, for the lambda body and it's parameters names, and `LambdaVariable`, which is like `Column` but for lambdas parameters. The reasoning why not using `Column` instead is later on this doc. + +Their logical representations: + +```rust +enum Expr { + LambdaFunction(LambdaFunction), // array_transform([2, 3], v -> v != t.n) + Lambda(Lambda), // v -> v != t.n + LambdaVariable(LambdaVariable), // v, of the lambda body: v != t.n + ... +} + +// array_transform([2, 3], v -> v != t.n) +struct LambdaFunction { + pub func: Arc, // global instance of array_transform + pub args: Vec, // [Expr::ScalarValue([2, 3]), Expr::Lambda(v -> v != n)] +} + +// v -> v != t.n +struct Lambda { + pub params: Vec, // ["v"] + pub body: Box, // v != n +} + +// v, of the lambda body: v != t.n +struct LambdaVariable { + pub name: String, // "v" + pub field: Option, // Some(Field::new("", DataType::Int32, false)) + pub spans: Spans, +} + +``` + +The example would be planned into a tree like this: + +``` +LambdaFunctionExpression + name: array_transform + children: + 1. ListExpression [2,3] + 2. LambdaExpression + parameters: ["v"] + body: + ComparisonExpression (!=) + left: + LambdaVariableExpression("v", Some(Field::new("", Int32, false))) + right: + ColumnExpression("t.n") +``` + +The physical counterparts definition: + +```rust + +struct LambdaFunctionExpr { + fun: Arc, // global instance of array_transform + name: String, // "array_transform" + args: Vec>, // [LiteralExpr([2, 3], LambdaExpr("v -> v != t.n"))] + return_field: FieldRef, // Field::new("", DataType::new_list(DataType::Boolean, false), false) + config_options: Arc, +} + + +struct LambdaExpr { + params: Vec, // ["v"] + body: Arc, // v -> v != t.n +} + +struct LambdaVariable { + name: String, // "v", of the lambda body: v != t.n + field: FieldRef, // Field::new("", DataType::Int32, false) + value: Option, // reasoning later on +} +``` + +Note: For those who primarly wants to check if this lambda implementation supports their usecase and don't want to spend much time here, it's okay to skip most collapsed blocks, as those serve mostly to help code reviewers, with the exception of `LambdaUDF` and the `array_transform` implementation of `LambdaUDF` relevant methods, collapsed due to their size + +
Physical planning implementation is trivial: + +```rust +fn create_physical_expr( + e: &Expr, + input_dfschema: &DFSchema, + execution_props: &ExecutionProps, +) -> Result> { + let input_schema = input_dfschema.as_arrow(); + + match e { + ... + Expr::LambdaFunction(LambdaFunction { func, args}) => { + let physical_args = + create_physical_exprs(args, input_dfschema, execution_props)?; + + Ok(Arc::new(LambdaFunctionExpr::try_new( + Arc::clone(func), + physical_args, + input_schema, + config_options: ... // irrelevant + )?)) + } + Expr::Lambda(Lambda { params, body }) => Ok(Arc::new(LambdaExpr::new( + params.clone(), + create_physical_expr(body, input_dfschema, execution_props)?, + ))), + Expr::LambdaVariable(LambdaVariable { + name, + field, + spans: _, + }) => lambda_variable( + name, + Arc::clone(field), + ), + } +} +``` + +
+
+ +The added `LambdaUDF` trait is almost a clone of `ScalarUDFImpl`, with the exception of: +1. `return_field_from_args` and `invoke_with_args`, where now `args.args` is a list of enums with two variants: `Value` or `Lambda` instead of a list of values +2. the addition of `lambdas_parameters`, which return a `Field` for each parameter supported for every lambda argument based on the `Field` of the non lambda arguments +3. the removal of `return_field` and the deprecated ones `is_nullable` and `display_name`. + +
LambdaUDF + +```rust + +trait LambdaUDF { + /// Returns a list of the same size as args where each value is the logic below applied to value at the correspondent position in args: + /// + /// If it's a value, return None + /// If it's a lambda, return the list of all parameters that that lambda supports + /// based on the Field of the non-lambda arguments + /// + /// Example for array_transform: + /// + /// `array_transform([2, 8], v -> v > 4)` + /// + /// let lambdas_parameters = array_transform.lambdas_parameters(&[ + /// ValueOrLambdaParameter::Value(Field::new("", DataType::new_list(DataType::Int32, false)))]), // the Field associated with the literal `[2, 8]` + /// ValueOrLambdaParameter::Lambda, // A lambda + /// ]?; + /// + /// assert_eq!( + /// lambdas_parameters, + /// vec![ + /// None, // it's a value, return None + /// // it's a lambda, return it's supported parameters, regardless of how many are actually used + /// Some(vec![ + /// Field::new("", DataType::Int32, false), // the value being transformed, + /// Field::new("", DataType::Int32, false), // the 1-based index being transformed, not used on the example above, but implementations doesn't need to care about it + /// ]) + /// ] + /// ) + fn lambdas_parameters( + &self, + args: &[ValueOrLambdaParameter], + ) -> Result>>>; + fn return_field_from_args(&self, args: LambdaReturnFieldArgs) -> Result; + fn invoke_with_args(&self, args: LambdaFunctionArgs) -> Result; + // ... omitted methods that are similar in ScalarUDFImpl +} + +pub enum ValueOrLambdaParameter { + /// A columnar value with the given field + Value(FieldRef), + /// A lambda + Lambda, +} + +/// Information about arguments passed to the function +/// +/// This structure contains metadata about how the function was called +/// such as the type of the arguments, any scalar arguments and if the +/// arguments can (ever) be null +/// +/// See [`LambdaUDF::return_field_from_args`] for more information +pub struct LambdaReturnFieldArgs<'a> { + /// The data types of the arguments to the function + /// + /// If argument `i` to the function is a lambda, it will be the field returned by the + /// lambda when executed with the arguments returned from `LambdaUDF::lambdas_parameters` + /// + /// For example, with `array_transform([1], v -> v == 5)` + /// this field will be `[ + // ValueOrLambdaField::Value(Field::new("", DataType::new_list(DataType::Int32, false), false)), + // ValueOrLambdaField::Lambda(Field::new("", DataType::Boolean, false)) + // ]` + pub arg_fields: &'a [ValueOrLambdaField], + /// Is argument `i` to the function a scalar (constant)? + /// + /// If the argument `i` is not a scalar, it will be None + /// + /// For example, if a function is called like `my_function(column_a, 5)` + /// this field will be `[None, Some(ScalarValue::Int32(Some(5)))]` + pub scalar_arguments: &'a [Option<&'a ScalarValue>], +} + +/// A tagged FieldRef indicating whether it correspond the field of a value or the field of the output of a lambda argument +pub enum ValueOrLambdaField { + /// The FieldRef of a ColumnarValue argument + Value(FieldRef), + /// The return FieldRef of the lambda body when evaluated with the parameters from LambdaUDF::lambda_parameters + Lambda(FieldRef), +} + +/// Arguments passed to [`LambdaUDF::invoke_with_args`] when invoking a +/// lambda function. +pub struct LambdaFunctionArgs { + /// The evaluated arguments to the function + pub args: Vec, + /// Field associated with each arg, if it exists + pub arg_fields: Vec, + /// The number of rows in record batch being evaluated + pub number_rows: usize, + /// The return field of the lambda function returned (from `return_type` + /// or `return_field_from_args`) when creating the physical expression + /// from the logical expression + pub return_field: FieldRef, + /// The config options at execution time + pub config_options: Arc, +} + +/// A lambda argument to a LambdaFunction +pub struct LambdaFunctionLambdaArg { + /// The parameters defined in this lambda + /// + /// For example, for `array_transform([2], v -> -v)`, + /// this will be `vec![Field::new("v", DataType::Int32, true)]` + pub params: Vec, + /// The body of the lambda + /// + /// For example, for `array_transform([2], v -> -v)`, + /// this will be the physical expression of `-v` + pub body: Arc, + /// A RecordBatch containing at least the captured columns inside this lambda body, if any + /// Note that it may contain additional, non-specified columns, + /// but that's implementation detail and should not be relied upon + /// + /// For example, for `array_transform([2], v -> v + t.a + t.b)`, + /// this will be a `RecordBatch` with at least two columns, `t.a` and `t.b` + pub captures: Option, +} + +// An argument to a LambdaUDF +pub enum ValueOrLambda { + Value(ColumnarValue), + Lambda(LambdaFunctionLambdaArg), +} +``` + + +
+ +
array_transform lambdas_parameters implementation + +```rust +impl LambdaUDF for ArrayTransform { + fn lambdas_parameters( + &self, + args: &[ValueOrLambdaParameter], + ) -> Result>>> { + // list is the field of [2, 3]: Field::new("", DataType::new_list(DataType::Int32, false), false) + let [ValueOrLambdaParameter::Value(list), ValueOrLambdaParameter::Lambda] = args + else { + return exec_err!( + "{} expects a value follewed by a lambda, got {:?}", + self.name(), + args + ); + }; + + // the field of [2, 3] inner values: Field::new("", DataType::Int32, false) + let (field, index_type) = match list.data_type() { + DataType::List(field) => (field, DataType::Int32), + DataType::LargeList(field) => (field, DataType::Int64), + DataType::FixedSizeList(field, _) => (field, DataType::Int32), + _ => return exec_err!("expected list, got {list}"), + }; + + // we don't need to omit the index in the case the lambda don't specify, e.g. array_transform([], v -> v*2), + // nor check whether the lambda contains more than two parameters, e.g. array_transform([], (v, i, j) -> v+i+j), + // as datafusion will do that for us + let value = Field::new("", field.data_type().clone(), field.is_nullable()) + .with_metadata(field.metadata().clone()); + let index = Field::new("", index_type, false); + + Ok(vec![None, Some(vec![value, index])]) + } +} +``` + +
+ +
array_transform return_field_from_args implementation + +```rust +impl LambdaUDF for ArrayTransform { + fn return_field_from_args( + &self, + args: datafusion_expr::LambdaReturnFieldArgs, + ) -> Result> { + // [ + // Field::new("", DataType::new_list(DataType::Int32, false), false), + // Field::new("", DataType::Boolean, false), + // ] + let [ValueOrLambdaField::Value(list), ValueOrLambdaField::Lambda(lambda)] = + take_function_args(self.name(), args.arg_fields)? + else { + return exec_err!( + "{} expects a value follewed by a lambda, got {:?}", + self.name(), + args + ); + }; + + // lambda is the return_field of the lambda body + // when evaluated with the parameters from lambdas_parameters + let field = Arc::new(Field::new( + Field::LIST_FIELD_DEFAULT_NAME, + lambda.data_type().clone(), + lambda.is_nullable(), + )); + + let return_type = match list.data_type() { + DataType::List(_) => DataType::List(field), + DataType::LargeList(_) => DataType::LargeList(field), + DataType::FixedSizeList(_, size) => DataType::FixedSizeList(field, *size), + other => plan_err!("expected list, got {other}"), + }; + + Ok(Arc::new(Field::new("", return_type, list.is_nullable()))) + } +} +``` + +
+ +
array_transform invoke_with_args implementation + + +```rust +impl LambdaUDF for ArrayTransform { + fn invoke_with_args(&self, args: LambdaFunctionArgs) -> Result { + let [list_value, lambda] = take_function_args(self.name(), &args.args)?; + + // list = [2, 3] + // lambda = LambdaFunctionLambdaArg { + // params: vec![Field::new("v", DataType::Int32, false)], + // body: PhysicalExpr("v != t.n"),// the physical expression of the lambda *body*, and not the lambda itself: this is not a LambdaExpr. + // captures: Some(record_batch!("t.n", Int32, [2])) + // } + let (ValueOrLambda::Value(list_value), ValueOrLambda::Lambda(lambda)) = + (list_value, lambda) + else { + return exec_err!( + "{} expects a value followed by a lambda, got {} and {}", + self.name(), + list_value, + lambda, + ); + }; + + let list_array = list_value.to_array(args.number_rows)?; + let list_values = match list_array.data_type() { + DataType::List(_) => list_array.as_list::().values(), + DataType::LargeList(_) => list_array.as_list::().values(), + DataType::FixedSizeList(_, _) => list_array.as_fixed_size_list().values(), + other => exec_err!("expected list, got {other}") + } + + // if any column got captured, we need to adjust it to the values arrays, + // duplicating values of list with mulitple values and removing values of empty lists + // list_indices is not cheap so is important to avoid it when no column is captured + let adjusted_captures = lambda + .captures + .as_ref() + //list_indices return the row_number for each sublist element: [[1, 2], [3], [4]] => [0,0,1,2], not included here + .map(|captures| take_record_batch(captures, &list_indices(&list_array)?)) + .transpose()? + .unwrap_or_else(|| { + RecordBatch::try_new_with_options( + Arc::new(Schema::empty()), + vec![], + &RecordBatchOptions::new().with_row_count(Some(list_values.len())), + ) + .unwrap() + }); + + // by using closures, bind_lambda_variables can evaluate only the needed ones avoiding unnecessary computations + let values_param = || Ok(Arc::clone(list_values)); + //elements_indices return the index of each element within its sublist: [[5, 3], [7, 1, 1]] => [1, 2, 1, 2, 3], not included here + let indices_param = || elements_indices(&list_array); + + let binded_body = bind_lambda_variables( + Arc::clone(&lambda.body), + &lambda.params, + &[&values_param, &indices_param], + )?; + + // call the transforming expression with the record batch + let transformed_values = binded_body + .evaluate(&adjusted_captures)? + .into_array(list_values.len())?; + + let field = match args.return_field.data_type() { + DataType::List(field) + | DataType::LargeList(field) + | DataType::FixedSizeList(field, _) => Arc::clone(field), + _ => { + return exec_err!( + "{} expected ScalarFunctionArgs.return_field to be a list, got {}", + self.name(), + args.return_field + ) + } + }; + + let transformed_list = match list_array.data_type() { + DataType::List(_) => { + let list = list_array.as_list(); + + Arc::new(ListArray::new( + field, + list.offsets().clone(), + transformed_values, + list.nulls().cloned(), + )) as ArrayRef + } + DataType::LargeList(_) => { + let large_list = list_array.as_list(); + + Arc::new(LargeListArray::new( + field, + large_list.offsets().clone(), + transformed_values, + large_list.nulls().cloned(), + )) + } + DataType::FixedSizeList(_, value_length) => { + Arc::new(FixedSizeListArray::new( + field, + *value_length, + transformed_values, + list_array.as_fixed_size_list().nulls().cloned(), + )) + } + other => exec_err!("expected list, got {other}")?, + }; + + Ok(ColumnarValue::Array(transformed_list)) + } +} +``` + +
+ +
How relevant LambdaUDF methods would be called and what they would return during planning and evaluation of the example + + +```rust +// this is called at sql planning +let lambdas_parameters = lambda_udf.lambdas_parameters(&[ + ValueOrLambdaParameter::Value(Field::new("", DataType::new_list(DataType::Int32, false), false)), // the Field of the [2, 3] literal + ValueOrLambdaParameter::Lambda, // A unspecified lambda. On the example, v -> v != t.n +])?; + +assert_eq!( + lambdas_parameters, + vec![ + // the [2, 3] argument, not a lambda so no parameters + None, + // the parameters that *can* be declared on the lambda, and not only + // those actually declared: the implementation doesn't need to care + // about it + Some(vec![ + Field::new("", DataType::Int32, false), // the list inner value + Field::new("", DataType::Int32, false), // the 1-based index of the element being transformed + ])] +); + + + +// this is called every time ExprSchemable is called on a LambdaFunction +let return_field = array_transform.return_field_from_args(&LambdaReturnFieldArgs { + arg_fields: &[ + ValueOrLambdaField::Value(Field::new("", DataType::new_list(DataType::Int32, false), false)), + ValueOrLambdaField::Lambda(Field::new("", DataType::Boolean, false)), // the return_field of the expression "v != t.n" when "v" is of the type returned in lambdas_parameters + ], + scalar_arguments // irrelevant +})?; + +assert_eq!(return_field, Field::new("", DataType::new_list(DataType::Boolean, false), false)); + + + +let value = array_transform.evaluate(&LambdaFunctionArgs { + args: vec![ + ValueOrLambda::Value(List([2, 3])), + ValueOrLambda::Lambda(LambdaFunctionLambdaArg { + params: vec![Field::new("v", DataType::Int32, false)], + body: PhysicalExpr("v != t.n"),// the physical expression of the lambda *body*, and not the lambda itself: this is not a LambdaExpr. + captures: Some(record_batch!("t.n", Int32, [2])) + }), + ], + arg_fields, // same as above + number_rows: 1, + return_field, // same as above + config_options, // irrelevant +})?; + +assert_eq!(value, BooleanArray::from([false, true])) +``` + +
+
+
+ +A pair LambdaUDF/LambdaUDFImpl like ScalarFunction was not used because those exist only [to maintain backwards compatibility with the older API](https://docs.rs/datafusion/latest/datafusion/logical_expr/struct.ScalarUDF.html#api-note) #8045 + +LambdaFunction invocation: + +Instead of evaluating all it's arguments as ScalarFunction, LambdaFunction does the following: + +1. If it's a non lambda argument, evaluate as usual, and provide the resulting `ColumnarValue` to `LambdaUDF::evaluate` as a `ValueOrLambda::Value` +2. If it's a lambda, construct a `LambdaFunctionLambdaArg` containing the lambda body physical expression and a record batch containing any captured columns as a `ValueOrLambda::Lambda` and provide it to `LambdaUDF::evaluate`. To avoid costly copies of uncaptured columns, we swap them with a `NullArray` while keeping the number of columns on the batch the same so captured columns indices are kept stable across the whole tree. The recent #18329 instead projects-out uncaptured columns and rewrites the expr adjusting columns indexes. If that is preferrable we can generalize that implementation and use it here too. + +
LambdaFunction evalution + +```rust + +impl PhysicalExpr for LambdaFunctionExpr { + fn evaluate(&self, batch: &RecordBatch) -> Result { + let args = self.args + .map(|arg| { + match arg.as_any().downcast_ref::() { + Some(lambda) => { + // helper method that returns the indices of the captured columns. In the example, the only column available (index 0) is captured, so this would be HashSet(0) + let captures = lambda.captures(); + + let captures = if !captures.is_empty() { + let (fields, columns): (Vec<_>, _) = std::iter::zip( + batch.schema_ref().fields(), + batch.columns(), + ) + .enumerate() + .map(|(column_index, (field, column))| { + if captures.contains(&column_index) { + (Arc::clone(field), Arc::clone(column)) + } else { + ( + Arc::new(Field::new( + field.name(), + DataType::Null, + false, + )), + Arc::new(NullArray::new(column.len())) as _, + ) + } + }) + .unzip(); + + let schema = Arc::new(Schema::new(fields)); + + Some(RecordBatch::try_new(schema, columns)?) + } else { + None + }; + + Ok(ValueOrLambda::Lambda(LambdaFunctionLambdaArg { + params, // irrelevant, + body: Arc::clone(lambda.body()), // use the lambda body and not the lambda itself + captures, + })) + } + None => Ok(ValueOrLambda::Value(arg.evaluate(batch)?)), + } + }) + .collect::>>()?; + + // evaluate the function + let output = self.fun.invoke_with_args(LambdaFunctionArgs { + args, + arg_fields, // irrelevant + number_rows: batch.num_rows(), + return_field: Arc::clone(&self.return_field), + config_options: Arc::clone(&self.config_options), + })?; + + Ok(output) + } +} + +``` + +
+
+ +Why `LambdaVariable` and not `Column`: + +Existing tree traversals that operate on columns would break if some column nodes referenced to a lambda parameter and not a real column. In the example query, projection pushdown would try to push the lambda parameter "v", which won't exist in table "t". + +Example of code of another traversal that would break: + +```rust +fn minimize_join_filter(expr: Arc, ...) -> JoinFilter { + let mut used_columns = HashSet::new(); + expr.apply(|expr| { + if let Some(col) = expr.as_any().downcast_ref::() { + // if this is a lambda column, this function will break + used_columns.insert(col.index()); + } + Ok(TreeNodeRecursion::Continue) + }); + ... +} +``` + +Furthermore, the implemention of `ExprSchemable` and `PhysicalExpr::return_field` for `Column` expects that the schema it receives as a argument contains an entry for its name, which is not the case for lambda parameters. + +By including a `FieldRef` on `LambdaVariable` that should be resolved either during construction time, as in the sql planner, or later by the an `AnalyzerRule`, `ExprSchemable` and `PhysicalExpr::return_field` simply return it's own Field: + +
LambdaVariable ExprSchemable and PhysicalExpr::return_field implementation + +```rust +impl ExprSchemable for Expr { + fn to_field( + &self, + schema: &dyn ExprSchema, + ) -> Result<(Option, Arc)> { + let (relation, schema_name) = self.qualified_name(); + let field = match self { + Expr::LambdaVariable(l) => Ok(Arc::clone(&l.field.ok_or_else(|| plan_err!("Unresolved LambdaVariable {}", l.name)))), + ... + }?; + + Ok(( + relation, + Arc::new(field.as_ref().clone().with_name(schema_name)), + )) + } + ... +} + +impl PhysicalExpr for LambdaVariable { + fn return_field(&self, _input_schema: &Schema) -> Result { + Ok(Arc::clone(&self.field)) + } + ... +} +``` + +
+
+ +For reference, [Spark](https://github.com/apache/spark/blob/8b68a172d34d2ed9bd0a2deefcae1840a78143b6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala#L77) and [Substrait](https://substrait.io/expressions/lambda_expressions/#parameter-references) also use a specialized node instead of a regular column + +There's also discussions on making every expr own it's type: #18845, #12604 + +
Possible fixes discarded due to complexity, requiring downstream changes and implementation size: + +1. Add a new set of TreeNode methods that provides the set of lambdas parameters names seen during the traversal, so column nodes can be tested if they refer to a regular column or to a lambda parameter. Any downstream user that wants to support lambdas would need use those methods instead of the existing ones. This also would add 1k+ lines to the PR. + +```rust +impl Expr { + pub fn transform_with_lambdas_params< + F: FnMut(Self, &HashSet) -> Result>, + >( + self, + mut f: F, + ) -> Result> {} +} +``` + +How minimize_join_filter would looks like: + + +```rust +fn minimize_join_filter(expr: Arc, ...) -> JoinFilter { + let mut used_columns = HashSet::new(); + expr.apply_with_lambdas_params(|expr, lambdas_params| { + if let Some(col) = expr.as_any().downcast_ref::() { + // dont include lambdas parameters + if !lambdas_params.contains(col.name()) { + used_columns.insert(col.index()); + } + } + Ok(TreeNodeRecursion::Continue) + }) + ... +} +``` + +2. Add a flag to the Column node indicating if it refers to a lambda parameter. Still requires checking for it on existing tree traversals that works on Columns (30+) and also downstream. + +```rust +//logical +struct Column { + pub relation: Option, + pub name: String, + pub spans: Spans, + pub is_lambda_parameter: bool, +} + +//physical +struct Column { + name: String, + index: usize, + is_lambda_parameter: bool, +} +``` + + +How minimize_join_filter would look like: + +```rust +fn minimize_join_filter(expr: Arc, ...) -> JoinFilter { + let mut used_columns = HashSet::new(); + expr.apply(|expr| { + if let Some(col) = expr.as_any().downcast_ref::() { + // dont include lambdas parameters + if !col.is_lambda_parameter { + used_columns.insert(col.index()); + } + } + Ok(TreeNodeRecursion::Continue) + }) + ... +} +``` + + +1. Add a new set of TreeNode methods that provides a schema that includes the lambdas parameters for the scope of the node being visited/transformed: + +```rust +impl Expr { + pub fn transform_with_schema< + F: FnMut(Self, &DFSchema) -> Result>, + >( + self, + schema: &DFSchema, + f: F, + ) -> Result> { ... } + ... other methods +} +``` + +For any given LambdaFunction found during the traversal, a new schema is created for each lambda argument that contains it's parameter, returned from LambdaUDF::lambdas_parameters +How it would look like: + +```rust + +pub fn infer_placeholder_types(self, schema: &DFSchema) -> Result<(Expr, bool)> { + let mut has_placeholder = false; + // Provide the schema as the first argument. + // Transforming closure receive an adjusted_schema as argument + self.transform_with_schema(schema, |mut expr, adjusted_schema| { + match &mut expr { + // Default to assuming the arguments are the same type + Expr::BinaryExpr(BinaryExpr { left, op: _, right }) => { + // use adjusted_schema and not schema. Those expressions may contain + // columns referring to a lambda parameter, which Field would only be + // available in adjusted_schema and not in schema + rewrite_placeholder(left.as_mut(), right.as_ref(), adjusted_schema)?; + rewrite_placeholder(right.as_mut(), left.as_ref(), adjusted_schema)?; + } + .... + +``` + +2. Make available trought LogicalPlan and ExecutionPlan nodes a schema that includes all lambdas parameters from all expressions owned by the node, and use this schema for tree traversals. For nodes which won't own any expression, the regular schema can be returned + + +```rust +impl LogicalPlan { + fn lambda_extended_schema(&self) -> &DFSchema; +} + +trait ExecutionPlan { + fn lambda_extended_schema(&self) -> &DFSchema; +} + +//usage +impl LogicalPlan { + pub fn replace_params_with_values( + self, + param_values: &ParamValues, + ) -> Result { + self.transform_up_with_subqueries(|plan| { + // use plan.lambda_extended_schema() containing lambdas parameters + // instead of plan.schema() which wont + let lambda_extended_schema = Arc::clone(plan.lambda_extended_schema()); + let name_preserver = NamePreserver::new(&plan); + plan.map_expressions(|e| { + // if this expression is child of lambda and contain columns referring it's parameters + // the lambda_extended_schema already contain them + let (e, has_placeholder) = e.infer_placeholder_types(&lambda_extended_schema)?; + .... + +``` +
+
+ +`LambdaVariable` evaluation, current implementation: + +The physical `LambdaVariable` contains an optional `ColumnarValue` that must be binded for each batch before evaluation with the helper function `bind_lambda_variables`, which rewrites the whole lambda body, binding any variable of the tree. + +
LambdaVariable::evaluate + +```rust +impl PhysicalExpr for LambdaVariable { + fn evaluate(&self, _batch: &RecordBatch) -> Result { + self.value.clone().ok_or_else(|| exec_datafusion_err!("Physical LambdaVariable {} unbinded value", self.name)) + } +} +``` + +
+
+ +Unbinded: +``` +LambdaExpression + parameters: ["v"] + body: + ComparisonExpression(!=) + left: + LambdaVariableExpression("v", Field::new("", Int32, false), None) + right: + ColumnExpression("n") +``` + +After binding: + +``` +LambdaExpression + parameters: ["v"] + body: + ComparisonExpression(!=) + left: + LambdaVariableExpression("v", Field::new("", Int32, false), Some([2, 3])) + right: + ColumnExpression("n") +``` + +Alternative: + +Make the `LambdaVariable` evaluate it's value from the batch passed to `PhysicalExpr::evaluate` as a regular column. For that, instead of binding the body, the `LambdaUDF` implementation would merge the captured batch of a lambda with the values of it's parameters. So that it happen via an index as a regular column, the schema used plan to physical `LambdaVariable` must contain the lambda parameters. This would be the only place during planning that a schema would contain those parameters. Otherwise it only can get the value from the batch via name instead of index + +1. Add a index to LambdaVariable, similar to Column, and remove the optional value. + +```rust +struct LambdaVariable { + name: String, // "v", of the lambda body: v != t.n + field: FieldRef, // Field::new("", DataType::Int32, false) + index: usize, // 1 +} +``` + +2. Insert the lambda parameters only at the Schema used to do the physical planning, to compute the index of a LambdaVariable + +
how physical planning would look like + +```rust +fn create_physical_expr( + e: &Expr, + input_dfschema: &DFSchema, + execution_props: &ExecutionProps, +) -> Result> { + let input_schema = input_dfschema.as_arrow(); + + match e { + ... + Expr::LambdaFunction(LambdaFunction { func, args}) => { + let args_metadata = args.iter() + .map(|arg| if arg.is::() { + Ok(ValueOrLambdaParameter::Lambda) + } else { + Ok(ValueOrLambdaParameter::Value(arg.to_field(input_dfschema)?)) + }) + .collect()?; + + let lambdas_parameters = func.lambdas_parameters(&args_metadata)?; + + let physical_args = std::iter::zip(args, lambdas_parameters) + .map(|(arg, lambda_parameters)| { + match (arg.downcast_ref::(), lambda_parameters) { + (Some(lambda), Some(lambda_parameters)) => { + let extended_dfschema = merge_schema_and_parameters(input_dfschame, lambda_parameters)?; + + create_physical_expr(body, extended_dfschema, execution_props) + } + (None, None) => create_physical_expr(arg, input_dfschema, execution_props), + (Some(_), None) => plan_err!("lambdas_parameters returned None for a lambda") + (None, Some(_)) => plan_err!("lambdas_parameters returned Some for a non lambda") + } + }) + .collect()?; + + Ok(Arc::new(LambdaFunctionExpr::try_new( + Arc::clone(func), + physical_args, + input_schema, + config_options: ... // irrelevant + )?)) + } + } +} +``` + +
+
+ +3. Insert the lambda parameters values into the RecordBatch during the evaluation phase: the LambdaUDF, instead of binding the lambda body variables, inserts it's parameters on the captured RecordBatch it receives on LambdaFunctionLambdaArg. + +How ArrayTransform::invoke_with_args would look like: + +```rust + ... + let values_param = || Ok(Arc::clone(list_values)); + let indices_param = || elements_indices(&list_array); + + let merged_batch = merge_captures_with_params( + adjusted_captures, + &lambda.params, + &[&values_param, &indices_param], + )?; + + // call the transforming expression with the record batch + let transformed_values = lambda.body + .evaluate(&merged_batch)? + .into_array(list_values.len())?; + + ... +``` + +
+ +Why is `LambdaVariable` `Field` is an `Option`? + +So expr_api users can construct a LambdaVariable just by using it's name, without having to set it's field. An `AnalyzerRule` will then set the `LambdaVariable` field based on the returned values from `LambdaUDF::lambdas_parameters` of any `LambdaFunction` it finds while traversing down a expr tree. We may include that rule on the default rules list for when the plan/expression tree is transformed by another rule in a way that changes the types of non lambda arguments of a lambda function, as it may change the types of it's lambda parameters, which would render `LambdaVariable` field's out of sync, as the rule would fix it. Or to not increase planning time we don't include it by default and instruct `expr_api` users to add it manually if needed + + + +```rust +array_transform( + col("my_array"), + lambda( + vec!["current_value"], + 2 * lambda_variable("current_value") + ) +) + +//instead of + +array_transform( + col("my_array"), + lambda( + vec!["current_value"], + 2 * lambda_variable("current_value", Field::new("", DataType::Int32, false)) + ) +) +``` + + +Why set `LambdaVariable` field during sql planning if it's optional and can be set later via an `AnalyzerRule`? + +Some parts of sql planning checks the type/nullability of the already planned children expression of the expr it's planning, and would error if doing so on a unresolved `LambdaVariable` +Take as example this expression: `array_transform([[0, 1]], v -> v[1])`. `FieldAccess` `v[1]` planning is handled by the `ExprPlanner` `FieldAccessPlanner`, which checks the datatype of `v`, a lambda variable, which `ExprSchemable` implementation depends on it's field being resolved, and not on the `PlannerContext` schema, requiring sql planner to plan `LambdaVariables` with a resolved field + + +
FieldAccessPlanner + +```rust +pub struct FieldAccessPlanner; + +impl ExprPlanner for FieldAccessPlanner { + fn plan_field_access( + &self, + expr: RawFieldAccessExpr, // "v[1]" + schema: &DFSchema, + ) -> Result> { + // { "v", "[1]" } + let RawFieldAccessExpr { expr, field_access } = expr; + + match field_access { + ... + // expr[idx] ==> array_element(expr, idx) + GetFieldAccess::ListIndex { key: index } => { + match expr { + ... + // ExprSchemable::get_type called + _ if matches!(expr.get_type(schema)?, DataType::Map(_, _)) => { + Ok(PlannerResult::Planned(Expr::ScalarFunction( + ScalarFunction::new_udf( + get_field_inner(), + vec![expr, *index], + ), + ))) + } + } + } + } + } +} +``` + +
+
+ + Therefore we can't plan all arguments on a single pass, and must first plan the non-lambda arguments, collect their types and nullability, pass them to `LambdaUDF::lambdas_parameters`, which will derive the type of it's lambda parameters based on the type of it's non-lambda argument, and return it to the planner, which, for each unplanned lambda argument, will create a new `PlannerContext` via `with_lambda_parameters`, which contains a mapping of lambdas parameters names to it's type. Then, when planning a `ast::Identifier`, it first check whether a lambda parameter with the given name exists, and if so, plans it into a `Expr::LambdaVariable` with a resolved field, otherwise plan it into a regular `Expr::Column`. + + + +
sql planning + + +```rust +struct PlannerContext { + /// The parameters of all lambdas seen so far + lambdas_parameters: HashMap, + // ... omitted fields +} + +impl PlannerContext { + pub fn with_lambda_parameters( + mut self, + arguments: impl IntoIterator, + ) -> Self { + self.lambdas_parameters + .extend(arguments.into_iter().map(|f| (f.name().clone(), f))); + + self + } +} + +// copied from sqlparser +struct LambdaFunction { + pub params: OneOrManyWithParens, // One("v") + pub body: Box, // v != t.n +} + +// copied from sqlparser +enum OneOrManyWithParens { + One(T), // "v" + Many(Vec), +} + +/// the planning would happens as the following: + +enum ExprOrLambda { + Expr(Expr), // planned [2, 3] + Lambda(ast::LambdaFunction), // unplanned v -> v != t.n +} + +impl SqlToRel { + // example function, won't exist + fn plan_array_transform(&self, array_transform: Arc, args: Vec, schema: &DFSchema, planner_context: &mut PlannerContext) -> Result { + let args = args.into_iter() + .map(|arg| match arg { + ast::Expr::LambdaFunction(l) => Ok(ExprOrLambda::Lambda(l)),//skip planning until we plan non lambda args + arg => Ok(ExprOrLambda::Expr( + self.sql_fn_arg_to_logical_expr_with_name( + arg, + schema, + planner_context, + )?, + )) + }) + .collect::>>()?; + + let args_metadata = args.iter() + .map(|arg| match arg { + Expr(expr) => Ok(ValueOrLambda::Value(expr.to_field(schema)?)), + Lambda(_) => Ok(ValueOrLambda::Lambda), + }) + .collect::>>()?; + + let lambdas_parameters = array_transform.lambdas_parameters(&args_metadata)?; + + let args = std::iter::zip(args, lambdas_parameters) + .map(|(arg, lambdas_parameters)| match (arg, lambdas_parameters) { + (ExprOrLambda::Expr(planned_expr), None) => Ok(planned_expr), + (ExprOrLambda::Lambda(unplanned_lambda), Some(lambda_parameters)) => { + let params = + unplanned_lambda.params + .iter() + .map(|p| p.value.clone()) + .collect(); + + let lambda_parameters = lambda_params + .into_iter() + .zip(¶ms) + .map(|(field, name)| Arc::new(field.with_name(name))); + + let mut planner_context = planner_context + .clone() + .with_lambda_parameters(lambda_parameters); + + Ok(( + Expr::Lambda(Lambda { + params, + body: Box::new(self.sql_expr_to_logical_expr( + *lambda.body, + schema, + &mut planner_context, + )?), + }), + None, + )) + } + (ExprOrLambda::Expr(planned_expr), Some(lambda_parameters)) => plan_err!("lambdas_parameters returned Some for a value"), + (ExprOrLambda::Lambda(unplanned_lambda), None) => plan_err!("lambdas_parameters returned None for a lambda"), + }) + .collect::>>()?; + + Ok(Expr::LambdaFunction(LambdaFunction { + func: array_transform, + args, + })) + } + + fn sql_identifier_to_expr( + &self, + id: ast::Ident, + schema: &DFSchema, + planner_context: &mut PlannerContext, + ) -> Result { + // simplified implementation + if let Some(field) = planner_context.lambdas_parameters.get(id) { + Ok(Expr::LambdaVariable(LambdaVariable { + name: id, // "v" + field, // Field::new("", DataType::Int32, false) + })) + } else { + Ok(Expr::Column(Column::new(id))) + } + } +} + +``` + +
+
+ +`LambdaFunction` `Signature` is non functional + +Currenty, `LambdaUDF::signature` returns the same `Signature` as `ScalarUDF`, but it's `type_signature` field is never used, as most variants of the `TypeSignature` enum aren't applicable to a lambda, and no type coercion is applied on it's arguments, being currently a implementation responsability. We should either add lambda compatible variants to the `TypeSignature` enum, create a new `LambdaTypeSignature` and `LambdaSignature`, or support no automatic type coercion at all on lambda functions. From 83dfbdd13e98be8a00f90becc26f5b9f277eebff Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Sun, 8 Mar 2026 16:56:46 -0300 Subject: [PATCH 08/12] add lambda note in substrait consumer --- .../substrait/src/logical_plan/consumer/expr/scalar_function.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/datafusion/substrait/src/logical_plan/consumer/expr/scalar_function.rs b/datafusion/substrait/src/logical_plan/consumer/expr/scalar_function.rs index f80cf43eb81eb..062c1ac03110c 100644 --- a/datafusion/substrait/src/logical_plan/consumer/expr/scalar_function.rs +++ b/datafusion/substrait/src/logical_plan/consumer/expr/scalar_function.rs @@ -30,6 +30,7 @@ pub async fn from_scalar_function( f: &ScalarFunction, input_schema: &DFSchema, ) -> Result { + //TODO: handle lambda functions, as they are also encoded as scalar functions let Some(fn_signature) = consumer .get_extensions() .functions From 34137e15ca87bd71a9a0b2bd386a823f81fafad8 Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Sun, 8 Mar 2026 17:32:41 -0300 Subject: [PATCH 09/12] add LambdaSignature --- datafusion/expr/src/lib.rs | 4 +- datafusion/expr/src/udlf.rs | 46 +++++++++++++------ .../functions-nested/src/array_transform.rs | 8 ++-- .../physical-expr/src/lambda_function.rs | 14 ++---- datafusion/proto/src/bytes/mod.rs | 10 ++-- 5 files changed, 49 insertions(+), 33 deletions(-) diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index e6dfd9fc1483b..4fc6c738ea06b 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -120,8 +120,8 @@ pub use udaf::{ }; pub use udf::{ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl}; pub use udlf::{ - LambdaFunctionArgs, LambdaFunctionLambdaArg, LambdaReturnFieldArgs, LambdaUDF, - ValueOrLambda, ValueOrLambdaField, ValueOrLambdaParameter, + LambdaFunctionArgs, LambdaFunctionLambdaArg, LambdaReturnFieldArgs, LambdaSignature, + LambdaUDF, ValueOrLambda, ValueOrLambdaField, ValueOrLambdaParameter, }; pub use udwf::{LimitEffect, ReversedUDWF, WindowUDF, WindowUDFImpl}; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; diff --git a/datafusion/expr/src/udlf.rs b/datafusion/expr/src/udlf.rs index 84f9494d2edd8..883a083ca9c9b 100644 --- a/datafusion/expr/src/udlf.rs +++ b/datafusion/expr/src/udlf.rs @@ -27,7 +27,7 @@ use datafusion_common::config::ConfigOptions; use datafusion_common::{Result, ScalarValue, not_impl_err}; use datafusion_expr_common::dyn_eq::{DynEq, DynHash}; use datafusion_expr_common::interval_arithmetic::Interval; -use datafusion_expr_common::signature::Signature; +use datafusion_expr_common::signature::{Volatility}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use std::any::Any; use std::cmp::Ordering; @@ -35,6 +35,28 @@ use std::fmt::Debug; use std::hash::{Hash, Hasher}; use std::sync::Arc; +/// Provides information necessary for calling a lambda function. +/// +/// - [`Volatility`] defines how the output of the function changes with the input. +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] +pub struct LambdaSignature { + /// The volatility of the function. See [Volatility] for more information. + pub volatility: Volatility, + /// Optional parameter names for the function arguments. + /// + /// If provided, enables named argument notation for function calls (e.g., `func(a => 1, b => 2)`). + /// + /// Defaults to `None`, meaning only positional arguments are supported. + pub parameter_names: Option>, +} + +impl LambdaSignature { + /// Creates a new Signature from a given volatility. + pub fn new(volatility: Volatility) -> LambdaSignature { + LambdaSignature { volatility, parameter_names: None } + } +} + impl PartialEq for dyn LambdaUDF { fn eq(&self, other: &Self) -> bool { self.dyn_eq(other.as_any()) @@ -190,19 +212,19 @@ pub enum ValueOrLambdaField { /// # use std::sync::LazyLock; /// # use arrow::datatypes::DataType; /// # use datafusion_common::{DataFusionError, plan_err, Result}; -/// # use datafusion_expr::{col, ColumnarValue, Documentation, LambdaFunctionArgs, Signature, Volatility}; +/// # use datafusion_expr::{col, ColumnarValue, Documentation, LambdaFunctionArgs, LambdaSignature, Volatility}; /// # use datafusion_expr::LambdaUDF; /// # use datafusion_expr::lambda_doc_sections::DOC_SECTION_MATH; /// /// This struct for a simple UDF that adds one to an int32 /// #[derive(Debug, PartialEq, Eq, Hash)] /// struct AddOne { -/// signature: Signature, +/// signature: LambdaSignature, /// } /// /// impl AddOne { /// fn new() -> Self { /// Self { -/// signature: Signature::uniform(1, vec![DataType::Int32], Volatility::Immutable), +/// signature: LambdaSignature::new(Volatility::Immutable), /// } /// } /// } @@ -221,7 +243,7 @@ pub enum ValueOrLambdaField { /// impl LambdaUDF for AddOne { /// fn as_any(&self) -> &dyn Any { self } /// fn name(&self) -> &str { "add_one" } -/// fn signature(&self) -> &Signature { &self.signature } +/// fn signature(&self) -> &LambdaSignature { &self.signature } /// fn return_type(&self, args: &[DataType]) -> Result { /// if !matches!(args.get(0), Some(&DataType::Int32)) { /// return plan_err!("add_one only accepts Int32 arguments"); @@ -274,14 +296,14 @@ pub trait LambdaUDF: Debug + DynEq + DynHash + Send + Sync { )) } - /// Returns a [`Signature`] describing the argument types for which this + /// Returns a [`LambdaSignature`] describing the argument types for which this /// function has an implementation, and the function's [`Volatility`]. /// - /// See [`Signature`] for more details on argument type handling + /// See [`LambdaSignature`] for more details on argument type handling /// and [`Self::return_type`] for computing the return type. /// /// [`Volatility`]: datafusion_expr_common::signature::Volatility - fn signature(&self) -> &Signature; + fn signature(&self) -> &LambdaSignature; /// Create a new instance of this function with updated configuration. /// @@ -534,8 +556,6 @@ pub trait LambdaUDF: Debug + DynEq + DynHash + Send + Sync { /// See the [type coercion module](crate::type_coercion) /// documentation for more details on type coercion /// - /// [`TypeSignature`]: crate::TypeSignature - /// /// For example, if your function requires a floating point arguments, but the user calls /// it like `my_func(1::int)` (i.e. with `1` as an integer), coerce_types can return `[DataType::Float64]` /// to ensure the argument is converted to `1::double` @@ -578,7 +598,7 @@ mod tests { struct TestLambdaUDF { name: &'static str, field: &'static str, - signature: Signature, + signature: LambdaSignature, } impl LambdaUDF for TestLambdaUDF { fn as_any(&self) -> &dyn Any { @@ -589,7 +609,7 @@ mod tests { self.name } - fn signature(&self) -> &Signature { + fn signature(&self) -> &LambdaSignature { &self.signature } @@ -637,7 +657,7 @@ mod tests { Arc::new(TestLambdaUDF { name, field: parameter, - signature: Signature::any(1, Volatility::Immutable), + signature: LambdaSignature::new(Volatility::Immutable), }) } diff --git a/datafusion/functions-nested/src/array_transform.rs b/datafusion/functions-nested/src/array_transform.rs index 746cfe6b62be5..b660f83b3ba2b 100644 --- a/datafusion/functions-nested/src/array_transform.rs +++ b/datafusion/functions-nested/src/array_transform.rs @@ -34,7 +34,7 @@ use datafusion_common::{ HashMap, Result, }; use datafusion_expr::{ - ColumnarValue, Documentation, LambdaFunctionArgs, LambdaUDF, Signature, + ColumnarValue, Documentation, LambdaFunctionArgs, LambdaSignature, LambdaUDF, ValueOrLambda, ValueOrLambdaField, ValueOrLambdaParameter, Volatility, }; use datafusion_macros::user_doc; @@ -70,7 +70,7 @@ use std::{any::Any, sync::Arc}; )] #[derive(Debug, PartialEq, Eq, Hash)] pub struct ArrayTransform { - signature: Signature, + signature: LambdaSignature, aliases: Vec, } @@ -83,7 +83,7 @@ impl Default for ArrayTransform { impl ArrayTransform { pub fn new() -> Self { Self { - signature: Signature::any(2, Volatility::Immutable), + signature: LambdaSignature::new(Volatility::Immutable), aliases: vec![String::from("list_transform")], } } @@ -102,7 +102,7 @@ impl LambdaUDF for ArrayTransform { &self.aliases } - fn signature(&self) -> &Signature { + fn signature(&self) -> &LambdaSignature { &self.signature } diff --git a/datafusion/physical-expr/src/lambda_function.rs b/datafusion/physical-expr/src/lambda_function.rs index 0b0c33cdd6be9..5ba354f4e8a4f 100644 --- a/datafusion/physical-expr/src/lambda_function.rs +++ b/datafusion/physical-expr/src/lambda_function.rs @@ -449,7 +449,7 @@ mod tests { use crate::LambdaFunctionExpr; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::Result; - use datafusion_expr::{LambdaFunctionArgs, LambdaUDF, Signature}; + use datafusion_expr::{LambdaFunctionArgs, LambdaUDF, LambdaSignature}; use datafusion_expr_common::columnar_value::ColumnarValue; use datafusion_physical_expr_common::physical_expr::is_volatile; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; @@ -457,7 +457,7 @@ mod tests { /// Test helper to create a mock UDF with a specific volatility #[derive(Debug, PartialEq, Eq, Hash)] struct MockLambdaUDF { - signature: Signature, + signature: LambdaSignature, } impl LambdaUDF for MockLambdaUDF { @@ -469,7 +469,7 @@ mod tests { "mock_function" } - fn signature(&self) -> &Signature { + fn signature(&self) -> &LambdaSignature { &self.signature } @@ -489,16 +489,12 @@ mod tests { fn test_lambda_function_volatile_node() { // Create a volatile UDF let volatile_udf = Arc::new(MockLambdaUDF { - signature: Signature::uniform( - 1, - vec![DataType::Float32], - Volatility::Volatile, - ), + signature: LambdaSignature::new(Volatility::Volatile), }); // Create a non-volatile UDF let stable_udf = Arc::new(MockLambdaUDF { - signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable), + signature: LambdaSignature::new(Volatility::Stable), }); let schema = Schema::new(vec![Field::new("a", DataType::Float32, false)]); diff --git a/datafusion/proto/src/bytes/mod.rs b/datafusion/proto/src/bytes/mod.rs index e421ea11c4b2d..732831b9464c7 100644 --- a/datafusion/proto/src/bytes/mod.rs +++ b/datafusion/proto/src/bytes/mod.rs @@ -27,8 +27,8 @@ use crate::protobuf; use datafusion_common::{not_impl_err, plan_datafusion_err, Result}; use datafusion_execution::TaskContext; use datafusion_expr::{ - create_udaf, create_udf, create_udwf, AggregateUDF, Expr, LambdaUDF, LogicalPlan, - Signature, Volatility, WindowUDF, + create_udaf, create_udf, create_udwf, AggregateUDF, Expr, LambdaUDF, LambdaSignature, LogicalPlan, + Volatility, WindowUDF, }; use prost::{ bytes::{Bytes, BytesMut}, @@ -196,7 +196,7 @@ impl Serializeable for Expr { #[derive(Debug, PartialEq, Eq, Hash)] struct MockLambdaUDF { name: String, - signature: Signature, + signature: LambdaSignature, } impl LambdaUDF for MockLambdaUDF { @@ -208,7 +208,7 @@ impl Serializeable for Expr { &self.name } - fn signature(&self) -> &Signature { + fn signature(&self) -> &LambdaSignature { &self.signature } @@ -229,7 +229,7 @@ impl Serializeable for Expr { Ok(Arc::new(MockLambdaUDF { name: name.to_string(), - signature: Signature::variadic_any(Volatility::Immutable), + signature: LambdaSignature::new(Volatility::Immutable), })) } } From 3ded1154a4aa7adc167ec033622308c2baa8d524 Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Sun, 8 Mar 2026 20:05:16 -0300 Subject: [PATCH 10/12] improve lambda type coercion --- datafusion/expr/src/lib.rs | 2 +- datafusion/expr/src/udlf.rs | 87 ++++++++++++++++--- .../functions-nested/src/array_transform.rs | 4 +- .../optimizer/src/analyzer/type_coercion.rs | 64 +++++++++++++- .../physical-expr/src/lambda_function.rs | 4 +- datafusion/proto/src/bytes/mod.rs | 2 +- 6 files changed, 143 insertions(+), 20 deletions(-) diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 4fc6c738ea06b..b03bab622a357 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -121,7 +121,7 @@ pub use udaf::{ pub use udf::{ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl}; pub use udlf::{ LambdaFunctionArgs, LambdaFunctionLambdaArg, LambdaReturnFieldArgs, LambdaSignature, - LambdaUDF, ValueOrLambda, ValueOrLambdaField, ValueOrLambdaParameter, + LambdaTypeSignature, LambdaUDF, ValueOrLambda, ValueOrLambdaField, ValueOrLambdaParameter, }; pub use udwf::{LimitEffect, ReversedUDWF, WindowUDF, WindowUDFImpl}; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; diff --git a/datafusion/expr/src/udlf.rs b/datafusion/expr/src/udlf.rs index 883a083ca9c9b..b2191e0883d30 100644 --- a/datafusion/expr/src/udlf.rs +++ b/datafusion/expr/src/udlf.rs @@ -35,11 +35,44 @@ use std::fmt::Debug; use std::hash::{Hash, Hasher}; use std::sync::Arc; +/// The types of arguments for which a function has implementations. +/// +/// [`LambdaTypeSignature`] **DOES NOT** define the types that a user query could call the +/// function with. DataFusion will automatically coerce (cast) argument types to +/// one of the supported function signatures, if possible. +/// +/// # Overview +/// Functions typically provide implementations for a small number of different +/// argument [`DataType`]s, rather than all possible combinations. If a user +/// calls a function with arguments that do not match any of the declared types, +/// DataFusion will attempt to automatically coerce (add casts to) function +/// arguments so they match the [`LambdaTypeSignature`]. See the [`type_coercion`] module +/// for more details +/// +/// [`type_coercion`]: crate::type_coercion +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] +pub enum LambdaTypeSignature { + /// The acceptable signature and coercions rules are special for this + /// function. + /// + /// If this signature is specified, + /// DataFusion will call [`LambdaUDF::coerce_value_types`] to prepare argument types. + /// + /// [`LambdaUDF::coerce_value_types`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/trait.LambdaUDF.html#method.coerce_value_types + UserDefined, + /// One or more lambdas or arguments with arbitrary types + VariadicAny, + /// The specified number of lambdas or arguments with arbitrary types. + Any(usize), +} + /// Provides information necessary for calling a lambda function. /// /// - [`Volatility`] defines how the output of the function changes with the input. #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] pub struct LambdaSignature { + /// The data types that the function accepts. See [LambdaTypeSignature] for more information. + pub type_signature: LambdaTypeSignature, /// The volatility of the function. See [Volatility] for more information. pub volatility: Volatility, /// Optional parameter names for the function arguments. @@ -51,9 +84,40 @@ pub struct LambdaSignature { } impl LambdaSignature { - /// Creates a new Signature from a given volatility. - pub fn new(volatility: Volatility) -> LambdaSignature { - LambdaSignature { volatility, parameter_names: None } + /// Creates a new LambdaSignature from a given type signature and volatility. + pub fn new(type_signature: LambdaTypeSignature, volatility: Volatility) -> Self { + LambdaSignature { + type_signature, + volatility, + parameter_names: None, + } + } + + /// User-defined coercion rules for the function. + pub fn user_defined(volatility: Volatility) -> Self { + Self { + type_signature: LambdaTypeSignature::UserDefined, + volatility, + parameter_names: None, + } + } + + /// An arbitrary number of lambdas or arguments of any type. + pub fn variadic_any(volatility: Volatility) -> Self { + Self { + type_signature: LambdaTypeSignature::VariadicAny, + volatility, + parameter_names: None, + } + } + + /// A specified number of arguments of any type + pub fn any(arg_count: usize, volatility: Volatility) -> Self { + Self { + type_signature: LambdaTypeSignature::Any(arg_count), + volatility, + parameter_names: None, + } } } @@ -99,10 +163,10 @@ impl Hash for dyn LambdaUDF { } } -#[derive(Clone, Debug)] -pub enum ValueOrLambdaParameter { - /// A columnar value with the given field - Value(FieldRef), +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum ValueOrLambdaParameter { + /// A value with the given associated data + Value(T), /// A lambda Lambda, } @@ -566,7 +630,10 @@ pub trait LambdaUDF: Debug + DynEq + DynHash + Send + Sync { /// # Return value /// A Vec the same length as `arg_types`. DataFusion will `CAST` the function call /// arguments to these specific types. - fn coerce_types(&self, _arg_types: &[DataType]) -> Result> { + fn coerce_value_types( + &self, + _arg_types: &[ValueOrLambdaParameter], + ) -> Result>> { not_impl_err!("Function {} does not implement coerce_types", self.name()) } @@ -581,7 +648,7 @@ pub trait LambdaUDF: Debug + DynEq + DynHash + Send + Sync { /// Returns the parameters that any lambda supports fn lambdas_parameters( &self, - args: &[ValueOrLambdaParameter], + args: &[ValueOrLambdaParameter], ) -> Result>>> { Ok(vec![None; args.len()]) } @@ -657,7 +724,7 @@ mod tests { Arc::new(TestLambdaUDF { name, field: parameter, - signature: LambdaSignature::new(Volatility::Immutable), + signature: LambdaSignature::variadic_any(Volatility::Immutable), }) } diff --git a/datafusion/functions-nested/src/array_transform.rs b/datafusion/functions-nested/src/array_transform.rs index b660f83b3ba2b..e0c4ab28c1fef 100644 --- a/datafusion/functions-nested/src/array_transform.rs +++ b/datafusion/functions-nested/src/array_transform.rs @@ -83,7 +83,7 @@ impl Default for ArrayTransform { impl ArrayTransform { pub fn new() -> Self { Self { - signature: LambdaSignature::new(Volatility::Immutable), + signature: LambdaSignature::any(2, Volatility::Immutable), aliases: vec![String::from("list_transform")], } } @@ -238,7 +238,7 @@ impl LambdaUDF for ArrayTransform { fn lambdas_parameters( &self, - args: &[ValueOrLambdaParameter], + args: &[ValueOrLambdaParameter], ) -> Result>>> { let [ValueOrLambdaParameter::Value(list), ValueOrLambdaParameter::Lambda] = args else { diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index e0b1d9096b415..8c6c42e7e630c 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -35,7 +35,7 @@ use datafusion_common::{ }; use datafusion_expr::expr::{ self, AggregateFunctionParams, Alias, Between, BinaryExpr, Case, Exists, InList, - InSubquery, Like, ScalarFunction, Sort, WindowFunction, + InSubquery, LambdaFunction, Like, ScalarFunction, Sort, WindowFunction, }; use datafusion_expr::expr_rewriter::coerce_plan_expr_for_schema; use datafusion_expr::expr_schema::cast_subquery; @@ -51,8 +51,9 @@ use datafusion_expr::type_coercion::{is_datetime, is_utf8_or_utf8view_or_large_u use datafusion_expr::utils::merge_schema; use datafusion_expr::{ is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown, not, - AggregateUDF, Expr, ExprSchemable, Join, Limit, LogicalPlan, Operator, Projection, - ScalarUDF, Union, WindowFrame, WindowFrameBound, WindowFrameUnits, + AggregateUDF, Expr, ExprSchemable, Join, LambdaTypeSignature, Limit, LogicalPlan, + Operator, Projection, ScalarUDF, Union, ValueOrLambdaParameter, WindowFrame, + WindowFrameBound, WindowFrameUnits, }; /// Performs type coercion by determining the schema @@ -582,6 +583,62 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { }); Ok(Transformed::yes(new_expr)) } + Expr::LambdaFunction(LambdaFunction { func, args }) => { + match func.signature().type_signature { + LambdaTypeSignature::UserDefined => { + let args_types = args + .iter() + .map(|arg| match arg { + Expr::Lambda(_) => Ok(ValueOrLambdaParameter::Lambda), + _ => Ok(ValueOrLambdaParameter::Value( + arg.get_type(self.schema)?, + )), + }) + .collect::>>()?; + + let value_types = func.coerce_value_types(&args_types)?; + + if args_types.iter().eq_by(&value_types, |a, b| match (a, b) { + (ValueOrLambdaParameter::Value(_type), None) => false, + (ValueOrLambdaParameter::Value(from), Some(to)) => from == to, + (ValueOrLambdaParameter::Lambda, None) => true, + (ValueOrLambdaParameter::Lambda, Some(_ty)) => false, + }) { + return Ok(Transformed::no(Expr::LambdaFunction( + LambdaFunction::new(func, args), + ))); + } + + let args = std::iter::zip(args, value_types) + .map(|(arg, ty)| match (&arg, ty) { + (Expr::Lambda(_), None) => Ok(arg), + (Expr::Lambda(_), Some(_ty)) => plan_err!("{} coerce_value_types returned Some for a lambda argument", func.name()), + (_, Some(ty)) => arg.cast_to(&ty, self.schema), + (_, None) => plan_err!("{} coerce_value_types returned None for a value argument", func.name()), + }) + .collect::>>()?; + + Ok(Transformed::yes(Expr::LambdaFunction(LambdaFunction::new( + func, args, + )))) + } + LambdaTypeSignature::VariadicAny => Ok(Transformed::no( + Expr::LambdaFunction(LambdaFunction::new(func, args)), + )), + LambdaTypeSignature::Any(number) => { + if args.len() != number { + return plan_err!( + "The function '{}' expected {number} arguments but received {}", + func.name(), args.len() + ); + } + + Ok(Transformed::no(Expr::LambdaFunction(LambdaFunction::new( + func, args, + )))) + } + } + } // TODO: remove the next line after `Expr::Wildcard` is removed #[expect(deprecated)] Expr::Alias(_) @@ -598,7 +655,6 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { | Expr::GroupingSet(_) | Expr::Placeholder(_) | Expr::OuterReferenceColumn(_, _) - | Expr::LambdaFunction(_) | Expr::Lambda(_) | Expr::LambdaVariable(_) => Ok(Transformed::no(expr)), } diff --git a/datafusion/physical-expr/src/lambda_function.rs b/datafusion/physical-expr/src/lambda_function.rs index 5ba354f4e8a4f..97af1f9b13891 100644 --- a/datafusion/physical-expr/src/lambda_function.rs +++ b/datafusion/physical-expr/src/lambda_function.rs @@ -489,12 +489,12 @@ mod tests { fn test_lambda_function_volatile_node() { // Create a volatile UDF let volatile_udf = Arc::new(MockLambdaUDF { - signature: LambdaSignature::new(Volatility::Volatile), + signature: LambdaSignature::variadic_any(Volatility::Volatile), }); // Create a non-volatile UDF let stable_udf = Arc::new(MockLambdaUDF { - signature: LambdaSignature::new(Volatility::Stable), + signature: LambdaSignature::variadic_any(Volatility::Stable), }); let schema = Schema::new(vec![Field::new("a", DataType::Float32, false)]); diff --git a/datafusion/proto/src/bytes/mod.rs b/datafusion/proto/src/bytes/mod.rs index 732831b9464c7..e9060c0f2c986 100644 --- a/datafusion/proto/src/bytes/mod.rs +++ b/datafusion/proto/src/bytes/mod.rs @@ -229,7 +229,7 @@ impl Serializeable for Expr { Ok(Arc::new(MockLambdaUDF { name: name.to_string(), - signature: LambdaSignature::new(Volatility::Immutable), + signature: LambdaSignature::variadic_any(Volatility::Immutable), })) } } From 82930ec92df0ec3a129a2cfcd34c897019e38a7f Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Sun, 8 Mar 2026 23:58:54 -0300 Subject: [PATCH 11/12] lambda function type coercion: stop using unstable Iterator::eq_by --- datafusion/optimizer/src/analyzer/type_coercion.rs | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 8c6c42e7e630c..e97d79e98e977 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -598,12 +598,14 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { let value_types = func.coerce_value_types(&args_types)?; - if args_types.iter().eq_by(&value_types, |a, b| match (a, b) { - (ValueOrLambdaParameter::Value(_type), None) => false, - (ValueOrLambdaParameter::Value(from), Some(to)) => from == to, - (ValueOrLambdaParameter::Lambda, None) => true, - (ValueOrLambdaParameter::Lambda, Some(_ty)) => false, - }) { + if args_types + .iter() + .map(|a| match a { + ValueOrLambdaParameter::Value(ty) => Some(ty), + ValueOrLambdaParameter::Lambda => None, + }) + .eq(value_types.iter().map(|v| v.as_ref())) + { return Ok(Transformed::no(Expr::LambdaFunction( LambdaFunction::new(func, args), ))); From 86d5999056b49b4b597c1103ee9f57f65cacc882 Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Mon, 9 Mar 2026 01:30:48 -0300 Subject: [PATCH 12/12] remove signature section from DOC.md --- DOC.md | 4 ---- 1 file changed, 4 deletions(-) diff --git a/DOC.md b/DOC.md index 10a0ab4e19407..a88e69a5689e3 100644 --- a/DOC.md +++ b/DOC.md @@ -1160,7 +1160,3 @@ impl SqlToRel {
- -`LambdaFunction` `Signature` is non functional - -Currenty, `LambdaUDF::signature` returns the same `Signature` as `ScalarUDF`, but it's `type_signature` field is never used, as most variants of the `TypeSignature` enum aren't applicable to a lambda, and no type coercion is applied on it's arguments, being currently a implementation responsability. We should either add lambda compatible variants to the `TypeSignature` enum, create a new `LambdaTypeSignature` and `LambdaSignature`, or support no automatic type coercion at all on lambda functions.