Skip to content

[JAX] Improve JAX tutorial documentation#2976

Open
jberchtold-nvidia wants to merge 7 commits into
NVIDIA:mainfrom
jberchtold-nvidia:jberchtold/improve-jax-tutorial
Open

[JAX] Improve JAX tutorial documentation#2976
jberchtold-nvidia wants to merge 7 commits into
NVIDIA:mainfrom
jberchtold-nvidia:jberchtold/improve-jax-tutorial

Conversation

@jberchtold-nvidia
Copy link
Copy Markdown
Collaborator

@jberchtold-nvidia jberchtold-nvidia commented May 11, 2026

Description

Reworks tutorial to focus on individual operations and their usage+performance. This will make it clearer to users the impact of each operation and they can focus on trying them out one-at-a-time depending on which are bottlenecks in their models.

Additionally, this switches from notebook .ipynb files to .rst and separate .py files for easier testing in CI to ensure our docs do not become stale and always work with the latest TE version.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Rework existing tutorial and replace with new Dense-specific tutorial
  • Placeholders for Attention and MoE
  • Refactor .ipynb notebooks to .rst and .py files for similar appearance in docs but better testability in CI by running .py files

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 11, 2026

Greptile Summary

This PR replaces the existing JAX integration notebook (te_jax_integration.ipynb) with a restructured tutorial suite consisting of RST hub pages and separate .py source files, making documentation testable in CI via pytest.

  • Adds a focused Dense GEMM tutorial (dense.rst + dense.py) covering single-GPU quantization, multi-GPU (DP+TP) sharding, and a placeholder for Collective GEMM; placeholder stubs are also added for Attention and MoE.
  • Introduces a conftest.py that correctly resolves the quickstart_jax_utils import path for pytest, and wires the new .py examples into both single-GPU (L0) and distributed (L1) CI tiers.
  • Replaces the old notebook-based toctree entry in docs/index.rst with the new .rst hub page.

Confidence Score: 5/5

Documentation-only refactor with no functional code path changes; safe to merge.

All changes are documentation and CI scaffolding. The RST files and Python tutorial are well-structured, conftest.py correctly handles path resolution for pytest, and the CI shell scripts add non-breaking test steps with appropriate skip guards.

No files require special attention beyond the section numbering nit in dense.rst.

Important Files Changed

Filename Overview
docs/examples/jax_examples/dense.rst New RST tutorial for Dense GEMMs. Section numbering jumps from §3 to §6 with no placeholder stubs for §4 and §5 (unlike §7 which has a Coming soon stub).
docs/examples/jax_examples/dense.py New tutorial source file with pytest test functions guarded by requires_mxfp8 skip marks.
docs/examples/jax_examples/conftest.py New pytest conftest that correctly adds docs/examples/ to sys.path using an absolute path derived from file.
docs/examples/te_jax_integration.rst New RST hub/landing page replacing the deleted te_jax_integration.ipynb.
qa/L0_jax_unittest/test.sh Adds a pytest run of docs/examples/jax_examples/ to the single-GPU CI tier.
qa/L1_jax_distributed_unittest/test.sh Adds a -k multi_gpu filtered pytest run of the tutorial to the distributed CI tier.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[docs/index.rst] --> B[te_jax_integration.rst hub page]
    B --> C[jax_examples/dense.rst Dense GEMM tutorial]
    B --> D[jax_examples/attention.rst Coming soon]
    B --> E[jax_examples/moe.rst Coming soon]
    C -->|literalinclude| F[dense.py companion source]
    C -->|literalinclude| G[dense.out captured output]
    F -->|pytest via conftest.py| H[CI: L0_jax_unittest single-GPU]
    F -->|pytest -k multi_gpu| I[CI: L1_jax_distributed 4-GPU]
    J[conftest.py sys.path fix] -.->|loaded before import| F
Loading

Reviews (3): Last reviewed commit: "Merge branch 'main' into jberchtold/impr..." | Re-trigger Greptile

Comment thread docs/examples/jax_examples/dense.ipynb Outdated
@@ -0,0 +1,446 @@
{
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Missing sections 4 and 5 — numbering jumps from § 3 to § 6

The notebook's section headings go ## 1## 2## 3. Single-GPU performance## 6. Multi-GPU## 7. Collective GEMM (placeholder), with no trace of sections 4 or 5. Unlike § 7, there are no "Coming soon" placeholders for them either. A reader following the numbered flow will assume content was accidentally deleted. If these sections are planned but not yet written, add stub cells similar to the ## 7 placeholder; if the numbering was simply mis-applied, renumber the existing headings to be consecutive (3 → 4 for Multi-GPU, 4 → 5 for Collective GEMM).

Comment thread docs/examples/jax_examples/dense.ipynb Outdated
@@ -0,0 +1,446 @@
{
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Unused warnings import

import warnings is imported in the setup cell but never referenced anywhere in the notebook. Remove it to keep the imports clean.

"\n",
"**TODO — Coming soon.**\n",
"\n",
"[← Back to the JAX integration overview](../te_jax_integration.ipynb)"
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will remove this placeholder before merging, but let me know if you have suggestions for the main hub notebook and how these sub-tutorials will be organized

Comment thread docs/examples/jax_examples/dense.ipynb Outdated
"source": [
"## 7. Collective GEMM (placeholder)\n",
"\n",
"*Coming soon.*"
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will remove this placeholder before merging, but let me know if you have feedback on where it fits w.r.t the rest of the tutorial

Comment thread docs/examples/jax_examples/moe.ipynb Outdated
"\n",
"**TODO — Coming soon.**\n",
"\n",
"This notebook will cover TE's `MoEBlock` layer which utilizes TE's optimized routing, permutation and grouped GEMM\n",
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will remove this placeholder before merging, but let me know if you have suggestions for the main hub notebook and how these sub-tutorials will be organized

Comment thread docs/examples/jax_examples/dense.ipynb Outdated
{
"cell_type": "markdown",
"id": "intro-md",
"metadata": {},
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reworking the existing getting started tutorials that are merged with PyTorch tutorials will be a follow-up PR

model_apply_fn=te_model.apply,
variables=te_vars,
input=x,
output_grad=dy,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 sys.path.append("..") breaks standalone execution from the repo root

dense.out documents that the canonical way to regenerate captured output is python3 docs/examples/jax_examples/dense.py > dense.out, run from the repo root. At that point ".." resolves to the parent of the repo root, not to docs/examples/ where quickstart_jax_utils lives, so import quickstart_jax_utils as utils raises ModuleNotFoundError. In pytest mode conftest.py inserts the correct absolute path before the module is imported, masking the problem in CI, but the standalone invocation (and the snippet shown to users in the tutorial) breaks. conftest.py already shows the right pattern — use os.path.dirname(os.path.abspath(__file__)) to construct an absolute path instead.

@jberchtold-nvidia
Copy link
Copy Markdown
Collaborator Author

/te-ci L1 L0

Copy link
Copy Markdown
Collaborator

@KshitijLakhani KshitijLakhani left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this skeleton.
I like the modular approach, concise explanation and benchmarking.

In general it looks good there might be some working around needed on item placements but I think that's going to be an evolving process.

Comment on lines +1 to +11
..
Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.

See LICENSE for license information.

JAX: Attention with TransformerEngine
=====================================

**TODO — Coming soon.**

`← Back to the JAX integration overview <../te_jax_integration.html>`_
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated to attention but looks like you are renaming the dir to examples/jax_examples whereas I think the pytorch side is examples/pytorch ?
I think we could stick with examples/jax - thoughts ?

`Haiku/Flax interop
<https://dm-haiku.readthedocs.io/en/latest/notebooks/flax.html>`_ if you're on
a different stack.)
* **Baseline dtype.** bf16 for inputs and parameters.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add GB200 (arch) details here rather than adding it in the example module or is that by choice ?
I think there's value in having all examples run on the same arch for consistency.

JAX: Attention with TransformerEngine
=====================================

**TODO — Coming soon.**
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the place holder.
cc: @cyanguwa I'll be adding in TE JAX attn examples in here

#
# See LICENSE for license information.

"""Pytest conftest for docs/examples/jax_examples.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with the usage of pytest in general, however I think currently the examples/mnist uses the in built Python UT module for the test example.
@phu0ngng and @tdophung it might be good to standardize and use pytest in there too - thoughts ?

@@ -0,0 +1,22 @@
# Numbers below are illustrative (captured on a GB200). Regenerate with:
# python3 docs/examples/jax_examples/dense.py > dense.out
# after substantial code changes.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"after substantial code changes" ?

and your performance comparison will not be accurate.


6. Multi-GPU: DP=2 / TP=2 on a single Dense
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Single GPU performane
    4,5 ?
  2. Multi-GPU: DP=2 / TP=2 on a single Dense

export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
NVTE_JAX_CUSTOM_CALLS="false" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder_without_custom_call.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py without custom calls"

# Exercise the docs/examples/jax_examples tutorials. The multi-GPU tests are
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding these here.
I think they are pretty light weight so no problem running often via L0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants