diff --git a/.github/workflows/pr_build_linux.yml b/.github/workflows/pr_build_linux.yml index c8ac8adac0..6e3f38743f 100644 --- a/.github/workflows/pr_build_linux.yml +++ b/.github/workflows/pr_build_linux.yml @@ -380,6 +380,7 @@ jobs: org.apache.comet.expressions.conditional.CometIfSuite org.apache.comet.expressions.conditional.CometCoalesceSuite org.apache.comet.expressions.conditional.CometCaseWhenSuite + org.apache.comet.CometRegExpJvmSuite - name: "sql" value: | org.apache.spark.sql.CometToPrettyStringSuite diff --git a/.github/workflows/pr_build_macos.yml b/.github/workflows/pr_build_macos.yml index 3c6953aade..dde953c872 100644 --- a/.github/workflows/pr_build_macos.yml +++ b/.github/workflows/pr_build_macos.yml @@ -232,6 +232,7 @@ jobs: org.apache.comet.expressions.conditional.CometIfSuite org.apache.comet.expressions.conditional.CometCoalesceSuite org.apache.comet.expressions.conditional.CometCaseWhenSuite + org.apache.comet.CometRegExpJvmSuite - name: "sql" value: | org.apache.spark.sql.CometToPrettyStringSuite diff --git a/.gitignore b/.gitignore index a3c97ff992..eed00f3262 100644 --- a/.gitignore +++ b/.gitignore @@ -27,3 +27,4 @@ output docs/comet-*/ docs/build/ docs/temp/ +docs/superpowers/ diff --git a/docs/source/user-guide/latest/compatibility/regex.md b/docs/source/user-guide/latest/compatibility/regex.md index 4d9d5b650c..53e1cd6b6d 100644 --- a/docs/source/user-guide/latest/compatibility/regex.md +++ b/docs/source/user-guide/latest/compatibility/regex.md @@ -19,6 +19,100 @@ under the License. # Regular Expressions -Comet uses the Rust regexp crate for evaluating regular expressions, and this has different behavior from Java's -regular expression engine. Comet will fall back to Spark for patterns that are known to produce different results, but -this can be overridden by setting `spark.comet.expression.regexp.allowIncompatible=true`. +Comet provides two regexp engines for evaluating regular expressions: a **Rust engine** that uses the Rust +[`regex`] crate natively, and an experimental **Java engine** that calls back into the JVM via Comet's JVM +UDF framework. The engine is selected with `spark.comet.exec.regexp.engine`, which accepts: + +- `java` (default) — route through the Java engine for full Spark compatibility. Requires + `spark.comet.jvmUdf.enabled=true`; otherwise regex expressions fall back to Spark with an explanatory + message. +- `rust` — run the Rust engine when an expression has a native implementation. Setting this is itself + the opt-in for the semantic differences between Java and Rust regex (no separate `allowIncompatible` + flag needed). Expressions without a native Rust implementation (`regexp_extract`, + `regexp_extract_all`, `regexp_instr`) fall back to Spark. + +The JVM UDF framework is experimental and disabled by default. With pure defaults +(`engine=java`, `jvmUdf.enabled=false`), all regex expressions fall back to Spark. + +## Choosing an engine + +| | Rust engine | Java engine (experimental, default) | +| -------------------- | --------------------------------------- | ------------------------------------------------------------------------------------------------------------------- | +| **Compatibility** | Differs from Java regex (see below) | 100% compatible with Spark | +| **Feature coverage** | `rlike`, `regexp_replace`, `split` only | All regexp expressions (`rlike`, `regexp_extract`, `regexp_extract_all`, `regexp_instr`, `regexp_replace`, `split`) | +| **Performance** | Fully native, no JNI overhead | One JNI round-trip per batch (Arrow vectors stay columnar) | +| **Pattern support** | Linear-time subset only | All Java regex features (backreferences, lookaround, etc.) | + +The **Rust engine** is faster but cannot match Java regex semantics for every pattern. Because the engine +choice is itself the opt-in, setting `spark.comet.exec.regexp.engine=rust` declares acceptance of those +differences without a separate per-expression flag. + +The **Java engine** is the default but the underlying JVM UDF framework is experimental and gated behind +`spark.comet.jvmUdf.enabled=true`; the behavior, configuration, and supported expressions may change in +future releases. + +## Why the engines differ + +Java's `java.util.regex` is a backtracking engine in the Perl/PCRE family. It supports the full range of +features that style of engine provides, including some whose worst-case running time grows exponentially with +the input. + +Rust's [`regex`] crate is a finite-automaton engine in the [RE2] family. It deliberately omits features that +cannot be implemented with a guarantee of linear-time matching. In exchange, every pattern it does accept runs +in time linear in the size of the input. This is the same trade-off RE2, Go's `regexp`, and several other +engines make. + +The practical consequence is that Java accepts a strictly larger set of patterns than the Rust engine, and +several constructs that look the same in source have different semantics on the two sides. + +## Features supported by Java but not by the Rust engine + +Patterns that use any of the following will not compile in Comet's Rust engine and must run on Spark (or use +the Java engine): + +- **Backreferences** such as `\1`, `\2`, or `\k`. The Rust engine has no backtracking and cannot match + a previously captured group. +- **Lookaround**, including lookahead (`(?=...)`, `(?!...)`) and lookbehind (`(?<=...)`, `(?...)`). +- **Possessive quantifiers** (`*+`, `++`, `?+`, `{n,m}+`). Rust supports greedy and lazy quantifiers but not + possessive. +- **Embedded code, conditionals, and recursion** such as `(?(cond)yes|no)` or `(?R)`. Rust accepts none of + these. + +## Features that exist on both sides but behave differently + +Even where both engines accept a construct, the matching behavior is not always the same. + +- **Unicode-aware character classes.** In the Rust engine, `\d`, `\w`, `\s`, and `.` are Unicode-aware by + default, so `\d` matches every digit codepoint defined by Unicode rather than only `0`-`9`. Java's defaults + match ASCII only and require the `UNICODE_CHARACTER_CLASS` flag (or `(?U)` inline) to switch to Unicode + semantics. The same pattern can therefore match a different set of characters on each side. +- **Line terminators.** In multiline mode, Java treats `\r`, `\n`, `\r\n`, and a few additional Unicode line + separators as line boundaries by default. The Rust engine treats only `\n` as a line boundary unless CRLF + mode is enabled. `^`, `$`, and `.` (with `(?s)` off) all depend on this definition. +- **Case-insensitive matching.** Both engines support `(?i)`, but Java's default is ASCII case folding while + the Rust engine uses full Unicode simple case folding when Unicode mode is on. Patterns that match characters + outside ASCII can produce different results. +- **POSIX character classes.** The Rust engine supports `[[:alpha:]]` style POSIX classes inside bracket + expressions but not Java's `\p{Alpha}` shorthand. Java accepts both. Unicode property escapes (`\p{L}`, + `\p{Greek}`, etc.) are supported by both engines but cover slightly different sets of properties. +- **Octal and Unicode escapes.** Java accepts `\0nnn` for octal and `\uXXXX` for a BMP codepoint. Rust uses + `\x{...}` for arbitrary codepoints and does not accept Java's bare `\uXXXX` form. +- **Empty matches in `split`.** Spark's `StringSplit`, which is built on Java's regex, includes leading empty + strings produced by zero-width matches at the start of the input. The Rust engine's `split` follows different + rules, so split results can differ in edge cases involving empty matches even when the pattern itself is + identical on both sides. + +## When the Rust engine is safe + +For most ASCII-only, non-anchored patterns that use only literal characters, simple character classes, and +ordinary quantifiers, the two engines produce the same results. If you are confident your patterns fit this +shape and want to avoid the JNI overhead of the Java engine, switching to the Rust engine with +`allowIncompatible=true` is generally safe. + +For anything that uses backreferences, lookaround, or relies on Java's specific Unicode or line-handling +defaults, use the experimental Java engine. + +[`java.util.regex`]: https://docs.oracle.com/javase/8/docs/api/java/util/regex/Pattern.html +[`regex`]: https://docs.rs/regex/latest/regex/ +[RE2]: https://github.com/google/re2/wiki/Syntax diff --git a/native/jni-bridge/src/lib.rs b/native/jni-bridge/src/lib.rs index d72323c961..f95d3cc174 100644 --- a/native/jni-bridge/src/lib.rs +++ b/native/jni-bridge/src/lib.rs @@ -231,7 +231,8 @@ pub struct JVMClasses<'a> { /// acquire & release native memory. pub comet_task_memory_manager: CometTaskMemoryManager<'a>, /// The CometUdfBridge class used to dispatch JVM scalar UDFs. - /// `None` if the class is not on the classpath. + /// `None` if the class is not on the classpath; the JVM-UDF dispatch path + /// reports a clear error rather than crashing executor init. pub comet_udf_bridge: Option>, } @@ -304,6 +305,9 @@ impl JVMClasses<'_> { comet_shuffle_block_iterator: CometShuffleBlockIterator::new(env).unwrap(), comet_task_memory_manager: CometTaskMemoryManager::new(env).unwrap(), comet_udf_bridge: { + // Optional: if the bridge class is absent (e.g. comet shading + // dropped org.apache.comet.udf.*), record None and clear the + // pending JVM exception so other JNI calls keep working. let bridge = CometUdfBridge::new(env).ok(); if env.exception_check() { env.exception_clear(); diff --git a/native/spark-expr/src/jvm_udf/mod.rs b/native/spark-expr/src/jvm_udf/mod.rs index 4ed25de6ee..99efc0c80b 100644 --- a/native/spark-expr/src/jvm_udf/mod.rs +++ b/native/spark-expr/src/jvm_udf/mod.rs @@ -171,7 +171,8 @@ impl PhysicalExpr for JvmScalarUdfExpr { let bridge = JVMClasses::get().comet_udf_bridge.as_ref().ok_or_else(|| { CometError::from(ExecutionError::GeneralError( "JVM UDF bridge unavailable: org.apache.comet.udf.CometUdfBridge \ - class was not found on the JVM classpath." + class was not found on the JVM classpath. Set \ + spark.comet.exec.regexp.engine=rust to disable this path." .to_string(), )) })?; @@ -237,7 +238,19 @@ impl PhysicalExpr for JvmScalarUdfExpr { // exactly once when the Box drops at end of scope. let result_data = unsafe { from_ffi(*out_array, &out_schema) } .map_err(|e| CometError::Arrow { source: e })?; - Ok(ColumnarValue::Array(make_array(result_data))) + let result_array = make_array(result_data); + + // The JVM may produce arrays with different field names (e.g. Arrow Java's + // ListVector uses "$data$" for child fields) than what DataFusion expects + // (e.g. "item"). Cast to the declared return_type to normalize schema. + let result_array = if result_array.data_type() != &self.return_type { + arrow::compute::cast(&result_array, &self.return_type) + .map_err(|e| CometError::Arrow { source: e })? + } else { + result_array + }; + + Ok(ColumnarValue::Array(result_array)) } fn children(&self) -> Vec<&Arc> { diff --git a/pom.xml b/pom.xml index 7419fecc92..129af7485e 100644 --- a/pom.xml +++ b/pom.xml @@ -1153,6 +1153,7 @@ under the License. native/proto/src/generated/** benchmarks/tpc/queries/** .claude/** + docs/superpowers/** diff --git a/spark/src/main/scala/org/apache/comet/CometConf.scala b/spark/src/main/scala/org/apache/comet/CometConf.scala index 1581abccff..eac1ed14ca 100644 --- a/spark/src/main/scala/org/apache/comet/CometConf.scala +++ b/spark/src/main/scala/org/apache/comet/CometConf.scala @@ -381,6 +381,40 @@ object CometConf extends ShimCometConf { .booleanConf .createWithDefault(false) + val COMET_JVM_UDF_ENABLED: ConfigEntry[Boolean] = + conf("spark.comet.jvmUdf.enabled") + .category(CATEGORY_EXEC) + .doc( + "Master switch for the JVM UDF framework, which lets native execution call back into " + + "the JVM to evaluate selected expressions for full Spark compatibility (at the cost " + + "of a JNI round-trip per batch). The framework is experimental and may change in " + + "future releases. Disabled by default; expressions that would otherwise route " + + "through the JVM UDF bridge fall back to native or to Spark while this is false.") + .booleanConf + .createWithDefault(false) + + val REGEXP_ENGINE_RUST = "rust" + val REGEXP_ENGINE_JAVA = "java" + + val COMET_REGEXP_ENGINE: ConfigEntry[String] = + conf("spark.comet.exec.regexp.engine") + .category(CATEGORY_EXEC) + .doc( + "Selects the engine used to evaluate Spark regular-expression expressions. " + + s"`$REGEXP_ENGINE_JAVA` (default) routes through a JVM-side UDF " + + "(java.util.regex.Pattern) for Spark-compatible semantics, at the cost of JNI " + + s"roundtrips per batch; this requires ${COMET_JVM_UDF_ENABLED.key}=true and " + + s"falls back to Spark otherwise. `$REGEXP_ENGINE_RUST` runs the native DataFusion " + + "regexp engine when an implementation exists; setting this is itself the opt-in " + + "for the semantic differences between Java and Rust regex. Affected expressions: " + + "rlike, regexp_extract, regexp_extract_all, regexp_replace, regexp_instr, and " + + "split (the extract/instr family has no native Rust path; they fall back to Spark " + + s"under `$REGEXP_ENGINE_RUST`).") + .stringConf + .transform(_.toLowerCase(Locale.ROOT)) + .checkValues(Set(REGEXP_ENGINE_RUST, REGEXP_ENGINE_JAVA)) + .createWithDefault(REGEXP_ENGINE_JAVA) + val COMET_EXEC_SHUFFLE_WITH_HASH_PARTITIONING_ENABLED: ConfigEntry[Boolean] = conf("spark.comet.native.shuffle.partitioning.hash.enabled") .category(CATEGORY_SHUFFLE) diff --git a/spark/src/main/scala/org/apache/comet/GenerateDocs.scala b/spark/src/main/scala/org/apache/comet/GenerateDocs.scala index 870fb5e47d..4fb3c3e35b 100644 --- a/spark/src/main/scala/org/apache/comet/GenerateDocs.scala +++ b/spark/src/main/scala/org/apache/comet/GenerateDocs.scala @@ -303,7 +303,7 @@ object GenerateDocs { annotations += ((fromTypeName, toTypeName, note.trim.replace("(10,2)", ""))) } "C" - case Incompatible(notes) => + case Incompatible(notes, _) => notes.filter(_.trim.nonEmpty).foreach { note => annotations += ((fromTypeName, toTypeName, note.trim.replace("(10,2)", ""))) } diff --git a/spark/src/main/scala/org/apache/comet/expressions/RegExp.scala b/spark/src/main/scala/org/apache/comet/expressions/RegExp.scala deleted file mode 100644 index 8dc0c828c7..0000000000 --- a/spark/src/main/scala/org/apache/comet/expressions/RegExp.scala +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.comet.expressions - -object RegExp { - - /** Determine whether the regexp pattern is supported natively and compatible with Spark */ - def isSupportedPattern(pattern: String): Boolean = { - // this is a placeholder for implementing logic to determine if the pattern - // is known to be compatible with Spark, so that we can enable regexp automatically - // for common cases and fallback to Spark for more complex cases - false - } - -} diff --git a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala index 72c2bea9e4..d3062d03de 100644 --- a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala +++ b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala @@ -709,7 +709,7 @@ case class CometExecRule(session: SparkSession) case Unsupported(notes) => withInfo(op, notes.getOrElse("")) false - case Incompatible(notes) => + case Incompatible(notes, _) => val allowIncompat = CometConf.isOperatorAllowIncompat(opName) val incompatConf = CometConf.getOperatorAllowIncompatConfigKey(opName) if (allowIncompat) { diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index d85a2c30cb..d48c5ffac4 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -182,6 +182,9 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim { classOf[Like] -> CometLike, classOf[Lower] -> CometLower, classOf[OctetLength] -> CometScalarFunction("octet_length"), + classOf[RegExpExtract] -> CometRegExpExtract, + classOf[RegExpExtractAll] -> CometRegExpExtractAll, + classOf[RegExpInStr] -> CometRegExpInStr, classOf[RegExpReplace] -> CometRegExpReplace, classOf[Reverse] -> CometReverse, classOf[RLike] -> CometRLike, @@ -527,23 +530,29 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim { case Unsupported(notes) => withInfo(fn, notes.getOrElse("")) None - case Incompatible(notes) => + case Incompatible(notes, optedInBy) => val exprAllowIncompat = CometConf.isExprAllowIncompat(exprConfName) - if (exprAllowIncompat) { + val namedConfOptIn = optedInBy.exists(isOptedInVia) + if (exprAllowIncompat || namedConfOptIn) { if (notes.isDefined) { - logWarning( - s"Comet supports $fn when " + - s"${CometConf.getExprAllowIncompatConfigKey(exprConfName)}=true " + - s"but has notes: ${notes.get}") + val optInDesc = if (namedConfOptIn) { + optedInBy.get + } else { + s"${CometConf.getExprAllowIncompatConfigKey(exprConfName)}=true" + } + logWarning(s"Comet supports $fn when $optInDesc but has notes: ${notes.get}") } aggHandler.convert(aggExpr, fn, inputs, binding, conf) } else { val optionalNotes = notes.map(str => s" ($str)").getOrElse("") + val extraOptIn = optedInBy + .map(kv => s" or by setting $kv") + .getOrElse("") withInfo( fn, s"$fn is not fully compatible with Spark$optionalNotes. To enable it anyway, " + - s"set ${CometConf.getExprAllowIncompatConfigKey(exprConfName)}=true. " + - s"${CometConf.COMPAT_GUIDE}.") + s"set ${CometConf.getExprAllowIncompatConfigKey(exprConfName)}=true" + + s"$extraOptIn. ${CometConf.COMPAT_GUIDE}.") None } case Compatible(notes) => @@ -619,6 +628,21 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim { exprToProtoInternal(newExpr, inputs, binding) } + /** + * True when the current SQLConf has the named config set to the given value. The argument is a + * `key=value` string used by `Incompatible.optedInBy` to declare which config opts the user + * into running an otherwise-incompatible expression. The configured value is compared + * case-insensitively after splitting on the first `=`. + */ + private def isOptedInVia(keyEqualsValue: String): Boolean = { + keyEqualsValue.split("=", 2) match { + case Array(key, expected) => + Option(SQLConf.get.getConfString(key, null)) + .exists(_.equalsIgnoreCase(expected)) + case _ => false + } + } + /** * Convert a Spark expression to a protocol-buffer representation of a native Comet/DataFusion * expression. @@ -652,23 +676,29 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim { case Unsupported(notes) => withInfo(expr, notes.getOrElse("")) None - case Incompatible(notes) => + case Incompatible(notes, optedInBy) => val exprAllowIncompat = CometConf.isExprAllowIncompat(exprConfName) - if (exprAllowIncompat) { + val namedConfOptIn = optedInBy.exists(isOptedInVia) + if (exprAllowIncompat || namedConfOptIn) { if (notes.isDefined) { - logWarning( - s"Comet supports $expr when " + - s"${CometConf.getExprAllowIncompatConfigKey(exprConfName)}=true " + - s"but has notes: ${notes.get}") + val optInDesc = if (namedConfOptIn) { + optedInBy.get + } else { + s"${CometConf.getExprAllowIncompatConfigKey(exprConfName)}=true" + } + logWarning(s"Comet supports $expr when $optInDesc but has notes: ${notes.get}") } handler.convert(expr, inputs, binding) } else { val optionalNotes = notes.map(str => s" ($str)").getOrElse("") + val extraOptIn = optedInBy + .map(kv => s" or by setting $kv") + .getOrElse("") withInfo( expr, s"$expr is not fully compatible with Spark$optionalNotes. To enable it anyway, " + - s"set ${CometConf.getExprAllowIncompatConfigKey(exprConfName)}=true. " + - s"${CometConf.COMPAT_GUIDE}.") + s"set ${CometConf.getExprAllowIncompatConfigKey(exprConfName)}=true" + + s"$extraOptIn. ${CometConf.COMPAT_GUIDE}.") None } case Compatible(notes) => diff --git a/spark/src/main/scala/org/apache/comet/serde/SupportLevel.scala b/spark/src/main/scala/org/apache/comet/serde/SupportLevel.scala index cb78c7d2d4..4fc1a31337 100644 --- a/spark/src/main/scala/org/apache/comet/serde/SupportLevel.scala +++ b/spark/src/main/scala/org/apache/comet/serde/SupportLevel.scala @@ -37,8 +37,17 @@ case class Compatible(notes: Option[String] = None) extends SupportLevel * * Any compatibility differences are noted in the * [[https://datafusion.apache.org/comet/user-guide/compatibility.html Comet Compatibility Guide]]. + * + * @param notes + * Optional human-readable notes about the incompatibility. + * @param optedInBy + * Optional `key=value` pair naming a SQLConf entry that, when set to `value`, opts the user + * into running this expression despite the incompatibility, in addition to the per-expression + * `spark.comet.expression..allowIncompatible` flag. Use this when a broader config (for + * example, an engine selector) already encodes the user's acceptance of the trade-off. */ -case class Incompatible(notes: Option[String] = None) extends SupportLevel +case class Incompatible(notes: Option[String] = None, optedInBy: Option[String] = None) + extends SupportLevel /** Comet does not support this feature */ case class Unsupported(notes: Option[String] = None) extends SupportLevel diff --git a/spark/src/main/scala/org/apache/comet/serde/strings.scala b/spark/src/main/scala/org/apache/comet/serde/strings.scala index aec4b19111..6db99a0a45 100644 --- a/spark/src/main/scala/org/apache/comet/serde/strings.scala +++ b/spark/src/main/scala/org/apache/comet/serde/strings.scala @@ -21,15 +21,76 @@ package org.apache.comet.serde import java.util.Locale -import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Concat, ConcatWs, Expression, GetJsonObject, If, InitCap, IsNull, Left, Length, Like, Literal, Lower, RegExpReplace, Right, RLike, StringLPad, StringRepeat, StringRPad, StringSplit, Substring, SubstringIndex, Upper} -import org.apache.spark.sql.types.{BinaryType, DataTypes, LongType, StringType} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Concat, ConcatWs, Expression, GetJsonObject, If, InitCap, IsNull, Left, Length, Like, Literal, Lower, RegExpExtract, RegExpExtractAll, RegExpInStr, RegExpReplace, Right, RLike, StringLPad, StringRepeat, StringRPad, StringSplit, Substring, SubstringIndex, Upper} +import org.apache.spark.sql.types.{ArrayType, BinaryType, DataTypes, IntegerType, LongType, StringType} import org.apache.spark.unsafe.types.UTF8String import org.apache.comet.CometConf import org.apache.comet.CometSparkSessionExtensions.withInfo -import org.apache.comet.expressions.{CometCast, CometEvalMode, RegExp} +import org.apache.comet.expressions.{CometCast, CometEvalMode} import org.apache.comet.serde.ExprOuterClass.Expr -import org.apache.comet.serde.QueryPlanSerde.{createBinaryExpr, exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto, scalarFunctionExprToProtoWithReturnType} +import org.apache.comet.serde.QueryPlanSerde.{createBinaryExpr, exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto, scalarFunctionExprToProtoWithReturnType, serializeDataType} + +/** + * Routing decision for a regex expression. Each regex serde delegates to [[RegexpRoute.choose]] + * to pick between the native Rust regex engine, the JVM UDF regex engine, or Spark fallback. + */ +private sealed trait RegexpRoute +private object RegexpRoute { + + /** Run the native Rust regex implementation. */ + case object Native extends RegexpRoute + + /** Run the JVM-side UDF regex implementation. */ + case object JvmUdf extends RegexpRoute + + /** Decline to run natively; the operator falls back to Spark with the given reason. */ + case class Fallback(reason: String) extends RegexpRoute + + /** + * SupportLevel returned by serdes whose route is [[Native]]. Surfaced as `Incompatible` so the + * standard gating in `QueryPlanSerde` can recognize `spark.comet.exec.regexp.engine=rust` as + * the opt-in (via `optedInBy`) without each serde repeating the check. + */ + val nativeIncompatible: Incompatible = + Incompatible( + Some("Rust regexp engine has different semantics from Java regexp"), + Some(s"${CometConf.COMET_REGEXP_ENGINE.key}=${CometConf.REGEXP_ENGINE_RUST}")) + + /** + * Pick a route given the user's config and whether a native Rust implementation exists for the + * expression. `engine=java` (default) routes to the JVM UDF if the master switch is on; else + * Spark fallback. `engine=rust` runs native if available; else Spark fallback. + */ + def choose(exprName: String, hasNative: Boolean): RegexpRoute = { + val engine = CometConf.COMET_REGEXP_ENGINE.get() + val jvmUdfEnabled = CometConf.COMET_JVM_UDF_ENABLED.get() + + engine match { + case CometConf.REGEXP_ENGINE_RUST => + if (hasNative) { + Native + } else { + Fallback( + s"$exprName has no native Rust implementation. Set " + + s"${CometConf.COMET_REGEXP_ENGINE.key}=${CometConf.REGEXP_ENGINE_JAVA} with " + + s"${CometConf.COMET_JVM_UDF_ENABLED.key}=true to use the JVM regex engine.") + } + + case CometConf.REGEXP_ENGINE_JAVA => + if (jvmUdfEnabled) { + JvmUdf + } else { + Fallback( + s"$exprName requires ${CometConf.COMET_JVM_UDF_ENABLED.key}=true when " + + s"${CometConf.COMET_REGEXP_ENGINE.key}=${CometConf.REGEXP_ENGINE_JAVA}. " + + "The JVM UDF framework is experimental and disabled by default.") + } + + case other => Fallback(s"Unknown ${CometConf.COMET_REGEXP_ENGINE.key}=$other") + } + } +} object CometStringRepeat extends CometExpressionSerde[StringRepeat] { @@ -279,29 +340,83 @@ object CometLike extends CometExpressionSerde[Like] { object CometRLike extends CometExpressionSerde[RLike] { - override def getIncompatibleReasons(): Seq[String] = Seq( - "Uses Rust regexp engine, which has different behavior to Java regexp engine") + override def getSupportLevel(expr: RLike): SupportLevel = { + expr.right match { + case _: Literal => + RegexpRoute.choose("rlike", hasNative = true) match { + case RegexpRoute.Native => RegexpRoute.nativeIncompatible + case RegexpRoute.JvmUdf => Compatible(None) + case RegexpRoute.Fallback(reason) => Unsupported(Some(reason)) + } + case _ => Unsupported(Some("Only scalar regexp patterns are supported")) + } + } override def convert(expr: RLike, inputs: Seq[Attribute], binding: Boolean): Option[Expr] = { + RegexpRoute.choose("rlike", hasNative = true) match { + case RegexpRoute.Native => convertViaNativeRegex(expr, inputs, binding) + case RegexpRoute.JvmUdf => convertViaJvmUdf(expr, inputs, binding) + case RegexpRoute.Fallback(reason) => + withInfo(expr, reason) + None + } + } + + private def convertViaNativeRegex( + expr: RLike, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { expr.right match { - case Literal(pattern, DataTypes.StringType) => - if (!RegExp.isSupportedPattern(pattern.toString) && - !CometConf.isExprAllowIncompat("regexp")) { - withInfo( - expr, - s"Regexp pattern $pattern is not compatible with Spark. " + - s"Set ${CometConf.getExprAllowIncompatConfigKey("regexp")}=true " + - "to allow it anyway.") - None - } else { - createBinaryExpr( - expr, - expr.left, - expr.right, - inputs, - binding, - (builder, binaryExpr) => builder.setRlike(binaryExpr)) + case Literal(_, DataTypes.StringType) => + createBinaryExpr( + expr, + expr.left, + expr.right, + inputs, + binding, + (builder, binaryExpr) => builder.setRlike(binaryExpr)) + case _ => + withInfo(expr, "Only scalar regexp patterns are supported") + None + } + } + + private def convertViaJvmUdf( + expr: RLike, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { + expr.right match { + case Literal(value, DataTypes.StringType) => + if (value == null) { + withInfo(expr, "Null literal pattern is handled by Spark fallback") + return None + } + val patternStr = value.toString + try { + java.util.regex.Pattern.compile(patternStr) + } catch { + case e: java.util.regex.PatternSyntaxException => + withInfo(expr, s"Invalid regex pattern: ${e.getDescription}") + return None } + val subjectProto = exprToProtoInternal(expr.left, inputs, binding) + val patternProto = exprToProtoInternal(expr.right, inputs, binding) + if (subjectProto.isEmpty || patternProto.isEmpty) { + return None + } + val returnType = serializeDataType(DataTypes.BooleanType).getOrElse(return None) + val udfBuilder = ExprOuterClass.JvmScalarUdf + .newBuilder() + .setClassName("org.apache.comet.udf.RegExpLikeUDF") + .addArgs(subjectProto.get) + .addArgs(patternProto.get) + .setReturnType(returnType) + .setReturnNullable(expr.nullable) + Some( + ExprOuterClass.Expr + .newBuilder() + .setJvmScalarUdf(udfBuilder.build()) + .build()) case _ => withInfo(expr, "Only scalar regexp patterns are supported") None @@ -309,6 +424,188 @@ object CometRLike extends CometExpressionSerde[RLike] { } } +object CometRegExpExtract extends CometExpressionSerde[RegExpExtract] { + + override def getSupportLevel(expr: RegExpExtract): SupportLevel = { + (expr.regexp, expr.idx) match { + case (_: Literal, _: Literal) => + RegexpRoute.choose("regexp_extract", hasNative = false) match { + case RegexpRoute.JvmUdf => Compatible(None) + case RegexpRoute.Fallback(reason) => Unsupported(Some(reason)) + case RegexpRoute.Native => Unsupported(Some("regexp_extract has no native impl")) + } + case (_: Literal, _) => + Unsupported(Some("Only scalar group index is supported")) + case _ => Unsupported(Some("Only scalar regexp patterns are supported")) + } + } + + override def convert( + expr: RegExpExtract, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { + (expr.regexp, expr.idx) match { + case (Literal(pattern, DataTypes.StringType), Literal(_, _: IntegerType)) => + if (pattern == null) { + withInfo(expr, "Null literal pattern is handled by Spark fallback") + return None + } + try { + java.util.regex.Pattern.compile(pattern.toString) + } catch { + case e: java.util.regex.PatternSyntaxException => + withInfo(expr, s"Invalid regex pattern: ${e.getDescription}") + return None + } + val subjectProto = exprToProtoInternal(expr.subject, inputs, binding) + val patternProto = exprToProtoInternal(expr.regexp, inputs, binding) + val idxProto = exprToProtoInternal(expr.idx, inputs, binding) + if (subjectProto.isEmpty || patternProto.isEmpty || idxProto.isEmpty) { + return None + } + val returnType = serializeDataType(DataTypes.StringType).getOrElse(return None) + val udfBuilder = ExprOuterClass.JvmScalarUdf + .newBuilder() + .setClassName("org.apache.comet.udf.RegExpExtractUDF") + .addArgs(subjectProto.get) + .addArgs(patternProto.get) + .addArgs(idxProto.get) + .setReturnType(returnType) + .setReturnNullable(expr.nullable) + Some( + ExprOuterClass.Expr + .newBuilder() + .setJvmScalarUdf(udfBuilder.build()) + .build()) + case _ => + withInfo(expr, "Only scalar regexp patterns and group index are supported") + None + } + } +} + +object CometRegExpExtractAll extends CometExpressionSerde[RegExpExtractAll] { + + override def getSupportLevel(expr: RegExpExtractAll): SupportLevel = { + (expr.regexp, expr.idx) match { + case (_: Literal, _: Literal) => + RegexpRoute.choose("regexp_extract_all", hasNative = false) match { + case RegexpRoute.JvmUdf => Compatible(None) + case RegexpRoute.Fallback(reason) => Unsupported(Some(reason)) + case RegexpRoute.Native => + Unsupported(Some("regexp_extract_all has no native impl")) + } + case (_: Literal, _) => + Unsupported(Some("Only scalar group index is supported")) + case _ => Unsupported(Some("Only scalar regexp patterns are supported")) + } + } + + override def convert( + expr: RegExpExtractAll, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { + (expr.regexp, expr.idx) match { + case (Literal(pattern, DataTypes.StringType), Literal(_, _: IntegerType)) => + if (pattern == null) { + withInfo(expr, "Null literal pattern is handled by Spark fallback") + return None + } + try { + java.util.regex.Pattern.compile(pattern.toString) + } catch { + case e: java.util.regex.PatternSyntaxException => + withInfo(expr, s"Invalid regex pattern: ${e.getDescription}") + return None + } + val subjectProto = exprToProtoInternal(expr.subject, inputs, binding) + val patternProto = exprToProtoInternal(expr.regexp, inputs, binding) + val idxProto = exprToProtoInternal(expr.idx, inputs, binding) + if (subjectProto.isEmpty || patternProto.isEmpty || idxProto.isEmpty) { + return None + } + val returnType = + serializeDataType(ArrayType(StringType, containsNull = true)).getOrElse(return None) + val udfBuilder = ExprOuterClass.JvmScalarUdf + .newBuilder() + .setClassName("org.apache.comet.udf.RegExpExtractAllUDF") + .addArgs(subjectProto.get) + .addArgs(patternProto.get) + .addArgs(idxProto.get) + .setReturnType(returnType) + .setReturnNullable(expr.nullable) + Some( + ExprOuterClass.Expr + .newBuilder() + .setJvmScalarUdf(udfBuilder.build()) + .build()) + case _ => + withInfo(expr, "Only scalar regexp patterns and group index are supported") + None + } + } +} + +object CometRegExpInStr extends CometExpressionSerde[RegExpInStr] { + + override def getSupportLevel(expr: RegExpInStr): SupportLevel = { + (expr.regexp, expr.idx) match { + case (_: Literal, _: Literal) => + RegexpRoute.choose("regexp_instr", hasNative = false) match { + case RegexpRoute.JvmUdf => Compatible(None) + case RegexpRoute.Fallback(reason) => Unsupported(Some(reason)) + case RegexpRoute.Native => Unsupported(Some("regexp_instr has no native impl")) + } + case (_: Literal, _) => + Unsupported(Some("Only scalar group index is supported")) + case _ => Unsupported(Some("Only scalar regexp patterns are supported")) + } + } + + override def convert( + expr: RegExpInStr, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { + (expr.regexp, expr.idx) match { + case (Literal(pattern, DataTypes.StringType), Literal(_, _: IntegerType)) => + if (pattern == null) { + withInfo(expr, "Null literal pattern is handled by Spark fallback") + return None + } + try { + java.util.regex.Pattern.compile(pattern.toString) + } catch { + case e: java.util.regex.PatternSyntaxException => + withInfo(expr, s"Invalid regex pattern: ${e.getDescription}") + return None + } + val subjectProto = exprToProtoInternal(expr.subject, inputs, binding) + val patternProto = exprToProtoInternal(expr.regexp, inputs, binding) + val idxProto = exprToProtoInternal(expr.idx, inputs, binding) + if (subjectProto.isEmpty || patternProto.isEmpty || idxProto.isEmpty) { + return None + } + val returnType = serializeDataType(DataTypes.IntegerType).getOrElse(return None) + val udfBuilder = ExprOuterClass.JvmScalarUdf + .newBuilder() + .setClassName("org.apache.comet.udf.RegExpInStrUDF") + .addArgs(subjectProto.get) + .addArgs(patternProto.get) + .addArgs(idxProto.get) + .setReturnType(returnType) + .setReturnNullable(expr.nullable) + Some( + ExprOuterClass.Expr + .newBuilder() + .setJvmScalarUdf(udfBuilder.build()) + .build()) + case _ => + withInfo(expr, "Only scalar regexp patterns and group index are supported") + None + } + } +} + object CometStringRPad extends CometExpressionSerde[StringRPad] { override def getUnsupportedReasons(): Seq[String] = Seq( @@ -367,24 +664,22 @@ object CometStringLPad extends CometExpressionSerde[StringLPad] { } object CometRegExpReplace extends CometExpressionSerde[RegExpReplace] { - override def getIncompatibleReasons(): Seq[String] = Seq( - "Regexp pattern may not be compatible with Spark") override def getUnsupportedReasons(): Seq[String] = Seq( "Only supports `regexp_replace` with an offset of 1 (no offset)") override def getSupportLevel(expr: RegExpReplace): SupportLevel = { - if (!RegExp.isSupportedPattern(expr.regexp.toString) && - !CometConf.isExprAllowIncompat("regexp")) { - withInfo( - expr, - s"Regexp pattern ${expr.regexp} is not compatible with Spark. " + - s"Set ${CometConf.getExprAllowIncompatConfigKey("regexp")}=true " + - "to allow it anyway.") - return Incompatible() - } expr.pos match { - case Literal(value, DataTypes.IntegerType) if value == 1 => Compatible() + case Literal(value, DataTypes.IntegerType) if value == 1 => + expr.regexp match { + case _: Literal => + RegexpRoute.choose("regexp_replace", hasNative = true) match { + case RegexpRoute.Native => RegexpRoute.nativeIncompatible + case RegexpRoute.JvmUdf => Compatible(None) + case RegexpRoute.Fallback(reason) => Unsupported(Some(reason)) + } + case _ => Unsupported(Some("Only scalar regexp patterns are supported")) + } case _ => Unsupported(Some("Comet only supports regexp_replace with an offset of 1 (no offset).")) } @@ -394,6 +689,19 @@ object CometRegExpReplace extends CometExpressionSerde[RegExpReplace] { expr: RegExpReplace, inputs: Seq[Attribute], binding: Boolean): Option[Expr] = { + RegexpRoute.choose("regexp_replace", hasNative = true) match { + case RegexpRoute.Native => convertViaNativeRegex(expr, inputs, binding) + case RegexpRoute.JvmUdf => convertViaJvmUdf(expr, inputs, binding) + case RegexpRoute.Fallback(reason) => + withInfo(expr, reason) + None + } + } + + private def convertViaNativeRegex( + expr: RegExpReplace, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { val subjectExpr = exprToProtoInternal(expr.subject, inputs, binding) val patternExpr = exprToProtoInternal(expr.regexp, inputs, binding) val replacementExpr = exprToProtoInternal(expr.rep, inputs, binding) @@ -408,6 +716,49 @@ object CometRegExpReplace extends CometExpressionSerde[RegExpReplace] { flagsExpr) optExprWithInfo(optExpr, expr, expr.subject, expr.regexp, expr.rep, expr.pos) } + + private def convertViaJvmUdf( + expr: RegExpReplace, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { + expr.regexp match { + case Literal(pattern, DataTypes.StringType) => + if (pattern == null) { + withInfo(expr, "Null literal pattern is handled by Spark fallback") + return None + } + try { + java.util.regex.Pattern.compile(pattern.toString) + } catch { + case e: java.util.regex.PatternSyntaxException => + withInfo(expr, s"Invalid regex pattern: ${e.getDescription}") + return None + } + val subjectProto = exprToProtoInternal(expr.subject, inputs, binding) + val patternProto = exprToProtoInternal(expr.regexp, inputs, binding) + val repProto = exprToProtoInternal(expr.rep, inputs, binding) + if (subjectProto.isEmpty || patternProto.isEmpty || repProto.isEmpty) { + return None + } + val returnType = serializeDataType(DataTypes.StringType).getOrElse(return None) + val udfBuilder = ExprOuterClass.JvmScalarUdf + .newBuilder() + .setClassName("org.apache.comet.udf.RegExpReplaceUDF") + .addArgs(subjectProto.get) + .addArgs(patternProto.get) + .addArgs(repProto.get) + .setReturnType(returnType) + .setReturnNullable(expr.nullable) + Some( + ExprOuterClass.Expr + .newBuilder() + .setJvmScalarUdf(udfBuilder.build()) + .build()) + case _ => + withInfo(expr, "Only scalar regexp patterns are supported") + None + } + } } /** @@ -417,16 +768,35 @@ object CometRegExpReplace extends CometExpressionSerde[RegExpReplace] { */ object CometStringSplit extends CometExpressionSerde[StringSplit] { - override def getIncompatibleReasons(): Seq[String] = Seq( - "Regex engine differences between Java and Rust") - - override def getSupportLevel(expr: StringSplit): SupportLevel = - Incompatible(Some("Regex engine differences between Java and Rust")) + override def getSupportLevel(expr: StringSplit): SupportLevel = { + expr.regex match { + case _: Literal => + RegexpRoute.choose("split", hasNative = true) match { + case RegexpRoute.Native => RegexpRoute.nativeIncompatible + case RegexpRoute.JvmUdf => Compatible(None) + case RegexpRoute.Fallback(reason) => Unsupported(Some(reason)) + } + case _ => Unsupported(Some("Only scalar regex patterns are supported")) + } + } override def convert( expr: StringSplit, inputs: Seq[Attribute], binding: Boolean): Option[Expr] = { + RegexpRoute.choose("split", hasNative = true) match { + case RegexpRoute.Native => convertViaNativeRegex(expr, inputs, binding) + case RegexpRoute.JvmUdf => convertViaJvmUdf(expr, inputs, binding) + case RegexpRoute.Fallback(reason) => + withInfo(expr, reason) + None + } + } + + private def convertViaNativeRegex( + expr: StringSplit, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { val strExpr = exprToProtoInternal(expr.str, inputs, binding) val regexExpr = exprToProtoInternal(expr.regex, inputs, binding) val limitExpr = exprToProtoInternal(expr.limit, inputs, binding) @@ -439,6 +809,50 @@ object CometStringSplit extends CometExpressionSerde[StringSplit] { limitExpr) optExprWithInfo(optExpr, expr, expr.str, expr.regex, expr.limit) } + + private def convertViaJvmUdf( + expr: StringSplit, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { + expr.regex match { + case Literal(pattern, DataTypes.StringType) => + if (pattern == null) { + withInfo(expr, "Null literal pattern is handled by Spark fallback") + return None + } + try { + java.util.regex.Pattern.compile(pattern.toString) + } catch { + case e: java.util.regex.PatternSyntaxException => + withInfo(expr, s"Invalid regex pattern: ${e.getDescription}") + return None + } + val strProto = exprToProtoInternal(expr.str, inputs, binding) + val regexProto = exprToProtoInternal(expr.regex, inputs, binding) + val limitProto = exprToProtoInternal(expr.limit, inputs, binding) + if (strProto.isEmpty || regexProto.isEmpty || limitProto.isEmpty) { + return None + } + val returnType = + serializeDataType(ArrayType(StringType, containsNull = false)).getOrElse(return None) + val udfBuilder = ExprOuterClass.JvmScalarUdf + .newBuilder() + .setClassName("org.apache.comet.udf.StringSplitUDF") + .addArgs(strProto.get) + .addArgs(regexProto.get) + .addArgs(limitProto.get) + .setReturnType(returnType) + .setReturnNullable(expr.nullable) + Some( + ExprOuterClass.Expr + .newBuilder() + .setJvmScalarUdf(udfBuilder.build()) + .build()) + case _ => + withInfo(expr, "Only scalar regex patterns are supported") + None + } + } } object CometGetJsonObject extends CometExpressionSerde[GetJsonObject] { diff --git a/spark/src/main/scala/org/apache/comet/udf/RegExpExtractAllUDF.scala b/spark/src/main/scala/org/apache/comet/udf/RegExpExtractAllUDF.scala new file mode 100644 index 0000000000..ddcb44a6a9 --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/udf/RegExpExtractAllUDF.scala @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.udf + +import java.nio.charset.StandardCharsets +import java.util.concurrent.ConcurrentHashMap +import java.util.regex.Pattern + +import org.apache.arrow.vector.{IntVector, ValueVector, VarCharVector} +import org.apache.arrow.vector.complex.ListVector +import org.apache.arrow.vector.types.pojo.{ArrowType, FieldType} + +import org.apache.comet.CometArrowAllocator + +/** + * `regexp_extract_all(subject, pattern, idx)` implemented with java.util.regex.Pattern. + * + * Returns an array of strings: for every match of pattern in subject, extracts the idx-th + * capturing group. idx=0 returns the entire match. + * + * Inputs: + * - inputs(0): VarCharVector subject column + * - inputs(1): VarCharVector pattern (scalar, length-1) + * - inputs(2): IntVector group index (scalar, length-1) + * + * Output: ListVector of VarChar, same length as subject. + */ +class RegExpExtractAllUDF extends CometUDF { + + private val patternCache = new ConcurrentHashMap[String, Pattern]() + + override def evaluate(inputs: Array[ValueVector], numRows: Int): ValueVector = { + require(inputs.length == 3, s"RegExpExtractAllUDF expects 3 inputs, got ${inputs.length}") + val subject = inputs(0).asInstanceOf[VarCharVector] + val patternVec = inputs(1).asInstanceOf[VarCharVector] + val idxVec = inputs(2).asInstanceOf[IntVector] + require( + patternVec.getValueCount >= 1 && !patternVec.isNull(0), + "RegExpExtractAllUDF requires a non-null scalar pattern") + require( + idxVec.getValueCount >= 1 && !idxVec.isNull(0), + "RegExpExtractAllUDF requires a non-null scalar group index") + + val patternStr = new String(patternVec.get(0), StandardCharsets.UTF_8) + val pattern = patternCache.computeIfAbsent(patternStr, Pattern.compile) + val idx = idxVec.get(0) + + val n = subject.getValueCount + val out = ListVector.empty("regexp_extract_all_result", CometArrowAllocator) + out.addOrGetVector[VarCharVector](new FieldType(true, ArrowType.Utf8.INSTANCE, null)) + val writer = out.getWriter + + var i = 0 + while (i < n) { + if (subject.isNull(i)) { + out.setNull(i) + } else { + val s = new String(subject.get(i), StandardCharsets.UTF_8) + val matcher = pattern.matcher(s) + writer.setPosition(i) + writer.startList() + while (matcher.find()) { + if (idx <= matcher.groupCount()) { + val group = matcher.group(idx) + val bytes = + if (group == null) "".getBytes(StandardCharsets.UTF_8) + else group.getBytes(StandardCharsets.UTF_8) + val buf = CometArrowAllocator.buffer(bytes.length) + buf.writeBytes(bytes) + writer.varChar().writeVarChar(0, bytes.length, buf) + buf.close() + } else { + val bytes = "".getBytes(StandardCharsets.UTF_8) + val buf = CometArrowAllocator.buffer(bytes.length) + buf.writeBytes(bytes) + writer.varChar().writeVarChar(0, bytes.length, buf) + buf.close() + } + } + writer.endList() + } + i += 1 + } + out.setValueCount(n) + out + } +} diff --git a/spark/src/main/scala/org/apache/comet/udf/RegExpExtractUDF.scala b/spark/src/main/scala/org/apache/comet/udf/RegExpExtractUDF.scala new file mode 100644 index 0000000000..9e056557a1 --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/udf/RegExpExtractUDF.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.udf + +import java.nio.charset.StandardCharsets +import java.util.concurrent.ConcurrentHashMap +import java.util.regex.Pattern + +import org.apache.arrow.vector.{IntVector, ValueVector, VarCharVector} + +import org.apache.comet.CometArrowAllocator + +/** + * `regexp_extract(subject, pattern, idx)` implemented with java.util.regex.Pattern. + * + * Returns the string matching the idx-th capturing group of the first match, or empty string if + * no match. idx=0 returns the entire match. + * + * Inputs: + * - inputs(0): VarCharVector subject column + * - inputs(1): VarCharVector pattern (scalar, length-1) + * - inputs(2): IntVector group index (scalar, length-1) + * + * Output: VarCharVector, same length as subject. + */ +class RegExpExtractUDF extends CometUDF { + + private val patternCache = new ConcurrentHashMap[String, Pattern]() + + override def evaluate(inputs: Array[ValueVector], numRows: Int): ValueVector = { + require(inputs.length == 3, s"RegExpExtractUDF expects 3 inputs, got ${inputs.length}") + val subject = inputs(0).asInstanceOf[VarCharVector] + val patternVec = inputs(1).asInstanceOf[VarCharVector] + val idxVec = inputs(2).asInstanceOf[IntVector] + require( + patternVec.getValueCount >= 1 && !patternVec.isNull(0), + "RegExpExtractUDF requires a non-null scalar pattern") + require( + idxVec.getValueCount >= 1 && !idxVec.isNull(0), + "RegExpExtractUDF requires a non-null scalar group index") + + val patternStr = new String(patternVec.get(0), StandardCharsets.UTF_8) + val pattern = patternCache.computeIfAbsent(patternStr, Pattern.compile) + val idx = idxVec.get(0) + + val n = subject.getValueCount + val out = new VarCharVector("regexp_extract_result", CometArrowAllocator) + out.allocateNew(n) + + var i = 0 + while (i < n) { + if (subject.isNull(i)) { + out.setNull(i) + } else { + val s = new String(subject.get(i), StandardCharsets.UTF_8) + val matcher = pattern.matcher(s) + if (matcher.find() && idx <= matcher.groupCount()) { + val group = matcher.group(idx) + if (group == null) { + out.setSafe(i, "".getBytes(StandardCharsets.UTF_8)) + } else { + out.setSafe(i, group.getBytes(StandardCharsets.UTF_8)) + } + } else { + out.setSafe(i, "".getBytes(StandardCharsets.UTF_8)) + } + } + i += 1 + } + out.setValueCount(n) + out + } +} diff --git a/spark/src/main/scala/org/apache/comet/udf/RegExpInStrUDF.scala b/spark/src/main/scala/org/apache/comet/udf/RegExpInStrUDF.scala new file mode 100644 index 0000000000..44c02fda18 --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/udf/RegExpInStrUDF.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.udf + +import java.nio.charset.StandardCharsets +import java.util.concurrent.ConcurrentHashMap +import java.util.regex.Pattern + +import org.apache.arrow.vector.{IntVector, ValueVector, VarCharVector} + +import org.apache.comet.CometArrowAllocator + +/** + * `regexp_instr(subject, pattern, idx)` implemented with java.util.regex.Pattern. + * + * Returns the 1-based position of the start of the first match of the idx-th capturing group, or + * 0 if no match. idx=0 means the entire match. + * + * Inputs: + * - inputs(0): VarCharVector subject column + * - inputs(1): VarCharVector pattern (scalar, length-1) + * - inputs(2): IntVector group index (scalar, length-1) + * + * Output: IntVector, same length as subject. + */ +class RegExpInStrUDF extends CometUDF { + + private val patternCache = new ConcurrentHashMap[String, Pattern]() + + override def evaluate(inputs: Array[ValueVector], numRows: Int): ValueVector = { + require(inputs.length == 3, s"RegExpInStrUDF expects 3 inputs, got ${inputs.length}") + val subject = inputs(0).asInstanceOf[VarCharVector] + val patternVec = inputs(1).asInstanceOf[VarCharVector] + val idxVec = inputs(2).asInstanceOf[IntVector] + require( + patternVec.getValueCount >= 1 && !patternVec.isNull(0), + "RegExpInStrUDF requires a non-null scalar pattern") + require( + idxVec.getValueCount >= 1 && !idxVec.isNull(0), + "RegExpInStrUDF requires a non-null scalar group index") + + val patternStr = new String(patternVec.get(0), StandardCharsets.UTF_8) + val pattern = patternCache.computeIfAbsent(patternStr, Pattern.compile) + + val n = subject.getValueCount + val out = new IntVector("regexp_instr_result", CometArrowAllocator) + out.allocateNew(n) + + var i = 0 + while (i < n) { + if (subject.isNull(i)) { + out.setNull(i) + } else { + val s = new String(subject.get(i), StandardCharsets.UTF_8) + val matcher = pattern.matcher(s) + if (matcher.find()) { + // Spark regexp_instr always returns 1-based position of the entire match start + out.set(i, matcher.start() + 1) + } else { + out.set(i, 0) + } + } + i += 1 + } + out.setValueCount(n) + out + } +} diff --git a/spark/src/main/scala/org/apache/comet/udf/RegExpLikeUDF.scala b/spark/src/main/scala/org/apache/comet/udf/RegExpLikeUDF.scala new file mode 100644 index 0000000000..c8073dadfd --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/udf/RegExpLikeUDF.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.udf + +import java.nio.charset.StandardCharsets +import java.util.concurrent.ConcurrentHashMap +import java.util.regex.Pattern + +import org.apache.arrow.vector.{BitVector, ValueVector, VarCharVector} + +import org.apache.comet.CometArrowAllocator + +/** + * `regexp` / `RLike` implemented with java.util.regex.Pattern (Java semantics). + * + * Inputs: + * - inputs(0): VarCharVector subject column + * - inputs(1): VarCharVector pattern, 1-row scalar (serde guarantees this) + * + * Output: BitVector (Arrow boolean), same length as the subject vector. + */ +class RegExpLikeUDF extends CometUDF { + + private val patternCache = new ConcurrentHashMap[String, Pattern]() + + override def evaluate(inputs: Array[ValueVector], numRows: Int): ValueVector = { + require(inputs.length == 2, s"RegExpLikeUDF expects 2 inputs, got ${inputs.length}") + val subject = inputs(0).asInstanceOf[VarCharVector] + val patternVec = inputs(1).asInstanceOf[VarCharVector] + require( + patternVec.getValueCount >= 1 && !patternVec.isNull(0), + "RegExpLikeUDF requires a non-null scalar pattern") + + val patternStr = new String(patternVec.get(0), StandardCharsets.UTF_8) + val pattern = patternCache.computeIfAbsent(patternStr, Pattern.compile) + + val n = subject.getValueCount + val out = new BitVector("rlike_result", CometArrowAllocator) + out.allocateNew(n) + + var i = 0 + while (i < n) { + if (subject.isNull(i)) { + out.setNull(i) + } else { + val s = new String(subject.get(i), StandardCharsets.UTF_8) + out.set(i, if (pattern.matcher(s).find()) 1 else 0) + } + i += 1 + } + out.setValueCount(n) + out + } +} diff --git a/spark/src/main/scala/org/apache/comet/udf/RegExpReplaceUDF.scala b/spark/src/main/scala/org/apache/comet/udf/RegExpReplaceUDF.scala new file mode 100644 index 0000000000..7fc598d42a --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/udf/RegExpReplaceUDF.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.udf + +import java.nio.charset.StandardCharsets +import java.util.concurrent.ConcurrentHashMap +import java.util.regex.Pattern + +import org.apache.arrow.vector.{ValueVector, VarCharVector} + +import org.apache.comet.CometArrowAllocator + +/** + * `regexp_replace(subject, pattern, replacement)` implemented with java.util.regex.Pattern. + * + * Replaces all occurrences of pattern in subject with replacement. + * + * Inputs: + * - inputs(0): VarCharVector subject column + * - inputs(1): VarCharVector pattern (scalar, length-1) + * - inputs(2): VarCharVector replacement (scalar, length-1) + * + * Output: VarCharVector, same length as subject. + */ +class RegExpReplaceUDF extends CometUDF { + + private val patternCache = new ConcurrentHashMap[String, Pattern]() + + override def evaluate(inputs: Array[ValueVector], numRows: Int): ValueVector = { + require(inputs.length == 3, s"RegExpReplaceUDF expects 3 inputs, got ${inputs.length}") + val subject = inputs(0).asInstanceOf[VarCharVector] + val patternVec = inputs(1).asInstanceOf[VarCharVector] + val replacementVec = inputs(2).asInstanceOf[VarCharVector] + require( + patternVec.getValueCount >= 1 && !patternVec.isNull(0), + "RegExpReplaceUDF requires a non-null scalar pattern") + require( + replacementVec.getValueCount >= 1 && !replacementVec.isNull(0), + "RegExpReplaceUDF requires a non-null scalar replacement") + + val patternStr = new String(patternVec.get(0), StandardCharsets.UTF_8) + val pattern = patternCache.computeIfAbsent(patternStr, Pattern.compile) + val replacement = new String(replacementVec.get(0), StandardCharsets.UTF_8) + + val n = subject.getValueCount + val out = new VarCharVector("regexp_replace_result", CometArrowAllocator) + out.allocateNew(n) + + var i = 0 + while (i < n) { + if (subject.isNull(i)) { + out.setNull(i) + } else { + val s = new String(subject.get(i), StandardCharsets.UTF_8) + val result = pattern.matcher(s).replaceAll(replacement) + out.setSafe(i, result.getBytes(StandardCharsets.UTF_8)) + } + i += 1 + } + out.setValueCount(n) + out + } +} diff --git a/spark/src/main/scala/org/apache/comet/udf/StringSplitUDF.scala b/spark/src/main/scala/org/apache/comet/udf/StringSplitUDF.scala new file mode 100644 index 0000000000..bce92538cd --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/udf/StringSplitUDF.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.udf + +import java.nio.charset.StandardCharsets +import java.util.concurrent.ConcurrentHashMap +import java.util.regex.Pattern + +import org.apache.arrow.vector.ValueVector +import org.apache.arrow.vector.VarCharVector +import org.apache.arrow.vector.complex.ListVector +import org.apache.arrow.vector.types.pojo.{ArrowType, FieldType} + +import org.apache.comet.CometArrowAllocator + +/** + * `split(subject, pattern, limit)` implemented with java.util.regex.Pattern. + * + * Splits the subject string around matches of the pattern, up to the specified limit. + * + * Inputs: + * - inputs(0): VarCharVector subject column + * - inputs(1): VarCharVector pattern (scalar, length-1) + * - inputs(2): IntVector limit (scalar, length-1) + * + * Output: ListVector of VarChar, same length as subject. + */ +class StringSplitUDF extends CometUDF { + + private val patternCache = new ConcurrentHashMap[String, Pattern]() + + override def evaluate(inputs: Array[ValueVector], numRows: Int): ValueVector = { + require(inputs.length == 3, s"StringSplitUDF expects 3 inputs, got ${inputs.length}") + val subject = inputs(0).asInstanceOf[VarCharVector] + val patternVec = inputs(1).asInstanceOf[VarCharVector] + val limitVec = inputs(2).asInstanceOf[org.apache.arrow.vector.IntVector] + require( + patternVec.getValueCount >= 1 && !patternVec.isNull(0), + "StringSplitUDF requires a non-null scalar pattern") + require( + limitVec.getValueCount >= 1 && !limitVec.isNull(0), + "StringSplitUDF requires a non-null scalar limit") + + val patternStr = new String(patternVec.get(0), StandardCharsets.UTF_8) + val pattern = patternCache.computeIfAbsent(patternStr, Pattern.compile) + val limit = limitVec.get(0) + + val n = subject.getValueCount + val out = ListVector.empty("string_split_result", CometArrowAllocator) + out.addOrGetVector[VarCharVector](new FieldType(true, ArrowType.Utf8.INSTANCE, null)) + val writer = out.getWriter + + var i = 0 + while (i < n) { + if (subject.isNull(i)) { + out.setNull(i) + } else { + val s = new String(subject.get(i), StandardCharsets.UTF_8) + // Spark semantics: limit <= 0 means no limit (split returns all) + val parts = if (limit <= 0) pattern.split(s, -1) else pattern.split(s, limit) + writer.setPosition(i) + writer.startList() + var j = 0 + while (j < parts.length) { + val bytes = parts(j).getBytes(StandardCharsets.UTF_8) + val buf = CometArrowAllocator.buffer(bytes.length) + buf.writeBytes(bytes) + writer.varChar().writeVarChar(0, bytes.length, buf) + buf.close() + j += 1 + } + writer.endList() + } + i += 1 + } + out.setValueCount(n) + out + } +} diff --git a/spark/src/main/spark-3.5/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-3.5/org/apache/comet/shims/CometExprShim.scala index d3e3270700..9aca45b29d 100644 --- a/spark/src/main/spark-3.5/org/apache/comet/shims/CometExprShim.scala +++ b/spark/src/main/spark-3.5/org/apache/comet/shims/CometExprShim.scala @@ -56,7 +56,7 @@ trait CometExprShim extends CommonStringExprs { val isCastSupported = castSupported match { case Compatible(_) => true - case Incompatible(_) => true + case Incompatible(_, _) => true case _ => false } diff --git a/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala index 3d5b34bfd2..26405c9f75 100644 --- a/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala +++ b/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala @@ -102,7 +102,7 @@ trait CometExprShim extends CommonStringExprs { val isCastSupported = castSupported match { case Compatible(_) => true - case Incompatible(_) => true + case Incompatible(_, _) => true case _ => false } diff --git a/spark/src/main/spark-4.1/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-4.1/org/apache/comet/shims/CometExprShim.scala index 5e906a0d83..d6f201b70c 100644 --- a/spark/src/main/spark-4.1/org/apache/comet/shims/CometExprShim.scala +++ b/spark/src/main/spark-4.1/org/apache/comet/shims/CometExprShim.scala @@ -101,7 +101,7 @@ trait CometExprShim extends CommonStringExprs { val isCastSupported = castSupported match { case Compatible(_) => true - case Incompatible(_) => true + case Incompatible(_, _) => true case _ => false } diff --git a/spark/src/main/spark-4.2/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-4.2/org/apache/comet/shims/CometExprShim.scala index 5e906a0d83..d6f201b70c 100644 --- a/spark/src/main/spark-4.2/org/apache/comet/shims/CometExprShim.scala +++ b/spark/src/main/spark-4.2/org/apache/comet/shims/CometExprShim.scala @@ -101,7 +101,7 @@ trait CometExprShim extends CommonStringExprs { val isCastSupported = castSupported match { case Compatible(_) => true - case Incompatible(_) => true + case Incompatible(_, _) => true case _ => false } diff --git a/spark/src/test/resources/sql-tests/expressions/string/regexp_extract.sql b/spark/src/test/resources/sql-tests/expressions/string/regexp_extract.sql new file mode 100644 index 0000000000..fac725839a --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/string/regexp_extract.sql @@ -0,0 +1,58 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- Test regexp_extract via JVM regex engine +-- Config: spark.comet.jvmUdf.enabled=true +-- Config: spark.comet.exec.regexp.engine=java + +statement +CREATE TABLE test_regexp_extract(s string) USING parquet + +statement +INSERT INTO test_regexp_extract VALUES ('abc123def'), ('no match'), (NULL), ('xyz789'), ('hello world'), ('aa') + +-- group 0: entire match +query +SELECT regexp_extract(s, '\d+', 0) FROM test_regexp_extract + +-- group 1: first capturing group +query +SELECT regexp_extract(s, '([a-z]+)(\d+)', 1) FROM test_regexp_extract + +-- group 2: second capturing group +query +SELECT regexp_extract(s, '([a-z]+)(\d+)', 2) FROM test_regexp_extract + +-- no match returns empty string +query +SELECT regexp_extract(s, 'NOMATCH', 0) FROM test_regexp_extract + +-- backreference pattern (Java-only) +query +SELECT regexp_extract(s, '(\w)\1', 0) FROM test_regexp_extract + +-- lookahead (Java-only) +query +SELECT regexp_extract(s, 'abc(?=\d)', 0) FROM test_regexp_extract + +-- embedded flags (Java-only) +query +SELECT regexp_extract(s, '(?i)HELLO', 0) FROM test_regexp_extract + +-- literal arguments +query +SELECT regexp_extract('abc123', '(\d+)', 1), regexp_extract('no digits', '(\d+)', 1), regexp_extract(NULL, '(\d+)', 1) diff --git a/spark/src/test/resources/sql-tests/expressions/string/regexp_extract_all.sql b/spark/src/test/resources/sql-tests/expressions/string/regexp_extract_all.sql new file mode 100644 index 0000000000..f56b5258ff --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/string/regexp_extract_all.sql @@ -0,0 +1,54 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- Test regexp_extract_all via JVM regex engine +-- Config: spark.comet.jvmUdf.enabled=true +-- Config: spark.comet.exec.regexp.engine=java + +statement +CREATE TABLE test_regexp_extract_all(s string) USING parquet + +statement +INSERT INTO test_regexp_extract_all VALUES ('abc123def456'), ('no match'), (NULL), ('100-200-300'), ('hello world') + +-- group 0: all entire matches +query +SELECT regexp_extract_all(s, '\d+', 0) FROM test_regexp_extract_all + +-- group 1: first capturing group from each match +query +SELECT regexp_extract_all(s, '([a-z]+)(\d+)', 1) FROM test_regexp_extract_all + +-- group 2: second capturing group from each match +query +SELECT regexp_extract_all(s, '([a-z]+)(\d+)', 2) FROM test_regexp_extract_all + +-- no match returns empty array +query +SELECT regexp_extract_all(s, 'NOMATCH', 0) FROM test_regexp_extract_all + +-- backreference pattern (Java-only) +query +SELECT regexp_extract_all(s, '(\d)\1', 0) FROM test_regexp_extract_all + +-- embedded flags (Java-only) +query +SELECT regexp_extract_all(s, '(?i)[A-Z]+', 0) FROM test_regexp_extract_all + +-- literal arguments +query +SELECT regexp_extract_all('abc123def456', '(\d+)', 1), regexp_extract_all('no digits', '(\d+)', 1), regexp_extract_all(NULL, '(\d+)', 1) diff --git a/spark/src/test/resources/sql-tests/expressions/string/regexp_instr.sql b/spark/src/test/resources/sql-tests/expressions/string/regexp_instr.sql new file mode 100644 index 0000000000..fb4e42074e --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/string/regexp_instr.sql @@ -0,0 +1,50 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- Test regexp_instr via JVM regex engine +-- Config: spark.comet.jvmUdf.enabled=true +-- Config: spark.comet.exec.regexp.engine=java + +statement +CREATE TABLE test_regexp_instr(s string) USING parquet + +statement +INSERT INTO test_regexp_instr VALUES ('abc123def'), ('no match'), (NULL), ('123xyz'), ('hello world'), ('aa') + +-- basic: position of first digit sequence +query +SELECT regexp_instr(s, '\d+', 0) FROM test_regexp_instr + +-- group 1 (still returns position of entire match per Spark semantics) +query +SELECT regexp_instr(s, '([a-z]+)(\d+)', 1) FROM test_regexp_instr + +-- no match returns 0 +query +SELECT regexp_instr(s, 'NOMATCH', 0) FROM test_regexp_instr + +-- backreference pattern (Java-only) +query +SELECT regexp_instr(s, '(\w)\1', 0) FROM test_regexp_instr + +-- embedded flags (Java-only) +query +SELECT regexp_instr(s, '(?i)HELLO', 0) FROM test_regexp_instr + +-- literal arguments +query +SELECT regexp_instr('abc123', '\d+', 0), regexp_instr('no digits', '\d+', 0), regexp_instr(NULL, '\d+', 0) diff --git a/spark/src/test/resources/sql-tests/expressions/string/regexp_replace_java.sql b/spark/src/test/resources/sql-tests/expressions/string/regexp_replace_java.sql new file mode 100644 index 0000000000..549d9d77b0 --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/string/regexp_replace_java.sql @@ -0,0 +1,52 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- Test regexp_replace via JVM regex engine +-- Config: spark.comet.jvmUdf.enabled=true +-- Config: spark.comet.exec.regexp.engine=java + +statement +CREATE TABLE test_regexp_replace_java(s string) USING parquet + +statement +INSERT INTO test_regexp_replace_java VALUES ('100-200'), ('abc'), (''), (NULL), ('phone 123-456-7890'), ('aabbcc') + +query +SELECT regexp_replace(s, '\d+', 'X') FROM test_regexp_replace_java + +query +SELECT regexp_replace(s, '\d+', 'X', 1) FROM test_regexp_replace_java + +-- backreference in replacement +query +SELECT regexp_replace(s, '(\d+)-(\d+)', '$2-$1') FROM test_regexp_replace_java + +-- backreference in pattern (Java-only) +query +SELECT regexp_replace(s, '(\w)\1', 'Z') FROM test_regexp_replace_java + +-- lookahead (Java-only) +query +SELECT regexp_replace(s, '\d+(?=-)', 'X') FROM test_regexp_replace_java + +-- embedded flags (Java-only) +query +SELECT regexp_replace(s, '(?i)ABC', 'X') FROM test_regexp_replace_java + +-- literal arguments +query +SELECT regexp_replace('100-200', '(\d+)', 'X'), regexp_replace('abc', '(\d+)', 'X'), regexp_replace(NULL, '(\d+)', 'X') diff --git a/spark/src/test/resources/sql-tests/expressions/string/regexp_replace_enabled.sql b/spark/src/test/resources/sql-tests/expressions/string/regexp_replace_rust.sql similarity index 90% rename from spark/src/test/resources/sql-tests/expressions/string/regexp_replace_enabled.sql rename to spark/src/test/resources/sql-tests/expressions/string/regexp_replace_rust.sql index 97b4917c33..164517084a 100644 --- a/spark/src/test/resources/sql-tests/expressions/string/regexp_replace_enabled.sql +++ b/spark/src/test/resources/sql-tests/expressions/string/regexp_replace_rust.sql @@ -15,8 +15,8 @@ -- specific language governing permissions and limitations -- under the License. --- Test regexp_replace() with regexp allowIncompatible enabled (happy path) --- Config: spark.comet.expression.regexp.allowIncompatible=true +-- Test regexp_replace() with Rust regexp engine +-- Config: spark.comet.exec.regexp.engine=rust statement CREATE TABLE test_regexp_replace_enabled(s string) USING parquet diff --git a/spark/src/test/resources/sql-tests/expressions/string/rlike.sql b/spark/src/test/resources/sql-tests/expressions/string/rlike_java.sql similarity index 50% rename from spark/src/test/resources/sql-tests/expressions/string/rlike.sql rename to spark/src/test/resources/sql-tests/expressions/string/rlike_java.sql index 97350918ba..7994c3915c 100644 --- a/spark/src/test/resources/sql-tests/expressions/string/rlike.sql +++ b/spark/src/test/resources/sql-tests/expressions/string/rlike_java.sql @@ -15,17 +15,37 @@ -- specific language governing permissions and limitations -- under the License. +-- Test RLIKE via JVM regex engine +-- Config: spark.comet.jvmUdf.enabled=true +-- Config: spark.comet.exec.regexp.engine=java + statement -CREATE TABLE test_rlike(s string) USING parquet +CREATE TABLE test_rlike_java(s string) USING parquet statement -INSERT INTO test_rlike VALUES ('hello'), ('12345'), (''), (NULL), ('Hello World'), ('abc123') +INSERT INTO test_rlike_java VALUES ('hello'), ('12345'), (''), (NULL), ('Hello World'), ('abc123'), ('aa'), ('ab') + +query +SELECT s RLIKE '^\d+$' FROM test_rlike_java + +query +SELECT s RLIKE '^[a-z]+$' FROM test_rlike_java + +query +SELECT s RLIKE '' FROM test_rlike_java + +-- backreference (Java-only) +query +SELECT s RLIKE '^(\w)\1$' FROM test_rlike_java -query expect_fallback(Regexp pattern) -SELECT s RLIKE '^[0-9]+$' FROM test_rlike +-- lookahead (Java-only) +query +SELECT s RLIKE 'abc(?=\d)' FROM test_rlike_java -query expect_fallback(Regexp pattern) -SELECT s RLIKE '^[a-z]+$' FROM test_rlike +-- embedded flags (Java-only) +query +SELECT s RLIKE '(?i)hello' FROM test_rlike_java -query spark_answer_only -SELECT s RLIKE '' FROM test_rlike +-- literal arguments +query +SELECT 'hello' RLIKE '^[a-z]+$', '12345' RLIKE '^\d+$', '' RLIKE '', NULL RLIKE 'a' diff --git a/spark/src/test/resources/sql-tests/expressions/string/rlike_enabled.sql b/spark/src/test/resources/sql-tests/expressions/string/rlike_rust.sql similarity index 90% rename from spark/src/test/resources/sql-tests/expressions/string/rlike_enabled.sql rename to spark/src/test/resources/sql-tests/expressions/string/rlike_rust.sql index 5b2bd05fb3..11d98fff12 100644 --- a/spark/src/test/resources/sql-tests/expressions/string/rlike_enabled.sql +++ b/spark/src/test/resources/sql-tests/expressions/string/rlike_rust.sql @@ -15,8 +15,8 @@ -- specific language governing permissions and limitations -- under the License. --- Test RLIKE with regexp allowIncompatible enabled (happy path) --- Config: spark.comet.expression.regexp.allowIncompatible=true +-- Test RLIKE with Rust regexp engine +-- Config: spark.comet.exec.regexp.engine=rust statement CREATE TABLE test_rlike_enabled(s string) USING parquet diff --git a/spark/src/test/resources/sql-tests/expressions/string/split_java.sql b/spark/src/test/resources/sql-tests/expressions/string/split_java.sql new file mode 100644 index 0000000000..913b3c61fc --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/string/split_java.sql @@ -0,0 +1,54 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- Test split via JVM regex engine +-- Config: spark.comet.jvmUdf.enabled=true +-- Config: spark.comet.exec.regexp.engine=java + +statement +CREATE TABLE test_split_java(s string) USING parquet + +statement +INSERT INTO test_split_java VALUES ('one,two,three'), ('hello'), (''), (NULL), ('a::b::c'), ('aXbXc') + +-- basic split on comma +query +SELECT split(s, ',', -1) FROM test_split_java + +-- split with limit +query +SELECT split(s, ',', 2) FROM test_split_java + +-- split on regex pattern +query +SELECT split(s, '[,:]', -1) FROM test_split_java + +-- split on multi-char separator +query +SELECT split(s, '::', -1) FROM test_split_java + +-- lookahead in pattern (Java-only) +query +SELECT split(s, '(?=X)', -1) FROM test_split_java + +-- embedded flags (Java-only) +query +SELECT split(s, '(?i)x', -1) FROM test_split_java + +-- literal arguments +query +SELECT split('a,b,c', ',', -1), split('hello', ',', -1), split(NULL, ',', -1) diff --git a/spark/src/test/resources/sql-tests/expressions/string/regexp_replace.sql b/spark/src/test/resources/sql-tests/expressions/string/split_rust.sql similarity index 60% rename from spark/src/test/resources/sql-tests/expressions/string/regexp_replace.sql rename to spark/src/test/resources/sql-tests/expressions/string/split_rust.sql index 967674a894..99f08fec28 100644 --- a/spark/src/test/resources/sql-tests/expressions/string/regexp_replace.sql +++ b/spark/src/test/resources/sql-tests/expressions/string/split_rust.sql @@ -15,14 +15,24 @@ -- specific language governing permissions and limitations -- under the License. +-- Test split with Rust regexp engine +-- Config: spark.comet.exec.regexp.engine=rust + statement -CREATE TABLE test_regexp_replace(s string) USING parquet +CREATE TABLE test_split_rust_enabled(s string) USING parquet statement -INSERT INTO test_regexp_replace VALUES ('100-200'), ('abc'), (''), (NULL), ('phone 123-456-7890') +INSERT INTO test_split_rust_enabled VALUES ('one,two,three'), ('hello'), (''), (NULL), ('a::b::c') + +query +SELECT split(s, ',', -1) FROM test_split_rust_enabled + +query +SELECT split(s, ',', 2) FROM test_split_rust_enabled -query expect_fallback(Regexp pattern) -SELECT regexp_replace(s, '(\\d+)', 'X') FROM test_regexp_replace +query +SELECT split(s, '::', -1) FROM test_split_rust_enabled -query expect_fallback(Regexp pattern) -SELECT regexp_replace(s, '(\\d+)', 'X', 1) FROM test_regexp_replace +-- literal arguments +query +SELECT split('a,b,c', ',', -1), split('hello', ',', -1), split(NULL, ',', -1) diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 9b5e85be3b..3bcb7bdc41 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -934,7 +934,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { // add repetitive data to trigger dictionary encoding Range(0, 100).map(_ => "John Smith") withParquetFile(data.zipWithIndex, withDictionary) { file => - withSQLConf(CometConf.getExprAllowIncompatConfigKey("regexp") -> "true") { + withSQLConf(CometConf.COMET_REGEXP_ENGINE.key -> CometConf.REGEXP_ENGINE_RUST) { spark.read.parquet(file).createOrReplaceTempView(table) val query = sql(s"select _2 as id, _1 rlike 'R[a-z]+s [Rr]ose' from $table") checkSparkAnswerAndOperator(query) @@ -1006,7 +1006,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { // "Smith$", "Smith\\Z", "Smith\\z") - withSQLConf(CometConf.getExprAllowIncompatConfigKey("regexp") -> "true") { + withSQLConf(CometConf.COMET_REGEXP_ENGINE.key -> CometConf.REGEXP_ENGINE_RUST) { patterns.foreach { pattern => val query2 = sql(s"select name, '$pattern', name rlike '$pattern' from $table") checkSparkAnswerAndOperator(query2) @@ -1066,7 +1066,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { "\\V") val qualifiers = Seq("", "+", "*", "?", "{1,}") - withSQLConf(CometConf.getExprAllowIncompatConfigKey("regexp") -> "true") { + withSQLConf(CometConf.COMET_REGEXP_ENGINE.key -> CometConf.REGEXP_ENGINE_RUST) { // testing every possible combination takes too long, so we pick some // random combinations for (_ <- 0 until 100) { diff --git a/spark/src/test/scala/org/apache/comet/CometFuzzIcebergSuite.scala b/spark/src/test/scala/org/apache/comet/CometFuzzIcebergSuite.scala index f37d997c41..6fa6061d34 100644 --- a/spark/src/test/scala/org/apache/comet/CometFuzzIcebergSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometFuzzIcebergSuite.scala @@ -133,7 +133,7 @@ class CometFuzzIcebergSuite extends CometFuzzIcebergBase { } test("regexp_replace") { - withSQLConf(CometConf.getExprAllowIncompatConfigKey("regexp") -> "true") { + withSQLConf(CometConf.COMET_REGEXP_ENGINE.key -> CometConf.REGEXP_ENGINE_RUST) { val df = spark.table(icebergTableName) // We want to make sure that the schema generator wasn't modified to accidentally omit // StringType, since then this test would not run any queries and silently pass. diff --git a/spark/src/test/scala/org/apache/comet/CometFuzzTestSuite.scala b/spark/src/test/scala/org/apache/comet/CometFuzzTestSuite.scala index c1d7d8e72e..f4f903bf1d 100644 --- a/spark/src/test/scala/org/apache/comet/CometFuzzTestSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometFuzzTestSuite.scala @@ -219,7 +219,7 @@ class CometFuzzTestSuite extends CometFuzzTestBase { } test("regexp_replace") { - withSQLConf(CometConf.getExprAllowIncompatConfigKey("regexp") -> "true") { + withSQLConf(CometConf.COMET_REGEXP_ENGINE.key -> CometConf.REGEXP_ENGINE_RUST) { val df = spark.read.parquet(filename) df.createOrReplaceTempView("t1") // We want to make sure that the schema generator wasn't modified to accidentally omit diff --git a/spark/src/test/scala/org/apache/comet/CometRegExpJvmSuite.scala b/spark/src/test/scala/org/apache/comet/CometRegExpJvmSuite.scala new file mode 100644 index 0000000000..dfe474996b --- /dev/null +++ b/spark/src/test/scala/org/apache/comet/CometRegExpJvmSuite.scala @@ -0,0 +1,393 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet + +import org.apache.spark.SparkConf +import org.apache.spark.sql.CometTestBase +import org.apache.spark.sql.comet.{CometFilterExec, CometProjectExec} +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper + +class CometRegExpJvmSuite extends CometTestBase with AdaptiveSparkPlanHelper { + + override protected def sparkConf: SparkConf = + super.sparkConf + .set(CometConf.COMET_JVM_UDF_ENABLED.key, "true") + .set(CometConf.COMET_REGEXP_ENGINE.key, CometConf.REGEXP_ENGINE_JAVA) + + // Patterns that the Rust regex crate cannot handle. Using one of these proves + // the JVM path was taken: if the pattern reached native, native would have + // rejected it and the operator would not be Comet. + private val backreference = "^(\\\\w)\\\\1$" + private val lookahead = "foo(?=bar)" + private val lookbehind = "(?<=foo)bar" + private val embeddedFlags = "(?i)foo" + private val namedGroup = "(?\\\\d)" + + private def withSubjects(values: String*)(f: => Unit): Unit = { + withTable("t") { + sql("CREATE TABLE t (s STRING) USING parquet") + val rows = values + .map(v => if (v == null) "(NULL)" else s"('${v.replace("'", "''")}')") + .mkString(", ") + sql(s"INSERT INTO t VALUES $rows") + f + } + } + + // ========== rlike tests ========== + + test("rlike: projection produces Java regex semantics with null handling") { + withSubjects("abc123", "no digits", null, "mixed_42_data") { + val df = sql("SELECT s, s rlike '\\\\d+' AS m FROM t") + checkSparkAnswerAndOperator(df) + } + } + + test("rlike: predicate filters rows using Java regex semantics") { + withSubjects("abc123", "no digits", null, "mixed_42_data") { + val df = sql("SELECT s FROM t WHERE s rlike '\\\\d+'") + checkSparkAnswerAndOperator(df) + } + } + + test("rlike: backreference in projection (Java-only construct)") { + withSubjects("aa", "ab", "xyzzy", null) { + val df = sql(s"SELECT s, s rlike '$backreference' FROM t") + checkSparkAnswerAndOperator(df) + val plan = df.queryExecution.executedPlan + assert( + collect(plan) { case p: CometProjectExec => p }.nonEmpty, + s"Expected CometProjectExec in:\n$plan") + } + } + + test("rlike: backreference in predicate (Java-only construct)") { + withSubjects("aa", "ab", "xyzzy", null) { + val df = sql(s"SELECT s FROM t WHERE s rlike '$backreference'") + checkSparkAnswerAndOperator(df) + val plan = df.queryExecution.executedPlan + assert( + collect(plan) { case f: CometFilterExec => f }.nonEmpty, + s"Expected CometFilterExec in:\n$plan") + } + } + + test("rlike: lookahead pattern (Java-only construct)") { + withSubjects("foobar", "foobaz", "barfoo", null) { + checkSparkAnswerAndOperator(sql(s"SELECT s, s rlike '$lookahead' FROM t")) + checkSparkAnswerAndOperator(sql(s"SELECT s FROM t WHERE s rlike '$lookahead'")) + } + } + + test("rlike: lookbehind pattern (Java-only construct)") { + withSubjects("foobar", "barbar", "foofoo", null) { + checkSparkAnswerAndOperator(sql(s"SELECT s, s rlike '$lookbehind' FROM t")) + } + } + + test("rlike: embedded case-insensitive flag (Java-only construct)") { + withSubjects("FOO", "foo", "fOO", "bar") { + checkSparkAnswerAndOperator(sql(s"SELECT s, s rlike '$embeddedFlags' FROM t")) + } + } + + test("rlike: named groups (Java-only construct)") { + withSubjects("a1", "ab", "9z", null) { + checkSparkAnswerAndOperator(sql(s"SELECT s, s rlike '$namedGroup' FROM t")) + } + } + + test("rlike: empty pattern matches every non-null row") { + withSubjects("abc", "", null) { + checkSparkAnswerAndOperator(sql("SELECT s, s rlike '' FROM t")) + } + } + + test("rlike: empty subject string is handled correctly") { + withSubjects("", "x", null) { + checkSparkAnswerAndOperator(sql("SELECT s, s rlike '^$' FROM t")) + } + } + + test("rlike: all-null subject column produces all-null result") { + withSubjects(null, null, null) { + checkSparkAnswerAndOperator(sql("SELECT s rlike '\\\\d+' FROM t")) + } + } + + test("rlike: null literal pattern falls back to Spark") { + withSubjects("a", "b", null) { + checkSparkAnswer(sql("SELECT s rlike CAST(NULL AS STRING) FROM t")) + } + } + + test("rlike: invalid pattern falls back to Spark") { + withSubjects("a") { + val ex = intercept[Throwable](sql("SELECT s rlike '[' FROM t").collect()) + assert( + ex.getMessage.toLowerCase.contains("regex") || + ex.getMessage.contains("PatternSyntax") || + ex.getMessage.contains("Unclosed"), + s"Unexpected error: ${ex.getMessage}") + } + } + + test("rlike: combines with filter, projection, and aggregate") { + withTable("t") { + sql("CREATE TABLE t (s STRING, k INT) USING parquet") + sql("""INSERT INTO t VALUES + | ('aa', 1), ('ab', 1), ('aa', 2), ('xyzzy', 2), ('aa', 3), (NULL, 3)""".stripMargin) + val df = sql(s"""SELECT k, COUNT(*) AS c + |FROM t + |WHERE s rlike '$backreference' + |GROUP BY k + |ORDER BY k""".stripMargin) + checkSparkAnswerAndOperator(df) + } + } + + test("rlike: many rows spanning multiple batches") { + withTable("t") { + sql("CREATE TABLE t (s STRING) USING parquet") + val values = (0 until 5000) + .map(i => if (i % 7 == 0) "(NULL)" else s"('row_${i}_aa')") + .mkString(", ") + sql(s"INSERT INTO t VALUES $values") + checkSparkAnswerAndOperator(sql(s"SELECT s, s rlike '$backreference' FROM t")) + checkSparkAnswerAndOperator(sql(s"SELECT s FROM t WHERE s rlike '$backreference'")) + } + } + + // ========== regexp_extract tests ========== + + test("regexp_extract: basic group extraction") { + withSubjects("abc123def", "no match", null, "xyz789") { + checkSparkAnswerAndOperator( + sql("SELECT s, regexp_extract(s, '([a-z]+)(\\\\d+)', 1) FROM t")) + checkSparkAnswerAndOperator( + sql("SELECT s, regexp_extract(s, '([a-z]+)(\\\\d+)', 2) FROM t")) + } + } + + test("regexp_extract: group 0 returns entire match") { + withSubjects("hello world", "foo123bar", null) { + checkSparkAnswerAndOperator(sql("SELECT s, regexp_extract(s, '\\\\d+', 0) FROM t")) + } + } + + test("regexp_extract: no match returns empty string") { + withSubjects("abc", "def", null) { + checkSparkAnswerAndOperator(sql("SELECT s, regexp_extract(s, '\\\\d+', 0) FROM t")) + } + } + + test("regexp_extract: backreference pattern (Java-only)") { + withSubjects("aa", "ab", "bb", null) { + checkSparkAnswerAndOperator(sql("SELECT s, regexp_extract(s, '(\\\\w)\\\\1', 0) FROM t")) + } + } + + test("regexp_extract: lookahead pattern (Java-only)") { + withSubjects("foobar", "foobaz", null) { + checkSparkAnswerAndOperator(sql("SELECT s, regexp_extract(s, 'foo(?=bar)', 0) FROM t")) + } + } + + test("regexp_extract: embedded flags (Java-only)") { + withSubjects("FOO123", "foo456", "bar789") { + checkSparkAnswerAndOperator( + sql("SELECT s, regexp_extract(s, '(?i)(foo)(\\\\d+)', 2) FROM t")) + } + } + + test("regexp_extract: all-null column") { + withSubjects(null, null, null) { + checkSparkAnswerAndOperator(sql("SELECT regexp_extract(s, '(\\\\d+)', 1) FROM t")) + } + } + + // ========== regexp_extract_all tests ========== + + test("regexp_extract_all: basic extraction of all matches") { + withSubjects("abc123def456", "no match", null, "x1y2z3") { + checkSparkAnswerAndOperator(sql("SELECT s, regexp_extract_all(s, '(\\\\d+)', 1) FROM t")) + } + } + + test("regexp_extract_all: group 0 returns full matches") { + withSubjects("cat bat hat", "no vowels", null) { + checkSparkAnswerAndOperator(sql("SELECT s, regexp_extract_all(s, '[a-z]at', 0) FROM t")) + } + } + + test("regexp_extract_all: multiple groups") { + withSubjects("a1b2c3", "x9y8", null) { + checkSparkAnswerAndOperator( + sql("SELECT s, regexp_extract_all(s, '([a-z])(\\\\d)', 1) FROM t")) + checkSparkAnswerAndOperator( + sql("SELECT s, regexp_extract_all(s, '([a-z])(\\\\d)', 2) FROM t")) + } + } + + test("regexp_extract_all: no matches returns empty array") { + withSubjects("abc", "def") { + checkSparkAnswerAndOperator(sql("SELECT s, regexp_extract_all(s, '\\\\d+', 0) FROM t")) + } + } + + test("regexp_extract_all: lookahead pattern (Java-only)") { + withSubjects("foobar foobaz fooqux") { + checkSparkAnswerAndOperator( + sql("SELECT s, regexp_extract_all(s, 'foo(?=ba[rz])', 0) FROM t")) + } + } + + // ========== regexp_replace tests ========== + + test("regexp_replace: basic replacement") { + withSubjects("abc123def456", "no digits", null) { + checkSparkAnswerAndOperator(sql("SELECT s, regexp_replace(s, '\\\\d+', 'NUM') FROM t")) + } + } + + test("regexp_replace: backreference in pattern (Java-only)") { + withSubjects("aabbcc", "abcabc", null) { + checkSparkAnswerAndOperator(sql("SELECT s, regexp_replace(s, '(\\\\w)\\\\1', 'X') FROM t")) + } + } + + test("regexp_replace: backreference in replacement") { + withSubjects("hello world", "foo bar", null) { + checkSparkAnswerAndOperator( + sql("SELECT s, regexp_replace(s, '(\\\\w+) (\\\\w+)', '$2 $1') FROM t")) + } + } + + test("regexp_replace: lookahead pattern (Java-only)") { + withSubjects("foobar", "foobaz", null) { + checkSparkAnswerAndOperator(sql("SELECT s, regexp_replace(s, 'foo(?=bar)', 'XXX') FROM t")) + } + } + + test("regexp_replace: empty pattern replaces between characters") { + withSubjects("abc", "", null) { + checkSparkAnswerAndOperator(sql("SELECT s, regexp_replace(s, '', '-') FROM t")) + } + } + + test("regexp_replace: all-null column") { + withSubjects(null, null, null) { + checkSparkAnswerAndOperator(sql("SELECT regexp_replace(s, '\\\\d', 'X') FROM t")) + } + } + + // ========== regexp_instr tests ========== + + test("regexp_instr: basic position finding") { + withSubjects("abc123def", "no match", null, "456xyz") { + checkSparkAnswerAndOperator(sql("SELECT s, regexp_instr(s, '\\\\d+', 0) FROM t")) + } + } + + test("regexp_instr: specific group position") { + withSubjects("abc123def456", "xyz", null) { + checkSparkAnswerAndOperator(sql("SELECT s, regexp_instr(s, '([a-z]+)(\\\\d+)', 1) FROM t")) + checkSparkAnswerAndOperator(sql("SELECT s, regexp_instr(s, '([a-z]+)(\\\\d+)', 2) FROM t")) + } + } + + test("regexp_instr: no match returns 0") { + withSubjects("abc", "def", null) { + checkSparkAnswerAndOperator(sql("SELECT s, regexp_instr(s, '\\\\d+', 0) FROM t")) + } + } + + test("regexp_instr: lookahead (Java-only)") { + withSubjects("foobar", "foobaz", null) { + checkSparkAnswerAndOperator(sql("SELECT s, regexp_instr(s, 'foo(?=bar)', 0) FROM t")) + } + } + + // ========== split tests ========== + + test("split: basic regex split") { + withSubjects("a,b,c", "x,,y", null, "single") { + checkSparkAnswerAndOperator(sql("SELECT s, split(s, ',') FROM t")) + } + } + + test("split: regex pattern") { + withSubjects("abc123def456ghi", "no-digits", null) { + checkSparkAnswerAndOperator(sql("SELECT s, split(s, '\\\\d+') FROM t")) + } + } + + test("split: with limit") { + withSubjects("a,b,c,d,e") { + checkSparkAnswerAndOperator(sql("SELECT s, split(s, ',', 3) FROM t")) + } + } + + test("split: limit -1 returns all") { + withSubjects("a,,b,,c") { + checkSparkAnswerAndOperator(sql("SELECT s, split(s, ',', -1) FROM t")) + } + } + + test("split: lookahead pattern (Java-only)") { + withSubjects("camelCaseString", "anotherOne", null) { + checkSparkAnswerAndOperator(sql("SELECT s, split(s, '(?=[A-Z])') FROM t")) + } + } + + test("split: all-null column") { + withSubjects(null, null, null) { + checkSparkAnswerAndOperator(sql("SELECT split(s, ',') FROM t")) + } + } + + // ========== multi-batch and combined tests ========== + + test("regexp_extract: many rows spanning multiple batches") { + withTable("t") { + sql("CREATE TABLE t (s STRING) USING parquet") + val values = (0 until 5000) + .map(i => if (i % 7 == 0) "(NULL)" else s"('item_${i}_value')") + .mkString(", ") + sql(s"INSERT INTO t VALUES $values") + checkSparkAnswerAndOperator( + sql("SELECT s, regexp_extract(s, 'item_(\\\\d+)_value', 1) FROM t")) + } + } + + test("all regexp expressions combined in one query") { + withSubjects("abc123def456", "hello world", null, "aa") { + checkSparkAnswerAndOperator(sql(""" + |SELECT + | s, + | s rlike '\\d+' AS has_digits, + | regexp_extract(s, '(\\d+)', 1) AS first_num, + | regexp_replace(s, '\\d+', 'N') AS replaced, + | regexp_instr(s, '\\d+', 0) AS num_pos + |FROM t + |""".stripMargin)) + } + } +} diff --git a/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala index 89d5dfd4bc..b2e0560ad6 100644 --- a/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala @@ -149,7 +149,7 @@ class CometStringExpressionSuite extends CometTestBase { } test("split string basic") { - withSQLConf("spark.comet.expression.StringSplit.allowIncompatible" -> "true") { + withSQLConf(CometConf.COMET_REGEXP_ENGINE.key -> CometConf.REGEXP_ENGINE_RUST) { withParquetTable((0 until 5).map(i => (s"value$i,test$i", i)), "tbl") { checkSparkAnswerAndOperator("SELECT split(_1, ',') FROM tbl") checkSparkAnswerAndOperator("SELECT split('one,two,three', ',') FROM tbl") @@ -159,7 +159,7 @@ class CometStringExpressionSuite extends CometTestBase { } test("split string with limit") { - withSQLConf("spark.comet.expression.StringSplit.allowIncompatible" -> "true") { + withSQLConf(CometConf.COMET_REGEXP_ENGINE.key -> CometConf.REGEXP_ENGINE_RUST) { withParquetTable((0 until 5).map(i => ("a,b,c,d,e", i)), "tbl") { checkSparkAnswerAndOperator("SELECT split(_1, ',', 2) FROM tbl") checkSparkAnswerAndOperator("SELECT split(_1, ',', 3) FROM tbl") @@ -170,7 +170,7 @@ class CometStringExpressionSuite extends CometTestBase { } test("split string with regex patterns") { - withSQLConf("spark.comet.expression.StringSplit.allowIncompatible" -> "true") { + withSQLConf(CometConf.COMET_REGEXP_ENGINE.key -> CometConf.REGEXP_ENGINE_RUST) { withParquetTable((0 until 5).map(i => ("word1 word2 word3", i)), "tbl") { checkSparkAnswerAndOperator("SELECT split(_1, ' ') FROM tbl") checkSparkAnswerAndOperator("SELECT split(_1, '\\\\s+') FROM tbl") @@ -183,7 +183,7 @@ class CometStringExpressionSuite extends CometTestBase { } test("split string edge cases") { - withSQLConf("spark.comet.expression.StringSplit.allowIncompatible" -> "true") { + withSQLConf(CometConf.COMET_REGEXP_ENGINE.key -> CometConf.REGEXP_ENGINE_RUST) { withParquetTable(Seq(("", 0), ("single", 1), (null, 2), ("a", 3)), "tbl") { checkSparkAnswerAndOperator("SELECT split(_1, ',') FROM tbl") } @@ -191,7 +191,7 @@ class CometStringExpressionSuite extends CometTestBase { } test("split string with UTF-8 characters") { - withSQLConf("spark.comet.expression.StringSplit.allowIncompatible" -> "true") { + withSQLConf(CometConf.COMET_REGEXP_ENGINE.key -> CometConf.REGEXP_ENGINE_RUST) { // CJK characters withParquetTable(Seq(("你好,世界", 0), ("こんにちは,世界", 1)), "tbl_cjk") { checkSparkAnswerAndOperator("SELECT split(_1, ',') FROM tbl_cjk") diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometRegExpBenchmark.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometRegExpBenchmark.scala new file mode 100644 index 0000000000..ae2ddc6a66 --- /dev/null +++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometRegExpBenchmark.scala @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.benchmark + +import org.apache.spark.benchmark.Benchmark + +import org.apache.comet.CometConf + +/** + * Configuration for a single rlike pattern under benchmark. + * + * @param name + * short label for the pattern + * @param pattern + * the regex literal supplied to rlike + */ +case class RegExpPattern(name: String, pattern: String) + +/** + * Benchmark `rlike` across all execution modes: + * - Spark + * - Comet (Scan only) + * - Comet (Scan + Exec, native Rust regex) + * - Comet (Scan + Exec, JVM-side java.util.regex) + * + * To run: + * {{{ + * SPARK_GENERATE_BENCHMARK_FILES=1 \ + * make benchmark-org.apache.spark.sql.benchmark.CometRegExpBenchmark + * }}} + * + * Results land in `spark/benchmarks/CometRegExpBenchmark-**results.txt`. + */ +object CometRegExpBenchmark extends CometBenchmarkBase { + + // Patterns chosen to span common rlike shapes. Avoid Java-only constructs + // that the native (Rust) path cannot accept, since those would be skipped + // rather than benchmarked in the native case. + private val patterns = List( + RegExpPattern("character_class", "[0-9]+"), + RegExpPattern("anchored", "^[0-9]"), + RegExpPattern("alternation", "abc|def|ghi"), + RegExpPattern("multi_class", "[a-zA-Z][0-9]+"), + RegExpPattern("repetition", "(ab){2,}")) + + override def runCometBenchmark(mainArgs: Array[String]): Unit = { + runBenchmarkWithTable("rlike modes", 1024) { v => + withTempPath { dir => + withTempTable("parquetV1Table") { + prepareTable( + dir, + spark.sql(s"SELECT REPEAT(CAST(value AS STRING), 10) AS c1 FROM $tbl")) + + patterns.foreach { p => + val query = s"select c1 rlike '${p.pattern}' from parquetV1Table" + runBenchmark(p.name) { + runRLikeModes(p.name, v, query) + } + } + } + } + } + } + + /** Runs all four modes for a single rlike query. */ + private def runRLikeModes(name: String, cardinality: Long, query: String): Unit = { + val benchmark = new Benchmark(name, cardinality, output = output) + + benchmark.addCase("Spark") { _ => + withSQLConf(CometConf.COMET_ENABLED.key -> "false") { + spark.sql(query).noop() + } + } + + benchmark.addCase("Comet (Scan)") { _ => + withSQLConf( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "false") { + spark.sql(query).noop() + } + } + + val baseExec = Map( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true", + "spark.sql.optimizer.constantFolding.enabled" -> "false") + + benchmark.addCase("Comet (Exec, native Rust regex)") { _ => + val configs = baseExec ++ Map(CometConf.getExprAllowIncompatConfigKey("regexp") -> "true") + withSQLConf(configs.toSeq: _*) { + spark.sql(query).noop() + } + } + + benchmark.addCase("Comet (Exec, JVM regex)") { _ => + val configs = + baseExec ++ Map( + CometConf.COMET_JVM_UDF_ENABLED.key -> "true", + CometConf.COMET_REGEXP_ENGINE.key -> CometConf.REGEXP_ENGINE_JAVA) + withSQLConf(configs.toSeq: _*) { + spark.sql(query).noop() + } + } + + benchmark.run() + } +}