Skip to content

Commit 91092a6

Browse files
committed
feat: fuse CheckOverflow with Cast and WideDecimalBinaryExpr
Eliminate redundant CheckOverflow when wrapping WideDecimalBinaryExpr (which already handles overflow). Fuse Cast(Decimal128→Decimal128) + CheckOverflow into a single DecimalRescaleCheckOverflow expression that rescales and validates precision in one pass.
1 parent d7495bd commit 91092a6

4 files changed

Lines changed: 466 additions & 6 deletions

File tree

native/core/src/execution/planner.rs

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -126,9 +126,9 @@ use datafusion_comet_proto::{
126126
use datafusion_comet_spark_expr::monotonically_increasing_id::MonotonicallyIncreasingId;
127127
use datafusion_comet_spark_expr::{
128128
ArrayInsert, Avg, AvgDecimal, Cast, CheckOverflow, Correlation, Covariance, CreateNamedStruct,
129-
GetArrayStructFields, GetStructField, IfExpr, ListExtract, NormalizeNaNAndZero, RandExpr,
130-
RandnExpr, SparkCastOptions, Stddev, SumDecimal, ToJson, UnboundColumn, Variance,
131-
WideDecimalBinaryExpr, WideDecimalOp,
129+
DecimalRescaleCheckOverflow, GetArrayStructFields, GetStructField, IfExpr, ListExtract,
130+
NormalizeNaNAndZero, RandExpr, RandnExpr, SparkCastOptions, Stddev, SumDecimal, ToJson,
131+
UnboundColumn, Variance, WideDecimalBinaryExpr, WideDecimalOp,
132132
};
133133
use itertools::Itertools;
134134
use jni::objects::GlobalRef;
@@ -377,10 +377,37 @@ impl PhysicalPlanner {
377377
)))
378378
}
379379
ExprStruct::CheckOverflow(expr) => {
380-
let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema)?;
380+
let child =
381+
self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&input_schema))?;
381382
let data_type = to_arrow_datatype(expr.datatype.as_ref().unwrap());
382383
let fail_on_error = expr.fail_on_error;
383384

385+
// WideDecimalBinaryExpr already handles overflow — skip redundant check
386+
if child
387+
.as_any()
388+
.downcast_ref::<WideDecimalBinaryExpr>()
389+
.is_some()
390+
{
391+
return Ok(child);
392+
}
393+
394+
// Fuse Cast(Decimal128→Decimal128) + CheckOverflow into single rescale+check
395+
if let Some(cast) = child.as_any().downcast_ref::<Cast>() {
396+
if let (
397+
DataType::Decimal128(p_out, s_out),
398+
Ok(DataType::Decimal128(_p_in, s_in)),
399+
) = (&data_type, cast.child.data_type(&input_schema))
400+
{
401+
return Ok(Arc::new(DecimalRescaleCheckOverflow::new(
402+
Arc::clone(&cast.child),
403+
s_in,
404+
*p_out,
405+
*s_out,
406+
fail_on_error,
407+
)));
408+
}
409+
}
410+
384411
Ok(Arc::new(CheckOverflow::new(
385412
child,
386413
data_type,

native/spark-expr/src/lib.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@ pub use json_funcs::{FromJson, ToJson};
8181
pub use math_funcs::{
8282
create_modulo_expr, create_negate_expr, spark_ceil, spark_decimal_div,
8383
spark_decimal_integral_div, spark_floor, spark_make_decimal, spark_round, spark_unhex,
84-
spark_unscaled_value, CheckOverflow, NegativeExpr, NormalizeNaNAndZero, WideDecimalBinaryExpr,
85-
WideDecimalOp,
84+
spark_unscaled_value, CheckOverflow, DecimalRescaleCheckOverflow, NegativeExpr,
85+
NormalizeNaNAndZero, WideDecimalBinaryExpr, WideDecimalOp,
8686
};
8787
pub use string_funcs::*;
8888

0 commit comments

Comments
 (0)