Skip to content

Add NVTE_BACKWARD_MODE=default|unquant|dequant#2644

Open
zianglih wants to merge 51 commits intoNVIDIA:mainfrom
zianglih:keep-bwd
Open

Add NVTE_BACKWARD_MODE=default|unquant|dequant#2644
zianglih wants to merge 51 commits intoNVIDIA:mainfrom
zianglih:keep-bwd

Conversation

@zianglih
Copy link

@zianglih zianglih commented Feb 3, 2026

Description

@HumansAnd

Add an NVTE_KEEP_BACKWARD_UNQUANTIZED env var for quantized fprop + high precision wgrad & dgrad.

Add NVTE_BACKWARD_MODE=default|unquant|dequant env var:

  • default: existing default quantization behavior
  • unquant: quantized fprop + high precision wgrad & dgrad using unquantized activation and weight
  • dequant: quantized fpop + high precision wgrad & dgrad using activation and weight dequantized directly from fprop quantized value

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 3, 2026

Greptile Summary

This PR introduces NVTE_BACKWARD_MODE=default|unquant|dequant — an opt-in env var and recipe field that allows quantized forward passes to be paired with high-precision (unquant: original activations/weights) or dequantized-from-FP8 (dequant: saved quantized tensors dequantized at backward time) gradient computation. The feature is implemented across Linear, LayerNormLinear, GroupedLinear, and all te_ops fusible-op paths. LayerNormMLP explicitly asserts unsupported with a helpful redirection message.

Key changes:

  • backward_mode field added to all recipe dataclasses, resolved from NVTE_BACKWARD_MODE env var or explicit argument; DelayedScaling enforces default only.
  • Forward paths suppress columnwise quantization when backward_mode != "default", saving memory by not storing FP8 transpose layouts.
  • Backward paths explicitly dequantize saved FP8 tensors (dequant mode) or reference the original pre-quantization tensors (unquant mode) for high-precision dgrad/wgrad GEMMs.
  • Userbuffers overlap and fused backward activation ops are correctly disabled for unquant/dequant.
  • Empty-tensor guards added to .dequantize() in all three storage types to handle zero-M-dimension grouped splits.
  • A comprehensive 1810-line test file covers all layer types, recipes, shapes, and reference comparisons.

Notable issues:

  • fuser.py (te_ops path): reduce_and_update_fp8_tensors(forward=False) is not suppressed for unquant/dequant modes in _OperationFuserAutogradFunction.backward(), unlike the module path where ctx.reduce_and_update_bwd_fp8_tensors = False is explicitly set. This inconsistency could corrupt backward FP8 scaling state when mixing modes or when using te_ops.Linear.
  • linear.py / layernorm_linear.py / grouped_linear.py: New additions to _get_quantizers() call FP8GlobalStateManager.get_fp8_recipe() without a null-safety guard, creating a latent AttributeError risk if the call site invariants change.
  • basic_linear.py (te_ops path): The dequant mode backward relies on maybe_dequantize implicitly handling QuantizedTensorStorage objects, an implicit contract that should be documented.

Confidence Score: 3/5

  • The feature is functionally correct for the module path and well-tested, but the te_ops (fusible ops) path has a behavioral inconsistency with FP8 amax update that needs resolution before merging.
  • The PR adds a well-structured opt-in feature with extensive tests (1810-line test file), correct handling in all module paths, and appropriate unsupported-path guards. However, fuser.py's backward still unconditionally triggers reduce_and_update_fp8_tensors(forward=False) for the te_ops path in unquant/dequant modes — inconsistent with the module path — which can corrupt backward FP8 scaling state. The get_fp8_recipe() null-guard issue in _get_quantizers() across three files also represents a latent crash risk.
  • transformer_engine/pytorch/ops/fuser.py (FP8 amax update inconsistency for te_ops path) and transformer_engine/pytorch/module/linear.py, layernorm_linear.py, grouped_linear.py (null-guard on get_fp8_recipe() in _get_quantizers()).

Important Files Changed

Filename Overview
transformer_engine/common/recipe/init.py Adds backward_mode field to all recipe dataclasses (DelayedScaling, Float8CurrentScaling, MXFP8BlockScaling, Float8BlockScaling, NVFP4BlockScaling, CustomRecipe). Uses _resolve_backward_mode() helper to read from NVTE_BACKWARD_MODE env var or validate an explicit argument. DelayedScaling.__post_init__ enforces backward_mode == "default". Implementation is clean and consistent across recipes.
transformer_engine/pytorch/module/linear.py Adds backward_mode detection in _Linear.forward, implements unquant/dequant input/weight handling, overrides ctx.fp8=False and clears quantizers for non-default modes, and adds explicit weight dequantization in dgrad backward. The new _get_quantizers() call at line 1543 calls get_fp8_recipe() without a null guard, which could cause AttributeError if called outside an FP8 context.
transformer_engine/pytorch/ops/basic/basic_linear.py Adds backward_mode parameter to _functional_forward, correctly adjusts columnwise_usage and update_usage calls for unquant/dequant modes, and stores the appropriate tensors for backward. For dequant mode the te_ops path implicitly relies on maybe_dequantize handling QuantizedTensorStorage; this works today but the contract is not made explicit. The unquant mode always saves both full-precision input_ and weight regardless of which gradients are actually needed.
transformer_engine/pytorch/ops/fuser.py Adds backward_mode to the fusion cache key so that changing the mode triggers re-fusion. However, _OperationFuserAutogradFunction.backward() still unconditionally calls reduce_and_update_fp8_tensors(forward=False) when is_first_module is True, even for unquant/dequant modes where no FP8 backward GEMMs ran, creating an inconsistency with the module path which explicitly disables this update.
transformer_engine/pytorch/module/layernorm_linear.py Adds backward_mode handling to _LayerNormLinear: correctly saves high-precision ln_out_hp for unquant mode, dequantizes ln_out and weight explicitly in dequant backward, and overrides ctx flags for non-default modes. The same get_fp8_recipe() null-guard concern exists at the end of _get_quantizers().
transformer_engine/pytorch/module/layernorm_mlp.py Correctly gates the new feature with a clear assertion: assert backward_mode == "default". This prevents LayerNormMLP from being used with unquant/dequant modes and the error message directs users to use LayerNormLinear + Linear instead.
transformer_engine/pytorch/module/grouped_linear.py Adds backward_mode handling to _GroupedLinear: adjusts columnwise_usage, handles empty-split dequantization explicitly with a zero-tensor fallback, and overrides ctx flags in backward. The new _get_quantizers() additions share the same get_fp8_recipe() null-guard concern as linear.py.
tests/pytorch/test_backward_mode.py New comprehensive test file (1810 lines) covering unquant and dequant backward modes for Linear, LayerNormLinear, ops.Linear, GroupedLinear, and all fused ops patterns across all supported quantization recipes. Tests verify both forward output equality and backward gradient correctness against well-defined references.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[FP8 Forward Pass] --> B{backward_mode?}
    B -->|default| C[Quantize input & weight\ncolumnwise=True for backward\nSave quantized tensors]
    B -->|unquant| D[Quantize input & weight\ncolumnwise=False\nSave ORIGINAL high-precision\ninput_ and self.weight]
    B -->|dequant| E[Quantize input & weight\ncolumnwise=False\nSave QUANTIZED FP8 tensors\nrowwise only]

    C --> F[Backward: FP8 GEMMs\nwith quantized grad_output\ndgrad uses columnwise weight\nwgrad uses columnwise input]
    D --> G[Backward: High-precision GEMMs\ndgrad uses original weight\nwgrad uses original input\nno dequantize needed]
    E --> H[Backward: High-precision GEMMs\nexplicit .dequantize on saved tensors\nfor dgrad weight and wgrad input]

    F --> I[Update FP8 amax\nfor backward quantizers]
    G --> J[ctx.fp8=False\nSkip FP8 amax update\nin module path]
    H --> J

    J --> K{te_ops path?}
    K -->|Yes| L[⚠️ fuser.py still calls\nreduce_and_update_fp8_tensors\nforward=False]
    K -->|No| M[✅ Correctly skipped]
Loading

Comments Outside Diff (3)

  1. transformer_engine/pytorch/ops/fuser.py, line 290-291 (link)

    FP8 amax update not suppressed for te_ops path in unquant/dequant mode

    In the _OperationFuserAutogradFunction.backward(), reduce_and_update_fp8_tensors(forward=False) still fires whenever func_ctx.is_first_module is True — even in unquant/dequant modes where no FP8 backward GEMMs ran and no amax data was accumulated for the backward quantizers.

    This is inconsistent with the module path (_Linear, _GroupedLinear, _LayerNormLinear), where the PR explicitly sets ctx.reduce_and_update_bwd_fp8_tensors = False in unquant/dequant modes. Without a similar guard here, the backward FP8 scaling factors for grad_output quantizers will be updated with stale amax values from the previous iteration. If a user switches back to default mode later, the backward quantizers will start from an incorrectly-scaled state.

    Consider gating the update on the backward mode, e.g.:

    # Update FP8 scaling factors
    backward_mode = self.backward_mode if hasattr(func_ctx, "fuser") else "default"
    if func_ctx.is_first_module and not _is_graph_capturing() and backward_mode == "default":
        FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)

    or store backward_mode in func_ctx during forward and check it here.

  2. transformer_engine/pytorch/module/linear.py, line 1543-1547 (link)

    Potential None dereference on get_fp8_recipe()

    FP8GlobalStateManager.get_fp8_recipe() can return None (e.g., when FP8 is enabled but no recipe object has been installed yet). Calling .backward_mode on None would produce an AttributeError. While _get_quantizers() is currently only reached when self.fp8 is True, a defensive guard is cheap and prevents a hard-to-diagnose crash if the invariant is ever violated.

    The same pattern appears in layernorm_linear.py and grouped_linear.py at equivalent locations.

  3. transformer_engine/pytorch/ops/basic/basic_linear.py, line 1029-1035 (link)

    dequant backward for te_ops path uses implicit maybe_dequantize contract

    In dequant mode, op_forward stores quantized (QuantizedTensorStorage) tensors as saved_input and saved_weight, then sets ctx.with_quantized_compute = False. In op_backward, these pass to _functional_backward with with_quantized_compute=False, which silently calls maybe_dequantize(x_local, dtype) and maybe_dequantize(w, dtype).

    Unlike the module path (_Linear, _LayerNormLinear) where explicit if ctx.backward_mode == "dequant": ... .dequantize(...) calls are added in backward, the te_ops path relies on maybe_dequantize transparently handling QuantizedTensorStorage objects. This implicit contract works today (the storage types all have dequantize() methods), but it is fragile — any new storage subclass that doesn't implement dequantize() correctly will silently fall through to a wrong result rather than failing with a clear error.

    Consider adding an explicit check or comment to make the contract explicit:

    # Save state for backward pass
    if ctx.requires_grad:
        saved_input = input_ if backward_mode == "unquant" else x_local
        saved_weight = self.weight if backward_mode == "unquant" else w
        # In dequant mode, saved_input/saved_weight may be QuantizedTensorStorage;
        # _functional_backward with with_quantized_compute=False will call
        # maybe_dequantize() which dispatches to .dequantize() on those objects.

Last reviewed commit: a2b5250

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

17 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

6 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@zianglih
Copy link
Author

zianglih commented Feb 3, 2026

I'll work on potential unit test breakage.

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 4 comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

# Note: dgrad GEMM requires row-wise usage, wgrad GEMM
# requires column-wise usage
if ctx.grad_output_quantizer is not None:
if ctx.grad_output_quantizer is not None and use_fp8_bwd:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this line seems redundant since you already skip the quantization step in base.py grad_output_preprocess?

not ctx.use_bias
and not ctx.requires_wgrad
and ctx.grad_output_quantizer is not None
and use_fp8_bwd
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment as above

recipe = cls.get_fp8_recipe()
if recipe is not None and recipe.delayed():
# Ignore NVTE_KEEP_BACKWARD_UNQUANTIZED when delayed scaling is used
return False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it's better to assert an error for delayed scaling? Okay with both.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. If the user specifies an unsupported combination, I think it's better to fail loudly than to secretly disobey their instructions.

# Note: dgrad GEMM requires row-wise usage, wgrad GEMM
# requires column-wise usage
if ctx.grad_output_quantizer is not None:
if ctx.grad_output_quantizer is not None and use_fp8_bwd:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems redundant too if we skip quant in grad_output_preprocess

zianglih and others added 26 commits March 11, 2026 20:42
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
@negvet
Copy link
Collaborator

negvet commented Mar 12, 2026

/te-ci pytorch L1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants