Skip to content

feat: Support configurable mesh locality and device count in TPU benchmarks#125

Merged
linamy85 merged 3 commits into
AI-Hypercomputer:chsfrom
simrankaurb:production-mesh-fix
Jun 9, 2026
Merged

feat: Support configurable mesh locality and device count in TPU benchmarks#125
linamy85 merged 3 commits into
AI-Hypercomputer:chsfrom
simrankaurb:production-mesh-fix

Conversation

@simrankaurb

@simrankaurb simrankaurb commented Jun 4, 2026

Copy link
Copy Markdown

This PR introduces dynamic parameterization for mesh locality and device counts. It allows benchmarks to configure whether they run on the local host or the full slice, while preserving the original full-slice behavior as the default.


Changes

1. Utility Changes (Ironwood/src/benchmark_utils.py)

  • Updated create_mesh to accept a local_mesh boolean parameter (defaults to False). If True, it restricts mesh creation to jax.local_devices().
  • Updated sharding helpers (handle_per_device_based_on_sharding and handle_all_devices_based_on_sharding) to accept an explicit device_count parameter.
  • Updated handle_based_on_sharding to accept an optional device_count parameter (defaults to jax.device_count()), passing it down to the helpers.

2. Benchmark API Changes (Ironwood/src/benchmark_gemm.py)

  • Exposed run_on_local_node: bool = False in all 5 GEMM benchmark signatures and their metrics calculation functions.
  • Passed local_mesh=run_on_local_node to create_mesh.
  • Dynamically determined the active device_count (jax.local_device_count() if run_on_local_node is True, else jax.device_count()) and passed it to handle_based_on_sharding to ensure accurate FLOPs/throughput calculation.

Backward Compatibility & Safety

  • All other compute benchmarks (benchmark_compute.py, benchmark_inference_compute.py, etc.) remain untouched.
  • Because local_mesh defaults to False and device_count defaults to None (falling back to global device_count), all untouched benchmarks will continue to execute and report metrics across the full slice out-of-the-box.

@linamy85

linamy85 commented Jun 5, 2026

Copy link
Copy Markdown
Collaborator

Thanks @simrankaurb , Have we verified the functionality and metrics correctness on this change?

Also, should we also cover the HBM and H2D/D2H?

@simrankaurb

Copy link
Copy Markdown
Author

Thanks for the comment @linamy85 ! I have tested the change and pasted all logs in go/tpu-single-host-analysis. HBM and H2D/D2H seem to be working on local devices only. Please take a look and let me know if we have any concerns.
Thanks!!

Thanks @simrankaurb , Have we verified the functionality and metrics correctness on this change?

Also, should we also cover the HBM and H2D/D2H?

@simrankaurb simrankaurb force-pushed the production-mesh-fix branch from 6b24c1a to c1da258 Compare June 8, 2026 14:18
Comment thread Ironwood/src/benchmark_utils.py Outdated
return value
case ShardingStrategy.SHARDING_ON_ALL_DEVICES_WITH_M:
return value // jax.device_count()
return value // jax.local_device_count()

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we change this to local_device_count(), the full node-pool scenario might not work as intended. Should we use parameter to control this switch? For example,

  1. In the benchmark_utils.py, we can accept device_count as parameter.
  2. In the benchmark_gemm.py, we can add new parameter (e.g. run_on_local_node, default to False)

WDYT?

cc @simonleesyuan30

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Please let me know if the design proposed in this one pager sits well with the expected behaviour: https://docs.google.com/document/d/1y0K2VJT0BLYV2-ZqHlm0RDV6WD8jHvJ5eIr_KBbjohA

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Implementation completed. PTAL!

Testing: Ran gemm for single host and full slice. Here are the configs and values recorded: https://paste.googleplex.com/6202599689814016
The difference between TFLOPs/s and TFLOPS/s/device can tell us the number of devices(8 and 16) in both cases.

@simrankaurb simrankaurb force-pushed the production-mesh-fix branch from 0cba43b to c3852dc Compare June 9, 2026 07:04
@simrankaurb simrankaurb changed the title Fix mesh creation to use local devices for single-host benchmarks feat: Support configurable mesh locality and device count in TPU benchmarks Jun 9, 2026
@simrankaurb simrankaurb force-pushed the production-mesh-fix branch from c3852dc to 946541b Compare June 9, 2026 07:17
@linamy85 linamy85 merged commit f010c92 into AI-Hypercomputer:chs Jun 9, 2026
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants