Skip to content

Conversation

@ptrendx
Copy link
Member

@ptrendx ptrendx commented Dec 1, 2025

Description

This PR includes a few performance optimizations targeting the CPU overhead. The code, perf numbers etc. are WIP. The code gets kind of ugly though :-(.

For the prepare_forward changes I did not touch attention (@cyanguwa FYI) since it has multiple exit points from the forward and was worried that I would miss something there - it would be great if we could refactor that part first to have a single return statement instead.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@ptrendx
Copy link
Member Author

ptrendx commented Dec 1, 2025

/te-ci pytorch

Comment on lines 644 to 646
def fast_set_attr(self, name: str, value: Any) -> None:
self.__dict__[name] = value
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume we are separating out this function so we can manually avoid overheads from __setattr__ and dict? Doing some benchmarking:

  • dict read: 9 ns
  • dict write: 13 ns
  • dict in: 9 ns
  • dict.get: 14 ns
  • Function call: 9 ns
  • Class attr read: 3 ns
  • Class attr write: 5 ns
  • Class custom getattr: 101 ns
  • Class custom setattr: 134 ns
Benchmarking script

I ran the following on a GB200 node. For the dict times, I subtracted out the overhead from list reads. For the class getattr/setattr times, I subtracted out the overhead from range.

import contextlib
import time

class Timer:
    """Measure time interval."""

    def __init__(self) -> None:
        self._start = None
        self._end = None

    def time(self) -> float:
	"""CPU time interval in seconds."""
        return self._end - self._start

    @contextlib.contextmanager
    def context(self):
        """Context manager to capture time interval."""
	self._start = time.perf_counter()
        yield
        self._end = time.perf_counter()

def main() -> None:

    # Options
    iters = 1024 * 1024

    # Timer
    timer = Timer()

    # Dummy data
    str_list = ["lorem", "ipsum", "dolor", "sit", "amet", "consectetur", "adipiscing", "elit"]
    str_list = [str_list[i % len(str_list)] for i in range(iters)]
    str_dict = {s: len(s) for s in str_list}
    class PlainClass:
        def __init__(self) -> None:
            self.attr = 1
    class CustomGetattrSetattrClass:
        def __init__(self) -> None:
            self.attr = 1
        def __getattribute__(self, name):
            return super().__getattribute__(name)
	def __setattr__(self, name, val):
            super().__setattr__(name, val)

    # Timer overhead
    with timer.context():
        pass
    print(f"Timer overhead: {timer.time() * 1e9 / iters} ns/iter")

    # Range loop
    with timer.context():
        for _ in range(iters):
            pass
    print(f"Range loop: {timer.time() * 1e9 / iters} ns/iter")

    # List loop
    with timer.context():
        for _ in str_list:
            pass
    print(f"List loop: {timer.time() * 1e9 / iters} ns/iter")

    # Empty range+enumerate loop
    with timer.context():
        for i, j in enumerate(range(iters)):
            pass
    print(f"Range+enumerate loop: {timer.time() * 1e9 / iters} ns/iter")

    # Empty range+enumerate loop
    with timer.context():
        for i, s in enumerate(str_list):
            pass
    print(f"List+enumerate loop: {timer.time() * 1e9 / iters} ns/iter")

    # List reads
    with timer.context():
        for i in range(iters):
            str_list[i]
    print(f"List reads: {timer.time() * 1e9 / iters} ns/iter")

    # Dict reads
    with timer.context():
        for i in range(iters):
            str_dict[str_list[i]]
    print(f"Dict reads: {timer.time() * 1e9 / iters} ns/iter")

    # Dict get
    with timer.context():
        for i in range(iters):
            str_dict.get(str_list[i], None)
    print(f"Dict gets: {timer.time() * 1e9 / iters} ns/iter")

    # Dict writes
    with timer.context():
        for i in range(iters):
            str_dict[str_list[i]] = i
    print(f"Dict writes: {timer.time() * 1e9 / iters} ns/iter")

    # Dict membership
    with timer.context():
        for i in range(iters):
            str_list[i] in str_dict
    print(f"Dict membership: {timer.time() * 1e9 / iters} ns/iter")

    # Function call
    def func() -> None:
        pass
    with timer.context():
        for _ in range(iters):
            func()
    print(f"Function call: {timer.time() * 1e9 / iters} ns/iter")

    # Function call
    func = lambda: None
    with timer.context():
        for _ in range(iters):
            func()
    print(f"Lambda call: {timer.time() * 1e9 / iters} ns/iter")

    # Class attr read
    myobj = PlainClass()
    with timer.context():
        for _ in range(iters):
            _ = myobj.attr
    print(f"Class attr read: {timer.time() * 1e9 / iters} ns/iter")

    # Class attr write
    myobj = PlainClass()
    with timer.context():
        for i in range(iters):
            myobj.attr = i
    print(f"Class attr write: {timer.time() * 1e9 / iters} ns/iter")

    # getattr
    myobj = PlainClass()
    with timer.context():
        for _ in range(iters):
            getattr(myobj, "attr", None)
    print(f"getattr: {timer.time() * 1e9 / iters} ns/iter")

    # getattr
    myobj = PlainClass()
    with timer.context():
        for i in range(iters):
            setattr(myobj, "attr", i)
    print(f"setattr: {timer.time() * 1e9 / iters} ns/iter")

    # Class custom getattr
    myobj = CustomGetattrSetattrClass()
    with timer.context():
        for _ in range(iters):
            _ = myobj.attr
    print(f"Class custom getattr: {timer.time() * 1e9 / iters} ns/iter")

    # Class custom setattr
    myobj = CustomGetattrSetattrClass()
    with timer.context():
        for i in range(iters):
            myobj.attr = i
    print(f"Class custom setattr: {timer.time() * 1e9 / iters} ns/iter")

if __name__ == "__main__":
    main()

How much perf difference do you observe from fast_set_attr? I could see how it could save us ~1 us of overhead, but it would be good to make sure before making the code messier.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't want to comment too much on the perf results yet since up till now they all come from my machine and not a real cluster, but that anecdotal evidence shows that the time of the small test of just running BF16 Linear layer forward for many iterations after the proposed code changes go from 9.2 to 7.7 s. The fast_set_attr alone brought it to ~8.4s.
I will test it properly and report the timings in the description of the PR.
Now, about introducing the separate function - since ultimately this is the optimization that you came up with at some point, there already was the machinery to not do the expensive Module.set_attr for some parameters. The problem that I see is discoverability - if people do not study that code very cautiously they will not realize that they should not just do self.something = something. Therefore I think we should actually go a more explicit way and in the set_attr of TE module just error out with a message to either use fast_set_attr for the things we are sure are just small values (since the usage of dict directly has some problems BTW since it e.g. bypasses properties and stuff) and use a new function, let's call it just set_attr for anything where we need the full machinery.

Copy link
Collaborator

@timmoon10 timmoon10 Dec 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer not to ban self.something = something. I think readability and safety are more important for non-performance-critical things like initialization and checkpointing. It would be better to make this function an advanced internal implementation with a name like _fast_setattr.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How would we then make sure that this does not resurface in the future?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Went with the explicit setattr calls and having a warning issued when the regular setattr function is used. That way the users can still use the regular setattr call if they want, but for the internal development we make sure during testing that the warning does not trigger. To make the code less ugly we only turn on the warning after the constructor is finished - that way we can still use the nice syntax during construction (where there are the most occurences) since we do not care about the speed there.

@ptrendx ptrendx force-pushed the pr_python_cpu_optimization branch from 5eefe3e to 1c7d896 Compare December 2, 2025 22:45
@ptrendx ptrendx force-pushed the pr_python_cpu_optimization branch 3 times, most recently from 948747b to c4e380f Compare December 16, 2025 21:20
@ptrendx
Copy link
Member Author

ptrendx commented Jan 10, 2026

/te-ci pytorch

@ptrendx ptrendx marked this pull request as ready for review January 10, 2026 00:48
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 10, 2026

Greptile Summary

This PR introduces several CPU performance optimizations targeting overhead reduction in the PyTorch modules.

Key Changes:

  • Replaced prepare_forward context manager with explicit prepare_forward()/end_forward() method calls for Linear, LayerNormLinear, LayerNormMLP, and GroupedLinear modules (DotProductAttention retains context manager via prepare_forward_ctx due to multiple exit points)
  • Added fast_setattr() and module_setattr() helper methods to bypass PyTorch's expensive __setattr__ for regular attribute assignments
  • Moved module name validation from forward pass to __init__ to reduce per-forward overhead
  • C++ optimization: early-return null check before mutex lock acquisition in tensor allocator Free methods
  • Updated attribute assignments throughout codebase to use fast_setattr() where applicable

Architecture Notes:

  • Forward methods now use try-finally blocks to ensure end_forward() is called even on exceptions, maintaining NVTX range stack integrity
  • The prepare_forward_ctx context manager is preserved for modules with complex control flow (attention)

Confidence Score: 4/5

  • This PR is safe to merge with low risk - the refactoring maintains correctness with proper exception handling.
  • The PR implements well-structured CPU optimizations. Exception safety is properly handled via try-finally blocks in all forward methods. The C++ mutex optimization is safe. One minor redundancy exists in end_forward but it doesn't affect correctness.
  • transformer_engine/pytorch/module/base.py - core changes to prepare_forward/end_forward pattern

Important Files Changed

Filename Overview
transformer_engine/pytorch/module/base.py Refactored prepare_forward from context manager to explicit method + end_forward, added fast_setattr/module_setattr helpers to reduce __setattr__ overhead, name validation moved to init. Try-finally wrapping is used properly in prepare_forward_ctx.
transformer_engine/pytorch/module/linear.py Forward pass uses try-finally for exception safety with prepare_forward/end_forward. Name passed to parent class init. mark_not_offload consolidated within if cpu_offloading: block.
transformer_engine/pytorch/module/layernorm_linear.py Forward pass uses try-finally for exception safety with prepare_forward/end_forward. Name typing fixed to Optional[str] and passed to parent class init.
transformer_engine/pytorch/module/layernorm_mlp.py Forward pass uses try-finally for exception safety with prepare_forward/end_forward. Name typing fixed to Optional[str] and passed to parent class init. Uses fast_setattr for bias_gelu_nvfusion.
transformer_engine/pytorch/module/grouped_linear.py Forward pass uses try-finally for exception safety with prepare_forward/end_forward. Name passed to parent class init.
transformer_engine/common/transformer_engine.cpp Moved early-return null check before mutex lock in Free methods to avoid unnecessary lock acquisition.

Sequence Diagram

sequenceDiagram
    participant User as User Code
    participant Linear as Linear.forward()
    participant Base as TransformerEngineBaseModule
    participant NVTX as NVTX Profiler

    User->>Linear: forward(inp)
    Linear->>Base: prepare_forward(inp)
    Base->>Base: init_fp8_metadata()
    Base->>NVTX: nvtx_range_push("Linear forward")
    Base-->>Linear: return inp
    
    Note over Linear: try block
    Linear->>Linear: _get_weight_and_bias_tensors()
    Linear->>Linear: _get_quantizers()
    Linear->>Linear: linear_fn(*args)
    
    Note over Linear: finally block
    Linear->>Base: end_forward()
    Base->>Base: restore_fp8_meta_tensors (if needed)
    Base->>NVTX: nvtx_range_pop()
    
    Linear-->>User: return out
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 5 comments

Edit Code Review Agent Settings | Greptile

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 10, 2026

Greptile Overview

Greptile Summary

This PR attempts to optimize CPU overhead by:

  1. Replacing context managers with manual prepare_forward()/end_forward() calls
  2. Introducing fast_setattr() to bypass PyTorch's __setattr__ overhead
  3. Adding an __setattr__ override that warns when the slow path is used
  4. Optimizing C++ mutex locking by moving null checks before lock acquisition
  5. Configuring pytest to treat RuntimeWarnings as errors

Critical Issues Found

1. NVTX Range Imbalance on Exceptions (HIGH SEVERITY)

The refactoring from context managers to manual prepare_forward()/end_forward() calls breaks exception safety. If an exception occurs between these calls (e.g., shape mismatch, CUDA OOM, assertion failure), nvtx_range_pop() is never called, corrupting the NVTX stack. This affects all forward methods in Linear, LayerNormLinear, LayerNormMLP, and GroupedLinear modules.

2. setattr Doesn't Actually Optimize (HIGH SEVERITY)

The new __setattr__ override still calls super().__setattr__(name, value) after emitting a warning, meaning every attribute assignment still goes through the slow PyTorch path. This defeats the purpose of the optimization.

3. Multiple RuntimeWarning Violations (CRITICAL SEVERITY)

Six locations in base.py use direct attribute assignment after initialization (lines 965, 966, 1558, 1559, 1565, 1581, 1608). Since pytest.ini now treats RuntimeWarnings as errors, all tests will fail.

Positive Aspects

  • C++ mutex optimization is correct and beneficial
  • Attention module correctly uses prepare_forward_ctx context manager
  • All module subclasses properly set _initialized flag
  • Test scripts correctly updated to use pytest.ini

Recommendation

This PR cannot be merged in its current state due to the RuntimeWarning violations that will cause all tests to fail. The NVTX exception safety issue is also critical for production use.

Confidence Score: 0/5

  • This PR is not safe to merge - it will cause all tests to fail due to RuntimeWarning violations
  • Score of 0 reflects critical issues that will break the build: (1) Six direct attribute assignments trigger RuntimeWarnings which pytest.ini treats as errors, causing all tests to fail immediately; (2) NVTX range imbalance on exceptions will corrupt profiling; (3) setattr optimization doesn't actually work as intended
  • transformer_engine/pytorch/module/base.py requires immediate attention - must fix all direct attribute assignments (lines 965, 966, 1558, 1559, 1565, 1581, 1608) and address exception safety for NVTX ranges

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/pytorch/module/base.py 1/5 Critical bugs: NVTX range imbalance on exceptions, setattr doesn't optimize, multiple RuntimeWarning-triggering assignments
transformer_engine/pytorch/module/linear.py 2/5 Changed prepare_forward context manager to manual calls, adds _initialized flag, affected by base.py NVTX bug
transformer_engine/pytorch/module/layernorm_linear.py 2/5 Changed prepare_forward context manager to manual calls, adds _initialized flag, affected by base.py NVTX bug
transformer_engine/pytorch/module/layernorm_mlp.py 2/5 Changed prepare_forward context manager to manual calls, adds _initialized flag, affected by base.py NVTX bug
transformer_engine/pytorch/module/grouped_linear.py 2/5 Changed prepare_forward context manager to manual calls, adds _initialized flag, affected by base.py NVTX bug
tests/pytorch/pytest.ini 3/5 New file that treats RuntimeWarning as error - exposes issues with setattr implementation in base.py

Sequence Diagram

sequenceDiagram
    participant User
    participant Module as Linear/LayerNorm Module
    participant Base as TransformerEngineBaseModule
    participant NVTX as NVTX Range Stack
    participant FP8 as FP8GlobalStateManager

    User->>Module: forward(inp)
    Module->>Base: prepare_forward(inp)
    Base->>Base: init_fp8_metadata()
    Base->>Base: get_amp_dtype()
    Base->>FP8: Copy FP8 metadata (if recompute)
    Base->>NVTX: nvtx_range_push("Module forward")
    Note over NVTX: Range pushed - needs cleanup!
    Base-->>Module: returns processed inp
    
    alt Exception occurs during forward
        Module->>Module: Compute quantizers/weights
        Module--xModule: Exception raised!
        Note over NVTX: ⚠️ BUG: nvtx_range_pop() never called!
        Note over NVTX: Stack becomes imbalanced
    else Normal execution
        Module->>Module: Compute quantizers/weights
        Module->>Module: Call linear_fn()
        Module->>Base: end_forward()
        Base->>FP8: Restore FP8 metadata (if recompute)
        Base->>NVTX: nvtx_range_pop()
        Note over NVTX: Range properly cleaned up
        Base-->>Module: return
        Module-->>User: return output
    end
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

7 files reviewed, 7 comments

Edit Code Review Agent Settings | Greptile

Comment on lines 651 to 661
def __setattr__(self, name: str, value: Any) -> None:
if name in TransformerEngineBaseModule._fast_setattr_names:
# torch.nn.Module has a custom __setattr__ that handles
# modules, parameters, and buffers. This is unnecessary
# overhead when setting plain attrs.
self.__dict__[name] = value
else:
# Default case
super().__setattr__(name, value)
if "_initialized" in self.__dict__ and self._initialized:
warnings.warn(
"""The default implementation of torch.nn.Module introduces significant CPU overhead
when setting attributes and is therefore not recommended. Please use the explicit
calls (fast_setattr for setting regular values and module_setattr for setting
parameters, children modules and buffers).""",
RuntimeWarning,
)
super().__setattr__(name, value)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The __setattr__ override doesn't actually optimize anything - it still calls super().__setattr__(name, value) after emitting the warning. This means every attribute assignment still goes through the slow PyTorch Module path, defeating the purpose of the optimization.

The warning is also emitted on EVERY attribute assignment after initialization, which would be extremely noisy and unhelpful for debugging.

Additionally, there are multiple places in the codebase that still use direct attribute assignment after initialization:

  • Line 965: self.tp_group = tp_group in set_tensor_parallel_group()
  • Line 966: self.tp_group_initialized = True
  • Lines 1558, 1559, 1565, 1581, 1608: Various debug-related attribute assignments

Since pytest.ini now treats RuntimeWarnings as errors, all of these will cause test failures.

Recommendation: Either:

  1. Remove the __setattr__ override entirely and rely on developers remembering to use fast_setattr, or
  2. Update all the remaining direct assignments to use fast_setattr or module_setattr

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 10, 2026

Additional Comments (5)

transformer_engine/pytorch/module/base.py
Direct attribute assignments after initialization will trigger the RuntimeWarning defined in __setattr__, which pytest.ini now treats as an error. This will cause test failures.

        self.fast_setattr("tp_group", tp_group)
        self.fast_setattr("tp_group_initialized", True)

transformer_engine/pytorch/module/base.py
Direct attribute assignments after initialization will trigger the RuntimeWarning defined in __setattr__, which pytest.ini now treats as an error. This will cause test failures.

            self.fast_setattr("debug_last_iteration", TEDebugState.get_iteration())
            self.fast_setattr("debug_enabled_in_this_iteration", debug)

transformer_engine/pytorch/module/base.py
Direct attribute assignment after initialization will trigger the RuntimeWarning defined in __setattr__, which pytest.ini now treats as an error. This will cause test failures.

        self.fast_setattr("debug_last_iteration", TEDebugState.get_iteration())

transformer_engine/pytorch/module/base.py
Direct attribute assignment after initialization will trigger the RuntimeWarning defined in __setattr__, which pytest.ini now treats as an error. This will cause test failures.

        self.fast_setattr("next_iter_when_debug_should_be_run", next_iter_when_debug_should_be_run(quantizers))

transformer_engine/pytorch/module/base.py
Direct attribute assignment after initialization will trigger the RuntimeWarning defined in __setattr__, which pytest.ini now treats as an error. This will cause test failures.

            self.fast_setattr("name", f"Layer_{TEDebugState.get_layer_count()}")

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

8 files reviewed, 8 comments

Edit Code Review Agent Settings | Greptile

Comment on lines 651 to 661
def __setattr__(self, name: str, value: Any) -> None:
if name in TransformerEngineBaseModule._fast_setattr_names:
# torch.nn.Module has a custom __setattr__ that handles
# modules, parameters, and buffers. This is unnecessary
# overhead when setting plain attrs.
self.__dict__[name] = value
else:
# Default case
super().__setattr__(name, value)
if "_initialized" in self.__dict__ and self._initialized:
warnings.warn(
"""The default implementation of torch.nn.Module introduces significant CPU overhead
when setting attributes and is therefore not recommended. Please use the explicit
calls (fast_setattr for setting regular values and module_setattr for setting
parameters, children modules and buffers).""",
RuntimeWarning,
)
super().__setattr__(name, value)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The __setattr__ override emits a RuntimeWarning every time it's called after initialization, but it still calls super().__setattr__() to perform the assignment. This means:

  1. The warning is emitted but the "slow" operation still happens - there's no actual optimization here
  2. With pytest.ini treating RuntimeWarning as error, any accidental direct attribute assignment will cause test failures
  3. The warning message suggests using fast_setattr but doesn't prevent the slow path

If the goal is CPU optimization, this should either:

  • Raise an exception instead of warning (fail-fast), OR
  • Actually use the fast path by default after initialization

The current implementation warns about performance but doesn't fix it.

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 12, 2026

Additional Comments (5)

transformer_engine/pytorch/module/base.py
The set_tensor_parallel_group() method directly assigns to self.tp_group and self.tp_group_initialized using the standard = operator. After initialization (when _initialized=True), this will trigger the RuntimeWarning from the overridden __setattr__ method, which pytest.ini now treats as an error.

This method is called after module initialization in user code, so it will fail. The assignments should use fast_setattr():

        self.fast_setattr("tp_group", tp_group)
        self.fast_setattr("tp_group_initialized", True)

transformer_engine/pytorch/module/base.py
The is_debug_iter() method directly assigns to self.debug_last_iteration and self.debug_enabled_in_this_iteration. Since this is called during forward passes (after _initialized=True), these assignments will trigger RuntimeWarnings that pytest treats as errors.

These should use fast_setattr():

            self.fast_setattr("debug_last_iteration", TEDebugState.get_iteration())
            self.fast_setattr("debug_enabled_in_this_iteration", debug)

transformer_engine/pytorch/module/base.py
Another instance where self.debug_last_iteration is assigned directly during forward pass execution, which will trigger RuntimeWarning treated as error by pytest.

        self.fast_setattr("debug_last_iteration", TEDebugState.get_iteration())

transformer_engine/pytorch/module/base.py
The no_debug_features_active() method assigns to self.next_iter_when_debug_should_be_run, which is called during forward execution. This will trigger the RuntimeWarning.

        self.fast_setattr("next_iter_when_debug_should_be_run", next_iter_when_debug_should_be_run(quantizers))

transformer_engine/pytorch/module/base.py
The _validate_name() method assigns self.name directly during forward execution if no name was provided. This will trigger RuntimeWarning.

            self.fast_setattr("name", f"Layer_{TEDebugState.get_layer_count()}")

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

6 files reviewed, 6 comments

Edit Code Review Agent Settings | Greptile

Comment on lines 637 to 660
def fast_setattr(self, name: str, value: Any) -> None:
"""
Fast version of the Module's set attribute function.
Should be used for regular attributes, but not properties nor parameters/buffers.
"""
self.__dict__[name] = value

def module_setattr(self, name: str, value: Any) -> None:
"""
Regular version of the Module's set attribute function.
Should be used only when the fast version cannot be used - for the properties,
parameters and buffers.
"""
super().__setattr__(name, value)

def __setattr__(self, name: str, value: Any) -> None:
if name in TransformerEngineBaseModule._fast_setattr_names:
# torch.nn.Module has a custom __setattr__ that handles
# modules, parameters, and buffers. This is unnecessary
# overhead when setting plain attrs.
self.__dict__[name] = value
else:
# Default case
super().__setattr__(name, value)
if "_initialized" in self.__dict__ and self._initialized:
warnings.warn(
"""The default implementation of torch.nn.Module introduces significant CPU overhead
when setting attributes and is therefore not recommended. Please use the explicit
calls (fast_setattr for setting regular values and module_setattr for setting
parameters, children modules and buffers).""",
RuntimeWarning,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Design Issue: The setattr override doesn't actually optimize anything

The new __setattr__ implementation still calls super().__setattr__(), which means it still incurs the full PyTorch Module overhead. The warning is issued but the slow path is always taken.

To actually optimize, the code should:

  1. Skip the warning during __init__ (check if "_initialized" not in self.__dict__)
  2. After init, require explicit use of fast_setattr() or module_setattr()
  3. Consider making __setattr__ raise an error instead of a warning to enforce the pattern

Currently, this just adds overhead (checking _initialized + issuing warning) on top of the existing slow path.

Comment on lines 1111 to 1113
def train(self, mode: bool = True) -> Self:
with warnings.catch_warnings():
warnings.filterwarnings("ignore", ".*CPU overhead.*")
return super().train(mode)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Workaround Needed: train() method silences the RuntimeWarning

This method filters out the "CPU overhead" warning because PyTorch's Module.train() internally uses __setattr__ to set the training flag on all child modules. This is a workaround for the fact that the base implementation can't be avoided in some cases.

This reveals a potential issue: other PyTorch internal methods that use __setattr__ will also trigger warnings. Consider:

  • to() device transfers
  • half(), float(), bfloat16() dtype conversions
  • register_module(), register_buffer(), register_parameter()

These may all need similar workarounds.

Comment on lines 652 to 661
def __setattr__(self, name: str, value: Any) -> None:
if name in TransformerEngineBaseModule._fast_setattr_names:
# torch.nn.Module has a custom __setattr__ that handles
# modules, parameters, and buffers. This is unnecessary
# overhead when setting plain attrs.
self.__dict__[name] = value
else:
# Default case
super().__setattr__(name, value)
if "_initialized" in self.__dict__ and self._initialized:
warnings.warn(
"""The default implementation of torch.nn.Module introduces significant CPU overhead
when setting attributes and is therefore not recommended. Please use the explicit
calls (fast_setattr for setting regular values and module_setattr for setting
parameters, children modules and buffers).""",
RuntimeWarning,
)
super().__setattr__(name, value)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P1] setattr warning defeats the optimization: The __setattr__ override emits a RuntimeWarning when _initialized=True, which pytest.ini treats as an error. However, __setattr__ still calls super().__setattr__(), meaning it doesn't actually prevent or optimize anything - it just warns.

If the goal is CPU optimization, the warning should either:

  1. Be accompanied by raising an exception to force using fast_setattr(), or
  2. Implement the fast path directly in __setattr__ when appropriate

Currently, any test that triggers the warning will fail, but production code will just emit warnings with no actual optimization benefit from this implementation.

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 12, 2026

Additional Comments (4)

transformer_engine/pytorch/module/base.py
CRITICAL: Direct attribute assignment will trigger RuntimeWarning treated as error

These direct assignments using self.attribute = will trigger the RuntimeWarning from the overridden __setattr__ method (line 652-660), which pytest.ini now treats as an error.

This method is called after module initialization (when _initialized=True), so these assignments will fail tests.

        self.fast_setattr("tp_group", tp_group)
        self.fast_setattr("tp_group_initialized", True)

transformer_engine/pytorch/module/base.py
CRITICAL: Direct attribute assignment will trigger RuntimeWarning treated as error

These direct assignments will trigger RuntimeWarning since is_debug_iter() is called during forward passes (after _initialized=True).

            self.fast_setattr("debug_last_iteration", TEDebugState.get_iteration())
            self.fast_setattr("debug_enabled_in_this_iteration", debug)

transformer_engine/pytorch/module/base.py
CRITICAL: Direct attribute assignment will trigger RuntimeWarning treated as error

This assignment will also trigger the RuntimeWarning.

        self.fast_setattr("debug_last_iteration", TEDebugState.get_iteration())

transformer_engine/common/transformer_engine.cpp
Good optimization: Moving the early return check (if (index == 0) return;) before acquiring the mutex lock avoids unnecessary synchronization overhead when freeing null pointers. This is a correct and safe optimization that reduces contention on the allocator mutex.

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 12, 2026

Additional Comments (6)

transformer_engine/pytorch/module/base.py, line 1096
[P0] NVTX range imbalance on exceptions: The refactoring from context manager to explicit prepare_forward()/end_forward() calls breaks exception safety. If an exception occurs between prepare_forward() (which calls nvtx_range_push) and end_forward() (which calls nvtx_range_pop), the NVTX range stack will be corrupted.

This affects all forward methods in Linear, LayerNormLinear, LayerNormMLP, and GroupedLinear modules. Any exception (shape mismatch, CUDA OOM, assertion failure, etc.) will leave an unclosed NVTX range.

The original context manager pattern guaranteed cleanup via __exit__. The new pattern requires all code paths to call end_forward(), but exceptions bypass this.


transformer_engine/pytorch/module/base.py, line 661
[P1] setattr warning defeats the optimization: The __setattr__ override emits a RuntimeWarning when _initialized=True, which pytest.ini treats as an error. However, __setattr__ still calls super().__setattr__(), meaning it doesn't actually prevent or optimize anything - it just warns.

If the goal is CPU optimization, the warning should either:

  1. Be accompanied by raising an exception to force using fast_setattr(), or
  2. Implement the fast path directly in __setattr__ when appropriate

Currently, any test that triggers the warning will fail, but production code will just emit warnings with no actual optimization benefit from this implementation.


transformer_engine/pytorch/module/base.py, line 968
[P0] set_tensor_parallel_group uses slow setattr after init: This method is called after module initialization (when _initialized=True), directly assigning to self.tp_group and self.tp_group_initialized using =, which triggers the RuntimeWarning that pytest treats as an error.

        self.fast_setattr("tp_group", tp_group)
        self.fast_setattr("tp_group_initialized", True)

transformer_engine/pytorch/module/base.py, line 1578
[P0] is_debug_iter uses slow setattr after init: The is_debug_iter() method directly assigns to self.debug_last_iteration and self.debug_enabled_in_this_iteration at lines 1564-1565 and 1571. Since this is called during forward passes (after _initialized=True), these assignments will trigger RuntimeWarning that pytest treats as errors.

            self.fast_setattr("debug_last_iteration", TEDebugState.get_iteration())
            self.fast_setattr("debug_enabled_in_this_iteration", debug)

And at line 1571:

        self.fast_setattr("debug_last_iteration", TEDebugState.get_iteration())

transformer_engine/pytorch/module/base.py, line 1115
[P3] train() method hides warnings inappropriately: The train() method suppresses CPU overhead warnings with a catch_warnings context. This defeats the purpose of the pytest.ini configuration that treats RuntimeWarning as errors, allowing slow paths to be used silently during .train() / .eval() calls.

If the warnings are genuinely problematic here, the underlying issue should be fixed rather than suppressed.


transformer_engine/pytorch/module/layernorm_mlp.py, line 2093
[P0] Attribute assignment will trigger RuntimeWarning: Direct assignment to self.bias_gelu_nvfusion uses the slow __setattr__ path after initialization, which triggers a RuntimeWarning that pytest treats as an error.

            self.fast_setattr("bias_gelu_nvfusion", False)

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, 4 comments

Edit Code Review Agent Settings | Greptile

Comment on lines 652 to 661
def __setattr__(self, name: str, value: Any) -> None:
if name in TransformerEngineBaseModule._fast_setattr_names:
# torch.nn.Module has a custom __setattr__ that handles
# modules, parameters, and buffers. This is unnecessary
# overhead when setting plain attrs.
self.__dict__[name] = value
else:
# Default case
super().__setattr__(name, value)
if "_initialized" in self.__dict__ and self._initialized:
warnings.warn(
"""The default implementation of torch.nn.Module introduces significant CPU overhead
when setting attributes and is therefore not recommended. Please use the explicit
calls (fast_setattr for setting regular values and module_setattr for setting
parameters, children modules and buffers).""",
RuntimeWarning,
)
super().__setattr__(name, value)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WARNING: setattr override defeats the optimization it's meant to enforce

The __setattr__ override emits a RuntimeWarning when _initialized=True, but then still calls super().__setattr__(name, value) which goes through PyTorch's slow path anyway. This means:

  1. You pay the cost of the warning check
  2. You still pay the full cost of PyTorch's __setattr__
  3. The warning is treated as an error by pytest, so any code path using it will fail tests

This doesn't actually optimize anything - it just breaks code that uses standard attribute assignment. The optimization only works when code explicitly calls fast_setattr().

Implications:

  • Any code path (including external libraries or user code) that does module.attribute = value will fail in tests
  • The train() method workaround (lines 1111-1114) shows this is already a problem
  • Properties, descriptors, and any dynamic attribute access will trigger this warning

Recommended approach:

  • Remove the warning from __setattr__ - it creates more problems than it solves
  • Keep fast_setattr() as an opt-in optimization
  • Add comments/documentation encouraging use of fast_setattr() where performance matters

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 12, 2026

Additional Comments (2)

transformer_engine/pytorch/module/base.py
CRITICAL: Direct attribute assignments will trigger RuntimeWarning treated as test failure

These direct assignments using self.attribute = will trigger the RuntimeWarning from the overridden __setattr__ method (lines 652-661), which pytest.ini now treats as an error. This method is called after module initialization (when _initialized=True), so these assignments will cause tests to fail.

        self.fast_setattr("tp_group", tp_group)
        self.fast_setattr("tp_group_initialized", True)

transformer_engine/pytorch/module/base.py
CRITICAL: Direct attribute assignments will trigger RuntimeWarning treated as test failure

These direct assignments to self.debug_last_iteration and self.debug_enabled_in_this_iteration will trigger RuntimeWarning since is_debug_iter() is called during forward passes (after _initialized=True). The pytest.ini configuration now treats RuntimeWarnings as errors, causing test failures.

            self.fast_setattr("debug_last_iteration", TEDebugState.get_iteration())
            self.fast_setattr("debug_enabled_in_this_iteration", debug)
        else:
            # If this is the same iteration as previous invocation of the module,
            # we use the debug value from the first invocation in the iteration.
            debug = self.debug_enabled_in_this_iteration

        self.fast_setattr("debug_last_iteration", TEDebugState.get_iteration())

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 12, 2026

Additional Comments (3)

transformer_engine/pytorch/module/base.py
CRITICAL: Direct attribute assignments will trigger RuntimeWarning treated as test failure

The set_tensor_parallel_group() method uses direct attribute assignment (self.tp_group = tp_group and self.tp_group_initialized = True). This method is called after module initialization (when _initialized=True), so these assignments will trigger the RuntimeWarning from the overridden __setattr__ method (lines 652-661), which pytest.ini now treats as an error.

        self.fast_setattr("tp_group", tp_group)
        self.fast_setattr("tp_group_initialized", True)

transformer_engine/pytorch/module/base.py
CRITICAL: Direct attribute assignments will trigger RuntimeWarning treated as test failure

The is_debug_iter() method uses direct attribute assignment for self.debug_last_iteration and self.debug_enabled_in_this_iteration. Since this method is called during forward passes (after _initialized=True), these assignments will trigger RuntimeWarnings that pytest.ini treats as errors.

            self.fast_setattr("debug_last_iteration", TEDebugState.get_iteration())
            self.fast_setattr("debug_enabled_in_this_iteration", debug)

transformer_engine/pytorch/module/base.py
CRITICAL: Direct attribute assignment will trigger RuntimeWarning treated as test failure

This direct assignment to self.debug_last_iteration in the is_debug_iter() method will trigger a RuntimeWarning when called during forward passes (after _initialized=True).

        self.fast_setattr("debug_last_iteration", TEDebugState.get_iteration())

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

9 files reviewed, 9 comments

Edit Code Review Agent Settings | Greptile

Comment on lines 1111 to 1113
def train(self, mode: bool = True) -> Self:
with warnings.catch_warnings():
warnings.filterwarnings("ignore", ".*CPU overhead.*")
return super().train(mode)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P2] The train() method suppresses the CPU overhead warning by filtering it out. However, this means PyTorch's __setattr__ is still being called for all the child module state changes during train()/eval() mode switching. This defeats the optimization goal since the expensive __setattr__ path is still taken - it just doesn't warn about it.

The proper fix would be to override train() to use fast_setattr for the training attribute on all child modules, rather than suppressing the warning.

Comment on lines 652 to 661
def __setattr__(self, name: str, value: Any) -> None:
if name in TransformerEngineBaseModule._fast_setattr_names:
# torch.nn.Module has a custom __setattr__ that handles
# modules, parameters, and buffers. This is unnecessary
# overhead when setting plain attrs.
self.__dict__[name] = value
else:
# Default case
super().__setattr__(name, value)
if "_initialized" in self.__dict__ and self._initialized:
warnings.warn(
"""The default implementation of torch.nn.Module introduces significant CPU overhead
when setting attributes and is therefore not recommended. Please use the explicit
calls (fast_setattr for setting regular values and module_setattr for setting
parameters, children modules and buffers).""",
RuntimeWarning,
)
super().__setattr__(name, value)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P2] The __setattr__ override still calls super().__setattr__(name, value) after emitting the warning, meaning the expensive PyTorch __setattr__ is always executed. This doesn't optimize anything - it only adds overhead by checking _initialized and potentially emitting a warning before doing the expensive operation anyway.

If the goal is CPU optimization, the __setattr__ should use fast_setattr for appropriate attributes instead of just warning. The current implementation only serves as a detection mechanism, not an optimization.

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 12, 2026

Additional Comments (3)

transformer_engine/pytorch/module/base.py
[P0] These direct attribute assignments will trigger the RuntimeWarning defined in __setattr__ (lines 652-661), which pytest.ini now treats as an error. The set_tensor_parallel_group() method is called after module initialization (when _initialized=True), causing tests to fail.

        self.fast_setattr("tp_group", tp_group)
        self.fast_setattr("tp_group_initialized", True)

transformer_engine/pytorch/module/base.py
[P0] These direct attribute assignments will trigger RuntimeWarning since is_debug_iter() is called during forward passes (after _initialized=True). With pytest.ini treating RuntimeWarnings as errors, this will cause test failures.

            self.fast_setattr("debug_last_iteration", TEDebugState.get_iteration())
            self.fast_setattr("debug_enabled_in_this_iteration", debug)

transformer_engine/pytorch/module/base.py
[P0] This direct assignment will trigger the RuntimeWarning. Called during forward passes when _initialized=True.

        self.fast_setattr("debug_last_iteration", TEDebugState.get_iteration())

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

7 files reviewed, 7 comments

Edit Code Review Agent Settings | Greptile

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 12, 2026

Additional Comments (3)

transformer_engine/pytorch/module/base.py
[P0] Direct attribute assignments will trigger RuntimeWarning treated as test failure. This method is called post-initialization (when _initialized=True), so these assignments trigger the __setattr__ override (line 653) that emits RuntimeWarning. With pytest.ini now treating RuntimeWarnings as errors, all tests calling this method will fail.

        self.fast_setattr("tp_group", tp_group)
        self.fast_setattr("tp_group_initialized", True)

transformer_engine/pytorch/module/base.py
[P0] Direct attribute assignments will trigger RuntimeWarning treated as test failure. The is_debug_iter() method is called during forward passes (post-initialization), so these assignments trigger the __setattr__ override that emits RuntimeWarning, causing pytest to fail.

            self.fast_setattr("debug_last_iteration", TEDebugState.get_iteration())
            self.fast_setattr("debug_enabled_in_this_iteration", debug)

transformer_engine/pytorch/module/base.py
[P0] Direct attribute assignment will trigger RuntimeWarning treated as test failure. This line is reached during forward passes when _initialized=True, triggering the __setattr__ override.

        self.fast_setattr("debug_last_iteration", TEDebugState.get_iteration())

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
@ptrendx ptrendx force-pushed the pr_python_cpu_optimization branch from f96abbd to 2fb6ee3 Compare January 13, 2026 20:05
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (1)

  1. transformer_engine/pytorch/module/base.py, line 1570 (link)

    syntax: CRITICAL: RuntimeWarning will cause test failures

    This direct attribute assignment will also trigger the RuntimeWarning.

17 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

ptrendx and others added 4 commits January 13, 2026 16:30
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 14, 2026

Greptile's behavior is changing!

From now on, if a review finishes with no comments, we will not post an additional "statistics" comment to confirm that our review found nothing to comment on. However, you can confirm that we reviewed your changes in the status check section.

This feature can be toggled off in your Code Review Settings by deselecting "Create a status check for each PR".

ptrendx and others added 3 commits January 13, 2026 16:55
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (1)

  1. transformer_engine/pytorch/module/base.py, line 967-968 (link)

    logic: set_tensor_parallel_group uses direct attribute assignment instead of fast_setattr. This public method is documented to be called after module initialization, which will trigger RuntimeWarning (now treated as error by pytest.ini).

19 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
way

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
@ptrendx
Copy link
Member Author

ptrendx commented Jan 15, 2026

/te-ci L1 pytorch

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

12 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
@ptrendx
Copy link
Member Author

ptrendx commented Jan 16, 2026

/te-ci L1 pytorch

@ptrendx ptrendx added the 2.12.0 label Jan 17, 2026
@ptrendx ptrendx force-pushed the pr_python_cpu_optimization branch from 2ee3442 to cd11e67 Compare January 20, 2026 17:44
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@ptrendx ptrendx force-pushed the pr_python_cpu_optimization branch from cd11e67 to 028d03f Compare January 20, 2026 17:50
@ptrendx
Copy link
Member Author

ptrendx commented Jan 20, 2026

/te-ci pytorch L1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants