diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index eb7905bcd5..b29d1289f8 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -2751,7 +2751,7 @@ def forward( cu_seqlens, max_s, ) -> torch.Tensor: - with self.prepare_forward(inp, num_gemms=3) as inp: + with self.prepare_forward_ctx(inp, num_gemms=3) as inp: out = _custom_mha_fp8.apply( inp, self.qkv_weight, diff --git a/tests/pytorch/debug/test_sanity.py b/tests/pytorch/debug/test_sanity.py index aee5474e76..2bc4b35590 100644 --- a/tests/pytorch/debug/test_sanity.py +++ b/tests/pytorch/debug/test_sanity.py @@ -30,10 +30,17 @@ stats: [min, max, mean, std, l1_norm, l2_norm, cur_amax, dynamic_range] start_step : 0 end_step: 1 +""", + "log_fp8": """log_fp8: + layers: + layer_types: [linear] + enabled: + True + transformer_engine: LogFp8TensorStats: enabled: True tensors: [activation, gradient, weight] - stats: [underflows, overflows] + stats: [underflows%] start_step : 0 end_step: 1 """, @@ -46,22 +53,26 @@ FakeQuant: enabled: True gemms: [fprop, dgrad, wgrad] + tensors: [activation, weight, gradient] quant_format: FP8E5M2 """, } +# Configs that require FP8 to be enabled +fp8_required_configs = {"log_fp8"} + def _get_model(model_key): if model_key == "linear": - return te.Linear(D, D) + return te.Linear(D, D, name="layer") if model_key == "layernorm_linear": - return te.LayerNormLinear(D, D) + return te.LayerNormLinear(D, D, name="layer") if model_key == "layernorm_mlp": - return te.LayerNormMLP(D, D, D) + return te.LayerNormMLP(D, D, D, name="layer") if model_key == "mha_attention": - return te.MultiheadAttention(D, H) + return te.MultiheadAttention(D, H, name="layer") if model_key == "transformer_layer": - return te.TransformerLayer(D, D, H) + return te.TransformerLayer(D, D, H, name="layer") def _run_forward_backward(model, fp8): @@ -95,4 +106,6 @@ def _run_test(model_key, fp8, config, feature_dirs, config_file, log_dir): def test_sanity_debug(model_key, fp8, config_key, feature_dirs): if fp8 and not fp8_available: pytest.skip(reason_for_no_fp8) + if not fp8 and config_key in fp8_required_configs: + pytest.skip(f"Config '{config_key}' requires FP8") _run_test(model_key, fp8, configs[config_key], feature_dirs) diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 6880dd560a..06971443dd 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -454,9 +454,9 @@ class TensorAllocator { } void Free(NVTETensor t) { - std::lock_guard lock(mutex); uintptr_t index = reinterpret_cast(t); if (index == 0) return; + std::lock_guard lock(mutex); NVTE_CHECK(index <= memory.size(), "Invalid tensor."); free_list.push_back(index); // Clean up @@ -564,9 +564,9 @@ class GroupedTensorAllocator { } void Free(NVTEGroupedTensor t) { - std::lock_guard lock(mutex); uintptr_t index = reinterpret_cast(t); if (index == 0) return; + std::lock_guard lock(mutex); NVTE_CHECK(index <= memory.size(), "Invalid grouped tensor."); free_list.push_back(index); // Clean up diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 6e5a12a103..51ffbc2e48 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -676,9 +676,9 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: # assume attention uses the same fp8_group as GEMMs fp8_group = FP8GlobalStateManager.get_fp8_group() - self.fp8_parameters = FP8GlobalStateManager.with_fp8_parameters() - self.fp8 = FP8GlobalStateManager.is_fp8_enabled() - self.fp8_calibration = FP8GlobalStateManager.is_fp8_calibration() + self.fast_setattr("fp8_parameters", FP8GlobalStateManager.with_fp8_parameters()) + self.fast_setattr("fp8", FP8GlobalStateManager.is_fp8_enabled()) + self.fast_setattr("fp8_calibration", FP8GlobalStateManager.is_fp8_calibration()) fp8_enabled = self.fp8 or self.fp8_calibration self.fp8_meta["fp8_checkpoint"] = self.fp8 or self.fp8_calibration if self.fp8_parameters or fp8_enabled: @@ -703,7 +703,7 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: ) else: # If fp8 isn't enabled, turn off and return. - self.fp8_initialized = False + self.fast_setattr("fp8_initialized", False) return if self.fp8_parameters and not self.fp8_initialized: @@ -721,7 +721,7 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: # Allocate scales and amaxes self.init_fp8_meta_tensors(fp8_recipes) - self.fp8_initialized = True + self.fast_setattr("fp8_initialized", True) self.fp8_meta["recipe"] = fp8_recipe_dpa if fp8_recipe != fp8_recipe_dpa: @@ -1000,7 +1000,7 @@ def forward( cases. It is ignored for other backends and when context parallelism is enabled. """ - with self.prepare_forward( + with self.prepare_forward_ctx( query_layer, num_gemms=3, allow_non_contiguous=True, @@ -1145,10 +1145,11 @@ def forward( if attn_mask_type == "padding_causal": attn_mask_type = attn_mask_type + "_bottom_right" - self.attention_type = "cross" - self.flash_attention.attention_type = self.attention_type - self.fused_attention.attention_type = self.attention_type - self.unfused_attention.attention_type = self.attention_type + if self.attention_type != "cross": + self.fast_setattr("attention_type", "cross") + self.flash_attention.attention_type = self.attention_type + self.fused_attention.attention_type = self.attention_type + self.unfused_attention.attention_type = self.attention_type query_layer, key_layer, value_layer = [ x.contiguous() if not x.is_contiguous() else x diff --git a/transformer_engine/pytorch/attention/multi_head_attention.py b/transformer_engine/pytorch/attention/multi_head_attention.py index f875fd1e0a..d813e7c8f1 100644 --- a/transformer_engine/pytorch/attention/multi_head_attention.py +++ b/transformer_engine/pytorch/attention/multi_head_attention.py @@ -8,7 +8,6 @@ from typing import Callable, List, Optional, Tuple, Union import torch -from transformer_engine.debug.pytorch.debug_state import TEDebugState from transformer_engine.pytorch.quantization import FP8GlobalStateManager from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor from transformer_engine.pytorch.module.base import TransformerEngineBaseModule @@ -335,6 +334,7 @@ def __init__( self.hidden_size_kv = self.hidden_size_per_attention_head * self.num_gqa_groups self.name = name + TransformerEngineBaseModule._validate_name(self) common_gemm_kwargs = { "fuse_wgrad_accumulation": fuse_wgrad_accumulation, @@ -739,9 +739,6 @@ def forward( core_attention_bias_type in AttnBiasTypes ), f"core_attention_bias_type {core_attention_bias_type} is not supported!" - if TEDebugState.debug_enabled: - TransformerEngineBaseModule._validate_name(self) - # ================================================= # Pre-allocate memory for key-value cache for inference # ================================================= diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index 004a04ab4c..f269e21b8c 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -729,8 +729,8 @@ def checkpoint( if isinstance(function, TransformerEngineBaseModule): # If this TE module is FSDP-wrapped, clear its FSDP group information because there's no need # to scatter/gather activations that we will recompute anyway. - setattr(function, "fsdp_wrapped", False) - setattr(function, "fsdp_group", None) + function.fast_setattr("fsdp_wrapped", False) + function.fast_setattr("fsdp_group", None) # Otherwise discard unused te.utils.checkpoint.checkpoint() arguments # and execute TE's own checkpointing @@ -2022,7 +2022,7 @@ def prepare_te_modules_for_fsdp(fsdp_root: torch.nn.Module) -> None: ) root_state = _get_module_fsdp_state(fsdp_root) assert root_state is not None, "Root module does not have a valid _FSDPState." - setattr(fsdp_root.module, "fsdp_group", root_state.process_group) + fsdp_root.module.fast_setattr("fsdp_group", root_state.process_group) # Iterate through all FSDP-wrapped submodules and inject FSDP information into TE modules fsdp_states, fsdp_modules = _get_fsdp_states_with_modules(fsdp_root) @@ -2033,7 +2033,7 @@ def prepare_te_modules_for_fsdp(fsdp_root: torch.nn.Module) -> None: "TE modules with primary weights in FP8 cannot be FSDP-wrapped. " "Please initialize your model without the te.quantized_model_init(...) context." ) - setattr(fsdp_module.module, "fsdp_group", state.process_group) + fsdp_module.module.fast_setattr("fsdp_group", state.process_group) class FullyShardedDataParallel(FSDP): diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 875d245a8f..841cdf04ca 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -10,9 +10,8 @@ import warnings from enum import Enum from abc import ABC, abstractmethod -from typing import Any, Dict, Generator, List, Optional, Set, Tuple, Union +from typing import Any, Dict, Generator, List, Optional, Tuple, Union from contextlib import contextmanager -import logging from types import MethodType import torch @@ -50,6 +49,8 @@ is_non_tn_fp8_gemm_supported, torch_get_autocast_gpu_dtype, get_nvtx_range_context, + nvtx_range_push, + nvtx_range_pop, ) from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage from ...common.recipe import DelayedScaling, Recipe @@ -605,10 +606,10 @@ def fill_userbuffers_buffer_for_all_gather( class TransformerEngineBaseModule(torch.nn.Module, ABC): """Base TE module.""" - def __init__(self) -> None: + def __init__(self, name: Optional[str] = None) -> None: super().__init__() assert torch.cuda.is_available(), "TransformerEngine needs CUDA." - self.name = None + self.name = name self.next_iter_when_debug_should_be_run = 0 self.fp8_initialized = False self.fp8 = False @@ -633,26 +634,22 @@ def __init__(self) -> None: if not TEDebugState.debug_enabled: TEDebugState.initialize() + self._validate_name() - # Names of attributes that can be set quickly (see __setattr__ - # method) - _fast_setattr_names: Set[str] = { - "activation_dtype", - "fp8", - "fp8_initialized", - "fp8_calibration", - "fp8_parameters", - } + def fast_setattr(self, name: str, value: Any) -> None: + """ + Fast version of the Module's set attribute function. + Should be used for regular attributes, but not properties nor parameters/buffers. + """ + self.__dict__[name] = value - def __setattr__(self, name: str, value: Any) -> None: - if name in TransformerEngineBaseModule._fast_setattr_names: - # torch.nn.Module has a custom __setattr__ that handles - # modules, parameters, and buffers. This is unnecessary - # overhead when setting plain attrs. - self.__dict__[name] = value - else: - # Default case - super().__setattr__(name, value) + def module_setattr(self, name: str, value: Any) -> None: + """ + Regular version of the Module's set attribute function. + Should be used only when the fast version cannot be used - for the properties, + parameters and buffers. + """ + super().__setattr__(name, value) def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> None: """ @@ -773,7 +770,7 @@ def init_fp8_meta_tensors(self, recipe: Recipe) -> None: self.set_meta_tensor(True, recipe) self.set_meta_tensor(False, recipe) - self.fp8_meta_tensors_initialized = True + self.fast_setattr("fp8_meta_tensors_initialized", True) def get_fp8_meta_tensors(self) -> None: """Get scales and amaxes.""" @@ -930,7 +927,7 @@ def set_activation_dtype(self, inp: torch.Tensor) -> None: """Get activation data type for AMP.""" # Native AMP (`torch.autocast`) gets highest priority if torch.is_autocast_enabled(): - self.activation_dtype = torch_get_autocast_gpu_dtype() + self.fast_setattr("activation_dtype", torch_get_autocast_gpu_dtype()) return # All checks after this have already been performed once, thus skip @@ -945,7 +942,7 @@ def set_activation_dtype(self, inp: torch.Tensor) -> None: "Data types for parameters must match when outside of autocasted region. " f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}" ) - self.activation_dtype = dtype + self.fast_setattr("activation_dtype", dtype) def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None: """ @@ -957,8 +954,8 @@ def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> N tp_group : ProcessGroup, default = None tensor parallel process group. """ - self.tp_group = tp_group - self.tp_group_initialized = True + self.fast_setattr("tp_group", tp_group) + self.fast_setattr("tp_group_initialized", True) def _get_fp8_params(self) -> Union[List[torch.Tensor], None]: """returns the FP8 weights.""" @@ -974,48 +971,51 @@ def _get_fp8_params(self) -> Union[List[torch.Tensor], None]: # assume FP8 execution. def init_fp8_metadata(self, num_gemms: int = 1) -> None: """Initialize fp8 related metadata and tensors during fprop.""" - _original_recipe = self.fp8_meta.get("recipe", None) - - self.fp8_parameters = FP8GlobalStateManager.with_fp8_parameters() - self.fp8 = FP8GlobalStateManager.is_fp8_enabled() - self.fp8_calibration = FP8GlobalStateManager.is_fp8_calibration() - fp8_enabled = self.fp8 or self.fp8_calibration - self.fp8_meta["fp8_checkpoint"] = self.fp8 or self.fp8_calibration - - if self.fp8_parameters or fp8_enabled: - if ( - self.fp8_initialized - and FP8GlobalStateManager.get_fp8_recipe() == self.fp8_meta["recipe"] - ): + meta = self.fp8_meta + + fp8 = FP8GlobalStateManager.is_fp8_enabled() + fp8_parameters = FP8GlobalStateManager.with_fp8_parameters() + fp8_calibration = FP8GlobalStateManager.is_fp8_calibration() + self.fast_setattr("fp8_parameters", fp8_parameters) + self.fast_setattr("fp8", fp8) + self.fast_setattr("fp8_calibration", fp8_calibration) + fp8_enabled = fp8 or fp8_calibration + meta["fp8_checkpoint"] = fp8_enabled + + _original_recipe = None + + if fp8_parameters or fp8_enabled: + _original_recipe = meta.get("recipe", None) + if self.fp8_initialized and FP8GlobalStateManager.get_fp8_recipe() == _original_recipe: # FP8 init has already been run and recipe is the same, don't do anything. return - self.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe() + meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe() else: # If fp8 isn't enabled, turn off and return. - self.fp8_initialized = False + self.fast_setattr("fp8_initialized", False) return - if self.fp8_parameters and not self.fp8_initialized: - self.fp8_meta["num_gemms"] = num_gemms - self.init_fp8_meta_tensors(self.fp8_meta["recipe"]) + if fp8_parameters and not self.fp8_initialized: + meta["num_gemms"] = num_gemms + self.init_fp8_meta_tensors(meta["recipe"]) if fp8_enabled: # Set FP8 and other FP8 metadata - self.fp8_meta["num_gemms"] = num_gemms - self.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group() + meta["num_gemms"] = num_gemms + meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group() # Set FP8_MAX per tensor according to recipe - if hasattr(self.fp8_meta["recipe"], "fp8_format"): - self.fp8_meta["fp8_max_fwd"] = self.fp8_meta["recipe"].fp8_format.value.max_fwd - self.fp8_meta["fp8_max_bwd"] = self.fp8_meta["recipe"].fp8_format.value.max_bwd + if hasattr(meta["recipe"], "fp8_format"): + meta["fp8_max_fwd"] = meta["recipe"].fp8_format.value.max_fwd + meta["fp8_max_bwd"] = meta["recipe"].fp8_format.value.max_bwd # Allocate scales and amaxes - self.init_fp8_meta_tensors(self.fp8_meta["recipe"]) - self.fp8_initialized = True + self.init_fp8_meta_tensors(meta["recipe"]) + self.fast_setattr("fp8_initialized", True) - self.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe() + meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe() - _current_recipe = self.fp8_meta["recipe"] + _current_recipe = meta["recipe"] if _original_recipe is not None and not ( issubclass(_current_recipe.__class__, _original_recipe.__class__) or issubclass(_original_recipe.__class__, _current_recipe.__class__) @@ -1028,22 +1028,18 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: # Clear cached workspaces as they were created with the old recipe/quantizer type self._fp8_workspaces.clear() - @contextmanager def prepare_forward( self, inp: torch.Tensor, num_gemms: int = 1, allow_non_contiguous: bool = False, allow_different_data_and_param_types: bool = False, - ) -> Generator[torch.Tensor, None, None]: - """Checks and prep for FWD. - The context manager is needed because there isn't a way for a module to know - if it's the last FP8 module in the forward autocast. It is useful - to setup the forward aggregated amax reduction for every module - just in case. The autocast exit will pick up the most recent one. - """ - self.allow_different_data_and_param_types = allow_different_data_and_param_types - self.forwarded_at_least_once = True + ) -> torch.Tensor: + """Checks and prepares for FWD execution.""" + self.fast_setattr( + "allow_different_data_and_param_types", allow_different_data_and_param_types + ) + self.fast_setattr("forwarded_at_least_once", True) # Activation recomputation is used and this is the second forward phase. if self.fp8 and in_fp8_activation_recompute_phase(): @@ -1074,13 +1070,37 @@ def prepare_forward( if self.training and is_fp8_activation_recompute_enabled(): FP8GlobalStateManager.copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta) - with get_nvtx_range_context(self.__class__.__name__ + " forward"): - if not allow_non_contiguous and not inp.is_contiguous(): - inp = inp.contiguous() - yield inp + nvtx_range_push(self.__class__.__name__ + " forward") + if not allow_non_contiguous and not inp.is_contiguous(): + inp = inp.contiguous() + return inp + def end_forward(self): + """ + Required to be called at the end of the forward function to properly handle + DelayedScaling metadata handling and the NVTX ranges. + """ + delayed_scaling_recipe = self.fp8 and self.fp8_meta["recipe"].delayed() if delayed_scaling_recipe and self.fp8 and in_fp8_activation_recompute_phase(): FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta) + nvtx_range_pop() + + @contextmanager + def prepare_forward_ctx( + self, + inp: torch.Tensor, + num_gemms: int = 1, + allow_non_contiguous: bool = False, + allow_different_data_and_param_types: bool = False, + ) -> Generator[torch.Tensor, None, None]: + """Checks and prepares for FWD execution.""" + inp = self.prepare_forward( + inp, num_gemms, allow_non_contiguous, allow_different_data_and_param_types + ) + try: + yield inp + finally: + self.end_forward() def set_nccl_overlap_warning_if_tp(self) -> None: """When using TP, the NCCL communication needs to be scheduled @@ -1315,9 +1335,9 @@ def clear(self): # Update the parameter based on its type if not is_dtensor: - setattr(self, name, param) + self.module_setattr(name, param) else: - setattr(self, name, dtensor_param) + self.module_setattr(name, dtensor_param) @abstractmethod def forward(self): @@ -1516,7 +1536,6 @@ def is_debug_iter(self) -> bool: debug = TEDebugState.debug_enabled if not debug: return False - self._validate_name() # If layer is run first time in new iteration, # we need to check if the debug should be enabled for this layer - @@ -1530,14 +1549,14 @@ def is_debug_iter(self) -> bool: debug = False else: debug = TEDebugState.get_iteration() >= self.next_iter_when_debug_should_be_run - self.debug_last_iteration = TEDebugState.get_iteration() - self.debug_enabled_in_this_iteration = debug + self.fast_setattr("debug_last_iteration", TEDebugState.get_iteration()) + self.fast_setattr("debug_enabled_in_this_iteration", debug) else: # If this is the same iteration as previous invocation of the module, # we use the debug value from the first invocation in the iteration. debug = self.debug_enabled_in_this_iteration - self.debug_last_iteration = TEDebugState.get_iteration() + self.fast_setattr("debug_last_iteration", TEDebugState.get_iteration()) if self.wgrad_store is not None: if debug and self.wgrad_store.delay_wgrad_compute(): @@ -1553,7 +1572,9 @@ def no_debug_features_active(self, quantizers): # Sometimes features inform that they will not be enabled for particular layer # for multiple next iterations. - self.next_iter_when_debug_should_be_run = next_iter_when_debug_should_be_run(quantizers) + self.fast_setattr( + "next_iter_when_debug_should_be_run", next_iter_when_debug_should_be_run(quantizers) + ) if not run_current: return True @@ -1565,22 +1586,13 @@ def no_debug_features_active(self, quantizers): def _validate_name(self): """ Validate name passed to the module. - This is invoked in the forward() method as module names are assigned after Model is initialized in Megatron-LM. - If no name is assigned, it creates a default name with layer count as the variable. + It creates a default name with layer count as the variable + which may be changed by the user of the module. """ if self.name is not None: return - assert TEDebugState.debug_enabled - import nvdlfw_inspect.api as debug_api - - if self.name is None: - debug_api.log_message( - "Names are not provided to debug modules. ", - "Creating and using generic names. Pass names to debug modules for better" - " insight. ", - level=logging.WARNING, - ) - self.name = f"Layer_{TEDebugState.get_layer_count()}" + + self.name = f"Layer_{TEDebugState.get_layer_count()}" def _check_weight_tensor_recipe_correspondence(self) -> None: """ diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index e6e69b3e4a..c9ceb714e3 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -614,7 +614,7 @@ def __init__( save_original_input: bool = False, name: Optional[str] = None, ) -> None: - super().__init__() + super().__init__(name) params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype self.num_gemms = num_gemms @@ -633,7 +633,6 @@ def __init__( ), "GroupedLinear doesn't support Userbuffer overlap." self.get_rng_state_tracker = get_rng_state_tracker self.rng_tracker_name = rng_tracker_name - self.name = name self.wgrad_store = WeightGradStore(delay_wgrad_compute) @@ -789,7 +788,8 @@ def forward( is_grad_enabled = torch.is_grad_enabled() - with self.prepare_forward(inp, num_gemms=self.num_gemms) as inp: + inp = self.prepare_forward(inp, num_gemms=self.num_gemms) + try: weight_tensors = self._get_weight_tensors() bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)] @@ -844,6 +844,9 @@ def forward( ) out = linear_fn(*autograd_ctx, inp, non_tensor_args, *weight_tensors, *bias_tensors) + finally: + self.end_forward() + if self.return_bias: return out, [cast_if_needed(b, self.activation_dtype) for b in bias_tensors] return out diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index ca30ef9567..702916696b 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -1158,9 +1158,9 @@ def __init__( ub_name: Optional[str] = None, delay_wgrad_compute: bool = False, symmetric_ar_type: Optional[str] = None, - name: str = None, + name: Optional[str] = None, ) -> None: - super().__init__() + super().__init__(name) params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype self.in_features = in_features @@ -1179,7 +1179,6 @@ def __init__( self.symmetric_ar_type = symmetric_ar_type self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad) - self.name = name if tp_group is None: self.tp_size = tp_size @@ -1508,10 +1507,11 @@ def forward( ).is_fp8_ubuf(): fp8_grad = True - with self.prepare_forward( + inp = self.prepare_forward( inp, allow_non_contiguous=False # removed .contiguous from inside the layer - ) as inp: + ) + try: # Get concatenated weight and bias tensors weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() @@ -1590,6 +1590,9 @@ def forward( non_tensor_args, ) + finally: + self.end_forward() + if self.return_layernorm_output: out, ln_out = out diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 35e4522138..bec6744518 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1787,7 +1787,7 @@ def __init__( zero_centered_gamma: bool = False, device: Union[torch.device, str] = "cuda", ub_overlap_ag: bool = False, - name: str = None, + name: Optional[str] = None, ub_overlap_rs: bool = False, ub_overlap_rs_dgrad: bool = False, ub_bulk_dgrad: bool = False, @@ -1796,7 +1796,7 @@ def __init__( symmetric_ar_type: Optional[str] = None, checkpoint: bool = False, ) -> None: - super().__init__() + super().__init__(name) params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype self.fuse_wgrad_accumulation = fuse_wgrad_accumulation @@ -1827,7 +1827,6 @@ def __init__( for use_fp8 in [False, True] ) ) - self.name = name self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad) @@ -2047,8 +2046,9 @@ def forward( if get_ub("fc2_fprop", FP8GlobalStateManager.is_fp8_enabled()).is_fp8_ubuf(): fp8_output = True - with self.prepare_forward(inp, num_gemms=2) as inp: + inp = self.prepare_forward(inp, num_gemms=2) + try: quantizers = ( self._get_quantizers(fp8_output, is_grad_enabled) if not debug @@ -2087,7 +2087,7 @@ def forward( # Disable bias_gelu_nvfusion for determinism checkpointing in non-reentrant mode if self.bias_gelu_nvfusion and not use_reentrant_activation_recompute(): - self.bias_gelu_nvfusion = False + self.fast_setattr("bias_gelu_nvfusion", False) if is_grad_enabled: fwd_fn = _LayerNormMLP.apply @@ -2157,6 +2157,9 @@ def forward( non_tensor_args, ) + finally: + self.end_forward() + if self.return_layernorm_output: out, ln_out = out diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 38104604d8..23ad8cacb0 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -428,8 +428,8 @@ def forward( # weights if weights are externally touched outside this module ctx.weight_object = weight - if cpu_offloading: mark_not_offload(weight, weightmat, bias) + # TODO(ksivamani): Check memory usage tensors_to_save, tensor_objects = prepare_for_saving( saved_inputmat, @@ -1098,7 +1098,7 @@ def __init__( save_original_input: bool = False, name: Optional[str] = None, ) -> None: - super().__init__() + super().__init__(name) params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype self.in_features = in_features @@ -1111,7 +1111,6 @@ def __init__( self.rng_tracker_name = rng_tracker_name self.symmetric_ar_type = symmetric_ar_type self.save_original_input = save_original_input - self.name = name self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad) @@ -1395,11 +1394,8 @@ def forward( ).is_fp8_ubuf(): fp8_grad = True - with self.prepare_forward( - inp, - allow_non_contiguous=isinstance(inp, QuantizedTensor), - ) as inp: - + inp = self.prepare_forward(inp, allow_non_contiguous=isinstance(inp, QuantizedTensor)) + try: weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() quantizers = ( @@ -1470,6 +1466,8 @@ def forward( bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None, non_tensor_args, ) + finally: + self.end_forward() if self.gemm_bias_unfused_add: out = out + cast_if_needed(bias_tensor, self.activation_dtype) diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index 9b9ccc5185..7c3125a165 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -12,7 +12,6 @@ from transformer_engine.pytorch.torch_version import torch_version from transformer_engine.pytorch.module import LayerNormMLP, LayerNorm, RMSNorm -from transformer_engine.debug.pytorch.debug_state import TEDebugState from transformer_engine.pytorch.attention.multi_head_attention import MultiheadAttention from transformer_engine.pytorch.attention.inference import InferenceParams from transformer_engine.pytorch.jit import ( @@ -398,6 +397,7 @@ def __init__( self.softmax_type = softmax_type self.name = name + TransformerEngineBaseModule._validate_name(self) attention_args = ( hidden_size, @@ -446,7 +446,7 @@ def __init__( qk_norm_type=qk_norm_type, qk_norm_eps=qk_norm_eps, qk_norm_before_rope=qk_norm_before_rope, - name=name + ".self_attention" if name is not None else None, + name=self.name + ".self_attention" if self.name is not None else None, ) if layer_type == "decoder": @@ -463,7 +463,7 @@ def __init__( qk_norm_type=qk_norm_type, qk_norm_eps=qk_norm_eps, qk_norm_before_rope=qk_norm_before_rope, - name=name + ".inter_attention" if name is not None else None, + name=self.name + ".inter_attention" if self.name is not None else None, ) # LayerNorm -> activation(Linear + Bias) -> Linear @@ -499,7 +499,7 @@ def __init__( activation_params=activation_params, normalization=normalization, device=device, - name=name + ".layernorm_mlp" if name is not None else None, + name=self.name + ".layernorm_mlp" if self.name is not None else None, ) self.hidden_dropout = hidden_dropout @@ -768,9 +768,6 @@ def forward( enc_dec_attn_mask[i].dtype == torch.bool for i in range(len(enc_dec_attn_mask)) ), "Encoder-decoder attention mask must be boolean tensor(s)" - if TEDebugState.debug_enabled: - TransformerEngineBaseModule._validate_name(self) - # For AMP if torch.is_autocast_enabled(): hidden_states = cast_if_needed(hidden_states, torch_get_autocast_gpu_dtype())