-
Notifications
You must be signed in to change notification settings - Fork 404
fix 4326: bound exported dynamic shapes and normalize symbolic slice … #4341
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,50 @@ | ||
| import sys | ||
|
|
||
| import torch | ||
| from torch.fx import GraphModule, Node | ||
|
|
||
| from .pass_utils import clean_up_graph_after_modifications | ||
|
|
||
|
|
||
| _INT64_MAX = 2**63 - 1 | ||
| _SYM_MIN = getattr(torch, "sym_min", None) | ||
|
|
||
|
|
||
| def _is_int64_max(x: object) -> bool: | ||
| return isinstance(x, int) and x in (sys.maxsize, _INT64_MAX) | ||
|
|
||
|
|
||
| def eliminate_sym_min_int64_max( | ||
| gm: GraphModule, settings: object = None | ||
| ) -> GraphModule: | ||
| """Remove no-op sym_min nodes where one operand is INT64_MAX. | ||
|
|
||
| torch.export may emit sym_min(sym, INT64_MAX) for an effectively unbounded | ||
| symbolic value. That expression is equivalent to sym, and leaving it in the | ||
| graph can produce runtime calls to torch.sym_min with Tensor inputs. | ||
| """ | ||
| if _SYM_MIN is None: | ||
| return gm | ||
|
|
||
| modified = False | ||
| for node in list(gm.graph.nodes): | ||
| if ( | ||
| node.op != "call_function" | ||
| or node.target is not _SYM_MIN | ||
| or len(node.args) < 2 | ||
| ): | ||
| continue | ||
|
|
||
| lhs, rhs = node.args[:2] | ||
| if _is_int64_max(rhs) and isinstance(lhs, Node): | ||
| passthrough = lhs | ||
| elif _is_int64_max(lhs) and isinstance(rhs, Node): | ||
| passthrough = rhs | ||
| else: | ||
| continue | ||
|
|
||
| node.replace_all_uses_with(passthrough) | ||
| gm.graph.erase_node(node) | ||
| modified = True | ||
|
|
||
| return clean_up_graph_after_modifications(gm) if modified else gm |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,88 @@ | ||
| import operator | ||
| from typing import Optional | ||
|
|
||
| import torch | ||
| from torch.fx import GraphModule, Node | ||
|
|
||
| from .pass_utils import clean_up_graph_after_modifications | ||
|
|
||
|
|
||
| def _negative_symint_operand(x: object) -> Optional[object]: | ||
| # Return n for symbolic bounds represented as -n. The caller rewrites | ||
| # that bound to dim_size - n, matching Python's negative indexing rules. | ||
| if ( | ||
| isinstance(x, Node) | ||
| and x.op == "call_function" | ||
| and x.target in (operator.neg, torch.ops.aten.neg.default) | ||
| and len(x.args) == 1 | ||
| ): | ||
| return x.args[0] | ||
| return None | ||
|
|
||
|
|
||
| def _rank(x: Node) -> Optional[int]: | ||
| val = x.meta.get("val") | ||
| if isinstance(val, torch.Tensor): | ||
| return val.dim() | ||
| if hasattr(val, "shape"): | ||
| return len(val.shape) | ||
| return None | ||
|
|
||
|
|
||
| def normalize_negative_slice_stop( | ||
| gm: GraphModule, settings: object = None | ||
| ) -> GraphModule: | ||
| """Normalize negative symbolic slice bounds to positive dim-relative bounds. | ||
|
|
||
| Python slicing accepts negative bounds such as x[-n:] or x[:-n]. TensorRT | ||
| shape expressions need the equivalent positive bound, dim_size - n. | ||
| """ | ||
| modified = False | ||
|
|
||
| for node in list(gm.graph.nodes): | ||
| if node.op != "call_function" or node.target != torch.ops.aten.slice.Tensor: | ||
| continue | ||
|
|
||
| args = list(node.args) | ||
| if len(args) < 3: | ||
| continue | ||
|
|
||
| input_node, dim = args[:2] | ||
| if not isinstance(input_node, Node) or not isinstance(dim, int): | ||
| continue | ||
|
|
||
| rank = _rank(input_node) | ||
| if rank is not None: | ||
| # Match PyTorch dim normalization for negative dims. | ||
| dim = dim % rank | ||
|
|
||
| rewritten = False | ||
| # aten.slice.Tensor can appear as (input, dim, start) or | ||
| # (input, dim, start, stop, ...). Normalize either symbolic bound. | ||
| for bound_index in (2, 3): | ||
| if len(args) <= bound_index: | ||
| continue | ||
|
|
||
| bound = args[bound_index] | ||
| positive_offset = _negative_symint_operand(bound) | ||
| if positive_offset is None: | ||
| continue | ||
|
|
||
| with gm.graph.inserting_before(node): | ||
| dim_size = gm.graph.call_function( | ||
| torch.ops.aten.sym_size.int, args=(input_node, dim) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice yeah I think is exactly what we are looking for. If you want something a bit cleaner to read, we have this utility: but its more style than functional
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Cool, let me review and edit accordingly. |
||
| ) | ||
| # A negative symbolic bound -n becomes dim_size - n. | ||
| normalized_bound = gm.graph.call_function( | ||
| operator.sub, args=(dim_size, positive_offset) | ||
| ) | ||
|
|
||
| args[bound_index] = normalized_bound | ||
| rewritten = True | ||
|
|
||
| if rewritten: | ||
| args[1] = dim | ||
| node.args = tuple(args) | ||
| modified = True | ||
|
|
||
| return clean_up_graph_after_modifications(gm) if modified else gm | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@apbose please review / comment on if this overlaps with your work
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was adapted from a previous fix in Wenbing’s branch: _apply_dynamic_shape_bounds.
It fixes the case where ZoomASR is exported with broad or unbounded Dim.DYNAMIC ranges by applying the user-provided torch_tensorrt.Input min/max bounds during Torch-TensorRT compilation. Without this compiler-side fix, I had to work around the issue in the ZoomASR export code by changing the dynamic shape annotations from:
"features": {0: Dim.DYNAMIC, 1: Dim.DYNAMIC, 2: Dim.STATIC},
"feature_lengths": {0: Dim.DYNAMIC},
to explicitly bounded dimensions:
batch_dim = Dim("batch", min=1, max=max_batch_size)
feature_len_dim = Dim("feature_len", min=1, max=3000)
dynamic_shapes={
"features": {0: batch_dim, 1: feature_len_dim, 2: Dim.STATIC},
"feature_lengths": {0: batch_dim},
}
With _apply_dynamic_shape_bounds, the model export can keep using Dim.DYNAMIC, while Torch-TensorRT still constructs the intended bounded TensorRT optimization profile from the provided input specs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My only concern with this code and Wenbings prior version of this is im not sure it properly inserts the updated shapes. Like I would expect some form of shape prop done here as well after constraining the placeholders. If we can solve this in export with explicit ranges, it might be better to split this into its own PR that fully handles constraining an already exported exported program. It should still unblock Shane without and the explicit shape in export right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed and yes this will unblock Shane.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah PR 4213 addresses this, validating the input bounds against the exporter bounds.