diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index 1030e30aaf..bf06accbc9 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -56,6 +56,7 @@ use datafusion_spark::function::math::hex::SparkHex; use datafusion_spark::function::math::width_bucket::SparkWidthBucket; use datafusion_spark::function::string::char::CharFunc; use datafusion_spark::function::string::concat::SparkConcat; +use datafusion_spark::function::string::luhn_check::SparkLuhnCheck; use datafusion_spark::function::string::space::SparkSpace; use futures::poll; use futures::stream::StreamExt; @@ -402,6 +403,7 @@ fn register_datafusion_spark_function(session_ctx: &SessionContext) { session_ctx.register_udf(ScalarUDF::new_from_impl(SparkWidthBucket::default())); session_ctx.register_udf(ScalarUDF::new_from_impl(MapFromEntries::default())); session_ctx.register_udf(ScalarUDF::new_from_impl(SparkCrc32::default())); + session_ctx.register_udf(ScalarUDF::new_from_impl(SparkLuhnCheck::default())); session_ctx.register_udf(ScalarUDF::new_from_impl(SparkSpace::default())); session_ctx.register_udf(ScalarUDF::new_from_impl(SparkBitCount::default())); } diff --git a/spark/src/main/scala/org/apache/comet/serde/statics.scala b/spark/src/main/scala/org/apache/comet/serde/statics.scala index 0737644ab9..9dbc6d169f 100644 --- a/spark/src/main/scala/org/apache/comet/serde/statics.scala +++ b/spark/src/main/scala/org/apache/comet/serde/statics.scala @@ -19,7 +19,7 @@ package org.apache.comet.serde -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.{Attribute, ExpressionImplUtils} import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils @@ -34,7 +34,8 @@ object CometStaticInvoke extends CometExpressionSerde[StaticInvoke] { : Map[(String, Class[_]), CometExpressionSerde[StaticInvoke]] = Map( ("readSidePadding", classOf[CharVarcharCodegenUtils]) -> CometScalarFunction( - "read_side_padding")) + "read_side_padding"), + ("isLuhnNumber", classOf[ExpressionImplUtils]) -> CometScalarFunction("luhn_check")) override def convert( expr: StaticInvoke, diff --git a/spark/src/test/resources/sql-tests/expressions/string/luhn_check.sql b/spark/src/test/resources/sql-tests/expressions/string/luhn_check.sql new file mode 100644 index 0000000000..ea10808467 --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/string/luhn_check.sql @@ -0,0 +1,33 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- MinSparkVersion: 3.5 +-- ConfigMatrix: parquet.enable.dictionary=false,true + +statement +CREATE TABLE test_luhn(s string) USING parquet + +statement +INSERT INTO test_luhn VALUES ('79927398710'), ('79927398713'), ('1234567812345670'), ('0'), (''), ('abc'), (NULL) + +-- column input +query +SELECT luhn_check(s) FROM test_luhn + +-- literal arguments +query +SELECT luhn_check('79927398713'), luhn_check('79927398710'), luhn_check(''), luhn_check(NULL)