Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 95 additions & 1 deletion intermediate_source/torch_compile_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,99 @@ def false_branch(y):
# Check out our `section on graph breaks in the torch.compile programming model <https://docs.pytorch.org/docs/stable/user_guide/torch_compiler/compile/programming_model.graph_breaks_index.html>`__
# for tips on how to work around graph breaks.

######################################################################
# Escape Hatches
# ---------------
#
# When ``torch.compile`` cannot trace through a piece of code, you have
# several escape hatches depending on your needs:
#
# .. code-block:: text
#
# Applying torch.compile
# │
# ├─ Don't want to spend time fixing graph breaks?
# │ └─ fullgraph=False (default)
# │
# ├─ No need to preserve side effects?
# │ └─ nonstrict_trace
# │
# ├─ Don't need control over backward / don't need a custom op?
# │ └─ leaf_function
# │
# └─ Otherwise
# └─ custom ops (torch.library.custom_op)
#
# **fullgraph=False** (default): ``torch.compile`` inserts graph breaks
# around untraceable code and continues compiling the rest. This is the
# simplest option — you don't change your code at all — but each graph
# break is a lost optimization opportunity.
#
# **nonstrict_trace**: Dynamo does not trace into the decorated function,
# but AOT Autograd *does*. This means the function's operations are still
# compiled and optimized. Use this when the function body is valid
# PyTorch but triggers graph breaks due to unsupported Python constructs.
# The tradeoff: side effects (mutations to globals, printing) may not be
# preserved.

torch._dynamo.reset()


@torch._dynamo.nonstrict_trace
def my_complex_fn(x):
torch._dynamo.graph_break()
return torch.sin(x) + 1


@torch.compile(fullgraph=True)
def use_nonstrict(x):
return my_complex_fn(x)


print(use_nonstrict(torch.randn(3, 3)))

######################################################################
# **leaf_function**: Neither Dynamo nor AOT Autograd trace into it.
# The original Python code runs eagerly at runtime. Use this for code
# with side effects (logging, external library calls) or code that is
# fundamentally untraceable. You must provide a ``register_fake``
# implementation for shape inference.

from torch._dynamo.decorators import leaf_function

torch._dynamo.reset()


@leaf_function
def log_and_compute(x):
stats = x.mean(dim=1)
print(f"Per-sample means: {stats}")
return (stats,)


@log_and_compute.register_fake
def log_and_compute_fake(x):
return (x.new_empty(x.shape[0]),)


@torch.compile(backend="aot_eager", fullgraph=True)
def use_leaf(x):
return log_and_compute(x)


print(use_leaf(torch.randn(3, 4)))


######################################################################
# **torch.library.custom_op**: The most heavyweight escape hatch.
# Custom ops give you full control over forward *and* backward behavior,
# work with ``torch.export``, and are visible as first-class ops in the
# graph. Use this when you need a custom autograd formula, need to wrap
# a non-PyTorch kernel (e.g. CUDA/Triton), or need export compatibility.
#
# See the `custom ops tutorial <https://docs.pytorch.org/tutorials/advanced/python_custom_ops.html>`__
# for details.

######################################################################
# Troubleshooting
# ---------------
Expand All @@ -413,7 +506,8 @@ def false_branch(y):
#
# In this tutorial, we introduced ``torch.compile`` by covering
# basic usage, demonstrating speedups over eager mode, comparing to TorchScript,
# and briefly describing graph breaks.
# graph breaks, and escape hatches (``nonstrict_trace``, ``leaf_function``,
# and custom ops) for code that cannot be traced.
#
# For an end-to-end example on a real model, check out our `end-to-end torch.compile tutorial <https://pytorch.org/tutorials/intermediate/torch_compile_full_example.html>`__.
#
Expand Down
Loading