diff --git a/datafusion/sql/src/unparser/dialect.rs b/datafusion/sql/src/unparser/dialect.rs index d9344622405fc..5a545bfaa5319 100644 --- a/datafusion/sql/src/unparser/dialect.rs +++ b/datafusion/sql/src/unparser/dialect.rs @@ -416,9 +416,11 @@ impl PostgreSqlDialect { }; Ok(Some(ast::Expr::AnyOp { - left: Box::new(unparser.expr_to_sql(needle)?), + // Recurse through the annotated entry point so the stack-growth + // protection engages on nested arguments; see issue #23056. + left: Box::new(unparser.expr_to_sql_with_nesting(needle)?), compare_op: BinaryOperator::Eq, - right: Box::new(unparser.expr_to_sql(haystack)?), + right: Box::new(unparser.expr_to_sql_with_nesting(haystack)?), is_some: false, })) } diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index d83c6b6e13bb7..7ce0948c990f4 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -31,6 +31,7 @@ use std::vec; use super::Unparser; use super::dialect::IntervalStyle; +use crate::stack::StackGuard; use arrow::array::{ ArrayRef, Date32Array, Date64Array, PrimitiveArray, types::{ @@ -94,6 +95,26 @@ const IS: &BinaryOperator = &BinaryOperator::BitwiseAnd; impl Unparser<'_> { pub fn expr_to_sql(&self, expr: &Expr) -> Result { + // Unparsing recurses once per nesting level. The function-argument and + // dialect scalar-function-override paths cost more per level than the + // default `recursive` red zone, so without raising the minimum stack + // size the stack-growing trampoline engages too late and the OS stack + // overflows on deeply nested expressions (issue #23056). The size + // mirrors the planner's `StackGuard` usage in `query.rs`. + let _guard = StackGuard::new(256 * 1024); + self.expr_to_sql_with_nesting(expr) + } + + /// Recursive entry point shared by the public [`Self::expr_to_sql`] and the + /// internal recursion sites (scalar-function arguments, arrays, maps, and + /// dialect scalar-function overrides). + /// + /// This carries the `recursive` annotation so every nesting level becomes a + /// stack-growth checkpoint. Internal recursion must call this rather than + /// the public [`Self::expr_to_sql`]: the public entry point is not + /// annotated and would re-install the [`StackGuard`] on every level. + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] + pub(crate) fn expr_to_sql_with_nesting(&self, expr: &Expr) -> Result { let mut root_expr = self.expr_to_sql_inner(expr)?; if self.pretty { root_expr = self.remove_unnecessary_nesting(root_expr, LOWEST, LOWEST); @@ -660,7 +681,7 @@ impl Unparser<'_> { fn make_array_to_sql(&self, args: &[Expr]) -> Result { let args = args .iter() - .map(|e| self.expr_to_sql(e)) + .map(|e| self.expr_to_sql_with_nesting(e)) .collect::>>()?; Ok(ast::Expr::Array(Array { elem: args, @@ -687,8 +708,8 @@ impl Unparser<'_> { 2, "array_element must have exactly 2 arguments" ); - let array = self.expr_to_sql(&args[0])?; - let index = self.expr_to_sql(&args[1])?; + let array = self.expr_to_sql_with_nesting(&args[0])?; + let index = self.expr_to_sql_with_nesting(&args[1])?; Ok(ast::Expr::CompoundFieldAccess { root: Box::new(array), access_chain: vec![ast::AccessExpr::Subscript(Subscript::Index { index })], @@ -711,7 +732,7 @@ impl Unparser<'_> { Ok(ast::DictionaryField { key, - value: Box::new(self.expr_to_sql(&chunk[1])?), + value: Box::new(self.expr_to_sql_with_nesting(&chunk[1])?), }) }) .collect::>>()?; @@ -781,7 +802,8 @@ impl Unparser<'_> { fn map_to_sql(&self, args: &[Expr]) -> Result { assert_eq_or_internal_err!(args.len(), 2, "map must have exactly 2 arguments"); - let ast::Expr::Array(Array { elem: keys, .. }) = self.expr_to_sql(&args[0])? + let ast::Expr::Array(Array { elem: keys, .. }) = + self.expr_to_sql_with_nesting(&args[0])? else { return internal_err!( "map expects first argument to be an array, but received: {:?}", @@ -789,7 +811,8 @@ impl Unparser<'_> { ); }; - let ast::Expr::Array(Array { elem: values, .. }) = self.expr_to_sql(&args[1])? + let ast::Expr::Array(Array { elem: values, .. }) = + self.expr_to_sql_with_nesting(&args[1])? else { return internal_err!( "map expects second argument to be an array, but received: {:?}", @@ -923,7 +946,7 @@ impl Unparser<'_> { ) { Ok(ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Wildcard)) } else { - self.expr_to_sql(e) + self.expr_to_sql_with_nesting(e) .map(|e| ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(e))) } }) @@ -968,6 +991,7 @@ impl Unparser<'_> { /// /// Also note that when fetching the precedence of a nested expression, we ignore other nested /// expressions, so precedence of expr `(a * (b + c))` equals `*` and not `+`. + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] fn remove_unnecessary_nesting( &self, expr: ast::Expr, @@ -3278,6 +3302,61 @@ mod tests { Ok(()) } + /// Regression test for https://github.com/apache/datafusion/issues/23056 + /// + /// Deeply-nested expressions whose unparse path routes through scalar + /// function arguments and dialect scalar-function overrides used to + /// overflow the OS stack even with `recursive_protection` enabled, + /// because the per-level stack cost of those paths exceeds the default + /// `recursive` red zone and the unparser installed no [`StackGuard`]. + /// + /// This test only asserts the protected behavior, so it is gated on the + /// `recursive_protection` feature. Without that feature the unparser is + /// not stack-safe by design and a deep enough expression will overflow. + #[cfg(feature = "recursive_protection")] + #[test] + fn test_deeply_nested_expr_does_not_overflow_stack() { + // Far deeper than the ~60 levels that overflow without protection, but + // bounded so the trampoline's heap stacks stay reasonable in debug. + const DEPTH: usize = 2_000; + + // Run on an explicit, realistically-sized thread stack. The work is + // performed on a spawned thread so an overflow (in the unfixed code) + // aborts the process and fails the test deterministically rather than + // depending on the harness thread's stack size. + let handle = std::thread::Builder::new() + .stack_size(2 * 1024 * 1024) + .spawn(|| { + // 1. Linear chain through a dialect scalar-function override: + // array_has(array_has(... array_has(col, 'x') ...), 'x'). + // PostgreSqlDialect unparses array_has via array_has_to_sql_any, + // which recurses back into the unparser for each argument. + let mut nested_fn: Expr = col("c"); + for _ in 0..DEPTH { + nested_fn = array_has(nested_fn, lit("x")); + } + let pg = PostgreSqlDialect {}; + Unparser::new(&pg) + .expr_to_sql(&nested_fn) + .expect("deeply nested scalar function should unparse"); + + // 2. Linear chain of plain binary operators, exercising the + // inner -> inner recursion on the default dialect. + let mut nested_binary: Expr = col("c"); + for _ in 0..DEPTH { + nested_binary = nested_binary + lit(1); + } + Unparser::default() + .expr_to_sql(&nested_binary) + .expect("deeply nested binary expression should unparse"); + }) + .unwrap(); + + // If the unparser overflows, the process aborts and this join is never + // reached; otherwise the spawned thread returns cleanly. + handle.join().expect("unparsing thread should not panic"); + } + #[test] fn test_window_func_support_window_frame() -> Result<()> { let default_dialect: Arc = diff --git a/datafusion/sql/src/unparser/utils.rs b/datafusion/sql/src/unparser/utils.rs index 732e030b335d8..1cc023d1125f4 100644 --- a/datafusion/sql/src/unparser/utils.rs +++ b/datafusion/sql/src/unparser/utils.rs @@ -448,7 +448,7 @@ pub(crate) fn date_part_to_sql( ) -> Result> { match (style, date_part_args.len()) { (DateFieldExtractStyle::Extract, 2) => { - let date_expr = unparser.expr_to_sql(&date_part_args[1])?; + let date_expr = unparser.expr_to_sql_with_nesting(&date_part_args[1])?; if let Expr::Literal(ScalarValue::Utf8(Some(field)), _) = &date_part_args[0] { let field = match field.to_lowercase().as_str() { "year" => ast::DateTimeField::Year, @@ -468,7 +468,7 @@ pub(crate) fn date_part_to_sql( } } (DateFieldExtractStyle::Strftime, 2) => { - let column = unparser.expr_to_sql(&date_part_args[1])?; + let column = unparser.expr_to_sql_with_nesting(&date_part_args[1])?; if let Expr::Literal(ScalarValue::Utf8(Some(field)), _) = &date_part_args[0] { let field = match field.to_lowercase().as_str() {