Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,7 +625,6 @@ def _tosa_pipeline(
DecomposePermuteForU55Pass(),
RewriteSlicePass(),
InsertConstShapesPass(),
ExirToTosaPass(exported_program),
]
)

Expand All @@ -634,6 +633,7 @@ def _tosa_pipeline(
[
CastInt64BuffersToInt32Pass(exported_program),
FuseEqualPlaceholdersPass(exported_program),
ExirToTosaPass(exported_program),
SymbolicToTosaShapesPass(),
InsertDynamicPaddingPass(),
FuseConsecutiveConcatShapesPass(),
Expand Down
44 changes: 44 additions & 0 deletions backends/arm/_passes/aten_to_tosa_tensor_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,47 @@ def rewrite_argmax(node: Node, pass_: AtenToDialectPass) -> DialectNodeSpec:
(input_node, dim),
{},
)


def rewrite_binary_operator(
node: Node, pass_: AtenToDialectPass
) -> DialectNodeSpec | None:
match node.target:
case exir_ops.edge.aten.add.Tensor:
target = exir_ops.backend.tosa.ADD.default
case exir_ops.edge.aten.bitwise_and.Tensor:
target = exir_ops.backend.tosa.BITWISE_AND.default
case exir_ops.edge.aten.bitwise_left_shift.Tensor:
target = exir_ops.backend.tosa.LOGICAL_LEFT_SHIFT.default
case exir_ops.edge.aten.bitwise_or.Tensor:
target = exir_ops.backend.tosa.BITWISE_OR.default
case exir_ops.edge.aten.bitwise_right_shift.Tensor:
target = exir_ops.backend.tosa.ARITHMETIC_RIGHT_SHIFT.default
case exir_ops.edge.aten.bitwise_xor.Tensor:
target = exir_ops.backend.tosa.BITWISE_XOR.default
case exir_ops.edge.aten.eq.Tensor:
target = exir_ops.backend.tosa.EQUAL.default
case exir_ops.edge.aten.ge.Tensor:
target = exir_ops.backend.tosa.GREATER_EQUAL.default
case exir_ops.edge.aten.gt.Tensor:
target = exir_ops.backend.tosa.GREATER.default
case exir_ops.edge.aten.logical_and.default:
target = exir_ops.backend.tosa.LOGICAL_AND.default
case exir_ops.edge.aten.logical_or.default:
target = exir_ops.backend.tosa.LOGICAL_OR.default
case exir_ops.edge.aten.logical_xor.default:
target = exir_ops.backend.tosa.LOGICAL_XOR.default
case exir_ops.edge.aten.maximum.default:
target = exir_ops.backend.tosa.MAXIMUM.default
case exir_ops.edge.aten.minimum.default:
target = exir_ops.backend.tosa.MINIMUM.default
case exir_ops.edge.aten.mul.Tensor:
target = exir_ops.backend.tosa.MUL.default
case exir_ops.edge.aten.pow.Tensor_Tensor:
target = exir_ops.backend.tosa.POW.default
case exir_ops.edge.aten.sub.Tensor:
target = exir_ops.backend.tosa.SUB.default
case _:
return None

return DialectNodeSpec(target, node.args, dict(node.kwargs))
55 changes: 50 additions & 5 deletions backends/arm/_passes/exir_to_tosa_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,24 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from collections.abc import Callable

import executorch.backends.arm.tosa.dialect # noqa: F401
from executorch.backends.arm._passes.aten_to_tosa_activation_functions import (
get_activation_replacement,
)
from executorch.backends.arm._passes.aten_to_tosa_tensor_operators import rewrite_argmax
from executorch.backends.arm._passes.aten_to_tosa_tensor_operators import (
rewrite_argmax,
rewrite_binary_operator,
)
from executorch.backends.transforms.aten_to_dialect_pass import (
AtenToDialectPass,
DialectNodeSpec,
SubstitutionFn,
)
from executorch.exir.dialects._ops import ops as exir_ops
from torch.fx import Node
from torch.fx.node import Target


class ExirToTosaPass(AtenToDialectPass):
Expand All @@ -25,17 +32,55 @@ class ExirToTosaPass(AtenToDialectPass):
"""


def register_dialect_substitutions(
*targets: Target,
) -> Callable[[SubstitutionFn], SubstitutionFn]:
def decorator(func: SubstitutionFn) -> SubstitutionFn:
for target in targets:
ExirToTosaPass.register_dialect_substitution(target)(func)
return func

return decorator


@ExirToTosaPass.register_dialect_substitution(exir_ops.edge.aten.argmax.default)
def _get_tensor_operators_replacement(
node: Node, pass_: AtenToDialectPass
) -> DialectNodeSpec:
return rewrite_argmax(node, pass_)


@ExirToTosaPass.register_dialect_substitution(exir_ops.edge.aten.clamp.default)
@ExirToTosaPass.register_dialect_substitution(exir_ops.edge.aten.erf.default)
@ExirToTosaPass.register_dialect_substitution(exir_ops.edge.aten.sigmoid.default)
@ExirToTosaPass.register_dialect_substitution(exir_ops.edge.aten.tanh.default)
@register_dialect_substitutions(
exir_ops.edge.aten.add.Tensor,
exir_ops.edge.aten.bitwise_and.Tensor,
exir_ops.edge.aten.bitwise_left_shift.Tensor,
exir_ops.edge.aten.bitwise_or.Tensor,
exir_ops.edge.aten.bitwise_right_shift.Tensor,
exir_ops.edge.aten.bitwise_xor.Tensor,
exir_ops.edge.aten.eq.Tensor,
exir_ops.edge.aten.ge.Tensor,
exir_ops.edge.aten.gt.Tensor,
exir_ops.edge.aten.logical_and.default,
exir_ops.edge.aten.logical_or.default,
exir_ops.edge.aten.logical_xor.default,
exir_ops.edge.aten.maximum.default,
exir_ops.edge.aten.minimum.default,
exir_ops.edge.aten.mul.Tensor,
exir_ops.edge.aten.pow.Tensor_Tensor,
exir_ops.edge.aten.sub.Tensor,
)
def _get_binary_operator_replacement(
node: Node, pass_: AtenToDialectPass
) -> DialectNodeSpec | None:
return rewrite_binary_operator(node, pass_)


@register_dialect_substitutions(
exir_ops.edge.aten.clamp.default,
exir_ops.edge.aten.erf.default,
exir_ops.edge.aten.sigmoid.default,
exir_ops.edge.aten.tanh.default,
)
def _get_activation_replacement(
node: Node, pass_: AtenToDialectPass
) -> DialectNodeSpec | None:
Expand Down
20 changes: 7 additions & 13 deletions backends/arm/_passes/promote_bool_operands_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# The TOSA BITWISE_AND, BITWISE_OR, and BITWISE_XOR don't handle bool inputs.
# When a targeted op receives boolean tensors, we promote them to an integer type before
# invocation and cast the result back to the expected dtype afterwards.
# Some TOSA ops don't handle bool inputs. When a targeted op receives boolean
# tensors, we promote them to an integer type before invocation and cast the
# result back to the expected dtype afterwards.

from typing import Set, Type

Expand All @@ -23,10 +23,9 @@ class PromoteBoolOperandsPass(ArmOpTargetedPass):

_passes_required_after: Set[Type[ExportPass]] = set()

# Bool bitwise ops are handled by RewriteBoolBitwiseToLogicalPass. Promoting
# them here would hide the bool dtype and prevent that rewrite.
target_ops = {
exir_ops.edge.aten.bitwise_and.Tensor,
exir_ops.edge.aten.bitwise_or.Tensor,
exir_ops.edge.aten.bitwise_xor.Tensor,
exir_ops.edge.aten.mul.Tensor,
}

Expand All @@ -41,14 +40,9 @@ def call_operator(self, op, args, kwargs, meta):
# select the first non-bool dtype, or None if all bool
promoted_dtype = next((dt for dt in original_dtypes if dt != torch.bool), None)

# if we don't have a dtype specified by the op, promote to default choice for the op
# If all operands are bool, promote mul to int32.
if promoted_dtype is None:
if op == exir_ops.edge.aten.mul.Tensor:
# mul as int32
promoted_dtype = torch.int32
else:
# bitwise ops can be int8
promoted_dtype = torch.int8
promoted_dtype = torch.int32

target_dtypes = []
for dt in original_dtypes:
Expand Down
52 changes: 51 additions & 1 deletion backends/arm/operator_support/tosa_supported_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,15 @@ def _is_quantized_constant(node: torch.fx.Node) -> bool:
return len(users) > 0


def _floating_profile_negative_checks(
tosa_spec: TosaSpecification, reporter: WhyNoPartitionReporter
) -> list[OperatorSupportBase]:
checks: list[OperatorSupportBase] = [CheckMixedFloatingInputs(reporter)]
if not tosa_spec.support_integer():
checks.append(CheckInt32ComparisonInputs(reporter))
return checks


def is_quantized(node: torch.fx.Node) -> bool:
"""Checks if the node is quantized.

Expand Down Expand Up @@ -341,7 +350,7 @@ def _negative_checks(
checks.extend(_wrapped_additional_checks(additional_checks, reporter))

if tosa_spec.support_float():
checks.append(CheckMixedFloatingInputs(reporter))
checks.extend(_floating_profile_negative_checks(tosa_spec, reporter))
else:
checks.append(CheckArmQuantized(reporter))
checks.append(CheckProperQuantization(reporter))
Expand Down Expand Up @@ -995,6 +1004,47 @@ def is_node_supported(
return True


class CheckInt32ComparisonInputs(OperatorSupportBase):
"""Reject int32 comparisons under the FP profile."""

target_ops = {
exir_ops.edge.aten.eq.Tensor,
exir_ops.edge.aten.eq.Scalar,
exir_ops.edge.aten.ge.Tensor,
exir_ops.edge.aten.ge.Scalar,
exir_ops.edge.aten.gt.Tensor,
exir_ops.edge.aten.gt.Scalar,
exir_ops.edge.aten.le.Tensor,
exir_ops.edge.aten.le.Scalar,
exir_ops.edge.aten.lt.Tensor,
exir_ops.edge.aten.lt.Scalar,
}

def __init__(self, reporter: WhyNoPartitionReporter) -> None:
self.reporter = reporter
super().__init__()

def is_node_supported(
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
) -> bool:
if node.target not in self.target_ops:
return True

for input_node in (
input_node
for input_node in node.all_input_nodes
if input_node.op != "get_attr"
):
if get_first_fake_tensor(input_node).dtype == torch.int32:
self.reporter.report_reject(
node,
"FP profile does not support int32 comparison inputs.",
)
return False

return True


class RankCheck(OperatorSupportBase):
"""Reject nodes with rank greater than ``max_rank``."""

Expand Down
22 changes: 11 additions & 11 deletions backends/arm/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from . import ( # noqa
node_visitor,
op_abs,
op_add,
op_amax,
op_amin,
op_any,
Expand All @@ -21,56 +20,57 @@
op_ceil,
op_cond_if,
op_cos,
op_eq,
op_exp,
op_floor,
op_ge,
op_gt,
op_log,
op_logical_not,
op_maximum,
op_minimum,
op_mul,
op_neg,
op_permute,
op_pow,
op_reciprocal,
op_repeat,
op_rshift_tensor,
op_rsqrt,
op_sin,
op_sub,
op_sum,
op_to_dim_order_copy,
op_tosa_add,
op_tosa_argmax,
op_tosa_avg_pool2d,
op_tosa_avg_pool2d_adaptive,
op_tosa_binary_ops,
op_tosa_cast_to_block_scaled,
op_tosa_clamp,
op_tosa_conv2d,
op_tosa_conv2d_block_scaled,
op_tosa_conv3d,
op_tosa_custom,
op_tosa_depthwise_conv2d,
op_tosa_eq,
op_tosa_erf,
op_tosa_gather,
op_tosa_ge,
op_tosa_gt,
op_tosa_identity,
op_tosa_matmul,
op_tosa_matmul_t_block_scaled,
op_tosa_max_pool2d,
op_tosa_max_pool2d_adaptive,
op_tosa_maximum,
op_tosa_minimum,
op_tosa_mul,
op_tosa_pad,
op_tosa_pow,
op_tosa_rescale,
op_tosa_resize,
op_tosa_rshift_tensor,
op_tosa_scatter,
op_tosa_shapes,
op_tosa_sigmoid,
op_tosa_slice,
op_tosa_sub,
op_tosa_table,
op_tosa_tanh,
op_tosa_transpose_conv2d,
op_view,
op_where,
op_while,
ops_binary,
)
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

@register_node_visitor
class AddVisitor(SimpleNodeVisitor):
target = "aten.add.Tensor"
target = "tosa.ADD.default"

@classmethod
def get_config(cls) -> SimpleNodeVisitorConfig:
Expand Down
Loading
Loading