-
Notifications
You must be signed in to change notification settings - Fork 621
[Common][PyTorch] Fuse scaling and unscaling of bf16 momentums into kernels #2632
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Xin Yao <xiny@nvidia.com>
Signed-off-by: Xin Yao <xiny@nvidia.com>
Greptile OverviewGreptile SummaryThis PR optimizes BF16 momentum handling in FusedAdam by fusing scaling/unscaling operations directly into the CUDA kernels, eliminating explicit FP32 copies and reducing peak memory footprint. It also enables CUDA Graphs (capturable mode) for BF16 momentums. Key changes:
Implementation approach: Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant FusedAdam as FusedAdam (Python)
participant AdamKernel as Adam CUDA Kernel
participant Moments as BF16 Moments (exp_avg/exp_avg_sq)
Note over User,Moments: BF16 Moments with Fused Scaling Flow
User->>FusedAdam: step() with BF16 moments
FusedAdam->>FusedAdam: Check fuse_unscale flag (BF16 moments)
FusedAdam->>FusedAdam: get_unscaled_state(skip_unscale=True)
Note over FusedAdam: Skip explicit BF16→FP32 conversion
FusedAdam->>AdamKernel: Call kernel with BF16 moments directly
AdamKernel->>AdamKernel: Load BF16 moments, cast to FP32 for math
AdamKernel->>AdamKernel: Perform Adam update in FP32
AdamKernel->>Moments: Store updated moments as BF16
Note over AdamKernel,Moments: Fused scaling: FP32→BF16 cast in kernel
AdamKernel-->>FusedAdam: Return
FusedAdam->>FusedAdam: Skip explicit scaling (fuse_unscale=True)
Note over FusedAdam: Avoid explicit FP32 copy & scaling
FusedAdam-->>User: Complete
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, 1 comment
Signed-off-by: Xin Yao <xiny@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, no comments
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 files reviewed, 1 comment
| @pytest.mark.skipif(not is_bf16_available(), reason="bf16 if not supported") | ||
| def test_bf16_exp_avg_and_exp_avg_sq(self): | ||
| self.gen_precision_aware_test( | ||
| use_fp8_params=False, | ||
| param_dtype=torch.bfloat16, | ||
| use_master_weights=True, | ||
| master_weight_dtype=torch.float32, | ||
| grad_dtype=torch.float32, | ||
| exp_avg_dtype=torch.bfloat16, | ||
| exp_avg_sq_dtype=torch.bfloat16, | ||
| master_rtol=2e-3, | ||
| master_atol=2e-3, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider adding a test for capturable mode (CUDA Graphs) with BF16 momentums, since the PR description mentions "Enable CUDA Graphs for BF16 momentums" as a key feature. The current test only covers non-capturable mode.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
Signed-off-by: Xin Yao <xiny@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, 1 comment
| @pytest.mark.skipif(not is_bf16_available(), reason="bf16 if not supported") | ||
| def test_bf16_exp_avg_and_exp_avg_sq(self): | ||
| self.gen_precision_aware_test( | ||
| use_fp8_params=False, | ||
| param_dtype=torch.bfloat16, | ||
| use_master_weights=True, | ||
| master_weight_dtype=torch.float32, | ||
| grad_dtype=torch.float32, | ||
| exp_avg_dtype=torch.bfloat16, | ||
| exp_avg_sq_dtype=torch.bfloat16, | ||
| master_rtol=2e-3, | ||
| master_atol=2e-3, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Test only covers non-capturable mode. Add test for capturable mode with BF16 momentums since PR enables CUDA Graphs for this case.
Note: Check that gen_precision_aware_test supports a capturable parameter, or create a separate test method.
Description
This PR only enables this feature for BF16 momentums, because BF16 doesn't require real scaling and unscaling, but just type converting.
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: