Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,12 @@ def decode_fn(encoded_structure):
encoded_structure = py_utils.merge_dicts(encoded_structure,
encoded_py_structure['flat_py'])
encoded_structure = tf.nest.pack_sequence_as(
encoded_py_structure['full'], tf.nest.flatten(encoded_structure))
encoded_py_structure['full'],
[
encoded_structure[k] for k, _ in
py_utils.flatten_with_joined_string_paths(
encoded_py_structure['full'])
])
return encoder.decode(encoded_structure[_TENSORS],
encoded_structure[_PARAMS],
encoded_structure[_SHAPES])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,35 @@ def test_python_constants_not_exposed(self):
self.assertAllClose(x, decoded_x_tf)
self.assertAllClose(x, decoded_x_py)

@tf_test_util.run_all_in_graph_and_eager_modes
def test_interleaved_py_tf_parameters(self):
"""Tests that interleaved Python and TF parameters are decoded correctly."""
class SwappedKeysStage(test_utils.SimpleLinearEncodingStage):
def get_params(self, name=None):
# Return parameters in a specific order to test if encoding/decoding
# respects this order or sorts keys.
params = {self.B_PARAM_KEY: self._b, self.A_PARAM_KEY: self._a}
return params, params

x = tf.constant(2.0)
tensorspec = tf.TensorSpec.from_tensor(x)

# Use one Python constant and one TF Variable to ensure they are split and merged.
b_var = tf.compat.v1.get_variable('b_var_interleaved', initializer=3.0)

encoder = simple_encoder.SimpleEncoder(
core_encoder.EncoderComposer(
SwappedKeysStage(2.0, b_var)).make(),
tensorspec)

state = encoder.initial_state()
iteration = _make_iteration_function(encoder)

self.evaluate(tf.compat.v1.global_variables_initializer())

x_val, _, decoded_x, _ = self.evaluate(iteration(x, state))
self.assertAllClose(x_val, decoded_x)

@tf_test_util.run_all_in_graph_and_eager_modes
def test_decode_needs_input_shape_static(self):
"""Tests that mechanism for passing input shape works with static shape."""
Expand Down
Loading