diff --git a/intermediate_source/torch_compile_tutorial.py b/intermediate_source/torch_compile_tutorial.py index 75fab8c3f3..7fc27cf96c 100644 --- a/intermediate_source/torch_compile_tutorial.py +++ b/intermediate_source/torch_compile_tutorial.py @@ -397,6 +397,99 @@ def false_branch(y): # Check out our `section on graph breaks in the torch.compile programming model `__ # for tips on how to work around graph breaks. +###################################################################### +# Escape Hatches +# --------------- +# +# When ``torch.compile`` cannot trace through a piece of code, you have +# several escape hatches depending on your needs: +# +# .. code-block:: text +# +# Applying torch.compile +# │ +# ├─ Don't want to spend time fixing graph breaks? +# │ └─ fullgraph=False (default) +# │ +# ├─ No need to preserve side effects? +# │ └─ nonstrict_trace +# │ +# ├─ Don't need control over backward / don't need a custom op? +# │ └─ leaf_function +# │ +# └─ Otherwise +# └─ custom ops (torch.library.custom_op) +# +# **fullgraph=False** (default): ``torch.compile`` inserts graph breaks +# around untraceable code and continues compiling the rest. This is the +# simplest option — you don't change your code at all — but each graph +# break is a lost optimization opportunity. +# +# **nonstrict_trace**: Dynamo does not trace into the decorated function, +# but AOT Autograd *does*. This means the function's operations are still +# compiled and optimized. Use this when the function body is valid +# PyTorch but triggers graph breaks due to unsupported Python constructs. +# The tradeoff: side effects (mutations to globals, printing) may not be +# preserved. + +torch._dynamo.reset() + + +@torch._dynamo.nonstrict_trace +def my_complex_fn(x): + torch._dynamo.graph_break() + return torch.sin(x) + 1 + + +@torch.compile(fullgraph=True) +def use_nonstrict(x): + return my_complex_fn(x) + + +print(use_nonstrict(torch.randn(3, 3))) + +###################################################################### +# **leaf_function**: Neither Dynamo nor AOT Autograd trace into it. +# The original Python code runs eagerly at runtime. Use this for code +# with side effects (logging, external library calls) or code that is +# fundamentally untraceable. You must provide a ``register_fake`` +# implementation for shape inference. + +from torch._dynamo.decorators import leaf_function + +torch._dynamo.reset() + + +@leaf_function +def log_and_compute(x): + stats = x.mean(dim=1) + print(f"Per-sample means: {stats}") + return (stats,) + + +@log_and_compute.register_fake +def log_and_compute_fake(x): + return (x.new_empty(x.shape[0]),) + + +@torch.compile(backend="aot_eager", fullgraph=True) +def use_leaf(x): + return log_and_compute(x) + + +print(use_leaf(torch.randn(3, 4))) + + +###################################################################### +# **torch.library.custom_op**: The most heavyweight escape hatch. +# Custom ops give you full control over forward *and* backward behavior, +# work with ``torch.export``, and are visible as first-class ops in the +# graph. Use this when you need a custom autograd formula, need to wrap +# a non-PyTorch kernel (e.g. CUDA/Triton), or need export compatibility. +# +# See the `custom ops tutorial `__ +# for details. + ###################################################################### # Troubleshooting # --------------- @@ -413,7 +506,8 @@ def false_branch(y): # # In this tutorial, we introduced ``torch.compile`` by covering # basic usage, demonstrating speedups over eager mode, comparing to TorchScript, -# and briefly describing graph breaks. +# graph breaks, and escape hatches (``nonstrict_trace``, ``leaf_function``, +# and custom ops) for code that cannot be traced. # # For an end-to-end example on a real model, check out our `end-to-end torch.compile tutorial `__. #