diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index f11a91a600c..b021a819b22 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -95,7 +95,11 @@ FoldAndAnnotateQParamsPass, QuantizeClampArgumentsPass, ) +from .eliminate_rescale_before_mul_pass import ( # noqa + EliminateRescaleBeforeMulPass, +) from .fuse_batch_norm2d_pass import FuseBatchNorm2dPass # noqa +from .fuse_consecutive_rescales_pass import FuseConsecutiveRescalesPass # noqa from .fuse_constant_ops_pass import ( # noqa ComputeConstantOpsAOTPass, FuseConstantArgsPass, diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 9cedc6851c8..9f724c49015 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -91,8 +91,10 @@ DecomposeUnfoldToGatherPass, DecomposeVarPass, DecorateFp32toInt32CastingPass, + EliminateRescaleBeforeMulPass, FoldAndAnnotateQParamsPass, FuseBatchNorm2dPass, + FuseConsecutiveRescalesPass, FuseConstantArgsPass, FuseDuplicateUsersPass, FuseEqualPlaceholdersPass, @@ -264,6 +266,8 @@ def _tosa_pipeline( # Ticket: MLETORCH-1539 DecomposeLinearPass(), InsertRescaleInt32Pass(), + FuseConsecutiveRescalesPass(), + EliminateRescaleBeforeMulPass(), InsertControlFlowRescalesPass(), DecomposeQuantNodesPass(), ] diff --git a/backends/arm/_passes/eliminate_rescale_before_mul_pass.py b/backends/arm/_passes/eliminate_rescale_before_mul_pass.py new file mode 100644 index 00000000000..8ee76080302 --- /dev/null +++ b/backends/arm/_passes/eliminate_rescale_before_mul_pass.py @@ -0,0 +1,162 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import cast, Set, Type + +import torch +from executorch.backends.arm._passes.arm_pass import ArmPass +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass +from torch.fx import GraphModule, Node +from torch.fx.passes.infra.pass_base import PassResult + + +class EliminateRescaleBeforeMulPass(ArmPass): + """Eliminate redundant INT32->INT32 RESCALE ops feeding exclusively into MUL. + + After InsertRescaleInt32Pass and FuseConsecutiveRescalesPass, the graph may + contain INT32->INT32 RESCALE nodes between consecutive elementwise ops. + When such a RESCALE feeds exclusively into MUL ops, it is computationally + redundant and can be removed with a compensating scale adjustment on the + downstream output RESCALE. + + Why only MUL (not ADD/SUB): + For ADD/SUB, InsertRescaleInt32Pass rescales both inputs to a common + scale (2 * max(lhs, rhs) / (1 << shift_bits)) to ensure correct + integer arithmetic — the input RESCALE is required for operand + alignment. For MUL, input scales remain unchanged because the output + scale is the product of input scales (S_out = S_0 * S_1), regardless + of what the input scales are. A RESCALE adjusting scale before MUL is + therefore mathematically redundant: the adjustment can be absorbed + into the downstream output RESCALE as + new_out_scale = old_out_scale * removed_scale. + See InsertRescaleInt32Pass._get_inputs_rescaled_qparams() for the + scale arithmetic distinction. + + Why not Conv2D/MatMul boundaries: + Empirically, eliminating RESCALE ops at Conv2D/MatMul boundaries + causes the Vela NPU compiler to generate worse instruction schedules. + The INT32->INT8->INT32 round-trips at those boundaries provide natural + scheduling breaks that help Vela's register allocator. Removing them + caused +12.9% (CC) and +16.1% (Detector) cycle regressions. + + When multiple eligible RESCALEs feed the same MUL (e.g., both inputs have + INT32->INT32 RESCALEs), each is eliminated sequentially. The downstream + scale adjustments compose correctly because MUL's output scale is + multiplicative: removing RESCALE_A (scale S_a) then RESCALE_B (scale S_b) + yields new_out_scale = old_out_scale * S_a * S_b, which is correct. + """ + + _passes_required_after: Set[Type[ExportPass]] = set() + + def call(self, graph_module: GraphModule) -> PassResult: + graph = graph_module.graph + modified = False + nodes_to_erase = [] + + for node in list(graph.nodes): + node = cast(Node, node) + if not _is_rescale(node): + continue + + # Must be INT32 output + if node.args[1] != torch.int32: + continue + + # Must have zero points of 0 (INT32->INT32 rescales from + # InsertRescaleInt32Pass always have zp=0) + input_zp = node.args[3] + output_zp = node.args[4] + if input_zp != 0 or output_zp != 0: + continue + + # All users must be MUL ops + if len(node.users) == 0: + continue + if not all( + u.op == "call_function" + and u.target == exir_ops.edge.aten.mul.Tensor + for u in node.users + ): + continue + + # All downstream users of each MUL must be RESCALEs so we can + # compensate for the removed scale. Without this guard, non-RESCALE + # consumers of MUL would receive incorrectly scaled values. + if not all( + mul_out.users and all(_is_rescale(u) for u in mul_out.users) + for mul_out in node.users + ): + continue + + # All downstream RESCALEs must produce INT32 (staying within the + # INT32 computation region). If any converts to INT8/INT16, it + # defines a quantization boundary where the annotated scale must + # match the actual integer values. Modifying such a RESCALE would + # break TABLE ops (exp, log, sigmoid, etc.) that build lookup + # tables from the quantization annotation, and would also affect + # Conv/MatMul boundaries where Vela relies on precise scaling. + if not all( + mul_output_user.args[1] == torch.int32 + for mul_out in node.users + for mul_output_user in mul_out.users + ): + continue + + # Check that the input is also INT32 — the preceding node should + # produce INT32 (either another RESCALE with INT32 output, or an + # elementwise op wrapped by InsertRescaleInt32Pass). + rescale_input = node.args[0] + if not _produces_int32(rescale_input): + continue + + removed_scale = float(node.args[2][0]) + + # Adjust the downstream output RESCALE scale for each MUL user + for mul_user in list(node.users): + for mul_output_user in list(mul_user.users): + old_scale = float(mul_output_user.args[2][0]) + new_scale = old_scale * removed_scale + args = list(mul_output_user.args) + args[2] = [new_scale] + mul_output_user.args = tuple(args) + + # Replace the RESCALE with its input + node.replace_all_uses_with(rescale_input) + nodes_to_erase.append(node) + modified = True + + for n in nodes_to_erase: + if len(n.users) == 0: + graph.erase_node(n) + + if modified: + graph_module = super().call(graph_module).graph_module + graph_module.recompile() + + return PassResult(graph_module, modified) + + +def _is_rescale(node: Node) -> bool: + return ( + node.op == "call_function" + and node.target == exir_ops.backend.tosa.RESCALE.default + ) + + +def _produces_int32(node: Node) -> bool: + """Check if a node produces INT32 output.""" + if isinstance(node, Node): + # If it's a RESCALE, check its output dtype arg + if _is_rescale(node): + return node.args[1] == torch.int32 + # For other ops, check the fake tensor metadata + if "val" in node.meta: + val = node.meta["val"] + if isinstance(val, torch.Tensor) and val.dtype == torch.int32: + return True + if hasattr(val, "dtype") and val.dtype == torch.int32: + return True + return False diff --git a/backends/arm/_passes/fuse_consecutive_rescales_pass.py b/backends/arm/_passes/fuse_consecutive_rescales_pass.py new file mode 100644 index 00000000000..8be90a8e5d9 --- /dev/null +++ b/backends/arm/_passes/fuse_consecutive_rescales_pass.py @@ -0,0 +1,124 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import cast, Set, Type + +import torch +from executorch.backends.arm._passes.arm_pass import ArmPass +from executorch.backends.arm._passes.arm_pass_utils import create_node +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass +from torch.fx import GraphModule, Node +from torch.fx.passes.infra.pass_base import PassResult + + +class FuseConsecutiveRescalesPass(ArmPass): + """Fuse consecutive RESCALE(INT32->INT8/INT16) -> RESCALE(INT8/INT16->INT32) + pairs. + + InsertRescaleInt32Pass wraps each add/mul/sub with input rescales + (INT8/INT16->INT32) and an output rescale (INT32->INT8/INT16). When + two such ops are chained (e.g., add1 -> add2), the output rescale + of add1 feeds directly into an input rescale of add2, creating a + redundant INT32->INT8/INT16->INT32 round-trip that loses precision. + + This pass detects such pairs and either: + - Removes both if the composed scale is ~1.0 and zero points match + - Replaces both with a single INT32->INT32 RESCALE with composed + scale + + Handles multi-user R1 nodes: when R1 feeds both RESCALE and + non-RESCALE users, each R1->R2 RESCALE pair is fused individually + while preserving R1 for its non-RESCALE users. + + """ + + _passes_required_after: Set[Type[ExportPass]] = set() + + def call(self, graph_module: GraphModule) -> PassResult: + graph = graph_module.graph + modified = False + nodes_to_erase = [] + + for node in list(graph.nodes): + node = cast(Node, node) + if not _is_rescale(node): + continue + + # R1 = node: output rescale (INT32 -> INT8/INT16) + r1_output_dtype = node.args[1] + if r1_output_dtype not in (torch.int8, torch.int16): + continue + + r1_input = node.args[0] + r1_input_zp = node.args[3] + r1_output_zp = node.args[4] + r1_scale = float(node.args[2][0]) + + # Check each user individually (handles multi-user R1) + for user in list(node.users): + if not _is_rescale(user): + continue + + # R2 = user: input rescale (INT8/INT16 -> INT32) + r2_output_dtype = user.args[1] + if r2_output_dtype != torch.int32: + continue + + r2_input_zp = user.args[3] + + # Guard: intermediate zero points must match for correct + # composition. Without this, the offset term + # (r1_output_zp - r2_input_zp) * r2_scale is silently lost. + if r1_output_zp != r2_input_zp: + continue + + r2_scale = float(user.args[2][0]) + composed_scale = r1_scale * r2_scale + r2_output_zp = user.args[4] + + if abs(composed_scale - 1.0) < 1e-6 and r1_input_zp == r2_output_zp: + # Identity: wire R1's input directly to R2's users + user.replace_all_uses_with(r1_input) + nodes_to_erase.append(user) + else: + # Non-identity: replace with single INT32->INT32 RESCALE + with graph.inserting_before(user): + composed_node = create_node( + graph, + exir_ops.backend.tosa.RESCALE.default, + ( + r1_input, + r2_output_dtype, + [composed_scale], + r1_input_zp, + r2_output_zp, + ), + from_node=user, + ) + user.replace_all_uses_with(composed_node) + nodes_to_erase.append(user) + + modified = True + + # Always consider R1 for removal; actual erasure is guarded below + nodes_to_erase.append(node) + + for node in nodes_to_erase: + if len(node.users) == 0: + graph.erase_node(node) + + if modified: + graph_module = super().call(graph_module).graph_module + graph_module.recompile() + + return PassResult(graph_module, modified) + + +def _is_rescale(node: Node) -> bool: + return ( + node.op == "call_function" + and node.target == exir_ops.backend.tosa.RESCALE.default + ) diff --git a/backends/arm/test/models/test_residual_conv_block.py b/backends/arm/test/models/test_residual_conv_block.py new file mode 100644 index 00000000000..4a830833b93 --- /dev/null +++ b/backends/arm/test/models/test_residual_conv_block.py @@ -0,0 +1,139 @@ +# Copyright 2025-2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +"""Residual conv block model test for ARM TOSA backend. + +Tests a minimal residual architecture with conv->batchnorm->relu->add blocks and +permute operations, representative of quantized signal processing models where +FuseConsecutiveRescalesPass eliminates redundant RESCALE pairs. + +""" + +from typing import Tuple + +import torch +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import ( + EthosU55PipelineINT, + EthosU85PipelineINT, + TosaPipelineFP, + TosaPipelineINT, + VgfPipeline, +) + + +class ResidualConvBlock(torch.nn.Module): + """Residual conv block with batchnorm and permute operations. + + Architecture: conv->bn->relu->add (residual) -> permute -> + conv->bn->relu->add. When quantized, each residual add is + wrapped with INT32 RESCALEs by InsertRescaleInt32Pass. Stacked + blocks create consecutive RESCALE pairs (INT32->INT8->INT32) + between adjacent adds that FuseConsecutiveRescalesPass + eliminates. + + """ + + def __init__(self): + super().__init__() + self.conv1 = torch.nn.Conv2d(3, 3, 3, padding=1) + self.bn1 = torch.nn.BatchNorm2d(3) + self.relu1 = torch.nn.ReLU() + self.conv2 = torch.nn.Conv2d(3, 3, 3, padding=1) + self.bn2 = torch.nn.BatchNorm2d(3) + self.relu2 = torch.nn.ReLU() + + def forward(self, x): + # Block 1: conv → batchnorm → relu → residual add + out = self.relu1(self.bn1(self.conv1(x))) + out = out + x # residual add 1 + + # Channel reordering (common in signal processing models) + out = out.permute(0, 1, 3, 2) + + # Block 2: conv → batchnorm → relu → residual add + out2 = self.relu2(self.bn2(self.conv2(out))) + out2 = out2 + out # residual add 2 + return out2 + + +model = ResidualConvBlock().eval() +model_inputs = (torch.randn(1, 3, 8, 8),) +input_t = Tuple[torch.Tensor] + + +def test_residual_conv_block_tosa_FP(): + pipeline = TosaPipelineFP[input_t]( + model, + model_inputs, + aten_op=[], + exir_op=[], + use_to_edge_transform_and_lower=True, + ) + pipeline.run() + + +def test_residual_conv_block_tosa_INT(): + pipeline = TosaPipelineINT[input_t]( + model, + model_inputs, + aten_op=[], + exir_op=[], + use_to_edge_transform_and_lower=True, + atol=0.25, + qtol=1, + frobenius_threshold=None, + cosine_threshold=None, + ) + pipeline.run() + + +@common.XfailIfNoCorstone300 +def test_residual_conv_block_u55_INT(): + pipeline = EthosU55PipelineINT[input_t]( + model, + model_inputs, + aten_ops=[], + exir_ops=[], + use_to_edge_transform_and_lower=True, + ) + pipeline.run() + + +@common.XfailIfNoCorstone320 +def test_residual_conv_block_u85_INT(): + pipeline = EthosU85PipelineINT[input_t]( + model, + model_inputs, + aten_ops=[], + exir_ops=[], + use_to_edge_transform_and_lower=True, + ) + pipeline.run() + + +@common.SkipIfNoModelConverter +def test_residual_conv_block_vgf_quant(): + pipeline = VgfPipeline[input_t]( + model, + model_inputs, + aten_op=[], + exir_op=[], + use_to_edge_transform_and_lower=True, + quantize=True, + ) + pipeline.run() + + +@common.SkipIfNoModelConverter +def test_residual_conv_block_vgf_no_quant(): + pipeline = VgfPipeline[input_t]( + model, + model_inputs, + aten_op=[], + exir_op=[], + use_to_edge_transform_and_lower=True, + quantize=False, + ) + pipeline.run() diff --git a/backends/arm/test/passes/test_fuse_quantized_activation_pass.py b/backends/arm/test/passes/test_fuse_quantized_activation_pass.py new file mode 100644 index 00000000000..f140549e762 --- /dev/null +++ b/backends/arm/test/passes/test_fuse_quantized_activation_pass.py @@ -0,0 +1,49 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import torch +from executorch.backends.arm._passes.fuse_quantized_activation_pass import ( + FuseQuantizedActivationPass, +) +from executorch.backends.arm.test.tester.test_pipeline import PassPipeline + +input_t = Tuple[torch.Tensor] + + +class ConvRelu(torch.nn.Module): + """Conv2d followed by ReLU — existing fuseable behavior.""" + + def __init__(self) -> None: + super().__init__() + self.conv = torch.nn.Conv2d(3, 3, 3, padding=1) + self.relu = torch.nn.ReLU() + + def get_inputs(self) -> input_t: + return (torch.randn(1, 3, 8, 8),) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.relu(self.conv(x)) + + +def test_fuse_relu_after_conv_quantized() -> None: + """Existing behavior: ReLU after conv is fused in quantized graph.""" + module = ConvRelu() + pipeline = PassPipeline[input_t]( + module, + module.get_inputs(), + quantize=True, + ops_before_pass={ + "executorch_exir_dialects_edge__ops_aten_relu_default": 1, + }, + ops_not_after_pass=[ + "executorch_exir_dialects_edge__ops_aten_relu_default", + ], + pass_list=[FuseQuantizedActivationPass], + ) + pipeline.pop_stage("run_method_and_compare_outputs") + pipeline.run() diff --git a/backends/arm/test/passes/test_rescale_optimization.py b/backends/arm/test/passes/test_rescale_optimization.py new file mode 100644 index 00000000000..49f850cbc00 --- /dev/null +++ b/backends/arm/test/passes/test_rescale_optimization.py @@ -0,0 +1,391 @@ +# Copyright 2025-2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +"""FuseConsecutiveRescalesPass validation tests. + +Tests that InsertRescaleInt32Pass creates consecutive RESCALE pairs between +chained arithmetic ops and that FuseConsecutiveRescalesPass correctly eliminates +them. + +""" + +from typing import Tuple + +import torch +from executorch.backends.arm._passes import ( + FoldAndAnnotateQParamsPass, + FuseConsecutiveRescalesPass, + InsertRescaleInt32Pass, +) +from executorch.backends.arm.test.tester.test_pipeline import PassPipeline + +RESCALE_OP = "executorch_exir_dialects_backend__ops_tosa_RESCALE_default" + + +# ============================================================================ +# Toy Models +# ============================================================================ + + +class AddChain(torch.nn.Module): + """Two cascaded adds: (x + y) + z.""" + + input_t = Tuple[torch.Tensor, torch.Tensor, torch.Tensor] + + def forward(self, x, y, z): + return (x + y) + z + + @staticmethod + def get_test_inputs(): + return ( + torch.randn(1, 3, 8, 8), + torch.randn(1, 3, 8, 8), + torch.randn(1, 3, 8, 8), + ) + + +class BranchingAdd(torch.nn.Module): + """Multi-user R1: (x + y) feeds two downstream adds. + + After InsertRescaleInt32Pass, add1's output RESCALE (R1) feeds + into both add2's and add3's input RESCALEs (R2, R3). The pass + must fuse each R1->R2 and R1->R3 pair individually, removing R1 + only when all RESCALE users are fused away. + + """ + + input_t = Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] + + def forward(self, x, y, z, w): + a = x + y # add1: output RESCALE R1 has two RESCALE users + b = a + z # add2: input RESCALE R2 consumes R1 + c = a + w # add3: input RESCALE R3 consumes R1 + return b + c + + @staticmethod + def get_test_inputs(): + return ( + torch.randn(1, 3, 8, 8), + torch.randn(1, 3, 8, 8), + torch.randn(1, 3, 8, 8), + torch.randn(1, 3, 8, 8), + ) + + +# ============================================================================ +# Assertion Functions (used via pass_functions, not pass_list) +# ============================================================================ + + +def _is_rescale(node): + """Check if a graph node is a TOSA RESCALE op.""" + return node.op == "call_function" and "RESCALE" in str(node.target) + + +def assert_consecutive_rescales_exist(exported_program): + """Assert at least one RESCALE->RESCALE adjacency exists.""" + graph_module = exported_program.graph_module + rescale_info = [] + has_consecutive = False + for node in graph_module.graph.nodes: + if not _is_rescale(node): + continue + user_names = [u.name for u in node.users if _is_rescale(u)] + rescale_info.append(f"{node.name} -> users: {user_names}") + if user_names: + has_consecutive = True + + assert has_consecutive, ( + "RESCALE nodes exist but no consecutive pattern found.\n" + "RESCALE edges:\n" + "\n".join(f" {info}" for info in rescale_info) + ) + return exported_program + + +def assert_no_consecutive_rescales(exported_program): + """Assert no RESCALE->RESCALE adjacency remains after fusion.""" + graph_module = exported_program.graph_module + for node in graph_module.graph.nodes: + if not _is_rescale(node): + continue + rescale_users = [u for u in node.users if _is_rescale(u)] + assert not rescale_users, ( + f"Consecutive RESCALE pair still exists: " + f"{node.name} -> {[u.name for u in rescale_users]}" + ) + return exported_program + + +def assert_rescale_count_reduced(exported_program): + """Assert fusion reduced RESCALE count below pre-fusion level. + + Two chained adds produce 6 RESCALEs before fusion. After fusion, each + consecutive pair is either removed (identity) or replaced by a single + composed RESCALE, so the count must be strictly less than 6. + + """ + graph_module = exported_program.graph_module + count = sum(1 for n in graph_module.graph.nodes if _is_rescale(n)) + assert count < 6, f"Expected fewer than 6 RESCALEs after fusion, got {count}" + return exported_program + + +def assert_no_int8_to_int32_via_int8(exported_program): + """Assert no INT32->INT8->INT32 round-trip patterns remain. + + After fusion, no RESCALE outputting INT8 should feed a RESCALE outputting + INT32 (the specific pattern this pass eliminates). + + """ + graph_module = exported_program.graph_module + for node in graph_module.graph.nodes: + if not _is_rescale(node): + continue + if node.args[1] not in (torch.int8, torch.int16): + continue + for user in node.users: + if _is_rescale(user) and user.args[1] == torch.int32: + raise AssertionError( + f"INT8/INT16->INT32 round-trip still exists: " + f"{node.name} (->INT8/INT16) -> {user.name} (->INT32)" + ) + return exported_program + + +def assert_identity_fusion_no_int32_to_int32(exported_program): + """Assert identity fusion removed both RESCALEs. + + When composed scale is ~1.0 and zero points match, the pass takes + the identity path: both R1 and R2 are removed entirely. No + INT32->INT32 RESCALE should exist since a composed node is only + created on the non-identity path. + + """ + graph_module = exported_program.graph_module + for node in graph_module.graph.nodes: + if not _is_rescale(node): + continue + output_dtype = node.args[1] + if output_dtype == torch.int32: + input_node = node.args[0] + if _is_rescale(input_node) and input_node.args[1] == torch.int32: + raise AssertionError( + f"INT32->INT32 composed RESCALE found ({node.name}), " + f"expected identity fusion to remove both nodes" + ) + return exported_program + + +def assert_exact_rescale_count(expected_count): + """Return assertion function that checks exact RESCALE count after + fusion. + """ + + def _assert(exported_program): + graph_module = exported_program.graph_module + count = sum(1 for n in graph_module.graph.nodes if _is_rescale(n)) + assert ( + count == expected_count + ), f"Expected exactly {expected_count} RESCALEs after fusion, got {count}" + return exported_program + + return _assert + + +# ============================================================================ +# Tests +# ============================================================================ + + +def test_add_chain_rescale_count(): + """Two cascaded adds produce expected RESCALEs. + + Each add has 2 INT8 inputs (need INT8->INT32) and 1 INT32 output (need + INT32->INT8), giving 3 RESCALEs per add = 6 total for two chained adds. + + """ + model = AddChain() + pipeline = PassPipeline[AddChain.input_t]( + model, + model.get_test_inputs(), + quantize=True, + ops_not_before_pass={RESCALE_OP}, + ops_after_pass={RESCALE_OP: 6}, + pass_list=[FoldAndAnnotateQParamsPass, InsertRescaleInt32Pass], + ) + pipeline.pop_stage("run_method_and_compare_outputs") + pipeline.run() + + +def test_add_chain_consecutive_rescales(): + """Consecutive RESCALE->RESCALE pattern exists between adds. + + add1's output RESCALE (INT32->INT8) feeds directly into add2's input RESCALE + (INT8->INT32), creating a redundant round-trip. + + """ + model = AddChain() + pipeline = PassPipeline[AddChain.input_t]( + model, + model.get_test_inputs(), + quantize=True, + pass_list=[FoldAndAnnotateQParamsPass, InsertRescaleInt32Pass], + pass_functions=[assert_consecutive_rescales_exist], + ) + pipeline.pop_stage("run_method_and_compare_outputs") + pipeline.run() + + +def test_fuse_consecutive_rescales(): + """FuseConsecutiveRescalesPass eliminates consecutive pairs. + + After InsertRescaleInt32Pass, chained adds produce RESCALEs with consecutive + INT32->INT8->INT32 pairs. FuseConsecutiveRescalesPass merges each pair into + a single composed RESCALE, eliminating all consecutive adjacencies. + + """ + model = AddChain() + pipeline = PassPipeline[AddChain.input_t]( + model, + model.get_test_inputs(), + quantize=True, + pass_list=[ + FoldAndAnnotateQParamsPass, + InsertRescaleInt32Pass, + FuseConsecutiveRescalesPass, + ], + pass_functions=[assert_no_consecutive_rescales], + ) + pipeline.pop_stage("run_method_and_compare_outputs") + pipeline.run() + + +def test_fuse_identity_reduces_rescale_count(): + """Identity fusion removes both RESCALEs rather than composing. + + Two chained adds produce 6 RESCALEs. The consecutive pair between + add1's output and add2's input has composed scale ~1.0 (symmetric + requantize then dequantize), so the pass takes the identity path: + both R1 (INT32->INT8) and R2 (INT8->INT32) are removed entirely. + This is verified by: + 1. RESCALE count drops from 6 to 4 (one pair removed) + 2. No INT32->INT32 RESCALE exists (identity path, not composed) + 3. No INT8->INT32 round-trip remains + + """ + model = AddChain() + pipeline = PassPipeline[AddChain.input_t]( + model, + model.get_test_inputs(), + quantize=True, + pass_list=[ + FoldAndAnnotateQParamsPass, + InsertRescaleInt32Pass, + FuseConsecutiveRescalesPass, + ], + pass_functions=[ + assert_rescale_count_reduced, + assert_identity_fusion_no_int32_to_int32, + assert_no_int8_to_int32_via_int8, + ], + ) + pipeline.pop_stage("run_method_and_compare_outputs") + pipeline.run() + + +def test_fuse_branching_add_multi_user(): + """Multi-user R1: add1's output RESCALE feeds two adds. + + BranchingAdd creates (x+y) which feeds both (a+z) and (a+w). + After InsertRescaleInt32Pass, add1's output RESCALE R1 has two + RESCALE users (R2 for add2, R3 for add3). The pass must fuse + each pair individually and only remove R1 when all its RESCALE + users have been fused. + + """ + model = BranchingAdd() + pipeline = PassPipeline[BranchingAdd.input_t]( + model, + model.get_test_inputs(), + quantize=True, + pass_list=[FoldAndAnnotateQParamsPass, InsertRescaleInt32Pass], + pass_functions=[assert_consecutive_rescales_exist], + ) + pipeline.pop_stage("run_method_and_compare_outputs") + pipeline.run() + + +def test_fuse_branching_add_eliminates_pairs(): + """Multi-user fusion eliminates all consecutive pairs. + + After FuseConsecutiveRescalesPass, no RESCALE->RESCALE adjacencies should + remain, even when R1 originally had multiple RESCALE users. + + """ + model = BranchingAdd() + pipeline = PassPipeline[BranchingAdd.input_t]( + model, + model.get_test_inputs(), + quantize=True, + pass_list=[ + FoldAndAnnotateQParamsPass, + InsertRescaleInt32Pass, + FuseConsecutiveRescalesPass, + ], + pass_functions=[ + assert_no_consecutive_rescales, + assert_no_int8_to_int32_via_int8, + ], + ) + pipeline.pop_stage("run_method_and_compare_outputs") + pipeline.run() + + +def test_fuse_consecutive_rescales_output_correctness(): + """End-to-end correctness: fused graph matches original. + + Keeps run_method_and_compare_outputs enabled to verify that the + RESCALE fusion does not change numerical results. + + """ + model = AddChain() + pipeline = PassPipeline[AddChain.input_t]( + model, + model.get_test_inputs(), + quantize=True, + pass_list=[ + FoldAndAnnotateQParamsPass, + InsertRescaleInt32Pass, + FuseConsecutiveRescalesPass, + ], + pass_functions=[assert_no_consecutive_rescales], + ) + pipeline.run() + + +def test_fuse_branching_add_output_correctness(): + """Multi-user end-to-end correctness: fused branching graph + matches original. + + Keeps run_method_and_compare_outputs enabled to verify that + multi-user R1 fusion (where R1 feeds multiple downstream + RESCALEs) does not change numerical results. + + """ + model = BranchingAdd() + pipeline = PassPipeline[BranchingAdd.input_t]( + model, + model.get_test_inputs(), + quantize=True, + pass_list=[ + FoldAndAnnotateQParamsPass, + InsertRescaleInt32Pass, + FuseConsecutiveRescalesPass, + ], + pass_functions=[ + assert_no_consecutive_rescales, + assert_no_int8_to_int32_via_int8, + ], + ) + pipeline.run()