diff --git a/Cargo.lock b/Cargo.lock index 8101e2fe22d9..7c4ccb6ea7a0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2609,6 +2609,7 @@ dependencies = [ "datafusion-functions-aggregate", "datafusion-functions-nested", "log", + "num-traits", "percent-encoding", "rand 0.9.2", "serde_json", diff --git a/datafusion/spark/Cargo.toml b/datafusion/spark/Cargo.toml index 162b6d814e80..1c0e8bad7823 100644 --- a/datafusion/spark/Cargo.toml +++ b/datafusion/spark/Cargo.toml @@ -57,6 +57,7 @@ datafusion-functions = { workspace = true, features = ["crypto_expressions"] } datafusion-functions-aggregate = { workspace = true } datafusion-functions-nested = { workspace = true } log = { workspace = true } +num-traits = { workspace = true } percent-encoding = "2.3.2" rand = { workspace = true } serde_json = { workspace = true } diff --git a/datafusion/spark/src/function/math/mod.rs b/datafusion/spark/src/function/math/mod.rs index 7f7d04e06b0b..be3f34d3323f 100644 --- a/datafusion/spark/src/function/math/mod.rs +++ b/datafusion/spark/src/function/math/mod.rs @@ -23,6 +23,7 @@ pub mod hex; pub mod modulus; pub mod negative; pub mod rint; +pub mod round; pub mod trigonometry; pub mod unhex; pub mod width_bucket; @@ -38,6 +39,7 @@ make_udf_function!(hex::SparkHex, hex); make_udf_function!(modulus::SparkMod, modulus); make_udf_function!(modulus::SparkPmod, pmod); make_udf_function!(rint::SparkRint, rint); +make_udf_function!(round::SparkRound, round); make_udf_function!(unhex::SparkUnhex, unhex); make_udf_function!(width_bucket::SparkWidthBucket, width_bucket); make_udf_function!(trigonometry::SparkCsc, csc); @@ -63,6 +65,11 @@ pub mod expr_fn { "Returns the double value that is closest in value to the argument and is equal to a mathematical integer.", arg1 )); + export_functions!(( + round, + "Rounds the value of expr to scale decimal places using HALF_UP rounding mode.", + arg1 arg2 + )); export_functions!((unhex, "Converts hexadecimal string to binary.", arg1)); export_functions!((width_bucket, "Returns the bucket number into which the value of this expression would fall after being evaluated.", arg1 arg2 arg3 arg4)); export_functions!((csc, "Returns the cosecant of expr.", arg1)); @@ -88,6 +95,7 @@ pub fn functions() -> Vec> { modulus(), pmod(), rint(), + round(), unhex(), width_bucket(), csc(), diff --git a/datafusion/spark/src/function/math/round.rs b/datafusion/spark/src/function/math/round.rs new file mode 100644 index 000000000000..c6d94f881def --- /dev/null +++ b/datafusion/spark/src/function/math/round.rs @@ -0,0 +1,659 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::*; +use arrow::datatypes::{ + ArrowNativeTypeOp, DataType, Decimal32Type, Decimal64Type, Decimal128Type, + Decimal256Type, Float16Type, Float32Type, Float64Type, Int8Type, Int16Type, + Int32Type, Int64Type, UInt8Type, UInt16Type, UInt32Type, UInt64Type, +}; +use datafusion_common::types::{ + NativeType, logical_float32, logical_float64, logical_int32, +}; +use datafusion_common::{Result, ScalarValue, exec_err, not_impl_err}; +use datafusion_expr::{ + Coercion, ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, + TypeSignatureClass, Volatility, +}; + +/// Spark-compatible `round` expression +/// +/// +/// Rounds the value of `expr` to `scale` decimal places using HALF_UP rounding mode. +/// Returns the same type as the input expression. +/// +/// - `round(expr)` rounds to 0 decimal places (default scale = 0) +/// - `round(expr, scale)` rounds to `scale` decimal places +/// - For integer types with negative scale: `round(25, -1)` → `30` +/// - Uses HALF_UP rounding: 2.5 → 3, -2.5 → -3 (away from zero) +/// +/// Supported types: Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64, +/// Float16, Float32, Float64, Decimal32, Decimal64, Decimal128, Decimal256 +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkRound { + signature: Signature, +} + +impl Default for SparkRound { + fn default() -> Self { + Self::new() + } +} + +impl SparkRound { + pub fn new() -> Self { + let decimal = Coercion::new_exact(TypeSignatureClass::Decimal); + let integer = Coercion::new_exact(TypeSignatureClass::Integer); + let decimal_places = Coercion::new_implicit( + TypeSignatureClass::Native(logical_int32()), + vec![TypeSignatureClass::Integer], + NativeType::Int32, + ); + let float32 = Coercion::new_exact(TypeSignatureClass::Native(logical_float32())); + let float64 = Coercion::new_implicit( + TypeSignatureClass::Native(logical_float64()), + vec![TypeSignatureClass::Numeric], + NativeType::Float64, + ); + Self { + signature: Signature::one_of( + vec![ + // round(decimal, scale) + TypeSignature::Coercible(vec![ + decimal.clone(), + decimal_places.clone(), + ]), + // round(decimal) + TypeSignature::Coercible(vec![decimal]), + // round(integer, scale) + TypeSignature::Coercible(vec![ + integer.clone(), + decimal_places.clone(), + ]), + // round(integer) + TypeSignature::Coercible(vec![integer]), + // round(float32, scale) + TypeSignature::Coercible(vec![ + float32.clone(), + decimal_places.clone(), + ]), + // round(float32) + TypeSignature::Coercible(vec![float32]), + // round(float64, scale) + TypeSignature::Coercible(vec![float64.clone(), decimal_places]), + // round(float64) + TypeSignature::Coercible(vec![float64]), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SparkRound { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "round" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(arg_types[0].clone()) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + spark_round(&args.args, args.config_options.execution.enable_ansi_mode) + } +} + +/// Extract the scale (decimal places) from the second argument. +/// Returns `Some(0)` if no second argument is provided. +/// Returns `None` if the scale argument is NULL (Spark returns NULL for `round(expr, NULL)`). +fn get_scale(args: &[ColumnarValue]) -> Result> { + if args.len() < 2 { + return Ok(Some(0)); + } + + match &args[1] { + ColumnarValue::Scalar(ScalarValue::Int8(Some(v))) => Ok(Some(i32::from(*v))), + ColumnarValue::Scalar(ScalarValue::Int16(Some(v))) => Ok(Some(i32::from(*v))), + ColumnarValue::Scalar(ScalarValue::Int32(Some(v))) => Ok(Some(*v)), + ColumnarValue::Scalar(ScalarValue::Int64(Some(v))) => { + i32::try_from(*v).map(Some).map_err(|_| { + (exec_err!("round scale {v} is out of supported i32 range") + as Result<(), _>) + .unwrap_err() + }) + } + ColumnarValue::Scalar(ScalarValue::UInt8(Some(v))) => Ok(Some(i32::from(*v))), + ColumnarValue::Scalar(ScalarValue::UInt16(Some(v))) => Ok(Some(i32::from(*v))), + ColumnarValue::Scalar(ScalarValue::UInt32(Some(v))) => { + i32::try_from(*v).map(Some).map_err(|_| { + (exec_err!("round scale {v} is out of supported i32 range") + as Result<(), _>) + .unwrap_err() + }) + } + ColumnarValue::Scalar(ScalarValue::UInt64(Some(v))) => { + i32::try_from(*v).map(Some).map_err(|_| { + (exec_err!("round scale {v} is out of supported i32 range") + as Result<(), _>) + .unwrap_err() + }) + } + ColumnarValue::Scalar(sv) if sv.is_null() => Ok(None), + other => exec_err!("Unsupported type for round scale: {}", other.data_type()), + } +} + +/// Round a floating-point value to the given number of decimal places using +/// HALF_UP rounding mode (ties round away from zero). +/// +/// This matches Spark's `RoundBase` behaviour for `FloatType` / `DoubleType`, +/// which internally converts the value to `BigDecimal` and rounds with +/// `RoundingMode.HALF_UP`. +/// +/// # Arguments +/// * `value` – the floating-point number to round +/// * `scale` – number of decimal places to keep. +/// - `scale >= 0`: rounds to that many fractional digits +/// (e.g. `round_float(2.345, 2) == 2.35`) +/// - `scale < 0`: rounds to the left of the decimal point +/// (e.g. `round_float(125.0, -1) == 130.0`) +/// +/// # Examples +/// ```text +/// round_float(2.5, 0) → 3.0 // half rounds up +/// round_float(-2.5, 0) → -3.0 // half rounds away from zero +/// round_float(1.4, 0) → 1.0 +/// round_float(125.0, -1) → 130.0 +/// ``` +fn round_float(value: T, scale: i32) -> T { + if scale >= 0 { + let factor = T::from(10.0f64.powi(scale)).unwrap_or_else(T::infinity); + if factor.is_infinite() { + // Very large positive scale — value is already precise enough, return as-is + return value; + } + (value * factor).round() / factor + } else { + let factor = T::from(10.0f64.powi(-scale)).unwrap_or_else(T::infinity); + if factor.is_infinite() { + // Very large negative scale — any finite value rounds to 0 + return T::zero(); + } + (value / factor).round() * factor + } +} + +/// Round an integer value to the given scale using HALF_UP rounding mode. +/// +/// Only meaningful when `scale` is negative — a non-negative scale leaves +/// the integer unchanged because integers have no fractional part. +/// +/// This matches Spark's `RoundBase` behaviour for `ByteType`, `ShortType`, +/// `IntegerType`, and `LongType`, which round to the nearest power-of-ten +/// boundary and return the same integer type. +/// +/// In ANSI mode, overflow conditions return an error instead of wrapping. +/// +/// # Arguments +/// * `value` – the integer to round (widened to `i64` by callers) +/// * `scale` – rounding position relative to the ones digit. +/// - `scale >= 0`: returns `value` as-is +/// - `scale == -1`: rounds to the nearest 10 +/// - `scale == -2`: rounds to the nearest 100 +/// - If `10^|scale|` overflows `i64`, returns `0` +/// * `enable_ansi_mode` – when true, overflow returns an error +/// +/// # Examples +/// ```text +/// round_integer(25, -1, false) → Ok(30) +/// round_integer(-25, -1, false) → Ok(-30) +/// round_integer(123, -1, false) → Ok(120) +/// round_integer(150, -2, false) → Ok(200) +/// round_integer(42, 2, false) → Ok(42) // no-op for positive scale +/// round_integer(42, -10, false) → Ok(0) // factor overflows → 0 +/// ``` +fn round_integer(value: i64, scale: i32, enable_ansi_mode: bool) -> Result { + if scale >= 0 { + return Ok(value); + } + let abs_scale = (-scale) as u32; + let Some(factor) = 10_i64.checked_pow(abs_scale) else { + return Ok(0); + }; + let remainder = value % factor; + let threshold = factor / 2; + let result = if remainder >= threshold { + if enable_ansi_mode { + value + .checked_sub(remainder) + .and_then(|v| v.checked_add(factor)) + .ok_or_else(|| { + (exec_err!("Int64 overflow on round({value}, {scale})") + as Result<(), _>) + .unwrap_err() + })? + } else { + value.wrapping_sub(remainder).wrapping_add(factor) + } + } else if remainder <= -threshold { + if enable_ansi_mode { + value + .checked_sub(remainder) + .and_then(|v| v.checked_sub(factor)) + .ok_or_else(|| { + (exec_err!("Int64 overflow on round({value}, {scale})") + as Result<(), _>) + .unwrap_err() + })? + } else { + value.wrapping_sub(remainder).wrapping_sub(factor) + } + } else { + value - remainder + }; + Ok(result) +} + +// --------------------------------------------------------------------------- +// Decimal rounding using ArrowNativeTypeOp (HALF_UP) +// --------------------------------------------------------------------------- + +/// Round a decimal value represented as its unscaled integer using HALF_UP +/// rounding mode (ties round away from zero). +/// +/// This matches Spark's `RoundBase` behaviour for `DecimalType`, which calls +/// `BigDecimal.setScale(scale, RoundingMode.HALF_UP)`. +/// +/// Decimals are stored as `(unscaled_value, precision, scale)` where the real +/// value equals `unscaled_value * 10^(-scale)`. This function operates on the +/// unscaled integer directly: +/// +/// 1. Compute `diff = input_scale - decimal_places`. +/// If `diff <= 0` the requested precision is finer than (or equal to) the +/// stored scale, so nothing needs to be rounded — return as-is. +/// 2. Divide by `10^diff` to shift the rounding boundary into the ones digit. +/// 3. Inspect the remainder to decide whether to round up or down (HALF_UP). +/// 4. Multiply back by `10^diff` so the result is expressed at the original +/// `input_scale`. +/// +/// # Arguments +/// * `value` – unscaled decimal value +/// * `input_scale` – scale of the incoming decimal +/// * `decimal_places` – number of fractional digits to keep (may be negative) +/// +/// # Returns +/// The rounded unscaled value at the same `input_scale`, or an error +/// on overflow. +/// +/// # Examples +/// ```text +/// // 2.5 (unscaled 25, scale 1) rounded to 0 places → 3.0 (unscaled 30) +/// round_decimal(25_i128, 1, 0) → Ok(30) +/// +/// // 2.345 (unscaled 2345, scale 3) rounded to 2 places → 2.350 (unscaled 2350) +/// round_decimal(2345_i128, 3, 2) → Ok(2350) +/// ``` +fn round_decimal( + value: V, + input_scale: i8, + decimal_places: i32, +) -> Result { + let diff = i64::from(input_scale) - i64::from(decimal_places); + if diff <= 0 { + // Nothing to round – the requested precision is finer than (or equal to) the + // stored scale. + return Ok(value); + } + + let diff = diff as u32; + + let one = V::ONE; + let two = V::from_usize(2).ok_or_else(|| { + (exec_err!("Internal error: could not create constant 2") as Result<(), _>) + .unwrap_err() + })?; + let ten = V::from_usize(10).ok_or_else(|| { + (exec_err!("Internal error: could not create constant 10") as Result<(), _>) + .unwrap_err() + })?; + + let Ok(factor) = ten.pow_checked(diff) else { + // 10^diff overflows the decimal type — the rounding position is beyond + // the representable range, so any value rounds to 0. + // This matches Spark's BigDecimal.setScale behavior where rounding to a + // scale far beyond the number's magnitude yields 0. + return Ok(V::ZERO); + }; + + let mut quotient = value.div_wrapping(factor); + let remainder = value.mod_wrapping(factor); + + // HALF_UP: round away from zero when remainder is exactly half + let threshold = factor.div_wrapping(two); + if remainder >= threshold { + quotient = quotient.add_checked(one).map_err(|_| { + (exec_err!("Overflow while rounding decimal") as Result<(), _>).unwrap_err() + })?; + } else if remainder <= threshold.neg_wrapping() { + quotient = quotient.sub_checked(one).map_err(|_| { + (exec_err!("Overflow while rounding decimal") as Result<(), _>).unwrap_err() + })?; + } + + // Re-scale the quotient back to `input_scale` so the returned unscaled integer is + // at the original scale. `factor` is already `10^diff` which is exactly the shift + // we need. + quotient.mul_checked(factor).map_err(|_| { + (exec_err!("Overflow while rounding decimal") as Result<(), _>).unwrap_err() + }) +} + +// --------------------------------------------------------------------------- +// Macros for array dispatch +// --------------------------------------------------------------------------- + +macro_rules! impl_integer_array_round { + ($array:expr, $arrow_type:ty, $scale:expr, $enable_ansi_mode:expr) => {{ + let array = $array.as_primitive::<$arrow_type>(); + type Native = <$arrow_type as arrow::datatypes::ArrowPrimitiveType>::Native; + let result: PrimitiveArray<$arrow_type> = if $enable_ansi_mode { + array.try_unary(|x| { + let v = round_integer(x as i64, $scale, true)?; + Native::try_from(v).map_err(|_| { + (exec_err!( + "{} overflow on round({x}, {})", + stringify!($arrow_type), + $scale + ) as Result<(), _>) + .unwrap_err() + }) + })? + } else { + array.unary(|x| round_integer(x as i64, $scale, false).unwrap() as Native) + }; + Ok(ColumnarValue::Array(Arc::new(result))) + }}; +} + +macro_rules! impl_float_array_round { + ($array:expr, $arrow_type:ty, $scale:expr) => {{ + let array = $array.as_primitive::<$arrow_type>(); + let result: PrimitiveArray<$arrow_type> = array.unary(|x| round_float(x, $scale)); + Ok(ColumnarValue::Array(Arc::new(result))) + }}; +} + +macro_rules! impl_decimal_array_round { + ($array:expr, $arrow_type:ty, $input_scale:expr, $scale:expr) => {{ + let array = $array.as_primitive::<$arrow_type>(); + let result: PrimitiveArray<$arrow_type> = array + .try_unary(|x| round_decimal(x, $input_scale, $scale))? + .with_data_type($array.data_type().clone()); + Ok(ColumnarValue::Array(Arc::new(result))) + }}; +} + +// --------------------------------------------------------------------------- +// Core dispatch +// --------------------------------------------------------------------------- + +fn spark_round(args: &[ColumnarValue], enable_ansi_mode: bool) -> Result { + if args.is_empty() || args.len() > 2 { + return exec_err!("round requires 1 or 2 arguments, got {}", args.len()); + } + + let scale = match get_scale(args)? { + Some(s) => s, + None => { + // NULL scale → return NULL with the same data type as the first argument + return Ok(ColumnarValue::Scalar(ScalarValue::try_from( + args[0].data_type(), + )?)); + } + }; + + match &args[0] { + ColumnarValue::Array(array) => match array.data_type() { + DataType::Null => Ok(args[0].clone()), + + // Integer types + DataType::Int8 => { + impl_integer_array_round!(array, Int8Type, scale, enable_ansi_mode) + } + DataType::Int16 => { + impl_integer_array_round!(array, Int16Type, scale, enable_ansi_mode) + } + DataType::Int32 => { + impl_integer_array_round!(array, Int32Type, scale, enable_ansi_mode) + } + DataType::Int64 => { + impl_integer_array_round!(array, Int64Type, scale, enable_ansi_mode) + } + + // Unsigned integer types + DataType::UInt8 => { + impl_integer_array_round!(array, UInt8Type, scale, enable_ansi_mode) + } + DataType::UInt16 => { + impl_integer_array_round!(array, UInt16Type, scale, enable_ansi_mode) + } + DataType::UInt32 => { + impl_integer_array_round!(array, UInt32Type, scale, enable_ansi_mode) + } + DataType::UInt64 => { + let array = array.as_primitive::(); + let result: PrimitiveArray = array.try_unary(|x| { + let v_i64 = i64::try_from(x).map_err(|_| { + (exec_err!( + "round: UInt64 value {x} exceeds i64::MAX and cannot be rounded" + ) as Result<(), _>) + .unwrap_err() + })?; + round_integer(v_i64, scale, enable_ansi_mode) + .map(|v| v as u64) + })?; + Ok(ColumnarValue::Array(Arc::new(result))) + } + + // Float types + DataType::Float16 => impl_float_array_round!(array, Float16Type, scale), + DataType::Float32 => impl_float_array_round!(array, Float32Type, scale), + DataType::Float64 => impl_float_array_round!(array, Float64Type, scale), + + // Decimal types + DataType::Decimal32(_, input_scale) => { + impl_decimal_array_round!(array, Decimal32Type, *input_scale, scale) + } + DataType::Decimal64(_, input_scale) => { + impl_decimal_array_round!(array, Decimal64Type, *input_scale, scale) + } + DataType::Decimal128(_, input_scale) => { + impl_decimal_array_round!(array, Decimal128Type, *input_scale, scale) + } + DataType::Decimal256(_, input_scale) => { + impl_decimal_array_round!(array, Decimal256Type, *input_scale, scale) + } + + dt => not_impl_err!("Unsupported data type for Spark round(): {dt}"), + }, + + ColumnarValue::Scalar(sv) => match sv { + ScalarValue::Null => Ok(args[0].clone()), + _ if sv.is_null() => Ok(args[0].clone()), + + // Integer scalars + ScalarValue::Int8(Some(v)) => { + let r = round_integer(i64::from(*v), scale, enable_ansi_mode)?; + let result = if enable_ansi_mode { + i8::try_from(r).map_err(|_| { + (exec_err!("Int8 overflow on round({v}, {scale})") + as Result<(), _>) + .unwrap_err() + })? + } else { + r as i8 + }; + Ok(ColumnarValue::Scalar(ScalarValue::Int8(Some(result)))) + } + ScalarValue::Int16(Some(v)) => { + let r = round_integer(i64::from(*v), scale, enable_ansi_mode)?; + let result = if enable_ansi_mode { + i16::try_from(r).map_err(|_| { + (exec_err!("Int16 overflow on round({v}, {scale})") + as Result<(), _>) + .unwrap_err() + })? + } else { + r as i16 + }; + Ok(ColumnarValue::Scalar(ScalarValue::Int16(Some(result)))) + } + ScalarValue::Int32(Some(v)) => { + let r = round_integer(i64::from(*v), scale, enable_ansi_mode)?; + let result = if enable_ansi_mode { + i32::try_from(r).map_err(|_| { + (exec_err!("Int32 overflow on round({v}, {scale})") + as Result<(), _>) + .unwrap_err() + })? + } else { + r as i32 + }; + Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(result)))) + } + ScalarValue::Int64(Some(v)) => { + let result = round_integer(*v, scale, enable_ansi_mode)?; + Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some(result)))) + } + + // Unsigned integer scalars + ScalarValue::UInt8(Some(v)) => { + let r = round_integer(i64::from(*v), scale, enable_ansi_mode)?; + let result = if enable_ansi_mode { + u8::try_from(r).map_err(|_| { + (exec_err!("UInt8 overflow on round({v}, {scale})") + as Result<(), _>) + .unwrap_err() + })? + } else { + r as u8 + }; + Ok(ColumnarValue::Scalar(ScalarValue::UInt8(Some(result)))) + } + ScalarValue::UInt16(Some(v)) => { + let r = round_integer(i64::from(*v), scale, enable_ansi_mode)?; + let result = if enable_ansi_mode { + u16::try_from(r).map_err(|_| { + (exec_err!("UInt16 overflow on round({v}, {scale})") + as Result<(), _>) + .unwrap_err() + })? + } else { + r as u16 + }; + Ok(ColumnarValue::Scalar(ScalarValue::UInt16(Some(result)))) + } + ScalarValue::UInt32(Some(v)) => { + let r = round_integer(i64::from(*v), scale, enable_ansi_mode)?; + let result = if enable_ansi_mode { + u32::try_from(r).map_err(|_| { + (exec_err!("UInt32 overflow on round({v}, {scale})") + as Result<(), _>) + .unwrap_err() + })? + } else { + r as u32 + }; + Ok(ColumnarValue::Scalar(ScalarValue::UInt32(Some(result)))) + } + ScalarValue::UInt64(Some(v)) => { + let v_i64 = i64::try_from(*v).map_err(|_| { + (exec_err!( + "round: UInt64 value {v} exceeds i64::MAX and cannot be rounded" + ) as Result<(), _>) + .unwrap_err() + })?; + let result = round_integer(v_i64, scale, enable_ansi_mode)?; + Ok(ColumnarValue::Scalar(ScalarValue::UInt64(Some( + result as u64, + )))) + } + + // Float scalars + ScalarValue::Float16(Some(v)) => { + let result = round_float(*v, scale); + Ok(ColumnarValue::Scalar(ScalarValue::Float16(Some(result)))) + } + ScalarValue::Float32(Some(v)) => { + let result = round_float(*v, scale); + Ok(ColumnarValue::Scalar(ScalarValue::Float32(Some(result)))) + } + ScalarValue::Float64(Some(v)) => { + let result = round_float(*v, scale); + Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some(result)))) + } + + // Decimal scalars + ScalarValue::Decimal32(Some(v), precision, input_scale) => { + let rounded = round_decimal(*v, *input_scale, scale)?; + Ok(ColumnarValue::Scalar(ScalarValue::Decimal32( + Some(rounded), + *precision, + *input_scale, + ))) + } + ScalarValue::Decimal64(Some(v), precision, input_scale) => { + let rounded = round_decimal(*v, *input_scale, scale)?; + Ok(ColumnarValue::Scalar(ScalarValue::Decimal64( + Some(rounded), + *precision, + *input_scale, + ))) + } + ScalarValue::Decimal128(Some(v), precision, input_scale) => { + let rounded = round_decimal(*v, *input_scale, scale)?; + Ok(ColumnarValue::Scalar(ScalarValue::Decimal128( + Some(rounded), + *precision, + *input_scale, + ))) + } + ScalarValue::Decimal256(Some(v), precision, input_scale) => { + let rounded = round_decimal(*v, *input_scale, scale)?; + Ok(ColumnarValue::Scalar(ScalarValue::Decimal256( + Some(rounded), + *precision, + *input_scale, + ))) + } + + dt => not_impl_err!("Unsupported data type for Spark round(): {dt}"), + }, + } +} diff --git a/datafusion/sqllogictest/test_files/spark/math/round.slt b/datafusion/sqllogictest/test_files/spark/math/round.slt index bc1f6b72247a..91c5bdf0506f 100644 --- a/datafusion/sqllogictest/test_files/spark/math/round.slt +++ b/datafusion/sqllogictest/test_files/spark/math/round.slt @@ -15,13 +15,567 @@ # specific language governing permissions and limitations # under the License. -# This file was originally created by a porting script from: -# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function # This file is part of the implementation of the datafusion-spark function library. # For more information, please see: # https://github.com/apache/datafusion/issues/15914 -## Original Query: SELECT round(2.5, 0); -## PySpark 3.5.5 Result: {'round(2.5, 0)': Decimal('3'), 'typeof(round(2.5, 0))': 'decimal(2,0)', 'typeof(2.5)': 'decimal(2,1)', 'typeof(0)': 'int'} -#query -#SELECT round(2.5::decimal(2,1), 0::int); +# ------------------------------------------------------------------- +# Float / Double tests (HALF_UP rounding: .5 rounds away from zero) +# ------------------------------------------------------------------- + +# round(double) default scale = 0 +query R +SELECT round(2.5::double); +---- +3 + +query R +SELECT round(3.5::double); +---- +4 + +query R +SELECT round(-2.5::double); +---- +-3 + +query R +SELECT round(-3.5::double); +---- +-4 + +query R +SELECT round(1.4::double); +---- +1 + +query R +SELECT round(1.6::double); +---- +2 + +# round(double, scale) +query R +SELECT round(2.345::double, 2::int); +---- +2.35 + +query R +SELECT round(2.355::double, 2::int); +---- +2.36 + +query R +SELECT round(123.456::double, 1::int); +---- +123.5 + +# round(float) +query R +SELECT round(arrow_cast(2.5, 'Float32')); +---- +3 + +query R +SELECT round(arrow_cast(-2.5, 'Float32')); +---- +-3 + +# round(float) with scale +query R +SELECT round(arrow_cast(2.345, 'Float32'), 2::int); +---- +2.35 + +# round(float32) with negative scale +query R +SELECT round(arrow_cast(125.0, 'Float32'), -1::int); +---- +130 + +# ------------------------------------------------------------------- +# Float16 tests +# ------------------------------------------------------------------- + +# round(float16) default scale = 0 +query R +SELECT round(arrow_cast(2.5, 'Float16')); +---- +3 + +query R +SELECT round(arrow_cast(-2.5, 'Float16')); +---- +-3 + +# round(float16) with negative scale +query R +SELECT round(arrow_cast(125, 'Float16'), -1::int); +---- +130 + +# round(double) with negative scale +query R +SELECT round(125.0::double, -1::int); +---- +130 + +query R +SELECT round(125.0::double, -2::int); +---- +100 + +query R +SELECT round(0.0::double); +---- +0 + +# round(double) with Infinity and NaN +query R +SELECT round('Infinity'::double); +---- +Infinity + +query R +SELECT round('-Infinity'::double); +---- +-Infinity + +query R +SELECT round('NaN'::double); +---- +NaN + +query R +SELECT round('Infinity'::double, 2::int); +---- +Infinity + +query R +SELECT round('NaN'::double, 2::int); +---- +NaN + +# round(double) with extreme negative scale — should return 0, not NaN +query R +SELECT round(42.0::double, -400::int); +---- +0 + +# round(double) with extreme positive scale — should return value as-is +query R +SELECT round(2.5::double, 400::int); +---- +2.5 + +# ------------------------------------------------------------------- +# Integer tests (negative scale rounds to tens, hundreds, etc.) +# ------------------------------------------------------------------- + +# round(int, -1) — round to nearest 10 +query I +SELECT round(25::int, -1::int); +---- +30 + +query I +SELECT round(24::int, -1::int); +---- +20 + +query I +SELECT round(-25::int, -1::int); +---- +-30 + +query I +SELECT round(123::int, -1::int); +---- +120 + +# round(int, -2) — round to nearest 100 +query I +SELECT round(150::int, -2::int); +---- +200 + +query I +SELECT round(149::int, -2::int); +---- +100 + +# round(int, positive scale) — no-op for integers +query I +SELECT round(42::int, 2::int); +---- +42 + +# round(int) default scale = 0 — returns unchanged +query I +SELECT round(42::int); +---- +42 + +# round(bigint, -1) +query I +SELECT round(25::bigint, -1::int); +---- +30 + +# round(smallint, -1) +query I +SELECT round(25::smallint, -1::int); +---- +30 + +# round(tinyint, -1) +query I +SELECT round(arrow_cast(25, 'Int8'), -1::int); +---- +30 + +# round(int) with very large negative scale — should return 0 +query I +SELECT round(42::int, -10::int); +---- +0 + +# ------------------------------------------------------------------- +# Unsigned integer tests +# ------------------------------------------------------------------- + +# round(uint8, -1) +query I +SELECT round(arrow_cast(25, 'UInt8'), -1::int); +---- +30 + +# round(uint16, -1) +query I +SELECT round(arrow_cast(25, 'UInt16'), -1::int); +---- +30 + +# round(uint32, -1) +query I +SELECT round(arrow_cast(150, 'UInt32'), -2::int); +---- +200 + +# round(uint64, -1) +query I +SELECT round(arrow_cast(25, 'UInt64'), -1::int); +---- +30 + +# round(uint32, positive scale) — no-op for integers +query I +SELECT round(arrow_cast(42, 'UInt32'), 2::int); +---- +42 + +# ------------------------------------------------------------------- +# Decimal tests (HALF_UP rounding) +# ------------------------------------------------------------------- + +# --- Decimal32 --- + +# round(decimal32, 0) — round to integer +query ? +SELECT round(arrow_cast(2.5, 'Decimal32(9, 1)'), 0::int); +---- +3.0 + +query ? +SELECT round(arrow_cast(-2.5, 'Decimal32(9, 1)'), 0::int); +---- +-3.0 + +# round(decimal32, 2) +query ? +SELECT round(arrow_cast(2.345, 'Decimal32(9, 3)'), 2::int); +---- +2.350 + +# round(decimal32) default scale = 0 +query ? +SELECT round(arrow_cast(3.5, 'Decimal32(9, 1)')); +---- +4.0 + +# --- Decimal64 --- + +# round(decimal64, 0) — round to integer +query ? +SELECT round(arrow_cast(2.5, 'Decimal64(18, 1)'), 0::int); +---- +3.0 + +query ? +SELECT round(arrow_cast(-2.5, 'Decimal64(18, 1)'), 0::int); +---- +-3.0 + +# round(decimal64, 2) +query ? +SELECT round(arrow_cast(2.345, 'Decimal64(18, 3)'), 2::int); +---- +2.350 + +# round(decimal64) default scale = 0 +query ? +SELECT round(arrow_cast(3.5, 'Decimal64(18, 1)')); +---- +4.0 + +# --- Decimal128 --- + +# round(decimal, 0) — round to integer +query R +SELECT round(2.5::decimal(2,1), 0::int); +---- +3 + +query R +SELECT round(3.5::decimal(2,1), 0::int); +---- +4 + +query R +SELECT round(-2.5::decimal(2,1), 0::int); +---- +-3 + +# round(decimal) default scale = 0 +query R +SELECT round(2.5::decimal(2,1)); +---- +3 + +# round(decimal, 2) — keep 2 decimal places +query R +SELECT round(2.345::decimal(10,3), 2::int); +---- +2.35 + +query R +SELECT round(2.355::decimal(10,3), 2::int); +---- +2.36 + +# round(decimal, scale larger than input scale) — no change +query R +SELECT round(2.5::decimal(2,1), 5::int); +---- +2.5 + +# round(decimal, 1) +query R +SELECT round(123.456::decimal(10,3), 1::int); +---- +123.5 + +# round(decimal, negative scale) — round to tens +query R +SELECT round(125.0::decimal(10,1), -1::int); +---- +130 + +# round(decimal, extreme negative scale) — should return 0, not error +query R +SELECT round(2.5::decimal(10,1), -400::int); +---- +0 + +# --- Decimal256 --- + +# round(decimal256, 0) — round to integer +query R +SELECT round(arrow_cast(2.5, 'Decimal256(38, 1)'), 0::int); +---- +3 + +query R +SELECT round(arrow_cast(-2.5, 'Decimal256(38, 1)'), 0::int); +---- +-3 + +# round(decimal256, 2) +query R +SELECT round(arrow_cast(2.345, 'Decimal256(38, 3)'), 2::int); +---- +2.35 + +# round(decimal256) default scale = 0 +query R +SELECT round(arrow_cast(3.5, 'Decimal256(38, 1)')); +---- +4 + +# ------------------------------------------------------------------- +# NULL handling +# ------------------------------------------------------------------- + +query I +SELECT round(NULL::int); +---- +NULL + +query R +SELECT round(NULL::double); +---- +NULL + +query R +SELECT round(NULL::decimal(10,2)); +---- +NULL + +# round with NULL scale — Spark returns NULL +query I +SELECT round(42::int, NULL::int); +---- +NULL + +query R +SELECT round(2.5::double, NULL::int); +---- +NULL + +# ------------------------------------------------------------------- +# Column-based tests +# ------------------------------------------------------------------- + +statement ok +CREATE TABLE test_round (id int, int_val int, float_val double, dec_val decimal(10,3)) AS VALUES + (1, 25, 2.5, 2.345), + (2, 35, 3.5, 3.555), + (3, -25, -2.5, -2.345), + (4, 123, 1.4, 1.005), + (5, NULL, NULL, NULL); + +query IIRR rowsort +SELECT id, round(int_val, -1::int), round(float_val), round(dec_val, 2::int) FROM test_round; +---- +1 30 3 2.35 +2 40 4 3.56 +3 -30 -3 -2.35 +4 120 1 1.01 +5 NULL NULL NULL + +statement ok +DROP TABLE test_round; + +# ------------------------------------------------------------------- +# Expression tests +# ------------------------------------------------------------------- + +query R +SELECT round(3.14159::double, 2::int) + 1.0; +---- +4.140000000000001 + +# ------------------------------------------------------------------- +# Non-ANSI wrapping behavior +# When ANSI mode is off, integer overflow wraps silently. +# ------------------------------------------------------------------- + +# round(127::tinyint, -1) → 130, wraps as i8 → -126 +query I +SELECT round(arrow_cast(127, 'Int8'), -1::int); +---- +-126 + +# round(32767::smallint, -1) → 32770, wraps as i16 → -32766 +query I +SELECT round(arrow_cast(32767, 'Int16'), -1::int); +---- +-32766 + +# round(2147483647::int, -1) → 2147483650, wraps as i32 → -2147483646 +query I +SELECT round(2147483647::int, -1::int); +---- +-2147483646 + +# round(i64::MAX, -1) wraps as i64 → -9223372036854775806 +query I +SELECT round(9223372036854775807::bigint, -1::int); +---- +-9223372036854775806 + +# ------------------------------------------------------------------- +# ANSI mode tests: overflow detection for integer rounding +# ------------------------------------------------------------------- + +statement ok +set datafusion.execution.enable_ansi_mode = true; + +# ANSI mode: normal rounding should still work +query I +SELECT round(25::int, -1::int); +---- +30 + +query I +SELECT round(-25::int, -1::int); +---- +-30 + +query I +SELECT round(150::int, -2::int); +---- +200 + +# ANSI mode: positive scale on integers — no-op, no overflow +query I +SELECT round(42::int, 2::int); +---- +42 + +# ANSI mode: floats and decimals should work normally +query R +SELECT round(2.5::double); +---- +3 + +query R +SELECT round(2.5::decimal(2,1), 0::int); +---- +3 + +# ANSI mode: integer overflow should error +query error DataFusion error: Execution error: Int64 overflow on round +SELECT round(9223372036854775807::bigint, -1::int); + +# ANSI mode: Int32 overflow should error +query error DataFusion error: Execution error: Int32 overflow on round +SELECT round(2147483647::int, -1::int); + +# ANSI mode: Int16 overflow should error +query error DataFusion error: Execution error: Int16 overflow on round +SELECT round(arrow_cast(32767, 'Int16'), -1::int); + +# ANSI mode: Int8 overflow should error +query error DataFusion error: Execution error: Int8 overflow on round +SELECT round(arrow_cast(127, 'Int8'), -1::int); + +# Reset ANSI mode +statement ok +set datafusion.execution.enable_ansi_mode = false; + +# ------------------------------------------------------------------- +# Negative tests: unsupported data types +# ------------------------------------------------------------------- + +# round(string) should fail +query error Error during planning: Internal error: Function 'round' failed to match any signature +SELECT round('hello'::varchar); + +# round(boolean) should fail +query error Error during planning: Internal error: Function 'round' failed to match any signature +SELECT round(true); + +# round(timestamp) should fail +query error Error during planning: Internal error: Function 'round' failed to match any signature +SELECT round('2023-01-01T00:00:00'::timestamp);