Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions backends/arm/_passes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright 2025-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
Expand Down Expand Up @@ -95,7 +95,11 @@
FoldAndAnnotateQParamsPass,
QuantizeClampArgumentsPass,
)
from .eliminate_rescale_before_mul_pass import ( # noqa
EliminateRescaleBeforeMulPass,
)
from .fuse_batch_norm2d_pass import FuseBatchNorm2dPass # noqa
from .fuse_consecutive_rescales_pass import FuseConsecutiveRescalesPass # noqa
from .fuse_constant_ops_pass import ( # noqa
ComputeConstantOpsAOTPass,
FuseConstantArgsPass,
Expand Down
4 changes: 4 additions & 0 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,10 @@
DecomposeUnfoldToGatherPass,
DecomposeVarPass,
DecorateFp32toInt32CastingPass,
EliminateRescaleBeforeMulPass,
FoldAndAnnotateQParamsPass,
FuseBatchNorm2dPass,
FuseConsecutiveRescalesPass,
FuseConstantArgsPass,
FuseDuplicateUsersPass,
FuseEqualPlaceholdersPass,
Expand Down Expand Up @@ -264,6 +266,8 @@ def _tosa_pipeline(
# Ticket: MLETORCH-1539
DecomposeLinearPass(),
InsertRescaleInt32Pass(),
FuseConsecutiveRescalesPass(),
EliminateRescaleBeforeMulPass(),
Comment on lines 268 to +270
Copy link

Copilot AI Mar 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EliminateRescaleBeforeMulPass is added to the default TOSA INT pass pipeline, but there are no unit tests exercising its behavior (e.g., verifying the redundant INT32->INT32 RESCALE is removed and the downstream INT32 RESCALE scale is updated, including the shared-operand mul(x, x) case). Add a focused PassPipeline test similar to test_rescale_optimization.py to prevent regressions.

Copilot uses AI. Check for mistakes.
InsertControlFlowRescalesPass(),
DecomposeQuantNodesPass(),
]
Expand Down
162 changes: 162 additions & 0 deletions backends/arm/_passes/eliminate_rescale_before_mul_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import cast, Set, Type

import torch
from executorch.backends.arm._passes.arm_pass import ArmPass
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass
from torch.fx import GraphModule, Node
from torch.fx.passes.infra.pass_base import PassResult


class EliminateRescaleBeforeMulPass(ArmPass):
"""Eliminate redundant INT32->INT32 RESCALE ops feeding exclusively into MUL.

After InsertRescaleInt32Pass and FuseConsecutiveRescalesPass, the graph may
contain INT32->INT32 RESCALE nodes between consecutive elementwise ops.
When such a RESCALE feeds exclusively into MUL ops, it is computationally
redundant and can be removed with a compensating scale adjustment on the
downstream output RESCALE.

Why only MUL (not ADD/SUB):
For ADD/SUB, InsertRescaleInt32Pass rescales both inputs to a common
scale (2 * max(lhs, rhs) / (1 << shift_bits)) to ensure correct
integer arithmetic — the input RESCALE is required for operand
alignment. For MUL, input scales remain unchanged because the output
scale is the product of input scales (S_out = S_0 * S_1), regardless
of what the input scales are. A RESCALE adjusting scale before MUL is
therefore mathematically redundant: the adjustment can be absorbed
into the downstream output RESCALE as
new_out_scale = old_out_scale * removed_scale.
See InsertRescaleInt32Pass._get_inputs_rescaled_qparams() for the
scale arithmetic distinction.

Why not Conv2D/MatMul boundaries:
Empirically, eliminating RESCALE ops at Conv2D/MatMul boundaries
causes the Vela NPU compiler to generate worse instruction schedules.
The INT32->INT8->INT32 round-trips at those boundaries provide natural
scheduling breaks that help Vela's register allocator. Removing them
caused +12.9% (CC) and +16.1% (Detector) cycle regressions.

When multiple eligible RESCALEs feed the same MUL (e.g., both inputs have
INT32->INT32 RESCALEs), each is eliminated sequentially. The downstream
scale adjustments compose correctly because MUL's output scale is
multiplicative: removing RESCALE_A (scale S_a) then RESCALE_B (scale S_b)
yields new_out_scale = old_out_scale * S_a * S_b, which is correct.
"""

_passes_required_after: Set[Type[ExportPass]] = set()

def call(self, graph_module: GraphModule) -> PassResult:

Check warning on line 54 in backends/arm/_passes/eliminate_rescale_before_mul_pass.py

View workflow job for this annotation

GitHub Actions / lintrunner

FLAKE8 C901

'EliminateRescaleBeforeMulPass.call' is too complex (15) See https://www.flake8rules.com/rules/C901.html.
graph = graph_module.graph
modified = False
nodes_to_erase = []

for node in list(graph.nodes):
node = cast(Node, node)
if not _is_rescale(node):
continue

# Must be INT32 output
if node.args[1] != torch.int32:
continue

# Must have zero points of 0 (INT32->INT32 rescales from
# InsertRescaleInt32Pass always have zp=0)
input_zp = node.args[3]
output_zp = node.args[4]
if input_zp != 0 or output_zp != 0:
continue

# All users must be MUL ops
if len(node.users) == 0:
continue
if not all(
u.op == "call_function"
and u.target == exir_ops.edge.aten.mul.Tensor
for u in node.users
):
continue

# All downstream users of each MUL must be RESCALEs so we can
# compensate for the removed scale. Without this guard, non-RESCALE
# consumers of MUL would receive incorrectly scaled values.
if not all(
mul_out.users and all(_is_rescale(u) for u in mul_out.users)
for mul_out in node.users
):
continue

# All downstream RESCALEs must produce INT32 (staying within the
# INT32 computation region). If any converts to INT8/INT16, it
# defines a quantization boundary where the annotated scale must
# match the actual integer values. Modifying such a RESCALE would
# break TABLE ops (exp, log, sigmoid, etc.) that build lookup
# tables from the quantization annotation, and would also affect
# Conv/MatMul boundaries where Vela relies on precise scaling.
if not all(
mul_output_user.args[1] == torch.int32
for mul_out in node.users
for mul_output_user in mul_out.users
):
continue

# Check that the input is also INT32 — the preceding node should
# produce INT32 (either another RESCALE with INT32 output, or an
# elementwise op wrapped by InsertRescaleInt32Pass).
rescale_input = node.args[0]
if not _produces_int32(rescale_input):
continue

removed_scale = float(node.args[2][0])

# Adjust the downstream output RESCALE scale for each MUL user
for mul_user in list(node.users):
for mul_output_user in list(mul_user.users):
old_scale = float(mul_output_user.args[2][0])
new_scale = old_scale * removed_scale
args = list(mul_output_user.args)
args[2] = [new_scale]
mul_output_user.args = tuple(args)
Comment on lines +115 to +124
Copy link

Copilot AI Mar 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pass assumes RESCALE scales is a single-element list (args[2][0]) for both the removed RESCALE and the downstream compensation RESCALEs. The backend also creates per-channel RESCALEs (multi-element scales lists); if such a node ever matches the structural guards here, this logic would silently adjust only the first channel. Add an explicit len(scales)==1 guard (or handle vector scales correctly) before reading/writing args[2][0].

Copilot uses AI. Check for mistakes.

Comment on lines +117 to +125
Copy link

Copilot AI Mar 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Elimination scale compensation is computed per MUL user of the RESCALE, but Node.users does not capture how many times the RESCALE output is consumed by that MUL. If the same RESCALE feeds both MUL operands (e.g. mul(r, r) due to shared operand), removing it changes the MUL output by removed_scale^2, yet this code only multiplies downstream RESCALE scale by removed_scale once. Adjust the compensation factor based on the number of occurrences of the RESCALE node in each MUL's inputs (typically 1 or 2 for mul).

Copilot uses AI. Check for mistakes.
# Replace the RESCALE with its input
node.replace_all_uses_with(rescale_input)
nodes_to_erase.append(node)
modified = True

for n in nodes_to_erase:
if len(n.users) == 0:
graph.erase_node(n)

if modified:
graph_module = super().call(graph_module).graph_module
graph_module.recompile()

return PassResult(graph_module, modified)


def _is_rescale(node: Node) -> bool:
return (
node.op == "call_function"
and node.target == exir_ops.backend.tosa.RESCALE.default
)


def _produces_int32(node: Node) -> bool:
"""Check if a node produces INT32 output."""
if isinstance(node, Node):
# If it's a RESCALE, check its output dtype arg
if _is_rescale(node):
return node.args[1] == torch.int32
# For other ops, check the fake tensor metadata
if "val" in node.meta:
val = node.meta["val"]
if isinstance(val, torch.Tensor) and val.dtype == torch.int32:
return True
if hasattr(val, "dtype") and val.dtype == torch.int32:
return True
return False
124 changes: 124 additions & 0 deletions backends/arm/_passes/fuse_consecutive_rescales_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import cast, Set, Type

import torch
from executorch.backends.arm._passes.arm_pass import ArmPass
from executorch.backends.arm._passes.arm_pass_utils import create_node
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass
from torch.fx import GraphModule, Node
from torch.fx.passes.infra.pass_base import PassResult


class FuseConsecutiveRescalesPass(ArmPass):
"""Fuse consecutive RESCALE(INT32->INT8/INT16) -> RESCALE(INT8/INT16->INT32)
pairs.

InsertRescaleInt32Pass wraps each add/mul/sub with input rescales
(INT8/INT16->INT32) and an output rescale (INT32->INT8/INT16). When
two such ops are chained (e.g., add1 -> add2), the output rescale
of add1 feeds directly into an input rescale of add2, creating a
redundant INT32->INT8/INT16->INT32 round-trip that loses precision.

This pass detects such pairs and either:
- Removes both if the composed scale is ~1.0 and zero points match
- Replaces both with a single INT32->INT32 RESCALE with composed
scale

Handles multi-user R1 nodes: when R1 feeds both RESCALE and
non-RESCALE users, each R1->R2 RESCALE pair is fused individually
while preserving R1 for its non-RESCALE users.

"""

_passes_required_after: Set[Type[ExportPass]] = set()

def call(self, graph_module: GraphModule) -> PassResult:
graph = graph_module.graph
modified = False
nodes_to_erase = []

for node in list(graph.nodes):
node = cast(Node, node)
if not _is_rescale(node):
continue

# R1 = node: output rescale (INT32 -> INT8/INT16)
r1_output_dtype = node.args[1]
if r1_output_dtype not in (torch.int8, torch.int16):
continue

r1_input = node.args[0]
r1_input_zp = node.args[3]
r1_output_zp = node.args[4]
r1_scale = float(node.args[2][0])

# Check each user individually (handles multi-user R1)
for user in list(node.users):
if not _is_rescale(user):
continue

# R2 = user: input rescale (INT8/INT16 -> INT32)
r2_output_dtype = user.args[1]
if r2_output_dtype != torch.int32:
continue

r2_input_zp = user.args[3]

# Guard: intermediate zero points must match for correct
# composition. Without this, the offset term
# (r1_output_zp - r2_input_zp) * r2_scale is silently lost.
if r1_output_zp != r2_input_zp:
continue

r2_scale = float(user.args[2][0])
composed_scale = r1_scale * r2_scale
r2_output_zp = user.args[4]

if abs(composed_scale - 1.0) < 1e-6 and r1_input_zp == r2_output_zp:
# Identity: wire R1's input directly to R2's users
user.replace_all_uses_with(r1_input)
nodes_to_erase.append(user)
else:
# Non-identity: replace with single INT32->INT32 RESCALE
with graph.inserting_before(user):
composed_node = create_node(
graph,
exir_ops.backend.tosa.RESCALE.default,
(
r1_input,
r2_output_dtype,
[composed_scale],
r1_input_zp,
r2_output_zp,
),
from_node=user,
)
user.replace_all_uses_with(composed_node)
nodes_to_erase.append(user)

modified = True

# Always consider R1 for removal; actual erasure is guarded below
nodes_to_erase.append(node)

for node in nodes_to_erase:
if len(node.users) == 0:
graph.erase_node(node)

if modified:
graph_module = super().call(graph_module).graph_module
graph_module.recompile()

return PassResult(graph_module, modified)


def _is_rescale(node: Node) -> bool:
return (
node.op == "call_function"
and node.target == exir_ops.backend.tosa.RESCALE.default
)
Loading
Loading