Skip to content

Conversation

@yaox12
Copy link
Member

@yaox12 yaox12 commented Jan 29, 2026

Description

  1. Fuse scaling and unscaling of bf16 momentums into kernels avoid explicit FP32 copies, which reduces the peak memory footprint.
  2. Enable CUDA Graphs for BF16 momentums.

This PR only enables this feature for BF16 momentums, because BF16 doesn't require real scaling and unscaling, but just type converting.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Xin Yao <xiny@nvidia.com>
Signed-off-by: Xin Yao <xiny@nvidia.com>
@yaox12 yaox12 requested a review from timmoon10 January 29, 2026 07:22
@yaox12
Copy link
Member Author

yaox12 commented Jan 29, 2026

@kunlunl Can you review as well?

cc @nanz-nv This should resolve the peak memory issue, but I haven't verify it.

Also enable the capturable mode for BF16 momentums.

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 29, 2026

Greptile Overview

Greptile Summary

This PR optimizes BF16 momentum handling in FusedAdam by fusing scaling/unscaling operations directly into the CUDA kernels, eliminating explicit FP32 copies and reducing peak memory footprint. It also enables CUDA Graphs (capturable mode) for BF16 momentums.

Key changes:

  • Added MOMENT_T template parameter to Adam kernel functors to support both FP32 and BF16 moment storage
  • Updated kernel validation to allow both moments (exp_avg/exp_avg_sq) to be either FP32 or BF16 (must match)
  • Added fuse_unscale flag in Python code that skips explicit BF16→FP32→BF16 conversions when moments are BF16
  • Modified get_unscaled_state() to optionally skip unscaling for BF16, letting kernel handle type conversion
  • Updated capturable mode validation to allow BF16 moments alongside FP32
  • Fixed store_param_remainders validation ordering issue in latest commit
  • Added test coverage for BF16 momentums in non-capturable mode
  • Updated deprecated torch.cuda.amp.GradScaler to torch.amp.GradScaler

Implementation approach:
The optimization leverages that BF16 doesn't require true scaling/unscaling - just type conversion. The kernel loads BF16 moments, performs Adam math in FP32, then stores back as BF16, avoiding intermediate FP32 buffers in Python. This keeps tensor pointers stable for CUDA Graph capture.

Confidence Score: 4/5

  • Safe to merge with minor test coverage gap for capturable mode
  • Implementation is well-designed with proper type safety through template parameters and runtime validation. The fuse_unscale logic is sound and the latest commit fixed the store_param_remainders ordering bug. Score not 5 due to missing test coverage for the capturable mode with BF16 momentums (a key feature mentioned in PR description).
  • tests/pytorch/test_fused_optimizer.py - add capturable mode test for BF16 momentums

Important Files Changed

Filename Overview
tests/pytorch/test_fused_optimizer.py Added test for BF16 momentums and fixed GradScaler API usage
transformer_engine/common/multi_tensor/adam.cu Added MOMENT_T template parameter to support BF16 momentums, updated validation to allow BF16
transformer_engine/pytorch/optimizers/fused_adam.py Updated capturable mode validation to allow BF16 moments, added fuse_unscale flag, modified get_unscaled_state to skip unscaling for BF16

Sequence Diagram

sequenceDiagram
    participant User
    participant FusedAdam as FusedAdam (Python)
    participant AdamKernel as Adam CUDA Kernel
    participant Moments as BF16 Moments (exp_avg/exp_avg_sq)

    Note over User,Moments: BF16 Moments with Fused Scaling Flow

    User->>FusedAdam: step() with BF16 moments
    FusedAdam->>FusedAdam: Check fuse_unscale flag (BF16 moments)
    FusedAdam->>FusedAdam: get_unscaled_state(skip_unscale=True)
    Note over FusedAdam: Skip explicit BF16→FP32 conversion
    FusedAdam->>AdamKernel: Call kernel with BF16 moments directly
    AdamKernel->>AdamKernel: Load BF16 moments, cast to FP32 for math
    AdamKernel->>AdamKernel: Perform Adam update in FP32
    AdamKernel->>Moments: Store updated moments as BF16
    Note over AdamKernel,Moments: Fused scaling: FP32→BF16 cast in kernel
    AdamKernel-->>FusedAdam: Return
    FusedAdam->>FusedAdam: Skip explicit scaling (fuse_unscale=True)
    Note over FusedAdam: Avoid explicit FP32 copy & scaling
    FusedAdam-->>User: Complete
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@yaox12 yaox12 added the 2.13.0 label Jan 29, 2026
Signed-off-by: Xin Yao <xiny@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

yaox12 and others added 2 commits January 29, 2026 07:47
Signed-off-by: Xin Yao <xiny@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Comment on lines +410 to +422
@pytest.mark.skipif(not is_bf16_available(), reason="bf16 if not supported")
def test_bf16_exp_avg_and_exp_avg_sq(self):
self.gen_precision_aware_test(
use_fp8_params=False,
param_dtype=torch.bfloat16,
use_master_weights=True,
master_weight_dtype=torch.float32,
grad_dtype=torch.float32,
exp_avg_dtype=torch.bfloat16,
exp_avg_sq_dtype=torch.bfloat16,
master_rtol=2e-3,
master_atol=2e-3,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding a test for capturable mode (CUDA Graphs) with BF16 momentums, since the PR description mentions "Enable CUDA Graphs for BF16 momentums" as a key feature. The current test only covers non-capturable mode.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Signed-off-by: Xin Yao <xiny@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Comment on lines +410 to +422
@pytest.mark.skipif(not is_bf16_available(), reason="bf16 if not supported")
def test_bf16_exp_avg_and_exp_avg_sq(self):
self.gen_precision_aware_test(
use_fp8_params=False,
param_dtype=torch.bfloat16,
use_master_weights=True,
master_weight_dtype=torch.float32,
grad_dtype=torch.float32,
exp_avg_dtype=torch.bfloat16,
exp_avg_sq_dtype=torch.bfloat16,
master_rtol=2e-3,
master_atol=2e-3,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test only covers non-capturable mode. Add test for capturable mode with BF16 momentums since PR enables CUDA Graphs for this case.

Note: Check that gen_precision_aware_test supports a capturable parameter, or create a separate test method.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants