Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
c778208
PoC of the changes
ptrendx Dec 1, 2025
1b1c574
Early exit from the Free function for the empty tensor
ptrendx Dec 2, 2025
ac14119
Use the proper function for nvtx range
ptrendx Dec 2, 2025
616af19
Only do mark_not_offload when the cpu_offloading is enabled
ptrendx Dec 12, 2025
c0d2ccc
First pass on making the setattr issue not come back
ptrendx Dec 15, 2025
68b6f74
Actually add pytest.ini
ptrendx Dec 15, 2025
1c5434c
Changes to __init__
ptrendx Dec 16, 2025
778019d
A different way
ptrendx Dec 16, 2025
0fc2a62
WAR the fact that it is not possible to set __setattr__ dynamically
ptrendx Dec 16, 2025
2fb6ee3
Simpler solution and fixes
ptrendx Jan 10, 2026
4940724
Fix for the inference mode DPA
ptrendx Jan 14, 2026
8704a59
Start of debugging debug tools
ptrendx Jan 14, 2026
aa783a7
More fixes in debug
ptrendx Jan 14, 2026
fd158f6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 14, 2026
2601e5e
Speculative moving the validate_name to the constructor
ptrendx Jan 14, 2026
09a3a7c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 14, 2026
026b4b1
Fix
ptrendx Jan 14, 2026
e83e816
Making the debug tools names saner
ptrendx Jan 14, 2026
8843a46
Change the setattr usage in the tensor parallel group setting
ptrendx Jan 14, 2026
084847f
Adding try/finally - it does not seem to impact the time in observable
ptrendx Jan 14, 2026
d004747
Fixing lint issues and the thunder test
ptrendx Jan 15, 2026
523944a
Fix 1 of the debug tests
ptrendx Jan 15, 2026
9873235
Removed the warning and enforcement in the CI
ptrendx Jan 15, 2026
7b55639
try-finally in the context manager
ptrendx Jan 16, 2026
028d03f
Fixing the debug tests
ptrendx Jan 20, 2026
0d949ef
Merge branch 'main' into pr_python_cpu_optimization
ptrendx Jan 20, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/pytorch/attention/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -2751,7 +2751,7 @@ def forward(
cu_seqlens,
max_s,
) -> torch.Tensor:
with self.prepare_forward(inp, num_gemms=3) as inp:
with self.prepare_forward_ctx(inp, num_gemms=3) as inp:
out = _custom_mha_fp8.apply(
inp,
self.qkv_weight,
Expand Down
25 changes: 19 additions & 6 deletions tests/pytorch/debug/test_sanity.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,17 @@
stats: [min, max, mean, std, l1_norm, l2_norm, cur_amax, dynamic_range]
start_step : 0
end_step: 1
""",
"log_fp8": """log_fp8:
layers:
layer_types: [linear]
enabled:
True
transformer_engine:
LogFp8TensorStats:
enabled: True
tensors: [activation, gradient, weight]
stats: [underflows, overflows]
stats: [underflows%]
start_step : 0
end_step: 1
""",
Expand All @@ -46,22 +53,26 @@
FakeQuant:
enabled: True
gemms: [fprop, dgrad, wgrad]
tensors: [activation, weight, gradient]
quant_format: FP8E5M2
""",
}

# Configs that require FP8 to be enabled
fp8_required_configs = {"log_fp8"}


def _get_model(model_key):
if model_key == "linear":
return te.Linear(D, D)
return te.Linear(D, D, name="layer")
if model_key == "layernorm_linear":
return te.LayerNormLinear(D, D)
return te.LayerNormLinear(D, D, name="layer")
if model_key == "layernorm_mlp":
return te.LayerNormMLP(D, D, D)
return te.LayerNormMLP(D, D, D, name="layer")
if model_key == "mha_attention":
return te.MultiheadAttention(D, H)
return te.MultiheadAttention(D, H, name="layer")
if model_key == "transformer_layer":
return te.TransformerLayer(D, D, H)
return te.TransformerLayer(D, D, H, name="layer")


def _run_forward_backward(model, fp8):
Expand Down Expand Up @@ -95,4 +106,6 @@ def _run_test(model_key, fp8, config, feature_dirs, config_file, log_dir):
def test_sanity_debug(model_key, fp8, config_key, feature_dirs):
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if not fp8 and config_key in fp8_required_configs:
pytest.skip(f"Config '{config_key}' requires FP8")
_run_test(model_key, fp8, configs[config_key], feature_dirs)
4 changes: 2 additions & 2 deletions transformer_engine/common/transformer_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -454,9 +454,9 @@ class TensorAllocator {
}

void Free(NVTETensor t) {
std::lock_guard<std::mutex> lock(mutex);
uintptr_t index = reinterpret_cast<uintptr_t>(t);
if (index == 0) return;
std::lock_guard<std::mutex> lock(mutex);
NVTE_CHECK(index <= memory.size(), "Invalid tensor.");
free_list.push_back(index);
// Clean up
Expand Down Expand Up @@ -564,9 +564,9 @@ class GroupedTensorAllocator {
}

void Free(NVTEGroupedTensor t) {
std::lock_guard<std::mutex> lock(mutex);
uintptr_t index = reinterpret_cast<uintptr_t>(t);
if (index == 0) return;
std::lock_guard<std::mutex> lock(mutex);
NVTE_CHECK(index <= memory.size(), "Invalid grouped tensor.");
free_list.push_back(index);
// Clean up
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -676,9 +676,9 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None:
# assume attention uses the same fp8_group as GEMMs
fp8_group = FP8GlobalStateManager.get_fp8_group()

self.fp8_parameters = FP8GlobalStateManager.with_fp8_parameters()
self.fp8 = FP8GlobalStateManager.is_fp8_enabled()
self.fp8_calibration = FP8GlobalStateManager.is_fp8_calibration()
self.fast_setattr("fp8_parameters", FP8GlobalStateManager.with_fp8_parameters())
self.fast_setattr("fp8", FP8GlobalStateManager.is_fp8_enabled())
self.fast_setattr("fp8_calibration", FP8GlobalStateManager.is_fp8_calibration())
fp8_enabled = self.fp8 or self.fp8_calibration
self.fp8_meta["fp8_checkpoint"] = self.fp8 or self.fp8_calibration
if self.fp8_parameters or fp8_enabled:
Expand All @@ -703,7 +703,7 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None:
)
else:
# If fp8 isn't enabled, turn off and return.
self.fp8_initialized = False
self.fast_setattr("fp8_initialized", False)
return

if self.fp8_parameters and not self.fp8_initialized:
Expand All @@ -721,7 +721,7 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None:

# Allocate scales and amaxes
self.init_fp8_meta_tensors(fp8_recipes)
self.fp8_initialized = True
self.fast_setattr("fp8_initialized", True)

self.fp8_meta["recipe"] = fp8_recipe_dpa
if fp8_recipe != fp8_recipe_dpa:
Expand Down Expand Up @@ -1000,7 +1000,7 @@ def forward(
cases. It is ignored for other backends and when context parallelism is enabled.
"""

with self.prepare_forward(
with self.prepare_forward_ctx(
query_layer,
num_gemms=3,
allow_non_contiguous=True,
Expand Down Expand Up @@ -1145,10 +1145,11 @@ def forward(
if attn_mask_type == "padding_causal":
attn_mask_type = attn_mask_type + "_bottom_right"

self.attention_type = "cross"
self.flash_attention.attention_type = self.attention_type
self.fused_attention.attention_type = self.attention_type
self.unfused_attention.attention_type = self.attention_type
if self.attention_type != "cross":
self.fast_setattr("attention_type", "cross")
self.flash_attention.attention_type = self.attention_type
self.fused_attention.attention_type = self.attention_type
self.unfused_attention.attention_type = self.attention_type

query_layer, key_layer, value_layer = [
x.contiguous() if not x.is_contiguous() else x
Expand Down
5 changes: 1 addition & 4 deletions transformer_engine/pytorch/attention/multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from typing import Callable, List, Optional, Tuple, Union
import torch

from transformer_engine.debug.pytorch.debug_state import TEDebugState
from transformer_engine.pytorch.quantization import FP8GlobalStateManager
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
Expand Down Expand Up @@ -335,6 +334,7 @@ def __init__(
self.hidden_size_kv = self.hidden_size_per_attention_head * self.num_gqa_groups

self.name = name
TransformerEngineBaseModule._validate_name(self)

common_gemm_kwargs = {
"fuse_wgrad_accumulation": fuse_wgrad_accumulation,
Expand Down Expand Up @@ -739,9 +739,6 @@ def forward(
core_attention_bias_type in AttnBiasTypes
), f"core_attention_bias_type {core_attention_bias_type} is not supported!"

if TEDebugState.debug_enabled:
TransformerEngineBaseModule._validate_name(self)

# =================================================
# Pre-allocate memory for key-value cache for inference
# =================================================
Expand Down
8 changes: 4 additions & 4 deletions transformer_engine/pytorch/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,8 +729,8 @@ def checkpoint(
if isinstance(function, TransformerEngineBaseModule):
# If this TE module is FSDP-wrapped, clear its FSDP group information because there's no need
# to scatter/gather activations that we will recompute anyway.
setattr(function, "fsdp_wrapped", False)
setattr(function, "fsdp_group", None)
function.fast_setattr("fsdp_wrapped", False)
function.fast_setattr("fsdp_group", None)

# Otherwise discard unused te.utils.checkpoint.checkpoint() arguments
# and execute TE's own checkpointing
Expand Down Expand Up @@ -2022,7 +2022,7 @@ def prepare_te_modules_for_fsdp(fsdp_root: torch.nn.Module) -> None:
)
root_state = _get_module_fsdp_state(fsdp_root)
assert root_state is not None, "Root module does not have a valid _FSDPState."
setattr(fsdp_root.module, "fsdp_group", root_state.process_group)
fsdp_root.module.fast_setattr("fsdp_group", root_state.process_group)

# Iterate through all FSDP-wrapped submodules and inject FSDP information into TE modules
fsdp_states, fsdp_modules = _get_fsdp_states_with_modules(fsdp_root)
Expand All @@ -2033,7 +2033,7 @@ def prepare_te_modules_for_fsdp(fsdp_root: torch.nn.Module) -> None:
"TE modules with primary weights in FP8 cannot be FSDP-wrapped. "
"Please initialize your model without the te.quantized_model_init(...) context."
)
setattr(fsdp_module.module, "fsdp_group", state.process_group)
fsdp_module.module.fast_setattr("fsdp_group", state.process_group)


class FullyShardedDataParallel(FSDP):
Expand Down
Loading
Loading