diff --git a/datafusion/substrait/src/logical_plan/consumer/expr/mod.rs b/datafusion/substrait/src/logical_plan/consumer/expr/mod.rs index a0468dbd451b9..295456e95f9f3 100644 --- a/datafusion/substrait/src/logical_plan/consumer/expr/mod.rs +++ b/datafusion/substrait/src/logical_plan/consumer/expr/mod.rs @@ -21,6 +21,7 @@ mod field_reference; mod function_arguments; mod if_then; mod literal; +mod nested; mod scalar_function; mod singular_or_list; mod subquery; @@ -32,6 +33,7 @@ pub use field_reference::*; pub use function_arguments::*; pub use if_then::*; pub use literal::*; +pub use nested::*; pub use scalar_function::*; pub use singular_or_list::*; pub use subquery::*; diff --git a/datafusion/substrait/src/logical_plan/consumer/expr/nested.rs b/datafusion/substrait/src/logical_plan/consumer/expr/nested.rs new file mode 100644 index 0000000000000..f94a701342826 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/expr/nested.rs @@ -0,0 +1,151 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::consumer::SubstraitConsumer; +use datafusion::common::{DFSchema, not_impl_err, substrait_err}; +use datafusion::execution::FunctionRegistry; +use datafusion::logical_expr::Expr; +use substrait::proto::expression::Nested; +use substrait::proto::expression::nested::NestedType; + +/// Converts a Substrait [Nested] expression into a DataFusion [Expr]. +/// +/// Substrait Nested expressions represent complex type constructors (list, struct, map) +/// where elements are full expressions rather than just literals. This is used by +/// producers that emit `Nested { list: ... }` for array construction, as opposed to +/// `Literal { list: ... }` which only supports scalar values. +pub async fn from_nested( + consumer: &impl SubstraitConsumer, + nested: &Nested, + input_schema: &DFSchema, +) -> datafusion::common::Result { + let Some(nested_type) = &nested.nested_type else { + return substrait_err!("Nested expression requires a nested_type"); + }; + + match nested_type { + NestedType::List(list) => { + if list.values.is_empty() { + return substrait_err!( + "Empty Nested lists are not supported; use Literal.empty_list instead" + ); + } + + let mut args = Vec::with_capacity(list.values.len()); + for value in &list.values { + args.push(consumer.consume_expression(value, input_schema).await?); + } + + let make_array_udf = consumer.get_function_registry().udf("make_array")?; + Ok(Expr::ScalarFunction( + datafusion::logical_expr::expr::ScalarFunction::new_udf( + make_array_udf, + args, + ), + )) + } + NestedType::Struct(_) => { + not_impl_err!("Nested struct expressions are not yet supported") + } + NestedType::Map(_) => { + not_impl_err!("Nested map expressions are not yet supported") + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::logical_plan::consumer::utils::tests::test_consumer; + use substrait::proto::expression::Literal; + use substrait::proto::expression::nested::List; + use substrait::proto::{self, Expression}; + + fn make_i64_literal(value: i64) -> Expression { + Expression { + rex_type: Some(proto::expression::RexType::Literal(Literal { + nullable: false, + type_variation_reference: 0, + literal_type: Some(proto::expression::literal::LiteralType::I64(value)), + })), + } + } + + #[tokio::test] + async fn nested_list_with_literals() -> datafusion::common::Result<()> { + let consumer = test_consumer(); + let schema = DFSchema::empty(); + let nested = Nested { + nullable: false, + type_variation_reference: 0, + nested_type: Some(NestedType::List(List { + values: vec![ + make_i64_literal(1), + make_i64_literal(2), + make_i64_literal(3), + ], + })), + }; + + let expr = from_nested(&consumer, &nested, &schema).await?; + assert_eq!( + format!("{expr}"), + "make_array(Int64(1), Int64(2), Int64(3))" + ); + + Ok(()) + } + + #[tokio::test] + async fn nested_list_empty_rejected() -> datafusion::common::Result<()> { + let consumer = test_consumer(); + let schema = DFSchema::empty(); + let nested = Nested { + nullable: true, + type_variation_reference: 0, + nested_type: Some(NestedType::List(List { values: vec![] })), + }; + + let result = from_nested(&consumer, &nested, &schema).await; + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("Empty Nested lists are not supported") + ); + + Ok(()) + } + + #[tokio::test] + async fn nested_missing_type() -> datafusion::common::Result<()> { + let consumer = test_consumer(); + let schema = DFSchema::empty(); + let nested = Nested { + nullable: false, + type_variation_reference: 0, + nested_type: None, + }; + + let result = from_nested(&consumer, &nested, &schema).await; + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("nested_type")); + + Ok(()) + } +} diff --git a/datafusion/substrait/src/logical_plan/consumer/substrait_consumer.rs b/datafusion/substrait/src/logical_plan/consumer/substrait_consumer.rs index 730ceab8ccef3..14385888a8de4 100644 --- a/datafusion/substrait/src/logical_plan/consumer/substrait_consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer/substrait_consumer.rs @@ -18,7 +18,7 @@ use super::{ from_aggregate_rel, from_cast, from_cross_rel, from_exchange_rel, from_fetch_rel, from_field_reference, from_filter_rel, from_if_then, from_join_rel, from_literal, - from_project_rel, from_read_rel, from_scalar_function, from_set_rel, + from_nested, from_project_rel, from_read_rel, from_scalar_function, from_set_rel, from_singular_or_list, from_sort_rel, from_subquery, from_substrait_rel, from_substrait_rex, from_window_function, }; @@ -350,10 +350,10 @@ pub trait SubstraitConsumer: Send + Sync + Sized { async fn consume_nested( &self, - _expr: &Nested, - _input_schema: &DFSchema, + expr: &Nested, + input_schema: &DFSchema, ) -> datafusion::common::Result { - not_impl_err!("Nested expression not supported") + from_nested(self, expr, input_schema).await } async fn consume_enum( diff --git a/datafusion/substrait/tests/cases/logical_plans.rs b/datafusion/substrait/tests/cases/logical_plans.rs index 9de7cb8f3835e..663a372fe2e4f 100644 --- a/datafusion/substrait/tests/cases/logical_plans.rs +++ b/datafusion/substrait/tests/cases/logical_plans.rs @@ -270,4 +270,27 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn nested_list_expressions() -> Result<()> { + // Tests that a Substrait Nested list expression containing non-literal + // expressions (column references) uses the make_array UDF. + let proto_plan = + read_json("tests/testdata/test_plans/nested_list_expressions.substrait.json"); + let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan)?; + let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?; + + assert_snapshot!( + plan, + @r" + Projection: make_array(DATA.a, DATA.b) AS my_list + TableScan: DATA + " + ); + + // Trigger execution to ensure plan validity + DataFrame::new(ctx.state(), plan).show().await?; + + Ok(()) + } } diff --git a/datafusion/substrait/tests/testdata/test_plans/nested_list_expressions.substrait.json b/datafusion/substrait/tests/testdata/test_plans/nested_list_expressions.substrait.json new file mode 100644 index 0000000000000..85a69c41c5eb1 --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/nested_list_expressions.substrait.json @@ -0,0 +1,77 @@ +{ + "relations": [ + { + "root": { + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [2] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": ["a", "b"], + "struct": { + "types": [ + { + "i32": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "i32": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["DATA"] + } + } + }, + "expressions": [ + { + "nested": { + "nullable": false, + "list": { + "values": [ + { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": {} + } + } + ] + } + } + } + ] + } + }, + "names": ["my_list"] + } + } + ] +}