diff --git a/docs/source/contributor-guide/spark_expressions_support.md b/docs/source/contributor-guide/spark_expressions_support.md index 65f941210a..e3182c11ae 100644 --- a/docs/source/contributor-guide/spark_expressions_support.md +++ b/docs/source/contributor-guide/spark_expressions_support.md @@ -165,7 +165,7 @@ - [x] bit_get - [x] getbit - [x] shiftright -- [ ] shiftrightunsigned +- [x] shiftrightunsigned - [x] `|` - [x] `~` diff --git a/docs/source/user-guide/latest/expressions.md b/docs/source/user-guide/latest/expressions.md index a610a83cea..d7932731ca 100644 --- a/docs/source/user-guide/latest/expressions.md +++ b/docs/source/user-guide/latest/expressions.md @@ -207,16 +207,17 @@ of expressions that be disabled. ## Bitwise Expressions -| Expression | SQL | -| ------------ | ---- | -| BitwiseAnd | `&` | -| BitwiseCount | | -| BitwiseGet | | -| BitwiseOr | `\|` | -| BitwiseNot | `~` | -| BitwiseXor | `^` | -| ShiftLeft | `<<` | -| ShiftRight | `>>` | +| Expression | SQL | +| ------------------ | ----- | +| BitwiseAnd | `&` | +| BitwiseCount | | +| BitwiseGet | | +| BitwiseOr | `\|` | +| BitwiseNot | `~` | +| BitwiseXor | `^` | +| ShiftLeft | `<<` | +| ShiftRight | `>>` | +| ShiftRightUnsigned | `>>>` | ## Aggregate Expressions diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index 97e3f851c5..0dcd78ba0f 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -45,6 +45,7 @@ use datafusion_comet_proto::spark_operator::Operator; use datafusion_spark::function::array::array_contains::SparkArrayContains; use datafusion_spark::function::bitwise::bit_count::SparkBitCount; use datafusion_spark::function::bitwise::bit_get::SparkBitGet; +use datafusion_spark::function::bitwise::bit_shift::SparkBitShift; use datafusion_spark::function::bitwise::bitwise_not::SparkBitwiseNot; use datafusion_spark::function::datetime::date_add::SparkDateAdd; use datafusion_spark::function::datetime::date_sub::SparkDateSub; @@ -607,6 +608,7 @@ fn register_datafusion_spark_function(session_ctx: &SessionContext) { session_ctx.register_udf(ScalarUDF::new_from_impl(SparkFactorial::default())); session_ctx.register_udf(ScalarUDF::new_from_impl(SparkSec::default())); session_ctx.register_udf(ScalarUDF::new_from_impl(SparkRint::default())); + session_ctx.register_udf(ScalarUDF::new_from_impl(SparkBitShift::right_unsigned())); } /// Prepares arrow arrays for output. diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 8d48239e76..b72ac38a85 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -215,7 +215,8 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim { classOf[BitwiseNot] -> CometBitwiseNot, classOf[BitwiseXor] -> CometBitwiseXor, classOf[ShiftLeft] -> CometShiftLeft, - classOf[ShiftRight] -> CometShiftRight) + classOf[ShiftRight] -> CometShiftRight, + classOf[ShiftRightUnsigned] -> CometScalarFunction("shiftrightunsigned")) private[comet] val temporalExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map( diff --git a/spark/src/test/resources/sql-tests/expressions/bitwise/bitwise.sql b/spark/src/test/resources/sql-tests/expressions/bitwise/bitwise.sql index 8b7ade25f0..f4a481fb2f 100644 --- a/spark/src/test/resources/sql-tests/expressions/bitwise/bitwise.sql +++ b/spark/src/test/resources/sql-tests/expressions/bitwise/bitwise.sql @@ -44,6 +44,50 @@ SELECT shiftright(col1, 2), shiftright(col1, col2) FROM test query SELECT shiftleft(col1, 2), shiftleft(col1, col2) FROM test +-- ShiftRightUnsigned: first arg is Int or Long, second is Int. Returns the +-- same integer type as the first argument. Shift amount is normalized to the +-- bit width (Java semantics) for negative and large shifts. +statement +CREATE TABLE test_shiftrightunsigned_int(v int, s int) USING parquet + +statement +INSERT INTO test_shiftrightunsigned_int VALUES + (1, 1), + (-1, 1), + (8, 2), + (2147483647, 1), + (-2147483648, 1), + (0, 0), + (1, 0), + (1, 31), + (1, 32), + (1, 33), + (1, -1), + (NULL, 1), + (1, NULL) + +query +SELECT shiftrightunsigned(v, s) FROM test_shiftrightunsigned_int + +statement +CREATE TABLE test_shiftrightunsigned_long(v bigint, s int) USING parquet + +statement +INSERT INTO test_shiftrightunsigned_long VALUES + (1, 1), + (-1, 1), + (9223372036854775807, 1), + (-9223372036854775808, 1), + (0, 0), + (1, 63), + (1, 64), + (1, -1), + (NULL, 1), + (1, NULL) + +query +SELECT shiftrightunsigned(v, s) FROM test_shiftrightunsigned_long + query SELECT ~(11), ~col1, ~col2 FROM test @@ -79,3 +123,6 @@ SELECT bit_get(11, 0), bit_get(11, 1), bit_get(11, 2), bit_get(11, 3) query SELECT shiftright(1111, 2), shiftleft(1111, 2) + +query +SELECT shiftrightunsigned(1, 1), shiftrightunsigned(-1, 1), shiftrightunsigned(2147483647, 1), shiftrightunsigned(cast(-1 as bigint), 1)