diff --git a/parquet/src/main/java/org/apache/iceberg/parquet/MessageTypeToType.java b/parquet/src/main/java/org/apache/iceberg/parquet/MessageTypeToType.java index 841777152ee8..22f69605be56 100644 --- a/parquet/src/main/java/org/apache/iceberg/parquet/MessageTypeToType.java +++ b/parquet/src/main/java/org/apache/iceberg/parquet/MessageTypeToType.java @@ -247,6 +247,11 @@ public Optional visit(LogicalTypeAnnotation.JsonLogicalTypeAnnotation json public Optional visit(LogicalTypeAnnotation.BsonLogicalTypeAnnotation bsonType) { return Optional.of(Types.BinaryType.get()); } + + @Override + public Optional visit(LogicalTypeAnnotation.UUIDLogicalTypeAnnotation uuidType) { + return Optional.of(Types.UUIDType.get()); + } } private void addAlias(String name, int fieldId) { diff --git a/parquet/src/main/java/org/apache/iceberg/parquet/Parquet.java b/parquet/src/main/java/org/apache/iceberg/parquet/Parquet.java index f02974d6e79c..f0548003a8c0 100644 --- a/parquet/src/main/java/org/apache/iceberg/parquet/Parquet.java +++ b/parquet/src/main/java/org/apache/iceberg/parquet/Parquet.java @@ -1559,13 +1559,12 @@ public CloseableIterable build() { } catch (IOException e) { throw new RuntimeIOException(e); } - Schema fileSchema = ParquetSchemaUtil.convert(type); builder .useStatsFilter() .useDictionaryFilter() .useRecordFilter(filterRecords) .useBloomFilter() - .withFilter(ParquetFilters.convert(fileSchema, filter, caseSensitive)); + .withFilter(ParquetFilters.convert(type, filter, caseSensitive)); } else { // turn off filtering builder diff --git a/parquet/src/main/java/org/apache/iceberg/parquet/ParquetFilters.java b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetFilters.java index fc6febe19438..1479e21d3eae 100644 --- a/parquet/src/main/java/org/apache/iceberg/parquet/ParquetFilters.java +++ b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetFilters.java @@ -18,7 +18,11 @@ */ package org.apache.iceberg.parquet; +import java.math.BigDecimal; +import java.math.RoundingMode; import java.nio.ByteBuffer; +import java.util.Map; +import java.util.UUID; import org.apache.iceberg.Schema; import org.apache.iceberg.expressions.BoundPredicate; import org.apache.iceberg.expressions.BoundReference; @@ -29,19 +33,31 @@ import org.apache.iceberg.expressions.Expressions; import org.apache.iceberg.expressions.Literal; import org.apache.iceberg.expressions.UnboundPredicate; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.DecimalUtil; +import org.apache.iceberg.util.UUIDUtil; +import org.apache.parquet.column.ColumnDescriptor; import org.apache.parquet.filter2.compat.FilterCompat; import org.apache.parquet.filter2.predicate.FilterApi; import org.apache.parquet.filter2.predicate.FilterPredicate; import org.apache.parquet.filter2.predicate.Operators; import org.apache.parquet.io.api.Binary; +import org.apache.parquet.schema.MessageType; +import org.apache.parquet.schema.PrimitiveType; class ParquetFilters { private ParquetFilters() {} - static FilterCompat.Filter convert(Schema schema, Expression expr, boolean caseSensitive) { + static FilterCompat.Filter convert( + MessageType parquetSchema, Expression expr, boolean caseSensitive) { + Schema schema = ParquetSchemaUtil.convert(parquetSchema); FilterPredicate pred = - ExpressionVisitors.visit(expr, new ConvertFilterToParquet(schema, caseSensitive)); + ExpressionVisitors.visit( + expr, + new ConvertFilterToParquet( + schema, primitiveTypesById(parquetSchema, schema), caseSensitive)); // TODO: handle AlwaysFalse.INSTANCE if (pred != null && pred != AlwaysTrue.INSTANCE) { // FilterCompat will apply LogicalInverseRewriter @@ -51,12 +67,30 @@ static FilterCompat.Filter convert(Schema schema, Expression expr, boolean caseS } } + private static Map primitiveTypesById( + MessageType parquetSchema, Schema schema) { + Map primitiveTypesById = Maps.newHashMap(); + + for (ColumnDescriptor desc : parquetSchema.getColumns()) { + PrimitiveType primitiveType = parquetSchema.getType(desc.getPath()).asPrimitiveType(); + Integer fieldId = schema.aliasToId(String.join(".", desc.getPath())); + if (fieldId != null) { + primitiveTypesById.put(fieldId, primitiveType); + } + } + + return primitiveTypesById; + } + private static class ConvertFilterToParquet extends ExpressionVisitor { private final Schema schema; + private final Map primitiveTypesById; private final boolean caseSensitive; - private ConvertFilterToParquet(Schema schema, boolean caseSensitive) { + private ConvertFilterToParquet( + Schema schema, Map primitiveTypesById, boolean caseSensitive) { this.schema = schema; + this.primitiveTypesById = primitiveTypesById; this.caseSensitive = caseSensitive; } @@ -149,11 +183,18 @@ public FilterPredicate predicate(BoundPredicate pred) { case DOUBLE: return pred(op, FilterApi.doubleColumn(path), getParquetPrimitive(lit)); case STRING: - case UUID: case FIXED: case BINARY: - case DECIMAL: return pred(op, FilterApi.binaryColumn(path), getParquetPrimitive(lit)); + case UUID: + return pred(op, FilterApi.binaryColumn(path), getParquetUUID(lit)); + case DECIMAL: + return decimalPred( + op, + path, + primitiveTypesById.get(ref.fieldId()), + (Types.DecimalType) ref.type().asPrimitiveType(), + lit); } throw new UnsupportedOperationException("Cannot convert to Parquet filter: " + pred); @@ -173,6 +214,42 @@ public FilterPredicate predicate(UnboundPredicate pred) { } } + private static FilterPredicate decimalPred( + Operation op, + String path, + PrimitiveType primitiveType, + Types.DecimalType decimalType, + Literal lit) { + if (primitiveType == null) { + return AlwaysTrue.INSTANCE; + } + + BigDecimal decimal = decimalValue(decimalType, lit); + if (lit != null && decimal == null) { + return AlwaysTrue.INSTANCE; + } + + try { + switch (primitiveType.getPrimitiveTypeName()) { + case INT32: + return pred(op, FilterApi.intColumn(path), getDecimalAsInt(decimal)); + case INT64: + return pred(op, FilterApi.longColumn(path), getDecimalAsLong(decimal)); + case FIXED_LEN_BYTE_ARRAY: + return pred( + op, + FilterApi.binaryColumn(path), + getDecimalAsFixed(decimalType, primitiveType.getTypeLength(), decimal)); + case BINARY: + return pred(op, FilterApi.binaryColumn(path), getDecimalAsBinary(decimal)); + default: + return AlwaysTrue.INSTANCE; + } + } catch (ArithmeticException e) { + return AlwaysTrue.INSTANCE; + } + } + @SuppressWarnings("checkstyle:MethodTypeParameterName") private static , COL extends Operators.Column & Operators.SupportsLtGt> FilterPredicate pred(Operation op, COL col, C value) { @@ -214,13 +291,69 @@ FilterPredicate pred(Operation op, COL col, C value) { } } + private static Integer getDecimalAsInt(BigDecimal decimal) { + if (decimal == null) { + return null; + } + + return decimal.unscaledValue().intValueExact(); + } + + private static Long getDecimalAsLong(BigDecimal decimal) { + if (decimal == null) { + return null; + } + + return decimal.unscaledValue().longValueExact(); + } + + private static Binary getDecimalAsFixed(Types.DecimalType type, int length, BigDecimal decimal) { + if (decimal == null) { + return null; + } + + byte[] bytes = + DecimalUtil.toReusedFixLengthBytes( + type.precision(), type.scale(), decimal, new byte[length]); + return Binary.fromConstantByteArray(bytes); + } + + private static Binary getDecimalAsBinary(BigDecimal decimal) { + if (decimal == null) { + return null; + } + + return Binary.fromConstantByteArray(decimal.unscaledValue().toByteArray()); + } + + private static BigDecimal decimalValue(Types.DecimalType type, Literal lit) { + if (lit == null) { + return null; + } + + BigDecimal decimal = (BigDecimal) lit.value(); + try { + BigDecimal scaled = decimal.setScale(type.scale(), RoundingMode.UNNECESSARY); + return scaled.precision() <= type.precision() ? scaled : null; + } catch (ArithmeticException e) { + return null; + } + } + + private static Binary getParquetUUID(Literal lit) { + if (lit == null) { + return null; + } + + return Binary.fromConstantByteArray(UUIDUtil.convert((UUID) lit.value())); + } + @SuppressWarnings("unchecked") private static > C getParquetPrimitive(Literal lit) { if (lit == null) { return null; } - // TODO: this needs to convert to handle BigDecimal and UUID Object value = lit.value(); if (value instanceof Number) { return (C) lit.value(); diff --git a/parquet/src/test/java/org/apache/iceberg/parquet/TestParquetFilters.java b/parquet/src/test/java/org/apache/iceberg/parquet/TestParquetFilters.java new file mode 100644 index 000000000000..7a531a4870e7 --- /dev/null +++ b/parquet/src/test/java/org/apache/iceberg/parquet/TestParquetFilters.java @@ -0,0 +1,188 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.parquet; + +import static org.apache.iceberg.expressions.Expressions.equal; +import static org.apache.iceberg.types.Types.NestedField.optional; +import static org.assertj.core.api.Assertions.assertThat; + +import java.math.BigDecimal; +import java.util.UUID; +import org.apache.iceberg.Schema; +import org.apache.iceberg.expressions.Expression; +import org.apache.iceberg.types.TypeUtil; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.DecimalUtil; +import org.apache.iceberg.util.UUIDUtil; +import org.apache.parquet.filter2.compat.FilterCompat; +import org.apache.parquet.filter2.compat.FilterCompat.FilterPredicateCompat; +import org.apache.parquet.filter2.predicate.FilterPredicate; +import org.apache.parquet.filter2.predicate.Operators; +import org.apache.parquet.io.api.Binary; +import org.apache.parquet.schema.LogicalTypeAnnotation; +import org.apache.parquet.schema.MessageType; +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName; +import org.junit.jupiter.api.Test; + +class TestParquetFilters { + private static final Schema TABLE_SCHEMA = + new Schema( + optional(1, "decimal_int", Types.DecimalType.of(9, 2)), + optional(2, "decimal_long", Types.DecimalType.of(18, 2)), + optional(3, "decimal_fixed", Types.DecimalType.of(19, 2)), + optional(4, "uuid_col", Types.UUIDType.get())); + + private static final MessageType PARQUET_SCHEMA = + ParquetSchemaUtil.convert(TABLE_SCHEMA, "table"); + + private static final MessageType BINARY_DECIMAL_PARQUET_SCHEMA = + org.apache.parquet.schema.Types.buildMessage() + .optional(PrimitiveTypeName.BINARY) + .as(LogicalTypeAnnotation.decimalType(2, 9)) + .named("decimal_int") + .named("table"); + + private static final MessageType EXTENDED_FIXED_DECIMAL_PARQUET_SCHEMA = + org.apache.parquet.schema.Types.buildMessage() + .optional(PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY) + .length(16) + .as(LogicalTypeAnnotation.decimalType(2, 19)) + .named("decimal_fixed") + .named("table"); + + @Test + void convertsIntDecimalLiteral() { + BigDecimal decimal = new BigDecimal("12.34"); + + Operators.Eq predicate = predicate(equal("decimal_int", decimal), Operators.Eq.class); + + assertThat(predicate.getColumn().getColumnType()).isEqualTo(Integer.class); + assertThat(predicate.getValue()).isEqualTo(1234); + } + + @Test + void convertsLongDecimalLiteral() { + BigDecimal decimal = new BigDecimal("1234567890123456.78"); + + Operators.Eq predicate = predicate(equal("decimal_long", decimal), Operators.Eq.class); + + assertThat(predicate.getColumn().getColumnType()).isEqualTo(Long.class); + assertThat(predicate.getValue()).isEqualTo(123456789012345678L); + } + + @Test + void convertsFixedDecimalLiteral() { + Types.DecimalType decimalType = Types.DecimalType.of(19, 2); + BigDecimal decimal = new BigDecimal("12345678901234567.89"); + + Operators.Eq predicate = predicate(equal("decimal_fixed", decimal), Operators.Eq.class); + + byte[] expected = + DecimalUtil.toReusedFixLengthBytes( + decimalType.precision(), + decimalType.scale(), + decimal, + new byte[TypeUtil.decimalRequiredBytes(decimalType.precision())]); + + assertThat(predicate.getColumn().getColumnType()).isEqualTo(Binary.class); + assertThat(((Binary) predicate.getValue()).getBytes()).isEqualTo(expected); + } + + @Test + void usesParquetPhysicalTypeForBinaryDecimal() { + BigDecimal decimal = new BigDecimal("12.34"); + + Operators.Eq predicate = + predicate(BINARY_DECIMAL_PARQUET_SCHEMA, equal("decimal_int", decimal), Operators.Eq.class); + + assertThat(predicate.getColumn().getColumnType()).isEqualTo(Binary.class); + assertThat(((Binary) predicate.getValue()).getBytes()) + .isEqualTo(decimal.unscaledValue().toByteArray()); + } + + @Test + void usesParquetTypeLengthForFixedDecimal() { + Types.DecimalType decimalType = Types.DecimalType.of(19, 2); + BigDecimal decimal = new BigDecimal("-12345678901234567.89"); + + Operators.Eq predicate = + predicate( + EXTENDED_FIXED_DECIMAL_PARQUET_SCHEMA, + equal("decimal_fixed", decimal), + Operators.Eq.class); + + byte[] expected = + DecimalUtil.toReusedFixLengthBytes( + decimalType.precision(), decimalType.scale(), decimal, new byte[16]); + + assertThat(predicate.getColumn().getColumnType()).isEqualTo(Binary.class); + assertThat(((Binary) predicate.getValue()).getBytes()).isEqualTo(expected); + } + + @Test + void skipsDecimalLiteralWithIncompatibleScale() { + FilterCompat.Filter filter = + ParquetFilters.convert( + PARQUET_SCHEMA, equal("decimal_int", new BigDecimal("12.345")), true); + + assertThat(filter).isSameAs(FilterCompat.NOOP); + } + + @Test + void skipsDecimalLiteralThatExceedsPrecision() { + FilterCompat.Filter filter = + ParquetFilters.convert( + PARQUET_SCHEMA, equal("decimal_int", new BigDecimal("123456789.12")), true); + + assertThat(filter).isSameAs(FilterCompat.NOOP); + } + + @Test + void convertsDecimalLiteralWithTrailingZeros() { + Operators.Eq predicate = + predicate(equal("decimal_int", new BigDecimal("12.340")), Operators.Eq.class); + + assertThat(predicate.getColumn().getColumnType()).isEqualTo(Integer.class); + assertThat(predicate.getValue()).isEqualTo(1234); + } + + @Test + void convertsUuidLiteral() { + UUID uuid = UUID.fromString("f24f9b64-81fa-49d1-b74e-8c09a6e31c56"); + + Operators.Eq predicate = predicate(equal("uuid_col", uuid), Operators.Eq.class); + + assertThat(predicate.getColumn().getColumnType()).isEqualTo(Binary.class); + assertThat(((Binary) predicate.getValue()).getBytes()).isEqualTo(UUIDUtil.convert(uuid)); + } + + private static

P predicate(Expression expression, Class

type) { + return predicate(PARQUET_SCHEMA, expression, type); + } + + private static

P predicate( + MessageType parquetSchema, Expression expression, Class

type) { + FilterCompat.Filter filter = ParquetFilters.convert(parquetSchema, expression, true); + assertThat(filter).isInstanceOf(FilterPredicateCompat.class); + + FilterPredicate predicate = ((FilterPredicateCompat) filter).getFilterPredicate(); + assertThat(predicate).isInstanceOf(type); + return type.cast(predicate); + } +} diff --git a/parquet/src/test/java/org/apache/iceberg/parquet/TestParquetSchemaUtil.java b/parquet/src/test/java/org/apache/iceberg/parquet/TestParquetSchemaUtil.java index 1df904f13c7b..c50d77a25fe5 100644 --- a/parquet/src/test/java/org/apache/iceberg/parquet/TestParquetSchemaUtil.java +++ b/parquet/src/test/java/org/apache/iceberg/parquet/TestParquetSchemaUtil.java @@ -160,6 +160,22 @@ public void testAssignIdsToVariantTypesByNameMapping() { assertThat(messageTypeWithIdsFromNameMapping).isEqualTo(messageTypeWithIds); } + @Test + public void testConvertUuidLogicalType() { + MessageType messageType = + org.apache.parquet.schema.Types.buildMessage() + .optional(PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY) + .length(16) + .as(LogicalTypeAnnotation.uuidType()) + .id(1) + .named("uuid_col") + .named("test"); + + Schema actualSchema = ParquetSchemaUtil.convert(messageType); + + assertThat(actualSchema.findType(1)).isEqualTo(Types.UUIDType.get()); + } + @Test public void testSchemaConversionWithoutAssigningIds() { MessageType messageType =