diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index 0193f3012c..8522ac4bde 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -41,6 +41,7 @@ use datafusion::{ }; use datafusion_comet_proto::spark_operator::Operator; use datafusion_spark::function::bitwise::bit_get::SparkBitGet; +use datafusion_spark::function::bitwise::bit_shift::SparkShiftRightUnsigned; use datafusion_spark::function::bitwise::bitwise_not::SparkBitwiseNot; use datafusion_spark::function::datetime::date_add::SparkDateAdd; use datafusion_spark::function::datetime::date_sub::SparkDateSub; @@ -400,6 +401,7 @@ fn register_datafusion_spark_function(session_ctx: &SessionContext) { session_ctx.register_udf(ScalarUDF::new_from_impl(SparkWidthBucket::default())); session_ctx.register_udf(ScalarUDF::new_from_impl(MapFromEntries::default())); session_ctx.register_udf(ScalarUDF::new_from_impl(SparkCrc32::default())); + session_ctx.register_udf(ScalarUDF::new_from_impl(SparkShiftRightUnsigned::default())); } /// Prepares arrow arrays for output. diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 9d13ccd9ed..9fb8c86b49 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -188,7 +188,8 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[BitwiseNot] -> CometBitwiseNot, classOf[BitwiseXor] -> CometBitwiseXor, classOf[ShiftLeft] -> CometShiftLeft, - classOf[ShiftRight] -> CometShiftRight) + classOf[ShiftRight] -> CometShiftRight, + classOf[ShiftRightUnsigned] -> CometScalarFunction("shiftrightunsigned")) private val temporalExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map( classOf[DateAdd] -> CometDateAdd, diff --git a/spark/src/test/scala/org/apache/comet/CometBitwiseExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometBitwiseExpressionSuite.scala index 99a57b1575..50471c4d21 100644 --- a/spark/src/test/scala/org/apache/comet/CometBitwiseExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometBitwiseExpressionSuite.scala @@ -198,4 +198,12 @@ class CometBitwiseExpressionSuite extends CometTestBase with AdaptiveSparkPlanHe } } } + + test("shiftrightunsigned") { + withParquetTable(Seq((42, 2), (-1, 1), (255, 4), (0, 3), (Int.MinValue, 1)), "tbl") { + checkSparkAnswerAndOperator("SELECT shiftrightunsigned(_1, _2) FROM tbl") + checkSparkAnswerAndOperator("SELECT shiftrightunsigned(-1, 1) FROM tbl") + checkSparkAnswerAndOperator("SELECT shiftrightunsigned(NULL, 1) FROM tbl") + } + } }