From 10c39922e8f2c1070986ac05adec24ec91306e81 Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Fri, 23 Jan 2026 13:22:40 +0000 Subject: [PATCH] POC: JsonFieldName support --- .bazelversion | 1 + .../src/test/java/dev/cel/bundle/BUILD.bazel | 1 + .../test/java/dev/cel/bundle/CelImplTest.java | 72 +++++++++ .../src/main/java/dev/cel/checker/BUILD.bazel | 1 + .../dev/cel/checker/CelCheckerLegacyImpl.java | 10 +- .../java/dev/cel/checker/ExprChecker.java | 17 +- .../main/java/dev/cel/common/CelOptions.java | 17 +- .../main/java/dev/cel/common/CelSource.java | 10 +- .../cel/common/types/ProtoMessageType.java | 26 ++- .../types/ProtoMessageTypeProvider.java | 148 ++++++++++++++---- .../java/dev/cel/common/types/BUILD.bazel | 1 + .../types/ProtoMessageTypeProviderTest.java | 20 +++ .../common/types/ProtoMessageTypeTest.java | 3 +- .../src/main/java/dev/cel/runtime/BUILD.bazel | 1 + .../runtime/DescriptorMessageProvider.java | 9 ++ .../src/test/java/dev/cel/runtime/BUILD.bazel | 6 + .../cel/runtime/CelLiteInterpreterTest.java | 26 +-- .../cel/runtime/CelValueInterpreterTest.java | 14 +- .../java/dev/cel/runtime/InterpreterTest.java | 9 +- .../src/main/java/dev/cel/testing/BUILD.bazel | 4 + .../dev/cel/testing/BaseInterpreterTest.java | 76 ++++++++- .../dev/cel/testing/BaselineTestCase.java | 3 +- .../test/resources/protos/single_file.proto | 1 + 23 files changed, 411 insertions(+), 65 deletions(-) create mode 100644 .bazelversion diff --git a/.bazelversion b/.bazelversion new file mode 100644 index 000000000..6d2890793 --- /dev/null +++ b/.bazelversion @@ -0,0 +1 @@ +8.5.0 diff --git a/bundle/src/test/java/dev/cel/bundle/BUILD.bazel b/bundle/src/test/java/dev/cel/bundle/BUILD.bazel index 26cbe392d..bc40f5014 100644 --- a/bundle/src/test/java/dev/cel/bundle/BUILD.bazel +++ b/bundle/src/test/java/dev/cel/bundle/BUILD.bazel @@ -67,6 +67,7 @@ java_library( "@maven//:com_google_truth_extensions_truth_proto_extension", "@maven//:junit_junit", "@maven//:org_jspecify_jspecify", + "//testing/protos:single_file_java_proto", ], ) diff --git a/bundle/src/test/java/dev/cel/bundle/CelImplTest.java b/bundle/src/test/java/dev/cel/bundle/CelImplTest.java index 13c3392d3..41c26688d 100644 --- a/bundle/src/test/java/dev/cel/bundle/CelImplTest.java +++ b/bundle/src/test/java/dev/cel/bundle/CelImplTest.java @@ -19,10 +19,13 @@ import static dev.cel.common.CelFunctionDecl.newFunctionDeclaration; import static dev.cel.common.CelOverloadDecl.newGlobalOverload; import static dev.cel.common.CelOverloadDecl.newMemberOverload; +import static dev.cel.common.CelSource.Extension; import static org.junit.Assert.assertThrows; +import dev.cel.common.CelValidationException; import dev.cel.expr.CheckedExpr; import dev.cel.expr.Constant; +import dev.cel.runtime.CelEvaluationException; import dev.cel.expr.Decl; import dev.cel.expr.Decl.FunctionDecl; import dev.cel.expr.Decl.FunctionDecl.Overload; @@ -112,6 +115,7 @@ import dev.cel.runtime.CelUnknownSet; import dev.cel.runtime.CelVariableResolver; import dev.cel.runtime.UnknownContext; +import dev.cel.testing.testdata.SingleFileProto.SingleFile; import dev.cel.testing.testdata.proto3.StandaloneGlobalEnum; import java.time.Instant; import java.util.ArrayList; @@ -2193,6 +2197,74 @@ public void toBuilder_isImmutable() { assertThat(newRuntimeBuilder).isNotEqualTo(celImpl.toRuntimeBuilder()); } + @Test + public void eval_withJsonFieldName() throws Exception { + Cel cel = + standardCelBuilderWithMacros() + .addVar("file", StructTypeReference.create(SingleFile.getDescriptor().getFullName())) + .addMessageTypes(SingleFile.getDescriptor()) + .setOptions(CelOptions.current().enableJsonFieldNames(true).build()) + .build(); + CelAbstractSyntaxTree ast = + cel.compile("file.camelCased").getAst(); + + Object result = cel.createProgram(ast).eval(ImmutableMap.of("file", SingleFile.newBuilder().setSnakeCased("foo").build())); + + assertThat(result).isEqualTo("foo"); + } + + @Test + public void eval_withJsonFieldName_runtimeOptionDisabled_throws() throws Exception { + CelCompiler celCompiler = + CelCompilerFactory.standardCelCompilerBuilder() + .addVar("file", StructTypeReference.create(SingleFile.getDescriptor().getFullName())) + .addMessageTypes(SingleFile.getDescriptor()) + .setOptions(CelOptions.current().enableJsonFieldNames(true).build()) + .build(); + CelRuntime celRuntime = + CelRuntimeFactory.standardCelRuntimeBuilder() + .addMessageTypes(SingleFile.getDescriptor()) + .setOptions(CelOptions.current().enableJsonFieldNames(false).build()) + .build(); + CelAbstractSyntaxTree ast = celCompiler.compile("file.camelCased").getAst(); + + CelEvaluationException e = + assertThrows( + CelEvaluationException.class, + () -> celRuntime.createProgram(ast).eval(ImmutableMap.of("file", SingleFile.getDefaultInstance()))); + assertThat(e).hasMessageThat().contains("field 'camelCased' is not declared in message 'dev.cel.testing.testdata.SingleFile"); + } + + @Test + public void compile_withJsonFieldName_astTagged() throws Exception { + Cel cel = + standardCelBuilderWithMacros() + .addVar("file", StructTypeReference.create(SingleFile.getDescriptor().getFullName())) + .addMessageTypes(SingleFile.getDescriptor()) + .setOptions(CelOptions.current().enableJsonFieldNames(true).build()) + .build(); + CelAbstractSyntaxTree ast = + cel.compile("file.camelCased").getAst(); + + assertThat(ast.getSource().getExtensions()).contains(Extension.create("json_name", Extension.Version.of(1L, 1L), Extension.Component.COMPONENT_RUNTIME)); + } + + @Test + public void compile_withJsonFieldName_protoFieldNameComparison_throws() throws Exception { + Cel cel = + standardCelBuilderWithMacros() + .addVar("file", StructTypeReference.create(SingleFile.getDescriptor().getFullName())) + .addMessageTypes(SingleFile.getDescriptor()) + .setOptions(CelOptions.current().enableJsonFieldNames(true).build()) + .build(); + + CelValidationException e = + assertThrows( + CelValidationException.class, + () -> cel.compile("file.camelCased == file.snake_cased").getAst()); + assertThat(e).hasMessageThat().contains("undefined field 'snake_cased'"); + } + private static TypeProvider aliasingProvider(ImmutableMap typeAliases) { return new TypeProvider() { @Override diff --git a/checker/src/main/java/dev/cel/checker/BUILD.bazel b/checker/src/main/java/dev/cel/checker/BUILD.bazel index 6c486bd92..91a409a00 100644 --- a/checker/src/main/java/dev/cel/checker/BUILD.bazel +++ b/checker/src/main/java/dev/cel/checker/BUILD.bazel @@ -177,6 +177,7 @@ java_library( ":standard_decl", "//:auto_value", "//common:cel_ast", + "//common:cel_source", "//common:compiler_common", "//common:container", "//common:operator", diff --git a/checker/src/main/java/dev/cel/checker/CelCheckerLegacyImpl.java b/checker/src/main/java/dev/cel/checker/CelCheckerLegacyImpl.java index 41d1ca073..b8a490e65 100644 --- a/checker/src/main/java/dev/cel/checker/CelCheckerLegacyImpl.java +++ b/checker/src/main/java/dev/cel/checker/CelCheckerLegacyImpl.java @@ -456,9 +456,13 @@ public CelCheckerLegacyImpl build() { } CelTypeProvider messageTypeProvider = - new ProtoMessageTypeProvider( - CelDescriptorUtil.getAllDescriptorsFromFileDescriptor( - fileTypeSet, celOptions.resolveTypeDependencies())); + ProtoMessageTypeProvider.newBuilder() + .setAllowJsonFieldNames(celOptions.enableJsonFieldNames()) + .setCelDescriptors( + CelDescriptorUtil.getAllDescriptorsFromFileDescriptor( + fileTypeSet, celOptions.resolveTypeDependencies())) + .build(); + if (celTypeProvider != null && fileTypeSet.isEmpty()) { messageTypeProvider = celTypeProvider; } else if (celTypeProvider != null) { diff --git a/checker/src/main/java/dev/cel/checker/ExprChecker.java b/checker/src/main/java/dev/cel/checker/ExprChecker.java index 37b692ecf..348428279 100644 --- a/checker/src/main/java/dev/cel/checker/ExprChecker.java +++ b/checker/src/main/java/dev/cel/checker/ExprChecker.java @@ -31,6 +31,7 @@ import dev.cel.common.CelFunctionDecl; import dev.cel.common.CelOverloadDecl; import dev.cel.common.CelProtoAbstractSyntaxTree; +import dev.cel.common.CelSource; import dev.cel.common.Operator; import dev.cel.common.annotations.Internal; import dev.cel.common.ast.CelConstant; @@ -44,9 +45,11 @@ import dev.cel.common.types.MapType; import dev.cel.common.types.OptionalType; import dev.cel.common.types.SimpleType; +import dev.cel.common.types.ProtoMessageType; import dev.cel.common.types.TypeType; import java.util.ArrayList; import java.util.HashSet; +import java.util.Set; import java.util.List; import java.util.Map; import org.jspecify.annotations.Nullable; @@ -139,7 +142,7 @@ public static CelAbstractSyntaxTree typecheck( Map typeMap = Maps.transformValues(env.getTypeMap(), checker.inferenceContext::finalize); - return CelAbstractSyntaxTree.newCheckedAst(expr, ast.getSource(), env.getRefMap(), typeMap); + return CelAbstractSyntaxTree.newCheckedAst(expr, ast.getSource().toBuilder().addAllExtensions(checker.extensions).build(), env.getRefMap(), typeMap); } private final Env env; @@ -150,6 +153,7 @@ public static CelAbstractSyntaxTree typecheck( private final boolean compileTimeOverloadResolution; private final boolean homogeneousLiterals; private final boolean namespacedDeclarations; + private final Set extensions = new HashSet<>(); private ExprChecker( Env env, @@ -729,6 +733,17 @@ private CelType visitSelectField( if (operandType.kind() == CelKind.STRUCT) { TypeProvider.FieldType fieldType = getFieldType(expr.id(), getPosition(expr), operandType, field); + CelType typeDef = operandType; + if (!(typeDef instanceof ProtoMessageType) && typeProvider.lookupCelType(operandType.name()).isPresent()) { + typeDef = typeProvider.lookupCelType(operandType.name()).get(); + if (typeDef.kind() == CelKind.TYPE && !typeDef.parameters().isEmpty()) { + typeDef = typeDef.parameters().get(0); + } + } + + if (typeDef instanceof ProtoMessageType && ((ProtoMessageType) typeDef).isJsonName(field)) { + extensions.add(CelSource.Extension.create("json_name", CelSource.Extension.Version.of(1, 1), CelSource.Extension.Component.COMPONENT_RUNTIME)); + } // Type of the field resultType = fieldType.celType(); } else if (operandType.kind() == CelKind.MAP) { diff --git a/common/src/main/java/dev/cel/common/CelOptions.java b/common/src/main/java/dev/cel/common/CelOptions.java index d39d53803..d48202ec9 100644 --- a/common/src/main/java/dev/cel/common/CelOptions.java +++ b/common/src/main/java/dev/cel/common/CelOptions.java @@ -83,6 +83,8 @@ public enum ProtoUnsetFieldOptions { public abstract boolean enableNamespacedDeclarations(); + public abstract boolean enableJsonFieldNames(); + // Evaluation related options public abstract boolean disableCelStandardEquality(); @@ -150,6 +152,7 @@ public static Builder newBuilder() { .enableTimestampEpoch(false) .enableHeterogeneousNumericComparisons(false) .enableNamespacedDeclarations(true) + .enableJsonFieldNames(false) // Evaluation options .disableCelStandardEquality(true) .evaluateCanonicalTypesToNativeValues(false) @@ -170,7 +173,8 @@ public static Builder newBuilder() { .enableStringConcatenation(true) .enableListConcatenation(true) .enableComprehension(true) - .maxRegexProgramSize(-1); + .maxRegexProgramSize(-1) + ; } /** @@ -529,6 +533,17 @@ public abstract static class Builder { */ public abstract Builder maxRegexProgramSize(int value); + /** + * Use the `json_name` field option on a protobuf message as the name of the field. + * + *

If enabled, the compiler will only accept the `json_name` and no longer recognize + * the original protobuf field name. Use with caution as this may break existing expressions + * during compilation. The runtime continues to support both names to facilitate evaluation of + * legacy ASTs. + */ + public abstract Builder enableJsonFieldNames(boolean value); + + public abstract CelOptions build(); } } diff --git a/common/src/main/java/dev/cel/common/CelSource.java b/common/src/main/java/dev/cel/common/CelSource.java index 2678a0a2c..5e77d2ebf 100644 --- a/common/src/main/java/dev/cel/common/CelSource.java +++ b/common/src/main/java/dev/cel/common/CelSource.java @@ -320,19 +320,19 @@ public CelSource build() { public abstract static class Extension { /** Identifier for the extension. Example: constant_folding */ - abstract String id(); + public abstract String id(); /** * Version info. May be skipped if it isn't meaningful for the extension. (for example * constant_folding might always be v0.0). */ - abstract Version version(); + public abstract Version version(); /** * If set, the listed components must understand the extension for the expression to evaluate * correctly. */ - abstract ImmutableList affectedComponents(); + public abstract ImmutableList affectedComponents(); /** Version of the extension */ @AutoValue @@ -343,13 +343,13 @@ public abstract static class Version { * Major version changes indicate different required support level from the required * components. */ - abstract long major(); + public abstract long major(); /** * Minor version changes must not change the observed behavior from existing implementations, * but may be provided informational. */ - abstract long minor(); + public abstract long minor(); /** Create a new instance of Version with the provided major and minor values. */ public static Version of(long major, long minor) { diff --git a/common/src/main/java/dev/cel/common/types/ProtoMessageType.java b/common/src/main/java/dev/cel/common/types/ProtoMessageType.java index 7ac3f4fd3..8cc1caee3 100644 --- a/common/src/main/java/dev/cel/common/types/ProtoMessageType.java +++ b/common/src/main/java/dev/cel/common/types/ProtoMessageType.java @@ -29,14 +29,17 @@ public final class ProtoMessageType extends StructType { private final StructType.FieldResolver extensionResolver; + private final JsonNameResolver jsonNameResolver; ProtoMessageType( String name, ImmutableSet fieldNames, StructType.FieldResolver fieldResolver, - StructType.FieldResolver extensionResolver) { + StructType.FieldResolver extensionResolver, + JsonNameResolver jsonNameResolver) { super(name, fieldNames, fieldResolver); this.extensionResolver = extensionResolver; + this.jsonNameResolver = jsonNameResolver; } /** Find an {@code Extension} by its fully-qualified {@code extensionName}. */ @@ -46,20 +49,35 @@ public Optional findExtension(String extensionName) { .map(type -> Extension.of(extensionName, type, this)); } + /** + * Returns true if the field name is a json name. + */ + public boolean isJsonName(String fieldName) { + return jsonNameResolver.isJsonName(fieldName); + } + /** * Create a new instance of the {@code ProtoMessageType} using the {@code visibleFields} set as a * mask of the fields from the backing proto. */ public ProtoMessageType withVisibleFields(ImmutableSet visibleFields) { - return new ProtoMessageType(name, visibleFields, fieldResolver, extensionResolver); + return new ProtoMessageType(name, visibleFields, fieldResolver, extensionResolver, jsonNameResolver); } public static ProtoMessageType create( String name, ImmutableSet fieldNames, FieldResolver fieldResolver, - FieldResolver extensionResolver) { - return new ProtoMessageType(name, fieldNames, fieldResolver, extensionResolver); + FieldResolver extensionResolver, + JsonNameResolver jsonNameResolver) { + return new ProtoMessageType(name, fieldNames, fieldResolver, extensionResolver, jsonNameResolver); + } + + /** Functional interface for resolving whether a field name is a json name. */ + @FunctionalInterface + @Immutable + public interface JsonNameResolver { + boolean isJsonName(String fieldName); } /** {@code Extension} contains the name, type, and target message type of the extension. */ diff --git a/common/src/main/java/dev/cel/common/types/ProtoMessageTypeProvider.java b/common/src/main/java/dev/cel/common/types/ProtoMessageTypeProvider.java index 4b97178d0..672895b86 100644 --- a/common/src/main/java/dev/cel/common/types/ProtoMessageTypeProvider.java +++ b/common/src/main/java/dev/cel/common/types/ProtoMessageTypeProvider.java @@ -14,11 +14,9 @@ package dev.cel.common.types; -import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; import com.google.common.collect.ImmutableCollection; -import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableMultimap; import com.google.common.collect.ImmutableSet; @@ -34,11 +32,11 @@ import dev.cel.common.CelDescriptorUtil; import dev.cel.common.CelDescriptors; import dev.cel.common.internal.FileDescriptorSetConverter; +import java.util.Arrays; import java.util.Collection; import java.util.HashMap; import java.util.Map; import java.util.Optional; -import java.util.function.Function; /** * The {@code ProtoMessageTypeProvider} implements the {@link CelTypeProvider} interface to provide @@ -68,35 +66,35 @@ public final class ProtoMessageTypeProvider implements CelTypeProvider { .buildOrThrow(); private final ImmutableMap allTypes; + private final boolean allowJsonFieldNames; + /** Returns a new builder for {@link ProtoMessageTypeProvider}. */ + public static Builder newBuilder() { + return new Builder(); + } + + @Deprecated public ProtoMessageTypeProvider() { - this(CelDescriptors.builder().build()); + this(CelDescriptors.builder().build(), false); } + @Deprecated public ProtoMessageTypeProvider(FileDescriptorSet descriptorSet) { this( CelDescriptorUtil.getAllDescriptorsFromFileDescriptor( - FileDescriptorSetConverter.convert(descriptorSet))); + FileDescriptorSetConverter.convert(descriptorSet)), false); } + @Deprecated public ProtoMessageTypeProvider(Iterable descriptors) { this( CelDescriptorUtil.getAllDescriptorsFromFileDescriptor( - ImmutableSet.copyOf(Iterables.transform(descriptors, Descriptor::getFile)))); + ImmutableSet.copyOf(Iterables.transform(descriptors, Descriptor::getFile))), false); } + @Deprecated public ProtoMessageTypeProvider(ImmutableSet fileDescriptors) { - this(CelDescriptorUtil.getAllDescriptorsFromFileDescriptor(fileDescriptors)); - } - - public ProtoMessageTypeProvider(CelDescriptors celDescriptors) { - this.allTypes = - ImmutableMap.builder() - .putAll(createEnumTypes(celDescriptors.enumDescriptors())) - .putAll( - createProtoMessageTypes( - celDescriptors.messageTypeDescriptors(), celDescriptors.extensionDescriptors())) - .buildOrThrow(); + this(CelDescriptorUtil.getAllDescriptorsFromFileDescriptor(fileDescriptors), false); } @Override @@ -120,8 +118,17 @@ private ImmutableMap createProtoMessageTypes( if (protoMessageTypes.containsKey(descriptor.getFullName())) { continue; } - ImmutableList fieldNames = - descriptor.getFields().stream().map(FieldDescriptor::getName).collect(toImmutableList()); + + ImmutableSet.Builder fieldNamesBuilder = ImmutableSet.builder() ; + ImmutableSet.Builder jsonNamesBuilder = ImmutableSet.builder(); + for (FieldDescriptor fd : descriptor.getFields()) { + fieldNamesBuilder.add(fd.getName()); + if (allowJsonFieldNames) { + fieldNamesBuilder.add(fd.getJsonName()); + jsonNamesBuilder.add(fd.getJsonName()); + } + } + ImmutableSet jsonNames = jsonNamesBuilder.build(); Map extensionFields = new HashMap<>(); for (FieldDescriptor extension : extensionMap.get(descriptor.getFullName())) { @@ -133,9 +140,10 @@ private ImmutableMap createProtoMessageTypes( descriptor.getFullName(), ProtoMessageType.create( descriptor.getFullName(), - ImmutableSet.copyOf(fieldNames), + fieldNamesBuilder.build(), new FieldResolver(this, descriptor)::findField, - new FieldResolver(this, extensions)::findField)); + new FieldResolver(this, extensions)::findField, + jsonNames::contains)); } return ImmutableMap.copyOf(protoMessageTypes); } @@ -158,19 +166,35 @@ private ImmutableMap createEnumTypes( } private static class FieldResolver { - private final CelTypeProvider celTypeProvider; + private final ProtoMessageTypeProvider protoMessageTypeProvider; private final ImmutableMap fields; - private FieldResolver(CelTypeProvider celTypeProvider, Descriptor descriptor) { + private static ImmutableMap collectFieldDescriptorMap(ProtoMessageTypeProvider protoMessageTypeProvider, Descriptor descriptor) { + ImmutableMap.Builder builder = ImmutableMap.builder(); + for (FieldDescriptor fd : descriptor.getFields()) { + if (protoMessageTypeProvider.allowJsonFieldNames && !fd.getJsonName().isEmpty()) { + builder.put(fd.getJsonName(), fd); + } else { + builder.put(fd.getName(), fd); + } } + + return builder.buildKeepingLast(); + } + + private FieldResolver( + ProtoMessageTypeProvider protoMessageTypeProvider, + Descriptor descriptor + ) { this( - celTypeProvider, - descriptor.getFields().stream() - .collect(toImmutableMap(FieldDescriptor::getName, Function.identity()))); + protoMessageTypeProvider, + collectFieldDescriptorMap(protoMessageTypeProvider, descriptor) + ); } private FieldResolver( - CelTypeProvider celTypeProvider, ImmutableMap fields) { - this.celTypeProvider = celTypeProvider; + ProtoMessageTypeProvider protoMessageTypeProvider, + ImmutableMap fields) { + this.protoMessageTypeProvider = protoMessageTypeProvider; this.fields = fields; } @@ -203,11 +227,11 @@ private Optional findFieldInternal(FieldDescriptor fieldDescriptor) { String messageName = descriptor.getFullName(); fieldType = CelTypes.getWellKnownCelType(messageName) - .orElse(celTypeProvider.findType(descriptor.getFullName()).orElse(null)); + .orElse(protoMessageTypeProvider.findType(descriptor.getFullName()).orElse(null)); break; case ENUM: EnumDescriptor enumDescriptor = fieldDescriptor.getEnumType(); - fieldType = celTypeProvider.findType(enumDescriptor.getFullName()).orElse(null); + fieldType = protoMessageTypeProvider.findType(enumDescriptor.getFullName()).orElse(null); break; default: fieldType = PROTO_TYPE_TO_CEL_TYPE.get(fieldDescriptor.getType()); @@ -222,4 +246,68 @@ private Optional findFieldInternal(FieldDescriptor fieldDescriptor) { return Optional.of(fieldType); } } + + /** Builder for {@link ProtoMessageTypeProvider}. */ + public static final class Builder { + private final ImmutableSet.Builder fileDescriptors = ImmutableSet.builder(); + private boolean allowJsonFieldNames; + private CelDescriptors celDescriptors; + + /** Adds a {@link FileDescriptor} to the provider. */ + public Builder addFileDescriptors(FileDescriptor... fileDescriptors) { + return addFileDescriptors(Arrays.asList(fileDescriptors)); + } + + /** Adds a collection of {@link FileDescriptor}s to the provider. */ + public Builder addFileDescriptors(Iterable fileDescriptors) { + this.fileDescriptors.addAll(fileDescriptors); + return this; + } + + /** Adds a collection of {@link Descriptor}s. The parent file of each descriptor is added. */ + public Builder addDescriptors(Iterable descriptors) { + this.fileDescriptors.addAll(Iterables.transform(descriptors, Descriptor::getFile)); + return this; + } + + /** Adds a {@link FileDescriptorSet} to the provider. */ + public Builder addFileDescriptorSet(FileDescriptorSet fileDescriptorSet) { + this.fileDescriptors.addAll(FileDescriptorSetConverter.convert(fileDescriptorSet)); + return this; + } + + public Builder setAllowJsonFieldNames(boolean allowJsonFieldNames) { + this.allowJsonFieldNames = allowJsonFieldNames; + return this; + } + + public Builder setCelDescriptors(CelDescriptors celDescriptors) { + this.celDescriptors = celDescriptors; + return this; + } + + /** Builds the {@link ProtoMessageTypeProvider}. */ + public ProtoMessageTypeProvider build() { + if (celDescriptors == null) { + celDescriptors = CelDescriptorUtil.getAllDescriptorsFromFileDescriptor(fileDescriptors.build()); + } + return new ProtoMessageTypeProvider(celDescriptors, allowJsonFieldNames); + } + + private Builder() {} + } + + private ProtoMessageTypeProvider( + CelDescriptors celDescriptors, + boolean allowJsonFieldNames + ) { + this.allowJsonFieldNames = allowJsonFieldNames; + this.allTypes = + ImmutableMap.builder() + .putAll(createEnumTypes(celDescriptors.enumDescriptors())) + .putAll( + createProtoMessageTypes( + celDescriptors.messageTypeDescriptors(), celDescriptors.extensionDescriptors())) + .buildOrThrow(); + } } diff --git a/common/src/test/java/dev/cel/common/types/BUILD.bazel b/common/src/test/java/dev/cel/common/types/BUILD.bazel index 600a63940..0c8121bbd 100644 --- a/common/src/test/java/dev/cel/common/types/BUILD.bazel +++ b/common/src/test/java/dev/cel/common/types/BUILD.bazel @@ -16,6 +16,7 @@ java_library( "//common/types:cel_types", "//common/types:message_type_provider", "//common/types:type_providers", + "//testing/protos:single_file_java_proto", "@cel_spec//proto/cel/expr:checked_java_proto", "@cel_spec//proto/cel/expr/conformance/proto2:test_all_types_java_proto", "@cel_spec//proto/cel/expr/conformance/proto3:test_all_types_java_proto", diff --git a/common/src/test/java/dev/cel/common/types/ProtoMessageTypeProviderTest.java b/common/src/test/java/dev/cel/common/types/ProtoMessageTypeProviderTest.java index d774903e1..ea19df5d1 100644 --- a/common/src/test/java/dev/cel/common/types/ProtoMessageTypeProviderTest.java +++ b/common/src/test/java/dev/cel/common/types/ProtoMessageTypeProviderTest.java @@ -20,8 +20,10 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import dev.cel.common.types.CelTypeProvider.CombinedCelTypeProvider; +import dev.cel.common.types.StructType.Field; import dev.cel.expr.conformance.proto2.TestAllTypes; import dev.cel.expr.conformance.proto2.TestAllTypesExtensions; +import dev.cel.testing.testdata.SingleFileProto.SingleFile; import java.util.Optional; import org.junit.Test; import org.junit.runner.RunWith; @@ -254,4 +256,22 @@ public void types_combinedDuplicateProviderIsSameAsFirst() { CombinedCelTypeProvider combined = new CombinedCelTypeProvider(proto3Provider, proto3Provider); assertThat(combined.types()).hasSize(proto3Provider.types().size()); } + + @Test + public void findField_withJsonNameOption() { + ProtoMessageTypeProvider typeProvider = + ProtoMessageTypeProvider.newBuilder() + .addFileDescriptors(SingleFile.getDescriptor().getFile()) + .setAllowJsonFieldNames(true) + .build(); + + ProtoMessageType msgType = (ProtoMessageType) typeProvider.findType(SingleFile.getDescriptor().getFullName()).get(); + + // Note that these are the same fields, with json_name option set + Optional snakeCasedField = msgType.findField("snake_cased"); + Optional jsonNameField = msgType.findField("camelCased"); + + assertThat(snakeCasedField).isEmpty(); + assertThat(jsonNameField).isPresent(); + } } diff --git a/common/src/test/java/dev/cel/common/types/ProtoMessageTypeTest.java b/common/src/test/java/dev/cel/common/types/ProtoMessageTypeTest.java index 3feab06fd..0155d0b9e 100644 --- a/common/src/test/java/dev/cel/common/types/ProtoMessageTypeTest.java +++ b/common/src/test/java/dev/cel/common/types/ProtoMessageTypeTest.java @@ -46,7 +46,8 @@ public void setUp() { "my.package.TestMessage", FIELD_MAP.keySet(), (field) -> Optional.ofNullable(FIELD_MAP.get(field)), - (extension) -> Optional.ofNullable(EXTENSION_MAP.get(extension))); + (extension) -> Optional.ofNullable(EXTENSION_MAP.get(extension)), + (field) -> false); } @Test diff --git a/runtime/src/main/java/dev/cel/runtime/BUILD.bazel b/runtime/src/main/java/dev/cel/runtime/BUILD.bazel index d253edba3..0b9635efd 100644 --- a/runtime/src/main/java/dev/cel/runtime/BUILD.bazel +++ b/runtime/src/main/java/dev/cel/runtime/BUILD.bazel @@ -828,6 +828,7 @@ java_library( "//:auto_value", "//common:cel_ast", "//common:cel_descriptor_util", + "//common:cel_source", "//common:cel_descriptors", "//common:options", "//common/annotations", diff --git a/runtime/src/main/java/dev/cel/runtime/DescriptorMessageProvider.java b/runtime/src/main/java/dev/cel/runtime/DescriptorMessageProvider.java index 2980aea96..ba0e442ec 100644 --- a/runtime/src/main/java/dev/cel/runtime/DescriptorMessageProvider.java +++ b/runtime/src/main/java/dev/cel/runtime/DescriptorMessageProvider.java @@ -182,6 +182,15 @@ private FieldDescriptor findField(Descriptor descriptor, String fieldName) { } } + if (fieldDescriptor == null && celOptions.enableJsonFieldNames()) { + for (FieldDescriptor fd : descriptor.getFields()) { + if (fd.getJsonName().equals(fieldName)) { + fieldDescriptor = fd; + break; + } + } + } + if (fieldDescriptor == null) { throw new IllegalArgumentException( String.format( diff --git a/runtime/src/test/java/dev/cel/runtime/BUILD.bazel b/runtime/src/test/java/dev/cel/runtime/BUILD.bazel index ed76fdbe7..eef5cadc9 100644 --- a/runtime/src/test/java/dev/cel/runtime/BUILD.bazel +++ b/runtime/src/test/java/dev/cel/runtime/BUILD.bazel @@ -20,6 +20,7 @@ java_library( exclude = [ "CelLiteInterpreterTest.java", "CelValueInterpreterTest.java", + "InterpreterJsonNameTest.java", "InterpreterTest.java", ] + ANDROID_TESTS, ), @@ -118,6 +119,8 @@ java_library( ], deps = [ # "//java/com/google/testing/testsize:annotations", + "//common:options", + "//runtime", "//testing:base_interpreter_test", "@maven//:junit_junit", "@maven//:com_google_testparameterinjector_test_parameter_injector", @@ -132,6 +135,8 @@ java_library( ], deps = [ # "//java/com/google/testing/testsize:annotations", + "//common:options", + "//runtime", "//testing:base_interpreter_test", "@maven//:junit_junit", "@maven//:com_google_testparameterinjector_test_parameter_injector", @@ -182,6 +187,7 @@ java_library( "CelLiteInterpreterTest.java", ], deps = [ + "//common:options", "//common/values:proto_message_lite_value_provider", "//extensions:optional_library", "//runtime", diff --git a/runtime/src/test/java/dev/cel/runtime/CelLiteInterpreterTest.java b/runtime/src/test/java/dev/cel/runtime/CelLiteInterpreterTest.java index 088a2d7b0..bce1aba07 100644 --- a/runtime/src/test/java/dev/cel/runtime/CelLiteInterpreterTest.java +++ b/runtime/src/test/java/dev/cel/runtime/CelLiteInterpreterTest.java @@ -14,7 +14,9 @@ package dev.cel.runtime; +import com.google.testing.junit.testparameterinjector.TestParameter; import com.google.testing.junit.testparameterinjector.TestParameterInjector; +import dev.cel.common.CelOptions; import dev.cel.common.values.ProtoMessageLiteValueProvider; import dev.cel.expr.conformance.proto3.TestAllTypesCelDescriptor; import dev.cel.extensions.CelOptionalLibrary; @@ -28,15 +30,21 @@ @RunWith(TestParameterInjector.class) public class CelLiteInterpreterTest extends BaseInterpreterTest { public CelLiteInterpreterTest() { - super( - CelRuntimeFactory.standardCelRuntimeBuilder() - .setValueProvider( - ProtoMessageLiteValueProvider.newInstance( - dev.cel.expr.conformance.proto2.TestAllTypesCelDescriptor.getDescriptor(), - TestAllTypesCelDescriptor.getDescriptor())) - .addLibraries(CelOptionalLibrary.INSTANCE) - .setOptions(newBaseCelOptions().toBuilder().enableCelValue(true).build()) - .build()); + this(newBaseCelOptions().toBuilder().enableCelValue(true).build()); + } + protected CelLiteInterpreterTest(CelOptions celOptions) { + super(celOptions); + } + + @Override + protected CelRuntimeBuilder getRuntimeBuilder(CelOptions celOptions) { + return CelRuntimeFactory.standardCelRuntimeBuilder() + .setValueProvider( + ProtoMessageLiteValueProvider.newInstance( + dev.cel.expr.conformance.proto2.TestAllTypesCelDescriptor.getDescriptor(), + TestAllTypesCelDescriptor.getDescriptor())) + .addLibraries(CelOptionalLibrary.INSTANCE) + .setOptions(celOptions.toBuilder().enableCelValue(true).build()); } @Override diff --git a/runtime/src/test/java/dev/cel/runtime/CelValueInterpreterTest.java b/runtime/src/test/java/dev/cel/runtime/CelValueInterpreterTest.java index ad3fae082..ec8ed4938 100644 --- a/runtime/src/test/java/dev/cel/runtime/CelValueInterpreterTest.java +++ b/runtime/src/test/java/dev/cel/runtime/CelValueInterpreterTest.java @@ -14,8 +14,9 @@ package dev.cel.runtime; +import com.google.testing.junit.testparameterinjector.TestParameter; import com.google.testing.junit.testparameterinjector.TestParameterInjector; -// import com.google.testing.testsize.MediumTest; +import dev.cel.common.CelOptions; import dev.cel.testing.BaseInterpreterTest; import org.junit.runner.RunWith; @@ -25,7 +26,16 @@ public class CelValueInterpreterTest extends BaseInterpreterTest { public CelValueInterpreterTest() { - super(newBaseCelOptions().toBuilder().enableCelValue(true).build()); + this(newBaseCelOptions().toBuilder().enableCelValue(true).build()); + } + + protected CelValueInterpreterTest(CelOptions celOptions) { + super(celOptions); + } + + @Override + protected CelRuntimeBuilder getRuntimeBuilder(CelOptions celOptions) { + return super.getRuntimeBuilder(celOptions.toBuilder().enableCelValue(true).build()); } @Override diff --git a/runtime/src/test/java/dev/cel/runtime/InterpreterTest.java b/runtime/src/test/java/dev/cel/runtime/InterpreterTest.java index a6342c6e4..2870d0e41 100644 --- a/runtime/src/test/java/dev/cel/runtime/InterpreterTest.java +++ b/runtime/src/test/java/dev/cel/runtime/InterpreterTest.java @@ -14,12 +14,17 @@ package dev.cel.runtime; +import com.google.testing.junit.testparameterinjector.TestParameter; import com.google.testing.junit.testparameterinjector.TestParameterInjector; -// import com.google.testing.testsize.MediumTest; +import dev.cel.common.CelOptions; import dev.cel.testing.BaseInterpreterTest; import org.junit.runner.RunWith; /** Tests for {@link Interpreter} and related functionality. */ // @MediumTest @RunWith(TestParameterInjector.class) -public class InterpreterTest extends BaseInterpreterTest {} +public class InterpreterTest extends BaseInterpreterTest { + public InterpreterTest() { + super(BaseInterpreterTest.newBaseCelOptions()); + } +} diff --git a/testing/src/main/java/dev/cel/testing/BUILD.bazel b/testing/src/main/java/dev/cel/testing/BUILD.bazel index d9d6a6564..f642637e1 100644 --- a/testing/src/main/java/dev/cel/testing/BUILD.bazel +++ b/testing/src/main/java/dev/cel/testing/BUILD.bazel @@ -76,6 +76,8 @@ java_library( ":cel_baseline_test_case", "//:java_truth", "//common:cel_ast", + "//common:cel_exception", + "//common:compiler_common", "//common:container", "//common:options", "//common:proto_ast", @@ -84,8 +86,10 @@ java_library( "//common/internal:proto_time_utils", "//common/resources/testdata/proto3:standalone_global_enum_java_proto", "//common/types", + "//common/types:message_type_provider", "//common/types:type_providers", "//common/values:cel_byte_string", + "//compiler", "//extensions:optional_library", "//runtime", "//runtime:function_binding", diff --git a/testing/src/main/java/dev/cel/testing/BaseInterpreterTest.java b/testing/src/main/java/dev/cel/testing/BaseInterpreterTest.java index 425271e1b..9ccbe4d4f 100644 --- a/testing/src/main/java/dev/cel/testing/BaseInterpreterTest.java +++ b/testing/src/main/java/dev/cel/testing/BaseInterpreterTest.java @@ -68,7 +68,10 @@ import dev.cel.common.types.SimpleType; import dev.cel.common.types.StructTypeReference; import dev.cel.common.types.TypeParamType; +import dev.cel.common.types.ProtoMessageTypeProvider; import dev.cel.common.values.CelByteString; +import dev.cel.common.CelValidationException; +import dev.cel.compiler.CelCompilerFactory; import dev.cel.expr.conformance.proto3.TestAllTypes; import dev.cel.expr.conformance.proto3.TestAllTypes.NestedEnum; import dev.cel.expr.conformance.proto3.TestAllTypes.NestedMessage; @@ -77,6 +80,7 @@ import dev.cel.runtime.CelFunctionBinding; import dev.cel.runtime.CelLateFunctionBindings; import dev.cel.runtime.CelRuntime; +import dev.cel.runtime.CelRuntimeBuilder; import dev.cel.runtime.CelRuntimeFactory; import dev.cel.runtime.CelUnknownSet; import dev.cel.runtime.CelVariableResolver; @@ -114,36 +118,74 @@ public abstract class BaseInterpreterTest extends CelBaselineTestCase { .comprehensionMaxIterations(1_000) .build(); private CelRuntime celRuntime; + protected CelOptions celOptions; protected BaseInterpreterTest() { - this(newRuntime(BASE_CEL_OPTIONS)); + this(BASE_CEL_OPTIONS); } protected BaseInterpreterTest(CelOptions celOptions) { - this(newRuntime(celOptions)); + this.celOptions = celOptions; + this.celRuntime = getRuntimeBuilder(celOptions).build(); } protected BaseInterpreterTest(CelRuntime celRuntime) { + this(celRuntime, BASE_CEL_OPTIONS); + } + + protected BaseInterpreterTest(CelRuntime celRuntime, CelOptions celOptions) { this.celRuntime = celRuntime; + this.celOptions = celOptions; } - private static CelRuntime newRuntime(CelOptions celOptions) { + protected CelRuntimeBuilder getRuntimeBuilder(CelOptions celOptions) { return CelRuntimeFactory.standardCelRuntimeBuilder() .addLibraries(CelOptionalLibrary.INSTANCE) .addFileTypes(TEST_FILE_DESCRIPTORS) - .setOptions(celOptions) - .build(); + .setOptions(celOptions); } protected static CelOptions newBaseCelOptions() { return BASE_CEL_OPTIONS; } + @Override + protected CelAbstractSyntaxTree prepareTest(List descriptors) { + return prepareTest( + ProtoMessageTypeProvider.newBuilder() + .addFileDescriptors(descriptors) + .setAllowJsonFieldNames(celOptions != null && celOptions.enableJsonFieldNames()) + .build()); + } + + protected CelAbstractSyntaxTree prepareTest(CelTypeProvider typeProvider) { + prepareCompiler(typeProvider); + + CelAbstractSyntaxTree ast; + try { + ast = celCompiler.parse(source, testSourceDescription()).getAst(); + } catch (CelValidationException e) { + printTestValidationError(e); + return null; + } + + try { + return celCompiler.check(ast).getAst(); + } catch (CelValidationException e) { + printTestValidationError(e); + return null; + } + } + @Override protected void prepareCompiler(CelTypeProvider typeProvider) { super.prepareCompiler(typeProvider); this.celCompiler = - celCompiler.toCompilerBuilder().addLibraries(CelOptionalLibrary.INSTANCE).build(); + celCompiler + .toCompilerBuilder() + .addLibraries(CelOptionalLibrary.INSTANCE) + .setOptions(celOptions) + .build(); } private CelAbstractSyntaxTree compileTestCase() { @@ -2482,4 +2524,26 @@ private static Descriptor getDeserializedTestAllTypeDescriptor() { throw new RuntimeException("Error loading TestAllTypes descriptor", e); } } + + @Test + public void jsonFieldNames() throws Exception { + // Reconfigure the runtime with json field names enabled + this.celOptions = celOptions.toBuilder().enableJsonFieldNames(true).build(); + this.celRuntime = getRuntimeBuilder(celOptions).build(); + // BaseInterpreterTest.prepareCompiler will use the updated celOptions + + TestAllTypes message = TestAllTypes.newBuilder().setSingleInt32(42).build(); + declareVariable("x", StructTypeReference.create(TestAllTypes.getDescriptor().getFullName())); + + // Access via json_name (camelCase) + source = "x.singleInt32 == 42"; + assertThat(runTest(ImmutableMap.of("x", message))).isEqualTo(true); + + // Struct construction using json_names + source = "TestAllTypes{singleInt32: 42}.singleInt32 == 42"; + container = CelContainer.ofName(TestAllTypes.getDescriptor().getFile().getPackage()); + assertThat(runTest()).isEqualTo(true); + + skipBaselineVerification(); + } } diff --git a/testing/src/main/java/dev/cel/testing/BaselineTestCase.java b/testing/src/main/java/dev/cel/testing/BaselineTestCase.java index 493106451..c4e230bee 100644 --- a/testing/src/main/java/dev/cel/testing/BaselineTestCase.java +++ b/testing/src/main/java/dev/cel/testing/BaselineTestCase.java @@ -154,7 +154,8 @@ protected void verify() { String expected = getExpected().trim(); LineDiffer.Diff lineDiff = LineDiffer.diffLines(expected, actual); if (!lineDiff.isEmpty()) { - String actualFileLocation = tryCreateNewBaseline(actual); +// String actualFileLocation = tryCreateNewBaseline(actual); + String actualFileLocation = "foo"; throw new BaselineComparisonError( testName.getMethodName(), baselineFileName(), actual, actualFileLocation, lineDiff); } diff --git a/testing/src/test/resources/protos/single_file.proto b/testing/src/test/resources/protos/single_file.proto index 0fcf270a1..b5ce518e0 100644 --- a/testing/src/test/resources/protos/single_file.proto +++ b/testing/src/test/resources/protos/single_file.proto @@ -26,4 +26,5 @@ message SingleFile { string name = 1; Path path = 2; + string snake_cased = 3 [json_name = "camelCased"]; }