Skip to content

Commit 224a532

Browse files
committed
Fix Unparser for stacked Projections & Expand Unparser tests
1 parent 57b275a commit 224a532

3 files changed

Lines changed: 194 additions & 17 deletions

File tree

datafusion/core/tests/sql/unparser.rs

Lines changed: 127 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,9 @@ use datafusion_common::Column;
4545
use datafusion_expr::Expr;
4646
use datafusion_physical_plan::ExecutionPlanProperties;
4747
use datafusion_sql::unparser::Unparser;
48-
use datafusion_sql::unparser::dialect::DefaultDialect;
48+
use datafusion_sql::unparser::dialect::{
49+
DefaultDialect, Dialect, DuckDBDialect, MySqlDialect, PostgreSqlDialect, SqliteDialect,
50+
};
4951
use itertools::Itertools;
5052
use recursive::{set_minimum_stack_size, set_stack_allocation_size};
5153

@@ -142,6 +144,35 @@ fn tpch_queries() -> Vec<TestQuery> {
142144
queries
143145
}
144146

147+
/// Regression queries for specific bugs found in the unparser.
148+
fn regression_queries() -> Vec<TestQuery> {
149+
vec![
150+
// Stacked Projections from CSE + PostgreSqlDialect's required derived table
151+
// alias ("derived_projection") break SQL roundtrip: outer column refs keep the
152+
// original SubqueryAlias qualifier ("base") which doesn't match the alias.
153+
TestQuery {
154+
name: "cse_derived_projection".to_string(),
155+
sql: "\
156+
WITH base AS (SELECT name, salary FROM t) \
157+
SELECT name, \
158+
CASE WHEN SUM(salary) > 0 THEN 1 ELSE 0 END AS x, \
159+
CASE WHEN SUM(salary) > 0 THEN SUM(salary) ELSE 0 END AS y \
160+
FROM base GROUP BY name"
161+
.to_string(),
162+
},
163+
]
164+
}
165+
166+
/// Create a new SessionContext for regression tests.
167+
async fn regression_test_context() -> Result<SessionContext> {
168+
let ctx = SessionContext::new();
169+
ctx.sql("CREATE TABLE t (name TEXT, salary DOUBLE) AS VALUES ('a', 1.0), ('b', 2.0)")
170+
.await?
171+
.collect()
172+
.await?;
173+
Ok(ctx)
174+
}
175+
145176
/// Create a new SessionContext for testing that has all Clickbench tables registered.
146177
async fn clickbench_test_context() -> Result<SessionContext> {
147178
let ctx = SessionContext::new();
@@ -284,22 +315,28 @@ impl TestCaseResult {
284315
///
285316
/// This is the core test logic that:
286317
/// 1. Parses the original SQL and creates a logical plan
287-
/// 2. Unparses the logical plan back to SQL
288-
/// 3. Executes both the original and unparsed queries
289-
/// 4. Compares the results (sorting if the query has no ORDER BY)
290-
///
291-
/// This always uses [`DefaultDialect`] for unparsing.
318+
/// 2. Optionally optimizes the plan (to test unparser against optimizer output like CSE)
319+
/// 3. Unparses the logical plan back to SQL using the given dialect
320+
/// 4. Executes both the original and unparsed queries
321+
/// 5. Compares the results (sorting if the query has no ORDER BY)
292322
///
293323
/// # Arguments
294324
///
295325
/// * `ctx` - Session context with tables registered
296326
/// * `original` - The original SQL query to test
327+
/// * `dialect` - The unparser dialect to use
328+
/// * `optimize` - Whether to optimize the logical plan before unparsing
297329
///
298330
/// # Returns
299331
///
300332
/// A [`TestCaseResult`] indicating success or the specific failure mode.
301-
async fn collect_results(ctx: &SessionContext, original: &str) -> TestCaseResult {
302-
let unparser = Unparser::new(&DefaultDialect {});
333+
async fn collect_results(
334+
ctx: &SessionContext,
335+
original: &str,
336+
dialect: &dyn Dialect,
337+
optimize: bool,
338+
) -> TestCaseResult {
339+
let unparser = Unparser::new(dialect);
303340

304341
// Parse and create logical plan from original SQL
305342
let df = match ctx.sql(original).await {
@@ -312,8 +349,23 @@ async fn collect_results(ctx: &SessionContext, original: &str) -> TestCaseResult
312349
}
313350
};
314351

352+
// Optionally optimize the plan before unparsing
353+
let plan = if optimize {
354+
match ctx.state().optimize(df.logical_plan()) {
355+
Ok(optimized) => optimized,
356+
Err(e) => {
357+
return TestCaseResult::ExecutionError {
358+
original: original.to_string(),
359+
error: format!("Failed to optimize plan: {e}"),
360+
};
361+
}
362+
}
363+
} else {
364+
df.logical_plan().clone()
365+
};
366+
315367
// Unparse the logical plan back to SQL
316-
let unparsed = match unparser.plan_to_sql(df.logical_plan()) {
368+
let unparsed = match unparser.plan_to_sql(&plan) {
317369
Ok(sql) => format!("{sql:#}"),
318370
Err(e) => {
319371
return TestCaseResult::UnparseError {
@@ -419,6 +471,8 @@ async fn run_roundtrip_tests<F, Fut>(
419471
suite_name: &str,
420472
queries: Vec<TestQuery>,
421473
create_context: F,
474+
dialect: &dyn Dialect,
475+
optimize: bool,
422476
) where
423477
F: Fn() -> Fut,
424478
Fut: Future<Output = Result<SessionContext>>,
@@ -433,7 +487,7 @@ async fn run_roundtrip_tests<F, Fut>(
433487
continue;
434488
}
435489
};
436-
let result = collect_results(&ctx, &sql.sql).await;
490+
let result = collect_results(&ctx, &sql.sql, dialect, optimize).await;
437491
if result.is_failure() {
438492
println!("\x1b[31m✗\x1b[0m {} query: {}", suite_name, sql.name);
439493
errors.push(result.format_error(&sql.name));
@@ -451,16 +505,76 @@ async fn run_roundtrip_tests<F, Fut>(
451505
}
452506
}
453507

508+
/// Returns all dialects to test, paired with their display names.
509+
fn all_dialects() -> Vec<(&'static str, Box<dyn Dialect>)> {
510+
vec![
511+
("Default", Box::new(DefaultDialect {})),
512+
("PostgreSQL", Box::new(PostgreSqlDialect {})),
513+
("MySQL", Box::new(MySqlDialect {})),
514+
("SQLite", Box::new(SqliteDialect {})),
515+
("DuckDB", Box::new(DuckDBDialect::default())),
516+
// BigQuery has known issues with col_alias_overrides encoding
517+
// special characters in CSE-generated column names (e.g.
518+
// "sum(base.salary)" becomes "sum_40base_46salary_41"), which
519+
// breaks roundtrip tests with optimized plans. Should be fixed
520+
// in the BigQuery dialect's unparser.
521+
// ("BigQuery", Box::new(BigQueryDialect {})),
522+
]
523+
}
524+
454525
#[tokio::test]
455526
async fn test_clickbench_unparser_roundtrip() {
456-
run_roundtrip_tests("Clickbench", clickbench_queries(), clickbench_test_context)
457-
.await;
527+
for (dialect_name, dialect) in all_dialects() {
528+
for optimize in [false, true] {
529+
let opt_label = if optimize { ", optimized" } else { "" };
530+
let suite = format!("Clickbench ({dialect_name}{opt_label})");
531+
run_roundtrip_tests(
532+
&suite,
533+
clickbench_queries(),
534+
clickbench_test_context,
535+
dialect.as_ref(),
536+
optimize,
537+
)
538+
.await;
539+
}
540+
}
458541
}
459542

460543
#[tokio::test]
461544
async fn test_tpch_unparser_roundtrip() {
462545
// Grow stacker segments earlier to avoid deep unparser recursion overflow in q20.
463546
set_minimum_stack_size(512 * 1024);
464547
set_stack_allocation_size(8 * 1024 * 1024);
465-
run_roundtrip_tests("TPC-H", tpch_queries(), tpch_test_context).await;
548+
for (dialect_name, dialect) in all_dialects() {
549+
for optimize in [false, true] {
550+
let opt_label = if optimize { ", optimized" } else { "" };
551+
let suite = format!("TPC-H ({dialect_name}{opt_label})");
552+
run_roundtrip_tests(
553+
&suite,
554+
tpch_queries(),
555+
tpch_test_context,
556+
dialect.as_ref(),
557+
optimize,
558+
)
559+
.await;
560+
}
561+
}
562+
}
563+
564+
#[tokio::test]
565+
async fn test_regression_unparser_roundtrip() {
566+
for (dialect_name, dialect) in all_dialects() {
567+
for optimize in [false, true] {
568+
let opt_label = if optimize { ", optimized" } else { "" };
569+
let suite = format!("Regression ({dialect_name}{opt_label})");
570+
run_roundtrip_tests(
571+
&suite,
572+
regression_queries(),
573+
regression_test_context,
574+
dialect.as_ref(),
575+
optimize,
576+
)
577+
.await;
578+
}
579+
}
466580
}

datafusion/sql/src/unparser/plan.rs

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ use datafusion_expr::{
4949
LogicalPlanBuilder, Operator, Projection, SortExpr, TableScan, Unnest,
5050
UserDefinedLogicalNode, expr::Alias,
5151
};
52-
use sqlparser::ast::{self, Ident, OrderByKind, SetExpr, TableAliasColumnDef};
52+
use sqlparser::ast::{self, Ident, OrderByKind, SetExpr, TableAliasColumnDef, VisitMut};
53+
use std::ops::ControlFlow;
5354
use std::{sync::Arc, vec};
5455

5556
/// Convert a DataFusion [`LogicalPlan`] to [`ast::Statement`]
@@ -413,8 +414,23 @@ impl Unparser<'_> {
413414
};
414415
// Projection can be top-level plan for derived table
415416
if select.already_projected() {
417+
let alias = "derived_projection";
418+
if self.dialect.requires_derived_table_alias() {
419+
// When wrapping this projection as a derived subquery
420+
// with an alias, strip qualifiers from the outer
421+
// SELECT items so they don't reference tables that are
422+
// now hidden inside the subquery.
423+
// Unqualified column names are valid because the derived
424+
// table is the only FROM source.
425+
let items = select.pop_projections();
426+
let dequalified = items
427+
.into_iter()
428+
.map(dequalify_select_item)
429+
.collect();
430+
select.projection(dequalified);
431+
}
416432
return self.derive_with_dialect_alias(
417-
"derived_projection",
433+
alias,
418434
plan,
419435
relation,
420436
unnest_input_type
@@ -1415,6 +1431,53 @@ impl Unparser<'_> {
14151431
}
14161432
}
14171433

1434+
/// Strips table qualifiers from column references in a [`ast::SelectItem`].
1435+
///
1436+
/// When a projection is wrapped in a derived subquery (e.g. `(SELECT ...) AS alias`),
1437+
/// column references in the outer SELECT that used the original table qualifier
1438+
/// become invalid because those tables are now hidden inside the subquery.
1439+
/// Stripping the qualifier makes them unqualified, which is valid because the
1440+
/// derived table is the sole FROM source.
1441+
///
1442+
/// For example, `"base"."name"` becomes `"name"`.
1443+
///
1444+
/// This only affects [`ast::Expr::CompoundIdentifier`] nodes with 2+ parts,
1445+
/// keeping only the final column name identifier.
1446+
///
1447+
/// Safety: this strips *all* qualifiers, not just stale ones, so it is only
1448+
/// correct when the derived table is the sole FROM source (no JOINs in the
1449+
/// outer query). The `already_projected()` guard in the caller ensures this.
1450+
fn dequalify_select_item(mut item: ast::SelectItem) -> ast::SelectItem {
1451+
struct Dequalifier;
1452+
1453+
impl ast::VisitorMut for Dequalifier {
1454+
type Break = ();
1455+
1456+
fn pre_visit_expr(
1457+
&mut self,
1458+
expr: &mut ast::Expr,
1459+
) -> ControlFlow<Self::Break> {
1460+
if let ast::Expr::CompoundIdentifier(idents) = expr {
1461+
if idents.len() >= 2 {
1462+
let col_name = idents.last().unwrap().clone();
1463+
*expr = ast::Expr::Identifier(col_name);
1464+
}
1465+
}
1466+
ControlFlow::Continue(())
1467+
}
1468+
}
1469+
1470+
let mut visitor = Dequalifier;
1471+
match &mut item {
1472+
ast::SelectItem::UnnamedExpr(expr)
1473+
| ast::SelectItem::ExprWithAlias { expr, .. } => {
1474+
let _ = expr.visit(&mut visitor);
1475+
}
1476+
ast::SelectItem::QualifiedWildcard(..) | ast::SelectItem::Wildcard(..) => {}
1477+
}
1478+
item
1479+
}
1480+
14181481
impl From<BuilderError> for DataFusionError {
14191482
fn from(e: BuilderError) -> Self {
14201483
DataFusionError::External(Box::new(e))

datafusion/sql/tests/cases/plan_to_sql.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2634,14 +2634,14 @@ fn test_unparse_window() -> Result<()> {
26342634
let sql = unparser.plan_to_sql(&plan)?;
26352635
assert_snapshot!(
26362636
sql,
2637-
@r#"SELECT "test"."k", "test"."v", "rank() PARTITION BY [test.k] ORDER BY [test.v ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING" FROM (SELECT "test"."k" AS "k", "test"."v" AS "v", rank() OVER (PARTITION BY "test"."k" ORDER BY "test"."v" ASC NULLS FIRST ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS "rank() PARTITION BY [test.k] ORDER BY [test.v ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING" FROM "test") AS "test" WHERE ("rank() PARTITION BY [test.k] ORDER BY [test.v ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING" = 1)"#
2637+
@r#"SELECT "k", "v", "rank() PARTITION BY [test.k] ORDER BY [test.v ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING" FROM (SELECT "test"."k" AS "k", "test"."v" AS "v", rank() OVER (PARTITION BY "test"."k" ORDER BY "test"."v" ASC NULLS FIRST ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS "rank() PARTITION BY [test.k] ORDER BY [test.v ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING" FROM "test") AS "test" WHERE ("rank() PARTITION BY [test.k] ORDER BY [test.v ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING" = 1)"#
26382638
);
26392639

26402640
let unparser = Unparser::new(&UnparserMySqlDialect {});
26412641
let sql = unparser.plan_to_sql(&plan)?;
26422642
assert_snapshot!(
26432643
sql,
2644-
@"SELECT `test`.`k`, `test`.`v`, `rank() PARTITION BY [test.k] ORDER BY [test.v ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING` FROM (SELECT `test`.`k` AS `k`, `test`.`v` AS `v`, rank() OVER (PARTITION BY `test`.`k` ORDER BY `test`.`v` ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS `rank() PARTITION BY [test.k] ORDER BY [test.v ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING` FROM `test`) AS `test` WHERE (`rank() PARTITION BY [test.k] ORDER BY [test.v ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING` = 1)"
2644+
@"SELECT `k`, `v`, `rank() PARTITION BY [test.k] ORDER BY [test.v ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING` FROM (SELECT `test`.`k` AS `k`, `test`.`v` AS `v`, rank() OVER (PARTITION BY `test`.`k` ORDER BY `test`.`v` ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS `rank() PARTITION BY [test.k] ORDER BY [test.v ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING` FROM `test`) AS `test` WHERE (`rank() PARTITION BY [test.k] ORDER BY [test.v ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING` = 1)"
26452645
);
26462646

26472647
let unparser = Unparser::new(&SqliteDialect {});

0 commit comments

Comments
 (0)