-
Notifications
You must be signed in to change notification settings - Fork 620
More detailed documentation for recipes #2343
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
79ed6d7 to
9649cd8
Compare
3053170 to
7905a74
Compare
…m low precision training Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
… add GPU checks Changes: - Remove optimizer code from all recipe examples (keep only forward/backward) - Fix Format imports (use Format.E4M3 instead of string 'E4M3') - Fix params_dtype for PyTorch examples (add params_dtype=torch.bfloat16) - Add GPU capability assertions before START blocks for blockwise/mxfp8/nvfp4 - Fix JAX imports (Float8CurrentScaling from common.recipe, NVFP4BlockScaling) - Add global_shard_guard for TransformerLayer examples in JAX - Fix fused_layers_jax.py return tuple unpacking - Update memory_usage JAX examples with dynamic GPU measurement - Remove memory_usage_3_jax (JAX doesn't support FP8 weight storage) - Update performance_considerations.rst for JAX differences - Delete unused .out files and fp8_autocast_jax.py Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Greptile OverviewGreptile SummaryThis PR adds comprehensive documentation for Transformer Engine's low precision training capabilities, covering FP8, MXFP8, and NVFP4 quantization recipes for both PyTorch and JAX frameworks. Key Additions
Documentation QualityThe documentation is well-structured with clear progression from basic concepts to advanced optimization. Technical explanations are thorough, diagrams effectively illustrate complex concepts, and code examples cover both single-GPU and distributed scenarios. Minor Issues
All previously identified issues from earlier review rounds have been addressed. Confidence Score: 5/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User as User/Developer
participant Docs as Documentation
participant Intro as Introduction
participant Recipes as Recipe Docs<br/>(FP8/MXFP8/NVFP4)
participant Perf as Performance Guide
participant Examples as Code Examples
User->>Docs: Read low precision training docs
Docs->>Intro: 1. Learn mixed precision basics
Note over Intro: BF16/FP16 concepts<br/>Master weights<br/>Autocast usage
Intro->>Recipes: 2. Choose quantization recipe
Note over Recipes: FP8 Current Scaling<br/>FP8 Delayed Scaling<br/>FP8 Block Scaling<br/>MXFP8<br/>NVFP4
Recipes->>Examples: 3. Review framework examples
Note over Examples: PyTorch examples<br/>JAX examples<br/>Distributed training
Examples->>Perf: 4. Optimize performance
Note over Perf: Transpose handling<br/>Memory optimization<br/>Fused layers
Perf->>User: 5. Implement in production
Note over User: Apply recipe with<br/>te.autocast(recipe=...)
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
46 files reviewed, no comments
docs/features/low_precision_training/introduction/introduction.rst
Outdated
Show resolved
Hide resolved
docs/features/low_precision_training/introduction/introduction.rst
Outdated
Show resolved
Hide resolved
docs/features/low_precision_training/introduction/introduction.rst
Outdated
Show resolved
Hide resolved
docs/features/low_precision_training/introduction/introduction.rst
Outdated
Show resolved
Hide resolved
docs/features/low_precision_training/introduction/introduction.rst
Outdated
Show resolved
Hide resolved
docs/features/low_precision_training/introduction/introduction.rst
Outdated
Show resolved
Hide resolved
docs/features/low_precision_training/introduction/introduction.rst
Outdated
Show resolved
Hide resolved
docs/features/low_precision_training/introduction/introduction.rst
Outdated
Show resolved
Hide resolved
docs/features/low_precision_training/introduction/introduction.rst
Outdated
Show resolved
Hide resolved
docs/features/low_precision_training/introduction/introduction.rst
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 files reviewed, 2 comments
docs/features/low_precision_training/fp8_delayed_scaling/fp8_delayed_scaling.rst
Outdated
Show resolved
Hide resolved
docs/features/low_precision_training/fp8_current_scaling/img/fp8_tensor_core.svg
Outdated
Show resolved
Hide resolved
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 file reviewed, 1 comment
docs/features/low_precision_training/fp8_delayed_scaling/fp8_delayed_scaling.rst
Outdated
Show resolved
Hide resolved
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 files reviewed, 2 comments
docs/features/low_precision_training/fp8_delayed_scaling/fp8_delayed_scaling.rst
Show resolved
Hide resolved
docs/features/low_precision_training/introduction/introduction.rst
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, 3 comments
docs/features/low_precision_training/fp8_delayed_scaling/fp8_delayed_scaling.rst
Show resolved
Hide resolved
docs/features/low_precision_training/introduction/introduction.rst
Outdated
Show resolved
Hide resolved
...atures/low_precision_training/fp8_delayed_scaling/jax_delayed_scaling_distributed_example.py
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 file reviewed, 1 comment
docs/features/low_precision_training/introduction/bf16_fp16_training_jax.py
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 file reviewed, 1 comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No files reviewed, no comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, 3 comments
docs/features/low_precision_training/introduction/bf16_fp16_training_pytorch.py
Show resolved
Hide resolved
docs/features/low_precision_training/performance_considerations/memory_usage_2_pytorch.py
Show resolved
Hide resolved
jberchtold-nvidia
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fantastic work! The diagrams and explanations are great and along with the code examples will hopefully make things a lot clearer to users
docs/features/low_precision_training/fp8_delayed_scaling/jax_delayed_scaling_example.py
Outdated
Show resolved
Hide resolved
docs/features/low_precision_training/introduction/autocast_jax.py
Outdated
Show resolved
Hide resolved
docs/features/low_precision_training/performance_considerations/performance_considerations.rst
Outdated
Show resolved
Hide resolved
docs/features/low_precision_training/fp8_blockwise_scaling/jax_blockwise_scaling_example.py
Outdated
Show resolved
Hide resolved
docs/features/low_precision_training/nvfp4/jax_nvfp4_example.py
Outdated
Show resolved
Hide resolved
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
for more information, see https://pre-commit.ci
|
|
||
| Hopper (SM 9.0) | ||
|
|
||
| Blackwell and later (SM >= 10.0) – recipe is emulated with MXFP8. Note that this is done mainly for compatibility, MXFP8 is the preferred recipe on Blackwell. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- emulated = power of 2 scaling factor only
- remove the compatibility part.
| ----------------- | ||
|
|
||
| Blackwell and later (SM 10.0+) No newline at end of file | ||
| Blackwell and later (SM 10.0+) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment about SM 12
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, no comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, 2 comments
| =================================== | ||
|
|
||
| NVFP4 is the first 4-bit recipe introduced in Transformer Engine – | ||
| please refer to the `NVFP4 paper <https://arxiv.org/abs/2509.25149>`__ for more details. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
arXiv ID 2509.25149 uses prefix 2509 (September 2025), which is in the future. Verify this is the correct reference or use a placeholder format like [arXiv link pending] until the paper is published.
| layer = te.Linear(1024, 1024, params_dtype=torch.bfloat16) | ||
|
|
||
| inp = torch.randn(1024, 1024, dtype=torch.bfloat16, device="cuda") | ||
| with te.autocast(enabled=True): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing recipe parameter. te.autocast(enabled=True) uses a default recipe, but for documentation clarity, explicitly specify which recipe is being used (e.g., recipe=DelayedScaling()) to match other examples.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
Description
This PR adds detailed documentation for the Low Precision Training feature in Transformer Engine, covering FP8, MXFP8, NVFP4, and other quantization recipes for both PyTorch and JAX frameworks.
Type of change
Checklist: