From fb6ee476664aad8095c4c5ecc7a30b31d5bcea7f Mon Sep 17 00:00:00 2001 From: wyfEmma Date: Thu, 28 May 2026 15:22:16 +0000 Subject: [PATCH] fix librispeech ctc_loss incompatibility with jit --- .../librispeech_jax/workload.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py index 8eb37e25b..cb1933dd3 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -325,7 +325,14 @@ def greedy_decode( jax_sharding_utils.get_replicate_sharding(), # model_state jax_sharding_utils.get_replicate_sharding(), # rng ), - out_shardings=jax_sharding_utils.get_batch_dim_sharding(), + out_shardings={ + 'loss_per_example': jax_sharding_utils.get_batch_dim_sharding(), + 'decoded': jax_sharding_utils.get_batch_dim_sharding(), + 'decoded_paddings': jax_sharding_utils.get_batch_dim_sharding(), + 'targets': jax_sharding_utils.get_batch_dim_sharding(), + 'target_paddings': jax_sharding_utils.get_batch_dim_sharding(), + 'n_valid_examples': jax_sharding_utils.get_replicate_sharding(), + }, static_argnums=(0,), ) def _eval_step( @@ -354,8 +361,7 @@ def _eval_step( 'decoded_paddings': decoded_paddings, 'targets': targets, 'target_paddings': target_paddings, - 'n_valid_examples': jnp.zeros((len(jax.devices()), 1)) - + loss['n_valid_examples'], + 'n_valid_examples': loss['n_valid_examples'], } return metrics_dict