From 4b3fd485871b9f3f7464489f195e71feda3e6ecc Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 2 Mar 2026 10:53:10 -0700 Subject: [PATCH 1/6] feat: fused WideDecimalBinaryExpr for Decimal128 add/sub/mul MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace the 4-node expression tree (Cast→BinaryExpr→Cast→Cast) used for Decimal128 arithmetic that may overflow with a single fused expression that performs i256 register arithmetic directly. This reduces per-batch allocation from 4 intermediate arrays (112 bytes/elem) to 1 output array (16 bytes/elem). The new WideDecimalBinaryExpr evaluates children, performs add/sub/mul using i256 intermediates via try_binary, applies scale adjustment with HALF_UP rounding, checks precision bounds, and outputs a single Decimal128 array. Follows the same pattern as decimal_div. --- native/core/src/execution/planner.rs | 43 +- native/spark-expr/src/lib.rs | 3 +- native/spark-expr/src/math_funcs/mod.rs | 2 + .../math_funcs/wide_decimal_binary_expr.rs | 502 ++++++++++++++++++ 4 files changed, 528 insertions(+), 22 deletions(-) create mode 100644 native/spark-expr/src/math_funcs/wide_decimal_binary_expr.rs diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 094777e796..a56ac99ce1 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -128,6 +128,7 @@ use datafusion_comet_spark_expr::{ ArrayInsert, Avg, AvgDecimal, Cast, CheckOverflow, Correlation, Covariance, CreateNamedStruct, GetArrayStructFields, GetStructField, IfExpr, ListExtract, NormalizeNaNAndZero, RandExpr, RandnExpr, SparkCastOptions, Stddev, SumDecimal, ToJson, UnboundColumn, Variance, + WideDecimalBinaryExpr, WideDecimalOp, }; use itertools::Itertools; use jni::objects::GlobalRef; @@ -674,31 +675,31 @@ impl PhysicalPlanner { ) { ( DataFusionOperator::Plus | DataFusionOperator::Minus | DataFusionOperator::Multiply, - Ok(DataType::Decimal128(p1, s1)), - Ok(DataType::Decimal128(p2, s2)), + Ok(DataType::Decimal128(_p1, _s1)), + Ok(DataType::Decimal128(_p2, _s2)), ) if ((op == DataFusionOperator::Plus || op == DataFusionOperator::Minus) - && max(s1, s2) as u8 + max(p1 - s1 as u8, p2 - s2 as u8) + && max(_s1, _s2) as u8 + max(_p1 - _s1 as u8, _p2 - _s2 as u8) >= DECIMAL128_MAX_PRECISION) - || (op == DataFusionOperator::Multiply && p1 + p2 >= DECIMAL128_MAX_PRECISION) => + || (op == DataFusionOperator::Multiply + && _p1 + _p2 >= DECIMAL128_MAX_PRECISION) => { let data_type = return_type.map(to_arrow_datatype).unwrap(); - // For some Decimal128 operations, we need wider internal digits. - // Cast left and right to Decimal256 and cast the result back to Decimal128 - let left = Arc::new(Cast::new( - left, - DataType::Decimal256(p1, s1), - SparkCastOptions::new_without_timezone(EvalMode::Legacy, false), - )); - let right = Arc::new(Cast::new( - right, - DataType::Decimal256(p2, s2), - SparkCastOptions::new_without_timezone(EvalMode::Legacy, false), - )); - let child = Arc::new(BinaryExpr::new(left, op, right)); - Ok(Arc::new(Cast::new( - child, - data_type, - SparkCastOptions::new_without_timezone(EvalMode::Legacy, false), + let (p_out, s_out) = match &data_type { + DataType::Decimal128(p, s) => (*p, *s), + dt => { + return Err(ExecutionError::GeneralError(format!( + "Expected Decimal128 return type, got {dt:?}" + ))) + } + }; + let wide_op = match op { + DataFusionOperator::Plus => WideDecimalOp::Add, + DataFusionOperator::Minus => WideDecimalOp::Subtract, + DataFusionOperator::Multiply => WideDecimalOp::Multiply, + _ => unreachable!(), + }; + Ok(Arc::new(WideDecimalBinaryExpr::new( + left, right, wide_op, p_out, s_out, eval_mode, ))) } ( diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs index 40eb180ab8..47f0c31c1b 100644 --- a/native/spark-expr/src/lib.rs +++ b/native/spark-expr/src/lib.rs @@ -81,7 +81,8 @@ pub use json_funcs::{FromJson, ToJson}; pub use math_funcs::{ create_modulo_expr, create_negate_expr, spark_ceil, spark_decimal_div, spark_decimal_integral_div, spark_floor, spark_make_decimal, spark_round, spark_unhex, - spark_unscaled_value, CheckOverflow, NegativeExpr, NormalizeNaNAndZero, + spark_unscaled_value, CheckOverflow, NegativeExpr, NormalizeNaNAndZero, WideDecimalBinaryExpr, + WideDecimalOp, }; pub use string_funcs::*; diff --git a/native/spark-expr/src/math_funcs/mod.rs b/native/spark-expr/src/math_funcs/mod.rs index 35c1dc6504..1219bc7208 100644 --- a/native/spark-expr/src/math_funcs/mod.rs +++ b/native/spark-expr/src/math_funcs/mod.rs @@ -26,6 +26,7 @@ mod negative; mod round; pub(crate) mod unhex; mod utils; +mod wide_decimal_binary_expr; pub use ceil::spark_ceil; pub use div::spark_decimal_div; @@ -36,3 +37,4 @@ pub use modulo_expr::create_modulo_expr; pub use negative::{create_negate_expr, NegativeExpr}; pub use round::spark_round; pub use unhex::spark_unhex; +pub use wide_decimal_binary_expr::{WideDecimalBinaryExpr, WideDecimalOp}; diff --git a/native/spark-expr/src/math_funcs/wide_decimal_binary_expr.rs b/native/spark-expr/src/math_funcs/wide_decimal_binary_expr.rs new file mode 100644 index 0000000000..a1d5c92d3d --- /dev/null +++ b/native/spark-expr/src/math_funcs/wide_decimal_binary_expr.rs @@ -0,0 +1,502 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Fused wide-decimal binary expression for Decimal128 add/sub/mul that may overflow. +//! +//! Instead of building a 4-node expression tree (Cast→BinaryExpr→Cast→Cast), this performs +//! i256 intermediate arithmetic in a single expression, producing only one output array. + +use crate::math_funcs::utils::get_precision_scale; +use crate::EvalMode; +use arrow::array::{Array, ArrayRef, AsArray, Decimal128Array}; +use arrow::datatypes::{i256, DataType, Decimal128Type, Schema}; +use arrow::error::ArrowError; +use arrow::record_batch::RecordBatch; +use datafusion::common::Result; +use datafusion::logical_expr::ColumnarValue; +use datafusion::physical_expr::PhysicalExpr; +use std::fmt::{Display, Formatter}; +use std::hash::Hash; +use std::{any::Any, sync::Arc}; + +/// The arithmetic operation to perform. +#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy)] +pub enum WideDecimalOp { + Add, + Subtract, + Multiply, +} + +impl Display for WideDecimalOp { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + WideDecimalOp::Add => write!(f, "+"), + WideDecimalOp::Subtract => write!(f, "-"), + WideDecimalOp::Multiply => write!(f, "*"), + } + } +} + +/// A fused expression that evaluates Decimal128 add/sub/mul using i256 intermediate arithmetic, +/// applies scale adjustment with HALF_UP rounding, checks precision bounds, and outputs +/// a single Decimal128 array. +#[derive(Debug, Eq)] +pub struct WideDecimalBinaryExpr { + left: Arc, + right: Arc, + op: WideDecimalOp, + output_precision: u8, + output_scale: i8, + eval_mode: EvalMode, +} + +impl Hash for WideDecimalBinaryExpr { + fn hash(&self, state: &mut H) { + self.left.hash(state); + self.right.hash(state); + self.op.hash(state); + self.output_precision.hash(state); + self.output_scale.hash(state); + self.eval_mode.hash(state); + } +} + +impl PartialEq for WideDecimalBinaryExpr { + fn eq(&self, other: &Self) -> bool { + self.left.eq(&other.left) + && self.right.eq(&other.right) + && self.op == other.op + && self.output_precision == other.output_precision + && self.output_scale == other.output_scale + && self.eval_mode == other.eval_mode + } +} + +impl WideDecimalBinaryExpr { + pub fn new( + left: Arc, + right: Arc, + op: WideDecimalOp, + output_precision: u8, + output_scale: i8, + eval_mode: EvalMode, + ) -> Self { + Self { + left, + right, + op, + output_precision, + output_scale, + eval_mode, + } + } +} + +impl Display for WideDecimalBinaryExpr { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "WideDecimalBinaryExpr [{} {} {}, output: Decimal128({}, {})]", + self.left, self.op, self.right, self.output_precision, self.output_scale + ) + } +} + +/// Compute `value / divisor` with HALF_UP rounding. +#[inline] +fn div_round_half_up(value: i256, divisor: i256) -> i256 { + let (quot, rem) = (value / divisor, value % divisor); + // HALF_UP: if |remainder| * 2 >= |divisor|, round away from zero + let abs_rem_x2 = if rem < i256::ZERO { + rem.wrapping_neg() + } else { + rem + } + .wrapping_mul(i256::from_i128(2)); + let abs_divisor = if divisor < i256::ZERO { + divisor.wrapping_neg() + } else { + divisor + }; + if abs_rem_x2 >= abs_divisor { + if (value < i256::ZERO) != (divisor < i256::ZERO) { + quot.wrapping_sub(i256::ONE) + } else { + quot.wrapping_add(i256::ONE) + } + } else { + quot + } +} + +/// i256 constant for 10. +const I256_TEN: i256 = i256::from_i128(10); + +/// Compute 10^exp as i256. +#[inline] +fn i256_pow10(exp: u32) -> i256 { + let mut result = i256::ONE; + for _ in 0..exp { + result = result.wrapping_mul(I256_TEN); + } + result +} + +/// Maximum i128 value for a given decimal precision (1-indexed). +/// precision p allows values in [-10^p + 1, 10^p - 1]. +#[inline] +fn max_for_precision(precision: u8) -> i256 { + i256_pow10(precision as u32).wrapping_sub(i256::ONE) +} + +impl PhysicalExpr for WideDecimalBinaryExpr { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, _input_schema: &Schema) -> Result { + Ok(DataType::Decimal128( + self.output_precision, + self.output_scale, + )) + } + + fn nullable(&self, _input_schema: &Schema) -> Result { + Ok(true) + } + + fn evaluate(&self, batch: &RecordBatch) -> Result { + let left_val = self.left.evaluate(batch)?; + let right_val = self.right.evaluate(batch)?; + + let (left_arr, right_arr): (ArrayRef, ArrayRef) = match (&left_val, &right_val) { + (ColumnarValue::Array(l), ColumnarValue::Array(r)) => (Arc::clone(l), Arc::clone(r)), + (ColumnarValue::Scalar(l), ColumnarValue::Array(r)) => { + (l.to_array_of_size(r.len())?, Arc::clone(r)) + } + (ColumnarValue::Array(l), ColumnarValue::Scalar(r)) => { + (Arc::clone(l), r.to_array_of_size(l.len())?) + } + (ColumnarValue::Scalar(l), ColumnarValue::Scalar(r)) => (l.to_array()?, r.to_array()?), + }; + + let left = left_arr.as_primitive::(); + let right = right_arr.as_primitive::(); + let (_p1, s1) = get_precision_scale(left.data_type()); + let (_p2, s2) = get_precision_scale(right.data_type()); + + let p_out = self.output_precision; + let s_out = self.output_scale; + let op = self.op; + let eval_mode = self.eval_mode; + + let bound = max_for_precision(p_out); + let neg_bound = i256::ZERO.wrapping_sub(bound); + + let result: Decimal128Array = match op { + WideDecimalOp::Add | WideDecimalOp::Subtract => { + let max_scale = std::cmp::max(s1, s2); + let l_scale_up = i256_pow10((max_scale - s1) as u32); + let r_scale_up = i256_pow10((max_scale - s2) as u32); + let need_rescale = s_out < max_scale; + let rescale_divisor = if need_rescale { + i256_pow10((max_scale - s_out) as u32) + } else { + i256::ONE + }; + + arrow::compute::kernels::arity::try_binary(left, right, |l, r| { + let l256 = i256::from_i128(l).wrapping_mul(l_scale_up); + let r256 = i256::from_i128(r).wrapping_mul(r_scale_up); + let raw = match op { + WideDecimalOp::Add => l256.wrapping_add(r256), + WideDecimalOp::Subtract => l256.wrapping_sub(r256), + _ => unreachable!(), + }; + let result = if need_rescale { + div_round_half_up(raw, rescale_divisor) + } else { + raw + }; + check_overflow_and_convert(result, bound, neg_bound, eval_mode) + })? + } + WideDecimalOp::Multiply => { + let natural_scale = s1 + s2; + let need_rescale = s_out < natural_scale; + let rescale_divisor = if need_rescale { + i256_pow10((natural_scale - s_out) as u32) + } else { + i256::ONE + }; + + arrow::compute::kernels::arity::try_binary(left, right, |l, r| { + let raw = i256::from_i128(l).wrapping_mul(i256::from_i128(r)); + let result = if need_rescale { + div_round_half_up(raw, rescale_divisor) + } else { + raw + }; + check_overflow_and_convert(result, bound, neg_bound, eval_mode) + })? + } + }; + + let result = if eval_mode != EvalMode::Ansi { + result.null_if_overflow_precision(p_out) + } else { + result + }; + let result = result.with_data_type(DataType::Decimal128(p_out, s_out)); + Ok(ColumnarValue::Array(Arc::new(result))) + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.left, &self.right] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + Ok(Arc::new(WideDecimalBinaryExpr::new( + Arc::clone(&children[0]), + Arc::clone(&children[1]), + self.op, + self.output_precision, + self.output_scale, + self.eval_mode, + ))) + } + + fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + Display::fmt(self, f) + } +} + +/// Check if the i256 result fits in the output precision. In Ansi mode, return an error +/// on overflow. In Legacy/Try mode, return i128::MAX as a sentinel value that will be +/// nullified by `null_if_overflow_precision`. +#[inline] +fn check_overflow_and_convert( + result: i256, + bound: i256, + neg_bound: i256, + eval_mode: EvalMode, +) -> Result { + if result > bound || result < neg_bound { + if eval_mode == EvalMode::Ansi { + return Err(ArrowError::ComputeError("Arithmetic overflow".to_string())); + } + // Sentinel value — will be nullified by null_if_overflow_precision + Ok(i128::MAX) + } else { + Ok(result.to_i128().unwrap()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::Decimal128Array; + use arrow::datatypes::{Field, Schema}; + use arrow::record_batch::RecordBatch; + use datafusion::physical_expr::expressions::Column; + + fn make_batch( + left_values: Vec>, + left_precision: u8, + left_scale: i8, + right_values: Vec>, + right_precision: u8, + right_scale: i8, + ) -> RecordBatch { + let left_arr = Decimal128Array::from(left_values) + .with_data_type(DataType::Decimal128(left_precision, left_scale)); + let right_arr = Decimal128Array::from(right_values) + .with_data_type(DataType::Decimal128(right_precision, right_scale)); + let schema = Schema::new(vec![ + Field::new("left", left_arr.data_type().clone(), true), + Field::new("right", right_arr.data_type().clone(), true), + ]); + RecordBatch::try_new( + Arc::new(schema), + vec![Arc::new(left_arr), Arc::new(right_arr)], + ) + .unwrap() + } + + fn eval_expr( + batch: &RecordBatch, + op: WideDecimalOp, + output_precision: u8, + output_scale: i8, + eval_mode: EvalMode, + ) -> Result { + let left: Arc = Arc::new(Column::new("left", 0)); + let right: Arc = Arc::new(Column::new("right", 1)); + let expr = + WideDecimalBinaryExpr::new(left, right, op, output_precision, output_scale, eval_mode); + match expr.evaluate(batch)? { + ColumnarValue::Array(arr) => Ok(arr), + _ => panic!("expected array"), + } + } + + #[test] + fn test_add_same_scale() { + // Decimal128(38, 10) + Decimal128(38, 10) -> Decimal128(38, 10) + let batch = make_batch( + vec![Some(1000000000), Some(2500000000)], // 0.1, 0.25 (scale 10 → divide by 10^10 mentally) + 38, + 10, + vec![Some(2000000000), Some(7500000000)], + 38, + 10, + ); + let result = eval_expr(&batch, WideDecimalOp::Add, 38, 10, EvalMode::Legacy).unwrap(); + let arr = result.as_primitive::(); + assert_eq!(arr.value(0), 3000000000); // 0.1 + 0.2 + assert_eq!(arr.value(1), 10000000000); // 0.25 + 0.75 + } + + #[test] + fn test_subtract_same_scale() { + let batch = make_batch( + vec![Some(5000), Some(1000)], + 38, + 2, + vec![Some(3000), Some(2000)], + 38, + 2, + ); + let result = eval_expr(&batch, WideDecimalOp::Subtract, 38, 2, EvalMode::Legacy).unwrap(); + let arr = result.as_primitive::(); + assert_eq!(arr.value(0), 2000); // 50.00 - 30.00 + assert_eq!(arr.value(1), -1000); // 10.00 - 20.00 + } + + #[test] + fn test_add_different_scales() { + // Decimal128(10, 2) + Decimal128(10, 4) -> output scale 4 + let batch = make_batch( + vec![Some(150)], // 1.50 + 10, + 2, + vec![Some(2500)], // 0.2500 + 10, + 4, + ); + let result = eval_expr(&batch, WideDecimalOp::Add, 38, 4, EvalMode::Legacy).unwrap(); + let arr = result.as_primitive::(); + assert_eq!(arr.value(0), 17500); // 1.5000 + 0.2500 = 1.7500 + } + + #[test] + fn test_multiply_with_scale_reduction() { + // Decimal128(20, 5) * Decimal128(20, 5) -> natural scale 10, output scale 6 + // 1.00000 * 2.00000 = 2.000000 + let batch = make_batch( + vec![Some(100000)], // 1.00000 + 20, + 5, + vec![Some(200000)], // 2.00000 + 20, + 5, + ); + let result = eval_expr(&batch, WideDecimalOp::Multiply, 38, 6, EvalMode::Legacy).unwrap(); + let arr = result.as_primitive::(); + assert_eq!(arr.value(0), 2000000); // 2.000000 + } + + #[test] + fn test_multiply_half_up_rounding() { + // Test HALF_UP rounding: 1.5 * 1.5 = 2.25, but if output scale=1, should round to 2.3 + // Input: scale 1, values 15 (1.5) * 15 (1.5) = natural scale 2, raw = 225 + // Output scale 1: 225 / 10 = 22 remainder 5 -> HALF_UP rounds to 23 + let batch = make_batch( + vec![Some(15)], // 1.5 + 10, + 1, + vec![Some(15)], // 1.5 + 10, + 1, + ); + let result = eval_expr(&batch, WideDecimalOp::Multiply, 38, 1, EvalMode::Legacy).unwrap(); + let arr = result.as_primitive::(); + assert_eq!(arr.value(0), 23); // 2.3 + } + + #[test] + fn test_multiply_half_up_rounding_negative() { + // -1.5 * 1.5 = -2.25, output scale 1: -225/10 => -22 rem -5 -> HALF_UP rounds to -23 + let batch = make_batch( + vec![Some(-15)], // -1.5 + 10, + 1, + vec![Some(15)], // 1.5 + 10, + 1, + ); + let result = eval_expr(&batch, WideDecimalOp::Multiply, 38, 1, EvalMode::Legacy).unwrap(); + let arr = result.as_primitive::(); + assert_eq!(arr.value(0), -23); // -2.3 + } + + #[test] + fn test_overflow_legacy_mode_returns_null() { + // Use precision 1 (max value 9), so 5 + 5 = 10 overflows + let batch = make_batch(vec![Some(5)], 38, 0, vec![Some(5)], 38, 0); + let result = eval_expr(&batch, WideDecimalOp::Add, 1, 0, EvalMode::Legacy).unwrap(); + let arr = result.as_primitive::(); + assert!(arr.is_null(0)); + } + + #[test] + fn test_overflow_ansi_mode_returns_error() { + let batch = make_batch(vec![Some(5)], 38, 0, vec![Some(5)], 38, 0); + let result = eval_expr(&batch, WideDecimalOp::Add, 1, 0, EvalMode::Ansi); + assert!(result.is_err()); + } + + #[test] + fn test_null_propagation() { + let batch = make_batch(vec![Some(100), None], 10, 2, vec![None, Some(200)], 10, 2); + let result = eval_expr(&batch, WideDecimalOp::Add, 38, 2, EvalMode::Legacy).unwrap(); + let arr = result.as_primitive::(); + assert!(arr.is_null(0)); + assert!(arr.is_null(1)); + } + + #[test] + fn test_zeros() { + let batch = make_batch(vec![Some(0)], 38, 10, vec![Some(0)], 38, 10); + let result = eval_expr(&batch, WideDecimalOp::Multiply, 38, 10, EvalMode::Legacy).unwrap(); + let arr = result.as_primitive::(); + assert_eq!(arr.value(0), 0); + } + + #[test] + fn test_max_precision_values() { + // Max Decimal128(38,0) value: 10^38 - 1 + let max_val = 10i128.pow(38) - 1; + let batch = make_batch(vec![Some(max_val)], 38, 0, vec![Some(0)], 38, 0); + let result = eval_expr(&batch, WideDecimalOp::Add, 38, 0, EvalMode::Legacy).unwrap(); + let arr = result.as_primitive::(); + assert_eq!(arr.value(0), max_val); + } +} From d7495bdec1626d167aeca65680e30b82075c1f1e Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 2 Mar 2026 10:53:19 -0700 Subject: [PATCH 2/6] feat: add criterion benchmark for wide decimal binary expr Add benchmark comparing old Cast->BinaryExpr->Cast chain vs fused WideDecimalBinaryExpr for Decimal128 add/sub/mul. Covers four cases: add with same scale, add with different scales, multiply, and subtract. --- native/spark-expr/Cargo.toml | 4 + native/spark-expr/benches/wide_decimal.rs | 162 ++++++++++++++++++++++ 2 files changed, 166 insertions(+) create mode 100644 native/spark-expr/benches/wide_decimal.rs diff --git a/native/spark-expr/Cargo.toml b/native/spark-expr/Cargo.toml index e7c238f7eb..ccc0859f1a 100644 --- a/native/spark-expr/Cargo.toml +++ b/native/spark-expr/Cargo.toml @@ -103,3 +103,7 @@ path = "tests/spark_expr_reg.rs" [[bench]] name = "cast_from_boolean" harness = false + +[[bench]] +name = "wide_decimal" +harness = false diff --git a/native/spark-expr/benches/wide_decimal.rs b/native/spark-expr/benches/wide_decimal.rs new file mode 100644 index 0000000000..77d7a5c3c4 --- /dev/null +++ b/native/spark-expr/benches/wide_decimal.rs @@ -0,0 +1,162 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Benchmarks comparing the old Cast->BinaryExpr->Cast chain vs the fused WideDecimalBinaryExpr +//! for Decimal128 arithmetic that requires wider intermediate precision. + +use arrow::array::builder::Decimal128Builder; +use arrow::array::RecordBatch; +use arrow::datatypes::{DataType, Field, Schema}; +use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; +use datafusion::logical_expr::Operator; +use datafusion::physical_expr::expressions::{BinaryExpr, Column}; +use datafusion::physical_expr::PhysicalExpr; +use datafusion_comet_spark_expr::{ + Cast, EvalMode, SparkCastOptions, WideDecimalBinaryExpr, WideDecimalOp, +}; +use std::sync::Arc; + +const BATCH_SIZE: usize = 8192; + +/// Build a RecordBatch with two Decimal128 columns. +fn make_decimal_batch(p1: u8, s1: i8, p2: u8, s2: i8) -> RecordBatch { + let mut left = Decimal128Builder::new(); + let mut right = Decimal128Builder::new(); + for i in 0..BATCH_SIZE as i128 { + left.append_value(123456789012345_i128 + i * 1000); + right.append_value(987654321098765_i128 - i * 1000); + } + let left = left.finish().with_data_type(DataType::Decimal128(p1, s1)); + let right = right.finish().with_data_type(DataType::Decimal128(p2, s2)); + let schema = Schema::new(vec![ + Field::new("left", DataType::Decimal128(p1, s1), false), + Field::new("right", DataType::Decimal128(p2, s2), false), + ]); + RecordBatch::try_new(Arc::new(schema), vec![Arc::new(left), Arc::new(right)]).unwrap() +} + +/// Old approach: Cast(Decimal128->Decimal256) both sides, BinaryExpr, Cast(Decimal256->Decimal128). +fn build_old_expr( + p1: u8, + s1: i8, + p2: u8, + s2: i8, + op: Operator, + out_type: DataType, +) -> Arc { + let left_col: Arc = Arc::new(Column::new("left", 0)); + let right_col: Arc = Arc::new(Column::new("right", 1)); + let cast_opts = SparkCastOptions::new_without_timezone(EvalMode::Legacy, false); + let left_cast = Arc::new(Cast::new( + left_col, + DataType::Decimal256(p1, s1), + cast_opts.clone(), + )); + let right_cast = Arc::new(Cast::new( + right_col, + DataType::Decimal256(p2, s2), + cast_opts.clone(), + )); + let binary = Arc::new(BinaryExpr::new(left_cast, op, right_cast)); + Arc::new(Cast::new(binary, out_type, cast_opts)) +} + +/// New approach: single fused WideDecimalBinaryExpr. +fn build_new_expr(op: WideDecimalOp, p_out: u8, s_out: i8) -> Arc { + let left_col: Arc = Arc::new(Column::new("left", 0)); + let right_col: Arc = Arc::new(Column::new("right", 1)); + Arc::new(WideDecimalBinaryExpr::new( + left_col, + right_col, + op, + p_out, + s_out, + EvalMode::Legacy, + )) +} + +fn bench_case( + group: &mut criterion::BenchmarkGroup, + name: &str, + batch: &RecordBatch, + old_expr: &Arc, + new_expr: &Arc, +) { + group.bench_with_input(BenchmarkId::new("old", name), batch, |b, batch| { + b.iter(|| old_expr.evaluate(batch).unwrap()); + }); + group.bench_with_input(BenchmarkId::new("fused", name), batch, |b, batch| { + b.iter(|| new_expr.evaluate(batch).unwrap()); + }); +} + +fn criterion_benchmark(c: &mut Criterion) { + let mut group = c.benchmark_group("wide_decimal"); + + // Case 1: Add with same scale - Decimal128(38,10) + Decimal128(38,10) -> Decimal128(38,10) + // Triggers wide path because max(s1,s2) + max(p1-s1, p2-s2) = 10 + 28 = 38 >= 38 + { + let batch = make_decimal_batch(38, 10, 38, 10); + let old = build_old_expr(38, 10, 38, 10, Operator::Plus, DataType::Decimal128(38, 10)); + let new = build_new_expr(WideDecimalOp::Add, 38, 10); + bench_case(&mut group, "add_same_scale", &batch, &old, &new); + } + + // Case 2: Add with different scales - Decimal128(38,6) + Decimal128(38,4) -> Decimal128(38,6) + { + let batch = make_decimal_batch(38, 6, 38, 4); + let old = build_old_expr(38, 6, 38, 4, Operator::Plus, DataType::Decimal128(38, 6)); + let new = build_new_expr(WideDecimalOp::Add, 38, 6); + bench_case(&mut group, "add_diff_scale", &batch, &old, &new); + } + + // Case 3: Multiply - Decimal128(20,10) * Decimal128(20,10) -> Decimal128(38,6) + // Triggers wide path because p1 + p2 = 40 >= 38 + { + let batch = make_decimal_batch(20, 10, 20, 10); + let old = build_old_expr( + 20, + 10, + 20, + 10, + Operator::Multiply, + DataType::Decimal128(38, 6), + ); + let new = build_new_expr(WideDecimalOp::Multiply, 38, 6); + bench_case(&mut group, "multiply", &batch, &old, &new); + } + + // Case 4: Subtract with same scale - Decimal128(38,18) - Decimal128(38,18) -> Decimal128(38,18) + { + let batch = make_decimal_batch(38, 18, 38, 18); + let old = build_old_expr( + 38, + 18, + 38, + 18, + Operator::Minus, + DataType::Decimal128(38, 18), + ); + let new = build_new_expr(WideDecimalOp::Subtract, 38, 18); + bench_case(&mut group, "subtract", &batch, &old, &new); + } + + group.finish(); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); From 91092a68e949fe9e37f72c4fe821024247441ffd Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 2 Mar 2026 11:16:40 -0700 Subject: [PATCH 3/6] feat: fuse CheckOverflow with Cast and WideDecimalBinaryExpr MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Eliminate redundant CheckOverflow when wrapping WideDecimalBinaryExpr (which already handles overflow). Fuse Cast(Decimal128→Decimal128) + CheckOverflow into a single DecimalRescaleCheckOverflow expression that rescales and validates precision in one pass. --- native/core/src/execution/planner.rs | 35 +- native/spark-expr/src/lib.rs | 4 +- .../internal/decimal_rescale_check.rs | 431 ++++++++++++++++++ .../spark-expr/src/math_funcs/internal/mod.rs | 2 + 4 files changed, 466 insertions(+), 6 deletions(-) create mode 100644 native/spark-expr/src/math_funcs/internal/decimal_rescale_check.rs diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index a56ac99ce1..2d40160716 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -126,9 +126,9 @@ use datafusion_comet_proto::{ use datafusion_comet_spark_expr::monotonically_increasing_id::MonotonicallyIncreasingId; use datafusion_comet_spark_expr::{ ArrayInsert, Avg, AvgDecimal, Cast, CheckOverflow, Correlation, Covariance, CreateNamedStruct, - GetArrayStructFields, GetStructField, IfExpr, ListExtract, NormalizeNaNAndZero, RandExpr, - RandnExpr, SparkCastOptions, Stddev, SumDecimal, ToJson, UnboundColumn, Variance, - WideDecimalBinaryExpr, WideDecimalOp, + DecimalRescaleCheckOverflow, GetArrayStructFields, GetStructField, IfExpr, ListExtract, + NormalizeNaNAndZero, RandExpr, RandnExpr, SparkCastOptions, Stddev, SumDecimal, ToJson, + UnboundColumn, Variance, WideDecimalBinaryExpr, WideDecimalOp, }; use itertools::Itertools; use jni::objects::GlobalRef; @@ -377,10 +377,37 @@ impl PhysicalPlanner { ))) } ExprStruct::CheckOverflow(expr) => { - let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema)?; + let child = + self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&input_schema))?; let data_type = to_arrow_datatype(expr.datatype.as_ref().unwrap()); let fail_on_error = expr.fail_on_error; + // WideDecimalBinaryExpr already handles overflow — skip redundant check + if child + .as_any() + .downcast_ref::() + .is_some() + { + return Ok(child); + } + + // Fuse Cast(Decimal128→Decimal128) + CheckOverflow into single rescale+check + if let Some(cast) = child.as_any().downcast_ref::() { + if let ( + DataType::Decimal128(p_out, s_out), + Ok(DataType::Decimal128(_p_in, s_in)), + ) = (&data_type, cast.child.data_type(&input_schema)) + { + return Ok(Arc::new(DecimalRescaleCheckOverflow::new( + Arc::clone(&cast.child), + s_in, + *p_out, + *s_out, + fail_on_error, + ))); + } + } + Ok(Arc::new(CheckOverflow::new( child, data_type, diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs index 47f0c31c1b..49c8755cc9 100644 --- a/native/spark-expr/src/lib.rs +++ b/native/spark-expr/src/lib.rs @@ -81,8 +81,8 @@ pub use json_funcs::{FromJson, ToJson}; pub use math_funcs::{ create_modulo_expr, create_negate_expr, spark_ceil, spark_decimal_div, spark_decimal_integral_div, spark_floor, spark_make_decimal, spark_round, spark_unhex, - spark_unscaled_value, CheckOverflow, NegativeExpr, NormalizeNaNAndZero, WideDecimalBinaryExpr, - WideDecimalOp, + spark_unscaled_value, CheckOverflow, DecimalRescaleCheckOverflow, NegativeExpr, + NormalizeNaNAndZero, WideDecimalBinaryExpr, WideDecimalOp, }; pub use string_funcs::*; diff --git a/native/spark-expr/src/math_funcs/internal/decimal_rescale_check.rs b/native/spark-expr/src/math_funcs/internal/decimal_rescale_check.rs new file mode 100644 index 0000000000..ef20af6d5b --- /dev/null +++ b/native/spark-expr/src/math_funcs/internal/decimal_rescale_check.rs @@ -0,0 +1,431 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Fused decimal rescale + overflow check expression. +//! +//! Replaces the pattern `CheckOverflow(Cast(expr, Decimal128(p2,s2)), Decimal128(p2,s2))` +//! with a single expression that rescales and validates precision in one pass. + +use arrow::array::{as_primitive_array, Array, ArrayRef, Decimal128Array}; +use arrow::datatypes::{DataType, Decimal128Type, Schema}; +use arrow::error::ArrowError; +use arrow::record_batch::RecordBatch; +use datafusion::common::{DataFusionError, ScalarValue}; +use datafusion::logical_expr::ColumnarValue; +use datafusion::physical_expr::PhysicalExpr; +use std::hash::Hash; +use std::{ + any::Any, + fmt::{Display, Formatter}, + sync::Arc, +}; + +/// A fused expression that rescales a Decimal128 value (changing scale) and checks +/// for precision overflow in a single pass. Replaces the two-step +/// `CheckOverflow(Cast(expr, Decimal128(p,s)))` pattern. +#[derive(Debug, Eq)] +pub struct DecimalRescaleCheckOverflow { + child: Arc, + input_scale: i8, + output_precision: u8, + output_scale: i8, + fail_on_error: bool, +} + +impl Hash for DecimalRescaleCheckOverflow { + fn hash(&self, state: &mut H) { + self.child.hash(state); + self.input_scale.hash(state); + self.output_precision.hash(state); + self.output_scale.hash(state); + self.fail_on_error.hash(state); + } +} + +impl PartialEq for DecimalRescaleCheckOverflow { + fn eq(&self, other: &Self) -> bool { + self.child.eq(&other.child) + && self.input_scale == other.input_scale + && self.output_precision == other.output_precision + && self.output_scale == other.output_scale + && self.fail_on_error == other.fail_on_error + } +} + +impl DecimalRescaleCheckOverflow { + pub fn new( + child: Arc, + input_scale: i8, + output_precision: u8, + output_scale: i8, + fail_on_error: bool, + ) -> Self { + Self { + child, + input_scale, + output_precision, + output_scale, + fail_on_error, + } + } +} + +impl Display for DecimalRescaleCheckOverflow { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "DecimalRescaleCheckOverflow [child: {}, input_scale: {}, output: Decimal128({}, {}), fail_on_error: {}]", + self.child, self.input_scale, self.output_precision, self.output_scale, self.fail_on_error + ) + } +} + +/// Maximum absolute value for a given decimal precision: 10^p - 1. +#[inline] +fn precision_bound(precision: u8) -> i128 { + 10i128.pow(precision as u32) - 1 +} + +/// Rescale a single i128 value by the given delta (output_scale - input_scale) +/// and check precision bounds. Returns `Ok(value)` or `Ok(i128::MAX)` as sentinel +/// for overflow in legacy mode, or `Err` in ANSI mode. +#[inline] +fn rescale_and_check( + value: i128, + delta: i8, + scale_factor: i128, + bound: i128, + fail_on_error: bool, +) -> Result { + let rescaled = if delta > 0 { + // Scale up: multiply. Check for overflow. + match value.checked_mul(scale_factor) { + Some(v) => v, + None => { + if fail_on_error { + return Err(ArrowError::ComputeError( + "Decimal overflow during rescale".to_string(), + )); + } + return Ok(i128::MAX); // sentinel + } + } + } else if delta < 0 { + // Scale down with HALF_UP rounding + // divisor = 10^(-delta), half = divisor / 2 + let divisor = scale_factor; // already 10^abs(delta) + let half = divisor / 2; + let sign = if value < 0 { -1i128 } else { 1i128 }; + (value + sign * half) / divisor + } else { + value + }; + + // Precision check + if rescaled.abs() > bound { + if fail_on_error { + return Err(ArrowError::ComputeError( + "Decimal overflow: value does not fit in precision".to_string(), + )); + } + Ok(i128::MAX) // sentinel for null_if_overflow_precision + } else { + Ok(rescaled) + } +} + +impl PhysicalExpr for DecimalRescaleCheckOverflow { + fn as_any(&self) -> &dyn Any { + self + } + + fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + Display::fmt(self, f) + } + + fn data_type(&self, _: &Schema) -> datafusion::common::Result { + Ok(DataType::Decimal128( + self.output_precision, + self.output_scale, + )) + } + + fn nullable(&self, _: &Schema) -> datafusion::common::Result { + Ok(true) + } + + fn evaluate(&self, batch: &RecordBatch) -> datafusion::common::Result { + let arg = self.child.evaluate(batch)?; + let delta = self.output_scale - self.input_scale; + let abs_delta = delta.unsigned_abs(); + let scale_factor = 10i128.pow(abs_delta as u32); + let bound = precision_bound(self.output_precision); + let fail_on_error = self.fail_on_error; + let p_out = self.output_precision; + let s_out = self.output_scale; + + match arg { + ColumnarValue::Array(array) + if matches!(array.data_type(), DataType::Decimal128(_, _)) => + { + let decimal_array = as_primitive_array::(&array); + + let result: Decimal128Array = + arrow::compute::kernels::arity::try_unary(decimal_array, |value| { + rescale_and_check(value, delta, scale_factor, bound, fail_on_error) + })?; + + let result = if !fail_on_error { + result.null_if_overflow_precision(p_out) + } else { + result + }; + + let result = result + .with_precision_and_scale(p_out, s_out) + .map(|a| Arc::new(a) as ArrayRef)?; + + Ok(ColumnarValue::Array(result)) + } + ColumnarValue::Scalar(ScalarValue::Decimal128(v, _precision, _scale)) => { + let new_v = v.and_then(|val| { + rescale_and_check(val, delta, scale_factor, bound, fail_on_error) + .ok() + .and_then(|r| if r == i128::MAX { None } else { Some(r) }) + }); + Ok(ColumnarValue::Scalar(ScalarValue::Decimal128( + new_v, p_out, s_out, + ))) + } + v => Err(DataFusionError::Execution(format!( + "DecimalRescaleCheckOverflow expects Decimal128, but found {v:?}" + ))), + } + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.child] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> datafusion::common::Result> { + Ok(Arc::new(DecimalRescaleCheckOverflow::new( + Arc::clone(&children[0]), + self.input_scale, + self.output_precision, + self.output_scale, + self.fail_on_error, + ))) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{AsArray, Decimal128Array}; + use arrow::datatypes::{Field, Schema}; + use arrow::record_batch::RecordBatch; + use datafusion::physical_expr::expressions::Column; + + fn make_batch(values: Vec>, precision: u8, scale: i8) -> RecordBatch { + let arr = + Decimal128Array::from(values).with_data_type(DataType::Decimal128(precision, scale)); + let schema = Schema::new(vec![Field::new("col", arr.data_type().clone(), true)]); + RecordBatch::try_new(Arc::new(schema), vec![Arc::new(arr)]).unwrap() + } + + fn eval_expr( + batch: &RecordBatch, + input_scale: i8, + output_precision: u8, + output_scale: i8, + fail_on_error: bool, + ) -> datafusion::common::Result { + let child: Arc = Arc::new(Column::new("col", 0)); + let expr = DecimalRescaleCheckOverflow::new( + child, + input_scale, + output_precision, + output_scale, + fail_on_error, + ); + match expr.evaluate(batch)? { + ColumnarValue::Array(arr) => Ok(arr), + _ => panic!("expected array"), + } + } + + #[test] + fn test_scale_up() { + // Decimal128(10,2) -> Decimal128(10,4): 1.50 (150) -> 1.5000 (15000) + let batch = make_batch(vec![Some(150), Some(-300)], 10, 2); + let result = eval_expr(&batch, 2, 10, 4, false).unwrap(); + let arr = result.as_primitive::(); + assert_eq!(arr.value(0), 15000); // 1.5000 + assert_eq!(arr.value(1), -30000); // -3.0000 + } + + #[test] + fn test_scale_down_with_half_up_rounding() { + // Decimal128(10,4) -> Decimal128(10,2) + // 1.2350 (12350) -> round to 1.24 (124) with HALF_UP + // 1.2349 (12349) -> round to 1.23 (123) with HALF_UP + // -1.2350 (-12350) -> round to -1.24 (-124) with HALF_UP + let batch = make_batch(vec![Some(12350), Some(12349), Some(-12350)], 10, 4); + let result = eval_expr(&batch, 4, 10, 2, false).unwrap(); + let arr = result.as_primitive::(); + assert_eq!(arr.value(0), 124); // 1.24 + assert_eq!(arr.value(1), 123); // 1.23 + assert_eq!(arr.value(2), -124); // -1.24 + } + + #[test] + fn test_same_scale_precision_check_only() { + // Same scale, just check precision. Value 999 fits in precision 3, 1000 does not. + let batch = make_batch(vec![Some(999), Some(1000)], 38, 0); + let result = eval_expr(&batch, 0, 3, 0, false).unwrap(); + let arr = result.as_primitive::(); + assert_eq!(arr.value(0), 999); + assert!(arr.is_null(1)); // overflow -> null in legacy mode + } + + #[test] + fn test_overflow_null_in_legacy_mode() { + // Scale up causes overflow: 10^37 * 100 > i128::MAX range for precision 38 + // Use precision 3, value 10 (which is 10 at scale 0), scale up to scale 2 -> 1000, which overflows precision 3 + let batch = make_batch(vec![Some(10)], 38, 0); + let result = eval_expr(&batch, 0, 3, 2, false).unwrap(); + let arr = result.as_primitive::(); + assert!(arr.is_null(0)); // 10 * 100 = 1000 > 999 (max for precision 3) + } + + #[test] + fn test_overflow_error_in_ansi_mode() { + let batch = make_batch(vec![Some(10)], 38, 0); + let result = eval_expr(&batch, 0, 3, 2, true); + assert!(result.is_err()); + } + + #[test] + fn test_null_propagation() { + let batch = make_batch(vec![Some(100), None, Some(200)], 10, 2); + let result = eval_expr(&batch, 2, 10, 4, false).unwrap(); + let arr = result.as_primitive::(); + assert!(!arr.is_null(0)); + assert!(arr.is_null(1)); + assert!(!arr.is_null(2)); + } + + #[test] + fn test_scalar_path() { + let schema = Schema::new(vec![Field::new("col", DataType::Decimal128(10, 2), true)]); + let batch = RecordBatch::new_empty(Arc::new(schema)); + + let scalar_expr = DecimalRescaleCheckOverflow::new( + Arc::new(ScalarChild(Some(150), 10, 2)), + 2, + 10, + 4, + false, + ); + let result = scalar_expr.evaluate(&batch).unwrap(); + match result { + ColumnarValue::Scalar(ScalarValue::Decimal128(v, p, s)) => { + assert_eq!(v, Some(15000)); + assert_eq!(p, 10); + assert_eq!(s, 4); + } + _ => panic!("expected decimal scalar"), + } + } + + /// Helper expression that always returns a Decimal128 scalar. + #[derive(Debug, Eq, PartialEq, Hash)] + struct ScalarChild(Option, u8, i8); + + impl Display for ScalarChild { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "ScalarChild({:?})", self.0) + } + } + + impl PhysicalExpr for ScalarChild { + fn as_any(&self) -> &dyn Any { + self + } + fn data_type(&self, _: &Schema) -> datafusion::common::Result { + Ok(DataType::Decimal128(self.1, self.2)) + } + fn nullable(&self, _: &Schema) -> datafusion::common::Result { + Ok(true) + } + fn evaluate(&self, _batch: &RecordBatch) -> datafusion::common::Result { + Ok(ColumnarValue::Scalar(ScalarValue::Decimal128( + self.0, self.1, self.2, + ))) + } + fn children(&self) -> Vec<&Arc> { + vec![] + } + fn with_new_children( + self: Arc, + _children: Vec>, + ) -> datafusion::common::Result> { + Ok(self) + } + fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + Display::fmt(self, f) + } + } + + #[test] + fn test_scalar_null() { + let schema = Schema::new(vec![Field::new("col", DataType::Decimal128(10, 2), true)]); + let batch = RecordBatch::new_empty(Arc::new(schema)); + let expr = + DecimalRescaleCheckOverflow::new(Arc::new(ScalarChild(None, 10, 2)), 2, 10, 4, false); + let result = expr.evaluate(&batch).unwrap(); + match result { + ColumnarValue::Scalar(ScalarValue::Decimal128(v, _, _)) => { + assert_eq!(v, None); + } + _ => panic!("expected decimal scalar"), + } + } + + #[test] + fn test_scalar_overflow_legacy() { + let schema = Schema::new(vec![Field::new("col", DataType::Decimal128(38, 0), true)]); + let batch = RecordBatch::new_empty(Arc::new(schema)); + let expr = DecimalRescaleCheckOverflow::new( + Arc::new(ScalarChild(Some(10), 38, 0)), + 0, + 3, + 2, + false, + ); + let result = expr.evaluate(&batch).unwrap(); + match result { + ColumnarValue::Scalar(ScalarValue::Decimal128(v, _, _)) => { + assert_eq!(v, None); // 10 * 100 = 1000 > 999 + } + _ => panic!("expected decimal scalar"), + } + } +} diff --git a/native/spark-expr/src/math_funcs/internal/mod.rs b/native/spark-expr/src/math_funcs/internal/mod.rs index 29295f0d52..dff26146e8 100644 --- a/native/spark-expr/src/math_funcs/internal/mod.rs +++ b/native/spark-expr/src/math_funcs/internal/mod.rs @@ -16,11 +16,13 @@ // under the License. mod checkoverflow; +mod decimal_rescale_check; mod make_decimal; mod normalize_nan; mod unscaled_value; pub use checkoverflow::CheckOverflow; +pub use decimal_rescale_check::DecimalRescaleCheckOverflow; pub use make_decimal::spark_make_decimal; pub use normalize_nan::NormalizeNaNAndZero; pub use unscaled_value::spark_unscaled_value; From 35241bcb7144a0f68625127881885ffb9b5653ed Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 5 Mar 2026 15:26:47 -0700 Subject: [PATCH 4/6] fix: address PR review feedback for decimal optimizations - Handle scale-up when s_out > max(s1, s2) in add/subtract - Propagate errors in scalar path when fail_on_error=true - Guard against large scale delta (>38) overflow in rescale - Assert precision <= 38 in precision_bound - Assert exp <= 76 in i256_pow10 - Remove unnecessary _ prefix on used variables in planner - Use value.signum() instead of manual sign check - Verify Cast target type matches before fusing with CheckOverflow - Validate children count in with_new_children for both expressions - Add tests for scale-up, scalar error propagation, and large delta --- native/core/src/execution/planner.rs | 26 ++++---- .../internal/decimal_rescale_check.rs | 59 ++++++++++++++++-- .../math_funcs/wide_decimal_binary_expr.rs | 61 +++++++++++++++++-- 3 files changed, 124 insertions(+), 22 deletions(-) diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 2d40160716..4c10e0a343 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -392,19 +392,23 @@ impl PhysicalPlanner { } // Fuse Cast(Decimal128→Decimal128) + CheckOverflow into single rescale+check + // Only fuse when the Cast target type matches the CheckOverflow output type if let Some(cast) = child.as_any().downcast_ref::() { if let ( DataType::Decimal128(p_out, s_out), Ok(DataType::Decimal128(_p_in, s_in)), ) = (&data_type, cast.child.data_type(&input_schema)) { - return Ok(Arc::new(DecimalRescaleCheckOverflow::new( - Arc::clone(&cast.child), - s_in, - *p_out, - *s_out, - fail_on_error, - ))); + let cast_target = cast.data_type(&input_schema)?; + if cast_target == data_type { + return Ok(Arc::new(DecimalRescaleCheckOverflow::new( + Arc::clone(&cast.child), + s_in, + *p_out, + *s_out, + fail_on_error, + ))); + } } } @@ -702,13 +706,13 @@ impl PhysicalPlanner { ) { ( DataFusionOperator::Plus | DataFusionOperator::Minus | DataFusionOperator::Multiply, - Ok(DataType::Decimal128(_p1, _s1)), - Ok(DataType::Decimal128(_p2, _s2)), + Ok(DataType::Decimal128(p1, s1)), + Ok(DataType::Decimal128(p2, s2)), ) if ((op == DataFusionOperator::Plus || op == DataFusionOperator::Minus) - && max(_s1, _s2) as u8 + max(_p1 - _s1 as u8, _p2 - _s2 as u8) + && max(s1, s2) as u8 + max(p1 - s1 as u8, p2 - s2 as u8) >= DECIMAL128_MAX_PRECISION) || (op == DataFusionOperator::Multiply - && _p1 + _p2 >= DECIMAL128_MAX_PRECISION) => + && p1 + p2 >= DECIMAL128_MAX_PRECISION) => { let data_type = return_type.map(to_arrow_datatype).unwrap(); let (p_out, s_out) = match &data_type { diff --git a/native/spark-expr/src/math_funcs/internal/decimal_rescale_check.rs b/native/spark-expr/src/math_funcs/internal/decimal_rescale_check.rs index ef20af6d5b..36af22849e 100644 --- a/native/spark-expr/src/math_funcs/internal/decimal_rescale_check.rs +++ b/native/spark-expr/src/math_funcs/internal/decimal_rescale_check.rs @@ -95,8 +95,13 @@ impl Display for DecimalRescaleCheckOverflow { } /// Maximum absolute value for a given decimal precision: 10^p - 1. +/// Precision must be <= 38 (max for Decimal128). #[inline] fn precision_bound(precision: u8) -> i128 { + assert!( + precision <= 38, + "precision_bound: precision {precision} exceeds maximum 38" + ); 10i128.pow(precision as u32) - 1 } @@ -129,7 +134,7 @@ fn rescale_and_check( // divisor = 10^(-delta), half = divisor / 2 let divisor = scale_factor; // already 10^abs(delta) let half = divisor / 2; - let sign = if value < 0 { -1i128 } else { 1i128 }; + let sign = value.signum(); (value + sign * half) / divisor } else { value @@ -172,6 +177,14 @@ impl PhysicalExpr for DecimalRescaleCheckOverflow { let arg = self.child.evaluate(batch)?; let delta = self.output_scale - self.input_scale; let abs_delta = delta.unsigned_abs(); + // If abs_delta > 38, the scale factor overflows i128. In that case, + // any non-zero value will overflow the output precision, so we treat + // it as an immediate overflow condition. + if abs_delta > 38 { + return Err(DataFusionError::Execution(format!( + "DecimalRescaleCheckOverflow: scale delta {delta} exceeds maximum supported range" + ))); + } let scale_factor = 10i128.pow(abs_delta as u32); let bound = precision_bound(self.output_precision); let fail_on_error = self.fail_on_error; @@ -202,11 +215,14 @@ impl PhysicalExpr for DecimalRescaleCheckOverflow { Ok(ColumnarValue::Array(result)) } ColumnarValue::Scalar(ScalarValue::Decimal128(v, _precision, _scale)) => { - let new_v = v.and_then(|val| { - rescale_and_check(val, delta, scale_factor, bound, fail_on_error) - .ok() - .and_then(|r| if r == i128::MAX { None } else { Some(r) }) - }); + let new_v = match v { + Some(val) => { + let r = rescale_and_check(val, delta, scale_factor, bound, fail_on_error) + .map_err(|e| DataFusionError::ArrowError(Box::new(e), None))?; + if r == i128::MAX { None } else { Some(r) } + } + None => None, + }; Ok(ColumnarValue::Scalar(ScalarValue::Decimal128( new_v, p_out, s_out, ))) @@ -225,6 +241,12 @@ impl PhysicalExpr for DecimalRescaleCheckOverflow { self: Arc, children: Vec>, ) -> datafusion::common::Result> { + if children.len() != 1 { + return Err(DataFusionError::Internal(format!( + "DecimalRescaleCheckOverflow expects 1 child, got {}", + children.len() + ))); + } Ok(Arc::new(DecimalRescaleCheckOverflow::new( Arc::clone(&children[0]), self.input_scale, @@ -428,4 +450,29 @@ mod tests { _ => panic!("expected decimal scalar"), } } + + #[test] + fn test_scalar_overflow_ansi_returns_error() { + // fail_on_error=true must propagate the error, not silently return None + let schema = Schema::new(vec![Field::new("col", DataType::Decimal128(38, 0), true)]); + let batch = RecordBatch::new_empty(Arc::new(schema)); + let expr = DecimalRescaleCheckOverflow::new( + Arc::new(ScalarChild(Some(10), 38, 0)), + 0, + 3, + 2, + true, // fail_on_error = true + ); + let result = expr.evaluate(&batch); + assert!(result.is_err()); // must be error, not Ok(None) + } + + #[test] + fn test_large_scale_delta_returns_error() { + // delta = output_scale - input_scale = 38 - (-1) = 39 + // 10i128.pow(39) would overflow, so we must reject gracefully + let batch = make_batch(vec![Some(1)], 38, -1); + let result = eval_expr(&batch, -1, 38, 38, false); + assert!(result.is_err()); + } } diff --git a/native/spark-expr/src/math_funcs/wide_decimal_binary_expr.rs b/native/spark-expr/src/math_funcs/wide_decimal_binary_expr.rs index a1d5c92d3d..7c516c0989 100644 --- a/native/spark-expr/src/math_funcs/wide_decimal_binary_expr.rs +++ b/native/spark-expr/src/math_funcs/wide_decimal_binary_expr.rs @@ -146,9 +146,10 @@ fn div_round_half_up(value: i256, divisor: i256) -> i256 { /// i256 constant for 10. const I256_TEN: i256 = i256::from_i128(10); -/// Compute 10^exp as i256. +/// Compute 10^exp as i256. Panics if exp > 76 (max representable power of 10 in i256). #[inline] fn i256_pow10(exp: u32) -> i256 { + assert!(exp <= 76, "i256_pow10: exponent {exp} exceeds maximum 76"); let mut result = i256::ONE; for _ in 0..exp { result = result.wrapping_mul(I256_TEN); @@ -212,9 +213,16 @@ impl PhysicalExpr for WideDecimalBinaryExpr { let max_scale = std::cmp::max(s1, s2); let l_scale_up = i256_pow10((max_scale - s1) as u32); let r_scale_up = i256_pow10((max_scale - s2) as u32); - let need_rescale = s_out < max_scale; - let rescale_divisor = if need_rescale { - i256_pow10((max_scale - s_out) as u32) + // After add/sub at max_scale, we may need to rescale to s_out + let scale_diff = max_scale as i16 - s_out as i16; + let (need_scale_down, need_scale_up) = (scale_diff > 0, scale_diff < 0); + let rescale_divisor = if need_scale_down { + i256_pow10(scale_diff as u32) + } else { + i256::ONE + }; + let scale_up_factor = if need_scale_up { + i256_pow10((-scale_diff) as u32) } else { i256::ONE }; @@ -227,8 +235,10 @@ impl PhysicalExpr for WideDecimalBinaryExpr { WideDecimalOp::Subtract => l256.wrapping_sub(r256), _ => unreachable!(), }; - let result = if need_rescale { + let result = if need_scale_down { div_round_half_up(raw, rescale_divisor) + } else if need_scale_up { + raw.wrapping_mul(scale_up_factor) } else { raw }; @@ -273,6 +283,12 @@ impl PhysicalExpr for WideDecimalBinaryExpr { self: Arc, children: Vec>, ) -> Result> { + if children.len() != 2 { + return Err(datafusion::common::DataFusionError::Internal(format!( + "WideDecimalBinaryExpr expects 2 children, got {}", + children.len() + ))); + } Ok(Arc::new(WideDecimalBinaryExpr::new( Arc::clone(&children[0]), Arc::clone(&children[1]), @@ -499,4 +515,39 @@ mod tests { let arr = result.as_primitive::(); assert_eq!(arr.value(0), max_val); } + + #[test] + fn test_add_scale_up_to_output() { + // When s_out > max(s1, s2), result must be scaled UP + // Decimal128(10, 2) + Decimal128(10, 2) with output scale 4 + // 1.50 + 0.25 = 1.75, at scale 4 = 17500 + let batch = make_batch( + vec![Some(150)], // 1.50 + 10, + 2, + vec![Some(25)], // 0.25 + 10, + 2, + ); + let result = eval_expr(&batch, WideDecimalOp::Add, 38, 4, EvalMode::Legacy).unwrap(); + let arr = result.as_primitive::(); + assert_eq!(arr.value(0), 17500); // 1.7500 + } + + #[test] + fn test_subtract_scale_up_to_output() { + // s_out (4) > max(s1, s2) (2) — verify scale-up path for subtract + let batch = make_batch( + vec![Some(300)], // 3.00 + 10, + 2, + vec![Some(100)], // 1.00 + 10, + 2, + ); + let result = + eval_expr(&batch, WideDecimalOp::Subtract, 38, 4, EvalMode::Legacy).unwrap(); + let arr = result.as_primitive::(); + assert_eq!(arr.value(0), 20000); // 2.0000 + } } From 1141b399dbdb2abffc78645059c0c6e316d79bd9 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 5 Mar 2026 15:28:06 -0700 Subject: [PATCH 5/6] style: apply cargo fmt --- native/core/src/execution/planner.rs | 3 +-- .../src/math_funcs/internal/decimal_rescale_check.rs | 6 +++++- .../spark-expr/src/math_funcs/wide_decimal_binary_expr.rs | 3 +-- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 4c10e0a343..8c84c91737 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -711,8 +711,7 @@ impl PhysicalPlanner { ) if ((op == DataFusionOperator::Plus || op == DataFusionOperator::Minus) && max(s1, s2) as u8 + max(p1 - s1 as u8, p2 - s2 as u8) >= DECIMAL128_MAX_PRECISION) - || (op == DataFusionOperator::Multiply - && p1 + p2 >= DECIMAL128_MAX_PRECISION) => + || (op == DataFusionOperator::Multiply && p1 + p2 >= DECIMAL128_MAX_PRECISION) => { let data_type = return_type.map(to_arrow_datatype).unwrap(); let (p_out, s_out) = match &data_type { diff --git a/native/spark-expr/src/math_funcs/internal/decimal_rescale_check.rs b/native/spark-expr/src/math_funcs/internal/decimal_rescale_check.rs index 36af22849e..1322404951 100644 --- a/native/spark-expr/src/math_funcs/internal/decimal_rescale_check.rs +++ b/native/spark-expr/src/math_funcs/internal/decimal_rescale_check.rs @@ -219,7 +219,11 @@ impl PhysicalExpr for DecimalRescaleCheckOverflow { Some(val) => { let r = rescale_and_check(val, delta, scale_factor, bound, fail_on_error) .map_err(|e| DataFusionError::ArrowError(Box::new(e), None))?; - if r == i128::MAX { None } else { Some(r) } + if r == i128::MAX { + None + } else { + Some(r) + } } None => None, }; diff --git a/native/spark-expr/src/math_funcs/wide_decimal_binary_expr.rs b/native/spark-expr/src/math_funcs/wide_decimal_binary_expr.rs index 7c516c0989..3dfb6bf046 100644 --- a/native/spark-expr/src/math_funcs/wide_decimal_binary_expr.rs +++ b/native/spark-expr/src/math_funcs/wide_decimal_binary_expr.rs @@ -545,8 +545,7 @@ mod tests { 10, 2, ); - let result = - eval_expr(&batch, WideDecimalOp::Subtract, 38, 4, EvalMode::Legacy).unwrap(); + let result = eval_expr(&batch, WideDecimalOp::Subtract, 38, 4, EvalMode::Legacy).unwrap(); let arr = result.as_primitive::(); assert_eq!(arr.value(0), 20000); // 2.0000 } From 02d6a3bc0764a70a74cc3ffe5e8b0e479e575103 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 5 Mar 2026 15:39:20 -0700 Subject: [PATCH 6/6] fix: add defensive checks for CheckOverflow bypass and multiply scale-up - Validate WideDecimalBinaryExpr output type matches CheckOverflow data_type before bypassing the overflow check - Handle s_out > natural_scale (scale-up) in multiply path for consistency with add/subtract --- native/core/src/execution/planner.rs | 6 +++++- .../src/math_funcs/wide_decimal_binary_expr.rs | 16 ++++++++++++---- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 8c84c91737..a752512416 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -383,12 +383,16 @@ impl PhysicalPlanner { let fail_on_error = expr.fail_on_error; // WideDecimalBinaryExpr already handles overflow — skip redundant check + // but only if its output type matches CheckOverflow's declared type if child .as_any() .downcast_ref::() .is_some() { - return Ok(child); + let child_type = child.data_type(&input_schema)?; + if child_type == data_type { + return Ok(child); + } } // Fuse Cast(Decimal128→Decimal128) + CheckOverflow into single rescale+check diff --git a/native/spark-expr/src/math_funcs/wide_decimal_binary_expr.rs b/native/spark-expr/src/math_funcs/wide_decimal_binary_expr.rs index 3dfb6bf046..644252b46b 100644 --- a/native/spark-expr/src/math_funcs/wide_decimal_binary_expr.rs +++ b/native/spark-expr/src/math_funcs/wide_decimal_binary_expr.rs @@ -247,17 +247,25 @@ impl PhysicalExpr for WideDecimalBinaryExpr { } WideDecimalOp::Multiply => { let natural_scale = s1 + s2; - let need_rescale = s_out < natural_scale; - let rescale_divisor = if need_rescale { - i256_pow10((natural_scale - s_out) as u32) + let scale_diff = natural_scale as i16 - s_out as i16; + let (need_scale_down, need_scale_up) = (scale_diff > 0, scale_diff < 0); + let rescale_divisor = if need_scale_down { + i256_pow10(scale_diff as u32) + } else { + i256::ONE + }; + let scale_up_factor = if need_scale_up { + i256_pow10((-scale_diff) as u32) } else { i256::ONE }; arrow::compute::kernels::arity::try_binary(left, right, |l, r| { let raw = i256::from_i128(l).wrapping_mul(i256::from_i128(r)); - let result = if need_rescale { + let result = if need_scale_down { div_round_half_up(raw, rescale_divisor) + } else if need_scale_up { + raw.wrapping_mul(scale_up_factor) } else { raw };