From b76503573d81c9e608570d3453af7626c2c80498 Mon Sep 17 00:00:00 2001 From: Oscar Andersson Date: Thu, 30 Apr 2026 14:31:14 +0200 Subject: [PATCH] Arm backend: Support permute-removal for TOSA ops Adds a TOSA specific variant of RemovePermutesAroundElementwiseOps that makes sure that elementwise TOSA backend dialect operators also are covered by this pass. As of now this includes TABLE and RESCALE. Signed-off-by: Oscar Andersson Change-Id: Ia834a5b641af33419d1210be95b5ad7566e857b3 --- backends/arm/_passes/__init__.py | 3 + backends/arm/_passes/arm_pass_manager.py | 6 +- ...ve_permutes_around_elementwise_tosa_ops.py | 17 ++++++ ...ve_permutes_around_elementwise_tosa_ops.py | 59 +++++++++++++++++++ 4 files changed, 81 insertions(+), 4 deletions(-) create mode 100644 backends/arm/_passes/remove_permutes_around_elementwise_tosa_ops.py create mode 100644 backends/arm/test/passes/test_remove_permutes_around_elementwise_tosa_ops.py diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index 3f2cde5adef..d969af1bedf 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -143,6 +143,9 @@ from .remove_getitem_pass import RemoveGetItemPass # noqa from .remove_graph_asserts_pass import RemoveGraphAssertsPass # noqa from .remove_noop_pass import RemoveNoopPass # noqa +from .remove_permutes_around_elementwise_tosa_ops import ( # noqa + RemovePermutesAroundElementwiseTosaOps, +) from .replace_scalar_with_tensor_pass import ( # noqa ReplaceScalarWithTensorByProfilePass, ) diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index e39d8d605f4..589b056b9fb 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -124,6 +124,7 @@ RemoveGetItemPass, RemoveGraphAssertsPass, RemoveNoopPass, + RemovePermutesAroundElementwiseTosaOps, ReplaceInfAndLimitValuesPass, ReplaceScalarWithTensorByProfilePass, RewriteAvgPool2dPass, @@ -163,9 +164,6 @@ PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView, ) -from executorch.backends.transforms.remove_permutes_around_elementwise_ops import ( - RemovePermutesAroundElementwiseOps, -) from executorch.exir import ExportedProgram from executorch.exir.pass_base import ExportPass from executorch.exir.pass_manager import PassManager @@ -538,7 +536,7 @@ def _tosa_pipeline( RewritePadPass(), RewriteSlicePass(), FuseViewCopyTransformPass(), - RemovePermutesAroundElementwiseOps(), + RemovePermutesAroundElementwiseTosaOps(), PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView(), FuseCascadedTransposeOrPermuteOps(), ConvertPermuteSingletonToViewPass(), diff --git a/backends/arm/_passes/remove_permutes_around_elementwise_tosa_ops.py b/backends/arm/_passes/remove_permutes_around_elementwise_tosa_ops.py new file mode 100644 index 00000000000..bc03ebacd81 --- /dev/null +++ b/backends/arm/_passes/remove_permutes_around_elementwise_tosa_ops.py @@ -0,0 +1,17 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from executorch.backends.transforms.remove_permutes_around_elementwise_ops import ( + RemovePermutesAroundElementwiseOps, +) +from executorch.exir.dialects._ops import ops as exir_ops + + +class RemovePermutesAroundElementwiseTosaOps(RemovePermutesAroundElementwiseOps): + permutable_ops = { + *RemovePermutesAroundElementwiseOps.permutable_ops, + exir_ops.backend.tosa.RESCALE.default, + exir_ops.backend.tosa.TABLE.default, + } diff --git a/backends/arm/test/passes/test_remove_permutes_around_elementwise_tosa_ops.py b/backends/arm/test/passes/test_remove_permutes_around_elementwise_tosa_ops.py new file mode 100644 index 00000000000..341d985134e --- /dev/null +++ b/backends/arm/test/passes/test_remove_permutes_around_elementwise_tosa_ops.py @@ -0,0 +1,59 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.backends.arm._passes.remove_permutes_around_elementwise_tosa_ops import ( + RemovePermutesAroundElementwiseTosaOps, +) +from executorch.backends.arm.tosa.specification import ( + TosaLoweringContext, + TosaSpecification, +) +from executorch.exir.dialects._ops import ops as exir_ops + +TOSA_INT_SPEC = TosaSpecification.create_from_string("TOSA-1.0+INT") +PERMUTE_TARGET = exir_ops.edge.aten.permute_copy.default +RESCALE_TARGET = exir_ops.backend.tosa.RESCALE.default +TABLE_TARGET = exir_ops.backend.tosa.TABLE.default + + +def _count_nodes(graph_module: torch.fx.GraphModule, target) -> int: + return sum( + 1 + for node in graph_module.graph.nodes + if node.op == "call_function" and node.target == target + ) + + +def test_remove_permutes_around_rescale_tosa_INT() -> None: + graph = torch.fx.Graph() + x = graph.placeholder("x") + x.meta["val"] = torch.randn(1, 3, 4, 5) + + permute_in = graph.create_node( + "call_function", + PERMUTE_TARGET, + args=(x, [0, 2, 3, 1]), + ) + rescale = graph.create_node( + "call_function", + RESCALE_TARGET, + args=(permute_in, torch.int8, [1.0], 0, 0), + ) + permute_out = graph.create_node( + "call_function", + PERMUTE_TARGET, + args=(rescale, [0, 3, 1, 2]), + ) + graph.output(permute_out) + + graph_module = torch.fx.GraphModule({}, graph) + + with TosaLoweringContext(TOSA_INT_SPEC): + result = RemovePermutesAroundElementwiseTosaOps().call(graph_module) + + assert result.modified + assert _count_nodes(result.graph_module, PERMUTE_TARGET) == 0 + assert _count_nodes(result.graph_module, RESCALE_TARGET) == 1