diff --git a/encodings/fsst/src/dfa/mod.rs b/encodings/fsst/src/dfa/mod.rs index 358fd3a7ab5..01abf95d3b9 100644 --- a/encodings/fsst/src/dfa/mod.rs +++ b/encodings/fsst/src/dfa/mod.rs @@ -208,6 +208,13 @@ enum LikeKind<'a> { impl<'a> LikeKind<'a> { fn parse(pattern: &'a [u8]) -> Option { + // The fast-path matchers below do not understand SQL LIKE escape sequences (e.g. `\%` + // matching a literal `%`). If the pattern contains a backslash we fall back to the + // general implementation, which correctly interprets escapes. + if pattern.contains(&b'\\') { + return None; + } + // `prefix%` (including just `%` where prefix is empty) if let Some(prefix) = pattern.strip_suffix(b"%") && !prefix.contains(&b'%') diff --git a/encodings/fsst/src/dfa/tests.rs b/encodings/fsst/src/dfa/tests.rs index 6ad30ca685d..5e84362eb68 100644 --- a/encodings/fsst/src/dfa/tests.rs +++ b/encodings/fsst/src/dfa/tests.rs @@ -64,6 +64,16 @@ fn test_like_kind_parse() { // Suffix and underscore patterns are not supported. assert!(LikeKind::parse(b"%suffix").is_none()); assert!(LikeKind::parse(b"a_c").is_none()); + + // Patterns containing the SQL LIKE escape character must NOT be parsed by the fast path, + // because that path treats `%` and `_` literally and would misinterpret escapes. For + // example, `%\%` (the pattern produced by Spark's `endsWith("%")`) means "ends with `%`", + // not "contains `\`". The fast path should bail so the general implementation handles it. + assert!(LikeKind::parse(br"%\%").is_none()); + assert!(LikeKind::parse(br"\%%").is_none()); + assert!(LikeKind::parse(br"%\_%").is_none()); + assert!(LikeKind::parse(br"\_%").is_none()); + assert!(LikeKind::parse(br"%\\%").is_none()); } /// No symbols — all bytes escaped. Simplest case to see the two tables. diff --git a/java/build.gradle.kts b/java/build.gradle.kts index 3c09fab6950..0561bd837b8 100644 --- a/java/build.gradle.kts +++ b/java/build.gradle.kts @@ -78,5 +78,10 @@ allprojects { } } - tasks.register("format").get().dependsOn("spotlessApply") + if (project.name == "vortex-spark_2.12") { + // vortex-spark_2.12 and vortex-spark_2.13 share a projectDir; format from the 2.13 variant only. + tasks.register("format") { enabled = false } + } else { + tasks.register("format").get().dependsOn("spotlessApply") + } } diff --git a/java/vortex-jni/src/main/java/dev/vortex/api/Expression.java b/java/vortex-jni/src/main/java/dev/vortex/api/Expression.java index e4d8978e112..3c372b3bcb2 100644 --- a/java/vortex-jni/src/main/java/dev/vortex/api/Expression.java +++ b/java/vortex-jni/src/main/java/dev/vortex/api/Expression.java @@ -6,6 +6,7 @@ import com.google.common.base.Preconditions; import dev.vortex.VortexCleaner; import dev.vortex.jni.NativeExpression; +import java.math.BigInteger; import java.util.Arrays; /** @@ -44,6 +45,19 @@ public static Expression column(String fieldName) { return getItem(fieldName, root()); } + /** + * Access a nested field by walking {@code fieldNames} starting from the root of the array. With a single name this + * is equivalent to {@link #column(String)}. + */ + public static Expression column(String[] fieldNames) { + Preconditions.checkArgument(fieldNames.length > 0, "column requires at least one field name"); + Expression expr = root(); + for (String name : fieldNames) { + expr = getItem(name, expr); + } + return expr; + } + /** Project a subset of fields out of a struct expression. */ public static Expression select(String[] fieldNames, Expression child) { return new Expression(NativeExpression.select(fieldNames, child.nativePointer())); @@ -73,6 +87,33 @@ public static Expression isNull(Expression child) { return new Expression(NativeExpression.isNull(child.nativePointer())); } + public static Expression isNotNull(Expression child) { + return new Expression(NativeExpression.isNotNull(child.nativePointer())); + } + + /** + * SQL {@code LIKE} pattern match. + * + * @param negated whether to invert the result (i.e. {@code NOT LIKE}) + * @param caseInsensitive whether to perform case-insensitive matching ({@code ILIKE}) + */ + public static Expression like(Expression child, Expression pattern, boolean negated, boolean caseInsensitive) { + return new Expression( + NativeExpression.like(child.nativePointer(), pattern.nativePointer(), negated, caseInsensitive)); + } + + /** + * {@code value BETWEEN lower AND upper}. + * + * @param lowerStrict {@code true} for {@code lower < value}; {@code false} for {@code lower <= value}. + * @param upperStrict {@code true} for {@code value < upper}; {@code false} for {@code value <= upper}. + */ + public static Expression between( + Expression value, Expression lower, Expression upper, boolean lowerStrict, boolean upperStrict) { + return new Expression(NativeExpression.between( + value.nativePointer(), lower.nativePointer(), upper.nativePointer(), lowerStrict, upperStrict)); + } + public static Expression literal(boolean value) { return new Expression(NativeExpression.literalBool(value, false)); } @@ -109,6 +150,59 @@ public static Expression literal(String value) { return new Expression(NativeExpression.literalString(value)); } + public static Expression literal(byte[] value) { + Preconditions.checkArgument(value != null, "use nullLiteral(DType.BINARY) for a null binary literal"); + return new Expression(NativeExpression.literalBinary(value)); + } + + /** + * Create a decimal literal from its unscaled two's-complement big-endian byte representation (i.e. the value + * returned by {@link BigInteger#toByteArray()}). + */ + public static Expression literalDecimal(BigInteger unscaledValue, int precision, int scale) { + Preconditions.checkArgument(unscaledValue != null, "unscaledValue must not be null"); + return new Expression(NativeExpression.literalDecimal(unscaledValue.toByteArray(), precision, scale, false)); + } + + /** Create a null decimal literal with the specified precision and scale. */ + public static Expression nullLiteralDecimal(int precision, int scale) { + return new Expression(NativeExpression.literalDecimal(new byte[] {0}, precision, scale, true)); + } + + /** + * Create a Date literal. The {@code value} is the number of {@code unit} units since the Unix epoch. + * + * @param unit only {@link TimeUnit#DAYS} and {@link TimeUnit#MILLISECONDS} are valid for Date. + */ + public static Expression literalDate(long value, TimeUnit unit) { + return new Expression(NativeExpression.literalDate(value, unit.tag(), false)); + } + + /** Null Date literal. See {@link #literalDate(long, TimeUnit)} for the {@code unit} constraints. */ + public static Expression nullLiteralDate(TimeUnit unit) { + return new Expression(NativeExpression.literalDate(0L, unit.tag(), true)); + } + + /** + * Create a Timestamp literal. The {@code value} is the number of {@code unit} units since the Unix epoch. + * + * @param timezone optional IANA timezone identifier (e.g. {@code "UTC"}, {@code "America/Los_Angeles"}). Pass + * {@code null} for a local (zone-naive) timestamp. + */ + public static Expression literalTimestamp(long value, TimeUnit unit, String timezone) { + return new Expression(NativeExpression.literalTimestamp(value, unit.tag(), timezone, false)); + } + + /** Null Timestamp literal. See {@link #literalTimestamp(long, TimeUnit, String)} for parameter semantics. */ + public static Expression nullLiteralTimestamp(TimeUnit unit, String timezone) { + return new Expression(NativeExpression.literalTimestamp(0L, unit.tag(), timezone, true)); + } + + /** Create a typed null literal of the given primitive {@link DType}. */ + public static Expression nullLiteral(DType dtype) { + return new Expression(NativeExpression.literalNull(dtype.tag())); + } + private static long[] nativePointers(Expression[] exprs) { return Arrays.stream(exprs).mapToLong(Expression::nativePointer).toArray(); } @@ -138,4 +232,46 @@ public byte code() { return code; } } + + /** Time units for Date/Timestamp literals. Tag values must match the Rust {@code parse_time_unit} table. */ + public enum TimeUnit { + NANOSECONDS((byte) 0), + MICROSECONDS((byte) 1), + MILLISECONDS((byte) 2), + SECONDS((byte) 3), + DAYS((byte) 4); + + private final byte tag; + + TimeUnit(byte tag) { + this.tag = tag; + } + + public byte tag() { + return tag; + } + } + + /** Primitive {@link DType}s that can be used to construct typed null literals via {@link #nullLiteral(DType)}. */ + public enum DType { + BOOL((byte) 0), + I8((byte) 1), + I16((byte) 2), + I32((byte) 3), + I64((byte) 4), + F32((byte) 5), + F64((byte) 6), + UTF8((byte) 7), + BINARY((byte) 8); + + private final byte tag; + + DType(byte tag) { + this.tag = tag; + } + + public byte tag() { + return tag; + } + } } diff --git a/java/vortex-jni/src/main/java/dev/vortex/jni/NativeExpression.java b/java/vortex-jni/src/main/java/dev/vortex/jni/NativeExpression.java index fc5bedbe7b5..0587b937d79 100644 --- a/java/vortex-jni/src/main/java/dev/vortex/jni/NativeExpression.java +++ b/java/vortex-jni/src/main/java/dev/vortex/jni/NativeExpression.java @@ -27,6 +27,13 @@ private NativeExpression() {} public static native long isNull(long childPointer); + public static native long isNotNull(long childPointer); + + public static native long like(long childPointer, long patternPointer, boolean negated, boolean caseInsensitive); + + public static native long between( + long valuePointer, long lowerPointer, long upperPointer, boolean lowerStrict, boolean upperStrict); + public static native long literalBool(boolean value, boolean isNull); public static native long literalI8(byte value, boolean isNull); @@ -43,5 +50,15 @@ private NativeExpression() {} public static native long literalString(String value); + public static native long literalBinary(byte[] value); + + public static native long literalDecimal(byte[] unscaledBigEndian, int precision, int scale, boolean isNull); + + public static native long literalDate(long value, byte timeUnitTag, boolean isNull); + + public static native long literalTimestamp(long value, byte timeUnitTag, String timezone, boolean isNull); + + public static native long literalNull(byte dtypeTag); + public static native void free(long pointer); } diff --git a/java/vortex-spark/src/main/java/dev/vortex/spark/VortexDataSourceV2.java b/java/vortex-spark/src/main/java/dev/vortex/spark/VortexDataSourceV2.java index 1d564b8c0e0..b3d7d637504 100644 --- a/java/vortex-spark/src/main/java/dev/vortex/spark/VortexDataSourceV2.java +++ b/java/vortex-spark/src/main/java/dev/vortex/spark/VortexDataSourceV2.java @@ -20,6 +20,7 @@ import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.connector.catalog.Table; import org.apache.spark.sql.connector.catalog.TableProvider; +import org.apache.spark.sql.connector.expressions.Expressions; import org.apache.spark.sql.connector.expressions.Transform; import org.apache.spark.sql.sources.DataSourceRegister; import org.apache.spark.sql.types.DataType; @@ -118,6 +119,38 @@ public StructType inferSchema(CaseInsensitiveStringMap options) { return dataSchema; } + /** + * Infers partition transforms by inspecting Hive-style {@code key=value} segments in the first listed file path. + * + *

Spark calls this before {@link #getTable(StructType, Transform[], Map)} when the caller did not provide + * explicit partitioning. Returning identity transforms here lets downstream components (notably + * {@link dev.vortex.spark.read.VortexScanBuilder}) tell which schema columns are encoded in the directory layout + * rather than stored inside the Vortex files, which matters for predicate pushdown. + */ + @Override + public Transform[] inferPartitioning(CaseInsensitiveStringMap options) { + var paths = getPaths(options); + if (paths.isEmpty()) { + return new Transform[0]; + } + var formatOptions = buildDataSourceOptions(options.asCaseSensitiveMap()); + String pathToInfer = Objects.requireNonNull(Iterables.getLast(paths)); + if (!pathToInfer.endsWith(".vortex")) { + Optional firstFile = + NativeFiles.listFiles(VortexSparkSession.get(formatOptions), pathToInfer, formatOptions).stream() + .findFirst(); + if (firstFile.isEmpty()) { + return new Transform[0]; + } + pathToInfer = firstFile.get(); + } + Map partitionValues = PartitionPathUtils.parsePartitionValues(pathToInfer); + if (partitionValues.isEmpty()) { + return new Transform[0]; + } + return partitionValues.keySet().stream().map(Expressions::identity).toArray(Transform[]::new); + } + /** * Creates a Vortex table instance with the given schema and properties. * diff --git a/java/vortex-spark/src/main/java/dev/vortex/spark/VortexTable.java b/java/vortex-spark/src/main/java/dev/vortex/spark/VortexTable.java index 92cc55ff211..f65f74ccf19 100644 --- a/java/vortex-spark/src/main/java/dev/vortex/spark/VortexTable.java +++ b/java/vortex-spark/src/main/java/dev/vortex/spark/VortexTable.java @@ -58,7 +58,7 @@ public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) { Map opts = Maps.newHashMap(); opts.putAll(formatOptions); opts.putAll(options); - return new VortexScanBuilder(opts) + return new VortexScanBuilder(opts, partitionTransforms) .addAllPaths(paths) .addAllColumns(Arrays.asList(CatalogV2Util.structTypeToV2Columns(schema))); } diff --git a/java/vortex-spark/src/main/java/dev/vortex/spark/read/PrefetchingIterator.java b/java/vortex-spark/src/main/java/dev/vortex/spark/read/PrefetchingIterator.java deleted file mode 100644 index e5cd96a3958..00000000000 --- a/java/vortex-spark/src/main/java/dev/vortex/spark/read/PrefetchingIterator.java +++ /dev/null @@ -1,94 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -package dev.vortex.spark.read; - -import java.util.Iterator; -import java.util.concurrent.BlockingQueue; -import java.util.concurrent.LinkedBlockingQueue; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicLong; -import java.util.function.ToLongFunction; - -final class PrefetchingIterator implements Iterator, AutoCloseable { - // Global condition variable shared between the prefetcher and consumer threads, - // to coordinate wake ups for when the buffer may no longer be full. - private static final Object CONDITION = new Object(); - - private final BlockingQueue fetched = new LinkedBlockingQueue<>(); - private final Thread producerThread; - private final Iterator delegate; - private final AtomicBoolean closed = new AtomicBoolean(false); - private final AtomicLong bufferBytes = new AtomicLong(0); - private final long maxBufferSize; - private final ToLongFunction sizeFunc; - - PrefetchingIterator(Iterator delegate, long maxBufferSize, ToLongFunction sizeFunc) { - this.delegate = delegate; - this.maxBufferSize = maxBufferSize; - this.sizeFunc = sizeFunc; - this.producerThread = new Thread(this::prefetchLoop, "vortex-prefetch-thread"); - producerThread.setDaemon(true); - producerThread.start(); - } - - private void prefetchLoop() { - try { - while (!closed.get() && delegate.hasNext()) { - while (bufferBytes.get() > maxBufferSize) { - synchronized (CONDITION) { - CONDITION.wait(); - } - } - T nextElem = delegate.next(); - long elemSize = sizeFunc.applyAsLong(nextElem); - bufferBytes.addAndGet(elemSize); - fetched.put(nextElem); - } - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - throw new RuntimeException("Prefetching interrupted", e); - } catch (Exception e) { - throw new RuntimeException("Prefetching failed", e); - } finally { - closed.set(true); - } - } - - @Override - public boolean hasNext() { - // If the prefetcher is not finished, then we could be waiting for - // a fetched item. - while (!closed.get()) { - if (!fetched.isEmpty()) { - return true; - } - } - // If the prefetcher is finished, then we can examine fetched and immediately return a result. - return !fetched.isEmpty(); - } - - @Override - public T next() { - // We assume that this has been called after hasNext() returned true, so it is - // safe to call take() without checking if the queue maybe be empty. - try { - T nextElem = this.fetched.take(); - long elemSize = sizeFunc.applyAsLong(nextElem); - bufferBytes.addAndGet(-elemSize); - // Notify the producer that it may now be able to add more items to the queue. - synchronized (CONDITION) { - CONDITION.notify(); - } - return nextElem; - } catch (InterruptedException e) { - throw new RuntimeException("Prefetch queue take interrupted", e); - } - } - - @Override - public void close() { - closed.set(true); - producerThread.interrupt(); - } -} diff --git a/java/vortex-spark/src/main/java/dev/vortex/spark/read/SparkPredicateToVortexExpression.java b/java/vortex-spark/src/main/java/dev/vortex/spark/read/SparkPredicateToVortexExpression.java new file mode 100644 index 00000000000..8d3fe697153 --- /dev/null +++ b/java/vortex-spark/src/main/java/dev/vortex/spark/read/SparkPredicateToVortexExpression.java @@ -0,0 +1,522 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +package dev.vortex.spark.read; + +import dev.vortex.api.Expression; +import dev.vortex.api.Expression.BinaryOp; +import dev.vortex.api.Expression.TimeUnit; +import org.apache.spark.sql.connector.expressions.Literal; +import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.expressions.filter.AlwaysFalse; +import org.apache.spark.sql.connector.expressions.filter.AlwaysTrue; +import org.apache.spark.sql.connector.expressions.filter.And; +import org.apache.spark.sql.connector.expressions.filter.Not; +import org.apache.spark.sql.connector.expressions.filter.Or; +import org.apache.spark.sql.connector.expressions.filter.Predicate; +import org.apache.spark.sql.types.BinaryType; +import org.apache.spark.sql.types.BooleanType; +import org.apache.spark.sql.types.ByteType; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DateType; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.types.DecimalType; +import org.apache.spark.sql.types.DoubleType; +import org.apache.spark.sql.types.FloatType; +import org.apache.spark.sql.types.IntegerType; +import org.apache.spark.sql.types.LongType; +import org.apache.spark.sql.types.ShortType; +import org.apache.spark.sql.types.StringType; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.types.TimestampNTZType; +import org.apache.spark.sql.types.TimestampType; +import org.apache.spark.unsafe.types.UTF8String; + +import java.math.BigDecimal; +import java.math.BigInteger; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +/** + * Translates {@link Predicate Spark V2 predicates} into Vortex {@link Expression}s for predicate pushdown. + * + *

The translator aims to express every Spark predicate Vortex can evaluate. Predicates that cannot be translated + * (unsupported functions, literals on user-defined types, references to columns not present in the file, etc.) are left + * to Spark for post-scan evaluation. + */ +final class SparkPredicateToVortexExpression { + + private SparkPredicateToVortexExpression() { + } + + /** + * Returns true if the given Spark predicate can be translated to a Vortex expression and every named reference + * resolves to a real field path under {@code dataColumnTypes}. + * + *

{@code dataColumnTypes} maps each pushable top-level column name to its top-level Spark {@link DataType}; + * partition columns and columns the scan does not project should not appear in the map. For nested references + * (for example {@code info.email}) the validator walks the named reference part by part, descending into + * {@link StructType} fields so that {@code info} must be a struct that contains an {@code email} field. + * + *

This is the cheap check used in {@code SupportsPushDownV2Filters.pushPredicates} to decide which predicates + * Spark can drop. It does not allocate any native expressions; if it returns true, {@link #convert(Predicate)} must + * succeed (otherwise callers would silently drop predicates). + */ + static boolean isPushable(Predicate predicate, Map dataColumnTypes) { + for (NamedReference ref : predicate.references()) { + if (!resolveFieldPath(ref.fieldNames(), dataColumnTypes)) { + return false; + } + } + return isStructurallyPushable(predicate); + } + + /** + * Walks {@code parts} against {@code dataColumnTypes}, descending through {@link StructType} fields for + * dot-separated nested references. Returns true only when every part resolves to an actual field in the + * schema. + */ + private static boolean resolveFieldPath(String[] parts, Map dataColumnTypes) { + if (parts.length == 0) { + return false; + } + DataType current = dataColumnTypes.get(parts[0]); + if (current == null) { + return false; + } + for (int i = 1; i < parts.length; i++) { + if (!(current instanceof StructType struct)) { + return false; + } + Optional field = findField(struct, parts[i]); + if (field.isEmpty()) { + return false; + } + current = field.get().dataType(); + } + return true; + } + + private static Optional findField(StructType struct, String name) { + return Arrays.stream(struct.fields()).filter(structField -> structField.name().equals(name)).findFirst(); + } + + private static boolean isStructurallyPushable(Predicate predicate) { + if (predicate instanceof AlwaysTrue || predicate instanceof AlwaysFalse) { + return true; + } + if (predicate instanceof And a) { + return isStructurallyPushable(a.left()) && isStructurallyPushable(a.right()); + } + if (predicate instanceof Or o) { + return isStructurallyPushable(o.left()) && isStructurallyPushable(o.right()); + } + if (predicate instanceof Not n) { + return isStructurallyPushable(n.child()); + } + + org.apache.spark.sql.connector.expressions.Expression[] children = predicate.children(); + return switch (predicate.name()) { + case "=", "<>", "!=", ">", ">=", "<", "<=" -> isPushableComparison(children); + case "IS_NULL", "IS_NOT_NULL" -> children.length == 1 && isPushableFieldRef(children[0]); + case "IN" -> { + if (children.length < 2 || !isPushableFieldRef(children[0])) { + yield false; + } + for (int i = 1; i < children.length; i++) { + if (!isPushableLiteral(children[i])) { + yield false; + } + } + yield true; + } + case "STARTS_WITH", "ENDS_WITH", "CONTAINS" -> + children.length == 2 && isPushableFieldRef(children[0]) && isPushableStringLiteral(children[1]); + // `BOOLEAN_EXPRESSION` wraps a bare boolean-valued child. We only handle the case + // where the child itself is a field reference (e.g. `WHERE bool_col`). + case "BOOLEAN_EXPRESSION" -> children.length == 1 && isPushableFieldRef(children[0]); + default -> false; + }; + } + + /** + * Converts a Spark predicate to a Vortex expression. Returns {@link Optional#empty()} if the predicate cannot be + * translated; callers should normally pre-check with {@link #isPushable}. + */ + static Optional convert(Predicate predicate) { + if (predicate instanceof AlwaysTrue) { + return Optional.of(Expression.literal(true)); + } + if (predicate instanceof AlwaysFalse) { + return Optional.of(Expression.literal(false)); + } + if (predicate instanceof And a) { + Optional left = convert(a.left()); + Optional right = convert(a.right()); + if (left.isPresent() && right.isPresent()) { + return Optional.of(Expression.and(left.get(), right.get())); + } + return Optional.empty(); + } + if (predicate instanceof Or o) { + Optional left = convert(o.left()); + Optional right = convert(o.right()); + if (left.isPresent() && right.isPresent()) { + return Optional.of(Expression.or(left.get(), right.get())); + } + return Optional.empty(); + } + if (predicate instanceof Not n) { + return convert(n.child()).map(Expression::not); + } + org.apache.spark.sql.connector.expressions.Expression[] children = predicate.children(); + return switch (predicate.name()) { + case "=", "<>", "!=", ">", ">=", "<", "<=" -> convertComparison(predicate.name(), children); + case "IS_NULL" -> children.length == 1 ? columnOf(children[0]).map(Expression::isNull) : Optional.empty(); + case "IS_NOT_NULL" -> + children.length == 1 ? columnOf(children[0]).map(Expression::isNotNull) : Optional.empty(); + case "IN" -> convertIn(children); + case "STARTS_WITH" -> + convertStringMatch(children, /* leadingWildcard= */ false, /* trailingWildcard= */ true); + case "ENDS_WITH" -> + convertStringMatch(children, /* leadingWildcard= */ true, /* trailingWildcard= */ false); + case "CONTAINS" -> convertStringMatch(children, /* leadingWildcard= */ true, /* trailingWildcard= */ true); + case "BOOLEAN_EXPRESSION" -> children.length == 1 ? columnOf(children[0]) : Optional.empty(); + default -> Optional.empty(); + }; + } + + private static Optional convertComparison( + String op, org.apache.spark.sql.connector.expressions.Expression[] children) { + if (children.length != 2) { + return Optional.empty(); + } + // Allow either side to be the column; Spark's V2 builder sometimes commutes. + Optional lhs = exprOf(children[0]); + Optional rhs = exprOf(children[1]); + if (lhs.isEmpty() || rhs.isEmpty()) { + return Optional.empty(); + } + // We require at least one side to be a column reference to keep the surface small and to + // match what Vortex pushdown understands. + boolean lhsIsCol = isFieldRefExpr(children[0]); + boolean rhsIsCol = isFieldRefExpr(children[1]); + if (!lhsIsCol && !rhsIsCol) { + return Optional.empty(); + } + BinaryOp binaryOp = toBinaryOp(op); + // Canonicalize so the column is on the left when only one side is a column. + if (!lhsIsCol) { + binaryOp = swap(binaryOp); + Expression tmp = lhs.get(); + return Optional.of(Expression.binary(binaryOp, rhs.get(), tmp)); + } + return Optional.of(Expression.binary(binaryOp, lhs.get(), rhs.get())); + } + + private static Optional convertIn(org.apache.spark.sql.connector.expressions.Expression[] children) { + if (children.length < 2) { + return Optional.empty(); + } + Optional column = columnOf(children[0]); + if (column.isEmpty()) { + return Optional.empty(); + } + Expression columnExpr = column.get(); + List eqs = new ArrayList<>(children.length - 1); + for (int i = 1; i < children.length; i++) { + Optional literal = literalOf(children[i]); + if (literal.isEmpty()) { + return Optional.empty(); + } + eqs.add(Expression.binary(BinaryOp.EQ, columnExpr, literal.get())); + } + if (eqs.size() == 1) { + return Optional.of(eqs.get(0)); + } + return Optional.of(Expression.or(eqs.toArray(new Expression[0]))); + } + + private static Optional convertStringMatch( + org.apache.spark.sql.connector.expressions.Expression[] children, + boolean leadingWildcard, + boolean trailingWildcard) { + if (children.length != 2) { + return Optional.empty(); + } + Optional column = columnOf(children[0]); + Optional needle = stringValueOf(children[1]); + if (column.isEmpty() || needle.isEmpty()) { + return Optional.empty(); + } + String pattern = buildLikePattern(needle.get(), leadingWildcard, trailingWildcard); + return Optional.of(Expression.like( + column.get(), Expression.literal(pattern), /* negated= */ false, /* caseInsensitive= */ false)); + } + + /** + * Build a LIKE pattern from a literal substring, escaping the {@code %}, {@code _}, and {@code \} meta-characters + * so the Spark {@code STARTS_WITH}/{@code ENDS_WITH}/{@code CONTAINS} semantics (exact substring match) are + * preserved. + */ + private static String buildLikePattern(String literal, boolean leadingWildcard, boolean trailingWildcard) { + StringBuilder sb = new StringBuilder(literal.length() + 2); + if (leadingWildcard) { + sb.append('%'); + } + for (int i = 0; i < literal.length(); i++) { + char c = literal.charAt(i); + if (c == '%' || c == '_' || c == '\\') { + sb.append('\\'); + } + sb.append(c); + } + if (trailingWildcard) { + sb.append('%'); + } + return sb.toString(); + } + + private static BinaryOp toBinaryOp(String name) { + return switch (name) { + case "=" -> BinaryOp.EQ; + case "<>", "!=" -> BinaryOp.NOT_EQ; + case ">" -> BinaryOp.GT; + case ">=" -> BinaryOp.GTE; + case "<" -> BinaryOp.LT; + case "<=" -> BinaryOp.LTE; + default -> throw new IllegalArgumentException("not a pushable comparison operator: " + name); + }; + } + + private static BinaryOp swap(BinaryOp op) { + return switch (op) { + case EQ, NOT_EQ -> op; + case GT -> BinaryOp.LT; + case GTE -> BinaryOp.LTE; + case LT -> BinaryOp.GT; + case LTE -> BinaryOp.GTE; + default -> throw new IllegalArgumentException("not a comparison operator: " + op); + }; + } + + private static boolean isPushableComparison(org.apache.spark.sql.connector.expressions.Expression[] children) { + if (children.length != 2) { + return false; + } + boolean lhsCol = isPushableFieldRef(children[0]); + boolean lhsLit = isPushableLiteral(children[0]); + boolean rhsCol = isPushableFieldRef(children[1]); + boolean rhsLit = isPushableLiteral(children[1]); + boolean lhsOk = lhsCol || lhsLit; + boolean rhsOk = rhsCol || rhsLit; + // We need at least one column reference; otherwise the predicate is comparing two + // constants — Spark normally folds those, so we don't bother. + return lhsOk && rhsOk && (lhsCol || rhsCol); + } + + private static boolean isPushableFieldRef(org.apache.spark.sql.connector.expressions.Expression expr) { + return expr instanceof NamedReference && ((NamedReference) expr).fieldNames().length >= 1; + } + + private static boolean isFieldRefExpr(org.apache.spark.sql.connector.expressions.Expression expr) { + return expr instanceof NamedReference; + } + + /** + * Returns the Vortex column expression for a Spark named reference, walking nested struct fields. + */ + private static Optional columnOf(org.apache.spark.sql.connector.expressions.Expression expr) { + if (!(expr instanceof NamedReference)) { + return Optional.empty(); + } + String[] parts = ((NamedReference) expr).fieldNames(); + if (parts.length == 0) { + return Optional.empty(); + } + return Optional.of(Expression.column(parts)); + } + + private static Optional exprOf(org.apache.spark.sql.connector.expressions.Expression expr) { + Optional col = columnOf(expr); + if (col.isPresent()) { + return col; + } + return literalOf(expr); + } + + private static Optional stringValueOf(org.apache.spark.sql.connector.expressions.Expression expr) { + if (!(expr instanceof Literal)) { + return Optional.empty(); + } + Object value = ((Literal) expr).value(); + if (value == null) { + return Optional.empty(); + } + if (value instanceof UTF8String) { + return Optional.of(value.toString()); + } + if (value instanceof CharSequence) { + return Optional.of(value.toString()); + } + return Optional.empty(); + } + + private static boolean isPushableStringLiteral(org.apache.spark.sql.connector.expressions.Expression expr) { + return stringValueOf(expr).isPresent(); + } + + private static boolean isPushableLiteral(org.apache.spark.sql.connector.expressions.Expression expr) { + if (!(expr instanceof Literal)) { + return false; + } + Literal lit = (Literal) expr; + DataType dataType = lit.dataType(); + // Null literals are pushable (we emit a typed null literal). + if (lit.value() == null) { + return dataType instanceof BooleanType + || dataType instanceof ByteType + || dataType instanceof ShortType + || dataType instanceof IntegerType + || dataType instanceof LongType + || dataType instanceof FloatType + || dataType instanceof DoubleType + || dataType instanceof StringType + || dataType instanceof BinaryType + || dataType instanceof DateType + || dataType instanceof TimestampType + || dataType instanceof TimestampNTZType + || dataType instanceof DecimalType; + } + return literalOf(expr).isPresent(); + } + + private static Optional literalOf(org.apache.spark.sql.connector.expressions.Expression expr) { + if (!(expr instanceof Literal)) { + return Optional.empty(); + } + Literal lit = (Literal) expr; + Object value = lit.value(); + DataType dataType = lit.dataType(); + return convertLiteral(value, dataType); + } + + private static Optional convertLiteral(Object value, DataType dataType) { + if (dataType instanceof BooleanType) { + if (value == null) { + return Optional.of(Expression.nullLiteralBool()); + } + return Optional.of(Expression.literal((Boolean) value)); + } + if (dataType instanceof ByteType) { + if (value == null) { + return Optional.of(Expression.nullLiteral(Expression.DType.I8)); + } + return Optional.of(Expression.literal(((Number) value).byteValue())); + } + if (dataType instanceof ShortType) { + if (value == null) { + return Optional.of(Expression.nullLiteral(Expression.DType.I16)); + } + return Optional.of(Expression.literal(((Number) value).shortValue())); + } + if (dataType instanceof IntegerType) { + if (value == null) { + return Optional.of(Expression.nullLiteral(Expression.DType.I32)); + } + return Optional.of(Expression.literal(((Number) value).intValue())); + } + if (dataType instanceof LongType) { + if (value == null) { + return Optional.of(Expression.nullLiteral(Expression.DType.I64)); + } + return Optional.of(Expression.literal(((Number) value).longValue())); + } + if (dataType instanceof FloatType) { + if (value == null) { + return Optional.of(Expression.nullLiteral(Expression.DType.F32)); + } + return Optional.of(Expression.literal(((Number) value).floatValue())); + } + if (dataType instanceof DoubleType) { + if (value == null) { + return Optional.of(Expression.nullLiteral(Expression.DType.F64)); + } + return Optional.of(Expression.literal(((Number) value).doubleValue())); + } + if (dataType instanceof StringType) { + if (value == null) { + return Optional.of(Expression.nullLiteral(Expression.DType.UTF8)); + } + if (value instanceof UTF8String || value instanceof CharSequence) { + return Optional.of(Expression.literal(value.toString())); + } + } + if (dataType instanceof BinaryType) { + if (value == null) { + return Optional.of(Expression.nullLiteral(Expression.DType.BINARY)); + } + if (value instanceof byte[]) { + return Optional.of(Expression.literal((byte[]) value)); + } + } + if (dataType instanceof DateType) { + // Spark stores DateType as a 32-bit int day count since 1970-01-01. + if (value == null) { + return Optional.of(Expression.nullLiteralDate(TimeUnit.DAYS)); + } + return Optional.of(Expression.literalDate(((Number) value).longValue(), TimeUnit.DAYS)); + } + if (dataType instanceof TimestampType) { + // Spark stores TimestampType as a 64-bit microseconds-since-epoch in UTC. + if (value == null) { + return Optional.of(Expression.nullLiteralTimestamp(TimeUnit.MICROSECONDS, "UTC")); + } + return Optional.of(Expression.literalTimestamp(((Number) value).longValue(), TimeUnit.MICROSECONDS, "UTC")); + } + if (dataType instanceof TimestampNTZType) { + if (value == null) { + return Optional.of(Expression.nullLiteralTimestamp(TimeUnit.MICROSECONDS, null)); + } + return Optional.of(Expression.literalTimestamp(((Number) value).longValue(), TimeUnit.MICROSECONDS, null)); + } + if (dataType instanceof DecimalType) { + DecimalType decimalType = (DecimalType) dataType; + int precision = decimalType.precision(); + int scale = decimalType.scale(); + if (value == null) { + return Optional.of(Expression.nullLiteralDecimal(precision, scale)); + } + BigInteger unscaled = unscaledValueOf(value, scale); + if (unscaled == null) { + return Optional.empty(); + } + return Optional.of(Expression.literalDecimal(unscaled, precision, scale)); + } + // Some Spark literals (e.g. NullType, GeographyType) have no Vortex representation. + return Optional.empty(); + } + + /** + * Extract the unscaled integer value of a Spark decimal literal at the supplied {@code scale}. + */ + private static BigInteger unscaledValueOf(Object value, int scale) { + BigDecimal decimal; + if (value instanceof Decimal) { + decimal = ((Decimal) value).toJavaBigDecimal(); + } else if (value instanceof BigDecimal) { + decimal = (BigDecimal) value; + } else { + return null; + } + try { + return decimal.setScale(scale).unscaledValue(); + } catch (ArithmeticException ignored) { + return null; + } + } +} diff --git a/java/vortex-spark/src/main/java/dev/vortex/spark/read/VortexBatchExec.java b/java/vortex-spark/src/main/java/dev/vortex/spark/read/VortexBatchExec.java index e32d26f5643..8df7ce8e1db 100644 --- a/java/vortex-spark/src/main/java/dev/vortex/spark/read/VortexBatchExec.java +++ b/java/vortex-spark/src/main/java/dev/vortex/spark/read/VortexBatchExec.java @@ -16,6 +16,7 @@ import java.util.stream.Stream; import org.apache.spark.sql.connector.catalog.CatalogV2Util; import org.apache.spark.sql.connector.catalog.Column; +import org.apache.spark.sql.connector.expressions.filter.Predicate; import org.apache.spark.sql.connector.read.Batch; import org.apache.spark.sql.connector.read.InputPartition; import org.apache.spark.sql.connector.read.PartitionReaderFactory; @@ -26,6 +27,7 @@ public final class VortexBatchExec implements Batch { private final List paths; private final StructType readSchema; private final Map formatOptions; + private final Predicate[] pushedPredicates; private List resolvedPaths; /** @@ -33,11 +35,15 @@ public final class VortexBatchExec implements Batch { * * @param paths the list of file paths to scan * @param columns the list of columns to read from the files + * @param pushedPredicates predicates pushed down by Spark; converted to a single Vortex filter expression at read + * time */ - public VortexBatchExec(List paths, List columns, Map formatOptions) { + public VortexBatchExec( + List paths, List columns, Map formatOptions, Predicate[] pushedPredicates) { this.paths = List.copyOf(paths); this.readSchema = CatalogV2Util.v2ColumnsToStructType(columns.toArray(new Column[0])); this.formatOptions = Map.copyOf(formatOptions); + this.pushedPredicates = pushedPredicates == null ? new Predicate[0] : pushedPredicates.clone(); } /** @@ -66,7 +72,7 @@ public PartitionReaderFactory createReaderFactory() { List dataColumnNames = Arrays.stream(readSchema.fieldNames()) .filter(name -> !partitionColumns.contains(name)) .collect(Collectors.toList()); - return new VortexPartitionReaderFactory(dataColumnNames, formatOptions); + return new VortexPartitionReaderFactory(dataColumnNames, formatOptions, pushedPredicates); } private List resolvePaths() { diff --git a/java/vortex-spark/src/main/java/dev/vortex/spark/read/VortexPartitionReader.java b/java/vortex-spark/src/main/java/dev/vortex/spark/read/VortexPartitionReader.java index 9df44c07e6c..f9cd2363d59 100644 --- a/java/vortex-spark/src/main/java/dev/vortex/spark/read/VortexPartitionReader.java +++ b/java/vortex-spark/src/main/java/dev/vortex/spark/read/VortexPartitionReader.java @@ -18,6 +18,8 @@ import java.io.IOException; import java.util.List; import java.util.Map; +import java.util.Optional; +import org.apache.spark.sql.connector.expressions.filter.Predicate; import org.apache.spark.sql.connector.read.PartitionReader; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.vectorized.ColumnVector; @@ -46,7 +48,11 @@ final class VortexPartitionReader implements PartitionReader { private boolean currentBatchLoaded; private boolean exhausted; - VortexPartitionReader(VortexFilePartition spark, List dataColumnNames, Map formatOptions) { + VortexPartitionReader( + VortexFilePartition spark, + List dataColumnNames, + Map formatOptions, + Predicate[] pushedPredicates) { this.spark = spark; this.allocator = ArrowAllocation.rootAllocator(); @@ -58,9 +64,24 @@ final class VortexPartitionReader implements PartitionReader { Expression projection = Expression.select(dataColumnNames.toArray(new String[0]), Expression.root()); options.projection(projection); } + if (pushedPredicates != null && pushedPredicates.length > 0) { + buildFilterExpression(pushedPredicates).ifPresent(options::filter); + } scan = dataSource.scan(options.build()); } + private static Optional buildFilterExpression(Predicate[] predicates) { + Expression combined = null; + for (Predicate predicate : predicates) { + Optional expr = SparkPredicateToVortexExpression.convert(predicate); + if (expr.isEmpty()) { + continue; + } + combined = combined == null ? expr.get() : Expression.and(combined, expr.get()); + } + return Optional.ofNullable(combined); + } + @Override public boolean next() { if (exhausted) { diff --git a/java/vortex-spark/src/main/java/dev/vortex/spark/read/VortexPartitionReaderFactory.java b/java/vortex-spark/src/main/java/dev/vortex/spark/read/VortexPartitionReaderFactory.java index 9ffbfcc3cfb..e187e4863b1 100644 --- a/java/vortex-spark/src/main/java/dev/vortex/spark/read/VortexPartitionReaderFactory.java +++ b/java/vortex-spark/src/main/java/dev/vortex/spark/read/VortexPartitionReaderFactory.java @@ -5,10 +5,13 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import dev.vortex.jni.NativeRuntime; import dev.vortex.spark.VortexFilePartition; import java.io.Serializable; import java.util.List; +import java.util.Map; import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.expressions.filter.Predicate; import org.apache.spark.sql.connector.read.InputPartition; import org.apache.spark.sql.connector.read.PartitionReader; import org.apache.spark.sql.connector.read.PartitionReaderFactory; @@ -27,10 +30,13 @@ public final class VortexPartitionReaderFactory implements PartitionReaderFactor private final ImmutableList dataColumnNames; private final ImmutableMap formatOptions; + private final Predicate[] pushedPredicates; - public VortexPartitionReaderFactory(List dataColumnNames, java.util.Map formatOptions) { + public VortexPartitionReaderFactory( + List dataColumnNames, Map formatOptions, Predicate[] pushedPredicates) { this.dataColumnNames = ImmutableList.copyOf(dataColumnNames); this.formatOptions = ImmutableMap.copyOf(formatOptions); + this.pushedPredicates = pushedPredicates == null ? new Predicate[0] : pushedPredicates.clone(); } @Override @@ -40,8 +46,9 @@ public PartitionReader createReader(InputPartition partition) { @Override public PartitionReader createColumnarReader(InputPartition partition) { + NativeRuntime.setWorkerThreads(Integer.parseInt(formatOptions.getOrDefault("vortex.workerThreads", "4"))); VortexFilePartition spark = (VortexFilePartition) partition; - return new VortexPartitionReader(spark, dataColumnNames, formatOptions); + return new VortexPartitionReader(spark, dataColumnNames, formatOptions, pushedPredicates); } @Override diff --git a/java/vortex-spark/src/main/java/dev/vortex/spark/read/VortexScan.java b/java/vortex-spark/src/main/java/dev/vortex/spark/read/VortexScan.java index c6ec03eef80..d5949b57a4d 100644 --- a/java/vortex-spark/src/main/java/dev/vortex/spark/read/VortexScan.java +++ b/java/vortex-spark/src/main/java/dev/vortex/spark/read/VortexScan.java @@ -3,10 +3,12 @@ package dev.vortex.spark.read; +import java.util.Arrays; import java.util.List; import java.util.Map; import org.apache.spark.sql.connector.catalog.CatalogV2Util; import org.apache.spark.sql.connector.catalog.Column; +import org.apache.spark.sql.connector.expressions.filter.Predicate; import org.apache.spark.sql.connector.read.Batch; import org.apache.spark.sql.connector.read.Scan; import org.apache.spark.sql.types.StructType; @@ -17,6 +19,7 @@ public final class VortexScan implements Scan { private final List paths; private final List readColumns; private final Map formatOptions; + private final Predicate[] pushedPredicates; /** * Creates a new VortexScan for the specified file paths and columns. The caller is responsible for passing @@ -24,11 +27,17 @@ public final class VortexScan implements Scan { * * @param paths the list of Vortex file paths to scan * @param readColumns the list of columns to read from the files + * @param pushedPredicates predicates pushed down by Spark; {@code null} or empty means no pushdown */ - public VortexScan(List paths, List readColumns, Map formatOptions) { + public VortexScan( + List paths, + List readColumns, + Map formatOptions, + Predicate[] pushedPredicates) { this.paths = paths; this.readColumns = readColumns; this.formatOptions = formatOptions; + this.pushedPredicates = pushedPredicates == null ? new Predicate[0] : pushedPredicates.clone(); } /** @@ -46,7 +55,9 @@ public StructType readSchema() { /** Logging-friendly readable description of the scan source. */ @Override public String description() { - return String.format("VortexScan{paths=%s, columns=%s}", paths, readColumns); + return String.format( + "VortexScan{paths=%s, columns=%s, pushedPredicates=%s}", + paths, readColumns, Arrays.toString(pushedPredicates)); } /** @@ -58,7 +69,7 @@ public String description() { */ @Override public Batch toBatch() { - return new VortexBatchExec(paths, readColumns, formatOptions); + return new VortexBatchExec(paths, readColumns, formatOptions, pushedPredicates); } /** diff --git a/java/vortex-spark/src/main/java/dev/vortex/spark/read/VortexScanBuilder.java b/java/vortex-spark/src/main/java/dev/vortex/spark/read/VortexScanBuilder.java index a53472bc33b..94990432b45 100644 --- a/java/vortex-spark/src/main/java/dev/vortex/spark/read/VortexScanBuilder.java +++ b/java/vortex-spark/src/main/java/dev/vortex/spark/read/VortexScanBuilder.java @@ -6,27 +6,54 @@ import static com.google.common.base.Preconditions.checkState; import com.google.common.collect.ImmutableList; +import com.google.common.collect.Maps; import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Set; +import org.apache.spark.sql.connector.catalog.CatalogV2Util; import org.apache.spark.sql.connector.catalog.Column; +import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.expressions.Transform; +import org.apache.spark.sql.connector.expressions.filter.Predicate; import org.apache.spark.sql.connector.read.Scan; import org.apache.spark.sql.connector.read.ScanBuilder; import org.apache.spark.sql.connector.read.SupportsPushDownRequiredColumns; -import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.connector.read.SupportsPushDownV2Filters; +import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.StructType; /** Spark V2 {@link ScanBuilder} for table scans over Vortex files. */ -public final class VortexScanBuilder implements ScanBuilder, SupportsPushDownRequiredColumns { +public final class VortexScanBuilder + implements ScanBuilder, SupportsPushDownRequiredColumns, SupportsPushDownV2Filters { private final ImmutableList.Builder paths; private final List columns; private final Map formatOptions; + private final Set partitionColumnNames; + private Predicate[] pushedPredicates = new Predicate[0]; /** Creates a new VortexScanBuilder with empty paths and columns. */ public VortexScanBuilder(Map formatOptions) { + this(formatOptions, new Transform[0]); + } + + /** + * Creates a new VortexScanBuilder with empty paths and columns and the supplied partition transforms. Filters that + * reference partition columns are not pushed down, since the partition columns are not stored inside the Vortex + * files. + */ + public VortexScanBuilder(Map formatOptions, Transform[] partitionTransforms) { this.paths = ImmutableList.builder(); this.columns = new ArrayList<>(); - this.formatOptions = Map.copyOf(formatOptions); + Map options = Maps.newHashMap(); + options.put("vortex.workerThreads", "4"); + options.putAll(formatOptions); + this.formatOptions = options; + this.partitionColumnNames = collectPartitionColumnNames(partitionTransforms); } /** @@ -89,7 +116,7 @@ public Scan build() { // Allow empty columns for operations like count() that don't need actual column data // If no columns are specified, we'll read the minimal schema needed - return new VortexScan(paths, List.copyOf(this.columns), this.formatOptions); + return new VortexScan(paths, List.copyOf(this.columns), this.formatOptions, pushedPredicates); } /** @@ -102,10 +129,59 @@ public Scan build() { */ @Override public void pruneColumns(StructType requiredSchema) { - // TODO(aduffy): support deeply nested schema prunes columns.clear(); - for (StructField field : requiredSchema.fields()) { - columns.add(Column.create(field.name(), field.dataType())); + columns.addAll(Arrays.asList(CatalogV2Util.structTypeToV2Columns(requiredSchema))); + } + + /** + * Splits the supplied predicates into pushed and not-pushed sets. + * + *

A predicate is pushed when it references only data columns (not partition columns) and uses operators and + * literal types that {@link SparkPredicateToVortexExpression} can map to Vortex expressions. Predicates that + * reference partition columns or use unsupported features are returned to Spark for post-scan evaluation. + * + * @return the predicates that Spark must still evaluate + */ + @Override + public Predicate[] pushPredicates(Predicate[] predicates) { + Map dataColumnTypes = new HashMap<>(); + for (Column column : columns) { + if (!partitionColumnNames.contains(column.name())) { + dataColumnTypes.put(column.name(), column.dataType()); + } + } + List pushed = new ArrayList<>(); + List postScan = new ArrayList<>(); + for (Predicate predicate : predicates) { + if (SparkPredicateToVortexExpression.isPushable(predicate, dataColumnTypes)) { + pushed.add(predicate); + } else { + postScan.add(predicate); + } + } + this.pushedPredicates = pushed.toArray(new Predicate[0]); + return postScan.toArray(new Predicate[0]); + } + + /** Returns the predicates this scan promises to apply. */ + @Override + public Predicate[] pushedPredicates() { + return Arrays.copyOf(pushedPredicates, pushedPredicates.length); + } + + private static Set collectPartitionColumnNames(Transform[] transforms) { + if (transforms == null || transforms.length == 0) { + return Collections.emptySet(); + } + Set names = new HashSet<>(); + for (Transform transform : transforms) { + for (NamedReference ref : transform.references()) { + String[] parts = ref.fieldNames(); + if (parts.length == 1) { + names.add(parts[0]); + } + } } + return names; } } diff --git a/java/vortex-spark/src/test/java/dev/vortex/spark/VortexFilterPushdownTest.java b/java/vortex-spark/src/test/java/dev/vortex/spark/VortexFilterPushdownTest.java new file mode 100644 index 00000000000..61087837042 --- /dev/null +++ b/java/vortex-spark/src/test/java/dev/vortex/spark/VortexFilterPushdownTest.java @@ -0,0 +1,589 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +package dev.vortex.spark; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.io.IOException; +import java.math.BigDecimal; +import java.nio.file.Files; +import java.nio.file.Path; +import java.sql.Date; +import java.sql.Timestamp; +import java.util.Arrays; +import java.util.Comparator; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SaveMode; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.execution.QueryExecution; +import org.apache.spark.sql.execution.SparkPlan; +import org.apache.spark.sql.functions; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructType; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.api.io.TempDir; + +/** + * Tests that Spark predicate pushdown into the Vortex datasource produces correct results. + * + *

The tests write a Vortex dataset and then read it back applying various Spark filters. The + * {@code VortexScanBuilder.pushFilters} path attempts to translate each filter to a Vortex {@code Expression}; filters + * it cannot translate (or that reference partition columns) are returned to Spark for post-scan evaluation. Either way + * the final result must match the same query against the original DataFrame. + */ +@TestInstance(TestInstance.Lifecycle.PER_CLASS) +public final class VortexFilterPushdownTest { + + private SparkSession spark; + + @TempDir + Path tempDir; + + @BeforeAll + public void setUp() { + spark = SparkSession.builder() + .appName("VortexFilterPushdownTest") + .master("local[2]") + .config("spark.driver.host", "127.0.0.1") + .config("spark.sql.shuffle.partitions", "2") + .config("spark.sql.adaptive.enabled", "false") + .config("spark.ui.enabled", "false") + .getOrCreate(); + } + + @AfterAll + public void tearDown() { + if (spark != null) { + spark.stop(); + } + } + + @Test + @DisplayName("Equality, comparison, IS NULL, IN, AND/OR/NOT all return correct rows after pushdown") + public void testFilterPushdownCorrectness() throws IOException { + Dataset df = spark.createDataFrame( + Arrays.asList( + RowFactory.create(1, "alpha", 10L, true), + RowFactory.create(2, "beta", 20L, false), + RowFactory.create(3, "gamma", 30L, true), + RowFactory.create(4, "delta", null, false), + RowFactory.create(5, null, 50L, true)), + DataTypes.createStructType(Arrays.asList( + DataTypes.createStructField("id", DataTypes.IntegerType, false), + DataTypes.createStructField("name", DataTypes.StringType, true), + DataTypes.createStructField("amount", DataTypes.LongType, true), + DataTypes.createStructField("flag", DataTypes.BooleanType, false)))); + + Path outputPath = tempDir.resolve("pushdown_basic"); + df.write() + .format("vortex") + .option("path", outputPath.toUri().toString()) + .mode(SaveMode.Overwrite) + .save(); + + Dataset readDf = spark.read() + .format("vortex") + .option("path", outputPath.toUri().toString()) + .load(); + + assertEquals( + List.of(2), idsOf(readDf.filter(readDf.col("id").equalTo(2)).orderBy("id"))); + + assertEquals( + List.of(3, 4, 5), idsOf(readDf.filter(readDf.col("id").gt(2)).orderBy("id"))); + + assertEquals(List.of(1, 2), idsOf(readDf.filter(readDf.col("id").leq(2)).orderBy("id"))); + + // != + assertEquals( + List.of(1, 3, 4, 5), + idsOf(readDf.filter(readDf.col("id").notEqual(2)).orderBy("id"))); + + assertEquals( + List.of(1, 3), + idsOf(readDf.filter(readDf.col("name").isin("alpha", "gamma")).orderBy("id"))); + + assertEquals( + List.of(4), idsOf(readDf.filter(readDf.col("amount").isNull()).orderBy("id"))); + + assertEquals( + List.of(1, 2, 3, 5), + idsOf(readDf.filter(readDf.col("amount").isNotNull()).orderBy("id"))); + + assertEquals( + List.of(1, 3), + idsOf(readDf.filter(readDf.col("flag") + .equalTo(true) + .and(readDf.col("amount").lt(40L))) + .orderBy("id"))); + + assertEquals( + List.of(1, 4, 5), + idsOf(readDf.filter(readDf.col("id") + .equalTo(1) + .or(readDf.col("amount").isNull()) + .or(readDf.col("name").isNull())) + .orderBy("id"))); + + // NOT around a pushed predicate. + assertEquals( + List.of(2, 3, 4), + idsOf(readDf.filter(functions.not(readDf.col("name").startsWith("a"))) + .orderBy("id"))); + } + + @Test + @DisplayName("STARTS_WITH / ENDS_WITH / CONTAINS push down via LIKE with metachar escaping") + public void testStringPredicatePushdown() throws IOException { + Dataset df = spark.createDataFrame( + Arrays.asList( + RowFactory.create(1, "alpha_one"), + RowFactory.create(2, "alpha_two"), + RowFactory.create(3, "beta_one"), + RowFactory.create(4, "ab%cd"), // contains a literal % + RowFactory.create(5, null)), + DataTypes.createStructType(Arrays.asList( + DataTypes.createStructField("id", DataTypes.IntegerType, false), + DataTypes.createStructField("label", DataTypes.StringType, true)))); + + Path outputPath = tempDir.resolve("pushdown_strings"); + df.write() + .format("vortex") + .option("path", outputPath.toUri().toString()) + .mode(SaveMode.Overwrite) + .save(); + + Dataset readDf = spark.read() + .format("vortex") + .option("path", outputPath.toUri().toString()) + .load(); + + // STARTS_WITH + assertEquals( + List.of(1, 2), + idsOf(readDf.filter(readDf.col("label").startsWith("alpha")).orderBy("id"))); + + // ENDS_WITH + assertEquals( + List.of(1, 3), + idsOf(readDf.filter(readDf.col("label").endsWith("_one")).orderBy("id"))); + + // CONTAINS (literal underscore must match the underscore character). + assertEquals( + List.of(1, 3), + idsOf(readDf.filter(readDf.col("label").contains("_o")).orderBy("id"))); + + // CONTAINS with no special chars — verify standard substring search works. + assertEquals( + List.of(1, 2), + idsOf(readDf.filter(readDf.col("label").contains("alpha")).orderBy("id"))); + + // Literal "%" should not act as a LIKE wildcard; only id=4 contains it. + assertEquals( + List.of(4), + idsOf(readDf.filter(readDf.col("label").contains("%")).orderBy("id"))); + + // STARTS_WITH on an underscore should match the literal underscore character. + assertEquals( + List.of(), + idsOf(readDf.filter(readDf.col("label").startsWith("_")).orderBy("id"))); + } + + @Test + @DisplayName("Date, timestamp, and decimal literals push down through equality and range comparisons") + public void testTemporalAndDecimalPushdown() throws IOException { + StructType schema = DataTypes.createStructType(Arrays.asList( + DataTypes.createStructField("id", DataTypes.IntegerType, false), + DataTypes.createStructField("d", DataTypes.DateType, true), + DataTypes.createStructField("ts", DataTypes.TimestampType, true), + DataTypes.createStructField("amt", DataTypes.createDecimalType(10, 2), true))); + + Dataset df = spark.createDataFrame( + Arrays.asList( + RowFactory.create( + 1, + Date.valueOf("2020-01-01"), + Timestamp.valueOf("2020-01-01 00:00:00"), + new BigDecimal("1.23")), + RowFactory.create( + 2, + Date.valueOf("2021-06-15"), + Timestamp.valueOf("2021-06-15 12:30:00"), + new BigDecimal("99.99")), + RowFactory.create( + 3, + Date.valueOf("2022-12-31"), + Timestamp.valueOf("2022-12-31 23:59:59"), + new BigDecimal("-5.00"))), + schema); + + Path outputPath = tempDir.resolve("pushdown_temporal"); + df.write() + .format("vortex") + .option("path", outputPath.toUri().toString()) + .mode(SaveMode.Overwrite) + .save(); + + Dataset readDf = spark.read() + .format("vortex") + .option("path", outputPath.toUri().toString()) + .load(); + + // Date equality + assertEquals( + List.of(2), + idsOf(readDf.filter(readDf.col("d").equalTo(Date.valueOf("2021-06-15"))) + .orderBy("id"))); + + // Date range + assertEquals( + List.of(2, 3), + idsOf(readDf.filter(readDf.col("d").gt(Date.valueOf("2020-06-01"))) + .orderBy("id"))); + + // Timestamp range + assertEquals( + List.of(1, 2), + idsOf(readDf.filter(readDf.col("ts").lt(Timestamp.valueOf("2022-01-01 00:00:00"))) + .orderBy("id"))); + + // Decimal equality + assertEquals( + List.of(2), + idsOf(readDf.filter(readDf.col("amt").equalTo(new BigDecimal("99.99"))) + .orderBy("id"))); + + // Decimal range + assertEquals( + List.of(3), + idsOf(readDf.filter(readDf.col("amt").lt(new BigDecimal("0.00"))) + .orderBy("id"))); + } + + @Test + @DisplayName("Filters on nested struct fields push down") + public void testNestedFieldPushdown() throws IOException { + StructType inner = DataTypes.createStructType(Arrays.asList( + DataTypes.createStructField("category", DataTypes.StringType, true), + DataTypes.createStructField("score", DataTypes.IntegerType, true))); + StructType schema = DataTypes.createStructType(Arrays.asList( + DataTypes.createStructField("id", DataTypes.IntegerType, false), + DataTypes.createStructField("info", inner, true))); + + Dataset df = spark.createDataFrame( + Arrays.asList( + RowFactory.create(1, RowFactory.create("apple", 10)), + RowFactory.create(2, RowFactory.create("banana", 20)), + RowFactory.create(3, RowFactory.create("cherry", 30)), + RowFactory.create(4, RowFactory.create("apple", 40))), + schema); + + Path outputPath = tempDir.resolve("pushdown_nested"); + df.write() + .format("vortex") + .option("path", outputPath.toUri().toString()) + .mode(SaveMode.Overwrite) + .save(); + + Dataset readDf = spark.read() + .format("vortex") + .option("path", outputPath.toUri().toString()) + .load(); + + assertEquals( + List.of(1, 4), + idsOf(readDf.filter(readDf.col("info.category").equalTo("apple")) + .orderBy("id"))); + + assertEquals( + List.of(3, 4), + idsOf(readDf.filter(readDf.col("info.score").gt(20)).orderBy("id"))); + } + + @Test + @DisplayName("Filters on partition columns yield correct results without pushdown") + public void testFilterOnPartitionColumn() throws IOException { + Dataset df = spark.createDataFrame( + Arrays.asList( + RowFactory.create(1, "alpha", "A"), + RowFactory.create(2, "beta", "B"), + RowFactory.create(3, "gamma", "A"), + RowFactory.create(4, "delta", "B")), + DataTypes.createStructType(Arrays.asList( + DataTypes.createStructField("id", DataTypes.IntegerType, false), + DataTypes.createStructField("name", DataTypes.StringType, true), + DataTypes.createStructField("group", DataTypes.StringType, true)))); + + Path outputPath = tempDir.resolve("pushdown_partitioned"); + df.write() + .format("vortex") + .partitionBy("group") + .option("path", outputPath.toUri().toString()) + .mode(SaveMode.Overwrite) + .save(); + + Dataset readDf = spark.read() + .format("vortex") + .option("path", outputPath.toUri().toString()) + .load(); + + assertEquals( + List.of(1, 3), + idsOf(readDf.filter(readDf.col("group").equalTo("A")).orderBy("id"))); + + // Predicate spanning partition + data columns must still produce the right answer. + assertEquals( + List.of(3), + idsOf(readDf.filter(readDf.col("group") + .equalTo("A") + .and(readDf.col("id").gt(1))) + .orderBy("id"))); + } + + @Test + @DisplayName("Pushed filters appear in the executed scan node") + public void testPushedFiltersInPlan() throws IOException { + Dataset df = spark.createDataFrame( + Arrays.asList(RowFactory.create(1, "x"), RowFactory.create(2, "y"), RowFactory.create(3, "z")), + DataTypes.createStructType(Arrays.asList( + DataTypes.createStructField("id", DataTypes.IntegerType, false), + DataTypes.createStructField("label", DataTypes.StringType, true)))); + + Path outputPath = tempDir.resolve("pushdown_plan"); + df.write() + .format("vortex") + .option("path", outputPath.toUri().toString()) + .mode(SaveMode.Overwrite) + .save(); + + Dataset readDf = spark.read() + .format("vortex") + .option("path", outputPath.toUri().toString()) + .load(); + + Dataset filtered = readDf.filter(readDf.col("id").gt(1)); + QueryExecution qe = filtered.queryExecution(); + SparkPlan plan = qe.executedPlan(); + String planString = plan.toString(); + assertTrue( + planString.contains("id > 1"), + "Expected pushed predicate for id > 1 in the executed plan: " + planString); + } + + @Test + @DisplayName("Deep nesting of AND/OR/NOT pushes down correctly") + public void testDeeplyNestedLogicalPushdown() throws IOException { + Dataset df = spark.createDataFrame( + Arrays.asList( + RowFactory.create(1, 10, "alpha"), + RowFactory.create(2, 20, "beta"), + RowFactory.create(3, 30, "gamma"), + RowFactory.create(4, 40, "delta"), + RowFactory.create(5, 50, "epsilon")), + DataTypes.createStructType(Arrays.asList( + DataTypes.createStructField("id", DataTypes.IntegerType, false), + DataTypes.createStructField("val", DataTypes.IntegerType, false), + DataTypes.createStructField("label", DataTypes.StringType, true)))); + + Path outputPath = tempDir.resolve("pushdown_deep"); + df.write() + .format("vortex") + .option("path", outputPath.toUri().toString()) + .mode(SaveMode.Overwrite) + .save(); + + Dataset readDf = spark.read() + .format("vortex") + .option("path", outputPath.toUri().toString()) + .load(); + + // ((id > 1 AND val < 50) OR label = 'epsilon') AND NOT(label = 'beta') + assertEquals( + List.of(3, 4, 5), + idsOf(readDf.filter(readDf.col("id") + .gt(1) + .and(readDf.col("val").lt(50)) + .or(readDf.col("label").equalTo("epsilon")) + .and(functions.not(readDf.col("label").equalTo("beta")))) + .orderBy("id"))); + } + + @Test + @DisplayName("STARTS_WITH/ENDS_WITH/CONTAINS escape LIKE meta-characters in the literal substring") + public void testStringPredicateEscapeRegression() throws IOException { + // Cover every LIKE meta-character (`%`, `_`, `\\`) as well as a "safe" string to ensure + // ordinary substrings still pass through unchanged. Each fixture row carries the literal + // we will later search for using STARTS_WITH/ENDS_WITH/CONTAINS, plus a "decoy" row that + // would only match if the meta-character were interpreted as a wildcard. + Dataset df = spark.createDataFrame( + Arrays.asList( + RowFactory.create(1, "%pct%"), // contains literal % + RowFactory.create(2, "no-percent"), + RowFactory.create(3, "abc%def"), + RowFactory.create(4, "abXdef"), // would match `%_%` if `_` were a wildcard + RowFactory.create(5, "_under"), + RowFactory.create(6, "no-under"), + RowFactory.create(7, "a\\b"), // single literal backslash between a and b + RowFactory.create(8, "ab"), // would match `a\b` if `\` were stripped + RowFactory.create(9, "trail\\"), // ends with literal backslash + RowFactory.create(10, "%front")), // starts with literal % + DataTypes.createStructType(Arrays.asList( + DataTypes.createStructField("id", DataTypes.IntegerType, false), + DataTypes.createStructField("label", DataTypes.StringType, true)))); + + Path outputPath = tempDir.resolve("pushdown_string_escapes"); + df.write() + .format("vortex") + .option("path", outputPath.toUri().toString()) + .mode(SaveMode.Overwrite) + .save(); + + Dataset readDf = spark.read() + .format("vortex") + .option("path", outputPath.toUri().toString()) + .load(); + + // CONTAINS("%") -- must NOT match rows without a literal `%`. + assertEquals( + List.of(1, 3, 10), + idsOf(readDf.filter(readDf.col("label").contains("%")).orderBy("id"))); + + // STARTS_WITH("%") -- only rows that start with a literal `%`. + assertEquals( + List.of(1, 10), + idsOf(readDf.filter(readDf.col("label").startsWith("%")).orderBy("id"))); + + // ENDS_WITH("%") -- only rows that end with a literal `%`. + assertEquals( + List.of(1), + idsOf(readDf.filter(readDf.col("label").endsWith("%")).orderBy("id"))); + + // CONTAINS("_") -- must NOT match every row; only those with a literal `_`. + assertEquals( + List.of(5), + idsOf(readDf.filter(readDf.col("label").contains("_")).orderBy("id"))); + + // STARTS_WITH("_") -- only rows that start with a literal `_`. + assertEquals( + List.of(5), + idsOf(readDf.filter(readDf.col("label").startsWith("_")).orderBy("id"))); + + // CONTAINS("\\") -- must match rows with a single literal backslash. The Java string + // literal "\\" is a 1-char string containing just `\`. + assertEquals( + List.of(7, 9), + idsOf(readDf.filter(readDf.col("label").contains("\\")).orderBy("id"))); + + // ENDS_WITH("\\") -- only rows ending with a literal backslash. + assertEquals( + List.of(9), + idsOf(readDf.filter(readDf.col("label").endsWith("\\")).orderBy("id"))); + + // CONTAINS("abc%def") -- treat `%` literally; only row 3 has the exact substring. + assertEquals( + List.of(3), + idsOf(readDf.filter(readDf.col("label").contains("abc%def")).orderBy("id"))); + + // Sanity check: a non-meta substring still works. + assertEquals( + List.of(2, 6), + idsOf(readDf.filter(readDf.col("label").contains("no-")).orderBy("id"))); + } + + @Test + @DisplayName("Binary literals push down through equality") + public void testBinaryLiteralPushdown() throws IOException { + StructType schema = DataTypes.createStructType(Arrays.asList( + DataTypes.createStructField("id", DataTypes.IntegerType, false), + DataTypes.createStructField("payload", DataTypes.BinaryType, true))); + + Dataset df = spark.createDataFrame( + Arrays.asList( + RowFactory.create(1, new byte[] {0x01, 0x02, 0x03}), + RowFactory.create(2, new byte[] {0x04, 0x05}), + RowFactory.create(3, new byte[] {0x01, 0x02, 0x03})), + schema); + + Path outputPath = tempDir.resolve("pushdown_binary"); + df.write() + .format("vortex") + .option("path", outputPath.toUri().toString()) + .mode(SaveMode.Overwrite) + .save(); + + Dataset readDf = spark.read() + .format("vortex") + .option("path", outputPath.toUri().toString()) + .load(); + + assertEquals( + List.of(1, 3), + idsOf(readDf.filter(readDf.col("payload").equalTo(new byte[] {0x01, 0x02, 0x03})) + .orderBy("id"))); + } + + @Test + @DisplayName("Bare boolean column reference (e.g. WHERE bool_col) pushes down") + public void testBareBooleanColumnPushdown() throws IOException { + Dataset df = spark.createDataFrame( + Arrays.asList( + RowFactory.create(1, true), + RowFactory.create(2, false), + RowFactory.create(3, true), + RowFactory.create(4, false)), + DataTypes.createStructType(Arrays.asList( + DataTypes.createStructField("id", DataTypes.IntegerType, false), + DataTypes.createStructField("flag", DataTypes.BooleanType, false)))); + + Path outputPath = tempDir.resolve("pushdown_bool"); + df.write() + .format("vortex") + .option("path", outputPath.toUri().toString()) + .mode(SaveMode.Overwrite) + .save(); + + Dataset readDf = spark.read() + .format("vortex") + .option("path", outputPath.toUri().toString()) + .load(); + + // WHERE flag + assertEquals(List.of(1, 3), idsOf(readDf.filter(readDf.col("flag")).orderBy("id"))); + + // WHERE NOT flag + assertEquals( + List.of(2, 4), + idsOf(readDf.filter(functions.not(readDf.col("flag"))).orderBy("id"))); + } + + private static List idsOf(Dataset df) { + return df.collectAsList().stream().map(row -> row.getInt(0)).collect(Collectors.toList()); + } + + @AfterEach + public void cleanupTempFiles() throws IOException { + if (tempDir != null && Files.exists(tempDir)) { + try (Stream paths = Files.walk(tempDir)) { + paths.sorted(Comparator.reverseOrder()).forEach(path -> { + try { + Files.deleteIfExists(path); + } catch (IOException e) { + System.err.println("Failed to delete: " + path); + } + }); + } + } + } +} diff --git a/java/vortex-spark/src/test/java/dev/vortex/spark/read/SparkPredicateToVortexExpressionTest.java b/java/vortex-spark/src/test/java/dev/vortex/spark/read/SparkPredicateToVortexExpressionTest.java new file mode 100644 index 00000000000..b7feafd322c --- /dev/null +++ b/java/vortex-spark/src/test/java/dev/vortex/spark/read/SparkPredicateToVortexExpressionTest.java @@ -0,0 +1,114 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +package dev.vortex.spark.read; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.Map; +import org.apache.spark.sql.connector.expressions.Expression; +import org.apache.spark.sql.connector.expressions.LiteralValue; +import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.expressions.filter.Predicate; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructType; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; + +/** Unit tests for {@link SparkPredicateToVortexExpression#isPushable(Predicate, Map)}. */ +final class SparkPredicateToVortexExpressionTest { + + private static final StructType ADDRESS = DataTypes.createStructType(new org.apache.spark.sql.types.StructField[] { + DataTypes.createStructField("city", DataTypes.StringType, true), + DataTypes.createStructField("zip", DataTypes.IntegerType, true) + }); + + private static final StructType PROFILE = DataTypes.createStructType(new org.apache.spark.sql.types.StructField[] { + DataTypes.createStructField("email", DataTypes.StringType, true), + DataTypes.createStructField("address", ADDRESS, true) + }); + + private static final Map SCHEMA = + Map.of("id", DataTypes.IntegerType, "name", DataTypes.StringType, "profile", PROFILE); + + @Test + @DisplayName("Top-level column reference is pushable when present in the schema") + void topLevelColumnIsPushable() { + Predicate equality = equality(ref("id"), literal(42)); + assertTrue(SparkPredicateToVortexExpression.isPushable(equality, SCHEMA)); + } + + @Test + @DisplayName("Top-level column reference is not pushable when absent from the schema") + void unknownTopLevelColumnIsNotPushable() { + Predicate equality = equality(ref("missing"), literal(0)); + assertFalse(SparkPredicateToVortexExpression.isPushable(equality, SCHEMA)); + } + + @Test + @DisplayName("Nested field reference is pushable when every part resolves under struct types") + void nestedFieldThatExistsIsPushable() { + Predicate equality = equality(ref("profile", "email"), literal("a@b.com")); + assertTrue(SparkPredicateToVortexExpression.isPushable(equality, SCHEMA)); + } + + @Test + @DisplayName("Doubly nested field reference resolves through multiple struct levels") + void doublyNestedFieldIsPushable() { + Predicate equality = equality(ref("profile", "address", "zip"), literal(12345)); + assertTrue(SparkPredicateToVortexExpression.isPushable(equality, SCHEMA)); + } + + @Test + @DisplayName("Nested field that does not exist in the struct is not pushable") + void nestedFieldThatDoesNotExistIsNotPushable() { + Predicate equality = equality(ref("profile", "phone"), literal("555")); + assertFalse(SparkPredicateToVortexExpression.isPushable(equality, SCHEMA)); + } + + @Test + @DisplayName("Descending past a leaf (non-struct) field is not pushable") + void descendingPastLeafFieldIsNotPushable() { + // `name` is a String, not a struct — `name.first` cannot resolve. + Predicate equality = equality(ref("name", "first"), literal("alice")); + assertFalse(SparkPredicateToVortexExpression.isPushable(equality, SCHEMA)); + } + + @Test + @DisplayName("Empty named reference is not pushable") + void emptyReferenceIsNotPushable() { + Predicate equality = equality(ref(), literal(1)); + assertFalse(SparkPredicateToVortexExpression.isPushable(equality, SCHEMA)); + } + + private static Predicate equality(Expression left, Expression right) { + return new Predicate("=", new Expression[] {left, right}); + } + + private static NamedReference ref(String... parts) { + return new TestNamedReference(parts); + } + + private static LiteralValue literal(int value) { + return new LiteralValue<>(value, DataTypes.IntegerType); + } + + private static LiteralValue literal(String value) { + return new LiteralValue<>(org.apache.spark.unsafe.types.UTF8String.fromString(value), DataTypes.StringType); + } + + private static final class TestNamedReference implements NamedReference { + private final String[] fieldNames; + + TestNamedReference(String[] fieldNames) { + this.fieldNames = fieldNames; + } + + @Override + public String[] fieldNames() { + return fieldNames; + } + } +} diff --git a/vortex-array/src/scalar_fn/fns/like/mod.rs b/vortex-array/src/scalar_fn/fns/like/mod.rs index 2e908ca17bd..baaaa278287 100644 --- a/vortex-array/src/scalar_fn/fns/like/mod.rs +++ b/vortex-array/src/scalar_fn/fns/like/mod.rs @@ -241,6 +241,13 @@ enum LikeVariant<'a> { impl<'a> LikeVariant<'a> { /// Parse a LIKE pattern string into its relevant variant fn from_str(string: &str) -> Option> { + // We don't unescape SQL LIKE meta-characters, so fall back when the pattern uses them. + // Returning `None` here disables stat pruning for the predicate, which is sound — the + // LIKE evaluation itself runs and produces the correct answer. + if string.contains('\\') { + return None; + } + let Some(wildcard_pos) = string.find(['%', '_']) else { return Some(LikeVariant::Exact(string)); }; @@ -303,6 +310,60 @@ mod tests { assert_eq!(expr2.to_string(), "$ not ilike \"test*\""); } + #[test] + #[allow(deprecated, reason = "to_bool is fine for one-shot regression test")] + fn ends_with_percent_pattern() { + // The Spark filter pushdown translator emits `%\%` for `endsWith("%")`. Verify the full + // expression Spark sends (IS_NOT_NULL AND LIKE) matches only strings actually ending in `%`. + use crate::ToCanonical; + use crate::arrays::StructArray; + use crate::arrays::VarBinViewArray; + use crate::arrays::bool::BoolArrayExt; + use crate::dtype::FieldNames; + use crate::expr::and; + use crate::expr::is_not_null; + use crate::validity::Validity; + + let pattern = "%\\%"; + let arr = VarBinViewArray::from_iter_str(vec![ + "%pct%", + "no-percent", + "abc%def", + "abXdef", + "_under", + "no-under", + "a\\b", + "ab", + "trail\\", + "%front", + ]); + + let label_expr = get_item("label", root()); + let expr = and( + is_not_null(label_expr.clone()), + like(label_expr, lit(pattern)), + ); + let struct_arr = StructArray::try_new( + FieldNames::from(["label"]), + vec![arr.into_array()], + 10, + Validity::NonNullable, + ) + .unwrap(); + let result = struct_arr.into_array().apply(&expr).unwrap(); + let bools = result.to_bool(); + let bits = bools.to_bit_buffer(); + + let actual: Vec = (0..bits.len()).map(|i| bits.value(i)).collect(); + assert_eq!( + actual, + vec![ + true, false, false, false, false, false, false, false, false, false + ], + "endsWith(\"%\") only matches strings actually ending in `%`", + ); + } + #[test] fn test_like_variant() { // Supported patterns @@ -322,6 +383,12 @@ mod tests { // Unsupported patterns assert_eq!(LikeVariant::from_str("%suffix"), None); assert_eq!(LikeVariant::from_str("_pattern"), None); + + // Patterns containing the LIKE escape character disable pruning to avoid misinterpreting + // `\%` or `\_` as literal wildcards. + assert_eq!(LikeVariant::from_str("%\\%"), None); + assert_eq!(LikeVariant::from_str("\\_%"), None); + assert_eq!(LikeVariant::from_str("foo\\%bar"), None); } #[test] diff --git a/vortex-jni/src/expression.rs b/vortex-jni/src/expression.rs index aaed86efa75..00fa2ad4bc1 100644 --- a/vortex-jni/src/expression.rs +++ b/vortex-jni/src/expression.rs @@ -11,6 +11,7 @@ use std::sync::Arc; use jni::EnvUnowned; +use jni::objects::JByteArray; use jni::objects::JClass; use jni::objects::JLongArray; use jni::objects::JObjectArray; @@ -23,19 +24,34 @@ use jni::sys::jfloat; use jni::sys::jint; use jni::sys::jlong; use jni::sys::jshort; +use vortex::dtype::DType; +use vortex::dtype::DecimalDType; use vortex::dtype::FieldName; +use vortex::dtype::Nullability; +use vortex::dtype::PType; use vortex::expr::Expression; use vortex::expr::and_collect; +use vortex::expr::between; use vortex::expr::get_item; +use vortex::expr::is_not_null; use vortex::expr::is_null; use vortex::expr::lit; use vortex::expr::not; use vortex::expr::or_collect; use vortex::expr::root; use vortex::expr::select; +use vortex::extension::datetime::Date; +use vortex::extension::datetime::TimeUnit; +use vortex::extension::datetime::Timestamp; +use vortex::scalar::DecimalValue; use vortex::scalar::Scalar; +use vortex::scalar::ScalarValue; use vortex::scalar_fn::ScalarFnVTableExt; +use vortex::scalar_fn::fns::between::BetweenOptions; +use vortex::scalar_fn::fns::between::StrictComparison; use vortex::scalar_fn::fns::binary::Binary; +use vortex::scalar_fn::fns::like::Like; +use vortex::scalar_fn::fns::like::LikeOptions; use vortex::scalar_fn::fns::operators::Operator; use crate::errors::JNIError; @@ -69,6 +85,11 @@ fn parse_op(op: jbyte) -> Result { }) } +/// Parse a Vortex [`TimeUnit`] from the wire-encoded byte tag. +fn parse_time_unit(tag: jbyte) -> Result { + TimeUnit::try_from(tag as u8).map_err(JNIError::from) +} + #[unsafe(no_mangle)] pub extern "system" fn Java_dev_vortex_jni_NativeExpression_free( _env: EnvUnowned, @@ -201,6 +222,70 @@ pub extern "system" fn Java_dev_vortex_jni_NativeExpression_isNull( into_raw(is_null(child)) } +#[unsafe(no_mangle)] +pub extern "system" fn Java_dev_vortex_jni_NativeExpression_isNotNull( + _env: EnvUnowned, + _class: JClass, + child: jlong, +) -> jlong { + let child = unsafe { expr_ref(child) }.clone(); + into_raw(is_not_null(child)) +} + +#[unsafe(no_mangle)] +pub extern "system" fn Java_dev_vortex_jni_NativeExpression_like( + _env: EnvUnowned, + _class: JClass, + child: jlong, + pattern: jlong, + negated: jboolean, + case_insensitive: jboolean, +) -> jlong { + let child = unsafe { expr_ref(child) }.clone(); + let pattern = unsafe { expr_ref(pattern) }.clone(); + into_raw(Like.new_expr( + LikeOptions { + negated, + case_insensitive, + }, + [child, pattern], + )) +} + +#[unsafe(no_mangle)] +pub extern "system" fn Java_dev_vortex_jni_NativeExpression_between( + mut env: EnvUnowned, + _class: JClass, + value: jlong, + lower: jlong, + upper: jlong, + lower_strict: jboolean, + upper_strict: jboolean, +) -> jlong { + try_or_throw(&mut env, |_| { + let value = unsafe { expr_ref(value) }.clone(); + let lower = unsafe { expr_ref(lower) }.clone(); + let upper = unsafe { expr_ref(upper) }.clone(); + Ok(into_raw(between( + value, + lower, + upper, + BetweenOptions { + lower_strict: strict_from_bool(lower_strict), + upper_strict: strict_from_bool(upper_strict), + }, + ))) + }) +} + +fn strict_from_bool(value: jboolean) -> StrictComparison { + if value { + StrictComparison::Strict + } else { + StrictComparison::NonStrict + } +} + #[unsafe(no_mangle)] pub extern "system" fn Java_dev_vortex_jni_NativeExpression_literalBool( _env: EnvUnowned, @@ -259,3 +344,209 @@ pub extern "system" fn Java_dev_vortex_jni_NativeExpression_literalString( Ok(into_raw(lit(s))) }) } + +#[unsafe(no_mangle)] +pub extern "system" fn Java_dev_vortex_jni_NativeExpression_literalBinary( + mut env: EnvUnowned, + _class: JClass, + value: JByteArray, +) -> jlong { + try_or_throw(&mut env, |env| { + if value.is_null() { + let scalar = Scalar::null_native::(); + return Ok(into_raw(lit(scalar))); + } + let bytes: Vec = env.convert_byte_array(&value)?; + Ok(into_raw(lit(bytes.as_slice()))) + }) +} + +/// Build a decimal literal from a two's-complement big-endian byte representation of the +/// unscaled value (the format produced by Java's `BigInteger.toByteArray()`). +#[unsafe(no_mangle)] +pub extern "system" fn Java_dev_vortex_jni_NativeExpression_literalDecimal( + mut env: EnvUnowned, + _class: JClass, + unscaled_big_endian: JByteArray, + precision: jint, + scale: jint, + is_null_flag: jboolean, +) -> jlong { + try_or_throw(&mut env, |env| { + let precision = u8::try_from(precision).map_err(|_| { + vortex::error::vortex_err!("decimal precision out of range: {precision}") + })?; + let scale = i8::try_from(scale) + .map_err(|_| vortex::error::vortex_err!("decimal scale out of range: {scale}"))?; + let decimal_dtype = DecimalDType::try_new(precision, scale)?; + if is_null_flag { + return Ok(into_raw(lit(Scalar::null(DType::Decimal( + decimal_dtype, + Nullability::Nullable, + ))))); + } + let bytes = env.convert_byte_array(&unscaled_big_endian)?; + let decimal_value = decimal_value_from_be_bytes(&bytes, &decimal_dtype)?; + let scalar = Scalar::try_new( + DType::Decimal(decimal_dtype, Nullability::NonNullable), + Some(ScalarValue::from(decimal_value)), + )?; + Ok(into_raw(lit(scalar))) + }) +} + +/// Decode a two's-complement big-endian byte array (Java `BigInteger.toByteArray()` format) +/// into the smallest [`DecimalValue`] variant that can hold the precision. +fn decimal_value_from_be_bytes( + bytes: &[u8], + dtype: &DecimalDType, +) -> Result { + if bytes.is_empty() { + throw_runtime!("decimal unscaled value must have at least one byte"); + } + let value = i256_from_twos_complement_be(bytes); + // Pick the narrowest backing integer that fits the dtype's precision. + let required_bits = dtype.required_bit_width(); + if required_bits <= 8 { + let v = value + .maybe_i128() + .and_then(|v| i8::try_from(v).ok()) + .ok_or_else(|| vortex::error::vortex_err!("decimal value does not fit in i8"))?; + Ok(DecimalValue::I8(v)) + } else if required_bits <= 16 { + let v = value + .maybe_i128() + .and_then(|v| i16::try_from(v).ok()) + .ok_or_else(|| vortex::error::vortex_err!("decimal value does not fit in i16"))?; + Ok(DecimalValue::I16(v)) + } else if required_bits <= 32 { + let v = value + .maybe_i128() + .and_then(|v| i32::try_from(v).ok()) + .ok_or_else(|| vortex::error::vortex_err!("decimal value does not fit in i32"))?; + Ok(DecimalValue::I32(v)) + } else if required_bits <= 64 { + let v = value + .maybe_i128() + .and_then(|v| i64::try_from(v).ok()) + .ok_or_else(|| vortex::error::vortex_err!("decimal value does not fit in i64"))?; + Ok(DecimalValue::I64(v)) + } else if required_bits <= 128 { + let v = value + .maybe_i128() + .ok_or_else(|| vortex::error::vortex_err!("decimal value does not fit in i128"))?; + Ok(DecimalValue::I128(v)) + } else { + Ok(DecimalValue::I256(value)) + } +} + +/// Sign-extend a two's-complement big-endian byte slice into an `i256`. +fn i256_from_twos_complement_be(bytes: &[u8]) -> vortex::dtype::i256 { + let mut le = [0u8; 32]; + let len = bytes.len().min(32); + // Most significant byte comes first in big-endian; copy lowest 32 bytes reversed into LE. + for (i, b) in bytes.iter().rev().take(len).enumerate() { + le[i] = *b; + } + // If the original value is negative (high bit of the most-significant byte is set), + // sign-extend the remaining high bytes with 0xff. + if !bytes.is_empty() && (bytes[0] & 0x80) != 0 { + for byte in &mut le[len..] { + *byte = 0xff; + } + } + vortex::dtype::i256::from_le_bytes(le) +} + +#[unsafe(no_mangle)] +pub extern "system" fn Java_dev_vortex_jni_NativeExpression_literalDate( + mut env: EnvUnowned, + _class: JClass, + value: jlong, + time_unit_tag: jbyte, + is_null_flag: jboolean, +) -> jlong { + try_or_throw(&mut env, |_| { + let unit = parse_time_unit(time_unit_tag)?; + let nullability = if is_null_flag { + Nullability::Nullable + } else { + Nullability::NonNullable + }; + let ext = Date::try_new(unit, nullability)?; + let dtype = DType::Extension(ext.erased()); + if is_null_flag { + return Ok(into_raw(lit(Scalar::null(dtype)))); + } + let storage_value = match unit { + TimeUnit::Days => ScalarValue::from(i32::try_from(value).map_err(|_| { + vortex::error::vortex_err!("date value does not fit in i32 days: {value}") + })?), + TimeUnit::Milliseconds => ScalarValue::from(value), + other => throw_runtime!("date does not support time unit {other}"), + }; + Ok(into_raw(lit(Scalar::try_new(dtype, Some(storage_value))?))) + }) +} + +#[unsafe(no_mangle)] +pub extern "system" fn Java_dev_vortex_jni_NativeExpression_literalTimestamp( + mut env: EnvUnowned, + _class: JClass, + value: jlong, + time_unit_tag: jbyte, + timezone: JString, + is_null_flag: jboolean, +) -> jlong { + try_or_throw(&mut env, |env| { + let unit = parse_time_unit(time_unit_tag)?; + let tz: Option> = if timezone.is_null() { + None + } else { + let s: String = timezone.try_to_string(env)?; + Some(Arc::::from(s.as_str())) + }; + let nullability = if is_null_flag { + Nullability::Nullable + } else { + Nullability::NonNullable + }; + let ext = Timestamp::new_with_tz(unit, tz, nullability); + let dtype = DType::Extension(ext.erased()); + if is_null_flag { + return Ok(into_raw(lit(Scalar::null(dtype)))); + } + Ok(into_raw(lit(Scalar::try_new( + dtype, + Some(ScalarValue::from(value)), + )?))) + }) +} + +/// Build a typed null literal whose nullable dtype is selected by `dtype_tag`. +/// +/// Tag values intentionally do not overlap with [`parse_time_unit`]. +/// See `dev.vortex.api.Expression.DType` on the Java side for the source of truth. +#[unsafe(no_mangle)] +pub extern "system" fn Java_dev_vortex_jni_NativeExpression_literalNull( + mut env: EnvUnowned, + _class: JClass, + dtype_tag: jbyte, +) -> jlong { + try_or_throw(&mut env, |_| { + let dtype = match dtype_tag { + 0 => DType::Bool(Nullability::Nullable), + 1 => DType::Primitive(PType::I8, Nullability::Nullable), + 2 => DType::Primitive(PType::I16, Nullability::Nullable), + 3 => DType::Primitive(PType::I32, Nullability::Nullable), + 4 => DType::Primitive(PType::I64, Nullability::Nullable), + 5 => DType::Primitive(PType::F32, Nullability::Nullable), + 6 => DType::Primitive(PType::F64, Nullability::Nullable), + 7 => DType::Utf8(Nullability::Nullable), + 8 => DType::Binary(Nullability::Nullable), + other => throw_runtime!("unknown null dtype tag: {other}"), + }; + Ok(into_raw(lit(Scalar::null(dtype)))) + }) +}