Skip to content

fix 4326: bound exported dynamic shapes and normalize symbolic slice …#4341

Open
micwill755 wants to merge 3 commits into
pytorch:mainfrom
micwill755:fix/zoomasr-dynamic-shape-bounds
Open

fix 4326: bound exported dynamic shapes and normalize symbolic slice …#4341
micwill755 wants to merge 3 commits into
pytorch:mainfrom
micwill755:fix/zoomasr-dynamic-shape-bounds

Conversation

@micwill755

@micwill755 micwill755 commented Jun 12, 2026

Copy link
Copy Markdown
Collaborator

Description

This change fixes three dynamic-shape issues exposed by the ZoomASR repro. Together, these fixes make Torch-TensorRT respect user-provided input bounds, lower negative symbolic slicing into TensorRT-friendly shape math, and clean up symbolic expressions that otherwise survive into runtime or conversion.

  1. Dynamic shape bounds fix

A previous working branch had a helper called _apply_dynamic_shape_bounds in Torch-TensorRT _compiler.py. That function propagated user-provided torch_tensorrt.Input min/max bounds into the FX shape_env, which explains why the old path used the intended ZoomASR feature length bound of 3000 instead of the broad exported range [9, 16384].

  1. Negative symbolic slice normalization

PyTorch 2.12 can emit slices with negative symbolic bounds, such as x[-n:] or x[:-n]. This change adds a lowering pass that rewrites those bounds to dim_size - n for both slice start and stop, making the graph explicit and TensorRT-friendly.

  1. No-op sym_min cleanup

The issue discussion mentioned removing torch.sym_min(x, INT64_MAX) as a possible needed patch. This change adds that as a separate lowering pass; it removes the no-op sym_min expression before runtime, preventing failures where torch.sym_min receives tensor-like inputs.

Fixes #4326

Type of change

  • Bug fix (non-breaking change which fixes an issue)

Checklist:

  • My code follows the style guidelines of this project
  • I performed a self-review of my own code
  • I commented my code, particularly in hard-to-understand areas and hacks
  • I made corresponding changes to the documentation
  • I added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I added the relevant labels to my PR so that relevant reviewers are notified

@micwill755 micwill755 requested a review from narendasan June 12, 2026 20:50
@meta-cla meta-cla Bot added the cla signed label Jun 12, 2026
@github-actions github-actions Bot added component: tests Issues re: Tests component: lowering Issues re: The lowering / preprocessing passes component: conversion Issues re: Conversion stage component: core Issues re: The core compiler component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Jun 12, 2026
@github-actions github-actions Bot requested a review from cehongwang June 12, 2026 20:50
@narendasan

Copy link
Copy Markdown
Collaborator

@micwill755 precommit/prek (https://github.com/j178/prek) should be able to lint your code for you


with gm.graph.inserting_before(node):
dim_size = gm.graph.call_function(
torch.ops.aten.sym_size.int, args=(input_node, dim)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice yeah I think is exactly what we are looking for.

If you want something a bit cleaner to read, we have this utility:
https://github.com/pytorch/TensorRT/blob/main/py/torch_tensorrt/dynamo/lowering/_SubgraphBuilder.py

e.g. https://github.com/pytorch/TensorRT/blob/main/py/torch_tensorrt/dynamo/lowering/passes/complex_graph_rewrite.py#L456

but its more style than functional

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, let me review and edit accordingly.

partitioned_module.recompile()


def _apply_dynamic_shape_bounds(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@apbose please review / comment on if this overlaps with your work

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was adapted from a previous fix in Wenbing’s branch: _apply_dynamic_shape_bounds.

It fixes the case where ZoomASR is exported with broad or unbounded Dim.DYNAMIC ranges by applying the user-provided torch_tensorrt.Input min/max bounds during Torch-TensorRT compilation. Without this compiler-side fix, I had to work around the issue in the ZoomASR export code by changing the dynamic shape annotations from:

"features": {0: Dim.DYNAMIC, 1: Dim.DYNAMIC, 2: Dim.STATIC},
"feature_lengths": {0: Dim.DYNAMIC},
to explicitly bounded dimensions:

batch_dim = Dim("batch", min=1, max=max_batch_size)
feature_len_dim = Dim("feature_len", min=1, max=3000)

dynamic_shapes={
"features": {0: batch_dim, 1: feature_len_dim, 2: Dim.STATIC},
"feature_lengths": {0: batch_dim},
}

With _apply_dynamic_shape_bounds, the model export can keep using Dim.DYNAMIC, while Torch-TensorRT still constructs the intended bounded TensorRT optimization profile from the provided input specs.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My only concern with this code and Wenbings prior version of this is im not sure it properly inserts the updated shapes. Like I would expect some form of shape prop done here as well after constraining the placeholders. If we can solve this in export with explicit ranges, it might be better to split this into its own PR that fully handles constraining an already exported exported program. It should still unblock Shane without and the explicit shape in export right?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed and yes this will unblock Shane.

@narendasan narendasan requested a review from apbose June 12, 2026 21:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: core Issues re: The core compiler component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: lowering Issues re: The lowering / preprocessing passes component: tests Issues re: Tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

✨[Feature] Support negative dimension slicing in the dynamic shape case

2 participants