diff --git a/docs/spark_expressions_support.md b/docs/spark_expressions_support.md index 5474894108..bfff05b2e0 100644 --- a/docs/spark_expressions_support.md +++ b/docs/spark_expressions_support.md @@ -460,7 +460,7 @@ - [ ] try_to_binary - [ ] try_to_number - [x] ucase -- [ ] unbase64 +- [x] unbase64 - [x] upper ### struct_funcs diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 9d13ccd9ed..51c55df755 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -178,6 +178,7 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[Left] -> CometLeft, classOf[Right] -> CometRight, classOf[Substring] -> CometSubstring, + classOf[UnBase64] -> CometUnBase64, classOf[Upper] -> CometUpper) private val bitwiseExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map( diff --git a/spark/src/main/scala/org/apache/comet/serde/strings.scala b/spark/src/main/scala/org/apache/comet/serde/strings.scala index 64ba644048..892aa152f1 100644 --- a/spark/src/main/scala/org/apache/comet/serde/strings.scala +++ b/spark/src/main/scala/org/apache/comet/serde/strings.scala @@ -21,7 +21,7 @@ package org.apache.comet.serde import java.util.Locale -import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Concat, ConcatWs, Expression, If, InitCap, IsNull, Left, Length, Like, Literal, Lower, RegExpReplace, Right, RLike, StringLPad, StringRepeat, StringRPad, StringSplit, Substring, Upper} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Concat, ConcatWs, Expression, If, InitCap, IsNull, Left, Length, Like, Literal, Lower, RegExpReplace, Right, RLike, StringLPad, StringRepeat, StringRPad, StringSplit, Substring, UnBase64, Upper} import org.apache.spark.sql.types.{BinaryType, DataTypes, LongType, StringType} import org.apache.spark.unsafe.types.UTF8String @@ -31,6 +31,28 @@ import org.apache.comet.expressions.{CometCast, CometEvalMode, RegExp} import org.apache.comet.serde.ExprOuterClass.Expr import org.apache.comet.serde.QueryPlanSerde.{createBinaryExpr, exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto, scalarFunctionExprToProtoWithReturnType} +object CometUnBase64 extends CometExpressionSerde[UnBase64] { + + override def getSupportLevel(expr: UnBase64): SupportLevel = { + if (expr.failOnError) { + Unsupported(Some("unbase64 with failOnError is not supported")) + } else { + Compatible() + } + } + + override def convert( + expr: UnBase64, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + val inputExpr = exprToProtoInternal(expr.child, inputs, binding) + val encodingExpr = exprToProtoInternal(Literal("base64"), inputs, binding) + val optExpr = scalarFunctionExprToProto("decode", inputExpr, encodingExpr) + optExprWithInfo(optExpr, expr, expr.child) + } +} + + object CometStringRepeat extends CometExpressionSerde[StringRepeat] { override def convert( diff --git a/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala index 121d7f7d5a..f5be9ca55b 100644 --- a/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala @@ -148,6 +148,31 @@ class CometStringExpressionSuite extends CometTestBase { } } + test("unbase64") { + val data = Seq( + "SGVsbG8=", // base64("Hello") + "U3BhcmsgU1FM", // base64("Spark SQL") + "", // empty + null).map(Tuple1(_)) + withParquetTable(data, "tbl") { + // unbase64 decoding from column + checkSparkAnswerAndOperator("SELECT unbase64(_1) FROM tbl") + // unbase64 with inline literal + checkSparkAnswerAndOperator("SELECT unbase64('U3BhcmsgU1FM') FROM tbl") + // null handling + checkSparkAnswerAndOperator("SELECT unbase64(NULL) FROM tbl") + } + } + + test("to_binary with base64 falls back (failOnError)") { + val data = Seq("SGVsbG8=", "U3BhcmsgU1FM").map(Tuple1(_)) + withParquetTable(data, "tbl") { + checkSparkAnswerAndFallbackReason( + "SELECT to_binary(_1, 'base64') FROM tbl", + "unbase64 with failOnError is not supported") + } + } + test("split string basic") { withSQLConf("spark.comet.expression.StringSplit.allowIncompatible" -> "true") { withParquetTable((0 until 5).map(i => (s"value$i,test$i", i)), "tbl") {