diff --git a/tensorflow_model_optimization/python/core/internal/tensor_encoding/core/simple_encoder.py b/tensorflow_model_optimization/python/core/internal/tensor_encoding/core/simple_encoder.py index 03966569..09ff5844 100644 --- a/tensorflow_model_optimization/python/core/internal/tensor_encoding/core/simple_encoder.py +++ b/tensorflow_model_optimization/python/core/internal/tensor_encoding/core/simple_encoder.py @@ -134,7 +134,12 @@ def decode_fn(encoded_structure): encoded_structure = py_utils.merge_dicts(encoded_structure, encoded_py_structure['flat_py']) encoded_structure = tf.nest.pack_sequence_as( - encoded_py_structure['full'], tf.nest.flatten(encoded_structure)) + encoded_py_structure['full'], + [ + encoded_structure[k] for k, _ in + py_utils.flatten_with_joined_string_paths( + encoded_py_structure['full']) + ]) return encoder.decode(encoded_structure[_TENSORS], encoded_structure[_PARAMS], encoded_structure[_SHAPES]) diff --git a/tensorflow_model_optimization/python/core/internal/tensor_encoding/core/simple_encoder_test.py b/tensorflow_model_optimization/python/core/internal/tensor_encoding/core/simple_encoder_test.py index 27404cc0..2f504b3b 100644 --- a/tensorflow_model_optimization/python/core/internal/tensor_encoding/core/simple_encoder_test.py +++ b/tensorflow_model_optimization/python/core/internal/tensor_encoding/core/simple_encoder_test.py @@ -141,6 +141,35 @@ def test_python_constants_not_exposed(self): self.assertAllClose(x, decoded_x_tf) self.assertAllClose(x, decoded_x_py) + @tf_test_util.run_all_in_graph_and_eager_modes + def test_interleaved_py_tf_parameters(self): + """Tests that interleaved Python and TF parameters are decoded correctly.""" + class SwappedKeysStage(test_utils.SimpleLinearEncodingStage): + def get_params(self, name=None): + # Return parameters in a specific order to test if encoding/decoding + # respects this order or sorts keys. + params = {self.B_PARAM_KEY: self._b, self.A_PARAM_KEY: self._a} + return params, params + + x = tf.constant(2.0) + tensorspec = tf.TensorSpec.from_tensor(x) + + # Use one Python constant and one TF Variable to ensure they are split and merged. + b_var = tf.compat.v1.get_variable('b_var_interleaved', initializer=3.0) + + encoder = simple_encoder.SimpleEncoder( + core_encoder.EncoderComposer( + SwappedKeysStage(2.0, b_var)).make(), + tensorspec) + + state = encoder.initial_state() + iteration = _make_iteration_function(encoder) + + self.evaluate(tf.compat.v1.global_variables_initializer()) + + x_val, _, decoded_x, _ = self.evaluate(iteration(x, state)) + self.assertAllClose(x_val, decoded_x) + @tf_test_util.run_all_in_graph_and_eager_modes def test_decode_needs_input_shape_static(self): """Tests that mechanism for passing input shape works with static shape."""