-
Notifications
You must be signed in to change notification settings - Fork 621
[Pytorch] Add get_backward_dw_params api for TE module #2614
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[Pytorch] Add get_backward_dw_params api for TE module #2614
Conversation
Greptile OverviewGreptile SummaryThis PR fixes a bug where weight gradient computation hooks were being discarded when CUDA graphs are used with Megatron-LM, causing parameters to skip gradient reduction. Changes:
The implementation follows the same pattern that was previously reverted (commit d04c008), but with a cleaner function signature that leverages closure scope access to Confidence Score: 5/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant MegatronLM
participant GraphedCallable
participant backward_dw
participant TEModule
participant Hooks
Note over MegatronLM,Hooks: Wgrad CUDA Graph Execution Flow
MegatronLM->>GraphedCallable: call backward_dw()
GraphedCallable->>backward_dw: execute backward_dw()
alt need_bwd_dw_graph[graph_idx] == True
backward_dw->>backward_dw: replay wgrad graph
loop for each module in te_modules
backward_dw->>TEModule: check need_backward_dw()
alt module needs backward_dw
backward_dw->>TEModule: _trigger_wgrad_accumulation_and_reduce_hooks()
loop for each hook in wgrad_accumulation_and_reduce_hooks
TEModule->>Hooks: execute hook()
Note over Hooks: Performs grad accumulation<br/>and reduction
end
end
end
end
|
Greptile's behavior is changing!From now on, if a review finishes with no comments, we will not post an additional "statistics" comment to confirm that our review found nothing to comment on. However, you can confirm that we reviewed your changes in the status check section. This feature can be toggled off in your Code Review Settings by deselecting "Create a status check for each PR". |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 file reviewed, 1 comment
| Get the parameters for the backward weight gradient computation. | ||
| """ | ||
| params = [] | ||
| params.append(noop_cat(self._get_weight_tensors())) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: in backward_dw() (line 1520-1522), weight tensors are only accessed when not self.fuse_wgrad_accumulation, but this method unconditionally returns weight parameters. depending on Megatron-LM's usage, this could cause hooks to be registered on parameters that shouldn't have them when fuse_wgrad_accumulation=True
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
commit content reverted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No files reviewed, no comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No files reviewed, no comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No files reviewed, no comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 files reviewed, no comments
Description
This PR adds
get_backward_dw_paramsfor TE modules, which helps manage the hooks of parameters.For Megatron-LM,
get_backward_dw_paramswill be called once the wgrad cuda graph is executed. Currently the backward_post_hook of wgrad computation is discarded and will cause parameters to skip grad reduce.Type of change
Changes
Please list the changes introduced in this PR:
Checklist: