diff --git a/common/expr_factory.h b/common/expr_factory.h index 5607d8deb..c1770716b 100644 --- a/common/expr_factory.h +++ b/common/expr_factory.h @@ -376,7 +376,7 @@ class ExprFactory { return expr; } - private: + protected: friend class MacroExprFactory; friend class ParserMacroExprFactory; friend class OptimizerExprFactory; diff --git a/tools/BUILD b/tools/BUILD index ceb2befc5..7727470f1 100644 --- a/tools/BUILD +++ b/tools/BUILD @@ -204,6 +204,55 @@ cc_library( ], ) +cc_library( + name = "proto_to_predicate", + srcs = ["proto_to_predicate.cc"], + hdrs = ["proto_to_predicate.h"], + deps = [ + "//common:ast", + "//common:expr", + "//common:expr_factory", + "//common:operators", + "//internal:status_macros", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "proto_to_predicate_test", + srcs = ["proto_to_predicate_test.cc"], + deps = [ + ":cel_unparser", + ":proto_to_predicate", + "//common:ast", + "//common:ast_proto", + "//common:value", + "//env:config", + "//env:env_runtime", + "//env:env_yaml", + "//env:runtime_std_extensions", + "//eval/testutil:test_message_cc_proto", + "//extensions/protobuf:value", + "//internal:status_macros", + "//internal:testing", + "//internal:testing_descriptor_pool", + "//parser", + "//runtime", + "//runtime:activation", + "//tools/testdata:policy_cc_proto", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@com_google_protobuf//:protobuf", + ], +) + cc_test( name = "descriptor_pool_builder_test", srcs = ["descriptor_pool_builder_test.cc"], diff --git a/tools/proto_to_predicate.cc b/tools/proto_to_predicate.cc new file mode 100644 index 000000000..2f28baa80 --- /dev/null +++ b/tools/proto_to_predicate.cc @@ -0,0 +1,463 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tools/proto_to_predicate.h" + +#include +#include +#include +#include +#include + +#include "google/protobuf/descriptor.pb.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "common/ast.h" +#include "common/expr.h" +#include "common/expr_factory.h" +#include "common/operators.h" +#include "internal/status_macros.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" +#include "google/protobuf/reflection.h" + +namespace cel::tools { +namespace { + +using ::google::api::expr::common::CelOperator; +using ::google::protobuf::FieldDescriptor; +using ::google::protobuf::Message; +using ::google::protobuf::Reflection; + +class PredicateBuilder : public ExprFactory { + public: + explicit PredicateBuilder(absl::string_view input_name) + : ExprFactory(), input_name_(input_name), id_(1) {} + + absl::StatusOr Build(const Message& message) { + std::vector predicates; + Expr base_expr = NewIdent(NextId(), input_name_); + + CEL_RETURN_IF_ERROR(Walk(message, base_expr, predicates)); + Expr root = ConjoinPredicates(predicates); + return Ast(std::move(root), std::move(source_info_)); + } + + absl::StatusOr Build(absl::Span messages) { + if (messages.empty()) { + return Ast(NewBoolConst(NextId(), true), std::move(source_info_)); + } + + std::vector message_asts; + message_asts.reserve(messages.size()); + for (const auto* message : messages) { + std::vector predicates; + Expr base_expr = NewIdent(NextId(), input_name_); + + CEL_RETURN_IF_ERROR(Walk(*message, base_expr, predicates)); + message_asts.push_back(ConjoinPredicates(predicates)); + } + + return Ast(FoldBinaryOp(CelOperator::LOGICAL_OR, message_asts), + std::move(source_info_)); + } + + private: + Expr ConjoinPredicates(std::vector& predicates) { + if (predicates.empty()) { + return NewBoolConst(NextId(), true); + } + return FoldBinaryOp(CelOperator::LOGICAL_AND, predicates); + } + + // Retrieves the "match_path" string option from the field options if + // defined, returning an empty string otherwise. + std::string GetMatchPath(const ::google::protobuf::FieldDescriptor* field) { + const ::google::protobuf::Message& options = field->options(); + const ::google::protobuf::Reflection* refl = options.GetReflection(); + std::vector fields; + refl->ListFields(options, &fields); + for (const auto* f : fields) { + if (f->name() == "match_path") { + return refl->GetString(options, f); + } + } + return ""; + } + + // Parses a dot-separated string representation of a path (e.g. "dest.region") + // and builds a corresponding select chain AST. + Expr ParseAndBuildPath(absl::string_view path_str) { + std::vector parts = absl::StrSplit(path_str, '.'); + Expr e = NewIdent(NextId(), parts[0]); + for (size_t i = 1; i < parts.size(); ++i) { + e = NewSelect(NextId(), std::move(e), parts[i]); + } + return e; + } + ExprId NextId() { return id_++; } + + // --------------------------------------------------------------------------- + // Field value extraction + // --------------------------------------------------------------------------- + + // Converts a singular field value to a CEL constant expression. + Expr PrimitiveToExpr(ExprId expr_id, const Message& message, + const Reflection* reflection, + const FieldDescriptor* field) { + switch (field->cpp_type()) { + case FieldDescriptor::CPPTYPE_INT32: + return NewIntConst(expr_id, reflection->GetInt32(message, field)); + case FieldDescriptor::CPPTYPE_INT64: + return NewIntConst(expr_id, reflection->GetInt64(message, field)); + case FieldDescriptor::CPPTYPE_UINT32: + return NewUintConst(expr_id, reflection->GetUInt32(message, field)); + case FieldDescriptor::CPPTYPE_UINT64: + return NewUintConst(expr_id, reflection->GetUInt64(message, field)); + case FieldDescriptor::CPPTYPE_DOUBLE: + return NewDoubleConst(expr_id, reflection->GetDouble(message, field)); + case FieldDescriptor::CPPTYPE_FLOAT: + return NewDoubleConst(expr_id, reflection->GetFloat(message, field)); + case FieldDescriptor::CPPTYPE_BOOL: + return NewBoolConst(expr_id, reflection->GetBool(message, field)); + case FieldDescriptor::CPPTYPE_ENUM: + return NewIntConst(expr_id, reflection->GetEnumValue(message, field)); + case FieldDescriptor::CPPTYPE_STRING: { + std::string str_val = reflection->GetString(message, field); + if (field->type() == FieldDescriptor::TYPE_BYTES) { + return NewBytesConst(expr_id, std::move(str_val)); + } + return NewStringConst(expr_id, std::move(str_val)); + } + default: + // Message is handled elsewhere + break; + } + return NewNullConst(expr_id); + } + + Expr PrimitiveToExpr(const Message& message, const Reflection* reflection, + const FieldDescriptor* field) { + return PrimitiveToExpr(NextId(), message, reflection, field); + } + + // Converts a repeated field element to a CEL constant expression. + Expr RepeatedPrimitiveToExpr(const Message& message, + const Reflection* reflection, + const FieldDescriptor* field, int index) { + const ExprId id = NextId(); + switch (field->cpp_type()) { + case FieldDescriptor::CPPTYPE_INT32: + return NewIntConst(id, + reflection->GetRepeatedInt32(message, field, index)); + case FieldDescriptor::CPPTYPE_INT64: + return NewIntConst(id, + reflection->GetRepeatedInt64(message, field, index)); + case FieldDescriptor::CPPTYPE_UINT32: + return NewUintConst( + id, reflection->GetRepeatedUInt32(message, field, index)); + case FieldDescriptor::CPPTYPE_UINT64: + return NewUintConst( + id, reflection->GetRepeatedUInt64(message, field, index)); + case FieldDescriptor::CPPTYPE_DOUBLE: + return NewDoubleConst( + id, reflection->GetRepeatedDouble(message, field, index)); + case FieldDescriptor::CPPTYPE_FLOAT: + return NewDoubleConst( + id, reflection->GetRepeatedFloat(message, field, index)); + case FieldDescriptor::CPPTYPE_BOOL: + return NewBoolConst(id, + reflection->GetRepeatedBool(message, field, index)); + case FieldDescriptor::CPPTYPE_ENUM: + return NewIntConst( + id, reflection->GetRepeatedEnumValue(message, field, index)); + case FieldDescriptor::CPPTYPE_STRING: { + std::string str_val = + reflection->GetRepeatedString(message, field, index); + if (field->type() == FieldDescriptor::TYPE_BYTES) { + return NewBytesConst(id, std::move(str_val)); + } + return NewStringConst(id, std::move(str_val)); + } + default: + break; + } + return NewNullConst(id); + } + + // --------------------------------------------------------------------------- + // Expression construction helpers + // --------------------------------------------------------------------------- + + // Creates a binary operator call: `lhs rhs`. + Expr ConstructBinaryOp(absl::string_view op, Expr lhs, Expr rhs) { + std::vector args; + args.reserve(2); + args.push_back(std::move(lhs)); + args.push_back(std::move(rhs)); + return NewCall(NextId(), op, std::move(args)); + } + + Expr ConstructEquality(Expr lhs, Expr rhs) { + return ConstructBinaryOp(CelOperator::EQUALS, std::move(lhs), + std::move(rhs)); + } + + // Left-folds a vector of expressions with a binary operator. + // Requires: `exprs` is non-empty. + Expr FoldBinaryOp(absl::string_view op, std::vector& exprs) { + Expr root = std::move(exprs[0]); + for (size_t i = 1; i < exprs.size(); ++i) { + root = ConstructBinaryOp(op, std::move(root), std::move(exprs[i])); + } + return root; + } + + // --------------------------------------------------------------------------- + // Map field predicate (extracted from Walk for readability) + // --------------------------------------------------------------------------- + + // Builds the predicate for a map field to assert that all key-value pairs + // specified in the policy are present in the input map field: + // "key" in input.map && input.map["key"] == value + absl::Status WalkMapField(const Reflection* reflection, + const Message& message, + const FieldDescriptor* field, const Expr& base_expr, + int size, std::vector& predicates) { + const FieldDescriptor* const key_field = + field->message_type()->FindFieldByName("key"); + const FieldDescriptor* const value_field = + field->message_type()->FindFieldByName("value"); + + Expr map_path = NewSelect(NextId(), base_expr, field->name()); + + struct MapEntry { + const Message* message; + }; + std::vector entries; + entries.reserve(size); + for (int i = 0; i < size; ++i) { + entries.push_back({&reflection->GetRepeatedMessage(message, field, i)}); + } + + if (!entries.empty()) { + const Reflection* const entry_ref = entries[0].message->GetReflection(); + std::sort(entries.begin(), entries.end(), + [entry_ref, key_field](const MapEntry& a, const MapEntry& b) { + switch (key_field->cpp_type()) { + case FieldDescriptor::CPPTYPE_INT32: + return entry_ref->GetInt32(*a.message, key_field) < + entry_ref->GetInt32(*b.message, key_field); + case FieldDescriptor::CPPTYPE_INT64: + return entry_ref->GetInt64(*a.message, key_field) < + entry_ref->GetInt64(*b.message, key_field); + case FieldDescriptor::CPPTYPE_UINT32: + return entry_ref->GetUInt32(*a.message, key_field) < + entry_ref->GetUInt32(*b.message, key_field); + case FieldDescriptor::CPPTYPE_UINT64: + return entry_ref->GetUInt64(*a.message, key_field) < + entry_ref->GetUInt64(*b.message, key_field); + case FieldDescriptor::CPPTYPE_BOOL: + return !entry_ref->GetBool(*a.message, key_field) && + entry_ref->GetBool(*b.message, key_field); + case FieldDescriptor::CPPTYPE_STRING: + return entry_ref->GetString(*a.message, key_field) < + entry_ref->GetString(*b.message, key_field); + default: + return false; + } + }); + } + + std::vector map_checks; + map_checks.reserve(size); + for (const auto& entry : entries) { + const Message& entry_msg = *entry.message; + const Reflection* const entry_ref = entry_msg.GetReflection(); + + Expr key_expr = PrimitiveToExpr(entry_msg, entry_ref, key_field); + + // Represents `"key" in input.map` to assert the key exists. + Expr in_check = NewCall(NextId(), CelOperator::IN, + std::vector{key_expr, map_path}); + // Represents `input.map["key"]` to lookup the value. + Expr lookup_path = NewCall(NextId(), CelOperator::INDEX, + std::vector{map_path, key_expr}); + + if (value_field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { + const Message& value_msg = + entry_ref->GetMessage(entry_msg, value_field); + std::vector val_predicates; + CEL_RETURN_IF_ERROR(Walk(value_msg, lookup_path, val_predicates)); + + if (!val_predicates.empty()) { + Expr nested_check = + FoldBinaryOp(CelOperator::LOGICAL_AND, val_predicates); + // Represents `"key" in input.map && (nested message fields check...)` + Expr entry_check = + ConstructBinaryOp(CelOperator::LOGICAL_AND, std::move(in_check), + std::move(nested_check)); + map_checks.push_back(std::move(entry_check)); + } else { + // Represents `"key" in input.map` if nested message is empty. + map_checks.push_back(std::move(in_check)); + } + } else { + Expr value_expr = PrimitiveToExpr(entry_msg, entry_ref, value_field); + // Represents `input.map["key"] == value` + Expr eq_check = + ConstructEquality(std::move(lookup_path), std::move(value_expr)); + + // Represents `"key" in input.map && input.map["key"] == value` + Expr entry_check = ConstructBinaryOp( + CelOperator::LOGICAL_AND, std::move(in_check), std::move(eq_check)); + map_checks.push_back(std::move(entry_check)); + } + } + + predicates.push_back(FoldBinaryOp(CelOperator::LOGICAL_AND, map_checks)); + return absl::OkStatus(); + } + + // --------------------------------------------------------------------------- + // Repeated field predicate (extracted from Walk for readability) + // --------------------------------------------------------------------------- + + // Builds predicates for a repeated field: + // - Repeated Messages are mapped to a logical OR (||) of the generated + // predicates for each message. + // - Repeated Primitives are mapped either to: + // - `lhs in [values]` if a "match_path" option is specified. + // - `value in input.field` conjoined with && for each value otherwise. + absl::Status WalkRepeatedField(const Reflection* reflection, + const Message& message, + const FieldDescriptor* field, + const Expr& base_expr, int size, + std::vector& predicates) { + if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { + std::vector message_asts; + message_asts.reserve(size); + for (int i = 0; i < size; ++i) { + const Message& sub_message = + reflection->GetRepeatedMessage(message, field, i); + std::vector sub_predicates; + Expr sub_base = NewSelect(NextId(), base_expr, field->name()); + CEL_RETURN_IF_ERROR(Walk(sub_message, sub_base, sub_predicates)); + message_asts.push_back(ConjoinPredicates(sub_predicates)); + } + // Represents alternate message predicates conjoined with OR: `msg_1 || + // msg_2 || ...` + predicates.push_back(FoldBinaryOp(CelOperator::LOGICAL_OR, message_asts)); + return absl::OkStatus(); + } + + std::vector elements; + elements.reserve(size); + for (int i = 0; i < size; ++i) { + elements.push_back(NewListElement( + RepeatedPrimitiveToExpr(message, reflection, field, i))); + } + Expr literal_list = NewList(NextId(), std::move(elements)); + + std::string match_path_val = GetMatchPath(field); + if (!match_path_val.empty()) { + Expr lhs = ParseAndBuildPath(match_path_val); + // Represents `lhs in [values]` check (e.g. `dest.region in ["us-east", + // "us-west"]`). + predicates.push_back( + NewCall(NextId(), CelOperator::IN, + std::vector{std::move(lhs), std::move(literal_list)})); + return absl::OkStatus(); + } + + Expr map_path = NewSelect(NextId(), base_expr, field->name()); + std::vector element_checks; + element_checks.reserve(size); + for (int i = 0; i < size; ++i) { + Expr elem_expr = RepeatedPrimitiveToExpr(message, reflection, field, i); + // Represents `value in input.field` check. + Expr in_check = + NewCall(NextId(), CelOperator::IN, + std::vector{std::move(elem_expr), map_path}); + element_checks.push_back(std::move(in_check)); + } + // Represents `"val1" in input.list && "val2" in input.list && ...` + predicates.push_back( + FoldBinaryOp(CelOperator::LOGICAL_AND, element_checks)); + + return absl::OkStatus(); + } + + // --------------------------------------------------------------------------- + // Recursive message walk + // --------------------------------------------------------------------------- + + absl::Status Walk(const Message& message, const Expr& base_expr, + std::vector& predicates) { + const Reflection* const reflection = message.GetReflection(); + std::vector fields; + reflection->ListFields(message, &fields); + + for (const auto* field : fields) { + if (field->is_map()) { + const int size = reflection->FieldSize(message, field); + if (size > 0) { + CEL_RETURN_IF_ERROR(WalkMapField(reflection, message, field, + base_expr, size, predicates)); + } + } else if (field->is_repeated()) { + const int size = reflection->FieldSize(message, field); + if (size > 0) { + CEL_RETURN_IF_ERROR(WalkRepeatedField(reflection, message, field, + base_expr, size, predicates)); + } + } else if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { + const Message& sub_message = reflection->GetMessage(message, field); + Expr field_path = NewSelect(NextId(), base_expr, field->name()); + CEL_RETURN_IF_ERROR(Walk(sub_message, field_path, predicates)); + } else { + // Primitive field: base_expr.field == + Expr field_path = NewSelect(NextId(), base_expr, field->name()); + predicates.push_back( + ConstructEquality(std::move(field_path), + PrimitiveToExpr(message, reflection, field))); + } + } + return absl::OkStatus(); + } + + absl::string_view input_name_; + ExprId id_; + SourceInfo source_info_; +}; + +} // namespace + +absl::StatusOr ProtocolBufferToPredicateAst( + const ::google::protobuf::Message& message, absl::string_view input_name) { + PredicateBuilder builder(input_name); + return builder.Build(message); +} + +absl::StatusOr ProtocolBufferToPredicateAst( + absl::Span messages, + absl::string_view input_name) { + PredicateBuilder builder(input_name); + return builder.Build(messages); +} + +} // namespace cel::tools diff --git a/tools/proto_to_predicate.h b/tools/proto_to_predicate.h new file mode 100644 index 000000000..22796b23f --- /dev/null +++ b/tools/proto_to_predicate.h @@ -0,0 +1,48 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_TOOLS_PROTO_TO_PREDICATE_H_ +#define THIRD_PARTY_CEL_CPP_TOOLS_PROTO_TO_PREDICATE_H_ + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "common/ast.h" +#include "google/protobuf/message.h" + +namespace cel::tools { + +// Translates a Protocol Buffer message into a CEL AST representing a predicate. +// +// NOTE: The protocol message schemas used for policy definition should use +// `proto2` or `editions` (and not `proto3` implicit presence) to ensure correct +// behavior, as this library relies on field presence (via reflection) to +// identify which fields are explicitly set by the policy. +absl::StatusOr ProtocolBufferToPredicateAst( + const ::google::protobuf::Message& message, absl::string_view input_name); + +// Translates a list of Protocol Buffer messages into a CEL AST representing a +// conjoined or alternate predicate. +// +// NOTE: The protocol message schemas used for policy definition should use +// `proto2` or `editions` (and not `proto3` implicit presence) to ensure correct +// behavior, as this library relies on field presence (via reflection) to +// identify which fields are explicitly set by the policy. +absl::StatusOr ProtocolBufferToPredicateAst( + absl::Span messages, + absl::string_view input_name); + +} // namespace cel::tools + +#endif // THIRD_PARTY_CEL_CPP_TOOLS_PROTO_TO_PREDICATE_H_ diff --git a/tools/proto_to_predicate_test.cc b/tools/proto_to_predicate_test.cc new file mode 100644 index 000000000..02442984c --- /dev/null +++ b/tools/proto_to_predicate_test.cc @@ -0,0 +1,594 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tools/proto_to_predicate.h" + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/status_matchers.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "common/ast.h" +#include "common/ast_proto.h" +#include "common/value.h" +#include "env/config.h" +#include "env/env_runtime.h" +#include "env/env_yaml.h" +#include "env/runtime_std_extensions.h" +#include "eval/testutil/test_message.pb.h" +#include "extensions/protobuf/value.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "parser/parser.h" +#include "runtime/activation.h" +#include "runtime/runtime.h" +#include "tools/cel_unparser.h" +#include "tools/testdata/policy.pb.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/json/json.h" +#include "google/protobuf/message.h" +#include "google/protobuf/text_format.h" + +namespace cel::tools { +namespace { + +using ::absl_testing::IsOk; +using ::google::api::expr::runtime::TestMessage; + +constexpr absl::string_view kEnvYaml = R"( +name: "test" +extensions: + - name: "bindings" + - name: "optional" +variables: + - name: "input" + type: "google.api.expr.runtime.TestMessage" +)"; + +TestMessage ParseTestMessage(absl::string_view textproto) { + TestMessage msg; + google::protobuf::TextFormat::ParseFromString(textproto, &msg); + return msg; +} + +absl::StatusOr EvaluatePredicate(const cel::Ast& ast, + const TestMessage& input) { + auto descriptor_pool = cel::internal::GetSharedTestingDescriptorPool(); + + CEL_ASSIGN_OR_RETURN(cel::Config config, + cel::EnvConfigFromYaml(std::string(kEnvYaml))); + + cel::EnvRuntime env_runtime; + env_runtime.SetDescriptorPool(descriptor_pool); + cel::RegisterStandardExtensions(env_runtime); + env_runtime.SetConfig(config); + + CEL_ASSIGN_OR_RETURN(std::unique_ptr runtime, + env_runtime.NewRuntime()); + CEL_ASSIGN_OR_RETURN(std::unique_ptr program, + runtime->CreateProgram(std::make_unique(ast))); + + google::protobuf::Arena arena; + cel::Activation activation; + CEL_ASSIGN_OR_RETURN( + cel::Value val, cel::extensions::ProtoMessageToValue( + input, descriptor_pool.get(), + google::protobuf::MessageFactory::generated_factory(), &arena)); + activation.InsertOrAssignValue("input", val); + + CEL_ASSIGN_OR_RETURN(cel::Value result, + program->Evaluate(&arena, activation)); + if (!result.IsBool()) { + return absl::InvalidArgumentError( + "Predicate evaluate result must be a boolean value."); + } + return result.GetBool(); +} + +struct TestCase { + std::string name; + std::vector input_textprotos; + std::string expected_unparsed; + std::string eval_textproto; + bool expected_eval_result = true; + // If true, skip the eval step of the test. This is useful for tests where + // the expected expression does not share the same type structure as the + // input proto, such as empty messages. + bool skip_eval = false; +}; + +class ProtoToPredicateTest : public ::testing::TestWithParam {}; + +TEST_P(ProtoToPredicateTest, ConformanceTests) { + const TestCase& param = GetParam(); + + std::vector input_messages; + input_messages.reserve(param.input_textprotos.size()); + for (const auto& proto_str : param.input_textprotos) { + input_messages.push_back(ParseTestMessage(proto_str)); + } + + std::vector ptr_messages; + ptr_messages.reserve(input_messages.size()); + for (const auto& msg : input_messages) { + ptr_messages.push_back(&msg); + } + + absl::StatusOr ast_or; + if (input_messages.size() == 1) { + ast_or = ProtocolBufferToPredicateAst(input_messages[0], "input"); + } else { + ast_or = + ProtocolBufferToPredicateAst(absl::MakeSpan(ptr_messages), "input"); + } + + ASSERT_THAT(ast_or, IsOk()); + cel::Ast ast = std::move(*ast_or); + + cel::expr::ParsedExpr parsed_expr; + ASSERT_THAT(cel::AstToParsedExpr(ast, &parsed_expr), IsOk()); + ASSERT_OK_AND_ASSIGN(auto unparsed, google::api::expr::Unparse(parsed_expr)); + + EXPECT_EQ(unparsed, param.expected_unparsed); + + if (!param.skip_eval) { + TestMessage eval_msg = ParseTestMessage(param.eval_textproto); + ASSERT_OK_AND_ASSIGN(bool eval_result, EvaluatePredicate(ast, eval_msg)); + EXPECT_EQ(eval_result, param.expected_eval_result); + } +} + +INSTANTIATE_TEST_SUITE_P( + ProtoToPredicateSubCases, ProtoToPredicateTest, + testing::Values( + TestCase{ + .name = "EmptyMessageTest", + .input_textprotos = {""}, + .expected_unparsed = "true", + .eval_textproto = "", + }, + TestCase{ + .name = "EmptyMessagesListTest", + .input_textprotos = {}, + .expected_unparsed = "true", + .eval_textproto = "", + }, + TestCase{ + .name = "PrimitivesTest", + .input_textprotos = {R"pb( + int32_value: 42 string_value: "hello" + )pb"}, + .expected_unparsed = + "input.int32_value == 42 && input.string_value == \"hello\"", + .eval_textproto = R"pb( + int32_value: 42 string_value: "hello" + )pb", + }, + TestCase{ + .name = "AllPrimitivesTest", + .input_textprotos = {R"pb( + int32_value: 42 + int64_value: 43 + uint32_value: 44 + uint64_value: 45 + float_value: 46.5 + double_value: 47.5 + bool_value: true + enum_value: TEST_ENUM_1 + string_value: "hello" + bytes_value: "world" + )pb"}, + .expected_unparsed = + "input.int32_value == 42 && input.int64_value == 43 && " + "input.uint32_value == 44u && input.uint64_value == 45u && " + "input.float_value == 46.5 && input.double_value == 47.5 && " + "input.string_value == \"hello\" && " + "input.bytes_value == b\"world\" && " + "input.bool_value == true && " + "input.enum_value == 1", + .eval_textproto = R"pb( + int32_value: 42 + int64_value: 43 + uint32_value: 44 + uint64_value: 45 + float_value: 46.5 + double_value: 47.5 + bool_value: true + enum_value: TEST_ENUM_1 + string_value: "hello" + bytes_value: "world" + )pb", + }, + TestCase{ + .name = "NestedMessageTest", + .input_textprotos = {R"pb( + message_value: { int32_value: 42 } + )pb"}, + .expected_unparsed = "input.message_value.int32_value == 42", + .eval_textproto = R"pb( + message_value: { int32_value: 42 } + )pb", + }, + TestCase{ + .name = "RepeatedFieldTest", + .input_textprotos = {R"pb( + int32_list: [ 1, 2 ] + )pb"}, + .expected_unparsed = + "1 in input.int32_list && 2 in input.int32_list", + .eval_textproto = R"pb( + int32_list: [ 1, 2 ] + )pb", + }, + TestCase{ + .name = "RepeatedFieldSingleElementTest", + .input_textprotos = {R"pb( + int32_list: [ 42 ] + )pb"}, + .expected_unparsed = "42 in input.int32_list", + .eval_textproto = R"pb( + int32_list: [ 42 ] + )pb", + }, + TestCase{ + .name = "RepeatedFieldEmptyTest", + .input_textprotos = {R"pb( + int32_list: [] + )pb"}, + .expected_unparsed = "true", + .eval_textproto = R"pb( + int32_list: [] + )pb", + }, + TestCase{ + .name = "ListFieldEvalNegative", + .input_textprotos = {R"pb( + int32_list: [ 1, 2 ] + )pb"}, + .expected_unparsed = + "1 in input.int32_list && 2 in input.int32_list", + .eval_textproto = R"pb( + int32_list: [ 1, 3 ] + )pb", + .expected_eval_result = false, + }, + TestCase{ + .name = "SingleRepeatedFieldAllPrimitivesTest", + .input_textprotos = {R"pb( + int32_list: [ 42 ] + int64_list: [ 43 ] + uint32_list: [ 44 ] + uint64_list: [ 45 ] + float_list: [ 46.5 ] + double_list: [ 47.5 ] + bool_list: [ true ] + enum_list: [ TEST_ENUM_1 ] + string_list: [ "hello" ] + bytes_list: [ "world" ] + )pb"}, + .expected_unparsed = "42 in input.int32_list && " + "43 in input.int64_list && " + "44u in input.uint32_list && " + "45u in input.uint64_list && " + "46.5 in input.float_list && " + "47.5 in input.double_list && " + "\"hello\" in input.string_list && " + "b\"world\" in input.bytes_list && " + "true in input.bool_list && " + "1 in input.enum_list", + .eval_textproto = R"pb( + int32_list: [ 42 ] + int64_list: [ 43 ] + uint32_list: [ 44 ] + uint64_list: [ 45 ] + float_list: [ 46.5 ] + double_list: [ 47.5 ] + bool_list: [ true ] + enum_list: [ TEST_ENUM_1 ] + string_list: [ "hello" ] + bytes_list: [ "world" ] + )pb", + }, + TestCase{ + .name = "MultipleRepeatedFieldAllPrimitivesTest", + .input_textprotos = {R"pb( + int32_list: [ 42, 142 ] + int64_list: [ 43, 143 ] + uint32_list: [ 44, 144 ] + uint64_list: [ 45, 145 ] + float_list: [ 46.5, 146.5 ] + double_list: [ 47.5, 147.5 ] + bool_list: [ true, false ] + enum_list: [ TEST_ENUM_1, TEST_ENUM_2 ] + string_list: [ "hello", "universe" ] + bytes_list: [ "world", "space" ] + )pb"}, + .expected_unparsed = + "42 in input.int32_list && 142 in input.int32_list && " + "43 in input.int64_list && 143 in input.int64_list && " + "44u in input.uint32_list && 144u in input.uint32_list && " + "45u in input.uint64_list && 145u in input.uint64_list && " + "46.5 in input.float_list && 146.5 in input.float_list && " + "47.5 in input.double_list && 147.5 in input.double_list && " + "\"hello\" in input.string_list && \"universe\" in " + "input.string_list && " + "b\"world\" in input.bytes_list && b\"space\" in " + "input.bytes_list && " + "true in input.bool_list && false in input.bool_list && " + "1 in input.enum_list && 2 in input.enum_list", + .eval_textproto = R"pb( + int32_list: [ 42, 142 ] + int64_list: [ 43, 143 ] + uint32_list: [ 44, 144 ] + uint64_list: [ 45, 145 ] + float_list: [ 46.5, 146.5 ] + double_list: [ 47.5, 147.5 ] + bool_list: [ true, false ] + enum_list: [ TEST_ENUM_1, TEST_ENUM_2 ] + string_list: [ "hello", "universe" ] + bytes_list: [ "world", "space" ] + )pb", + }, + TestCase{ + .name = "MapFieldTest", + .input_textprotos = {R"pb( + string_int32_map: { key: "foo" value: 1 } + string_int32_map: { key: "bar" value: 2 } + )pb"}, + .expected_unparsed = "\"bar\" in input.string_int32_map && " + "input.string_int32_map[\"bar\"] == 2 && " + "\"foo\" in input.string_int32_map && " + "input.string_int32_map[\"foo\"] == 1", + .eval_textproto = R"pb( + string_int32_map: { key: "foo" value: 1 } + string_int32_map: { key: "bar" value: 2 } + )pb", + }, + TestCase{ + .name = "MapFieldEvalNegativeVal", + .input_textprotos = {R"pb( + string_int32_map: { key: "foo" value: 1 } + string_int32_map: { key: "bar" value: 2 } + )pb"}, + .expected_unparsed = "\"bar\" in input.string_int32_map && " + "input.string_int32_map[\"bar\"] == 2 && " + "\"foo\" in input.string_int32_map && " + "input.string_int32_map[\"foo\"] == 1", + .eval_textproto = R"pb( + string_int32_map: { key: "foo" value: 1 } + string_int32_map: { key: "bar" value: 3 } + )pb", + .expected_eval_result = false, + }, + TestCase{ + .name = "MapFieldEvalNegativeNoKey", + .input_textprotos = {R"pb( + string_int32_map: { key: "foo" value: 1 } + string_int32_map: { key: "bar" value: 2 } + )pb"}, + .expected_unparsed = "\"bar\" in input.string_int32_map && " + "input.string_int32_map[\"bar\"] == 2 && " + "\"foo\" in input.string_int32_map && " + "input.string_int32_map[\"foo\"] == 1", + .eval_textproto = R"pb( + string_int32_map: { key: "foo" value: 1 } + )pb", + .expected_eval_result = false, + }, + TestCase{ + .name = "MapFieldIntKeySortingTest", + .input_textprotos = {R"pb( + int32_int32_map: { key: 10 value: 100 } + int32_int32_map: { key: 5 value: 50 } + int32_int32_map: { key: 8 value: 80 } + )pb"}, + .expected_unparsed = "5 in input.int32_int32_map && " + "input.int32_int32_map[5] == 50 && " + "8 in input.int32_int32_map && " + "input.int32_int32_map[8] == 80 && " + "10 in input.int32_int32_map && " + "input.int32_int32_map[10] == 100", + .eval_textproto = R"pb( + int32_int32_map: { key: 5 value: 50 } + int32_int32_map: { key: 8 value: 80 } + int32_int32_map: { key: 10 value: 100 } + )pb", + }, + TestCase{ + .name = "MultipleMessagesTest", + .input_textprotos = {R"pb( + int32_value: 42 + )pb", + R"pb( + int32_value: 41 string_value: "hello" + )pb"}, + .expected_unparsed = + "input.int32_value == 42 || input.int32_value == 41 && " + "input.string_value == \"hello\"", + .eval_textproto = R"pb( + int32_value: 41 string_value: "hello" + )pb", + }, + TestCase{ + .name = "RepeatedMessageFieldTest", + .input_textprotos = {R"pb( + message_list: + [ { int32_value: 42 } + , { int32_value: 43 }] + )pb"}, + .expected_unparsed = "input.message_list.int32_value == 42 || " + "input.message_list.int32_value == 43", + .skip_eval = true, + }, + TestCase{ + .name = "RepeatedMessageSingleElementTest", + .input_textprotos = {R"pb( + message_list: + [ { int32_value: 42 }] + )pb"}, + .expected_unparsed = "input.message_list.int32_value == 42", + .skip_eval = true, + })); + +struct PolicyTestCase { + std::string name; + std::string json_input; + std::string expected_unparsed; +}; + +class PolicyJsonTest : public ::testing::TestWithParam {}; + +TEST_P(PolicyJsonTest, Conformance) { + const PolicyTestCase& param = GetParam(); + + cel::cpp::tools::Policy policy; + proto2::json::ParseOptions options; + options.ignore_unknown_fields = true; + auto status = + proto2::json::JsonStringToMessage(param.json_input, &policy, options); + ASSERT_THAT(status, IsOk()) << "Failed to parse JSON: " << param.json_input; + + absl::StatusOr ast_or; + std::vector ptr_messages; + ptr_messages.reserve(policy.destinations_size()); + for (const auto& dest : policy.destinations()) { + ptr_messages.push_back(&dest); + } + + if (ptr_messages.empty()) { + auto parsed_expr_or = google::api::expr::parser::Parse("false"); + ASSERT_THAT(parsed_expr_or, IsOk()); + auto ast_ptr_or = cel::CreateAstFromParsedExpr(*parsed_expr_or); + ASSERT_THAT(ast_ptr_or, IsOk()); + ast_or = std::move(**ast_ptr_or); + } else if (ptr_messages.size() == 1) { + ast_or = ProtocolBufferToPredicateAst(*ptr_messages[0], "dest"); + } else { + ast_or = ProtocolBufferToPredicateAst(absl::MakeSpan(ptr_messages), "dest"); + } + + ASSERT_THAT(ast_or, IsOk()); + cel::Ast ast = std::move(*ast_or); + + cel::expr::ParsedExpr parsed_expr; + ASSERT_THAT(cel::AstToParsedExpr(ast, &parsed_expr), IsOk()); + ASSERT_OK_AND_ASSIGN(auto unparsed, google::api::expr::Unparse(parsed_expr)); + + EXPECT_EQ(unparsed, param.expected_unparsed); +} + +INSTANTIATE_TEST_SUITE_P( + PolicyJsonSubCases, PolicyJsonTest, + testing::Values( + PolicyTestCase{ + .name = "SimpleMatch", + .json_input = + R"({ "destinations": [ { "agent": { "id": "agent-007" } } ] })", + .expected_unparsed = "dest.agent.name == \"agent-007\"", + }, + PolicyTestCase{ + .name = "MultipleFields", + .json_input = + R"({ "destinations": [ { + "tool": { + "name": "admin_tool", + "annotations": { + "read_only_hint": false + } + } + } + ] })", + .expected_unparsed = + "dest.tool.name == \"admin_tool\" && " + "dest.tool.annotations.read_only_hint == false", + }, + PolicyTestCase{ + .name = "RepeatedMessages", + .json_input = + R"({ "destinations": [ + { "agent": { "id": "worker-1" } }, + { "agent": { "id": "worker-2" } }, + ] })", + .expected_unparsed = "dest.agent.name == \"worker-1\" || " + "dest.agent.name == \"worker-2\"", + }, + PolicyTestCase{ + .name = "RepeatedPrimitiveArraySingleElement", + .json_input = + R"({ "destinations": [ { + "tool": { + "role_members": { + "admin": { + "principals": ["alice"] + } + } + } + } ] })", + .expected_unparsed = + "\"admin\" in dest.tool.role_members && " + "\"alice\" in dest.tool.role_members[\"admin\"].principals", + }, + PolicyTestCase{ + .name = "RepeatedArrayEmpty", + .json_input = R"({ "destinations": [ { "tool": { } } ] })", + .expected_unparsed = "true", + }, + PolicyTestCase{ + .name = "MapEquality", + .json_input = + R"({ "destinations": [ + { "tool": { + "name": "shell", + "labels": { + "cluster": "us-central1", + "project": "dev" + } + } + } ] })", + .expected_unparsed = + "dest.tool.name == \"shell\" && \"cluster\" in " + "dest.tool.labels && dest.tool.labels[\"cluster\"] == " + "\"us-central1\" && \"project\" in dest.tool.labels && " + "dest.tool.labels[\"project\"] == \"dev\"", + }, + PolicyTestCase{ + .name = "NestedMapEquality", + .json_input = + R"({ "destinations": [ + { "tool": { + "role_members": { + "admin": { + "all_users": true + } + } + } } + ] })", + .expected_unparsed = + "\"admin\" in dest.tool.role_members && " + "dest.tool.role_members[\"admin\"].all_users == true", + }, + PolicyTestCase{ + .name = "EmptyPolicy", + .json_input = "{}", + .expected_unparsed = "false", + })); + +} // namespace +} // namespace cel::tools diff --git a/tools/testdata/BUILD b/tools/testdata/BUILD index 493f0ff2f..035be4758 100644 --- a/tools/testdata/BUILD +++ b/tools/testdata/BUILD @@ -17,6 +17,14 @@ load( "flatbuffer_library_public", ) load("@rules_cc//cc:cc_library.bzl", "cc_library") +load( + "//third_party/protobuf/bazel:cc_proto_library.bzl", + "cc_proto_library", +) +load( + "//third_party/protobuf/bazel:proto_library.bzl", + "proto_library", +) licenses(["notice"]) @@ -46,3 +54,17 @@ cc_library( linkstatic = True, deps = ["@com_github_google_flatbuffers//:runtime_cc"], ) + +proto_library( + name = "policy_proto", + srcs = ["policy.proto"], + compatible_with = ["//buildenv/target:non_prod"], + visibility = ["//tools:__subpackages__"], +) + +cc_proto_library( + name = "policy_cc_proto", + compatible_with = ["//buildenv/target:non_prod"], + visibility = ["//tools:__subpackages__"], + deps = [":policy_proto"], +) diff --git a/tools/testdata/policy.proto b/tools/testdata/policy.proto new file mode 100644 index 000000000..a556242ba --- /dev/null +++ b/tools/testdata/policy.proto @@ -0,0 +1,73 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// policy.proto defines the test schema representing client-configured policies. +// It is used by the `proto_to_predicate` tool to translate Protobuf policies +// into CEL predicates. +edition = "2023"; + +package cel.cpp.tools; + +option cc_enable_arenas = true; + +// Represents the targeted client agent. +message Agent { + string name = 1 [json_name = "id"]; +} + +// Specifies additional metadata tool annotations. +message ToolAnnotations { + bool read_only_hint = 1; +} + +// Represents a mapped nested message entry value inside map fields. +message Members { + repeated string principals = 1; + + repeated string regions = 2; + + bool all_users = 3; + + bool all_authenticated_users = 4; +} + +// Represents a metadata tool block. +message Tool { + // The name of the tool. + string name = 1; + + // Additional metadata annotations for the tool. + ToolAnnotations annotations = 2; + + // A string-to-string map, transpiled as conjoined existence and equality + // checks. + map labels = 3; + + // A map with string keys representing roles and Member instances as values. + map role_members = 4; +} + +// Represents a policy mapping destination block. +message Target { + oneof kind { + Agent agent = 1; + Tool tool = 2; + } +} + +// Represents the top-level policy containing multiple alternate destination +// rules. +message Policy { + repeated Target destinations = 1; +}