Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/qualcomm/_passes/layout_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ class LayoutTransform(ExportPass):
exir_ops.edge.aten.repeat.default,
exir_ops.edge.aten.relu.default,
exir_ops.edge.aten.round.default,
exir_ops.edge.aten.scatter.src,
exir_ops.edge.aten.sigmoid.default,
exir_ops.edge.aten.sign.default,
exir_ops.edge.aten.slice_copy.Tensor,
Expand Down
4 changes: 2 additions & 2 deletions backends/qualcomm/builders/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ Please help update following table if you are contributing new operators:
+ 🚫 = Deprecated, supported with other QNN Ops


| Operators | HTP - 98/119 Enabled |
| Operators | HTP - 99/119 Enabled |
|-----------|---------|
| Argmax | ✓ |
| Argmin | ✓ |
Expand Down Expand Up @@ -472,7 +472,7 @@ Please help update following table if you are contributing new operators:
| ResizeNearestNeighbor | ✓ |
| RoiAlign | ✗ |
| RmsNorm | ✓ |
| ScatterElements | ✗ |
| ScatterElements | ✓ |
| ScatterNd | ✓ |
| Sigmoid | ✓ |
| Softmax | ✓ |
Expand Down
2 changes: 2 additions & 0 deletions backends/qualcomm/builders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
op_round,
op_rsqrt,
op_scalar_tensor,
op_scatter_elements,
op_select_copy,
op_sigmoid,
op_sign,
Expand Down Expand Up @@ -202,6 +203,7 @@
op_round,
op_rsqrt,
op_scalar_tensor,
op_scatter_elements,
op_select_copy,
op_sigmoid,
op_sign,
Expand Down
103 changes: 103 additions & 0 deletions backends/qualcomm/builders/op_scatter_elements.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict

import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager

import numpy as np
import torch
from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER, QCOM_DATA

from .node_visitor import NodeVisitor
from .node_visitor_manager import register_node_visitor
from .qnn_constants import OpScatterElements, QNN_OP_PACKAGE_NAME_QTI_AISW


@register_node_visitor
class ScatterElements(NodeVisitor):
target = ["aten.scatter.src"]

def __init__(self, *args) -> None:
super().__init__(*args)

def define_node(
self,
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper],
) -> PyQnnManager.PyQnnOpWrapper:
input_node = self.get_node(node.args[0])
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
node,
input_tensor,
PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
)

index_node = self.get_node(node.args[2])
index_tensor = self.get_tensor(index_node, node)
index_tensor_wrapper = self.define_tensor(
index_node,
node,
index_tensor.to(torch.int32),
PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
)

updates_node = self.get_node(node.args[3])
updates_tensor = self.get_tensor(updates_node, node)
updates_tensor_wrapper = self.define_tensor(
updates_node,
node,
updates_tensor,
PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
)

output_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
node,
node,
output_tensor,
PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
)

dim = node.args[1]
if dim < 0:
dim = dim % len(input_tensor.shape)

if QCOM_AXIS_ORDER in node.meta:
dim = node.meta[QCOM_AXIS_ORDER].index(dim)

scatter_op = PyQnnManager.PyQnnOpWrapper(
node.name,
QNN_OP_PACKAGE_NAME_QTI_AISW,
OpScatterElements.op_name,
)
scatter_op.AddInputTensors(
[
input_tensor_wrapper,
index_tensor_wrapper,
updates_tensor_wrapper,
]
)
scatter_op.AddOutputTensors([output_tensor_wrapper])

scatter_op.AddScalarParam(
OpScatterElements.param_axis,
PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
{QCOM_DATA: np.uint32(dim)},
)

scatter_op.AddScalarParam(
OpScatterElements.param_reduction,
PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
{QCOM_DATA: np.uint32(OpScatterElements.Reduction.NONE)},
)

return scatter_op
11 changes: 11 additions & 0 deletions backends/qualcomm/builders/qnn_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,17 @@ class OpRmsNorm:
param_axes: str = "axes"


@dataclass(init=False, frozen=True)
class OpScatterElements:
op_name: str = "ScatterElements"
param_axis: str = "axis"
param_reduction: str = "reduction"

@unique
class Reduction(IntEnum):
NONE = 0


@dataclass(init=False, frozen=True)
class OpScatterNd:
op_name: str = "ScatterNd"
Expand Down
1 change: 1 addition & 0 deletions backends/qualcomm/partition/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def get_skip_decomp_table() -> List[torch._ops.OperatorBase]:
torch.ops.aten.reflection_pad2d.default,
torch.ops.aten.rms_norm.default,
torch.ops.aten._safe_softmax.default,
torch.ops.aten.scatter.src,
torch.ops.aten.stack.default,
torch.ops.aten.upsample_bicubic2d.vec,
# This request is ignored because it is in a blocklist. Refer to exir/program/_program.py
Expand Down
39 changes: 39 additions & 0 deletions backends/qualcomm/quantizer/annotators/htp_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -1379,6 +1379,45 @@ class ScaledDotProductAttention(GeneralOpDef):
pass


@register_annotator(
[torch.ops.aten.scatter.src],
qnn_op=None,
)
class ScatterElements(GeneralOpDef):
@staticmethod
def annotate(node: Node, quantization_config: QuantizationConfig) -> None:
if _is_annotated([node]):
return

input_qspec_map = {}
input_act = None

if _is_float_tensor(node.args[0]):
input_act = node.args[0]
assert isinstance(input_act, Node)
input_qspec_map[input_act] = quantization_config.input_activation

if (
len(node.args) > 3
and isinstance(node.args[3], Node)
and _is_float_tensor(node.args[3])
):
input_qspec_map[node.args[3]] = SharedQuantizationSpec((input_act, node))

output_act_qspec = (
SharedQuantizationSpec((input_act, node))
if _is_float_tensor(node)
else None
)

if len(input_qspec_map) > 0 or output_act_qspec is not None:
node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=output_act_qspec,
_annotated=True,
)


@register_annotator(
[torch.ops.aten.sigmoid, torch.ops.aten.sigmoid.default],
QnnConstants.OpSigmoid.op_name,
Expand Down
39 changes: 39 additions & 0 deletions backends/qualcomm/quantizer/annotators/lpai_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -858,6 +858,45 @@ class ScaledDotProductAttention(GeneralOpDef):
pass


@register_annotator(
[torch.ops.aten.scatter.src],
qnn_op=None,
)
class ScatterElements(GeneralOpDef):
@staticmethod
def annotate(node: Node, quantization_config: QuantizationConfig) -> None:
if _is_annotated([node]):
return

input_qspec_map = {}
input_act = None

if _is_float_tensor(node.args[0]):
input_act = node.args[0]
assert isinstance(input_act, Node)
input_qspec_map[input_act] = quantization_config.input_activation

if (
len(node.args) > 3
and isinstance(node.args[3], Node)
and _is_float_tensor(node.args[3])
):
input_qspec_map[node.args[3]] = SharedQuantizationSpec((input_act, node))

output_act_qspec = (
SharedQuantizationSpec((input_act, node))
if _is_float_tensor(node)
else None
)

if len(input_qspec_map) > 0 or output_act_qspec is not None:
node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=output_act_qspec,
_annotated=True,
)


@register_annotator(
[torch.ops.aten.sigmoid, torch.ops.aten.sigmoid.default],
QnnConstants.OpSigmoid.op_name,
Expand Down
9 changes: 9 additions & 0 deletions backends/qualcomm/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2168,6 +2168,15 @@ def forward(self, query_layer, key_layer, value_layer, attn_mask):
return attn_output


class ScatterSrc(torch.nn.Module):
def __init__(self, dim=1):
super().__init__()
self.dim = dim

def forward(self, data, index, src):
return torch.scatter(data, self.dim, index, src)


class SelectCopy(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
85 changes: 85 additions & 0 deletions backends/qualcomm/tests/test_qnn_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1941,6 +1941,52 @@ def test_qnn_backend_round(self):
sample_input = (torch.randn([3, 4]),)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_scatter_src(self):
test_comb = [
{
QCOM_MODULE: [ScatterSrc(dim=1)], # noqa: F405
QCOM_SAMPLE_INPUTS: [
(
torch.zeros(3, 5),
torch.tensor(
[[0, 1, 2, 3, 4], [4, 3, 2, 1, 0], [1, 0, 3, 4, 2]],
dtype=torch.int64,
),
torch.rand(3, 5),
),
(
torch.zeros(3, 5, dtype=torch.float16),
torch.tensor(
[[0, 1, 2, 3, 4], [4, 3, 2, 1, 0], [1, 0, 3, 4, 2]],
dtype=torch.int64,
),
torch.rand(3, 5, dtype=torch.float16),
),
],
},
{
QCOM_MODULE: [ScatterSrc(dim=0)], # noqa: F405
QCOM_SAMPLE_INPUTS: [
(
torch.zeros(3, 5),
torch.tensor(
[[2, 1, 0, 1, 2], [0, 2, 1, 2, 0], [1, 0, 2, 0, 1]],
dtype=torch.int64,
),
torch.rand(3, 5),
),
],
},
]

index = 0
for comb in test_comb:
for module in comb[QCOM_MODULE]:
for sample_input in comb[QCOM_SAMPLE_INPUTS]:
with self.subTest(i=index):
index += 1
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_rsqrt(self):
module = Rsqrt() # noqa: F405
sample_input = (torch.abs(torch.randn([3, 4])),)
Expand Down Expand Up @@ -4530,6 +4576,45 @@ def test_qnn_backend_rsqrt(self):
module = self.get_qdq_module(module, sample_input)
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_scatter_src(self):
test_comb = [
{
QCOM_MODULE: [ScatterSrc(dim=1)], # noqa: F405
QCOM_SAMPLE_INPUTS: [
(
torch.zeros(3, 5),
torch.tensor(
[[0, 1, 2, 3, 4], [4, 3, 2, 1, 0], [1, 0, 3, 4, 2]],
dtype=torch.int64,
),
torch.rand(3, 5),
),
],
},
{
QCOM_MODULE: [ScatterSrc(dim=0)], # noqa: F405
QCOM_SAMPLE_INPUTS: [
(
torch.zeros(3, 5),
torch.tensor(
[[2, 1, 0, 1, 2], [0, 2, 1, 2, 0], [1, 0, 2, 0, 1]],
dtype=torch.int64,
),
torch.rand(3, 5),
),
],
},
]

index = 0
for comb in test_comb:
for module in comb[QCOM_MODULE]:
for sample_input in comb[QCOM_SAMPLE_INPUTS]:
with self.subTest(i=index):
index += 1
qdq_module = self.get_qdq_module(module, sample_input)
self.lower_module_and_test_output(qdq_module, sample_input)

def test_qnn_backend_sdpa(self):
modules = [
ScaledDotProductAttention(), # noqa: F405
Expand Down
Loading