Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions datafusion/sql/src/unparser/dialect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -416,9 +416,11 @@ impl PostgreSqlDialect {
};

Ok(Some(ast::Expr::AnyOp {
left: Box::new(unparser.expr_to_sql(needle)?),
// Recurse through the annotated entry point so the stack-growth
// protection engages on nested arguments; see issue #23056.
left: Box::new(unparser.expr_to_sql_with_nesting(needle)?),
compare_op: BinaryOperator::Eq,
right: Box::new(unparser.expr_to_sql(haystack)?),
right: Box::new(unparser.expr_to_sql_with_nesting(haystack)?),
is_some: false,
}))
}
Expand Down
93 changes: 86 additions & 7 deletions datafusion/sql/src/unparser/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ use std::vec;

use super::Unparser;
use super::dialect::IntervalStyle;
use crate::stack::StackGuard;
use arrow::array::{
ArrayRef, Date32Array, Date64Array, PrimitiveArray,
types::{
Expand Down Expand Up @@ -94,6 +95,26 @@ const IS: &BinaryOperator = &BinaryOperator::BitwiseAnd;

impl Unparser<'_> {
pub fn expr_to_sql(&self, expr: &Expr) -> Result<ast::Expr> {
// Unparsing recurses once per nesting level. The function-argument and
// dialect scalar-function-override paths cost more per level than the
// default `recursive` red zone, so without raising the minimum stack
// size the stack-growing trampoline engages too late and the OS stack
// overflows on deeply nested expressions (issue #23056). The size
// mirrors the planner's `StackGuard` usage in `query.rs`.
let _guard = StackGuard::new(256 * 1024);
self.expr_to_sql_with_nesting(expr)
}

/// Recursive entry point shared by the public [`Self::expr_to_sql`] and the
/// internal recursion sites (scalar-function arguments, arrays, maps, and
/// dialect scalar-function overrides).
///
/// This carries the `recursive` annotation so every nesting level becomes a
/// stack-growth checkpoint. Internal recursion must call this rather than
/// the public [`Self::expr_to_sql`]: the public entry point is not
/// annotated and would re-install the [`StackGuard`] on every level.
#[cfg_attr(feature = "recursive_protection", recursive::recursive)]
pub(crate) fn expr_to_sql_with_nesting(&self, expr: &Expr) -> Result<ast::Expr> {
let mut root_expr = self.expr_to_sql_inner(expr)?;
if self.pretty {
root_expr = self.remove_unnecessary_nesting(root_expr, LOWEST, LOWEST);
Expand Down Expand Up @@ -660,7 +681,7 @@ impl Unparser<'_> {
fn make_array_to_sql(&self, args: &[Expr]) -> Result<ast::Expr> {
let args = args
.iter()
.map(|e| self.expr_to_sql(e))
.map(|e| self.expr_to_sql_with_nesting(e))
.collect::<Result<Vec<_>>>()?;
Ok(ast::Expr::Array(Array {
elem: args,
Expand All @@ -687,8 +708,8 @@ impl Unparser<'_> {
2,
"array_element must have exactly 2 arguments"
);
let array = self.expr_to_sql(&args[0])?;
let index = self.expr_to_sql(&args[1])?;
let array = self.expr_to_sql_with_nesting(&args[0])?;
let index = self.expr_to_sql_with_nesting(&args[1])?;
Ok(ast::Expr::CompoundFieldAccess {
root: Box::new(array),
access_chain: vec![ast::AccessExpr::Subscript(Subscript::Index { index })],
Expand All @@ -711,7 +732,7 @@ impl Unparser<'_> {

Ok(ast::DictionaryField {
key,
value: Box::new(self.expr_to_sql(&chunk[1])?),
value: Box::new(self.expr_to_sql_with_nesting(&chunk[1])?),
})
})
.collect::<Result<Vec<_>>>()?;
Expand Down Expand Up @@ -781,15 +802,17 @@ impl Unparser<'_> {
fn map_to_sql(&self, args: &[Expr]) -> Result<ast::Expr> {
assert_eq_or_internal_err!(args.len(), 2, "map must have exactly 2 arguments");

let ast::Expr::Array(Array { elem: keys, .. }) = self.expr_to_sql(&args[0])?
let ast::Expr::Array(Array { elem: keys, .. }) =
self.expr_to_sql_with_nesting(&args[0])?
else {
return internal_err!(
"map expects first argument to be an array, but received: {:?}",
&args[0]
);
};

let ast::Expr::Array(Array { elem: values, .. }) = self.expr_to_sql(&args[1])?
let ast::Expr::Array(Array { elem: values, .. }) =
self.expr_to_sql_with_nesting(&args[1])?
else {
return internal_err!(
"map expects second argument to be an array, but received: {:?}",
Expand Down Expand Up @@ -923,7 +946,7 @@ impl Unparser<'_> {
) {
Ok(ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Wildcard))
} else {
self.expr_to_sql(e)
self.expr_to_sql_with_nesting(e)
.map(|e| ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(e)))
}
})
Expand Down Expand Up @@ -968,6 +991,7 @@ impl Unparser<'_> {
///
/// Also note that when fetching the precedence of a nested expression, we ignore other nested
/// expressions, so precedence of expr `(a * (b + c))` equals `*` and not `+`.
#[cfg_attr(feature = "recursive_protection", recursive::recursive)]
fn remove_unnecessary_nesting(
&self,
expr: ast::Expr,
Expand Down Expand Up @@ -3278,6 +3302,61 @@ mod tests {
Ok(())
}

/// Regression test for https://github.com/apache/datafusion/issues/23056
///
/// Deeply-nested expressions whose unparse path routes through scalar
/// function arguments and dialect scalar-function overrides used to
/// overflow the OS stack even with `recursive_protection` enabled,
/// because the per-level stack cost of those paths exceeds the default
/// `recursive` red zone and the unparser installed no [`StackGuard`].
///
/// This test only asserts the protected behavior, so it is gated on the
/// `recursive_protection` feature. Without that feature the unparser is
/// not stack-safe by design and a deep enough expression will overflow.
#[cfg(feature = "recursive_protection")]
#[test]
fn test_deeply_nested_expr_does_not_overflow_stack() {
// Far deeper than the ~60 levels that overflow without protection, but
// bounded so the trampoline's heap stacks stay reasonable in debug.
const DEPTH: usize = 2_000;

// Run on an explicit, realistically-sized thread stack. The work is
// performed on a spawned thread so an overflow (in the unfixed code)
// aborts the process and fails the test deterministically rather than
// depending on the harness thread's stack size.
let handle = std::thread::Builder::new()
.stack_size(2 * 1024 * 1024)
.spawn(|| {
// 1. Linear chain through a dialect scalar-function override:
// array_has(array_has(... array_has(col, 'x') ...), 'x').
// PostgreSqlDialect unparses array_has via array_has_to_sql_any,
// which recurses back into the unparser for each argument.
let mut nested_fn: Expr = col("c");
for _ in 0..DEPTH {
nested_fn = array_has(nested_fn, lit("x"));
}
let pg = PostgreSqlDialect {};
Unparser::new(&pg)
.expr_to_sql(&nested_fn)
.expect("deeply nested scalar function should unparse");

// 2. Linear chain of plain binary operators, exercising the
// inner -> inner recursion on the default dialect.
let mut nested_binary: Expr = col("c");
for _ in 0..DEPTH {
nested_binary = nested_binary + lit(1);
}
Unparser::default()
.expr_to_sql(&nested_binary)
.expect("deeply nested binary expression should unparse");
})
.unwrap();

// If the unparser overflows, the process aborts and this join is never
// reached; otherwise the spawned thread returns cleanly.
handle.join().expect("unparsing thread should not panic");
}

#[test]
fn test_window_func_support_window_frame() -> Result<()> {
let default_dialect: Arc<dyn Dialect> =
Expand Down
4 changes: 2 additions & 2 deletions datafusion/sql/src/unparser/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,7 @@ pub(crate) fn date_part_to_sql(
) -> Result<Option<ast::Expr>> {
match (style, date_part_args.len()) {
(DateFieldExtractStyle::Extract, 2) => {
let date_expr = unparser.expr_to_sql(&date_part_args[1])?;
let date_expr = unparser.expr_to_sql_with_nesting(&date_part_args[1])?;
if let Expr::Literal(ScalarValue::Utf8(Some(field)), _) = &date_part_args[0] {
let field = match field.to_lowercase().as_str() {
"year" => ast::DateTimeField::Year,
Expand All @@ -468,7 +468,7 @@ pub(crate) fn date_part_to_sql(
}
}
(DateFieldExtractStyle::Strftime, 2) => {
let column = unparser.expr_to_sql(&date_part_args[1])?;
let column = unparser.expr_to_sql_with_nesting(&date_part_args[1])?;

if let Expr::Literal(ScalarValue::Utf8(Some(field)), _) = &date_part_args[0] {
let field = match field.to_lowercase().as_str() {
Expand Down
Loading