Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 82 additions & 35 deletions native/spark-expr/src/conversion_funcs/string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,19 +37,29 @@ macro_rules! cast_utf8_to_timestamp {
($array:expr, $eval_mode:expr, $array_type:ty, $cast_method:ident, $tz:expr) => {{
let len = $array.len();
let mut cast_array = PrimitiveArray::<$array_type>::builder(len).with_timezone("UTC");
let mut cast_err: Option<SparkError> = None;
for i in 0..len {
if $array.is_null(i) {
cast_array.append_null()
} else if let Ok(Some(cast_value)) =
$cast_method($array.value(i).trim(), $eval_mode, $tz)
{
cast_array.append_value(cast_value);
} else {
cast_array.append_null()
match $cast_method($array.value(i).trim(), $eval_mode, $tz) {
Ok(Some(cast_value)) => cast_array.append_value(cast_value),
Ok(None) => cast_array.append_null(),
Err(e) => {
if $eval_mode == EvalMode::Ansi {
cast_err = Some(e);
break;
}
cast_array.append_null()
}
}
}
}
let result: ArrayRef = Arc::new(cast_array.finish()) as ArrayRef;
result
if let Some(e) = cast_err {
Err(e)
} else {
Ok(Arc::new(cast_array.finish()) as ArrayRef)
}
}};
}

Expand Down Expand Up @@ -668,15 +678,13 @@ pub(crate) fn cast_string_to_timestamp(
let tz = &timezone::Tz::from_str(timezone_str).unwrap();

let cast_array: ArrayRef = match to_type {
DataType::Timestamp(_, _) => {
cast_utf8_to_timestamp!(
string_array,
eval_mode,
TimestampMicrosecondType,
timestamp_parser,
tz
)
}
DataType::Timestamp(_, _) => cast_utf8_to_timestamp!(
string_array,
eval_mode,
TimestampMicrosecondType,
timestamp_parser,
tz
)?,
_ => unreachable!("Invalid data type {:?} in cast from string", to_type),
};
Ok(cast_array)
Expand Down Expand Up @@ -961,6 +969,12 @@ fn get_timestamp_values<T: TimeZone>(
) -> SparkResult<Option<i64>> {
let values: Vec<_> = value.split(['T', '-', ':', '.']).collect();
let year = values[0].parse::<i32>().unwrap_or_default();

// NaiveDate (used internally by chrono's with_ymd_and_hms) is bounded to ±262142.
if !(-262143..=262142).contains(&year) {
return Ok(None);
}

let month = values.get(1).map_or(1, |m| m.parse::<u32>().unwrap_or(1));
let day = values.get(2).map_or(1, |d| d.parse::<u32>().unwrap_or(1));
let hour = values.get(3).map_or(0, |h| h.parse::<u32>().unwrap_or(0));
Expand Down Expand Up @@ -1004,7 +1018,7 @@ fn get_timestamp_values<T: TimeZone>(
.with_second(second)
.with_microsecond(microsecond),
_ => {
return Err(SparkError::CastInvalidValue {
return Err(SparkError::InvalidInputInCastToDatetime {
value: value.to_string(),
from_type: "STRING".to_string(),
to_type: "TIMESTAMP".to_string(),
Expand Down Expand Up @@ -1095,31 +1109,31 @@ fn timestamp_parser<T: TimeZone>(
// Define regex patterns and corresponding parsing functions
let patterns = &[
(
Regex::new(r"^\d{4,5}$").unwrap(),
Regex::new(r"^\d{4,7}$").unwrap(),
parse_str_to_year_timestamp as fn(&str, &T) -> SparkResult<Option<i64>>,
),
(
Regex::new(r"^\d{4,5}-\d{2}$").unwrap(),
Regex::new(r"^\d{4,7}-\d{2}$").unwrap(),
parse_str_to_month_timestamp,
),
(
Regex::new(r"^\d{4,5}-\d{2}-\d{2}$").unwrap(),
Regex::new(r"^\d{4,7}-\d{2}-\d{2}$").unwrap(),
parse_str_to_day_timestamp,
),
(
Regex::new(r"^\d{4,5}-\d{2}-\d{2}T\d{1,2}$").unwrap(),
Regex::new(r"^\d{4,7}-\d{2}-\d{2}T\d{1,2}$").unwrap(),
parse_str_to_hour_timestamp,
),
(
Regex::new(r"^\d{4,5}-\d{2}-\d{2}T\d{2}:\d{2}$").unwrap(),
Regex::new(r"^\d{4,7}-\d{2}-\d{2}T\d{2}:\d{2}$").unwrap(),
parse_str_to_minute_timestamp,
),
(
Regex::new(r"^\d{4,5}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}$").unwrap(),
Regex::new(r"^\d{4,7}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}$").unwrap(),
parse_str_to_second_timestamp,
),
(
Regex::new(r"^\d{4,5}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{1,6}$").unwrap(),
Regex::new(r"^\d{4,7}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{1,6}$").unwrap(),
parse_str_to_microsecond_timestamp,
),
(
Expand All @@ -1140,7 +1154,7 @@ fn timestamp_parser<T: TimeZone>(

if timestamp.is_none() {
return if eval_mode == EvalMode::Ansi {
Err(SparkError::CastInvalidValue {
Err(SparkError::InvalidInputInCastToDatetime {
value: value.to_string(),
from_type: "STRING".to_string(),
to_type: "TIMESTAMP".to_string(),
Expand Down Expand Up @@ -1202,17 +1216,20 @@ fn date_parser(date_str: &str, eval_mode: EvalMode) -> SparkResult<Option<i32>>
}

fn is_valid_digits(segment: i32, digits: usize) -> bool {
// An integer is able to represent a date within [+-]5 million years.
// NaiveDate is bounded to [-262142, 262142] (6 digits). We allow up to 7 digits to support
// leading-zero year strings like "0002020" (= year 2020), matching Spark's
// isValidDigits. Values outside the bounds are caught by an explicit bounds
// check below.
let max_digits_year = 7;
//year (segment 0) can be between 4 to 7 digits,
//month and day (segment 1 and 2) can be between 1 to 2 digits
// year (segment 0) can be between 4 to 7 digits,
// month and day (segment 1 and 2) can be between 1 to 2 digits
(segment == 0 && digits >= 4 && digits <= max_digits_year)
|| (segment != 0 && digits > 0 && digits <= 2)
}

fn return_result(date_str: &str, eval_mode: EvalMode) -> SparkResult<Option<i32>> {
if eval_mode == EvalMode::Ansi {
Err(SparkError::CastInvalidValue {
Err(SparkError::InvalidInputInCastToDatetime {
value: date_str.to_string(),
from_type: "STRING".to_string(),
to_type: "DATE".to_string(),
Expand Down Expand Up @@ -1285,11 +1302,13 @@ fn date_parser(date_str: &str, eval_mode: EvalMode) -> SparkResult<Option<i32>>

date_segments[current_segment as usize] = current_segment_value.0;

match NaiveDate::from_ymd_opt(
sign * date_segments[0],
date_segments[1] as u32,
date_segments[2] as u32,
) {
// Reject out-of-range years explicitly
let year = sign * date_segments[0];
if !(-262143..=262142).contains(&year) {
return Ok(None);
}

match NaiveDate::from_ymd_opt(year, date_segments[1] as u32, date_segments[2] as u32) {
Some(date) => {
let duration_since_epoch = date
.signed_duration_since(DateTime::UNIX_EPOCH.naive_utc().date())
Expand Down Expand Up @@ -1341,7 +1360,8 @@ mod tests {
TimestampMicrosecondType,
timestamp_parser,
tz
);
)
.unwrap();

assert_eq!(
result.data_type(),
Expand All @@ -1350,6 +1370,33 @@ mod tests {
assert_eq!(result.len(), 4);
}

#[test]
fn test_cast_string_to_timestamp_ansi_error() {
// In ANSI mode, an invalid timestamp string must produce an error rather than null.
let array: ArrayRef = Arc::new(StringArray::from(vec![
Some("2020-01-01T12:34:56.123456"),
Some("not_a_timestamp"),
]));
let tz = &timezone::Tz::from_str("UTC").unwrap();
let string_array = array
.as_any()
.downcast_ref::<GenericStringArray<i32>>()
.expect("Expected a string array");

let eval_mode = EvalMode::Ansi;
let result = cast_utf8_to_timestamp!(
&string_array,
eval_mode,
TimestampMicrosecondType,
timestamp_parser,
tz
);
assert!(
result.is_err(),
"ANSI mode should return Err for an invalid timestamp string"
);
}

#[test]
fn test_cast_dict_string_to_timestamp() -> DataFusionResult<()> {
// prepare input data
Expand Down
25 changes: 24 additions & 1 deletion native/spark-expr/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,19 @@ pub enum SparkError {
to_type: String,
},

/// Like CastInvalidValue but maps to SparkDateTimeException instead of SparkNumberFormatException.
/// Used for string → timestamp/date cast failures where Spark throws SparkDateTimeException
/// with the CAST_INVALID_INPUT error class.
#[error("[CAST_INVALID_INPUT] The value '{value}' of the type \"{from_type}\" cannot be cast to \"{to_type}\" \
because it is malformed. Correct the value as per the syntax, or change its target type. \
Use `try_cast` to tolerate malformed input and return NULL instead. If necessary \
set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")]
InvalidInputInCastToDatetime {
value: String,
from_type: String,
to_type: String,
},

#[error("[NUMERIC_VALUE_OUT_OF_RANGE.WITH_SUGGESTION] {value} cannot be represented as Decimal({precision}, {scale}). If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error, and return NULL instead.")]
NumericValueOutOfRange {
value: String,
Expand Down Expand Up @@ -199,6 +212,7 @@ impl SparkError {
fn error_type_name(&self) -> &'static str {
match self {
SparkError::CastInvalidValue { .. } => "CastInvalidValue",
SparkError::InvalidInputInCastToDatetime { .. } => "InvalidInputInCastToDatetime",
SparkError::NumericValueOutOfRange { .. } => "NumericValueOutOfRange",
SparkError::NumericOutOfRange { .. } => "NumericOutOfRange",
SparkError::CastOverFlow { .. } => "CastOverFlow",
Expand Down Expand Up @@ -248,6 +262,11 @@ impl SparkError {
value,
from_type,
to_type,
}
| SparkError::InvalidInputInCastToDatetime {
value,
from_type,
to_type,
} => {
serde_json::json!({
"value": value,
Expand Down Expand Up @@ -456,9 +475,12 @@ impl SparkError {
// CastOverflow gets special handling with CastOverflowException
SparkError::CastOverFlow { .. } => "org/apache/spark/sql/comet/CastOverflowException",

// NumberFormatException (for cast invalid input errors)
// NumberFormatException (for cast invalid input errors on numeric types)
SparkError::CastInvalidValue { .. } => "org/apache/spark/SparkNumberFormatException",

// DateTimeException (for cast invalid input errors on datetime types)
SparkError::InvalidInputInCastToDatetime { .. } => "org/apache/spark/SparkDateTimeException",

// ArrayIndexOutOfBoundsException
SparkError::InvalidArrayIndex { .. }
| SparkError::InvalidElementAtIndex { .. }
Expand Down Expand Up @@ -497,6 +519,7 @@ impl SparkError {
match self {
// Cast errors
SparkError::CastInvalidValue { .. } => Some("CAST_INVALID_INPUT"),
SparkError::InvalidInputInCastToDatetime { .. } => Some("CAST_INVALID_INPUT"),
SparkError::CastOverFlow { .. } => Some("CAST_OVERFLOW"),
SparkError::NumericValueOutOfRange { .. } => {
Some("NUMERIC_VALUE_OUT_OF_RANGE.WITH_SUGGESTION")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ object SparkErrorConverter extends ShimSparkErrorConverter {
case None => Array.empty[QueryContext] // No context
}

val summary: String = errorJson.summary.orNull
val summary: String = errorJson.summary.getOrElse("")

// Delegate to version-specific shim - let conversion exceptions propagate
val optEx = convertErrorType(errorJson.errorType, errorClass, params, sparkContext, summary)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,10 +217,7 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim {
Compatible(Some("Only supports years between 262143 BC and 262142 AD"))
case DataTypes.TimestampType if timeZoneId.exists(tz => tz != "UTC") =>
Incompatible(Some(s"Cast will use UTC instead of $timeZoneId"))
case DataTypes.TimestampType if evalMode == CometEvalMode.ANSI =>
Incompatible(Some("ANSI mode not supported"))
case DataTypes.TimestampType =>
// https://github.com/apache/datafusion-comet/issues/328
Incompatible(Some("Not all valid formats are supported"))
case _ =>
unsupported(DataTypes.StringType, toType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

package org.apache.spark.sql.comet.shims

import org.apache.spark.{QueryContext, SparkException}
import org.apache.spark.{QueryContext, SparkDateTimeException, SparkException}
import org.apache.spark.sql.catalyst.trees.SQLQueryContext
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -164,6 +164,22 @@ trait ShimSparkErrorConverter {
QueryExecutionErrors
.invalidInputInCastToNumberError(targetType, str, sqlCtx(context)))

case "InvalidInputInCastToDatetime" =>
val expression =
s"'${params("value").toString.replace("\\", "\\\\").replace("'", "\\'")}'"
val sourceType = s""""${params("fromType").toString}""""
val targetType = s""""${params("toType").toString}""""
Some(
new SparkDateTimeException(
errorClass = "CAST_INVALID_INPUT",
messageParameters = Map(
"expression" -> expression,
"sourceType" -> sourceType,
"targetType" -> targetType,
"ansiConfig" -> "\"spark.sql.ansi.enabled\""),
context = context,
summary = summary))

case "CastOverFlow" =>
val fromType = getDataType(params("fromType").toString)
val toType = getDataType(params("toType").toString)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

package org.apache.spark.sql.comet.shims

import org.apache.spark.{QueryContext, SparkException}
import org.apache.spark.{QueryContext, SparkDateTimeException, SparkException}
import org.apache.spark.sql.catalyst.trees.SQLQueryContext
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -162,6 +162,22 @@ trait ShimSparkErrorConverter {
QueryExecutionErrors
.invalidInputInCastToNumberError(targetType, str, sqlCtx(context)))

case "InvalidInputInCastToDatetime" =>
val expression =
s"'${params("value").toString.replace("\\", "\\\\").replace("'", "\\'")}'"
val sourceType = s""""${params("fromType").toString}""""
val targetType = s""""${params("toType").toString}""""
Some(
new SparkDateTimeException(
errorClass = "CAST_INVALID_INPUT",
messageParameters = Map(
"expression" -> expression,
"sourceType" -> sourceType,
"targetType" -> targetType,
"ansiConfig" -> "\"spark.sql.ansi.enabled\""),
context = context,
summary = summary))

case "CastOverFlow" =>
val fromType = getDataType(params("fromType").toString)
val toType = getDataType(params("toType").toString)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,13 @@ trait ShimSparkErrorConverter {
QueryExecutionErrors
.invalidInputInCastToNumberError(targetType, str, context.headOption.orNull))

case "InvalidInputInCastToDatetime" =>
val str = UTF8String.fromString(params("value").toString)
val targetType = getDataType(params("toType").toString)
Some(
QueryExecutionErrors
.invalidInputInCastToDatetimeError(str, targetType, context.headOption.orNull))

case "CastOverFlow" =>
val fromType = getDataType(params("fromType").toString)
val toType = getDataType(params("toType").toString)
Expand Down
Loading
Loading