Add EliminateRescaleBeforeMulPass to remove redundant pre-MUL rescales#17999
Add EliminateRescaleBeforeMulPass to remove redundant pre-MUL rescales#17999Ninja91 wants to merge 2 commits intopytorch:mainfrom
Conversation
…rch#17830) Summary: TOSA requires INT32 arithmetic for add/sub/mul ops. `InsertRescaleInt32Pass` wraps each such op with input RESCALEs (INT8→INT32) and output RESCALE (INT32→INT8). When two such ops are chained, the output RESCALE of op1 feeds directly into the input RESCALE of op2, creating a redundant INT32→INT8→INT32 round-trip that wastes NPU cycles and loses precision. `FuseConsecutiveRescalesPass` detects these pairs and either: - Removes both if the composed scale is ~1.0 (identity) - Replaces both with a single INT32→INT32 RESCALE with composed scale Handles multi-user R1 nodes (e.g., residual connections, LayerNorm branching) by fusing each R1→R2 pair individually while preserving R1 for non-RESCALE users. ## Context Each unnecessary RESCALE is decomposed by Vela into Add+Mul NPU instructions (~1,130 cycles each on Ethos-U55-128). In meta-internal quantized models, RESCALE overhead accounts for 25-50% of total NPU cycles. This pass eliminates consecutive pairs at op boundaries, with multi-user handling catching additional pairs from branching patterns (LayerNorm's sub feeding both mul_square and mul_normalize). This diff also adds a `ResidualConvBlock` toy model and pass-level unit tests. Reviewed By: 3l1 Differential Revision: D94483331
Summary: Eliminate redundant INT32->INT32 RESCALE ops that feed exclusively into elementwise MUL by absorbing their scale factor into the downstream output RESCALE. After InsertRescaleInt32Pass and FuseConsecutiveRescalesPass, the graph may contain residual INT32->INT32 RESCALE nodes between consecutive elementwise ops. For MUL, input scales do not need alignment -- the output scale is the product of input scales (S_out = S_0 * S_1). A RESCALE adjusting scale before MUL is therefore mathematically redundant: the scale change can be absorbed into the downstream output RESCALE as `new_out_scale = old_out_scale * removed_scale`. The optimization is restricted to MUL nodes whose downstream RESCALEs remain in INT32 (staying within the INT32 computation region). RESCALEs that convert to INT8/INT16 define quantization boundaries where TABLE ops (exp, log, sigmoid, etc.) build lookup tables from the quantization annotation -- modifying such RESCALEs would create a mismatch between actual integer values and annotations, producing incorrect LUT entries. This optimization does NOT apply to ADD/SUB (which require aligned input scales for correct integer arithmetic) or Conv2D/MatMul boundaries (where removing rescales empirically degrades Vela NPU compiler instruction scheduling). Stacked on D94483331 (FuseConsecutiveRescalesPass). Differential Revision: D95243636
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/17999
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 Awaiting Approval, 6 New Failures, 1 Cancelled JobAs of commit 9226722 with merge base 122fdef ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
There was a problem hiding this comment.
Pull request overview
Adds new ARM TOSA graph-cleanup passes to reduce redundant RESCALE ops introduced by INT32 wrapping, aiming to keep more of the elementwise region in INT32 and reduce unnecessary quantization round-trips.
Changes:
- Introduces
FuseConsecutiveRescalesPassto fuse/remove INT32→INT8/INT16→INT32 consecutive RESCALE pairs. - Introduces
EliminateRescaleBeforeMulPassto remove redundant INT32→INT32 RESCALEs feeding exclusively intoaten.mul, compensating by adjusting downstream INT32 RESCALE scales. - Wires both passes into the default TOSA INT pipeline and adds new pass/model tests (currently focused on the consecutive-rescale fusion).
Reviewed changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| backends/arm/_passes/fuse_consecutive_rescales_pass.py | New pass to fuse consecutive RESCALE pairs (round-trip removal / composition). |
| backends/arm/_passes/eliminate_rescale_before_mul_pass.py | New pass to eliminate redundant pre-MUL INT32→INT32 RESCALEs and adjust downstream scales. |
| backends/arm/_passes/arm_pass_manager.py | Inserts the new passes into the TOSA INT lowering pipeline. |
| backends/arm/_passes/init.py | Exposes the new passes from the _passes package. |
| backends/arm/test/passes/test_rescale_optimization.py | Adds tests validating consecutive RESCALE patterns and fusion behavior. |
| backends/arm/test/passes/test_fuse_quantized_activation_pass.py | Adds a focused test for quantized activation fusion (ReLU-after-conv). |
| backends/arm/test/models/test_residual_conv_block.py | Adds an end-to-end residual-block style model test for ARM backend pipelines. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # Adjust the downstream output RESCALE scale for each MUL user | ||
| for mul_user in list(node.users): | ||
| for mul_output_user in list(mul_user.users): | ||
| old_scale = float(mul_output_user.args[2][0]) | ||
| new_scale = old_scale * removed_scale | ||
| args = list(mul_output_user.args) | ||
| args[2] = [new_scale] | ||
| mul_output_user.args = tuple(args) | ||
|
|
There was a problem hiding this comment.
Elimination scale compensation is computed per MUL user of the RESCALE, but Node.users does not capture how many times the RESCALE output is consumed by that MUL. If the same RESCALE feeds both MUL operands (e.g. mul(r, r) due to shared operand), removing it changes the MUL output by removed_scale^2, yet this code only multiplies downstream RESCALE scale by removed_scale once. Adjust the compensation factor based on the number of occurrences of the RESCALE node in each MUL's inputs (typically 1 or 2 for mul).
| removed_scale = float(node.args[2][0]) | ||
|
|
||
| # Adjust the downstream output RESCALE scale for each MUL user | ||
| for mul_user in list(node.users): | ||
| for mul_output_user in list(mul_user.users): | ||
| old_scale = float(mul_output_user.args[2][0]) | ||
| new_scale = old_scale * removed_scale | ||
| args = list(mul_output_user.args) | ||
| args[2] = [new_scale] | ||
| mul_output_user.args = tuple(args) |
There was a problem hiding this comment.
This pass assumes RESCALE scales is a single-element list (args[2][0]) for both the removed RESCALE and the downstream compensation RESCALEs. The backend also creates per-channel RESCALEs (multi-element scales lists); if such a node ever matches the structural guards here, this logic would silently adjust only the first channel. Add an explicit len(scales)==1 guard (or handle vector scales correctly) before reading/writing args[2][0].
| InsertRescaleInt32Pass(), | ||
| FuseConsecutiveRescalesPass(), | ||
| EliminateRescaleBeforeMulPass(), |
There was a problem hiding this comment.
EliminateRescaleBeforeMulPass is added to the default TOSA INT pass pipeline, but there are no unit tests exercising its behavior (e.g., verifying the redundant INT32->INT32 RESCALE is removed and the downstream INT32 RESCALE scale is updated, including the shared-operand mul(x, x) case). Add a focused PassPipeline test similar to test_rescale_optimization.py to prevent regressions.
Summary:
Eliminate redundant INT32->INT32 RESCALE ops that feed exclusively into
elementwise MUL by absorbing their scale factor into the downstream
output RESCALE.
After InsertRescaleInt32Pass and FuseConsecutiveRescalesPass, the graph
may contain residual INT32->INT32 RESCALE nodes between consecutive
elementwise ops. For MUL, input scales do not need alignment -- the
output scale is the product of input scales (S_out = S_0 * S_1). A
RESCALE adjusting scale before MUL is therefore mathematically
redundant: the scale change can be absorbed into the downstream output
RESCALE as
new_out_scale = old_out_scale * removed_scale.The optimization is restricted to MUL nodes whose downstream RESCALEs
remain in INT32 (staying within the INT32 computation region). RESCALEs
that convert to INT8/INT16 define quantization boundaries where TABLE
ops (exp, log, sigmoid, etc.) build lookup tables from the quantization
annotation -- modifying such RESCALEs would create a mismatch between
actual integer values and annotations, producing incorrect LUT entries.
This optimization does NOT apply to ADD/SUB (which require aligned
input scales for correct integer arithmetic) or Conv2D/MatMul boundaries
(where removing rescales empirically degrades Vela NPU compiler
instruction scheduling).
Stacked on D94483331 (FuseConsecutiveRescalesPass).
Differential Revision: D95243636