diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index f2b0b07fed..a9df8a1bb6 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -42,6 +42,7 @@ python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_gqa.xml $TE_PATH python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_optimizer.xml $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_multi_tensor.xml $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_backward_mode.xml $TE_PATH/tests/pytorch/test_backward_mode.py || test_fail "test_backward_mode.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py" diff --git a/tests/pytorch/test_backward_mode.py b/tests/pytorch/test_backward_mode.py new file mode 100644 index 0000000000..3aa47e1166 --- /dev/null +++ b/tests/pytorch/test_backward_mode.py @@ -0,0 +1,1810 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +from __future__ import annotations + +from contextlib import nullcontext +import math +from typing import Optional + +import pytest +import torch + +import transformer_engine.pytorch as te +import transformer_engine.pytorch.ops as te_ops +from transformer_engine.common import recipe +from transformer_engine.pytorch.cpp_extensions import general_gemm, layernorm_bwd +from transformer_engine.pytorch.quantization import FP8GlobalStateManager +from transformer_engine.pytorch.ops.fused import ( + BackwardActivationBias, + ForwardLinearBiasActivation, + ForwardLinearBiasAdd, + ForwardLinearScaleAdd, + UserbuffersForwardLinear, +) +from transformer_engine.pytorch.quantized_tensor import restore_from_saved + +from utils import ( + assert_close, + make_recipe, + reset_rng_states, + skip_unsupported_backward_mode, +) + + +# -------------------------- +# Mode and capability config +# -------------------------- + +_NON_QUANT_BACKWARD_MODES = ("unquant", "dequant") + +fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) +mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True) +fp8_block_scaling_available, reason_for_no_fp8_block_scaling = te.is_fp8_block_scaling_available( + return_reason=True +) +nvfp4_available, reason_for_no_nvfp4 = te.is_nvfp4_available(return_reason=True) +bf16_available, reason_for_no_bf16 = te.is_bf16_available(return_reason=True) + +_core_dtypes = [torch.float16, torch.float32] +_fused_dtypes = [torch.float16] +if bf16_available: + _core_dtypes.insert(1, torch.bfloat16) + _fused_dtypes.insert(1, torch.bfloat16) + + +@pytest.fixture(autouse=True) +def _reset_global_fp8_state(): + """Avoid global FP8-state leakage between parametrized cases.""" + yield + FP8GlobalStateManager.reset() + + +@pytest.fixture(params=_NON_QUANT_BACKWARD_MODES, ids=lambda mode: f"mode_{mode}") +def backward_mode(request: pytest.FixtureRequest) -> str: + """Backward mode under test.""" + return request.param + + +# -------------------------- +# Shared helpers +# -------------------------- + + +def _restore_saved_operands(output: torch.Tensor) -> list[Optional[torch.Tensor]]: + if output.grad_fn is None: + raise RuntimeError("Output tensor has no grad_fn; cannot inspect saved operands") + if not hasattr(output.grad_fn, "tensor_objects"): + raise RuntimeError("grad_fn does not expose tensor_objects for saved operand restoration") + return restore_from_saved(output.grad_fn.tensor_objects, list(output.grad_fn.saved_tensors)) + + +def _extract_linear_saved_operands( + saved_operands: list[Optional[torch.Tensor]], + *, + context: str, +) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: + if len(saved_operands) < 2: + raise RuntimeError( + f"Insufficient saved operands for {context} dequant reference " + f"(got {len(saved_operands)}, expected at least 2)." + ) + return saved_operands[0], saved_operands[1] + + +def _dequantize_saved_operand( + saved_operand: Optional[torch.Tensor], + dtype: torch.dtype, +) -> torch.Tensor: + if saved_operand is None: + raise RuntimeError("Expected saved operand but got None") + # In dequant mode we must consume the fprop-saved quantized payload directly. + # If row-wise payload is missing, the tensor was retargeted to a transpose-only + # layout and no longer represents the original fprop operand. + if ( + not isinstance(saved_operand, torch.Tensor) + and hasattr(saved_operand, "_rowwise_data") + and getattr(saved_operand, "_rowwise_data") is None + ): + raise RuntimeError( + "Saved dequant operand lost row-wise fprop payload (likely usage retarget)." + ) + if isinstance(saved_operand, torch.Tensor): + return saved_operand.to(dtype) + if not hasattr(saved_operand, "dequantize"): + raise RuntimeError(f"Unsupported saved operand type: {type(saved_operand)}") + return saved_operand.dequantize(dtype=dtype) + + +def _assert_saved_quantized_operand_uses_rowwise_only( + saved_operand: Optional[torch.Tensor], + *, + name: str, +) -> None: + if saved_operand is None: + raise RuntimeError(f"Expected quantized saved {name} operand but got None") + if isinstance(saved_operand, torch.Tensor): + raise RuntimeError( + f"Dequant reference expects quantized saved {name} operand, got torch.Tensor." + ) + if not hasattr(saved_operand, "dequantize"): + raise RuntimeError(f"Unsupported saved {name} operand type: {type(saved_operand)}") + if hasattr(saved_operand, "_rowwise_data") and getattr(saved_operand, "_rowwise_data") is None: + raise RuntimeError( + f"Saved dequant {name} operand lost row-wise fprop payload (likely usage retarget)." + ) + if ( + hasattr(saved_operand, "_columnwise_data") + and getattr(saved_operand, "_columnwise_data") is not None + ): + raise RuntimeError( + f"Saved dequant {name} operand unexpectedly carries column-wise payload." + ) + + +def _snapshot_saved_quantized_operand_layout( + saved_operand: Optional[torch.Tensor], + *, + name: str, +) -> dict[str, object]: + _assert_saved_quantized_operand_uses_rowwise_only(saved_operand, name=name) + rowwise_present = None + columnwise_present = None + rowwise_obj_id = None + if hasattr(saved_operand, "_rowwise_data"): + rowwise_data = getattr(saved_operand, "_rowwise_data") + rowwise_present = rowwise_data is not None + if rowwise_data is not None: + rowwise_obj_id = id(rowwise_data) + if hasattr(saved_operand, "_columnwise_data"): + columnwise_present = getattr(saved_operand, "_columnwise_data") is not None + return { + "name": name, + "saved_operand": saved_operand, + "rowwise_present": rowwise_present, + "columnwise_present": columnwise_present, + "rowwise_obj_id": rowwise_obj_id, + } + + +def _assert_saved_quantized_operand_layout_unchanged(snapshot: dict[str, object]) -> None: + name = snapshot.get("name") + if not isinstance(name, str): + raise RuntimeError(f"Invalid saved operand snapshot name: {name!r}") + saved_operand = snapshot.get("saved_operand") + _assert_saved_quantized_operand_uses_rowwise_only(saved_operand, name=name) + + rowwise_present = snapshot.get("rowwise_present") + if isinstance(rowwise_present, bool): + rowwise_data_now = getattr(saved_operand, "_rowwise_data", None) + rowwise_now = rowwise_data_now is not None + if rowwise_now != rowwise_present: + raise RuntimeError( + f"Saved dequant {name} operand row-wise payload presence changed " + f"from {rowwise_present} to {rowwise_now}." + ) + # Guard against hidden requantization that swaps in a new row-wise payload. + rowwise_obj_id = snapshot.get("rowwise_obj_id") + if ( + isinstance(rowwise_obj_id, int) + and rowwise_now + and id(rowwise_data_now) != rowwise_obj_id + ): + raise RuntimeError( + f"Saved dequant {name} operand row-wise payload identity changed " + "(likely rewritten/requantized)." + ) + + columnwise_present = snapshot.get("columnwise_present") + if isinstance(columnwise_present, bool): + columnwise_now = getattr(saved_operand, "_columnwise_data", None) is not None + if columnwise_now != columnwise_present: + raise RuntimeError( + f"Saved dequant {name} operand column-wise payload presence changed " + f"from {columnwise_present} to {columnwise_now}." + ) + + +def _snapshot_layout_invariants( + guard_operands: list[tuple[str, Optional[torch.Tensor]]], +) -> list[dict[str, object]]: + """Capture saved-operand layout invariants before backward runs.""" + return [ + _snapshot_saved_quantized_operand_layout(saved_operand, name=name) + for name, saved_operand in guard_operands + ] + + +def _assert_layout_invariants_unchanged(layout_invariants: list[dict[str, object]]) -> None: + """Validate saved-operand layout invariants after backward runs.""" + for layout_invariant in layout_invariants: + _assert_saved_quantized_operand_layout_unchanged(layout_invariant) + + +def _raise_if_ref_failed(ref_exc: Optional[Exception]) -> None: + """Re-raise deferred reference exceptions after layout checks.""" + if ref_exc is not None: + raise ref_exc + + +def _compute_linear_backward_reference_from_saved_operands( + saved_input: Optional[torch.Tensor], + saved_weight: Optional[torch.Tensor], + dy: torch.Tensor, + *, + dequant_dtype: torch.dtype, + out_dtype: torch.dtype, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # Dequant reference path: + # 1) use the exact operands saved by quantized forward, + # 2) dequantize them to the active high-precision compute dtype, + # 3) run backward GEMMs in high precision and compare exactly. + for name, saved_operand in (("input", saved_input), ("weight", saved_weight)): + _assert_saved_quantized_operand_uses_rowwise_only(saved_operand, name=name) + dy_mat = dy.reshape(-1, dy.shape[-1]) + + # Empty-token chunks can happen in grouped/fused paths. Reference should be zeros. + if dy_mat.shape[0] == 0: + out_features = dy_mat.shape[-1] + if saved_input is None: + raise RuntimeError("Expected saved input operand for empty-chunk dequant reference.") + in_features = saved_input.size(-1) + dx_ref = torch.zeros(*dy.shape[:-1], in_features, dtype=out_dtype, device=dy.device) + dw_ref = torch.zeros(out_features, in_features, dtype=out_dtype, device=dy.device) + db_ref = torch.zeros(out_features, dtype=out_dtype, device=dy.device) + return dx_ref, dw_ref, db_ref + + x_ref_full = _dequantize_saved_operand(saved_input, dequant_dtype) + x_ref = x_ref_full.reshape(-1, x_ref_full.shape[-1]) + w_ref = _dequantize_saved_operand(saved_weight, dequant_dtype) + + dx_ref_2d, *_ = general_gemm( + w_ref, + dy_mat, + out_dtype=out_dtype, + layout="NN", + grad=True, + ) + # Derive db from the same GEMM primitive used by runtime wgrad. This avoids + # tiny reduction-order drift vs. a standalone dy.sum() path in FP32 cases. + db_seed = torch.empty(dy_mat.shape[-1], dtype=out_dtype, device=dy_mat.device) + dw_ref, db_ref, *_ = general_gemm( + x_ref, + dy_mat, + out_dtype=out_dtype, + layout="NT", + grad=True, + bias=db_seed, + ) + if db_ref is None: + db_ref = dy_mat.sum(dim=0).to(out_dtype) + dx_ref = dx_ref_2d.view(*dy.shape[:-1], dx_ref_2d.shape[-1]) + return dx_ref, dw_ref, db_ref + + +_quantized_numerics_recipe_list = [ + pytest.param( + "fp8_current_scaling", + marks=pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8), + id="Float8CurrentScaling", + ), + pytest.param( + "mxfp8", + marks=pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8), + id="MXFP8BlockScaling", + ), + pytest.param( + "fp8_block_scaling", + marks=pytest.mark.skipif( + not fp8_block_scaling_available, + reason=reason_for_no_fp8_block_scaling, + ), + id="Float8BlockScaling", + ), + pytest.param( + "nvfp4", + marks=pytest.mark.skipif(not nvfp4_available, reason=reason_for_no_nvfp4), + id="NVFP4BlockScaling", + ), +] + +_shape_test_cases = [ + pytest.param((1, 64), 64, id="2d_m1_k64_n64"), + pytest.param((32, 64), 64, id="2d_m32_k64_n64"), + pytest.param((32, 96), 96, id="2d_m32_k96_n96"), + pytest.param((32, 1, 64), 64, id="3d_m32_s1_k64_n64"), + pytest.param((8, 4, 64), 128, id="3d_m32_k64_n128"), + pytest.param((16, 2, 128), 64, id="3d_m32_k128_n64"), + pytest.param((160, 64), 64, id="2d_m160_k64_n64"), + pytest.param((5, 64, 64), 64, id="3d_m320_k64_n64"), + pytest.param((3, 5, 32, 64), 96, id="4d_m480_k64_n96"), + pytest.param((2, 5, 16, 128), 64, id="4d_m160_k128_n64"), + # Intentionally unaligned token dimensions to exercise skip/support logic. + pytest.param((3, 64), 64, id="2d_m3_k64_n64_unaligned"), + pytest.param((3, 10, 64), 64, id="3d_m30_k64_n64_unaligned"), + pytest.param((3, 10, 96), 96, id="3d_m30_k96_n96_unaligned"), +] + +_bias_activation_shape_cases = [ + pytest.param((32, 64), id="2d_m32_k64"), + pytest.param((32, 96), id="2d_m32_k96"), + pytest.param((8, 4, 64), id="3d_m32_k64"), + pytest.param((160, 64), id="2d_m160_k64"), + pytest.param((5, 64, 64), id="3d_m320_k64"), + pytest.param((3, 5, 32, 64), id="4d_m480_k64"), + # Intentionally unaligned token dimensions to exercise skip/support logic. + pytest.param((3, 64), id="2d_m3_k64_unaligned"), + pytest.param((3, 10, 64), id="3d_m30_k64_unaligned"), + pytest.param((3, 10, 96), id="3d_m30_k96_unaligned"), +] + +_grouped_m_split_cases = [ + pytest.param([32, 32, 32, 32], id="uniform_splits"), + pytest.param([64, 0, 32, 32], id="with_empty_split"), + pytest.param([1, 31, 0, 96], id="small_and_empty_splits"), +] + +_linear_feature_cases = [ + pytest.param(64, 64, id="k64_n64"), + pytest.param(64, 128, id="k64_n128"), + pytest.param(128, 64, id="k128_n64"), + pytest.param(96, 96, id="k96_n96"), + pytest.param(64, 96, id="k64_n96"), + pytest.param(96, 64, id="k96_n64"), + pytest.param(128, 96, id="k128_n96"), + pytest.param(96, 128, id="k96_n128"), +] + +_output_feature_cases = [ + pytest.param(64, id="n64"), + pytest.param(96, id="n96"), + pytest.param(128, id="n128"), +] + + +def _copy_named_parameters(src_module: torch.nn.Module, dst_module: torch.nn.Module) -> None: + src_params = dict(src_module.named_parameters()) + with torch.no_grad(): + for name, dst_param in dst_module.named_parameters(): + if name not in src_params: + raise RuntimeError(f"Parameter {name} missing in source module") + dst_param.copy_(src_params[name]) + + +def _maybe_skip_recipe_dtype(recipe_name: str, dtype: torch.dtype) -> None: + if dtype == torch.bfloat16 and not bf16_available: + pytest.skip(reason_for_no_bf16) + if recipe_name == "nvfp4" and dtype != torch.bfloat16: + pytest.skip("NVFP4 is only supported with BF16 in this test") + + +def _make_linear_like_module( + module_type: str, + in_features: int, + out_features: int, + dtype: torch.dtype, + *, + bias: bool, +) -> torch.nn.Module: + if module_type == "linear": + return te.Linear( + in_features, + out_features, + bias=bias, + params_dtype=dtype, + device="cuda", + ) + if module_type == "layernorm_linear": + return te.LayerNormLinear( + in_features, + out_features, + bias=bias, + params_dtype=dtype, + device="cuda", + ) + if module_type == "ops_linear": + return te_ops.Linear( + in_features, + out_features, + bias=bias, + dtype=dtype, + device="cuda", + ) + raise ValueError(f"Unsupported module type: {module_type}") + + +def _maybe_skip_unsupported_recipe_module_combo(recipe_name: str, module_type: str) -> None: + if module_type == "ops_linear" and recipe_name == "fp8_block_scaling": + pytest.skip("Fusible ops (te_ops.Linear) do not support Float8BlockScaling recipe") + + +def _maybe_skip_unsupported_recipe_shape( + recipe_name: str, + input_shape: tuple[int, ...], + module_type: str, +) -> None: + flat_first_dim = math.prod(input_shape[:-1]) + last_dim = input_shape[-1] + + if module_type in ("linear", "layernorm_linear"): + if flat_first_dim % 8 != 0 or last_dim % 16 != 0: + pytest.skip( + "Linear/LayerNormLinear FP8 execution requires prod(shape[:-1]) divisible by 8 " + "and shape[-1] divisible by 16." + ) + return + + if module_type == "ops_linear": + if recipe_name == "mxfp8" and (flat_first_dim % 32 != 0 or last_dim % 32 != 0): + pytest.skip( + "te_ops.Linear + MXFP8 requires prod(shape[:-1]) and shape[-1] divisible by 32." + ) + if recipe_name == "nvfp4" and (flat_first_dim % 16 != 0 or last_dim % 16 != 0): + pytest.skip( + "te_ops.Linear + NVFP4 requires prod(shape[:-1]) and shape[-1] divisible by 16." + ) + + +def _maybe_skip_unsupported_grouped_splits(recipe_name: str, m_splits: list[int]) -> None: + non_empty_splits = [m for m in m_splits if m > 0] + if recipe_name == "mxfp8" and any(m % 32 != 0 for m in non_empty_splits): + pytest.skip("GroupedLinear + MXFP8 requires each non-empty m_split divisible by 32.") + if recipe_name == "fp8_block_scaling" and any(m % 4 != 0 for m in non_empty_splits): + pytest.skip( + "GroupedLinear + Float8BlockScaling requires each non-empty m_split divisible by 4." + ) + + +def _run_single_step( + module: torch.nn.Module, + x: torch.Tensor, + dy: torch.Tensor, + fp8_recipe: Optional[recipe.Recipe], +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + module.zero_grad(set_to_none=True) + x_run = x.detach().clone().requires_grad_(True) + autocast_ctx = ( + te.autocast(enabled=True, recipe=fp8_recipe) if fp8_recipe is not None else nullcontext() + ) + with autocast_ctx: + y = module(x_run) + if isinstance(y, tuple): + y = y[0] + y.backward(dy) + assert x_run.grad is not None + assert module.weight.grad is not None + bias = getattr(module, "bias", None) + bgrad = None if bias is None or bias.grad is None else bias.grad.detach().clone() + return ( + y.detach().clone(), + x_run.grad.detach().clone(), + module.weight.grad.detach().clone(), + bgrad, + ) + + +def _run_single_step_with_saved_operands( + module: torch.nn.Module, + x: torch.Tensor, + fp8_recipe: recipe.Recipe, +) -> tuple[ + torch.Tensor, + torch.Tensor, + list[Optional[torch.Tensor]], +]: + module.zero_grad(set_to_none=True) + x_run = x.detach().clone().requires_grad_(True) + with te.autocast(enabled=True, recipe=fp8_recipe): + y = module(x_run) + if isinstance(y, tuple): + y = y[0] + saved_operands = _restore_saved_operands(y) + return y, x_run, saved_operands + + +def _run_grouped_linear_single_step( + module: te.GroupedLinear, + x: torch.Tensor, + m_splits: list[int], + dy: torch.Tensor, + fp8_recipe: Optional[recipe.Recipe], +) -> tuple[torch.Tensor, torch.Tensor, list[torch.Tensor], list[Optional[torch.Tensor]]]: + module.zero_grad(set_to_none=True) + x_run = x.detach().clone().requires_grad_(True) + autocast_ctx = ( + te.autocast(enabled=True, recipe=fp8_recipe) if fp8_recipe is not None else nullcontext() + ) + with autocast_ctx: + y = module(x_run, m_splits) + y.backward(dy) + assert x_run.grad is not None + + dw = [getattr(module, f"weight{i}").grad.detach().clone() for i in range(module.num_gemms)] + db: list[Optional[torch.Tensor]] = [] + for i in range(module.num_gemms): + if module.use_bias: + db.append(getattr(module, f"bias{i}").grad.detach().clone()) + else: + db.append(None) + return y.detach().clone(), x_run.grad.detach().clone(), dw, db + + +def _run_grouped_linear_step_with_saved_operands( + module: te.GroupedLinear, + x: torch.Tensor, + m_splits: list[int], + fp8_recipe: recipe.Recipe, +) -> tuple[ + torch.Tensor, + torch.Tensor, + list[Optional[torch.Tensor]], +]: + module.zero_grad(set_to_none=True) + x_run = x.detach().clone().requires_grad_(True) + with te.autocast(enabled=True, recipe=fp8_recipe): + y = module(x_run, m_splits) + saved_operands = _restore_saved_operands(y) + return y, x_run, saved_operands + + +def _make_fused_model( + pattern: str, + in_features: int, + out_features: int, + dtype: torch.dtype, + *, + scale: float = 0.5, +) -> te_ops.Sequential: + if pattern == "bias_activation": + return te_ops.Sequential( + te_ops.Linear(in_features, out_features, bias=True, device="cuda", dtype=dtype), + te_ops.ReLU(), + ) + if pattern == "bias_add": + return te_ops.Sequential( + te_ops.Linear(in_features, out_features, bias=True, device="cuda", dtype=dtype), + te_ops.AddExtraInput(in_place=True), + ) + if pattern == "scale_add": + return te_ops.Sequential( + te_ops.Linear(in_features, out_features, bias=False, device="cuda", dtype=dtype), + te_ops.ConstantScale(scale), + te_ops.AddExtraInput(in_place=True), + ) + raise ValueError(f"Unsupported fused test pattern: {pattern}") + + +def _run_fused_single_step( + pattern: str, + model: te_ops.Sequential, + x1: torch.Tensor, + dy: torch.Tensor, + fp8_recipe: Optional[recipe.Recipe], + *, + x2: Optional[torch.Tensor] = None, +) -> tuple[ + torch.Tensor, + torch.Tensor, + Optional[torch.Tensor], + torch.Tensor, + Optional[torch.Tensor], +]: + model.zero_grad(set_to_none=True) + x1_run = x1.detach().clone().requires_grad_(True) + x2_run = x2.detach().clone().requires_grad_(True) if x2 is not None else None + autocast_ctx = ( + te.autocast(enabled=True, recipe=fp8_recipe) if fp8_recipe is not None else nullcontext() + ) + with autocast_ctx: + if pattern in ("bias_add", "scale_add"): + assert x2_run is not None + y = model(x1_run, x2_run) + else: + y = model(x1_run) + y.backward(dy) + assert x1_run.grad is not None + + dw = model[0].weight.grad.detach().clone() + db = None + if getattr(model[0], "bias", None) is not None and model[0].bias.grad is not None: + db = model[0].bias.grad.detach().clone() + dx2 = x2_run.grad.detach().clone() if x2_run is not None and x2_run.grad is not None else None + return y.detach().clone(), x1_run.grad.detach().clone(), dx2, dw, db + + +def _run_fused_single_step_with_saved_operands( + pattern: str, + model: te_ops.Sequential, + x1: torch.Tensor, + fp8_recipe: recipe.Recipe, + *, + x2: Optional[torch.Tensor] = None, +) -> tuple[ + torch.Tensor, + torch.Tensor, + Optional[torch.Tensor], + list[Optional[torch.Tensor]], +]: + model.zero_grad(set_to_none=True) + x1_run = x1.detach().clone().requires_grad_(True) + x2_run = x2.detach().clone().requires_grad_(True) if x2 is not None else None + with te.autocast(enabled=True, recipe=fp8_recipe): + if pattern in ("bias_add", "scale_add"): + assert x2_run is not None + y = model(x1_run, x2_run) + else: + y = model(x1_run) + saved_operands = _restore_saved_operands(y) + return y, x1_run, x2_run, saved_operands + + +def _run_quantize_op_single_step( + model: te_ops.Sequential, + x: torch.Tensor, + dy: torch.Tensor, + fp8_recipe: Optional[recipe.Recipe], +) -> tuple[torch.Tensor, torch.Tensor]: + x_run = x.detach().clone().requires_grad_(True) + autocast_ctx = ( + te.autocast(enabled=True, recipe=fp8_recipe) if fp8_recipe is not None else nullcontext() + ) + with autocast_ctx: + y = model(x_run) + y.backward(dy) + assert x_run.grad is not None + return y.detach().clone(), x_run.grad.detach().clone() + + +def _snapshot_backward_ctx_state( + output: torch.Tensor, +) -> tuple[str, bool, object, bool]: + if output.grad_fn is None: + raise RuntimeError("Output tensor has no grad_fn; cannot inspect backward context state.") + required_attrs = ( + "backward_mode", + "fp8", + "grad_output_quantizer", + "reduce_and_update_bwd_fp8_tensors", + ) + missing_attrs = [attr for attr in required_attrs if not hasattr(output.grad_fn, attr)] + if missing_attrs: + raise RuntimeError( + "grad_fn does not expose required backward context attributes: " + f"{', '.join(missing_attrs)}." + ) + return ( + getattr(output.grad_fn, "backward_mode"), + bool(getattr(output.grad_fn, "fp8")), + getattr(output.grad_fn, "grad_output_quantizer"), + bool(getattr(output.grad_fn, "reduce_and_update_bwd_fp8_tensors")), + ) + + +def _run_single_step_with_ctx_state( + module: torch.nn.Module, + x: torch.Tensor, + dy: torch.Tensor, + fp8_recipe: Optional[recipe.Recipe], +) -> tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + Optional[torch.Tensor], + tuple[str, bool, object, bool], +]: + module.zero_grad(set_to_none=True) + x_run = x.detach().clone().requires_grad_(True) + autocast_ctx = ( + te.autocast(enabled=True, recipe=fp8_recipe) if fp8_recipe is not None else nullcontext() + ) + with autocast_ctx: + y = module(x_run) + if isinstance(y, tuple): + y = y[0] + ctx_state = _snapshot_backward_ctx_state(y) + y.backward(dy) + assert x_run.grad is not None + assert module.weight.grad is not None + bias = getattr(module, "bias", None) + bgrad = None if bias is None or bias.grad is None else bias.grad.detach().clone() + return ( + y.detach().clone(), + x_run.grad.detach().clone(), + module.weight.grad.detach().clone(), + bgrad, + ctx_state, + ) + + +def _run_grouped_linear_single_step_with_ctx_state( + module: te.GroupedLinear, + x: torch.Tensor, + m_splits: list[int], + dy: torch.Tensor, + fp8_recipe: Optional[recipe.Recipe], +) -> tuple[ + torch.Tensor, + torch.Tensor, + list[torch.Tensor], + list[Optional[torch.Tensor]], + tuple[str, bool, bool], +]: + module.zero_grad(set_to_none=True) + x_run = x.detach().clone().requires_grad_(True) + autocast_ctx = ( + te.autocast(enabled=True, recipe=fp8_recipe) if fp8_recipe is not None else nullcontext() + ) + with autocast_ctx: + y = module(x_run, m_splits) + if y.grad_fn is None: + raise RuntimeError( + "Output tensor has no grad_fn; cannot inspect grouped backward state." + ) + required_attrs = ( + "backward_mode", + "fp8", + "reduce_and_update_bwd_fp8_tensors", + ) + missing_attrs = [attr for attr in required_attrs if not hasattr(y.grad_fn, attr)] + if missing_attrs: + raise RuntimeError( + "Grouped grad_fn does not expose required backward context attributes: " + f"{', '.join(missing_attrs)}." + ) + ctx_state = ( + getattr(y.grad_fn, "backward_mode"), + bool(getattr(y.grad_fn, "fp8")), + bool(getattr(y.grad_fn, "reduce_and_update_bwd_fp8_tensors")), + ) + y.backward(dy) + assert x_run.grad is not None + + dw = [getattr(module, f"weight{i}").grad.detach().clone() for i in range(module.num_gemms)] + db: list[Optional[torch.Tensor]] = [] + for i in range(module.num_gemms): + if module.use_bias: + db.append(getattr(module, f"bias{i}").grad.detach().clone()) + else: + db.append(None) + return y.detach().clone(), x_run.grad.detach().clone(), dw, db, ctx_state + + +# -------------------------- +# Tests +# -------------------------- + + +@pytest.mark.parametrize("recipe_name", _quantized_numerics_recipe_list) +def test_backward_mode_recipe_matches_requested_mode( + recipe_name: str, + backward_mode: str, +) -> None: + mode_recipe = make_recipe(recipe_name, backward_mode=backward_mode) + quant_recipe = make_recipe(recipe_name, backward_mode="default") + assert mode_recipe.backward_mode == backward_mode + assert quant_recipe.backward_mode == "default" + + +@pytest.mark.parametrize("recipe_name", _quantized_numerics_recipe_list) +@pytest.mark.parametrize("module_type", ("linear", "layernorm_linear", "ops_linear")) +@pytest.mark.parametrize("input_shape,out_features", _shape_test_cases) +@pytest.mark.parametrize("use_bias", (False, True), ids=("no_bias", "bias")) +@pytest.mark.parametrize("dtype", _core_dtypes, ids=str) +def test_linear_like_backward_mode_matches_reference( + recipe_name: str, + module_type: str, + input_shape: tuple[int, ...], + out_features: int, + use_bias: bool, + dtype: torch.dtype, + backward_mode: str, +) -> None: + reset_rng_states() + _maybe_skip_recipe_dtype(recipe_name, dtype) + _maybe_skip_unsupported_recipe_module_combo(recipe_name, module_type) + _maybe_skip_unsupported_recipe_shape(recipe_name, input_shape, module_type) + + in_features = input_shape[-1] + quantized_ref_recipe = make_recipe(recipe_name, backward_mode="default") + mode_recipe = make_recipe(recipe_name, backward_mode=backward_mode) + skip_unsupported_backward_mode(module_type, mode_recipe, backward_mode) + + module_quantized_ref = _make_linear_like_module( + module_type, + in_features, + out_features, + dtype, + bias=use_bias, + ) + module_bwd_mode = _make_linear_like_module( + module_type, + in_features, + out_features, + dtype, + bias=use_bias, + ) + _copy_named_parameters(module_quantized_ref, module_bwd_mode) + + output_shape = input_shape[:-1] + (out_features,) + x = torch.randn(*input_shape, dtype=dtype, device="cuda") + dy = torch.randn(*output_shape, dtype=dtype, device="cuda") + + y_quantized_ref, _, _, _ = _run_single_step(module_quantized_ref, x, dy, quantized_ref_recipe) + if backward_mode == "unquant": + # Unquant reference path: compare against a plain high-precision backward run + # (no fp8/autocast), starting from the same params and inputs. + module_unquantized_ref = _make_linear_like_module( + module_type, + in_features, + out_features, + dtype, + bias=use_bias, + ) + _copy_named_parameters(module_quantized_ref, module_unquantized_ref) + y_bwd_mode, dx_bwd_mode, dw_bwd_mode, db_bwd_mode = _run_single_step( + module_bwd_mode, + x, + dy, + mode_recipe, + ) + _, dx_ref, dw_ref, db_ref = _run_single_step( + module_unquantized_ref, + x, + dy, + None, + ) + else: + # Dequant reference path: capture saved forward operands from the real dequant-mode + # execution, then rebuild backward reference from those saved operands. + y_bwd_mode, x_bwd_mode, saved_operands = _run_single_step_with_saved_operands( + module_bwd_mode, x, mode_recipe + ) + y_bwd_mode_detached = y_bwd_mode.detach().clone() + + dx_ref: Optional[torch.Tensor] = None + dw_ref: Optional[torch.Tensor] = None + db_ref: Optional[torch.Tensor] = None + layout_invariants: list[dict[str, object]] = [] + guard_operands: list[tuple[str, Optional[torch.Tensor]]] = [] + ref_exc: Optional[Exception] = None + try: + if module_type == "layernorm_linear": + # LayerNormLinear dequant reference: + # 1) Compute d(ln_out), dw, db from linear backward with saved operands. + # 2) Compute exact dx via layernorm_bwd with saved norm statistics. + # _LayerNormLinear forward saves operands as: + # [inputmat, weightmat, origin_weight, bias, ln_weight, ln_out, mu, rsigma, ...] + if len(saved_operands) < 8: + raise RuntimeError( + "Insufficient saved operands for layernorm_linear dequant reference " + f"(got {len(saved_operands)}, expected at least 8)." + ) + saved_input = saved_operands[0] + saved_weight = saved_operands[1] + saved_ln_weight = saved_operands[4] + saved_ln_out = saved_operands[5] + saved_mu = saved_operands[6] + saved_rsigma = saved_operands[7] + guard_operands.extend( + [ + ("layernorm_linear_ln_out", saved_ln_out), + ("layernorm_linear_weight", saved_weight), + ] + ) + d_ln_out_ref, dw_ref, db_ref = ( + _compute_linear_backward_reference_from_saved_operands( + saved_ln_out, + saved_weight, + dy, + dequant_dtype=dtype, + out_dtype=dtype, + ) + ) + input_ref = _dequantize_saved_operand(saved_input, dtype) + input_ref_2d = input_ref.reshape(-1, input_ref.shape[-1]) + ln_weight_ref = _dequantize_saved_operand(saved_ln_weight, dtype).view(-1) + if saved_mu is None or saved_rsigma is None: + raise RuntimeError("Missing LayerNorm statistics in saved operands") + if not isinstance(saved_mu, torch.Tensor) or not isinstance( + saved_rsigma, torch.Tensor + ): + raise RuntimeError("LayerNorm statistics must be Tensor objects") + dx_ref, *_ = layernorm_bwd( + d_ln_out_ref.reshape(input_ref_2d.shape), + input_ref_2d, + saved_mu, + saved_rsigma, + ln_weight_ref, + module_bwd_mode.bwd_ln_sm_margin, + module_bwd_mode.zero_centered_gamma, + ) + dx_ref = dx_ref.view_as(x_bwd_mode) + else: + saved_input, saved_weight = _extract_linear_saved_operands( + saved_operands, + context=f"{module_type}", + ) + guard_operands.extend( + [ + (f"{module_type}_input", saved_input), + (f"{module_type}_weight", saved_weight), + ] + ) + dx_ref, dw_ref, db_ref = _compute_linear_backward_reference_from_saved_operands( + saved_input, + saved_weight, + dy, + dequant_dtype=dtype, + out_dtype=dtype, + ) + if module_type == "ops_linear" and use_bias: + # te_ops bias grad is reduced by the Bias op from incoming dy. + db_ref = dy.reshape(-1, dy.shape[-1]).sum(dim=0).to(dtype) + except Exception as exc: + ref_exc = exc + + layout_invariants = _snapshot_layout_invariants(guard_operands) + + y_bwd_mode.backward(dy) + assert x_bwd_mode.grad is not None + assert module_bwd_mode.weight.grad is not None + dx_bwd_mode = x_bwd_mode.grad.detach().clone() + dw_bwd_mode = module_bwd_mode.weight.grad.detach().clone() + bias = getattr(module_bwd_mode, "bias", None) + db_bwd_mode = None if bias is None or bias.grad is None else bias.grad.detach().clone() + y_bwd_mode = y_bwd_mode_detached + + _assert_layout_invariants_unchanged(layout_invariants) + _raise_if_ref_failed(ref_exc) + assert dx_ref is not None and dw_ref is not None and db_ref is not None + + assert_close(y_bwd_mode, y_quantized_ref, rtol=0, atol=0, check_dtype=True) + assert_close(dx_bwd_mode, dx_ref, rtol=0, atol=0, check_dtype=True) + assert_close(dw_bwd_mode, dw_ref, rtol=0, atol=0, check_dtype=True) + if use_bias: + assert db_bwd_mode is not None + assert db_ref is not None + assert_close(db_bwd_mode, db_ref, rtol=0, atol=0, check_dtype=True) + + +@pytest.mark.parametrize("recipe_name", _quantized_numerics_recipe_list) +@pytest.mark.parametrize("in_features,out_features", _linear_feature_cases) +@pytest.mark.parametrize("use_bias", (False, True), ids=("no_bias", "bias")) +@pytest.mark.parametrize("m_splits", _grouped_m_split_cases) +@pytest.mark.parametrize("dtype", _core_dtypes, ids=str) +def test_grouped_linear_backward_mode_matches_reference( + recipe_name: str, + in_features: int, + out_features: int, + use_bias: bool, + m_splits: list[int], + dtype: torch.dtype, + backward_mode: str, +) -> None: + if recipe_name == "nvfp4": + pytest.skip("NVFP4 not supported for grouped linear") + + reset_rng_states() + _maybe_skip_recipe_dtype(recipe_name, dtype) + _maybe_skip_unsupported_grouped_splits(recipe_name, m_splits) + num_gemms = len(m_splits) + num_tokens = sum(m_splits) + + quantized_ref_recipe = make_recipe(recipe_name, backward_mode="default") + mode_recipe = make_recipe(recipe_name, backward_mode=backward_mode) + + module_quantized_ref = te.GroupedLinear( + num_gemms, + in_features, + out_features, + bias=use_bias, + params_dtype=dtype, + device="cuda", + ) + module_bwd_mode = te.GroupedLinear( + num_gemms, + in_features, + out_features, + bias=use_bias, + params_dtype=dtype, + device="cuda", + ) + _copy_named_parameters(module_quantized_ref, module_bwd_mode) + + x = torch.randn(num_tokens, in_features, dtype=dtype, device="cuda") + dy = torch.randn(num_tokens, out_features, dtype=dtype, device="cuda") + + y_quantized_ref, _, _, _ = _run_grouped_linear_single_step( + module_quantized_ref, + x, + m_splits, + dy, + quantized_ref_recipe, + ) + if backward_mode == "unquant": + # Unquant reference path: grouped module in plain high precision. + module_unquantized_ref = te.GroupedLinear( + num_gemms, + in_features, + out_features, + bias=use_bias, + params_dtype=dtype, + device="cuda", + ) + _copy_named_parameters(module_quantized_ref, module_unquantized_ref) + y_bwd_mode, dx_bwd_mode, dw_bwd_mode, db_bwd_mode = _run_grouped_linear_single_step( + module_bwd_mode, + x, + m_splits, + dy, + mode_recipe, + ) + _, dx_ref, dw_ref, db_ref = _run_grouped_linear_single_step( + module_unquantized_ref, + x, + m_splits, + dy, + None, + ) + else: + # Dequant reference path for grouped GEMMs: + # each GEMM restores its own saved input/weight pair and computes its own ref grads. + y_bwd_mode, x_bwd_mode, saved_operands = _run_grouped_linear_step_with_saved_operands( + module_bwd_mode, x, m_splits, mode_recipe + ) + y_bwd_mode_detached = y_bwd_mode.detach().clone() + + dx_ref: Optional[torch.Tensor] = None + dw_ref: list[torch.Tensor] = [] + db_ref: list[Optional[torch.Tensor]] = [] + layout_invariants: list[dict[str, object]] = [] + guard_operands: list[tuple[str, Optional[torch.Tensor]]] = [] + ref_exc: Optional[Exception] = None + try: + if len(saved_operands) < 2 * num_gemms: + raise RuntimeError( + "Insufficient saved operands for GroupedLinear dequant reference " + f"(got {len(saved_operands)}, expected at least {2 * num_gemms})." + ) + + saved_inputs = saved_operands[:num_gemms] + saved_weights = saved_operands[num_gemms : 2 * num_gemms] + for i, (saved_input, saved_weight) in enumerate(zip(saved_inputs, saved_weights)): + guard_operands.extend( + [ + (f"grouped_input{i}", saved_input), + (f"grouped_weight{i}", saved_weight), + ] + ) + dy_chunks = torch.split(dy, m_splits) + + dx_chunks = [] + dw_ref = [] + db_ref = [] + for dy_chunk, saved_input, saved_weight in zip(dy_chunks, saved_inputs, saved_weights): + dx_i, dw_i, db_i = _compute_linear_backward_reference_from_saved_operands( + saved_input, + saved_weight, + dy_chunk, + dequant_dtype=dtype, + out_dtype=dtype, + ) + dx_chunks.append(dx_i) + dw_ref.append(dw_i) + db_ref.append(db_i if use_bias else None) + dx_ref = torch.cat(dx_chunks, dim=0) + except Exception as exc: + ref_exc = exc + + layout_invariants = _snapshot_layout_invariants(guard_operands) + + y_bwd_mode.backward(dy) + assert x_bwd_mode.grad is not None + dx_bwd_mode = x_bwd_mode.grad.detach().clone() + dw_bwd_mode = [ + getattr(module_bwd_mode, f"weight{i}").grad.detach().clone() + for i in range(module_bwd_mode.num_gemms) + ] + db_bwd_mode = [] + for i in range(module_bwd_mode.num_gemms): + if module_bwd_mode.use_bias: + db_bwd_mode.append(getattr(module_bwd_mode, f"bias{i}").grad.detach().clone()) + else: + db_bwd_mode.append(None) + y_bwd_mode = y_bwd_mode_detached + + _assert_layout_invariants_unchanged(layout_invariants) + _raise_if_ref_failed(ref_exc) + assert dx_ref is not None + + assert_close(y_bwd_mode, y_quantized_ref, rtol=0, atol=0, check_dtype=True) + assert_close(dx_bwd_mode, dx_ref, rtol=0, atol=0, check_dtype=True) + for test_dw, ref_dw in zip(dw_bwd_mode, dw_ref): + assert_close(test_dw, ref_dw, rtol=0, atol=0, check_dtype=True) + if use_bias: + for test_db, ref_db_i in zip(db_bwd_mode, db_ref): + assert test_db is not None + assert ref_db_i is not None + assert_close(test_db, ref_db_i, rtol=0, atol=0, check_dtype=True) + + +@pytest.mark.parametrize("recipe_name", _quantized_numerics_recipe_list) +@pytest.mark.parametrize("module_type", ("linear", "layernorm_linear")) +@pytest.mark.parametrize("input_shape,out_features", _shape_test_cases) +@pytest.mark.parametrize("use_bias", (False, True), ids=("no_bias", "bias")) +@pytest.mark.parametrize("dtype", _core_dtypes, ids=str) +def test_linear_like_runtime_backward_mode_switch_updates_ctx( + recipe_name: str, + module_type: str, + input_shape: tuple[int, ...], + out_features: int, + use_bias: bool, + dtype: torch.dtype, + backward_mode: str, +) -> None: + reset_rng_states() + _maybe_skip_recipe_dtype(recipe_name, dtype) + _maybe_skip_unsupported_recipe_module_combo(recipe_name, module_type) + _maybe_skip_unsupported_recipe_shape(recipe_name, input_shape, module_type) + + module = _make_linear_like_module( + module_type, + input_shape[-1], + out_features, + dtype, + bias=use_bias, + ) + x = torch.randn(*input_shape, dtype=dtype, device="cuda") + dy = torch.randn(*input_shape[:-1], out_features, dtype=dtype, device="cuda") + + default_recipe = make_recipe(recipe_name, backward_mode="default") + mode_recipe = make_recipe(recipe_name, backward_mode=backward_mode) + skip_unsupported_backward_mode(module_type, mode_recipe, backward_mode) + + *_, default_ctx = _run_single_step_with_ctx_state(module, x, dy, default_recipe) + ( + default_mode, + default_fp8, + default_grad_output_quantizer, + default_reduce_and_update, + ) = default_ctx + assert default_mode == "default" + assert default_fp8 + assert default_grad_output_quantizer is not None + assert default_reduce_and_update + + *_, switched_ctx = _run_single_step_with_ctx_state(module, x, dy, mode_recipe) + switched_mode, switched_fp8, switched_grad_output_quantizer, switched_reduce_and_update = ( + switched_ctx + ) + assert switched_mode == backward_mode + assert not switched_fp8 + assert switched_grad_output_quantizer is None + assert not switched_reduce_and_update + + *_, default_ctx_after = _run_single_step_with_ctx_state(module, x, dy, default_recipe) + ( + default_mode_after, + default_fp8_after, + default_grad_output_quantizer_after, + default_reduce_and_update_after, + ) = default_ctx_after + assert default_mode_after == "default" + assert default_fp8_after + assert default_grad_output_quantizer_after is not None + assert default_reduce_and_update_after + + +@pytest.mark.parametrize("recipe_name", _quantized_numerics_recipe_list) +@pytest.mark.parametrize("in_features,out_features", _linear_feature_cases) +@pytest.mark.parametrize("m_splits", _grouped_m_split_cases) +@pytest.mark.parametrize("use_bias", (False, True), ids=("no_bias", "bias")) +@pytest.mark.parametrize("dtype", _core_dtypes, ids=str) +def test_grouped_linear_runtime_backward_mode_switch_updates_ctx( + recipe_name: str, + in_features: int, + out_features: int, + m_splits: list[int], + use_bias: bool, + dtype: torch.dtype, + backward_mode: str, +) -> None: + if recipe_name == "nvfp4": + pytest.skip("NVFP4 not supported for grouped linear") + + reset_rng_states() + _maybe_skip_recipe_dtype(recipe_name, dtype) + _maybe_skip_unsupported_grouped_splits(recipe_name, m_splits) + + num_tokens = sum(m_splits) + module = te.GroupedLinear( + len(m_splits), + in_features, + out_features, + bias=use_bias, + params_dtype=dtype, + device="cuda", + ) + x = torch.randn(num_tokens, in_features, dtype=dtype, device="cuda") + dy = torch.randn(num_tokens, out_features, dtype=dtype, device="cuda") + + default_recipe = make_recipe(recipe_name, backward_mode="default") + mode_recipe = make_recipe(recipe_name, backward_mode=backward_mode) + + *_, default_ctx = _run_grouped_linear_single_step_with_ctx_state( + module, + x, + m_splits, + dy, + default_recipe, + ) + default_mode, default_fp8, default_reduce_and_update = default_ctx + assert default_mode == "default" + assert default_fp8 + assert default_reduce_and_update + + *_, switched_ctx = _run_grouped_linear_single_step_with_ctx_state( + module, + x, + m_splits, + dy, + mode_recipe, + ) + switched_mode, switched_fp8, switched_reduce_and_update = switched_ctx + assert switched_mode == backward_mode + assert not switched_fp8 + assert not switched_reduce_and_update + + *_, default_ctx_after = _run_grouped_linear_single_step_with_ctx_state( + module, + x, + m_splits, + dy, + default_recipe, + ) + default_mode_after, default_fp8_after, default_reduce_and_update_after = default_ctx_after + assert default_mode_after == "default" + assert default_fp8_after + assert default_reduce_and_update_after + + +@pytest.mark.parametrize("recipe_name", _quantized_numerics_recipe_list) +@pytest.mark.parametrize( + "fused_pattern,expected_fused_op", + ( + ("bias_add", ForwardLinearBiasAdd), + ("scale_add", ForwardLinearScaleAdd), + ), +) +@pytest.mark.parametrize("in_features,out_features", _linear_feature_cases) +@pytest.mark.parametrize("m", (1, 32), ids=("m1", "m32")) +@pytest.mark.parametrize("dtype", _fused_dtypes, ids=str) +def test_fused_linear_paths_match_backward_mode_reference( + recipe_name: str, + fused_pattern: str, + expected_fused_op: type, + in_features: int, + out_features: int, + m: int, + dtype: torch.dtype, + backward_mode: str, +) -> None: + _maybe_skip_recipe_dtype(recipe_name, dtype) + _maybe_skip_unsupported_recipe_module_combo(recipe_name, "ops_linear") + _maybe_skip_unsupported_recipe_shape(recipe_name, (m, in_features), "ops_linear") + + reset_rng_states() + + quantized_ref_recipe = make_recipe(recipe_name, backward_mode="default") + mode_recipe = make_recipe(recipe_name, backward_mode=backward_mode) + skip_unsupported_backward_mode("ops_linear", mode_recipe, backward_mode) + + model_quantized_ref = _make_fused_model(fused_pattern, in_features, out_features, dtype) + model_bwd_mode = _make_fused_model(fused_pattern, in_features, out_features, dtype) + _copy_named_parameters(model_quantized_ref, model_bwd_mode) + + x1 = torch.randn(m, in_features, dtype=dtype, device="cuda") + x2 = None + if fused_pattern in ("bias_add", "scale_add"): + x2 = torch.randn(m, out_features, dtype=dtype, device="cuda") + dy = torch.randn(m, out_features, dtype=dtype, device="cuda") + + y_quantized_ref, _, _, _, _ = _run_fused_single_step( + fused_pattern, + model_quantized_ref, + x1, + dy, + quantized_ref_recipe, + x2=x2, + ) + + if backward_mode == "unquant": + # Unquant reference path: replay the same fused model structure in plain + # high precision and compare backward outputs exactly. + model_unquantized_ref = _make_fused_model(fused_pattern, in_features, out_features, dtype) + _copy_named_parameters(model_quantized_ref, model_unquantized_ref) + + y_bwd_mode, dx1_bwd_mode, dx2_bwd_mode, dw_bwd_mode, db_bwd_mode = _run_fused_single_step( + fused_pattern, + model_bwd_mode, + x1, + dy, + mode_recipe, + x2=x2, + ) + _, dx1_ref, dx2_ref, dw_ref, db_ref = _run_fused_single_step( + fused_pattern, + model_unquantized_ref, + x1, + dy, + None, + x2=x2, + ) + else: + # Dequant reference path: compute backward reference from saved quantized + # linear operands (with branch-specific dy handling for fused epilogues). + y_bwd_mode, x1_bwd_mode, x2_bwd_mode_ref, saved_operands = ( + _run_fused_single_step_with_saved_operands( + fused_pattern, + model_bwd_mode, + x1, + mode_recipe, + x2=x2, + ) + ) + y_bwd_mode_detached = y_bwd_mode.detach().clone() + dx1_ref: Optional[torch.Tensor] = None + dx2_ref: Optional[torch.Tensor] = None + dw_ref: Optional[torch.Tensor] = None + db_ref: Optional[torch.Tensor] = None + layout_invariants: list[dict[str, object]] = [] + guard_operands: list[tuple[str, Optional[torch.Tensor]]] = [] + ref_exc: Optional[Exception] = None + try: + saved_input, saved_weight = _extract_linear_saved_operands( + saved_operands, + context=f"fused_{fused_pattern}", + ) + guard_operands.extend( + [ + (f"fused_{fused_pattern}_input", saved_input), + (f"fused_{fused_pattern}_weight", saved_weight), + ] + ) + dy_for_linear = dy * 0.5 if fused_pattern == "scale_add" else dy + dx1_ref, dw_ref, db_ref = _compute_linear_backward_reference_from_saved_operands( + saved_input, + saved_weight, + dy_for_linear, + dequant_dtype=dtype, + out_dtype=dtype, + ) + dx2_ref = dy if x2 is not None else None + except Exception as exc: + ref_exc = exc + + layout_invariants = _snapshot_layout_invariants(guard_operands) + + y_bwd_mode.backward(dy) + assert x1_bwd_mode.grad is not None + dx1_bwd_mode = x1_bwd_mode.grad.detach().clone() + dx2_bwd_mode = ( + x2_bwd_mode_ref.grad.detach().clone() + if x2_bwd_mode_ref is not None and x2_bwd_mode_ref.grad is not None + else None + ) + dw_bwd_mode = model_bwd_mode[0].weight.grad.detach().clone() + db_bwd_mode = None + if ( + getattr(model_bwd_mode[0], "bias", None) is not None + and model_bwd_mode[0].bias.grad is not None + ): + db_bwd_mode = model_bwd_mode[0].bias.grad.detach().clone() + y_bwd_mode = y_bwd_mode_detached + + _assert_layout_invariants_unchanged(layout_invariants) + _raise_if_ref_failed(ref_exc) + assert dx1_ref is not None and dw_ref is not None + + fused_ops = model_bwd_mode._module_groups[0]._forward_ops + assert len(fused_ops) >= 1 + assert isinstance(fused_ops[0][0], expected_fused_op) + + assert_close(y_bwd_mode, y_quantized_ref, rtol=0, atol=0, check_dtype=True) + assert_close(dx1_bwd_mode, dx1_ref, rtol=0, atol=0, check_dtype=True) + assert_close(dw_bwd_mode, dw_ref, rtol=0, atol=0, check_dtype=True) + if dx2_bwd_mode is not None and dx2_ref is not None: + assert_close(dx2_bwd_mode, dx2_ref, rtol=0, atol=0, check_dtype=True) + if db_bwd_mode is not None and db_ref is not None: + assert_close(db_bwd_mode, db_ref, rtol=0, atol=0, check_dtype=True) + + +@pytest.mark.parametrize("recipe_name", _quantized_numerics_recipe_list) +@pytest.mark.parametrize("input_shape", _bias_activation_shape_cases) +@pytest.mark.parametrize("out_features", _output_feature_cases) +@pytest.mark.parametrize("dtype", _fused_dtypes, ids=str) +def test_fused_bias_activation_matches_masked_linear_backward( + recipe_name: str, + input_shape: tuple[int, ...], + out_features: int, + dtype: torch.dtype, + backward_mode: str, +) -> None: + _maybe_skip_recipe_dtype(recipe_name, dtype) + _maybe_skip_unsupported_recipe_module_combo(recipe_name, "ops_linear") + _maybe_skip_unsupported_recipe_shape(recipe_name, input_shape, "ops_linear") + + reset_rng_states() + in_features = input_shape[-1] + + quantized_ref_recipe = make_recipe(recipe_name, backward_mode="default") + mode_recipe = make_recipe(recipe_name, backward_mode=backward_mode) + skip_unsupported_backward_mode("ops_linear", mode_recipe, backward_mode) + + model_quantized_ref = _make_fused_model("bias_activation", in_features, out_features, dtype) + model_bwd_mode = _make_fused_model("bias_activation", in_features, out_features, dtype) + _copy_named_parameters(model_quantized_ref, model_bwd_mode) + + x1 = torch.randn(*input_shape, dtype=dtype, device="cuda") + dy = torch.randn(*((*x1.shape[:-1], out_features)), dtype=dtype, device="cuda") + + y_quantized_ref, _, _, _, _ = _run_fused_single_step( + "bias_activation", + model_quantized_ref, + x1, + dy, + quantized_ref_recipe, + ) + + if backward_mode == "unquant": + # Unquant reference path: build a plain linear reference and apply the + # same activation mask (from quantized forward output) before backward. + linear_unquantized_ref = _make_linear_like_module( + "ops_linear", + in_features, + out_features, + dtype, + bias=True, + ) + _copy_named_parameters(model_bwd_mode[0], linear_unquantized_ref) + + y_bwd_mode, dx1_bwd_mode, _, dw_bwd_mode, db_bwd_mode = _run_fused_single_step( + "bias_activation", + model_bwd_mode, + x1, + dy, + mode_recipe, + ) + dy_after_activation = dy * (y_bwd_mode > 0).to(dy.dtype) + _, dx1_ref, dw_ref, db_ref = _run_single_step( + linear_unquantized_ref, + x1, + dy_after_activation, + None, + ) + else: + # Dequant reference path: restore saved linear operands from fused forward, + # apply the same activation mask, then run linear backward reference. + y_bwd_mode, x1_bwd_mode, _, saved_operands = _run_fused_single_step_with_saved_operands( + "bias_activation", + model_bwd_mode, + x1, + mode_recipe, + ) + y_bwd_mode_detached = y_bwd_mode.detach().clone() + dy_after_activation = dy * (y_bwd_mode > 0).to(dy.dtype) + dx1_ref: Optional[torch.Tensor] = None + dw_ref: Optional[torch.Tensor] = None + db_ref: Optional[torch.Tensor] = None + layout_invariants: list[dict[str, object]] = [] + guard_operands: list[tuple[str, Optional[torch.Tensor]]] = [] + ref_exc: Optional[Exception] = None + try: + saved_input, saved_weight = _extract_linear_saved_operands( + saved_operands, + context="fused_bias_activation", + ) + guard_operands.extend( + [ + ("fused_bias_activation_input", saved_input), + ("fused_bias_activation_weight", saved_weight), + ] + ) + dx1_ref, dw_ref, db_ref = _compute_linear_backward_reference_from_saved_operands( + saved_input, + saved_weight, + dy_after_activation, + dequant_dtype=dtype, + out_dtype=dtype, + ) + except Exception as exc: + ref_exc = exc + + layout_invariants = _snapshot_layout_invariants(guard_operands) + + y_bwd_mode.backward(dy) + assert x1_bwd_mode.grad is not None + dx1_bwd_mode = x1_bwd_mode.grad.detach().clone() + dw_bwd_mode = model_bwd_mode[0].weight.grad.detach().clone() + db_bwd_mode = ( + model_bwd_mode[0].bias.grad.detach().clone() + if model_bwd_mode[0].bias.grad is not None + else None + ) + y_bwd_mode = y_bwd_mode_detached + + _assert_layout_invariants_unchanged(layout_invariants) + _raise_if_ref_failed(ref_exc) + assert dx1_ref is not None and dw_ref is not None and db_ref is not None + + fused_ops = model_bwd_mode._module_groups[0]._forward_ops + assert len(fused_ops) >= 1 + assert isinstance(fused_ops[0][0], ForwardLinearBiasActivation) + + # In unquant/dequant modes, backward-activation+bias fusion should be disabled. + bwd_mode_backward_ops = model_bwd_mode._module_groups[0]._backward_ops + assert not any(isinstance(op, BackwardActivationBias) for op, _ in bwd_mode_backward_ops) + + # Quantized reference should still use fused backward path. + quantized_ref_backward_ops = model_quantized_ref._module_groups[0]._backward_ops + assert any(isinstance(op, BackwardActivationBias) for op, _ in quantized_ref_backward_ops) + + assert_close(y_bwd_mode, y_quantized_ref, rtol=0, atol=0, check_dtype=True) + assert_close(dx1_bwd_mode, dx1_ref, rtol=0, atol=0, check_dtype=True) + assert_close(dw_bwd_mode, dw_ref, rtol=0, atol=0, check_dtype=True) + assert db_bwd_mode is not None + assert db_ref is not None + assert_close(db_bwd_mode, db_ref, rtol=0, atol=0, check_dtype=True) + + +@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) +@pytest.mark.parametrize("recipe_name", _quantized_numerics_recipe_list) +@pytest.mark.parametrize("in_features,out_features", _linear_feature_cases) +@pytest.mark.parametrize("dtype", _core_dtypes, ids=str) +def test_operation_fuser_rebuilds_userbuffers_fusion_on_backward_mode_switch( + recipe_name: str, + in_features: int, + out_features: int, + dtype: torch.dtype, + backward_mode: str, + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Simulate a distributed setup to exercise Userbuffers fusion eligibility + # without launching a multi-rank job. + monkeypatch.setattr(torch.distributed, "is_initialized", lambda: True) + monkeypatch.setattr(torch.distributed, "get_world_size", lambda *_args, **_kwargs: 2) + + # Use a mutable recipe holder so we can switch fusion behavior on the same + # fuser object and verify that the cached fusion plan is refreshed. + current_recipe = {"value": make_recipe(recipe_name, backward_mode="default")} + monkeypatch.setattr(FP8GlobalStateManager, "get_fp8_recipe", lambda: current_recipe["value"]) + + reset_rng_states() + _maybe_skip_unsupported_recipe_module_combo(recipe_name, "ops_linear") + + # Build a Userbuffers-eligible fuser and representative inputs. + linear = te_ops.BasicLinear( + in_features, + out_features, + device="cuda", + dtype=dtype, + userbuffers_options={"comm_name": "qkv"}, + ) + linear.tensor_parallel_mode = "column" + linear.tensor_parallel_size = 2 + linear.sequence_parallel = True + bias = te_ops.Bias(out_features, device="cuda", dtype=dtype) + model = te_ops.Sequential(linear, bias) + model._module_groups = model._make_module_groups(model._modules.values()) + fuser = model._module_groups[0] + x = torch.randn(32, in_features, dtype=dtype, device="cuda", requires_grad=True) + extra_inputs = [() for _ in range(fuser._num_basic_ops)] + + quant_recipe = make_recipe(recipe_name, backward_mode="default") + skip_unsupported_backward_mode("ops_linear", quant_recipe, backward_mode) + fuser.maybe_fuse_ops( + is_grad_enabled=True, + recipe=quant_recipe, + input_=x, + extra_inputs=extra_inputs, + ) + assert any(isinstance(op, UserbuffersForwardLinear) for op, _ in fuser._forward_ops) + + non_quant_recipe = make_recipe(recipe_name, backward_mode=backward_mode) + skip_unsupported_backward_mode("ops_linear", non_quant_recipe, backward_mode) + current_recipe["value"] = non_quant_recipe + fuser.maybe_fuse_ops( + is_grad_enabled=True, + recipe=non_quant_recipe, + input_=x, + extra_inputs=extra_inputs, + ) + assert not any(isinstance(op, UserbuffersForwardLinear) for op, _ in fuser._forward_ops) + + +@pytest.mark.parametrize("recipe_name", _quantized_numerics_recipe_list) +@pytest.mark.parametrize("dtype", _core_dtypes, ids=str) +def test_quantize_op_respects_backward_mode( + recipe_name: str, + dtype: torch.dtype, + backward_mode: str, +) -> None: + _maybe_skip_recipe_dtype(recipe_name, dtype) + _maybe_skip_unsupported_recipe_module_combo(recipe_name, "ops_linear") + reset_rng_states() + + x = torch.randn(32, 64, dtype=dtype, device="cuda") + dy = torch.randn(32, 64, dtype=dtype, device="cuda") + + model_override = te_ops.Sequential(te_ops.Quantize(forward=True, backward=True)) + model_ref = te_ops.Sequential(te_ops.Quantize(forward=True, backward=False)) + + mode_recipe = make_recipe(recipe_name, backward_mode=backward_mode) + skip_unsupported_backward_mode("ops_linear", mode_recipe, backward_mode) + + y_override, dx_override = _run_quantize_op_single_step(model_override, x, dy, mode_recipe) + y_ref, dx_ref = _run_quantize_op_single_step(model_ref, x, dy, mode_recipe) + + assert_close(y_override, y_ref, rtol=0, atol=0, check_dtype=True) + assert_close(dx_override, dx_ref, rtol=0, atol=0, check_dtype=True) + + +@pytest.mark.parametrize("recipe_name", _quantized_numerics_recipe_list) +@pytest.mark.parametrize("module_type", ("linear", "layernorm_linear")) +def test_backward_mode_memory_peak_report( + recipe_name: str, + module_type: str, +) -> None: + """Diagnostic-only memory report for default/unquant/dequant backward modes.""" + reset_rng_states() + dtype = torch.bfloat16 + input_shape = (2048, 2048) + out_features = 2048 * 4 + in_features = input_shape[-1] + use_bias = True + + _maybe_skip_recipe_dtype(recipe_name, dtype) + _maybe_skip_unsupported_recipe_module_combo(recipe_name, module_type) + _maybe_skip_unsupported_recipe_shape(recipe_name, input_shape, module_type) + + base_module = _make_linear_like_module( + module_type, + in_features, + out_features, + dtype, + bias=use_bias, + ) + + x = torch.randn(*input_shape, dtype=dtype, device="cuda") + dy = torch.randn(*input_shape[:-1], out_features, dtype=dtype, device="cuda") + + modes = ("default", "unquant", "dequant") + mode_results: dict[str, dict[str, float] | str] = {} + + for mode in modes: + try: + mode_recipe = make_recipe(recipe_name, backward_mode=mode) + + # Keep params identical across modes for a cleaner apples-to-apples read. + module = _make_linear_like_module( + module_type, + in_features, + out_features, + dtype, + bias=use_bias, + ) + _copy_named_parameters(base_module, module) + + # Warmup run to reduce first-use kernel setup noise. + _run_single_step(module, x, dy, mode_recipe) + + module.zero_grad(set_to_none=True) + x_run = x.detach().clone().requires_grad_(True) + autocast_ctx = te.autocast(enabled=True, recipe=mode_recipe) + + torch.cuda.synchronize() + torch.cuda.reset_peak_memory_stats() + fwd_start_mem = torch.cuda.memory_allocated() + with autocast_ctx: + y = module(x_run) + if isinstance(y, tuple): + y = y[0] + torch.cuda.synchronize() + fwd_peak_alloc = float(torch.cuda.max_memory_allocated() - fwd_start_mem) + fwd_peak_reserved = float(torch.cuda.max_memory_reserved()) + + torch.cuda.reset_peak_memory_stats() + bwd_start_mem = torch.cuda.memory_allocated() + y.backward(dy) + torch.cuda.synchronize() + bwd_peak_alloc = float(torch.cuda.max_memory_allocated() - bwd_start_mem) + bwd_peak_reserved = float(torch.cuda.max_memory_reserved()) + + module.zero_grad(set_to_none=True) + x_run = x.detach().clone().requires_grad_(True) + autocast_ctx = te.autocast(enabled=True, recipe=mode_recipe) + + torch.cuda.synchronize() + torch.cuda.reset_peak_memory_stats() + e2e_start_mem = torch.cuda.memory_allocated() + with autocast_ctx: + y = module(x_run) + if isinstance(y, tuple): + y = y[0] + y.backward(dy) + torch.cuda.synchronize() + e2e_peak_alloc = float(torch.cuda.max_memory_allocated() - e2e_start_mem) + e2e_peak_reserved = float(torch.cuda.max_memory_reserved()) + + mode_results[mode] = { + "fwd_peak_alloc_mb": fwd_peak_alloc / (1024**2), + "fwd_peak_reserved_mb": fwd_peak_reserved / (1024**2), + "bwd_peak_alloc_mb": bwd_peak_alloc / (1024**2), + "bwd_peak_reserved_mb": bwd_peak_reserved / (1024**2), + "e2e_peak_alloc_mb": e2e_peak_alloc / (1024**2), + "e2e_peak_reserved_mb": e2e_peak_reserved / (1024**2), + } + except Exception as exc: # pragma: no cover - diagnostic reporting path + mode_results[mode] = f"{type(exc).__name__}: {exc}" + + print( + "\n[backward_mode_memory_peak_report] " + f"recipe={recipe_name} module_type={module_type} " + f"dtype={dtype} input_shape={input_shape} out_features={out_features}" + ) + print(" units=MB") + metric_col_width = 9 + delta_col_width = 18 + columns = ( + ("mode", metric_col_width), + ("fwd_alloc", metric_col_width), + ("bwd_alloc", metric_col_width), + ("e2e_alloc", metric_col_width), + ("fwd_resrv", metric_col_width), + ("bwd_resrv", metric_col_width), + ("e2e_resrv", metric_col_width), + ("delta_fwd", delta_col_width), + ("delta_bwd", delta_col_width), + ("delta_e2e", delta_col_width), + ) + print(" | ".join(f"{name:>{width}}" for name, width in columns)) + print("-+-".join("-" * width for _, width in columns)) + + def _format_delta_with_pct(delta: float, base: float) -> str: + if math.isclose(base, 0.0, abs_tol=1e-12): + return f"{delta:+.2f} (n/a)" + pct = 100.0 * delta / base + return f"{delta:+.2f} ({pct:+.2f}%)" + + default_metrics = mode_results.get("default") + for mode in modes: + metrics = mode_results[mode] + if isinstance(metrics, str): + print(f"{mode:>{metric_col_width}} | ERROR: {metrics}") + continue + + if isinstance(default_metrics, dict): + delta_fwd = metrics["fwd_peak_alloc_mb"] - default_metrics["fwd_peak_alloc_mb"] + delta_bwd = metrics["bwd_peak_alloc_mb"] - default_metrics["bwd_peak_alloc_mb"] + delta_e2e = metrics["e2e_peak_alloc_mb"] - default_metrics["e2e_peak_alloc_mb"] + delta_fwd_str = _format_delta_with_pct(delta_fwd, default_metrics["fwd_peak_alloc_mb"]) + delta_bwd_str = _format_delta_with_pct(delta_bwd, default_metrics["bwd_peak_alloc_mb"]) + delta_e2e_str = _format_delta_with_pct(delta_e2e, default_metrics["e2e_peak_alloc_mb"]) + else: + delta_fwd_str = "n/a" + delta_bwd_str = "n/a" + delta_e2e_str = "n/a" + + print( + f"{mode:>{metric_col_width}} | " + f"{metrics['fwd_peak_alloc_mb']:{metric_col_width}.2f} | " + f"{metrics['bwd_peak_alloc_mb']:{metric_col_width}.2f} | " + f"{metrics['e2e_peak_alloc_mb']:{metric_col_width}.2f} | " + f"{metrics['fwd_peak_reserved_mb']:{metric_col_width}.2f} | " + f"{metrics['bwd_peak_reserved_mb']:{metric_col_width}.2f} | " + f"{metrics['e2e_peak_reserved_mb']:{metric_col_width}.2f} | " + f"{delta_fwd_str:>{delta_col_width}} | " + f"{delta_bwd_str:>{delta_col_width}} | " + f"{delta_e2e_str:>{delta_col_width}}" + ) diff --git a/tests/pytorch/test_cpu_offloading.py b/tests/pytorch/test_cpu_offloading.py index 7da8dcf863..af8e3b884e 100644 --- a/tests/pytorch/test_cpu_offloading.py +++ b/tests/pytorch/test_cpu_offloading.py @@ -6,6 +6,7 @@ import contextlib import pytest import os +import copy import torch from typing import Optional, List from transformer_engine.pytorch.cpu_offload import ( @@ -18,7 +19,7 @@ from transformer_engine.pytorch.fp8 import FP8GlobalStateManager import transformer_engine.pytorch as te from transformer_engine.common import recipe -from utils import ModelConfig +from utils import ModelConfig, skip_unsupported_backward_mode import transformer_engine_torch as tex # Check supported quantization schemes @@ -416,9 +417,14 @@ def test_multiple_tensor_offload(self, recipe): class TestTELayers: @pytest.mark.parametrize("layer_type", Utils.get_layer_names()) @pytest.mark.parametrize("recipe", quantization_recipes) - def test_sanity(self, layer_type, recipe): + @pytest.mark.parametrize("backward_mode", ["default", "unquant", "dequant"]) + def test_sanity(self, layer_type, recipe, backward_mode): Utils.memory_leak_check() + skip_unsupported_backward_mode(layer_type, recipe, backward_mode) + if recipe is not None: + recipe = copy.deepcopy(recipe) + recipe.backward_mode = backward_mode # Skip ops-based layers with Float8BlockScaling recipe if ( layer_type in ["linear_op", "layernorm_mlp_ops"] @@ -458,9 +464,15 @@ def test_sanity(self, layer_type, recipe): @pytest.mark.parametrize("layer_type", Utils.get_layer_names()) @pytest.mark.parametrize("recipe", quantization_recipes) - def test_memory(self, layer_type, recipe): + @pytest.mark.parametrize("backward_mode", ["default", "unquant", "dequant"]) + def test_memory(self, layer_type, recipe, backward_mode): Utils.memory_leak_check() + skip_unsupported_backward_mode(layer_type, recipe, backward_mode) + if recipe is not None: + recipe = copy.deepcopy(recipe) + recipe.backward_mode = backward_mode + # Skip ops-based layers with Float8BlockScaling recipe if ( layer_type in ["linear_op", "layernorm_mlp_ops"] @@ -537,9 +549,15 @@ def test_memory(self, layer_type, recipe): @pytest.mark.parametrize("layer_type", Utils.get_layer_names()) @pytest.mark.parametrize("recipe", quantization_recipes) - def test_manual_synchronization(self, recipe, layer_type): + @pytest.mark.parametrize("backward_mode", ["default", "unquant", "dequant"]) + def test_manual_synchronization(self, recipe, layer_type, backward_mode): Utils.memory_leak_check() + skip_unsupported_backward_mode(layer_type, recipe, backward_mode) + if recipe is not None: + recipe = copy.deepcopy(recipe) + recipe.backward_mode = backward_mode + # Skip ops-based layers with Float8BlockScaling recipe if ( layer_type in ["linear_op", "layernorm_mlp_ops"] @@ -600,6 +618,7 @@ def test_manual_synchronization(self, recipe, layer_type): out_2.sum().backward() @pytest.mark.parametrize("recipe", quantization_recipes) + @pytest.mark.parametrize("backward_mode", ["default", "unquant", "dequant"]) @pytest.mark.parametrize("layer_type", Utils.get_layer_names()) @pytest.mark.parametrize("use_cuda_graphs", [True, False]) @pytest.mark.parametrize("retain_pinned_cpu_buffers", [True, False]) @@ -607,11 +626,17 @@ def test_manual_synchronization(self, recipe, layer_type): def test_numerics( self, recipe, + backward_mode, layer_type, use_cuda_graphs, backend, retain_pinned_cpu_buffers, ): + skip_unsupported_backward_mode(layer_type, recipe, backward_mode) + if recipe is not None: + recipe = copy.deepcopy(recipe) + recipe.backward_mode = backward_mode + # Skip ops-based layers with Float8BlockScaling recipe if ( layer_type in ["linear_op", "layernorm_mlp_ops"] diff --git a/tests/pytorch/test_cuda_graphs.py b/tests/pytorch/test_cuda_graphs.py index 1b9e11792e..bf304dc240 100644 --- a/tests/pytorch/test_cuda_graphs.py +++ b/tests/pytorch/test_cuda_graphs.py @@ -4,6 +4,7 @@ from typing import Callable, Dict, Iterable, List, Tuple, Union import pytest +import copy import torch from transformer_engine.pytorch import ( @@ -24,7 +25,7 @@ from transformer_engine.pytorch.quantization import FP8GlobalStateManager import transformer_engine.pytorch.ops as te_ops from transformer_engine.common import recipe -from utils import ModelConfig, reset_rng_states +from utils import ModelConfig, reset_rng_states, skip_unsupported_backward_mode # Check if FP8 is supported. fp8_available = is_fp8_available() @@ -360,6 +361,7 @@ def _test_cuda_graphs( @pytest.mark.parametrize("dtype", dtypes) @pytest.mark.parametrize("fp8_params", (False, True)) @pytest.mark.parametrize("fp8_recipe", fp8_recipes + [None], ids=lambda r: type(r).__name__) +@pytest.mark.parametrize("backward_mode", ("default", "unquant", "dequant")) def test_make_graphed_callables( *, module: str, @@ -368,10 +370,17 @@ def test_make_graphed_callables( dtype: torch.dtype, fp8_params: bool, fp8_recipe: recipe.Recipe, + backward_mode: str, fp8_weight_caching: bool = False, ) -> None: fp8 = fp8_recipe is not None + + skip_unsupported_backward_mode(module, fp8_recipe, backward_mode) + if fp8: + fp8_recipe = copy.deepcopy(fp8_recipe) + fp8_recipe.backward_mode = backward_mode + if fp8_params and not fp8: pytest.skip("FP8 needed for FP8 parameters.") if fp8_weight_caching and not fp8: @@ -440,18 +449,21 @@ def test_make_graphed_callables( @pytest.mark.parametrize("dtype", dtypes) @pytest.mark.parametrize("fp8_params", (False, True)) @pytest.mark.parametrize("fp8_recipe", fp8_recipes, ids=lambda r: type(r).__name__) +@pytest.mark.parametrize("backward_mode", ("default", "unquant", "dequant")) def test_make_graphed_callables_with_fp8_weight_caching( *, module: str, dtype: torch.dtype, fp8_params: bool, fp8_recipe: recipe.Recipe, + backward_mode: str, ) -> None: test_make_graphed_callables( module=module, dtype=dtype, fp8_params=fp8_params, fp8_recipe=fp8_recipe, + backward_mode=backward_mode, fp8_weight_caching=True, ) diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 384b6774f6..fd82996cbb 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -7,6 +7,7 @@ import torch import pytest import os +import copy import transformer_engine import transformer_engine.pytorch as te @@ -37,7 +38,7 @@ import transformer_engine_torch as tex from transformer_engine.pytorch.cpp_extensions import general_gemm from transformer_engine.pytorch.tensor.utils import replace_raw_data -from utils import ModelConfig +from utils import ModelConfig, skip_unsupported_backward_mode # Only run FP8 tests on supported devices. fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) @@ -383,6 +384,7 @@ def test_sanity_normalization_amp(dtype, model, skip_wgrad, skip_dgrad, normaliz @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("fp8_recipe", fp8_recipes) +@pytest.mark.parametrize("backward_mode", ["default", "unquant", "dequant"]) @pytest.mark.parametrize("model", ["small", "weird"]) @pytest.mark.parametrize("skip_wgrad", all_boolean) @pytest.mark.parametrize("zero_centered_gamma", all_boolean) @@ -392,6 +394,7 @@ def test_sanity_normalization_amp(dtype, model, skip_wgrad, skip_dgrad, normaliz def test_sanity_layernorm_linear( dtype, fp8_recipe, + backward_mode, model, skip_wgrad, zero_centered_gamma, @@ -401,6 +404,11 @@ def test_sanity_layernorm_linear( ): config = model_configs[model] + skip_unsupported_backward_mode("layernorm_linear", fp8_recipe, backward_mode) + if fp8_recipe is not None: + fp8_recipe = copy.deepcopy(fp8_recipe) + fp8_recipe.backward_mode = backward_mode + if fp8_recipe is not None: if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") @@ -424,13 +432,21 @@ def test_sanity_layernorm_linear( @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("fp8_recipe", fp8_recipes) +@pytest.mark.parametrize("backward_mode", ["default", "unquant", "dequant"]) @pytest.mark.parametrize("model", ["small", "weird"]) @pytest.mark.parametrize("skip_wgrad", all_boolean) @pytest.mark.parametrize("skip_dgrad", all_boolean) @pytest.mark.parametrize("microbatching", all_boolean) -def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad, microbatching): +def test_sanity_linear( + dtype, fp8_recipe, backward_mode, model, skip_wgrad, skip_dgrad, microbatching +): config = model_configs[model] + skip_unsupported_backward_mode("linear", fp8_recipe, backward_mode) + if fp8_recipe is not None: + fp8_recipe = copy.deepcopy(fp8_recipe) + fp8_recipe.backward_mode = backward_mode + if fp8_recipe is not None: if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") @@ -454,13 +470,21 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad, microba @pytest.mark.parametrize("bs", batch_sizes_with_zero) @pytest.mark.parametrize("model", ["small", "weird"]) @pytest.mark.parametrize("fp8_recipe", fp8_recipes) +@pytest.mark.parametrize("backward_mode", ["default", "unquant", "dequant"]) @pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("use_bias", all_boolean) -def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_params, use_bias): +def test_sanity_linear_with_zero_tokens( + dtype, bs, model, fp8_recipe, backward_mode, fp8_model_params, use_bias +): config = model_configs[model] ffn_hidden_size = 4 * config.hidden_size num_tokens = bs * config.max_seqlen_q + skip_unsupported_backward_mode("linear", fp8_recipe, backward_mode) + if fp8_recipe is not None: + fp8_recipe = copy.deepcopy(fp8_recipe) + fp8_recipe.backward_mode = backward_mode + if fp8_recipe is not None: if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") @@ -487,6 +511,7 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_ @pytest.mark.parametrize("bs", batch_sizes_with_zero) @pytest.mark.parametrize("model", ["small", "weird"]) @pytest.mark.parametrize("fp8_recipe", fp8_recipes) +@pytest.mark.parametrize("backward_mode", ["default", "unquant", "dequant"]) @pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("use_bias", all_boolean) @pytest.mark.parametrize("single_param", all_boolean) @@ -497,6 +522,7 @@ def test_sanity_grouped_linear( bs, model, fp8_recipe, + backward_mode, fp8_model_params, use_bias, single_param, @@ -509,6 +535,11 @@ def test_sanity_grouped_linear( bs = bs * 16 num_tokens = bs * config.max_seqlen_q * (num_gemms - 1) + skip_unsupported_backward_mode("grouped_linear", fp8_recipe, backward_mode) + if fp8_recipe is not None: + fp8_recipe = copy.deepcopy(fp8_recipe) + fp8_recipe.backward_mode = backward_mode + if fp8_recipe is not None: if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index 317240fb78..830ca6eecc 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -10,6 +10,7 @@ from typing import Optional, Tuple, Dict, Any, List from packaging.version import Version as PkgVersion +import pytest import torch import transformer_engine @@ -117,7 +118,7 @@ def quantization_tols(name: str) -> dict[str, float]: raise ValueError(f"Unsupported quantization scheme ({name})") -def make_recipe(name: Optional[str]) -> Optional[Recipe]: +def make_recipe(name: Optional[str], **recipe_kwargs: Any) -> Optional[Recipe]: """Make recipe for quantization scheme""" if name is None: return None @@ -125,26 +126,52 @@ def make_recipe(name: Optional[str]) -> Optional[Recipe]: return transformer_engine.common.recipe.DelayedScaling( fp8_format=transformer_engine.common.recipe.Format.E4M3, amax_history_len=8, + **recipe_kwargs, ) if name == "fp8_current_scaling": return transformer_engine.common.recipe.Float8CurrentScaling( fp8_format=transformer_engine.common.recipe.Format.E4M3, + **recipe_kwargs, ) if name == "mxfp8": return transformer_engine.common.recipe.MXFP8BlockScaling( fp8_format=transformer_engine.common.recipe.Format.E4M3, + **recipe_kwargs, ) if name == "fp8_block_scaling": - return transformer_engine.common.recipe.Float8BlockScaling() + return transformer_engine.common.recipe.Float8BlockScaling(**recipe_kwargs) if name == "nvfp4": return transformer_engine.common.recipe.NVFP4BlockScaling( disable_rht=True, disable_stochastic_rounding=True, disable_2d_quantization=True, + **recipe_kwargs, ) raise ValueError(f"Unsupported quantization scheme ({name})") +def skip_unsupported_backward_mode( + layer_type: str, + quant_recipe: Recipe, + backward_mode: str, +) -> None: + """Skip known unsupported layer/recipe/backward-mode combinations used in tests.""" + if backward_mode is None or backward_mode == "default": + return + if quant_recipe is None and backward_mode in ("unquant", "dequant"): + pytest.skip(f"Not a quantized recipe, cannot use backward mode {backward_mode}.") + if quant_recipe.delayed() and backward_mode in ("unquant", "dequant"): + pytest.skip(f"Delayed scaling does not support backward mode {backward_mode}.") + if layer_type in ( + "layernorm_mlp", + "layernorm_mlp_nocheckpoint", + "layernorm_mlp_checkpoint", + "transformer", + "transformer_layer", + ): + pytest.skip(f"{layer_type} does not support NVTE_BACKWARD_MODE={backward_mode}.") + + # Cached RNG state _rng_states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 18577b0eb4..9058f155c4 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -11,6 +11,20 @@ from pydantic.dataclasses import dataclass +_BACKWARD_MODES = ("default", "unquant", "dequant") + + +def _resolve_backward_mode(mode: Optional[str] = None) -> str: + """Return validated backward mode from argument or NVTE_BACKWARD_MODE env.""" + if mode is None: + mode = os.getenv("NVTE_BACKWARD_MODE", "default") + mode = mode.lower() + assert ( + mode in _BACKWARD_MODES + ), f"Invalid NVTE_BACKWARD_MODE value {mode!r}. Supported values are: default|unquant|dequant." + return mode + + class _FormatHelper(NamedTuple): """ Stores max FP8 values for fprop and bprop a `Format`. @@ -188,6 +202,8 @@ def scaling_factor_compute(amax: Tensor, `LayerNormLinear (BF16 output) -> (cast to FP8 ) FP8 DPA (cast to BF16) -> Linear`. When `fp8_mha = True, fp8_dpa = True`, it becomes `LayerNormLinear (FP8 output) -> FP8 DPA -> Linear`. + backward_mode : {'default', 'unquant', 'dequant'}, default = 'default' + Backward precision mode. Delayed scaling only supports `default`. Notes ----- @@ -211,9 +227,14 @@ def scaling_factor_compute(amax: Tensor, reduce_amax: bool = True fp8_dpa: bool = False fp8_mha: bool = False + backward_mode: str = field(default_factory=_resolve_backward_mode) def __post_init__(self) -> None: + self.backward_mode = _resolve_backward_mode(self.backward_mode) assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." + assert ( + self.backward_mode == "default" + ), "Delayed scaling only supports backward_mode=default." def __repr__(self) -> str: return ( @@ -223,7 +244,8 @@ def __repr__(self) -> str: f"amax_history_len={self.amax_history_len}, " f"reduce_amax={self.reduce_amax}, " f"fp8_dpa={self.fp8_dpa}, " - f"fp8_mha={self.fp8_mha}" + f"fp8_mha={self.fp8_mha}, " + f"backward_mode={self.backward_mode}" ) @@ -237,6 +259,11 @@ class Float8CurrentScaling(Recipe): fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.HYBRID Controls the FP8 data format used during forward and backward pass. + backward_mode : {'default', 'unquant', 'dequant'}, default = 'default' + Backward precision mode. `default` performs quantized backward, + `unquant` keeps original high-precision operands for backward, + and `dequant` dequantizes saved operands to the active high-precision + compute dtype (e.g. BF16/FP16/FP32) for backward. """ use_power_2_scales: bool = os.getenv("NVTE_FP8_CURRENT_SCALING_POWER_2_SCALES", "0") == "1" @@ -249,8 +276,10 @@ class Float8CurrentScaling(Recipe): fp8_gemm_wgrad: MMParams = MMParams(use_split_accumulator=True) fp8_dpa: bool = False fp8_mha: bool = False + backward_mode: str = field(default_factory=_resolve_backward_mode) def __post_init__(self) -> None: + self.backward_mode = _resolve_backward_mode(self.backward_mode) assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." def __repr__(self) -> str: @@ -264,7 +293,8 @@ def __repr__(self) -> str: f"fp8_gemm_dgrad={self.fp8_gemm_dgrad}, " f"fp8_gemm_wgrad={self.fp8_gemm_wgrad}, " f"fp8_dpa={self.fp8_dpa}, " - f"fp8_mha={self.fp8_mha}" + f"fp8_mha={self.fp8_mha}, " + f"backward_mode={self.backward_mode}" ) @@ -291,21 +321,29 @@ class MXFP8BlockScaling(Recipe): fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.E4M3 Controls the FP8 data format used during forward and backward pass. + backward_mode : {'default', 'unquant', 'dequant'}, default = 'default' + Backward precision mode. `default` performs quantized backward, + `unquant` keeps original high-precision operands for backward, + and `dequant` dequantizes saved operands to the active high-precision + compute dtype (e.g. BF16/FP16/FP32) for backward. """ margin: int = 0 fp8_format: Format = Format.E4M3 fp8_dpa: bool = False fp8_mha: bool = False + backward_mode: str = field(default_factory=_resolve_backward_mode) def __post_init__(self) -> None: + self.backward_mode = _resolve_backward_mode(self.backward_mode) assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." def __repr__(self) -> str: return ( f"recipe_type={self.__class__.__name__}, " f"margin={self.margin}, " - f"format={str(self.fp8_format).split('.')[1]}" + f"format={str(self.fp8_format).split('.')[1]}, " + f"backward_mode={self.backward_mode}" ) @@ -334,6 +372,11 @@ class Float8BlockScaling(Recipe): fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.E4M3 Controls the FP8 data format used during forward and backward pass. + backward_mode : {'default', 'unquant', 'dequant'}, default = 'default' + Backward precision mode. `default` performs quantized backward, + `unquant` keeps original high-precision operands for backward, + and `dequant` dequantizes saved operands to the active high-precision + compute dtype (e.g. BF16/FP16/FP32) for backward. """ use_f32_scales: bool = os.getenv("NVTE_FP8_BLOCK_SCALING_FP32_SCALES", "0") == "1" @@ -350,8 +393,10 @@ class Float8BlockScaling(Recipe): fp8_gemm_wgrad: MMParams = MMParams(use_split_accumulator=True) fp8_dpa: bool = False fp8_mha: bool = False + backward_mode: str = field(default_factory=_resolve_backward_mode) def __post_init__(self) -> None: + self.backward_mode = _resolve_backward_mode(self.backward_mode) assert self.x_block_scaling_dim in [1, 2], "Only 1D or 2D blocks supported for x" assert self.w_block_scaling_dim in [1, 2], "Only 1D or 2D blocks supported for w" assert self.grad_block_scaling_dim in [1, 2], "Only 1D or 2D blocks supported for grad" @@ -386,7 +431,8 @@ def __repr__(self) -> str: f"fp8_gemm_dgrad={self.fp8_gemm_dgrad}, " f"fp8_gemm_wgrad={self.fp8_gemm_wgrad}, " f"fp8_dpa={self.fp8_dpa}, " - f"fp8_mha={self.fp8_mha}" + f"fp8_mha={self.fp8_mha}, " + f"backward_mode={self.backward_mode}" ) @@ -435,6 +481,11 @@ class NVFP4BlockScaling(Recipe): If set to `True`, stochastic rounding is disabled during quantization for all tensors. disable_2d_quantization : bool, default = False If set to `True`, 1D block scaling with block size 16 is used for all tensors. + backward_mode : {'default', 'unquant', 'dequant'}, default = 'default' + Backward precision mode. `default` performs quantized backward, + `unquant` keeps original high-precision operands for backward, + and `dequant` dequantizes saved operands to the active high-precision + compute dtype (e.g. BF16/FP16/FP32) for backward. """ # Configuration envvars @@ -450,8 +501,10 @@ class NVFP4BlockScaling(Recipe): # Not applying quantization to attention for now fp8_dpa: bool = False fp8_mha: bool = False + backward_mode: str = field(default_factory=_resolve_backward_mode) def __post_init__(self) -> None: + self.backward_mode = _resolve_backward_mode(self.backward_mode) assert self.fp4_format == Format.E2M1, "Only E2M1 is supported for NVFP4 scaling" assert self.fp8_format == Format.E4M3, "Only E4M3 is supported for NVFP4 scaling" @@ -481,6 +534,7 @@ def __repr__(self) -> str: f"fp8_format={str(self.fp8_format).split('.')[1]}, " f"fp8_dpa={self.fp8_dpa}, " f"fp8_mha={self.fp8_mha}, " + f"backward_mode={self.backward_mode}, " f"fp4_quant_fwd_inp={self.fp4_quant_fwd_inp}, " f"fp4_quant_fwd_weight={self.fp4_quant_fwd_weight}, " f"fp4_quant_bwd_grad={self.fp4_quant_bwd_grad}, " @@ -512,12 +566,25 @@ class CustomRecipe(Recipe): - forward: "linear_input", "linear_weight", "linear_output" - backward: "linear_grad_output", "linear_grad_input" + backward_mode : {'default', 'unquant', 'dequant'}, default = 'default' + Backward precision mode. `default` performs quantized backward, + `unquant` keeps original high-precision operands for backward, + and `dequant` dequantizes saved operands to the active high-precision + compute dtype (e.g. BF16/FP16/FP32) for backward. """ qfactory: Callable[..., Any] fp8_dpa: bool = False fp8_mha: bool = False + backward_mode: str = field(default_factory=_resolve_backward_mode) + + def __post_init__(self) -> None: + self.backward_mode = _resolve_backward_mode(self.backward_mode) def __repr__(self) -> str: - return f"recipe_type={self.__class__.__name__}, qfactory={self.qfactory}" + return ( + f"recipe_type={self.__class__.__name__}, " + f"qfactory={self.qfactory}, " + f"backward_mode={self.backward_mode}" + ) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 28da4873f0..2ca1f1ace2 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1184,9 +1184,10 @@ def grad_output_preprocess( grad_output = grad_output.reshape((-1, grad_output.shape[-1])) grad_output = grad_output.contiguous() gather_grad_output = row_parallel_mode and ctx.sequence_parallel + use_fp8_bwd = ctx.fp8 and ctx.backward_mode == "default" # Non-FP8 case: bgrad is fused with wgrad for this case. - if not ctx.fp8 and not ctx.debug: + if not use_fp8_bwd and not ctx.debug: if gather_grad_output: if not ctx.ub_overlap_ag: # Perform NCCL all-gather grad_output, _ = gather_along_first_dim(grad_output, ctx.tp_group) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index fade2957d5..ef97e72ce2 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -97,6 +97,12 @@ def forward( save_original_input, debug, ) = non_tensor_args + if fp8: + backward_mode = FP8GlobalStateManager.get_fp8_recipe().backward_mode + else: + backward_mode = "default" + if backward_mode == "unquant": + save_original_input = True num_gemms = len(m_splits) weights = weights_and_biases[:num_gemms] @@ -112,10 +118,15 @@ def forward( input_quantizer.set_usage( rowwise=True, columnwise=( - is_grad_enabled and weight_requires_grad and not save_original_input + is_grad_enabled + and weight_requires_grad + and not save_original_input + and backward_mode == "default" ), ) columnwise_usage = is_grad_enabled and inp.requires_grad + if backward_mode in ("unquant", "dequant"): + columnwise_usage = False if not columnwise_usage: columnwise_usage = ( is_fp8_activation_recompute_enabled() @@ -240,7 +251,12 @@ def forward( else: for inputmat in inputmats: if isinstance(inputmat, QuantizedTensorStorage): - inputmat.update_usage(rowwise_usage=False, columnwise_usage=True) + if backward_mode in ("unquant", "dequant"): + # In dequant mode we should dequantize directly from + # fprop quantized layouts without retargeting usage. + inputmat.update_usage(rowwise_usage=True, columnwise_usage=False) + else: + inputmat.update_usage(rowwise_usage=False, columnwise_usage=True) else: inputmats = [None] * num_gemms @@ -291,6 +307,7 @@ def forward( ctx.activation_dtype = activation_dtype ctx.fp8 = fp8 ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None + ctx.backward_mode = backward_mode ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation ctx.cpu_offloading = cpu_offloading ctx.is_first_microbatch = is_first_microbatch @@ -309,6 +326,18 @@ def forward( ctx.save_original_input = save_original_input ctx.input_quantizers = input_quantizers + # Non-quantized backward mode overrides + if backward_mode in ("unquant", "dequant"): + ctx.fp8 = False + ctx.ub_overlap_ag = False + ctx.ub_overlap_rs_dgrad = False + ctx.ub_bulk_dgrad = False + ctx.ub_bulk_wgrad = False + ctx.grad_input_quantizer = None + ctx.grad_weight_quantizer = None + ctx.grad_output_quantizer = None + ctx.reduce_and_update_bwd_fp8_tensors = False + # [*, in_features] -> [*, out_features] except first dimension changes for SP return out.view(-1, *inp.shape[1:-1], out.shape[-1]) @@ -403,13 +432,32 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], dtype=ctx.activation_dtype, device=ctx.device, ) + weights_for_dgrad = weights + if ctx.backward_mode == "dequant": + weights_for_dgrad = [ + ( + weight.dequantize(dtype=ctx.activation_dtype) + if isinstance(weight, QuantizedTensorStorage) + else cast_if_needed(weight, ctx.activation_dtype) + ) + for weight in weights + ] + elif ctx.backward_mode == "unquant": + weights_for_dgrad = [ + ( + weight.dequantize(dtype=ctx.activation_dtype) + if isinstance(weight, QuantizedTensorStorage) + else cast_if_needed(weight, ctx.activation_dtype) + ) + for weight in origin_weights + ] # Make sure weights are available in column-wise format # for dgrad computation. - for weight in weights: + for weight in weights_for_dgrad: if isinstance(weight, QuantizedTensorStorage): weight.update_usage(columnwise_usage=True) general_grouped_gemm( - weights, + weights_for_dgrad, grad_output, [dgrad], ctx.grad_input_quantizers, @@ -464,6 +512,30 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], inputmats = torch.split( cast_if_needed(inp_view, ctx.activation_dtype), ctx.m_splits ) + elif ctx.backward_mode == "dequant": + inputmats_dequant = [] + for m_split, inputmat in zip(ctx.m_splits, inputmats): + if isinstance(inputmat, QuantizedTensorStorage): + if m_split == 0: + # Dequant kernels for some quantized storage formats + # (e.g. MXFP8/Float8BlockScaling) do not accept empty + # M-dimension inputs. For empty grouped splits, materialize + # an explicit empty high-precision matrix instead of invoking + # dequantize(). + inputmats_dequant.append( + torch.empty( + (0, ctx.weights_shape_1), + dtype=ctx.activation_dtype, + device=ctx.device, + ) + ) + else: + inputmats_dequant.append( + inputmat.dequantize(dtype=ctx.activation_dtype) + ) + else: + inputmats_dequant.append(cast_if_needed(inputmat, ctx.activation_dtype)) + inputmats = inputmats_dequant grouped_gemm_wgrad = functools.partial( general_grouped_gemm, quantization_params=ctx.grad_weight_quantizers, @@ -1073,6 +1145,13 @@ def _get_quantizers(self): for i in range(self.num_gemms): grad_output_quantizers[i].internal = True grad_output_quantizers[i].optimize_for_gemm = True + fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() + if fp8_recipe.backward_mode == "dequant" and (fp8_recipe.mxfp8() or fp8_recipe.nvfp4()): + for input_quantizer in input_quantizers: + input_quantizer.optimize_for_gemm = False + if torch.is_grad_enabled(): + for grad_output_quantizer in grad_output_quantizers: + grad_output_quantizer.optimize_for_gemm = False return ( input_quantizers, weight_quantizers, diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index a90105477c..31f65a5239 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -140,6 +140,10 @@ def forward( symmetric_ar_type, debug, ) = non_tensor_args + if fp8: + backward_mode = FP8GlobalStateManager.get_fp8_recipe().backward_mode + else: + backward_mode = "default" # NVTX label for profiling nvtx_label = "transformer_engine._LayerNormLinear.forward" @@ -198,7 +202,10 @@ def forward( if fp8: if input_quantizer is None: raise ValueError("Missing quantizer for input tensor") - input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input) + input_quantizer.set_usage( + rowwise=True, + columnwise=backward_needs_input and backward_mode == "default", + ) if with_input_all_gather and input_quantizer.supports_only_rowwise_all_gather(): # All-gather is not supported with FP8 column-wise data input_quantizer.set_usage(columnwise=False) @@ -211,6 +218,7 @@ def forward( and not debug and not return_layernorm_output and not return_layernorm_output_gathered + and backward_mode == "default" and not custom # TODO(negvet): and not FP8GlobalStateManager.get_fp8_recipe().custom() ) @@ -234,6 +242,7 @@ def forward( ln_out_return = None if return_layernorm_output or return_layernorm_output_gathered: ln_out_return = ln_out + ln_out_hp = ln_out if backward_mode == "unquant" else None # ------------------------------------------------------ # Prepare GEMM input tensor @@ -295,7 +304,10 @@ def forward( if is_weight_param_quantized and not debug: weight_quantizer = weight._quantizer elif weight_quantizer is not None: - weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled) + weight_quantizer.set_usage( + rowwise=True, + columnwise=is_grad_enabled and backward_mode == "default", + ) # Get quantized weight update_workspace = is_first_microbatch is None or is_first_microbatch @@ -408,13 +420,16 @@ def forward( # ------------------------------------------------------ if is_grad_enabled: + ln_out_to_save = ln_out + if backward_mode == "unquant": + ln_out_to_save = ln_out_hp ctx.weight_quantizer = weight_quantizer ctx.ln_out_needs_gather = ( weight.requires_grad and parallel_mode == "column" and sequence_parallel ) # Input with column-wise usage is needed for wgrad GEMM. - if backward_needs_input: + if backward_needs_input and backward_mode == "default": if isinstance(ln_out, QuantizedTensorStorage): # For sequence parallel in vanilla FP8, rowwise data is # to gather the input. For MXFP8, columnwise only data @@ -426,7 +441,7 @@ def forward( ln_out.update_usage(rowwise_usage=False) if cpu_offloading: - mark_activation_offload(inputmat, mu, rsigma, ln_out) + mark_activation_offload(inputmat, mu, rsigma, ln_out_to_save) # Scatter intermediate/activation tensors saved for the backward pass # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already @@ -438,7 +453,7 @@ def forward( mu, rsigma, weightmat if fp8 and not is_weight_param_quantized else None, - ln_out if weight.requires_grad else None, + ln_out_to_save if weight.requires_grad else None, ) nvtx_range_pop(f"{nvtx_label}.fsdp_scatter") @@ -465,7 +480,7 @@ def forward( weight, bias, ln_weight, - ln_out, + ln_out_to_save, mu, rsigma, ) @@ -492,6 +507,7 @@ def forward( ctx.activation_dtype = activation_dtype ctx.fp8 = fp8 ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None + ctx.backward_mode = backward_mode ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation ctx.cpu_offloading = cpu_offloading ctx.is_first_microbatch = is_first_microbatch @@ -522,6 +538,18 @@ def forward( ctx.wgrad_store = wgrad_store ctx.debug = debug + # Non-quantized backward mode overrides + if backward_mode in ("unquant", "dequant"): + ctx.fp8 = False + ctx.ub_overlap_ag = False + ctx.ub_overlap_rs_dgrad = False + ctx.ub_bulk_dgrad = False + ctx.ub_bulk_wgrad = False + ctx.grad_input_quantizer = None + ctx.grad_weight_quantizer = None + ctx.grad_output_quantizer = None + ctx.reduce_and_update_bwd_fp8_tensors = False + # ------------------------------------------------------ # Cached state for backward pass is ready... # ------------------------------------------------------ @@ -662,9 +690,14 @@ def backward( # -------------------------------------------------- ln_out_total = None ln_out_total_work = None + if ctx.backward_mode == "dequant": + if isinstance(ln_out, QuantizedTensorStorage): + ln_out = ln_out.dequantize(dtype=ctx.activation_dtype) + else: + ln_out = cast_if_needed(ln_out, ctx.activation_dtype) if ctx.ln_out_needs_gather: quantizer = None - if ctx.input_quantizer is not None: + if ctx.input_quantizer is not None and ctx.fp8: quantizer = ctx.input_quantizer if quantizer.supports_only_rowwise_all_gather(): # If data is in FP8, we compute FP8 transposes manually @@ -702,7 +735,11 @@ def backward( # Make sure required data is available if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(rowwise_usage=True) - if ctx.weight_quantizer is not None and isinstance(weight, QuantizedTensorStorage): + if ( + ctx.fp8 + and ctx.weight_quantizer is not None + and isinstance(weight, QuantizedTensorStorage) + ): weight.update_usage(columnwise_usage=True) # Choose whether to use GEMM kernel with split accumulator @@ -729,8 +766,18 @@ def backward( # dgrad GEMM # Note: dx = dy * w nvtx_range_push(f"{nvtx_label}.dgrad_gemm") + weight_for_dgrad = weight + if ctx.backward_mode == "dequant": + if isinstance(weight_for_dgrad, QuantizedTensorStorage): + weight_for_dgrad = weight_for_dgrad.dequantize(dtype=ctx.activation_dtype) + else: + weight_for_dgrad = cast_if_needed(weight_for_dgrad, ctx.activation_dtype) + elif ctx.backward_mode == "unquant": + weight_for_dgrad = origin_weight + if isinstance(weight_for_dgrad, QuantizedTensorStorage): + weight_for_dgrad = weight_for_dgrad.dequantize(dtype=ctx.activation_dtype) gemm_out, *_, reduce_scatter_out = general_gemm( - weight, + weight_for_dgrad, grad_output, layout="NN", grad=True, @@ -1627,6 +1674,11 @@ def _get_quantizers(self, fp8_output, fp8_grad, is_grad_enabled): grad_output_quantizer.optimize_for_gemm = True if fp8_grad: grad_input_quantizer = self.quantizers["scaling_bwd"][FP8BwdTensorIdx.GRAD_INPUT1] + fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() + if fp8_recipe.backward_mode == "dequant" and (fp8_recipe.mxfp8() or fp8_recipe.nvfp4()): + input_quantizer.optimize_for_gemm = False + if grad_output_quantizer is not None: + grad_output_quantizer.optimize_for_gemm = False return ( input_quantizer, diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 037fb6c858..4f206c866e 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -234,6 +234,14 @@ def _forward( debug, recompute_for_bwd, ) = non_tensor_args + if fp8: + backward_mode = FP8GlobalStateManager.get_fp8_recipe().backward_mode + else: + backward_mode = "default" + assert backward_mode == "default", ( + "NVTE_BACKWARD_MODE=unquant/dequant is not implemented in LayerNormMLP. " + "Replace LayerNormMLP with LayerNormLinear + Linear to enable unquant/dequant backward." + ) # if grad is enabled and this is not the bwd stage, we must save this so bwd knows which path to take if is_grad_enabled and not recompute_for_bwd: @@ -780,6 +788,7 @@ def _forward( ctx.fc2_main_grad_func = lambda: fc2_weight.main_grad ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None + ctx.backward_mode = backward_mode ctx.fc1_grad_input_quantizer = fc1_grad_input_quantizer ctx.fc1_grad_weight_quantizer = fc1_grad_weight_quantizer ctx.fc1_grad_output_quantizer = fc1_grad_output_quantizer diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 1e3eadc405..6e54f9de5d 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -128,6 +128,12 @@ def forward( save_original_input, debug, ) = non_tensor_args + if fp8: + backward_mode = FP8GlobalStateManager.get_fp8_recipe().backward_mode + else: + backward_mode = "default" + if backward_mode == "unquant": + save_original_input = True # NVTX label for profiling nvtx_label = "transformer_engine._Linear.forward" @@ -187,7 +193,10 @@ def forward( raise ValueError("Missing quantizer for input tensor") if not isinstance(inputmat, QuantizedTensorStorage) and not custom: own_quantized_input = True - input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input) + input_quantizer.set_usage( + rowwise=True, + columnwise=backward_needs_input and backward_mode == "default", + ) if isinstance( input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) ): @@ -229,7 +238,12 @@ def forward( if input_quantizer is None: raise ValueError("Missing quantizer for input tensor") input_quantizer.set_usage( - rowwise=True, columnwise=backward_needs_input and not save_original_input + rowwise=True, + columnwise=( + backward_needs_input + and not save_original_input + and backward_mode == "default" + ), ) inputmat = input_quantizer(inputmat) own_quantized_input = True @@ -254,6 +268,8 @@ def forward( # for debug mode we create quantizer every iteration, thus we need to set the quantizer states if weight_quantizer is not None and (not isinstance(weight, QuantizedTensor) or debug): columnwise_usage = is_grad_enabled and inp.requires_grad + if backward_mode in ("unquant", "dequant"): + columnwise_usage = False if not columnwise_usage: columnwise_usage = ( is_fp8_activation_recompute_enabled() @@ -387,7 +403,11 @@ def forward( and own_quantized_input and isinstance(inputmat, QuantizedTensorStorage) ): - if ( + if backward_mode in ("unquant", "dequant"): + # In dequant mode we should dequantize directly from the + # fprop quantized tensor layout without retargeting usage. + inputmat.update_usage(rowwise_usage=True, columnwise_usage=False) + elif ( ctx.backward_input_needs_gather and weight_quantizer.supports_only_rowwise_all_gather() ): @@ -442,6 +462,7 @@ def forward( ctx.activation_dtype = activation_dtype ctx.fp8 = fp8 ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None + ctx.backward_mode = backward_mode ctx.input_quantizer = input_quantizer ctx.grad_input_quantizer = grad_input_quantizer ctx.grad_weight_quantizer = grad_weight_quantizer @@ -485,6 +506,18 @@ def forward( FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module ctx.wgrad_store = wgrad_store + # Non-quantized backward mode overrides + if backward_mode in ("unquant", "dequant"): + ctx.fp8 = False + ctx.ub_overlap_ag = False + ctx.ub_overlap_rs_dgrad = False + ctx.ub_bulk_dgrad = False + ctx.ub_bulk_wgrad = False + ctx.grad_input_quantizer = None + ctx.grad_weight_quantizer = None + ctx.grad_output_quantizer = None + ctx.reduce_and_update_bwd_fp8_tensors = False + # ------------------------------------------------------ # Cached state for backward pass is ready... # ------------------------------------------------------ @@ -689,8 +722,10 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Make sure required data is available if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(rowwise_usage=True) - if ctx.weight_quantizer is not None and isinstance( - weight_fp8, QuantizedTensorStorage + if ( + ctx.fp8 + and ctx.weight_quantizer is not None + and isinstance(weight_fp8, QuantizedTensorStorage) ): weight_fp8.update_usage(columnwise_usage=True) @@ -719,8 +754,18 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Note: dx = dy * w nvtx_range_push(f"{nvtx_label}.dgrad_gemm") + weight_for_dgrad = weight_fp8 + if ctx.backward_mode == "dequant": + if isinstance(weight_for_dgrad, QuantizedTensorStorage): + weight_for_dgrad = weight_for_dgrad.dequantize(dtype=ctx.activation_dtype) + else: + weight_for_dgrad = cast_if_needed(weight_for_dgrad, ctx.activation_dtype) + elif ctx.backward_mode == "unquant": + weight_for_dgrad = weight + if isinstance(weight_for_dgrad, QuantizedTensorStorage): + weight_for_dgrad = weight_for_dgrad.dequantize(dtype=ctx.activation_dtype) gemm_out, *_, reduce_scatter_out = general_gemm( - weight_fp8, + weight_for_dgrad, grad_output, layout="NN", grad=True, @@ -1495,6 +1540,11 @@ def _get_quantizers(self, fp8_output, fp8_grad, is_grad_enabled): grad_output_quantizer.optimize_for_gemm = True if fp8_grad: grad_input_quantizer = self.quantizers["scaling_bwd"][FP8BwdTensorIdx.GRAD_INPUT1] + fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() + if fp8_recipe.backward_mode == "dequant" and (fp8_recipe.mxfp8() or fp8_recipe.nvfp4()): + input_quantizer.optimize_for_gemm = False + if grad_output_quantizer is not None: + grad_output_quantizer.optimize_for_gemm = False return ( input_quantizer, weight_quantizer, diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index 48376a297f..f2e03fa087 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -332,12 +332,15 @@ def pre_fuser_forward(self, *, requires_grad: bool) -> None: # Note: We cache the quantized input for backward pass, # but discard the quantized weights. weight_requires_grad = requires_grad and self.weight.requires_grad + columnwise_usage = weight_requires_grad + if FP8GlobalStateManager.get_fp8_recipe().backward_mode in ("unquant", "dequant"): + columnwise_usage = False input_quantizer = self.get_quantizer("forward", 0) weight_quantizer = self.get_quantizer("forward", 1) grad_output_quantizer = self.get_quantizer("backward", 0) - input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + input_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) weight_quantizer.set_usage(rowwise=True, columnwise=False) - grad_output_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + grad_output_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None: super().reset_recipe_state(recipe=recipe) @@ -355,6 +358,15 @@ def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None: grad_output_quantizer.internal = True if not (self.tensor_parallel_mode == "row" and self.sequence_parallel): grad_output_quantizer.optimize_for_gemm = True + if FP8GlobalStateManager.is_fp8_enabled(): + fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() + if fp8_recipe.backward_mode in ("unquant", "dequant") and ( + fp8_recipe.mxfp8() or fp8_recipe.nvfp4() + ): + if input_quantizer is not None: + input_quantizer.optimize_for_gemm = False + if grad_output_quantizer is not None: + grad_output_quantizer.optimize_for_gemm = False # Configure weight quantizer # Note: This function may be called in base class constructor, @@ -420,6 +432,7 @@ def _functional_forward( tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None, sequence_parallel: bool = False, with_quantized_compute: bool = False, + backward_mode: str = "default", input_quantizer: Optional[Quantizer] = None, weight_quantizer: Optional[Quantizer] = None, output_quantizer: Optional[Quantizer] = None, @@ -459,6 +472,8 @@ def _functional_forward( distributing along inner dimension (embedding dim) with_quantized_compute: bool, default = False Whether to perform compute with quantized data. + backward_mode: {`"default"`, `"unquant"`, `"dequant"`}, default = `"default"` + Backward-mode policy for quantized compute. input_quantizer: Quantizer, optional Builder class for quantized input tensor. weight_quantizer: Quantizer, optional @@ -510,7 +525,10 @@ def _functional_forward( if with_quantized_compute: if input_quantizer is None: raise ValueError("Missing quantizer for input tensor") - input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + input_quantizer.set_usage( + rowwise=True, + columnwise=weight_requires_grad and backward_mode == "default", + ) if with_x_all_gather: input_quantizer.set_usage(columnwise=False) x, x_async = gather_along_first_dim( @@ -542,7 +560,10 @@ def _functional_forward( elif with_quantized_compute and not is_quantized_tensor(w): if weight_quantizer is None: raise ValueError("Missing quantizer for weight tensor") - weight_quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) + weight_quantizer.set_usage( + rowwise=True, + columnwise=input_requires_grad and backward_mode == "default", + ) w = weight_quantizer(w) # Check output tensor @@ -611,14 +632,23 @@ def _functional_forward( # Prepare weight tensor for backward pass if input_requires_grad: - if w is not weight and with_quantized_compute and is_quantized_tensor(w): + if ( + w is not weight + and with_quantized_compute + and is_quantized_tensor(w) + and backward_mode == "default" + ): w.update_usage(rowwise_usage=False, columnwise_usage=True) else: w = None # Prepare input tensor for backward pass if weight_requires_grad: - if with_quantized_compute and is_quantized_tensor(x_local): + if ( + with_quantized_compute + and is_quantized_tensor(x_local) + and backward_mode == "default" + ): if not (isinstance(x_local, Float8TensorStorage) and with_x_all_gather): # FP8 does not support all-gather of transpose data x_local.update_usage(rowwise_usage=False, columnwise_usage=True) @@ -968,6 +998,10 @@ def op_forward( grad_output_quantizer = self.get_quantizer("backward", 0) grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() + if with_quantized_compute: + backward_mode = FP8GlobalStateManager.get_fp8_recipe().backward_mode + else: + backward_mode = "default" # Get autocast dtype if needed if torch.is_autocast_enabled(): @@ -984,6 +1018,7 @@ def op_forward( tensor_parallel_group=self.tensor_parallel_group, sequence_parallel=self.sequence_parallel, with_quantized_compute=with_quantized_compute, + backward_mode=backward_mode, input_quantizer=input_quantizer, weight_quantizer=weight_quantizer, output_quantizer=output_quantizer, @@ -993,10 +1028,13 @@ def op_forward( # Save state for backward pass if ctx.requires_grad: + saved_input = input_ if backward_mode == "unquant" else x_local + saved_weight = self.weight if backward_mode == "unquant" else w if is_cpu_offload_enabled(): - mark_activation_offload(x_local) - ctx.save_for_backward(x_local, w) - ctx.with_quantized_compute = with_quantized_compute + mark_activation_offload(saved_input) + ctx.save_for_backward(saved_input, saved_weight) + ctx.with_quantized_compute = with_quantized_compute and backward_mode == "default" + ctx.backward_mode = backward_mode ctx.input_quantizer = input_quantizer ctx.weight_quantizer = weight_quantizer ctx.grad_output_quantizer = grad_output_quantizer diff --git a/transformer_engine/pytorch/ops/basic/bias.py b/transformer_engine/pytorch/ops/basic/bias.py index d580f84866..ad147a8d85 100644 --- a/transformer_engine/pytorch/ops/basic/bias.py +++ b/transformer_engine/pytorch/ops/basic/bias.py @@ -10,6 +10,7 @@ import torch import transformer_engine_torch as tex +from ...quantization import FP8GlobalStateManager from ..op import BasicOperation, OperationContext from ...utils import canonicalize_device, canonicalize_dtype from ...tensor import Quantizer @@ -124,6 +125,10 @@ def op_forward( if ctx.requires_grad: ctx.grad_input_quantizer = prev_op_grad_output_quantizer + if FP8GlobalStateManager.is_fp8_enabled(): + fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() + if fp8_recipe.backward_mode in ("unquant", "dequant"): + ctx.grad_input_quantizer = None return x + b diff --git a/transformer_engine/pytorch/ops/basic/quantize.py b/transformer_engine/pytorch/ops/basic/quantize.py index fa3efc3807..c5474c18a0 100644 --- a/transformer_engine/pytorch/ops/basic/quantize.py +++ b/transformer_engine/pytorch/ops/basic/quantize.py @@ -59,6 +59,11 @@ def op_forward( quantize_forward = fp8_enabled and self._quantize_forward quantize_backward = fp8_enabled and self._quantize_backward + # Backward quantization is controlled by recipe backward mode. + if fp8_enabled: + recipe = FP8GlobalStateManager.get_fp8_recipe() + quantize_backward = quantize_backward and recipe.backward_mode == "default" + # Quantize if needed out = input_ if quantize_forward and not is_quantized_tensor(out): diff --git a/transformer_engine/pytorch/ops/fused/backward_activation_bias.py b/transformer_engine/pytorch/ops/fused/backward_activation_bias.py index 4ab082d32b..7b3025c03e 100644 --- a/transformer_engine/pytorch/ops/fused/backward_activation_bias.py +++ b/transformer_engine/pytorch/ops/fused/backward_activation_bias.py @@ -104,8 +104,9 @@ def fuse_backward_ops( """ - # Check if recipe supports bias activation fusion - if recipe is None: + # Check if recipe supports bias activation fusion. + # unquant/dequant backward modes should use unfused backward ops. + if recipe is None or recipe.backward_mode in ("unquant", "dequant"): return ops # Scan through ops, fusing if possible diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py index dfc11a19e7..7584891384 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py @@ -92,6 +92,10 @@ def fuser_forward( grad_output_quantizer = linear_op.get_quantizer("backward", 0) grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() + if with_quantized_compute: + backward_mode = FP8GlobalStateManager.get_fp8_recipe().backward_mode + else: + backward_mode = "default" # Get autocast dtype if needed if torch.is_autocast_enabled(): @@ -109,6 +113,7 @@ def fuser_forward( tensor_parallel_group=linear_op.tensor_parallel_group, sequence_parallel=linear_op.sequence_parallel, with_quantized_compute=with_quantized_compute, + backward_mode=backward_mode, input_quantizer=input_quantizer, weight_quantizer=weight_quantizer, output_quantizer=output_quantizer, @@ -118,10 +123,16 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: + saved_input, saved_weight = x_local, w + if backward_mode == "unquant": + saved_input, saved_weight = input_, linear_op.weight if is_cpu_offload_enabled(): - mark_activation_offload(x_local) - linear_op_ctx.save_for_backward(x_local, w) - linear_op_ctx.with_quantized_compute = with_quantized_compute + mark_activation_offload(saved_input) + linear_op_ctx.save_for_backward(saved_input, saved_weight) + linear_op_ctx.with_quantized_compute = ( + with_quantized_compute and backward_mode == "default" + ) + linear_op_ctx.backward_mode = backward_mode linear_op_ctx.input_quantizer = input_quantizer linear_op_ctx.weight_quantizer = weight_quantizer linear_op_ctx.grad_output_quantizer = grad_output_quantizer @@ -131,6 +142,8 @@ def fuser_forward( linear_op_ctx.weight_requires_grad = weight_requires_grad if bias_op is not None and bias_op_ctx.requires_grad: bias_op_ctx.grad_input_quantizer = linear_op.get_grad_output_quantizer() + if backward_mode in ("unquant", "dequant"): + bias_op_ctx.grad_input_quantizer = None return output, [() for _ in range(len(self.basic_ops))] diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py index 2dfc0566b7..6935330f4e 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py @@ -86,6 +86,10 @@ def fuser_forward( grad_output_quantizer = linear_op.get_quantizer("backward", 0) grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() + if with_quantized_compute: + backward_mode = FP8GlobalStateManager.get_fp8_recipe().backward_mode + else: + backward_mode = "default" # Get autocast dtype if needed if torch.is_autocast_enabled(): @@ -106,6 +110,7 @@ def fuser_forward( tensor_parallel_group=linear_op.tensor_parallel_group, sequence_parallel=linear_op.sequence_parallel, with_quantized_compute=with_quantized_compute, + backward_mode=backward_mode, input_quantizer=input_quantizer, weight_quantizer=weight_quantizer, output_quantizer=output_quantizer, @@ -115,10 +120,16 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: + saved_input, saved_weight = x_local, w + if backward_mode == "unquant": + saved_input, saved_weight = input_, linear_op.weight if is_cpu_offload_enabled(): - mark_activation_offload(x_local) - linear_op_ctx.save_for_backward(x_local, w) - linear_op_ctx.with_quantized_compute = with_quantized_compute + mark_activation_offload(saved_input) + linear_op_ctx.save_for_backward(saved_input, saved_weight) + linear_op_ctx.with_quantized_compute = ( + with_quantized_compute and backward_mode == "default" + ) + linear_op_ctx.backward_mode = backward_mode linear_op_ctx.input_quantizer = input_quantizer linear_op_ctx.weight_quantizer = weight_quantizer linear_op_ctx.grad_output_quantizer = grad_output_quantizer @@ -127,7 +138,9 @@ def fuser_forward( linear_op_ctx.input_requires_grad = input_requires_grad linear_op_ctx.weight_requires_grad = weight_requires_grad if bias_op is not None and bias_op_ctx.requires_grad: - bias_op_ctx.grad_input_quantizer = linear_op.get_grad_output_quantizer() + bias_op_ctx.grad_input_quantizer = ( + None if backward_mode != "default" else linear_op.get_grad_output_quantizer() + ) return output, [() for _ in range(len(self.basic_ops))] diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py index ae4bdd4b19..2358140c88 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py @@ -65,6 +65,10 @@ def fuser_forward( grad_output_quantizer = linear_op.get_quantizer("backward", 0) grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() + if with_quantized_compute: + backward_mode = FP8GlobalStateManager.get_fp8_recipe().backward_mode + else: + backward_mode = "default" # Get extra input tensor for add operation extra_input = basic_op_extra_inputs[2][0] @@ -87,6 +91,7 @@ def fuser_forward( tensor_parallel_group=linear_op.tensor_parallel_group, sequence_parallel=linear_op.sequence_parallel, with_quantized_compute=with_quantized_compute, + backward_mode=backward_mode, input_quantizer=input_quantizer, weight_quantizer=weight_quantizer, output_quantizer=output_quantizer, @@ -96,10 +101,16 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: + saved_input, saved_weight = x_local, w + if backward_mode == "unquant": + saved_input, saved_weight = input_, linear_op.weight if is_cpu_offload_enabled(): - mark_activation_offload(x_local) - linear_op_ctx.save_for_backward(x_local, w) - linear_op_ctx.with_quantized_compute = with_quantized_compute + mark_activation_offload(saved_input) + linear_op_ctx.save_for_backward(saved_input, saved_weight) + linear_op_ctx.with_quantized_compute = ( + with_quantized_compute and backward_mode == "default" + ) + linear_op_ctx.backward_mode = backward_mode linear_op_ctx.input_quantizer = input_quantizer linear_op_ctx.weight_quantizer = weight_quantizer linear_op_ctx.grad_output_quantizer = grad_output_quantizer diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py index 0d3e1d0416..54411f650d 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py @@ -388,6 +388,19 @@ def fuse_forward_ops( """ + # Disable Userbuffers for non-quantized backward modes. + # In unquant/dequant modes we want to avoid all UB-specific overlap + # paths and run through the standard non-UB operator sequence instead. + recipe = unused.get("recipe", None) + if recipe is not None: + backward_mode = recipe.backward_mode + elif FP8GlobalStateManager.is_fp8_enabled(): + backward_mode = FP8GlobalStateManager.get_fp8_recipe().backward_mode + else: + backward_mode = "default" + if backward_mode in ("unquant", "dequant"): + return ops + # Return immediately if environment is not distributed if not torch.distributed.is_initialized() or torch.distributed.get_world_size() == 1: return ops diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index 80386db2d9..616c075ad8 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -339,6 +339,7 @@ def __init__( # Cache and detect change of state relevant for fusing operations self.recipe_type = None self.first_op_requiring_backward = 0 + self.backward_mode = "default" self._last_amax_history_len = 0 # Flatten list of parameters @@ -415,9 +416,14 @@ def maybe_fuse_ops( # Early exit if fusion parameters haven't changed need_reset = False recipe_type = type(recipe) - fusion_params = (recipe_type, first_op_requiring_backward) - if fusion_params != (self.recipe_type, self.first_op_requiring_backward): - # Recipe type or grad requirmenets have changed + backward_mode = recipe.backward_mode if recipe is not None else "default" + fusion_params = (recipe_type, first_op_requiring_backward, backward_mode) + if fusion_params != ( + self.recipe_type, + self.first_op_requiring_backward, + self.backward_mode, + ): + # Recipe type, backward mode, or grad requirements have changed need_reset = True elif ( recipe is not None @@ -451,7 +457,7 @@ def maybe_fuse_ops( ) # Save current fusion params - self.recipe_type, self.first_op_requiring_backward = fusion_params + self.recipe_type, self.first_op_requiring_backward, self.backward_mode = fusion_params # Save amax history length if isinstance(recipe, DelayedScaling): diff --git a/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py index 52e292125e..2cede7e832 100644 --- a/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py @@ -222,6 +222,10 @@ def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: """ if dtype is None: dtype = self._dtype + + if 0 in self.size(): + return torch.empty(self.size(), dtype=dtype, device=self.device) + block_len = 128 if not self._is_2D_scaled: return self._dequantize_vectorwise(dtype=dtype) diff --git a/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py index 7bbe809c9d..34d507dd7e 100644 --- a/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py @@ -182,6 +182,8 @@ def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: """Dequantize to a higher precision.""" if dtype is None: dtype = self._dtype + if self._rowwise_data is not None and 0 in self._rowwise_data.size(): + return torch.empty(self.size(), dtype=dtype, device=self.device) return _FromMXFP8Func.forward(None, self, dtype) def size(self, *args, **kwargs): diff --git a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py index fb163c9032..9fdbe8d595 100644 --- a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py @@ -213,6 +213,8 @@ def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: """Dequantize to a higher precision.""" if dtype is None: dtype = self._dtype + if self._rowwise_data is not None and 0 in self._rowwise_data.size(): + return torch.empty(self.size(), dtype=dtype, device=self.device) return _FromNVFP4Func.forward(None, self, dtype) def size(self, dim: Optional[int] = None) -> Union[torch.Size, int]: