feat: Support configurable mesh locality and device count in TPU benchmarks#125
Conversation
|
Thanks @simrankaurb , Have we verified the functionality and metrics correctness on this change? Also, should we also cover the HBM and H2D/D2H? |
|
Thanks for the comment @linamy85 ! I have tested the change and pasted all logs in go/tpu-single-host-analysis. HBM and H2D/D2H seem to be working on local devices only. Please take a look and let me know if we have any concerns.
|
6b24c1a to
c1da258
Compare
| return value | ||
| case ShardingStrategy.SHARDING_ON_ALL_DEVICES_WITH_M: | ||
| return value // jax.device_count() | ||
| return value // jax.local_device_count() |
There was a problem hiding this comment.
If we change this to local_device_count(), the full node-pool scenario might not work as intended. Should we use parameter to control this switch? For example,
- In the benchmark_utils.py, we can accept device_count as parameter.
- In the benchmark_gemm.py, we can add new parameter (e.g. run_on_local_node, default to False)
WDYT?
There was a problem hiding this comment.
Thanks! Please let me know if the design proposed in this one pager sits well with the expected behaviour: https://docs.google.com/document/d/1y0K2VJT0BLYV2-ZqHlm0RDV6WD8jHvJ5eIr_KBbjohA
There was a problem hiding this comment.
Implementation completed. PTAL!
Testing: Ran gemm for single host and full slice. Here are the configs and values recorded: https://paste.googleplex.com/6202599689814016
The difference between TFLOPs/s and TFLOPS/s/device can tell us the number of devices(8 and 16) in both cases.
0cba43b to
c3852dc
Compare
c3852dc to
946541b
Compare
This PR introduces dynamic parameterization for mesh locality and device counts. It allows benchmarks to configure whether they run on the local host or the full slice, while preserving the original full-slice behavior as the default.
Changes
1. Utility Changes (
Ironwood/src/benchmark_utils.py)create_meshto accept alocal_meshboolean parameter (defaults toFalse). IfTrue, it restricts mesh creation tojax.local_devices().handle_per_device_based_on_shardingandhandle_all_devices_based_on_sharding) to accept an explicitdevice_countparameter.handle_based_on_shardingto accept an optionaldevice_countparameter (defaults tojax.device_count()), passing it down to the helpers.2. Benchmark API Changes (
Ironwood/src/benchmark_gemm.py)run_on_local_node: bool = Falsein all 5 GEMM benchmark signatures and their metrics calculation functions.local_mesh=run_on_local_nodetocreate_mesh.device_count(jax.local_device_count()ifrun_on_local_nodeisTrue, elsejax.device_count()) and passed it tohandle_based_on_shardingto ensure accurate FLOPs/throughput calculation.Backward Compatibility & Safety
benchmark_compute.py,benchmark_inference_compute.py, etc.) remain untouched.local_meshdefaults toFalseanddevice_countdefaults toNone(falling back to globaldevice_count), all untouched benchmarks will continue to execute and report metrics across the full slice out-of-the-box.