Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions transformer_engine/pytorch/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -853,12 +853,22 @@ def functionalized(*user_args, **user_kwargs):
return functionalized

def make_graphed_attribute_functions(graph_idx):
# Get te modules for current graph
te_modules = visited_te_modules.get(graph_idx, set())

# Attach backward_dw as an attribute to the graphed callable.
def backward_dw():
if need_bwd_dw_graph.get(graph_idx, False):
bwd_dw_graphs[graph_idx].replay()

# Trigger the grad accumulation hook for wgrad graphs.
for module in te_modules:
if (
isinstance(module, TransformerEngineBaseModule)
and module.need_backward_dw()
):
module._trigger_wgrad_accumulation_and_reduce_hooks()
Comment thread
timmoon10 marked this conversation as resolved.

# Attach reset as an attribute to the graphed callable.
def reset():
fwd_graphs[graph_idx].reset()
Expand Down
10 changes: 8 additions & 2 deletions transformer_engine/pytorch/module/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1526,8 +1526,14 @@ def backward_dw(self):
bias_tensor.grad = bgrad.to(bias_tensor.dtype)
del wgrad
del bgrad
for wgrad_accumulation_and_reduce_hook in self.wgrad_accumulation_and_reduce_hooks:
wgrad_accumulation_and_reduce_hook()
self._trigger_wgrad_accumulation_and_reduce_hooks()

def _trigger_wgrad_accumulation_and_reduce_hooks(self):
"""
Trigger the wgrad accumulation and reduce hooks.
"""
for wgrad_accumulation_and_reduce_hook in self.wgrad_accumulation_and_reduce_hooks:
wgrad_accumulation_and_reduce_hook()
Comment on lines +1531 to +1536
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unconditional hook trigger
TransformerEngineBaseModule.backward_dw() now always calls _trigger_wgrad_accumulation_and_reduce_hooks() (transformer_engine/pytorch/module/base.py:1529-1536). In the fuse_wgrad_accumulation=True path the module does not materialize weight_tensor.grad (base.py:1520-1522), so these hooks may run without the expected grads present. If any hook assumes .grad exists (common for grad-reduce hooks), this will raise at runtime specifically when delayed wgrad compute is enabled with fused wgrad accumulation.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These hooks are specific to Megatron-LM. They are delicate interfaces for expert users.


def is_debug_iter(self) -> bool:
"""
Expand Down
3 changes: 1 addition & 2 deletions transformer_engine/pytorch/module/grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,8 +873,7 @@ def backward_dw(self):
del grad_biases_
del wgrad_list
del tensor_list
for wgrad_accumulation_and_reduce_hook in self.wgrad_accumulation_and_reduce_hooks:
wgrad_accumulation_and_reduce_hook()
self._trigger_wgrad_accumulation_and_reduce_hooks()

def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None:
"""Customize quantizers based on current scaling recipe + linear."""
Expand Down
3 changes: 1 addition & 2 deletions transformer_engine/pytorch/module/layernorm_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2506,5 +2506,4 @@ def backward_dw(self):
del fc2_wgrad
del fc1_wgrad
del fc1_bias_grad
for wgrad_accumulation_and_reduce_hook in self.wgrad_accumulation_and_reduce_hooks:
wgrad_accumulation_and_reduce_hook()
self._trigger_wgrad_accumulation_and_reduce_hooks()
Loading