diff --git a/backends/qualcomm/_passes/__init__.py b/backends/qualcomm/_passes/__init__.py index fc59ce3d262..f82157d3cf0 100644 --- a/backends/qualcomm/_passes/__init__.py +++ b/backends/qualcomm/_passes/__init__.py @@ -32,6 +32,7 @@ from .decompose_remainder import DecomposeRemainder from .decompose_roll import DecomposeRoll from .decompose_silu import DecomposeSilu +from .decompose_tan import DecomposeTan from .decompose_threshold import DecomposeThreshold from .decompose_triu import DecomposeTriu from .decompose_trunc import DecomposeTrunc @@ -88,6 +89,7 @@ DecomposeRemainder, DecomposeRoll, DecomposeSilu, + DecomposeTan, DecomposeThreshold, DecomposeTriu, DecomposeTrunc, diff --git a/backends/qualcomm/_passes/decompose_tan.py b/backends/qualcomm/_passes/decompose_tan.py new file mode 100644 index 00000000000..b75cf9ff2df --- /dev/null +++ b/backends/qualcomm/_passes/decompose_tan.py @@ -0,0 +1,71 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.dialects.edge._ops import EdgeOpOverload +from executorch.exir.pass_base import ExportPass, PassResult + +from .utils import copy_meta + + +class DecomposeTan(ExportPass): + """ + Decompose tan(x) = sin(x) / cos(x) + """ + + def __init__(self): + super(DecomposeTan, self).__init__() + self.targets = { + torch.ops.aten.tan.default, + exir_ops.edge.aten.tan.default, + } + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + graph = graph_module.graph + + for node in list(graph.nodes): + if node.op == "call_function" and node.target in self.targets: + is_edge = isinstance(node.target, EdgeOpOverload) + + sin_op = ( + exir_ops.edge.aten.sin.default + if is_edge + else torch.ops.aten.sin.default + ) + cos_op = ( + exir_ops.edge.aten.cos.default + if is_edge + else torch.ops.aten.cos.default + ) + div_op = ( + exir_ops.edge.aten.div.Tensor + if is_edge + else torch.ops.aten.div.Tensor + ) + + with graph.inserting_before(node): + sin_node = graph.create_node( + "call_function", sin_op, (node.args[0],) + ) + sin_node.meta = copy_meta(node.meta) + + cos_node = graph.create_node( + "call_function", cos_op, (node.args[0],) + ) + cos_node.meta = copy_meta(node.meta) + + div_node = graph.create_node( + "call_function", div_op, (sin_node, cos_node) + ) + div_node.meta = copy_meta(node.meta) + + for user in node.users.copy(): + user.replace_input_with(node, div_node) + + graph.eliminate_dead_code() + graph_module.recompile() + return PassResult(graph_module, True) diff --git a/backends/qualcomm/_passes/qnn_pass_manager.py b/backends/qualcomm/_passes/qnn_pass_manager.py index 57354af11de..b0913bbefd9 100644 --- a/backends/qualcomm/_passes/qnn_pass_manager.py +++ b/backends/qualcomm/_passes/qnn_pass_manager.py @@ -37,6 +37,7 @@ DecomposeRemainder, DecomposeRoll, DecomposeSilu, + DecomposeTan, DecomposeThreshold, DecomposeTriu, DecomposeTrunc, @@ -112,6 +113,7 @@ def get_capture_program_passes(): (DecomposeMinMaxDim, True), (DecomposePad, True), (DecomposeRemainder, True), + (DecomposeTan, True), (DecomposeTrunc, True), (ExpandBroadcastTensorShape, True), (FixedLinearKeepDim, True), @@ -236,6 +238,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule): self.add_pass(DecomposeScaledDotProductAttention()) self.add_pass(DecomposeRoll()) self.add_pass(DecomposeSilu()) + self.add_pass(DecomposeTan()) self.add_pass(DecomposeThreshold()) self.add_pass(DecomposeTriu()) self.add_pass(DecomposeTrunc()) diff --git a/backends/qualcomm/_passes/utils.py b/backends/qualcomm/_passes/utils.py index 04371d61e1c..542fa1115a6 100755 --- a/backends/qualcomm/_passes/utils.py +++ b/backends/qualcomm/_passes/utils.py @@ -74,6 +74,7 @@ def get_passes_dependency_for_capture_program(): DecomposeMaxPool3d, DecomposePad, DecomposeRemainder, + DecomposeTan, DecomposeTrunc, ExpandBroadcastTensorShape, FixedLinearKeepDim, @@ -107,6 +108,7 @@ def get_passes_dependency_for_capture_program(): DecomposeMaxPool3d: [RemoveRedundancy], DecomposePad: [RemoveRedundancy], DecomposeRemainder: [RemoveRedundancy], + DecomposeTan: [RemoveRedundancy], DecomposeTrunc: [RemoveRedundancy], ExpandBroadcastTensorShape: [FoldQDQ], FixedLinearKeepDim: [FoldQDQ], diff --git a/backends/qualcomm/builders/README.md b/backends/qualcomm/builders/README.md index d71aeb27be6..611bd7ed7aa 100644 --- a/backends/qualcomm/builders/README.md +++ b/backends/qualcomm/builders/README.md @@ -368,7 +368,7 @@ Please help update following table if you are contributing new operators: + 🚫 = Deprecated, supported with other QNN Ops -| Operators | HTP - 98/119 Enabled | +| Operators | HTP - 99/119 Enabled | |-----------|---------| | Argmax | ✓ | | Argmin | ✓ | @@ -472,7 +472,7 @@ Please help update following table if you are contributing new operators: | ResizeNearestNeighbor | ✓ | | RoiAlign | ✗ | | RmsNorm | ✓ | -| ScatterElements | ✗ | +| ScatterElements | ✓ | | ScatterNd | ✓ | | Sigmoid | ✓ | | Softmax | ✓ | @@ -517,6 +517,7 @@ The following PyTorch operators are supported through decomposition or annotatio | `aten.remainder.Scalar`, `aten.remainder.Tensor` | `DecomposeRemainder` | | `aten.roll` | `DecomposeRoll` | | `aten.silu` | `DecomposeSilu` | +| `aten.tan` | `DecomposeTan` | | `aten.threshold` | `DecomposeThreshold` | | `aten.triu` | `DecomposeTriu` | | `aten.trunc` | `DecomposeTrunc` | diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index 12d5e0902db..bcb4b1f631c 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -2425,6 +2425,14 @@ def forward(self, x): return torch.swapaxes(x, axis0=self.axis0, axis1=self.axis1) +class Tan(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.tan(x) + + class Tanh(torch.nn.Module): def __init__(self): super().__init__() diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index d76e3ea1df7..b0d9943810d 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -2063,6 +2063,11 @@ def test_qnn_backend_swapaxes(self): sample_input = (torch.randn([1, 2, 3, 4]),) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_tan(self): + module = Tan() # noqa: F405 + sample_input = (torch.rand(2, 5, 1, 3) * 2 - 1,) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_tanh(self): module = Tanh() # noqa: F405 sample_input = (torch.randn(2, 5, 1, 3),) @@ -4667,6 +4672,12 @@ def test_qnn_backend_swapaxes(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_tan(self): + module = Tan() # noqa: F405 + sample_input = (torch.rand(2, 5, 1, 3) * 2 - 1,) + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_tanh(self): module = Tanh() # noqa: F405 sample_input = (torch.randn(2, 5, 1, 3),)