[JAX] Improve JAX tutorial documentation#2976
Conversation
Greptile SummaryThis PR replaces the existing JAX integration notebook (
Confidence Score: 5/5Documentation-only refactor with no functional code path changes; safe to merge. All changes are documentation and CI scaffolding. The RST files and Python tutorial are well-structured, conftest.py correctly handles path resolution for pytest, and the CI shell scripts add non-breaking test steps with appropriate skip guards. No files require special attention beyond the section numbering nit in dense.rst. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[docs/index.rst] --> B[te_jax_integration.rst hub page]
B --> C[jax_examples/dense.rst Dense GEMM tutorial]
B --> D[jax_examples/attention.rst Coming soon]
B --> E[jax_examples/moe.rst Coming soon]
C -->|literalinclude| F[dense.py companion source]
C -->|literalinclude| G[dense.out captured output]
F -->|pytest via conftest.py| H[CI: L0_jax_unittest single-GPU]
F -->|pytest -k multi_gpu| I[CI: L1_jax_distributed 4-GPU]
J[conftest.py sys.path fix] -.->|loaded before import| F
Reviews (3): Last reviewed commit: "Merge branch 'main' into jberchtold/impr..." | Re-trigger Greptile |
| @@ -0,0 +1,446 @@ | |||
| { | |||
There was a problem hiding this comment.
Missing sections 4 and 5 — numbering jumps from § 3 to § 6
The notebook's section headings go ## 1 → ## 2 → ## 3. Single-GPU performance → ## 6. Multi-GPU → ## 7. Collective GEMM (placeholder), with no trace of sections 4 or 5. Unlike § 7, there are no "Coming soon" placeholders for them either. A reader following the numbered flow will assume content was accidentally deleted. If these sections are planned but not yet written, add stub cells similar to the ## 7 placeholder; if the numbering was simply mis-applied, renumber the existing headings to be consecutive (3 → 4 for Multi-GPU, 4 → 5 for Collective GEMM).
| @@ -0,0 +1,446 @@ | |||
| { | |||
| "\n", | ||
| "**TODO — Coming soon.**\n", | ||
| "\n", | ||
| "[← Back to the JAX integration overview](../te_jax_integration.ipynb)" |
There was a problem hiding this comment.
I will remove this placeholder before merging, but let me know if you have suggestions for the main hub notebook and how these sub-tutorials will be organized
| "source": [ | ||
| "## 7. Collective GEMM (placeholder)\n", | ||
| "\n", | ||
| "*Coming soon.*" |
There was a problem hiding this comment.
I will remove this placeholder before merging, but let me know if you have feedback on where it fits w.r.t the rest of the tutorial
| "\n", | ||
| "**TODO — Coming soon.**\n", | ||
| "\n", | ||
| "This notebook will cover TE's `MoEBlock` layer which utilizes TE's optimized routing, permutation and grouped GEMM\n", |
There was a problem hiding this comment.
I will remove this placeholder before merging, but let me know if you have suggestions for the main hub notebook and how these sub-tutorials will be organized
| { | ||
| "cell_type": "markdown", | ||
| "id": "intro-md", | ||
| "metadata": {}, |
There was a problem hiding this comment.
Reworking the existing getting started tutorials that are merged with PyTorch tutorials will be a follow-up PR
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
for more information, see https://pre-commit.ci
| model_apply_fn=te_model.apply, | ||
| variables=te_vars, | ||
| input=x, | ||
| output_grad=dy, |
There was a problem hiding this comment.
sys.path.append("..") breaks standalone execution from the repo root
dense.out documents that the canonical way to regenerate captured output is python3 docs/examples/jax_examples/dense.py > dense.out, run from the repo root. At that point ".." resolves to the parent of the repo root, not to docs/examples/ where quickstart_jax_utils lives, so import quickstart_jax_utils as utils raises ModuleNotFoundError. In pytest mode conftest.py inserts the correct absolute path before the module is imported, masking the problem in CI, but the standalone invocation (and the snippet shown to users in the tutorial) breaks. conftest.py already shows the right pattern — use os.path.dirname(os.path.abspath(__file__)) to construct an absolute path instead.
|
/te-ci L1 L0 |
KshitijLakhani
left a comment
There was a problem hiding this comment.
Thanks for adding this skeleton.
I like the modular approach, concise explanation and benchmarking.
In general it looks good there might be some working around needed on item placements but I think that's going to be an evolving process.
| .. | ||
| Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
|
|
||
| See LICENSE for license information. | ||
|
|
||
| JAX: Attention with TransformerEngine | ||
| ===================================== | ||
|
|
||
| **TODO — Coming soon.** | ||
|
|
||
| `← Back to the JAX integration overview <../te_jax_integration.html>`_ |
There was a problem hiding this comment.
Unrelated to attention but looks like you are renaming the dir to examples/jax_examples whereas I think the pytorch side is examples/pytorch ?
I think we could stick with examples/jax - thoughts ?
| `Haiku/Flax interop | ||
| <https://dm-haiku.readthedocs.io/en/latest/notebooks/flax.html>`_ if you're on | ||
| a different stack.) | ||
| * **Baseline dtype.** bf16 for inputs and parameters. |
There was a problem hiding this comment.
Should we add GB200 (arch) details here rather than adding it in the example module or is that by choice ?
I think there's value in having all examples run on the same arch for consistency.
| JAX: Attention with TransformerEngine | ||
| ===================================== | ||
|
|
||
| **TODO — Coming soon.** |
There was a problem hiding this comment.
Thanks for the place holder.
cc: @cyanguwa I'll be adding in TE JAX attn examples in here
| # | ||
| # See LICENSE for license information. | ||
|
|
||
| """Pytest conftest for docs/examples/jax_examples. |
| @@ -0,0 +1,22 @@ | |||
| # Numbers below are illustrative (captured on a GB200). Regenerate with: | |||
| # python3 docs/examples/jax_examples/dense.py > dense.out | |||
| # after substantial code changes. | |||
There was a problem hiding this comment.
"after substantial code changes" ?
| and your performance comparison will not be accurate. | ||
|
|
||
|
|
||
| 6. Multi-GPU: DP=2 / TP=2 on a single Dense |
There was a problem hiding this comment.
- Single GPU performane
4,5 ? - Multi-GPU: DP=2 / TP=2 on a single Dense
| export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" | ||
| NVTE_JAX_CUSTOM_CALLS="false" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder_without_custom_call.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py without custom calls" | ||
|
|
||
| # Exercise the docs/examples/jax_examples tutorials. The multi-GPU tests are |
There was a problem hiding this comment.
Thanks for adding these here.
I think they are pretty light weight so no problem running often via L0
Description
Reworks tutorial to focus on individual operations and their usage+performance. This will make it clearer to users the impact of each operation and they can focus on trying them out one-at-a-time depending on which are bottlenecks in their models.
Additionally, this switches from notebook
.ipynbfiles to.rstand separate.pyfiles for easier testing in CI to ensure our docs do not become stale and always work with the latest TE version.Type of change
Changes
Checklist: