diff --git a/spark/src/main/scala/org/apache/comet/serde/arithmetic.scala b/spark/src/main/scala/org/apache/comet/serde/arithmetic.scala index 85574fbab7..259da5d26f 100644 --- a/spark/src/main/scala/org/apache/comet/serde/arithmetic.scala +++ b/spark/src/main/scala/org/apache/comet/serde/arithmetic.scala @@ -303,23 +303,14 @@ object CometRound extends CometExpressionSerde[Round] { exprToProtoInternal(Literal(null), inputs, binding) case _: ByteType | ShortType | IntegerType | LongType if _scale >= 0 => childExpr // _scale(I.e. decimal place) >= 0 is a no-op for integer types in Spark - case _: FloatType | DoubleType => - // We cannot properly match with the Spark behavior for floating-point numbers. - // Spark uses BigDecimal for rounding float/double, and BigDecimal fist converts a - // double to string internally in order to create its own internal representation. - // The problem is BigDecimal uses java.lang.Double.toString() and it has complicated - // rounding algorithm. E.g. -5.81855622136895E8 is actually - // -581855622.13689494132995605468750. Note the 5th fractional digit is 4 instead of - // 5. Java(Scala)'s toString() rounds it up to -581855622.136895. This makes a - // difference when rounding at 5th digit, I.e. round(-5.81855622136895E8, 5) should be - // -5.818556221369E8, instead of -5.8185562213689E8. There is also an example that - // toString() does NOT round up. 6.1317116247283497E18 is 6131711624728349696. It can - // be rounded up to 6.13171162472835E18 that still represents the same double number. - // I.e. 6.13171162472835E18 == 6.1317116247283497E18. However, toString() does not. - // That results in round(6.1317116247283497E18, -5) == 6.1317116247282995E18 instead - // of 6.1317116247283999E18. - withInfo(r, "Comet does not support Spark's BigDecimal rounding") - None + case _: FloatType => + // Spark rounds floats by widening to double, building a BigDecimal via + // java.lang.Double.toString, applying HALF_UP, and narrowing back. The toString + // algorithm differs between JDKs (notably 17 vs 21), so a native implementation + // can't match every JDK. Delegate to a JVM UDF that runs on the executor's JDK. + convertViaJvmUdf(r, "org.apache.comet.udf.RoundFloatUDF", _scale, inputs, binding) + case _: DoubleType => + convertViaJvmUdf(r, "org.apache.comet.udf.RoundDoubleUDF", _scale, inputs, binding) case _ => // `scale` must be Int64 type in DataFusion val scaleExpr = exprToProtoInternal(Literal(_scale.toLong, LongType), inputs, binding) @@ -334,6 +325,36 @@ object CometRound extends CometExpressionSerde[Round] { } } + + private def convertViaJvmUdf( + r: Round, + className: String, + scale: Int, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + val valueProto = exprToProtoInternal(r.child, inputs, binding) + val scaleProto = exprToProtoInternal(Literal(scale, IntegerType), inputs, binding) + if (valueProto.isEmpty || scaleProto.isEmpty) { + withInfo(r, r.child) + return None + } + val returnType = serializeDataType(r.dataType).getOrElse { + withInfo(r, s"Unsupported return type ${r.dataType} for Round JVM UDF") + return None + } + val udfBuilder = ExprOuterClass.JvmScalarUdf + .newBuilder() + .setClassName(className) + .addArgs(valueProto.get) + .addArgs(scaleProto.get) + .setReturnType(returnType) + .setReturnNullable(r.nullable) + Some( + ExprOuterClass.Expr + .newBuilder() + .setJvmScalarUdf(udfBuilder.build()) + .build()) + } } object CometUnaryMinus extends CometExpressionSerde[UnaryMinus] { diff --git a/spark/src/main/scala/org/apache/comet/udf/RoundDoubleUDF.scala b/spark/src/main/scala/org/apache/comet/udf/RoundDoubleUDF.scala new file mode 100644 index 0000000000..217c74199b --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/udf/RoundDoubleUDF.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.udf + +import org.apache.arrow.vector.{Float8Vector, IntVector, ValueVector} + +import org.apache.comet.CometArrowAllocator + +/** + * `round(double, scale)` implemented by delegating to Scala's `BigDecimal(d)`, which goes through + * `java.lang.Double.toString` before applying the requested scale. This matches Spark's + * `RoundBase` for `DoubleType` exactly on whatever JDK the executor is running, so output stays + * consistent across Java 17 / 21 even though the underlying `Double.toString` algorithm differs. + * + * Inputs: + * - inputs(0): Float8Vector value column (length = numRows, or length 1 when literal-folded) + * - inputs(1): IntVector scale, length-1 scalar (serde guarantees this) + * + * Output: Float8Vector, length numRows. + */ +class RoundDoubleUDF extends CometUDF { + + override def evaluate(inputs: Array[ValueVector], numRows: Int): ValueVector = { + require(inputs.length == 2, s"RoundDoubleUDF expects 2 inputs, got ${inputs.length}") + val values = inputs(0).asInstanceOf[Float8Vector] + val scaleVec = inputs(1).asInstanceOf[IntVector] + require( + scaleVec.getValueCount >= 1 && !scaleVec.isNull(0), + "RoundDoubleUDF requires a non-null scalar scale") + val scale = scaleVec.get(0) + + val out = new Float8Vector("round_double", CometArrowAllocator) + out.allocateNew(numRows) + + val valueIsScalar = values.getValueCount == 1 && numRows != 1 + if (valueIsScalar) { + if (values.isNull(0)) { + var i = 0 + while (i < numRows) { out.setNull(i); i += 1 } + } else { + val rounded = RoundDoubleUDF.roundDouble(values.get(0), scale) + var i = 0 + while (i < numRows) { out.set(i, rounded); i += 1 } + } + } else { + var i = 0 + while (i < numRows) { + if (values.isNull(i)) { + out.setNull(i) + } else { + out.set(i, RoundDoubleUDF.roundDouble(values.get(i), scale)) + } + i += 1 + } + } + out.setValueCount(numRows) + out + } +} + +object RoundDoubleUDF { + def roundDouble(v: Double, scale: Int): Double = { + if (v.isNaN || v.isInfinite) v + else BigDecimal(v).setScale(scale, BigDecimal.RoundingMode.HALF_UP).doubleValue + } +} diff --git a/spark/src/main/scala/org/apache/comet/udf/RoundFloatUDF.scala b/spark/src/main/scala/org/apache/comet/udf/RoundFloatUDF.scala new file mode 100644 index 0000000000..37c38d5b0a --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/udf/RoundFloatUDF.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.udf + +import org.apache.arrow.vector.{Float4Vector, IntVector, ValueVector} + +import org.apache.comet.CometArrowAllocator + +/** + * `round(float, scale)` implemented to mirror Spark's `RoundBase` for `FloatType`: widen to + * double, build a `BigDecimal` via `java.lang.Double.toString`, apply HALF_UP at the requested + * scale, then narrow back to float. The widening before BigDecimal construction is intentional: + * it matches Spark and produces the same result string the JDK uses for the value. + * + * Inputs: + * - inputs(0): Float4Vector value column (length = numRows, or length 1 when literal-folded) + * - inputs(1): IntVector scale, length-1 scalar (serde guarantees this) + * + * Output: Float4Vector, length numRows. + */ +class RoundFloatUDF extends CometUDF { + + override def evaluate(inputs: Array[ValueVector], numRows: Int): ValueVector = { + require(inputs.length == 2, s"RoundFloatUDF expects 2 inputs, got ${inputs.length}") + val values = inputs(0).asInstanceOf[Float4Vector] + val scaleVec = inputs(1).asInstanceOf[IntVector] + require( + scaleVec.getValueCount >= 1 && !scaleVec.isNull(0), + "RoundFloatUDF requires a non-null scalar scale") + val scale = scaleVec.get(0) + + val out = new Float4Vector("round_float", CometArrowAllocator) + out.allocateNew(numRows) + + val valueIsScalar = values.getValueCount == 1 && numRows != 1 + if (valueIsScalar) { + if (values.isNull(0)) { + var i = 0 + while (i < numRows) { out.setNull(i); i += 1 } + } else { + val rounded = RoundFloatUDF.roundFloat(values.get(0), scale) + var i = 0 + while (i < numRows) { out.set(i, rounded); i += 1 } + } + } else { + var i = 0 + while (i < numRows) { + if (values.isNull(i)) { + out.setNull(i) + } else { + out.set(i, RoundFloatUDF.roundFloat(values.get(i), scale)) + } + i += 1 + } + } + out.setValueCount(numRows) + out + } +} + +object RoundFloatUDF { + def roundFloat(v: Float, scale: Int): Float = { + if (v.isNaN || v.isInfinite) v + else BigDecimal(v.toDouble).setScale(scale, BigDecimal.RoundingMode.HALF_UP).floatValue + } +} diff --git a/spark/src/test/resources/sql-tests/expressions/math/round.sql b/spark/src/test/resources/sql-tests/expressions/math/round.sql index 7a821b8027..dfd00188d4 100644 --- a/spark/src/test/resources/sql-tests/expressions/math/round.sql +++ b/spark/src/test/resources/sql-tests/expressions/math/round.sql @@ -21,18 +21,86 @@ CREATE TABLE test_round(d double, i int) USING parquet statement INSERT INTO test_round VALUES (2.5, 0), (3.5, 0), (-2.5, 0), (123.456, 2), (123.456, -1), (NULL, 0), (cast('NaN' as double), 0), (cast('Infinity' as double), 0), (0.0, 0) -query expect_fallback(BigDecimal rounding) +query SELECT round(d, 0) FROM test_round WHERE i = 0 -query expect_fallback(BigDecimal rounding) +query SELECT round(d, 2) FROM test_round WHERE i = 2 -query expect_fallback(BigDecimal rounding) +query SELECT round(d, -1) FROM test_round WHERE i = -1 -query expect_fallback(BigDecimal rounding) +query SELECT round(d) FROM test_round -- literal + literal -query expect_fallback(BigDecimal rounding) +query SELECT round(123.456, 2), round(2.5, 0), round(3.5, 0), round(-2.5, 0), round(NULL, 0) + +-- HALF_UP semantics: .5 always rounds away from zero +statement +CREATE TABLE test_round_half_up(d double) USING parquet + +statement +INSERT INTO test_round_half_up VALUES (0.5), (1.5), (2.5), (-0.5), (-1.5), (-2.5) + +query +SELECT d, round(d, 0) FROM test_round_half_up + +-- various scales on a single value +query +SELECT round(123.456, 0), round(123.456, 1), round(123.456, 2), round(123.456, 3), round(123.456, 5) + +query +SELECT round(123.456, -1), round(123.456, -2), round(123.456, -3) + +-- special values +query +SELECT round(cast('NaN' as double), 2), round(cast('Infinity' as double), 2), round(cast('-Infinity' as double), 2) + +query +SELECT round(0.0, 5), round(-0.0, 5) + +-- very small values +query +SELECT round(1.0E-10, 15), round(1.0E-10, 10), round(1.0E-10, 5) + +-- negative scale on doubles +query +SELECT round(9999.9, -1), round(9999.9, -2), round(9999.9, -3), round(9999.9, -4) + +query +SELECT round(-9999.9, -1), round(-9999.9, -2), round(-9999.9, -3), round(-9999.9, -4) + +-- float type +statement +CREATE TABLE test_round_float(f float) USING parquet + +statement +INSERT INTO test_round_float VALUES (cast(2.5 as float)), (cast(3.5 as float)), (cast(-2.5 as float)), (cast(0.125 as float)), (cast(0.785 as float)), (cast(123.456 as float)), (cast('NaN' as float)), (cast('Infinity' as float)), (NULL) + +query +SELECT round(f, 0) FROM test_round_float + +query +SELECT round(f, 2) FROM test_round_float + +query +SELECT round(f, -1) FROM test_round_float + +-- BigDecimal rounding edge case from Spark +statement +CREATE TABLE test_round_edge(d double) USING parquet + +statement +INSERT INTO test_round_edge VALUES (-5.81855622136895E8), (6.1317116247283497E18), (6.13171162472835E18) + +query +SELECT round(d, 4), round(d, 5), round(d, 6) FROM test_round_edge + +query +SELECT round('-8316362075006449156', -5) + +-- round with column from table (not literals) +query +SELECT d, round(d, 0), round(d, 2), round(d, -1) FROM test_round diff --git a/spark/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q78/extended.txt b/spark/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q78/extended.txt index d29dbc13e5..d344860519 100644 --- a/spark/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q78/extended.txt +++ b/spark/src/test/resources/tpcds-plan-stability/approved-plans-v1_4/q78/extended.txt @@ -1,6 +1,6 @@ -TakeOrderedAndProject -+- Project [COMET: Comet does not support Spark's BigDecimal rounding] - +- CometNativeColumnarToRow +CometNativeColumnarToRow ++- CometTakeOrderedAndProject + +- CometProject +- CometSortMergeJoin :- CometProject : +- CometSortMergeJoin @@ -76,4 +76,4 @@ TakeOrderedAndProject +- CometFilter +- CometNativeScan parquet spark_catalog.default.date_dim -Comet accelerated 71 out of 76 eligible operators (93%). Final plan contains 1 transitions between Spark and Comet. \ No newline at end of file +Comet accelerated 73 out of 76 eligible operators (96%). Final plan contains 1 transitions between Spark and Comet. \ No newline at end of file diff --git a/spark/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q78/extended.txt b/spark/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q78/extended.txt index d29dbc13e5..d344860519 100644 --- a/spark/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q78/extended.txt +++ b/spark/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q78/extended.txt @@ -1,6 +1,6 @@ -TakeOrderedAndProject -+- Project [COMET: Comet does not support Spark's BigDecimal rounding] - +- CometNativeColumnarToRow +CometNativeColumnarToRow ++- CometTakeOrderedAndProject + +- CometProject +- CometSortMergeJoin :- CometProject : +- CometSortMergeJoin @@ -76,4 +76,4 @@ TakeOrderedAndProject +- CometFilter +- CometNativeScan parquet spark_catalog.default.date_dim -Comet accelerated 71 out of 76 eligible operators (93%). Final plan contains 1 transitions between Spark and Comet. \ No newline at end of file +Comet accelerated 73 out of 76 eligible operators (96%). Final plan contains 1 transitions between Spark and Comet. \ No newline at end of file diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index a172538f45..9641343568 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -2874,10 +2874,24 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { Byte.MinValue, Byte.MaxValue, Short.MinValue, - Short.MaxValue)).foreach { value => + Short.MaxValue, + Float.MinPositiveValue, + Float.MaxValue, + Float.NaN, + Float.MinValue, + Float.NegativeInfinity, + Float.PositiveInfinity, + Double.MinPositiveValue, + Double.MaxValue, + Double.NaN, + Double.MinValue, + Double.NegativeInfinity, + Double.PositiveInfinity, + -5.81855622136895e8, + 6.1317116247283497e18)).foreach { value => val data = Seq(value) withParquetTable(data, "tbl") { - Seq(-1000, -100, -10, -1, 0, 1, 10, 100, 1000).foreach { scale => + Seq(-1000, -100, -10, -5, -1, 0, 1, 5, 10, 100, 1000).foreach { scale => Seq(true, false).foreach { ansi => withSQLConf(SQLConf.ANSI_ENABLED.key -> ansi.toString) { val res = spark.sql(s"SELECT round(_1, $scale) from tbl") @@ -2899,6 +2913,54 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("round") { + Seq(true, false).foreach { dictionaryEnabled => + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test.parquet") + makeParquetFileAllPrimitiveTypes( + path, + dictionaryEnabled = dictionaryEnabled, + -128, + 128, + randomSize = 100) + withParquetTable(path.toString, "tbl") { + for (s <- Seq(-5, -1, 0, 1, 5, -1000, 1000, -323, -308, 308, -15, 15, -16, 16, null)) { + // array tests + // TODO: enable test for unsigned ints (_9, _10, _11, _12) + for (c <- Seq(2, 3, 4, 5, 6, 7, 8, 13, 15, 16, 17)) { + checkSparkAnswerAndOperator(s"select _${c}, round(_${c}, ${s}) FROM tbl") + } + // scalar tests + // Exclude the constant folding optimizer in order to actually execute the native round + // operations for scalar (literal) values. + withSQLConf( + "spark.sql.optimizer.excludedRules" -> "org.apache.spark.sql.catalyst.optimizer.ConstantFolding") { + for (n <- Seq("0.0", "-0.0", "0.5", "-0.5", "1.2", "-1.2")) { + checkSparkAnswerAndOperator(s"select round(cast(${n} as tinyint), ${s}) FROM tbl") + checkSparkAnswerAndOperator(s"select round(cast(${n} as float), ${s}) FROM tbl") + checkSparkAnswerAndOperator( + s"select round(cast(${n} as decimal(38, 18)), ${s}) FROM tbl") + checkSparkAnswerAndOperator( + s"select round(cast(${n} as decimal(20, 0)), ${s}) FROM tbl") + } + checkSparkAnswerAndOperator(s"select round(double('infinity'), ${s}) FROM tbl") + checkSparkAnswerAndOperator(s"select round(double('-infinity'), ${s}) FROM tbl") + checkSparkAnswerAndOperator(s"select round(double('NaN'), ${s}) FROM tbl") + checkSparkAnswerAndOperator( + s"select round(double('0.000000000000000000000000000000000001'), ${s}) FROM tbl") + } + } + } + } + } + } + + test("round double from large integer string") { + withParquetTable(Seq(Tuple1("-8316362075006449156")), "tbl") { + checkSparkAnswerAndOperator("SELECT round(cast(_1 as double), -5) FROM tbl") + } + } + test("test integral divide overflow for decimal") { // All inserted values produce a quotient > Decimal(38,0).max (~1e38), so they overflow // the intermediate decimal result type. In legacy/try mode both Spark and Comet return