Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 128 additions & 13 deletions datafusion/core/tests/sql/unparser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,10 @@ use datafusion_common::Column;
use datafusion_expr::Expr;
use datafusion_physical_plan::ExecutionPlanProperties;
use datafusion_sql::unparser::Unparser;
use datafusion_sql::unparser::dialect::DefaultDialect;
use datafusion_sql::unparser::dialect::{
DefaultDialect, Dialect, DuckDBDialect, MySqlDialect, PostgreSqlDialect,
SqliteDialect,
};
use itertools::Itertools;
use recursive::{set_minimum_stack_size, set_stack_allocation_size};

Expand Down Expand Up @@ -142,6 +145,35 @@ fn tpch_queries() -> Vec<TestQuery> {
queries
}

/// Regression queries for specific bugs found in the unparser.
fn regression_queries() -> Vec<TestQuery> {
vec![
// Stacked Projections from CSE + PostgreSqlDialect's required derived table
// alias ("derived_projection") break SQL roundtrip: outer column refs keep the
// original SubqueryAlias qualifier ("base") which doesn't match the alias.
TestQuery {
name: "cse_derived_projection".to_string(),
sql: "\
WITH base AS (SELECT name, salary FROM t) \
SELECT name, \
CASE WHEN SUM(salary) > 0 THEN 1 ELSE 0 END AS x, \
CASE WHEN SUM(salary) > 0 THEN SUM(salary) ELSE 0 END AS y \
FROM base GROUP BY name"
.to_string(),
},
]
}

/// Create a new SessionContext for regression tests.
async fn regression_test_context() -> Result<SessionContext> {
let ctx = SessionContext::new();
ctx.sql("CREATE TABLE t (name TEXT, salary DOUBLE) AS VALUES ('a', 1.0), ('b', 2.0)")
.await?
.collect()
.await?;
Ok(ctx)
}

/// Create a new SessionContext for testing that has all Clickbench tables registered.
async fn clickbench_test_context() -> Result<SessionContext> {
let ctx = SessionContext::new();
Expand Down Expand Up @@ -284,22 +316,28 @@ impl TestCaseResult {
///
/// This is the core test logic that:
/// 1. Parses the original SQL and creates a logical plan
/// 2. Unparses the logical plan back to SQL
/// 3. Executes both the original and unparsed queries
/// 4. Compares the results (sorting if the query has no ORDER BY)
///
/// This always uses [`DefaultDialect`] for unparsing.
/// 2. Optionally optimizes the plan (to test unparser against optimizer output like CSE)
/// 3. Unparses the logical plan back to SQL using the given dialect
/// 4. Executes both the original and unparsed queries
/// 5. Compares the results (sorting if the query has no ORDER BY)
///
/// # Arguments
///
/// * `ctx` - Session context with tables registered
/// * `original` - The original SQL query to test
/// * `dialect` - The unparser dialect to use
/// * `optimize` - Whether to optimize the logical plan before unparsing
///
/// # Returns
///
/// A [`TestCaseResult`] indicating success or the specific failure mode.
async fn collect_results(ctx: &SessionContext, original: &str) -> TestCaseResult {
let unparser = Unparser::new(&DefaultDialect {});
async fn collect_results(
ctx: &SessionContext,
original: &str,
dialect: &dyn Dialect,
optimize: bool,
) -> TestCaseResult {
let unparser = Unparser::new(dialect);

// Parse and create logical plan from original SQL
let df = match ctx.sql(original).await {
Expand All @@ -312,8 +350,23 @@ async fn collect_results(ctx: &SessionContext, original: &str) -> TestCaseResult
}
};

// Optionally optimize the plan before unparsing
let plan = if optimize {
match ctx.state().optimize(df.logical_plan()) {
Ok(optimized) => optimized,
Err(e) => {
return TestCaseResult::ExecutionError {
original: original.to_string(),
error: format!("Failed to optimize plan: {e}"),
};
}
}
} else {
df.logical_plan().clone()
};

// Unparse the logical plan back to SQL
let unparsed = match unparser.plan_to_sql(df.logical_plan()) {
let unparsed = match unparser.plan_to_sql(&plan) {
Ok(sql) => format!("{sql:#}"),
Err(e) => {
return TestCaseResult::UnparseError {
Expand Down Expand Up @@ -419,6 +472,8 @@ async fn run_roundtrip_tests<F, Fut>(
suite_name: &str,
queries: Vec<TestQuery>,
create_context: F,
dialect: &dyn Dialect,
optimize: bool,
) where
F: Fn() -> Fut,
Fut: Future<Output = Result<SessionContext>>,
Expand All @@ -433,7 +488,7 @@ async fn run_roundtrip_tests<F, Fut>(
continue;
}
};
let result = collect_results(&ctx, &sql.sql).await;
let result = collect_results(&ctx, &sql.sql, dialect, optimize).await;
if result.is_failure() {
println!("\x1b[31m✗\x1b[0m {} query: {}", suite_name, sql.name);
errors.push(result.format_error(&sql.name));
Expand All @@ -451,16 +506,76 @@ async fn run_roundtrip_tests<F, Fut>(
}
}

/// Returns all dialects to test, paired with their display names.
fn all_dialects() -> Vec<(&'static str, Box<dyn Dialect>)> {
vec![
("Default", Box::new(DefaultDialect {})),
("PostgreSQL", Box::new(PostgreSqlDialect {})),
("MySQL", Box::new(MySqlDialect {})),
("SQLite", Box::new(SqliteDialect {})),
("DuckDB", Box::new(DuckDBDialect::default())),
// BigQuery has known issues with col_alias_overrides encoding
// special characters in CSE-generated column names (e.g.
// "sum(base.salary)" becomes "sum_40base_46salary_41"), which
// breaks roundtrip tests with optimized plans. Should be fixed
// in the BigQuery dialect's unparser.
// ("BigQuery", Box::new(BigQueryDialect {})),
]
}

#[tokio::test]
async fn test_clickbench_unparser_roundtrip() {
run_roundtrip_tests("Clickbench", clickbench_queries(), clickbench_test_context)
.await;
for (dialect_name, dialect) in all_dialects() {
for optimize in [false, true] {
let opt_label = if optimize { ", optimized" } else { "" };
let suite = format!("Clickbench ({dialect_name}{opt_label})");
run_roundtrip_tests(
&suite,
clickbench_queries(),
clickbench_test_context,
dialect.as_ref(),
optimize,
)
.await;
}
}
}

#[tokio::test]
async fn test_tpch_unparser_roundtrip() {
// Grow stacker segments earlier to avoid deep unparser recursion overflow in q20.
set_minimum_stack_size(512 * 1024);
set_stack_allocation_size(8 * 1024 * 1024);
run_roundtrip_tests("TPC-H", tpch_queries(), tpch_test_context).await;
for (dialect_name, dialect) in all_dialects() {
for optimize in [false, true] {
let opt_label = if optimize { ", optimized" } else { "" };
let suite = format!("TPC-H ({dialect_name}{opt_label})");
run_roundtrip_tests(
&suite,
tpch_queries(),
tpch_test_context,
dialect.as_ref(),
optimize,
)
.await;
}
}
}

#[tokio::test]
async fn test_regression_unparser_roundtrip() {
for (dialect_name, dialect) in all_dialects() {
for optimize in [false, true] {
let opt_label = if optimize { ", optimized" } else { "" };
let suite = format!("Regression ({dialect_name}{opt_label})");
run_roundtrip_tests(
&suite,
regression_queries(),
regression_test_context,
dialect.as_ref(),
optimize,
)
.await;
}
}
}
62 changes: 60 additions & 2 deletions datafusion/sql/src/unparser/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ use datafusion_expr::{
LogicalPlanBuilder, Operator, Projection, SortExpr, TableScan, Unnest,
UserDefinedLogicalNode, expr::Alias,
};
use sqlparser::ast::{self, Ident, OrderByKind, SetExpr, TableAliasColumnDef};
use sqlparser::ast::{self, Ident, OrderByKind, SetExpr, TableAliasColumnDef, VisitMut};
use std::ops::ControlFlow;
use std::{sync::Arc, vec};

/// Convert a DataFusion [`LogicalPlan`] to [`ast::Statement`]
Expand Down Expand Up @@ -413,8 +414,21 @@ impl Unparser<'_> {
};
// Projection can be top-level plan for derived table
if select.already_projected() {
let alias = "derived_projection";
if self.dialect.requires_derived_table_alias() {
// When wrapping this projection as a derived subquery
// with an alias, strip qualifiers from the outer
// SELECT items so they don't reference tables that are
// now hidden inside the subquery.
// Unqualified column names are valid because the derived
// table is the only FROM source.
let items = select.pop_projections();
let dequalified =
items.into_iter().map(dequalify_select_item).collect();
select.projection(dequalified);
}
return self.derive_with_dialect_alias(
"derived_projection",
alias,
plan,
relation,
unnest_input_type
Expand Down Expand Up @@ -1415,6 +1429,50 @@ impl Unparser<'_> {
}
}

/// Strips table qualifiers from column references in a [`ast::SelectItem`].
///
/// When a projection is wrapped in a derived subquery (e.g. `(SELECT ...) AS alias`),
/// column references in the outer SELECT that used the original table qualifier
/// become invalid because those tables are now hidden inside the subquery.
/// Stripping the qualifier makes them unqualified, which is valid because the
/// derived table is the sole FROM source.
///
/// For example, `"base"."name"` becomes `"name"`.
///
/// This only affects [`ast::Expr::CompoundIdentifier`] nodes with 2+ parts,
/// keeping only the final column name identifier.
///
/// Safety: this strips *all* qualifiers, not just stale ones, so it is only
/// correct when the derived table is the sole FROM source (no JOINs in the
/// outer query). The `already_projected()` guard in the caller ensures this.
fn dequalify_select_item(mut item: ast::SelectItem) -> ast::SelectItem {
struct Dequalifier;

impl ast::VisitorMut for Dequalifier {
type Break = ();

fn pre_visit_expr(&mut self, expr: &mut ast::Expr) -> ControlFlow<Self::Break> {
if let ast::Expr::CompoundIdentifier(idents) = expr
&& idents.len() >= 2
{
let col_name = idents.last().unwrap().clone();
*expr = ast::Expr::Identifier(col_name);
}
ControlFlow::Continue(())
}
}

let mut visitor = Dequalifier;
match &mut item {
ast::SelectItem::UnnamedExpr(expr)
| ast::SelectItem::ExprWithAlias { expr, .. } => {
let _ = expr.visit(&mut visitor);
}
ast::SelectItem::QualifiedWildcard(..) | ast::SelectItem::Wildcard(..) => {}
}
item
}

impl From<BuilderError> for DataFusionError {
fn from(e: BuilderError) -> Self {
DataFusionError::External(Box::new(e))
Expand Down
4 changes: 2 additions & 2 deletions datafusion/sql/tests/cases/plan_to_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2634,14 +2634,14 @@ fn test_unparse_window() -> Result<()> {
let sql = unparser.plan_to_sql(&plan)?;
assert_snapshot!(
sql,
@r#"SELECT "test"."k", "test"."v", "rank() PARTITION BY [test.k] ORDER BY [test.v ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING" FROM (SELECT "test"."k" AS "k", "test"."v" AS "v", rank() OVER (PARTITION BY "test"."k" ORDER BY "test"."v" ASC NULLS FIRST ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS "rank() PARTITION BY [test.k] ORDER BY [test.v ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING" FROM "test") AS "test" WHERE ("rank() PARTITION BY [test.k] ORDER BY [test.v ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING" = 1)"#
@r#"SELECT "k", "v", "rank() PARTITION BY [test.k] ORDER BY [test.v ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING" FROM (SELECT "test"."k" AS "k", "test"."v" AS "v", rank() OVER (PARTITION BY "test"."k" ORDER BY "test"."v" ASC NULLS FIRST ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS "rank() PARTITION BY [test.k] ORDER BY [test.v ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING" FROM "test") AS "test" WHERE ("rank() PARTITION BY [test.k] ORDER BY [test.v ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING" = 1)"#
);

let unparser = Unparser::new(&UnparserMySqlDialect {});
let sql = unparser.plan_to_sql(&plan)?;
assert_snapshot!(
sql,
@"SELECT `test`.`k`, `test`.`v`, `rank() PARTITION BY [test.k] ORDER BY [test.v ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING` FROM (SELECT `test`.`k` AS `k`, `test`.`v` AS `v`, rank() OVER (PARTITION BY `test`.`k` ORDER BY `test`.`v` ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS `rank() PARTITION BY [test.k] ORDER BY [test.v ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING` FROM `test`) AS `test` WHERE (`rank() PARTITION BY [test.k] ORDER BY [test.v ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING` = 1)"
@"SELECT `k`, `v`, `rank() PARTITION BY [test.k] ORDER BY [test.v ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING` FROM (SELECT `test`.`k` AS `k`, `test`.`v` AS `v`, rank() OVER (PARTITION BY `test`.`k` ORDER BY `test`.`v` ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS `rank() PARTITION BY [test.k] ORDER BY [test.v ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING` FROM `test`) AS `test` WHERE (`rank() PARTITION BY [test.k] ORDER BY [test.v ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING` = 1)"
);

let unparser = Unparser::new(&SqliteDialect {});
Expand Down
Loading