Skip to content

[PyTorch] Fix fuser so it releases tensors properly#2750

Open
kainzhong wants to merge 2 commits intoNVIDIA:mainfrom
kainzhong:fix/release_tensor_for_fuser
Open

[PyTorch] Fix fuser so it releases tensors properly#2750
kainzhong wants to merge 2 commits intoNVIDIA:mainfrom
kainzhong:fix/release_tensor_for_fuser

Conversation

@kainzhong
Copy link
Collaborator

@kainzhong kainzhong commented Mar 10, 2026

Description

When training llama3 405B we observe OOMs with passing use_te_op_fuser=True in Megatron. This is because OperationFuser fails to properly release tensors that are saved for backward.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

In _OperationFuserAutogradFunction's backward, we use

saved_tensors = restore_from_saved(func_ctx.tensor_objects, func_ctx.saved_tensors)

to restore the saved tensor from backward. However, we didn't detach tensor_objects from func_ctx, where func_ctx will not be released until the next iteration due to pytorch's internal mechanism (pytorch would clear tensor saved by save_for_backward, but not for tensors attached to ctx directly). Therefore the attached tensors are not be released in time and caused higher memory usage. After manually set func_ctx.tensor_objects to None these tensors can be properly freed.

Before this fix After this fix
image image

In addition, in BasicLinear's backward, I manually called clear_tensor_data on the weight tensor if it's allocated by the quantizer (w is not self.weight). This is not necessary since with the previous fix where the weight tensor will still be released after fuser's backward, so it's more of an optimization. Removed this from the PR because it's not compatible with the mocked fuser backward in tests.

(Tested on Megatron-LM's llama3 8B example with manually setting use_te_op_fuser=True in gpt_builders)

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

ksivaman
ksivaman previously approved these changes Mar 10, 2026
Copy link
Member

@ksivaman ksivaman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 10, 2026

Greptile Summary

This PR fixes a memory leak in _OperationFuserAutogradFunction that caused OOMs during large-model training (e.g., LLaMA3 405B) when use_te_op_fuser=True. The single-line fix explicitly sets func_ctx.tensor_objects = None after restoring saved tensors in the backward pass.

  • Root cause: tensor_objects (a list of QuantizedTensorStorage instances) was saved directly on func_ctx rather than via save_for_backward. PyTorch automatically clears tensors registered through save_for_backward, but does NOT do so for attributes attached directly to the context object. As a result, tensor_objects was held alive until func_ctx itself was released — which PyTorch defers until the next iteration — causing temporarily elevated peak memory usage.
  • Fix: After calling restore_from_saved(func_ctx.tensor_objects, func_ctx.saved_tensors), the PR immediately nullifies the reference with func_ctx.tensor_objects = None, allowing the stored QuantizedTensorStorage objects to be garbage-collected promptly at the end of the backward pass.
  • Scope: The change is minimal, targeted, and does not alter any functional logic. No new warnings are introduced.

Confidence Score: 5/5

  • This PR is safe to merge — the change is a single-line, well-motivated memory fix with no logic alterations.
  • The fix is minimal (one line), logically sound, and consistent with PyTorch's documented behavior around autograd function contexts. func_ctx.tensor_objects is only read once (in restore_from_saved), so nullifying it immediately after has no functional side effects. The before/after memory profiles in the PR description further validate the fix empirically.
  • No files require special attention.

Important Files Changed

Filename Overview
transformer_engine/pytorch/ops/fuser.py Single-line fix: sets func_ctx.tensor_objects = None in backward after restoring saved tensors, allowing PyTorch to release non-standard QuantizedTensorStorage objects that would otherwise be held until the next iteration.

Sequence Diagram

sequenceDiagram
    participant FW as forward()
    participant PyTorch as PyTorch Autograd
    participant BW as backward()

    FW->>FW: prepare_for_saving(*to_save)<br/>→ tensors_to_save, tensor_objects
    FW->>PyTorch: save_for_backward(*tensors_to_save)<br/>(cleared automatically by PyTorch)
    FW->>PyTorch: func_ctx.tensor_objects = tensor_objects<br/>(NOT cleared automatically)

    Note over PyTorch: Forward pass completes.<br/>func_ctx kept alive by PyTorch<br/>until next iteration.

    PyTorch->>BW: backward() invoked

    BW->>BW: restore_from_saved(func_ctx.tensor_objects,<br/>func_ctx.saved_tensors)
    Note over BW: saved_tensors are released<br/>by PyTorch automatically ✅
    BW->>BW: func_ctx.tensor_objects = None<br/>(PR fix: manually release ✅)

    Note over PyTorch: tensor_objects can now be<br/>garbage-collected immediately,<br/>rather than waiting until<br/>the next training iteration.
Loading

Last reviewed commit: 8bbc047

@ptrendx
Copy link
Member

ptrendx commented Mar 10, 2026

/te-ci pytorch

ptrendx
ptrendx previously approved these changes Mar 10, 2026
@ksivaman
Copy link
Member

@kainzhong Could you take a look at the failed CI?

Signed-off-by: Kaining Zhong <kainingz@nvidia.com>
@kainzhong kainzhong dismissed stale reviews from ptrendx and ksivaman via 491779c March 11, 2026 00:32
@kainzhong kainzhong force-pushed the fix/release_tensor_for_fuser branch from e945732 to 491779c Compare March 11, 2026 00:32
@kainzhong
Copy link
Collaborator Author

OK I think the mocked fuser_backward method doesn't like me saving w = None in op_forward when it does device = w.device. I'll just get rid of this optimization to merge PR first and handle it later, since as long as we detach tensor_objects from func_ctx the memory should still be deallocated in time when fuser's backward finishes...

@kainzhong
Copy link
Collaborator Author

/te-ci pytorch

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants