diff --git a/tensorflow_model_optimization/python/core/internal/tensor_encoding/core/gather_encoder.py b/tensorflow_model_optimization/python/core/internal/tensor_encoding/core/gather_encoder.py index 38b8b2d8..1ce87307 100644 --- a/tensorflow_model_optimization/python/core/internal/tensor_encoding/core/gather_encoder.py +++ b/tensorflow_model_optimization/python/core/internal/tensor_encoding/core/gather_encoder.py @@ -370,7 +370,7 @@ def input_tensorspec(self): @property def fully_commutes_with_sum(self): # If any element is not True, the whole thing does not fully commute. - return sum(tf.nest.flatten(self._commuting_structure)) + return all(tf.nest.flatten(self._commuting_structure)) @property def state_update_aggregation_modes(self): diff --git a/tensorflow_model_optimization/python/core/internal/tensor_encoding/core/gather_encoder_test.py b/tensorflow_model_optimization/python/core/internal/tensor_encoding/core/gather_encoder_test.py index 26fbed25..737518a6 100644 --- a/tensorflow_model_optimization/python/core/internal/tensor_encoding/core/gather_encoder_test.py +++ b/tensorflow_model_optimization/python/core/internal/tensor_encoding/core/gather_encoder_test.py @@ -198,12 +198,14 @@ def test_full_commutativity_with_sum(self): core_encoder.EncoderComposer(test_utils.TimesTwoEncodingStage()).make(), spec) self.assertTrue(encoder.fully_commutes_with_sum) + self.assertIsInstance(encoder.fully_commutes_with_sum, bool) encoder = gather_encoder.GatherEncoder.from_encoder( core_encoder.EncoderComposer( test_utils.TimesTwoEncodingStage()).add_parent( test_utils.TimesTwoEncodingStage(), T2_VALS).make(), spec) self.assertTrue(encoder.fully_commutes_with_sum) + self.assertIsInstance(encoder.fully_commutes_with_sum, bool) encoder = core_encoder.EncoderComposer( test_utils.SignIntFloatEncodingStage()) @@ -213,6 +215,7 @@ def test_full_commutativity_with_sum(self): test_utils.PlusOneOverNEncodingStage(), T2_VALS) encoder = gather_encoder.GatherEncoder.from_encoder(encoder.make(), spec) self.assertFalse(encoder.fully_commutes_with_sum) + self.assertIsInstance(encoder.fully_commutes_with_sum, bool) @tf_test_util.run_all_in_graph_and_eager_modes def test_state_aggregation_modes(self):