From ba458d7d8d34da92fc1d6b239808da82e1716631 Mon Sep 17 00:00:00 2001 From: yehe Date: Wed, 27 May 2026 21:01:59 +0800 Subject: [PATCH] [spark] Support AlwaysTrue/AlwaysFalse and handle NaN equality in SparkFilterConverter * Add AlwaysTrue / AlwaysFalse to the list of supported Spark filters and translate them to the corresponding Paimon predicates. * Intercept EqualTo with a NaN Float/Double literal and push it down as an IsNaN predicate, since Spark's EqualTo treats NaN as equal to NaN while Paimon's equality predicate would not match any row. * Drop the two obsolete TODOs covered by the changes above. --- .../paimon/spark/SparkFilterConverter.java | 25 ++++++++++++-- .../spark/SparkFilterConverterTest.java | 34 +++++++++++++++++++ 2 files changed, 56 insertions(+), 3 deletions(-) diff --git a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkFilterConverter.java b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkFilterConverter.java index c0b8cfd66be1..31e4c8aec92a 100644 --- a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkFilterConverter.java +++ b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkFilterConverter.java @@ -23,6 +23,8 @@ import org.apache.paimon.types.DataType; import org.apache.paimon.types.RowType; +import org.apache.spark.sql.sources.AlwaysFalse; +import org.apache.spark.sql.sources.AlwaysTrue; import org.apache.spark.sql.sources.And; import org.apache.spark.sql.sources.EqualNullSafe; import org.apache.spark.sql.sources.EqualTo; @@ -54,6 +56,8 @@ public class SparkFilterConverter { public static final List SUPPORT_FILTERS = Arrays.asList( + "AlwaysTrue", + "AlwaysFalse", "EqualTo", "EqualNullSafe", "GreaterThan", @@ -97,10 +101,16 @@ public Predicate convert(Filter filter, boolean ignoreFailure) { } public Predicate convert(Filter filter) { - if (filter instanceof EqualTo) { + if (filter instanceof AlwaysTrue) { + return PredicateBuilder.alwaysTrue(); + } else if (filter instanceof AlwaysFalse) { + return PredicateBuilder.alwaysFalse(); + } else if (filter instanceof EqualTo) { EqualTo eq = (EqualTo) filter; - // TODO deal with isNaN int index = fieldIndex(eq.attribute()); + if (isNaN(eq.value())) { + return builder.isNaN(index); + } Object literal = convertLiteral(index, eq.value()); return builder.equal(index, literal); } else if (filter instanceof EqualNullSafe) { @@ -173,11 +183,20 @@ public Predicate convert(Filter filter) { return builder.contains(index, literal); } - // TODO: AlwaysTrue, AlwaysFalse throw new UnsupportedOperationException( filter + " is unsupported. Support Filters: " + SUPPORT_FILTERS); } + private static boolean isNaN(Object value) { + if (value instanceof Float) { + return Float.isNaN((Float) value); + } + if (value instanceof Double) { + return Double.isNaN((Double) value); + } + return false; + } + public Object convertLiteral(String field, Object value) { return convertLiteral(fieldIndex(field), value); } diff --git a/paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/SparkFilterConverterTest.java b/paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/SparkFilterConverterTest.java index 4248d07d769f..8b5457c9dff6 100644 --- a/paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/SparkFilterConverterTest.java +++ b/paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/SparkFilterConverterTest.java @@ -26,11 +26,15 @@ import org.apache.paimon.types.CharType; import org.apache.paimon.types.DataField; import org.apache.paimon.types.DateType; +import org.apache.paimon.types.DoubleType; +import org.apache.paimon.types.FloatType; import org.apache.paimon.types.IntType; import org.apache.paimon.types.RowType; import org.apache.paimon.types.TimestampType; import org.apache.paimon.types.VarCharType; +import org.apache.spark.sql.sources.AlwaysFalse; +import org.apache.spark.sql.sources.AlwaysTrue; import org.apache.spark.sql.sources.EqualNullSafe; import org.apache.spark.sql.sources.EqualTo; import org.apache.spark.sql.sources.GreaterThan; @@ -229,6 +233,36 @@ public void testDate() { assertThat(localDateExpression).isEqualTo(rawExpression); } + @Test + public void testAlwaysTrueFalse() { + RowType rowType = + new RowType(Collections.singletonList(new DataField(0, "id", new IntType()))); + SparkFilterConverter converter = new SparkFilterConverter(rowType); + + assertThat(converter.convert(new AlwaysTrue())).isEqualTo(PredicateBuilder.alwaysTrue()); + assertThat(converter.convert(new AlwaysFalse())).isEqualTo(PredicateBuilder.alwaysFalse()); + } + + @Test + public void testEqualToNaN() { + RowType rowType = + new RowType( + Arrays.asList( + new DataField(0, "f", new FloatType()), + new DataField(1, "d", new DoubleType()))); + SparkFilterConverter converter = new SparkFilterConverter(rowType); + PredicateBuilder builder = new PredicateBuilder(rowType); + + EqualTo eqNaNFloat = EqualTo.apply("f", Float.NaN); + assertThat(converter.convert(eqNaNFloat)).isEqualTo(builder.isNaN(0)); + + EqualTo eqNaNDouble = EqualTo.apply("d", Double.NaN); + assertThat(converter.convert(eqNaNDouble)).isEqualTo(builder.isNaN(1)); + + EqualTo eqFloat = EqualTo.apply("f", 1.0f); + assertThat(converter.convert(eqFloat)).isEqualTo(builder.equal(0, 1.0f)); + } + @Test public void testIgnoreFailure() { List dataFields = new ArrayList<>();