[PyTorch] Fix fuser so it releases tensors properly#2750
Open
kainzhong wants to merge 2 commits intoNVIDIA:mainfrom
Open
[PyTorch] Fix fuser so it releases tensors properly#2750kainzhong wants to merge 2 commits intoNVIDIA:mainfrom
kainzhong wants to merge 2 commits intoNVIDIA:mainfrom
Conversation
Contributor
Greptile SummaryThis PR fixes a memory leak in
Confidence Score: 5/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant FW as forward()
participant PyTorch as PyTorch Autograd
participant BW as backward()
FW->>FW: prepare_for_saving(*to_save)<br/>→ tensors_to_save, tensor_objects
FW->>PyTorch: save_for_backward(*tensors_to_save)<br/>(cleared automatically by PyTorch)
FW->>PyTorch: func_ctx.tensor_objects = tensor_objects<br/>(NOT cleared automatically)
Note over PyTorch: Forward pass completes.<br/>func_ctx kept alive by PyTorch<br/>until next iteration.
PyTorch->>BW: backward() invoked
BW->>BW: restore_from_saved(func_ctx.tensor_objects,<br/>func_ctx.saved_tensors)
Note over BW: saved_tensors are released<br/>by PyTorch automatically ✅
BW->>BW: func_ctx.tensor_objects = None<br/>(PR fix: manually release ✅)
Note over PyTorch: tensor_objects can now be<br/>garbage-collected immediately,<br/>rather than waiting until<br/>the next training iteration.
Last reviewed commit: 8bbc047 |
Member
|
/te-ci pytorch |
ptrendx
previously approved these changes
Mar 10, 2026
Member
|
@kainzhong Could you take a look at the failed CI? |
Signed-off-by: Kaining Zhong <kainingz@nvidia.com>
e945732 to
491779c
Compare
Collaborator
Author
|
OK I think the mocked |
Collaborator
Author
|
/te-ci pytorch |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
When training llama3 405B we observe OOMs with passing
use_te_op_fuser=Truein Megatron. This is becauseOperationFuserfails to properly release tensors that are saved for backward.Type of change
Changes
Please list the changes introduced in this PR:
In
_OperationFuserAutogradFunction's backward, we useto restore the saved tensor from backward. However, we didn't detach
tensor_objectsfromfunc_ctx, wherefunc_ctxwill not be released until the next iteration due to pytorch's internal mechanism (pytorch would clear tensor saved bysave_for_backward, but not for tensors attached toctxdirectly). Therefore the attached tensors are not be released in time and caused higher memory usage. After manually setfunc_ctx.tensor_objectsto None these tensors can be properly freed.In addition, inRemoved this from the PR because it's not compatible with the mocked fuser backward in tests.BasicLinear's backward, I manually calledclear_tensor_dataon the weight tensor if it's allocated by the quantizer (w is not self.weight). This is not necessary since with the previous fix where the weight tensor will still be released after fuser's backward, so it's more of an optimization.(Tested on Megatron-LM's llama3 8B example with manually setting
use_te_op_fuser=Trueingpt_builders)Checklist: