From 12665f5076e3559cffbce1005372a6a334f87e67 Mon Sep 17 00:00:00 2001 From: noroshi <253434427+n0r0shi@users.noreply.github.com> Date: Wed, 25 Feb 2026 07:07:31 +0000 Subject: [PATCH] feat: support binary lpad/rpad via StaticInvoke Add Rust UDFs for Spark's ByteArray.lpad/rpad (triggered when calling lpad/rpad on binary columns): - binary_lpad: left-pad binary array with cyclic byte pattern - binary_rpad: right-pad binary array with cyclic byte pattern Wire via StaticInvoke handler with foldability checks: str must be a column (not literal), len and pad must be scalar. --- native/spark-expr/src/comet_scalar_funcs.rs | 17 ++- .../src/static_invoke/binary_pad.rs | 114 ++++++++++++++++++ native/spark-expr/src/static_invoke/mod.rs | 2 + .../org/apache/comet/serde/statics.scala | 46 ++++++- .../comet/CometStringExpressionSuite.scala | 16 ++- 5 files changed, 186 insertions(+), 9 deletions(-) create mode 100644 native/spark-expr/src/static_invoke/binary_pad.rs diff --git a/native/spark-expr/src/comet_scalar_funcs.rs b/native/spark-expr/src/comet_scalar_funcs.rs index 4bfdef7096..67c11ec863 100644 --- a/native/spark-expr/src/comet_scalar_funcs.rs +++ b/native/spark-expr/src/comet_scalar_funcs.rs @@ -20,10 +20,11 @@ use crate::math_funcs::abs::abs; use crate::math_funcs::checked_arithmetic::{checked_add, checked_div, checked_mul, checked_sub}; use crate::math_funcs::modulo_expr::spark_modulo; use crate::{ - spark_ceil, spark_decimal_div, spark_decimal_integral_div, spark_floor, spark_isnan, - spark_lpad, spark_make_decimal, spark_read_side_padding, spark_round, spark_rpad, spark_unhex, - spark_unscaled_value, EvalMode, SparkBitwiseCount, SparkContains, SparkDateDiff, - SparkDateTrunc, SparkMakeDate, SparkSizeFunc, SparkStringSpace, + spark_binary_lpad, spark_binary_rpad, spark_ceil, spark_decimal_div, + spark_decimal_integral_div, spark_floor, spark_isnan, spark_lpad, spark_make_decimal, + spark_read_side_padding, spark_round, spark_rpad, spark_unhex, spark_unscaled_value, EvalMode, + SparkBitwiseCount, SparkContains, SparkDateDiff, SparkDateTrunc, SparkMakeDate, + SparkSizeFunc, SparkStringSpace, }; use arrow::datatypes::DataType; use datafusion::common::{DataFusionError, Result as DataFusionResult}; @@ -112,6 +113,14 @@ pub fn create_comet_physical_fun_with_eval_mode( let func = Arc::new(spark_read_side_padding); make_comet_scalar_udf!("read_side_padding", func, without data_type) } + "binary_lpad" => { + let func = Arc::new(spark_binary_lpad); + make_comet_scalar_udf!("binary_lpad", func, without data_type) + } + "binary_rpad" => { + let func = Arc::new(spark_binary_rpad); + make_comet_scalar_udf!("binary_rpad", func, without data_type) + } "rpad" => { let func = Arc::new(spark_rpad); make_comet_scalar_udf!("rpad", func, without data_type) diff --git a/native/spark-expr/src/static_invoke/binary_pad.rs b/native/spark-expr/src/static_invoke/binary_pad.rs new file mode 100644 index 0000000000..0b6bc1d088 --- /dev/null +++ b/native/spark-expr/src/static_invoke/binary_pad.rs @@ -0,0 +1,114 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::builder::BinaryBuilder; +use arrow::array::{Array, ArrayRef, AsArray}; +use arrow::datatypes::DataType; +use datafusion::common::{DataFusionError, ScalarValue}; +use datafusion::physical_plan::ColumnarValue; +use std::sync::Arc; + +/// Spark's ByteArray.lpad: left-pad binary array with cyclic pattern. +pub fn spark_binary_lpad(args: &[ColumnarValue]) -> Result { + binary_pad_impl(args, true) +} + +/// Spark's ByteArray.rpad: right-pad binary array with cyclic pattern. +pub fn spark_binary_rpad(args: &[ColumnarValue]) -> Result { + binary_pad_impl(args, false) +} + +fn binary_pad_impl( + args: &[ColumnarValue], + is_left_pad: bool, +) -> Result { + match args { + [ColumnarValue::Array(array), ColumnarValue::Scalar(ScalarValue::Int32(Some(len))), ColumnarValue::Scalar(ScalarValue::Binary(Some(pad)))] => + { + let len = *len; + match array.data_type() { + DataType::Binary => { + let binary_array = array.as_binary::(); + let mut builder = BinaryBuilder::with_capacity(binary_array.len(), 0); + + for i in 0..binary_array.len() { + if binary_array.is_null(i) { + builder.append_null(); + } else { + let bytes = binary_array.value(i); + let result = pad_bytes(bytes, len as usize, pad, is_left_pad); + builder.append_value(&result); + } + } + Ok(ColumnarValue::Array(Arc::new(builder.finish()) as ArrayRef)) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {other:?} for binary_pad", + ))), + } + } + other => Err(DataFusionError::Internal(format!( + "Unsupported arguments {other:?} for binary_pad", + ))), + } +} + +/// Pad bytes to target length using cyclic pad pattern. +/// Matches Spark's ByteArray.lpad/rpad behavior. +fn pad_bytes(bytes: &[u8], len: usize, pad: &[u8], is_left_pad: bool) -> Vec { + if len == 0 { + return Vec::new(); + } + + if pad.is_empty() { + // Empty pattern: return first `len` bytes or copy of input + let take = bytes.len().min(len); + return bytes[..take].to_vec(); + } + + let mut result = vec![0u8; len]; + let min_len = bytes.len().min(len); + + if is_left_pad { + // Copy input bytes to the right side of result + result[len - min_len..].copy_from_slice(&bytes[..min_len]); + // Fill remaining left side with pad pattern + if bytes.len() < len { + fill_with_pattern(&mut result, 0, len - bytes.len(), pad); + } + } else { + // Copy input bytes to the left side of result + result[..min_len].copy_from_slice(&bytes[..min_len]); + // Fill remaining right side with pad pattern + if bytes.len() < len { + fill_with_pattern(&mut result, bytes.len(), len, pad); + } + } + + result +} + +/// Fill result[first_pos..beyond_pos] with cyclic pad pattern. +fn fill_with_pattern(result: &mut [u8], first_pos: usize, beyond_pos: usize, pad: &[u8]) { + let mut pos = first_pos; + while pos < beyond_pos { + let remaining = beyond_pos - pos; + let take = pad.len().min(remaining); + result[pos..pos + take].copy_from_slice(&pad[..take]); + pos += take; + } +} diff --git a/native/spark-expr/src/static_invoke/mod.rs b/native/spark-expr/src/static_invoke/mod.rs index 6a2176b5f9..9b27fc314d 100644 --- a/native/spark-expr/src/static_invoke/mod.rs +++ b/native/spark-expr/src/static_invoke/mod.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +mod binary_pad; mod char_varchar_utils; +pub use binary_pad::{spark_binary_lpad, spark_binary_rpad}; pub use char_varchar_utils::{spark_lpad, spark_read_side_padding, spark_rpad}; diff --git a/spark/src/main/scala/org/apache/comet/serde/statics.scala b/spark/src/main/scala/org/apache/comet/serde/statics.scala index 0737644ab9..4640cc4b76 100644 --- a/spark/src/main/scala/org/apache/comet/serde/statics.scala +++ b/spark/src/main/scala/org/apache/comet/serde/statics.scala @@ -22,8 +22,11 @@ package org.apache.comet.serde import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils +import org.apache.spark.unsafe.types.ByteArray import org.apache.comet.CometSparkSessionExtensions.withInfo +import org.apache.comet.serde.ExprOuterClass.Expr +import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProtoWithReturnType} object CometStaticInvoke extends CometExpressionSerde[StaticInvoke] { @@ -34,7 +37,9 @@ object CometStaticInvoke extends CometExpressionSerde[StaticInvoke] { : Map[(String, Class[_]), CometExpressionSerde[StaticInvoke]] = Map( ("readSidePadding", classOf[CharVarcharCodegenUtils]) -> CometScalarFunction( - "read_side_padding")) + "read_side_padding"), + ("lpad", classOf[ByteArray]) -> CometBinaryPad("binary_lpad"), + ("rpad", classOf[ByteArray]) -> CometBinaryPad("binary_rpad")) override def convert( expr: StaticInvoke, @@ -52,3 +57,42 @@ object CometStaticInvoke extends CometExpressionSerde[StaticInvoke] { } } } + +/** + * Handler for ByteArray.lpad/rpad StaticInvoke (Spark 3.2+, via BinaryPad). Maps to Comet's + * binary_lpad/binary_rpad UDFs. + */ +private case class CometBinaryPad(funcName: String) extends CometExpressionSerde[StaticInvoke] { + + override def convert( + expr: StaticInvoke, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { + val str = expr.arguments(0) + val len = expr.arguments(1) + val pad = expr.arguments(2) + if (str.foldable) { + withInfo(expr, "Scalar values are not supported for the str argument", str) + return None + } + if (!len.foldable) { + withInfo(expr, "Only scalar values are supported for the len argument", len) + return None + } + if (!pad.foldable) { + withInfo(expr, "Only scalar values are supported for the pad argument", pad) + return None + } + val strExpr = exprToProtoInternal(str, inputs, binding) + val lenExpr = exprToProtoInternal(len, inputs, binding) + val padExpr = exprToProtoInternal(pad, inputs, binding) + val optExpr = scalarFunctionExprToProtoWithReturnType( + funcName, + expr.dataType, + false, + strExpr, + lenExpr, + padExpr) + optExprWithInfo(optExpr, expr, expr.arguments: _*) + } +} diff --git a/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala index 121d7f7d5a..60e6675c5f 100644 --- a/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala @@ -136,12 +136,20 @@ class CometStringExpressionSuite extends CometTestBase { // all arguments are literal, so Spark constant folding will kick in // and pad function will not be evaluated by Comet checkSparkAnswerAndOperator(sql) - } else { - // Comet will fall back to Spark because the plan contains a staticinvoke instruction - // which is not supported + } else if (isLiteralStr) { + checkSparkAnswerAndFallbackReason( + sql, + "Scalar values are not supported for the str argument") + } else if (!isLiteralLen) { checkSparkAnswerAndFallbackReason( sql, - s"Static invoke expression: $expr is not supported") + "Only scalar values are supported for the len argument") + } else if (!isLiteralPad) { + checkSparkAnswerAndFallbackReason( + sql, + "Only scalar values are supported for the pad argument") + } else { + checkSparkAnswerAndOperator(sql) } } }