diff --git a/docs/examples/jax_examples/attention.rst b/docs/examples/jax_examples/attention.rst
new file mode 100644
index 0000000000..c9f84da634
--- /dev/null
+++ b/docs/examples/jax_examples/attention.rst
@@ -0,0 +1,11 @@
+..
+ Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+
+ See LICENSE for license information.
+
+JAX: Attention with TransformerEngine
+=====================================
+
+**TODO — Coming soon.**
+
+`← Back to the JAX integration overview <../te_jax_integration.html>`_
diff --git a/docs/examples/jax_examples/conftest.py b/docs/examples/jax_examples/conftest.py
new file mode 100644
index 0000000000..a584e7392e
--- /dev/null
+++ b/docs/examples/jax_examples/conftest.py
@@ -0,0 +1,17 @@
+# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# See LICENSE for license information.
+
+"""Pytest conftest for docs/examples/jax_examples.
+
+Adds ``docs/examples/`` to ``sys.path`` so the example modules can do
+``import quickstart_jax_utils`` regardless of the directory pytest was invoked
+from.
+"""
+import os
+import sys
+
+_HERE = os.path.dirname(os.path.abspath(__file__))
+_EXAMPLES_ROOT = os.path.dirname(_HERE)
+if _EXAMPLES_ROOT not in sys.path:
+ sys.path.insert(0, _EXAMPLES_ROOT)
diff --git a/docs/examples/jax_examples/dense.out b/docs/examples/jax_examples/dense.out
new file mode 100644
index 0000000000..d6628b2f82
--- /dev/null
+++ b/docs/examples/jax_examples/dense.out
@@ -0,0 +1,22 @@
+# Numbers below are illustrative (captured on a GB200). Regenerate with:
+# python3 docs/examples/jax_examples/dense.py > dense.out
+# after substantial code changes.
+
+# SINGLE_GPU_OUTPUT_START
+Variable collections: ['params']
+{'params': {'Dense_0': {'kernel': ((4096, 16384), dtype('float32'))}}}
+
+bf16 baseline:
+Mean time: 4.126 ms
+
+TE MXFP8BlockScaling:
+Mean time: 1.690 ms
+# SINGLE_GPU_OUTPUT_END
+
+# MULTI_GPU_OUTPUT_START
+bf16 DP=2/TP=2:
+Mean time: 1.726 ms
+
+TE MXFP8BlockScaling DP=2/TP=2:
+Mean time: 0.969 ms
+# MULTI_GPU_OUTPUT_END
diff --git a/docs/examples/jax_examples/dense.py b/docs/examples/jax_examples/dense.py
new file mode 100644
index 0000000000..bba341e074
--- /dev/null
+++ b/docs/examples/jax_examples/dense.py
@@ -0,0 +1,238 @@
+# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# See LICENSE for license information.
+
+"""JAX: Dense GEMMs with TransformerEngine.
+
+Companion source for ``dense.rst``. Code blocks between ``# DENSE_*_START`` /
+``# DENSE_*_END`` markers are pulled into the RST via ``literalinclude``.
+
+Run as a pytest module to exercise the example end-to-end:
+
+ pytest -v docs/examples/jax_examples/dense.py
+
+The multi-GPU section auto-skips when fewer than 4 GPUs are visible.
+"""
+
+# DENSE_IMPORTS_START
+import sys
+
+sys.path.append("..") # so we can import quickstart_jax_utils from docs/examples/
+
+import jax
+import jax.numpy as jnp
+from flax import linen as nn
+
+import quickstart_jax_utils as utils
+
+# DENSE_IMPORTS_END
+
+
+# DENSE_BASELINE_MODEL_START
+class FlaxDenseBlock(nn.Module):
+ """One linear layer. ``dot_general_cls`` lets us swap the GEMM impl."""
+
+ features: int
+ dot_general_cls: callable = lambda: None
+
+ @nn.compact
+ def __call__(self, x):
+ return nn.Dense(
+ features=self.features,
+ use_bias=False,
+ dot_general=self.dot_general_cls(),
+ )(x)
+
+
+# DENSE_BASELINE_MODEL_END
+
+
+# DENSE_INPUTS_SETUP_START
+batch, seq, hidden, out_features = 4, 2048, 4096, 16384
+dtype = jnp.bfloat16
+
+key = jax.random.PRNGKey(0)
+k_init, k_x, k_dy = jax.random.split(key, 3)
+x = jax.random.normal(k_x, (batch, seq, hidden)).astype(dtype)
+dy = jax.random.normal(k_dy, (batch, seq, out_features)).astype(dtype)
+
+baseline = FlaxDenseBlock(features=out_features)
+baseline_vars = baseline.init(k_init, x)
+# DENSE_INPUTS_SETUP_END
+
+
+# DENSE_TE_SETUP_START
+from transformer_engine.jax import flax as te_flax
+from transformer_engine.common.recipe import MXFP8BlockScaling
+
+recipe = MXFP8BlockScaling()
+te_dot_general_cls = te_flax.make_dot_general_cls(recipe)
+
+te_model = FlaxDenseBlock(features=out_features, dot_general_cls=te_dot_general_cls)
+te_vars = te_model.init(k_init, x)
+
+print("Variable collections:", list(te_vars.keys()))
+print(jax.tree_util.tree_map(lambda a: (a.shape, a.dtype), te_vars))
+# DENSE_TE_SETUP_END
+
+
+# DENSE_SINGLE_GPU_BENCH_START
+def run_single_gpu_bench():
+ print("bf16 baseline:")
+ utils.speedometer(
+ model_apply_fn=baseline.apply,
+ variables=baseline_vars,
+ input=x,
+ output_grad=dy,
+ )
+
+ print(f"\nTE {type(recipe).__name__}:")
+ utils.speedometer(
+ model_apply_fn=te_model.apply,
+ variables=te_vars,
+ input=x,
+ output_grad=dy,
+ )
+
+
+# DENSE_SINGLE_GPU_BENCH_END
+
+
+# DENSE_MULTI_GPU_MESH_SETUP_START
+from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
+from jax.experimental import mesh_utils
+from transformer_engine.jax.sharding import MeshResource, global_shard_guard
+
+
+def build_dp_tp_mesh():
+ # 2x2 mesh: DP on one axis, TP on the other.
+ devices = mesh_utils.create_device_mesh((2, 2))
+ mesh = Mesh(devices, axis_names=("dp", "tp"))
+
+ # Tell TE which mesh axis is which. This is a *global* setting, established
+ # outside JIT, so TE's GEMM primitives can plan comms accordingly.
+ mesh_resource = MeshResource(dp_resource="dp", tp_resource="tp")
+ return mesh, mesh_resource
+
+
+# DENSE_MULTI_GPU_MESH_SETUP_END
+
+
+# DENSE_MULTI_GPU_SHARD_SETUP_START
+def shard_variables(mesh, variables_dict):
+ kernel_sharding = NamedSharding(mesh, P(None, "tp"))
+
+ def _shard(variables):
+ params = variables["params"]
+ sharded = jax.device_put(params["Dense_0"]["kernel"], kernel_sharding)
+ return {
+ **variables,
+ "params": {
+ **params,
+ "Dense_0": {**params["Dense_0"], "kernel": sharded},
+ },
+ }
+
+ input_sharding = NamedSharding(mesh, P("dp", None, None))
+ output_grad_sharding = NamedSharding(mesh, P("dp", None, "tp"))
+
+ return {
+ "x": jax.device_put(x, input_sharding),
+ "dy": jax.device_put(dy, output_grad_sharding),
+ **{name: _shard(vars_) for name, vars_ in variables_dict.items()},
+ }
+
+
+# DENSE_MULTI_GPU_SHARD_SETUP_END
+
+
+# DENSE_MULTI_GPU_BENCH_START
+def run_multi_gpu_bench():
+ mesh, mesh_resource = build_dp_tp_mesh()
+ sharded = shard_variables(mesh, {"baseline": baseline_vars, "te": te_vars})
+
+ with jax.set_mesh(mesh), global_shard_guard(mesh_resource):
+ print("bf16 DP=2/TP=2:")
+ utils.speedometer(
+ model_apply_fn=baseline.apply,
+ variables=sharded["baseline"],
+ input=sharded["x"],
+ output_grad=sharded["dy"],
+ )
+
+ print(f"\nTE {type(recipe).__name__} DP=2/TP=2:")
+ utils.speedometer(
+ model_apply_fn=te_model.apply,
+ variables=sharded["te"],
+ input=sharded["x"],
+ output_grad=sharded["dy"],
+ )
+
+
+# DENSE_MULTI_GPU_BENCH_END
+
+
+# -----------------------------------------------------------------------------
+# Pytest entry points (not pulled into docs).
+#
+# These run the same code shown in the snippets above and add numeric / smoke
+# assertions so CI catches regressions.
+# -----------------------------------------------------------------------------
+
+import pytest
+from transformer_engine.jax.quantize import is_scaling_mode_supported, ScalingMode
+
+_mxfp8_supported, _mxfp8_reason = is_scaling_mode_supported(ScalingMode.MXFP8_1D_SCALING)
+requires_mxfp8 = pytest.mark.skipif(
+ not _mxfp8_supported, reason=f"MXFP8 not supported on this device: {_mxfp8_reason}"
+)
+
+
+def test_baseline_runs():
+ out = baseline.apply(baseline_vars, x)
+ assert out.shape == (batch, seq, out_features)
+ assert out.dtype == dtype
+
+
+@requires_mxfp8
+def test_te_dense_runs():
+ out = te_model.apply(te_vars, x)
+ assert out.shape == (batch, seq, out_features)
+
+
+@requires_mxfp8
+def test_te_matches_baseline():
+ """TE quantized Dense should match the bf16 baseline within MXFP8 tolerance."""
+ diffs = utils.compare_fwd_bwd(
+ baseline.apply,
+ baseline_vars,
+ te_model.apply,
+ te_vars,
+ input=x,
+ output_grad=dy,
+ )
+ # MXFP8 quantizes activations / weights, so we accept noticeable rel diff vs bf16.
+ # Tune these in follow-ups once we have real CI numbers.
+ assert diffs["y"]["max_rel"] < 0.20, diffs
+ assert diffs["dx"]["max_rel"] < 0.20, diffs
+ assert diffs["dW"]["max_rel"] < 0.30, diffs
+
+
+@requires_mxfp8
+def test_single_gpu_benchmark():
+ run_single_gpu_bench()
+
+
+@requires_mxfp8
+@pytest.mark.skipif(len(jax.devices()) < 4, reason="needs 4 GPUs for DP=2/TP=2")
+def test_multi_gpu_benchmark():
+ run_multi_gpu_bench()
+
+
+if __name__ == "__main__":
+ run_single_gpu_bench()
+ if len(jax.devices()) >= 4:
+ print()
+ run_multi_gpu_bench()
+ else:
+ print("\n[skipped multi-GPU section: <4 devices visible]")
diff --git a/docs/examples/jax_examples/dense.rst b/docs/examples/jax_examples/dense.rst
new file mode 100644
index 0000000000..e04326c5f5
--- /dev/null
+++ b/docs/examples/jax_examples/dense.rst
@@ -0,0 +1,175 @@
+..
+ Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+
+ See LICENSE for license information.
+
+JAX: Dense GEMMs with TransformerEngine
+=======================================
+
+This document walks through replacing a plain ``flax.linen.Dense``'s GEMM with
+TransformerEngine's quantized GEMM.
+
+**Recipe.** We use ``MXFP8BlockScaling`` in this tutorial. ``MXFP8BlockScaling`` and
+``NVFP4BlockScaling`` require a Blackwell-class GPU; on Hopper, swap in
+``DelayedScaling`` or ``Float8CurrentScaling``.
+
+`← Back to the JAX integration overview <../te_jax_integration.html>`_
+
+1. Baseline: a plain Flax Dense block
+-------------------------------------
+
+We isolate the optimization to a single linear layer so it's clear what's
+changing. ``dot_general_cls`` is exposed as a constructor argument so we can swap
+in TE later without touching the model definition.
+
+.. literalinclude:: dense.py
+ :language: python
+ :start-after: # DENSE_BASELINE_MODEL_START
+ :end-before: # DENSE_BASELINE_MODEL_END
+
+.. literalinclude:: dense.py
+ :language: python
+ :start-after: # DENSE_INPUTS_SETUP_START
+ :end-before: # DENSE_INPUTS_SETUP_END
+
+
+2. Quantized Dense via ``make_dot_general_cls``
+-----------------------------------------------
+
+TE exposes a helper, ``te_flax.make_dot_general_cls(recipe)``, that returns a Flax
+module class you pass directly to ``nn.Dense(..., dot_general=...)``.
+
+With this API, TE doesn't create the ``kernel`` params; it only wraps the GEMM.
+All your initialization, sharding annotations, and optimizer state stay where
+they were.
+
+.. literalinclude:: dense.py
+ :language: python
+ :start-after: # DENSE_TE_SETUP_START
+ :end-before: # DENSE_TE_SETUP_END
+
+.. note::
+
+ **What about DelayedScaling state?**
+
+ Most recipes are stateless — scaling factors are computed from each tensor
+ as it flows through the GEMM, so there is nothing to persist across steps.
+ However, if you swap in ``DelayedScaling`` instead, ``init`` will produce a
+ second variable collection, ``_overwrite_with_gradient``, holding
+ ``kernel_amax_history``, ``kernel_scale``, ``x_amax_history``, ``x_scale``,
+ etc. These are **not** model parameters — they are Flax variables that TE
+ updates each step to compute per-tensor scales from a rolling amax window.
+
+ If you use ``DelayedScaling``, you must thread the *entire* ``var_collect``
+ through your training loop (not just ``params``) so the history persists
+ across steps. ``MXFP8BlockScaling``, ``NVFP4BlockScaling``, and
+ ``Float8CurrentScaling`` do not require this.
+
+
+3. Single-GPU performance
+-------------------------
+
+``speedometer`` runs a JIT-compiled forward+backward loop with warmup, on the
+same input for both models.
+
+.. literalinclude:: dense.py
+ :language: python
+ :start-after: # DENSE_SINGLE_GPU_BENCH_START
+ :end-before: # DENSE_SINGLE_GPU_BENCH_END
+
+.. raw:: html
+
+
+ Output:
+
+
+.. container:: program-output
+
+ .. literalinclude:: dense.out
+ :language: text
+ :start-after: # SINGLE_GPU_OUTPUT_START
+ :end-before: # SINGLE_GPU_OUTPUT_END
+
+On a single GB200, that's roughly **2.5× faster** for the fwd+bwd of one large
+Dense — and the only code change was passing ``dot_general=te_dot_general_cls()``
+into ``nn.Dense``.
+
+The speedup depends on shape: large GEMMs benefit most. Very small GEMMs may
+not benefit at all because the cast + scale overhead can dominate.
+
+.. warning::
+
+ **Remat / activation checkpointing.** If your training loop uses
+ ``jax.checkpoint_policies.checkpoint_dots`` (or any policy that matches
+ ``jax.lax.dot_general``), swap it for
+ ``transformer_engine.jax.checkpoint_policies.checkpoint_dots_and_te_gemms``.
+ Otherwise TE's quantized GEMM primitives won't be checkpointed correctly
+ and your performance comparison will not be accurate.
+
+
+6. Multi-GPU: DP=2 / TP=2 on a single Dense
+-------------------------------------------
+
+**Prerequisite:** this section requires four GPUs.
+
+Keeping the same ``FlaxDenseBlock`` from the rest of the document, we run it on
+a 2×2 mesh with **data parallelism** on one axis and **tensor parallelism**
+(column-parallel: shard the kernel's output dim) on the other.
+
+Two pieces wire this up:
+
+1. A ``jax.sharding.Mesh`` you build once at module scope (outside JIT).
+2. TE's ``MeshResource``, set globally via ``global_shard_guard``, which tells
+ TE which mesh axes are DP and TP.
+
+.. literalinclude:: dense.py
+ :language: python
+ :start-after: # DENSE_MULTI_GPU_MESH_SETUP_START
+ :end-before: # DENSE_MULTI_GPU_MESH_SETUP_END
+
+**Sharding plan:**
+
+.. csv-table::
+ :header: "Tensor", "Shape", "PartitionSpec"
+ :widths: 30, 40, 30
+
+ "Kernel (column-parallel)", "``(hidden, out_features)``", "``P(None, 'tp')``"
+ "Input activations", "``(batch, seq, hidden)``", "``P('dp', None, None)``"
+ "Gradient on output", "``(batch, seq, out_features)``", "``P('dp', None, 'tp')``"
+
+.. literalinclude:: dense.py
+ :language: python
+ :start-after: # DENSE_MULTI_GPU_SHARD_SETUP_START
+ :end-before: # DENSE_MULTI_GPU_SHARD_SETUP_END
+
+.. literalinclude:: dense.py
+ :language: python
+ :start-after: # DENSE_MULTI_GPU_BENCH_START
+ :end-before: # DENSE_MULTI_GPU_BENCH_END
+
+.. raw:: html
+
+
+ Output:
+
+
+.. container:: program-output
+
+ .. literalinclude:: dense.out
+ :language: text
+ :start-after: # MULTI_GPU_OUTPUT_START
+ :end-before: # MULTI_GPU_OUTPUT_END
+
+
+7. Collective GEMM (placeholder)
+--------------------------------
+
+*Coming soon.*
+
+
+Next steps
+----------
+
+* `Attention `_
+* `Mixture of Experts `_
+* `← Hub <../te_jax_integration.html>`_
diff --git a/docs/examples/jax_examples/moe.rst b/docs/examples/jax_examples/moe.rst
new file mode 100644
index 0000000000..fb1c8496ba
--- /dev/null
+++ b/docs/examples/jax_examples/moe.rst
@@ -0,0 +1,17 @@
+..
+ Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+
+ See LICENSE for license information.
+
+JAX: Mixture of Experts with TransformerEngine
+==============================================
+
+**TODO — Coming soon.**
+
+This document will cover TE's ``MoEBlock`` layer which utilizes TE's optimized
+routing, permutation and grouped GEMM:
+
+* single-GPU ``MoEBlock`` usage vs ``jax.lax.ragged_dot``
+* expert-parallel sharding considerations.
+
+`← Back to the JAX integration overview <../te_jax_integration.html>`_
diff --git a/docs/examples/te_jax_integration.ipynb b/docs/examples/te_jax_integration.ipynb
deleted file mode 100644
index 66d16ed52f..0000000000
--- a/docs/examples/te_jax_integration.ipynb
+++ /dev/null
@@ -1,462 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "id": "962d87bb",
- "metadata": {},
- "source": [
- "\n",
- "\n",
- "# JAX: Integrating TE into an existing framework\n",
- "\n",
- "This tutorial will cover how to integrate TransformerEngine into an existing JAX model framework, such as [MaxText's TE integration](https://github.com/AI-Hypercomputer/maxtext/blob/ed517cf80d9aa81f76e236c5516dacebfe39e96d/src/MaxText/layers/quantizations.py#L753) or your own model framework. \n"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "b36876bb",
- "metadata": {},
- "source": [
- "Let's start with a standard JAX+Flax Transformer layer"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "id": "d5284a38",
- "metadata": {},
- "outputs": [],
- "source": [
- "import jax\n",
- "import jax.numpy as jnp\n",
- "from flax import linen as nn\n",
- "import quickstart_jax_utils as utils\n",
- "from typing import Optional"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 11,
- "id": "a4d1cfdc",
- "metadata": {},
- "outputs": [],
- "source": [
- "class FlaxMLP(nn.Module):\n",
- " \"\"\"Feed-forward network in Transformer layer\n",
- " Built with plain Flax modules.\n",
- " \"\"\"\n",
- " hidden_size: int\n",
- " ffn_hidden_size: int\n",
- " dot_general_cls: callable = lambda: None\n",
- "\n",
- " @nn.compact\n",
- " def __call__(self, x: jnp.ndarray) -> jnp.ndarray:\n",
- " x = nn.Dense(features=self.ffn_hidden_size, use_bias=True, dot_general=self.dot_general_cls())(x)\n",
- " x = nn.gelu(x, approximate=True) # equivalent to tanh approximation\n",
- " x = nn.Dense(features=self.hidden_size, use_bias=True, dot_general=self.dot_general_cls())(x)\n",
- " return x\n",
- "\n",
- "class FlaxTransformerLayer(nn.Module):\n",
- " \"\"\"Basic Transformer layer using plain Flax modules\"\"\"\n",
- " hidden_size: int\n",
- " ffn_hidden_size: int\n",
- " num_attention_heads: int\n",
- " layernorm_eps: float = 1e-5\n",
- " attention_dropout: float = 0.1\n",
- " dot_general_cls: callable = lambda: None\n",
- " \n",
- " def setup(self):\n",
- " self.kv_channels = self.hidden_size // self.num_attention_heads\n",
- "\n",
- " @nn.compact\n",
- " def __call__(\n",
- " self, \n",
- " x: jnp.ndarray, \n",
- " attention_mask: Optional[jnp.ndarray] = None,\n",
- " deterministic: bool = False\n",
- " ) -> jnp.ndarray:\n",
- " # Create causal mask if not provided\n",
- " if attention_mask is None:\n",
- " attention_mask = nn.make_causal_mask(x[..., 0], dtype=jnp.bool_)\n",
- " \n",
- " res = x\n",
- " x = nn.LayerNorm(epsilon=self.layernorm_eps)(x)\n",
- " \n",
- " # Fused QKV projection\n",
- " qkv = nn.Dense(features=3 * self.hidden_size, use_bias=True, dot_general=self.dot_general_cls())(x)\n",
- " qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], self.num_attention_heads, 3 * self.kv_channels)\n",
- " q, k, v = jnp.split(qkv, 3, axis=3)\n",
- " \n",
- " # q, k, v now have shape [batch, seq_len, num_heads, kv_channels]\n",
- " # which is the correct format for dot_product_attention\n",
- " \n",
- " # Apply dot product attention\n",
- " # Note: dot_product_attention expects mask to be broadcastable to \n",
- " # [batch, num_heads, q_length, kv_length], but attention_mask from \n",
- " # nn.make_causal_mask has shape [batch, 1, seq_len, seq_len]\n",
- " \n",
- " # Generate dropout RNG key when needed (not deterministic and dropout_rate > 0)\n",
- " dropout_rng = None\n",
- " if not deterministic and self.attention_dropout > 0:\n",
- " dropout_rng = self.make_rng('dropout')\n",
- " \n",
- " # See quickstart_jax.ipynb for details on using TE's faster fused attention\n",
- " x = nn.dot_product_attention(\n",
- " query=q,\n",
- " key=k,\n",
- " value=v,\n",
- " mask=attention_mask,\n",
- " dropout_rng=dropout_rng,\n",
- " dropout_rate=self.attention_dropout,\n",
- " deterministic=deterministic,\n",
- " broadcast_dropout=True,\n",
- " )\n",
- " \n",
- " # Reshape output from [batch, seq_len, num_heads, kv_channels] to [batch, seq_len, hidden_size]\n",
- " x = x.reshape(x.shape[0], x.shape[1], self.hidden_size)\n",
- "\n",
- " # Output projection\n",
- " x = nn.Dense(features=self.hidden_size, use_bias=True, dot_general=self.dot_general_cls())(x)\n",
- " \n",
- " x = res + x\n",
- " \n",
- " # Second residual connection\n",
- " res = x\n",
- " x = nn.LayerNorm(epsilon=self.layernorm_eps)(x)\n",
- " \n",
- " # MLP\n",
- " mlp = FlaxMLP(\n",
- " hidden_size=self.hidden_size,\n",
- " ffn_hidden_size=self.ffn_hidden_size,\n",
- " dot_general_cls=self.dot_general_cls,\n",
- " )\n",
- " x = mlp(x)\n",
- " \n",
- " return x + res\n"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "db16bf70",
- "metadata": {},
- "source": [
- "We've exposed `dot_general_cls` here so we can test out different GEMM implementations later. By default, Flax's `nn.Dense` will use JAX's GEMM `jax.lax.dot_general` when `dot_general` is `None`."
- ]
- },
- {
- "cell_type": "markdown",
- "id": "fbc3510b",
- "metadata": {},
- "source": [
- "## Testing Performance\n",
- "\n",
- "Now let's test the performance of our FlaxTransformerLayer:\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "id": "8b44649d",
- "metadata": {},
- "outputs": [],
- "source": [
- "# Layer configuration\n",
- "hidden_size = 4096\n",
- "sequence_length = 2048\n",
- "batch_size = 4\n",
- "ffn_hidden_size = 16384\n",
- "num_attention_heads = 32\n",
- "dtype = jnp.bfloat16\n",
- "\n",
- "# Synthetic data\n",
- "key, dropout_key = jax.random.split(jax.random.PRNGKey(42))\n",
- "x = jax.random.normal(key, (batch_size, sequence_length, hidden_size)).astype(dtype)\n",
- "dy = jax.random.normal(key, (batch_size, sequence_length, hidden_size)).astype(dtype)\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "id": "e44ed26d",
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Pure Flax FlaxTransformerLayer initialized successfully!\n",
- "Parameter shapes: {'params': {'Dense_0': {'bias': (12288,), 'kernel': (4096, 12288)}, 'Dense_1': {'bias': (4096,), 'kernel': (4096, 4096)}, 'FlaxMLP_0': {'Dense_0': {'bias': (16384,), 'kernel': (4096, 16384)}, 'Dense_1': {'bias': (4096,), 'kernel': (16384, 4096)}}, 'LayerNorm_0': {'bias': (4096,), 'scale': (4096,)}, 'LayerNorm_1': {'bias': (4096,), 'scale': (4096,)}}}\n"
- ]
- }
- ],
- "source": [
- "# Initialize the FlaxTransformerLayer\n",
- "flax_transformer = FlaxTransformerLayer(\n",
- " hidden_size=hidden_size,\n",
- " ffn_hidden_size=ffn_hidden_size,\n",
- " num_attention_heads=num_attention_heads,\n",
- ")\n",
- "\n",
- "# Initialize parameters\n",
- "params = flax_transformer.init(key, x, attention_mask=None, deterministic=False)\n",
- "\n",
- "print(\"Pure Flax FlaxTransformerLayer initialized successfully!\")\n",
- "print(f\"Parameter shapes: {jax.tree_util.tree_map(lambda x: x.shape, params)}\")\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "id": "de91af7a",
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Input shape: (4, 2048, 4096)\n",
- "Output shape: (4, 2048, 4096)\n",
- "Output dtype: float32\n",
- "Forward pass completed successfully!\n"
- ]
- }
- ],
- "source": [
- "# Example usage of forward pass\n",
- "y = flax_transformer.apply(params, x, attention_mask=None, deterministic=True)\n",
- "print(f\"Input shape: {x.shape}\")\n",
- "print(f\"Output shape: {y.shape}\")\n",
- "print(f\"Output dtype: {y.dtype}\")\n",
- "print(\"Forward pass completed successfully!\")\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "id": "037bc8d9",
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Mean time: 18.83516788482666 ms\n"
- ]
- }
- ],
- "source": [
- "import importlib\n",
- "import quickstart_jax_utils\n",
- "importlib.reload(quickstart_jax_utils)\n",
- "\n",
- "utils.speedometer(\n",
- " model_apply_fn=flax_transformer.apply,\n",
- " variables=params,\n",
- " input=x,\n",
- " output_grad=dy,\n",
- " forward_kwargs={\"attention_mask\": None, \"deterministic\": False},\n",
- " rngs={\"dropout\": dropout_key},\n",
- ")"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "5e9310c9",
- "metadata": {},
- "source": [
- "## Transformer Engine"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "1f8e213e",
- "metadata": {},
- "source": [
- "TransformerEngine/JAX is currently using Flax Linen. However, it is easily compatible with Flax NNX or Haiku.\n",
- "* [Use Flax NNX and Linen together](https://flax.readthedocs.io/en/latest/guides/bridge_guide.html)\n",
- "* [Haiku and Flax interop](https://dm-haiku.readthedocs.io/en/latest/notebooks/flax.html)\n",
- "\n",
- "Additionally, with the tutorial below, no model parameters need to be managed by TransformerEngine. You can keep all your existing model parameters, initialization, and sharding the same. The only change required is to call TE's dot_general_cls instead of the default Dense dot_general implementation. TE's dot_general_cls is a small module that performs a quantized dense VJP and stores some small recipe-specific state."
- ]
- },
- {
- "cell_type": "markdown",
- "id": "4477d4e9",
- "metadata": {},
- "source": [
- "Now we'll select a recipe. `DelayedScaling` and `CurrentScaling` use per-tensor scaling and are supported on Hopper and Blackwell. `MXFP8BlockScaling` and `NVFP4BlockScaling` use block scaling or a combination of both per-tensor and block scaling and are supported on Blackwell.\n",
- "\n",
- "If you would like to customize the recipe further, various options can be changed by passing args to the recipe's constructor."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "id": "5ddf41e7",
- "metadata": {},
- "outputs": [],
- "source": [
- "from transformer_engine.common.recipe import DelayedScaling, Float8CurrentScaling, MXFP8BlockScaling, NVFP4BlockScaling\n",
- "from transformer_engine.jax import flax as te_flax \n",
- "\n",
- "# Choose a quantization recipe. This can be modified to any of the recipes imported above.\n",
- "quantization_recipe = DelayedScaling()\n",
- "\n",
- "te_dot_general_cls = te_flax.make_dot_general_cls(quantization_recipe)\n",
- "\n",
- "rngs = {'dropout': dropout_key}\n",
- "if isinstance(quantization_recipe, NVFP4BlockScaling):\n",
- " # The NVFP4 recipe requires a Flax RNG for stochastic rounding\n",
- " rngs['sr_rng'] = jax.random.PRNGKey(0)\n"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "c8769655",
- "metadata": {},
- "source": [
- "Now using this quantized dense in our model is as simple as passing in `dot_general_fn=te_dot_general`. Let's try it out!\n",
- "\n",
- "\n",
- "\n",
- "Important: Remat Policy\n",
- "\n",
- "TE's quantization uses specialized TE quantized GEMM primitives. If you are using any built-in JAX checkpoint policies that look for JAX GEMMs (dots), such as `jax.checkpoint_policies.checkpoint_dots`, please replace the policy with `transformer_engine.jax.checkpoint_policies.checkpoint_dots_and_te_gemms` or similar policies to ensure TE's quantized GEMM primitives are checkpointed correctly.\n",
- "\n",
- "If this is not performed, TE GEMMs will be rematerialized introducing an incorrect performance comparison.\n",
- "\n",
- "
"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 12,
- "id": "8407d2ea",
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Pure Flax FlaxTransformerLayer initialized successfully!\n",
- "Parameter shapes: {'Dense_0': {'bias': (12288,), 'kernel': (4096, 12288)}, 'Dense_1': {'bias': (4096,), 'kernel': (4096, 4096)}, 'FlaxMLP_0': {'Dense_0': {'bias': (16384,), 'kernel': (4096, 16384)}, 'Dense_1': {'bias': (4096,), 'kernel': (16384, 4096)}}, 'LayerNorm_0': {'bias': (4096,), 'scale': (4096,)}, 'LayerNorm_1': {'bias': (4096,), 'scale': (4096,)}}\n",
- "Additional state: {'_overwrite_with_gradient': {'FlaxMLP_0': {'TEWrapper_dot_general_0': {'grad_amax_history': (1024,), 'grad_scale': (1,), 'kernel_amax_history': (1024,), 'kernel_scale': (1,), 'x_amax_history': (1024,), 'x_scale': (1,)}, 'TEWrapper_dot_general_1': {'grad_amax_history': (1024,), 'grad_scale': (1,), 'kernel_amax_history': (1024,), 'kernel_scale': (1,), 'x_amax_history': (1024,), 'x_scale': (1,)}}, 'TEWrapper_dot_general_0': {'grad_amax_history': (1024,), 'grad_scale': (1,), 'kernel_amax_history': (1024,), 'kernel_scale': (1,), 'x_amax_history': (1024,), 'x_scale': (1,)}, 'TEWrapper_dot_general_1': {'grad_amax_history': (1024,), 'grad_scale': (1,), 'kernel_amax_history': (1024,), 'kernel_scale': (1,), 'x_amax_history': (1024,), 'x_scale': (1,)}}}\n"
- ]
- }
- ],
- "source": [
- "# Initialize the FlaxTransformerLayer\n",
- "flax_transformer = FlaxTransformerLayer(\n",
- " hidden_size=hidden_size,\n",
- " ffn_hidden_size=ffn_hidden_size,\n",
- " num_attention_heads=num_attention_heads,\n",
- " dot_general_cls=te_dot_general_cls,\n",
- ")\n",
- "\n",
- "# Initialize parameters\n",
- "var_collect = flax_transformer.init(key, x, attention_mask=None, deterministic=False)\n",
- "\n",
- "print(\"Pure Flax FlaxTransformerLayer initialized successfully!\")\n",
- "print(f\"Parameter shapes: {jax.tree_util.tree_map(lambda x: x.shape, var_collect['params'])}\")\n",
- "print(f\"Additional state: {jax.tree_util.tree_map(lambda x: x.shape, {k: v for k, v in var_collect.items() if k != 'params'})}\")"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "abe27237",
- "metadata": {},
- "source": [
- "If using a recipe that stores additional state, such as `DelayedScaling`, you'll see this additional state stored as Flax variables. It is important to maintain and pass the whole state of Flax variables `var_collect` across training steps, not just the model params, for proper usage of stateful recipes like `DelayedScaling`.\n",
- "\n",
- "For example, above inside `Additional state: ` you'll see the `amax_history` of each quantization which is used to compute the per-tensor scale in the `DelayedScaling` recipe."
- ]
- },
- {
- "cell_type": "markdown",
- "id": "5ab72935",
- "metadata": {},
- "source": [
- "The reason we need `te_dot_general_cls` as a Flax module instead of a module-less function like `jax.lax.dot_general` is for some quantization recipes to track internal state separate from model parameters.\n",
- "\n",
- "Flax modules can manage 3 things:\n",
- "1. Model parameters/weights, e.g. your Dense \"kernel\", \"bias\", etc.\n",
- "2. RNGs for dropout, stochastic rounding, etc.\n",
- "3. Flax variables. These are additional state variables that are used across training steps but are distinct from model params in that you don't take gradients or optimize them. Currently, we only use this for DelayedScaling's amax_history state\n",
- "\n",
- "With the simplest quantization integration shown in this tutorial, we want users to keep their existing model param setup so they don't need to worry about preserving the sharding, init distribution, etc.. So we don't need point 1 since we don't do model param creation in this codepath with dot_general_cls, but we still do need `te_dot_general_cls()` to produce a Flax module since we potentially need to do points 2 or 3 which need to be in a Flax module."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 13,
- "id": "3b6b344b",
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Input shape: (4, 2048, 4096)\n",
- "Output shape: (4, 2048, 4096)\n",
- "Output dtype: float32\n",
- "Forward pass completed successfully!\n"
- ]
- }
- ],
- "source": [
- "# Example usage of forward pass\n",
- "y = flax_transformer.apply(var_collect, x, attention_mask=None, deterministic=True, rngs=rngs)\n",
- "print(f\"Input shape: {x.shape}\")\n",
- "print(f\"Output shape: {y.shape}\")\n",
- "print(f\"Output dtype: {y.dtype}\")\n",
- "print(\"Forward pass completed successfully!\")\n"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "d178f247",
- "metadata": {},
- "source": [
- "Now let's measure the performance!"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 14,
- "id": "5cc6c2a7",
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Mean time: 10.553865432739258 ms\n"
- ]
- }
- ],
- "source": [
- "import importlib\n",
- "import quickstart_jax_utils\n",
- "importlib.reload(quickstart_jax_utils)\n",
- "\n",
- "utils.speedometer(\n",
- " model_apply_fn=flax_transformer.apply,\n",
- " variables=var_collect,\n",
- " input=x,\n",
- " output_grad=dy,\n",
- " forward_kwargs={\"attention_mask\": None, \"deterministic\": False},\n",
- " rngs=rngs,\n",
- ")"
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3 (ipykernel)",
- "language": "python",
- "name": "python3"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 5
-}
diff --git a/docs/examples/te_jax_integration.rst b/docs/examples/te_jax_integration.rst
new file mode 100644
index 0000000000..a6dd0d401e
--- /dev/null
+++ b/docs/examples/te_jax_integration.rst
@@ -0,0 +1,91 @@
+..
+ Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+
+ See LICENSE for license information.
+
+JAX: Integrating TransformerEngine into an existing framework
+=============================================================
+
+This is the landing page for a series of focused documents on bringing
+TransformerEngine into a JAX+Flax codebase one optimization at a time. Each
+linked page isolates a single feature so you can see exactly what changes are
+required and what are the performance benefits.
+
+Pick a topic
+------------
+
+.. list-table::
+ :header-rows: 1
+ :widths: 25, 15, 60
+
+ * - Document
+ - Status
+ - Covers
+ * - `Dense GEMMs `_
+ - **Available**
+ - ``nn.Dense`` → quantized GEMM; single-GPU speedup; multi-GPU speedup;
+ Collective GEMM
+ * - `Attention `_
+ - *Coming soon*
+ -
+ * - `Mixture of Experts `_
+ - *Coming soon*
+ -
+
+
+Quantization recipes at a glance
+--------------------------------
+
+TE exposes its quantization choices as **recipes**. Please see
+`Low-precision Training
+`_
+for a more detailed description of each recipe.
+
+.. list-table::
+ :header-rows: 1
+ :widths: 25, 15, 30, 30
+
+ * - Recipe
+ - Hardware
+ - State
+ - When to use
+ * - ``DelayedScaling``
+ - Hopper+
+ - amax history (Flax variables)
+ - Per-tensor FP8 with amax history
+ * - ``Float8CurrentScaling``
+ - Hopper+
+ - none
+ - Per-tensor FP8 without an amax history
+ * - ``MXFP8BlockScaling``
+ - Blackwell+
+ - none
+ - Block-scaled FP8 (32-element blocks)
+ * - ``NVFP4BlockScaling``
+ - Blackwell+
+ - requires a Flax RNG ``sr_rng``
+ - FP4 with 2D block scaling and stochastic rounding
+
+Import them from ``transformer_engine.common.recipe``.
+
+
+Conventions used across these documents
+---------------------------------------
+
+* **Framework.** Flax Linen. (TE/JAX uses Linen; see
+ `Flax NNX/Linen interop
+ `_ and
+ `Haiku/Flax interop
+ `_ if you're on
+ a different stack.)
+* **Baseline dtype.** bf16 for inputs and parameters.
+* **Benchmarking.** ``quickstart_jax_utils.speedometer`` runs a JIT-compiled
+ fwd+bwd loop with warmup.
+
+
+.. toctree::
+ :hidden:
+
+ jax_examples/dense
+ jax_examples/attention
+ jax_examples/moe
diff --git a/docs/index.rst b/docs/index.rst
index 7389553679..53c4b0e37e 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -57,7 +57,7 @@ Transformer Engine documentation
examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb
examples/te_gemma/tutorial_generation_gemma_with_te.ipynb
examples/onnx/onnx_export.ipynb
- examples/te_jax_integration.ipynb
+ examples/te_jax_integration.rst
examples/op_fuser/op_fuser.rst
.. toctree::
diff --git a/qa/L0_jax_unittest/test.sh b/qa/L0_jax_unittest/test.sh
index 3453e35d2c..9cd171f896 100644
--- a/qa/L0_jax_unittest/test.sh
+++ b/qa/L0_jax_unittest/test.sh
@@ -42,6 +42,11 @@ python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/py
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
NVTE_JAX_CUSTOM_CALLS="false" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder_without_custom_call.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py without custom calls"
+# Exercise the docs/examples/jax_examples tutorials. The multi-GPU tests are
+# skipped at runtime when fewer than 4 devices are visible, so this is safe on
+# single-GPU runners.
+python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_docs_examples_jax.xml $TE_PATH/docs/examples/jax_examples/ || test_fail "docs/examples/jax_examples"
+
if [ $RET -ne 0 ]; then
echo "Error: some sub-tests failed: $FAILED_CASES"
exit 1
diff --git a/qa/L1_jax_distributed_unittest/test.sh b/qa/L1_jax_distributed_unittest/test.sh
index 4f92d1c783..ea33828f53 100644
--- a/qa/L1_jax_distributed_unittest/test.sh
+++ b/qa/L1_jax_distributed_unittest/test.sh
@@ -37,6 +37,10 @@ XLA_FLAGS="$XLA_FLAGS --xla_gpu_enable_nccl_comm_splitting=false" python3 -m pyt
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_dist_fused_attn.xml $TE_PATH/tests/jax/test_distributed_fused_attn.py || test_fail "test_distributed_fused_attn.py"
+# Exercise the multi-GPU tutorial in docs/examples/jax_examples (needs >= 4 GPUs;
+# auto-skips otherwise).
+python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_docs_examples_jax_distributed.xml -k multi_gpu $TE_PATH/docs/examples/jax_examples/ || test_fail "docs/examples/jax_examples (multi-GPU)"
+
# TODO(Phuong): add this test back after it is verified
# SCRIPT_NAME=$TE_PATH/tests/jax/test_multi_process_distributed_grouped_gemm.py bash $TE_PATH/tests/jax/multi_process_launch.sh || test_fail "test_multi_process_distributed_grouped_gemm.py"