diff --git a/datafusion/core/tests/execution/logical_plan.rs b/datafusion/core/tests/execution/logical_plan.rs index 3eaa3fb2ed5e6..6104cec03f1a9 100644 --- a/datafusion/core/tests/execution/logical_plan.rs +++ b/datafusion/core/tests/execution/logical_plan.rs @@ -71,6 +71,7 @@ async fn count_only_nulls() -> Result<()> { order_by: vec![], null_treatment: None, }, + spans: Spans::new(), })], )?); diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 7e4308976169d..3bfa7863256e1 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -930,6 +930,8 @@ pub struct ScalarFunction { pub func: Arc, /// List of expressions to feed to the functions as arguments pub args: Vec, + /// Original source code location, if known + pub spans: Spans, } impl ScalarFunction { @@ -937,6 +939,11 @@ impl ScalarFunction { pub fn name(&self) -> &str { self.func.name() } + + /// Returns a mutable reference to the spans + pub fn spans_mut(&mut self) -> &mut Spans { + &mut self.spans + } } impl ScalarFunction { @@ -944,7 +951,11 @@ impl ScalarFunction { /// /// [`ScalarUDF`]: crate::ScalarUDF pub fn new_udf(udf: Arc, args: Vec) -> Self { - Self { func: udf, args } + Self { + func: udf, + args, + spans: Spans::new(), + } } } @@ -1094,6 +1105,8 @@ pub struct AggregateFunction { /// Name of the function pub func: Arc, pub params: AggregateFunctionParams, + /// Original source code location, if known + pub spans: Spans, } #[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] @@ -1127,8 +1140,14 @@ impl AggregateFunction { order_by, null_treatment, }, + spans: Spans::new(), } } + + /// Returns a mutable reference to the spans + pub fn spans_mut(&mut self) -> &mut Spans { + &mut self.spans + } } /// A function used as a SQL window function @@ -2287,6 +2306,8 @@ impl Expr { match self { Expr::Column(col) => Some(&col.spans), Expr::Not(inner) | Expr::Negative(inner) => inner.spans(), + Expr::ScalarFunction(func) => Some(&func.spans), + Expr::AggregateFunction(func) => Some(&func.spans), _ => None, } } @@ -2482,10 +2503,12 @@ impl NormalizeEq for Expr { Expr::ScalarFunction(ScalarFunction { func: self_func, args: self_args, + .. }), Expr::ScalarFunction(ScalarFunction { func: other_func, args: other_args, + .. }), ) => { self_func.name() == other_func.name() @@ -2506,6 +2529,7 @@ impl NormalizeEq for Expr { order_by: self_order_by, null_treatment: self_null_treatment, }, + .. }), Expr::AggregateFunction(AggregateFunction { func: other_func, @@ -2517,6 +2541,7 @@ impl NormalizeEq for Expr { order_by: other_order_by, null_treatment: other_null_treatment, }, + .. }), ) => { self_func.name() == other_func.name() @@ -2797,7 +2822,9 @@ impl HashNode for Expr { | Expr::TryCast(TryCast { expr: _expr, field }) => { field.hash(state); } - Expr::ScalarFunction(ScalarFunction { func, args: _args }) => { + Expr::ScalarFunction(ScalarFunction { + func, args: _args, .. + }) => { func.hash(state); } Expr::AggregateFunction(AggregateFunction { @@ -2810,6 +2837,7 @@ impl HashNode for Expr { order_by: _, null_treatment, }, + .. }) => { func.hash(state); distinct.hash(state); @@ -2952,7 +2980,7 @@ impl Display for SchemaDisplay<'_> { | Expr::OuterReferenceColumn(..) | Expr::Placeholder(_) | Expr::Wildcard { .. } => write!(f, "{}", self.0), - Expr::AggregateFunction(AggregateFunction { func, params }) => { + Expr::AggregateFunction(AggregateFunction { func, params, .. }) => { match func.schema_name(params) { Ok(name) => { write!(f, "{name}") @@ -3109,7 +3137,7 @@ impl Display for SchemaDisplay<'_> { Expr::Unnest(Unnest { expr }) => { write!(f, "UNNEST({})", SchemaDisplay(expr)) } - Expr::ScalarFunction(ScalarFunction { func, args }) => { + Expr::ScalarFunction(ScalarFunction { func, args, .. }) => { match func.schema_name(args) { Ok(name) => { write!(f, "{name}") @@ -3408,7 +3436,7 @@ impl Display for SqlDisplay<'_> { Ok(()) } - Expr::AggregateFunction(AggregateFunction { func, params }) => { + Expr::AggregateFunction(AggregateFunction { func, params, .. }) => { match func.human_display(params) { Ok(name) => { write!(f, "{name}") @@ -3639,7 +3667,7 @@ impl Display for Expr { } } } - Expr::AggregateFunction(AggregateFunction { func, params }) => { + Expr::AggregateFunction(AggregateFunction { func, params, .. }) => { match func.display_name(params) { Ok(name) => { write!(f, "{name}") @@ -4399,6 +4427,7 @@ mod test { let udf = Expr::ScalarFunction(ScalarFunction { func: Arc::new(ScalarUDF::new_from_impl(TestUDF {})), args: vec![expr(), expr()], + spans: Spans::new(), }); let Expr::ScalarFunction(scalar) = &udf else { unreachable!() diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 039bbad65a660..0fb1b2915f365 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -33,8 +33,8 @@ use arrow::datatypes::FieldRef; use arrow::datatypes::{DataType, Field}; use datafusion_common::datatype::FieldExt; use datafusion_common::{ - Column, DataFusionError, ExprSchema, Result, ScalarValue, Spans, TableReference, - not_impl_err, plan_datafusion_err, plan_err, + Column, DataFusionError, Diagnostic, ExprSchema, Result, ScalarValue, Span, Spans, + TableReference, not_impl_err, plan_datafusion_err, plan_err, }; use datafusion_expr_common::type_coercion::binary::BinaryTypeCoercer; use datafusion_functions_window_common::field::WindowUDFFieldArgs; @@ -549,13 +549,13 @@ impl ExprSchemable for Expr { match fun { WindowFunctionDefinition::AggregateUDF(udaf) => { let new_fields = - verify_function_arguments(udaf.as_ref(), &fields)?; + verify_function_arguments(udaf.as_ref(), &fields, None)?; let return_field = udaf.return_field(&new_fields)?; Ok(return_field) } WindowFunctionDefinition::WindowUDF(udwf) => { let new_fields = - verify_function_arguments(udwf.as_ref(), &fields)?; + verify_function_arguments(udwf.as_ref(), &fields, None)?; let return_field = udwf .field(WindowUDFFieldArgs::new(&new_fields, &schema_name))?; Ok(return_field) @@ -565,20 +565,23 @@ impl ExprSchemable for Expr { Expr::AggregateFunction(AggregateFunction { func, params: AggregateFunctionParams { args, .. }, + spans, }) => { let fields = args .iter() .map(|e| e.to_field(schema).map(|(_, f)| f)) .collect::>>()?; - let new_fields = verify_function_arguments(func.as_ref(), &fields)?; + let new_fields = + verify_function_arguments(func.as_ref(), &fields, spans.first())?; func.return_field(&new_fields) } - Expr::ScalarFunction(ScalarFunction { func, args }) => { + Expr::ScalarFunction(ScalarFunction { func, args, spans }) => { let fields = args .iter() .map(|e| e.to_field(schema).map(|(_, f)| f)) .collect::>>()?; - let new_fields = verify_function_arguments(func.as_ref(), &fields)?; + let new_fields = + verify_function_arguments(func.as_ref(), &fields, spans.first())?; let arguments = args .iter() @@ -720,6 +723,7 @@ impl ExprSchemable for Expr { fn verify_function_arguments( function: &F, input_fields: &[FieldRef], + func_span: Option, ) -> Result> { fields_with_udf(input_fields, function).map_err(|err| { let data_types = input_fields @@ -727,18 +731,38 @@ fn verify_function_arguments( .map(|f| f.data_type()) .cloned() .collect::>(); - plan_datafusion_err!( - "{}. {}", - match err { - DataFusionError::Plan(msg) => msg, - err => err.to_string(), - }, - utils::generate_signature_error_message( - function.name(), - function.signature(), - &data_types - ) + let name = function.name(); + let signature_msg = utils::generate_signature_error_message( + name, + function.signature(), + &data_types, + ); + let err_msg = match err { + DataFusionError::Plan(msg) => msg, + err => err.to_string(), + }; + + let types_str = data_types + .iter() + .map(|dt| dt.to_string()) + .collect::>() + .join(", "); + let candidates = function + .signature() + .type_signature + .to_string_repr_with_names(function.signature().parameter_names.as_deref()) + .iter() + .map(|args_str| format!("{name}({args_str})")) + .collect::>() + .join(", "); + let diagnostic = Diagnostic::new_error( + format!("invalid argument type(s) for '{name}'"), + func_span, ) + .with_note(format!("called with argument type(s): {types_str}"), None) + .with_help(format!("candidate function(s): {candidates}"), None); + + plan_datafusion_err!("{err_msg}. {signature_msg}").with_diagnostic(diagnostic) }) } diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs index 010441b5a25d1..10f7818c366e2 100644 --- a/datafusion/expr/src/tree_node.rs +++ b/datafusion/expr/src/tree_node.rs @@ -253,11 +253,13 @@ impl TreeNode for Expr { Expr::TryCast(TryCast { expr, field }) => expr .map_elements(f)? .update_data(|be| Expr::TryCast(TryCast::new_from_field(be, field))), - Expr::ScalarFunction(ScalarFunction { func, args }) => { + Expr::ScalarFunction(ScalarFunction { func, args, spans }) => { args.map_elements(f)?.map_data(|new_args| { - Ok(Expr::ScalarFunction(ScalarFunction::new_udf( - func, new_args, - ))) + Ok(Expr::ScalarFunction(ScalarFunction { + func, + args: new_args, + spans, + })) })? } Expr::WindowFunction(window_fun) => { @@ -304,16 +306,20 @@ impl TreeNode for Expr { order_by, null_treatment, }, + spans, }) => (args, filter, order_by).map_elements(f)?.map_data( |(new_args, new_filter, new_order_by)| { - Ok(Expr::AggregateFunction(AggregateFunction::new_udf( + Ok(Expr::AggregateFunction(AggregateFunction { func, - new_args, - distinct, - new_filter, - new_order_by, - null_treatment, - ))) + params: AggregateFunctionParams { + args: new_args, + distinct, + filter: new_filter, + order_by: new_order_by, + null_treatment, + }, + spans, + })) }, )?, Expr::GroupingSet(grouping_set) => match grouping_set { diff --git a/datafusion/functions-aggregate/src/planner.rs b/datafusion/functions-aggregate/src/planner.rs index 8a6d9b9bb1e9f..48736706473c3 100644 --- a/datafusion/functions-aggregate/src/planner.rs +++ b/datafusion/functions-aggregate/src/planner.rs @@ -17,7 +17,7 @@ //! SQL planning extensions like [`AggregateFunctionPlanner`] -use datafusion_common::Result; +use datafusion_common::{Result, Spans}; use datafusion_expr::{ Expr, expr::{AggregateFunction, AggregateFunctionParams}, @@ -52,6 +52,7 @@ impl ExprPlanner for AggregateFunctionPlanner { order_by, null_treatment, }, + spans: Spans::new(), }); let saved_name = NamePreserver::new_for_projection().save(&origin_expr); @@ -66,6 +67,7 @@ impl ExprPlanner for AggregateFunctionPlanner { order_by, null_treatment, }, + .. }) = origin_expr else { unreachable!("") diff --git a/datafusion/functions-nested/src/array_has.rs b/datafusion/functions-nested/src/array_has.rs index 04818258f040b..355a3d2a044c9 100644 --- a/datafusion/functions-nested/src/array_has.rs +++ b/datafusion/functions-nested/src/array_has.rs @@ -163,7 +163,7 @@ impl ScalarUDFImpl for ArrayHas { ))); } } - Expr::ScalarFunction(ScalarFunction { func, args }) + Expr::ScalarFunction(ScalarFunction { func, args, .. }) if func == &make_array_udf() => { // make_array has a static set of arguments, so we can pull the arguments out from it diff --git a/datafusion/functions-nested/src/extract.rs b/datafusion/functions-nested/src/extract.rs index 202a76bd0b035..e37c3787552a0 100644 --- a/datafusion/functions-nested/src/extract.rs +++ b/datafusion/functions-nested/src/extract.rs @@ -1162,6 +1162,7 @@ mod tests { Expr::Column(Column::new_unqualified("my_array")), Expr::Column(Column::new_unqualified("my_index")), ], + spans: datafusion_common::Spans::new(), }); assert_eq!( ExprSchemable::get_type(&udf_expr, &schema).unwrap(), diff --git a/datafusion/functions-nested/src/planner.rs b/datafusion/functions-nested/src/planner.rs index e96fdb7d4baca..e5590613c18de 100644 --- a/datafusion/functions-nested/src/planner.rs +++ b/datafusion/functions-nested/src/planner.rs @@ -154,6 +154,7 @@ impl ExprPlanner for FieldAccessPlanner { order_by, null_treatment, }, + .. }) if is_array_agg(&func) => Ok(PlannerResult::Planned( Expr::AggregateFunction(AggregateFunction::new_udf( nth_value_udaf(), diff --git a/datafusion/functions/src/core/getfield.rs b/datafusion/functions/src/core/getfield.rs index 93a4cddef453e..60d8639152c12 100644 --- a/datafusion/functions/src/core/getfield.rs +++ b/datafusion/functions/src/core/getfield.rs @@ -307,6 +307,7 @@ fn simplify_get_field_over_struct_constructor(args: &[Expr]) -> Option { let Expr::ScalarFunction(ScalarFunction { func, args: ctor_args, + .. }) = base else { return None; @@ -607,6 +608,7 @@ impl ScalarUDFImpl for GetFieldFunc { if let Expr::ScalarFunction(ScalarFunction { func, args: inner_args, + .. }) = current_expr && func.inner().is::() { diff --git a/datafusion/functions/src/math/log.rs b/datafusion/functions/src/math/log.rs index 2ca2ed1b572be..892cec76f0c3d 100644 --- a/datafusion/functions/src/math/log.rs +++ b/datafusion/functions/src/math/log.rs @@ -367,7 +367,7 @@ impl ScalarUDFImpl for LogFunc { &info.get_data_type(&base)?, )?))) } - Expr::ScalarFunction(ScalarFunction { func, mut args }) + Expr::ScalarFunction(ScalarFunction { func, mut args, .. }) if is_pow(&func) && args.len() == 2 && base == args[0] => { let b = args.pop().unwrap(); // length checked above diff --git a/datafusion/functions/src/math/power.rs b/datafusion/functions/src/math/power.rs index 252a3ea0b31d7..246401a853758 100644 --- a/datafusion/functions/src/math/power.rs +++ b/datafusion/functions/src/math/power.rs @@ -174,7 +174,7 @@ impl ScalarUDFImpl for PowerFunc { base, &base_type, ))) } - Expr::ScalarFunction(ScalarFunction { func, mut args }) + Expr::ScalarFunction(ScalarFunction { func, mut args, .. }) if is_log(&func) && args.len() == 2 && base == args[0] => { let b = args.pop().unwrap(); // length checked above diff --git a/datafusion/functions/src/string/concat.rs b/datafusion/functions/src/string/concat.rs index b10db23472c99..43a4f1b0286a6 100644 --- a/datafusion/functions/src/string/concat.rs +++ b/datafusion/functions/src/string/concat.rs @@ -27,7 +27,7 @@ use crate::strings::{ }; use datafusion_common::cast::{as_binary_array, as_string_array, as_string_view_array}; use datafusion_common::{ - Result, ScalarValue, exec_datafusion_err, internal_err, plan_err, + Result, ScalarValue, Spans, exec_datafusion_err, internal_err, plan_err, }; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext}; @@ -384,6 +384,7 @@ pub(crate) fn simplify_concat(args: Vec) -> Result { ScalarFunction { func: concat(), args: new_args, + spans: Spans::new(), }, ))) } else { diff --git a/datafusion/functions/src/string/concat_ws.rs b/datafusion/functions/src/string/concat_ws.rs index 2c2d4bd42165b..8ac433826d046 100644 --- a/datafusion/functions/src/string/concat_ws.rs +++ b/datafusion/functions/src/string/concat_ws.rs @@ -30,7 +30,7 @@ use crate::strings::{ use datafusion_common::cast::{ as_large_string_array, as_string_array, as_string_view_array, }; -use datafusion_common::{Result, ScalarValue, exec_err, internal_err, plan_err}; +use datafusion_common::{Result, ScalarValue, Spans, exec_err, internal_err, plan_err}; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext}; use datafusion_expr::{ColumnarValue, Documentation, Expr, Volatility, lit}; @@ -431,6 +431,7 @@ fn simplify_concat_ws(delimiter: &Expr, args: &[Expr]) -> Result Result { let case = coerce_case_expression(case, self.schema)?; Ok(Transformed::yes(Expr::Case(case))) } - Expr::ScalarFunction(ScalarFunction { func, args }) => { + Expr::ScalarFunction(ScalarFunction { func, args, .. }) => { let new_expr = coerce_arguments_for_signature(args, self.schema, func.as_ref())?; Ok(Transformed::yes(Expr::ScalarFunction( @@ -712,6 +712,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { order_by, null_treatment, }, + .. }) => { let new_expr = coerce_arguments_for_signature(args, self.schema, func.as_ref())?; diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 2775d62144c56..4803ea39c236a 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -654,7 +654,7 @@ impl CSEController for ExprCSEController<'_> { // In case of `ScalarFunction`s and `HigherOrderFunction`s we don't know which children are surely // executed so start visiting all children conditionally and stop the // recursion with `TreeNodeRecursion::Jump`. - Expr::ScalarFunction(ScalarFunction { func, args }) => { + Expr::ScalarFunction(ScalarFunction { func, args, .. }) => { func.conditional_arguments(args) } Expr::HigherOrderFunction(HigherOrderFunction { func, args }) => { diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 39c8541b51b2f..c73fc134a1b2a 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -1608,17 +1608,20 @@ impl TreeNodeRewriter for Simplifier<'_> { .not(), ) } - Expr::ScalarFunction(ScalarFunction { func: udf, args }) => { - match udf.simplify(args, info)? { - ExprSimplifyResult::Original(args) => { - Transformed::no(Expr::ScalarFunction(ScalarFunction { - func: udf, - args, - })) - } - ExprSimplifyResult::Simplified(expr) => Transformed::yes(expr), + Expr::ScalarFunction(ScalarFunction { + func: udf, + args, + spans, + }) => match udf.simplify(args, info)? { + ExprSimplifyResult::Original(args) => { + Transformed::no(Expr::ScalarFunction(ScalarFunction { + func: udf, + args, + spans, + })) } - } + ExprSimplifyResult::Simplified(expr) => Transformed::yes(expr), + }, Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction { ref func, @@ -2108,7 +2111,7 @@ fn get_preimage( right_expr: &Expr, info: &SimplifyContext, ) -> Result { - let Expr::ScalarFunction(ScalarFunction { func, args }) = left_expr else { + let Expr::ScalarFunction(ScalarFunction { func, args, .. }) = left_expr else { return Ok(PreimageResult::None); }; if !is_literal_or_literal_cast(right_expr) { diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index 00c8fab228117..a3121c936ec0c 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -76,6 +76,7 @@ fn is_single_distinct_agg(aggr_expr: &[Expr]) -> Result { order_by, null_treatment: _, }, + .. }) = expr { if filter.is_some() || !order_by.is_empty() { @@ -190,6 +191,7 @@ impl OptimizerRule for SingleDistinctToGroupBy { order_by, null_treatment, }, + .. }) => { if distinct { assert_eq_or_internal_err!( diff --git a/datafusion/physical-expr/src/aggregate.rs b/datafusion/physical-expr/src/aggregate.rs index e5d55aba4f51c..5ecdfc031f8ff 100644 --- a/datafusion/physical-expr/src/aggregate.rs +++ b/datafusion/physical-expr/src/aggregate.rs @@ -504,6 +504,7 @@ impl<'a> LoweredAggregateBuilder<'a> { order_by, null_treatment, }, + .. }) = &expr else { return internal_err!("Invalid aggregate expression '{expr:?}'"); diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index d0d0508a106a5..7cdfc3a278d4d 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -336,7 +336,7 @@ pub fn create_physical_expr( input_dfschema, execution_props, )?), - Expr::ScalarFunction(ScalarFunction { func, args }) => { + Expr::ScalarFunction(ScalarFunction { func, args, .. }) => { let physical_args = create_physical_exprs(args, input_dfschema, execution_props)?; let config_options = match execution_props.config_options.as_ref() { diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 516aca4094451..48904317e6a5c 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -375,6 +375,7 @@ pub fn serialize_expr( order_by, null_treatment, }, + .. }) => { let mut buf = Vec::new(); let _ = codec.try_encode_udaf(func, &mut buf); @@ -402,7 +403,7 @@ pub fn serialize_expr( "Proto serialization error: Scalar Variable not supported".to_string(), )); } - Expr::ScalarFunction(ScalarFunction { func, args }) => { + Expr::ScalarFunction(ScalarFunction { func, args, .. }) => { let mut buf = Vec::new(); let _ = codec.try_encode_udf(func, &mut buf); protobuf::LogicalExprNode { diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 701485eee733c..70003441bac95 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -348,7 +348,12 @@ impl SqlToRel<'_, S> { }; // After resolution, all arguments are positional - let inner = ScalarFunction::new_udf(fm, resolved_args); + let mut inner = ScalarFunction::new_udf(fm, resolved_args); + if self.options.collect_spans + && let Some(span) = Span::try_from_sqlparser_span(sql_parser_span) + { + inner.spans_mut().add_span(span); + } if name.eq_ignore_ascii_case(inner.name()) { return Ok(Expr::ScalarFunction(inner)); @@ -844,7 +849,7 @@ impl SqlToRel<'_, S> { null_treatment, } = aggregate_expr; - let inner = expr::AggregateFunction::new_udf( + let mut inner = expr::AggregateFunction::new_udf( func, args, distinct, @@ -852,6 +857,11 @@ impl SqlToRel<'_, S> { order_by, null_treatment, ); + if self.options.collect_spans + && let Some(span) = Span::try_from_sqlparser_span(sql_parser_span) + { + inner.spans_mut().add_span(span); + } if name.eq_ignore_ascii_case(inner.func.name()) { return Ok(Expr::AggregateFunction(inner)); diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index bcc46e837bba2..43695b4de2ca6 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -119,7 +119,7 @@ impl Unparser<'_> { negated: *negated, }) } - Expr::ScalarFunction(ScalarFunction { func, args }) => { + Expr::ScalarFunction(ScalarFunction { func, args, .. }) => { let func_name = func.name(); if let Some(expr) = self @@ -3257,6 +3257,7 @@ mod tests { Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64)), Expr::Literal(ScalarValue::Int64(Some(2)), None), ], + spans: Spans::new(), }); let ast = unparser.expr_to_sql(&expr)?; @@ -3332,6 +3333,7 @@ mod tests { let expr = Expr::ScalarFunction(ScalarFunction { func: Arc::new(ScalarUDF::from(FromUnixtimeFunc::new())), args: vec![col("date_col")], + spans: Spans::new(), }); let ast = unparser.expr_to_sql(&expr)?; @@ -3416,6 +3418,7 @@ mod tests { Expr::Literal(ScalarValue::Utf8(Some(precision.to_string())), None), col("date_col"), ], + spans: Spans::new(), }); let ast = unparser.expr_to_sql(&expr)?; @@ -3674,6 +3677,7 @@ mod tests { datafusion_functions::datetime::date_part::DatePartFunc::new(), )), args: vec![lit("YEAR"), col("date_col")], + spans: Spans::new(), }); let actual = format!("{}", unparser.expr_to_sql(&expr)?); assert_eq!(actual, "EXTRACT(YEAR FROM `date_col`)"); diff --git a/datafusion/sql/tests/cases/diagnostic.rs b/datafusion/sql/tests/cases/diagnostic.rs index df46a48d88579..69996db353d6d 100644 --- a/datafusion/sql/tests/cases/diagnostic.rs +++ b/datafusion/sql/tests/cases/diagnostic.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use datafusion_expr::test::function_stub::sum_udaf; use datafusion_functions::string; use insta::assert_snapshot; use std::{collections::HashMap, ops::ControlFlow, sync::Arc}; @@ -33,6 +34,27 @@ use regex::Regex; use crate::{MockContextProvider, MockSessionState}; +fn do_query_with_state(sql: &'static str, state: MockSessionState) -> Diagnostic { + let statement = DFParserBuilder::new(sql) + .build() + .expect("unable to create parser") + .parse_statement() + .expect("unable to parse query"); + let options = ParserOptions { + collect_spans: true, + ..ParserOptions::default() + }; + let context = MockContextProvider { state }; + let sql_to_rel = SqlToRel::new_with_options(&context, options); + match sql_to_rel.statement_to_plan(statement) { + Ok(_) => panic!("expected error"), + Err(err) => match err.diagnostic() { + Some(diag) => diag.clone(), + None => panic!("expected diagnostic, but got: {err}"), + }, + } +} + fn do_query(sql: &'static str) -> Diagnostic { let statement = DFParserBuilder::new(sql) .build() @@ -481,6 +503,7 @@ fn test_syntax_error() -> Result<()> { } } + #[test] fn test_eq_null_warning_in_where() -> Result<()> { let query = "SELECT * FROM person WHERE /*cmp*/first_name = /*null*/NULL/*null+cmp*/"; @@ -671,3 +694,16 @@ fn test_multiple_null_comparison_warnings() -> Result<()> { ); Ok(()) } + +#[test] +fn test_invalid_aggregate_function_argument_types() -> Result<()> { + let state = MockSessionState::default().with_aggregate_function(sum_udaf()); + let query = "SELECT /*a*/sum/*a*/(first_name) FROM person"; + let spans = get_spans(query); + let diag = do_query_with_state(query, state); + assert_snapshot!(diag.message, @"invalid argument type(s) for 'sum'"); + assert_eq!(diag.span, Some(spans["a"])); + assert_snapshot!(diag.notes[0].message, @"called with argument type(s): Utf8"); + assert_snapshot!(diag.helps[0].message, @"candidate function(s): sum(UserDefined)"); + Ok(()) +}