diff --git a/common/BUILD b/common/BUILD index a016d2cb5..01710329b 100644 --- a/common/BUILD +++ b/common/BUILD @@ -53,6 +53,7 @@ cc_library( deps = [ ":ast", ":type", + ":type_kind", "//internal:status_macros", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", diff --git a/common/ast/metadata.h b/common/ast/metadata.h index 197790ff3..1a69b5b50 100644 --- a/common/ast/metadata.h +++ b/common/ast/metadata.h @@ -573,6 +573,10 @@ class TypeSpec { TypeSpecKind& mutable_type_kind() { return type_kind_; } + bool is_specified() const { + return !absl::holds_alternative(type_kind_); + } + bool has_dyn() const { return absl::holds_alternative(type_kind_); } diff --git a/common/internal/signature.cc b/common/internal/signature.cc index 5c75225f9..5f4af2b14 100644 --- a/common/internal/signature.cc +++ b/common/internal/signature.cc @@ -64,125 +64,144 @@ void AppendEscaped(std::string* result, std::string_view str, bool escape_dot) { } } -absl::Status AppendTypeParameters(std::string* result, const Type& type); +absl::Status AppendTypeDesc(std::string* result, const TypeSpec& type_spec); -// Recursively appends a string representation of the given `type` to `result`. -// Type parameters are enclosed in angle brackets and separated by commas. -// -// Grammar: -// TypeDesc = NamespaceIdentifier [ "<" TypeList ">" ] ; -// NamespaceIdentifier = [ "." ] Identifier { "." Identifier } ; -// TypeList = TypeElem { "," TypeElem } ; -// TypeElem = TypeDesc | TypeParam -// TypeParam = "~" Alpha ; -// Identifier = ( Alpha | "_" ) { AlphaNumeric | "_" } ; -// (* Terminals *) -// Alpha = "a"..."z" | "A"..."Z" ; -// Digit = "0"..."9" ; -// AlphaNumeric = Alpha | Digit ; -// -// For compatibility, the implementation allows unexpected characters in -// type names and parameters and escapes them with a backslash. -absl::Status AppendTypeDesc(std::string* result, const Type& type) { - switch (type.kind()) { - case TypeKind::kNull: - absl::StrAppend(result, "null"); - break; - case TypeKind::kBool: - absl::StrAppend(result, "bool"); - break; - case TypeKind::kInt: - absl::StrAppend(result, "int"); - break; - case TypeKind::kUint: - absl::StrAppend(result, "uint"); - break; - case TypeKind::kDouble: - absl::StrAppend(result, "double"); - break; - case TypeKind::kString: - absl::StrAppend(result, "string"); - break; - case TypeKind::kBytes: - absl::StrAppend(result, "bytes"); - break; - case TypeKind::kDuration: - absl::StrAppend(result, "duration"); - break; - case TypeKind::kTimestamp: - absl::StrAppend(result, "timestamp"); - break; - case TypeKind::kAny: - absl::StrAppend(result, "any"); - break; - case TypeKind::kDyn: - absl::StrAppend(result, "dyn"); - break; - case TypeKind::kBoolWrapper: - absl::StrAppend(result, "bool_wrapper"); - break; - case TypeKind::kIntWrapper: - absl::StrAppend(result, "int_wrapper"); - break; - case TypeKind::kUintWrapper: - absl::StrAppend(result, "uint_wrapper"); - break; - case TypeKind::kDoubleWrapper: - absl::StrAppend(result, "double_wrapper"); - break; - case TypeKind::kStringWrapper: - absl::StrAppend(result, "string_wrapper"); - break; - case TypeKind::kBytesWrapper: - absl::StrAppend(result, "bytes_wrapper"); - break; - case TypeKind::kList: - absl::StrAppend(result, "list"); - CEL_RETURN_IF_ERROR(AppendTypeParameters(result, type)); - break; - case TypeKind::kMap: - absl::StrAppend(result, "map"); - CEL_RETURN_IF_ERROR(AppendTypeParameters(result, type)); - break; - case TypeKind::kFunction: - absl::StrAppend(result, "function"); - CEL_RETURN_IF_ERROR(AppendTypeParameters(result, type)); - break; - case TypeKind::kType: - absl::StrAppend(result, "type"); - CEL_RETURN_IF_ERROR(AppendTypeParameters(result, type)); - break; - case TypeKind::kTypeParam: - absl::StrAppend(result, "~"); - AppendEscaped(result, type.GetTypeParam().name(), /*escape_dot=*/true); - break; - case TypeKind::kOpaque: - AppendEscaped(result, type.name(), /*escape_dot=*/false); - CEL_RETURN_IF_ERROR(AppendTypeParameters(result, type)); - break; - case TypeKind::kStruct: - AppendEscaped(result, type.name(), /*escape_dot=*/false); - CEL_RETURN_IF_ERROR(AppendTypeParameters(result, type)); - break; - default: - return absl::InvalidArgumentError( - absl::StrFormat("Type kind: %s is not supported in CEL declarations", - type.DebugString())); +absl::Status AppendTypeSpecList(std::string* result, + const std::vector& params) { + if (!params.empty()) { + result->push_back('<'); + for (size_t i = 0; i < params.size(); ++i) { + CEL_RETURN_IF_ERROR(AppendTypeDesc(result, params[i])); + if (i < params.size() - 1) { + result->push_back(','); + } + } + result->push_back('>'); } return absl::OkStatus(); } -absl::Status AppendTypeParameters(std::string* result, const Type& type) { - const auto& parameters = type.GetParameters(); - if (!parameters.empty()) { - result->push_back('<'); - for (size_t i = 0; i < parameters.size(); ++i) { - CEL_RETURN_IF_ERROR(AppendTypeDesc(result, parameters[i])); - if (i < parameters.size() - 1) { +absl::Status AppendTypeDesc(std::string* result, const TypeSpec& type_spec) { + if (type_spec.has_null()) { + absl::StrAppend(result, "null"); + } else if (type_spec.has_dyn()) { + absl::StrAppend(result, "dyn"); + } else if (type_spec.has_primitive()) { + switch (type_spec.primitive()) { + case PrimitiveType::kBool: + absl::StrAppend(result, "bool"); + break; + case PrimitiveType::kInt64: + absl::StrAppend(result, "int"); + break; + case PrimitiveType::kUint64: + absl::StrAppend(result, "uint"); + break; + case PrimitiveType::kDouble: + absl::StrAppend(result, "double"); + break; + case PrimitiveType::kString: + absl::StrAppend(result, "string"); + break; + case PrimitiveType::kBytes: + absl::StrAppend(result, "bytes"); + break; + default: + return absl::InvalidArgumentError("Unsupported primitive type"); + } + } else if (type_spec.has_well_known()) { + switch (type_spec.well_known()) { + case WellKnownTypeSpec::kAny: + absl::StrAppend(result, "any"); + break; + case WellKnownTypeSpec::kTimestamp: + absl::StrAppend(result, "timestamp"); + break; + case WellKnownTypeSpec::kDuration: + absl::StrAppend(result, "duration"); + break; + default: + return absl::InvalidArgumentError("Unsupported well-known type"); + } + } else if (type_spec.has_wrapper()) { + switch (type_spec.wrapper()) { + case PrimitiveType::kBool: + absl::StrAppend(result, "bool_wrapper"); + break; + case PrimitiveType::kInt64: + absl::StrAppend(result, "int_wrapper"); + break; + case PrimitiveType::kUint64: + absl::StrAppend(result, "uint_wrapper"); + break; + case PrimitiveType::kDouble: + absl::StrAppend(result, "double_wrapper"); + break; + case PrimitiveType::kString: + absl::StrAppend(result, "string_wrapper"); + break; + case PrimitiveType::kBytes: + absl::StrAppend(result, "bytes_wrapper"); + break; + default: + return absl::InvalidArgumentError("Unsupported wrapper type"); + } + } else if (type_spec.has_list_type()) { + absl::StrAppend(result, "list"); + if (type_spec.list_type().has_elem_type()) { + result->push_back('<'); + CEL_RETURN_IF_ERROR( + AppendTypeDesc(result, type_spec.list_type().elem_type())); + result->push_back('>'); + } + } else if (type_spec.has_map_type()) { + absl::StrAppend(result, "map"); + if (type_spec.map_type().has_key_type() && + type_spec.map_type().has_value_type()) { + result->push_back('<'); + CEL_RETURN_IF_ERROR( + AppendTypeDesc(result, type_spec.map_type().key_type())); + result->push_back(','); + CEL_RETURN_IF_ERROR( + AppendTypeDesc(result, type_spec.map_type().value_type())); + result->push_back('>'); + } + } else if (type_spec.has_function()) { + absl::StrAppend(result, "function"); + if (type_spec.function().has_result_type() || + !type_spec.function().arg_types().empty()) { + result->push_back('<'); + if (type_spec.function().has_result_type()) { + CEL_RETURN_IF_ERROR( + AppendTypeDesc(result, type_spec.function().result_type())); + } else { + absl::StrAppend(result, "dyn"); + } + for (const auto& arg : type_spec.function().arg_types()) { result->push_back(','); + CEL_RETURN_IF_ERROR(AppendTypeDesc(result, arg)); } + result->push_back('>'); } + } else if (type_spec.has_type()) { + absl::StrAppend(result, "type"); + result->push_back('<'); + CEL_RETURN_IF_ERROR(AppendTypeDesc(result, type_spec.type())); result->push_back('>'); + } else if (type_spec.has_type_param()) { + absl::StrAppend(result, "~"); + AppendEscaped(result, type_spec.type_param().type(), /*escape_dot=*/true); + } else if (type_spec.has_abstract_type()) { + AppendEscaped(result, type_spec.abstract_type().name(), + /*escape_dot=*/false); + CEL_RETURN_IF_ERROR(AppendTypeSpecList( + result, type_spec.abstract_type().parameter_types())); + } else if (type_spec.has_message_type()) { + AppendEscaped(result, type_spec.message_type().type(), + /*escape_dot=*/false); + } else { + return absl::InvalidArgumentError(absl::StrCat( + "Unsupported type in signature: ", FormatTypeSpec(type_spec))); } return absl::OkStatus(); } @@ -190,13 +209,32 @@ absl::Status AppendTypeParameters(std::string* result, const Type& type) { absl::StatusOr MakeTypeSignature(const Type& type) { std::string result; - CEL_RETURN_IF_ERROR(AppendTypeDesc(&result, type)); + CEL_ASSIGN_OR_RETURN(TypeSpec type_spec, ConvertTypeToTypeSpec(type)); + CEL_RETURN_IF_ERROR(AppendTypeDesc(&result, type_spec)); + return result; +} + +absl::StatusOr MakeTypeSpecSignature(const TypeSpec& type_spec) { + std::string result; + CEL_RETURN_IF_ERROR(AppendTypeDesc(&result, type_spec)); return result; } absl::StatusOr MakeOverloadSignature( std::string_view function_name, const std::vector& args, bool is_member) { + std::vector arg_type_specs; + arg_type_specs.reserve(args.size()); + for (const auto& arg : args) { + CEL_ASSIGN_OR_RETURN(TypeSpec type_spec, ConvertTypeToTypeSpec(arg)); + arg_type_specs.push_back(type_spec); + } + return MakeOverloadSignature(function_name, arg_type_specs, is_member); +} + +absl::StatusOr MakeOverloadSignature( + std::string_view function_name, const std::vector& args, + bool is_member) { std::string result; if (is_member) { if (!args.empty()) { @@ -589,10 +627,14 @@ absl::StatusOr ParseFunctionSignature( return out; } +absl::StatusOr ParseTypeSpec(std::string_view signature) { + std::string stripped_sig = StripUnescapedWhitespace(signature); + return ParseTypeSignature(stripped_sig); +} + absl::StatusOr ParseType(std::string_view signature, google::protobuf::Arena* arena, const google::protobuf::DescriptorPool& pool) { - std::string stripped_sig = StripUnescapedWhitespace(signature); - CEL_ASSIGN_OR_RETURN(auto type_spec, ParseTypeSignature(stripped_sig)); + CEL_ASSIGN_OR_RETURN(auto type_spec, ParseTypeSpec(signature)); return cel::ConvertTypeSpecToType(type_spec, arena, pool); } diff --git a/common/internal/signature.h b/common/internal/signature.h index 3fdba4b2e..79e963760 100644 --- a/common/internal/signature.h +++ b/common/internal/signature.h @@ -37,6 +37,16 @@ namespace cel::common_internal { // - `list>` absl::StatusOr MakeTypeSignature(const Type& type); +// Generates an signature for a `cel::TypeSpec`, which is a string +// representation of the type. +// +// Examples: +// +// - `int` +// - `list` +// - `list>` +absl::StatusOr MakeTypeSpecSignature(const TypeSpec& type_spec); + // Generates an identifier for a function overload based on the function name // and the types of the arguments. If `is_member` is true, the first argument // type is used as the receiver and is prepended to the function name, followed @@ -59,6 +69,15 @@ absl::StatusOr MakeOverloadSignature( std::string_view function_name, const std::vector& args, bool is_member); +// Generates an identifier for a function overload based on the function name +// and the types of the arguments. See above for more details. +absl::StatusOr MakeOverloadSignature( + std::string_view function_name, const std::vector& args, + bool is_member); + +// Parses a string type signature directly into a `cel::TypeSpec`. +absl::StatusOr ParseTypeSpec(std::string_view signature); + // Parses a string type signature directly into a `cel::Type`. absl::StatusOr ParseType(std::string_view signature, google::protobuf::Arena* arena, const google::protobuf::DescriptorPool& pool); diff --git a/common/internal/signature_test.cc b/common/internal/signature_test.cc index 765055f75..3399a8f78 100644 --- a/common/internal/signature_test.cc +++ b/common/internal/signature_test.cc @@ -14,6 +14,7 @@ // limitations under the License. #include +#include #include #include @@ -42,82 +43,9 @@ google::protobuf::Arena* GetTestArena() { return &*arena; } -void VerifyParsedMatchesType(const TypeSpec& parsed, const Type& original) { - switch (original.kind()) { - case TypeKind::kDyn: - EXPECT_TRUE(parsed.has_dyn()); - break; - case TypeKind::kNull: - EXPECT_TRUE(parsed.has_null()); - break; - case TypeKind::kBool: - EXPECT_EQ(parsed.primitive(), PrimitiveType::kBool); - break; - case TypeKind::kInt: - EXPECT_EQ(parsed.primitive(), PrimitiveType::kInt64); - break; - case TypeKind::kUint: - EXPECT_EQ(parsed.primitive(), PrimitiveType::kUint64); - break; - case TypeKind::kDouble: - EXPECT_EQ(parsed.primitive(), PrimitiveType::kDouble); - break; - case TypeKind::kString: - EXPECT_EQ(parsed.primitive(), PrimitiveType::kString); - break; - case TypeKind::kBytes: - EXPECT_EQ(parsed.primitive(), PrimitiveType::kBytes); - break; - case TypeKind::kAny: - EXPECT_EQ(parsed.well_known(), WellKnownTypeSpec::kAny); - break; - case TypeKind::kTimestamp: - EXPECT_EQ(parsed.well_known(), WellKnownTypeSpec::kTimestamp); - break; - case TypeKind::kDuration: - EXPECT_EQ(parsed.well_known(), WellKnownTypeSpec::kDuration); - break; - case TypeKind::kList: - EXPECT_TRUE(parsed.has_list_type()); - if (!original.GetParameters().empty()) { - VerifyParsedMatchesType(parsed.list_type().elem_type(), - original.GetParameters()[0]); - } - break; - case TypeKind::kMap: - EXPECT_TRUE(parsed.has_map_type()); - if (!original.GetParameters().empty()) { - VerifyParsedMatchesType(parsed.map_type().key_type(), - original.GetParameters()[0]); - } - if (original.GetParameters().size() > 1) { - VerifyParsedMatchesType(parsed.map_type().value_type(), - original.GetParameters()[1]); - } - break; - case TypeKind::kBoolWrapper: - case TypeKind::kIntWrapper: - case TypeKind::kUintWrapper: - case TypeKind::kDoubleWrapper: - case TypeKind::kStringWrapper: - case TypeKind::kBytesWrapper: - EXPECT_TRUE(parsed.has_wrapper()); - break; - case TypeKind::kType: - EXPECT_TRUE(parsed.has_type()); - if (!original.GetParameters().empty()) { - VerifyParsedMatchesType(parsed.type(), original.GetParameters()[0]); - } - break; - case TypeKind::kTypeParam: - EXPECT_TRUE(parsed.has_type_param()); - break; - default: - EXPECT_TRUE(parsed.has_abstract_type()); - break; - } +void VerifyParsedMatchesType(const TypeSpec& parsed, const TypeSpec& expected) { + EXPECT_EQ(parsed, expected); } - void VerifyTypesEqual(const Type& lhs, const Type& rhs) { EXPECT_EQ(lhs.kind(), rhs.kind()); if (lhs.kind() != rhs.kind()) return; @@ -238,11 +166,11 @@ std::vector GetTypeSignatureTestCases() { TEST(TypeSignatureTest, UnsupportedTypes) { EXPECT_THAT(common_internal::MakeTypeSignature(UnknownType{}), StatusIs(absl::StatusCode::kInvalidArgument, - HasSubstr("Type kind: *unknown* is not supported"))); + HasSubstr("Unsupported Type kind: *unknown*"))); EXPECT_THAT(common_internal::MakeTypeSignature(ErrorType{}), StatusIs(absl::StatusCode::kInvalidArgument, - HasSubstr("Type kind: *error* is not supported"))); + HasSubstr("Unsupported type in signature: *error*"))); } INSTANTIATE_TEST_SUITE_P(TypeIdTest, TypeSignatureTest, @@ -260,7 +188,7 @@ TEST_P(TypeSignatureTest, ParseTypeCheck) { struct OverloadSignatureTestCase { std::string function_name = "hello"; - std::vector args; + std::vector args; bool is_member = false; std::string expected_signature; std::string expected_error; @@ -285,98 +213,109 @@ TEST_P(OverloadSignatureTest, OverloadSignature) { std::vector GetOverloadSignatureTestCases() { return { { - .args = {StringType{}}, + .args = {TypeSpec(PrimitiveType::kString)}, .expected_signature = "hello(string)", }, { - .args = {IntType{}, UintType{}}, + .args = {TypeSpec(PrimitiveType::kInt64), + TypeSpec(PrimitiveType::kUint64)}, .expected_signature = "hello(int,uint)", }, { - .args = {ListType(GetTestArena(), StringType{})}, + .args = {TypeSpec(ListTypeSpec( + std::make_unique(PrimitiveType::kString)))}, .expected_signature = "hello(list)", }, { - .args = {ListType(GetTestArena(), TypeParamType("A"))}, + .args = {TypeSpec( + ListTypeSpec(std::make_unique(ParamTypeSpec("A"))))}, .expected_signature = "hello(list<~A>)", }, { - .args = {MapType(GetTestArena(), IntType{}, DynType{})}, + .args = {TypeSpec( + MapTypeSpec(std::make_unique(PrimitiveType::kInt64), + std::make_unique(DynTypeSpec())))}, .expected_signature = "hello(map)", }, { - .args = {MapType(GetTestArena(), TypeParamType("B"), - TypeParamType("C"))}, + .args = {TypeSpec( + MapTypeSpec(std::make_unique(ParamTypeSpec("B")), + std::make_unique(ParamTypeSpec("C"))))}, .expected_signature = "hello(map<~B,~C>)", }, { - .args = {OpaqueType( - GetTestArena(), "bar", - {FunctionType(GetTestArena(), TypeParamType("D"), {})})}, + .args = {TypeSpec(AbstractType( + "bar", + {TypeSpec(FunctionTypeSpec( + std::make_unique(ParamTypeSpec("D")), {}))}))}, .expected_signature = "hello(bar>)", }, { - .args = {AnyType{}}, + .args = {TypeSpec(WellKnownTypeSpec::kAny)}, .expected_signature = "hello(any)", }, { - .args = {DurationType{}}, + .args = {TypeSpec(WellKnownTypeSpec::kDuration)}, .expected_signature = "hello(duration)", }, { - .args = {TimestampType{}}, + .args = {TypeSpec(WellKnownTypeSpec::kTimestamp)}, .expected_signature = "hello(timestamp)", }, { - .args = {BoolWrapperType{}}, + .args = {TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kBool))}, .expected_signature = "hello(bool_wrapper)", }, { - .args = {IntWrapperType{}}, + .args = {TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kInt64))}, .expected_signature = "hello(int_wrapper)", }, { - .args = {UintWrapperType{}}, + .args = {TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kUint64))}, .expected_signature = "hello(uint_wrapper)", }, { - .args = {MessageType( - GetTestingDescriptorPool()->FindMessageTypeByName( - "cel.expr.conformance.proto3.TestAllTypes"))}, + .args = {TypeSpec( + AbstractType("cel.expr.conformance.proto3.TestAllTypes", {}))}, .expected_signature = "hello(cel.expr.conformance.proto3.TestAllTypes)", }, { - .args = {StringType{}}, + .args = {TypeSpec(PrimitiveType::kString)}, .is_member = true, .expected_signature = "string.hello()", }, { - .args = {StringType{}, ListType(GetTestArena(), BoolType{})}, + .args = {TypeSpec(PrimitiveType::kString), + TypeSpec(ListTypeSpec( + std::make_unique(PrimitiveType::kBool)))}, .is_member = true, .expected_signature = "string.hello(list)", }, { - .args = {StringType{}, BoolType{}, DynType{}}, + .args = {TypeSpec(PrimitiveType::kString), + TypeSpec(PrimitiveType::kBool), TypeSpec(DynTypeSpec())}, .is_member = true, .expected_signature = "string.hello(bool,dyn)", }, { .function_name = "hello", - .args = {OpaqueType(GetTestArena(), "bar", - {TypeParamType("dummy.type")})}, + .args = {TypeSpec( + AbstractType("bar", {TypeSpec(ParamTypeSpec("dummy.type"))}))}, .is_member = true, .expected_signature = R"(bar<~dummy\.type>.hello())", }, { .function_name = "inspect", - .args = {Type(TypeType(GetTestArena(), StringType{}))}, + .args = {TypeSpec( + std::make_unique(PrimitiveType::kString))}, .expected_signature = "inspect(type)", }, { .function_name = R"(h.(e),l\o)", - .args = {StringType{}, - ListType(GetTestArena(), TypeParamType(R"(a,b..(d)\e)"))}, + .args = {TypeSpec(PrimitiveType::kString), + TypeSpec(ListTypeSpec(std::make_unique( + ParamTypeSpec(R"(a,b..(d)\e)"))))}, .is_member = true, .expected_signature = R"(string.h\.\(e\)\,l\\\o(list<~a\,b\.\\.\(d\)\\e>))", @@ -385,7 +324,8 @@ std::vector GetOverloadSignatureTestCases() { } TEST(OverloadSignatureTest, MemberFunctionNoReceiverError) { - auto signature = common_internal::MakeOverloadSignature("hello", {}, true); + auto signature = common_internal::MakeOverloadSignature( + "hello", std::vector{}, true); EXPECT_THAT(signature, StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Member function with no receiver"))); @@ -700,5 +640,19 @@ TEST(ParseSignatureTest, EmptyOrWhitespaceErrors) { HasSubstr("Empty type signature"))); } +TEST(OverloadSignatureTest, TypeArgumentArray) { + std::vector args; + args.push_back(Type(IntType())); + args.push_back(Type(StringType())); + args.push_back(Type(ListType(GetTestArena(), IntType()))); + args.push_back( + Type(MessageType(GetTestingDescriptorPool()->FindMessageTypeByName( + "cel.expr.conformance.proto3.TestAllTypes")))); + args.push_back(Type(OpaqueType(GetTestArena(), "Foo", {TypeParamType("T")}))); + ASSERT_OK_AND_ASSIGN(auto sig, MakeOverloadSignature("foo", args, false)); + EXPECT_EQ(sig, + "foo(int,string,list,cel.expr.conformance.proto3.TestAllTypes," + "Foo<~T>)"); +} } // namespace } // namespace cel::common_internal diff --git a/common/type_spec_resolver.cc b/common/type_spec_resolver.cc index 97451f390..98df66b77 100644 --- a/common/type_spec_resolver.cc +++ b/common/type_spec_resolver.cc @@ -14,6 +14,7 @@ #include "common/type_spec_resolver.h" +#include #include #include #include @@ -22,8 +23,12 @@ #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "common/ast.h" #include "common/type.h" +#include "common/type_kind.h" #include "internal/status_macros.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" namespace cel { @@ -179,4 +184,104 @@ absl::StatusOr ConvertTypeSpecToType(const TypeSpec& type_spec, return absl::InvalidArgumentError("Unknown TypeSpec kind"); } +absl::StatusOr ConvertTypeToTypeSpec(const Type& type) { + switch (type.kind()) { + case TypeKind::kNull: + return TypeSpec(NullTypeSpec{}); + case TypeKind::kDyn: + return TypeSpec(DynTypeSpec{}); + case TypeKind::kBool: + return TypeSpec(PrimitiveType::kBool); + case TypeKind::kInt: + return TypeSpec(PrimitiveType::kInt64); + case TypeKind::kUint: + return TypeSpec(PrimitiveType::kUint64); + case TypeKind::kDouble: + return TypeSpec(PrimitiveType::kDouble); + case TypeKind::kString: + return TypeSpec(PrimitiveType::kString); + case TypeKind::kBytes: + return TypeSpec(PrimitiveType::kBytes); + case TypeKind::kAny: + return TypeSpec(WellKnownTypeSpec::kAny); + case TypeKind::kTimestamp: + return TypeSpec(WellKnownTypeSpec::kTimestamp); + case TypeKind::kDuration: + return TypeSpec(WellKnownTypeSpec::kDuration); + case TypeKind::kBoolWrapper: + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kBool)); + case TypeKind::kIntWrapper: + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kInt64)); + case TypeKind::kUintWrapper: + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kUint64)); + case TypeKind::kDoubleWrapper: + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kDouble)); + case TypeKind::kStringWrapper: + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kString)); + case TypeKind::kBytesWrapper: + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kBytes)); + case TypeKind::kList: { + CEL_ASSIGN_OR_RETURN(auto elem_type, + ConvertTypeToTypeSpec(type.GetList().element())); + return TypeSpec( + ListTypeSpec(std::make_unique(std::move(elem_type)))); + } + case TypeKind::kMap: { + CEL_ASSIGN_OR_RETURN(auto key_type, + ConvertTypeToTypeSpec(type.GetMap().key())); + CEL_ASSIGN_OR_RETURN(auto value_type, + ConvertTypeToTypeSpec(type.GetMap().value())); + return TypeSpec( + MapTypeSpec(std::make_unique(std::move(key_type)), + std::make_unique(std::move(value_type)))); + } + case TypeKind::kFunction: { + auto func_type = type.GetFunction(); + CEL_ASSIGN_OR_RETURN(auto result_type, + ConvertTypeToTypeSpec(func_type.result())); + std::vector arg_types; + arg_types.reserve(func_type.args().size()); + for (const auto& arg : func_type.args()) { + CEL_ASSIGN_OR_RETURN(auto arg_type, ConvertTypeToTypeSpec(arg)); + arg_types.push_back(std::move(arg_type)); + } + return TypeSpec( + FunctionTypeSpec(std::make_unique(std::move(result_type)), + std::move(arg_types))); + } + case TypeKind::kTypeParam: + return TypeSpec(ParamTypeSpec(std::string(type.GetTypeParam().name()))); + case TypeKind::kStruct: { + if (type.IsMessage()) { + return TypeSpec(MessageTypeSpec(std::string(type.GetMessage().name()))); + } + return absl::InvalidArgumentError("Unsupported struct type"); + } + case TypeKind::kOpaque: { + auto opaque_type = type.GetOpaque(); + std::vector params; + params.reserve(opaque_type.GetParameters().size()); + for (const auto& param : opaque_type.GetParameters()) { + CEL_ASSIGN_OR_RETURN(auto param_type, ConvertTypeToTypeSpec(param)); + params.push_back(std::move(param_type)); + } + return TypeSpec( + AbstractType(std::string(opaque_type.name()), std::move(params))); + } + case TypeKind::kType: { + CEL_ASSIGN_OR_RETURN(auto nested_type, + ConvertTypeToTypeSpec(type.GetType().GetType())); + return TypeSpec(std::make_unique(std::move(nested_type))); + } + case TypeKind::kError: + return TypeSpec(ErrorTypeSpec::kValue); + case TypeKind::kEnum: + return TypeSpec( + AbstractType(std::string(type.GetEnum().name()), /*params=*/{})); + default: + return absl::InvalidArgumentError(absl::StrCat( + "Unsupported Type kind: ", TypeKindToString(type.kind()))); + } +} + } // namespace cel diff --git a/common/type_spec_resolver.h b/common/type_spec_resolver.h index 44e1e088f..edbfa3bde 100644 --- a/common/type_spec_resolver.h +++ b/common/type_spec_resolver.h @@ -32,6 +32,9 @@ absl::StatusOr ConvertTypeSpecToType(const TypeSpec& type_spec, google::protobuf::Arena* arena, const google::protobuf::DescriptorPool& pool); +// Resolves a `cel::Type` to a `cel::TypeSpec`. +absl::StatusOr ConvertTypeToTypeSpec(const Type& type); + } // namespace cel #endif // THIRD_PARTY_CEL_CPP_COMMON_TYPE_SPEC_RESOLVER_H_ diff --git a/common/type_spec_resolver_test.cc b/common/type_spec_resolver_test.cc index c7fbb2cf8..1cda7280f 100644 --- a/common/type_spec_resolver_test.cc +++ b/common/type_spec_resolver_test.cc @@ -23,6 +23,7 @@ #include "absl/base/no_destructor.h" #include "absl/status/status.h" #include "absl/status/status_matchers.h" +#include "common/ast.h" #include "common/type.h" #include "common/type_kind.h" #include "internal/testing.h" @@ -33,6 +34,7 @@ namespace cel { namespace { using ::absl_testing::IsOk; +using ::absl_testing::IsOkAndHolds; using ::absl_testing::StatusIs; using ::cel::internal::GetTestingDescriptorPool; using ::testing::HasSubstr; @@ -67,6 +69,7 @@ TEST_P(ConversionTest, TestTypeSpecConversion) { auto t, ConvertTypeSpecToType(std::get<0>(GetParam()), GetTestArena(), *GetTestingDescriptorPool())); EXPECT_EQ(t.kind(), std::get<1>(GetParam())); + EXPECT_THAT(ConvertTypeToTypeSpec(t), IsOkAndHolds(std::get<0>(GetParam()))); } INSTANTIATE_TEST_SUITE_P( @@ -104,6 +107,8 @@ TEST(TypeSpecResolverTest, ListTypeConversion) { ASSERT_THAT(t, IsOk()); EXPECT_TRUE(t->IsList()); EXPECT_TRUE(t->GetList().element().IsInt()); + ASSERT_OK_AND_ASSIGN(auto spec2, ConvertTypeToTypeSpec(*t)); + EXPECT_EQ(spec2, spec); } TEST(TypeSpecResolverTest, MapTypeConversion) { @@ -116,6 +121,8 @@ TEST(TypeSpecResolverTest, MapTypeConversion) { EXPECT_TRUE(t->IsMap()); EXPECT_TRUE(t->GetMap().key().IsString()); EXPECT_TRUE(t->GetMap().value().IsBytes()); + ASSERT_OK_AND_ASSIGN(auto spec2, ConvertTypeToTypeSpec(*t)); + EXPECT_EQ(spec2, spec); } TEST(TypeSpecResolverTest, FunctionTypeConversion) { @@ -129,6 +136,8 @@ TEST(TypeSpecResolverTest, FunctionTypeConversion) { EXPECT_TRUE(t->IsFunction()); EXPECT_EQ(t->GetFunction().args().size(), 1); EXPECT_TRUE(t->GetFunction().result().IsBool()); + ASSERT_OK_AND_ASSIGN(auto spec2, ConvertTypeToTypeSpec(*t)); + EXPECT_EQ(spec2, spec); } TEST(TypeSpecResolverTest, TypeParamConversion) { @@ -138,6 +147,8 @@ TEST(TypeSpecResolverTest, TypeParamConversion) { ASSERT_THAT(t, IsOk()); EXPECT_TRUE(t->IsTypeParam()); EXPECT_EQ(t->GetTypeParam().name(), "T"); + ASSERT_OK_AND_ASSIGN(auto spec2, ConvertTypeToTypeSpec(*t)); + EXPECT_EQ(spec2, spec); } TEST(TypeSpecResolverTest, MessageTypeConversion) { @@ -148,6 +159,10 @@ TEST(TypeSpecResolverTest, MessageTypeConversion) { ASSERT_THAT(t, IsOk()); EXPECT_TRUE(t->IsMessage()); EXPECT_EQ(t->name(), "cel.expr.conformance.proto3.TestAllTypes"); + ASSERT_OK_AND_ASSIGN(auto spec2, ConvertTypeToTypeSpec(*t)); + EXPECT_EQ( + spec2, + TypeSpec(MessageTypeSpec("cel.expr.conformance.proto3.TestAllTypes"))); } TEST(TypeSpecResolverTest, MessageTypeWithParamsError) { @@ -172,6 +187,8 @@ TEST(TypeSpecResolverTest, UnresolvedAbstractTypeFallbackToOpaque) { EXPECT_EQ(t->name(), "my.custom.OpaqueType"); EXPECT_EQ(t->GetParameters().size(), 1); EXPECT_TRUE(t->GetParameters()[0].IsInt()); + ASSERT_OK_AND_ASSIGN(auto spec2, ConvertTypeToTypeSpec(*t)); + EXPECT_EQ(spec2, spec); } TEST(TypeSpecResolverTest, OptionalType) { @@ -186,6 +203,8 @@ TEST(TypeSpecResolverTest, OptionalType) { EXPECT_EQ(t->GetParameters().size(), 1); EXPECT_TRUE(t->GetParameters()[0].IsInt()); EXPECT_TRUE(t->IsOptional()); + ASSERT_OK_AND_ASSIGN(auto spec2, ConvertTypeToTypeSpec(*t)); + EXPECT_EQ(spec2, spec); } TEST(TypeSpecResolverTest, TypeTypeConversion) { @@ -196,6 +215,8 @@ TEST(TypeSpecResolverTest, TypeTypeConversion) { ASSERT_THAT(t, IsOk()); EXPECT_TRUE(t->IsType()); EXPECT_TRUE(t->GetType().GetType().IsInt()); + ASSERT_OK_AND_ASSIGN(auto spec2, ConvertTypeToTypeSpec(*t)); + EXPECT_EQ(spec2, spec); } TEST(TypeSpecResolverTest, ErrorTypeConversion) { @@ -204,6 +225,8 @@ TEST(TypeSpecResolverTest, ErrorTypeConversion) { ConvertTypeSpecToType(spec, GetTestArena(), *GetTestingDescriptorPool()); ASSERT_THAT(t, IsOk()); EXPECT_TRUE(t->IsError()); + ASSERT_OK_AND_ASSIGN(auto spec2, ConvertTypeToTypeSpec(*t)); + EXPECT_EQ(spec2, spec); } TEST(TypeSpecResolverTest, MessageTypeSpecConversion) { @@ -213,6 +236,8 @@ TEST(TypeSpecResolverTest, MessageTypeSpecConversion) { ASSERT_THAT(t, IsOk()); EXPECT_TRUE(t->IsMessage()); EXPECT_EQ(t->name(), "cel.expr.conformance.proto3.TestAllTypes"); + ASSERT_OK_AND_ASSIGN(auto spec2, ConvertTypeToTypeSpec(*t)); + EXPECT_EQ(spec2, spec); } TEST(TypeSpecResolverTest, MessageTypeSpecNotFoundError) { @@ -231,6 +256,8 @@ TEST(TypeSpecResolverTest, EnumTypeConversion) { ASSERT_THAT(t, IsOk()); EXPECT_TRUE(t->IsEnum()); EXPECT_EQ(t->name(), "cel.expr.conformance.proto3.TestAllTypes.NestedEnum"); + ASSERT_OK_AND_ASSIGN(auto spec2, ConvertTypeToTypeSpec(*t)); + EXPECT_EQ(spec2, spec); } TEST(TypeSpecResolverTest, EnumTypeWithParamsError) { diff --git a/env/BUILD b/env/BUILD index 41ffc1723..f3a8988a4 100644 --- a/env/BUILD +++ b/env/BUILD @@ -28,6 +28,7 @@ cc_library( "type_info.h", ], deps = [ + "//common:ast", "//common:constant", "//common:type", "//common:type_kind", @@ -120,7 +121,9 @@ cc_library( features = ["-use_header_modules"], deps = [ ":config", + "//common:ast", "//common:constant", + "//common/internal:signature", "//internal:status_macros", "//internal:strings", "@com_google_absl//absl/algorithm:container", @@ -178,6 +181,7 @@ cc_test( ":config", "//common:type", "//common:type_proto", + "//common/ast:metadata", "//internal:proto_matchers", "//internal:testing", "//internal:testing_descriptor_pool", @@ -201,7 +205,6 @@ cc_test( "//common:type", "//common:value", "//compiler", - "//internal:proto_matchers", "//internal:status_macros", "//internal:testing", "//internal:testing_descriptor_pool", @@ -241,7 +244,6 @@ cc_test( "//common:value", "//compiler", "//extensions:math_ext", - "//internal:status_macros", "//internal:testing", "//internal:testing_descriptor_pool", "//runtime", diff --git a/env/env_yaml.cc b/env/env_yaml.cc index 159786598..0d7a5f1fc 100644 --- a/env/env_yaml.cc +++ b/env/env_yaml.cc @@ -35,8 +35,11 @@ #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" +#include "common/ast.h" #include "common/constant.h" +#include "common/internal/signature.h" #include "env/config.h" +#include "env/type_info.h" #include "internal/status_macros.h" #include "internal/strings.h" #include "yaml-cpp/emitter.h" @@ -135,6 +138,18 @@ absl::StatusOr GetBool(absl::string_view yaml, absl::string_view key, } } +// Returns the key in the map `node` that has the given `value_node` as its +// value. If no such key exists, returns `value_node` itself. +YAML::Node GetContextNodeForKeyValue(const YAML::Node& node, + const YAML::Node& value_node) { + for (const auto& kv : node) { + if (kv.second.IsDefined() && kv.second.is(value_node)) { + return kv.first; + } + } + return value_node; +} + absl::Status ParseName(Config& config, absl::string_view yaml, const YAML::Node& root) { const YAML::Node name = root["name"]; @@ -407,7 +422,23 @@ absl::Status ParseStandardLibraryConfig(Config& config, absl::string_view yaml, absl::StatusOr ParseTypeInfo(const YAML::Node& node, absl::string_view yaml) { Config::TypeInfo type_config; + const YAML::Node type = node["type"]; const YAML::Node type_name = node["type_name"]; + if (type.IsDefined() && type_name.IsDefined()) { + return YamlError(yaml, GetContextNodeForKeyValue(node, type_name), + "Node 'type' and 'type_name' are mutually exclusive"); + } + + if (type.IsDefined()) { + if (!type.IsScalar()) { + return YamlError(yaml, type, "Node 'type' is not a string"); + } + CEL_ASSIGN_OR_RETURN(auto type_spec, + common_internal::ParseTypeSpec(GetString(yaml, type))); + CEL_ASSIGN_OR_RETURN(auto type_config, TypeSpecToTypeInfo(type_spec)); + return type_config; + } + if (!type_name.IsDefined()) { return type_config; } @@ -627,7 +658,8 @@ absl::Status ParseVariableConfigs(Config& config, absl::string_view yaml, } absl::StatusOr ParseFunctionOverloadConfig( - absl::string_view yaml, const YAML::Node& overload) { + absl::string_view yaml, const YAML::Node& overload, + absl::string_view function_name) { Config::FunctionOverloadConfig overload_config; if (!overload || !overload.IsMap()) { return YamlError(yaml, overload, "Function overload is not a map"); @@ -654,40 +686,89 @@ absl::StatusOr ParseFunctionOverloadConfig( } } + const YAML::Node signature_node = overload["signature"]; const YAML::Node target = overload["target"]; - if (target.IsDefined()) { - if (!target.IsMap()) { - return YamlError(yaml, target, "Function overload target is not a map"); + const YAML::Node args = overload["args"]; + if (signature_node.IsDefined()) { + if (!signature_node.IsScalar()) { + return YamlError(yaml, signature_node, + "Function overload signature is not a string"); } - CEL_ASSIGN_OR_RETURN(Config::TypeInfo type_info, - ParseTypeInfo(target, yaml)); - overload_config.is_member_function = true; - overload_config.parameters.push_back(type_info); - } - const YAML::Node args = overload["args"]; - if (args.IsDefined()) { - if (!args.IsSequence()) { - return YamlError(yaml, args, "Function overload args is not a sequence"); + if (target.IsDefined()) { + return YamlError(yaml, GetContextNodeForKeyValue(overload, target), + "Function overload signature and target are mutually " + "exclusive"); + } + if (args.IsDefined()) { + return YamlError(yaml, GetContextNodeForKeyValue(overload, args), + "Function overload signature and args are mutually " + "exclusive"); + } + + std::string signature = GetString(yaml, signature_node); + CEL_ASSIGN_OR_RETURN( + common_internal::ParsedFunctionOverload parsed_signature, + common_internal::ParseFunctionSignature(signature)); + if (parsed_signature.function_name != function_name) { + return YamlError(yaml, signature_node, + absl::StrCat("Function overload name \"", + parsed_signature.function_name, + "\" does not match function name \"", + function_name, "\"")); + } + overload_config.is_member_function = parsed_signature.is_member; + if (!parsed_signature.signature_type.has_function()) { + return absl::InternalError(absl::StrCat( + "Function overload signature has no function type: ", signature)); } - for (const YAML::Node& arg : args) { - if (!arg.IsMap()) { - return YamlError(yaml, arg, "Function overload arg is not a map"); + const FunctionTypeSpec& function_type_spec = + parsed_signature.signature_type.function(); + for (const auto& arg : function_type_spec.arg_types()) { + CEL_ASSIGN_OR_RETURN(auto type_info, TypeSpecToTypeInfo(arg)); + overload_config.parameters.push_back(std::move(type_info)); + } + } else { + if (target.IsDefined()) { + if (!target.IsMap()) { + return YamlError(yaml, target, "Function overload target is not a map"); } CEL_ASSIGN_OR_RETURN(Config::TypeInfo type_info, - ParseTypeInfo(arg, yaml)); + ParseTypeInfo(target, yaml)); + overload_config.is_member_function = true; overload_config.parameters.push_back(type_info); } - } + if (args.IsDefined()) { + if (!args.IsSequence()) { + return YamlError(yaml, args, + "Function overload args is not a sequence"); + } + for (const YAML::Node& arg : args) { + if (!arg.IsMap()) { + return YamlError(yaml, arg, "Function overload arg is not a map"); + } + CEL_ASSIGN_OR_RETURN(Config::TypeInfo type_info, + ParseTypeInfo(arg, yaml)); + overload_config.parameters.push_back(type_info); + } + } + } const YAML::Node return_type = overload["return"]; if (return_type.IsDefined()) { - if (!return_type.IsMap()) { - return YamlError(yaml, return_type, - "Function overload return type is not a map"); + if (return_type.IsScalar()) { + CEL_ASSIGN_OR_RETURN(auto type_spec, common_internal::ParseTypeSpec( + GetString(yaml, return_type))); + CEL_ASSIGN_OR_RETURN(overload_config.return_type, + TypeSpecToTypeInfo(type_spec)); + } else if (return_type.IsMap()) { + CEL_ASSIGN_OR_RETURN(overload_config.return_type, + ParseTypeInfo(return_type, yaml)); + } else { + return YamlError( + yaml, return_type, + "Function overload return type is neither a string nor a map"); } - CEL_ASSIGN_OR_RETURN(overload_config.return_type, - ParseTypeInfo(return_type, yaml)); } return overload_config; } @@ -728,8 +809,9 @@ absl::Status ParseFunctionConfigs(Config& config, absl::string_view yaml, } for (const YAML::Node& overload : overloads) { - CEL_ASSIGN_OR_RETURN(Config::FunctionOverloadConfig overload_config, - ParseFunctionOverloadConfig(yaml, overload)); + CEL_ASSIGN_OR_RETURN( + Config::FunctionOverloadConfig overload_config, + ParseFunctionOverloadConfig(yaml, overload, function_config.name)); function_config.overload_configs.push_back(std::move(overload_config)); } } @@ -893,26 +975,43 @@ void EmitStandardLibraryConfig(const Config& env_config, YAML::Emitter& out) { out << YAML::EndMap; } -void EmitTypeInfo(const Config::TypeInfo& type_info, YAML::Emitter& out) { +void EmitTypeInfo(const Config::TypeInfo& type_info, YAML::Emitter& out, + const EnvConfigToYamlOptions& options) { // Note: the map is already started when this is called, so we don't emit // BeginMap here or EndMap at the end. - out << YAML::Key << "type_name"; - out << YAML::Value << YAML::DoubleQuoted << type_info.name; - if (type_info.is_type_param) { - out << YAML::Key << "is_type_param" << YAML::Value << true; - } - if (!type_info.params.empty()) { - out << YAML::Key << "params" << YAML::Value << YAML::BeginSeq; - for (const Config::TypeInfo& param : type_info.params) { - out << YAML::BeginMap; - EmitTypeInfo(param, out); - out << YAML::EndMap; + bool signature_generated = false; + if (options.use_type_signatures) { + absl::StatusOr type_spec = TypeInfoToTypeSpec(type_info); + if (type_spec.ok()) { + absl::StatusOr signature = + common_internal::MakeTypeSpecSignature(*type_spec); + if (signature.ok()) { + out << YAML::Key << "type"; + out << YAML::Value << YAML::DoubleQuoted << *signature; + signature_generated = true; + } + } + } + if (!signature_generated) { + out << YAML::Key << "type_name"; + out << YAML::Value << YAML::DoubleQuoted << type_info.name; + if (type_info.is_type_param) { + out << YAML::Key << "is_type_param" << YAML::Value << true; + } + if (!type_info.params.empty()) { + out << YAML::Key << "params" << YAML::Value << YAML::BeginSeq; + for (const Config::TypeInfo& param : type_info.params) { + out << YAML::BeginMap; + EmitTypeInfo(param, out, options); + out << YAML::EndMap; + } + out << YAML::EndSeq; } - out << YAML::EndSeq; } } -void EmitVariableConfigs(const Config& env_config, YAML::Emitter& out) { +void EmitVariableConfigs(const Config& env_config, YAML::Emitter& out, + const EnvConfigToYamlOptions& options) { const auto& variable_configs = env_config.GetVariableConfigs(); if (variable_configs.empty()) { return; @@ -936,7 +1035,7 @@ void EmitVariableConfigs(const Config& env_config, YAML::Emitter& out) { out << YAML::Key << "description"; out << YAML::Value << YAML::DoubleQuoted << variable_config.description; } - EmitTypeInfo(variable_config.type_info, out); + EmitTypeInfo(variable_config.type_info, out, options); if (variable_config.value.has_value()) { const Constant& constant = variable_config.value; switch (constant.kind_case()) { @@ -991,51 +1090,97 @@ void EmitVariableConfigs(const Config& env_config, YAML::Emitter& out) { } void EmitFunctionOverloadConfig( - const Config::FunctionOverloadConfig& overload_config, YAML::Emitter& out) { + absl::string_view function_name, + const Config::FunctionOverloadConfig& overload_config, YAML::Emitter& out, + const EnvConfigToYamlOptions& options) { out << YAML::BeginMap; - out << YAML::Key << "id"; - out << YAML::Value << YAML::DoubleQuoted << overload_config.overload_id; - if (overload_config.is_member_function) { - out << YAML::Key << "target" << YAML::Value; - out << YAML::BeginMap; - if (overload_config.parameters.empty()) { - // This should never happen, but if it does, emit a dynamic type. - EmitTypeInfo({.name = "dyn"}, out); - } else { - EmitTypeInfo(overload_config.parameters[0], out); + if (!overload_config.overload_id.empty()) { + out << YAML::Key << "id"; + out << YAML::Value << YAML::DoubleQuoted << overload_config.overload_id; + } + bool signature_generated = false; + if (options.use_type_signatures) { + bool param_type_spec_generated = true; + std::vector params; + params.reserve(overload_config.parameters.size()); + for (const auto& parameter : overload_config.parameters) { + absl::StatusOr type_spec = TypeInfoToTypeSpec(parameter); + if (!type_spec.ok()) { + param_type_spec_generated = false; + break; + } + params.push_back(std::move(*type_spec)); } - out << YAML::EndMap; - if (overload_config.parameters.size() > 1) { - out << YAML::Key << "args"; - out << YAML::Value << YAML::BeginSeq; - for (size_t i = 1; i < overload_config.parameters.size(); ++i) { - out << YAML::BeginMap; - EmitTypeInfo(overload_config.parameters[i], out); - out << YAML::EndMap; + if (param_type_spec_generated) { + absl::StatusOr signature = + common_internal::MakeOverloadSignature( + function_name, params, overload_config.is_member_function); + if (signature.ok()) { + out << YAML::Key << "signature"; + out << YAML::Value << YAML::DoubleQuoted << *signature; + signature_generated = true; } - out << YAML::EndSeq; } - } else { - if (!overload_config.parameters.empty()) { - out << YAML::Key << "args"; - out << YAML::Value << YAML::BeginSeq; - for (const Config::TypeInfo& parameter : overload_config.parameters) { - out << YAML::BeginMap; - EmitTypeInfo(parameter, out); - out << YAML::EndMap; + } + if (!signature_generated) { + if (overload_config.is_member_function) { + out << YAML::Key << "target" << YAML::Value; + out << YAML::BeginMap; + if (overload_config.parameters.empty()) { + // This should never happen, but if it does, emit a dynamic type. + EmitTypeInfo({.name = "dyn"}, out, options); + } else { + EmitTypeInfo(overload_config.parameters[0], out, options); + } + out << YAML::EndMap; + if (overload_config.parameters.size() > 1) { + out << YAML::Key << "args"; + out << YAML::Value << YAML::BeginSeq; + for (size_t i = 1; i < overload_config.parameters.size(); ++i) { + out << YAML::BeginMap; + EmitTypeInfo(overload_config.parameters[i], out, options); + out << YAML::EndMap; + } + out << YAML::EndSeq; + } + } else { + if (!overload_config.parameters.empty()) { + out << YAML::Key << "args"; + out << YAML::Value << YAML::BeginSeq; + for (const Config::TypeInfo& parameter : overload_config.parameters) { + out << YAML::BeginMap; + EmitTypeInfo(parameter, out, options); + out << YAML::EndMap; + } + out << YAML::EndSeq; } - out << YAML::EndSeq; } } - out << YAML::Key << "return"; - out << YAML::Value << YAML::BeginMap; - EmitTypeInfo(overload_config.return_type, out); - out << YAML::EndMap; - + bool return_type_signature_generated = false; + if (options.use_type_signatures) { + absl::StatusOr type_spec = + TypeInfoToTypeSpec(overload_config.return_type); + if (type_spec.ok()) { + absl::StatusOr signature = + common_internal::MakeTypeSpecSignature(*type_spec); + if (signature.ok()) { + out << YAML::Key << "return"; + out << YAML::Value << YAML::DoubleQuoted << *signature; + return_type_signature_generated = true; + } + } + } + if (!return_type_signature_generated) { + out << YAML::Key << "return"; + out << YAML::Value << YAML::BeginMap; + EmitTypeInfo(overload_config.return_type, out, options); + out << YAML::EndMap; + } out << YAML::EndMap; } -void EmitFunctionConfigs(const Config& env_config, YAML::Emitter& out) { +void EmitFunctionConfigs(const Config& env_config, YAML::Emitter& out, + const EnvConfigToYamlOptions& options) { const std::vector& function_configs = env_config.GetFunctionConfigs(); if (function_configs.empty()) { @@ -1085,7 +1230,8 @@ void EmitFunctionConfigs(const Config& env_config, YAML::Emitter& out) { out << YAML::Key << "overloads" << YAML::Value << YAML::BeginSeq; for (const Config::FunctionOverloadConfig& overload_config : sorted_overloads) { - EmitFunctionOverloadConfig(overload_config, out); + EmitFunctionOverloadConfig(function_config.name, overload_config, out, + options); } out << YAML::EndSeq; } @@ -1116,7 +1262,8 @@ absl::StatusOr EnvConfigFromYaml(const std::string& yaml) { return config; } -void EnvConfigToYaml(const Config& env_config, std::ostream& os) { +void EnvConfigToYaml(const Config& env_config, std::ostream& os, + const EnvConfigToYamlOptions& options) { YAML::Emitter out(os); out.SetIndent(2); out << YAML::BeginMap; @@ -1127,8 +1274,8 @@ void EnvConfigToYaml(const Config& env_config, std::ostream& os) { EmitContainerConfig(env_config, out); EmitExtensionConfigs(env_config, out); EmitStandardLibraryConfig(env_config, out); - EmitVariableConfigs(env_config, out); - EmitFunctionConfigs(env_config, out); + EmitVariableConfigs(env_config, out, options); + EmitFunctionConfigs(env_config, out, options); out << YAML::EndMap; } diff --git a/env/env_yaml.h b/env/env_yaml.h index c96b45933..b3c4c6210 100644 --- a/env/env_yaml.h +++ b/env/env_yaml.h @@ -31,8 +31,42 @@ namespace cel { // expensive expressions. absl::StatusOr EnvConfigFromYaml(const std::string& yaml); +struct EnvConfigToYamlOptions { + // Whether to use type and overload signatures instead of arg/return types in + // the output YAML. + // Example of type signature: "map>" vs + // type_name: "map" + // params: + // - type_name: "int" + // - type_name: "A" + // params: + // - type_name: "B" + // is_type_param: true + // + // Example of overload signature config: + // name: "foo" + // overloads: + // - signature: "timestamp.foo(A<~B>)" + // return: "int" + // vs + // name: "foo" + // overloads: + // - id: "foo_id" + // target: + // type_name: "timestamp" + // args: + // - type_name: "A" + // params: + // - type_name: "B" + // is_type_param: true + // return: + // type_name: "int" + bool use_type_signatures = true; +}; + // EnvConfigToYaml serializes an environment configuration as a YAML string. -void EnvConfigToYaml(const Config& env_config, std::ostream& os); +void EnvConfigToYaml(const Config& env_config, std::ostream& os, + const EnvConfigToYamlOptions& options = {}); } // namespace cel diff --git a/env/env_yaml_test.cc b/env/env_yaml_test.cc index d19c0dbfb..a21649b2e 100644 --- a/env/env_yaml_test.cc +++ b/env/env_yaml_test.cc @@ -195,6 +195,28 @@ TEST(EnvYamlTest, ParseVariableConfigs) { } TEST(EnvYamlTest, ParseVariableConfigWithTypeParams) { + ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(R"yaml( + variables: + - name: "dict" + type: "map" + )yaml")); + + const Config::VariableConfig& variable_config = + config.GetVariableConfigs()[0]; + EXPECT_EQ(variable_config.name, "dict"); + const auto& type_info = variable_config.type_info; + EXPECT_EQ(type_info.name, "map"); + EXPECT_FALSE(type_info.is_type_param); + EXPECT_THAT(type_info.params, SizeIs(2)); + EXPECT_EQ(type_info.params[0].name, "string"); + EXPECT_FALSE(type_info.params[0].is_type_param); + EXPECT_THAT(type_info.params[0].params, IsEmpty()); + EXPECT_EQ(type_info.params[1].name, "A"); + EXPECT_TRUE(type_info.params[1].is_type_param); + EXPECT_THAT(type_info.params[1].params, IsEmpty()); +} + +TEST(EnvYamlTest, ParseVariableConfigWithTypeParamsLegacySyntax) { ASSERT_OK_AND_ASSIGN(Config config, EnvConfigFromYaml(R"yaml( variables: - name: "dict" @@ -221,7 +243,7 @@ TEST(EnvYamlTest, ParseVariableConfigWithTypeParams) { } struct ParseConstantTestCase { - std::string type_name; + std::string type; std::string value; std::string expected_error; // Empty if no error. Constant expected_constant; @@ -236,10 +258,10 @@ TEST_P(EnvYamlParseConstantTest, EnvYamlParseConstant) { R"yaml( variables: - name: "const" - type_name: "%s" + type: "%s" value: %s )yaml", - param.type_name, param.value); + param.type, param.value); absl::StatusOr status_or_config = EnvConfigFromYaml(yaml); if (!param.expected_error.empty()) { EXPECT_THAT(status_or_config, StatusIs(absl::StatusCode::kInvalidArgument, @@ -251,8 +273,7 @@ TEST_P(EnvYamlParseConstantTest, EnvYamlParseConstant) { const Config::VariableConfig& variable_config = config.GetVariableConfigs()[0]; EXPECT_EQ(variable_config.name, "const"); - EXPECT_EQ(variable_config.type_info.name, param.type_name) - << " yaml: " << yaml; + EXPECT_EQ(variable_config.type_info.name, param.type) << " yaml: " << yaml; EXPECT_EQ(variable_config.value, param.expected_constant) << " yaml: " << yaml; } @@ -260,119 +281,119 @@ TEST_P(EnvYamlParseConstantTest, EnvYamlParseConstant) { std::vector GetParseConstantTestCases() { return { ParseConstantTestCase{ - .type_name = "null", + .type = "null", .value = "\"\"", .expected_constant = Constant(nullptr), }, ParseConstantTestCase{ - .type_name = "null", + .type = "null", .value = "anything", .expected_error = "Failed to parse null constant", }, ParseConstantTestCase{ - .type_name = "bool", + .type = "bool", .value = "TRUE", .expected_constant = Constant(true), }, ParseConstantTestCase{ - .type_name = "bool", + .type = "bool", .value = "false", .expected_constant = Constant(false), }, ParseConstantTestCase{ - .type_name = "bool", + .type = "bool", .value = "yes", .expected_error = "Failed to parse bool constant", }, ParseConstantTestCase{ - .type_name = "int", + .type = "int", .value = "42", .expected_constant = Constant(int64_t{42}), }, ParseConstantTestCase{ - .type_name = "int", + .type = "int", .value = "41.999", .expected_error = "Failed to parse int constant", }, ParseConstantTestCase{ - .type_name = "uint", + .type = "uint", .value = "42", .expected_constant = Constant(uint64_t{42}), }, ParseConstantTestCase{ - .type_name = "uint", + .type = "uint", .value = "42u", .expected_constant = Constant(uint64_t{42}), }, ParseConstantTestCase{ - .type_name = "uint", + .type = "uint", .value = "-1", .expected_error = "Failed to parse uint constant", }, ParseConstantTestCase{ - .type_name = "double", + .type = "double", .value = "42.42", .expected_constant = Constant(42.42), }, ParseConstantTestCase{ - .type_name = "double", + .type = "double", .value = "abc", .expected_error = "Failed to parse double constant", }, ParseConstantTestCase{ - .type_name = "bytes", + .type = "bytes", .value = "abc", .expected_constant = Constant(BytesConstant("abc")), }, ParseConstantTestCase{ - .type_name = "bytes", + .type = "bytes", .value = "b\"\\xFF\\x00\\x01\"", .expected_constant = Constant(BytesConstant(absl::string_view("\xff\x00\x01", 3))), }, ParseConstantTestCase{ - .type_name = "bytes", + .type = "bytes", .value = "!!binary /wAB", .expected_constant = Constant(BytesConstant(absl::string_view("\xff\x00\x01", 3))), }, ParseConstantTestCase{ - .type_name = "bytes", + .type = "bytes", .value = "!!binary YWJj=", .expected_error = "Node 'YWJj=' is not a valid Base64 encoded binary", }, ParseConstantTestCase{ - .type_name = "bytes", + .type = "bytes", .value = "abc", .expected_constant = Constant(BytesConstant("abc")), }, ParseConstantTestCase{ - .type_name = "string", + .type = "string", .value = "abc", .expected_constant = Constant(StringConstant("abc")), }, ParseConstantTestCase{ - .type_name = "string", + .type = "string", .value = "\"\\\"abc\\\"\"", .expected_constant = Constant(StringConstant("\"abc\"")), }, ParseConstantTestCase{ - .type_name = "duration", + .type = "duration", .value = "1s", .expected_constant = Constant(absl::Seconds(1)), }, ParseConstantTestCase{ - .type_name = "duration", + .type = "duration", .value = "abc", .expected_error = "Failed to parse duration constant", }, ParseConstantTestCase{ - .type_name = "timestamp", + .type = "timestamp", .value = "2023-01-01T00:00:00Z", .expected_constant = Constant(absl::FromUnixSeconds(1672531200)), }, ParseConstantTestCase{ - .type_name = "timestamp", + .type = "timestamp", .value = "abc", .expected_error = "Failed to parse timestamp constant", }, @@ -439,6 +460,50 @@ TEST_P(EnvYamlParseFunctionTest, EnvYamlParseFunction) { std::vector GetParseFunctionTestCases() { return { + ParseFunctionTestCase{ + .yaml = R"yaml( + functions: + - name: "isEmpty" + description: |- + determines whether a list is empty, + or a string has no characters + overloads: + - signature: "google.protobuf.StringValue.isEmpty()" + examples: + - "''.isEmpty() // true" + return: "bool" + - signature: "list<~T>.isEmpty()" + examples: + - "[].isEmpty() // true" + - "[1].isEmpty() // false" + return: "bool" + )yaml", + .expected_function_config = + { + .name = "isEmpty", + .description = "determines whether a list is empty,\nor a " + "string has no characters", + .overload_configs = + { + Config::FunctionOverloadConfig{ + .examples = {"''.isEmpty() // true"}, + .is_member_function = true, + .parameters = {{.name = "string_wrapper"}}, + .return_type = {.name = "bool"}, + }, + Config::FunctionOverloadConfig{ + .examples = {"[].isEmpty() // true", + "[1].isEmpty() // false"}, + .is_member_function = true, + .parameters = {{.name = "list", + .params = {{.name = "T", + .is_type_param = + true}}}}, + .return_type = {.name = "bool"}, + }, + }, + }, + }, ParseFunctionTestCase{ .yaml = R"yaml( functions: @@ -495,6 +560,34 @@ std::vector GetParseFunctionTestCases() { }, }, }, + ParseFunctionTestCase{ + .yaml = R"yaml( + functions: + - name: "contains" + overloads: + - signature: "contains(list<~T>, ~T)" + examples: + - "contains([1, 2, 3], 2) // true" + return: "bool" + )yaml", + .expected_function_config = + { + .name = "contains", + .overload_configs = + { + Config::FunctionOverloadConfig{ + .examples = {"contains([1, 2, 3], 2) // true"}, + .is_member_function = false, + .parameters = + {{.name = "list", + .params = {{.name = "T", + .is_type_param = true}}}, + {.name = "T", .is_type_param = true}}, + .return_type = {.name = "bool"}, + }, + }, + }, + }, ParseFunctionTestCase{ .yaml = R"yaml( functions: @@ -865,6 +958,18 @@ INSTANTIATE_TEST_SUITE_P( "| is_type_param: maybe\n" "| ^", }, + ParseTestCase{ + .yaml = R"yaml( + variables: + - name: "foo" + type_name: "opaque" + type: "opaque" + )yaml", + .expected_error = "4:19: Node 'type' and 'type_name'" + " are mutually exclusive\n" + "| type_name: \"opaque\"\n" + "| ^", + }, ParseTestCase{ .yaml = R"yaml( variables: @@ -965,12 +1070,25 @@ INSTANTIATE_TEST_SUITE_P( - name: "foo" overloads: - id: "foo_int64" - return: "to sender" + return: [1] )yaml", .expected_error = "6:31: Function overload return type" - " is not a map\n" - "| return: \"to sender\"\n" + " is neither a string nor a map\n" + "| return: [1]\n" "| ^", + }, + ParseTestCase{ + .yaml = R"yaml( + functions: + - name: "foo" + overloads: + - id: "foo_int64" + signature: "bar()" + )yaml", + .expected_error = "6:34: Function overload name \"bar\" " + "does not match function name \"foo\"\n" + "| signature: \"bar()\"\n" + "| ^", })); std::string Unindent(std::string_view yaml) { @@ -999,6 +1117,7 @@ std::string Unindent(std::string_view yaml) { struct ExportTestCase { absl::StatusOr config; std::string expected_yaml; + std::string expected_alt_yaml; }; class EnvYamlExportTest : public testing::TestWithParam {}; @@ -1011,6 +1130,14 @@ TEST_P(EnvYamlExportTest, EnvYamlExport) { std::string yaml_output = Unindent(ss.str()); std::string expected_yaml = Unindent(param.expected_yaml); EXPECT_EQ(yaml_output, expected_yaml); + + if (!param.expected_alt_yaml.empty()) { + std::stringstream alt_ss; + EnvConfigToYaml(config, alt_ss, {.use_type_signatures = false}); + std::string alt_yaml_output = Unindent(alt_ss.str()); + std::string expected_alt_yaml = Unindent(param.expected_alt_yaml); + EXPECT_EQ(alt_yaml_output, expected_alt_yaml); + } } std::vector GetExportTestCases() { @@ -1211,7 +1338,7 @@ std::vector GetExportTestCases() { .expected_yaml = R"yaml( variables: - name: "foo" - type_name: "null" + type: "null" )yaml", }, ExportTestCase{ @@ -1224,6 +1351,12 @@ std::vector GetExportTestCases() { return config; }(), .expected_yaml = R"yaml( + variables: + - name: "foo" + type: "bool" + value: true + )yaml", + .expected_alt_yaml = R"yaml( variables: - name: "foo" type_name: "bool" @@ -1240,6 +1373,12 @@ std::vector GetExportTestCases() { return config; }(), .expected_yaml = R"yaml( + variables: + - name: "foo" + type: "int" + value: 42 + )yaml", + .expected_alt_yaml = R"yaml( variables: - name: "foo" type_name: "int" @@ -1258,7 +1397,7 @@ std::vector GetExportTestCases() { .expected_yaml = R"yaml( variables: - name: "foo" - type_name: "uint" + type: "uint" value: 777 )yaml", }, @@ -1274,7 +1413,7 @@ std::vector GetExportTestCases() { .expected_yaml = R"yaml( variables: - name: "foo" - type_name: "double" + type: "double" value: 0.75 )yaml", }, @@ -1291,7 +1430,7 @@ std::vector GetExportTestCases() { .expected_yaml = R"yaml( variables: - name: "foo" - type_name: "bytes" + type: "bytes" value: b"\xff\x00\x01" )yaml", }, @@ -1309,7 +1448,7 @@ std::vector GetExportTestCases() { .expected_yaml = R"yaml( variables: - name: "foo" - type_name: "string" + type: "string" value: "'single' \"double\"" )yaml", }, @@ -1324,6 +1463,12 @@ std::vector GetExportTestCases() { return config; }(), .expected_yaml = R"yaml( + variables: + - name: "foo" + type: "duration" + value: 1h2m3s + )yaml", + .expected_alt_yaml = R"yaml( variables: - name: "foo" type_name: "duration" @@ -1340,6 +1485,12 @@ std::vector GetExportTestCases() { return config; }(), .expected_yaml = R"yaml( + variables: + - name: "foo" + type: "timestamp" + value: 2026-01-02T03:04:05Z + )yaml", + .expected_alt_yaml = R"yaml( variables: - name: "foo" type_name: "timestamp" @@ -1358,7 +1509,7 @@ std::vector GetExportTestCases() { .expected_yaml = R"yaml( variables: - name: "foo" - type_name: "google.expr.proto3.test.TestAllTypes" + type: "google.expr.proto3.test.TestAllTypes" )yaml", }, ExportTestCase{ @@ -1373,6 +1524,11 @@ std::vector GetExportTestCases() { return config; }(), .expected_yaml = R"yaml( + variables: + - name: "foo" + type: "A" + )yaml", + .expected_alt_yaml = R"yaml( variables: - name: "foo" type_name: "A" @@ -1402,12 +1558,22 @@ std::vector GetExportTestCases() { {.overload_id = "foo_overload_id", .is_member_function = true, .parameters = {{.name = "timestamp"}, - {.name = "A", .params = {{.name = "B"}}}}, + {.name = "A", + .params = {{.name = "B", + .is_type_param = true}}}}, .return_type = {.name = "int"}}, }})); return config; }(), .expected_yaml = R"yaml( + functions: + - name: "foo" + overloads: + - id: "foo_overload_id" + signature: "timestamp.foo(A<~B>)" + return: "int" + )yaml", + .expected_alt_yaml = R"yaml( functions: - name: "foo" overloads: @@ -1418,6 +1584,7 @@ std::vector GetExportTestCases() { - type_name: "A" params: - type_name: "B" + is_type_param: true return: type_name: "int" )yaml", @@ -1440,6 +1607,17 @@ std::vector GetExportTestCases() { return config; }(), .expected_yaml = R"yaml( + functions: + - name: "foo" + overloads: + - id: "foo_overload_b" + signature: "foo(double,A)" + return: "string" + - id: "foo_overload_a" + signature: "foo(timestamp)" + return: "list" + )yaml", + .expected_alt_yaml = R"yaml( functions: - name: "foo" overloads: @@ -1539,30 +1717,30 @@ std::vector GetRoundTripTestCases() { R"yaml( variables: - name: "a" - type_name: "null" + type: "null" - name: "b" - type_name: "bool" + type: "bool" value: true - name: "c" - type_name: "int" + type: "int" value: 42 - name: "d" - type_name: "uint" + type: "uint" value: 777 - name: "e" - type_name: "double" + type: "double" value: 0.75 - name: "f" - type_name: "bytes" + type: "bytes" value: b"\xff\x00\x01" - name: "g" - type_name: "string" + type: "string" value: "plain 'single' \"double\"" - name: "h" - type_name: "duration" + type: "duration" value: 1h2m3s - name: "i" - type_name: "timestamp" + type: "timestamp" value: 2026-01-02T03:04:05Z )yaml", R"yaml( @@ -1575,29 +1753,16 @@ std::vector GetRoundTripTestCases() { - name: "foo" overloads: - id: "foo_overload_id" - target: - type_name: "timestamp" - args: - - type_name: "A" - params: - - type_name: "B" - return: - type_name: "int" + signature: "timestamp.foo(A<~B>)" + return: "int" )yaml", R"yaml( functions: - name: "foo" overloads: - id: "foo_overload_id" - args: - - type_name: "timestamp" - - type_name: "A" - params: - - type_name: "B" - return: - type_name: "list" - params: - - type_name: "int" + signature: "foo(timestamp,A<~B>)" + return: "list" )yaml", }; } diff --git a/env/type_info.cc b/env/type_info.cc index a5b47b6f1..7896a92e2 100644 --- a/env/type_info.cc +++ b/env/type_info.cc @@ -15,12 +15,15 @@ #include "env/type_info.h" #include +#include #include #include "absl/base/no_destructor.h" #include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "common/ast.h" #include "common/type.h" #include "common/type_kind.h" #include "env/config.h" @@ -180,5 +183,225 @@ absl::StatusOr TypeInfoToType( return DynType(); } } +absl::StatusOr TypeInfoToTypeSpec(const Config::TypeInfo& type_info) { + if (type_info.is_type_param) { + return TypeSpec(ParamTypeSpec(type_info.name)); + } + + std::optional type_kind = TypeNameToTypeKind(type_info.name); + if (!type_kind.has_value()) { + if (type_info.params.empty()) { + return TypeSpec(MessageTypeSpec(type_info.name)); + } else { + std::vector param_specs; + for (const Config::TypeInfo& param : type_info.params) { + CEL_ASSIGN_OR_RETURN(TypeSpec param_spec, TypeInfoToTypeSpec(param)); + param_specs.push_back(std::move(param_spec)); + } + return TypeSpec(AbstractType(type_info.name, std::move(param_specs))); + } + } + + switch (*type_kind) { + case TypeKind::kNull: + return TypeSpec(NullTypeSpec()); + case TypeKind::kBool: + return TypeSpec(PrimitiveType::kBool); + case TypeKind::kInt: + return TypeSpec(PrimitiveType::kInt64); + case TypeKind::kUint: + return TypeSpec(PrimitiveType::kUint64); + case TypeKind::kDouble: + return TypeSpec(PrimitiveType::kDouble); + case TypeKind::kString: + return TypeSpec(PrimitiveType::kString); + case TypeKind::kBytes: + return TypeSpec(PrimitiveType::kBytes); + case TypeKind::kTimestamp: + return TypeSpec(WellKnownTypeSpec::kTimestamp); + case TypeKind::kDuration: + return TypeSpec(WellKnownTypeSpec::kDuration); + case TypeKind::kList: { + if (!type_info.params.empty()) { + CEL_ASSIGN_OR_RETURN(TypeSpec elem_type, + TypeInfoToTypeSpec(type_info.params[0])); + return TypeSpec( + ListTypeSpec(std::make_unique(std::move(elem_type)))); + } else { + return TypeSpec( + ListTypeSpec(std::make_unique(DynTypeSpec()))); + } + } + case TypeKind::kMap: { + TypeSpec key_type(DynTypeSpec{}); + TypeSpec value_type(DynTypeSpec{}); + if (!type_info.params.empty()) { + CEL_ASSIGN_OR_RETURN(key_type, TypeInfoToTypeSpec(type_info.params[0])); + } + if (type_info.params.size() > 1) { + CEL_ASSIGN_OR_RETURN(value_type, + TypeInfoToTypeSpec(type_info.params[1])); + } + return TypeSpec( + MapTypeSpec(std::make_unique(std::move(key_type)), + std::make_unique(std::move(value_type)))); + } + case TypeKind::kDyn: + return TypeSpec(DynTypeSpec()); + case TypeKind::kAny: + return TypeSpec(WellKnownTypeSpec::kAny); + case TypeKind::kBoolWrapper: + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kBool)); + case TypeKind::kIntWrapper: + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kInt64)); + case TypeKind::kUintWrapper: + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kUint64)); + case TypeKind::kDoubleWrapper: + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kDouble)); + case TypeKind::kStringWrapper: + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kString)); + case TypeKind::kBytesWrapper: + return TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kBytes)); + case TypeKind::kType: { + if (type_info.params.empty()) { + return TypeSpec(std::make_unique(DynTypeSpec())); + } + CEL_ASSIGN_OR_RETURN(TypeSpec type_param, + TypeInfoToTypeSpec(type_info.params[0])); + return TypeSpec(std::make_unique(std::move(type_param))); + } + default: + return TypeSpec(DynTypeSpec()); + } +} + +absl::StatusOr TypeSpecToTypeInfo(const TypeSpec& type_spec) { + Config::TypeInfo type_info; + + if (type_spec.has_dyn()) { + type_info.name = "dyn"; + } else if (type_spec.has_null()) { + type_info.name = "null"; + } else if (type_spec.has_primitive()) { + switch (type_spec.primitive()) { + case PrimitiveType::kBool: + type_info.name = "bool"; + break; + case PrimitiveType::kInt64: + type_info.name = "int"; + break; + case PrimitiveType::kUint64: + type_info.name = "uint"; + break; + case PrimitiveType::kDouble: + type_info.name = "double"; + break; + case PrimitiveType::kString: + type_info.name = "string"; + break; + case PrimitiveType::kBytes: + type_info.name = "bytes"; + break; + default: + return absl::InvalidArgumentError("Unspecified primitive type"); + } + } else if (type_spec.has_wrapper()) { + switch (type_spec.wrapper()) { + case PrimitiveType::kBool: + type_info.name = "bool_wrapper"; + break; + case PrimitiveType::kInt64: + type_info.name = "int_wrapper"; + break; + case PrimitiveType::kUint64: + type_info.name = "uint_wrapper"; + break; + case PrimitiveType::kDouble: + type_info.name = "double_wrapper"; + break; + case PrimitiveType::kString: + type_info.name = "string_wrapper"; + break; + case PrimitiveType::kBytes: + type_info.name = "bytes_wrapper"; + break; + default: + return absl::InvalidArgumentError("Unspecified wrapper type"); + } + } else if (type_spec.has_well_known()) { + switch (type_spec.well_known()) { + case WellKnownTypeSpec::kAny: + type_info.name = "any"; + break; + case WellKnownTypeSpec::kTimestamp: + type_info.name = "timestamp"; + break; + case WellKnownTypeSpec::kDuration: + type_info.name = "duration"; + break; + default: + return absl::InvalidArgumentError("Unspecified well known type"); + } + } else if (type_spec.has_list_type()) { + type_info.name = "list"; + const ListTypeSpec& list_type = type_spec.list_type(); + if (list_type.has_elem_type() && list_type.elem_type().is_specified()) { + CEL_ASSIGN_OR_RETURN(Config::TypeInfo param, + TypeSpecToTypeInfo(list_type.elem_type())); + type_info.params.push_back(std::move(param)); + } + } else if (type_spec.has_map_type()) { + type_info.name = "map"; + const MapTypeSpec& map_type = type_spec.map_type(); + bool has_key = + map_type.has_key_type() && map_type.key_type().is_specified(); + bool has_value = + map_type.has_value_type() && map_type.value_type().is_specified(); + if (has_key || has_value) { + if (has_key) { + CEL_ASSIGN_OR_RETURN(Config::TypeInfo param, + TypeSpecToTypeInfo(map_type.key_type())); + type_info.params.push_back(std::move(param)); + } else { + type_info.params.push_back(Config::TypeInfo{.name = "dyn"}); + } + if (has_value) { + CEL_ASSIGN_OR_RETURN(Config::TypeInfo param_value, + TypeSpecToTypeInfo(map_type.value_type())); + type_info.params.push_back(std::move(param_value)); + } else { + type_info.params.push_back(Config::TypeInfo{.name = "dyn"}); + } + } + } else if (type_spec.has_message_type()) { + type_info.name = type_spec.message_type().type(); + } else if (type_spec.has_type_param()) { + type_info.name = type_spec.type_param().type(); + type_info.is_type_param = true; + } else if (type_spec.has_type()) { + type_info.name = "type"; + CEL_ASSIGN_OR_RETURN(Config::TypeInfo param, + TypeSpecToTypeInfo(type_spec.type())); + type_info.params.push_back(std::move(param)); + } else if (type_spec.has_abstract_type()) { + type_info.name = type_spec.abstract_type().name(); + for (const TypeSpec& param_spec : + type_spec.abstract_type().parameter_types()) { + CEL_ASSIGN_OR_RETURN(Config::TypeInfo param, + TypeSpecToTypeInfo(param_spec)); + type_info.params.push_back(std::move(param)); + } + } else if (type_spec.has_error()) { + return absl::InvalidArgumentError( + "ErrorType cannot be converted to TypeInfo"); + } else if (type_spec.has_function()) { + return absl::InvalidArgumentError( + "FunctionType cannot be converted to TypeInfo"); + } else { + return absl::InvalidArgumentError("Unknown TypeSpec kind"); + } + + return type_info; +} } // namespace cel diff --git a/env/type_info.h b/env/type_info.h index bb3cfde43..3f802ce1a 100644 --- a/env/type_info.h +++ b/env/type_info.h @@ -16,6 +16,7 @@ #define THIRD_PARTY_CEL_CPP_ENV_TYPE_INFO_H_ #include "absl/status/statusor.h" +#include "common/ast.h" #include "common/type.h" #include "env/config.h" #include "google/protobuf/arena.h" @@ -30,6 +31,12 @@ absl::StatusOr TypeInfoToType( const Config::TypeInfo& type_info, const google::protobuf::DescriptorPool* descriptor_pool, google::protobuf::Arena* arena); +// Converts a Config::TypeInfo to a cel::TypeSpec. +absl::StatusOr TypeInfoToTypeSpec(const Config::TypeInfo& type_info); + +// Converts a cel::TypeSpec to a Config::TypeInfo. +absl::StatusOr TypeSpecToTypeInfo(const TypeSpec& type_spec); + } // namespace cel #endif // THIRD_PARTY_CEL_CPP_ENV_TYPE_INFO_H_ diff --git a/env/type_info_test.cc b/env/type_info_test.cc index 015d8a928..b8db2e425 100644 --- a/env/type_info_test.cc +++ b/env/type_info_test.cc @@ -14,9 +14,12 @@ #include "env/type_info.h" +#include +#include #include #include +#include "common/ast/metadata.h" #include "common/type.h" #include "common/type_proto.h" #include "env/config.h" @@ -127,5 +130,103 @@ std::vector GetTestCases() { INSTANTIATE_TEST_SUITE_P(TypeInfoTest, TypeInfoTest, ValuesIn(GetTestCases())); +void ExpectTypeInfoEq(const Config::TypeInfo& actual, + const Config::TypeInfo& expected) { + EXPECT_EQ(actual.name, expected.name); + EXPECT_EQ(actual.is_type_param, expected.is_type_param); + ASSERT_EQ(actual.params.size(), expected.params.size()); + for (size_t i = 0; i < actual.params.size(); ++i) { + ExpectTypeInfoEq(actual.params[i], expected.params[i]); + } +} + +struct TypeSpecTestCase { + TypeSpec type_spec; + Config::TypeInfo expected_type_info; +}; + +using TypeSpecToTypeInfoTest = testing::TestWithParam; + +TEST_P(TypeSpecToTypeInfoTest, Convert) { + const TypeSpecTestCase& param = GetParam(); + ASSERT_OK_AND_ASSIGN(Config::TypeInfo actual_type_info, + TypeSpecToTypeInfo(param.type_spec)); + ExpectTypeInfoEq(actual_type_info, param.expected_type_info); +} + +std::vector GetTypeSpecTestCases() { + return { + TypeSpecTestCase{ + .type_spec = TypeSpec(PrimitiveType::kInt64), + .expected_type_info = {.name = "int"}, + }, + TypeSpecTestCase{ + .type_spec = TypeSpec( + ListTypeSpec(std::make_unique(PrimitiveType::kInt64))), + .expected_type_info = {.name = "list", + .params = {Config::TypeInfo{.name = "int"}}}, + }, + TypeSpecTestCase{ + .type_spec = TypeSpec(ListTypeSpec()), + .expected_type_info = {.name = "list"}, + }, + TypeSpecTestCase{ + .type_spec = TypeSpec( + MapTypeSpec(std::make_unique(PrimitiveType::kString), + std::make_unique(PrimitiveType::kInt64))), + .expected_type_info = {.name = "map", + .params = {Config::TypeInfo{.name = "string"}, + Config::TypeInfo{.name = "int"}}}, + }, + TypeSpecTestCase{ + .type_spec = TypeSpec(MapTypeSpec()), + .expected_type_info = {.name = "map"}, + }, + TypeSpecTestCase{ + .type_spec = TypeSpec( + MessageTypeSpec("cel.expr.conformance.proto2.TestAllTypes")), + .expected_type_info = + {.name = "cel.expr.conformance.proto2.TestAllTypes"}, + }, + TypeSpecTestCase{ + .type_spec = + TypeSpec(AbstractType("A", {TypeSpec(ParamTypeSpec("B"))})), + .expected_type_info = {.name = "A", + .params = {Config::TypeInfo{ + .name = "B", .is_type_param = true}}}, + }, + TypeSpecTestCase{ + .type_spec = TypeSpec(WellKnownTypeSpec::kAny), + .expected_type_info = {.name = "any"}, + }, + TypeSpecTestCase{ + .type_spec = TypeSpec(WellKnownTypeSpec::kTimestamp), + .expected_type_info = {.name = "timestamp"}, + }, + TypeSpecTestCase{ + .type_spec = TypeSpec(PrimitiveTypeWrapper(PrimitiveType::kDouble)), + .expected_type_info = {.name = "double_wrapper"}, + }, + TypeSpecTestCase{ + .type_spec = TypeSpec( + std::make_unique(WellKnownTypeSpec::kDuration)), + .expected_type_info = {.name = "type", + .params = {Config::TypeInfo{.name = + "duration"}}}, + }, + TypeSpecTestCase{ + .type_spec = TypeSpec(DynTypeSpec{}), + .expected_type_info = {.name = "dyn"}, + }, + TypeSpecTestCase{ + .type_spec = TypeSpec(NullTypeSpec{}), + .expected_type_info = {.name = "null"}, + }, + }; +} + +INSTANTIATE_TEST_SUITE_P(TypeSpecToTypeInfoTest, TypeSpecToTypeInfoTest, + ValuesIn(GetTypeSpecTestCases())); + } // namespace } // namespace cel