diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 2e23fef1da76..16bbffabe29b 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -593,7 +593,23 @@ impl LogicalPlanBuilder { self, expr: Vec<(impl Into, bool)>, ) -> Result { - project_with_validation(Arc::unwrap_or_clone(self.plan), expr).map(Self::new) + project_with_validation(Arc::unwrap_or_clone(self.plan), expr, None) + .map(Self::new) + } + + /// Apply a projection, aliasing non-Column/non-Alias expressions to + /// match the field names from the provided schema. + pub fn project_with_validation_and_schema( + self, + expr: impl IntoIterator>, + schema: &DFSchemaRef, + ) -> Result { + project_with_validation( + Arc::unwrap_or_clone(self.plan), + expr.into_iter().map(|e| (e, true)), + Some(schema), + ) + .map(Self::new) } /// Select the given column indices @@ -1916,7 +1932,7 @@ pub fn project( plan: LogicalPlan, expr: impl IntoIterator>, ) -> Result { - project_with_validation(plan, expr.into_iter().map(|e| (e, true))) + project_with_validation(plan, expr.into_iter().map(|e| (e, true)), None) } /// Create Projection. Similar to project except that the expressions @@ -1929,6 +1945,7 @@ pub fn project( fn project_with_validation( plan: LogicalPlan, expr: impl IntoIterator, bool)>, + schema: Option<&DFSchemaRef>, ) -> Result { let mut projected_expr = vec![]; for (e, validate) in expr { @@ -1984,6 +2001,17 @@ fn project_with_validation( } } } + + // When inside a set expression, alias non-Column/non-Alias expressions + // to match the left side's field names, avoiding duplicate name errors. + if let Some(schema) = &schema { + for (expr, field) in projected_expr.iter_mut().zip(schema.fields()) { + if !matches!(expr, Expr::Column(_) | Expr::Alias(_)) { + *expr = expr.clone().alias(field.name()); + } + } + } + validate_unique_names("Projections", projected_expr.iter())?; Projection::try_new(projected_expr, Arc::new(plan)).map(LogicalPlan::Projection) diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index b7e270e4f057..32daf65a71fa 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -270,6 +270,10 @@ pub struct PlannerContext { outer_from_schema: Option, /// The query schema defined by the table create_table_schema: Option, + /// When planning non-first queries in a set expression + /// (UNION/INTERSECT/EXCEPT), holds the schema of the left-most query. + /// Used to alias duplicate expressions to match the left side's field names. + set_expr_left_schema: Option, } impl Default for PlannerContext { @@ -287,6 +291,7 @@ impl PlannerContext { outer_queries_schemas_stack: vec![], outer_from_schema: None, create_table_schema: None, + set_expr_left_schema: None, } } @@ -400,6 +405,14 @@ impl PlannerContext { pub(super) fn remove_cte(&mut self, cte_name: &str) { self.ctes.remove(cte_name); } + + /// Sets the left-most set expression schema, returning the previous value + pub(super) fn set_set_expr_left_schema( + &mut self, + schema: Option, + ) -> Option { + std::mem::replace(&mut self.set_expr_left_schema, schema) + } } /// SQL query planner and binder diff --git a/datafusion/sql/src/query.rs b/datafusion/sql/src/query.rs index 1b7bb856a592..e320d2ee6e9c 100644 --- a/datafusion/sql/src/query.rs +++ b/datafusion/sql/src/query.rs @@ -152,7 +152,7 @@ impl SqlToRel<'_, S> { let empty_from = matches!(plan, LogicalPlan::EmptyRelation(_)); let select_exprs = self.prepare_select_exprs(&plan, exprs, empty_from, planner_context)?; - self.project(plan, select_exprs) + self.project(plan, select_exprs, None) } PipeOperator::Extend { exprs } => { let empty_from = matches!(plan, LogicalPlan::EmptyRelation(_)); @@ -162,7 +162,7 @@ impl SqlToRel<'_, S> { std::iter::once(SelectExpr::Wildcard(WildcardOptions::default())) .chain(extend_exprs) .collect(); - self.project(plan, all_exprs) + self.project(plan, all_exprs, None) } PipeOperator::As { alias } => self.apply_table_alias( plan, diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs index 7e291afa04b6..d11fd6265ff3 100644 --- a/datafusion/sql/src/select.rs +++ b/datafusion/sql/src/select.rs @@ -29,7 +29,7 @@ use crate::utils::{ use datafusion_common::error::DataFusionErrorBuilder; use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; -use datafusion_common::{Column, DFSchema, Result, not_impl_err, plan_err}; +use datafusion_common::{Column, DFSchema, DFSchemaRef, Result, not_impl_err, plan_err}; use datafusion_common::{RecursionUnnestOption, UnnestOptions}; use datafusion_expr::expr::{Alias, PlannedReplaceSelectItem, WildcardOptions}; use datafusion_expr::expr_rewriter::{ @@ -90,6 +90,10 @@ impl SqlToRel<'_, S> { return not_impl_err!("SORT BY"); } + // Capture and clear set expression schema so it doesn't leak + // into subqueries planned during FROM clause handling. + let set_expr_left_schema = planner_context.set_set_expr_left_schema(None); + // Process `from` clause let plan = self.plan_from_tables(select.from, planner_context)?; let empty_from = matches!(plan, LogicalPlan::EmptyRelation(_)); @@ -110,7 +114,8 @@ impl SqlToRel<'_, S> { )?; // Having and group by clause may reference aliases defined in select projection - let projected_plan = self.project(base_plan.clone(), select_exprs)?; + let projected_plan = + self.project(base_plan.clone(), select_exprs, set_expr_left_schema)?; let select_exprs = projected_plan.expressions(); let order_by = @@ -879,18 +884,29 @@ impl SqlToRel<'_, S> { &self, input: LogicalPlan, expr: Vec, + set_expr_left_schema: Option, ) -> Result { // convert to Expr for validate_schema_satisfies_exprs - let exprs = expr + let plain_exprs = expr .iter() .filter_map(|e| match e { SelectExpr::Expression(expr) => Some(expr.to_owned()), _ => None, }) .collect::>(); - self.validate_schema_satisfies_exprs(input.schema(), &exprs)?; - - LogicalPlanBuilder::from(input).project(expr)?.build() + self.validate_schema_satisfies_exprs(input.schema(), &plain_exprs)?; + + // When inside a set expression, pass the left-most schema so + // that expressions get aliased to match, avoiding duplicate + // name errors from expressions like `count(*), count(*)`. + let builder = LogicalPlanBuilder::from(input); + if let Some(left_schema) = set_expr_left_schema { + builder + .project_with_validation_and_schema(expr, &left_schema)? + .build() + } else { + builder.project(expr)?.build() + } } /// Create an aggregate plan. diff --git a/datafusion/sql/src/set_expr.rs b/datafusion/sql/src/set_expr.rs index d4e771cb4858..dc8e4f14d1ee 100644 --- a/datafusion/sql/src/set_expr.rs +++ b/datafusion/sql/src/set_expr.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +use std::sync::Arc; + use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use datafusion_common::{ DataFusionError, Diagnostic, Result, Span, not_impl_err, plan_err, @@ -42,7 +44,23 @@ impl SqlToRel<'_, S> { let left_span = Span::try_from_sqlparser_span(left.span()); let right_span = Span::try_from_sqlparser_span(right.span()); let left_plan = self.set_expr_to_plan(*left, planner_context); + // Store the left plan's schema so that the right side can + // alias duplicate expressions to match. Skip for BY NAME + // operations since those match columns by name, not position. + if let Ok(plan) = &left_plan + && plan.schema().fields().len() > 1 + && !matches!( + set_quantifier, + SetQuantifier::ByName + | SetQuantifier::AllByName + | SetQuantifier::DistinctByName + ) + { + planner_context + .set_set_expr_left_schema(Some(Arc::clone(plan.schema()))); + } let right_plan = self.set_expr_to_plan(*right, planner_context); + planner_context.set_set_expr_left_schema(None); let (left_plan, right_plan) = match (left_plan, right_plan) { (Ok(left_plan), Ok(right_plan)) => (left_plan, right_plan), (Err(left_err), Err(right_err)) => { diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 29c17be69ce5..b1f4929eaa22 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -2655,6 +2655,106 @@ fn union_all_by_name_same_column_names() { ); } +#[test] +fn union_all_with_duplicate_expressions() { + let sql = "\ + SELECT 0 a, 0 b \ + UNION ALL SELECT 1, 1 \ + UNION ALL SELECT count(*), count(*) FROM orders"; + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r" + Union + Union + Projection: Int64(0) AS a, Int64(0) AS b + EmptyRelation: rows=1 + Projection: Int64(1) AS a, Int64(1) AS b + EmptyRelation: rows=1 + Projection: count(*) AS a, count(*) AS b + Aggregate: groupBy=[[]], aggr=[[count(*)]] + TableScan: orders + " + ); +} + +#[test] +fn union_with_qualified_and_duplicate_expressions() { + let sql = "\ + SELECT 0 a, id b, price c, 0 d FROM test_decimal \ + UNION SELECT 1, *, 1 FROM test_decimal"; + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @" + Distinct: + Union + Projection: Int64(0) AS a, test_decimal.id AS b, test_decimal.price AS c, Int64(0) AS d + TableScan: test_decimal + Projection: Int64(1) AS a, test_decimal.id, test_decimal.price, Int64(1) AS d + TableScan: test_decimal + " + ); +} + +#[test] +fn intersect_with_duplicate_expressions() { + let sql = "\ + SELECT 0 a, 0 b \ + INTERSECT SELECT 1, 1 \ + INTERSECT SELECT count(*), count(*) FROM orders"; + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r" + LeftSemi Join: left.a = right.a, left.b = right.b + Distinct: + SubqueryAlias: left + LeftSemi Join: left.a = right.a, left.b = right.b + Distinct: + SubqueryAlias: left + Projection: Int64(0) AS a, Int64(0) AS b + EmptyRelation: rows=1 + SubqueryAlias: right + Projection: Int64(1) AS a, Int64(1) AS b + EmptyRelation: rows=1 + SubqueryAlias: right + Projection: count(*) AS a, count(*) AS b + Aggregate: groupBy=[[]], aggr=[[count(*)]] + TableScan: orders + " + ); +} + +#[test] +fn except_with_duplicate_expressions() { + let sql = "\ + SELECT 0 a, 0 b \ + EXCEPT SELECT 1, 1 \ + EXCEPT SELECT count(*), count(*) FROM orders"; + let plan = logical_plan(sql).unwrap(); + assert_snapshot!( + plan, + @r" + LeftAnti Join: left.a = right.a, left.b = right.b + Distinct: + SubqueryAlias: left + LeftAnti Join: left.a = right.a, left.b = right.b + Distinct: + SubqueryAlias: left + Projection: Int64(0) AS a, Int64(0) AS b + EmptyRelation: rows=1 + SubqueryAlias: right + Projection: Int64(1) AS a, Int64(1) AS b + EmptyRelation: rows=1 + SubqueryAlias: right + Projection: count(*) AS a, count(*) AS b + Aggregate: groupBy=[[]], aggr=[[count(*)]] + TableScan: orders + " + ); +} + #[test] fn empty_over() { let sql = "SELECT order_id, MAX(order_id) OVER () from orders"; diff --git a/datafusion/sqllogictest/test_files/union.slt b/datafusion/sqllogictest/test_files/union.slt index d858d0ae3ea4..d0ad5c8bb3c5 100644 --- a/datafusion/sqllogictest/test_files/union.slt +++ b/datafusion/sqllogictest/test_files/union.slt @@ -256,6 +256,30 @@ Bob_new John John_new +# Test UNION ALL with unaliased duplicate literal values on the right side. +# The second projection will inherit field names from the first one, and so +# pass the unique projection expression name check. +query TII rowsort +SELECT name, 1 as table, 1 as row FROM t1 WHERE id = 1 +UNION ALL +SELECT name, 2, 2 FROM t2 WHERE id = 2 +---- +Alex 1 1 +Bob 2 2 + +# Test nested UNION, EXCEPT, INTERSECT with duplicate unaliased literals. +# Only the first SELECT has column aliases, which should propagate to all projections. +query III rowsort +SELECT 1 as a, 0 as b, 0 as c +UNION ALL +((SELECT 2, 0, 0 UNION ALL SELECT 3, 0, 0) EXCEPT SELECT 3, 0, 0) +UNION ALL +(SELECT 4, 0, 0 INTERSECT SELECT 4, 0, 0) +---- +1 0 0 +2 0 0 +4 0 0 + # Plan is unnested query TT EXPLAIN SELECT name FROM t1 UNION ALL (SELECT name from t2 UNION ALL SELECT name || '_new' from t2)