diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index fedca6eb65b..a9e93e61b8f 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -625,7 +625,6 @@ def _tosa_pipeline( DecomposePermuteForU55Pass(), RewriteSlicePass(), InsertConstShapesPass(), - ExirToTosaPass(exported_program), ] ) @@ -634,6 +633,7 @@ def _tosa_pipeline( [ CastInt64BuffersToInt32Pass(exported_program), FuseEqualPlaceholdersPass(exported_program), + ExirToTosaPass(exported_program), SymbolicToTosaShapesPass(), InsertDynamicPaddingPass(), FuseConsecutiveConcatShapesPass(), diff --git a/backends/arm/_passes/aten_to_tosa_tensor_operators.py b/backends/arm/_passes/aten_to_tosa_tensor_operators.py index 140aa87615f..d793628ce45 100644 --- a/backends/arm/_passes/aten_to_tosa_tensor_operators.py +++ b/backends/arm/_passes/aten_to_tosa_tensor_operators.py @@ -24,3 +24,47 @@ def rewrite_argmax(node: Node, pass_: AtenToDialectPass) -> DialectNodeSpec: (input_node, dim), {}, ) + + +def rewrite_binary_operator( + node: Node, pass_: AtenToDialectPass +) -> DialectNodeSpec | None: + match node.target: + case exir_ops.edge.aten.add.Tensor: + target = exir_ops.backend.tosa.ADD.default + case exir_ops.edge.aten.bitwise_and.Tensor: + target = exir_ops.backend.tosa.BITWISE_AND.default + case exir_ops.edge.aten.bitwise_left_shift.Tensor: + target = exir_ops.backend.tosa.LOGICAL_LEFT_SHIFT.default + case exir_ops.edge.aten.bitwise_or.Tensor: + target = exir_ops.backend.tosa.BITWISE_OR.default + case exir_ops.edge.aten.bitwise_right_shift.Tensor: + target = exir_ops.backend.tosa.ARITHMETIC_RIGHT_SHIFT.default + case exir_ops.edge.aten.bitwise_xor.Tensor: + target = exir_ops.backend.tosa.BITWISE_XOR.default + case exir_ops.edge.aten.eq.Tensor: + target = exir_ops.backend.tosa.EQUAL.default + case exir_ops.edge.aten.ge.Tensor: + target = exir_ops.backend.tosa.GREATER_EQUAL.default + case exir_ops.edge.aten.gt.Tensor: + target = exir_ops.backend.tosa.GREATER.default + case exir_ops.edge.aten.logical_and.default: + target = exir_ops.backend.tosa.LOGICAL_AND.default + case exir_ops.edge.aten.logical_or.default: + target = exir_ops.backend.tosa.LOGICAL_OR.default + case exir_ops.edge.aten.logical_xor.default: + target = exir_ops.backend.tosa.LOGICAL_XOR.default + case exir_ops.edge.aten.maximum.default: + target = exir_ops.backend.tosa.MAXIMUM.default + case exir_ops.edge.aten.minimum.default: + target = exir_ops.backend.tosa.MINIMUM.default + case exir_ops.edge.aten.mul.Tensor: + target = exir_ops.backend.tosa.MUL.default + case exir_ops.edge.aten.pow.Tensor_Tensor: + target = exir_ops.backend.tosa.POW.default + case exir_ops.edge.aten.sub.Tensor: + target = exir_ops.backend.tosa.SUB.default + case _: + return None + + return DialectNodeSpec(target, node.args, dict(node.kwargs)) diff --git a/backends/arm/_passes/exir_to_tosa_pass.py b/backends/arm/_passes/exir_to_tosa_pass.py index c0c6efb1a6c..91814541217 100644 --- a/backends/arm/_passes/exir_to_tosa_pass.py +++ b/backends/arm/_passes/exir_to_tosa_pass.py @@ -3,17 +3,24 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from collections.abc import Callable + import executorch.backends.arm.tosa.dialect # noqa: F401 from executorch.backends.arm._passes.aten_to_tosa_activation_functions import ( get_activation_replacement, ) -from executorch.backends.arm._passes.aten_to_tosa_tensor_operators import rewrite_argmax +from executorch.backends.arm._passes.aten_to_tosa_tensor_operators import ( + rewrite_argmax, + rewrite_binary_operator, +) from executorch.backends.transforms.aten_to_dialect_pass import ( AtenToDialectPass, DialectNodeSpec, + SubstitutionFn, ) from executorch.exir.dialects._ops import ops as exir_ops from torch.fx import Node +from torch.fx.node import Target class ExirToTosaPass(AtenToDialectPass): @@ -25,6 +32,17 @@ class ExirToTosaPass(AtenToDialectPass): """ +def register_dialect_substitutions( + *targets: Target, +) -> Callable[[SubstitutionFn], SubstitutionFn]: + def decorator(func: SubstitutionFn) -> SubstitutionFn: + for target in targets: + ExirToTosaPass.register_dialect_substitution(target)(func) + return func + + return decorator + + @ExirToTosaPass.register_dialect_substitution(exir_ops.edge.aten.argmax.default) def _get_tensor_operators_replacement( node: Node, pass_: AtenToDialectPass @@ -32,10 +50,37 @@ def _get_tensor_operators_replacement( return rewrite_argmax(node, pass_) -@ExirToTosaPass.register_dialect_substitution(exir_ops.edge.aten.clamp.default) -@ExirToTosaPass.register_dialect_substitution(exir_ops.edge.aten.erf.default) -@ExirToTosaPass.register_dialect_substitution(exir_ops.edge.aten.sigmoid.default) -@ExirToTosaPass.register_dialect_substitution(exir_ops.edge.aten.tanh.default) +@register_dialect_substitutions( + exir_ops.edge.aten.add.Tensor, + exir_ops.edge.aten.bitwise_and.Tensor, + exir_ops.edge.aten.bitwise_left_shift.Tensor, + exir_ops.edge.aten.bitwise_or.Tensor, + exir_ops.edge.aten.bitwise_right_shift.Tensor, + exir_ops.edge.aten.bitwise_xor.Tensor, + exir_ops.edge.aten.eq.Tensor, + exir_ops.edge.aten.ge.Tensor, + exir_ops.edge.aten.gt.Tensor, + exir_ops.edge.aten.logical_and.default, + exir_ops.edge.aten.logical_or.default, + exir_ops.edge.aten.logical_xor.default, + exir_ops.edge.aten.maximum.default, + exir_ops.edge.aten.minimum.default, + exir_ops.edge.aten.mul.Tensor, + exir_ops.edge.aten.pow.Tensor_Tensor, + exir_ops.edge.aten.sub.Tensor, +) +def _get_binary_operator_replacement( + node: Node, pass_: AtenToDialectPass +) -> DialectNodeSpec | None: + return rewrite_binary_operator(node, pass_) + + +@register_dialect_substitutions( + exir_ops.edge.aten.clamp.default, + exir_ops.edge.aten.erf.default, + exir_ops.edge.aten.sigmoid.default, + exir_ops.edge.aten.tanh.default, +) def _get_activation_replacement( node: Node, pass_: AtenToDialectPass ) -> DialectNodeSpec | None: diff --git a/backends/arm/_passes/promote_bool_operands_pass.py b/backends/arm/_passes/promote_bool_operands_pass.py index 8e162ded1bd..774eb14dead 100644 --- a/backends/arm/_passes/promote_bool_operands_pass.py +++ b/backends/arm/_passes/promote_bool_operands_pass.py @@ -3,9 +3,9 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# The TOSA BITWISE_AND, BITWISE_OR, and BITWISE_XOR don't handle bool inputs. -# When a targeted op receives boolean tensors, we promote them to an integer type before -# invocation and cast the result back to the expected dtype afterwards. +# Some TOSA ops don't handle bool inputs. When a targeted op receives boolean +# tensors, we promote them to an integer type before invocation and cast the +# result back to the expected dtype afterwards. from typing import Set, Type @@ -23,10 +23,9 @@ class PromoteBoolOperandsPass(ArmOpTargetedPass): _passes_required_after: Set[Type[ExportPass]] = set() + # Bool bitwise ops are handled by RewriteBoolBitwiseToLogicalPass. Promoting + # them here would hide the bool dtype and prevent that rewrite. target_ops = { - exir_ops.edge.aten.bitwise_and.Tensor, - exir_ops.edge.aten.bitwise_or.Tensor, - exir_ops.edge.aten.bitwise_xor.Tensor, exir_ops.edge.aten.mul.Tensor, } @@ -41,14 +40,9 @@ def call_operator(self, op, args, kwargs, meta): # select the first non-bool dtype, or None if all bool promoted_dtype = next((dt for dt in original_dtypes if dt != torch.bool), None) - # if we don't have a dtype specified by the op, promote to default choice for the op + # If all operands are bool, promote mul to int32. if promoted_dtype is None: - if op == exir_ops.edge.aten.mul.Tensor: - # mul as int32 - promoted_dtype = torch.int32 - else: - # bitwise ops can be int8 - promoted_dtype = torch.int8 + promoted_dtype = torch.int32 target_dtypes = [] for dt in original_dtypes: diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index 82a529d62a2..469c9c0fb07 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -173,6 +173,15 @@ def _is_quantized_constant(node: torch.fx.Node) -> bool: return len(users) > 0 +def _floating_profile_negative_checks( + tosa_spec: TosaSpecification, reporter: WhyNoPartitionReporter +) -> list[OperatorSupportBase]: + checks: list[OperatorSupportBase] = [CheckMixedFloatingInputs(reporter)] + if not tosa_spec.support_integer(): + checks.append(CheckInt32ComparisonInputs(reporter)) + return checks + + def is_quantized(node: torch.fx.Node) -> bool: """Checks if the node is quantized. @@ -341,7 +350,7 @@ def _negative_checks( checks.extend(_wrapped_additional_checks(additional_checks, reporter)) if tosa_spec.support_float(): - checks.append(CheckMixedFloatingInputs(reporter)) + checks.extend(_floating_profile_negative_checks(tosa_spec, reporter)) else: checks.append(CheckArmQuantized(reporter)) checks.append(CheckProperQuantization(reporter)) @@ -995,6 +1004,47 @@ def is_node_supported( return True +class CheckInt32ComparisonInputs(OperatorSupportBase): + """Reject int32 comparisons under the FP profile.""" + + target_ops = { + exir_ops.edge.aten.eq.Tensor, + exir_ops.edge.aten.eq.Scalar, + exir_ops.edge.aten.ge.Tensor, + exir_ops.edge.aten.ge.Scalar, + exir_ops.edge.aten.gt.Tensor, + exir_ops.edge.aten.gt.Scalar, + exir_ops.edge.aten.le.Tensor, + exir_ops.edge.aten.le.Scalar, + exir_ops.edge.aten.lt.Tensor, + exir_ops.edge.aten.lt.Scalar, + } + + def __init__(self, reporter: WhyNoPartitionReporter) -> None: + self.reporter = reporter + super().__init__() + + def is_node_supported( + self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node + ) -> bool: + if node.target not in self.target_ops: + return True + + for input_node in ( + input_node + for input_node in node.all_input_nodes + if input_node.op != "get_attr" + ): + if get_first_fake_tensor(input_node).dtype == torch.int32: + self.reporter.report_reject( + node, + "FP profile does not support int32 comparison inputs.", + ) + return False + + return True + + class RankCheck(OperatorSupportBase): """Reject nodes with rank greater than ``max_rank``.""" diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py index 1acaf4e65ef..1e0a671e1ae 100644 --- a/backends/arm/operators/__init__.py +++ b/backends/arm/operators/__init__.py @@ -12,7 +12,6 @@ from . import ( # noqa node_visitor, op_abs, - op_add, op_amax, op_amin, op_any, @@ -21,30 +20,23 @@ op_ceil, op_cond_if, op_cos, - op_eq, op_exp, op_floor, - op_ge, - op_gt, op_log, op_logical_not, - op_maximum, - op_minimum, - op_mul, op_neg, op_permute, - op_pow, op_reciprocal, op_repeat, - op_rshift_tensor, op_rsqrt, op_sin, - op_sub, op_sum, op_to_dim_order_copy, + op_tosa_add, op_tosa_argmax, op_tosa_avg_pool2d, op_tosa_avg_pool2d_adaptive, + op_tosa_binary_ops, op_tosa_cast_to_block_scaled, op_tosa_clamp, op_tosa_conv2d, @@ -52,25 +44,33 @@ op_tosa_conv3d, op_tosa_custom, op_tosa_depthwise_conv2d, + op_tosa_eq, op_tosa_erf, op_tosa_gather, + op_tosa_ge, + op_tosa_gt, op_tosa_identity, op_tosa_matmul, op_tosa_matmul_t_block_scaled, op_tosa_max_pool2d, op_tosa_max_pool2d_adaptive, + op_tosa_maximum, + op_tosa_minimum, + op_tosa_mul, op_tosa_pad, + op_tosa_pow, op_tosa_rescale, op_tosa_resize, + op_tosa_rshift_tensor, op_tosa_scatter, op_tosa_shapes, op_tosa_sigmoid, op_tosa_slice, + op_tosa_sub, op_tosa_table, op_tosa_tanh, op_tosa_transpose_conv2d, op_view, op_where, op_while, - ops_binary, ) diff --git a/backends/arm/operators/op_add.py b/backends/arm/operators/op_tosa_add.py similarity index 96% rename from backends/arm/operators/op_add.py rename to backends/arm/operators/op_tosa_add.py index 7f95dde94f8..a9cee733874 100644 --- a/backends/arm/operators/op_add.py +++ b/backends/arm/operators/op_tosa_add.py @@ -14,7 +14,7 @@ @register_node_visitor class AddVisitor(SimpleNodeVisitor): - target = "aten.add.Tensor" + target = "tosa.ADD.default" @classmethod def get_config(cls) -> SimpleNodeVisitorConfig: diff --git a/backends/arm/operators/op_tosa_binary_ops.py b/backends/arm/operators/op_tosa_binary_ops.py new file mode 100644 index 00000000000..4485fa514d8 --- /dev/null +++ b/backends/arm/operators/op_tosa_binary_ops.py @@ -0,0 +1,93 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, List + +import tosa_serializer as ts + +from executorch.backends.arm.operators.node_visitor import register_node_visitor +from executorch.backends.arm.operators.simple_node_visitor import ( + SimpleNodeVisitor, + SimpleNodeVisitorConfig, +) + + +def binary_operator_factory( + target: str, + tosa_op, + attr_method: str, + valid_dtypes: List[Any], +): + operator_target = target + + class BinaryOperator(SimpleNodeVisitor): + target = operator_target + + @classmethod + def get_config(cls) -> SimpleNodeVisitorConfig: + return SimpleNodeVisitorConfig( + tosa_op=tosa_op, + attr_method=attr_method, + num_inputs=2, + input_dtypes=valid_dtypes, + ) + + register_node_visitor(BinaryOperator) + + +binary_operator_factory( + "tosa.BITWISE_AND.default", + ts.Op.BITWISE_AND, + "BitwiseAndAttribute", + [ts.DType.INT8, ts.DType.INT16, ts.DType.INT32], +) +binary_operator_factory( + "tosa.BITWISE_OR.default", + ts.Op.BITWISE_OR, + "BitwiseOrAttribute", + [ts.DType.INT8, ts.DType.INT16, ts.DType.INT32], +) +binary_operator_factory( + "tosa.BITWISE_XOR.default", + ts.Op.BITWISE_XOR, + "BitwiseXorAttribute", + [ts.DType.INT8, ts.DType.INT16, ts.DType.INT32], +) +binary_operator_factory( + "tosa.INTDIV.default", + ts.Op.INTDIV, + "IntDivAttribute", + [ts.DType.INT32], +) +binary_operator_factory( + "tosa.LOGICAL_AND.default", + ts.Op.LOGICAL_AND, + "LogicalAndAttribute", + [ts.DType.BOOL], +) +binary_operator_factory( + "tosa.LOGICAL_LEFT_SHIFT.default", + ts.Op.LOGICAL_LEFT_SHIFT, + "LogicalLeftShiftAttribute", + [ts.DType.INT8, ts.DType.INT16, ts.DType.INT32], +) +binary_operator_factory( + "tosa.LOGICAL_OR.default", + ts.Op.LOGICAL_OR, + "LogicalOrAttribute", + [ts.DType.BOOL], +) +binary_operator_factory( + "tosa.LOGICAL_RIGHT_SHIFT.default", + ts.Op.LOGICAL_RIGHT_SHIFT, + "LogicalRightShiftAttribute", + [ts.DType.INT8, ts.DType.INT16, ts.DType.INT32], +) +binary_operator_factory( + "tosa.LOGICAL_XOR.default", + ts.Op.LOGICAL_XOR, + "LogicalXorAttribute", + [ts.DType.BOOL], +) diff --git a/backends/arm/operators/op_eq.py b/backends/arm/operators/op_tosa_eq.py similarity index 96% rename from backends/arm/operators/op_eq.py rename to backends/arm/operators/op_tosa_eq.py index 7d5472d07ed..7b961765942 100644 --- a/backends/arm/operators/op_eq.py +++ b/backends/arm/operators/op_tosa_eq.py @@ -14,7 +14,7 @@ @register_node_visitor class EqualVisitor(SimpleNodeVisitor): - target = "aten.eq.Tensor" + target = "tosa.EQUAL.default" @classmethod def get_config(cls) -> SimpleNodeVisitorConfig: diff --git a/backends/arm/operators/op_ge.py b/backends/arm/operators/op_tosa_ge.py similarity index 95% rename from backends/arm/operators/op_ge.py rename to backends/arm/operators/op_tosa_ge.py index aec8dc96044..0b4e9358870 100644 --- a/backends/arm/operators/op_ge.py +++ b/backends/arm/operators/op_tosa_ge.py @@ -16,7 +16,7 @@ @register_node_visitor class GreaterEqualVisitor(SimpleNodeVisitor): - target = "aten.ge.Tensor" + target = "tosa.GREATER_EQUAL.default" @classmethod def get_config(cls) -> SimpleNodeVisitorConfig: diff --git a/backends/arm/operators/op_gt.py b/backends/arm/operators/op_tosa_gt.py similarity index 96% rename from backends/arm/operators/op_gt.py rename to backends/arm/operators/op_tosa_gt.py index f7b05889d1d..5ffb22879f3 100644 --- a/backends/arm/operators/op_gt.py +++ b/backends/arm/operators/op_tosa_gt.py @@ -16,7 +16,7 @@ @register_node_visitor class GreaterThanVisitor(SimpleNodeVisitor): - target = "aten.gt.Tensor" + target = "tosa.GREATER.default" @classmethod def get_config(cls) -> SimpleNodeVisitorConfig: diff --git a/backends/arm/operators/op_maximum.py b/backends/arm/operators/op_tosa_maximum.py similarity index 96% rename from backends/arm/operators/op_maximum.py rename to backends/arm/operators/op_tosa_maximum.py index de62834fc65..d291c5173b7 100644 --- a/backends/arm/operators/op_maximum.py +++ b/backends/arm/operators/op_tosa_maximum.py @@ -14,7 +14,7 @@ @register_node_visitor class MaximumVisitor(SimpleNodeVisitor): - target = "aten.maximum.default" + target = "tosa.MAXIMUM.default" @classmethod def get_config(cls) -> SimpleNodeVisitorConfig: diff --git a/backends/arm/operators/op_minimum.py b/backends/arm/operators/op_tosa_minimum.py similarity index 96% rename from backends/arm/operators/op_minimum.py rename to backends/arm/operators/op_tosa_minimum.py index fd1348a4405..602b112c624 100644 --- a/backends/arm/operators/op_minimum.py +++ b/backends/arm/operators/op_tosa_minimum.py @@ -14,7 +14,7 @@ @register_node_visitor class MinimumVisitor(SimpleNodeVisitor): - target = "aten.minimum.default" + target = "tosa.MINIMUM.default" @classmethod def get_config(cls) -> SimpleNodeVisitorConfig: diff --git a/backends/arm/operators/op_mul.py b/backends/arm/operators/op_tosa_mul.py similarity index 97% rename from backends/arm/operators/op_mul.py rename to backends/arm/operators/op_tosa_mul.py index e442ce10c1f..65521eab88d 100644 --- a/backends/arm/operators/op_mul.py +++ b/backends/arm/operators/op_tosa_mul.py @@ -3,11 +3,9 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. - from typing import Any, List -import torch - +import torch.fx import tosa_serializer as ts from executorch.backends.arm.operators.node_visitor import ( @@ -24,7 +22,7 @@ @register_node_visitor class MulVisitor(NodeVisitor): - target = "aten.mul.Tensor" + target = "tosa.MUL.default" def define_node( self, diff --git a/backends/arm/operators/op_pow.py b/backends/arm/operators/op_tosa_pow.py similarity index 96% rename from backends/arm/operators/op_pow.py rename to backends/arm/operators/op_tosa_pow.py index 6de49449743..49d47c97f1b 100644 --- a/backends/arm/operators/op_pow.py +++ b/backends/arm/operators/op_tosa_pow.py @@ -17,7 +17,7 @@ @register_node_visitor class PowVisitor(SimpleNodeVisitor): - target = "aten.pow.Tensor_Tensor" + target = "tosa.POW.default" tosa_specs = FP_SPECS @classmethod diff --git a/backends/arm/operators/op_rshift_tensor.py b/backends/arm/operators/op_tosa_rshift_tensor.py similarity index 96% rename from backends/arm/operators/op_rshift_tensor.py rename to backends/arm/operators/op_tosa_rshift_tensor.py index e700c9a6db0..607a74858be 100644 --- a/backends/arm/operators/op_rshift_tensor.py +++ b/backends/arm/operators/op_tosa_rshift_tensor.py @@ -3,11 +3,9 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. - from typing import Any, List -import torch - +import torch.fx import tosa_serializer as ts from executorch.backends.arm.operators.node_visitor import ( @@ -24,7 +22,7 @@ @register_node_visitor class RshiftVisitor(NodeVisitor): - target = "aten.bitwise_right_shift.Tensor" + target = "tosa.ARITHMETIC_RIGHT_SHIFT.default" def define_node( self, @@ -49,7 +47,6 @@ def define_node( # TODO MLETORCH-525 Emulate round == False with different decomposition round = True attr.ArithmeticRightShiftAttribute(round=round) - self._serialize_operator( node, tosa_graph, diff --git a/backends/arm/operators/op_sub.py b/backends/arm/operators/op_tosa_sub.py similarity index 96% rename from backends/arm/operators/op_sub.py rename to backends/arm/operators/op_tosa_sub.py index a9786027e33..5c95919e0e5 100644 --- a/backends/arm/operators/op_sub.py +++ b/backends/arm/operators/op_tosa_sub.py @@ -14,7 +14,7 @@ @register_node_visitor class SubVisitor(SimpleNodeVisitor): - target = "aten.sub.Tensor" + target = "tosa.SUB.default" @classmethod def get_config(cls) -> SimpleNodeVisitorConfig: diff --git a/backends/arm/operators/ops_binary.py b/backends/arm/operators/ops_binary.py deleted file mode 100644 index 37a8fd226ff..00000000000 --- a/backends/arm/operators/ops_binary.py +++ /dev/null @@ -1,113 +0,0 @@ -# Copyright 2025-2026 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - - -from typing import Any, Callable, List - -import torch -import torch.fx - -import tosa_serializer as ts - -from executorch.backends.arm.operators.node_visitor import ( - NodeVisitor, - register_node_visitor, -) -from executorch.backends.arm.operators.operator_validation_utils import ( - validate_num_inputs, - validate_same_dtype, - validate_valid_dtype, -) -from executorch.backends.arm.tosa.mapping import TosaArg - - -def binary_operator_factory( - bw_target: str, tosa_op, attr_builder: Callable[[Any], None] -): - """Creates and registers NodeVisitors for operators that have two inputs and - map directly to a TOSA op. - """ - - class BinaryOperator(NodeVisitor): - target = bw_target - - def define_node( - self, - node: torch.fx.Node, - tosa_graph: Any, - inputs: List[TosaArg], - output: TosaArg, - ) -> None: - validate_num_inputs(self.target, inputs, 2) - validate_same_dtype(self.target, [*inputs, output], ts) - - if self.target in [ - "aten.bitwise_and.Tensor", - "aten.bitwise_xor.Tensor", - "aten.bitwise_or.Tensor", - "aten.bitwise_left_shift.Tensor", - ]: - validate_valid_dtype( - self.target, - [*inputs, output], - [ts.DType.INT8, ts.DType.INT16, ts.DType.INT32], - self.tosa_spec, - ) - if self.target in [ - "aten.logical_and.default", - "aten.logical_xor.defaul", - "aten.logical_or.default", - ]: - validate_valid_dtype( - self.target, - [*inputs, output], - [ts.DType.BOOL], - self.tosa_spec, - ) - attr = ts.TosaSerializerAttribute() - attr_builder(attr) - self._serialize_operator( - node, - tosa_graph, - tosa_op, - [inputs[0].name, inputs[1].name], - [output.name], - attr, - ) - - register_node_visitor(BinaryOperator) - - -binary_operator_factory( - "aten.bitwise_and.Tensor", - ts.Op.BITWISE_AND, - lambda attr: attr.BitwiseAndAttribute(), -) -binary_operator_factory( - "aten.bitwise_xor.Tensor", - ts.Op.BITWISE_XOR, - lambda attr: attr.BitwiseXorAttribute(), -) -binary_operator_factory( - "aten.bitwise_or.Tensor", ts.Op.BITWISE_OR, lambda attr: attr.BitwiseOrAttribute() -) -binary_operator_factory( - "aten.logical_and.default", - ts.Op.LOGICAL_AND, - lambda attr: attr.LogicalAndAttribute(), -) -binary_operator_factory( - "aten.logical_xor.default", - ts.Op.LOGICAL_XOR, - lambda attr: attr.LogicalXorAttribute(), -) -binary_operator_factory( - "aten.logical_or.default", ts.Op.LOGICAL_OR, lambda attr: attr.LogicalOrAttribute() -) -binary_operator_factory( - "aten.bitwise_left_shift.Tensor", - ts.Op.LOGICAL_LEFT_SHIFT, - lambda attr: attr.LogicalLeftShiftAttribute(), -) diff --git a/backends/arm/test/misc/tosa_dialect/test_tosa_dialect_binary_ops.py b/backends/arm/test/misc/tosa_dialect/test_tosa_dialect_binary_ops.py index f886e29c834..920454f5a9b 100644 --- a/backends/arm/test/misc/tosa_dialect/test_tosa_dialect_binary_ops.py +++ b/backends/arm/test/misc/tosa_dialect/test_tosa_dialect_binary_ops.py @@ -295,7 +295,7 @@ def test_bitwise_and_rejects_int64_without_extension() -> None: pytest.param( exir_ops.backend.tosa.ARITHMETIC_RIGHT_SHIFT.default, "TOSA-1.1+FP", - False, + True, id="arithmetic_right_shift_fp", ), pytest.param( @@ -324,7 +324,7 @@ def test_bitwise_and_rejects_int64_without_extension() -> None: ), ], ) -def test_tosa_integer_shift_and_bitwise_ops_registered_for_int_profile_only( +def test_tosa_integer_shift_and_bitwise_ops_profile_registration( op, spec: str, expected: bool, diff --git a/backends/arm/test/ops/test_where.py b/backends/arm/test/ops/test_where.py index 40b0847b838..234203f1cc4 100644 --- a/backends/arm/test/ops/test_where.py +++ b/backends/arm/test/ops/test_where.py @@ -189,8 +189,11 @@ def scalar_condition(input: torch.Tensor): } test_modules_FP_unsupported_dtype = { - "float32_tensor_cond_tuple_dtype": lambda: float32_tensor_cond_tuple_dtype, - "int32_scalar_cond": lambda: int32_scalar_cond, + "float32_tensor_cond_tuple_dtype": lambda: ( + float32_tensor_cond_tuple_dtype, + 1, + ), + "int32_scalar_cond": lambda: (int32_scalar_cond, 0), } test_modules_INT = { @@ -215,11 +218,12 @@ def test_where_self_tosa_FP(test_module): @common.parametrize("test_module", test_modules_FP_unsupported_dtype) def test_where_self_tosa_FP_unsupported_dtype(test_module): + module, n_expected_delegates = test_module() pipeline = OpNotSupportedPipeline[input_t]( - test_module(), - test_module().get_inputs(), + module, + module.get_inputs(), {exir_op: 1}, - n_expected_delegates=1, # condition can be delegated + n_expected_delegates=n_expected_delegates, ) pipeline.run() diff --git a/backends/arm/test/passes/test_convert_int64_const_ops_to_int32.py b/backends/arm/test/passes/test_convert_int64_const_ops_to_int32.py index d39c4527f7e..d46651d2093 100644 --- a/backends/arm/test/passes/test_convert_int64_const_ops_to_int32.py +++ b/backends/arm/test/passes/test_convert_int64_const_ops_to_int32.py @@ -58,7 +58,6 @@ def test_convert_int64_const_ops_to_int32_tosa_FP_arange_default( "torch.ops.aten.view.default", ] exir_ops_checks = [ - "executorch_exir_dialects_edge__ops_aten_lt_Tensor", "executorch_exir_dialects_edge__ops_aten_view_copy_default", ] pipeline = TosaPipelineFP[input_t1]( @@ -119,7 +118,6 @@ def test_convert_int64_const_ops_to_int32_tosa_FP_arange_start( "torch.ops.aten.view.default", ] exir_ops_checks = [ - "executorch_exir_dialects_edge__ops_aten_lt_Tensor", "executorch_exir_dialects_edge__ops_aten_view_copy_default", ] pipeline = TosaPipelineFP[input_t1]( @@ -180,7 +178,6 @@ def test_convert_int64_const_ops_to_int32_tosa_FP_arange_start_step( "torch.ops.aten.view.default", ] exir_ops_checks = [ - "executorch_exir_dialects_edge__ops_aten_lt_Tensor", "executorch_exir_dialects_edge__ops_aten_view_copy_default", ] pipeline = TosaPipelineFP[input_t1]( @@ -342,7 +339,6 @@ def test_convert_int64_const_ops_to_int32_tosa_FP_full( "executorch_exir_dialects_edge__ops_aten_add_Tensor", "executorch_exir_dialects_edge__ops_aten_view_copy_default", "executorch_exir_dialects_edge__ops_aten_mul_Tensor", - "executorch_exir_dialects_edge__ops_aten_lt_Tensor", ] pipeline = TosaPipelineFP[input_t2]( module, diff --git a/backends/arm/test/passes/test_promote_bool_operands_pass.py b/backends/arm/test/passes/test_promote_bool_operands_pass.py index 61be6d43efb..911ed54074b 100644 --- a/backends/arm/test/passes/test_promote_bool_operands_pass.py +++ b/backends/arm/test/passes/test_promote_bool_operands_pass.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -64,7 +64,6 @@ def test_promote_bool_operands_tosa_FP_all_bool(test_data: tensor_pair_t) -> Non } ops_after_pass = { "executorch_exir_dialects_edge__ops_aten_bitwise_and_Tensor": 1, - "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 3, } pipeline = PassPipeline[tensor_pair_t]( module, @@ -76,8 +75,7 @@ def test_promote_bool_operands_tosa_FP_all_bool(test_data: tensor_pair_t) -> Non ) pipeline.run() cast_dtypes = _collect_cast_dtypes(pipeline) - assert cast_dtypes.count(torch.int8) == 2 - assert cast_dtypes.count(torch.bool) == 1 + assert cast_dtypes == [] @common.parametrize("test_data", MixedMulModule.test_data) diff --git a/backends/arm/tosa/dialect/ops/binary_elementwise.py b/backends/arm/tosa/dialect/ops/binary_elementwise.py index 0b62cc49867..1a3f7222419 100644 --- a/backends/arm/tosa/dialect/ops/binary_elementwise.py +++ b/backends/arm/tosa/dialect/ops/binary_elementwise.py @@ -145,7 +145,7 @@ def ADD(input1: torch.Tensor, input2: torch.Tensor) -> torch.Tensor: @register_fake_tosa_op( "ARITHMETIC_RIGHT_SHIFT(Tensor input1, Tensor input2, *, bool round=False) -> Tensor", - INT_SPECS, + TosaSpecification.all_versions_and_profiles(), ) def ARITHMETIC_RIGHT_SHIFT( input1: torch.Tensor, @@ -153,7 +153,7 @@ def ARITHMETIC_RIGHT_SHIFT( *, round: bool = False, ) -> torch.Tensor: - _validate_int_dtype(input1.dtype, "ARITHMETIC_RIGHT_SHIFT") + _validate_any_profile_int_dtype(input1.dtype, "ARITHMETIC_RIGHT_SHIFT") return _binary_meta(input1, input2, "ARITHMETIC_RIGHT_SHIFT")