diff --git a/native/spark-expr/src/comet_scalar_funcs.rs b/native/spark-expr/src/comet_scalar_funcs.rs index 4bfdef7096..37667f0df8 100644 --- a/native/spark-expr/src/comet_scalar_funcs.rs +++ b/native/spark-expr/src/comet_scalar_funcs.rs @@ -20,10 +20,11 @@ use crate::math_funcs::abs::abs; use crate::math_funcs::checked_arithmetic::{checked_add, checked_div, checked_mul, checked_sub}; use crate::math_funcs::modulo_expr::spark_modulo; use crate::{ - spark_ceil, spark_decimal_div, spark_decimal_integral_div, spark_floor, spark_isnan, - spark_lpad, spark_make_decimal, spark_read_side_padding, spark_round, spark_rpad, spark_unhex, - spark_unscaled_value, EvalMode, SparkBitwiseCount, SparkContains, SparkDateDiff, - SparkDateTrunc, SparkMakeDate, SparkSizeFunc, SparkStringSpace, + spark_ceil, spark_char_type_write_side_check, spark_decimal_div, + spark_decimal_integral_div, spark_floor, spark_isnan, spark_lpad, spark_make_decimal, + spark_read_side_padding, spark_round, spark_rpad, spark_unhex, spark_unscaled_value, + spark_varchar_type_write_side_check, EvalMode, SparkBitwiseCount, SparkContains, + SparkDateDiff, SparkDateTrunc, SparkMakeDate, SparkSizeFunc, SparkStringSpace, }; use arrow::datatypes::DataType; use datafusion::common::{DataFusionError, Result as DataFusionResult}; @@ -112,6 +113,14 @@ pub fn create_comet_physical_fun_with_eval_mode( let func = Arc::new(spark_read_side_padding); make_comet_scalar_udf!("read_side_padding", func, without data_type) } + "char_type_write_side_check" => { + let func = Arc::new(spark_char_type_write_side_check); + make_comet_scalar_udf!("char_type_write_side_check", func, without data_type) + } + "varchar_type_write_side_check" => { + let func = Arc::new(spark_varchar_type_write_side_check); + make_comet_scalar_udf!("varchar_type_write_side_check", func, without data_type) + } "rpad" => { let func = Arc::new(spark_rpad); make_comet_scalar_udf!("rpad", func, without data_type) diff --git a/native/spark-expr/src/static_invoke/char_varchar_utils/mod.rs b/native/spark-expr/src/static_invoke/char_varchar_utils/mod.rs index 5bb94a7ad5..c57112c6b6 100644 --- a/native/spark-expr/src/static_invoke/char_varchar_utils/mod.rs +++ b/native/spark-expr/src/static_invoke/char_varchar_utils/mod.rs @@ -16,5 +16,7 @@ // under the License. mod read_side_padding; +mod write_side_check; pub use read_side_padding::{spark_lpad, spark_read_side_padding, spark_rpad}; +pub use write_side_check::{spark_char_type_write_side_check, spark_varchar_type_write_side_check}; diff --git a/native/spark-expr/src/static_invoke/char_varchar_utils/write_side_check.rs b/native/spark-expr/src/static_invoke/char_varchar_utils/write_side_check.rs new file mode 100644 index 0000000000..a1284a9cf4 --- /dev/null +++ b/native/spark-expr/src/static_invoke/char_varchar_utils/write_side_check.rs @@ -0,0 +1,133 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::builder::GenericStringBuilder; +use arrow::array::cast::as_dictionary_array; +use arrow::array::types::Int32Type; +use arrow::array::{make_array, Array, DictionaryArray}; +use arrow::array::{ArrayRef, OffsetSizeTrait}; +use arrow::datatypes::DataType; +use datafusion::common::{cast::as_generic_string_array, DataFusionError, ScalarValue}; +use datafusion::physical_plan::ColumnarValue; +use std::sync::Arc; + +/// Spark's charTypeWriteSideCheck: pad if shorter, trim trailing spaces if longer. +/// Throws if string exceeds limit after trimming. +pub fn spark_char_type_write_side_check( + args: &[ColumnarValue], +) -> Result { + write_side_check_impl(args, true) +} + +/// Spark's varcharTypeWriteSideCheck: return as-is if within limit, trim trailing spaces if longer. +/// Throws if string exceeds limit after trimming. +pub fn spark_varchar_type_write_side_check( + args: &[ColumnarValue], +) -> Result { + write_side_check_impl(args, false) +} + +fn write_side_check_impl( + args: &[ColumnarValue], + pad_if_shorter: bool, +) -> Result { + match args { + [ColumnarValue::Array(array), ColumnarValue::Scalar(ScalarValue::Int32(Some(limit)))] => { + let limit = *limit as usize; + match array.data_type() { + DataType::Utf8 => { + write_side_check_internal::(array, limit, pad_if_shorter) + } + DataType::LargeUtf8 => { + write_side_check_internal::(array, limit, pad_if_shorter) + } + DataType::Dictionary(_, value_type) => { + let dict = as_dictionary_array::(array); + let col = if value_type.as_ref() == &DataType::Utf8 { + write_side_check_internal::(dict.values(), limit, pad_if_shorter)? + } else { + write_side_check_internal::(dict.values(), limit, pad_if_shorter)? + }; + let values = col.to_array(0)?; + let result = DictionaryArray::try_new(dict.keys().clone(), values)?; + Ok(ColumnarValue::Array(make_array(result.into()))) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {other:?} for write_side_check", + ))), + } + } + other => Err(DataFusionError::Internal(format!( + "Unsupported arguments {other:?} for write_side_check", + ))), + } +} + +fn write_side_check_internal( + array: &ArrayRef, + limit: usize, + pad_if_shorter: bool, +) -> Result { + let string_array = as_generic_string_array::(array)?; + + let mut builder = + GenericStringBuilder::::with_capacity(string_array.len(), string_array.len() * limit); + let mut buffer = String::with_capacity(limit); + + for string in string_array.iter() { + match string { + Some(s) => { + let char_len = s.chars().count(); + if char_len <= limit { + if pad_if_shorter && char_len < limit { + // Pad with spaces to reach limit + buffer.clear(); + buffer.push_str(s); + for _ in 0..(limit - char_len) { + buffer.push(' '); + } + builder.append_value(&buffer); + } else { + builder.append_value(s); + } + } else { + // Trim trailing spaces + let trimmed = s.trim_end_matches(' '); + let trimmed_char_len = trimmed.chars().count(); + if trimmed_char_len > limit { + return Err(DataFusionError::Execution(format!( + "Exceeds char/varchar type length limitation: {limit}" + ))); + } + if pad_if_shorter && trimmed_char_len < limit { + // For CHAR type: pad back to limit after trimming + buffer.clear(); + buffer.push_str(trimmed); + for _ in 0..(limit - trimmed_char_len) { + buffer.push(' '); + } + builder.append_value(&buffer); + } else { + builder.append_value(trimmed); + } + } + } + None => builder.append_null(), + } + } + Ok(ColumnarValue::Array(Arc::new(builder.finish()))) +} diff --git a/native/spark-expr/src/static_invoke/mod.rs b/native/spark-expr/src/static_invoke/mod.rs index 6a2176b5f9..9e85bc5d86 100644 --- a/native/spark-expr/src/static_invoke/mod.rs +++ b/native/spark-expr/src/static_invoke/mod.rs @@ -17,4 +17,7 @@ mod char_varchar_utils; -pub use char_varchar_utils::{spark_lpad, spark_read_side_padding, spark_rpad}; +pub use char_varchar_utils::{ + spark_char_type_write_side_check, spark_lpad, spark_read_side_padding, spark_rpad, + spark_varchar_type_write_side_check, +}; diff --git a/spark/src/main/scala/org/apache/comet/serde/statics.scala b/spark/src/main/scala/org/apache/comet/serde/statics.scala index 0737644ab9..3970a00b80 100644 --- a/spark/src/main/scala/org/apache/comet/serde/statics.scala +++ b/spark/src/main/scala/org/apache/comet/serde/statics.scala @@ -34,7 +34,11 @@ object CometStaticInvoke extends CometExpressionSerde[StaticInvoke] { : Map[(String, Class[_]), CometExpressionSerde[StaticInvoke]] = Map( ("readSidePadding", classOf[CharVarcharCodegenUtils]) -> CometScalarFunction( - "read_side_padding")) + "read_side_padding"), + ("charTypeWriteSideCheck", classOf[CharVarcharCodegenUtils]) -> CometScalarFunction( + "char_type_write_side_check"), + ("varcharTypeWriteSideCheck", classOf[CharVarcharCodegenUtils]) -> CometScalarFunction( + "varchar_type_write_side_check")) override def convert( expr: StaticInvoke, diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index f0f022868f..14597194d9 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -2252,6 +2252,32 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("charTypeWriteSideCheck") { + val table = "test" + withTable(table) { + sql(s"create table $table(col CHAR(5)) using parquet") + sql(s"insert into $table values('ab')") + sql(s"insert into $table values('abcde')") + sql(s"insert into $table values('abc ')") // trailing spaces, equals limit after trim+pad + sql(s"insert into $table values('')") + // Read back — CHAR(5) should pad to 5 characters + checkSparkAnswerAndOperator(s"SELECT col FROM $table") + } + } + + test("varcharTypeWriteSideCheck") { + val table = "test" + withTable(table) { + sql(s"create table $table(col VARCHAR(5)) using parquet") + sql(s"insert into $table values('ab')") + sql(s"insert into $table values('abcde')") + sql(s"insert into $table values('abc ')") // trailing spaces within limit + sql(s"insert into $table values('')") + // Read back — VARCHAR(5) should NOT pad + checkSparkAnswerAndOperator(s"SELECT col FROM $table") + } + } + test("isnan") { Seq("true", "false").foreach { dictionary => withSQLConf("parquet.enable.dictionary" -> dictionary) {