From 1974b049a31604fc384faf4a12025ca719be34d8 Mon Sep 17 00:00:00 2001 From: Dmitri Plotnikov Date: Mon, 1 Jun 2026 16:34:03 -0700 Subject: [PATCH] Refactor type signature generation and parsing to use TypeSpec. Refactors the internal signature generation and parsing logic to operate on `cel::TypeSpec` instead of `cel::Type`. Adds a utility to convert `cel::Type` to `cel::TypeSpec`. Updates the YAML environment configuration parser to accept type signatures as strings for variable types and function overload definitions, in addition to the existing structured format. PiperOrigin-RevId: 924976939 --- common/BUILD | 1 + common/ast/metadata.h | 4 + common/internal/signature.cc | 268 +++++++++++++++++------------- common/internal/signature.h | 19 +++ common/internal/signature_test.cc | 166 +++++++----------- common/type_spec_resolver.cc | 105 ++++++++++++ common/type_spec_resolver.h | 3 + common/type_spec_resolver_test.cc | 27 +++ 8 files changed, 374 insertions(+), 219 deletions(-) diff --git a/common/BUILD b/common/BUILD index a016d2cb5..01710329b 100644 --- a/common/BUILD +++ b/common/BUILD @@ -53,6 +53,7 @@ cc_library( deps = [ ":ast", ":type", + ":type_kind", "//internal:status_macros", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", diff --git a/common/ast/metadata.h b/common/ast/metadata.h index 197790ff3..1a69b5b50 100644 --- a/common/ast/metadata.h +++ b/common/ast/metadata.h @@ -573,6 +573,10 @@ class TypeSpec { TypeSpecKind& mutable_type_kind() { return type_kind_; } + bool is_specified() const { + return !absl::holds_alternative(type_kind_); + } + bool has_dyn() const { return absl::holds_alternative(type_kind_); } diff --git a/common/internal/signature.cc b/common/internal/signature.cc index 5c75225f9..5f4af2b14 100644 --- a/common/internal/signature.cc +++ b/common/internal/signature.cc @@ -64,125 +64,144 @@ void AppendEscaped(std::string* result, std::string_view str, bool escape_dot) { } } -absl::Status AppendTypeParameters(std::string* result, const Type& type); +absl::Status AppendTypeDesc(std::string* result, const TypeSpec& type_spec); -// Recursively appends a string representation of the given `type` to `result`. -// Type parameters are enclosed in angle brackets and separated by commas. -// -// Grammar: -// TypeDesc = NamespaceIdentifier [ "<" TypeList ">" ] ; -// NamespaceIdentifier = [ "." ] Identifier { "." Identifier } ; -// TypeList = TypeElem { "," TypeElem } ; -// TypeElem = TypeDesc | TypeParam -// TypeParam = "~" Alpha ; -// Identifier = ( Alpha | "_" ) { AlphaNumeric | "_" } ; -// (* Terminals *) -// Alpha = "a"..."z" | "A"..."Z" ; -// Digit = "0"..."9" ; -// AlphaNumeric = Alpha | Digit ; -// -// For compatibility, the implementation allows unexpected characters in -// type names and parameters and escapes them with a backslash. -absl::Status AppendTypeDesc(std::string* result, const Type& type) { - switch (type.kind()) { - case TypeKind::kNull: - absl::StrAppend(result, "null"); - break; - case TypeKind::kBool: - absl::StrAppend(result, "bool"); - break; - case TypeKind::kInt: - absl::StrAppend(result, "int"); - break; - case TypeKind::kUint: - absl::StrAppend(result, "uint"); - break; - case TypeKind::kDouble: - absl::StrAppend(result, "double"); - break; - case TypeKind::kString: - absl::StrAppend(result, "string"); - break; - case TypeKind::kBytes: - absl::StrAppend(result, "bytes"); - break; - case TypeKind::kDuration: - absl::StrAppend(result, "duration"); - break; - case TypeKind::kTimestamp: - absl::StrAppend(result, "timestamp"); - break; - case TypeKind::kAny: - absl::StrAppend(result, "any"); - break; - case TypeKind::kDyn: - absl::StrAppend(result, "dyn"); - break; - case TypeKind::kBoolWrapper: - absl::StrAppend(result, "bool_wrapper"); - break; - case TypeKind::kIntWrapper: - absl::StrAppend(result, "int_wrapper"); - break; - case TypeKind::kUintWrapper: - absl::StrAppend(result, "uint_wrapper"); - break; - case TypeKind::kDoubleWrapper: - absl::StrAppend(result, "double_wrapper"); - break; - case TypeKind::kStringWrapper: - absl::StrAppend(result, "string_wrapper"); - break; - case TypeKind::kBytesWrapper: - absl::StrAppend(result, "bytes_wrapper"); - break; - case TypeKind::kList: - absl::StrAppend(result, "list"); - CEL_RETURN_IF_ERROR(AppendTypeParameters(result, type)); - break; - case TypeKind::kMap: - absl::StrAppend(result, "map"); - CEL_RETURN_IF_ERROR(AppendTypeParameters(result, type)); - break; - case TypeKind::kFunction: - absl::StrAppend(result, "function"); - CEL_RETURN_IF_ERROR(AppendTypeParameters(result, type)); - break; - case TypeKind::kType: - absl::StrAppend(result, "type"); - CEL_RETURN_IF_ERROR(AppendTypeParameters(result, type)); - break; - case TypeKind::kTypeParam: - absl::StrAppend(result, "~"); - AppendEscaped(result, type.GetTypeParam().name(), /*escape_dot=*/true); - break; - case TypeKind::kOpaque: - AppendEscaped(result, type.name(), /*escape_dot=*/false); - CEL_RETURN_IF_ERROR(AppendTypeParameters(result, type)); - break; - case TypeKind::kStruct: - AppendEscaped(result, type.name(), /*escape_dot=*/false); - CEL_RETURN_IF_ERROR(AppendTypeParameters(result, type)); - break; - default: - return absl::InvalidArgumentError( - absl::StrFormat("Type kind: %s is not supported in CEL declarations", - type.DebugString())); +absl::Status AppendTypeSpecList(std::string* result, + const std::vector& params) { + if (!params.empty()) { + result->push_back('<'); + for (size_t i = 0; i < params.size(); ++i) { + CEL_RETURN_IF_ERROR(AppendTypeDesc(result, params[i])); + if (i < params.size() - 1) { + result->push_back(','); + } + } + result->push_back('>'); } return absl::OkStatus(); } -absl::Status AppendTypeParameters(std::string* result, const Type& type) { - const auto& parameters = type.GetParameters(); - if (!parameters.empty()) { - result->push_back('<'); - for (size_t i = 0; i < parameters.size(); ++i) { - CEL_RETURN_IF_ERROR(AppendTypeDesc(result, parameters[i])); - if (i < parameters.size() - 1) { +absl::Status AppendTypeDesc(std::string* result, const TypeSpec& type_spec) { + if (type_spec.has_null()) { + absl::StrAppend(result, "null"); + } else if (type_spec.has_dyn()) { + absl::StrAppend(result, "dyn"); + } else if (type_spec.has_primitive()) { + switch (type_spec.primitive()) { + case PrimitiveType::kBool: + absl::StrAppend(result, "bool"); + break; + case PrimitiveType::kInt64: + absl::StrAppend(result, "int"); + break; + case PrimitiveType::kUint64: + absl::StrAppend(result, "uint"); + break; + case PrimitiveType::kDouble: + absl::StrAppend(result, "double"); + break; + case PrimitiveType::kString: + absl::StrAppend(result, "string"); + break; + case PrimitiveType::kBytes: + absl::StrAppend(result, "bytes"); + break; + default: + return absl::InvalidArgumentError("Unsupported primitive type"); + } + } else if (type_spec.has_well_known()) { + switch (type_spec.well_known()) { + case WellKnownTypeSpec::kAny: + absl::StrAppend(result, "any"); + break; + case WellKnownTypeSpec::kTimestamp: + absl::StrAppend(result, "timestamp"); + break; + case WellKnownTypeSpec::kDuration: + absl::StrAppend(result, "duration"); + break; + default: + return absl::InvalidArgumentError("Unsupported well-known type"); + } + } else if (type_spec.has_wrapper()) { + switch (type_spec.wrapper()) { + case PrimitiveType::kBool: + absl::StrAppend(result, "bool_wrapper"); + break; + case PrimitiveType::kInt64: + absl::StrAppend(result, "int_wrapper"); + break; + case PrimitiveType::kUint64: + absl::StrAppend(result, "uint_wrapper"); + break; + case PrimitiveType::kDouble: + absl::StrAppend(result, "double_wrapper"); + break; + case PrimitiveType::kString: + absl::StrAppend(result, "string_wrapper"); + break; + case PrimitiveType::kBytes: + absl::StrAppend(result, "bytes_wrapper"); + break; + default: + return absl::InvalidArgumentError("Unsupported wrapper type"); + } + } else if (type_spec.has_list_type()) { + absl::StrAppend(result, "list"); + if (type_spec.list_type().has_elem_type()) { + result->push_back('<'); + CEL_RETURN_IF_ERROR( + AppendTypeDesc(result, type_spec.list_type().elem_type())); + result->push_back('>'); + } + } else if (type_spec.has_map_type()) { + absl::StrAppend(result, "map"); + if (type_spec.map_type().has_key_type() && + type_spec.map_type().has_value_type()) { + result->push_back('<'); + CEL_RETURN_IF_ERROR( + AppendTypeDesc(result, type_spec.map_type().key_type())); + result->push_back(','); + CEL_RETURN_IF_ERROR( + AppendTypeDesc(result, type_spec.map_type().value_type())); + result->push_back('>'); + } + } else if (type_spec.has_function()) { + absl::StrAppend(result, "function"); + if (type_spec.function().has_result_type() || + !type_spec.function().arg_types().empty()) { + result->push_back('<'); + if (type_spec.function().has_result_type()) { + CEL_RETURN_IF_ERROR( + AppendTypeDesc(result, type_spec.function().result_type())); + } else { + absl::StrAppend(result, "dyn"); + } + for (const auto& arg : type_spec.function().arg_types()) { result->push_back(','); + CEL_RETURN_IF_ERROR(AppendTypeDesc(result, arg)); } + result->push_back('>'); } + } else if (type_spec.has_type()) { + absl::StrAppend(result, "type"); + result->push_back('<'); + CEL_RETURN_IF_ERROR(AppendTypeDesc(result, type_spec.type())); result->push_back('>'); + } else if (type_spec.has_type_param()) { + absl::StrAppend(result, "~"); + AppendEscaped(result, type_spec.type_param().type(), /*escape_dot=*/true); + } else if (type_spec.has_abstract_type()) { + AppendEscaped(result, type_spec.abstract_type().name(), + /*escape_dot=*/false); + CEL_RETURN_IF_ERROR(AppendTypeSpecList( + result, type_spec.abstract_type().parameter_types())); + } else if (type_spec.has_message_type()) { + AppendEscaped(result, type_spec.message_type().type(), + /*escape_dot=*/false); + } else { + return absl::InvalidArgumentError(absl::StrCat( + "Unsupported type in signature: ", FormatTypeSpec(type_spec))); } return absl::OkStatus(); } @@ -190,13 +209,32 @@ absl::Status AppendTypeParameters(std::string* result, const Type& type) { absl::StatusOr MakeTypeSignature(const Type& type) { std::string result; - CEL_RETURN_IF_ERROR(AppendTypeDesc(&result, type)); + CEL_ASSIGN_OR_RETURN(TypeSpec type_spec, ConvertTypeToTypeSpec(type)); + CEL_RETURN_IF_ERROR(AppendTypeDesc(&result, type_spec)); + return result; +} + +absl::StatusOr MakeTypeSpecSignature(const TypeSpec& type_spec) { + std::string result; + CEL_RETURN_IF_ERROR(AppendTypeDesc(&result, type_spec)); return result; } absl::StatusOr MakeOverloadSignature( std::string_view function_name, const std::vector& args, bool is_member) { + std::vector arg_type_specs; + arg_type_specs.reserve(args.size()); + for (const auto& arg : args) { + CEL_ASSIGN_OR_RETURN(TypeSpec type_spec, ConvertTypeToTypeSpec(arg)); + arg_type_specs.push_back(type_spec); + } + return MakeOverloadSignature(function_name, arg_type_specs, is_member); +} + +absl::StatusOr MakeOverloadSignature( + std::string_view function_name, const std::vector& args, + bool is_member) { std::string result; if (is_member) { if (!args.empty()) { @@ -589,10 +627,14 @@ absl::StatusOr ParseFunctionSignature( return out; } +absl::StatusOr ParseTypeSpec(std::string_view signature) { + std::string stripped_sig = StripUnescapedWhitespace(signature); + return ParseTypeSignature(stripped_sig); +} + absl::StatusOr ParseType(std::string_view signature, google::protobuf::Arena* arena, const google::protobuf::DescriptorPool& pool) { - std::string stripped_sig = StripUnescapedWhitespace(signature); - CEL_ASSIGN_OR_RETURN(auto type_spec, ParseTypeSignature(stripped_sig)); + CEL_ASSIGN_OR_RETURN(auto type_spec, ParseTypeSpec(signature)); return cel::ConvertTypeSpecToType(type_spec, arena, pool); } diff --git a/common/internal/signature.h b/common/internal/signature.h index 3fdba4b2e..79e963760 100644 --- a/common/internal/signature.h +++ b/common/internal/signature.h @@ -37,6 +37,16 @@ namespace cel::common_internal { // - `list>` absl::StatusOr MakeTypeSignature(const Type& type); +// Generates an signature for a `cel::TypeSpec`, which is a string +// representation of the type. +// +// Examples: +// +// - `int` +// - `list` +// - `list>` +absl::StatusOr MakeTypeSpecSignature(const TypeSpec& type_spec); + // Generates an identifier for a function overload based on the function name // and the types of the arguments. If `is_member` is true, the first argument // type is used as the receiver and is prepended to the function name, followed @@ -59,6 +69,15 @@ absl::StatusOr MakeOverloadSignature( std::string_view function_name, const std::vector& args, bool is_member); +// Generates an identifier for a function overload based on the function name +// and the types of the arguments. See above for more details. +absl::StatusOr MakeOverloadSignature( + std::string_view function_name, const std::vector& args, + bool is_member); + +// Parses a string type signature directly into a `cel::TypeSpec`. +absl::StatusOr ParseTypeSpec(std::string_view signature); + // Parses a string type signature directly into a `cel::Type`. absl::StatusOr ParseType(std::string_view signature, google::protobuf::Arena* arena, const google::protobuf::DescriptorPool& pool); diff --git a/common/internal/signature_test.cc b/common/internal/signature_test.cc index 765055f75..3399a8f78 100644 --- a/common/internal/signature_test.cc +++ b/common/internal/signature_test.cc @@ -14,6 +14,7 @@ // limitations under the License. #include +#include #include #include @@ -42,82 +43,9 @@ google::protobuf::Arena* GetTestArena() { return &*arena; } -void VerifyParsedMatchesType(const TypeSpec& parsed, const Type& original) { - switch (original.kind()) { - case TypeKind::kDyn: - EXPECT_TRUE(parsed.has_dyn()); - break; - case TypeKind::kNull: - EXPECT_TRUE(parsed.has_null()); - break; - case TypeKind::kBool: - EXPECT_EQ(parsed.primitive(), PrimitiveType::kBool); - break; - case TypeKind::kInt: - EXPECT_EQ(parsed.primitive(), PrimitiveType::kInt64); - break; - case TypeKind::kUint: - EXPECT_EQ(parsed.primitive(), PrimitiveType::kUint64); - break; - case TypeKind::kDouble: - EXPECT_EQ(parsed.primitive(), PrimitiveType::kDouble); - break; - case TypeKind::kString: - EXPECT_EQ(parsed.primitive(), PrimitiveType::kString); - break; - case TypeKind::kBytes: - EXPECT_EQ(parsed.primitive(), PrimitiveType::kBytes); - break; - case TypeKind::kAny: - EXPECT_EQ(parsed.well_known(), WellKnownTypeSpec::kAny); - break; - case TypeKind::kTimestamp: - EXPECT_EQ(parsed.well_known(), WellKnownTypeSpec::kTimestamp); - break; - case TypeKind::kDuration: - EXPECT_EQ(parsed.well_known(), WellKnownTypeSpec::kDuration); - break; - case TypeKind::kList: - EXPECT_TRUE(parsed.has_list_type()); - if (!original.GetParameters().empty()) { - VerifyParsedMatchesType(parsed.list_type().elem_type(), - original.GetParameters()[0]); - } - break; - case TypeKind::kMap: - EXPECT_TRUE(parsed.has_map_type()); - if (!original.GetParameters().empty()) { - VerifyParsedMatchesType(parsed.map_type().key_type(), - original.GetParameters()[0]); - } - if (original.GetParameters().size() > 1) { - VerifyParsedMatchesType(parsed.map_type().value_type(), - original.GetParameters()[1]); - } - break; - case TypeKind::kBoolWrapper: - case TypeKind::kIntWrapper: - case TypeKind::kUintWrapper: - case TypeKind::kDoubleWrapper: - case TypeKind::kStringWrapper: - case TypeKind::kBytesWrapper: - EXPECT_TRUE(parsed.has_wrapper()); - break; - case TypeKind::kType: - EXPECT_TRUE(parsed.has_type()); - if (!original.GetParameters().empty()) { - VerifyParsedMatchesType(parsed.type(), original.GetParameters()[0]); - } - break; - case TypeKind::kTypeParam: - EXPECT_TRUE(parsed.has_type_param()); - break; - default: - EXPECT_TRUE(parsed.has_abstract_type()); - break; - } +void VerifyParsedMatchesType(const TypeSpec& parsed, const TypeSpec& expected) { + EXPECT_EQ(parsed, expected); } - void VerifyTypesEqual(const Type& lhs, const Type& rhs) { EXPECT_EQ(lhs.kind(), rhs.kind()); if (lhs.kind() != rhs.kind()) return; @@ -238,11 +166,11 @@ std::vector GetTypeSignatureTestCases() { TEST(TypeSignatureTest, UnsupportedTypes) { EXPECT_THAT(common_internal::MakeTypeSignature(UnknownType{}), StatusIs(absl::StatusCode::kInvalidArgument, - HasSubstr("Type kind: *unknown* is not supported"))); + HasSubstr("Unsupported Type kind: *unknown*"))); EXPECT_THAT(common_internal::MakeTypeSignature(ErrorType{}), StatusIs(absl::StatusCode::kInvalidArgument, - HasSubstr("Type kind: *error* is not supported"))); + HasSubstr("Unsupported type in signature: *error*"))); } INSTANTIATE_TEST_SUITE_P(TypeIdTest, TypeSignatureTest, @@ -260,7 +188,7 @@ TEST_P(TypeSignatureTest, ParseTypeCheck) { struct OverloadSignatureTestCase { std::string function_name = "hello"; - std::vector args; + std::vector args; bool is_member = false; std::string expected_signature; std::string expected_error; @@ -285,98 +213,109 @@ TEST_P(OverloadSignatureTest, OverloadSignature) { std::vector GetOverloadSignatureTestCases() { return { { - .args = {StringType{}}, + .args = {TypeSpec(PrimitiveType::kString)}, .expected_signature = "hello(string)", }, { - .args = {IntType{}, UintType{}}, + .args = {TypeSpec(PrimitiveType::kInt64), + TypeSpec(PrimitiveType::kUint64)}, .expected_signature = "hello(int,uint)", }, { - .args = {ListType(GetTestArena(), StringType{})}, + .args = {TypeSpec(ListTypeSpec( + std::make_unique(PrimitiveType::kString)))}, .expected_signature = "hello(list)", }, { - .args = {ListType(GetTestArena(), TypeParamType("A"))}, + .args = {TypeSpec( + ListTypeSpec(std::make_unique(ParamTypeSpec("A"))))}, .expected_signature = "hello(list<~A>)", }, { - .args = {MapType(GetTestArena(), IntType{}, DynType{})}, + .args = {TypeSpec( + MapTypeSpec(std::make_unique(PrimitiveType::kInt64), + std::make_unique(DynTypeSpec())))}, .expected_signature = "hello(map)", }, { - .args = {MapType(GetTestArena(), TypeParamType("B"), - TypeParamType("C"))}, + .args = {TypeSpec( + MapTypeSpec(std::make_unique(ParamTypeSpec("B")), + std::make_unique(ParamTypeSpec("C"))))}, .expected_signature = "hello(map<~B,~C>)", }, { - .args = {OpaqueType( - GetTestArena(), "bar", - {FunctionType(GetTestArena(), TypeParamType("D"), {})})}, + .args = {TypeSpec(AbstractType( + "bar", + {TypeSpec(FunctionTypeSpec( + std::make_unique(ParamTypeSpec("D")), {}))}))}, .expected_signature = "hello(bar>)", }, { - .args = {AnyType{}}, + .args = {TypeSpec(WellKnownTypeSpec::kAny)}, .expected_signature = "hello(any)", }, { - .args = {DurationType{}}, + .args = {TypeSpec(WellKnownTypeSpec::kDuration)}, .expected_signature = "hello(duration)", }, { - .args = {TimestampType{}}, + .args = {TypeSpec(WellKnownTypeSpec::kTimestamp)}, .expected_signature = "hello(timestamp)", }, { - .args = {BoolWrapperType{}}, + .args = {TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kBool))}, .expected_signature = "hello(bool_wrapper)", }, { - .args = {IntWrapperType{}}, + .args = {TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kInt64))}, .expected_signature = "hello(int_wrapper)", }, { - .args = {UintWrapperType{}}, + .args = {TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kUint64))}, .expected_signature = "hello(uint_wrapper)", }, { - .args = {MessageType( - GetTestingDescriptorPool()->FindMessageTypeByName( - "cel.expr.conformance.proto3.TestAllTypes"))}, + .args = {TypeSpec( + AbstractType("cel.expr.conformance.proto3.TestAllTypes", {}))}, .expected_signature = "hello(cel.expr.conformance.proto3.TestAllTypes)", }, { - .args = {StringType{}}, + .args = {TypeSpec(PrimitiveType::kString)}, .is_member = true, .expected_signature = "string.hello()", }, { - .args = {StringType{}, ListType(GetTestArena(), BoolType{})}, + .args = {TypeSpec(PrimitiveType::kString), + TypeSpec(ListTypeSpec( + std::make_unique(PrimitiveType::kBool)))}, .is_member = true, .expected_signature = "string.hello(list)", }, { - .args = {StringType{}, BoolType{}, DynType{}}, + .args = {TypeSpec(PrimitiveType::kString), + TypeSpec(PrimitiveType::kBool), TypeSpec(DynTypeSpec())}, .is_member = true, .expected_signature = "string.hello(bool,dyn)", }, { .function_name = "hello", - .args = {OpaqueType(GetTestArena(), "bar", - {TypeParamType("dummy.type")})}, + .args = {TypeSpec( + AbstractType("bar", {TypeSpec(ParamTypeSpec("dummy.type"))}))}, .is_member = true, .expected_signature = R"(bar<~dummy\.type>.hello())", }, { .function_name = "inspect", - .args = {Type(TypeType(GetTestArena(), StringType{}))}, + .args = {TypeSpec( + std::make_unique(PrimitiveType::kString))}, .expected_signature = "inspect(type)", }, { .function_name = R"(h.(e),l\o)", - .args = {StringType{}, - ListType(GetTestArena(), TypeParamType(R"(a,b..(d)\e)"))}, + .args = {TypeSpec(PrimitiveType::kString), + TypeSpec(ListTypeSpec(std::make_unique( + ParamTypeSpec(R"(a,b..(d)\e)"))))}, .is_member = true, .expected_signature = R"(string.h\.\(e\)\,l\\\o(list<~a\,b\.\\.\(d\)\\e>))", @@ -385,7 +324,8 @@ std::vector GetOverloadSignatureTestCases() { } TEST(OverloadSignatureTest, MemberFunctionNoReceiverError) { - auto signature = common_internal::MakeOverloadSignature("hello", {}, true); + auto signature = common_internal::MakeOverloadSignature( + "hello", std::vector{}, true); EXPECT_THAT(signature, StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Member function with no receiver"))); @@ -700,5 +640,19 @@ TEST(ParseSignatureTest, EmptyOrWhitespaceErrors) { HasSubstr("Empty type signature"))); } +TEST(OverloadSignatureTest, TypeArgumentArray) { + std::vector args; + args.push_back(Type(IntType())); + args.push_back(Type(StringType())); + args.push_back(Type(ListType(GetTestArena(), IntType()))); + args.push_back( + Type(MessageType(GetTestingDescriptorPool()->FindMessageTypeByName( + "cel.expr.conformance.proto3.TestAllTypes")))); + args.push_back(Type(OpaqueType(GetTestArena(), "Foo", {TypeParamType("T")}))); + ASSERT_OK_AND_ASSIGN(auto sig, MakeOverloadSignature("foo", args, false)); + EXPECT_EQ(sig, + "foo(int,string,list,cel.expr.conformance.proto3.TestAllTypes," + "Foo<~T>)"); +} } // namespace } // namespace cel::common_internal diff --git a/common/type_spec_resolver.cc b/common/type_spec_resolver.cc index 97451f390..98df66b77 100644 --- a/common/type_spec_resolver.cc +++ b/common/type_spec_resolver.cc @@ -14,6 +14,7 @@ #include "common/type_spec_resolver.h" +#include #include #include #include @@ -22,8 +23,12 @@ #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "common/ast.h" #include "common/type.h" +#include "common/type_kind.h" #include "internal/status_macros.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" namespace cel { @@ -179,4 +184,104 @@ absl::StatusOr ConvertTypeSpecToType(const TypeSpec& type_spec, return absl::InvalidArgumentError("Unknown TypeSpec kind"); } +absl::StatusOr ConvertTypeToTypeSpec(const Type& type) { + switch (type.kind()) { + case TypeKind::kNull: + return TypeSpec(NullTypeSpec{}); + case TypeKind::kDyn: + return TypeSpec(DynTypeSpec{}); + case TypeKind::kBool: + return TypeSpec(PrimitiveType::kBool); + case TypeKind::kInt: + return TypeSpec(PrimitiveType::kInt64); + case TypeKind::kUint: + return TypeSpec(PrimitiveType::kUint64); + case TypeKind::kDouble: + return TypeSpec(PrimitiveType::kDouble); + case TypeKind::kString: + return TypeSpec(PrimitiveType::kString); + case TypeKind::kBytes: + return TypeSpec(PrimitiveType::kBytes); + case TypeKind::kAny: + return TypeSpec(WellKnownTypeSpec::kAny); + case TypeKind::kTimestamp: + return TypeSpec(WellKnownTypeSpec::kTimestamp); + case TypeKind::kDuration: + return TypeSpec(WellKnownTypeSpec::kDuration); + case TypeKind::kBoolWrapper: + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kBool)); + case TypeKind::kIntWrapper: + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kInt64)); + case TypeKind::kUintWrapper: + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kUint64)); + case TypeKind::kDoubleWrapper: + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kDouble)); + case TypeKind::kStringWrapper: + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kString)); + case TypeKind::kBytesWrapper: + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kBytes)); + case TypeKind::kList: { + CEL_ASSIGN_OR_RETURN(auto elem_type, + ConvertTypeToTypeSpec(type.GetList().element())); + return TypeSpec( + ListTypeSpec(std::make_unique(std::move(elem_type)))); + } + case TypeKind::kMap: { + CEL_ASSIGN_OR_RETURN(auto key_type, + ConvertTypeToTypeSpec(type.GetMap().key())); + CEL_ASSIGN_OR_RETURN(auto value_type, + ConvertTypeToTypeSpec(type.GetMap().value())); + return TypeSpec( + MapTypeSpec(std::make_unique(std::move(key_type)), + std::make_unique(std::move(value_type)))); + } + case TypeKind::kFunction: { + auto func_type = type.GetFunction(); + CEL_ASSIGN_OR_RETURN(auto result_type, + ConvertTypeToTypeSpec(func_type.result())); + std::vector arg_types; + arg_types.reserve(func_type.args().size()); + for (const auto& arg : func_type.args()) { + CEL_ASSIGN_OR_RETURN(auto arg_type, ConvertTypeToTypeSpec(arg)); + arg_types.push_back(std::move(arg_type)); + } + return TypeSpec( + FunctionTypeSpec(std::make_unique(std::move(result_type)), + std::move(arg_types))); + } + case TypeKind::kTypeParam: + return TypeSpec(ParamTypeSpec(std::string(type.GetTypeParam().name()))); + case TypeKind::kStruct: { + if (type.IsMessage()) { + return TypeSpec(MessageTypeSpec(std::string(type.GetMessage().name()))); + } + return absl::InvalidArgumentError("Unsupported struct type"); + } + case TypeKind::kOpaque: { + auto opaque_type = type.GetOpaque(); + std::vector params; + params.reserve(opaque_type.GetParameters().size()); + for (const auto& param : opaque_type.GetParameters()) { + CEL_ASSIGN_OR_RETURN(auto param_type, ConvertTypeToTypeSpec(param)); + params.push_back(std::move(param_type)); + } + return TypeSpec( + AbstractType(std::string(opaque_type.name()), std::move(params))); + } + case TypeKind::kType: { + CEL_ASSIGN_OR_RETURN(auto nested_type, + ConvertTypeToTypeSpec(type.GetType().GetType())); + return TypeSpec(std::make_unique(std::move(nested_type))); + } + case TypeKind::kError: + return TypeSpec(ErrorTypeSpec::kValue); + case TypeKind::kEnum: + return TypeSpec( + AbstractType(std::string(type.GetEnum().name()), /*params=*/{})); + default: + return absl::InvalidArgumentError(absl::StrCat( + "Unsupported Type kind: ", TypeKindToString(type.kind()))); + } +} + } // namespace cel diff --git a/common/type_spec_resolver.h b/common/type_spec_resolver.h index 44e1e088f..edbfa3bde 100644 --- a/common/type_spec_resolver.h +++ b/common/type_spec_resolver.h @@ -32,6 +32,9 @@ absl::StatusOr ConvertTypeSpecToType(const TypeSpec& type_spec, google::protobuf::Arena* arena, const google::protobuf::DescriptorPool& pool); +// Resolves a `cel::Type` to a `cel::TypeSpec`. +absl::StatusOr ConvertTypeToTypeSpec(const Type& type); + } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_TYPE_SPEC_RESOLVER_H_ diff --git a/common/type_spec_resolver_test.cc b/common/type_spec_resolver_test.cc index c7fbb2cf8..1cda7280f 100644 --- a/common/type_spec_resolver_test.cc +++ b/common/type_spec_resolver_test.cc @@ -23,6 +23,7 @@ #include "absl/base/no_destructor.h" #include "absl/status/status.h" #include "absl/status/status_matchers.h" +#include "common/ast.h" #include "common/type.h" #include "common/type_kind.h" #include "internal/testing.h" @@ -33,6 +34,7 @@ namespace cel { namespace { using ::absl_testing::IsOk; +using ::absl_testing::IsOkAndHolds; using ::absl_testing::StatusIs; using ::cel::internal::GetTestingDescriptorPool; using ::testing::HasSubstr; @@ -67,6 +69,7 @@ TEST_P(ConversionTest, TestTypeSpecConversion) { auto t, ConvertTypeSpecToType(std::get<0>(GetParam()), GetTestArena(), *GetTestingDescriptorPool())); EXPECT_EQ(t.kind(), std::get<1>(GetParam())); + EXPECT_THAT(ConvertTypeToTypeSpec(t), IsOkAndHolds(std::get<0>(GetParam()))); } INSTANTIATE_TEST_SUITE_P( @@ -104,6 +107,8 @@ TEST(TypeSpecResolverTest, ListTypeConversion) { ASSERT_THAT(t, IsOk()); EXPECT_TRUE(t->IsList()); EXPECT_TRUE(t->GetList().element().IsInt()); + ASSERT_OK_AND_ASSIGN(auto spec2, ConvertTypeToTypeSpec(*t)); + EXPECT_EQ(spec2, spec); } TEST(TypeSpecResolverTest, MapTypeConversion) { @@ -116,6 +121,8 @@ TEST(TypeSpecResolverTest, MapTypeConversion) { EXPECT_TRUE(t->IsMap()); EXPECT_TRUE(t->GetMap().key().IsString()); EXPECT_TRUE(t->GetMap().value().IsBytes()); + ASSERT_OK_AND_ASSIGN(auto spec2, ConvertTypeToTypeSpec(*t)); + EXPECT_EQ(spec2, spec); } TEST(TypeSpecResolverTest, FunctionTypeConversion) { @@ -129,6 +136,8 @@ TEST(TypeSpecResolverTest, FunctionTypeConversion) { EXPECT_TRUE(t->IsFunction()); EXPECT_EQ(t->GetFunction().args().size(), 1); EXPECT_TRUE(t->GetFunction().result().IsBool()); + ASSERT_OK_AND_ASSIGN(auto spec2, ConvertTypeToTypeSpec(*t)); + EXPECT_EQ(spec2, spec); } TEST(TypeSpecResolverTest, TypeParamConversion) { @@ -138,6 +147,8 @@ TEST(TypeSpecResolverTest, TypeParamConversion) { ASSERT_THAT(t, IsOk()); EXPECT_TRUE(t->IsTypeParam()); EXPECT_EQ(t->GetTypeParam().name(), "T"); + ASSERT_OK_AND_ASSIGN(auto spec2, ConvertTypeToTypeSpec(*t)); + EXPECT_EQ(spec2, spec); } TEST(TypeSpecResolverTest, MessageTypeConversion) { @@ -148,6 +159,10 @@ TEST(TypeSpecResolverTest, MessageTypeConversion) { ASSERT_THAT(t, IsOk()); EXPECT_TRUE(t->IsMessage()); EXPECT_EQ(t->name(), "cel.expr.conformance.proto3.TestAllTypes"); + ASSERT_OK_AND_ASSIGN(auto spec2, ConvertTypeToTypeSpec(*t)); + EXPECT_EQ( + spec2, + TypeSpec(MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes"))); } TEST(TypeSpecResolverTest, MessageTypeWithParamsError) { @@ -172,6 +187,8 @@ TEST(TypeSpecResolverTest, UnresolvedAbstractTypeFallbackToOpaque) { EXPECT_EQ(t->name(), "my.custom.OpaqueType"); EXPECT_EQ(t->GetParameters().size(), 1); EXPECT_TRUE(t->GetParameters()[0].IsInt()); + ASSERT_OK_AND_ASSIGN(auto spec2, ConvertTypeToTypeSpec(*t)); + EXPECT_EQ(spec2, spec); } TEST(TypeSpecResolverTest, OptionalType) { @@ -186,6 +203,8 @@ TEST(TypeSpecResolverTest, OptionalType) { EXPECT_EQ(t->GetParameters().size(), 1); EXPECT_TRUE(t->GetParameters()[0].IsInt()); EXPECT_TRUE(t->IsOptional()); + ASSERT_OK_AND_ASSIGN(auto spec2, ConvertTypeToTypeSpec(*t)); + EXPECT_EQ(spec2, spec); } TEST(TypeSpecResolverTest, TypeTypeConversion) { @@ -196,6 +215,8 @@ TEST(TypeSpecResolverTest, TypeTypeConversion) { ASSERT_THAT(t, IsOk()); EXPECT_TRUE(t->IsType()); EXPECT_TRUE(t->GetType().GetType().IsInt()); + ASSERT_OK_AND_ASSIGN(auto spec2, ConvertTypeToTypeSpec(*t)); + EXPECT_EQ(spec2, spec); } TEST(TypeSpecResolverTest, ErrorTypeConversion) { @@ -204,6 +225,8 @@ TEST(TypeSpecResolverTest, ErrorTypeConversion) { ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); ASSERT_THAT(t, IsOk()); EXPECT_TRUE(t->IsError()); + ASSERT_OK_AND_ASSIGN(auto spec2, ConvertTypeToTypeSpec(*t)); + EXPECT_EQ(spec2, spec); } TEST(TypeSpecResolverTest, MessageTypeSpecConversion) { @@ -213,6 +236,8 @@ TEST(TypeSpecResolverTest, MessageTypeSpecConversion) { ASSERT_THAT(t, IsOk()); EXPECT_TRUE(t->IsMessage()); EXPECT_EQ(t->name(), "cel.expr.conformance.proto3.TestAllTypes"); + ASSERT_OK_AND_ASSIGN(auto spec2, ConvertTypeToTypeSpec(*t)); + EXPECT_EQ(spec2, spec); } TEST(TypeSpecResolverTest, MessageTypeSpecNotFoundError) { @@ -231,6 +256,8 @@ TEST(TypeSpecResolverTest, EnumTypeConversion) { ASSERT_THAT(t, IsOk()); EXPECT_TRUE(t->IsEnum()); EXPECT_EQ(t->name(), "cel.expr.conformance.proto3.TestAllTypes.NestedEnum"); + ASSERT_OK_AND_ASSIGN(auto spec2, ConvertTypeToTypeSpec(*t)); + EXPECT_EQ(spec2, spec); } TEST(TypeSpecResolverTest, EnumTypeWithParamsError) {