Skip to content

Add EliminateRescaleBeforeMulPass to remove redundant pre-MUL rescales#17999

Open
Ninja91 wants to merge 2 commits intopytorch:mainfrom
Ninja91:export-D95243636
Open

Add EliminateRescaleBeforeMulPass to remove redundant pre-MUL rescales#17999
Ninja91 wants to merge 2 commits intopytorch:mainfrom
Ninja91:export-D95243636

Conversation

@Ninja91
Copy link
Contributor

@Ninja91 Ninja91 commented Mar 8, 2026

Summary:
Eliminate redundant INT32->INT32 RESCALE ops that feed exclusively into
elementwise MUL by absorbing their scale factor into the downstream
output RESCALE.

After InsertRescaleInt32Pass and FuseConsecutiveRescalesPass, the graph
may contain residual INT32->INT32 RESCALE nodes between consecutive
elementwise ops. For MUL, input scales do not need alignment -- the
output scale is the product of input scales (S_out = S_0 * S_1). A
RESCALE adjusting scale before MUL is therefore mathematically
redundant: the scale change can be absorbed into the downstream output
RESCALE as new_out_scale = old_out_scale * removed_scale.

The optimization is restricted to MUL nodes whose downstream RESCALEs
remain in INT32 (staying within the INT32 computation region). RESCALEs
that convert to INT8/INT16 define quantization boundaries where TABLE
ops (exp, log, sigmoid, etc.) build lookup tables from the quantization
annotation -- modifying such RESCALEs would create a mismatch between
actual integer values and annotations, producing incorrect LUT entries.

This optimization does NOT apply to ADD/SUB (which require aligned
input scales for correct integer arithmetic) or Conv2D/MatMul boundaries
(where removing rescales empirically degrades Vela NPU compiler
instruction scheduling).

Stacked on D94483331 (FuseConsecutiveRescalesPass).

Differential Revision: D95243636

Ninja91 added 2 commits March 8, 2026 16:02
…rch#17830)

Summary:

TOSA requires INT32 arithmetic for add/sub/mul ops. `InsertRescaleInt32Pass` wraps each such op with input RESCALEs (INT8→INT32) and output RESCALE (INT32→INT8). When two such ops are chained, the output RESCALE of op1 feeds directly into the input RESCALE of op2, creating a redundant INT32→INT8→INT32 round-trip that wastes NPU cycles and loses precision.

`FuseConsecutiveRescalesPass` detects these pairs and either:
- Removes both if the composed scale is ~1.0 (identity)
- Replaces both with a single INT32→INT32 RESCALE with composed scale

Handles multi-user R1 nodes (e.g., residual connections, LayerNorm branching) by fusing each R1→R2 pair individually while preserving R1 for non-RESCALE users.

## Context

Each unnecessary RESCALE is decomposed by Vela into Add+Mul NPU instructions (~1,130 cycles each on Ethos-U55-128). In meta-internal quantized models, RESCALE overhead accounts for 25-50% of total NPU cycles. This pass eliminates consecutive pairs at op boundaries, with multi-user handling catching additional pairs from branching patterns (LayerNorm's sub feeding both mul_square and mul_normalize).

This diff also adds a `ResidualConvBlock` toy model and pass-level unit tests.

Reviewed By: 3l1

Differential Revision: D94483331
Summary:
Eliminate redundant INT32->INT32 RESCALE ops that feed exclusively into
elementwise MUL by absorbing their scale factor into the downstream
output RESCALE.

After InsertRescaleInt32Pass and FuseConsecutiveRescalesPass, the graph
may contain residual INT32->INT32 RESCALE nodes between consecutive
elementwise ops. For MUL, input scales do not need alignment -- the
output scale is the product of input scales (S_out = S_0 * S_1). A
RESCALE adjusting scale before MUL is therefore mathematically
redundant: the scale change can be absorbed into the downstream output
RESCALE as `new_out_scale = old_out_scale * removed_scale`.

The optimization is restricted to MUL nodes whose downstream RESCALEs
remain in INT32 (staying within the INT32 computation region). RESCALEs
that convert to INT8/INT16 define quantization boundaries where TABLE
ops (exp, log, sigmoid, etc.) build lookup tables from the quantization
annotation -- modifying such RESCALEs would create a mismatch between
actual integer values and annotations, producing incorrect LUT entries.

This optimization does NOT apply to ADD/SUB (which require aligned
input scales for correct integer arithmetic) or Conv2D/MatMul boundaries
(where removing rescales empirically degrades Vela NPU compiler
instruction scheduling).

Stacked on D94483331 (FuseConsecutiveRescalesPass).

Differential Revision: D95243636
Copilot AI review requested due to automatic review settings March 8, 2026 23:02
@Ninja91 Ninja91 requested a review from digantdesai as a code owner March 8, 2026 23:02
@pytorch-bot
Copy link

pytorch-bot bot commented Mar 8, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/17999

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 Awaiting Approval, 6 New Failures, 1 Cancelled Job

As of commit 9226722 with merge base 122fdef (image):

AWAITING APPROVAL - The following workflow needs approval before CI can run:

NEW FAILURES - The following jobs have failed:

CANCELLED JOB - The following job was cancelled. Please retry:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 8, 2026
@meta-codesync
Copy link
Contributor

meta-codesync bot commented Mar 8, 2026

@Ninja91 has exported this pull request. If you are a Meta employee, you can view the originating Diff in D95243636.

@github-actions
Copy link

github-actions bot commented Mar 8, 2026

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds new ARM TOSA graph-cleanup passes to reduce redundant RESCALE ops introduced by INT32 wrapping, aiming to keep more of the elementwise region in INT32 and reduce unnecessary quantization round-trips.

Changes:

  • Introduces FuseConsecutiveRescalesPass to fuse/remove INT32→INT8/INT16→INT32 consecutive RESCALE pairs.
  • Introduces EliminateRescaleBeforeMulPass to remove redundant INT32→INT32 RESCALEs feeding exclusively into aten.mul, compensating by adjusting downstream INT32 RESCALE scales.
  • Wires both passes into the default TOSA INT pipeline and adds new pass/model tests (currently focused on the consecutive-rescale fusion).

Reviewed changes

Copilot reviewed 7 out of 7 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
backends/arm/_passes/fuse_consecutive_rescales_pass.py New pass to fuse consecutive RESCALE pairs (round-trip removal / composition).
backends/arm/_passes/eliminate_rescale_before_mul_pass.py New pass to eliminate redundant pre-MUL INT32→INT32 RESCALEs and adjust downstream scales.
backends/arm/_passes/arm_pass_manager.py Inserts the new passes into the TOSA INT lowering pipeline.
backends/arm/_passes/init.py Exposes the new passes from the _passes package.
backends/arm/test/passes/test_rescale_optimization.py Adds tests validating consecutive RESCALE patterns and fusion behavior.
backends/arm/test/passes/test_fuse_quantized_activation_pass.py Adds a focused test for quantized activation fusion (ReLU-after-conv).
backends/arm/test/models/test_residual_conv_block.py Adds an end-to-end residual-block style model test for ARM backend pipelines.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +117 to +125
# Adjust the downstream output RESCALE scale for each MUL user
for mul_user in list(node.users):
for mul_output_user in list(mul_user.users):
old_scale = float(mul_output_user.args[2][0])
new_scale = old_scale * removed_scale
args = list(mul_output_user.args)
args[2] = [new_scale]
mul_output_user.args = tuple(args)

Copy link

Copilot AI Mar 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Elimination scale compensation is computed per MUL user of the RESCALE, but Node.users does not capture how many times the RESCALE output is consumed by that MUL. If the same RESCALE feeds both MUL operands (e.g. mul(r, r) due to shared operand), removing it changes the MUL output by removed_scale^2, yet this code only multiplies downstream RESCALE scale by removed_scale once. Adjust the compensation factor based on the number of occurrences of the RESCALE node in each MUL's inputs (typically 1 or 2 for mul).

Copilot uses AI. Check for mistakes.
Comment on lines +115 to +124
removed_scale = float(node.args[2][0])

# Adjust the downstream output RESCALE scale for each MUL user
for mul_user in list(node.users):
for mul_output_user in list(mul_user.users):
old_scale = float(mul_output_user.args[2][0])
new_scale = old_scale * removed_scale
args = list(mul_output_user.args)
args[2] = [new_scale]
mul_output_user.args = tuple(args)
Copy link

Copilot AI Mar 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pass assumes RESCALE scales is a single-element list (args[2][0]) for both the removed RESCALE and the downstream compensation RESCALEs. The backend also creates per-channel RESCALEs (multi-element scales lists); if such a node ever matches the structural guards here, this logic would silently adjust only the first channel. Add an explicit len(scales)==1 guard (or handle vector scales correctly) before reading/writing args[2][0].

Copilot uses AI. Check for mistakes.
Comment on lines 268 to +270
InsertRescaleInt32Pass(),
FuseConsecutiveRescalesPass(),
EliminateRescaleBeforeMulPass(),
Copy link

Copilot AI Mar 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EliminateRescaleBeforeMulPass is added to the default TOSA INT pass pipeline, but there are no unit tests exercising its behavior (e.g., verifying the redundant INT32->INT32 RESCALE is removed and the downstream INT32 RESCALE scale is updated, including the shared-operand mul(x, x) case). Add a focused PassPipeline test similar to test_rescale_optimization.py to prevent regressions.

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported meta-exported

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants