Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 30 additions & 2 deletions datafusion/expr/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,23 @@ impl LogicalPlanBuilder {
self,
expr: Vec<(impl Into<SelectExpr>, bool)>,
) -> Result<Self> {
project_with_validation(Arc::unwrap_or_clone(self.plan), expr).map(Self::new)
project_with_validation(Arc::unwrap_or_clone(self.plan), expr, None)
.map(Self::new)
}

/// Apply a projection, aliasing non-Column/non-Alias expressions to
/// match the field names from the provided schema.
pub fn project_with_validation_and_schema(
self,
expr: impl IntoIterator<Item = impl Into<SelectExpr>>,
schema: &DFSchemaRef,
) -> Result<Self> {
project_with_validation(
Arc::unwrap_or_clone(self.plan),
expr.into_iter().map(|e| (e, true)),
Some(schema),
)
.map(Self::new)
}

/// Select the given column indices
Expand Down Expand Up @@ -1916,7 +1932,7 @@ pub fn project(
plan: LogicalPlan,
expr: impl IntoIterator<Item = impl Into<SelectExpr>>,
) -> Result<LogicalPlan> {
project_with_validation(plan, expr.into_iter().map(|e| (e, true)))
project_with_validation(plan, expr.into_iter().map(|e| (e, true)), None)
}

/// Create Projection. Similar to project except that the expressions
Expand All @@ -1929,6 +1945,7 @@ pub fn project(
fn project_with_validation(
plan: LogicalPlan,
expr: impl IntoIterator<Item = (impl Into<SelectExpr>, bool)>,
schema: Option<&DFSchemaRef>,
) -> Result<LogicalPlan> {
let mut projected_expr = vec![];
for (e, validate) in expr {
Expand Down Expand Up @@ -1984,6 +2001,17 @@ fn project_with_validation(
}
}
}

// When inside a set expression, alias non-Column/non-Alias expressions
// to match the left side's field names, avoiding duplicate name errors.
if let Some(schema) = &schema {
for (expr, field) in projected_expr.iter_mut().zip(schema.fields()) {
if !matches!(expr, Expr::Column(_) | Expr::Alias(_)) {
*expr = expr.clone().alias(field.name());
}
}
}

validate_unique_names("Projections", projected_expr.iter())?;

Projection::try_new(projected_expr, Arc::new(plan)).map(LogicalPlan::Projection)
Expand Down
13 changes: 13 additions & 0 deletions datafusion/sql/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,10 @@ pub struct PlannerContext {
outer_from_schema: Option<DFSchemaRef>,
/// The query schema defined by the table
create_table_schema: Option<DFSchemaRef>,
/// When planning non-first queries in a set expression
/// (UNION/INTERSECT/EXCEPT), holds the schema of the left-most query.
/// Used to alias duplicate expressions to match the left side's field names.
set_expr_left_schema: Option<DFSchemaRef>,
}

impl Default for PlannerContext {
Expand All @@ -287,6 +291,7 @@ impl PlannerContext {
outer_queries_schemas_stack: vec![],
outer_from_schema: None,
create_table_schema: None,
set_expr_left_schema: None,
}
}

Expand Down Expand Up @@ -400,6 +405,14 @@ impl PlannerContext {
pub(super) fn remove_cte(&mut self, cte_name: &str) {
self.ctes.remove(cte_name);
}

/// Sets the left-most set expression schema, returning the previous value
pub(super) fn set_set_expr_left_schema(
&mut self,
schema: Option<DFSchemaRef>,
) -> Option<DFSchemaRef> {
std::mem::replace(&mut self.set_expr_left_schema, schema)
}
}

/// SQL query planner and binder
Expand Down
4 changes: 2 additions & 2 deletions datafusion/sql/src/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
let empty_from = matches!(plan, LogicalPlan::EmptyRelation(_));
let select_exprs =
self.prepare_select_exprs(&plan, exprs, empty_from, planner_context)?;
self.project(plan, select_exprs)
self.project(plan, select_exprs, None)
}
PipeOperator::Extend { exprs } => {
let empty_from = matches!(plan, LogicalPlan::EmptyRelation(_));
Expand All @@ -162,7 +162,7 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
std::iter::once(SelectExpr::Wildcard(WildcardOptions::default()))
.chain(extend_exprs)
.collect();
self.project(plan, all_exprs)
self.project(plan, all_exprs, None)
}
PipeOperator::As { alias } => self.apply_table_alias(
plan,
Expand Down
28 changes: 22 additions & 6 deletions datafusion/sql/src/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use crate::utils::{

use datafusion_common::error::DataFusionErrorBuilder;
use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
use datafusion_common::{Column, DFSchema, Result, not_impl_err, plan_err};
use datafusion_common::{Column, DFSchema, DFSchemaRef, Result, not_impl_err, plan_err};
use datafusion_common::{RecursionUnnestOption, UnnestOptions};
use datafusion_expr::expr::{Alias, PlannedReplaceSelectItem, WildcardOptions};
use datafusion_expr::expr_rewriter::{
Expand Down Expand Up @@ -90,6 +90,10 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
return not_impl_err!("SORT BY");
}

// Capture and clear set expression schema so it doesn't leak
// into subqueries planned during FROM clause handling.
let set_expr_left_schema = planner_context.set_set_expr_left_schema(None);

// Process `from` clause
let plan = self.plan_from_tables(select.from, planner_context)?;
let empty_from = matches!(plan, LogicalPlan::EmptyRelation(_));
Expand All @@ -110,7 +114,8 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
)?;

// Having and group by clause may reference aliases defined in select projection
let projected_plan = self.project(base_plan.clone(), select_exprs)?;
let projected_plan =
self.project(base_plan.clone(), select_exprs, set_expr_left_schema)?;
let select_exprs = projected_plan.expressions();

let order_by =
Expand Down Expand Up @@ -879,18 +884,29 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
&self,
input: LogicalPlan,
expr: Vec<SelectExpr>,
set_expr_left_schema: Option<DFSchemaRef>,
) -> Result<LogicalPlan> {
// convert to Expr for validate_schema_satisfies_exprs
let exprs = expr
let plain_exprs = expr
.iter()
.filter_map(|e| match e {
SelectExpr::Expression(expr) => Some(expr.to_owned()),
_ => None,
})
.collect::<Vec<_>>();
self.validate_schema_satisfies_exprs(input.schema(), &exprs)?;

LogicalPlanBuilder::from(input).project(expr)?.build()
self.validate_schema_satisfies_exprs(input.schema(), &plain_exprs)?;

// When inside a set expression, pass the left-most schema so
// that expressions get aliased to match, avoiding duplicate
// name errors from expressions like `count(*), count(*)`.
let builder = LogicalPlanBuilder::from(input);
if let Some(left_schema) = set_expr_left_schema {
builder
.project_with_validation_and_schema(expr, &left_schema)?
.build()
} else {
builder.project(expr)?.build()
}
}

/// Create an aggregate plan.
Expand Down
18 changes: 18 additions & 0 deletions datafusion/sql/src/set_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
// specific language governing permissions and limitations
// under the License.

use std::sync::Arc;

use crate::planner::{ContextProvider, PlannerContext, SqlToRel};
use datafusion_common::{
DataFusionError, Diagnostic, Result, Span, not_impl_err, plan_err,
Expand Down Expand Up @@ -42,7 +44,23 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
let left_span = Span::try_from_sqlparser_span(left.span());
let right_span = Span::try_from_sqlparser_span(right.span());
let left_plan = self.set_expr_to_plan(*left, planner_context);
// Store the left plan's schema so that the right side can
// alias duplicate expressions to match. Skip for BY NAME
// operations since those match columns by name, not position.
if let Ok(plan) = &left_plan
&& plan.schema().fields().len() > 1
&& !matches!(
set_quantifier,
SetQuantifier::ByName
| SetQuantifier::AllByName
| SetQuantifier::DistinctByName
)
{
planner_context
.set_set_expr_left_schema(Some(Arc::clone(plan.schema())));
}
let right_plan = self.set_expr_to_plan(*right, planner_context);
planner_context.set_set_expr_left_schema(None);
let (left_plan, right_plan) = match (left_plan, right_plan) {
(Ok(left_plan), Ok(right_plan)) => (left_plan, right_plan),
(Err(left_err), Err(right_err)) => {
Expand Down
100 changes: 100 additions & 0 deletions datafusion/sql/tests/sql_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2655,6 +2655,106 @@ fn union_all_by_name_same_column_names() {
);
}

#[test]
fn union_all_with_duplicate_expressions() {
let sql = "\
SELECT 0 a, 0 b \
UNION ALL SELECT 1, 1 \
UNION ALL SELECT count(*), count(*) FROM orders";
let plan = logical_plan(sql).unwrap();
assert_snapshot!(
plan,
@r"
Union
Union
Projection: Int64(0) AS a, Int64(0) AS b
EmptyRelation: rows=1
Projection: Int64(1) AS a, Int64(1) AS b
EmptyRelation: rows=1
Projection: count(*) AS a, count(*) AS b
Aggregate: groupBy=[[]], aggr=[[count(*)]]
TableScan: orders
"
);
}

#[test]
fn union_with_qualified_and_duplicate_expressions() {
let sql = "\
SELECT 0 a, id b, price c, 0 d FROM test_decimal \
UNION SELECT 1, *, 1 FROM test_decimal";
let plan = logical_plan(sql).unwrap();
assert_snapshot!(
plan,
@"
Distinct:
Union
Projection: Int64(0) AS a, test_decimal.id AS b, test_decimal.price AS c, Int64(0) AS d
TableScan: test_decimal
Projection: Int64(1) AS a, test_decimal.id, test_decimal.price, Int64(1) AS d
TableScan: test_decimal
"
);
}

#[test]
fn intersect_with_duplicate_expressions() {
let sql = "\
SELECT 0 a, 0 b \
INTERSECT SELECT 1, 1 \
INTERSECT SELECT count(*), count(*) FROM orders";
let plan = logical_plan(sql).unwrap();
assert_snapshot!(
plan,
@r"
LeftSemi Join: left.a = right.a, left.b = right.b
Distinct:
SubqueryAlias: left
LeftSemi Join: left.a = right.a, left.b = right.b
Distinct:
SubqueryAlias: left
Projection: Int64(0) AS a, Int64(0) AS b
EmptyRelation: rows=1
SubqueryAlias: right
Projection: Int64(1) AS a, Int64(1) AS b
EmptyRelation: rows=1
SubqueryAlias: right
Projection: count(*) AS a, count(*) AS b
Aggregate: groupBy=[[]], aggr=[[count(*)]]
TableScan: orders
"
);
}

#[test]
fn except_with_duplicate_expressions() {
let sql = "\
SELECT 0 a, 0 b \
EXCEPT SELECT 1, 1 \
EXCEPT SELECT count(*), count(*) FROM orders";
let plan = logical_plan(sql).unwrap();
assert_snapshot!(
plan,
@r"
LeftAnti Join: left.a = right.a, left.b = right.b
Distinct:
SubqueryAlias: left
LeftAnti Join: left.a = right.a, left.b = right.b
Distinct:
SubqueryAlias: left
Projection: Int64(0) AS a, Int64(0) AS b
EmptyRelation: rows=1
SubqueryAlias: right
Projection: Int64(1) AS a, Int64(1) AS b
EmptyRelation: rows=1
SubqueryAlias: right
Projection: count(*) AS a, count(*) AS b
Aggregate: groupBy=[[]], aggr=[[count(*)]]
TableScan: orders
"
);
}

#[test]
fn empty_over() {
let sql = "SELECT order_id, MAX(order_id) OVER () from orders";
Expand Down
24 changes: 24 additions & 0 deletions datafusion/sqllogictest/test_files/union.slt
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,30 @@ Bob_new
John
John_new

# Test UNION ALL with unaliased duplicate literal values on the right side.
# The second projection will inherit field names from the first one, and so
# pass the unique projection expression name check.
query TII rowsort
SELECT name, 1 as table, 1 as row FROM t1 WHERE id = 1
UNION ALL
SELECT name, 2, 2 FROM t2 WHERE id = 2
----
Alex 1 1
Bob 2 2

# Test nested UNION, EXCEPT, INTERSECT with duplicate unaliased literals.
# Only the first SELECT has column aliases, which should propagate to all projections.
query III rowsort
SELECT 1 as a, 0 as b, 0 as c
UNION ALL
((SELECT 2, 0, 0 UNION ALL SELECT 3, 0, 0) EXCEPT SELECT 3, 0, 0)
UNION ALL
(SELECT 4, 0, 0 INTERSECT SELECT 4, 0, 0)
----
1 0 0
2 0 0
4 0 0

# Plan is unnested
query TT
EXPLAIN SELECT name FROM t1 UNION ALL (SELECT name from t2 UNION ALL SELECT name || '_new' from t2)
Expand Down
Loading