Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -898,6 +898,67 @@ def _insert_complex_io_adapters(
partitioned_module.recompile()


def _apply_dynamic_shape_bounds(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@apbose please review / comment on if this overlaps with your work

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was adapted from a previous fix in Wenbing’s branch: _apply_dynamic_shape_bounds.

It fixes the case where ZoomASR is exported with broad or unbounded Dim.DYNAMIC ranges by applying the user-provided torch_tensorrt.Input min/max bounds during Torch-TensorRT compilation. Without this compiler-side fix, I had to work around the issue in the ZoomASR export code by changing the dynamic shape annotations from:

"features": {0: Dim.DYNAMIC, 1: Dim.DYNAMIC, 2: Dim.STATIC},
"feature_lengths": {0: Dim.DYNAMIC},
to explicitly bounded dimensions:

batch_dim = Dim("batch", min=1, max=max_batch_size)
feature_len_dim = Dim("feature_len", min=1, max=3000)

dynamic_shapes={
"features": {0: batch_dim, 1: feature_len_dim, 2: Dim.STATIC},
"feature_lengths": {0: batch_dim},
}

With _apply_dynamic_shape_bounds, the model export can keep using Dim.DYNAMIC, while Torch-TensorRT still constructs the intended bounded TensorRT optimization profile from the provided input specs.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My only concern with this code and Wenbings prior version of this is im not sure it properly inserts the updated shapes. Like I would expect some form of shape prop done here as well after constraining the placeholders. If we can solve this in export with explicit ranges, it might be better to split this into its own PR that fully handles constraining an already exported exported program. It should still unblock Shane without and the explicit shape in export right?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed and yes this will unblock Shane.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah PR 4213 addresses this, validating the input bounds against the exporter bounds.

  1. One issue against changing the var_to_range is that it tracks the exporter bounds, and if we change the Input bounds for a later compile, it conflates the Input bound of TRT profile with the export bound, and can lead to error. Also it does not do the shape prop. So something like this fails in above
ep = torch.export.export(model, (sample,),
    dynamic_shapes={"x": {0: Dim.DYNAMIC}})
# var_to_range[s0] = [2, int_oo]  -- model valid for any size

Now two separate compilations:

# User A: small deployment, max batch 100
trt_a = torchtrt.dynamo.compile(ep, inputs=[Input(max_shape=(100, 8))])
# _apply_dynamic_shape_bounds mutates var_to_range[s0] → [2, 100]

# User B: larger deployment, max batch 500
trt_b = torchtrt.dynamo.compile(ep, inputs=[Input(max_shape=(500, 8))])
# _apply_dynamic_shape_bounds now sees var_to_range[s0] = [2, 100]
# computes min(100, 500) = 100
# User B silently gets max=100 instead of 500
  1. Second type which the other PR addresses is — if the export used explicit Dim("batch", min=10, max=20) and the user passes Input(min_shape=2) this method will clamp input bound to 10 ignoring 2.

gm: torch.fx.GraphModule,
sample_arg_inputs: Sequence[Input],
sample_kwarg_inputs: dict[Any, Any],
) -> None:
"""Propagate user Input min/max bounds into the FX shape_env.

This lets explicit torch_tensorrt.Input bounds constrain exported programs
that otherwise carry broad Dim.DYNAMIC ranges.
"""
from torch.utils._sympy.value_ranges import ValueRanges

placeholders = [n for n in gm.graph.nodes if n.op == "placeholder"]

sample_by_name: dict[str, Input] = {}
for i, node in enumerate(placeholders):
if i < len(sample_arg_inputs):
inp = sample_arg_inputs[i]
if isinstance(inp, Input) and inp.shape_mode == Input._ShapeMode.DYNAMIC:
sample_by_name[node.target] = inp

for name, inp in sample_kwarg_inputs.items():
if isinstance(inp, Input) and inp.shape_mode == Input._ShapeMode.DYNAMIC:
sample_by_name[name] = inp

if not sample_by_name:
return

updated_syms: set = set()
for node in placeholders:
if node.target not in sample_by_name:
continue

sample_input = sample_by_name[node.target]
fake_val = node.meta.get("val")
if not isinstance(fake_val, torch.Tensor):
continue

min_shape = sample_input.shape["min_shape"]
max_shape = sample_input.shape["max_shape"]

for d, dim in enumerate(fake_val.size()):
if not isinstance(dim, torch.SymInt) or d >= len(min_shape):
continue

expr = dim.node.expr
if expr in updated_syms:
continue

shape_env = dim.node.shape_env
if expr not in shape_env.var_to_range:
continue

old_range = shape_env.var_to_range[expr]
lower = max(old_range.lower, min_shape[d])
upper = min(old_range.upper, max_shape[d])
shape_env.var_to_range[expr] = ValueRanges(lower=lower, upper=upper)
updated_syms.add(expr)
logger.debug("Updated shape_env range for %s: [%s, %s]", expr, lower, upper)


@fn_supports_debugger # type: ignore[misc]
def compile_module(
gm: torch.fx.GraphModule,
Expand Down Expand Up @@ -929,6 +990,8 @@ def compile_module(
if sample_kwarg_inputs is None:
sample_kwarg_inputs = {}

_apply_dynamic_shape_bounds(gm, sample_arg_inputs, sample_kwarg_inputs)

# Configure user compilation settings to converters.
CONVERTERS.set_compilation_settings(settings)

Expand Down
2 changes: 2 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,6 +816,7 @@ def aten_ops_rsqrt(
)


@dynamo_tensorrt_converter(operator.neg, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.aten.neg.default, supports_dynamic_shapes=True)
def aten_ops_neg(
ctx: ConversionContext,
Expand Down Expand Up @@ -2223,6 +2224,7 @@ def aten_ops_maximum(
)


@dynamo_tensorrt_converter(torch.sym_min, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.aten.minimum.default, supports_dynamic_shapes=True)
def aten_ops_minimum(
ctx: ConversionContext,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from .complex_graph_rewrite import complex_graph_detection
from .constant_folding import constant_fold
from .eliminate_sym_min_int64_max import eliminate_sym_min_int64_max
from .force_causal_efficient_attention import force_causal_efficient_attention
from .fuse_prims_broadcast import fuse_prims_broadcast
from .pass_manager import DynamoPassManager
Expand All @@ -23,6 +24,7 @@
from .replace_fused_rms_norm import replace_fused_rms_norm
from .replace_max_pool_with_indices import replace_max_pool_with_indices
from .rule_based_autocast import rule_based_autocast
from .normalize_negative_slice_stop import normalize_negative_slice_stop

pre_lowering_pass_list = [
remove_detach,
Expand All @@ -41,6 +43,8 @@
remove_num_users_is_0_nodes,
complex_graph_detection,
force_causal_efficient_attention,
eliminate_sym_min_int64_max,
normalize_negative_slice_stop,
]

if not is_tegra_platform():
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import sys

import torch
from torch.fx import GraphModule, Node

from .pass_utils import clean_up_graph_after_modifications


_INT64_MAX = 2**63 - 1
_SYM_MIN = getattr(torch, "sym_min", None)


def _is_int64_max(x: object) -> bool:
return isinstance(x, int) and x in (sys.maxsize, _INT64_MAX)


def eliminate_sym_min_int64_max(
gm: GraphModule, settings: object = None
) -> GraphModule:
"""Remove no-op sym_min nodes where one operand is INT64_MAX.

torch.export may emit sym_min(sym, INT64_MAX) for an effectively unbounded
symbolic value. That expression is equivalent to sym, and leaving it in the
graph can produce runtime calls to torch.sym_min with Tensor inputs.
"""
if _SYM_MIN is None:
return gm

modified = False
for node in list(gm.graph.nodes):
if (
node.op != "call_function"
or node.target is not _SYM_MIN
or len(node.args) < 2
):
continue

lhs, rhs = node.args[:2]
if _is_int64_max(rhs) and isinstance(lhs, Node):
passthrough = lhs
elif _is_int64_max(lhs) and isinstance(rhs, Node):
passthrough = rhs
else:
continue

node.replace_all_uses_with(passthrough)
gm.graph.erase_node(node)
modified = True

return clean_up_graph_after_modifications(gm) if modified else gm
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import operator
from typing import Optional

import torch
from torch.fx import GraphModule, Node

from .pass_utils import clean_up_graph_after_modifications


def _negative_symint_operand(x: object) -> Optional[object]:
# Return n for symbolic bounds represented as -n. The caller rewrites
# that bound to dim_size - n, matching Python's negative indexing rules.
if (
isinstance(x, Node)
and x.op == "call_function"
and x.target in (operator.neg, torch.ops.aten.neg.default)
and len(x.args) == 1
):
return x.args[0]
return None


def _rank(x: Node) -> Optional[int]:
val = x.meta.get("val")
if isinstance(val, torch.Tensor):
return val.dim()
if hasattr(val, "shape"):
return len(val.shape)
return None


def normalize_negative_slice_stop(
gm: GraphModule, settings: object = None
) -> GraphModule:
"""Normalize negative symbolic slice bounds to positive dim-relative bounds.

Python slicing accepts negative bounds such as x[-n:] or x[:-n]. TensorRT
shape expressions need the equivalent positive bound, dim_size - n.
"""
modified = False

for node in list(gm.graph.nodes):
if node.op != "call_function" or node.target != torch.ops.aten.slice.Tensor:
continue

args = list(node.args)
if len(args) < 3:
continue

input_node, dim = args[:2]
if not isinstance(input_node, Node) or not isinstance(dim, int):
continue

rank = _rank(input_node)
if rank is not None:
# Match PyTorch dim normalization for negative dims.
dim = dim % rank

rewritten = False
# aten.slice.Tensor can appear as (input, dim, start) or
# (input, dim, start, stop, ...). Normalize either symbolic bound.
for bound_index in (2, 3):
if len(args) <= bound_index:
continue

bound = args[bound_index]
positive_offset = _negative_symint_operand(bound)
if positive_offset is None:
continue

with gm.graph.inserting_before(node):
dim_size = gm.graph.call_function(
torch.ops.aten.sym_size.int, args=(input_node, dim)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice yeah I think is exactly what we are looking for.

If you want something a bit cleaner to read, we have this utility:
https://github.com/pytorch/TensorRT/blob/main/py/torch_tensorrt/dynamo/lowering/_SubgraphBuilder.py

e.g. https://github.com/pytorch/TensorRT/blob/main/py/torch_tensorrt/dynamo/lowering/passes/complex_graph_rewrite.py#L456

but its more style than functional

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, let me review and edit accordingly.

)
# A negative symbolic bound -n becomes dim_size - n.
normalized_bound = gm.graph.call_function(
operator.sub, args=(dim_size, positive_offset)
)

args[bound_index] = normalized_bound
rewritten = True

if rewritten:
args[1] = dim
node.args = tuple(args)
modified = True

return clean_up_graph_after_modifications(gm) if modified else gm
99 changes: 99 additions & 0 deletions tests/py/dynamo/lowering/test_aten_lowering_passes.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import operator
import sys

import torch
import torch_tensorrt
from torch.testing._internal.common_utils import TestCase, run_tests
Expand Down Expand Up @@ -278,6 +281,102 @@ def forward(self, x: torch.Tensor):
self.assertTrue(True)


class TestNormalizeNegativeSliceStop(TestCase):
def test_normalizes_negative_symbolic_start_bound(self):
from torch_tensorrt.dynamo.lowering.passes.normalize_negative_slice_stop import (
normalize_negative_slice_stop,
)

graph = torch.fx.Graph()
x = graph.placeholder("x")
x.meta["val"] = torch.empty(2, 5, 3)
n = graph.placeholder("n")
neg = graph.call_function(operator.neg, args=(n,))
sliced = graph.call_function(torch.ops.aten.slice.Tensor, args=(x, -2, neg))
graph.output(sliced)

gm = torch.fx.GraphModule({}, graph)
gm = normalize_negative_slice_stop(gm)

slice_node = next(
node
for node in gm.graph.nodes
if node.op == "call_function" and node.target == torch.ops.aten.slice.Tensor
)
self.assertEqual(slice_node.args[1], 1)

normalized_start = slice_node.args[2]
self.assertEqual(normalized_start.op, "call_function")
self.assertEqual(normalized_start.target, operator.sub)

dim_size, offset = normalized_start.args
self.assertEqual(dim_size.target, torch.ops.aten.sym_size.int)
self.assertEqual(dim_size.args[0], x)
self.assertEqual(dim_size.args[1], 1)
self.assertEqual(offset, n)

def test_normalizes_negative_symbolic_stop_bound(self):
from torch_tensorrt.dynamo.lowering.passes.normalize_negative_slice_stop import (
normalize_negative_slice_stop,
)

graph = torch.fx.Graph()
x = graph.placeholder("x")
x.meta["val"] = torch.empty(2, 5, 3)
n = graph.placeholder("n")
neg = graph.call_function(torch.ops.aten.neg.default, args=(n,))
sliced = graph.call_function(torch.ops.aten.slice.Tensor, args=(x, 1, 0, neg))
graph.output(sliced)

gm = torch.fx.GraphModule({}, graph)
gm = normalize_negative_slice_stop(gm)

slice_node = next(
node
for node in gm.graph.nodes
if node.op == "call_function" and node.target == torch.ops.aten.slice.Tensor
)

normalized_stop = slice_node.args[3]
self.assertEqual(normalized_stop.op, "call_function")
self.assertEqual(normalized_stop.target, operator.sub)

dim_size, offset = normalized_stop.args
self.assertEqual(dim_size.target, torch.ops.aten.sym_size.int)
self.assertEqual(dim_size.args[0], x)
self.assertEqual(dim_size.args[1], 1)
self.assertEqual(offset, n)


class TestEliminateSymMinInt64Max(TestCase):
def test_eliminates_noop_sym_min_int64_max(self):
if not hasattr(torch, "sym_min"):
self.skipTest("torch.sym_min is not available")

from torch_tensorrt.dynamo.lowering.passes.eliminate_sym_min_int64_max import (
eliminate_sym_min_int64_max,
)

graph = torch.fx.Graph()
x = graph.placeholder("x")
rhs_int64_max = graph.call_function(torch.sym_min, args=(x, sys.maxsize))
lhs_int64_max = graph.call_function(torch.sym_min, args=(2**63 - 1, x))
graph.output((rhs_int64_max, lhs_int64_max))

gm = torch.fx.GraphModule({}, graph)
gm = eliminate_sym_min_int64_max(gm)

self.assertFalse(
any(
node.op == "call_function" and node.target is torch.sym_min
for node in gm.graph.nodes
)
)

output_node = next(node for node in gm.graph.nodes if node.op == "output")
self.assertEqual(output_node.args[0], (x, x))


class TestRewriteEfficientAttention(TestCase):
def test_force_causal_efficient_attention(self):
class RewriteEfficientAttention(torch.nn.Module):
Expand Down
Loading