diff --git a/docs/data/sql_functions.yml b/docs/data/sql_functions.yml index 83bc11abe4451..96e3e49bab52d 100644 --- a/docs/data/sql_functions.yml +++ b/docs/data/sql_functions.yml @@ -1167,6 +1167,16 @@ variant: parser will keep the last occurrence of all fields with the same key, otherwise when `allowDuplicateKeys` is false it will throw an error. The default value of `allowDuplicateKeys` is false. + - sql: variant '[' INT ']' + table: VARIANT.at(INT) + description: | + If the VARIANT is an ARRAY value, returns a VARIANT whose value is the element at + the specified index. Otherwise, NULL is returned. + - sql: variant '[' STRING ']' + table: VARIANT.at(STRING) + description: | + If the VARIANT is a MAP value that has an element with this key, a VARIANT holding + the associated value is returned. Otherwise, NULL is returned. valueconstruction: - sql: | diff --git a/docs/data/sql_functions_zh.yml b/docs/data/sql_functions_zh.yml index 8a5c772d21505..821294824b13e 100644 --- a/docs/data/sql_functions_zh.yml +++ b/docs/data/sql_functions_zh.yml @@ -1241,7 +1241,6 @@ variant: 同键的字段,否则当 allowDuplicateKeys 为 false 时,它会抛出一个错误。默认情况下, allowDuplicateKeys 的值为 false。 - - sql: TRY_PARSE_JSON(json_string[, allow_duplicate_keys]) description: | 尽可能将 JSON 字符串解析为 Variant。如果 JSON 字符串无效,则返回 NULL。如果希望抛出错误而不是返回 NULL, @@ -1251,6 +1250,16 @@ variant: 同键的字段,否则当 allowDuplicateKeys 为 false 时,它会抛出一个错误。默认情况下, allowDuplicateKeys 的值为 false。 + - sql: variant '[' INT ']' + table: VARIANT.at(INT) + description: | + 如果这是一个 ARRAY 类型的 VARIANT,则返回一个 VARIANT,其值为指定索引处的元素。否则返回 NULL。 + + - sql: variant '[' STRING ']' + table: VARIANT.at(STRING) + description: | + 如果这是一个 MAP 类型的 VARIANT,则返回一个 VARIANT,其值为与指定键关联的值。否则返回 NULL。 + valueconstruction: - sql: | diff --git a/flink-table/flink-sql-parser/src/main/java/org/apache/calcite/sql/fun/SqlItemOperator.java b/flink-table/flink-sql-parser/src/main/java/org/apache/calcite/sql/fun/SqlItemOperator.java new file mode 100644 index 0000000000000..7f30f97ac105f --- /dev/null +++ b/flink-table/flink-sql-parser/src/main/java/org/apache/calcite/sql/fun/SqlItemOperator.java @@ -0,0 +1,218 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.calcite.sql.fun; + +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.rel.type.RelDataTypeField; +import org.apache.calcite.sql.SqlCall; +import org.apache.calcite.sql.SqlCallBinding; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.SqlOperandCountRange; +import org.apache.calcite.sql.SqlOperatorBinding; +import org.apache.calcite.sql.SqlSpecialOperator; +import org.apache.calcite.sql.SqlWriter; +import org.apache.calcite.sql.parser.SqlParserPos; +import org.apache.calcite.sql.type.OperandTypes; +import org.apache.calcite.sql.type.SqlOperandCountRanges; +import org.apache.calcite.sql.type.SqlSingleOperandTypeChecker; +import org.apache.calcite.sql.type.SqlTypeFamily; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.sql.type.SqlTypeUtil; + +import java.util.Arrays; + +import static java.util.Objects.requireNonNull; +import static org.apache.calcite.sql.type.NonNullableAccessors.getComponentTypeOrThrow; +import static org.apache.calcite.sql.validate.SqlNonNullableAccessors.getOperandLiteralValueOrThrow; + +/** + * The item operator {@code [ ... ]}, used to access a given element of an array, map or struct. For + * example, {@code myArray[3]}, {@code "myMap['foo']"}, {@code myStruct[2]} or {@code + * myStruct['fieldName']}. + * + *

This class was copied over from Calcite 1.39.0 version to support access variant + * (FLINK-37924). + * + *

Line 148, CALCITE-7325, should be removed after upgrading Calcite to 1.42.0. + */ +public class SqlItemOperator extends SqlSpecialOperator { + public final int offset; + public final boolean safe; + + public SqlItemOperator( + String name, SqlSingleOperandTypeChecker operandTypeChecker, int offset, boolean safe) { + super(name, SqlKind.ITEM, 100, true, null, null, operandTypeChecker); + this.offset = offset; + this.safe = safe; + } + + @Override + public ReduceResult reduceExpr(int ordinal, TokenSequence list) { + SqlNode left = list.node(ordinal - 1); + SqlNode right = list.node(ordinal + 1); + return new ReduceResult( + ordinal - 1, + ordinal + 2, + createCall( + SqlParserPos.sum( + Arrays.asList( + requireNonNull(left, "left").getParserPosition(), + requireNonNull(right, "right").getParserPosition(), + list.pos(ordinal))), + left, + right)); + } + + @Override + public void unparse(SqlWriter writer, SqlCall call, int leftPrec, int rightPrec) { + call.operand(0).unparse(writer, leftPrec, 0); + final SqlWriter.Frame frame = writer.startList("[", "]"); + if (!this.getName().equals("ITEM")) { + final SqlWriter.Frame offsetFrame = writer.startFunCall(this.getName()); + call.operand(1).unparse(writer, 0, 0); + writer.endFunCall(offsetFrame); + } else { + call.operand(1).unparse(writer, 0, 0); + } + writer.endList(frame); + } + + @Override + public SqlOperandCountRange getOperandCountRange() { + return SqlOperandCountRanges.of(2); + } + + @Override + public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFailure) { + final SqlNode left = callBinding.operand(0); + final SqlNode right = callBinding.operand(1); + if (!getOperandTypeChecker().checkSingleOperandType(callBinding, left, 0, throwOnFailure)) { + return false; + } + final SqlSingleOperandTypeChecker checker = getChecker(callBinding); + return checker.checkSingleOperandType(callBinding, right, 0, throwOnFailure); + } + + @Override + public SqlSingleOperandTypeChecker getOperandTypeChecker() { + return (SqlSingleOperandTypeChecker) + requireNonNull(super.getOperandTypeChecker(), "operandTypeChecker"); + } + + private static SqlSingleOperandTypeChecker getChecker(SqlCallBinding callBinding) { + final RelDataType operandType = callBinding.getOperandType(0); + switch (operandType.getSqlTypeName()) { + case ARRAY: + return OperandTypes.family(SqlTypeFamily.INTEGER); + case MAP: + RelDataType keyType = + requireNonNull(operandType.getKeyType(), "operandType.getKeyType()"); + SqlTypeName sqlTypeName = keyType.getSqlTypeName(); + return OperandTypes.family( + requireNonNull( + sqlTypeName.getFamily(), + () -> + "keyType.getSqlTypeName().getFamily() null, type is " + + sqlTypeName)); + case ROW: + case ANY: + case DYNAMIC_STAR: + case VARIANT: + return OperandTypes.family(SqlTypeFamily.INTEGER) + .or(OperandTypes.family(SqlTypeFamily.CHARACTER)); + default: + throw callBinding.newValidationSignatureError(); + } + } + + @Override + public String getAllowedSignatures(String name) { + if (name.equals("ITEM")) { + return "[]\n" + + "[]\n" + + "[|]\n" + + "[|]"; + } else { + return "[" + name + "()]"; + } + } + + @Override + public RelDataType inferReturnType(SqlOperatorBinding opBinding) { + final RelDataTypeFactory typeFactory = opBinding.getTypeFactory(); + final RelDataType operandType = opBinding.getOperandType(0); + switch (operandType.getSqlTypeName()) { + case VARIANT: + // Return type is always nullable VARIANT + return typeFactory.createTypeWithNullability(operandType, true); + case ARRAY: + return typeFactory.createTypeWithNullability( + getComponentTypeOrThrow(operandType), true); + case MAP: + return typeFactory.createTypeWithNullability( + requireNonNull( + operandType.getValueType(), + () -> "operandType.getValueType() is null for " + operandType), + true); + case ROW: + RelDataType fieldType; + RelDataType indexType = opBinding.getOperandType(1); + + if (SqlTypeUtil.isString(indexType)) { + final String fieldName = + getOperandLiteralValueOrThrow(opBinding, 1, String.class); + RelDataTypeField field = operandType.getField(fieldName, false, false); + if (field == null) { + throw new AssertionError( + "Cannot infer type of field '" + + fieldName + + "' within ROW type: " + + operandType); + } else { + fieldType = field.getType(); + } + } else if (SqlTypeUtil.isIntType(indexType)) { + Integer index = opBinding.getOperandLiteralValue(1, Integer.class); + if (index == null || index < 1 || index > operandType.getFieldCount()) { + throw new AssertionError( + "Cannot infer type of field at position " + + index + + " within ROW type: " + + operandType); + } else { + fieldType = + operandType.getFieldList().get(index - 1).getType(); // 1 indexed + } + } else { + throw new AssertionError( + "Unsupported field identifier type: '" + indexType + "'"); + } + if (operandType.isNullable()) { + fieldType = typeFactory.createTypeWithNullability(fieldType, true); + } + return fieldType; + case ANY: + case DYNAMIC_STAR: + return typeFactory.createTypeWithNullability( + typeFactory.createSqlType(SqlTypeName.ANY), true); + default: + throw new AssertionError(); + } + } +} diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java index cea16b08a8e66..02f9b45dec4ca 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java @@ -2442,7 +2442,8 @@ ANY, and(logical(LogicalTypeRoot.BOOLEAN), LITERAL) sequence( or( logical(LogicalTypeRoot.ARRAY), - logical(LogicalTypeRoot.MAP)), + logical(LogicalTypeRoot.MAP), + logical(LogicalTypeRoot.VARIANT)), InputTypeStrategies.ITEM_AT_INDEX)) .outputTypeStrategy(SpecificTypeStrategies.ITEM_AT) .build(); diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/ItemAtIndexArgumentTypeStrategy.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/ItemAtIndexArgumentTypeStrategy.java index 82b360ccc4354..d35a92ab9ebad 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/ItemAtIndexArgumentTypeStrategy.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/ItemAtIndexArgumentTypeStrategy.java @@ -41,7 +41,10 @@ * or {@link LogicalTypeRoot#MULTISET} * *

the type to be equal to the key type of {@link LogicalTypeRoot#MAP} if the first argument is a - * map. + * map + * + *

a {@link LogicalTypeFamily#NUMERIC} type or {LogicalTypeFamily.CHARACTER_STRING} type if the + * first argument is a {@link LogicalTypeRoot#VARIANT}. */ @Internal public final class ItemAtIndexArgumentTypeStrategy implements ArgumentTypeStrategy { @@ -86,12 +89,36 @@ public Optional inferArgumentType( } } + if (collectionType.is(LogicalTypeRoot.VARIANT)) { + if (indexType.getLogicalType().is(LogicalTypeFamily.INTEGER_NUMERIC)) { + + if (callContext.isArgumentLiteral(1)) { + Optional literalVal = callContext.getArgumentValue(1, Integer.class); + if (literalVal.isPresent() && literalVal.get() <= 0) { + return callContext.fail( + throwOnFailure, + "The provided index must be a valid SQL index starting from 1, but was '%s'", + literalVal.get()); + } + } + + return Optional.of(indexType); + } else if (indexType.getLogicalType().is(LogicalTypeFamily.CHARACTER_STRING)) { + return Optional.of(indexType); + } else { + return callContext.fail( + throwOnFailure, + "Incorrect type %s supplied for the variant value. Variant values can only be accessed with a CHARACTER STRING map key or an INTEGER NUMERIC array index.", + indexType.getLogicalType().toString()); + } + } + return Optional.empty(); } @Override public Signature.Argument getExpectedArgument( FunctionDefinition functionDefinition, int argumentPos) { - return Signature.Argument.of("[ | ]"); + return Signature.Argument.of("[ | | ]"); } } diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/ItemAtTypeStrategy.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/ItemAtTypeStrategy.java index dc5a3ddebf8f4..ba781ee4875bb 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/ItemAtTypeStrategy.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/ItemAtTypeStrategy.java @@ -20,6 +20,7 @@ import org.apache.flink.annotation.Internal; import org.apache.flink.table.functions.BuiltInFunctionDefinitions; +import org.apache.flink.table.types.AtomicDataType; import org.apache.flink.table.types.CollectionDataType; import org.apache.flink.table.types.DataType; import org.apache.flink.table.types.KeyValueDataType; @@ -33,27 +34,30 @@ /** * An output type strategy for {@link BuiltInFunctionDefinitions#AT}. * - *

Returns either the element of an {@link LogicalTypeFamily#COLLECTION} type or the value of - * {@link LogicalTypeRoot#MAP}. + *

Returns either the element of an {@link LogicalTypeFamily#COLLECTION} type, the value of + * {@link LogicalTypeRoot#MAP}, or another {@link LogicalTypeRoot#VARIANT} value obtained by + * accessing the input {@link LogicalTypeRoot#VARIANT}. */ @Internal public final class ItemAtTypeStrategy implements TypeStrategy { @Override public Optional inferType(CallContext callContext) { - DataType arrayOrMapType = callContext.getArgumentDataTypes().get(0); + DataType containerType = callContext.getArgumentDataTypes().get(0); final Optional legacyArrayElement = - StrategyUtils.extractLegacyArrayElement(arrayOrMapType); + StrategyUtils.extractLegacyArrayElement(containerType); if (legacyArrayElement.isPresent()) { return legacyArrayElement; } - if (arrayOrMapType.getLogicalType().is(LogicalTypeRoot.ARRAY)) { + if (containerType.getLogicalType().is(LogicalTypeRoot.ARRAY)) { return Optional.of( - ((CollectionDataType) arrayOrMapType).getElementDataType().nullable()); - } else if (arrayOrMapType instanceof KeyValueDataType) { - return Optional.of(((KeyValueDataType) arrayOrMapType).getValueDataType().nullable()); + ((CollectionDataType) containerType).getElementDataType().nullable()); + } else if (containerType instanceof KeyValueDataType) { + return Optional.of(((KeyValueDataType) containerType).getValueDataType().nullable()); + } else if (containerType.getLogicalType().is(LogicalTypeRoot.VARIANT)) { + return Optional.of(((AtomicDataType) containerType).nullable()); } return Optional.empty(); diff --git a/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/inference/ItemAtIndexArgumentTypeStrategyTest.java b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/inference/ItemAtIndexArgumentTypeStrategyTest.java index d38ebd011f347..1a5bf23b7a0f6 100644 --- a/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/inference/ItemAtIndexArgumentTypeStrategyTest.java +++ b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/inference/ItemAtIndexArgumentTypeStrategyTest.java @@ -42,7 +42,7 @@ protected Stream testData() { DataTypes.ARRAY(DataTypes.STRING().notNull()), DataTypes.SMALLINT().notNull()) .expectSignature( - "f([ | ], [ | ])") + "f([ | | ], [ | | ])") .expectArgumentTypes( DataTypes.ARRAY(DataTypes.STRING().notNull()), DataTypes.SMALLINT().notNull()), @@ -58,7 +58,7 @@ protected Stream testData() { DataTypes.MAP(DataTypes.BIGINT(), DataTypes.STRING().notNull()), DataTypes.SMALLINT()) .expectSignature( - "f([ | ], [ | ])") + "f([ | | ], [ | | ])") .expectArgumentTypes( DataTypes.MAP(DataTypes.BIGINT(), DataTypes.STRING().notNull()), DataTypes.BIGINT()), @@ -67,11 +67,36 @@ protected Stream testData() { DataTypes.MAP(DataTypes.BIGINT(), DataTypes.STRING().notNull()), DataTypes.STRING()) .expectErrorMessage("Expected index for a MAP to be of type: BIGINT"), - TestSpec.forStrategy("Validate incorrect index", ITEM_AT_INPUT_STRATEGY) + TestSpec.forStrategy( + "Validate incorrect index for an array", ITEM_AT_INPUT_STRATEGY) .calledWithArgumentTypes( DataTypes.ARRAY(DataTypes.BIGINT()), DataTypes.INT().notNull()) .calledWithLiteralAt(1, 0) .expectErrorMessage( - "The provided index must be a valid SQL index starting from 1, but was '0'")); + "The provided index must be a valid SQL index starting from 1, but was '0'"), + TestSpec.forStrategy("Validate integer index for a variant", ITEM_AT_INPUT_STRATEGY) + .calledWithArgumentTypes( + DataTypes.VARIANT(), DataTypes.SMALLINT().notNull()) + .expectSignature( + "f([ | | ], [ | | ])") + .expectArgumentTypes(DataTypes.VARIANT(), DataTypes.SMALLINT().notNull()), + TestSpec.forStrategy( + "Validate incorrect index for a variant", ITEM_AT_INPUT_STRATEGY) + .calledWithArgumentTypes( + DataTypes.VARIANT(), DataTypes.SMALLINT().notNull()) + .calledWithLiteralAt(1, 0) + .expectErrorMessage( + "The provided index must be a valid SQL index starting from 1, but was '0'"), + TestSpec.forStrategy("Validate string key for a variant", ITEM_AT_INPUT_STRATEGY) + .calledWithArgumentTypes(DataTypes.VARIANT(), DataTypes.STRING().notNull()) + .expectSignature( + "f([ | | ], [ | | ])") + .expectArgumentTypes(DataTypes.VARIANT(), DataTypes.STRING().notNull()), + TestSpec.forStrategy( + "Validate incorrect variant key for a variant", + ITEM_AT_INPUT_STRATEGY) + .calledWithArgumentTypes(DataTypes.VARIANT(), DataTypes.DOUBLE().notNull()) + .expectErrorMessage( + "Incorrect type DOUBLE NOT NULL supplied for the variant value. Variant values can only be accessed with a CHARACTER STRING map key or an INTEGER NUMERIC array index.")); } } diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ExprCodeGenerator.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ExprCodeGenerator.scala index c625bfd89b34f..7154aa09f0a6e 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ExprCodeGenerator.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ExprCodeGenerator.scala @@ -743,7 +743,11 @@ class ExprCodeGenerator(ctx: CodeGeneratorContext, nullableInput: Boolean) case LogicalTypeRoot.ROW | LogicalTypeRoot.STRUCTURED_TYPE => generateDot(ctx, operands) - case _ => throw new CodeGenException("Expect an array, a map or a row.") + case LogicalTypeRoot.VARIANT => + val key = operands(1) + generateVariantGet(ctx, operands.head, key) + + case _ => throw new CodeGenException("Expect an array, a map, a row or a variant.") } case CARDINALITY => diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/ScalarOperatorGens.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/ScalarOperatorGens.scala index 249e73fc4ca12..fe6b1d960475f 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/ScalarOperatorGens.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/ScalarOperatorGens.scala @@ -1668,6 +1668,100 @@ object ScalarOperatorGens { generateUnaryOperatorIfNotNull(ctx, resultType, map)(_ => s"${map.resultTerm}.size()") } + def generateVariantGet( + ctx: CodeGeneratorContext, + variant: GeneratedExpression, + key: GeneratedExpression): GeneratedExpression = { + val Seq(resultTerm, nullTerm) = newNames(ctx, "result", "isNull") + val tmpValue = newName(ctx, "tmpValue") + + val variantType = variant.resultType.asInstanceOf[VariantType] + val variantTerm = variant.resultTerm + val variantTypeTerm = primitiveTypeTermForType(variantType) + val variantDefaultTerm = primitiveDefaultValue(variantType) + + val keyTerm = key.resultTerm + val keyType = key.resultType + + val accessCode = if (isInteger(keyType)) { + generateIntegerKeyAccess( + variantTerm, + variantTypeTerm, + keyTerm, + resultTerm, + nullTerm, + tmpValue + ) + } else if (isCharacterString(keyType)) { + val fieldName = key.literalValue.get.toString + generateCharacterStringKeyAccess( + variantTerm, + variantTypeTerm, + fieldName, + resultTerm, + nullTerm, + tmpValue + ) + } else { + throw new CodeGenException(s"Unsupported key type for variant: $keyType") + } + + val finalCode = + s""" + |${variant.code} + |${key.code} + |boolean $nullTerm = (${variant.nullTerm} || ${key.nullTerm}); + |$variantTypeTerm $resultTerm = $variantDefaultTerm; + |if (!$nullTerm) { + | $accessCode + |} + """.stripMargin + + GeneratedExpression(resultTerm, nullTerm, finalCode, variantType) + } + + private def generateCharacterStringKeyAccess( + variantTerm: String, + variantTypeTerm: String, + fieldName: String, + resultTerm: String, + nullTerm: String, + tmpValue: String): String = { + s""" + | if ($variantTerm.isObject()){ + | $variantTypeTerm $tmpValue = $variantTerm.getField("$fieldName"); + | if ($tmpValue == null) { + | $nullTerm = true; + | } else { + | $resultTerm = $tmpValue; + | } + | } else { + | throw new org.apache.flink.table.api.TableRuntimeException("String key access on variant requires an object variant, but a non-object variant was provided."); + | } + """.stripMargin + } + + private def generateIntegerKeyAccess( + variantTerm: String, + variantTypeTerm: String, + keyTerm: String, + resultTerm: String, + nullTerm: String, + tmpValue: String): String = { + s""" + | if ($variantTerm.isArray()){ + | $variantTypeTerm $tmpValue = $variantTerm.getElement($keyTerm - 1); + | if ($tmpValue == null) { + | $nullTerm = true; + | } else { + | $resultTerm = $tmpValue; + | } + | } else { + | throw new org.apache.flink.table.api.TableRuntimeException("Integer index access on variant requires an array variant, but a non-array variant was provided."); + | } + """.stripMargin + } + // ---------------------------------------------------------------------------------------- // private generate utils // ---------------------------------------------------------------------------------------- diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/VariantSemanticTest.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/VariantSemanticTest.java index 6815e7b2528fd..7b7b39a1fc380 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/VariantSemanticTest.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/VariantSemanticTest.java @@ -20,6 +20,7 @@ import org.apache.flink.table.api.DataTypes; import org.apache.flink.table.api.TableRuntimeException; +import org.apache.flink.table.api.ValidationException; import org.apache.flink.table.functions.AggregateFunction; import org.apache.flink.table.functions.ScalarFunction; import org.apache.flink.table.planner.plan.nodes.exec.testutils.SemanticTestBase; @@ -33,7 +34,10 @@ import java.util.List; +import static org.apache.flink.table.api.Expressions.$; + /** Semantic tests for {@link DataTypes#VARIANT()} type. */ +@SuppressWarnings("checkstyle:LocalFinalVariableName") public class VariantSemanticTest extends SemanticTestBase { static final VariantBuilder BUILDER = Variant.newBuilder(); @@ -253,6 +257,195 @@ public class VariantSemanticTest extends SemanticTestBase { .runSql("INSERT INTO sink_t SELECT k, SUM(v) AS total FROM t GROUP BY k") .build(); + static final TableTestProgram VARIANT_ARRAY_ACCESS; + + static final TableTestProgram VARIANT_OBJECT_ACCESS; + + static final TableTestProgram VARIANT_NESTED_ACCESS; + + static final TableTestProgram VARIANT_ARRAY_ERROR_ACCESS; + + static final TableTestProgram VARIANT_OBJECT_ERROR_ACCESS; + + public static final SourceTestStep VARIANT_ARRAY_SOURCE = + SourceTestStep.newBuilder("t") + .addSchema("v VARIANT") + .producedValues( + Row.of( + BUILDER.array() + .add(BUILDER.of(1)) + .add(BUILDER.of("hello")) + .add(BUILDER.of(3.14)) + .build()), + Row.of( + BUILDER.array() + .add(BUILDER.of(10)) + .add(BUILDER.of("world")) + .build()), + new Row(1)) + .build(); + + public static final SinkTestStep VARIANT_ARRAY_SINK = + SinkTestStep.newBuilder("sink_t") + .addSchema("v1 VARIANT", "v2 VARIANT", "v3 VARIANT") + .consumedValues( + Row.of(BUILDER.of(1), BUILDER.of("hello"), BUILDER.of(3.14)), + Row.of(BUILDER.of(10), BUILDER.of("world"), null), + new Row(3)) + .build(); + + public static final SourceTestStep VARIANT_OBJECT_SOURCE = + SourceTestStep.newBuilder("t") + .addSchema("v VARIANT") + .producedValues( + Row.of( + BUILDER.object() + .add("name", BUILDER.of("Alice")) + .add("age", BUILDER.of(30)) + .add("city", BUILDER.of("NYC")) + .build()), + Row.of( + BUILDER.object() + .add("name", BUILDER.of("Bob")) + .add("age", BUILDER.of(25)) + .build()), + new Row(1)) + .build(); + + public static final SinkTestStep VARIANT_OBJECT_SINK = + SinkTestStep.newBuilder("sink_t") + .addSchema("name VARIANT", "age VARIANT", "city VARIANT") + .consumedValues( + Row.of(BUILDER.of("Alice"), BUILDER.of(30), BUILDER.of("NYC")), + Row.of(BUILDER.of("Bob"), BUILDER.of(25), null), + new Row(3)) + .build(); + + public static final SourceTestStep VARIANT_NESTED_SOURCE = + SourceTestStep.newBuilder("t") + .addSchema("v VARIANT") + .producedValues( + Row.of( + BUILDER.object() + .add( + "users", + BUILDER.array() + .add( + BUILDER.object() + .add( + "id", + BUILDER.of(1)) + .add( + "name", + BUILDER.of( + "Alice")) + .build()) + .build()) + .build()), + new Row(1)) + .build(); + + public static final SinkTestStep VARIANT_NESTED_SINK = + SinkTestStep.newBuilder("sink_t") + .addSchema("user_id VARIANT", "user_name VARIANT") + .consumedValues(Row.of(BUILDER.of(1), BUILDER.of("Alice")), new Row(2)) + .build(); + + static { + VARIANT_ARRAY_ACCESS = + TableTestProgram.of( + "variant-array-access", + "validates variant array access using [] operator in sql and at() in table api") + .setupTableSource(VARIANT_ARRAY_SOURCE) + .setupTableSink(VARIANT_ARRAY_SINK) + .runSql("INSERT INTO sink_t SELECT v[1], v[2], v[3] FROM t") + .runTableApi( + t -> + t.from("t") + .select( + $("v").at(1).as("v1"), + $("v").at(2).as("v2"), + $("v").at(3).as("v3")), + "sink_t") + .build(); + + VARIANT_OBJECT_ACCESS = + TableTestProgram.of( + "variant-object-access", + "validates variant object field access using [] operator in sql and at() in table api") + .setupTableSource(VARIANT_OBJECT_SOURCE) + .setupTableSink(VARIANT_OBJECT_SINK) + .runSql("INSERT INTO sink_t SELECT v['name'], v['age'], v['city'] FROM t") + .runTableApi( + t -> + t.from("t") + .select( + $("v").at("name").as("name"), + $("v").at("age").as("age"), + $("v").at("city").as("city")), + "sink_t") + .build(); + + VARIANT_NESTED_ACCESS = + TableTestProgram.of( + "variant-nested-access", + "validates variant nested access using [] operator in sql and at() in table api") + .setupTableSource(VARIANT_NESTED_SOURCE) + .setupTableSink(VARIANT_NESTED_SINK) + .runSql( + "INSERT INTO sink_t SELECT v['users'][1]['id'], v['users'][1]['name'] FROM t") + .runTableApi( + t -> + t.from("t") + .select( + $("v").at("users") + .at(1) + .at("id") + .as("user_id"), + $("v").at("users") + .at(1) + .at("name") + .as("user_name")), + "sink_t") + .build(); + + VARIANT_ARRAY_ERROR_ACCESS = + TableTestProgram.of( + "variant-array-error-access", + "validates variant array access using [] operator in sql and at() in table api with string") + .setupTableSource(VARIANT_ARRAY_SOURCE) + .runFailingSql( + "SELECT v['1'], v['2'], v['3'] FROM t", + TableRuntimeException.class, + "String key access on variant requires an object variant, but a non-object variant was provided.") + .runFailingSql( + "SELECT v[1.5], v[4.2], v[3.3] FROM t", + ValidationException.class, + "Cannot apply 'ITEM' to arguments of type 'ITEM(, )'. Supported form(s): []\n" + + "[]\n" + + "[|]\n" + + "[|]") + .build(); + + VARIANT_OBJECT_ERROR_ACCESS = + TableTestProgram.of( + "variant-object-error-access", + "validates variant object field access using [] operator in sql and at() in table api") + .setupTableSource(VARIANT_OBJECT_SOURCE) + .runFailingSql( + "SELECT v[1], v[2], v[3] FROM t", + TableRuntimeException.class, + "Integer index access on variant requires an array variant, but a non-array variant was provided.") + .runFailingSql( + "SELECT v[1.5], v[4.2], v[3.3] FROM t", + ValidationException.class, + "Cannot apply 'ITEM' to arguments of type 'ITEM(, )'. Supported form(s): []\n" + + "[]\n" + + "[|]\n" + + "[|]") + .build(); + } + @Override public List programs() { return List.of( @@ -264,7 +457,12 @@ public List programs() { BUILTIN_AGG_WITH_RETRACTION, VARIANT_AS_UDF_ARG, VARIANT_AS_UDAF_ARG, - VARIANT_AS_AGG_KEY); + VARIANT_AS_AGG_KEY, + VARIANT_ARRAY_ACCESS, + VARIANT_OBJECT_ACCESS, + VARIANT_NESTED_ACCESS, + VARIANT_ARRAY_ERROR_ACCESS, + VARIANT_OBJECT_ERROR_ACCESS); } public static class MyUdf extends ScalarFunction {