From 9a1d617b25f932c8dc3a657b01c4055b828a6f79 Mon Sep 17 00:00:00 2001 From: Ricardo Torres Date: Thu, 14 May 2026 20:48:58 -0700 Subject: [PATCH] The `fully_commutes_with_sum` property was incorrectly using `sum()` on the flattened commuting structure, which does not yield a boolean. Changed to use `all()` to ensure the property returns True only if all sub-encoders fully commute with sum. Added type assertions in tests to verify the return type is boolean. PiperOrigin-RevId: 915763952 --- .../core/internal/tensor_encoding/core/gather_encoder.py | 2 +- .../core/internal/tensor_encoding/core/gather_encoder_test.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/tensorflow_model_optimization/python/core/internal/tensor_encoding/core/gather_encoder.py b/tensorflow_model_optimization/python/core/internal/tensor_encoding/core/gather_encoder.py index 38b8b2d81..1ce87307c 100644 --- a/tensorflow_model_optimization/python/core/internal/tensor_encoding/core/gather_encoder.py +++ b/tensorflow_model_optimization/python/core/internal/tensor_encoding/core/gather_encoder.py @@ -370,7 +370,7 @@ def input_tensorspec(self): @property def fully_commutes_with_sum(self): # If any element is not True, the whole thing does not fully commute. - return sum(tf.nest.flatten(self._commuting_structure)) + return all(tf.nest.flatten(self._commuting_structure)) @property def state_update_aggregation_modes(self): diff --git a/tensorflow_model_optimization/python/core/internal/tensor_encoding/core/gather_encoder_test.py b/tensorflow_model_optimization/python/core/internal/tensor_encoding/core/gather_encoder_test.py index 26fbed255..737518a68 100644 --- a/tensorflow_model_optimization/python/core/internal/tensor_encoding/core/gather_encoder_test.py +++ b/tensorflow_model_optimization/python/core/internal/tensor_encoding/core/gather_encoder_test.py @@ -198,12 +198,14 @@ def test_full_commutativity_with_sum(self): core_encoder.EncoderComposer(test_utils.TimesTwoEncodingStage()).make(), spec) self.assertTrue(encoder.fully_commutes_with_sum) + self.assertIsInstance(encoder.fully_commutes_with_sum, bool) encoder = gather_encoder.GatherEncoder.from_encoder( core_encoder.EncoderComposer( test_utils.TimesTwoEncodingStage()).add_parent( test_utils.TimesTwoEncodingStage(), T2_VALS).make(), spec) self.assertTrue(encoder.fully_commutes_with_sum) + self.assertIsInstance(encoder.fully_commutes_with_sum, bool) encoder = core_encoder.EncoderComposer( test_utils.SignIntFloatEncodingStage()) @@ -213,6 +215,7 @@ def test_full_commutativity_with_sum(self): test_utils.PlusOneOverNEncodingStage(), T2_VALS) encoder = gather_encoder.GatherEncoder.from_encoder(encoder.make(), spec) self.assertFalse(encoder.fully_commutes_with_sum) + self.assertIsInstance(encoder.fully_commutes_with_sum, bool) @tf_test_util.run_all_in_graph_and_eager_modes def test_state_aggregation_modes(self):