From 084b4c05a89958996c11acfc4204aab653848f59 Mon Sep 17 00:00:00 2001 From: Arik Horodniceanu Date: Tue, 14 Apr 2026 16:01:55 -0700 Subject: [PATCH] Qualcomm AI Engine Direct - Adding QNN backend support for scatter.src core ATen op --- backends/qualcomm/_passes/layout_transform.py | 1 + backends/qualcomm/builders/README.md | 4 +- backends/qualcomm/builders/__init__.py | 2 + .../qualcomm/builders/op_scatter_elements.py | 103 ++++++++++++++++++ backends/qualcomm/builders/qnn_constants.py | 11 ++ backends/qualcomm/partition/utils.py | 1 + .../quantizer/annotators/htp_rules.py | 39 +++++++ .../quantizer/annotators/lpai_rules.py | 39 +++++++ backends/qualcomm/tests/models.py | 9 ++ backends/qualcomm/tests/test_qnn_delegate.py | 85 +++++++++++++++ 10 files changed, 292 insertions(+), 2 deletions(-) create mode 100644 backends/qualcomm/builders/op_scatter_elements.py diff --git a/backends/qualcomm/_passes/layout_transform.py b/backends/qualcomm/_passes/layout_transform.py index 9422051addd..2c01a08e622 100644 --- a/backends/qualcomm/_passes/layout_transform.py +++ b/backends/qualcomm/_passes/layout_transform.py @@ -118,6 +118,7 @@ class LayoutTransform(ExportPass): exir_ops.edge.aten.repeat.default, exir_ops.edge.aten.relu.default, exir_ops.edge.aten.round.default, + exir_ops.edge.aten.scatter.src, exir_ops.edge.aten.sigmoid.default, exir_ops.edge.aten.sign.default, exir_ops.edge.aten.slice_copy.Tensor, diff --git a/backends/qualcomm/builders/README.md b/backends/qualcomm/builders/README.md index d71aeb27be6..51c4f05ed16 100644 --- a/backends/qualcomm/builders/README.md +++ b/backends/qualcomm/builders/README.md @@ -368,7 +368,7 @@ Please help update following table if you are contributing new operators: + 🚫 = Deprecated, supported with other QNN Ops -| Operators | HTP - 98/119 Enabled | +| Operators | HTP - 99/119 Enabled | |-----------|---------| | Argmax | ✓ | | Argmin | ✓ | @@ -472,7 +472,7 @@ Please help update following table if you are contributing new operators: | ResizeNearestNeighbor | ✓ | | RoiAlign | ✗ | | RmsNorm | ✓ | -| ScatterElements | ✗ | +| ScatterElements | ✓ | | ScatterNd | ✓ | | Sigmoid | ✓ | | Softmax | ✓ | diff --git a/backends/qualcomm/builders/__init__.py b/backends/qualcomm/builders/__init__.py index a897dfa53bd..28c6547a7fc 100644 --- a/backends/qualcomm/builders/__init__.py +++ b/backends/qualcomm/builders/__init__.py @@ -89,6 +89,7 @@ op_round, op_rsqrt, op_scalar_tensor, + op_scatter_elements, op_select_copy, op_sigmoid, op_sign, @@ -202,6 +203,7 @@ op_round, op_rsqrt, op_scalar_tensor, + op_scatter_elements, op_select_copy, op_sigmoid, op_sign, diff --git a/backends/qualcomm/builders/op_scatter_elements.py b/backends/qualcomm/builders/op_scatter_elements.py new file mode 100644 index 00000000000..4bcf4572803 --- /dev/null +++ b/backends/qualcomm/builders/op_scatter_elements.py @@ -0,0 +1,103 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import Dict + +import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager + +import numpy as np +import torch +from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER, QCOM_DATA + +from .node_visitor import NodeVisitor +from .node_visitor_manager import register_node_visitor +from .qnn_constants import OpScatterElements, QNN_OP_PACKAGE_NAME_QTI_AISW + + +@register_node_visitor +class ScatterElements(NodeVisitor): + target = ["aten.scatter.src"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], + ) -> PyQnnManager.PyQnnOpWrapper: + input_node = self.get_node(node.args[0]) + input_tensor = self.get_tensor(input_node, node) + input_tensor_wrapper = self.define_tensor( + input_node, + node, + input_tensor, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + + index_node = self.get_node(node.args[2]) + index_tensor = self.get_tensor(index_node, node) + index_tensor_wrapper = self.define_tensor( + index_node, + node, + index_tensor.to(torch.int32), + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + + updates_node = self.get_node(node.args[3]) + updates_tensor = self.get_tensor(updates_node, node) + updates_tensor_wrapper = self.define_tensor( + updates_node, + node, + updates_tensor, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + + output_tensor = self.get_tensor(node, node) + output_tensor_wrapper = self.define_tensor( + node, + node, + output_tensor, + PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + + dim = node.args[1] + if dim < 0: + dim = dim % len(input_tensor.shape) + + if QCOM_AXIS_ORDER in node.meta: + dim = node.meta[QCOM_AXIS_ORDER].index(dim) + + scatter_op = PyQnnManager.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpScatterElements.op_name, + ) + scatter_op.AddInputTensors( + [ + input_tensor_wrapper, + index_tensor_wrapper, + updates_tensor_wrapper, + ] + ) + scatter_op.AddOutputTensors([output_tensor_wrapper]) + + scatter_op.AddScalarParam( + OpScatterElements.param_axis, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + {QCOM_DATA: np.uint32(dim)}, + ) + + scatter_op.AddScalarParam( + OpScatterElements.param_reduction, + PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + {QCOM_DATA: np.uint32(OpScatterElements.Reduction.NONE)}, + ) + + return scatter_op diff --git a/backends/qualcomm/builders/qnn_constants.py b/backends/qualcomm/builders/qnn_constants.py index d7ec30fddc0..f8aad206037 100644 --- a/backends/qualcomm/builders/qnn_constants.py +++ b/backends/qualcomm/builders/qnn_constants.py @@ -587,6 +587,17 @@ class OpRmsNorm: param_axes: str = "axes" +@dataclass(init=False, frozen=True) +class OpScatterElements: + op_name: str = "ScatterElements" + param_axis: str = "axis" + param_reduction: str = "reduction" + + @unique + class Reduction(IntEnum): + NONE = 0 + + @dataclass(init=False, frozen=True) class OpScatterNd: op_name: str = "ScatterNd" diff --git a/backends/qualcomm/partition/utils.py b/backends/qualcomm/partition/utils.py index a83444a56b2..93f00d4e994 100644 --- a/backends/qualcomm/partition/utils.py +++ b/backends/qualcomm/partition/utils.py @@ -68,6 +68,7 @@ def get_skip_decomp_table() -> List[torch._ops.OperatorBase]: torch.ops.aten.reflection_pad2d.default, torch.ops.aten.rms_norm.default, torch.ops.aten._safe_softmax.default, + torch.ops.aten.scatter.src, torch.ops.aten.stack.default, torch.ops.aten.upsample_bicubic2d.vec, # This request is ignored because it is in a blocklist. Refer to exir/program/_program.py diff --git a/backends/qualcomm/quantizer/annotators/htp_rules.py b/backends/qualcomm/quantizer/annotators/htp_rules.py index cd65d02c752..b7ac6d2c1d5 100644 --- a/backends/qualcomm/quantizer/annotators/htp_rules.py +++ b/backends/qualcomm/quantizer/annotators/htp_rules.py @@ -1379,6 +1379,45 @@ class ScaledDotProductAttention(GeneralOpDef): pass +@register_annotator( + [torch.ops.aten.scatter.src], + qnn_op=None, +) +class ScatterElements(GeneralOpDef): + @staticmethod + def annotate(node: Node, quantization_config: QuantizationConfig) -> None: + if _is_annotated([node]): + return + + input_qspec_map = {} + input_act = None + + if _is_float_tensor(node.args[0]): + input_act = node.args[0] + assert isinstance(input_act, Node) + input_qspec_map[input_act] = quantization_config.input_activation + + if ( + len(node.args) > 3 + and isinstance(node.args[3], Node) + and _is_float_tensor(node.args[3]) + ): + input_qspec_map[node.args[3]] = SharedQuantizationSpec((input_act, node)) + + output_act_qspec = ( + SharedQuantizationSpec((input_act, node)) + if _is_float_tensor(node) + else None + ) + + if len(input_qspec_map) > 0 or output_act_qspec is not None: + node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=output_act_qspec, + _annotated=True, + ) + + @register_annotator( [torch.ops.aten.sigmoid, torch.ops.aten.sigmoid.default], QnnConstants.OpSigmoid.op_name, diff --git a/backends/qualcomm/quantizer/annotators/lpai_rules.py b/backends/qualcomm/quantizer/annotators/lpai_rules.py index 60cebfcc5c0..4e1a4399aab 100644 --- a/backends/qualcomm/quantizer/annotators/lpai_rules.py +++ b/backends/qualcomm/quantizer/annotators/lpai_rules.py @@ -858,6 +858,45 @@ class ScaledDotProductAttention(GeneralOpDef): pass +@register_annotator( + [torch.ops.aten.scatter.src], + qnn_op=None, +) +class ScatterElements(GeneralOpDef): + @staticmethod + def annotate(node: Node, quantization_config: QuantizationConfig) -> None: + if _is_annotated([node]): + return + + input_qspec_map = {} + input_act = None + + if _is_float_tensor(node.args[0]): + input_act = node.args[0] + assert isinstance(input_act, Node) + input_qspec_map[input_act] = quantization_config.input_activation + + if ( + len(node.args) > 3 + and isinstance(node.args[3], Node) + and _is_float_tensor(node.args[3]) + ): + input_qspec_map[node.args[3]] = SharedQuantizationSpec((input_act, node)) + + output_act_qspec = ( + SharedQuantizationSpec((input_act, node)) + if _is_float_tensor(node) + else None + ) + + if len(input_qspec_map) > 0 or output_act_qspec is not None: + node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=output_act_qspec, + _annotated=True, + ) + + @register_annotator( [torch.ops.aten.sigmoid, torch.ops.aten.sigmoid.default], QnnConstants.OpSigmoid.op_name, diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index 12d5e0902db..3fabed2458f 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -2168,6 +2168,15 @@ def forward(self, query_layer, key_layer, value_layer, attn_mask): return attn_output +class ScatterSrc(torch.nn.Module): + def __init__(self, dim=1): + super().__init__() + self.dim = dim + + def forward(self, data, index, src): + return torch.scatter(data, self.dim, index, src) + + class SelectCopy(torch.nn.Module): def __init__(self): super().__init__() diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index d76e3ea1df7..f53eef5743d 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -1941,6 +1941,52 @@ def test_qnn_backend_round(self): sample_input = (torch.randn([3, 4]),) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_scatter_src(self): + test_comb = [ + { + QCOM_MODULE: [ScatterSrc(dim=1)], # noqa: F405 + QCOM_SAMPLE_INPUTS: [ + ( + torch.zeros(3, 5), + torch.tensor( + [[0, 1, 2, 3, 4], [4, 3, 2, 1, 0], [1, 0, 3, 4, 2]], + dtype=torch.int64, + ), + torch.rand(3, 5), + ), + ( + torch.zeros(3, 5, dtype=torch.float16), + torch.tensor( + [[0, 1, 2, 3, 4], [4, 3, 2, 1, 0], [1, 0, 3, 4, 2]], + dtype=torch.int64, + ), + torch.rand(3, 5, dtype=torch.float16), + ), + ], + }, + { + QCOM_MODULE: [ScatterSrc(dim=0)], # noqa: F405 + QCOM_SAMPLE_INPUTS: [ + ( + torch.zeros(3, 5), + torch.tensor( + [[2, 1, 0, 1, 2], [0, 2, 1, 2, 0], [1, 0, 2, 0, 1]], + dtype=torch.int64, + ), + torch.rand(3, 5), + ), + ], + }, + ] + + index = 0 + for comb in test_comb: + for module in comb[QCOM_MODULE]: + for sample_input in comb[QCOM_SAMPLE_INPUTS]: + with self.subTest(i=index): + index += 1 + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_rsqrt(self): module = Rsqrt() # noqa: F405 sample_input = (torch.abs(torch.randn([3, 4])),) @@ -4530,6 +4576,45 @@ def test_qnn_backend_rsqrt(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_scatter_src(self): + test_comb = [ + { + QCOM_MODULE: [ScatterSrc(dim=1)], # noqa: F405 + QCOM_SAMPLE_INPUTS: [ + ( + torch.zeros(3, 5), + torch.tensor( + [[0, 1, 2, 3, 4], [4, 3, 2, 1, 0], [1, 0, 3, 4, 2]], + dtype=torch.int64, + ), + torch.rand(3, 5), + ), + ], + }, + { + QCOM_MODULE: [ScatterSrc(dim=0)], # noqa: F405 + QCOM_SAMPLE_INPUTS: [ + ( + torch.zeros(3, 5), + torch.tensor( + [[2, 1, 0, 1, 2], [0, 2, 1, 2, 0], [1, 0, 2, 0, 1]], + dtype=torch.int64, + ), + torch.rand(3, 5), + ), + ], + }, + ] + + index = 0 + for comb in test_comb: + for module in comb[QCOM_MODULE]: + for sample_input in comb[QCOM_SAMPLE_INPUTS]: + with self.subTest(i=index): + index += 1 + qdq_module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(qdq_module, sample_input) + def test_qnn_backend_sdpa(self): modules = [ ScaledDotProductAttention(), # noqa: F405