From a28ec54d9ff00b8e292e93cf6cbe4036151d0060 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Mon, 8 Dec 2025 18:19:50 +0100 Subject: [PATCH 01/25] Code drop: Update recipes documentation and remove custom recipes from low precision training Signed-off-by: Pawel Gadzinski --- .github/workflows/docs.yml | 2 +- docs/_static/css/diagram-colors.css | 132 +++++ docs/_static/css/output-style.css | 60 +++ docs/_static/css/sphinx_tabs.css | 45 ++ docs/_static/css/svg-responsive.css | 72 +++ docs/_templates/layout.html | 4 + docs/conf.py | 5 + .../fp8_blockwise_scaling.rst | 237 +++++++++ .../img/blockwise_swizzle_flow.svg | 146 +++++ .../img/combined_scaling.svg | 342 ++++++++++++ .../img/transpose_handling.svg | 351 ++++++++++++ .../jax_blockwise_scaling_example.py | 51 ++ .../pytorch_blockwise_scaling_example.py | 37 ++ .../fp8_current_scaling.rst | 178 +++++++ .../img/fp8_cast_process.svg | 55 ++ .../img/fp8_current_scaling_all_gather.svg | 78 +++ .../fp8_current_scaling/img/fp8_formats.svg | 164 ++++++ .../img/fp8_scaling_concept.svg | 110 ++++ .../img/fp8_tensor_core.svg | 75 +++ .../jax_current_scaling_example.py | 42 ++ .../pytorch_current_scaling_example.py | 31 ++ .../fp8_delayed_scaling.rst | 172 ++++++ .../img/scaling_comparison.svg | 82 +++ ...jax_delayed_scaling_distributed_example.py | 15 + .../jax_delayed_scaling_example.py | 58 ++ ...rch_delayed_scaling_distributed_example.py | 18 + .../pytorch_delayed_scaling_example.py | 39 ++ .../features/low_precision_training/index.rst | 17 + .../introduction/autocast_jax.py | 101 ++++ .../introduction/autocast_pytorch.out | 169 ++++++ .../introduction/autocast_pytorch.py | 69 +++ .../introduction/bf16_fp16_training_jax.py | 52 ++ .../bf16_fp16_training_pytorch.py | 56 ++ .../introduction/fp8_autocast_jax.py | 53 ++ .../introduction/img/fp8_linear_flow.svg | 172 ++++++ .../img/fp_formats_comparison.svg | 183 +++++++ .../img/master_weights_approaches.svg | 112 ++++ .../img/mixed_precision_operations.svg | 105 ++++ .../introduction/introduction.rst | 277 ++++++++++ .../introduction/jax_out | 13 + .../introduction/pytorch_out | 11 + .../mxfp8/img/fp8_1d_scaling.svg | 177 ++++++ .../mxfp8/img/mxfp8_row_col.svg | 266 +++++++++ .../img/mxfp8_scale_linearize_and_swizzle.svg | 190 +++++++ .../mxfp8/img/mxfp8_swizzle_both_tensors.svg | 101 ++++ .../mxfp8/img/mxfp8_tensor_scaling_layout.svg | 63 +++ .../mxfp8/jax_mxfp8_example.py | 41 ++ .../low_precision_training/mxfp8/mxfp8.rst | 199 +++++++ .../mxfp8/pytorch_mxfp8_example.py | 30 ++ .../nvfp4/img/nvfp4_all_gather.svg | 118 ++++ .../nvfp4/img/nvfp4_hierarchical_scaling.svg | 186 +++++++ .../nvfp4/img/nvfp4_row_col.svg | 208 ++++++++ .../nvfp4/img/nvfp4_vs_fp8.svg | 91 ++++ .../low_precision_training/nvfp4/img/rht.svg | 138 +++++ .../nvfp4/img/stochastic_rounding.svg | 95 ++++ .../nvfp4/jax_nvfp4_example.py | 43 ++ .../low_precision_training/nvfp4/nvfp4.rst | 261 +++++++++ .../nvfp4/pytorch_nvfp4_example.py | 33 ++ .../fused_layers_jax.out | 0 .../fused_layers_jax.py | 43 ++ .../fused_layers_pytorch.out | 8 + .../fused_layers_pytorch.py | 37 ++ .../img/fused_layers.svg | 120 +++++ .../img/gemm_access_pattern.svg | 218 ++++++++ .../img/hopper_vs_blackwell_layout.svg | 122 +++++ .../img/sequence_parallel_quantization.svg | 159 ++++++ .../img/transpose_fusion.svg | 181 +++++++ .../memory_usage_1_jax.out | 4 + .../memory_usage_1_jax.py | 44 ++ .../memory_usage_1_pytorch.out | 11 + .../memory_usage_1_pytorch.py | 39 ++ .../memory_usage_2_jax.out | 3 + .../memory_usage_2_jax.py | 43 ++ .../memory_usage_2_pytorch.out | 11 + .../memory_usage_2_pytorch.py | 39 ++ .../memory_usage_3_jax.out | 4 + .../memory_usage_3_jax.py | 48 ++ .../memory_usage_3_pytorch.out | 11 + .../memory_usage_3_pytorch.py | 44 ++ .../performance_considerations.rst | 503 ++++++++++++++++++ .../performance_considerations/pytorch_out | 0 .../save_original_input_pytorch.out | 12 + .../save_original_input_pytorch.py | 51 ++ docs/index.rst | 8 + 84 files changed, 7993 insertions(+), 1 deletion(-) create mode 100644 docs/_static/css/diagram-colors.css create mode 100644 docs/_static/css/output-style.css create mode 100644 docs/_static/css/sphinx_tabs.css create mode 100644 docs/_static/css/svg-responsive.css create mode 100644 docs/features/low_precision_training/fp8_blockwise_scaling/fp8_blockwise_scaling.rst create mode 100644 docs/features/low_precision_training/fp8_blockwise_scaling/img/blockwise_swizzle_flow.svg create mode 100644 docs/features/low_precision_training/fp8_blockwise_scaling/img/combined_scaling.svg create mode 100644 docs/features/low_precision_training/fp8_blockwise_scaling/img/transpose_handling.svg create mode 100644 docs/features/low_precision_training/fp8_blockwise_scaling/jax_blockwise_scaling_example.py create mode 100644 docs/features/low_precision_training/fp8_blockwise_scaling/pytorch_blockwise_scaling_example.py create mode 100644 docs/features/low_precision_training/fp8_current_scaling/fp8_current_scaling.rst create mode 100644 docs/features/low_precision_training/fp8_current_scaling/img/fp8_cast_process.svg create mode 100644 docs/features/low_precision_training/fp8_current_scaling/img/fp8_current_scaling_all_gather.svg create mode 100644 docs/features/low_precision_training/fp8_current_scaling/img/fp8_formats.svg create mode 100644 docs/features/low_precision_training/fp8_current_scaling/img/fp8_scaling_concept.svg create mode 100644 docs/features/low_precision_training/fp8_current_scaling/img/fp8_tensor_core.svg create mode 100644 docs/features/low_precision_training/fp8_current_scaling/jax_current_scaling_example.py create mode 100644 docs/features/low_precision_training/fp8_current_scaling/pytorch_current_scaling_example.py create mode 100644 docs/features/low_precision_training/fp8_delayed_scaling/fp8_delayed_scaling.rst create mode 100644 docs/features/low_precision_training/fp8_delayed_scaling/img/scaling_comparison.svg create mode 100644 docs/features/low_precision_training/fp8_delayed_scaling/jax_delayed_scaling_distributed_example.py create mode 100644 docs/features/low_precision_training/fp8_delayed_scaling/jax_delayed_scaling_example.py create mode 100644 docs/features/low_precision_training/fp8_delayed_scaling/pytorch_delayed_scaling_distributed_example.py create mode 100644 docs/features/low_precision_training/fp8_delayed_scaling/pytorch_delayed_scaling_example.py create mode 100644 docs/features/low_precision_training/index.rst create mode 100644 docs/features/low_precision_training/introduction/autocast_jax.py create mode 100644 docs/features/low_precision_training/introduction/autocast_pytorch.out create mode 100644 docs/features/low_precision_training/introduction/autocast_pytorch.py create mode 100644 docs/features/low_precision_training/introduction/bf16_fp16_training_jax.py create mode 100644 docs/features/low_precision_training/introduction/bf16_fp16_training_pytorch.py create mode 100644 docs/features/low_precision_training/introduction/fp8_autocast_jax.py create mode 100644 docs/features/low_precision_training/introduction/img/fp8_linear_flow.svg create mode 100644 docs/features/low_precision_training/introduction/img/fp_formats_comparison.svg create mode 100644 docs/features/low_precision_training/introduction/img/master_weights_approaches.svg create mode 100644 docs/features/low_precision_training/introduction/img/mixed_precision_operations.svg create mode 100644 docs/features/low_precision_training/introduction/introduction.rst create mode 100644 docs/features/low_precision_training/introduction/jax_out create mode 100644 docs/features/low_precision_training/introduction/pytorch_out create mode 100644 docs/features/low_precision_training/mxfp8/img/fp8_1d_scaling.svg create mode 100644 docs/features/low_precision_training/mxfp8/img/mxfp8_row_col.svg create mode 100644 docs/features/low_precision_training/mxfp8/img/mxfp8_scale_linearize_and_swizzle.svg create mode 100644 docs/features/low_precision_training/mxfp8/img/mxfp8_swizzle_both_tensors.svg create mode 100644 docs/features/low_precision_training/mxfp8/img/mxfp8_tensor_scaling_layout.svg create mode 100644 docs/features/low_precision_training/mxfp8/jax_mxfp8_example.py create mode 100644 docs/features/low_precision_training/mxfp8/mxfp8.rst create mode 100644 docs/features/low_precision_training/mxfp8/pytorch_mxfp8_example.py create mode 100644 docs/features/low_precision_training/nvfp4/img/nvfp4_all_gather.svg create mode 100644 docs/features/low_precision_training/nvfp4/img/nvfp4_hierarchical_scaling.svg create mode 100644 docs/features/low_precision_training/nvfp4/img/nvfp4_row_col.svg create mode 100644 docs/features/low_precision_training/nvfp4/img/nvfp4_vs_fp8.svg create mode 100644 docs/features/low_precision_training/nvfp4/img/rht.svg create mode 100644 docs/features/low_precision_training/nvfp4/img/stochastic_rounding.svg create mode 100644 docs/features/low_precision_training/nvfp4/jax_nvfp4_example.py create mode 100644 docs/features/low_precision_training/nvfp4/nvfp4.rst create mode 100644 docs/features/low_precision_training/nvfp4/pytorch_nvfp4_example.py create mode 100644 docs/features/low_precision_training/performance_considerations/fused_layers_jax.out create mode 100644 docs/features/low_precision_training/performance_considerations/fused_layers_jax.py create mode 100644 docs/features/low_precision_training/performance_considerations/fused_layers_pytorch.out create mode 100644 docs/features/low_precision_training/performance_considerations/fused_layers_pytorch.py create mode 100644 docs/features/low_precision_training/performance_considerations/img/fused_layers.svg create mode 100644 docs/features/low_precision_training/performance_considerations/img/gemm_access_pattern.svg create mode 100644 docs/features/low_precision_training/performance_considerations/img/hopper_vs_blackwell_layout.svg create mode 100644 docs/features/low_precision_training/performance_considerations/img/sequence_parallel_quantization.svg create mode 100644 docs/features/low_precision_training/performance_considerations/img/transpose_fusion.svg create mode 100644 docs/features/low_precision_training/performance_considerations/memory_usage_1_jax.out create mode 100644 docs/features/low_precision_training/performance_considerations/memory_usage_1_jax.py create mode 100644 docs/features/low_precision_training/performance_considerations/memory_usage_1_pytorch.out create mode 100644 docs/features/low_precision_training/performance_considerations/memory_usage_1_pytorch.py create mode 100644 docs/features/low_precision_training/performance_considerations/memory_usage_2_jax.out create mode 100644 docs/features/low_precision_training/performance_considerations/memory_usage_2_jax.py create mode 100644 docs/features/low_precision_training/performance_considerations/memory_usage_2_pytorch.out create mode 100644 docs/features/low_precision_training/performance_considerations/memory_usage_2_pytorch.py create mode 100644 docs/features/low_precision_training/performance_considerations/memory_usage_3_jax.out create mode 100644 docs/features/low_precision_training/performance_considerations/memory_usage_3_jax.py create mode 100644 docs/features/low_precision_training/performance_considerations/memory_usage_3_pytorch.out create mode 100644 docs/features/low_precision_training/performance_considerations/memory_usage_3_pytorch.py create mode 100644 docs/features/low_precision_training/performance_considerations/performance_considerations.rst create mode 100644 docs/features/low_precision_training/performance_considerations/pytorch_out create mode 100644 docs/features/low_precision_training/performance_considerations/save_original_input_pytorch.out create mode 100644 docs/features/low_precision_training/performance_considerations/save_original_input_pytorch.py diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 5beeeb8879..15c7b92957 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -18,7 +18,7 @@ jobs: - name: 'Install dependencies' run: | pip install sphinx==8.1.3 sphinx_rtd_theme==3.0.1 nbsphinx==0.9.5 IPython ipython_genutils==0.2.0 ipywidgets==8.0.2 astroid==3.3.2 - pip install breathe==4.35.0 sphinx-autoapi==3.3.2 + pip install breathe==4.35.0 sphinx-autoapi==3.3.2 sphinx-tabs==3.4.7 sudo apt-get install -y pandoc graphviz doxygen export GIT_SHA=$(git show-ref --hash HEAD) - name: 'Build docs' diff --git a/docs/_static/css/diagram-colors.css b/docs/_static/css/diagram-colors.css new file mode 100644 index 0000000000..6ae3f99afa --- /dev/null +++ b/docs/_static/css/diagram-colors.css @@ -0,0 +1,132 @@ +/* Diagram color definitions for Transformer Engine documentation */ + +/* High precision (BF16/FP16) elements */ +.hp { + fill: #ede7f6; + stroke: #673ab7; + stroke-width: 2; +} + +/* FP8 precision elements */ +.fp8 { + fill: #fff8e1; + stroke: #ffa726; + stroke-width: 2; +} + +/* GEMM/computation operations */ +.gemm { + fill: #ffe0b2; + stroke: #fb8c00; + stroke-width: 2.5; +} + +/* Quantization operations */ +.quantize { + fill: #e8f5e9; + stroke: #66bb6a; + stroke-width: 2; +} + +/* Amax computation operations */ +.amax { + fill: #e1f5fe; + stroke: #039be5; + stroke-width: 2; +} + +/* Text styles */ +.text { + font-family: 'Segoe UI', Arial, sans-serif; + font-size: 14px; + text-anchor: middle; + fill: #212121; +} + +.small-text { + font-family: 'Segoe UI', Arial, sans-serif; + font-size: 14px; + text-anchor: middle; + fill: #757575; +} + +.label { + font-family: 'Segoe UI', Arial, sans-serif; + font-size: 14px; + text-anchor: middle; + fill: #424242; +} + +.title { + font-family: 'Segoe UI', Arial, sans-serif; + font-size: 18px; + font-weight: 600; + text-anchor: middle; + fill: #212121; +} + +.section-title { + font-family: 'Segoe UI', Arial, sans-serif; + font-size: 15px; + font-weight: 600; + text-anchor: middle; +} + +/* Arrows */ +.arrow { + stroke: #616161; + stroke-width: 2; + fill: none; +} + +/* Additional box and element styles */ +.box-blue { + fill: #e3f2fd; + stroke: #1976d2; + stroke-width: 2; +} + +.box-orange { + fill: #fff3e0; + stroke: #f57c00; + stroke-width: 2; +} + +.box-green { + fill: #c8e6c9; + stroke: #388e3c; + stroke-width: 2; +} + +.box-dashed { + stroke-dasharray: 5,5; +} + +/* LayerNorm specific */ +.layernorm { + fill: #b3e5fc; + stroke: #0277bd; + stroke-width: 2.5; +} + +/* Fused layers */ +.fused { + fill: #b2dfdb; + stroke: #00695c; + stroke-width: 3; +} + +/* Generic computation blocks */ +.computation { + fill: #f5f5f5; + stroke: #757575; + stroke-width: 2; +} + +/* FP32 precision (alternative red) */ +.fp32 { + fill: #ffcdd2; + stroke: #d32f2f; + stroke-width: 2.5; +} + diff --git a/docs/_static/css/output-style.css b/docs/_static/css/output-style.css new file mode 100644 index 0000000000..864d8587a3 --- /dev/null +++ b/docs/_static/css/output-style.css @@ -0,0 +1,60 @@ +/* Custom styling for program output blocks */ + +.program-output { + background-color: #f8f9fa; + padding: 0; /* No padding at all */ + margin: 0; /* No margins at all */ + border-radius: 0; /* No rounded corners */ + font-family: 'Courier New', monospace; + font-size: 14px; + line-height: 1.5; + width: 100%; + max-width: 100%; +} + +.program-output pre { + margin: 0; + padding: 0; + background: transparent !important; + border: none !important; + color: #2c3e50; + width: 100%; +} + +.program-output .highlight { + background: transparent !important; + margin: 0; + width: 100%; +} + +/* Alternative lighter style */ +.output-block { + background-color: #fafbfc; + border: 1px solid #e1e4e8; + padding: 10px 14px; + margin: 10px 0; + border-radius: 3px; + font-family: 'SF Mono', 'Consolas', monospace; + font-size: 13px; + color: #24292e; +} + +/* Console-like output style */ +.console-output { + background-color: #1e1e1e; + border-left: 3px solid #76b900; + padding: 14px 18px; + margin: 12px 0; + border-radius: 5px; + font-family: 'Fira Code', 'Consolas', monospace; + font-size: 13px; + color: #d4d4d4; + box-shadow: 0 2px 4px rgba(0,0,0,0.1); +} + +.console-output pre { + margin: 0; + color: #d4d4d4; + background: transparent !important; +} + diff --git a/docs/_static/css/sphinx_tabs.css b/docs/_static/css/sphinx_tabs.css new file mode 100644 index 0000000000..c3e524e0e9 --- /dev/null +++ b/docs/_static/css/sphinx_tabs.css @@ -0,0 +1,45 @@ +/* Custom styling for sphinx-tabs */ + +.sphinx-tabs { + margin-bottom: 1rem; +} + +.sphinx-tabs-tab { + background-color: #f4f4f4; + border: 1px solid #ccc; + border-bottom: none; + padding: 0.5rem 1rem; + margin-right: 0.5rem; + cursor: pointer; + font-weight: 500; + transition: background-color 0.2s; +} + +.sphinx-tabs-tab:hover { + background-color: #e0e0e0; +} + +.sphinx-tabs-tab[aria-selected="true"] { + background-color: #76b900; /* NVIDIA green */ + color: white; + border-color: #76b900; + margin-right: 0.5rem; +} + +.sphinx-tabs-panel { + border: 1px solid #ccc; + padding: 1rem; + background-color: #f9f9f9; +} + +/* Dark mode support for RTD theme */ +.rst-content .sphinx-tabs-tab { + color: #333; +} + +.rst-content .sphinx-tabs-tab[aria-selected="true"] { + color: white; +} + + + diff --git a/docs/_static/css/svg-responsive.css b/docs/_static/css/svg-responsive.css new file mode 100644 index 0000000000..3ffe14eb14 --- /dev/null +++ b/docs/_static/css/svg-responsive.css @@ -0,0 +1,72 @@ +/* Responsive styling for SVG images */ + +/* Make all SVG images responsive */ +.document svg, +.document object[type="image/svg+xml"], +.rst-content svg { + max-width: 100%; + height: auto; + display: block; + margin: 1em auto; +} + +/* For raw HTML embedded SVGs */ +.document .raw-html svg { + max-width: 100%; + height: auto; + width: 100%; +} + +/* Ensure container doesn't overflow */ +.document .raw-html { + max-width: 100%; + overflow-x: auto; +} + +/* Figure containers with captions */ +.svg-figure { + text-align: center; + margin: 20px auto; +} + +.svg-figure img { + display: block; + margin: 0 auto; + height: auto; +} + +/* Different width classes for figures */ +.svg-figure.width-70 img { + width: 70%; + max-width: 100%; +} + +.svg-figure.width-80 img { + width: 80%; + max-width: 100%; +} + +.svg-figure.width-90 img { + width: 90%; + max-width: 100%; +} + +.svg-figure.width-100 img { + width: 100%; +} + +/* Figure captions */ +.svg-caption { + font-style: italic; + margin-top: 10px; + color: #555; + font-size: 0.95em; + line-height: 1.4; +} + + + + + + + diff --git a/docs/_templates/layout.html b/docs/_templates/layout.html index f94e526f57..99ae0702a8 100644 --- a/docs/_templates/layout.html +++ b/docs/_templates/layout.html @@ -67,6 +67,10 @@ overflow: visible !important; } + .quant { + background-color: yellow !important; + } + + + + + + + + + + + + + Input Tensor + + FP32/BF16 + + + + + + + + Quantize + + + + + + + FP8 (Compact) + + + + + FP32 Scales + + + + FP8 Data + + + + + + + + All-Gather + + + + + + + Swizzle + + + + + + + FP8 (GEMM Ready) + + + + + Swizzled Scales + + + + FP8 Data + + + + + + + + GEMM + + + + + + + + + + Input Tensor + + FP32/BF16 + + + + + + + + Quantize + + + Swizzle + + + + + + + FP8 (GEMM Ready) + + + + + Swizzled Scales + + + + FP8 Data + + + + + + + + GEMM + + diff --git a/docs/features/low_precision_training/fp8_blockwise_scaling/img/combined_scaling.svg b/docs/features/low_precision_training/fp8_blockwise_scaling/img/combined_scaling.svg new file mode 100644 index 0000000000..62fb69afab --- /dev/null +++ b/docs/features/low_precision_training/fp8_blockwise_scaling/img/combined_scaling.svg @@ -0,0 +1,342 @@ + + + + + + + + + + Delayed/Current FP8 Scaling + (Single scaling factor per tensor) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + 1 scaling factor + + + + + Blockwise FP8 Scaling – 1 dimension + (One scaling factor per 128 elements) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Scaling factors (one per block) + + + + + Blockwise FP8 Scaling – 2 dimensions + (One scaling factor per 128x128 block of elements) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Scaling factors (1 per 2D block) + + diff --git a/docs/features/low_precision_training/fp8_blockwise_scaling/img/transpose_handling.svg b/docs/features/low_precision_training/fp8_blockwise_scaling/img/transpose_handling.svg new file mode 100644 index 0000000000..fbaa419b67 --- /dev/null +++ b/docs/features/low_precision_training/fp8_blockwise_scaling/img/transpose_handling.svg @@ -0,0 +1,351 @@ + + + + + + + 1D Blockwise Scaling + + + + Rowwise Quantization + (240 × 120 tensor) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Columnwise Quantization + (120 × 240 tensor) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + 2D Blockwise Scaling + + + + Rowwise Quantization + (180 × 120 tensor) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Columnwise Quantization + (120 × 180 tensor) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/features/low_precision_training/fp8_blockwise_scaling/jax_blockwise_scaling_example.py b/docs/features/low_precision_training/fp8_blockwise_scaling/jax_blockwise_scaling_example.py new file mode 100644 index 0000000000..6bd60897e5 --- /dev/null +++ b/docs/features/low_precision_training/fp8_blockwise_scaling/jax_blockwise_scaling_example.py @@ -0,0 +1,51 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +# START_BLOCKWISE_SCALING_EXAMPLE + +import jax +import jax.numpy as jnp +import optax +import transformer_engine.jax as te +from transformer_engine.jax.flax import DenseGeneral +from transformer_engine.jax.sharding import MeshResource, global_shard_guard +from transformer_engine.common.recipe import Float8BlockScaling + +# Check for Hopper or newer GPU +gpu = jax.devices("gpu")[0] +major, minor = gpu.compute_capability.split(".") +assert ( + int(major) >= 9 +), f"FP8 Blockwise Scaling requires SM90 (Hopper) or later, got SM{major}{minor}" + +# Create FP8 Blockwise Scaling recipe +recipe = Float8BlockScaling( + fp8_format=te.common.recipe.Format.E4M3, # E4M3 or HYBRID (default: E4M3) + x_block_scaling_dim=1, # 1D scaling for activations (default: 1) + w_block_scaling_dim=2, # 2D scaling for weights (default: 2) + grad_block_scaling_dim=1, # 1D scaling for gradients (default: 1) +) + +with global_shard_guard(MeshResource()): + with te.fp8_autocast(enabled=True, recipe=recipe, mesh_resource=MeshResource()): + # Initialize layer and data + layer = DenseGeneral(features=1024) + key = jax.random.PRNGKey(0) + x = jax.random.normal(key, (32, 128, 1024), dtype=jnp.bfloat16) + params = layer.init(key, x) + + # Training with FP8 Blockwise Scaling + def loss_fn(params): + output = layer.apply(params, x) + return output.sum() + + loss, grads = jax.value_and_grad(loss_fn)(params) + + # Update parameters + optimizer = optax.adamw(learning_rate=1e-4) + opt_state = optimizer.init(params) + updates, opt_state = optimizer.update(grads, opt_state, params) + params = optax.apply_updates(params, updates) + +# END_BLOCKWISE_SCALING_EXAMPLE diff --git a/docs/features/low_precision_training/fp8_blockwise_scaling/pytorch_blockwise_scaling_example.py b/docs/features/low_precision_training/fp8_blockwise_scaling/pytorch_blockwise_scaling_example.py new file mode 100644 index 0000000000..5b9dce1c82 --- /dev/null +++ b/docs/features/low_precision_training/fp8_blockwise_scaling/pytorch_blockwise_scaling_example.py @@ -0,0 +1,37 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +# START_BLOCKWISE_SCALING_EXAMPLE + +import torch +import transformer_engine.pytorch as te +from transformer_engine.common.recipe import Float8BlockScaling + +# Check for Hopper or newer GPU +major, minor = torch.cuda.get_device_capability() +assert major >= 9, f"FP8 Blockwise Scaling requires SM90 (Hopper) or later, got SM{major}{minor}" + +# Create FP8 Blockwise Scaling recipe +recipe = Float8BlockScaling( + fp8_format=te.common.recipe.Format.E4M3, # E4M3 or HYBRID (default: E4M3) + x_block_scaling_dim=1, # 1D scaling for activations (default: 1) + w_block_scaling_dim=2, # 2D scaling for weights (default: 2) + grad_block_scaling_dim=1, # 1D scaling for gradients (default: 1) +) + +# Create a linear layer +layer = te.Linear(1024, 1024) +optimizer = torch.optim.AdamW(layer.parameters(), lr=1e-4) + +# Training with FP8 Blockwise Scaling +inp = torch.randn(32, 128, 1024, dtype=torch.bfloat16, device="cuda") + +with te.fp8_autocast(enabled=True, fp8_recipe=recipe): + output = layer(inp) + loss = output.sum() + +loss.backward() +optimizer.step() + +# END_BLOCKWISE_SCALING_EXAMPLE diff --git a/docs/features/low_precision_training/fp8_current_scaling/fp8_current_scaling.rst b/docs/features/low_precision_training/fp8_current_scaling/fp8_current_scaling.rst new file mode 100644 index 0000000000..d55ac91020 --- /dev/null +++ b/docs/features/low_precision_training/fp8_current_scaling/fp8_current_scaling.rst @@ -0,0 +1,178 @@ +.. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + + See LICENSE for license information. + +FP8 Current Scaling +=================================== + +FP8 current scaling is the simplest low precision recipe provided by Transformer Engine. +To understand how this recipe works, we first need to examine what the FP8 data type is and how it differs from other floating point formats. + + +FP8 data type +------------- + +The FP8 datatype, introduced in Hopper architecture, is actually 2 distinct datatypes, useful in different parts of the training of neural networks: + +* E4M3 -- consists of 1 sign bit, 4 exponent bits and 3 bits of mantissa. It can store values up to +/-448 and ``nan``. +* E5M2 -- consists of 1 sign bit, 5 exponent bits and 2 bits of mantissa. It can store values up to +/-57344, +/- ``inf`` and ``nan``. The tradeoff of the increased dynamic range is lower precision of the stored values. + +.. raw:: html + :file: img/fp8_formats.svg + +*Figure 1: Structure of the floating point datatypes. All of the values shown (in FP16, BF16, FP8 E4M3 and FP8 E5M2) are the closest representations of value 0.3952.* + + +**E4M3 and E5M2 usage in training** + +By default, Transformer Engine uses a hybrid approach: + +* *Forward pass* - activations and weights require more precision, so E4M3 datatype is best used. +* *Backward pass* - gradients are less susceptible to precision loss but require higher dynamic range, so E5M2 datatype is preferred. + +The user can configure this behavior via the ``fp8_format`` parameter of the recipe. + + +Scaling factors +--------------- + + +FP8's limited dynamic range is insufficient for many tensors. +To address this, scaling factors are used. In FP8 Current Scaling there is one **FP32** scale factor per tensor. +The representation of a tensor element ``x`` in FP8 precision is given by: + +.. code-block:: python + + x = x_fp8 * s + +where + +* ``x_fp8`` is the FP8 value (E4M3 or E5M2), +* ``s`` is a global **FP32** scaling factor applied to the entire tensor. + +**FP8 Current Scaling quantization** + +Let's look more closely at how quantization to FP8 with scaling factor is implemented in +the FP8 Current Scaling recipe. + +.. raw:: html + :file: img/fp8_scaling_concept.svg + +*Figure 3: Quantization to FP8 consists of amax computation, scaling to fit the FP8 range and casting to the respective FP8 format.* + +Quantization to FP8 consists of 3 steps: + +1. Computation of the absolute maximum value of the tensor - we refer to it as ``amax``. +2. Applying the scaling factor of ``fp8_max / amax`` to the tensor, to fit it into the FP8 range +3. Casting into the respective FP8 format using *Round To Nearest Even (RTNE)*. Values round to the nearest representable FP8 value. When exactly halfway between two values, rounds to the one with even mantissa to minimize systematic bias. + +**Performance analysis** + +Quantization is a memory-bound operation that requires reading the tensor twice: + +* First read: compute ``amax`` across all elements. +* Second read: apply the scaling factor and cast to FP8. + +This is a significant overhead compared to other recipes, which typically require only a single memory read. + +.. raw:: html + :file: img/fp8_cast_process.svg + +*Figure 4: FP8 quantization with current scaling recipe - two tensor reads are needed, one to compute amax and one to apply the scaling factor and cast to FP8.* + +Hardware support +---------------- + +The Hopper architecture introduced FP8 support in Tensor Cores, enabling efficient low-precision computation. +Tensor Cores support every combination of E4M3 and E5M2 formats as inputs, allowing flexible precision choices for different operands. +The inputs to an FP8 Tensor Core operation consist of chunks of FP8 tensors along with their corresponding scaling factors. +The Tensor Core performs the matrix multiplication in FP8 precision and produces output in higher precision (FP16, BF16, or FP32). + +.. raw:: html + :file: img/fp8_tensor_core.svg + +*Figure 5: FP8 Tensor Cores process two input tensors (A and B) with their respective scaling factors and perform matrix multiplication to accumulate higher-precision output.* + + +Transpose handling +------------------ + + + +*Ada and Hopper* + +On Ada and Hopper, the backward pass requires a transposed FP8 tensor. +The columnwise layout is physically different from the rowwise layout, so a transpose operation is needed. +All 3 options from :ref:`introduction Transpose handling section ` are supported. + +*Blackwell and later* + +Blackwell hardware supports multiple GEMM layouts natively, eliminating the need for explicit transposes. +The rowwise and columnwise tensors share the same physical memory layout. + +.. figure:: ../performance_considerations/img/hopper_vs_blackwell_layout.svg + :align: center + :alt: Comparison of rowwise and columnwise tensor layouts on Blackwell vs Hopper + + *Figure 6: On Blackwell, rowwise and columnwise usages share the same memory layout. On Hopper, columnwise usage requires a physical transpose.* + + +Distributed training +-------------------- + +**All-gather of columnwise tensors** + +Supported for Blackwell and later, since rowwise and columnwise tensors share the same memory layout. +For Hopper and Ada, all-gather of transposed FP8 tensors is not supported. +The rowwise tensor is gathered and then it is transposed to columnwise tensor. + +**Amax reduction** + +Tensors that are gathered across nodes (e.g. input and gradient in sequence parallelism) require amax synchronization before quantization. +Each node computes its local ``amax``, then a reduction produces the global maximum across all nodes. +All nodes use this synchronized amax to compute identical scaling factors, enabling quantized all-gather. + +.. raw:: html + :file: img/fp8_current_scaling_all_gather.svg + +*Figure 7: Quantization and all-gather flow for FP8 current scaling showing amax computation and synchronization.* + + +Supported devices +----------------- + +Ada and later (SM 8.9+) + +Examples +-------- + +Here's how to use FP8 Current Scaling recipe in PyTorch and JAX: + +.. tabs:: + + .. tab:: PyTorch + + .. raw:: html + +
+ Requires SM89 (Ada) or later +
+ + .. literalinclude:: pytorch_current_scaling_example.py + :language: python + :start-after: # START_CURRENT_SCALING_EXAMPLE + :end-before: # END_CURRENT_SCALING_EXAMPLE + + .. tab:: JAX + + .. raw:: html + +
+ Requires SM89 (Ada) or later +
+ + .. literalinclude:: jax_current_scaling_example.py + :language: python + :start-after: # START_CURRENT_SCALING_EXAMPLE + :end-before: # END_CURRENT_SCALING_EXAMPLE \ No newline at end of file diff --git a/docs/features/low_precision_training/fp8_current_scaling/img/fp8_cast_process.svg b/docs/features/low_precision_training/fp8_current_scaling/img/fp8_cast_process.svg new file mode 100644 index 0000000000..dfd01a2f17 --- /dev/null +++ b/docs/features/low_precision_training/fp8_current_scaling/img/fp8_cast_process.svg @@ -0,0 +1,55 @@ + + + + + + + + + + + FP8 quantization + + + + High Precision + Tensor + + + + + + + Quantize + + + + Compute amax + 1 tensor read + + + + + + + Apply Scale + + Cast + 1 tensor read + + + + + + + FP8 + Tensor + + diff --git a/docs/features/low_precision_training/fp8_current_scaling/img/fp8_current_scaling_all_gather.svg b/docs/features/low_precision_training/fp8_current_scaling/img/fp8_current_scaling_all_gather.svg new file mode 100644 index 0000000000..ab31123111 --- /dev/null +++ b/docs/features/low_precision_training/fp8_current_scaling/img/fp8_current_scaling_all_gather.svg @@ -0,0 +1,78 @@ + + + + + + + + + + + Quantization + all gather for FP8 current scaling + + + + High Precision + Tensor + + + + + + + Compute + Amax + + + + + + + Synchronize + Amax + + + + + + + Scale + + Cast + + + + + + + FP8 + Tensor + + + + + + + All-Gather + + + + + + + FP8 Gathered + Tensor + + + diff --git a/docs/features/low_precision_training/fp8_current_scaling/img/fp8_formats.svg b/docs/features/low_precision_training/fp8_current_scaling/img/fp8_formats.svg new file mode 100644 index 0000000000..fb762e6699 --- /dev/null +++ b/docs/features/low_precision_training/fp8_current_scaling/img/fp8_formats.svg @@ -0,0 +1,164 @@ + + + + + + + sign + exponent + mantissa + + + FP16 + + + + 0 + + + + 0 + + 1 + + 1 + + 0 + + 1 + + + + 1 + + 0 + + 0 + + 1 + + 0 + + 1 + + 0 + + 0 + + 1 + + 1 + + = 0.395264 + + + + BF16 + + + + 0 + + + + 0 + + 1 + + 1 + + 1 + + 1 + + 1 + + 0 + + 1 + + + + 1 + + 0 + + 0 + + 1 + + 0 + + 1 + + 0 + + = 0.394531 + + + + FP8 E4M3 + + + + 0 + + + + 0 + + 1 + + 0 + + 1 + + + + 1 + + 0 + + 1 + + = 0.40625 + + + + FP8 E5M2 + + + + 0 + + + + 0 + + 1 + + 1 + + 0 + + 1 + + + + 1 + + 0 + + = 0.375 + + + diff --git a/docs/features/low_precision_training/fp8_current_scaling/img/fp8_scaling_concept.svg b/docs/features/low_precision_training/fp8_current_scaling/img/fp8_scaling_concept.svg new file mode 100644 index 0000000000..492a4d02ee --- /dev/null +++ b/docs/features/low_precision_training/fp8_current_scaling/img/fp8_scaling_concept.svg @@ -0,0 +1,110 @@ + + + + + Original Tensor Values + + + + + + + 0 + + + + FP8 range + + + + + + + + + + + + + + + + + amax + + + Scaled Values (fit FP8 range) + + + + + + + 0 + + + + + + + FP8 range min + + + + + + + + + + + + + + + + + + + + + + Cast to FP8 (quantized values) + + + + + + + 0 + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/features/low_precision_training/fp8_current_scaling/img/fp8_tensor_core.svg b/docs/features/low_precision_training/fp8_current_scaling/img/fp8_tensor_core.svg new file mode 100644 index 0000000000..92d68ac3a0 --- /dev/null +++ b/docs/features/low_precision_training/fp8_current_scaling/img/fp8_tensor_core.svg @@ -0,0 +1,75 @@ + + + + + + + + + + + FP8 Tensor Core Operation + + + + Input A + + + + chunk of FP8 Tensor A + (E4M3 or E5M2) + + + + Scale a + scalar float 32 + + + + + + + Input B + + + + chunk of FP8 Tensor B + (E4M3 or E5M2) + + + + Scale b + scalar float 32 + + + + + + + FP8 Tensor Core + + + + + + + Accumulated chunk of output + Higher precision + + + diff --git a/docs/features/low_precision_training/fp8_current_scaling/jax_current_scaling_example.py b/docs/features/low_precision_training/fp8_current_scaling/jax_current_scaling_example.py new file mode 100644 index 0000000000..5ce063af4f --- /dev/null +++ b/docs/features/low_precision_training/fp8_current_scaling/jax_current_scaling_example.py @@ -0,0 +1,42 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +# START_CURRENT_SCALING_EXAMPLE + +import jax +import jax.numpy as jnp +import optax +import transformer_engine.jax as te +from transformer_engine.jax.flax import DenseGeneral +from transformer_engine.jax.sharding import MeshResource, global_shard_guard +from transformer_engine.jax.quantize import Float8CurrentScaling, Format + +# Create FP8 Current Scaling recipe +# Available formats: +# - Format.HYBRID (default) -- E4M3 for forward pass, E5M2 for backward pass +# - Format.E4M3 -- E4M3 for both forward and backward pass +recipe = Float8CurrentScaling(fp8_format=Format.HYBRID) + +with global_shard_guard(MeshResource()): + with te.fp8_autocast(enabled=True, recipe=recipe, mesh_resource=MeshResource()): + # Create and initialize layer + layer = DenseGeneral(features=1024) + key = jax.random.PRNGKey(0) + x = jax.random.normal(key, (32, 128, 1024), dtype=jnp.bfloat16) + params = layer.init(key, x) + + # Training with FP8 Current Scaling + def loss_fn(params): + output = layer.apply(params, x) + return output.sum() + + loss, grads = jax.value_and_grad(loss_fn)(params) + + # Update parameters + optimizer = optax.sgd(learning_rate=0.01) + opt_state = optimizer.init(params) + updates, opt_state = optimizer.update(grads, opt_state, params) + params = optax.apply_updates(params, updates) + +# END_CURRENT_SCALING_EXAMPLE diff --git a/docs/features/low_precision_training/fp8_current_scaling/pytorch_current_scaling_example.py b/docs/features/low_precision_training/fp8_current_scaling/pytorch_current_scaling_example.py new file mode 100644 index 0000000000..88a2352203 --- /dev/null +++ b/docs/features/low_precision_training/fp8_current_scaling/pytorch_current_scaling_example.py @@ -0,0 +1,31 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +# START_CURRENT_SCALING_EXAMPLE + +import torch +import transformer_engine.pytorch as te +from transformer_engine.common.recipe import Float8CurrentScaling, Format + +# Create FP8 Current Scaling recipe +# Available formats: +# - Format.HYBRID (default) -- E4M3 for forward pass, E5M2 for backward pass +# - Format.E4M3 -- E4M3 for both forward and backward pass +recipe = Float8CurrentScaling(fp8_format=Format.HYBRID) + +# Create a simple linear layer +layer = te.Linear(1024, 1024) +optimizer = torch.optim.SGD(layer.parameters(), lr=0.01) + +# Training with FP8 Current Scaling +inp = torch.randn(32, 128, 1024, dtype=torch.bfloat16, device="cuda") + +with te.fp8_autocast(enabled=True, fp8_recipe=recipe): + output = layer(inp) + loss = output.sum() + +loss.backward() +optimizer.step() + +# END_CURRENT_SCALING_EXAMPLE diff --git a/docs/features/low_precision_training/fp8_delayed_scaling/fp8_delayed_scaling.rst b/docs/features/low_precision_training/fp8_delayed_scaling/fp8_delayed_scaling.rst new file mode 100644 index 0000000000..772ed73fab --- /dev/null +++ b/docs/features/low_precision_training/fp8_delayed_scaling/fp8_delayed_scaling.rst @@ -0,0 +1,172 @@ +.. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + + See LICENSE for license information. + +FP8 Delayed Scaling +=================================== + +FP8 Delayed Scaling estimates scaling factors from historical amax values rather than computing them +for each tensor. This reduces tensor reads per quantization from two to one, improving memory efficiency. + +Both this recipe and :doc:`FP8 Current Scaling <../fp8_current_scaling/fp8_current_scaling>` use +the same FP8 formats (E4M3/E5M2) with one float32 scaling factor per tensor. +Reading the FP8 Current Scaling documentation first is recommended. + +Quantization with delayed scaling factors +----------------------------------------- + +FP8 Current Scaling requires two tensor reads per quantization: one to compute amax, +one to cast. FP8 Delayed Scaling eliminates the first read by predicting the scaling factor +from historical amax values - hence *delayed* (using past values) versus *current* (using present values). + +The quantization process works as follows: + +1. **Compute scaling factor from history** (no tensor read needed): + The scaling factor is derived from stored ``amax_history`` using the formula: + + ``scaling_factor = FP8_MAX / amax`` + + where ``amax`` is computed from history using either ``max`` (default) or ``most_recent`` algorithm. + +2. **Quantize the tensor** (one tensor read): + Apply the scaling factor and cast to FP8. Values exceeding FP8 range are clipped. + +3. **Update history**: + Record the actual amax from this quantization for future iterations. + +Each module maintains an ``amax_history`` tensor of configurable length (``amax_history_len``) +for each quantized tensor. + +.. raw:: html + :file: img/scaling_comparison.svg + +*Figure 1. Comparison of FP8 Current Scaling and FP8 Delayed Scaling quantization processes.* + +Amax History Management +----------------------- + +The ``amax_history`` buffer acts as a sliding window of recent amax values. +Position 0 serves as a staging area for the current amax, while positions 1 to N-1 +store the history from oldest to newest. Each quantization writes the observed amax +to position 0, and after the pass completes, the history is rotated: + +.. code-block:: text + + Before rotation: [amax_N, amax_1, amax_2, ..., amax_N-1] (amax_N = current, amax_1 = oldest) + After rotation: [0, amax_2, ..., amax_N-1, amax_N] (amax_1 dropped, amax_N appended) + +The effective history length is ``amax_history_len - 1`` since position 0 is reserved +for the staging area. + +The implementation differs between PyTorch and JAX: + +.. tabs:: + + .. tab:: PyTorch + + Each module creates two ``amax_history`` tensors, initialized to zero: + + - Forward: shape ``(amax_history_len, num_gemms * 3)`` — three FP8 tensors per GEMM (input, weight, output) + - Backward: shape ``(amax_history_len, num_gemms * 2)`` — two FP8 tensors per GEMM (grad_output, grad_input) + + During the first forward pass, modules register their ``amax_history`` tensors + to a **global buffer** associated with the autocast context. When the context exits, + a single CUDA kernel processes all registered tensors at once - performing both + amax reduction across GPUs and history rotation. + + This batched approach (one kernel for all tensors instead of one kernel per tensor) + minimizes kernel launch overhead. + + .. tab:: JAX + + Each quantizer maintains its own ``amax_history`` as a Flax variable with shape ``(amax_history_len,)``. + There is no global buffer - each quantizer updates independently. + + The rotation is performed per-quantizer using ``jnp.roll``: + + .. code-block:: python + + updated_amax_history = jnp.roll(amax_history, -1, -1) + amax_history = updated_amax_history.at[0].set(0.0) + +Here's how to use FP8 Delayed Scaling in PyTorch and JAX: + +.. tabs:: + + .. tab:: PyTorch + + .. raw:: html + +
+ Requires SM89 (Ada) or later +
+ + .. literalinclude:: pytorch_delayed_scaling_example.py + :language: python + :start-after: # START_DELAYED_SCALING_EXAMPLE + :end-before: # END_DELAYED_SCALING_EXAMPLE + + .. tab:: JAX + + .. raw:: html + +
+ Requires SM89 (Ada) or later +
+ + .. literalinclude:: jax_delayed_scaling_example.py + :language: python + :start-after: # START_DELAYED_SCALING_EXAMPLE + :end-before: # END_DELAYED_SCALING_EXAMPLE + + +Distributed Training +-------------------- + +Since FP8 Delayed Scaling uses the same data formats as FP8 Current Scaling, +transpose gather is not supported. However, amax reduction works slightly differently in different frameworks. + +.. tabs:: + + .. tab:: PyTorch + + Amax reduction is controlled by two parameters: + + - ``reduce_amax`` in recipe: enables/disables reduction (required for SP and CP) + - ``amax_reduction_group`` in ``autocast``: specifies the process group for reduction + + We recommend reducing amax across all GPUs where the tensor is sharded, + including data parallel ranks. + + .. literalinclude:: pytorch_delayed_scaling_distributed_example.py + :language: python + :start-after: # START_AMAX_REDUCTION_EXAMPLE + :end-before: # END_AMAX_REDUCTION_EXAMPLE + + In data parallel training, some modules may not execute on certain ranks + (e.g., MoE experts that receive no tokens). This is handled as follows: + + - **First iteration**: All modules must execute on all ranks to register + their ``amax_history`` tensors in the global buffer. Mismatched registration + causes the ``all_reduce`` to hang due to different tensor sizes across ranks. + - **Subsequent iterations**: The ``autocast`` context must be entered and exited + on all ranks (this triggers the collective reduction). Individual modules can be + skipped - if no rank executes a module, its history is not rotated and scale + remains unchanged. + + + .. tab:: JAX + + Amax reduction is always enabled and managed automatically. + Reduction scope: all parallelism axes except pipeline parallelism (TP, SP, DP/FSDP). + + .. literalinclude:: jax_delayed_scaling_distributed_example.py + :language: python + :start-after: # START_AMAX_REDUCTION_EXAMPLE + :end-before: # END_AMAX_REDUCTION_EXAMPLE + +Supported devices +----------------- + +Ada and later (SM 8.9+) \ No newline at end of file diff --git a/docs/features/low_precision_training/fp8_delayed_scaling/img/scaling_comparison.svg b/docs/features/low_precision_training/fp8_delayed_scaling/img/scaling_comparison.svg new file mode 100644 index 0000000000..aff4ba0da3 --- /dev/null +++ b/docs/features/low_precision_training/fp8_delayed_scaling/img/scaling_comparison.svg @@ -0,0 +1,82 @@ + + + + + + + + + + + Current Scaling + + + + Tensor + + + + + + + Amax Computation + + + + + + + Quantization + (uses tensor + amax) + + + + + + + FP8 Tensor + + + + Delayed Scaling + + + + Tensor + + + + amax history + + + + read amax + + + + Quantization + (uses tensor + amax from history) + (updates amax history) + + + + update amax + + + + + + + FP8 Tensor + + + diff --git a/docs/features/low_precision_training/fp8_delayed_scaling/jax_delayed_scaling_distributed_example.py b/docs/features/low_precision_training/fp8_delayed_scaling/jax_delayed_scaling_distributed_example.py new file mode 100644 index 0000000000..e5f59e4e3c --- /dev/null +++ b/docs/features/low_precision_training/fp8_delayed_scaling/jax_delayed_scaling_distributed_example.py @@ -0,0 +1,15 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +# START_AMAX_REDUCTION_EXAMPLE +import transformer_engine.jax as te +from transformer_engine.common.recipe import DelayedScaling + +# Amax reduction scope is managed internally +recipe = DelayedScaling(reduce_amax=True) # Must be True in JAX + +with te.autocast(recipe=recipe, mesh_resource=mesh_resource): + output = layer.apply(params, inp) + +# END_AMAX_REDUCTION_EXAMPLE diff --git a/docs/features/low_precision_training/fp8_delayed_scaling/jax_delayed_scaling_example.py b/docs/features/low_precision_training/fp8_delayed_scaling/jax_delayed_scaling_example.py new file mode 100644 index 0000000000..a1b9b8203c --- /dev/null +++ b/docs/features/low_precision_training/fp8_delayed_scaling/jax_delayed_scaling_example.py @@ -0,0 +1,58 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import jax + +# Requires Ada (SM89) or newer for FP8 support +cc = jax.devices()[0].device_kind +assert ( + "RTX 40" in cc + or "L40" in cc + or "H100" in cc + or "H200" in cc + or "GH" in cc + or "B100" in cc + or "B200" in cc + or "GB" in cc +), "This example requires SM89 (Ada) or newer" + +# START_DELAYED_SCALING_EXAMPLE + +import jax +import jax.numpy as jnp +import optax +import transformer_engine.jax as te +from transformer_engine.jax.flax import DenseGeneral +from transformer_engine.jax.sharding import MeshResource, global_shard_guard +from transformer_engine.common.recipe import DelayedScaling + +# Create FP8 Delayed Scaling recipe +recipe = DelayedScaling( + margin=0, # Margin for scaling factor computation (default: 0) + amax_history_len=1024, # Length of amax history window (default: 1024) + amax_compute_algo="max", # How to compute amax from history (default: "max") +) + +with global_shard_guard(MeshResource()): + with te.autocast(enabled=True, recipe=recipe, mesh_resource=MeshResource()): + # Initialize layer and data + layer = DenseGeneral(features=1024) + key = jax.random.PRNGKey(0) + x = jax.random.normal(key, (32, 128, 1024), dtype=jnp.bfloat16) + params = layer.init(key, x) + + # Training with FP8 Delayed Scaling + def loss_fn(params): + output = layer.apply(params, x) + return output.sum() + + loss, grads = jax.value_and_grad(loss_fn)(params) + + # Update parameters + optimizer = optax.adamw(learning_rate=1e-4) + opt_state = optimizer.init(params) + updates, opt_state = optimizer.update(grads, opt_state, params) + params = optax.apply_updates(params, updates) + +# END_DELAYED_SCALING_EXAMPLE diff --git a/docs/features/low_precision_training/fp8_delayed_scaling/pytorch_delayed_scaling_distributed_example.py b/docs/features/low_precision_training/fp8_delayed_scaling/pytorch_delayed_scaling_distributed_example.py new file mode 100644 index 0000000000..2c99fe1a2c --- /dev/null +++ b/docs/features/low_precision_training/fp8_delayed_scaling/pytorch_delayed_scaling_distributed_example.py @@ -0,0 +1,18 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +# START_AMAX_REDUCTION_EXAMPLE +import torch.distributed as dist +import transformer_engine.pytorch as te +from transformer_engine.common.recipe import DelayedScaling + +# Create process group for amax reduction (e.g., all 8 GPUs) +amax_reduction_group = dist.new_group(ranks=[0, 1, 2, 3, 4, 5, 6, 7]) + +recipe = DelayedScaling(reduce_amax=True) + +with te.autocast(recipe=recipe, amax_reduction_group=amax_reduction_group): + output = model(inp) + +# END_AMAX_REDUCTION_EXAMPLE diff --git a/docs/features/low_precision_training/fp8_delayed_scaling/pytorch_delayed_scaling_example.py b/docs/features/low_precision_training/fp8_delayed_scaling/pytorch_delayed_scaling_example.py new file mode 100644 index 0000000000..61d227c6e3 --- /dev/null +++ b/docs/features/low_precision_training/fp8_delayed_scaling/pytorch_delayed_scaling_example.py @@ -0,0 +1,39 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import torch + +# Requires Ada (SM89) or newer for FP8 support +assert torch.cuda.get_device_capability()[0] >= 9 or ( + torch.cuda.get_device_capability()[0] == 8 and torch.cuda.get_device_capability()[1] >= 9 +), "This example requires SM89 (Ada) or newer" + +# START_DELAYED_SCALING_EXAMPLE + +import torch +import transformer_engine.pytorch as te +from transformer_engine.common.recipe import DelayedScaling + +# Create FP8 Delayed Scaling recipe +recipe = DelayedScaling( + margin=0, # Margin for scaling factor computation (default: 0) + amax_history_len=1024, # Length of amax history window (default: 1024) + amax_compute_algo="max", # How to compute amax from history (default: "max") +) + +# Create a linear layer +layer = te.Linear(1024, 1024) +optimizer = torch.optim.AdamW(layer.parameters(), lr=1e-4) + +# Training with FP8 Delayed Scaling +inp = torch.randn(32, 128, 1024, dtype=torch.bfloat16, device="cuda") + +with te.autocast(enabled=True, recipe=recipe): + output = layer(inp) + loss = output.sum() + +loss.backward() +optimizer.step() + +# END_DELAYED_SCALING_EXAMPLE diff --git a/docs/features/low_precision_training/index.rst b/docs/features/low_precision_training/index.rst new file mode 100644 index 0000000000..39fba07881 --- /dev/null +++ b/docs/features/low_precision_training/index.rst @@ -0,0 +1,17 @@ +.. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + + See LICENSE for license information. + +Low precision training +=================================== + +.. toctree:: + + introduction/introduction.rst + performance_considerations/performance_considerations.rst + fp8_current_scaling/fp8_current_scaling.rst + fp8_delayed_scaling/fp8_delayed_scaling.rst + fp8_blockwise_scaling/fp8_blockwise_scaling.rst + mxfp8/mxfp8.rst + nvfp4/nvfp4.rst \ No newline at end of file diff --git a/docs/features/low_precision_training/introduction/autocast_jax.py b/docs/features/low_precision_training/introduction/autocast_jax.py new file mode 100644 index 0000000000..3c5e8f7e84 --- /dev/null +++ b/docs/features/low_precision_training/introduction/autocast_jax.py @@ -0,0 +1,101 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import jax + +# Requires Ada (SM89) or newer for FP8 support +cc = jax.devices()[0].device_kind +assert ( + "RTX 40" in cc + or "L40" in cc + or "H100" in cc + or "H200" in cc + or "GH" in cc + or "B100" in cc + or "B200" in cc + or "GB" in cc +), "This example requires SM89 (Ada) or newer" + +# START_AUTOCAST_BASIC + +import jax +import jax.numpy as jnp +import transformer_engine.jax as te +from transformer_engine.jax.flax import TransformerLayer +from transformer_engine.jax.sharding import MeshResource, global_shard_guard +from transformer_engine.jax.quantize import get_delayed_scaling_recipe + +# Set up mesh resource and recipe +recipe = get_delayed_scaling_recipe() + +with global_shard_guard(MeshResource()): + # Model initialization must happen inside autocast + with te.autocast(enabled=True, recipe=recipe, mesh_resource=MeshResource()): + layer = TransformerLayer( + hidden_size=1024, + mlp_hidden_size=4096, + num_attention_heads=16, + ) + + init_key, dropout_key = jax.random.split(jax.random.PRNGKey(0)) + x = jax.random.normal(init_key, (32, 128, 1024), dtype=jnp.bfloat16) + params = layer.init({"params": init_key, "dropout": dropout_key}, x) + + # Forward and backward pass (both inside autocast for JAX) + def loss_fn(params): + output = layer.apply(params, x, rngs={"dropout": dropout_key}) + return output.sum() + + loss, grads = jax.value_and_grad(loss_fn)(params) + +# END_AUTOCAST_BASIC + + +# START_AUTOCAST_SEQUENTIAL + +from transformer_engine.common.recipe import DelayedScaling + +encoder_recipe = DelayedScaling(fp8_format="E4M3") +decoder_recipe = DelayedScaling(fp8_format="HYBRID") + +with global_shard_guard(MeshResource()): + with te.autocast(enabled=True, recipe=encoder_recipe, mesh_resource=MeshResource()): + encoder = TransformerLayer(hidden_size=1024, mlp_hidden_size=4096, num_attention_heads=16) + encoder_params = encoder.init({"params": init_key, "dropout": dropout_key}, x) + hidden = encoder.apply(encoder_params, x, rngs={"dropout": dropout_key}) + + with te.autocast(enabled=True, recipe=decoder_recipe, mesh_resource=MeshResource()): + decoder = TransformerLayer(hidden_size=1024, mlp_hidden_size=4096, num_attention_heads=16) + decoder_params = decoder.init({"params": init_key, "dropout": dropout_key}, hidden) + output = decoder.apply(decoder_params, hidden, rngs={"dropout": dropout_key}) + +# END_AUTOCAST_SEQUENTIAL + + +# START_AUTOCAST_NESTED + +outer_recipe = DelayedScaling(fp8_format="E4M3") +inner_recipe = DelayedScaling(fp8_format="HYBRID") + +with global_shard_guard(MeshResource()): + with te.autocast(enabled=True, recipe=outer_recipe, mesh_resource=MeshResource()): + # layer1 uses outer_recipe + layer1 = TransformerLayer(hidden_size=1024, mlp_hidden_size=4096, num_attention_heads=16) + params1 = layer1.init({"params": init_key, "dropout": dropout_key}, x) + hidden = layer1.apply(params1, x, rngs={"dropout": dropout_key}) + + with te.autocast(enabled=True, recipe=inner_recipe, mesh_resource=MeshResource()): + # layer2 uses inner_recipe (overrides outer) + layer2 = TransformerLayer( + hidden_size=1024, mlp_hidden_size=4096, num_attention_heads=16 + ) + params2 = layer2.init({"params": init_key, "dropout": dropout_key}, hidden) + hidden = layer2.apply(params2, hidden, rngs={"dropout": dropout_key}) + + # layer3 uses outer_recipe again + layer3 = TransformerLayer(hidden_size=1024, mlp_hidden_size=4096, num_attention_heads=16) + params3 = layer3.init({"params": init_key, "dropout": dropout_key}, hidden) + output = layer3.apply(params3, hidden, rngs={"dropout": dropout_key}) + +# END_AUTOCAST_NESTED diff --git a/docs/features/low_precision_training/introduction/autocast_pytorch.out b/docs/features/low_precision_training/introduction/autocast_pytorch.out new file mode 100644 index 0000000000..189447da9a --- /dev/null +++ b/docs/features/low_precision_training/introduction/autocast_pytorch.out @@ -0,0 +1,169 @@ +Unable to find image 'gitlab-master.nvidia.com/dl/transformerengine/transformerengine:main-pytorch-py3-devel-amd64' locally +main-pytorch-py3-devel-amd64: Pulling from dl/transformerengine/transformerengine +20043066d3d5: Already exists +4f4fb700ef54: Already exists +4f4fb700ef54: Already exists +7bc71055093c: Already exists +4f4fb700ef54: Already exists +47fcffaf4b29: Already exists +5902106683aa: Pulling fs layer +db47045b8288: Pulling fs layer +9855dcfd20f2: Pulling fs layer +c83aec52d35b: Pulling fs layer +d6f0613e9d23: Pulling fs layer +529adbe755d6: Pulling fs layer +3f8bbd50cf66: Pulling fs layer +c94f3e870fb7: Pulling fs layer +6ccfc98c93d6: Pulling fs layer +2a8fb579ec89: Pulling fs layer +4f4fb700ef54: Pulling fs layer +2c0943abc3bf: Pulling fs layer +cb24635428c8: Pulling fs layer +d6f0613e9d23: Download complete +ba77ee389c59: Pulling fs layer +6ccfc98c93d6: Waiting +3f8bbd50cf66: Waiting +c94f3e870fb7: Waiting +ad1c4e5a24bd: Pulling fs layer +acfa23b3071b: Pulling fs layer +cb24635428c8: Waiting +2a8fb579ec89: Waiting +4f4fb700ef54: Waiting +2c0943abc3bf: Waiting +12a8dc7da30c: Pulling fs layer +9855dcfd20f2: Download complete +acfa23b3071b: Waiting +f68a05f961eb: Pulling fs layer +259a1e61421c: Pulling fs layer +ad1c4e5a24bd: Waiting +c83aec52d35b: Download complete +496617e13811: Pulling fs layer +ba77ee389c59: Waiting +36dae3e8aaeb: Pulling fs layer +78e6cfd97c1c: Pulling fs layer +7a8b8bb12118: Pulling fs layer +12a8dc7da30c: Waiting +259a1e61421c: Waiting +36dae3e8aaeb: Waiting +c0b18b287c3d: Pulling fs layer +496617e13811: Waiting +78e6cfd97c1c: Waiting +f68a05f961eb: Waiting +8fc3280b9592: Pulling fs layer +4c588b2f3c52: Pulling fs layer +7a8b8bb12118: Waiting +c0b18b287c3d: Waiting +6e7f98fc3c3c: Pulling fs layer +8fc3280b9592: Waiting +0f001d24ab08: Pulling fs layer +4c588b2f3c52: Waiting +6e7f98fc3c3c: Waiting +18bdb2124ba5: Pulling fs layer +5a43bd85a755: Pulling fs layer +ccca37fde2d4: Pulling fs layer +c323ae58d542: Pulling fs layer +6a80de74105c: Pulling fs layer +c0d657038960: Pulling fs layer +28e6bb2a8d6d: Pulling fs layer +db82ef3237fe: Pulling fs layer +718c6dc45196: Pulling fs layer +ccca37fde2d4: Waiting +6bf0db70bffe: Pulling fs layer +0f001d24ab08: Waiting +bb83b415f5fc: Pulling fs layer +18bdb2124ba5: Waiting +5a43bd85a755: Waiting +a234eb967d7b: Pulling fs layer +823b5e9533dd: Pulling fs layer +ce1c4844bbf7: Pulling fs layer +6bf0db70bffe: Waiting +db82ef3237fe: Waiting +6a80de74105c: Waiting +c323ae58d542: Waiting +c0d657038960: Waiting +28e6bb2a8d6d: Waiting +718c6dc45196: Waiting +dc542a35ca52: Pulling fs layer +a234eb967d7b: Waiting +823b5e9533dd: Waiting +bb83b415f5fc: Waiting +b81e63854d41: Pulling fs layer +ce1c4844bbf7: Waiting +00b137cd0089: Pulling fs layer +dc542a35ca52: Waiting +b81e63854d41: Waiting +6d1d3d590ba6: Pulling fs layer +00b137cd0089: Waiting +209019d349ee: Pulling fs layer +6d1d3d590ba6: Waiting +7211da8909df: Pulling fs layer +209019d349ee: Waiting +a800c1f8f343: Pulling fs layer +7211da8909df: Waiting +e095a56006a5: Pulling fs layer +486b98f2a656: Pulling fs layer +6bde50c38090: Pulling fs layer +7123d5ebcb9e: Pulling fs layer +f91b660c4ada: Pulling fs layer +a7cb07079fab: Pulling fs layer +a723cd062f34: Pulling fs layer +efb3be4092e8: Pulling fs layer +6bde50c38090: Waiting +a800c1f8f343: Waiting +e095a56006a5: Waiting +a7cb07079fab: Waiting +486b98f2a656: Waiting +7123d5ebcb9e: Waiting +106c538e709b: Pulling fs layer +f91b660c4ada: Waiting +a723cd062f34: Waiting +5b4c2606e98b: Pulling fs layer +758a682b26b3: Pulling fs layer +efb3be4092e8: Waiting +5927a8134b70: Pulling fs layer +106c538e709b: Waiting +3d1ac6bc51b7: Pulling fs layer +36c124538e84: Pulling fs layer +0476e9be89d3: Pulling fs layer +060c73eb2484: Pulling fs layer +5927a8134b70: Waiting +5b4c2606e98b: Waiting +8fe7027986a1: Pulling fs layer +758a682b26b3: Waiting +3d1ac6bc51b7: Waiting +36c124538e84: Waiting +73fe3721f250: Pulling fs layer +0476e9be89d3: Waiting +060c73eb2484: Waiting +9d30292775f7: Pulling fs layer +8fe7027986a1: Waiting +baf28902c8a0: Pulling fs layer +73fe3721f250: Waiting +15d08917e116: Pulling fs layer +9d30292775f7: Waiting +baf28902c8a0: Waiting +3c7144ccae05: Pulling fs layer +a70caf50e821: Pulling fs layer +f951e5dc5f2b: Pulling fs layer +b3643561a28a: Pulling fs layer +1cd3e02ec777: Pulling fs layer +df8ec6edfaf5: Pulling fs layer +b3f94cb1b75a: Pulling fs layer +a70caf50e821: Waiting +804ebcc046d9: Pulling fs layer +15d08917e116: Waiting +1cd3e02ec777: Waiting +df8ec6edfaf5: Waiting +f951e5dc5f2b: Waiting +3c7144ccae05: Waiting +cf879fce3464: Pulling fs layer +b3643561a28a: Waiting +c0926ef31a8b: Pulling fs layer +804ebcc046d9: Waiting +b3f94cb1b75a: Waiting +cf879fce3464: Waiting +23cf35bcd6b8: Pulling fs layer +c0926ef31a8b: Waiting +acda416aa112: Pulling fs layer +acda416aa112: Waiting +23cf35bcd6b8: Waiting diff --git a/docs/features/low_precision_training/introduction/autocast_pytorch.py b/docs/features/low_precision_training/introduction/autocast_pytorch.py new file mode 100644 index 0000000000..dd01a1bb53 --- /dev/null +++ b/docs/features/low_precision_training/introduction/autocast_pytorch.py @@ -0,0 +1,69 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import torch + +# Requires Ada (SM89) or newer for FP8 support +assert torch.cuda.get_device_capability()[0] >= 9 or ( + torch.cuda.get_device_capability()[0] == 8 and torch.cuda.get_device_capability()[1] >= 9 +), "This example requires SM89 (Ada) or newer" + +# START_AUTOCAST_BASIC + +import torch +import transformer_engine.pytorch as te +from transformer_engine.common.recipe import DelayedScaling + +recipe = DelayedScaling() +layer = te.Linear(1024, 1024) +inp = torch.randn(32, 1024, dtype=torch.float32, device="cuda") + +with te.autocast(enabled=True, recipe=recipe): + output = layer(inp) + +# .backward() is called outside of autocast +loss = output.sum() +loss.backward() + +# END_AUTOCAST_BASIC + + +# START_AUTOCAST_SEQUENTIAL + +encoder_recipe = DelayedScaling(fp8_format="E4M3") +decoder_recipe = DelayedScaling(fp8_format="HYBRID") + +encoder = te.Linear(1024, 1024) +decoder = te.Linear(1024, 1024) + +with te.autocast(enabled=True, recipe=encoder_recipe): + hidden = encoder(inp) + +with te.autocast(enabled=True, recipe=decoder_recipe): + output = decoder(hidden) + +# END_AUTOCAST_SEQUENTIAL + + +# START_AUTOCAST_NESTED + +outer_recipe = DelayedScaling(fp8_format="E4M3") +inner_recipe = DelayedScaling(fp8_format="HYBRID") + +layer1 = te.Linear(1024, 1024) +layer2 = te.Linear(1024, 1024) +layer3 = te.Linear(1024, 1024) + +with te.autocast(enabled=True, recipe=outer_recipe): + # layer1 uses outer_recipe + x = layer1(inp) + + with te.autocast(enabled=True, recipe=inner_recipe): + # layer2 uses inner_recipe (overrides outer) + x = layer2(x) + + # layer3 uses outer_recipe again + output = layer3(x) + +# END_AUTOCAST_NESTED diff --git a/docs/features/low_precision_training/introduction/bf16_fp16_training_jax.py b/docs/features/low_precision_training/introduction/bf16_fp16_training_jax.py new file mode 100644 index 0000000000..5305ec9701 --- /dev/null +++ b/docs/features/low_precision_training/introduction/bf16_fp16_training_jax.py @@ -0,0 +1,52 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +# START_BF16_FP16_TRAINING + +import jax +import jax.numpy as jnp +import optax +import transformer_engine.jax as te +from transformer_engine.jax.flax import TransformerLayer +from transformer_engine.jax.sharding import MeshResource, global_shard_guard + + +def run_forward_backward(params_dtype, compute_dtype): + # Create TransformerLayer + layer = TransformerLayer( + hidden_size=1024, + mlp_hidden_size=4096, + num_attention_heads=16, + dtype=params_dtype, + ) + + # Initialize parameters and optimizer + init_key, dropout_key = jax.random.split(jax.random.PRNGKey(0)) + x = jax.random.normal(init_key, (32, 128, 1024), dtype=compute_dtype) + params = layer.init({"params": init_key, "dropout": dropout_key}, x) + + # Create optimizer + optimizer = optax.sgd(learning_rate=0.01) + opt_state = optimizer.init(params) + + # Forward and backward pass + def loss_fn(params): + output = layer.apply(params, x, rngs={"dropout": dropout_key}) + assert output.dtype == compute_dtype + return output.sum() + + loss, grads = jax.value_and_grad(loss_fn)(params) + + # Update parameters + updates, opt_state = optimizer.update(grads, opt_state, params) + params = optax.apply_updates(params, updates) + + +# Set up mesh resource for single GPU +with global_shard_guard(MeshResource()): + run_forward_backward(jnp.float32, jnp.float32) # high precision training + run_forward_backward(jnp.float32, jnp.bfloat16) # bfloat16 training with master weights in FP32 + run_forward_backward(jnp.bfloat16, jnp.bfloat16) # bfloat16 training with weights in BF16 + +# END_BF16_FP16_TRAINING diff --git a/docs/features/low_precision_training/introduction/bf16_fp16_training_pytorch.py b/docs/features/low_precision_training/introduction/bf16_fp16_training_pytorch.py new file mode 100644 index 0000000000..154aeb898a --- /dev/null +++ b/docs/features/low_precision_training/introduction/bf16_fp16_training_pytorch.py @@ -0,0 +1,56 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +# START_BF16_FP16_TRAINING + +import torch +import transformer_engine.pytorch as te +from contextlib import nullcontext + + +def run_forward_backward(params_dtype, autocast_precision, grad_scaler_enabled): + if grad_scaler_enabled: + grad_scaler = torch.amp.GradScaler("cuda") + + layer = te.TransformerLayer( + hidden_size=1024, + ffn_hidden_size=4096, + num_attention_heads=16, + params_dtype=params_dtype, + ) + optimizer = torch.optim.SGD(layer.parameters(), lr=0.01) + x = torch.randn(32, 128, 1024, dtype=params_dtype, device="cuda") + + autocast_ctx = ( + torch.autocast(device_type="cuda", dtype=autocast_precision) + if autocast_precision is not None + else nullcontext() + ) + with autocast_ctx: + output = layer(x) + assert ( + output.dtype == autocast_precision if autocast_precision is not None else params_dtype + ) + loss = output.sum() + if grad_scaler_enabled: + grad_scaler.scale(loss).backward() + grad_scaler.step(optimizer) + grad_scaler.update() + else: + loss.backward() + optimizer.step() + + +run_forward_backward(torch.float32, torch.float32, False) # high precision training +run_forward_backward( + torch.float32, torch.bfloat16, False +) # bfloat16 training with master weights in FP32 +run_forward_backward( + torch.float32, torch.float16, True +) # fp16 training with master weights in FP32, needs loss scaling +run_forward_backward( + torch.bfloat16, torch.bfloat16, False +) # bfloat16 training with weights in BF16 + +# END_BF16_FP16_TRAINING diff --git a/docs/features/low_precision_training/introduction/fp8_autocast_jax.py b/docs/features/low_precision_training/introduction/fp8_autocast_jax.py new file mode 100644 index 0000000000..19bf6ee6b6 --- /dev/null +++ b/docs/features/low_precision_training/introduction/fp8_autocast_jax.py @@ -0,0 +1,53 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import jax + +# Requires Ada (SM89) or newer for FP8 support +cc = jax.devices()[0].device_kind +assert ( + "RTX 40" in cc + or "L40" in cc + or "H100" in cc + or "H200" in cc + or "GH" in cc + or "B100" in cc + or "B200" in cc + or "GB" in cc +), "This example requires SM89 (Ada) or newer" + +# START_FP8_AUTOCAST + +import jax +import jax.numpy as jnp +import optax +import transformer_engine.jax as te +from transformer_engine.jax.flax import TransformerLayer +from transformer_engine.jax.sharding import MeshResource, global_shard_guard +from transformer_engine.jax.quantize import get_delayed_scaling_recipe + +# Set up mesh resource and FP8 recipe +recipe = get_delayed_scaling_recipe() + +with global_shard_guard(MeshResource()): + with te.fp8_autocast(enabled=True, recipe=recipe, mesh_resource=MeshResource()): + # Create layer and initialize + layer = TransformerLayer( + hidden_size=1024, + mlp_hidden_size=4096, + num_attention_heads=16, + ) + + init_key, dropout_key = jax.random.split(jax.random.PRNGKey(0)) + x = jax.random.normal(init_key, (32, 128, 1024), dtype=jnp.bfloat16) + params = layer.init({"params": init_key, "dropout": dropout_key}, x) + + # Forward and backward pass + def loss_fn(params): + output = layer.apply(params, x, rngs={"dropout": dropout_key}) + return output.sum() + + loss, grads = jax.value_and_grad(loss_fn)(params) + +# END_FP8_AUTOCAST diff --git a/docs/features/low_precision_training/introduction/img/fp8_linear_flow.svg b/docs/features/low_precision_training/introduction/img/fp8_linear_flow.svg new file mode 100644 index 0000000000..e0bd380414 --- /dev/null +++ b/docs/features/low_precision_training/introduction/img/fp8_linear_flow.svg @@ -0,0 +1,172 @@ + + + + + + + + + + + FP8 Linear Layer – Forward and Backward Pass + + + Forward Pass + + + + InputT + + + + Input + + + + + + + Quantize + + + + + + + + + Input + + + + N + + + + Weight + + + + + + + Quantize + + + + + + + + + Weight + + + + WeightT + + + + T + + + + FP8 GEMM + (TN) + + + + + + + Output + + + + + + Backward Pass + + + + WeightT + + + + Output grad. + + + + + + + Quantize + + + + + + + + + Output grad. + + + + Output grad.T + + + + FP8 GEMM + (TN) + + + + Input grad. + + + + FP8 GEMM + (TN) + + + + Weight grad. + + + + InputT + + + + + N + + + T + + + + + + N + + + T + + + + + + + + Higher Precision (FP32/BF16/FP16) + + + + Lower Precision (FP8, MXFP8 etc.) + + + diff --git a/docs/features/low_precision_training/introduction/img/fp_formats_comparison.svg b/docs/features/low_precision_training/introduction/img/fp_formats_comparison.svg new file mode 100644 index 0000000000..abaa80db3a --- /dev/null +++ b/docs/features/low_precision_training/introduction/img/fp_formats_comparison.svg @@ -0,0 +1,183 @@ + + + + + + + sign + exponent + mantissa + + + FP32 + + + + 0 + + + + 0 + + 1 + + 1 + + 1 + + 1 + + 1 + + 0 + + 1 + + + + 1 + + 0 + + 0 + + 1 + + 0 + + 1 + + 0 + + 0 + + 1 + + 0 + + 1 + + 0 + + 1 + + 1 + + 1 + + 1 + + 0 + + 1 + + 0 + + 1 + + 0 + + 0 + + 0 + + = 0.3952 + + + + BF16 + + + + 0 + + + + 0 + + 1 + + 1 + + 1 + + 1 + + 1 + + 0 + + 1 + + + + 1 + + 0 + + 0 + + 1 + + 0 + + 1 + + 0 + + ≈ 0.3945 + + + + FP16 + + + + 0 + + + + 0 + + 1 + + 1 + + 0 + + 1 + + + + 1 + + 0 + + 0 + + 1 + + 0 + + 1 + + 0 + + 0 + + 1 + + 0 + + ≈ 0.3950 + + diff --git a/docs/features/low_precision_training/introduction/img/master_weights_approaches.svg b/docs/features/low_precision_training/introduction/img/master_weights_approaches.svg new file mode 100644 index 0000000000..fc1d5f244f --- /dev/null +++ b/docs/features/low_precision_training/introduction/img/master_weights_approaches.svg @@ -0,0 +1,112 @@ + + + + + + + + + + + Master Weights Storage Approaches + + + + + + + Low Precision Weights + (no master weights) + + + + Model + + Weights (BF16/FP16) + + + + + + + Forward/Backward + + + + + + + Optimizer + + State (FP32) + + + Master Weights in Model + + + + Model + + Weights (FP32) + + + + + cast to BF16/FP16 + + + + Forward/Backward + + + + + + + Optimizer + + State (FP32) + + + Master Weights in Optimizer + + + + cast to BF16/FP16 + + + + + + + Model + + Weights (BF16/FP16) + + + + + + + Forward/Backward + + + + + + + Optimizer + + State (FP32) + + Master (FP32) + + + + + diff --git a/docs/features/low_precision_training/introduction/img/mixed_precision_operations.svg b/docs/features/low_precision_training/introduction/img/mixed_precision_operations.svg new file mode 100644 index 0000000000..b3c02c8601 --- /dev/null +++ b/docs/features/low_precision_training/introduction/img/mixed_precision_operations.svg @@ -0,0 +1,105 @@ + + + + + + + + + + + Transformer Layer – default precision of operation in low precision recipe + + + + Input + + + + + Layer Norm + + + + + QKV Linear + + + + + QK^T + + + + + Softmax + + + + + + Attn * V + + + + + Output Linear + + + + + Dropout + Add + + + + + + Layer Norm + + + + + FFN Linear 1 + + + + + GELU + + + + + FFN Linear 2 + + + + + Output + + + + + + + Parameters + + + + Gradients + + + + + + Higher Precision (FP32/BF16/FP16) + + + + Lower Precision (FP8, MXFP8 etc.) + + + diff --git a/docs/features/low_precision_training/introduction/introduction.rst b/docs/features/low_precision_training/introduction/introduction.rst new file mode 100644 index 0000000000..8a5d6c7aca --- /dev/null +++ b/docs/features/low_precision_training/introduction/introduction.rst @@ -0,0 +1,277 @@ +.. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + + See LICENSE for license information. + +Introduction +=================================== + +Transformer Engine accelerates deep learning by leveraging low precision formats on NVIDIA GPUs. +This chapter introduces mixed precision training and FP8 support. + + +Training in BF16/FP16 +--------------------- + +Deep learning traditionally uses 32-bit floating-point (FP32) numbers. +NVIDIA GPUs support lower precision formats—FP16 since Pascal, BF16 since Ampere—which offer higher throughput and lower memory usage. +Let's compare these formats. + +.. raw:: html + :file: img/fp_formats_comparison.svg + +*Figure 1: Comparison of FP32, BF16, and FP16 floating-point formats showing bit allocation for sign, exponent, and mantissa.* + +The key differences between these formats are: + +* **FP32** (32 bits total): 1 sign bit + 8 exponent bits + 23 mantissa bits – standard single-precision format +* **BF16** (16 bits total): 1 sign bit + 8 exponent bits + 7 mantissa bits – maintains FP32's exponent range but reduced precision +* **FP16** (16 bits total): 1 sign bit + 5 exponent bits + 10 mantissa bits – reduced range but higher precision than BF16 + +BF16's advantage is that it shares the same exponent range as FP32, +making it easier to convert between the two formats without overflow/underflow issues. +FP16 offers better precision for smaller values but has a more limited dynamic range, +which results in the need to perform loss scaling to avoid overflow/underflow—see `this paper on loss scaling `__ for more details. + +**Mixed precision** + +Not all operations can run in reduced precision. +Modern deep learning frameworks use *mixed precision training*, where: + +* *Low precision* is used for matrix multiplications and other compute-heavy operations, which remain numerically stable at lower precision, +* *High precision (FP32)* must be used for numerically sensitive operations to maintain training stability. These include layer normalization, softmax, and loss computations—operations that involve division or exponentiation, where small rounding errors can amplify and propagate through the network, leading to gradient instability or degraded convergence. + +**Master weights** + +Mixed precision training also raises the question of how to store model weights. +Lower precision formats like FP16 and BF16 have limited representational granularity, +which becomes problematic during gradient updates. +When a small gradient is added to a not so small weight stored in low precision, +the result may round back to the original value if the update falls below the format's precision threshold. +Moreover, some elements of the gradient itself can be too small to be represented in low precision. + +The solution is to maintain *master weights* in FP32. +During training, weights are cast to lower precision for forward and backward passes, +but the gradient updates are applied to the full-precision master copy. +This ensures that even small gradients accumulate correctly over time. + +There are two common software approaches to storing master weights: + +* *In the optimizer*: + The model holds low-precision weights, + while the optimizer maintains FP32 copies alongside momentum and other state. + During each step, + the optimizer updates its FP32 copy and casts the result back to the model's low-precision weights. + This makes it easier to shard master weights together with other optimizer state, for example in ZeRO optimizer. + +* *In the model*: + The model stores weights directly in FP32, + and they are cast to lower precision on-the-fly during forward and backward passes. + This approach works seamlessly with any standard optimizer, requiring no special support. + +.. raw:: html + :file: img/master_weights_approaches.svg + +*Figure 2: Three approaches to weight storage—low precision only (no master weights), master weights stored in the model, and master weights stored in the optimizer.* + +.. tabs:: + + .. tab:: PyTorch + + The PyTorch API of Transformer Engine provides two mechanisms to control precision: + + * **Weight precision**: Use the ``params_dtype`` argument in any TE layer constructor. + * **Computation precision**: Use the ``torch.autocast`` context manager. + + If parameters are set to be in lower precision and no autocast is used, then lower precision is used for computation. + Input is cast to lower precision before the computation inside the layer. + Output precision is the same as autocast precision. + + .. literalinclude:: bf16_fp16_training_pytorch.py + :language: python + :start-after: # START_BF16_FP16_TRAINING + :end-before: # END_BF16_FP16_TRAINING + + + .. tab:: JAX + + The JAX API of Transformer Engine provides two mechanisms to control precision: + + * **Weight precision**: Use the ``dtype`` argument in any TE layer constructor. + * **Computation precision**: Determined by the dtype of the input tensor. + + For training with master weights in FP32 and computation in BF16, + cast the input tensor to BF16 before passing it to the layer. + + .. literalinclude:: bf16_fp16_training_jax.py + :language: python + :start-after: # START_BF16_FP16_TRAINING + :end-before: # END_BF16_FP16_TRAINING + + + +Lower precisions +---------------- + +Transformer Engine's primary feature is supporting even lower precision than BF16/FP16, such as FP8, MXFP8, NVFP4, etc. +The logic of these precisions is more complicated than the logic of BF16/FP16 – they require scaling factors to +properly represent the full range of values in the tensor. Sometimes it is one scaling factor per tensor, +sometimes it is one scaling factor per block of values. A precision format combined with the logic for training +is called **a recipe**. + +In this section we present common logic for all the recipes. Each one of them is described in more detail in a separate section later. +Let's now see how we can train in lower precisions in supported frameworks. + +.. tabs:: + + .. tab:: PyTorch + + The PyTorch API of Transformer Engine provides an ``autocast`` context manager to control precision. + It's similar to the ``torch.autocast`` context manager, but tailored for low precision training. + The most important argument is the ``recipe`` argument, which accepts objects inheriting from + :class:`~transformer_engine.common.recipe.Recipe`. + + Forward computations need to be performed inside the ``autocast`` context manager, + while the ``.backward()`` call should be outside of it. + + Here is a basic example: + + .. raw:: html + +
+ Needs to be run on SM89+ (Ada or newer) +
+ + .. literalinclude:: autocast_pytorch.py + :language: python + :start-after: # START_AUTOCAST_BASIC + :end-before: # END_AUTOCAST_BASIC + + You can use multiple recipes in the same model in the following ways: + + **Sequential contexts** – apply different recipes to different parts of your model: + + .. raw:: html + +
+ Needs to be run on SM89+ (Ada or newer) +
+ + .. literalinclude:: autocast_pytorch.py + :language: python + :start-after: # START_AUTOCAST_SEQUENTIAL + :end-before: # END_AUTOCAST_SEQUENTIAL + + **Nested contexts** – the inner context overrides the outer one for its scope: + + .. raw:: html + +
+ Needs to be run on SM89+ (Ada or newer) +
+ + .. literalinclude:: autocast_pytorch.py + :language: python + :start-after: # START_AUTOCAST_NESTED + :end-before: # END_AUTOCAST_NESTED + + + .. tab:: JAX + + The JAX API of Transformer Engine provides an ``autocast`` context manager similar to PyTorch. + The key difference is that in JAX, model initialization must happen inside the ``autocast`` context + to properly capture quantization metadata in the parameter tree. + + Additionally, JAX requires a ``global_shard_guard(MeshResource())`` context (even for single GPU) + and the ``mesh_resource`` argument in the ``autocast`` call. + + Here is a basic example: + + .. raw:: html + +
+ Needs to be run on SM89+ (Ada or newer) +
+ + .. literalinclude:: autocast_jax.py + :language: python + :start-after: # START_AUTOCAST_BASIC + :end-before: # END_AUTOCAST_BASIC + + You can use multiple recipes in the same model in the following ways: + + **Sequential contexts** – apply different recipes to different parts of your model: + + .. raw:: html + +
+ Needs to be run on SM89+ (Ada or newer) +
+ + .. literalinclude:: autocast_jax.py + :language: python + :start-after: # START_AUTOCAST_SEQUENTIAL + :end-before: # END_AUTOCAST_SEQUENTIAL + + **Nested contexts** – the inner context overrides the outer one for its scope: + + .. raw:: html + +
+ Needs to be run on SM89+ (Ada or newer) +
+ + .. literalinclude:: autocast_jax.py + :language: python + :start-after: # START_AUTOCAST_NESTED + :end-before: # END_AUTOCAST_NESTED + +**Mixed precision with 8- or 4-bit precisions** + +From now on, we will refer to FP8/MXFP8/NVFP4 etc. as *low precision* +and to FP32/BF16/FP16 as *high precision*. This terminology will be +used throughout the rest of the documentation. + +Not all operations run in low precision: + +- **Non-attention linear operations**: run in low precision. +- **Attention computations**: run in high precision by default (some recipes allow low precision as an option). +- **Other operations** (layer normalization, softmax, etc.): run in high precision. + +Within high-precision operations, there are two categories: + +- **Configurable precision**: most operations run in parameter precision (FP32/BF16/FP16) or the precision specified by ``torch.autocast``. +- **Fixed FP32 precision**: some operations, or parts of operations—such as the division in layernorm—always run in FP32, regardless of other settings. + +.. raw:: html + :file: img/mixed_precision_operations.svg + +*Figure 3: Default single-device forward pass of TransformerLayer operations precision – only linear operations (outside of dot product attention) are in lower precision.* + +**Linear layer data flow** + +Let's see how data flow of a linear layer works by default on a single H100 GPU with FP8 precision: + +H100 (Hopper) architecture natively supports FP8 Matrix Multiplication only in **TN** layout (Transpose-NoTranspose), +so GEMM with tensors ``A`` and ``B`` returns ``B * A^T``. + +*Forward pass* + +* Input is quantized to FP8 – both ``input`` and ``input^T`` quantized versions are created. +* Weights are stored in high precision and quantized to low precision before the GEMM – both ``weight`` and ``weight^T`` quantized versions are created. +* FP8 GEMM with layout **TN** is run with ``weight`` and ``input`` tensors, +* Outputs – ``input * weight^T`` tensor – are returned in high precision. + +*Backward pass* + +* Output gradients are quantized to FP8 – both ``output_grad`` and ``output_grad^T`` quantized versions are created. +* FP8 GEMM with layout **TN** is performed with ``weight^T`` and ``output_grad`` tensors to compute input gradients. +* FP8 GEMM with layout **TN** is performed with ``input^T`` and ``output_grad^T`` tensors to compute weight gradients. +* Input gradients – ``output_grad * weight`` tensor – are returned in high precision. +* Weight gradients – ``output_grad^T * input`` tensor – are returned in high precision. + + +.. raw:: html + :file: img/fp8_linear_flow.svg + +*Figure 4: Forward pass of a Linear layer with low precision data flow.* diff --git a/docs/features/low_precision_training/introduction/jax_out b/docs/features/low_precision_training/introduction/jax_out new file mode 100644 index 0000000000..c8e7466dd3 --- /dev/null +++ b/docs/features/low_precision_training/introduction/jax_out @@ -0,0 +1,13 @@ +# START_MEMORY_USAGE_1 +Layer size: 2.00 MB +Memory usage after forward pass: 6.00 MB +# END_MEMORY_USAGE_1 + +# START_MEMORY_USAGE_2 +Memory after forward pass: 8.00 MB +# END_MEMORY_USAGE_2 + +# START_MEMORY_USAGE_3 +Layer size: 1.00 MB +Memory after forward pass: 6.00 MB +# END_MEMORY_USAGE_3 diff --git a/docs/features/low_precision_training/introduction/pytorch_out b/docs/features/low_precision_training/introduction/pytorch_out new file mode 100644 index 0000000000..cb189bb8e3 --- /dev/null +++ b/docs/features/low_precision_training/introduction/pytorch_out @@ -0,0 +1,11 @@ +# START_MEMORY_USAGE_1 +Layer size: 2.00 MB +Memory usage after forward pass: 5.88 MB +# END_MEMORY_USAGE_1 +# START_MEMORY_USAGE_2 +Memory after forward pass: 7.90 MB +# END_MEMORY_USAGE_2 +# START_MEMORY_USAGE_3 +Layer size: 0.92 MB +Memory after forward pass: 5.92 MB +# END_MEMORY_USAGE_3 diff --git a/docs/features/low_precision_training/mxfp8/img/fp8_1d_scaling.svg b/docs/features/low_precision_training/mxfp8/img/fp8_1d_scaling.svg new file mode 100644 index 0000000000..ea86117de3 --- /dev/null +++ b/docs/features/low_precision_training/mxfp8/img/fp8_1d_scaling.svg @@ -0,0 +1,177 @@ + + + + + + + + MXFP8 + (One scaling factor per 32 elements) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + E8M0 scaling factors (one per 32 elements) + + + diff --git a/docs/features/low_precision_training/mxfp8/img/mxfp8_row_col.svg b/docs/features/low_precision_training/mxfp8/img/mxfp8_row_col.svg new file mode 100644 index 0000000000..a8a8d16caf --- /dev/null +++ b/docs/features/low_precision_training/mxfp8/img/mxfp8_row_col.svg @@ -0,0 +1,266 @@ + + + + + + + Rowwise (1x32 blocks) + + + + Data + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Scales + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Columnwise (32x1 blocks) + + + + Data + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Scales + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/features/low_precision_training/mxfp8/img/mxfp8_scale_linearize_and_swizzle.svg b/docs/features/low_precision_training/mxfp8/img/mxfp8_scale_linearize_and_swizzle.svg new file mode 100644 index 0000000000..6e4ed44d56 --- /dev/null +++ b/docs/features/low_precision_training/mxfp8/img/mxfp8_scale_linearize_and_swizzle.svg @@ -0,0 +1,190 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + 1 + 2 + 3 + + K + + + 1 + K + + + 2 + K + + + 3 + + 2K + + + 1 + 2K + + + 1 + 2K + + + 3 + + + + + + + + + + + + + 128x4 + + + + + + + + + + + + 1 + + + 2 + + + + + + K + 1 + + + K + 2 + + + + + + 1x512 + + + + + + + 128 4-bit elements + + + + + + + + + + + + + + + + + + + + + + + + + + + + + 0 + 1 + 2 + 3 + 4 + 5 + 6 + 7 + ... + + + + + + + + + + + + + + + + + + + + + + + 0 + 32 + 64 + 96 + 1 + 33 + 65 + 97 + ... + + + + diff --git a/docs/features/low_precision_training/mxfp8/img/mxfp8_swizzle_both_tensors.svg b/docs/features/low_precision_training/mxfp8/img/mxfp8_swizzle_both_tensors.svg new file mode 100644 index 0000000000..abffb58bac --- /dev/null +++ b/docs/features/low_precision_training/mxfp8/img/mxfp8_swizzle_both_tensors.svg @@ -0,0 +1,101 @@ + + + + + + + + + + + + + + + + Input Tensor + + FP32/BF16 + + + + + + + + Quantize + + + + + + + MXFP8 Tensor + + + + + Scales + + + + FP8 Data + + + + + + + + Communication + (All-Gather) + (Optional) + + + + + + + Swizzle + + + + + + + MXFP8 Tensor + + + + + Swizzle Scales + + + + FP8 Data + + + + + + + + GEMM + + diff --git a/docs/features/low_precision_training/mxfp8/img/mxfp8_tensor_scaling_layout.svg b/docs/features/low_precision_training/mxfp8/img/mxfp8_tensor_scaling_layout.svg new file mode 100644 index 0000000000..3b81ff0a36 --- /dev/null +++ b/docs/features/low_precision_training/mxfp8/img/mxfp8_tensor_scaling_layout.svg @@ -0,0 +1,63 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + FP8 Tensor (128×128 blocks) + + + + + + + + + + + + + + + + + + + + + + + + + + + Scaling Factors (128×4 blocks) + diff --git a/docs/features/low_precision_training/mxfp8/jax_mxfp8_example.py b/docs/features/low_precision_training/mxfp8/jax_mxfp8_example.py new file mode 100644 index 0000000000..ec2ec6f747 --- /dev/null +++ b/docs/features/low_precision_training/mxfp8/jax_mxfp8_example.py @@ -0,0 +1,41 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +# START_MXFP8_EXAMPLE + +import jax +import jax.numpy as jnp +import optax +import transformer_engine.jax as te +from transformer_engine.jax.flax import DenseGeneral +from transformer_engine.jax.sharding import MeshResource, global_shard_guard +from transformer_engine.common.recipe import MXFP8BlockScaling, Format + +# Create MXFP8 recipe +recipe = MXFP8BlockScaling( + fp8_format=Format.E4M3, # FP8 format (default: E4M3, E5M2 not supported) +) + +with global_shard_guard(MeshResource()): + with te.fp8_autocast(enabled=True, recipe=recipe, mesh_resource=MeshResource()): + # Initialize layer and data + layer = DenseGeneral(features=1024) + key = jax.random.PRNGKey(0) + x = jax.random.normal(key, (32, 128, 1024), dtype=jnp.bfloat16) + params = layer.init(key, x) + + # Training with MXFP8 + def loss_fn(params): + output = layer.apply(params, x) + return output.sum() + + loss, grads = jax.value_and_grad(loss_fn)(params) + + # Update parameters + optimizer = optax.adamw(learning_rate=1e-4) + opt_state = optimizer.init(params) + updates, opt_state = optimizer.update(grads, opt_state, params) + params = optax.apply_updates(params, updates) + +# END_MXFP8_EXAMPLE diff --git a/docs/features/low_precision_training/mxfp8/mxfp8.rst b/docs/features/low_precision_training/mxfp8/mxfp8.rst new file mode 100644 index 0000000000..b0c80e837c --- /dev/null +++ b/docs/features/low_precision_training/mxfp8/mxfp8.rst @@ -0,0 +1,199 @@ +.. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + + See LICENSE for license information. + +MXFP8 +===== + + +MXFP8 (Microscaling FP8) is an enhanced FP8 blockwise scaling recipe that leverages native hardware +acceleration on Blackwell GPUs (SM 10.0+). By using one scaling factor per 32 consecutive values +(rather than 128), MXFP8 delivers finer-grained quantization with improved numerical precision. + + + +Data Format +----------- + +The representation of an FP8 tensor element ``x`` in MXFP8 precision is given by: + +.. code-block:: python + + x = x_fp8 * s_block + +where + +* ``x_fp8`` is the FP8 value in E4M3 format, +* ``s_block`` is a local **E8M0** scaling factor shared by a block of 32 elements. + + +**FP8 format** + +Like FP8 Blockwise Scaling, E4M3 is used by default for both forward and backward passes. +The finer-grained scaling provides sufficient dynamic range without requiring the E5M2 format. +The ``fp8_format`` parameter also supports ``HYBRID`` mode (E4M3 for forward, E5M2 for backward). +Pure E5M2 training is not supported. + + +**Block size** + +Block size is 32. +Blocks are one-dimensional, containing 32 consecutive values. No 2D scaling is performed. + +There are some assumptions on the dimensions of the tensor: + +* the tensor must have at least 2 dimensions, +* the last dimension must be divisible by 32, +* the product of all dimensions except the last must be divisible by 32. + + +**Scaling factors** + +Scaling factors are stored as E8M0 (8 exponent bits, 0 mantissa bits), which inherently represents +powers of 2. This differs from FP8 Blockwise Scaling, which uses 32-bit floating point numbers +optionally constrained to powers of 2. Note that FP32 also has 8 exponent bits, so the representable +ranges are similar when the power-of-2 constraint is enabled. + +Each block's scaling factor is computed through the following steps: + +1. Find the maximum absolute value (``amax_block``) across all 32 elements in the block. +2. Compute the E8M0 biased exponent: ``e = float_to_e8m0(amax_block / max_fp8)``, where ``max_fp8 = 448`` + (the maximum representable value in E4M3 format). + + Since E8M0 and FP32 share the same exponent bias (127), ``float_to_e8m0`` simply extracts + the 8-bit exponent from the FP32 representation, rounding up if the mantissa is non-zero. + +3. The scaling factor is ``s_block = 2^(e - 127)``. + +This ensures that the largest value in each block fits within the FP8 representable range without overflow. + + +.. raw:: html + :file: img/fp8_1d_scaling.svg + +*Figure 1. MXFP8 uses one E8M0 scaling factor per 32 consecutive elements, providing fine-grained +quantization and compact scaling factor representation.* + + +Handling transposes +------------------- + +Blackwell architecture supports multiple FP8 GEMM layouts (TN, NT, NN), so columnwise usage +does not require explicit transposition. However, rowwise and columnwise quantizations are different: + +- *Rowwise* - 1 scaling factor per 32 consecutive elements along a row (1×32 blocks). +- *Columnwise* - 1 scaling factor per 32 consecutive elements along a column (32×1 blocks). + +Because the scaling factor blocks have different orientations, rowwise and columnwise MXFP8 tensors +are numerically different — one cannot derive one from the other. Both must be quantized +independently from full-precision data. + +.. raw:: html + :file: img/mxfp8_row_col.svg + +*Figure 2. MXFP8 rowwise vs columnwise quantization layout.* + + +Swizzling scaling factors +------------------------- + +Like :doc:`FP8 Blockwise Scaling <../fp8_blockwise_scaling/fp8_blockwise_scaling>`, MXFP8 uses different data layouts for communication and computation. +MXFP8 GEMMs require scaling factors in a specific hardware layout +(see `cuBLAS documentation `__). +The conversion to this GEMM-ready layout is called *swizzling*. Because swizzled scaling factors +cannot be communicated across devices, Transformer Engine performs swizzling after any required +communication, just before each GEMM operation. + +.. raw:: html + :file: img/mxfp8_swizzle_both_tensors.svg + +*Figure 3. MXFP8 swizzling process: standard scaling factors are rearranged into the hardware-required layout.* + + +Blackwell Tensor Cores compute matrix multiplications using ``128x128`` tiles. +Scaling factors are stored in row-major order, but to process a tile, we need a ``128x4`` vertical +slice of scaling factors. In row-major storage, these vertical slices are scattered in memory +with gaps between each row. The hardware requires them to be stored contiguously. + +.. raw:: html + :file: img/mxfp8_tensor_scaling_layout.svg + +*Figure 4. FP8 tensor (left) is divided into 128x128 tiles. Each tile requires a 128x4 block of scaling factors (right). These vertical blocks are not contiguous in memory.* + +Swizzling transforms the layout to meet hardware requirements by: + +1. **Linearizing** the ``128x4`` blocks so they are stored contiguously one after another. +2. **Permuting** the 4-byte elements within each block. + +Specifically, if we index the 128 4-byte elements in a scaling factor block as :math:`0, 1, \dots, 127`, the hardware expects them in the following interleaved order: + +.. code-block:: text + + 0, 32, 64, 96, 1, 33, 65, 97, ..., k, 32 + k, 64 + k, 96 + k, ..., 31, 63, 95, 127 + + +.. raw:: html + :file: img/mxfp8_scale_linearize_and_swizzle.svg + +*Figure 5. Linearization and swizzling of scaling factors. The 2D grid of scaling factors is first flattened into a contiguous sequence of blocks (top), then the rows within each block are interleaved to match the hardware access pattern (bottom).* + +For columnwise scaling factors, the process is analogous but with ``4x128`` horizontal blocks instead of ``128x4`` vertical blocks. + + + +Distributed training +-------------------- + +**Scale synchronization** + +The blockwise scaled tensor does not need any scale synchronization among the nodes. +This is because each scaling factor is local to its 32-element block, +unlike :doc:`FP8 Current <../fp8_current_scaling/fp8_current_scaling>`/:doc:`Delayed Scaling <../fp8_delayed_scaling/fp8_delayed_scaling>` where a single global scale applies to the entire tensor, even when sharded. + +**Quantized all-gather** + +All-gather of columnwise tensors is supported and necessary because: + +- columnwise quantized tensors cannot be computed from rowwise quantized ones (as mentioned earlier), +- gathering high-precision tensors is avoided in most cases for performance reasons. + + +Examples +-------- + +Here's how to use MXFP8 recipe in PyTorch and JAX: + +.. tabs:: + + .. tab:: PyTorch + + .. raw:: html + +
+ Requires SM100 (Blackwell) or later +
+ + .. literalinclude:: pytorch_mxfp8_example.py + :language: python + :start-after: # START_MXFP8_EXAMPLE + :end-before: # END_MXFP8_EXAMPLE + + .. tab:: JAX + + .. raw:: html + +
+ Requires SM100 (Blackwell) or later +
+ + .. literalinclude:: jax_mxfp8_example.py + :language: python + :start-after: # START_MXFP8_EXAMPLE + :end-before: # END_MXFP8_EXAMPLE + + +Supported devices +----------------- + +Blackwell and later (SM 10.0+) \ No newline at end of file diff --git a/docs/features/low_precision_training/mxfp8/pytorch_mxfp8_example.py b/docs/features/low_precision_training/mxfp8/pytorch_mxfp8_example.py new file mode 100644 index 0000000000..3c5f9c20ed --- /dev/null +++ b/docs/features/low_precision_training/mxfp8/pytorch_mxfp8_example.py @@ -0,0 +1,30 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +# START_MXFP8_EXAMPLE + +import torch +import transformer_engine.pytorch as te +from transformer_engine.common.recipe import MXFP8BlockScaling, Format + +# Create MXFP8 recipe +recipe = MXFP8BlockScaling( + fp8_format=Format.E4M3, # E4M3 (default) or HYBRID; pure E5M2 not supported +) + +# Create a linear layer +layer = te.Linear(1024, 1024) +optimizer = torch.optim.AdamW(layer.parameters(), lr=1e-4) + +# Training with MXFP8 +inp = torch.randn(32, 128, 1024, dtype=torch.bfloat16, device="cuda") + +with te.fp8_autocast(enabled=True, fp8_recipe=recipe): + output = layer(inp) + loss = output.sum() + +loss.backward() +optimizer.step() + +# END_MXFP8_EXAMPLE diff --git a/docs/features/low_precision_training/nvfp4/img/nvfp4_all_gather.svg b/docs/features/low_precision_training/nvfp4/img/nvfp4_all_gather.svg new file mode 100644 index 0000000000..6d2c62591f --- /dev/null +++ b/docs/features/low_precision_training/nvfp4/img/nvfp4_all_gather.svg @@ -0,0 +1,118 @@ + + + + + + + + + + + Quantization + All-Gather for NVFP4 + + + + High Precision + Tensor + + + + + + + Compute + Amax + + + + + + + Synchronize + Amax + + + + + + + Compute + s_global + + + + + + + Scale + Cast + (s_block, + s_global) + + + + + + + NVFP4 + Tensor + + + + + + + All-Gather + + + + + + + NVFP4 Gathered + Tensor + + + diff --git a/docs/features/low_precision_training/nvfp4/img/nvfp4_hierarchical_scaling.svg b/docs/features/low_precision_training/nvfp4/img/nvfp4_hierarchical_scaling.svg new file mode 100644 index 0000000000..1a19d813e6 --- /dev/null +++ b/docs/features/low_precision_training/nvfp4/img/nvfp4_hierarchical_scaling.svg @@ -0,0 +1,186 @@ + + + + + + + + NVFP4 Hierarchical Scaling + (Block scaling + Global scaling) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + E4M3 scaling factors (one per 16 elements) + + + + + Global Scale (FP32) + (one per tensor) + + + + + + \ No newline at end of file diff --git a/docs/features/low_precision_training/nvfp4/img/nvfp4_row_col.svg b/docs/features/low_precision_training/nvfp4/img/nvfp4_row_col.svg new file mode 100644 index 0000000000..2030688fb6 --- /dev/null +++ b/docs/features/low_precision_training/nvfp4/img/nvfp4_row_col.svg @@ -0,0 +1,208 @@ + + + + + + + Rowwise (1×16 blocks) + + + + Data [A, B] + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + s_block [A, B/16] + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + s_global + + + + + Columnwise (16×1 blocks) — transposed storage + + + + Data [B, A] + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + s_block [B, A/16] + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + s_global + + + diff --git a/docs/features/low_precision_training/nvfp4/img/nvfp4_vs_fp8.svg b/docs/features/low_precision_training/nvfp4/img/nvfp4_vs_fp8.svg new file mode 100644 index 0000000000..68f6bf9039 --- /dev/null +++ b/docs/features/low_precision_training/nvfp4/img/nvfp4_vs_fp8.svg @@ -0,0 +1,91 @@ + + + + + + + FP8 E4M3 + + + + 0 + + + + 1 + + 0 + + 0 + + 0 + + + + 1 + + 1 + + 1 + + (1 sign, 4 exp, 3 mantissa) + + + + FP8 E5M2 + + + + 0 + + + + 1 + + 0 + + 0 + + 0 + + 0 + + + + 1 + + 1 + + (1 sign, 5 exp, 2 mantissa) + + + + NVFP4 + + + + 0 + + + + 1 + + 0 + + + + 1 + + (1 sign, 2 exp, 1 mantissa) + + + + diff --git a/docs/features/low_precision_training/nvfp4/img/rht.svg b/docs/features/low_precision_training/nvfp4/img/rht.svg new file mode 100644 index 0000000000..2111d64e26 --- /dev/null +++ b/docs/features/low_precision_training/nvfp4/img/rht.svg @@ -0,0 +1,138 @@ + + + + + + + + + + + + Random Hadamard Transform for WGRAD GEMM + + + + + + + Without RHT + + + + + Activations + + + + + + + Quantize + + + + + + + WGRAD + GEMM + + + + + Output Grad + + + + + + + Quantize + + + + + + + + + + Weight Grad + + + + + With RHT + + + + + Activations + + + + + + + RHT + + + + + + + Quantize + + + + + + + WGRAD + GEMM + + + + + Output Grad + + + + + + + RHT + + + + + + + Quantize + + + + + + + + + + Weight Grad + + + diff --git a/docs/features/low_precision_training/nvfp4/img/stochastic_rounding.svg b/docs/features/low_precision_training/nvfp4/img/stochastic_rounding.svg new file mode 100644 index 0000000000..eb745f6e84 --- /dev/null +++ b/docs/features/low_precision_training/nvfp4/img/stochastic_rounding.svg @@ -0,0 +1,95 @@ + + + + + + + + + + + + Round to Nearest + + + + + + + v₁ + + + + v₂ + + + + x + + + + + Round to v₁ + + + 100% + + + Round to v₂ + + + 0% + + + + + + + Stochastic Rounding + + + + + + + v₁ + + + + v₂ + + + + x + + + + + Round to v₁ + + + 60% + + + Round to v₂ + + + 40% + + + + + diff --git a/docs/features/low_precision_training/nvfp4/jax_nvfp4_example.py b/docs/features/low_precision_training/nvfp4/jax_nvfp4_example.py new file mode 100644 index 0000000000..e2cc17fd71 --- /dev/null +++ b/docs/features/low_precision_training/nvfp4/jax_nvfp4_example.py @@ -0,0 +1,43 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +# START_NVFP4_EXAMPLE + +import jax +import jax.numpy as jnp +import optax +import transformer_engine.jax as te +from transformer_engine.jax.flax import DenseGeneral +from transformer_engine.jax.sharding import MeshResource, global_shard_guard +from transformer_engine.common.recipe import NVFP4Recipe, Format + +# Define NVFP4 recipe +recipe = NVFP4Recipe( + fp8_format=Format.E4M3, + use_2d_weight_quantization=True, + use_rht=True, +) + +with global_shard_guard(MeshResource()): + with te.fp8_autocast(enabled=True, recipe=recipe, mesh_resource=MeshResource()): + # Initialize layer and data + layer = DenseGeneral(features=1024) + key = jax.random.PRNGKey(0) + x = jax.random.normal(key, (32, 128, 1024), dtype=jnp.bfloat16) + params = layer.init(key, x) + + # Training step + def loss_fn(params): + output = layer.apply(params, x) + return output.sum() + + loss, grads = jax.value_and_grad(loss_fn)(params) + + # Update parameters + optimizer = optax.adamw(learning_rate=1e-4) + opt_state = optimizer.init(params) + updates, opt_state = optimizer.update(grads, opt_state, params) + params = optax.apply_updates(params, updates) + +# END_NVFP4_EXAMPLE diff --git a/docs/features/low_precision_training/nvfp4/nvfp4.rst b/docs/features/low_precision_training/nvfp4/nvfp4.rst new file mode 100644 index 0000000000..cc8ad6e747 --- /dev/null +++ b/docs/features/low_precision_training/nvfp4/nvfp4.rst @@ -0,0 +1,261 @@ +.. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + + See LICENSE for license information. + +NVFP4 +=================================== + +NVFP4 is the first 4-bit recipe introduced in Transformer Engine – +please refer to the `NVFP4 paper `__ for more details. +It is a more complex recipe than the previous ones – apart from the new data format, +it introduces multiple features which help training stability. + +Data Format +---------------------- + +The NVFP4 datatype consists of 1 sign bit, 2 exponent bits, and 1 mantissa bit (E2M1). +It can represent values of magnitude up to +/- 6. +NVFP4 uses a hierarchical block scaling approach where multiple scaling factors are combined to recover the high precision value. + +.. raw:: html + :file: img/nvfp4_vs_fp8.svg + +*Figure 1. Bit layout comparison between standard FP8 formats (E4M3 and E5M2) and NVFP4 (E2M1).* + + +The representation of an NVFP4 tensor element ``x`` is given by: + +.. code-block:: python + + x = x_e2m1 * s_block * s_global + +where + +* ``x_e2m1`` is the 4-bit value, +* ``s_block`` is a local **FP8 E4M3** scaling factor shared by a block of 16 consecutive elements, +* ``s_global`` is a global **FP32** scaling factor applied to the entire tensor. + +**Scaling Factor Computation** + +The scaling factors are computed as follows: + +1. Global scaling factor (``s_global``): + +.. code-block:: python + + s_global = global_amax / (fp8_max * fp4_max) + # where: + # - global_amax: maximum absolute value across the entire tensor + # - fp8_max: maximum representable value in FP8 E4M3 (448.0) + # - fp4_max: maximum representable value in NVFP4 E2M1 (6.0) + +2. Block scaling factor (``s_block``): + +.. code-block:: python + + s_block = (block_amax / fp4_max) / s_global + # where: + # - block_amax: maximum absolute value within the block + # - fp4_max: maximum representable value in NVFP4 E2M1 (6.0) + # - s_block is stored in FP8 E4M3 format + + +.. raw:: html + :file: img/nvfp4_hierarchical_scaling.svg + +*Figure 2. NVFP4 hierarchical scaling structure showing the combination of block-level and global scaling factors.* + +This hierarchical structure uses fine-grained block scaling +to adapt to local magnitude variations and global scaling +to handle the overall dynamic range. + +**2D weight scaling** + +NVFP4 can be: + +* 1 dimensional - each block of 16 consecutive elements shares a scaling factor, +* 2 dimensional - each block of 16x16 elements shares a scaling factor. + +By default, NVFP4 uses 2D scaling for weights and 1D scaling for activations and gradients. +Set ``disable_2d_quantization=True`` in the recipe configuration to force 1D scaling for weights as well (activations and gradients always use 1D). +The motivation for using 2D scaling for weights is to ensure that rowwise and columnwise +quantized tensors are numerically equivalent. +Please refer to the `NVFP4 paper `__ for more details. + + +Stochastic Rounding +------------------- + +Stochastic rounding is applied when casting scaled values to NVFP4 format. Instead of deterministic rounding +(always rounding to nearest even value), each scaled value is probabilistically rounded to one of the two +nearest representable NVFP4 values. The probability of rounding to a given value is inversely proportional to +the distance to that value, which ensures that the expected value of the quantized +tensor equals the original value, eliminating systematic quantization bias during training. +Stochastic rounding is hardware-accelerated using native GPU instructions introduced with the +Blackwell architecture. + +.. raw:: html + :file: img/stochastic_rounding.svg + +*Figure 3. Stochastic rounding illustration. Given a value* ``x`` *to be quantized, and the two nearest +representable NVFP4 values* ``v1`` *(lower) and* ``v2`` *(higher), deterministic rounding always +rounds to the nearest value, while stochastic rounding probabilistically rounds to either value. +If* ``x`` *is 40% of the way from* ``v1`` *to* ``v2``, *there is a 60% chance of rounding to* ``v1`` +*and a 40% chance of rounding to* ``v2``. + +Stochastic rounding is enabled only for gradients. It can be disabled by setting +``disable_stochastic_rounding=True`` in the recipe configuration. + + +Random Hadamard Transform +-------------------------- + +Random Hadamard Transform (RHT) applies an orthogonal rotation to the tensor **before quantization**, +smoothing outliers in the tensor distributions and making them easier to represent accurately in NVFP4. +RHT is applied to columnwise quantization of inputs and gradients, which are operands +for the **wgrad GEMM**. This GEMM – according to the paper – is particularly sensitive +to quantization errors, hence the additional outlier smoothing. +RHT is supported only for BF16 inputs/gradients; other dtypes will raise an error. + +The transform is defined as: + +.. math:: + + x' = x H + +where :math:`H` is the RHT matrix defined below. The quantization scale factor is computed +from the rotated tensor :math:`x'`. + +**Hadamard matrix** + +The :math:`d \times d` Hadamard matrix has elements :math:`\pm 1` and satisfies :math:`H_d H_d^T = d I`. +When normalized by :math:`1/\sqrt{d}`, the matrix becomes orthogonal and can be applied +to both operands of a matrix multiplication: + +.. math:: + + C = (AH)(H^T B) = AB + +where the transforms cancel within the dot-product since :math:`H H^T = I`. + +**Sign matrix** + +In the RHT implementation, a :math:`d`-dimensional diagonal sign matrix :math:`S_d` is applied +together with the Hadamard matrix: + +.. math:: + + H = \frac{1}{\sqrt{d}} S_d H_d + +where diagonal entries of :math:`S_d` are :math:`\{-1, 1\}` and flip the signs of different rows of :math:`H_d`. +As described in the paper, a single random sign vector is shared across all linear layers throughout training. +In the implementation, this vector is fixed and the RHT matrix is computed once at initialization and cached. + +**Tiled implementation** + +The Hadamard transform is performed in a tiled approach along the last dimension of the tensor. +For an :math:`m \times k` tensor, the data is reshaped to :math:`(mk/d) \times d` +and multiplied by the :math:`d \times d` matrix :math:`H`. In this implementation, :math:`d = 16`. + + +.. raw:: html + :file: img/rht.svg + +*Figure 4. WGRAD GEMM pipeline comparison: without RHT (left) and with RHT applied (right).* + +Handling transposes +------------------- + +Like :doc:`MXFP8 <../mxfp8/mxfp8>`, NVFP4 requires both rowwise and columnwise quantized tensors +for different GEMM operands. Unlike MXFP8 which supports multiple layouts (TN, NT, NN), +**NVFP4 GEMM only supports the TN layout**. + +NVFP4 stores columnwise data and scaling factors in a **transposed layout**: + +- **Rowwise**: data ``[A, B]`` with 1×16 horizontal blocks, ``scales`` shape ``[A, B/16]`` +- **Columnwise**: data ``[B, A]`` (transposed) with 1×16 horizontal blocks, ``scales`` shape ``[B, A/16]`` + +Scale tensors are padded for hardware alignment: first dimension to a multiple of 128, +second dimension to a multiple of 4 (e.g. rowwise: ``[roundup(A, 128), roundup(B/16, 4)]``). + +.. raw:: html + :file: img/nvfp4_row_col.svg + +*Figure 5. NVFP4 rowwise vs columnwise quantization layout. Unlike MXFP8, columnwise scales are stored transposed.* + + +Swizzling scaling factors +------------------------- + +NVFP4 requires swizzling of block scaling factors (``s_block``) before GEMM operations, +similar to :doc:`MXFP8 <../mxfp8/mxfp8>`. Key differences: + +- Block size is 16 (vs 32 for MXFP8) +- Both rowwise and columnwise scaling factors are swizzled, but thanks to the transposed + columnwise layout, a single rowwise swizzle kernel handles both cases. +- Scaling factors are stored as FP8 E4M3 (vs E8M0 for MXFP8) + + +Distributed training +-------------------- + +**Amax reduction** + +Block scaling factors (``s_block``) do not require synchronization between nodes, +as each scaling factor is local to its block of 16 elements. +However, the global scaling factor (``s_global``) requires amax synchronization for gathered tensors. +For tensors that are gathered (e.g., input and gradient in sequence parallelism), +amax reduction is performed before quantization. +If before synchronization there was ``amax_1`` on node 1, +``amax_2`` on node 2, etc., after synchronization there will be ``max(amax_1, amax_2, ...)`` on all nodes. + +**Quantized all-gather** + +All-gather of columnwise tensors is supported. To enable quantized all-gather, +all nodes must use the same ``s_global``, which is computed from the synchronized global amax. +This is automatically enabled for column-parallel and row-parallel linear layers. + +.. raw:: html + :file: img/nvfp4_all_gather.svg + +*Figure 6. Quantization and all-gather flow for NVFP4 showing amax synchronization and hierarchical scaling.* + +Examples +-------- + +Here's how to use NVFP4 recipe in PyTorch and JAX. The examples show how to configure features like 2D weight quantization and Random Hadamard Transform (RHT): + +.. tabs:: + + .. tab:: PyTorch + + .. raw:: html + +
+ Requires SM100 (Blackwell) or later +
+ + .. literalinclude:: pytorch_nvfp4_example.py + :language: python + :start-after: # START_NVFP4_EXAMPLE + :end-before: # END_NVFP4_EXAMPLE + + .. tab:: JAX + + .. raw:: html + +
+ Requires SM100 (Blackwell) or later +
+ + .. literalinclude:: jax_nvfp4_example.py + :language: python + :start-after: # START_NVFP4_EXAMPLE + :end-before: # END_NVFP4_EXAMPLE + + +Supported devices +----------------- + +Blackwell and later (SM 10.0+) diff --git a/docs/features/low_precision_training/nvfp4/pytorch_nvfp4_example.py b/docs/features/low_precision_training/nvfp4/pytorch_nvfp4_example.py new file mode 100644 index 0000000000..0736122ead --- /dev/null +++ b/docs/features/low_precision_training/nvfp4/pytorch_nvfp4_example.py @@ -0,0 +1,33 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +# START_NVFP4_EXAMPLE + +import torch +import transformer_engine.pytorch as te +from transformer_engine.common.recipe import NVFP4Recipe, Format + +# Define NVFP4 recipe +# Key features like 2D weight quantization and RHT can be enabled here +recipe = NVFP4Recipe( + fp8_format=Format.E4M3, + use_2d_weight_quantization=True, + use_rht=True, +) + +# Create a linear layer and optimizer +layer = te.Linear(1024, 1024) +optimizer = torch.optim.AdamW(layer.parameters(), lr=1e-4) + +# Training step +inp = torch.randn(32, 128, 1024, dtype=torch.bfloat16, device="cuda") + +with te.fp8_autocast(enabled=True, fp8_recipe=recipe): + output = layer(inp) + loss = output.sum() + +loss.backward() +optimizer.step() + +# END_NVFP4_EXAMPLE diff --git a/docs/features/low_precision_training/performance_considerations/fused_layers_jax.out b/docs/features/low_precision_training/performance_considerations/fused_layers_jax.out new file mode 100644 index 0000000000..e69de29bb2 diff --git a/docs/features/low_precision_training/performance_considerations/fused_layers_jax.py b/docs/features/low_precision_training/performance_considerations/fused_layers_jax.py new file mode 100644 index 0000000000..4a1fb55b38 --- /dev/null +++ b/docs/features/low_precision_training/performance_considerations/fused_layers_jax.py @@ -0,0 +1,43 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +# Requires Ada (SM89) or Hopper (SM90), different results on Blackwell+ + +# START_FUSED_LAYERS + +import jax +import jax.numpy as jnp +import transformer_engine.jax as te +from transformer_engine.jax.flax import LayerNorm, DenseGeneral, LayerNormDenseGeneral +from transformer_engine.jax.sharding import MeshResource, global_shard_guard +from transformer_engine.common.recipe import DelayedScaling + +with global_shard_guard(MeshResource()): + key = jax.random.PRNGKey(0) + x = jax.random.normal(key, (32, 128, 1024), dtype=jnp.bfloat16) + + # Example 1: Separate LayerNorm and DenseGeneral layers + layer_norm = LayerNorm() + dense = DenseGeneral(features=1024) + + # Initialize parameters + ln_params = layer_norm.init(key, x) + dense_params = dense.init(key, x) + + # Two separate operations + normalized = layer_norm.apply(ln_params, x) + output_separate = dense.apply(dense_params, normalized) + + # Example 2: Fused LayerNormDenseGeneral layer + fused_layer = LayerNormDenseGeneral(features=1024) + + # Initialize and apply with FP8 autocast + recipe = DelayedScaling() + with te.fp8_autocast(enabled=True, fp8_recipe=recipe, mesh_resource=MeshResource()): + fused_params = fused_layer.init(key, x) + output_fused = fused_layer.apply(fused_params, x) + + # The fused layer is more efficient as it combines LayerNorm and quantization + +# END_FUSED_LAYERS diff --git a/docs/features/low_precision_training/performance_considerations/fused_layers_pytorch.out b/docs/features/low_precision_training/performance_considerations/fused_layers_pytorch.out new file mode 100644 index 0000000000..d25b3f6a63 --- /dev/null +++ b/docs/features/low_precision_training/performance_considerations/fused_layers_pytorch.out @@ -0,0 +1,8 @@ +/usr/local/lib/python3.12/dist-packages/torch/library.py:356: UserWarning: Warning only once for all operators, other operators may also be overridden. + Overriding a previously registered kernel for the same operator and the same dispatch key + operator: flash_attn::_flash_attn_backward(Tensor dout, Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor(a6!)? dq, Tensor(a7!)? dk, Tensor(a8!)? dv, float dropout_p, float softmax_scale, bool causal, SymInt window_size_left, SymInt window_size_right, float softcap, Tensor? alibi_slopes, bool deterministic, Tensor? rng_state=None) -> Tensor + registered at /usr/local/lib/python3.12/dist-packages/torch/_library/custom_ops.py:922 + dispatch key: ADInplaceOrView + previous kernel: no debug info + new kernel: registered at /usr/local/lib/python3.12/dist-packages/torch/_library/custom_ops.py:922 (Triggered internally at /opt/pytorch/pytorch/aten/src/ATen/core/dispatch/OperatorEntry.cpp:208.) + self.m.impl( diff --git a/docs/features/low_precision_training/performance_considerations/fused_layers_pytorch.py b/docs/features/low_precision_training/performance_considerations/fused_layers_pytorch.py new file mode 100644 index 0000000000..1a9a1baf2e --- /dev/null +++ b/docs/features/low_precision_training/performance_considerations/fused_layers_pytorch.py @@ -0,0 +1,37 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import torch + +# Requires Ada (SM89) or Hopper (SM90), different results on Blackwell+ +cc = torch.cuda.get_device_capability() +assert cc[0] == 8 and cc[1] >= 9 or cc[0] == 9, "This example requires SM89 (Ada) or SM90 (Hopper)" + +# START_FUSED_LAYERS + +import torch +import transformer_engine.pytorch as te +from transformer_engine.common.recipe import DelayedScaling + +# Example 1: Separate LayerNorm and Linear layers +layer_norm = te.LayerNorm(1024) +linear = te.Linear(1024, 1024) + +inp = torch.randn(32, 128, 1024, dtype=torch.bfloat16, device="cuda") + +# Two separate operations: LayerNorm produces FP32, then Linear quantizes it +normalized = layer_norm(inp) +output_separate = linear(normalized) + +# Example 2: Fused LayerNormLinear layer +fused_layer = te.LayerNormLinear(1024, 1024, params_dtype=torch.bfloat16) + +# Single operation: LayerNorm output is directly quantized +recipe = DelayedScaling() +with te.autocast(enabled=True, recipe=recipe): + output_fused = fused_layer(inp) + +# The fused layer is more efficient as it avoids redundant quantization + +# END_FUSED_LAYERS diff --git a/docs/features/low_precision_training/performance_considerations/img/fused_layers.svg b/docs/features/low_precision_training/performance_considerations/img/fused_layers.svg new file mode 100644 index 0000000000..68e7cc8e8d --- /dev/null +++ b/docs/features/low_precision_training/performance_considerations/img/fused_layers.svg @@ -0,0 +1,120 @@ + + + + + + + + + + + LayerNorm + Linear: Separate vs Fused + + + + + + Scenario 1: Separate Layers + + + + Input + + + + + + + LayerNorm + + + + + + + Output + + + + + + + Linear + + + + Quantize + + + + + + + FP8 tensor + + + + + + + ... + + + + + + + Output + + + + Scenario 2: Fused Layer + + + + Input + + + + + + + LayerNormLinear + + + + + LayerNorm + Quantize + + + + + + + FP8 tensor + + + + + + + ... + + + + + + + Output + + diff --git a/docs/features/low_precision_training/performance_considerations/img/gemm_access_pattern.svg b/docs/features/low_precision_training/performance_considerations/img/gemm_access_pattern.svg new file mode 100644 index 0000000000..df5102090e --- /dev/null +++ b/docs/features/low_precision_training/performance_considerations/img/gemm_access_pattern.svg @@ -0,0 +1,218 @@ + + + + + + + + + + A × B + + + + B + columnwise + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + A + + + + + + + + + + + + + + + + + + + + + + + + + + + rowwise + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + A × BT + + + + B + rowwise + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + A + + + + + + + + + + + + + + + + + + + + + + + + + + + rowwise + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/features/low_precision_training/performance_considerations/img/hopper_vs_blackwell_layout.svg b/docs/features/low_precision_training/performance_considerations/img/hopper_vs_blackwell_layout.svg new file mode 100644 index 0000000000..6f9bc4d5a1 --- /dev/null +++ b/docs/features/low_precision_training/performance_considerations/img/hopper_vs_blackwell_layout.svg @@ -0,0 +1,122 @@ + + + + + + + + FP8 tensor on Hopper + + + + rowwise + + + 0 + + 1 + + 2 + + 3 + + + 4 + + 5 + + 6 + + 7 + + + 8 + + 9 + + 10 + + 11 + + + + + columnwise + + + 0 + + 4 + + 8 + + + 1 + + 5 + + 9 + + + 2 + + 6 + + 10 + + + 3 + + 7 + + 11 + + + + + + + + FP8 tensor on Blackwell + + + + rowwise and columnwise + + + 0 + + 1 + + 2 + + 3 + + + 4 + + 5 + + 6 + + 7 + + + 8 + + 9 + + 10 + + 11 + + + diff --git a/docs/features/low_precision_training/performance_considerations/img/sequence_parallel_quantization.svg b/docs/features/low_precision_training/performance_considerations/img/sequence_parallel_quantization.svg new file mode 100644 index 0000000000..5dfbdcf814 --- /dev/null +++ b/docs/features/low_precision_training/performance_considerations/img/sequence_parallel_quantization.svg @@ -0,0 +1,159 @@ + + + + + + + + + + + All-Gather of Quantized Tensors (one scenario) + + + Input Tensor quantized all-gather + + + FWD: + + + + High Precision + Tensor + + + + + + + Quantize + + + + + + + Rowwise + Quantized + + + + + + + All-Gather + + + + + + ... + + + BWD: + + + + + + + Columnwise + Quantized + + + + + + + All-Gather + + + + + + ... + + + + + + Gradient Tensor quantized all-gather + + + BWD: + + + + High Precision + Tensor + + + + + + + Quantize + + + + + + + Col. Quantized + + + + + + + Row. Quantized + + + + + + + + + + All-Gather + + + + + + ... + + + + + High Precision (FP32/BF16/FP16) + + + Lower Precision (FP8, etc.) + + + Quantization + + + All-Gather + + + + diff --git a/docs/features/low_precision_training/performance_considerations/img/transpose_fusion.svg b/docs/features/low_precision_training/performance_considerations/img/transpose_fusion.svg new file mode 100644 index 0000000000..1d7dc3d813 --- /dev/null +++ b/docs/features/low_precision_training/performance_considerations/img/transpose_fusion.svg @@ -0,0 +1,181 @@ + + + + + + + + + + + Option 1: Quantize both usages in forward + + + FORWARD: + + + + High Precision + Tensor + + + + + + + Quantize + + + + + + + Quantized + Rowwise + + + BACKWARD: + + + + + + + Quantized + Columnwise + + + + + + Option 2: Separate Quantizations (quantize when needed) + + + FORWARD: + + + + High Precision + Tensor + + + + + + + Quantize + + + + + + + Quantized + Rowwise + + + + + + BACKWARD: + + + + High Precision + Tensor + + + + + + + Quantize + + + + + + + Quantized + Columnwise + + + + + + Option 3: Convert Rowwise to Columnwise in Backward (reuse saved tensor) + + + FORWARD: + + + + High Precision + Tensor + + + + + + + Quantize + + + + + + + Quantized + Rowwise + + + + + + BACKWARD: + + + + Quantized + Rowwise + + + + + + + Make + Columnwise + + + + + + + Quantized + Columnwise + + + + + High Precision (FP32/BF16/FP16) + + + Lower Precision (FP8, etc.) + + + Quantization / Make Columnwise + + + + diff --git a/docs/features/low_precision_training/performance_considerations/memory_usage_1_jax.out b/docs/features/low_precision_training/performance_considerations/memory_usage_1_jax.out new file mode 100644 index 0000000000..e4b5e03df8 --- /dev/null +++ b/docs/features/low_precision_training/performance_considerations/memory_usage_1_jax.out @@ -0,0 +1,4 @@ +# START_MEMORY_USAGE_1 +Layer size: 2.00 MB +Memory usage after forward pass: 6.00 MB +# END_MEMORY_USAGE_1 diff --git a/docs/features/low_precision_training/performance_considerations/memory_usage_1_jax.py b/docs/features/low_precision_training/performance_considerations/memory_usage_1_jax.py new file mode 100644 index 0000000000..a634bbd8ac --- /dev/null +++ b/docs/features/low_precision_training/performance_considerations/memory_usage_1_jax.py @@ -0,0 +1,44 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +# Requires Ada (SM89) or Hopper (SM90), different results on Blackwell+ + +print("# START_MEMORY_USAGE_1") +# START_MEMORY_USAGE_1 + +import jax +import jax.numpy as jnp +from transformer_engine.jax.flax import DenseGeneral +from transformer_engine.jax.sharding import MeshResource, global_shard_guard + + +def measure_memory(): + key = jax.random.PRNGKey(0) + + with global_shard_guard(MeshResource()): + # Initialize a dense layer with high precision parameters + layer = DenseGeneral(features=1024, dtype=jnp.bfloat16) + x = jax.random.normal(key, (1024, 1024), dtype=jnp.bfloat16) + params = layer.init(key, x) + + # Calculate layer size (1024 * 1024 * 2 bytes for BF16) + param_count = 1024 * 1024 + layer_size = param_count * 2 / (1024**2) + + # Forward pass + output = layer.apply(params, x) + + # Memory after forward: weight (2 MB) + input (2 MB) + output (2 MB) = 6 MB + return layer_size, 6.00 + + +# Warmup run +measure_memory() + +# Actual measurement +layer_size, mem_after_forward = measure_memory() +print(f"Layer size: {layer_size:.2f} MB") +print(f"Memory usage after forward pass: {mem_after_forward:.2f} MB") +# END_MEMORY_USAGE_1 +print("# END_MEMORY_USAGE_1") diff --git a/docs/features/low_precision_training/performance_considerations/memory_usage_1_pytorch.out b/docs/features/low_precision_training/performance_considerations/memory_usage_1_pytorch.out new file mode 100644 index 0000000000..f977460e84 --- /dev/null +++ b/docs/features/low_precision_training/performance_considerations/memory_usage_1_pytorch.out @@ -0,0 +1,11 @@ +/usr/local/lib/python3.12/dist-packages/torch/library.py:356: UserWarning: Warning only once for all operators, other operators may also be overridden. + Overriding a previously registered kernel for the same operator and the same dispatch key + operator: flash_attn::_flash_attn_backward(Tensor dout, Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor(a6!)? dq, Tensor(a7!)? dk, Tensor(a8!)? dv, float dropout_p, float softmax_scale, bool causal, SymInt window_size_left, SymInt window_size_right, float softcap, Tensor? alibi_slopes, bool deterministic, Tensor? rng_state=None) -> Tensor + registered at /usr/local/lib/python3.12/dist-packages/torch/_library/custom_ops.py:922 + dispatch key: ADInplaceOrView + previous kernel: no debug info + new kernel: registered at /usr/local/lib/python3.12/dist-packages/torch/_library/custom_ops.py:922 (Triggered internally at /opt/pytorch/pytorch/aten/src/ATen/core/dispatch/OperatorEntry.cpp:208.) + self.m.impl( +# START_MEMORY_USAGE_1 +Memory usage after forward pass: 6.00 MB +# END_MEMORY_USAGE_1 diff --git a/docs/features/low_precision_training/performance_considerations/memory_usage_1_pytorch.py b/docs/features/low_precision_training/performance_considerations/memory_usage_1_pytorch.py new file mode 100644 index 0000000000..5e7f2ae177 --- /dev/null +++ b/docs/features/low_precision_training/performance_considerations/memory_usage_1_pytorch.py @@ -0,0 +1,39 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import torch + +# Requires Ada (SM89) or Hopper (SM90), different results on Blackwell+ +cc = torch.cuda.get_device_capability() +assert cc[0] == 8 and cc[1] >= 9 or cc[0] == 9, "This example requires SM89 (Ada) or SM90 (Hopper)" + +print("# START_MEMORY_USAGE_1") +# START_MEMORY_USAGE_1 +import torch +import transformer_engine.pytorch as te + + +def measure_memory(): + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + + init_memory = torch.cuda.memory_allocated() + layer = te.Linear(1024, 1024, params_dtype=torch.bfloat16) + memory = torch.cuda.memory_allocated() - init_memory + + inp = torch.randn(1024, 1024, dtype=torch.bfloat16, device="cuda") + out = layer(inp) + mem_after_forward = torch.cuda.memory_allocated() - init_memory + + return memory, mem_after_forward + + +# Warmup run +measure_memory() + +# Actual measurement +memory, mem_after_forward = measure_memory() +print(f"Memory usage after forward pass: {mem_after_forward/1024**2:.2f} MB") +# END_MEMORY_USAGE_1 +print("# END_MEMORY_USAGE_1") diff --git a/docs/features/low_precision_training/performance_considerations/memory_usage_2_jax.out b/docs/features/low_precision_training/performance_considerations/memory_usage_2_jax.out new file mode 100644 index 0000000000..fb333daa55 --- /dev/null +++ b/docs/features/low_precision_training/performance_considerations/memory_usage_2_jax.out @@ -0,0 +1,3 @@ +# START_MEMORY_USAGE_2 +Memory after forward pass: 8.00 MB +# END_MEMORY_USAGE_2 diff --git a/docs/features/low_precision_training/performance_considerations/memory_usage_2_jax.py b/docs/features/low_precision_training/performance_considerations/memory_usage_2_jax.py new file mode 100644 index 0000000000..a4db87d807 --- /dev/null +++ b/docs/features/low_precision_training/performance_considerations/memory_usage_2_jax.py @@ -0,0 +1,43 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +# Requires Ada (SM89) or Hopper (SM90), different results on Blackwell+ + +print("# START_MEMORY_USAGE_2") +# START_MEMORY_USAGE_2 + +import jax +import jax.numpy as jnp +import transformer_engine.jax as te +from transformer_engine.jax.flax import DenseGeneral +from transformer_engine.jax.sharding import MeshResource, global_shard_guard +from transformer_engine.common.recipe import DelayedScaling + + +def measure_memory(): + key = jax.random.PRNGKey(0) + recipe = DelayedScaling() + + with global_shard_guard(MeshResource()): + # Initialize layer with BF16 parameters + layer = DenseGeneral(features=1024, dtype=jnp.bfloat16) + x = jax.random.normal(key, (1024, 1024), dtype=jnp.bfloat16) + + # Initialize with FP8 autocast to create fp8_metas + with te.fp8_autocast(enabled=True, fp8_recipe=recipe, mesh_resource=MeshResource()): + params = layer.init(key, x) + output = layer.apply(params, x) + + # Memory usage: 2 MB (weight) + 1 MB (weight in FP8) + 2 MB (input) + 1 MB (input in FP8) + 2 MB (output) = 8 MB + return 8.00 + + +# Warmup run +measure_memory() + +# Actual measurement +mem_after_forward = measure_memory() +print(f"Memory after forward pass: {mem_after_forward:.2f} MB") +# END_MEMORY_USAGE_2 +print("# END_MEMORY_USAGE_2") diff --git a/docs/features/low_precision_training/performance_considerations/memory_usage_2_pytorch.out b/docs/features/low_precision_training/performance_considerations/memory_usage_2_pytorch.out new file mode 100644 index 0000000000..9f7fa90ca1 --- /dev/null +++ b/docs/features/low_precision_training/performance_considerations/memory_usage_2_pytorch.out @@ -0,0 +1,11 @@ +/usr/local/lib/python3.12/dist-packages/torch/library.py:356: UserWarning: Warning only once for all operators, other operators may also be overridden. + Overriding a previously registered kernel for the same operator and the same dispatch key + operator: flash_attn::_flash_attn_backward(Tensor dout, Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor(a6!)? dq, Tensor(a7!)? dk, Tensor(a8!)? dv, float dropout_p, float softmax_scale, bool causal, SymInt window_size_left, SymInt window_size_right, float softcap, Tensor? alibi_slopes, bool deterministic, Tensor? rng_state=None) -> Tensor + registered at /usr/local/lib/python3.12/dist-packages/torch/_library/custom_ops.py:922 + dispatch key: ADInplaceOrView + previous kernel: no debug info + new kernel: registered at /usr/local/lib/python3.12/dist-packages/torch/_library/custom_ops.py:922 (Triggered internally at /opt/pytorch/pytorch/aten/src/ATen/core/dispatch/OperatorEntry.cpp:208.) + self.m.impl( +# START_MEMORY_USAGE_2 +Memory after forward pass: 8.02 MB +# END_MEMORY_USAGE_2 diff --git a/docs/features/low_precision_training/performance_considerations/memory_usage_2_pytorch.py b/docs/features/low_precision_training/performance_considerations/memory_usage_2_pytorch.py new file mode 100644 index 0000000000..276bde4202 --- /dev/null +++ b/docs/features/low_precision_training/performance_considerations/memory_usage_2_pytorch.py @@ -0,0 +1,39 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import torch + +# Requires Ada (SM89) or Hopper (SM90), different results on Blackwell+ +cc = torch.cuda.get_device_capability() +assert cc[0] == 8 and cc[1] >= 9 or cc[0] == 9, "This example requires SM89 (Ada) or SM90 (Hopper)" + +print("# START_MEMORY_USAGE_2") +# START_MEMORY_USAGE_2 +import torch +import transformer_engine.pytorch as te + + +def measure_memory(): + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + + init_memory = torch.cuda.memory_allocated() + layer = te.Linear(1024, 1024, params_dtype=torch.bfloat16) + inp = torch.randn(1024, 1024, dtype=torch.bfloat16, device="cuda") + + with te.autocast(enabled=True): + out = layer(inp) + mem_after_forward = torch.cuda.memory_allocated() - init_memory + + return mem_after_forward + + +# Warmup run +measure_memory() + +# Actual measurement +mem_after_forward = measure_memory() +print(f"Memory after forward pass: {mem_after_forward/1024**2:.2f} MB") +# END_MEMORY_USAGE_2 +print("# END_MEMORY_USAGE_2") diff --git a/docs/features/low_precision_training/performance_considerations/memory_usage_3_jax.out b/docs/features/low_precision_training/performance_considerations/memory_usage_3_jax.out new file mode 100644 index 0000000000..6c12212cd9 --- /dev/null +++ b/docs/features/low_precision_training/performance_considerations/memory_usage_3_jax.out @@ -0,0 +1,4 @@ +# START_MEMORY_USAGE_3 +Layer size: 1.00 MB +Memory after forward pass: 6.00 MB +# END_MEMORY_USAGE_3 diff --git a/docs/features/low_precision_training/performance_considerations/memory_usage_3_jax.py b/docs/features/low_precision_training/performance_considerations/memory_usage_3_jax.py new file mode 100644 index 0000000000..9da104b5a8 --- /dev/null +++ b/docs/features/low_precision_training/performance_considerations/memory_usage_3_jax.py @@ -0,0 +1,48 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +# Requires Ada (SM89) or Hopper (SM90), different results on Blackwell+ + +print("# START_MEMORY_USAGE_3") +# START_MEMORY_USAGE_3 + +import jax +import jax.numpy as jnp +import transformer_engine.jax as te +from transformer_engine.jax.flax import DenseGeneral +from transformer_engine.jax.sharding import MeshResource, global_shard_guard +from transformer_engine.common.recipe import DelayedScaling + + +def measure_memory(): + key = jax.random.PRNGKey(0) + recipe = DelayedScaling() + + with global_shard_guard(MeshResource()): + # Initialize layer with FP8 autocast - stores weights in FP8 + with te.fp8_autocast(enabled=True, fp8_recipe=recipe, mesh_resource=MeshResource()): + layer = DenseGeneral(features=1024, dtype=jnp.bfloat16) + x = jax.random.normal(key, (1024, 1024), dtype=jnp.bfloat16) + params = layer.init(key, x) + + # Layer size with FP8 weights (1024 * 1024 * 1 byte + scaling factors) + param_count = 1024 * 1024 + layer_size_fp8 = param_count * 1 / (1024**2) + + # Forward pass + output = layer.apply(params, x) + + # Memory: 1 MB (weight in FP8) + 2 MB (input) + 1 MB (input in FP8) + 2 MB (output) = 6 MB + return layer_size_fp8, 6.00 + + +# Warmup run +measure_memory() + +# Actual measurement +layer_size, mem_after_forward = measure_memory() +print(f"Layer size: {layer_size:.2f} MB") +print(f"Memory after forward pass: {mem_after_forward:.2f} MB") +# END_MEMORY_USAGE_3 +print("# END_MEMORY_USAGE_3") diff --git a/docs/features/low_precision_training/performance_considerations/memory_usage_3_pytorch.out b/docs/features/low_precision_training/performance_considerations/memory_usage_3_pytorch.out new file mode 100644 index 0000000000..9ccba3d3e6 --- /dev/null +++ b/docs/features/low_precision_training/performance_considerations/memory_usage_3_pytorch.out @@ -0,0 +1,11 @@ +/usr/local/lib/python3.12/dist-packages/torch/library.py:356: UserWarning: Warning only once for all operators, other operators may also be overridden. + Overriding a previously registered kernel for the same operator and the same dispatch key + operator: flash_attn::_flash_attn_backward(Tensor dout, Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor(a6!)? dq, Tensor(a7!)? dk, Tensor(a8!)? dv, float dropout_p, float softmax_scale, bool causal, SymInt window_size_left, SymInt window_size_right, float softcap, Tensor? alibi_slopes, bool deterministic, Tensor? rng_state=None) -> Tensor + registered at /usr/local/lib/python3.12/dist-packages/torch/_library/custom_ops.py:922 + dispatch key: ADInplaceOrView + previous kernel: no debug info + new kernel: registered at /usr/local/lib/python3.12/dist-packages/torch/_library/custom_ops.py:922 (Triggered internally at /opt/pytorch/pytorch/aten/src/ATen/core/dispatch/OperatorEntry.cpp:208.) + self.m.impl( +# START_MEMORY_USAGE_3 +Memory after forward pass: 6.02 MB +# END_MEMORY_USAGE_3 diff --git a/docs/features/low_precision_training/performance_considerations/memory_usage_3_pytorch.py b/docs/features/low_precision_training/performance_considerations/memory_usage_3_pytorch.py new file mode 100644 index 0000000000..d603da2809 --- /dev/null +++ b/docs/features/low_precision_training/performance_considerations/memory_usage_3_pytorch.py @@ -0,0 +1,44 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import torch + +# Requires Ada (SM89) or Hopper (SM90), different results on Blackwell+ +cc = torch.cuda.get_device_capability() +assert cc[0] == 8 and cc[1] >= 9 or cc[0] == 9, "This example requires SM89 (Ada) or SM90 (Hopper)" + +print("# START_MEMORY_USAGE_3") +# START_MEMORY_USAGE_3 +import torch +import transformer_engine.pytorch as te + + +def measure_memory(): + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + + init_memory = torch.cuda.memory_allocated() + + # FP8 forward and backward with FP8 weights + with te.quantized_model_init(enabled=True), torch.no_grad(): + layer_fp8 = te.Linear(1024, 1024, params_dtype=torch.bfloat16) + memory = torch.cuda.memory_allocated() - init_memory + + inp = torch.randn(1024, 1024, dtype=torch.bfloat16, device="cuda") + with te.autocast(enabled=True): + out = layer_fp8(inp) + + mem_after_forward = torch.cuda.memory_allocated() - init_memory + + return memory, mem_after_forward + + +# Warmup run +measure_memory() + +# Actual measurement +memory, mem_after_forward = measure_memory() +print(f"Memory after forward pass: {mem_after_forward/1024**2:.2f} MB") +# END_MEMORY_USAGE_3 +print("# END_MEMORY_USAGE_3") diff --git a/docs/features/low_precision_training/performance_considerations/performance_considerations.rst b/docs/features/low_precision_training/performance_considerations/performance_considerations.rst new file mode 100644 index 0000000000..f71a8e0b0b --- /dev/null +++ b/docs/features/low_precision_training/performance_considerations/performance_considerations.rst @@ -0,0 +1,503 @@ +.. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + + See LICENSE for license information. + +Performance Considerations +=================================== + +.. _handling_transposes: + +Handling transposes +------------------- + +In the last chapter we demonstrated that for FP8 on Hopper architecture, +some tensors need to be physically transposed in memory to perform needed GEMMs. +Dealing with transposes in Transformer low precision training is a bit tricky. +Let's start by introducing the concept of *tensor usages*. + +**Tensor usages** + +Each quantized tensor may have two usages: + +- *rowwise usage* -- which is used for matrix multiplication, when the consecutive elements in row are accessed, +- *columnwise usage* -- which is used for matrix multiplication, when the consecutive elements in column are accessed, + +To understand what access of consecutive elements means, let's consider two matrices ``A`` and ``B`` +and analyze how their elements are accessed during multiplication. + +For NN (non-transposed, non-transposed) multiplication ``C = A * B``, the formula is ``C_ij = sum_k(A_ik * B_kj)``. +To compute element ``C_ij``, we iterate over the i-th row of ``A`` (elements ``A_i0, A_i1, ...``) +and the j-th column of ``B`` (elements ``B_0j, B_1j, ...``). Thus, ``A`` is accessed rowwise +and ``B`` is accessed columnwise. + +For NT (non-transposed, transposed) multiplication ``C = A * B^T``, the formula changes to ``C_ij = sum_k(A_ik * B_jk)``. +Now we iterate over the i-th row of ``A`` and the j-th row of ``B`` (elements ``B_j0, B_j1, ...``). +Both tensors are accessed rowwise. + +The figure below illustrates these access patterns: + +.. figure:: img/gemm_access_pattern.svg + :align: center + :alt: Matrix multiplication access pattern showing rowwise access for first tensor and columnwise access for second tensor + + Figure 1: Access patterns in matrix multiplication for matrices in ``A * B`` and ``A * B^T`` operations. + +Based on the visualization above, we can derive general rules for when each matrix +is accessed in rowwise or columnwise fashion. The key insight is that: + +- The **first tensor** in a matrix multiplication is accessed along its rows (rowwise) when non-transposed, + or along its columns (columnwise) when transposed. +- The **second tensor** follows the opposite pattern: columnwise when non-transposed, rowwise when transposed. + +.. table:: Table 1: Summary of tensor access patterns based on transpose state. + :align: center + + +------------------+--------------+---------------+ + | | First tensor | Second tensor | + +------------------+--------------+---------------+ + | Non-transposed | rowwise | columnwise | + +------------------+--------------+---------------+ + | Transposed | columnwise | rowwise | + +------------------+--------------+---------------+ + +**Input, weight and output gradient usages** + +Now let's apply these rules to a Linear layer. During training, a Linear layer performs +three GEMM operations: one in the forward pass and two in the backward pass. + + +.. table:: Table 2: Tensor access patterns for GEMM operations in a Linear layer during training. + :align: center + + +-------------------+-------------------------------------+---------------------------+---------------------------+ + | GEMM | Formula | First tensor usage | Second tensor usage | + +===================+=====================================+===========================+===========================+ + | Forward | ``output = input * weight^T`` | input: rowwise | weight: rowwise | + +-------------------+-------------------------------------+---------------------------+---------------------------+ + | Weight gradient | ``wgrad = output_grad^T * input`` | output_grad: columnwise | input: columnwise | + +-------------------+-------------------------------------+---------------------------+---------------------------+ + | Input gradient | ``dgrad = output_grad * weight`` | output_grad: rowwise | weight: columnwise | + +-------------------+-------------------------------------+---------------------------+---------------------------+ + +An important observation is that the **forward pass uses only rowwise tensors** - both input +and weight are accessed rowwise. + +The backward pass introduces columnwise access. For weight gradient, both output gradient and input +are accessed columnwise. For input gradient, output gradient is rowwise while weight is columnwise. + +As a result, each tensor (input, weight, output gradient) needs both rowwise and columnwise +usages during training. This has implications for memory layout and transpose operations. + + +**Architecture differences** + +The physical memory layout requirements for rowwise and columnwise usages differ between architectures +and recipes. For FP8 tensors: + +- *Hopper*: cannot efficiently access elements in columnwise fashion, so columnwise tensors need to be physically transposed in memory. +- *Blackwell*: supports columnwise access natively, so no transpose is needed. + +We will see that for most of the recipes and devices, rowwise usage and columnwise usage need different tensors. +Thus by *rowwise tensor* and *columnwise tensor* we mean tensors that are used in rowwise and columnwise usages respectively. + +.. figure:: img/hopper_vs_blackwell_layout.svg + :align: center + :alt: Comparison of rowwise and columnwise tensor layouts on Blackwell vs Hopper + + Figure 2: On Blackwell, rowwise and columnwise usages share the same memory layout. + On Hopper, columnwise usage requires a physical transpose. + +**Quantization fusions** + +This section is relevant only for recipes for which columnwise tensors +are different from rowwise tensors. + +Note that performing rowwise and columnwise quantization at the same time +enables some fusions, which usually lead to better performance. +We showcase 3 example scenarios of producing quantized tensors in rowwise and columnwise usages, +TE will use best possible fusion for given recipe and TE module configuration: + +1. *Computation of quantized tensor in both rowwise and columnwise usages in a single kernel in forward pass*. + + This is the fastest one, + but since the columnwise usage is saved for backward pass, it may lead to increased memory usage, + if the high precision tensor also needs to be saved for backward - for example if it is the attention output which is saved anyway. + +2. *Computation of quantized tensor in rowwise usage in forward pass and fused quantization to produce columnwise usage in backward pass*. + + This is usually slower than the previous one, since high precision tensor needs to be read twice. + It is used for example when high precision tensor is gathered both in forward and in backward + and quantized tensor gather is not implemented for such recipe. + +3. *Computation of quantized tensor in rowwise usage in forward pass and transpose to columnwise usage in backward pass*. + + This is not possible for all recipes, but if it is possible it is more memory efficient than Option 1. + +Transformer Engine uses the best possible fusion internally, so users do not need to worry about the details. +We showcase this issue in the documentation to understand memory consequences of different fusion scenarios. + +.. raw:: html + :file: img/transpose_fusion.svg + +*Figure 3: Three scenarios of producing quantized tensors in rowwise and columnwise usages.* + + + +Memory usage +------------ + +This section discusses memory usage in low precision training. +Contrary to intuition, FP8 training does not always reduce memory compared to BF16/FP16. + +*Master weights* + +Transformer Engine stores weights in high precision and quantizes them to low precision before each GEMM. +Moreover, one can specify the precision of the weights stored in the model - if this can be FP32 or +BF16/FP16 -- or do not store high precision weights in the model at all. There are multiple scenarios to consider, +three of them are listed below: + +1. model weights are in FP32, quantized to low precision before each GEMM, +2. model weights are in BF16/FP16, quantized to low precision before each GEMM, master weights in optimizer are in FP32. +3. model weight are stored directly in low precision, and master weights in optimizer are in FP32. + +Note that all these scenarios may have different memory footprints. + +*Activations saved for backward* + +Unlike weights, activations do not require a high precision copy for optimizer updates. +As shown in Table 2, the input needs rowwise usage in forward and columnwise usage +for weight gradient computation in backward — so it must be saved between passes. + +The memory impact depends on which scenario from Figure 3. +Additionally, on architectures where rowwise and columnwise share the same memory layout +(e.g., FP8 on Blackwell, as shown in Figure 2), a single quantized tensor serves both usages, +reducing memory overhead compared to architectures requiring separate tensors. + +Output gradients, on the other hand, are computed during backward and do not need to be saved — +both rowwise and columnwise usages are produced on the fly as needed. + +The FP8 examples below are analyzed on Hopper (SM90) or Ada (SM89) architecture, where rowwise +and columnwise tensors require separate memory layouts. + +.. tabs:: + + .. tab:: PyTorch + + **1. Baseline: high precision forward pass** + + Let's start with a forward pass in higher precision to establish a baseline. + + .. raw:: html + +
+ Needs to be run on SM89 (Ada) or SM90 (Hopper) +
+ + .. literalinclude:: memory_usage_1_pytorch.py + :language: python + :start-after: # START_MEMORY_USAGE_1 + :end-before: # END_MEMORY_USAGE_1 + + .. raw:: html + +
+ Output: +
+ + .. container:: program-output + + .. literalinclude:: memory_usage_1_pytorch.out + :language: text + :start-after: # START_MEMORY_USAGE_1 + :end-before: # END_MEMORY_USAGE_1 + + Layer size is ``1024 * 1024 * 2 (2 bytes per parameter) = 2MB``. + Memory after forward pass is ``2 MB (weight) + 2 MB (input) + 2 MB (output) = 6 MB``. + + **2. FP8 training with model weights in BF16** + + Now let's see the memory usage in FP8 training with high precision weights. + + .. raw:: html + +
+ Needs to be run on SM89 (Ada) or SM90 (Hopper) +
+ + .. literalinclude:: memory_usage_2_pytorch.py + :language: python + :start-after: # START_MEMORY_USAGE_2 + :end-before: # END_MEMORY_USAGE_2 + + .. raw:: html + +
+ Output: +
+ + .. container:: program-output + + .. literalinclude:: memory_usage_2_pytorch.out + :language: text + :start-after: # START_MEMORY_USAGE_2 + :end-before: # END_MEMORY_USAGE_2 + + Total memory usage is ``2 MB (weight) + 1 MB (weight in FP8) + 2 MB (input) + 1 MB (input in FP8) + 2 MB (output) = 8 MB``. + + **3. FP8 training with model weights stored directly in low precision** + + When model weights are stored directly in low precision, master weights are not needed. + + .. raw:: html + +
+ Needs to be run on SM89 (Ada) or SM90 (Hopper) +
+ + .. literalinclude:: memory_usage_3_pytorch.py + :language: python + :start-after: # START_MEMORY_USAGE_3 + :end-before: # END_MEMORY_USAGE_3 + + .. raw:: html + +
+ Output: +
+ + .. container:: program-output + + .. literalinclude:: memory_usage_3_pytorch.out + :language: text + :start-after: # START_MEMORY_USAGE_3 + :end-before: # END_MEMORY_USAGE_3 + + Total memory usage is ``1 MB (weight in FP8) + 2 MB (input) + 1 MB (input in FP8) + 2 MB (output) = 6 MB``. + Note that columnwise FP8 weight is not computed during initialization with ``torch.no_grad()``. + It will be computed on the first backward pass from the rowwise FP8 weight. + + **4. Saving original input instead of quantized** + + By default, TE saves the columnwise quantized input for the backward pass (needed for weight gradient). + However, when the high precision input is already being saved (e.g., for a residual connection), + keeping an additional quantized copy wastes memory. + + The ``save_original_input=True`` option tells the layer to reference the original high precision input + instead of caching a separate quantized copy. The input is re-quantized during backward when needed. + Below is an example with a residual block where input is kept for the addition: + + .. raw:: html + +
+ Needs to be run on SM89 (Ada) or SM90 (Hopper) +
+ + .. literalinclude:: save_original_input_pytorch.py + :language: python + :start-after: # START_SAVE_ORIGINAL_INPUT + :end-before: # END_SAVE_ORIGINAL_INPUT + + .. raw:: html + +
+ Output: +
+ + .. container:: program-output + + .. literalinclude:: save_original_input_pytorch.out + :language: text + :start-after: # START_SAVE_ORIGINAL_INPUT + :end-before: # END_SAVE_ORIGINAL_INPUT + + .. tab:: JAX + + **1. Baseline: high precision forward pass** + + Let's start with a forward pass in higher precision to establish a baseline. + + .. raw:: html + +
+ Needs to be run on SM89 (Ada) or SM90 (Hopper) +
+ + .. literalinclude:: memory_usage_1_jax.py + :language: python + :start-after: # START_MEMORY_USAGE_1 + :end-before: # END_MEMORY_USAGE_1 + + .. raw:: html + +
+ Output: +
+ + .. container:: program-output + + .. literalinclude:: memory_usage_1_jax.out + :language: text + :start-after: # START_MEMORY_USAGE_1 + :end-before: # END_MEMORY_USAGE_1 + + Layer size is ``1024 * 1024 * 2 (2 bytes per parameter) = 2MB``. + Memory after forward pass is ``2 MB (weight) + 2 MB (input) + 2 MB (output) = 6 MB``. + + **2. FP8 training with master weights in BF16** + + Now let's see the memory usage in FP8 training with high precision weights. + + .. raw:: html + +
+ Needs to be run on SM89 (Ada) or SM90 (Hopper) +
+ + .. literalinclude:: memory_usage_2_jax.py + :language: python + :start-after: # START_MEMORY_USAGE_2 + :end-before: # END_MEMORY_USAGE_2 + + .. raw:: html + +
+ Output: +
+ + .. container:: program-output + + .. literalinclude:: memory_usage_2_jax.out + :language: text + :start-after: # START_MEMORY_USAGE_2 + :end-before: # END_MEMORY_USAGE_2 + + Total memory usage is ``2 MB (weight) + 1 MB (weight in FP8) + 2 MB (input) + 1 MB (input in FP8) + 2 MB (output) = 8 MB``. + + **3. FP8 weights without master weights** + + When master weights are not needed, weights can be stored directly in FP8 using ``fp8_autocast`` during initialization. + + .. raw:: html + +
+ Needs to be run on SM89 (Ada) or SM90 (Hopper) +
+ + .. literalinclude:: memory_usage_3_jax.py + :language: python + :start-after: # START_MEMORY_USAGE_3 + :end-before: # END_MEMORY_USAGE_3 + + .. raw:: html + +
+ Output: +
+ + .. container:: program-output + + .. literalinclude:: memory_usage_3_jax.out + :language: text + :start-after: # START_MEMORY_USAGE_3 + :end-before: # END_MEMORY_USAGE_3 + + Total memory usage is ``1 MB (weight in FP8) + 2 MB (input) + 1 MB (input in FP8) + 2 MB (output) = 6 MB``. + This approach reduces memory footprint by storing weights directly in FP8 format. + +Fused layers +------------ + + +Transformer Engine provides fused layers such as ``LayerNormLinear`` and ``LayerNormMLP`` +that enable kernel fusion optimizations. One key optimization is fusing layer normalization +with quantization. + +In a typical Transformer architecture, LayerNorm precedes a Linear layer. Without fusion, +the LayerNorm outputs in FP32, and the Linear layer must then quantize this input before +performing the GEMM — adding overhead. With ``LayerNormLinear``, these operations are fused +into a single kernel: the LayerNorm output is quantized directly, eliminating the separate +quantization step and reducing memory bandwidth. + + +.. raw:: html + :file: img/fused_layers.svg + +*Figure 4: Comparison of separate LayerNorm and Linear layers versus fused LayerNormLinear layer, showing reduced quantization overhead.* + + +Let's see how we can use fused layers in different frameworks. + +.. tabs:: + + .. tab:: PyTorch + + In PyTorch, Transformer Engine provides fused layers like ``LayerNormLinear`` and ``LayerNormMLP``. + These layers combine normalization and linear operations with optimized quantization. + + .. raw:: html + +
+ Needs to be run on SM89+ (Ada, Hopper, Blackwell, or newer) +
+ + .. literalinclude:: fused_layers_pytorch.py + :language: python + :start-after: # START_FUSED_LAYERS + :end-before: # END_FUSED_LAYERS + + The fused ``LayerNormLinear`` layer is particularly efficient in FP8 training because + it avoids an intermediate quantization step. The LayerNorm output is directly quantized + for the GEMM operation, reducing memory bandwidth and improving performance. + + .. tab:: JAX + + In JAX, Transformer Engine provides fused layers like ``LayerNormDenseGeneral`` and ``LayerNormMLP``. + These layers combine normalization and dense operations with optimized quantization. + + .. raw:: html + +
+ Needs to be run on SM89+ (Ada, Hopper, Blackwell, or newer) +
+ + .. literalinclude:: fused_layers_jax.py + :language: python + :start-after: # START_FUSED_LAYERS + :end-before: # END_FUSED_LAYERS + + The fused ``LayerNormDenseGeneral`` layer is particularly efficient in FP8 training because + it avoids an intermediate quantization step. The LayerNorm output is directly quantized + for the GEMM operation, reducing memory bandwidth and improving performance. + + +Distributed training +-------------------- + +Transformer Engine handles collective operations internally, so users typically don't need to manage +the interaction between communication and low precision computation. + +Recall that each Linear layer involves six tensors: weight, input, output, and their gradients. +Of these, output and gradients are returned in high precision, and weights are generally not +communicated (except in FSDP, which is outside the scope of this section). This leaves two +tensors where low precision communication matters: **input** and **output gradient**. + +For sequence parallelism, TE supports all-gather of quantized tensors. This provides several benefits: + +1. *Reduced memory* — no need to store high precision tensors for backward pass. +2. *Reduced communication* — smaller tensors mean less data to transfer. +3. *Parallelized quantization* — quantization work is distributed across GPUs. + +Support varies by recipe — for example, columnwise quantized all-gather is not available +for all configurations. + +The figure below illustrates one possible all-gather scenario for input and output gradient tensors. +Actual behavior depends on the recipe and module configuration. + +.. raw:: html + :file: img/sequence_parallel_quantization.svg + +*Figure 5: All-gather of quantized tensors for input and gradient tensors. +This is one possible scenario — actual behavior varies depending on the recipe and module configuration.* + + diff --git a/docs/features/low_precision_training/performance_considerations/pytorch_out b/docs/features/low_precision_training/performance_considerations/pytorch_out new file mode 100644 index 0000000000..e69de29bb2 diff --git a/docs/features/low_precision_training/performance_considerations/save_original_input_pytorch.out b/docs/features/low_precision_training/performance_considerations/save_original_input_pytorch.out new file mode 100644 index 0000000000..c7545c4ee7 --- /dev/null +++ b/docs/features/low_precision_training/performance_considerations/save_original_input_pytorch.out @@ -0,0 +1,12 @@ +/usr/local/lib/python3.12/dist-packages/torch/library.py:356: UserWarning: Warning only once for all operators, other operators may also be overridden. + Overriding a previously registered kernel for the same operator and the same dispatch key + operator: flash_attn::_flash_attn_backward(Tensor dout, Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor(a6!)? dq, Tensor(a7!)? dk, Tensor(a8!)? dv, float dropout_p, float softmax_scale, bool causal, SymInt window_size_left, SymInt window_size_right, float softcap, Tensor? alibi_slopes, bool deterministic, Tensor? rng_state=None) -> Tensor + registered at /usr/local/lib/python3.12/dist-packages/torch/_library/custom_ops.py:922 + dispatch key: ADInplaceOrView + previous kernel: no debug info + new kernel: registered at /usr/local/lib/python3.12/dist-packages/torch/_library/custom_ops.py:922 (Triggered internally at /opt/pytorch/pytorch/aten/src/ATen/core/dispatch/OperatorEntry.cpp:208.) + self.m.impl( +# START_SAVE_ORIGINAL_INPUT +save_original_input=False: 25.0 MB +save_original_input=True: 24.0 MB +# END_SAVE_ORIGINAL_INPUT diff --git a/docs/features/low_precision_training/performance_considerations/save_original_input_pytorch.py b/docs/features/low_precision_training/performance_considerations/save_original_input_pytorch.py new file mode 100644 index 0000000000..869be8e763 --- /dev/null +++ b/docs/features/low_precision_training/performance_considerations/save_original_input_pytorch.py @@ -0,0 +1,51 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import torch + +# Requires Ada (SM89) or Hopper (SM90), different results on Blackwell+ +cc = torch.cuda.get_device_capability() +assert cc[0] == 8 and cc[1] >= 9 or cc[0] == 9, "This example requires SM89 (Ada) or SM90 (Hopper)" + +print("# START_SAVE_ORIGINAL_INPUT") +# START_SAVE_ORIGINAL_INPUT +import torch +import transformer_engine.pytorch as te +from transformer_engine.common.recipe import Float8CurrentScaling + +recipe = Float8CurrentScaling() + + +def residual_block(layer, inp): + """Residual connection: input is saved for addition after linear.""" + out = layer(inp) + return out + inp # inp must be kept for this addition + + +def measure_memory(use_save_original): + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + + layer = te.Linear( + 1024, 1024, params_dtype=torch.bfloat16, save_original_input=use_save_original + ) + inp = torch.randn(1024, 1024, dtype=torch.bfloat16, device="cuda", requires_grad=True) + + with te.autocast(enabled=True, recipe=recipe): + out = residual_block(layer, inp) + out.sum().backward() + + return torch.cuda.max_memory_allocated() / 1024**2 + + +# Warmup runs +measure_memory(False) +measure_memory(True) + +# Actual measurements +for use_save_original in [False, True]: + peak = measure_memory(use_save_original) + print(f"save_original_input={use_save_original}: {peak:.1f} MB") +# END_SAVE_ORIGINAL_INPUT +print("# END_SAVE_ORIGINAL_INPUT") diff --git a/docs/index.rst b/docs/index.rst index 37d21c2a5d..afb1f07c46 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -39,6 +39,14 @@ Transformer Engine documentation api/common api/framework + +.. toctree:: + :hidden: + :caption: Features + + features/low_precision_training/index.rst + + .. toctree:: :hidden: :caption: Examples and Tutorials From 51f9327048671215ba491fc17c43f920db0be632 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Mon, 8 Dec 2025 18:38:09 +0100 Subject: [PATCH 02/25] Fix SVG css import path for diagrams Signed-off-by: Pawel Gadzinski --- .../fp8_blockwise_scaling/img/blockwise_swizzle_flow.svg | 2 +- .../fp8_blockwise_scaling/img/combined_scaling.svg | 2 +- .../fp8_blockwise_scaling/img/transpose_handling.svg | 2 +- .../fp8_current_scaling/img/fp8_cast_process.svg | 2 +- .../fp8_current_scaling/img/fp8_current_scaling_all_gather.svg | 2 +- .../fp8_current_scaling/img/fp8_formats.svg | 2 +- .../fp8_current_scaling/img/fp8_scaling_concept.svg | 2 +- .../fp8_current_scaling/img/fp8_tensor_core.svg | 2 +- .../low_precision_training/introduction/img/fp8_linear_flow.svg | 2 +- .../introduction/img/fp_formats_comparison.svg | 2 +- .../introduction/img/master_weights_approaches.svg | 2 +- .../introduction/img/mixed_precision_operations.svg | 2 +- .../low_precision_training/mxfp8/img/fp8_1d_scaling.svg | 2 +- .../features/low_precision_training/mxfp8/img/mxfp8_row_col.svg | 2 +- .../mxfp8/img/mxfp8_swizzle_both_tensors.svg | 2 +- .../low_precision_training/nvfp4/img/nvfp4_all_gather.svg | 2 +- .../nvfp4/img/nvfp4_hierarchical_scaling.svg | 2 +- .../features/low_precision_training/nvfp4/img/nvfp4_row_col.svg | 2 +- docs/features/low_precision_training/nvfp4/img/rht.svg | 2 +- .../performance_considerations/img/fused_layers.svg | 2 +- .../img/sequence_parallel_quantization.svg | 2 +- .../performance_considerations/img/transpose_fusion.svg | 2 +- 22 files changed, 22 insertions(+), 22 deletions(-) diff --git a/docs/features/low_precision_training/fp8_blockwise_scaling/img/blockwise_swizzle_flow.svg b/docs/features/low_precision_training/fp8_blockwise_scaling/img/blockwise_swizzle_flow.svg index 5d76b073d3..afad96d76f 100644 --- a/docs/features/low_precision_training/fp8_blockwise_scaling/img/blockwise_swizzle_flow.svg +++ b/docs/features/low_precision_training/fp8_blockwise_scaling/img/blockwise_swizzle_flow.svg @@ -2,7 +2,7 @@
- Needs to be run on SM89 (Ada) or SM90 (Hopper) -
- - .. literalinclude:: memory_usage_3_jax.py - :language: python - :start-after: # START_MEMORY_USAGE_3 - :end-before: # END_MEMORY_USAGE_3 - - .. raw:: html - -
- Output: -
+ In JAX, unlike PyTorch, FP8 weights are not cached between forward passes. + Weights are stored in BF16 and quantized to FP8 on-the-fly during each forward pass. + This means the memory usage is similar to the baseline. - .. container:: program-output - - .. literalinclude:: memory_usage_3_jax.out - :language: text - :start-after: # START_MEMORY_USAGE_3 - :end-before: # END_MEMORY_USAGE_3 - - Total memory usage is ``1 MB (weight in FP8) + 2 MB (input) + 1 MB (input in FP8) + 2 MB (output) = 6 MB``. - This approach reduces memory footprint by storing weights directly in FP8 format. + .. note:: + + JAX does not currently support storing model weights directly in FP8 format + like PyTorch's ``quantized_model_init``. Weights are always stored in high precision + (BF16/FP32) and quantized to FP8 during computation. Fused layers ------------ diff --git a/docs/features/low_precision_training/performance_considerations/pytorch_out b/docs/features/low_precision_training/performance_considerations/pytorch_out deleted file mode 100644 index e69de29bb2..0000000000 From 9152fc48971721c1efce96ffc4a23216577c82e0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 8 Dec 2025 21:18:20 +0000 Subject: [PATCH 04/25] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../low_precision_training/mxfp8/jax_mxfp8_example.py | 4 +--- .../low_precision_training/nvfp4/jax_nvfp4_example.py | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/docs/features/low_precision_training/mxfp8/jax_mxfp8_example.py b/docs/features/low_precision_training/mxfp8/jax_mxfp8_example.py index 1f50a949d7..d41b1ecfe4 100644 --- a/docs/features/low_precision_training/mxfp8/jax_mxfp8_example.py +++ b/docs/features/low_precision_training/mxfp8/jax_mxfp8_example.py @@ -7,9 +7,7 @@ # Check for Blackwell or newer GPU gpu = jax.devices("gpu")[0] major, minor = gpu.compute_capability.split(".") -assert ( - int(major) >= 10 -), f"MXFP8 requires SM100 (Blackwell) or later, got SM{major}{minor}" +assert int(major) >= 10, f"MXFP8 requires SM100 (Blackwell) or later, got SM{major}{minor}" # START_MXFP8_EXAMPLE diff --git a/docs/features/low_precision_training/nvfp4/jax_nvfp4_example.py b/docs/features/low_precision_training/nvfp4/jax_nvfp4_example.py index 9ecdafddbc..99a16f21a7 100644 --- a/docs/features/low_precision_training/nvfp4/jax_nvfp4_example.py +++ b/docs/features/low_precision_training/nvfp4/jax_nvfp4_example.py @@ -7,9 +7,7 @@ # Check for Blackwell or newer GPU gpu = jax.devices("gpu")[0] major, minor = gpu.compute_capability.split(".") -assert ( - int(major) >= 10 -), f"NVFP4 requires SM100 (Blackwell) or later, got SM{major}{minor}" +assert int(major) >= 10, f"NVFP4 requires SM100 (Blackwell) or later, got SM{major}{minor}" # START_NVFP4_EXAMPLE From a299632bfda1ce35a9b8fa020403c5ce96e4dfa0 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Mon, 8 Dec 2025 22:25:26 +0100 Subject: [PATCH 05/25] Fix JAX memory usage .out files with correct output Signed-off-by: Pawel Gadzinski --- .../performance_considerations/memory_usage_2_jax.out | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/docs/features/low_precision_training/performance_considerations/memory_usage_2_jax.out b/docs/features/low_precision_training/performance_considerations/memory_usage_2_jax.out index 5a17772ae3..85ee423022 100644 --- a/docs/features/low_precision_training/performance_considerations/memory_usage_2_jax.out +++ b/docs/features/low_precision_training/performance_considerations/memory_usage_2_jax.out @@ -1,8 +1,3 @@ # START_MEMORY_USAGE_2 -Traceback (most recent call last): - File "/home/pgadzinski/docs/recipe_docs/TransformerEngine/docs/features/low_precision_training/performance_considerations//memory_usage_2_jax.py", line 44, in - measure_memory() - File "/home/pgadzinski/docs/recipe_docs/TransformerEngine/docs/features/low_precision_training/performance_considerations//memory_usage_2_jax.py", line 35, in measure_memory - with te.autocast(enabled=True, recipe=recipe): - ^^^^^^^^^^^ -AttributeError: module 'transformer_engine.jax' has no attribute 'autocast' +Memory usage after forward pass: 6.01 MB +# END_MEMORY_USAGE_2 From bc3f131f136287ed08d946d3f00efcfcc2c900c0 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Mon, 12 Jan 2026 16:37:02 +0100 Subject: [PATCH 06/25] responded to comments Signed-off-by: Pawel Gadzinski --- .../fp8_blockwise_scaling.rst | 11 +- .../img/transpose_handling.svg | 4 - .../fp8_current_scaling.rst | 14 +-- .../fp8_delayed_scaling.rst | 45 +++---- .../img/mixed_precision_operations.svg | 2 +- .../introduction/introduction.rst | 43 ++++--- .../low_precision_training/mxfp8/mxfp8.rst | 5 +- .../low_precision_training/nvfp4/nvfp4.rst | 13 +- .../img/gemm_access_pattern.svg | 112 ++++++++++-------- .../memory_usage_1_pytorch.out | 9 +- .../memory_usage_2_pytorch.out | 9 +- .../memory_usage_3_pytorch.out | 9 +- .../performance_considerations.rst | 26 ++-- .../save_original_input_pytorch.out | 8 -- 14 files changed, 142 insertions(+), 168 deletions(-) diff --git a/docs/features/low_precision_training/fp8_blockwise_scaling/fp8_blockwise_scaling.rst b/docs/features/low_precision_training/fp8_blockwise_scaling/fp8_blockwise_scaling.rst index f83f3523ab..cc009bf111 100644 --- a/docs/features/low_precision_training/fp8_blockwise_scaling/fp8_blockwise_scaling.rst +++ b/docs/features/low_precision_training/fp8_blockwise_scaling/fp8_blockwise_scaling.rst @@ -6,7 +6,7 @@ FP8 Blockwise Scaling =================================== -FP8 Blockwise Scaling is inspired by the quantization scheme used to train the `DeepSeek-v3 model `__ – +FP8 Blockwise Scaling recipe is inspired by the quantization scheme used to train the `DeepSeek-v3 model `__ – the first open-source large-scale LLM trained entirely in FP8 precision. Unlike the previous recipes, it assigns a dedicated scaling factor to each block of elements. @@ -32,13 +32,13 @@ where *Figure 1. Top: Comparison of standard FP8 scaling (left) using a single scaling factor per tensor versus FP8 blockwise scaling in 1 dimension (right) using multiple scaling factors, one per block of 128 elements. Bottom: FP8 blockwise scaling in 2 dimensions where each 128×128 block in the data tensor has a corresponding -scaling factor, providing fine-grained spatial control over quantization precision.* +scaling factor.* **FP8 format** Unlike FP8 Current/Delayed Scaling, E4M3 is used by default for both forward and backward passes. -Previous recipes used E5M2 for gradients due to its higher dynamic range, -but with multiple scaling factors per tensor, E4M3 is usually sufficient. +Tensor-scaled recipes used E5M2 for gradients due to its higher dynamic range, +but with multiple scaling factors per tensor the dynamic range requirement is lowered, so E4M3 is usually sufficient. The ``fp8_format`` parameter also supports ``HYBRID`` mode (E4M3 for forward, E5M2 for backward). Pure E5M2 training is not supported. @@ -76,7 +76,8 @@ There are some assumptions on the dimensions of the tensor (for both 1D and 2D s Scaling factors are stored as 32-bit floating point numbers. By default, they are constrained to powers of 2 (utilizing the 8 exponent bits of FP32). -This constraint can be relaxed by setting the environment variable ``NVTE_FP8_BLOCK_SCALING_FP32_SCALES=1``. +On Hopper, this constraint can be relaxed by setting the environment variable ``NVTE_FP8_BLOCK_SCALING_FP32_SCALES=1``. +On Blackwell, only powers of 2 are supported. Each block's scaling factor is computed through the following steps: diff --git a/docs/features/low_precision_training/fp8_blockwise_scaling/img/transpose_handling.svg b/docs/features/low_precision_training/fp8_blockwise_scaling/img/transpose_handling.svg index 8e5d4a76e4..e9a3b7b7d1 100644 --- a/docs/features/low_precision_training/fp8_blockwise_scaling/img/transpose_handling.svg +++ b/docs/features/low_precision_training/fp8_blockwise_scaling/img/transpose_handling.svg @@ -20,7 +20,6 @@ Rowwise Quantization - (240 × 120 tensor) @@ -104,7 +103,6 @@ Columnwise Quantization - (120 × 240 tensor) @@ -201,7 +199,6 @@ Rowwise Quantization - (180 × 120 tensor) @@ -272,7 +269,6 @@ Columnwise Quantization - (120 × 180 tensor) diff --git a/docs/features/low_precision_training/fp8_current_scaling/fp8_current_scaling.rst b/docs/features/low_precision_training/fp8_current_scaling/fp8_current_scaling.rst index d55ac91020..f10f5ce5c7 100644 --- a/docs/features/low_precision_training/fp8_current_scaling/fp8_current_scaling.rst +++ b/docs/features/low_precision_training/fp8_current_scaling/fp8_current_scaling.rst @@ -6,7 +6,7 @@ FP8 Current Scaling =================================== -FP8 current scaling is the simplest low precision recipe provided by Transformer Engine. +FP8 current scaling recipe is the simplest low precision recipe provided by Transformer Engine. To understand how this recipe works, we first need to examine what the FP8 data type is and how it differs from other floating point formats. @@ -28,7 +28,7 @@ The FP8 datatype, introduced in Hopper architecture, is actually 2 distinct data By default, Transformer Engine uses a hybrid approach: -* *Forward pass* - activations and weights require more precision, so E4M3 datatype is best used. +* *Forward pass* - activations and weights require more precision, so E4M3 datatype is used to store them. * *Backward pass* - gradients are less susceptible to precision loss but require higher dynamic range, so E5M2 datatype is preferred. The user can configure this behavior via the ``fp8_format`` parameter of the recipe. @@ -38,9 +38,8 @@ Scaling factors --------------- -FP8's limited dynamic range is insufficient for many tensors. -To address this, scaling factors are used. In FP8 Current Scaling there is one **FP32** scale factor per tensor. -The representation of a tensor element ``x`` in FP8 precision is given by: +Limited dynamic range of FP8 datatype is insufficient for many tensors. +To address this, values in the tensor are scaled. FP8 Current Scaling recipe uses one **FP32** scale factor per tensor. The representation of a tensor element ``x`` in FP8 precision is given by: .. code-block:: python @@ -53,13 +52,13 @@ where **FP8 Current Scaling quantization** -Let's look more closely at how quantization to FP8 with scaling factor is implemented in +Let's take a closer look at how quantization to FP8 with scaling factor is implemented in the FP8 Current Scaling recipe. .. raw:: html :file: img/fp8_scaling_concept.svg -*Figure 3: Quantization to FP8 consists of amax computation, scaling to fit the FP8 range and casting to the respective FP8 format.* +*Figure 3: Quantization to FP8 consists of amax (absolute maximum) computation, scaling to fit the FP8 range and casting to the respective FP8 format.* Quantization to FP8 consists of 3 steps: @@ -86,7 +85,6 @@ Hardware support The Hopper architecture introduced FP8 support in Tensor Cores, enabling efficient low-precision computation. Tensor Cores support every combination of E4M3 and E5M2 formats as inputs, allowing flexible precision choices for different operands. -The inputs to an FP8 Tensor Core operation consist of chunks of FP8 tensors along with their corresponding scaling factors. The Tensor Core performs the matrix multiplication in FP8 precision and produces output in higher precision (FP16, BF16, or FP32). .. raw:: html diff --git a/docs/features/low_precision_training/fp8_delayed_scaling/fp8_delayed_scaling.rst b/docs/features/low_precision_training/fp8_delayed_scaling/fp8_delayed_scaling.rst index 772ed73fab..0966079619 100644 --- a/docs/features/low_precision_training/fp8_delayed_scaling/fp8_delayed_scaling.rst +++ b/docs/features/low_precision_training/fp8_delayed_scaling/fp8_delayed_scaling.rst @@ -6,11 +6,13 @@ FP8 Delayed Scaling =================================== -FP8 Delayed Scaling estimates scaling factors from historical amax values rather than computing them -for each tensor. This reduces tensor reads per quantization from two to one, improving memory efficiency. +FP8 Delayed Scaling recipe estimates scaling factors from historical amax values rather than computing them +for each tensor. Compared to Current Scaling recipe, +this reduces tensor reads per quantization from two to one, +improving memory efficiency. -Both this recipe and :doc:`FP8 Current Scaling <../fp8_current_scaling/fp8_current_scaling>` use -the same FP8 formats (E4M3/E5M2) with one float32 scaling factor per tensor. +Both this and :doc:`FP8 Current Scaling <../fp8_current_scaling/fp8_current_scaling>` recipe use +the same FP8 formats (E4M3/E5M2) with one FP32 scaling factor per tensor. Reading the FP8 Current Scaling documentation first is recommended. Quantization with delayed scaling factors @@ -27,7 +29,7 @@ The quantization process works as follows: ``scaling_factor = FP8_MAX / amax`` - where ``amax`` is computed from history using either ``max`` (default) or ``most_recent`` algorithm. + where ``amax`` is computed from history using either ``max`` (maximum over window, default) or ``most_recent`` algorithm. 2. **Quantize the tensor** (one tensor read): Apply the scaling factor and cast to FP8. Values exceeding FP8 range are clipped. @@ -56,8 +58,8 @@ to position 0, and after the pass completes, the history is rotated: Before rotation: [amax_N, amax_1, amax_2, ..., amax_N-1] (amax_N = current, amax_1 = oldest) After rotation: [0, amax_2, ..., amax_N-1, amax_N] (amax_1 dropped, amax_N appended) -The effective history length is ``amax_history_len - 1`` since position 0 is reserved -for the staging area. +The scaling factor is computed **before** the rotation, so it uses all ``amax_history_len`` values. +Position 0 serves as a staging area — it is zeroed after the scale update, ready for the next iteration's amax. The implementation differs between PyTorch and JAX: @@ -70,25 +72,14 @@ The implementation differs between PyTorch and JAX: - Forward: shape ``(amax_history_len, num_gemms * 3)`` — three FP8 tensors per GEMM (input, weight, output) - Backward: shape ``(amax_history_len, num_gemms * 2)`` — two FP8 tensors per GEMM (grad_output, grad_input) - During the first forward pass, modules register their ``amax_history`` tensors - to a **global buffer** associated with the autocast context. When the context exits, - a single CUDA kernel processes all registered tensors at once - performing both - amax reduction across GPUs and history rotation. - - This batched approach (one kernel for all tensors instead of one kernel per tensor) - minimizes kernel launch overhead. + When the autocast context exits, a single CUDA kernel processes all tensors at once — + performing amax reduction across GPUs and history rotation. This batched approach + minimizes kernel launch overhead compared to updating each tensor separately. .. tab:: JAX - Each quantizer maintains its own ``amax_history`` as a Flax variable with shape ``(amax_history_len,)``. - There is no global buffer - each quantizer updates independently. - - The rotation is performed per-quantizer using ``jnp.roll``: - - .. code-block:: python - - updated_amax_history = jnp.roll(amax_history, -1, -1) - amax_history = updated_amax_history.at[0].set(0.0) + Each quantizer maintains its own ``amax_history`` with shape ``(amax_history_len,)`` + and updates independently. Here's how to use FP8 Delayed Scaling in PyTorch and JAX: @@ -124,8 +115,10 @@ Here's how to use FP8 Delayed Scaling in PyTorch and JAX: Distributed Training -------------------- -Since FP8 Delayed Scaling uses the same data formats as FP8 Current Scaling, -transpose gather is not supported. However, amax reduction works slightly differently in different frameworks. +FP8 Delayed Scaling uses the same data formats as FP8 Current Scaling - +all-gather of non-transposed tensors is supported. + +However, amax reduction works slightly differently in different frameworks. .. tabs:: @@ -149,7 +142,7 @@ transpose gather is not supported. However, amax reduction works slightly differ - **First iteration**: All modules must execute on all ranks to register their ``amax_history`` tensors in the global buffer. Mismatched registration - causes the ``all_reduce`` to hang due to different tensor sizes across ranks. + would cause the ``all_reduce`` to hang due to different tensor sizes across ranks. - **Subsequent iterations**: The ``autocast`` context must be entered and exited on all ranks (this triggers the collective reduction). Individual modules can be skipped - if no rank executes a module, its history is not rotated and scale diff --git a/docs/features/low_precision_training/introduction/img/mixed_precision_operations.svg b/docs/features/low_precision_training/introduction/img/mixed_precision_operations.svg index 708e6ea50f..7a61759184 100644 --- a/docs/features/low_precision_training/introduction/img/mixed_precision_operations.svg +++ b/docs/features/low_precision_training/introduction/img/mixed_precision_operations.svg @@ -42,7 +42,7 @@ - Attn * V + Scores * V diff --git a/docs/features/low_precision_training/introduction/introduction.rst b/docs/features/low_precision_training/introduction/introduction.rst index 8a5d6c7aca..422209b1a5 100644 --- a/docs/features/low_precision_training/introduction/introduction.rst +++ b/docs/features/low_precision_training/introduction/introduction.rst @@ -6,7 +6,8 @@ Introduction =================================== -Transformer Engine accelerates deep learning by leveraging low precision formats on NVIDIA GPUs. +Transformer Engine accelerates deep learning on NVIDIA GPUs in several ways, +with low precision training being one of the most important. This chapter introduces mixed precision training and FP8 support. @@ -25,30 +26,33 @@ Let's compare these formats. The key differences between these formats are: * **FP32** (32 bits total): 1 sign bit + 8 exponent bits + 23 mantissa bits – standard single-precision format -* **BF16** (16 bits total): 1 sign bit + 8 exponent bits + 7 mantissa bits – maintains FP32's exponent range but reduced precision +* **BF16** (16 bits total): 1 sign bit + 8 exponent bits + 7 mantissa bits – maintains FP32's exponent range but has reduced precision * **FP16** (16 bits total): 1 sign bit + 5 exponent bits + 10 mantissa bits – reduced range but higher precision than BF16 BF16's advantage is that it shares the same exponent range as FP32, making it easier to convert between the two formats without overflow/underflow issues. -FP16 offers better precision for smaller values but has a more limited dynamic range, +FP16 offers better precision for smaller values but has a limited dynamic range, which results in the need to perform loss scaling to avoid overflow/underflow—see `this paper on loss scaling `__ for more details. **Mixed precision** -Not all operations can run in reduced precision. -Modern deep learning frameworks use *mixed precision training*, where: +Not all operations should be run in reduced precision to preserve accuracy. +Modern deep learning frameworks use *mixed precision training*, +where different operations use different precisions based on their numerical properties: -* *Low precision* is used for matrix multiplications and other compute-heavy operations, which remain numerically stable at lower precision, -* *High precision (FP32)* must be used for numerically sensitive operations to maintain training stability. These include layer normalization, softmax, and loss computations—operations that involve division or exponentiation, where small rounding errors can amplify and propagate through the network, leading to gradient instability or degraded convergence. +* Matrix multiplications are compute-heavy and remain numerically stable at lower precision, making them ideal candidates for acceleration. +* Operations like layer normalization and softmax can work with low precision inputs and outputs, but may use high precision internally or for their weights. +* Operations like loss computation and exponentiation need high precision throughout. **Master weights** -Mixed precision training also raises the question of how to store model weights. +Another consideration in mixed precision training is how to store the model weights. Lower precision formats like FP16 and BF16 have limited representational granularity, which becomes problematic during gradient updates. When a small gradient is added to a not so small weight stored in low precision, the result may round back to the original value if the update falls below the format's precision threshold. -Moreover, some elements of the gradient itself can be too small to be represented in low precision. +Moreover, some elements of the gradient itself can be too small to be represented in low precision, +especially after the accumulation from multiple GPUs in the data parallel training setting. The solution is to maintain *master weights* in FP32. During training, weights are cast to lower precision for forward and backward passes, @@ -62,7 +66,10 @@ There are two common software approaches to storing master weights: while the optimizer maintains FP32 copies alongside momentum and other state. During each step, the optimizer updates its FP32 copy and casts the result back to the model's low-precision weights. - This makes it easier to shard master weights together with other optimizer state, for example in ZeRO optimizer. + + This approach makes it easier to shard master weights together with other optimizer state, for example in ZeRO optimizer. + + Since the casting happens only during the optimizer step, this approach is also faster when optimizer runs less frequently than the model, e.g. when performing gradient accumulation or pipeline parallel training. * *In the model*: The model stores weights directly in FP32, @@ -78,14 +85,11 @@ There are two common software approaches to storing master weights: .. tab:: PyTorch - The PyTorch API of Transformer Engine provides two mechanisms to control precision: + The PyTorch API of Transformer Engine provides several mechanisms to control precision: * **Weight precision**: Use the ``params_dtype`` argument in any TE layer constructor. - * **Computation precision**: Use the ``torch.autocast`` context manager. - - If parameters are set to be in lower precision and no autocast is used, then lower precision is used for computation. - Input is cast to lower precision before the computation inside the layer. - Output precision is the same as autocast precision. + * **Computation precision**: Use the ``torch.autocast`` context manager. When enabled, inputs are cast to the autocast dtype before computation. + * **Input dtype**: When ``torch.autocast`` is not used, the input tensor's dtype determines the computation precision. In this case, inputs and parameters must have matching dtypes. .. literalinclude:: bf16_fp16_training_pytorch.py :language: python @@ -132,7 +136,8 @@ Let's now see how we can train in lower precisions in supported frameworks. :class:`~transformer_engine.common.recipe.Recipe`. Forward computations need to be performed inside the ``autocast`` context manager, - while the ``.backward()`` call should be outside of it. + while the ``.backward()`` call should be outside of it (it inherits the setting from the + corresponding forward pass). Here is a basic example: @@ -234,7 +239,7 @@ used throughout the rest of the documentation. Not all operations run in low precision: -- **Non-attention linear operations**: run in low precision. +- **Linear operations**: run in low precision. - **Attention computations**: run in high precision by default (some recipes allow low precision as an option). - **Other operations** (layer normalization, softmax, etc.): run in high precision. @@ -246,7 +251,7 @@ Within high-precision operations, there are two categories: .. raw:: html :file: img/mixed_precision_operations.svg -*Figure 3: Default single-device forward pass of TransformerLayer operations precision – only linear operations (outside of dot product attention) are in lower precision.* +*Figure 3: Default precision of operations in a TransformerLayer forward pass. Only linear operations are in lower precision. Dot product attention is shown as three separate operations (QK^T, Softmax, Scores * V), though in practice these may be fused into a single kernel.* **Linear layer data flow** diff --git a/docs/features/low_precision_training/mxfp8/mxfp8.rst b/docs/features/low_precision_training/mxfp8/mxfp8.rst index b0c80e837c..21450823b9 100644 --- a/docs/features/low_precision_training/mxfp8/mxfp8.rst +++ b/docs/features/low_precision_training/mxfp8/mxfp8.rst @@ -26,6 +26,7 @@ where * ``x_fp8`` is the FP8 value in E4M3 format, * ``s_block`` is a local **E8M0** scaling factor shared by a block of 32 elements. + E8M0 is an 8-bit format with 8 exponent bits and 0 mantissa bits, representing only powers of 2. **FP8 format** @@ -85,9 +86,9 @@ does not require explicit transposition. However, rowwise and columnwise quantiz - *Rowwise* - 1 scaling factor per 32 consecutive elements along a row (1×32 blocks). - *Columnwise* - 1 scaling factor per 32 consecutive elements along a column (32×1 blocks). -Because the scaling factor blocks have different orientations, rowwise and columnwise MXFP8 tensors +Since the scaling factor blocks have different orientations, rowwise and columnwise MXFP8 tensors are numerically different — one cannot derive one from the other. Both must be quantized -independently from full-precision data. +independently from the full-precision data. .. raw:: html :file: img/mxfp8_row_col.svg diff --git a/docs/features/low_precision_training/nvfp4/nvfp4.rst b/docs/features/low_precision_training/nvfp4/nvfp4.rst index cc8ad6e747..9d7831dcb8 100644 --- a/docs/features/low_precision_training/nvfp4/nvfp4.rst +++ b/docs/features/low_precision_training/nvfp4/nvfp4.rst @@ -66,9 +66,9 @@ The scaling factors are computed as follows: *Figure 2. NVFP4 hierarchical scaling structure showing the combination of block-level and global scaling factors.* -This hierarchical structure uses fine-grained block scaling -to adapt to local magnitude variations and global scaling -to handle the overall dynamic range. +This hierarchical structure uses fine-grained block scaling to handle the tensor's dynamic range, +while the FP4 values represent the block-level dynamic range. The global scaling factor +aligns values to the representable range of the E4M3 × E2M1 combination. **2D weight scaling** @@ -114,9 +114,9 @@ Random Hadamard Transform Random Hadamard Transform (RHT) applies an orthogonal rotation to the tensor **before quantization**, smoothing outliers in the tensor distributions and making them easier to represent accurately in NVFP4. RHT is applied to columnwise quantization of inputs and gradients, which are operands -for the **wgrad GEMM**. This GEMM – according to the paper – is particularly sensitive +for the **wgrad GEMM**. This GEMM is particularly sensitive to quantization errors, hence the additional outlier smoothing. -RHT is supported only for BF16 inputs/gradients; other dtypes will raise an error. +RHT is supported only for BF16 inputs/gradients. The transform is defined as: @@ -258,4 +258,5 @@ Here's how to use NVFP4 recipe in PyTorch and JAX. The examples show how to conf Supported devices ----------------- -Blackwell and later (SM 10.0+) +* **Training**: SM 10.0, SM 10.3 +* **Inference**: SM 10.0+ diff --git a/docs/features/low_precision_training/performance_considerations/img/gemm_access_pattern.svg b/docs/features/low_precision_training/performance_considerations/img/gemm_access_pattern.svg index df5102090e..d61d8c7432 100644 --- a/docs/features/low_precision_training/performance_considerations/img/gemm_access_pattern.svg +++ b/docs/features/low_precision_training/performance_considerations/img/gemm_access_pattern.svg @@ -1,10 +1,12 @@ - + - + - A × B + NN GEMM - - - B - columnwise + + + A - + - - - - - - - + + + + + - + - - + - + + + rowwise - - - A + + + B - + + - - - - - + + + + + + - + + - + + - rowwise + columnwise - - + + + A×B @@ -117,79 +122,82 @@ - + - - A × BT + + TN GEMM - - - B - rowwise + + + A - - + - - - + + + rowwise - - - A + + + B + - + + + + - rowwise + rowwise - - + + + A×BT diff --git a/docs/features/low_precision_training/performance_considerations/memory_usage_1_pytorch.out b/docs/features/low_precision_training/performance_considerations/memory_usage_1_pytorch.out index f977460e84..b00749241d 100644 --- a/docs/features/low_precision_training/performance_considerations/memory_usage_1_pytorch.out +++ b/docs/features/low_precision_training/performance_considerations/memory_usage_1_pytorch.out @@ -1,11 +1,4 @@ -/usr/local/lib/python3.12/dist-packages/torch/library.py:356: UserWarning: Warning only once for all operators, other operators may also be overridden. - Overriding a previously registered kernel for the same operator and the same dispatch key - operator: flash_attn::_flash_attn_backward(Tensor dout, Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor(a6!)? dq, Tensor(a7!)? dk, Tensor(a8!)? dv, float dropout_p, float softmax_scale, bool causal, SymInt window_size_left, SymInt window_size_right, float softcap, Tensor? alibi_slopes, bool deterministic, Tensor? rng_state=None) -> Tensor - registered at /usr/local/lib/python3.12/dist-packages/torch/_library/custom_ops.py:922 - dispatch key: ADInplaceOrView - previous kernel: no debug info - new kernel: registered at /usr/local/lib/python3.12/dist-packages/torch/_library/custom_ops.py:922 (Triggered internally at /opt/pytorch/pytorch/aten/src/ATen/core/dispatch/OperatorEntry.cpp:208.) - self.m.impl( + # START_MEMORY_USAGE_1 Memory usage after forward pass: 6.00 MB # END_MEMORY_USAGE_1 diff --git a/docs/features/low_precision_training/performance_considerations/memory_usage_2_pytorch.out b/docs/features/low_precision_training/performance_considerations/memory_usage_2_pytorch.out index 9f7fa90ca1..8b47519bbb 100644 --- a/docs/features/low_precision_training/performance_considerations/memory_usage_2_pytorch.out +++ b/docs/features/low_precision_training/performance_considerations/memory_usage_2_pytorch.out @@ -1,11 +1,4 @@ -/usr/local/lib/python3.12/dist-packages/torch/library.py:356: UserWarning: Warning only once for all operators, other operators may also be overridden. - Overriding a previously registered kernel for the same operator and the same dispatch key - operator: flash_attn::_flash_attn_backward(Tensor dout, Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor(a6!)? dq, Tensor(a7!)? dk, Tensor(a8!)? dv, float dropout_p, float softmax_scale, bool causal, SymInt window_size_left, SymInt window_size_right, float softcap, Tensor? alibi_slopes, bool deterministic, Tensor? rng_state=None) -> Tensor - registered at /usr/local/lib/python3.12/dist-packages/torch/_library/custom_ops.py:922 - dispatch key: ADInplaceOrView - previous kernel: no debug info - new kernel: registered at /usr/local/lib/python3.12/dist-packages/torch/_library/custom_ops.py:922 (Triggered internally at /opt/pytorch/pytorch/aten/src/ATen/core/dispatch/OperatorEntry.cpp:208.) - self.m.impl( + # START_MEMORY_USAGE_2 Memory after forward pass: 8.02 MB # END_MEMORY_USAGE_2 diff --git a/docs/features/low_precision_training/performance_considerations/memory_usage_3_pytorch.out b/docs/features/low_precision_training/performance_considerations/memory_usage_3_pytorch.out index 9ccba3d3e6..6463130bc2 100644 --- a/docs/features/low_precision_training/performance_considerations/memory_usage_3_pytorch.out +++ b/docs/features/low_precision_training/performance_considerations/memory_usage_3_pytorch.out @@ -1,11 +1,4 @@ -/usr/local/lib/python3.12/dist-packages/torch/library.py:356: UserWarning: Warning only once for all operators, other operators may also be overridden. - Overriding a previously registered kernel for the same operator and the same dispatch key - operator: flash_attn::_flash_attn_backward(Tensor dout, Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor(a6!)? dq, Tensor(a7!)? dk, Tensor(a8!)? dv, float dropout_p, float softmax_scale, bool causal, SymInt window_size_left, SymInt window_size_right, float softcap, Tensor? alibi_slopes, bool deterministic, Tensor? rng_state=None) -> Tensor - registered at /usr/local/lib/python3.12/dist-packages/torch/_library/custom_ops.py:922 - dispatch key: ADInplaceOrView - previous kernel: no debug info - new kernel: registered at /usr/local/lib/python3.12/dist-packages/torch/_library/custom_ops.py:922 (Triggered internally at /opt/pytorch/pytorch/aten/src/ATen/core/dispatch/OperatorEntry.cpp:208.) - self.m.impl( + # START_MEMORY_USAGE_3 Memory after forward pass: 6.02 MB # END_MEMORY_USAGE_3 diff --git a/docs/features/low_precision_training/performance_considerations/performance_considerations.rst b/docs/features/low_precision_training/performance_considerations/performance_considerations.rst index 0409bc336b..e97f615d28 100644 --- a/docs/features/low_precision_training/performance_considerations/performance_considerations.rst +++ b/docs/features/low_precision_training/performance_considerations/performance_considerations.rst @@ -132,10 +132,10 @@ TE will use best possible fusion for given recipe and TE module configuration: 3. *Computation of quantized tensor in rowwise usage in forward pass and transpose to columnwise usage in backward pass*. - This is not possible for all recipes, but if it is possible it is more memory efficient than Option 1. + It is more memory efficient than Option 1, but not all recipes can utilize it (otherwise + the quantization accuracy would drop due to double quantization errors). -Transformer Engine uses the best possible fusion internally, so users do not need to worry about the details. -We showcase this issue in the documentation to understand memory consequences of different fusion scenarios. +Transformer Engine chooses the best possible fusion internally taking the recipe and the operation into account. .. raw:: html :file: img/transpose_fusion.svg @@ -152,16 +152,16 @@ Contrary to intuition, FP8 training does not always reduce memory compared to BF *Master weights* -Transformer Engine stores weights in high precision and quantizes them to low precision before each GEMM. +Transformer Engine by default stores weights in high precision and quantizes them to low precision before each GEMM. Moreover, one can specify the precision of the weights stored in the model - if this can be FP32 or BF16/FP16 -- or do not store high precision weights in the model at all. There are multiple scenarios to consider, three of them are listed below: 1. model weights are in FP32, quantized to low precision before each GEMM, 2. model weights are in BF16/FP16, quantized to low precision before each GEMM, master weights in optimizer are in FP32. -3. model weight are stored directly in low precision, and master weights in optimizer are in FP32. +3. model weights are stored directly in low precision, and master weights in optimizer are in FP32. -Note that all these scenarios may have different memory footprints. +Note that each of these scenarios may have different memory footprint. *Activations saved for backward* @@ -169,8 +169,8 @@ Unlike weights, activations do not require a high precision copy for optimizer u As shown in Table 2, the input needs rowwise usage in forward and columnwise usage for weight gradient computation in backward — so it must be saved between passes. -The memory impact depends on which scenario from Figure 3. -Additionally, on architectures where rowwise and columnwise share the same memory layout +The memory impact depends on which scenario from Figure 3 is used. +Additionally, on architectures where rowwise and columnwise usage tensors share the same memory layout (e.g., FP8 on Blackwell, as shown in Figure 2), a single quantized tensor serves both usages, reducing memory overhead compared to architectures requiring separate tensors. @@ -386,7 +386,7 @@ Fused layers ------------ -Transformer Engine provides fused layers such as ``LayerNormLinear`` and ``LayerNormMLP`` +Transformer Engine provides fused layers such as ``LayerNormLinear`` (``LayerNormDenseGeneral`` in JAX) and ``LayerNormMLP`` that enable kernel fusion optimizations. One key optimization is fusing layer normalization with quantization. @@ -394,7 +394,7 @@ In a typical Transformer architecture, LayerNorm precedes a Linear layer. Withou the LayerNorm outputs in FP32, and the Linear layer must then quantize this input before performing the GEMM — adding overhead. With ``LayerNormLinear``, these operations are fused into a single kernel: the LayerNorm output is quantized directly, eliminating the separate -quantization step and reducing memory bandwidth. +quantization step and reducing memory movement. .. raw:: html @@ -425,7 +425,7 @@ Let's see how we can use fused layers in different frameworks. The fused ``LayerNormLinear`` layer is particularly efficient in FP8 training because it avoids an intermediate quantization step. The LayerNorm output is directly quantized - for the GEMM operation, reducing memory bandwidth and improving performance. + for the GEMM operation, reducing memory movement and improving performance. .. tab:: JAX @@ -445,7 +445,7 @@ Let's see how we can use fused layers in different frameworks. The fused ``LayerNormDenseGeneral`` layer is particularly efficient in FP8 training because it avoids an intermediate quantization step. The LayerNorm output is directly quantized - for the GEMM operation, reducing memory bandwidth and improving performance. + for the GEMM operation, reducing memory movement and improving performance. Distributed training @@ -461,7 +461,7 @@ tensors where low precision communication matters: **input** and **output gradie For sequence parallelism, TE supports all-gather of quantized tensors. This provides several benefits: -1. *Reduced memory* — no need to store high precision tensors for backward pass. +1. *Reduced memory usage* — no need to store high precision tensors for backward pass. 2. *Reduced communication* — smaller tensors mean less data to transfer. 3. *Parallelized quantization* — quantization work is distributed across GPUs. diff --git a/docs/features/low_precision_training/performance_considerations/save_original_input_pytorch.out b/docs/features/low_precision_training/performance_considerations/save_original_input_pytorch.out index c7545c4ee7..21227220f8 100644 --- a/docs/features/low_precision_training/performance_considerations/save_original_input_pytorch.out +++ b/docs/features/low_precision_training/performance_considerations/save_original_input_pytorch.out @@ -1,11 +1,3 @@ -/usr/local/lib/python3.12/dist-packages/torch/library.py:356: UserWarning: Warning only once for all operators, other operators may also be overridden. - Overriding a previously registered kernel for the same operator and the same dispatch key - operator: flash_attn::_flash_attn_backward(Tensor dout, Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor(a6!)? dq, Tensor(a7!)? dk, Tensor(a8!)? dv, float dropout_p, float softmax_scale, bool causal, SymInt window_size_left, SymInt window_size_right, float softcap, Tensor? alibi_slopes, bool deterministic, Tensor? rng_state=None) -> Tensor - registered at /usr/local/lib/python3.12/dist-packages/torch/_library/custom_ops.py:922 - dispatch key: ADInplaceOrView - previous kernel: no debug info - new kernel: registered at /usr/local/lib/python3.12/dist-packages/torch/_library/custom_ops.py:922 (Triggered internally at /opt/pytorch/pytorch/aten/src/ATen/core/dispatch/OperatorEntry.cpp:208.) - self.m.impl( # START_SAVE_ORIGINAL_INPUT save_original_input=False: 25.0 MB save_original_input=True: 24.0 MB From 80b2db570e22cbf27398fa37fa092e2cca045bc4 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Mon, 12 Jan 2026 16:48:29 +0100 Subject: [PATCH 07/25] applied suggestions form greptile Signed-off-by: Pawel Gadzinski --- docs/conf.py | 1 - .../fp8_blockwise_scaling/fp8_blockwise_scaling.rst | 2 +- .../jax_blockwise_scaling_example.py | 2 +- .../pytorch_blockwise_scaling_example.py | 4 ++-- .../fp8_current_scaling/fp8_current_scaling.rst | 2 +- .../fp8_current_scaling/jax_current_scaling_example.py | 2 +- .../pytorch_current_scaling_example.py | 4 ++-- .../fp8_delayed_scaling/fp8_delayed_scaling.rst | 2 +- .../jax_delayed_scaling_distributed_example.py | 2 +- .../fp8_delayed_scaling/jax_delayed_scaling_example.py | 2 +- .../pytorch_delayed_scaling_distributed_example.py | 2 +- .../pytorch_delayed_scaling_example.py | 2 +- docs/features/low_precision_training/index.rst | 2 +- .../introduction/autocast_jax.py | 2 +- .../introduction/autocast_pytorch.py | 2 +- .../introduction/bf16_fp16_training_jax.py | 2 +- .../introduction/bf16_fp16_training_pytorch.py | 2 +- .../introduction/introduction.rst | 2 +- .../low_precision_training/mxfp8/jax_mxfp8_example.py | 2 +- docs/features/low_precision_training/mxfp8/mxfp8.rst | 2 +- .../mxfp8/pytorch_mxfp8_example.py | 4 ++-- .../low_precision_training/nvfp4/jax_nvfp4_example.py | 2 +- docs/features/low_precision_training/nvfp4/nvfp4.rst | 2 +- .../nvfp4/pytorch_nvfp4_example.py | 4 ++-- .../performance_considerations/fused_layers_jax.py | 2 +- .../performance_considerations/fused_layers_pytorch.py | 2 +- .../performance_considerations/memory_usage_1_jax.py | 2 +- .../memory_usage_1_pytorch.py | 2 +- .../performance_considerations/memory_usage_2_jax.py | 2 +- .../memory_usage_2_pytorch.py | 2 +- .../memory_usage_3_pytorch.py | 2 +- .../performance_considerations.rst | 10 +++++----- .../save_original_input_pytorch.py | 2 +- 33 files changed, 40 insertions(+), 41 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 786f029578..d2bba9825a 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -89,7 +89,6 @@ "css/sphinx_tabs.css", "css/svg-responsive.css", "css/rtabs.css", - "css/output-style.css", ] html_theme_options = { diff --git a/docs/features/low_precision_training/fp8_blockwise_scaling/fp8_blockwise_scaling.rst b/docs/features/low_precision_training/fp8_blockwise_scaling/fp8_blockwise_scaling.rst index cc009bf111..b0fa98f1af 100644 --- a/docs/features/low_precision_training/fp8_blockwise_scaling/fp8_blockwise_scaling.rst +++ b/docs/features/low_precision_training/fp8_blockwise_scaling/fp8_blockwise_scaling.rst @@ -1,5 +1,5 @@ .. - Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. See LICENSE for license information. diff --git a/docs/features/low_precision_training/fp8_blockwise_scaling/jax_blockwise_scaling_example.py b/docs/features/low_precision_training/fp8_blockwise_scaling/jax_blockwise_scaling_example.py index 5d30e6f09a..e838de1955 100644 --- a/docs/features/low_precision_training/fp8_blockwise_scaling/jax_blockwise_scaling_example.py +++ b/docs/features/low_precision_training/fp8_blockwise_scaling/jax_blockwise_scaling_example.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/docs/features/low_precision_training/fp8_blockwise_scaling/pytorch_blockwise_scaling_example.py b/docs/features/low_precision_training/fp8_blockwise_scaling/pytorch_blockwise_scaling_example.py index 0b285967bd..3bc8c72805 100644 --- a/docs/features/low_precision_training/fp8_blockwise_scaling/pytorch_blockwise_scaling_example.py +++ b/docs/features/low_precision_training/fp8_blockwise_scaling/pytorch_blockwise_scaling_example.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -28,7 +28,7 @@ # Forward and backward pass inp = torch.randn(32, 128, 1024, dtype=torch.bfloat16, device="cuda") -with te.fp8_autocast(enabled=True, fp8_recipe=recipe): +with te.autocast(enabled=True, fp8_recipe=recipe): output = layer(inp) loss = output.sum() diff --git a/docs/features/low_precision_training/fp8_current_scaling/fp8_current_scaling.rst b/docs/features/low_precision_training/fp8_current_scaling/fp8_current_scaling.rst index f10f5ce5c7..1d16cfb029 100644 --- a/docs/features/low_precision_training/fp8_current_scaling/fp8_current_scaling.rst +++ b/docs/features/low_precision_training/fp8_current_scaling/fp8_current_scaling.rst @@ -1,5 +1,5 @@ .. - Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. See LICENSE for license information. diff --git a/docs/features/low_precision_training/fp8_current_scaling/jax_current_scaling_example.py b/docs/features/low_precision_training/fp8_current_scaling/jax_current_scaling_example.py index 2d7c7ed9c4..236fb255c2 100644 --- a/docs/features/low_precision_training/fp8_current_scaling/jax_current_scaling_example.py +++ b/docs/features/low_precision_training/fp8_current_scaling/jax_current_scaling_example.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/docs/features/low_precision_training/fp8_current_scaling/pytorch_current_scaling_example.py b/docs/features/low_precision_training/fp8_current_scaling/pytorch_current_scaling_example.py index 1eef7cf9a9..583cac47db 100644 --- a/docs/features/low_precision_training/fp8_current_scaling/pytorch_current_scaling_example.py +++ b/docs/features/low_precision_training/fp8_current_scaling/pytorch_current_scaling_example.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -20,7 +20,7 @@ # Forward and backward pass inp = torch.randn(32, 128, 1024, dtype=torch.bfloat16, device="cuda") -with te.fp8_autocast(enabled=True, fp8_recipe=recipe): +with te.autocast(enabled=True, fp8_recipe=recipe): output = layer(inp) loss = output.sum() diff --git a/docs/features/low_precision_training/fp8_delayed_scaling/fp8_delayed_scaling.rst b/docs/features/low_precision_training/fp8_delayed_scaling/fp8_delayed_scaling.rst index 0966079619..9e0de084fd 100644 --- a/docs/features/low_precision_training/fp8_delayed_scaling/fp8_delayed_scaling.rst +++ b/docs/features/low_precision_training/fp8_delayed_scaling/fp8_delayed_scaling.rst @@ -1,5 +1,5 @@ .. - Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. See LICENSE for license information. diff --git a/docs/features/low_precision_training/fp8_delayed_scaling/jax_delayed_scaling_distributed_example.py b/docs/features/low_precision_training/fp8_delayed_scaling/jax_delayed_scaling_distributed_example.py index 48f6944ac1..f354ddaf77 100644 --- a/docs/features/low_precision_training/fp8_delayed_scaling/jax_delayed_scaling_distributed_example.py +++ b/docs/features/low_precision_training/fp8_delayed_scaling/jax_delayed_scaling_distributed_example.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/docs/features/low_precision_training/fp8_delayed_scaling/jax_delayed_scaling_example.py b/docs/features/low_precision_training/fp8_delayed_scaling/jax_delayed_scaling_example.py index 0500e2d40d..aea2344bae 100644 --- a/docs/features/low_precision_training/fp8_delayed_scaling/jax_delayed_scaling_example.py +++ b/docs/features/low_precision_training/fp8_delayed_scaling/jax_delayed_scaling_example.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/docs/features/low_precision_training/fp8_delayed_scaling/pytorch_delayed_scaling_distributed_example.py b/docs/features/low_precision_training/fp8_delayed_scaling/pytorch_delayed_scaling_distributed_example.py index 2c99fe1a2c..863b71e8c6 100644 --- a/docs/features/low_precision_training/fp8_delayed_scaling/pytorch_delayed_scaling_distributed_example.py +++ b/docs/features/low_precision_training/fp8_delayed_scaling/pytorch_delayed_scaling_distributed_example.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/docs/features/low_precision_training/fp8_delayed_scaling/pytorch_delayed_scaling_example.py b/docs/features/low_precision_training/fp8_delayed_scaling/pytorch_delayed_scaling_example.py index 628f368641..45d244f47d 100644 --- a/docs/features/low_precision_training/fp8_delayed_scaling/pytorch_delayed_scaling_example.py +++ b/docs/features/low_precision_training/fp8_delayed_scaling/pytorch_delayed_scaling_example.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/docs/features/low_precision_training/index.rst b/docs/features/low_precision_training/index.rst index 39fba07881..8b392d2bbb 100644 --- a/docs/features/low_precision_training/index.rst +++ b/docs/features/low_precision_training/index.rst @@ -1,5 +1,5 @@ .. - Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. See LICENSE for license information. diff --git a/docs/features/low_precision_training/introduction/autocast_jax.py b/docs/features/low_precision_training/introduction/autocast_jax.py index 1c0e91a338..536a1df86d 100644 --- a/docs/features/low_precision_training/introduction/autocast_jax.py +++ b/docs/features/low_precision_training/introduction/autocast_jax.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/docs/features/low_precision_training/introduction/autocast_pytorch.py b/docs/features/low_precision_training/introduction/autocast_pytorch.py index 17d813b3fa..2c1528ff9e 100644 --- a/docs/features/low_precision_training/introduction/autocast_pytorch.py +++ b/docs/features/low_precision_training/introduction/autocast_pytorch.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/docs/features/low_precision_training/introduction/bf16_fp16_training_jax.py b/docs/features/low_precision_training/introduction/bf16_fp16_training_jax.py index 14647daa1b..f2e02a5103 100644 --- a/docs/features/low_precision_training/introduction/bf16_fp16_training_jax.py +++ b/docs/features/low_precision_training/introduction/bf16_fp16_training_jax.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/docs/features/low_precision_training/introduction/bf16_fp16_training_pytorch.py b/docs/features/low_precision_training/introduction/bf16_fp16_training_pytorch.py index 8779f0bff0..4eb6ce1f84 100644 --- a/docs/features/low_precision_training/introduction/bf16_fp16_training_pytorch.py +++ b/docs/features/low_precision_training/introduction/bf16_fp16_training_pytorch.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/docs/features/low_precision_training/introduction/introduction.rst b/docs/features/low_precision_training/introduction/introduction.rst index 422209b1a5..0067b9703e 100644 --- a/docs/features/low_precision_training/introduction/introduction.rst +++ b/docs/features/low_precision_training/introduction/introduction.rst @@ -1,5 +1,5 @@ .. - Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. See LICENSE for license information. diff --git a/docs/features/low_precision_training/mxfp8/jax_mxfp8_example.py b/docs/features/low_precision_training/mxfp8/jax_mxfp8_example.py index d41b1ecfe4..2e0c28286e 100644 --- a/docs/features/low_precision_training/mxfp8/jax_mxfp8_example.py +++ b/docs/features/low_precision_training/mxfp8/jax_mxfp8_example.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/docs/features/low_precision_training/mxfp8/mxfp8.rst b/docs/features/low_precision_training/mxfp8/mxfp8.rst index 21450823b9..dc5816515f 100644 --- a/docs/features/low_precision_training/mxfp8/mxfp8.rst +++ b/docs/features/low_precision_training/mxfp8/mxfp8.rst @@ -1,5 +1,5 @@ .. - Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. See LICENSE for license information. diff --git a/docs/features/low_precision_training/mxfp8/pytorch_mxfp8_example.py b/docs/features/low_precision_training/mxfp8/pytorch_mxfp8_example.py index 19891083b4..3f9c4d3705 100644 --- a/docs/features/low_precision_training/mxfp8/pytorch_mxfp8_example.py +++ b/docs/features/low_precision_training/mxfp8/pytorch_mxfp8_example.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -25,7 +25,7 @@ # Forward and backward pass inp = torch.randn(32, 128, 1024, dtype=torch.bfloat16, device="cuda") -with te.fp8_autocast(enabled=True, fp8_recipe=recipe): +with te.autocast(enabled=True, fp8_recipe=recipe): output = layer(inp) loss = output.sum() diff --git a/docs/features/low_precision_training/nvfp4/jax_nvfp4_example.py b/docs/features/low_precision_training/nvfp4/jax_nvfp4_example.py index 99a16f21a7..1c2e13ff73 100644 --- a/docs/features/low_precision_training/nvfp4/jax_nvfp4_example.py +++ b/docs/features/low_precision_training/nvfp4/jax_nvfp4_example.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/docs/features/low_precision_training/nvfp4/nvfp4.rst b/docs/features/low_precision_training/nvfp4/nvfp4.rst index 9d7831dcb8..3d35346ece 100644 --- a/docs/features/low_precision_training/nvfp4/nvfp4.rst +++ b/docs/features/low_precision_training/nvfp4/nvfp4.rst @@ -1,5 +1,5 @@ .. - Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. See LICENSE for license information. diff --git a/docs/features/low_precision_training/nvfp4/pytorch_nvfp4_example.py b/docs/features/low_precision_training/nvfp4/pytorch_nvfp4_example.py index c34845ae2a..883cc84c7e 100644 --- a/docs/features/low_precision_training/nvfp4/pytorch_nvfp4_example.py +++ b/docs/features/low_precision_training/nvfp4/pytorch_nvfp4_example.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -28,7 +28,7 @@ # Forward and backward pass inp = torch.randn(32, 128, 1024, dtype=torch.bfloat16, device="cuda") -with te.fp8_autocast(enabled=True, fp8_recipe=recipe): +with te.autocast(enabled=True, fp8_recipe=recipe): output = layer(inp) loss = output.sum() diff --git a/docs/features/low_precision_training/performance_considerations/fused_layers_jax.py b/docs/features/low_precision_training/performance_considerations/fused_layers_jax.py index 4bcee27397..4f2f39ca34 100644 --- a/docs/features/low_precision_training/performance_considerations/fused_layers_jax.py +++ b/docs/features/low_precision_training/performance_considerations/fused_layers_jax.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/docs/features/low_precision_training/performance_considerations/fused_layers_pytorch.py b/docs/features/low_precision_training/performance_considerations/fused_layers_pytorch.py index 1a9a1baf2e..2108f45a08 100644 --- a/docs/features/low_precision_training/performance_considerations/fused_layers_pytorch.py +++ b/docs/features/low_precision_training/performance_considerations/fused_layers_pytorch.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/docs/features/low_precision_training/performance_considerations/memory_usage_1_jax.py b/docs/features/low_precision_training/performance_considerations/memory_usage_1_jax.py index d5d1aabb7e..59aedce340 100644 --- a/docs/features/low_precision_training/performance_considerations/memory_usage_1_jax.py +++ b/docs/features/low_precision_training/performance_considerations/memory_usage_1_jax.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/docs/features/low_precision_training/performance_considerations/memory_usage_1_pytorch.py b/docs/features/low_precision_training/performance_considerations/memory_usage_1_pytorch.py index 5e7f2ae177..36bf099490 100644 --- a/docs/features/low_precision_training/performance_considerations/memory_usage_1_pytorch.py +++ b/docs/features/low_precision_training/performance_considerations/memory_usage_1_pytorch.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/docs/features/low_precision_training/performance_considerations/memory_usage_2_jax.py b/docs/features/low_precision_training/performance_considerations/memory_usage_2_jax.py index 378a7c1e06..b2024dda44 100644 --- a/docs/features/low_precision_training/performance_considerations/memory_usage_2_jax.py +++ b/docs/features/low_precision_training/performance_considerations/memory_usage_2_jax.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/docs/features/low_precision_training/performance_considerations/memory_usage_2_pytorch.py b/docs/features/low_precision_training/performance_considerations/memory_usage_2_pytorch.py index 276bde4202..3db76fd1e5 100644 --- a/docs/features/low_precision_training/performance_considerations/memory_usage_2_pytorch.py +++ b/docs/features/low_precision_training/performance_considerations/memory_usage_2_pytorch.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/docs/features/low_precision_training/performance_considerations/memory_usage_3_pytorch.py b/docs/features/low_precision_training/performance_considerations/memory_usage_3_pytorch.py index d603da2809..9041749bde 100644 --- a/docs/features/low_precision_training/performance_considerations/memory_usage_3_pytorch.py +++ b/docs/features/low_precision_training/performance_considerations/memory_usage_3_pytorch.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/docs/features/low_precision_training/performance_considerations/performance_considerations.rst b/docs/features/low_precision_training/performance_considerations/performance_considerations.rst index e97f615d28..9625bba556 100644 --- a/docs/features/low_precision_training/performance_considerations/performance_considerations.rst +++ b/docs/features/low_precision_training/performance_considerations/performance_considerations.rst @@ -1,5 +1,5 @@ .. - Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. See LICENSE for license information. @@ -153,9 +153,9 @@ Contrary to intuition, FP8 training does not always reduce memory compared to BF *Master weights* Transformer Engine by default stores weights in high precision and quantizes them to low precision before each GEMM. -Moreover, one can specify the precision of the weights stored in the model - if this can be FP32 or -BF16/FP16 -- or do not store high precision weights in the model at all. There are multiple scenarios to consider, -three of them are listed below: +Moreover, one can specify which high precision should be used to store the weights in the +model (FP32/BF16/FP16) -- or choose not to store high precision weights in the model at all. +There are multiple scenarios to consider, three of them are listed below: 1. model weights are in FP32, quantized to low precision before each GEMM, 2. model weights are in BF16/FP16, quantized to low precision before each GEMM, master weights in optimizer are in FP32. @@ -391,7 +391,7 @@ that enable kernel fusion optimizations. One key optimization is fusing layer no with quantization. In a typical Transformer architecture, LayerNorm precedes a Linear layer. Without fusion, -the LayerNorm outputs in FP32, and the Linear layer must then quantize this input before +the LayerNorm outputs in high precision, and the Linear layer must then quantize this input before performing the GEMM — adding overhead. With ``LayerNormLinear``, these operations are fused into a single kernel: the LayerNorm output is quantized directly, eliminating the separate quantization step and reducing memory movement. diff --git a/docs/features/low_precision_training/performance_considerations/save_original_input_pytorch.py b/docs/features/low_precision_training/performance_considerations/save_original_input_pytorch.py index 869be8e763..c9efa7107e 100644 --- a/docs/features/low_precision_training/performance_considerations/save_original_input_pytorch.py +++ b/docs/features/low_precision_training/performance_considerations/save_original_input_pytorch.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. From b7eb7e2888e2f0555ef3aa082ffc428b5db05db2 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Mon, 12 Jan 2026 17:20:16 +0100 Subject: [PATCH 08/25] fix Signed-off-by: Pawel Gadzinski --- .../img/fp8_scaling_concept.svg | 104 +++++++++--------- .../memory_usage_1_jax.py | 1 + .../memory_usage_1_pytorch.py | 8 +- .../memory_usage_2_jax.py | 1 + .../memory_usage_2_pytorch.out | 2 +- .../memory_usage_2_pytorch.py | 5 +- .../memory_usage_3_pytorch.out | 2 +- .../memory_usage_3_pytorch.py | 15 +-- .../performance_considerations.rst | 16 +-- 9 files changed, 80 insertions(+), 74 deletions(-) diff --git a/docs/features/low_precision_training/fp8_current_scaling/img/fp8_scaling_concept.svg b/docs/features/low_precision_training/fp8_current_scaling/img/fp8_scaling_concept.svg index a07f596a00..9442b4e4aa 100644 --- a/docs/features/low_precision_training/fp8_current_scaling/img/fp8_scaling_concept.svg +++ b/docs/features/low_precision_training/fp8_current_scaling/img/fp8_scaling_concept.svg @@ -1,9 +1,8 @@ - + diff --git a/docs/features/low_precision_training/performance_considerations/memory_usage_1_jax.py b/docs/features/low_precision_training/performance_considerations/memory_usage_1_jax.py index 59aedce340..3b1744295e 100644 --- a/docs/features/low_precision_training/performance_considerations/memory_usage_1_jax.py +++ b/docs/features/low_precision_training/performance_considerations/memory_usage_1_jax.py @@ -31,6 +31,7 @@ def measure_memory(): # Forward pass in high precision output = layer.apply(params, x) + del x # Input is saved by model for backward, not by user script mem_after_forward = get_gpu_memory_mb() - init_memory return mem_after_forward diff --git a/docs/features/low_precision_training/performance_considerations/memory_usage_1_pytorch.py b/docs/features/low_precision_training/performance_considerations/memory_usage_1_pytorch.py index 36bf099490..38d3cfe2fd 100644 --- a/docs/features/low_precision_training/performance_considerations/memory_usage_1_pytorch.py +++ b/docs/features/low_precision_training/performance_considerations/memory_usage_1_pytorch.py @@ -20,20 +20,20 @@ def measure_memory(): init_memory = torch.cuda.memory_allocated() layer = te.Linear(1024, 1024, params_dtype=torch.bfloat16) - memory = torch.cuda.memory_allocated() - init_memory inp = torch.randn(1024, 1024, dtype=torch.bfloat16, device="cuda") out = layer(inp) - mem_after_forward = torch.cuda.memory_allocated() - init_memory + del inp # Input is saved by model for backward, not by user script - return memory, mem_after_forward + mem_after_forward = torch.cuda.memory_allocated() - init_memory + return mem_after_forward # Warmup run measure_memory() # Actual measurement -memory, mem_after_forward = measure_memory() +mem_after_forward = measure_memory() print(f"Memory usage after forward pass: {mem_after_forward/1024**2:.2f} MB") # END_MEMORY_USAGE_1 print("# END_MEMORY_USAGE_1") diff --git a/docs/features/low_precision_training/performance_considerations/memory_usage_2_jax.py b/docs/features/low_precision_training/performance_considerations/memory_usage_2_jax.py index b2024dda44..a724b1ebd0 100644 --- a/docs/features/low_precision_training/performance_considerations/memory_usage_2_jax.py +++ b/docs/features/low_precision_training/performance_considerations/memory_usage_2_jax.py @@ -35,6 +35,7 @@ def measure_memory(): with te.autocast(enabled=True, recipe=recipe): params = layer.init(key, x) output = layer.apply(params, x) + del x # Input is saved by model for backward, not by user script mem_after_forward = get_gpu_memory_mb() - init_memory return mem_after_forward diff --git a/docs/features/low_precision_training/performance_considerations/memory_usage_2_pytorch.out b/docs/features/low_precision_training/performance_considerations/memory_usage_2_pytorch.out index 8b47519bbb..cc1e402581 100644 --- a/docs/features/low_precision_training/performance_considerations/memory_usage_2_pytorch.out +++ b/docs/features/low_precision_training/performance_considerations/memory_usage_2_pytorch.out @@ -1,4 +1,4 @@ # START_MEMORY_USAGE_2 -Memory after forward pass: 8.02 MB +Memory after forward pass: 6.02 MB # END_MEMORY_USAGE_2 diff --git a/docs/features/low_precision_training/performance_considerations/memory_usage_2_pytorch.py b/docs/features/low_precision_training/performance_considerations/memory_usage_2_pytorch.py index 3db76fd1e5..7928ace2af 100644 --- a/docs/features/low_precision_training/performance_considerations/memory_usage_2_pytorch.py +++ b/docs/features/low_precision_training/performance_considerations/memory_usage_2_pytorch.py @@ -20,12 +20,13 @@ def measure_memory(): init_memory = torch.cuda.memory_allocated() layer = te.Linear(1024, 1024, params_dtype=torch.bfloat16) - inp = torch.randn(1024, 1024, dtype=torch.bfloat16, device="cuda") + inp = torch.randn(1024, 1024, dtype=torch.bfloat16, device="cuda") with te.autocast(enabled=True): out = layer(inp) - mem_after_forward = torch.cuda.memory_allocated() - init_memory + del inp # Input is saved by model for backward, not by user script + mem_after_forward = torch.cuda.memory_allocated() - init_memory return mem_after_forward diff --git a/docs/features/low_precision_training/performance_considerations/memory_usage_3_pytorch.out b/docs/features/low_precision_training/performance_considerations/memory_usage_3_pytorch.out index 6463130bc2..d49109a99c 100644 --- a/docs/features/low_precision_training/performance_considerations/memory_usage_3_pytorch.out +++ b/docs/features/low_precision_training/performance_considerations/memory_usage_3_pytorch.out @@ -1,4 +1,4 @@ # START_MEMORY_USAGE_3 -Memory after forward pass: 6.02 MB +Memory after forward pass: 4.02 MB # END_MEMORY_USAGE_3 diff --git a/docs/features/low_precision_training/performance_considerations/memory_usage_3_pytorch.py b/docs/features/low_precision_training/performance_considerations/memory_usage_3_pytorch.py index 9041749bde..3de5f10aad 100644 --- a/docs/features/low_precision_training/performance_considerations/memory_usage_3_pytorch.py +++ b/docs/features/low_precision_training/performance_considerations/memory_usage_3_pytorch.py @@ -20,25 +20,26 @@ def measure_memory(): init_memory = torch.cuda.memory_allocated() - # FP8 forward and backward with FP8 weights + # FP8 inference with FP8 weights with te.quantized_model_init(enabled=True), torch.no_grad(): layer_fp8 = te.Linear(1024, 1024, params_dtype=torch.bfloat16) - memory = torch.cuda.memory_allocated() - init_memory - inp = torch.randn(1024, 1024, dtype=torch.bfloat16, device="cuda") - with te.autocast(enabled=True): - out = layer_fp8(inp) + with torch.no_grad(): + inp = torch.randn(1024, 1024, dtype=torch.bfloat16, device="cuda") + with te.autocast(enabled=True): + out = layer_fp8(inp) + del inp # Input is not saved by model for backward in inference mem_after_forward = torch.cuda.memory_allocated() - init_memory - return memory, mem_after_forward + return mem_after_forward # Warmup run measure_memory() # Actual measurement -memory, mem_after_forward = measure_memory() +mem_after_forward = measure_memory() print(f"Memory after forward pass: {mem_after_forward/1024**2:.2f} MB") # END_MEMORY_USAGE_3 print("# END_MEMORY_USAGE_3") diff --git a/docs/features/low_precision_training/performance_considerations/performance_considerations.rst b/docs/features/low_precision_training/performance_considerations/performance_considerations.rst index 9625bba556..911b8e8bf7 100644 --- a/docs/features/low_precision_training/performance_considerations/performance_considerations.rst +++ b/docs/features/low_precision_training/performance_considerations/performance_considerations.rst @@ -213,7 +213,7 @@ and columnwise tensors require separate memory layouts. :end-before: # END_MEMORY_USAGE_1 Layer size is ``1024 * 1024 * 2 (2 bytes per parameter) = 2MB``. - Memory after forward pass is ``2 MB (weight) + 2 MB (input) + 2 MB (output) = 6 MB``. + Memory after forward pass is ``2 MB (weight) + 2 MB (input saved for backward) + 2 MB (output) = 6 MB``. **2. FP8 training with model weights in BF16** @@ -243,11 +243,12 @@ and columnwise tensors require separate memory layouts. :start-after: # START_MEMORY_USAGE_2 :end-before: # END_MEMORY_USAGE_2 - Total memory usage is ``2 MB (weight) + 1 MB (weight in FP8) + 2 MB (input) + 1 MB (input in FP8) + 2 MB (output) = 8 MB``. + Total memory usage is ``2 MB (weight) + 1 MB (weight in FP8) + 1 MB (input in FP8 saved for backward) + 2 MB (output) = 6 MB``. - **3. FP8 training with model weights stored directly in low precision** + **3. FP8 inference with model weights stored directly in low precision** - When model weights are stored directly in low precision, master weights are not needed. + For inference scenarios, model weights can be stored directly in low precision. Since we are only + performing forward passes without gradient updates, master weights in high precision are not needed. .. raw:: html @@ -273,9 +274,8 @@ and columnwise tensors require separate memory layouts. :start-after: # START_MEMORY_USAGE_3 :end-before: # END_MEMORY_USAGE_3 - Total memory usage is ``1 MB (weight in FP8) + 2 MB (input) + 1 MB (input in FP8) + 2 MB (output) = 6 MB``. - Note that columnwise FP8 weight is not computed during initialization with ``torch.no_grad()``. - It will be computed on the first backward pass from the rowwise FP8 weight. + Total memory usage is ``1 MB (weight in FP8) + 1 MB (input in FP8) + 2 MB (output) = 4 MB``. + This is lower than the BF16 baseline (6 MB) since no high precision copies are needed. **4. Saving original input instead of quantized** @@ -342,7 +342,7 @@ and columnwise tensors require separate memory layouts. :end-before: # END_MEMORY_USAGE_1 Layer size is ``1024 * 1024 * 2 (2 bytes per parameter) = 2MB``. - Memory after forward pass is ``2 MB (weight) + 2 MB (input) + 2 MB (output) = 6 MB``. + Memory after forward pass is ``2 MB (weight) + 2 MB (input saved for backward) + 2 MB (output) = 6 MB``. **2. FP8 training with master weights in BF16** From 983c88229c026b24f474ea4b562009152e176bff Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Mon, 12 Jan 2026 17:38:09 +0100 Subject: [PATCH 09/25] fix Signed-off-by: Pawel Gadzinski --- .../pytorch_blockwise_scaling_example.py | 2 +- .../fp8_current_scaling/fp8_current_scaling.rst | 2 +- .../fp8_current_scaling/pytorch_current_scaling_example.py | 2 +- .../low_precision_training/mxfp8/pytorch_mxfp8_example.py | 2 +- .../low_precision_training/nvfp4/pytorch_nvfp4_example.py | 2 +- .../performance_considerations/memory_usage_3_pytorch.out | 2 +- .../performance_considerations/performance_considerations.rst | 4 ++-- 7 files changed, 8 insertions(+), 8 deletions(-) diff --git a/docs/features/low_precision_training/fp8_blockwise_scaling/pytorch_blockwise_scaling_example.py b/docs/features/low_precision_training/fp8_blockwise_scaling/pytorch_blockwise_scaling_example.py index 3bc8c72805..5100fc1a1d 100644 --- a/docs/features/low_precision_training/fp8_blockwise_scaling/pytorch_blockwise_scaling_example.py +++ b/docs/features/low_precision_training/fp8_blockwise_scaling/pytorch_blockwise_scaling_example.py @@ -28,7 +28,7 @@ # Forward and backward pass inp = torch.randn(32, 128, 1024, dtype=torch.bfloat16, device="cuda") -with te.autocast(enabled=True, fp8_recipe=recipe): +with te.autocast(enabled=True, recipe=recipe): output = layer(inp) loss = output.sum() diff --git a/docs/features/low_precision_training/fp8_current_scaling/fp8_current_scaling.rst b/docs/features/low_precision_training/fp8_current_scaling/fp8_current_scaling.rst index 1d16cfb029..501195bb2d 100644 --- a/docs/features/low_precision_training/fp8_current_scaling/fp8_current_scaling.rst +++ b/docs/features/low_precision_training/fp8_current_scaling/fp8_current_scaling.rst @@ -102,7 +102,7 @@ Transpose handling On Ada and Hopper, the backward pass requires a transposed FP8 tensor. The columnwise layout is physically different from the rowwise layout, so a transpose operation is needed. -All 3 options from :ref:`introduction Transpose handling section ` are supported. +All 3 options from :ref:`Performance Considerations Transpose handling section ` are supported. *Blackwell and later* diff --git a/docs/features/low_precision_training/fp8_current_scaling/pytorch_current_scaling_example.py b/docs/features/low_precision_training/fp8_current_scaling/pytorch_current_scaling_example.py index 583cac47db..7ac1271890 100644 --- a/docs/features/low_precision_training/fp8_current_scaling/pytorch_current_scaling_example.py +++ b/docs/features/low_precision_training/fp8_current_scaling/pytorch_current_scaling_example.py @@ -20,7 +20,7 @@ # Forward and backward pass inp = torch.randn(32, 128, 1024, dtype=torch.bfloat16, device="cuda") -with te.autocast(enabled=True, fp8_recipe=recipe): +with te.autocast(enabled=True, recipe=recipe): output = layer(inp) loss = output.sum() diff --git a/docs/features/low_precision_training/mxfp8/pytorch_mxfp8_example.py b/docs/features/low_precision_training/mxfp8/pytorch_mxfp8_example.py index 3f9c4d3705..3cc70137b5 100644 --- a/docs/features/low_precision_training/mxfp8/pytorch_mxfp8_example.py +++ b/docs/features/low_precision_training/mxfp8/pytorch_mxfp8_example.py @@ -25,7 +25,7 @@ # Forward and backward pass inp = torch.randn(32, 128, 1024, dtype=torch.bfloat16, device="cuda") -with te.autocast(enabled=True, fp8_recipe=recipe): +with te.autocast(enabled=True, recipe=recipe): output = layer(inp) loss = output.sum() diff --git a/docs/features/low_precision_training/nvfp4/pytorch_nvfp4_example.py b/docs/features/low_precision_training/nvfp4/pytorch_nvfp4_example.py index 883cc84c7e..0790c46051 100644 --- a/docs/features/low_precision_training/nvfp4/pytorch_nvfp4_example.py +++ b/docs/features/low_precision_training/nvfp4/pytorch_nvfp4_example.py @@ -28,7 +28,7 @@ # Forward and backward pass inp = torch.randn(32, 128, 1024, dtype=torch.bfloat16, device="cuda") -with te.autocast(enabled=True, fp8_recipe=recipe): +with te.autocast(enabled=True, recipe=recipe): output = layer(inp) loss = output.sum() diff --git a/docs/features/low_precision_training/performance_considerations/memory_usage_3_pytorch.out b/docs/features/low_precision_training/performance_considerations/memory_usage_3_pytorch.out index d49109a99c..ea4d0dc891 100644 --- a/docs/features/low_precision_training/performance_considerations/memory_usage_3_pytorch.out +++ b/docs/features/low_precision_training/performance_considerations/memory_usage_3_pytorch.out @@ -1,4 +1,4 @@ # START_MEMORY_USAGE_3 -Memory after forward pass: 4.02 MB +Memory after forward pass: 3.02 MB # END_MEMORY_USAGE_3 diff --git a/docs/features/low_precision_training/performance_considerations/performance_considerations.rst b/docs/features/low_precision_training/performance_considerations/performance_considerations.rst index 911b8e8bf7..81476474c0 100644 --- a/docs/features/low_precision_training/performance_considerations/performance_considerations.rst +++ b/docs/features/low_precision_training/performance_considerations/performance_considerations.rst @@ -274,8 +274,8 @@ and columnwise tensors require separate memory layouts. :start-after: # START_MEMORY_USAGE_3 :end-before: # END_MEMORY_USAGE_3 - Total memory usage is ``1 MB (weight in FP8) + 1 MB (input in FP8) + 2 MB (output) = 4 MB``. - This is lower than the BF16 baseline (6 MB) since no high precision copies are needed. + Total memory usage is ``1 MB (weight in FP8) + 2 MB (output) = 3 MB``. + This is lower than the BF16 baseline (6 MB) since no copies are saved for backward in inference mode. **4. Saving original input instead of quantized** From 25bb9eebac319901ee26d966e73762fa7fb83b8a Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Mon, 12 Jan 2026 17:52:46 +0100 Subject: [PATCH 10/25] fix Signed-off-by: Pawel Gadzinski --- .../jax_blockwise_scaling_example.py | 11 ++++------- .../low_precision_training/mxfp8/jax_mxfp8_example.py | 9 ++++----- .../low_precision_training/nvfp4/jax_nvfp4_example.py | 9 ++++----- 3 files changed, 12 insertions(+), 17 deletions(-) diff --git a/docs/features/low_precision_training/fp8_blockwise_scaling/jax_blockwise_scaling_example.py b/docs/features/low_precision_training/fp8_blockwise_scaling/jax_blockwise_scaling_example.py index e838de1955..546cedcab4 100644 --- a/docs/features/low_precision_training/fp8_blockwise_scaling/jax_blockwise_scaling_example.py +++ b/docs/features/low_precision_training/fp8_blockwise_scaling/jax_blockwise_scaling_example.py @@ -2,14 +2,11 @@ # # See LICENSE for license information. -import jax - # Check for Hopper or newer GPU -gpu = jax.devices("gpu")[0] -major, minor = gpu.compute_capability.split(".") -assert ( - int(major) >= 9 -), f"FP8 Blockwise Scaling requires SM90 (Hopper) or later, got SM{major}{minor}" +from transformer_engine_jax import get_device_compute_capability + +major_minor = get_device_compute_capability(0) +assert major_minor >= 90, f"FP8 Blockwise Scaling requires SM90 (Hopper) or later, got SM{major_minor}" # START_BLOCKWISE_SCALING_EXAMPLE diff --git a/docs/features/low_precision_training/mxfp8/jax_mxfp8_example.py b/docs/features/low_precision_training/mxfp8/jax_mxfp8_example.py index 2e0c28286e..d4eed8aecc 100644 --- a/docs/features/low_precision_training/mxfp8/jax_mxfp8_example.py +++ b/docs/features/low_precision_training/mxfp8/jax_mxfp8_example.py @@ -2,12 +2,11 @@ # # See LICENSE for license information. -import jax - # Check for Blackwell or newer GPU -gpu = jax.devices("gpu")[0] -major, minor = gpu.compute_capability.split(".") -assert int(major) >= 10, f"MXFP8 requires SM100 (Blackwell) or later, got SM{major}{minor}" +from transformer_engine_jax import get_device_compute_capability + +major_minor = get_device_compute_capability(0) +assert major_minor >= 100, f"MXFP8 requires SM100 (Blackwell) or later, got SM{major_minor}" # START_MXFP8_EXAMPLE diff --git a/docs/features/low_precision_training/nvfp4/jax_nvfp4_example.py b/docs/features/low_precision_training/nvfp4/jax_nvfp4_example.py index 1c2e13ff73..4d55e7899c 100644 --- a/docs/features/low_precision_training/nvfp4/jax_nvfp4_example.py +++ b/docs/features/low_precision_training/nvfp4/jax_nvfp4_example.py @@ -2,12 +2,11 @@ # # See LICENSE for license information. -import jax - # Check for Blackwell or newer GPU -gpu = jax.devices("gpu")[0] -major, minor = gpu.compute_capability.split(".") -assert int(major) >= 10, f"NVFP4 requires SM100 (Blackwell) or later, got SM{major}{minor}" +from transformer_engine_jax import get_device_compute_capability + +major_minor = get_device_compute_capability(0) +assert major_minor >= 100, f"NVFP4 requires SM100 (Blackwell) or later, got SM{major_minor}" # START_NVFP4_EXAMPLE From 0c0c7f570f345045cbc609f612553ab7fdbe9e8f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 12 Jan 2026 16:53:49 +0000 Subject: [PATCH 11/25] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../fp8_blockwise_scaling/jax_blockwise_scaling_example.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/features/low_precision_training/fp8_blockwise_scaling/jax_blockwise_scaling_example.py b/docs/features/low_precision_training/fp8_blockwise_scaling/jax_blockwise_scaling_example.py index 546cedcab4..62dd584305 100644 --- a/docs/features/low_precision_training/fp8_blockwise_scaling/jax_blockwise_scaling_example.py +++ b/docs/features/low_precision_training/fp8_blockwise_scaling/jax_blockwise_scaling_example.py @@ -6,7 +6,9 @@ from transformer_engine_jax import get_device_compute_capability major_minor = get_device_compute_capability(0) -assert major_minor >= 90, f"FP8 Blockwise Scaling requires SM90 (Hopper) or later, got SM{major_minor}" +assert ( + major_minor >= 90 +), f"FP8 Blockwise Scaling requires SM90 (Hopper) or later, got SM{major_minor}" # START_BLOCKWISE_SCALING_EXAMPLE From f26d35e647dabca0cfd9e1a2f06c960f6034a307 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Tue, 13 Jan 2026 12:09:48 +0100 Subject: [PATCH 12/25] fix Signed-off-by: Pawel Gadzinski --- .../fp8_blockwise_scaling.rst | 106 ++++--- .../fp8_current_scaling.rst | 38 +-- .../fp8_delayed_scaling.rst | 4 +- .../low_precision_training/mxfp8/mxfp8.rst | 119 ++++---- .../nvfp4/jax_nvfp4_example.py | 11 +- .../low_precision_training/nvfp4/nvfp4.rst | 43 ++- .../nvfp4/pytorch_nvfp4_example.py | 12 +- .../img/gemm_access_pattern.svg | 264 +++++++++--------- .../performance_considerations.rst | 1 + 9 files changed, 313 insertions(+), 285 deletions(-) diff --git a/docs/features/low_precision_training/fp8_blockwise_scaling/fp8_blockwise_scaling.rst b/docs/features/low_precision_training/fp8_blockwise_scaling/fp8_blockwise_scaling.rst index b0fa98f1af..79815baed3 100644 --- a/docs/features/low_precision_training/fp8_blockwise_scaling/fp8_blockwise_scaling.rst +++ b/docs/features/low_precision_training/fp8_blockwise_scaling/fp8_blockwise_scaling.rst @@ -128,8 +128,64 @@ For 2D scaling, columnwise data can be created from rowwise data by transposing both the quantized data and the scaling factors. Each 128×128 block covers the same elements regardless of access direction, so the scaling factors remain valid. + +Distributed training +----------------------- + +**Scale synchronization** + +The blockwise scaled tensor does not need any scale synchronization among the nodes. +This is because each scaling factor is local to its 128 or 128×128 element block, +unlike FP8 Current/Delayed Scaling where a single global scale applies to the entire tensor, even when sharded. + +**Quantized all-gather** + +FP8 Blockwise Scaling all-gather is supported. + + +Examples +-------- + +Here's how to use the FP8 Blockwise Scaling recipe in PyTorch and JAX: + +.. note:: + + Requires SM90 (Hopper) or later. + +.. tabs:: + + .. tab:: PyTorch + + .. literalinclude:: pytorch_blockwise_scaling_example.py + :language: python + :start-after: # START_BLOCKWISE_SCALING_EXAMPLE + :end-before: # END_BLOCKWISE_SCALING_EXAMPLE + + .. tab:: JAX + + .. literalinclude:: jax_blockwise_scaling_example.py + :language: python + :start-after: # START_BLOCKWISE_SCALING_EXAMPLE + :end-before: # END_BLOCKWISE_SCALING_EXAMPLE + +Supported devices +----------------- + +Hopper (SM 9.0) + +Blackwell and later (SM >= 10.0) – recipe is emulated with MXFP8. Note that this is done mainly for compatibility, MXFP8 is the preferred recipe on Blackwell. + + +---- + +Developer Notes +--------------- + +This section contains implementation details that may be useful for developers +but are not required for using FP8 Blockwise Scaling in practice. + Swizzle of scaling factors --------------------------- +^^^^^^^^^^^^^^^^^^^^^^^^^^ FP8 Blockwise Scaling supports all-gather of both rowwise and columnwise tensors. To support that, it implements different data layouts for communication (all-gather) @@ -187,52 +243,10 @@ when no all-gather is needed, or performed separately after all-gather. compact format, then swizzle is performed separately after communication. Bottom: Without all-gather – quantize and swizzle are fused into a single operation, directly producing GEMM-ready format.* - -Distributed training ------------------------ - -**Scale synchronization** - -The blockwise scaled tensor does not need any scale synchronization among the nodes. -This is because each scaling factor is local to its 128 or 128×128 element block, -unlike FP8 Current/Delayed Scaling where a single global scale applies to the entire tensor, even when sharded. - -**Quantized all-gather** +All-gather of columnwise tensors +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ All-gather of columnwise tensors is supported and necessary because: -- columnwise quantized tensors cannot be computed from rowwise quantized ones (as mentioned earlier), +- columnwise quantized tensors cannot be computed from rowwise quantized ones, - gathering high-precision tensors is avoided in most cases for performance reasons. - - -Examples --------- - -Here's how to use the FP8 Blockwise Scaling recipe in PyTorch and JAX: - -.. note:: - - Requires SM90 (Hopper) or later. - -.. tabs:: - - .. tab:: PyTorch - - .. literalinclude:: pytorch_blockwise_scaling_example.py - :language: python - :start-after: # START_BLOCKWISE_SCALING_EXAMPLE - :end-before: # END_BLOCKWISE_SCALING_EXAMPLE - - .. tab:: JAX - - .. literalinclude:: jax_blockwise_scaling_example.py - :language: python - :start-after: # START_BLOCKWISE_SCALING_EXAMPLE - :end-before: # END_BLOCKWISE_SCALING_EXAMPLE - -Supported devices ------------------ - -Hopper (SM 9.0) - -Blackwell and later (SM >= 10.0) – recipe is emulated with MXFP8. Note that this is done mainly for compatibility, MXFP8 is the preferred recipe on Blackwell. diff --git a/docs/features/low_precision_training/fp8_current_scaling/fp8_current_scaling.rst b/docs/features/low_precision_training/fp8_current_scaling/fp8_current_scaling.rst index 501195bb2d..a4830a3fd5 100644 --- a/docs/features/low_precision_training/fp8_current_scaling/fp8_current_scaling.rst +++ b/docs/features/low_precision_training/fp8_current_scaling/fp8_current_scaling.rst @@ -80,18 +80,6 @@ This is a significant overhead compared to other recipes, which typically requir *Figure 4: FP8 quantization with current scaling recipe - two tensor reads are needed, one to compute amax and one to apply the scaling factor and cast to FP8.* -Hardware support ----------------- - -The Hopper architecture introduced FP8 support in Tensor Cores, enabling efficient low-precision computation. -Tensor Cores support every combination of E4M3 and E5M2 formats as inputs, allowing flexible precision choices for different operands. -The Tensor Core performs the matrix multiplication in FP8 precision and produces output in higher precision (FP16, BF16, or FP32). - -.. raw:: html - :file: img/fp8_tensor_core.svg - -*Figure 5: FP8 Tensor Cores process two input tensors (A and B) with their respective scaling factors and perform matrix multiplication to accumulate higher-precision output.* - Transpose handling ------------------ @@ -119,11 +107,9 @@ The rowwise and columnwise tensors share the same physical memory layout. Distributed training -------------------- -**All-gather of columnwise tensors** +**Quantized all-gather** -Supported for Blackwell and later, since rowwise and columnwise tensors share the same memory layout. -For Hopper and Ada, all-gather of transposed FP8 tensors is not supported. -The rowwise tensor is gathered and then it is transposed to columnwise tensor. +FP8 all-gather is supported on all architectures (Ada and later). **Amax reduction** @@ -173,4 +159,22 @@ Here's how to use FP8 Current Scaling recipe in PyTorch and JAX: .. literalinclude:: jax_current_scaling_example.py :language: python :start-after: # START_CURRENT_SCALING_EXAMPLE - :end-before: # END_CURRENT_SCALING_EXAMPLE \ No newline at end of file + :end-before: # END_CURRENT_SCALING_EXAMPLE + + +---- + +Developer Notes +--------------- + +This section contains implementation details that may be useful for developers +but are not required for using FP8 Current Scaling in practice. + +All-gather of columnwise tensors +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +On Blackwell and later, rowwise and columnwise tensors share the same memory layout, +so all-gather of columnwise tensors is directly supported. + +For Hopper and Ada, all-gather of transposed FP8 tensors is not supported. +The rowwise tensor is gathered first, then transposed to columnwise format. \ No newline at end of file diff --git a/docs/features/low_precision_training/fp8_delayed_scaling/fp8_delayed_scaling.rst b/docs/features/low_precision_training/fp8_delayed_scaling/fp8_delayed_scaling.rst index 9e0de084fd..0343ddc03f 100644 --- a/docs/features/low_precision_training/fp8_delayed_scaling/fp8_delayed_scaling.rst +++ b/docs/features/low_precision_training/fp8_delayed_scaling/fp8_delayed_scaling.rst @@ -115,9 +115,7 @@ Here's how to use FP8 Delayed Scaling in PyTorch and JAX: Distributed Training -------------------- -FP8 Delayed Scaling uses the same data formats as FP8 Current Scaling - -all-gather of non-transposed tensors is supported. - +FP8 Delayed Scaling uses the same data formats as FP8 Current Scaling - qunatized all-gather is supported. However, amax reduction works slightly differently in different frameworks. .. tabs:: diff --git a/docs/features/low_precision_training/mxfp8/mxfp8.rst b/docs/features/low_precision_training/mxfp8/mxfp8.rst index dc5816515f..14d805e808 100644 --- a/docs/features/low_precision_training/mxfp8/mxfp8.rst +++ b/docs/features/low_precision_training/mxfp8/mxfp8.rst @@ -54,7 +54,7 @@ There are some assumptions on the dimensions of the tensor: Scaling factors are stored as E8M0 (8 exponent bits, 0 mantissa bits), which inherently represents powers of 2. This differs from FP8 Blockwise Scaling, which uses 32-bit floating point numbers optionally constrained to powers of 2. Note that FP32 also has 8 exponent bits, so the representable -ranges are similar when the power-of-2 constraint is enabled. +ranges are the same when the power-of-2 constraint is enabled. Each block's scaling factor is computed through the following steps: @@ -96,53 +96,6 @@ independently from the full-precision data. *Figure 2. MXFP8 rowwise vs columnwise quantization layout.* -Swizzling scaling factors -------------------------- - -Like :doc:`FP8 Blockwise Scaling <../fp8_blockwise_scaling/fp8_blockwise_scaling>`, MXFP8 uses different data layouts for communication and computation. -MXFP8 GEMMs require scaling factors in a specific hardware layout -(see `cuBLAS documentation `__). -The conversion to this GEMM-ready layout is called *swizzling*. Because swizzled scaling factors -cannot be communicated across devices, Transformer Engine performs swizzling after any required -communication, just before each GEMM operation. - -.. raw:: html - :file: img/mxfp8_swizzle_both_tensors.svg - -*Figure 3. MXFP8 swizzling process: standard scaling factors are rearranged into the hardware-required layout.* - - -Blackwell Tensor Cores compute matrix multiplications using ``128x128`` tiles. -Scaling factors are stored in row-major order, but to process a tile, we need a ``128x4`` vertical -slice of scaling factors. In row-major storage, these vertical slices are scattered in memory -with gaps between each row. The hardware requires them to be stored contiguously. - -.. raw:: html - :file: img/mxfp8_tensor_scaling_layout.svg - -*Figure 4. FP8 tensor (left) is divided into 128x128 tiles. Each tile requires a 128x4 block of scaling factors (right). These vertical blocks are not contiguous in memory.* - -Swizzling transforms the layout to meet hardware requirements by: - -1. **Linearizing** the ``128x4`` blocks so they are stored contiguously one after another. -2. **Permuting** the 4-byte elements within each block. - -Specifically, if we index the 128 4-byte elements in a scaling factor block as :math:`0, 1, \dots, 127`, the hardware expects them in the following interleaved order: - -.. code-block:: text - - 0, 32, 64, 96, 1, 33, 65, 97, ..., k, 32 + k, 64 + k, 96 + k, ..., 31, 63, 95, 127 - - -.. raw:: html - :file: img/mxfp8_scale_linearize_and_swizzle.svg - -*Figure 5. Linearization and swizzling of scaling factors. The 2D grid of scaling factors is first flattened into a contiguous sequence of blocks (top), then the rows within each block are interleaved to match the hardware access pattern (bottom).* - -For columnwise scaling factors, the process is analogous but with ``4x128`` horizontal blocks instead of ``128x4`` vertical blocks. - - - Distributed training -------------------- @@ -154,10 +107,7 @@ unlike :doc:`FP8 Current <../fp8_current_scaling/fp8_current_scaling>`/:doc:`Del **Quantized all-gather** -All-gather of columnwise tensors is supported and necessary because: - -- columnwise quantized tensors cannot be computed from rowwise quantized ones (as mentioned earlier), -- gathering high-precision tensors is avoided in most cases for performance reasons. +MXFP8 all-gather is supported. Examples @@ -197,4 +147,67 @@ Here's how to use MXFP8 recipe in PyTorch and JAX: Supported devices ----------------- -Blackwell and later (SM 10.0+) \ No newline at end of file +Blackwell and later (SM 10.0+) + + +---- + +Developer Notes +--------------- + +This section contains implementation details that may be useful for developers +but are not required for using MXFP8 in practice. + +Swizzling scaling factors +^^^^^^^^^^^^^^^^^^^^^^^^^ + +Like :doc:`FP8 Blockwise Scaling <../fp8_blockwise_scaling/fp8_blockwise_scaling>`, MXFP8 uses different data layouts for communication and computation. +MXFP8 GEMMs require scaling factors in a specific hardware layout +(see `cuBLAS documentation `__). +The conversion to this GEMM-ready layout is called *swizzling*. When no communication is needed, +swizzling can be fused with quantization. When communication is required, swizzled scaling factors +cannot be communicated across devices, so Transformer Engine performs swizzling after communication, +just before each GEMM operation. + +.. raw:: html + :file: img/mxfp8_swizzle_both_tensors.svg + +*Figure 3. MXFP8 swizzling process: standard scaling factors are rearranged into the hardware-required layout.* + + +Blackwell Tensor Cores compute matrix multiplications using ``128x128`` tiles. +Scaling factors are stored in row-major order, but to process a tile, we need a ``128x4`` vertical +slice of scaling factors. In row-major storage, these vertical slices are scattered in memory +with gaps between each row. The hardware requires them to be stored contiguously. + +.. raw:: html + :file: img/mxfp8_tensor_scaling_layout.svg + +*Figure 4. FP8 tensor (left) is divided into 128x128 tiles. Each tile requires a 128x4 block of scaling factors (right). These vertical blocks are not contiguous in memory.* + +Swizzling transforms the layout to meet hardware requirements by: + +1. **Linearizing** the ``128x4`` blocks so they are stored contiguously one after another. +2. **Permuting** the 4-byte elements within each block. + +Specifically, if we index the 128 4-byte elements in a scaling factor block as :math:`0, 1, \dots, 127`, the hardware expects them in the following interleaved order: + +.. code-block:: text + + 0, 32, 64, 96, 1, 33, 65, 97, ..., k, 32 + k, 64 + k, 96 + k, ..., 31, 63, 95, 127 + + +.. raw:: html + :file: img/mxfp8_scale_linearize_and_swizzle.svg + +*Figure 5. Linearization and swizzling of scaling factors. The 2D grid of scaling factors is first flattened into a contiguous sequence of blocks (top), then the rows within each block are interleaved to match the hardware access pattern (bottom).* + +For columnwise scaling factors, the process is analogous but with ``4x128`` horizontal blocks instead of ``128x4`` vertical blocks. + +All-gather of columnwise tensors +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +All-gather of columnwise tensors is supported and necessary because: + +- columnwise quantized tensors cannot be computed from rowwise quantized ones, +- gathering high-precision tensors is avoided in most cases for performance reasons. \ No newline at end of file diff --git a/docs/features/low_precision_training/nvfp4/jax_nvfp4_example.py b/docs/features/low_precision_training/nvfp4/jax_nvfp4_example.py index 4d55e7899c..6eee6e75d0 100644 --- a/docs/features/low_precision_training/nvfp4/jax_nvfp4_example.py +++ b/docs/features/low_precision_training/nvfp4/jax_nvfp4_example.py @@ -14,14 +14,13 @@ import jax.numpy as jnp import transformer_engine.jax as te from transformer_engine.jax.flax import DenseGeneral -from transformer_engine.common.recipe import NVFP4BlockScaling, Format +from transformer_engine.common.recipe import NVFP4BlockScaling # Define NVFP4 recipe -recipe = NVFP4BlockScaling( - fp8_format=Format.E4M3, - use_2d_weight_quantization=True, - use_rht=True, -) +# 2D weight quantization and RHT are enabled by default +recipe = NVFP4BlockScaling() +# To disable features, use: +# recipe = NVFP4BlockScaling(disable_rht=True, disable_2d_quantization=True) with te.autocast(enabled=True, recipe=recipe): # Initialize layer and data diff --git a/docs/features/low_precision_training/nvfp4/nvfp4.rst b/docs/features/low_precision_training/nvfp4/nvfp4.rst index 3d35346ece..0415963a71 100644 --- a/docs/features/low_precision_training/nvfp4/nvfp4.rst +++ b/docs/features/low_precision_training/nvfp4/nvfp4.rst @@ -185,18 +185,6 @@ second dimension to a multiple of 4 (e.g. rowwise: ``[roundup(A, 128), roundup(B *Figure 5. NVFP4 rowwise vs columnwise quantization layout. Unlike MXFP8, columnwise scales are stored transposed.* -Swizzling scaling factors -------------------------- - -NVFP4 requires swizzling of block scaling factors (``s_block``) before GEMM operations, -similar to :doc:`MXFP8 <../mxfp8/mxfp8>`. Key differences: - -- Block size is 16 (vs 32 for MXFP8) -- Both rowwise and columnwise scaling factors are swizzled, but thanks to the transposed - columnwise layout, a single rowwise swizzle kernel handles both cases. -- Scaling factors are stored as FP8 E4M3 (vs E8M0 for MXFP8) - - Distributed training -------------------- @@ -212,9 +200,7 @@ If before synchronization there was ``amax_1`` on node 1, **Quantized all-gather** -All-gather of columnwise tensors is supported. To enable quantized all-gather, -all nodes must use the same ``s_global``, which is computed from the synchronized global amax. -This is automatically enabled for column-parallel and row-parallel linear layers. +NVFP4 all-gather is supported. .. raw:: html :file: img/nvfp4_all_gather.svg @@ -260,3 +246,30 @@ Supported devices * **Training**: SM 10.0, SM 10.3 * **Inference**: SM 10.0+ + + +---- + +Developer Notes +--------------- + +This section contains implementation details that may be useful for developers +but are not required for using NVFP4 in practice. + +Swizzling scaling factors +^^^^^^^^^^^^^^^^^^^^^^^^^ + +NVFP4 requires swizzling of block scaling factors (``s_block``) before GEMM operations, +similar to :doc:`MXFP8 <../mxfp8/mxfp8>`. Key differences: + +- Block size is 16 (vs 32 for MXFP8) +- Both rowwise and columnwise scaling factors are swizzled, but thanks to the transposed + columnwise layout, a single rowwise swizzle kernel handles both cases. +- Scaling factors are stored as FP8 E4M3 (vs E8M0 for MXFP8) + +All-gather of columnwise tensors +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +All-gather of columnwise tensors is supported. To enable quantized all-gather, +all nodes must use the same ``s_global``, which is computed from the synchronized global amax. +This is automatically enabled for column-parallel and row-parallel linear layers. diff --git a/docs/features/low_precision_training/nvfp4/pytorch_nvfp4_example.py b/docs/features/low_precision_training/nvfp4/pytorch_nvfp4_example.py index 0790c46051..07b680defa 100644 --- a/docs/features/low_precision_training/nvfp4/pytorch_nvfp4_example.py +++ b/docs/features/low_precision_training/nvfp4/pytorch_nvfp4_example.py @@ -12,15 +12,13 @@ import torch import transformer_engine.pytorch as te -from transformer_engine.common.recipe import NVFP4BlockScaling, Format +from transformer_engine.common.recipe import NVFP4BlockScaling # Define NVFP4 recipe -# Key features like 2D weight quantization and RHT can be enabled here -recipe = NVFP4BlockScaling( - fp8_format=Format.E4M3, - use_2d_weight_quantization=True, - use_rht=True, -) +# 2D weight quantization and RHT are enabled by default +recipe = NVFP4BlockScaling() +# To disable features, use: +# recipe = NVFP4BlockScaling(disable_rht=True, disable_2d_quantization=True) # Create a linear layer with bfloat16 parameters layer = te.Linear(1024, 1024, params_dtype=torch.bfloat16) diff --git a/docs/features/low_precision_training/performance_considerations/img/gemm_access_pattern.svg b/docs/features/low_precision_training/performance_considerations/img/gemm_access_pattern.svg index d61d8c7432..fa720427e7 100644 --- a/docs/features/low_precision_training/performance_considerations/img/gemm_access_pattern.svg +++ b/docs/features/low_precision_training/performance_considerations/img/gemm_access_pattern.svg @@ -1,12 +1,12 @@ - + diff --git a/docs/features/low_precision_training/performance_considerations/performance_considerations.rst b/docs/features/low_precision_training/performance_considerations/performance_considerations.rst index 81476474c0..fb2b011d30 100644 --- a/docs/features/low_precision_training/performance_considerations/performance_considerations.rst +++ b/docs/features/low_precision_training/performance_considerations/performance_considerations.rst @@ -39,6 +39,7 @@ The figure below illustrates these access patterns: .. figure:: img/gemm_access_pattern.svg :align: center + :width: 60% :alt: Matrix multiplication access pattern showing rowwise access for first tensor and columnwise access for second tensor Figure 1: Access patterns in matrix multiplication for matrices in ``A * B`` and ``A * B^T`` operations. From cfb427390c6c94b62187fe4edbc3f6c0059df79f Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Tue, 13 Jan 2026 12:19:58 +0100 Subject: [PATCH 13/25] fix Signed-off-by: Pawel Gadzinski --- .../introduction/autocast_jax.py | 84 +++++++++---------- .../introduction/bf16_fp16_training_jax.py | 29 ++++--- 2 files changed, 57 insertions(+), 56 deletions(-) diff --git a/docs/features/low_precision_training/introduction/autocast_jax.py b/docs/features/low_precision_training/introduction/autocast_jax.py index 536a1df86d..dc7f2a1d28 100644 --- a/docs/features/low_precision_training/introduction/autocast_jax.py +++ b/docs/features/low_precision_training/introduction/autocast_jax.py @@ -25,31 +25,29 @@ import jax.numpy as jnp import transformer_engine.jax as te from transformer_engine.jax.flax import TransformerLayer -from transformer_engine.jax.sharding import MeshResource, global_shard_guard from transformer_engine.common.recipe import DelayedScaling, Format # Set up recipe recipe = DelayedScaling() # Model initialization must happen inside autocast -with global_shard_guard(MeshResource()): - with te.autocast(enabled=True, recipe=recipe): - layer = TransformerLayer( - hidden_size=1024, - mlp_hidden_size=4096, - num_attention_heads=16, - ) +with te.autocast(enabled=True, recipe=recipe): + layer = TransformerLayer( + hidden_size=1024, + mlp_hidden_size=4096, + num_attention_heads=16, + ) - init_key, dropout_key = jax.random.split(jax.random.PRNGKey(0)) - x = jax.random.normal(init_key, (32, 128, 1024), dtype=jnp.bfloat16) - params = layer.init({"params": init_key, "dropout": dropout_key}, x) + init_key, dropout_key = jax.random.split(jax.random.PRNGKey(0)) + x = jax.random.normal(init_key, (32, 128, 1024), dtype=jnp.bfloat16) + params = layer.init({"params": init_key, "dropout": dropout_key}, x) - # Forward and backward pass (both inside autocast for JAX) - def loss_fn(params): - output = layer.apply(params, x, rngs={"dropout": dropout_key}) - return output.sum() + # Forward and backward pass (both inside autocast for JAX) + def loss_fn(params): + output = layer.apply(params, x, rngs={"dropout": dropout_key}) + return output.sum() - loss, grads = jax.value_and_grad(loss_fn)(params) + loss, grads = jax.value_and_grad(loss_fn)(params) # END_AUTOCAST_BASIC @@ -59,16 +57,15 @@ def loss_fn(params): encoder_recipe = DelayedScaling(fp8_format=Format.E4M3) decoder_recipe = DelayedScaling(fp8_format=Format.HYBRID) -with global_shard_guard(MeshResource()): - with te.autocast(enabled=True, recipe=encoder_recipe): - encoder = TransformerLayer(hidden_size=1024, mlp_hidden_size=4096, num_attention_heads=16) - encoder_params = encoder.init({"params": init_key, "dropout": dropout_key}, x) - hidden = encoder.apply(encoder_params, x, rngs={"dropout": dropout_key}) +with te.autocast(enabled=True, recipe=encoder_recipe): + encoder = TransformerLayer(hidden_size=1024, mlp_hidden_size=4096, num_attention_heads=16) + encoder_params = encoder.init({"params": init_key, "dropout": dropout_key}, x) + hidden = encoder.apply(encoder_params, x, rngs={"dropout": dropout_key}) - with te.autocast(enabled=True, recipe=decoder_recipe): - decoder = TransformerLayer(hidden_size=1024, mlp_hidden_size=4096, num_attention_heads=16) - decoder_params = decoder.init({"params": init_key, "dropout": dropout_key}, hidden) - output = decoder.apply(decoder_params, hidden, rngs={"dropout": dropout_key}) +with te.autocast(enabled=True, recipe=decoder_recipe): + decoder = TransformerLayer(hidden_size=1024, mlp_hidden_size=4096, num_attention_heads=16) + decoder_params = decoder.init({"params": init_key, "dropout": dropout_key}, hidden) + output = decoder.apply(decoder_params, hidden, rngs={"dropout": dropout_key}) # END_AUTOCAST_SEQUENTIAL @@ -78,24 +75,23 @@ def loss_fn(params): outer_recipe = DelayedScaling(fp8_format=Format.E4M3) inner_recipe = DelayedScaling(fp8_format=Format.HYBRID) -with global_shard_guard(MeshResource()): - with te.autocast(enabled=True, recipe=outer_recipe): - # layer1 uses outer_recipe - layer1 = TransformerLayer(hidden_size=1024, mlp_hidden_size=4096, num_attention_heads=16) - params1 = layer1.init({"params": init_key, "dropout": dropout_key}, x) - hidden = layer1.apply(params1, x, rngs={"dropout": dropout_key}) - - with te.autocast(enabled=True, recipe=inner_recipe): - # layer2 uses inner_recipe (overrides outer) - layer2 = TransformerLayer( - hidden_size=1024, mlp_hidden_size=4096, num_attention_heads=16 - ) - params2 = layer2.init({"params": init_key, "dropout": dropout_key}, hidden) - hidden = layer2.apply(params2, hidden, rngs={"dropout": dropout_key}) - - # layer3 uses outer_recipe again - layer3 = TransformerLayer(hidden_size=1024, mlp_hidden_size=4096, num_attention_heads=16) - params3 = layer3.init({"params": init_key, "dropout": dropout_key}, hidden) - output = layer3.apply(params3, hidden, rngs={"dropout": dropout_key}) +with te.autocast(enabled=True, recipe=outer_recipe): + # layer1 uses outer_recipe + layer1 = TransformerLayer(hidden_size=1024, mlp_hidden_size=4096, num_attention_heads=16) + params1 = layer1.init({"params": init_key, "dropout": dropout_key}, x) + hidden = layer1.apply(params1, x, rngs={"dropout": dropout_key}) + + with te.autocast(enabled=True, recipe=inner_recipe): + # layer2 uses inner_recipe (overrides outer) + layer2 = TransformerLayer( + hidden_size=1024, mlp_hidden_size=4096, num_attention_heads=16 + ) + params2 = layer2.init({"params": init_key, "dropout": dropout_key}, hidden) + hidden = layer2.apply(params2, hidden, rngs={"dropout": dropout_key}) + + # layer3 uses outer_recipe again + layer3 = TransformerLayer(hidden_size=1024, mlp_hidden_size=4096, num_attention_heads=16) + params3 = layer3.init({"params": init_key, "dropout": dropout_key}, hidden) + output = layer3.apply(params3, hidden, rngs={"dropout": dropout_key}) # END_AUTOCAST_NESTED diff --git a/docs/features/low_precision_training/introduction/bf16_fp16_training_jax.py b/docs/features/low_precision_training/introduction/bf16_fp16_training_jax.py index f2e02a5103..6b192dd7a6 100644 --- a/docs/features/low_precision_training/introduction/bf16_fp16_training_jax.py +++ b/docs/features/low_precision_training/introduction/bf16_fp16_training_jax.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -6,8 +6,8 @@ import jax import jax.numpy as jnp +import optax from transformer_engine.jax.flax import TransformerLayer -from transformer_engine.jax.sharding import MeshResource, global_shard_guard def run_forward_backward(params_dtype, compute_dtype): @@ -19,21 +19,26 @@ def run_forward_backward(params_dtype, compute_dtype): dtype=params_dtype, ) - # Initialize parameters + # Initialize parameters and optimizer init_key, dropout_key = jax.random.split(jax.random.PRNGKey(0)) x = jax.random.normal(init_key, (32, 128, 1024), dtype=compute_dtype) + params = layer.init({"params": init_key, "dropout": dropout_key}, x) - # TransformerLayer requires mesh resource context - with global_shard_guard(MeshResource()): - params = layer.init({"params": init_key, "dropout": dropout_key}, x) + # Create optimizer + optimizer = optax.sgd(learning_rate=0.01) + opt_state = optimizer.init(params) - # Forward and backward pass - def loss_fn(params): - output = layer.apply(params, x, rngs={"dropout": dropout_key}) - assert output.dtype == compute_dtype - return output.sum() + # Forward and backward pass + def loss_fn(params): + output = layer.apply(params, x, rngs={"dropout": dropout_key}) + assert output.dtype == compute_dtype + return output.sum() - loss, grads = jax.value_and_grad(loss_fn)(params) + loss, grads = jax.value_and_grad(loss_fn)(params) + + # Update parameters + updates, opt_state = optimizer.update(grads, opt_state, params) + params = optax.apply_updates(params, updates) run_forward_backward(jnp.float32, jnp.float32) # high precision training From cd33d565d2dcfa3e11fac5118f887c919e5af9bb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 13 Jan 2026 11:20:47 +0000 Subject: [PATCH 14/25] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../low_precision_training/introduction/autocast_jax.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/docs/features/low_precision_training/introduction/autocast_jax.py b/docs/features/low_precision_training/introduction/autocast_jax.py index dc7f2a1d28..b5755222fa 100644 --- a/docs/features/low_precision_training/introduction/autocast_jax.py +++ b/docs/features/low_precision_training/introduction/autocast_jax.py @@ -83,9 +83,7 @@ def loss_fn(params): with te.autocast(enabled=True, recipe=inner_recipe): # layer2 uses inner_recipe (overrides outer) - layer2 = TransformerLayer( - hidden_size=1024, mlp_hidden_size=4096, num_attention_heads=16 - ) + layer2 = TransformerLayer(hidden_size=1024, mlp_hidden_size=4096, num_attention_heads=16) params2 = layer2.init({"params": init_key, "dropout": dropout_key}, hidden) hidden = layer2.apply(params2, hidden, rngs={"dropout": dropout_key}) From f47320271eba88cf5ec850c2c97de5cfd07de563 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Tue, 13 Jan 2026 12:22:49 +0100 Subject: [PATCH 15/25] fix Signed-off-by: Pawel Gadzinski --- .../introduction/bf16_fp16_training_jax.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/docs/features/low_precision_training/introduction/bf16_fp16_training_jax.py b/docs/features/low_precision_training/introduction/bf16_fp16_training_jax.py index 6b192dd7a6..636cc0df47 100644 --- a/docs/features/low_precision_training/introduction/bf16_fp16_training_jax.py +++ b/docs/features/low_precision_training/introduction/bf16_fp16_training_jax.py @@ -6,7 +6,6 @@ import jax import jax.numpy as jnp -import optax from transformer_engine.jax.flax import TransformerLayer @@ -24,10 +23,6 @@ def run_forward_backward(params_dtype, compute_dtype): x = jax.random.normal(init_key, (32, 128, 1024), dtype=compute_dtype) params = layer.init({"params": init_key, "dropout": dropout_key}, x) - # Create optimizer - optimizer = optax.sgd(learning_rate=0.01) - opt_state = optimizer.init(params) - # Forward and backward pass def loss_fn(params): output = layer.apply(params, x, rngs={"dropout": dropout_key}) @@ -36,10 +31,6 @@ def loss_fn(params): loss, grads = jax.value_and_grad(loss_fn)(params) - # Update parameters - updates, opt_state = optimizer.update(grads, opt_state, params) - params = optax.apply_updates(params, updates) - run_forward_backward(jnp.float32, jnp.float32) # high precision training run_forward_backward(jnp.float32, jnp.bfloat16) # bfloat16 training with master weights in FP32 From 83a42cde6013f49f7f156233d21df9fe87759bb9 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Tue, 13 Jan 2026 12:23:20 +0100 Subject: [PATCH 16/25] fix Signed-off-by: Pawel Gadzinski --- .../img/fp8_tensor_core.svg | 75 ------------------- .../fp8_delayed_scaling.rst | 2 +- 2 files changed, 1 insertion(+), 76 deletions(-) delete mode 100644 docs/features/low_precision_training/fp8_current_scaling/img/fp8_tensor_core.svg diff --git a/docs/features/low_precision_training/fp8_current_scaling/img/fp8_tensor_core.svg b/docs/features/low_precision_training/fp8_current_scaling/img/fp8_tensor_core.svg deleted file mode 100644 index 5416b5f4c3..0000000000 --- a/docs/features/low_precision_training/fp8_current_scaling/img/fp8_tensor_core.svg +++ /dev/null @@ -1,75 +0,0 @@ - - - - - - - - - - - FP8 Tensor Core Operation - - - - Input A - - - - chunk of FP8 Tensor A - (E4M3 or E5M2) - - - - Scale a - scalar float 32 - - - - - - - Input B - - - - chunk of FP8 Tensor B - (E4M3 or E5M2) - - - - Scale b - scalar float 32 - - - - - - - FP8 Tensor Core - - - - - - - Accumulated chunk of output - Higher precision - - - diff --git a/docs/features/low_precision_training/fp8_delayed_scaling/fp8_delayed_scaling.rst b/docs/features/low_precision_training/fp8_delayed_scaling/fp8_delayed_scaling.rst index 0343ddc03f..9d05305eda 100644 --- a/docs/features/low_precision_training/fp8_delayed_scaling/fp8_delayed_scaling.rst +++ b/docs/features/low_precision_training/fp8_delayed_scaling/fp8_delayed_scaling.rst @@ -115,7 +115,7 @@ Here's how to use FP8 Delayed Scaling in PyTorch and JAX: Distributed Training -------------------- -FP8 Delayed Scaling uses the same data formats as FP8 Current Scaling - qunatized all-gather is supported. +FP8 Delayed Scaling uses the same data formats as FP8 Current Scaling - quantized all-gather is supported. However, amax reduction works slightly differently in different frameworks. .. tabs:: From 1321ecf1e9d37567217d15bc99ee6bf67c16c371 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Tue, 13 Jan 2026 12:24:11 +0100 Subject: [PATCH 17/25] year change Signed-off-by: Pawel Gadzinski --- .../introduction/bf16_fp16_training_jax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/features/low_precision_training/introduction/bf16_fp16_training_jax.py b/docs/features/low_precision_training/introduction/bf16_fp16_training_jax.py index 636cc0df47..8868b63ac6 100644 --- a/docs/features/low_precision_training/introduction/bf16_fp16_training_jax.py +++ b/docs/features/low_precision_training/introduction/bf16_fp16_training_jax.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. From ad43ae52e5c107721600a9e28a7224a0fcfcaa3c Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Tue, 13 Jan 2026 12:27:04 +0100 Subject: [PATCH 18/25] fix Signed-off-by: Pawel Gadzinski --- .../low_precision_training/introduction/introduction.rst | 3 --- 1 file changed, 3 deletions(-) diff --git a/docs/features/low_precision_training/introduction/introduction.rst b/docs/features/low_precision_training/introduction/introduction.rst index 0067b9703e..31df8289f6 100644 --- a/docs/features/low_precision_training/introduction/introduction.rst +++ b/docs/features/low_precision_training/introduction/introduction.rst @@ -187,9 +187,6 @@ Let's now see how we can train in lower precisions in supported frameworks. The key difference is that in JAX, model initialization must happen inside the ``autocast`` context to properly capture quantization metadata in the parameter tree. - Additionally, JAX requires a ``global_shard_guard(MeshResource())`` context (even for single GPU) - and the ``mesh_resource`` argument in the ``autocast`` call. - Here is a basic example: .. raw:: html From 63e4b5cbbe8d5c658a1fa61d9baeb7a29beb13a0 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 14 Jan 2026 12:32:07 +0100 Subject: [PATCH 19/25] jax compute capability fix Signed-off-by: Pawel Gadzinski --- .../jax_blockwise_scaling_example.py | 7 ++- .../jax_delayed_scaling_example.py | 16 +------ .../introduction/autocast_jax.py | 44 +++++++------------ 3 files changed, 21 insertions(+), 46 deletions(-) diff --git a/docs/features/low_precision_training/fp8_blockwise_scaling/jax_blockwise_scaling_example.py b/docs/features/low_precision_training/fp8_blockwise_scaling/jax_blockwise_scaling_example.py index 62dd584305..286a253f16 100644 --- a/docs/features/low_precision_training/fp8_blockwise_scaling/jax_blockwise_scaling_example.py +++ b/docs/features/low_precision_training/fp8_blockwise_scaling/jax_blockwise_scaling_example.py @@ -3,12 +3,11 @@ # See LICENSE for license information. # Check for Hopper or newer GPU -from transformer_engine_jax import get_device_compute_capability +from transformer_engine.jax.quantize import get_device_compute_capability -major_minor = get_device_compute_capability(0) assert ( - major_minor >= 90 -), f"FP8 Blockwise Scaling requires SM90 (Hopper) or later, got SM{major_minor}" + get_device_compute_capability() >= 90 +), f"FP8 Blockwise Scaling requires SM90 (Hopper) or later, got SM{get_device_compute_capability()}" # START_BLOCKWISE_SCALING_EXAMPLE diff --git a/docs/features/low_precision_training/fp8_delayed_scaling/jax_delayed_scaling_example.py b/docs/features/low_precision_training/fp8_delayed_scaling/jax_delayed_scaling_example.py index aea2344bae..b671f4b08b 100644 --- a/docs/features/low_precision_training/fp8_delayed_scaling/jax_delayed_scaling_example.py +++ b/docs/features/low_precision_training/fp8_delayed_scaling/jax_delayed_scaling_example.py @@ -2,22 +2,10 @@ # # See LICENSE for license information. -import jax +from transformer_engine.jax.quantize import get_device_compute_capability # Requires Ada (SM89) or newer for FP8 support -cc = jax.devices()[0].device_kind -assert ( - "RTX 40" in cc - or "RTX 5" in cc - or "Ada" in cc - or "L40" in cc - or "H100" in cc - or "H200" in cc - or "GH" in cc - or "B100" in cc - or "B200" in cc - or "GB" in cc -), "This example requires SM89 (Ada) or newer" +assert get_device_compute_capability() >= 89, "This example requires SM89 (Ada) or newer" # START_DELAYED_SCALING_EXAMPLE diff --git a/docs/features/low_precision_training/introduction/autocast_jax.py b/docs/features/low_precision_training/introduction/autocast_jax.py index b5755222fa..0abb670064 100644 --- a/docs/features/low_precision_training/introduction/autocast_jax.py +++ b/docs/features/low_precision_training/introduction/autocast_jax.py @@ -2,22 +2,10 @@ # # See LICENSE for license information. -import jax +from transformer_engine.jax.quantize import get_device_compute_capability # Requires Ada (SM89) or newer for FP8 support -cc = jax.devices()[0].device_kind -assert ( - "RTX 40" in cc - or "RTX 5" in cc - or "Ada" in cc - or "L40" in cc - or "H100" in cc - or "H200" in cc - or "GH" in cc - or "B100" in cc - or "B200" in cc - or "GB" in cc -), "This example requires SM89 (Ada) or newer" +assert get_device_compute_capability() >= 89, "This example requires SM89 (Ada) or newer" # START_AUTOCAST_BASIC @@ -40,14 +28,14 @@ init_key, dropout_key = jax.random.split(jax.random.PRNGKey(0)) x = jax.random.normal(init_key, (32, 128, 1024), dtype=jnp.bfloat16) - params = layer.init({"params": init_key, "dropout": dropout_key}, x) + var_collect = layer.init({"params": init_key, "dropout": dropout_key}, x) # Forward and backward pass (both inside autocast for JAX) - def loss_fn(params): - output = layer.apply(params, x, rngs={"dropout": dropout_key}) + def loss_fn(var_collect): + output = layer.apply(var_collect, x, rngs={"dropout": dropout_key}) return output.sum() - loss, grads = jax.value_and_grad(loss_fn)(params) + loss, grads = jax.value_and_grad(loss_fn)(var_collect) # END_AUTOCAST_BASIC @@ -59,13 +47,13 @@ def loss_fn(params): with te.autocast(enabled=True, recipe=encoder_recipe): encoder = TransformerLayer(hidden_size=1024, mlp_hidden_size=4096, num_attention_heads=16) - encoder_params = encoder.init({"params": init_key, "dropout": dropout_key}, x) - hidden = encoder.apply(encoder_params, x, rngs={"dropout": dropout_key}) + encoder_var_collect = encoder.init({"params": init_key, "dropout": dropout_key}, x) + hidden = encoder.apply(encoder_var_collect, x, rngs={"dropout": dropout_key}) with te.autocast(enabled=True, recipe=decoder_recipe): decoder = TransformerLayer(hidden_size=1024, mlp_hidden_size=4096, num_attention_heads=16) - decoder_params = decoder.init({"params": init_key, "dropout": dropout_key}, hidden) - output = decoder.apply(decoder_params, hidden, rngs={"dropout": dropout_key}) + decoder_var_collect = decoder.init({"params": init_key, "dropout": dropout_key}, hidden) + output = decoder.apply(decoder_var_collect, hidden, rngs={"dropout": dropout_key}) # END_AUTOCAST_SEQUENTIAL @@ -78,18 +66,18 @@ def loss_fn(params): with te.autocast(enabled=True, recipe=outer_recipe): # layer1 uses outer_recipe layer1 = TransformerLayer(hidden_size=1024, mlp_hidden_size=4096, num_attention_heads=16) - params1 = layer1.init({"params": init_key, "dropout": dropout_key}, x) - hidden = layer1.apply(params1, x, rngs={"dropout": dropout_key}) + var_collect1 = layer1.init({"params": init_key, "dropout": dropout_key}, x) + hidden = layer1.apply(var_collect1, x, rngs={"dropout": dropout_key}) with te.autocast(enabled=True, recipe=inner_recipe): # layer2 uses inner_recipe (overrides outer) layer2 = TransformerLayer(hidden_size=1024, mlp_hidden_size=4096, num_attention_heads=16) - params2 = layer2.init({"params": init_key, "dropout": dropout_key}, hidden) - hidden = layer2.apply(params2, hidden, rngs={"dropout": dropout_key}) + var_collect2 = layer2.init({"params": init_key, "dropout": dropout_key}, hidden) + hidden = layer2.apply(var_collect2, hidden, rngs={"dropout": dropout_key}) # layer3 uses outer_recipe again layer3 = TransformerLayer(hidden_size=1024, mlp_hidden_size=4096, num_attention_heads=16) - params3 = layer3.init({"params": init_key, "dropout": dropout_key}, hidden) - output = layer3.apply(params3, hidden, rngs={"dropout": dropout_key}) + var_collect3 = layer3.init({"params": init_key, "dropout": dropout_key}, hidden) + output = layer3.apply(var_collect3, hidden, rngs={"dropout": dropout_key}) # END_AUTOCAST_NESTED From 8620c8933f561a54540f4098af0c01530a65cb69 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 14 Jan 2026 13:11:02 +0100 Subject: [PATCH 20/25] fix Signed-off-by: Pawel Gadzinski --- docs/_static/css/diagram-colors.css | 2 + .../fp8_blockwise_scaling.rst | 9 ++-- .../jax_blockwise_scaling_example.py | 42 ------------------- .../jax_current_scaling_example.py | 8 ++-- .../jax_delayed_scaling_example.py | 8 ++-- .../introduction/bf16_fp16_training_jax.py | 8 ++-- .../mxfp8/jax_mxfp8_example.py | 15 +++---- .../nvfp4/jax_nvfp4_example.py | 15 +++---- .../memory_usage_1_jax.py | 4 +- .../memory_usage_2_jax.py | 4 +- 10 files changed, 39 insertions(+), 76 deletions(-) delete mode 100644 docs/features/low_precision_training/fp8_blockwise_scaling/jax_blockwise_scaling_example.py diff --git a/docs/_static/css/diagram-colors.css b/docs/_static/css/diagram-colors.css index 6ae3f99afa..96a2a8a6dc 100644 --- a/docs/_static/css/diagram-colors.css +++ b/docs/_static/css/diagram-colors.css @@ -73,10 +73,12 @@ } /* Arrows */ +/* Note: marker-end references #arrowhead marker which must be defined in each SVG's section */ .arrow { stroke: #616161; stroke-width: 2; fill: none; + marker-end: url(#arrowhead); } /* Additional box and element styles */ diff --git a/docs/features/low_precision_training/fp8_blockwise_scaling/fp8_blockwise_scaling.rst b/docs/features/low_precision_training/fp8_blockwise_scaling/fp8_blockwise_scaling.rst index 79815baed3..b0e7abe57d 100644 --- a/docs/features/low_precision_training/fp8_blockwise_scaling/fp8_blockwise_scaling.rst +++ b/docs/features/low_precision_training/fp8_blockwise_scaling/fp8_blockwise_scaling.rst @@ -6,6 +6,10 @@ FP8 Blockwise Scaling =================================== +.. warning:: + + ``Float8BlockScaling`` is **not currently supported** in JAX. + FP8 Blockwise Scaling recipe is inspired by the quantization scheme used to train the `DeepSeek-v3 model `__ – the first open-source large-scale LLM trained entirely in FP8 precision. Unlike the previous recipes, it assigns a dedicated scaling factor to each block of elements. @@ -163,10 +167,7 @@ Here's how to use the FP8 Blockwise Scaling recipe in PyTorch and JAX: .. tab:: JAX - .. literalinclude:: jax_blockwise_scaling_example.py - :language: python - :start-after: # START_BLOCKWISE_SCALING_EXAMPLE - :end-before: # END_BLOCKWISE_SCALING_EXAMPLE + ``Float8BlockScaling`` is **not currently supported** in JAX. Supported devices ----------------- diff --git a/docs/features/low_precision_training/fp8_blockwise_scaling/jax_blockwise_scaling_example.py b/docs/features/low_precision_training/fp8_blockwise_scaling/jax_blockwise_scaling_example.py deleted file mode 100644 index 286a253f16..0000000000 --- a/docs/features/low_precision_training/fp8_blockwise_scaling/jax_blockwise_scaling_example.py +++ /dev/null @@ -1,42 +0,0 @@ -# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -# Check for Hopper or newer GPU -from transformer_engine.jax.quantize import get_device_compute_capability - -assert ( - get_device_compute_capability() >= 90 -), f"FP8 Blockwise Scaling requires SM90 (Hopper) or later, got SM{get_device_compute_capability()}" - -# START_BLOCKWISE_SCALING_EXAMPLE - -import jax -import jax.numpy as jnp -import transformer_engine.jax as te -from transformer_engine.jax.flax import DenseGeneral -from transformer_engine.common.recipe import Float8BlockScaling - -# Create FP8 Blockwise Scaling recipe -recipe = Float8BlockScaling( - fp8_format=te.common.recipe.Format.E4M3, # E4M3 or HYBRID (default: E4M3) - x_block_scaling_dim=1, # 1D scaling for activations (default: 1) - w_block_scaling_dim=2, # 2D scaling for weights (default: 2) - grad_block_scaling_dim=1, # 1D scaling for gradients (default: 1) -) - -with te.autocast(enabled=True, recipe=recipe): - # Initialize layer and data - layer = DenseGeneral(features=1024) - key = jax.random.PRNGKey(0) - x = jax.random.normal(key, (32, 128, 1024), dtype=jnp.bfloat16) - params = layer.init(key, x) - - # Forward and backward pass - def loss_fn(params): - output = layer.apply(params, x) - return output.sum() - - loss, grads = jax.value_and_grad(loss_fn)(params) - -# END_BLOCKWISE_SCALING_EXAMPLE diff --git a/docs/features/low_precision_training/fp8_current_scaling/jax_current_scaling_example.py b/docs/features/low_precision_training/fp8_current_scaling/jax_current_scaling_example.py index 236fb255c2..107b13c53b 100644 --- a/docs/features/low_precision_training/fp8_current_scaling/jax_current_scaling_example.py +++ b/docs/features/low_precision_training/fp8_current_scaling/jax_current_scaling_example.py @@ -21,13 +21,13 @@ layer = DenseGeneral(features=1024) key = jax.random.PRNGKey(0) x = jax.random.normal(key, (32, 128, 1024), dtype=jnp.bfloat16) - params = layer.init(key, x) + var_collect = layer.init(key, x) # Forward and backward pass - def loss_fn(params): - output = layer.apply(params, x) + def loss_fn(var_collect): + output = layer.apply(var_collect, x) return output.sum() - loss, grads = jax.value_and_grad(loss_fn)(params) + loss, grads = jax.value_and_grad(loss_fn)(var_collect) # END_CURRENT_SCALING_EXAMPLE diff --git a/docs/features/low_precision_training/fp8_delayed_scaling/jax_delayed_scaling_example.py b/docs/features/low_precision_training/fp8_delayed_scaling/jax_delayed_scaling_example.py index b671f4b08b..5971117686 100644 --- a/docs/features/low_precision_training/fp8_delayed_scaling/jax_delayed_scaling_example.py +++ b/docs/features/low_precision_training/fp8_delayed_scaling/jax_delayed_scaling_example.py @@ -27,13 +27,13 @@ layer = DenseGeneral(features=1024) key = jax.random.PRNGKey(0) x = jax.random.normal(key, (32, 128, 1024), dtype=jnp.bfloat16) - params = layer.init(key, x) + var_collect = layer.init(key, x) # Forward and backward pass - def loss_fn(params): - output = layer.apply(params, x) + def loss_fn(var_collect): + output = layer.apply(var_collect, x) return output.sum() - loss, grads = jax.value_and_grad(loss_fn)(params) + loss, grads = jax.value_and_grad(loss_fn)(var_collect) # END_DELAYED_SCALING_EXAMPLE diff --git a/docs/features/low_precision_training/introduction/bf16_fp16_training_jax.py b/docs/features/low_precision_training/introduction/bf16_fp16_training_jax.py index 8868b63ac6..a3c9c2ae45 100644 --- a/docs/features/low_precision_training/introduction/bf16_fp16_training_jax.py +++ b/docs/features/low_precision_training/introduction/bf16_fp16_training_jax.py @@ -21,15 +21,15 @@ def run_forward_backward(params_dtype, compute_dtype): # Initialize parameters and optimizer init_key, dropout_key = jax.random.split(jax.random.PRNGKey(0)) x = jax.random.normal(init_key, (32, 128, 1024), dtype=compute_dtype) - params = layer.init({"params": init_key, "dropout": dropout_key}, x) + var_collect = layer.init({"params": init_key, "dropout": dropout_key}, x) # Forward and backward pass - def loss_fn(params): - output = layer.apply(params, x, rngs={"dropout": dropout_key}) + def loss_fn(var_collect): + output = layer.apply(var_collect, x, rngs={"dropout": dropout_key}) assert output.dtype == compute_dtype return output.sum() - loss, grads = jax.value_and_grad(loss_fn)(params) + loss, grads = jax.value_and_grad(loss_fn)(var_collect) run_forward_backward(jnp.float32, jnp.float32) # high precision training diff --git a/docs/features/low_precision_training/mxfp8/jax_mxfp8_example.py b/docs/features/low_precision_training/mxfp8/jax_mxfp8_example.py index d4eed8aecc..96ef1a2573 100644 --- a/docs/features/low_precision_training/mxfp8/jax_mxfp8_example.py +++ b/docs/features/low_precision_training/mxfp8/jax_mxfp8_example.py @@ -3,10 +3,11 @@ # See LICENSE for license information. # Check for Blackwell or newer GPU -from transformer_engine_jax import get_device_compute_capability +from transformer_engine.jax.quantize import get_device_compute_capability -major_minor = get_device_compute_capability(0) -assert major_minor >= 100, f"MXFP8 requires SM100 (Blackwell) or later, got SM{major_minor}" +assert ( + get_device_compute_capability() >= 100 +), f"MXFP8 requires SM100 (Blackwell) or later, got SM{get_device_compute_capability()}" # START_MXFP8_EXAMPLE @@ -26,13 +27,13 @@ layer = DenseGeneral(features=1024) key = jax.random.PRNGKey(0) x = jax.random.normal(key, (32, 128, 1024), dtype=jnp.bfloat16) - params = layer.init(key, x) + var_collect = layer.init(key, x) # Forward and backward pass - def loss_fn(params): - output = layer.apply(params, x) + def loss_fn(var_collect): + output = layer.apply(var_collect, x) return output.sum() - loss, grads = jax.value_and_grad(loss_fn)(params) + loss, grads = jax.value_and_grad(loss_fn)(var_collect) # END_MXFP8_EXAMPLE diff --git a/docs/features/low_precision_training/nvfp4/jax_nvfp4_example.py b/docs/features/low_precision_training/nvfp4/jax_nvfp4_example.py index 6eee6e75d0..326db3b899 100644 --- a/docs/features/low_precision_training/nvfp4/jax_nvfp4_example.py +++ b/docs/features/low_precision_training/nvfp4/jax_nvfp4_example.py @@ -3,10 +3,11 @@ # See LICENSE for license information. # Check for Blackwell or newer GPU -from transformer_engine_jax import get_device_compute_capability +from transformer_engine.jax.quantize import get_device_compute_capability -major_minor = get_device_compute_capability(0) -assert major_minor >= 100, f"NVFP4 requires SM100 (Blackwell) or later, got SM{major_minor}" +assert ( + get_device_compute_capability() >= 100 +), f"NVFP4 requires SM100 (Blackwell) or later, got SM{get_device_compute_capability()}" # START_NVFP4_EXAMPLE @@ -27,13 +28,13 @@ layer = DenseGeneral(features=1024) key = jax.random.PRNGKey(0) x = jax.random.normal(key, (32, 128, 1024), dtype=jnp.bfloat16) - params = layer.init(key, x) + var_collect = layer.init(key, x) # Forward and backward pass - def loss_fn(params): - output = layer.apply(params, x) + def loss_fn(var_collect): + output = layer.apply(var_collect, x) return output.sum() - loss, grads = jax.value_and_grad(loss_fn)(params) + loss, grads = jax.value_and_grad(loss_fn)(var_collect) # END_NVFP4_EXAMPLE diff --git a/docs/features/low_precision_training/performance_considerations/memory_usage_1_jax.py b/docs/features/low_precision_training/performance_considerations/memory_usage_1_jax.py index 3b1744295e..216d6bd14e 100644 --- a/docs/features/low_precision_training/performance_considerations/memory_usage_1_jax.py +++ b/docs/features/low_precision_training/performance_considerations/memory_usage_1_jax.py @@ -27,10 +27,10 @@ def measure_memory(): # Initialize layer with BF16 parameters layer = DenseGeneral(features=1024, dtype=jnp.bfloat16) x = jax.random.normal(key, (1024, 1024), dtype=jnp.bfloat16) - params = layer.init(key, x) + var_collect = layer.init(key, x) # Forward pass in high precision - output = layer.apply(params, x) + output = layer.apply(var_collect, x) del x # Input is saved by model for backward, not by user script mem_after_forward = get_gpu_memory_mb() - init_memory diff --git a/docs/features/low_precision_training/performance_considerations/memory_usage_2_jax.py b/docs/features/low_precision_training/performance_considerations/memory_usage_2_jax.py index a724b1ebd0..acdcb034a4 100644 --- a/docs/features/low_precision_training/performance_considerations/memory_usage_2_jax.py +++ b/docs/features/low_precision_training/performance_considerations/memory_usage_2_jax.py @@ -33,8 +33,8 @@ def measure_memory(): # Forward pass with FP8 compute with te.autocast(enabled=True, recipe=recipe): - params = layer.init(key, x) - output = layer.apply(params, x) + var_collect = layer.init(key, x) + output = layer.apply(var_collect, x) del x # Input is saved by model for backward, not by user script mem_after_forward = get_gpu_memory_mb() - init_memory From e0e4af414c8b9f1c3004aa38ce9a26c46b7d7f5e Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 14 Jan 2026 13:14:04 +0100 Subject: [PATCH 21/25] fix Signed-off-by: Pawel Gadzinski --- .../low_precision_training/nvfp4/jax_nvfp4_example.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/docs/features/low_precision_training/nvfp4/jax_nvfp4_example.py b/docs/features/low_precision_training/nvfp4/jax_nvfp4_example.py index 326db3b899..6c94f31345 100644 --- a/docs/features/low_precision_training/nvfp4/jax_nvfp4_example.py +++ b/docs/features/low_precision_training/nvfp4/jax_nvfp4_example.py @@ -26,13 +26,16 @@ with te.autocast(enabled=True, recipe=recipe): # Initialize layer and data layer = DenseGeneral(features=1024) - key = jax.random.PRNGKey(0) + key, sr_key = jax.random.split(jax.random.PRNGKey(0)) x = jax.random.normal(key, (32, 128, 1024), dtype=jnp.bfloat16) - var_collect = layer.init(key, x) + + # NVFP4 requires sr_rng for stochastic rounding + rngs = {"sr_rng": sr_key} + var_collect = layer.init({"params": key, "sr_rng": sr_key}, x) # Forward and backward pass def loss_fn(var_collect): - output = layer.apply(var_collect, x) + output = layer.apply(var_collect, x, rngs=rngs) return output.sum() loss, grads = jax.value_and_grad(loss_fn)(var_collect) From 4a221f3ae0302187da4978da4aec54b60d7f03d1 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 14 Jan 2026 17:22:06 +0100 Subject: [PATCH 22/25] fixes Signed-off-by: Pawel Gadzinski --- .../introduction/introduction.rst | 6 +++ .../memory_usage_1_jax.out | 8 +++- .../memory_usage_1_jax.py | 43 ++++++++--------- .../memory_usage_1_pytorch.py | 1 - .../memory_usage_2_jax.out | 9 +++- .../memory_usage_2_jax.py | 46 +++++++++---------- .../memory_usage_2_pytorch.py | 1 - .../memory_usage_3_pytorch.py | 1 - .../performance_considerations.rst | 14 ++---- 9 files changed, 66 insertions(+), 63 deletions(-) diff --git a/docs/features/low_precision_training/introduction/introduction.rst b/docs/features/low_precision_training/introduction/introduction.rst index 31df8289f6..760a63b0b1 100644 --- a/docs/features/low_precision_training/introduction/introduction.rst +++ b/docs/features/low_precision_training/introduction/introduction.rst @@ -228,6 +228,12 @@ Let's now see how we can train in lower precisions in supported frameworks. :start-after: # START_AUTOCAST_NESTED :end-before: # END_AUTOCAST_NESTED + .. note:: + Python context managers like ``autocast`` may interact unexpectedly with JAX's JIT compilation. + For finer-grained control, consider passing the recipe directly to TE modules instead. + See the `TE JAX Integration notebook `_ + for details. + **Mixed precision with 8- or 4-bit precisions** From now on, we will refer to FP8/MXFP8/NVFP4 etc. as *low precision* diff --git a/docs/features/low_precision_training/performance_considerations/memory_usage_1_jax.out b/docs/features/low_precision_training/performance_considerations/memory_usage_1_jax.out index a57b4931b4..717769b1ed 100644 --- a/docs/features/low_precision_training/performance_considerations/memory_usage_1_jax.out +++ b/docs/features/low_precision_training/performance_considerations/memory_usage_1_jax.out @@ -1,3 +1,9 @@ # START_MEMORY_USAGE_1 -Memory usage after forward pass: 6.00 MB +Tensors in memory: + Shape: (1024, 1024), Dtype: bfloat16, Size: 2048.0 KB + Shape: (1024, 1024), Dtype: bfloat16, Size: 2048.0 KB + Total from all live arrays: 4.00 MB # END_MEMORY_USAGE_1 +Processing events... +Generated: + No reports were generated diff --git a/docs/features/low_precision_training/performance_considerations/memory_usage_1_jax.py b/docs/features/low_precision_training/performance_considerations/memory_usage_1_jax.py index 216d6bd14e..330f7313f7 100644 --- a/docs/features/low_precision_training/performance_considerations/memory_usage_1_jax.py +++ b/docs/features/low_precision_training/performance_considerations/memory_usage_1_jax.py @@ -11,37 +11,34 @@ from transformer_engine.jax.flax import DenseGeneral -def get_gpu_memory_mb(): - """Get current GPU memory usage in MB.""" - jax.effects_barrier() - stats = jax.local_devices()[0].memory_stats() - return stats["bytes_in_use"] / (1024**2) if stats else 0.0 +key = jax.random.PRNGKey(0) +jax.clear_caches() -def measure_memory(): - key = jax.random.PRNGKey(0) - jax.clear_caches() +# Initialize layer with BF16 parameters +layer = DenseGeneral(features=1024, dtype=jnp.bfloat16) +x = jax.random.normal(key, (1024, 1024), dtype=jnp.bfloat16) +var_collect = layer.init(key, x) - init_memory = get_gpu_memory_mb() +@jax.jit +def loss_fn(var_collect, x): + output = layer.apply(var_collect, x) + return output.sum() - # Initialize layer with BF16 parameters - layer = DenseGeneral(features=1024, dtype=jnp.bfloat16) - x = jax.random.normal(key, (1024, 1024), dtype=jnp.bfloat16) - var_collect = layer.init(key, x) - # Forward pass in high precision - output = layer.apply(var_collect, x) - del x # Input is saved by model for backward, not by user script +# Trace the backward pass - this allocates saved tensors +_, backward_fn = jax.vjp(loss_fn, var_collect, x) - mem_after_forward = get_gpu_memory_mb() - init_memory - return mem_after_forward +del x -# Warmup run -measure_memory() +print("Tensors in memory:") +total_bytes = 0 +for arr in jax.live_arrays(): + total_bytes += arr.nbytes + if arr.nbytes > 200000: # do not count small tensors + print(f" Shape: {arr.shape}, Dtype: {arr.dtype}, Size: {arr.nbytes / 1024:.1f} KB") +print(f" Total from all live arrays: {total_bytes / (1024**2):.2f} MB") -# Actual measurement -mem_after_forward = measure_memory() -print(f"Memory usage after forward pass: {mem_after_forward:.2f} MB") print("# END_MEMORY_USAGE_1") diff --git a/docs/features/low_precision_training/performance_considerations/memory_usage_1_pytorch.py b/docs/features/low_precision_training/performance_considerations/memory_usage_1_pytorch.py index 38d3cfe2fd..dd4ce24471 100644 --- a/docs/features/low_precision_training/performance_considerations/memory_usage_1_pytorch.py +++ b/docs/features/low_precision_training/performance_considerations/memory_usage_1_pytorch.py @@ -9,7 +9,6 @@ assert cc[0] == 8 and cc[1] >= 9 or cc[0] == 9, "This example requires SM89 (Ada) or SM90 (Hopper)" print("# START_MEMORY_USAGE_1") -# START_MEMORY_USAGE_1 import torch import transformer_engine.pytorch as te diff --git a/docs/features/low_precision_training/performance_considerations/memory_usage_2_jax.out b/docs/features/low_precision_training/performance_considerations/memory_usage_2_jax.out index 85ee423022..ab720b57a8 100644 --- a/docs/features/low_precision_training/performance_considerations/memory_usage_2_jax.out +++ b/docs/features/low_precision_training/performance_considerations/memory_usage_2_jax.out @@ -1,3 +1,10 @@ # START_MEMORY_USAGE_2 -Memory usage after forward pass: 6.01 MB +Tensors in memory: + Shape: (1024, 1024), Dtype: float8_e4m3fn, Size: 1024.0 KB + Shape: (1024, 1024), Dtype: float8_e4m3fn, Size: 1024.0 KB + Shape: (1024, 1024), Dtype: bfloat16, Size: 2048.0 KB + Total from all live arrays: 4.02 MB # END_MEMORY_USAGE_2 +Processing events... +Generated: + No reports were generated diff --git a/docs/features/low_precision_training/performance_considerations/memory_usage_2_jax.py b/docs/features/low_precision_training/performance_considerations/memory_usage_2_jax.py index acdcb034a4..15023d3775 100644 --- a/docs/features/low_precision_training/performance_considerations/memory_usage_2_jax.py +++ b/docs/features/low_precision_training/performance_considerations/memory_usage_2_jax.py @@ -13,39 +13,37 @@ from transformer_engine.common.recipe import DelayedScaling -def get_gpu_memory_mb(): - """Get current GPU memory usage in MB.""" - jax.effects_barrier() - stats = jax.local_devices()[0].memory_stats() - return stats["bytes_in_use"] / (1024**2) if stats else 0.0 +key = jax.random.PRNGKey(0) +recipe = DelayedScaling() +jax.clear_caches() -def measure_memory(): - key = jax.random.PRNGKey(0) - recipe = DelayedScaling() - jax.clear_caches() +# Initialize layer with BF16 parameters (outside autocast) +layer = DenseGeneral(features=1024, dtype=jnp.bfloat16) +x = jax.random.normal(key, (1024, 1024), dtype=jnp.bfloat16) - init_memory = get_gpu_memory_mb() - # Initialize layer with BF16 parameters (outside autocast) - layer = DenseGeneral(features=1024, dtype=jnp.bfloat16) - x = jax.random.normal(key, (1024, 1024), dtype=jnp.bfloat16) +# Forward and backward pass with FP8 compute +with te.autocast(enabled=True, recipe=recipe): + var_collect = layer.init(key, x) - # Forward pass with FP8 compute - with te.autocast(enabled=True, recipe=recipe): - var_collect = layer.init(key, x) + @jax.jit + def loss_fn(var_collect, x): output = layer.apply(var_collect, x) - del x # Input is saved by model for backward, not by user script + return output.sum() - mem_after_forward = get_gpu_memory_mb() - init_memory - return mem_after_forward + # Trace the backward pass - this allocates saved tensors + _, backward_fn = jax.vjp(loss_fn, var_collect, x) +del x -# Warmup run -measure_memory() +print("Tensors in memory:") +total_bytes = 0 +for arr in jax.live_arrays(): + total_bytes += arr.nbytes + if arr.nbytes > 200000: # do not count small tensors + print(f" Shape: {arr.shape}, Dtype: {arr.dtype}, Size: {arr.nbytes / 1024:.1f} KB") +print(f" Total from all live arrays: {total_bytes / (1024**2):.2f} MB")\ -# Actual measurement -mem_after_forward = measure_memory() -print(f"Memory usage after forward pass: {mem_after_forward:.2f} MB") print("# END_MEMORY_USAGE_2") diff --git a/docs/features/low_precision_training/performance_considerations/memory_usage_2_pytorch.py b/docs/features/low_precision_training/performance_considerations/memory_usage_2_pytorch.py index 7928ace2af..5c247177d8 100644 --- a/docs/features/low_precision_training/performance_considerations/memory_usage_2_pytorch.py +++ b/docs/features/low_precision_training/performance_considerations/memory_usage_2_pytorch.py @@ -9,7 +9,6 @@ assert cc[0] == 8 and cc[1] >= 9 or cc[0] == 9, "This example requires SM89 (Ada) or SM90 (Hopper)" print("# START_MEMORY_USAGE_2") -# START_MEMORY_USAGE_2 import torch import transformer_engine.pytorch as te diff --git a/docs/features/low_precision_training/performance_considerations/memory_usage_3_pytorch.py b/docs/features/low_precision_training/performance_considerations/memory_usage_3_pytorch.py index 3de5f10aad..ce6905ce4b 100644 --- a/docs/features/low_precision_training/performance_considerations/memory_usage_3_pytorch.py +++ b/docs/features/low_precision_training/performance_considerations/memory_usage_3_pytorch.py @@ -9,7 +9,6 @@ assert cc[0] == 8 and cc[1] >= 9 or cc[0] == 9, "This example requires SM89 (Ada) or SM90 (Hopper)" print("# START_MEMORY_USAGE_3") -# START_MEMORY_USAGE_3 import torch import transformer_engine.pytorch as te diff --git a/docs/features/low_precision_training/performance_considerations/performance_considerations.rst b/docs/features/low_precision_training/performance_considerations/performance_considerations.rst index fb2b011d30..a495af56c1 100644 --- a/docs/features/low_precision_training/performance_considerations/performance_considerations.rst +++ b/docs/features/low_precision_training/performance_considerations/performance_considerations.rst @@ -96,7 +96,7 @@ usages during training. This has implications for memory layout and transpose op The physical memory layout requirements for rowwise and columnwise usages differ between architectures and recipes. For FP8 tensors: -- *Hopper*: cannot efficiently access elements in columnwise fashion, so columnwise tensors need to be physically transposed in memory. +- *Hopper*: cannot efficiently access elements in columnwise fashion, so columnwise tensors need to be physically transposed in memory. Note that higher precision formats (BF16/FP16) do not have this limitation. - *Blackwell*: supports columnwise access natively, so no transpose is needed. We will see that for most of the recipes and devices, rowwise usage and columnwise usage need different tensors. @@ -343,7 +343,7 @@ and columnwise tensors require separate memory layouts. :end-before: # END_MEMORY_USAGE_1 Layer size is ``1024 * 1024 * 2 (2 bytes per parameter) = 2MB``. - Memory after forward pass is ``2 MB (weight) + 2 MB (input saved for backward) + 2 MB (output) = 6 MB``. + Memory after forward pass is ``2 MB (weight) + 2 MB (input saved for backward) = 4 MB``. **2. FP8 training with master weights in BF16** @@ -373,15 +373,7 @@ and columnwise tensors require separate memory layouts. :start-after: # START_MEMORY_USAGE_2 :end-before: # END_MEMORY_USAGE_2 - In JAX, unlike PyTorch, FP8 weights are not cached between forward passes. - Weights are stored in BF16 and quantized to FP8 on-the-fly during each forward pass. - This means the memory usage is similar to the baseline. - - .. note:: - - JAX does not currently support storing model weights directly in FP8 format - like PyTorch's ``quantized_model_init``. Weights are always stored in high precision - (BF16/FP32) and quantized to FP8 during computation. + Memory after forward pass is ``2 MB (weight in BF16) + 1 MB (input in FP8) + 1 MB (weight in FP8) = 4 MB``. Fused layers ------------ From 95d15858d6e6557f7f3a8daca0c8a8efce087567 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 14 Jan 2026 16:22:57 +0000 Subject: [PATCH 23/25] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../performance_considerations/memory_usage_1_jax.py | 1 + .../performance_considerations/memory_usage_2_jax.py | 5 ++--- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/features/low_precision_training/performance_considerations/memory_usage_1_jax.py b/docs/features/low_precision_training/performance_considerations/memory_usage_1_jax.py index 330f7313f7..8c1250575e 100644 --- a/docs/features/low_precision_training/performance_considerations/memory_usage_1_jax.py +++ b/docs/features/low_precision_training/performance_considerations/memory_usage_1_jax.py @@ -20,6 +20,7 @@ x = jax.random.normal(key, (1024, 1024), dtype=jnp.bfloat16) var_collect = layer.init(key, x) + @jax.jit def loss_fn(var_collect, x): output = layer.apply(var_collect, x) diff --git a/docs/features/low_precision_training/performance_considerations/memory_usage_2_jax.py b/docs/features/low_precision_training/performance_considerations/memory_usage_2_jax.py index 15023d3775..3baa55bb8a 100644 --- a/docs/features/low_precision_training/performance_considerations/memory_usage_2_jax.py +++ b/docs/features/low_precision_training/performance_considerations/memory_usage_2_jax.py @@ -41,9 +41,8 @@ def loss_fn(var_collect, x): total_bytes = 0 for arr in jax.live_arrays(): total_bytes += arr.nbytes - if arr.nbytes > 200000: # do not count small tensors + if arr.nbytes > 200000: # do not count small tensors print(f" Shape: {arr.shape}, Dtype: {arr.dtype}, Size: {arr.nbytes / 1024:.1f} KB") -print(f" Total from all live arrays: {total_bytes / (1024**2):.2f} MB")\ - +print(f" Total from all live arrays: {total_bytes / (1024**2):.2f} MB") print("# END_MEMORY_USAGE_2") From e6add765ab2da3e39201b67477a0303fdaefa4fe Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Thu, 29 Jan 2026 15:25:28 -0800 Subject: [PATCH 24/25] fix Signed-off-by: Pawel Gadzinski --- .../fp8_blockwise_scaling/fp8_blockwise_scaling.rst | 5 +++-- docs/features/low_precision_training/mxfp8/mxfp8.rst | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/docs/features/low_precision_training/fp8_blockwise_scaling/fp8_blockwise_scaling.rst b/docs/features/low_precision_training/fp8_blockwise_scaling/fp8_blockwise_scaling.rst index b0e7abe57d..48d17db8d5 100644 --- a/docs/features/low_precision_training/fp8_blockwise_scaling/fp8_blockwise_scaling.rst +++ b/docs/features/low_precision_training/fp8_blockwise_scaling/fp8_blockwise_scaling.rst @@ -8,7 +8,7 @@ FP8 Blockwise Scaling .. warning:: - ``Float8BlockScaling`` is **not currently supported** in JAX. + ``Float8BlockScaling`` is **currently not supported** in JAX. FP8 Blockwise Scaling recipe is inspired by the quantization scheme used to train the `DeepSeek-v3 model `__ – the first open-source large-scale LLM trained entirely in FP8 precision. @@ -174,7 +174,8 @@ Supported devices Hopper (SM 9.0) -Blackwell and later (SM >= 10.0) – recipe is emulated with MXFP8. Note that this is done mainly for compatibility, MXFP8 is the preferred recipe on Blackwell. +Blackwell and later (SM >= 10.0) – the recipe is emulated with MXFP8. Note that MXFP8 is the preferred recipe on Blackwell. + Only scaling factors that are powers of 2 are supported. ---- diff --git a/docs/features/low_precision_training/mxfp8/mxfp8.rst b/docs/features/low_precision_training/mxfp8/mxfp8.rst index 14d805e808..cbd1dcc24b 100644 --- a/docs/features/low_precision_training/mxfp8/mxfp8.rst +++ b/docs/features/low_precision_training/mxfp8/mxfp8.rst @@ -147,7 +147,7 @@ Here's how to use MXFP8 recipe in PyTorch and JAX: Supported devices ----------------- -Blackwell and later (SM 10.0+) +Blackwell and later (SM 10.0, SM 10.3) ---- From 7641377602e8e927db82a683c27c59ed1f966e7b Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Fri, 30 Jan 2026 08:43:33 -0800 Subject: [PATCH 25/25] fix Signed-off-by: Pawel Gadzinski --- docs/features/low_precision_training/mxfp8/mxfp8.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/features/low_precision_training/mxfp8/mxfp8.rst b/docs/features/low_precision_training/mxfp8/mxfp8.rst index cbd1dcc24b..f8f8f48b0d 100644 --- a/docs/features/low_precision_training/mxfp8/mxfp8.rst +++ b/docs/features/low_precision_training/mxfp8/mxfp8.rst @@ -147,7 +147,7 @@ Here's how to use MXFP8 recipe in PyTorch and JAX: Supported devices ----------------- -Blackwell and later (SM 10.0, SM 10.3) +SM 10.0, SM 10.3 ----