From b3e891d57d130377ccb298d6c09c4acbaa9fd213 Mon Sep 17 00:00:00 2001 From: jethroqti Date: Fri, 9 Jan 2026 18:43:09 -0800 Subject: [PATCH 1/2] Qualcomm AI Engine Direct - add pass for extra padding then maxpool2d Summary: The padding value used in max_pool2d operations differs between PyTorch and QNN implementations. PyTorch uses negative infinity, while QNN uses zero. To ensure consistent max_pool2d output across both frameworks, we handle this by padding tensor with constant in advance then doing max_pool2d without constant padding. Test plans: python backends/qualcomm/tests/test_qnn_delegate.py TestQNNQuantizedOperator.test_qnn_backend_max_pool2d -b build-android -H ${HOST} -s ${SN} -m ${CHIPID} python backends/qualcomm/tests/test_qnn_delegate.py TestQNNFloatingPointOperator.test_qnn_backend_max_pool2d -b build-android -H ${HOST} -s ${SN} -m ${CHIPID} --- backends/qualcomm/_passes/__init__.py | 2 + backends/qualcomm/_passes/qnn_pass_manager.py | 4 +- .../_passes/recompose_pad_maxpool2d.py | 147 ++++++++++++++++++ backends/qualcomm/_passes/utils.py | 2 + backends/qualcomm/tests/models.py | 10 +- backends/qualcomm/tests/test_qnn_delegate.py | 31 +++- 6 files changed, 185 insertions(+), 11 deletions(-) create mode 100644 backends/qualcomm/_passes/recompose_pad_maxpool2d.py diff --git a/backends/qualcomm/_passes/__init__.py b/backends/qualcomm/_passes/__init__.py index 49449fe2190..83e4e9bad37 100644 --- a/backends/qualcomm/_passes/__init__.py +++ b/backends/qualcomm/_passes/__init__.py @@ -40,6 +40,7 @@ from .insert_reshape_for_reduce_ops import InsertReshapeForReduceOps from .layout_transform import LayoutTransform from .lift_constant_scalar_operands import LiftConstantScalarOperands +from .recompose_pad_maxpool2d import RecomposePadMaxPool2d from .recompose_pixel_unshuffle import RecomposePixelUnshuffle from .recompose_rms_norm import RecomposeRmsNorm from .reduce_dynamic_range import ReduceDynamicRange @@ -87,6 +88,7 @@ InsertRequantize, LayoutTransform, LiftConstantScalarOperands, + RecomposePadMaxPool2d, RecomposePixelUnshuffle, RecomposeRmsNorm, ReduceDynamicRange, diff --git a/backends/qualcomm/_passes/qnn_pass_manager.py b/backends/qualcomm/_passes/qnn_pass_manager.py index 46a1dfb0970..5f4168c1770 100644 --- a/backends/qualcomm/_passes/qnn_pass_manager.py +++ b/backends/qualcomm/_passes/qnn_pass_manager.py @@ -45,6 +45,7 @@ InsertReshapeForReduceOps, LayoutTransform, LiftConstantScalarOperands, + RecomposePadMaxPool2d, RecomposePixelUnshuffle, RecomposeRmsNorm, ReduceDynamicRange, @@ -93,13 +94,14 @@ def get_capture_program_passes(): (ConvertBmmToMatmul, False), (DecomposeAny, True), (DecomposeColIm, True), + (DecomposeMaxPool3d, True), (DecomposeMinMaxDim, True), (ExpandBroadcastTensorShape, True), (FixedLinearKeepDim, True), (FoldQDQ, True), (I64toI32, True), (LayoutTransform, True), - (DecomposeMaxPool3d, True), + (RecomposePadMaxPool2d, True), (RecomposePixelUnshuffle, True), (RecomposeRmsNorm, True), (Remove0DTensor, True), diff --git a/backends/qualcomm/_passes/recompose_pad_maxpool2d.py b/backends/qualcomm/_passes/recompose_pad_maxpool2d.py new file mode 100644 index 00000000000..4762d2beb53 --- /dev/null +++ b/backends/qualcomm/_passes/recompose_pad_maxpool2d.py @@ -0,0 +1,147 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import operator +from typing import cast, List + +import torch + +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult + +from torch._subclasses.fake_tensor import FakeTensorMode + + +def add_fake_tensor_to_node(padding_node, input_shape, padding_args, dtype): + fake_mode = FakeTensorMode() + + with fake_mode: + batch, channels, height, width = input_shape + pad_left, pad_right, pad_top, pad_bottom = padding_args + output_shape = ( + batch, + channels, + height + pad_top + pad_bottom, + width + pad_left + pad_right, + ) + fake_output = torch.empty(output_shape, dtype=dtype) + if not hasattr(padding_node, "meta"): + padding_node.meta = {} + padding_node.meta["val"] = fake_output + + return fake_output + + +class RecomposePadMaxPool2d(ExportPass): + """ + The padding value used in max_pool2d operations differs between PyTorch and QNN implementations. + PyTorch uses negative infinity, while QNN uses zero. To ensure consistent max_pool2d output across both frameworks, + we handle this by padding tensor with constant in advance then doing max_pool2d without constant padding. + Note that for the quantization flow, we set quant_min as the padding value. If, at runtime, there is a value smaller than quant_min, + it could result in an accuracy drop. + """ + + def __init__(self): + super(RecomposePadMaxPool2d, self).__init__() + self.getitem = operator.getitem + self.max_pool2d = exir_ops.edge.aten.max_pool2d_with_indices.default + self.pad_op = exir_ops.edge.aten.constant_pad_nd.default + + def call(self, graph_module: torch.fx.GraphModule): + graph = graph_module.graph + for node in graph.nodes: + num_args = len(node.args) + if ( + node.op == "call_function" + and node.target == self.max_pool2d + and num_args > 3 + ): + padding = cast(List[int], node.args[3]) + if len(padding) == 1: + padding *= 2 + if padding[0] == 0 and padding[1] == 0: + continue + # create padding info for constant_pad_nd + padding = cast(List[int], node.args[3]) + if len(padding) == 1: + padding *= 4 + elif len(padding) == 2: + padding = [padding[1], padding[1], padding[0], padding[0]] + + input_node = node.args[0] + # kernel info + filter_size = cast(List[int], node.args[1]) + if len(filter_size) == 1: + filter_size *= 2 + # stride info + stride = cast(List[int], node.args[2]) + if len(stride) == 1: + stride *= 2 + # dilation info + dilation = [1, 1] + if num_args > 4: + dilation = cast(List[int], node.args[4]) + if len(padding) == 1: + dilation *= 2 + + ceil_mode = node.args[5] if num_args > 5 else False + + # We need to know the minimum value of input tensor of max_pool2d. + padding_value = float("-inf") + if quant_attrs := node.meta.get("quant_attrs"): + padding_value = quant_attrs.get("quant_min") + pad_value = padding_value + if quant_attrs: + pad_value = ( + padding_value - quant_attrs["zero_point"] + ) * quant_attrs["scale"] + with graph_module.graph.inserting_after(input_node): + padding_node = graph.create_node( + "call_function", + self.pad_op, + ( + input_node, + padding, + pad_value, + ), + ) + add_fake_tensor_to_node( + padding_node, + input_node.meta["val"].shape, + padding, + input_node.meta["val"].dtype, + ) + if quant_attrs: + padding_node.meta["quant_attrs"] = node.meta["quant_attrs"] + + with graph_module.graph.inserting_after(padding_node): + # max_pool2d + maxpool2d_args = ( + padding_node, + filter_size, + stride, + (0, 0), + dilation, + ceil_mode, + ) + maxpool2d_node_tuple = graph.create_node( + "call_function", + self.max_pool2d, + maxpool2d_args, + ) + if quant_attrs: + maxpool2d_node_tuple.meta["quant_attrs"] = node.meta[ + "quant_attrs" + ] + maxpool2d_node_tuple.meta["val"] = [None, None] + maxpool2d_node_tuple.meta["val"][0] = padding_node.meta["val"] + + for user in node.users.copy(): + user.replace_input_with(node, maxpool2d_node_tuple) + + graph.eliminate_dead_code() + graph_module.recompile() + return PassResult(graph_module, True) diff --git a/backends/qualcomm/_passes/utils.py b/backends/qualcomm/_passes/utils.py index e395511e438..72749a29544 100755 --- a/backends/qualcomm/_passes/utils.py +++ b/backends/qualcomm/_passes/utils.py @@ -75,6 +75,7 @@ def get_passes_dependency_for_capture_program(): FoldQDQ, I64toI32, LayoutTransform, + RecomposePadMaxPool2d, RecomposePixelUnshuffle, RecomposeRmsNorm, RemoveRedundancy, @@ -105,6 +106,7 @@ def get_passes_dependency_for_capture_program(): ExpandBroadcastTensorShape, FixedLinearKeepDim, ], + RecomposePadMaxPool2d: [DecomposeMaxPool3d, FoldQDQ], RecomposePixelUnshuffle: [RemoveRedundancy], RecomposeRmsNorm: [RemoveRedundancy], TagQuantIO: [LayoutTransform], diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index 006827a2785..931b9ee2731 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -1463,14 +1463,14 @@ def forward(self, x): class MaxPool2d(torch.nn.Module): - def __init__(self): + def __init__(self, kernel_size=3, stride=1, padding=1, ceil_mode=True): super().__init__() self.max_pool2d = torch.nn.MaxPool2d( - kernel_size=3, - stride=1, - padding=1, + kernel_size=kernel_size, + stride=stride, + padding=padding, dilation=1, - ceil_mode=True, + ceil_mode=ceil_mode, ) def forward(self, x): diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 381524135a9..dab8c127c55 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -1403,9 +1403,16 @@ def test_qnn_backend_max_dim(self): self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_max_pool2d(self): - module = MaxPool2d() # noqa: F405 + modules = [ + MaxPool2d(3, 1, 0, True), # noqa: F405 + MaxPool2d(3, 1, 0, False), # noqa: F405 + MaxPool2d(3, 1, 1, True), # noqa: F405 + MaxPool2d(3, 1, 1, False), # noqa: F405 + ] sample_input = (torch.randn(4, 3, 24, 24),) - self.lower_module_and_test_output(module, sample_input) + for i, module in enumerate(modules): + with self.subTest(i=i): + self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_max_pool3d(self): # NOTE: The pad should be at most half of effective kernel size. @@ -3661,10 +3668,24 @@ def test_qnn_backend_max_dim(self): self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_max_pool2d(self): - module = MaxPool2d() # noqa: F405 + modules = [ + MaxPool2d(3, 1, 0, True), # noqa: F405 + MaxPool2d(3, 1, 0, False), # noqa: F405 + MaxPool2d(3, 1, 1, True), # noqa: F405 + MaxPool2d(3, 1, 1, False), # noqa: F405 + ] + test_quants = [QuantDtype.use_8a8w, QuantDtype.use_16a4w, QuantDtype.use_16a8w] sample_input = (torch.randn(4, 3, 24, 24),) - module = self.get_qdq_module(module, sample_input) - self.lower_module_and_test_output(module, sample_input) + test_pairs = [ + (module, quant_type) # noqa: F405 + for module, quant_type in itertools.product(modules, test_quants) + ] + for i, (test_module, qtype) in enumerate(test_pairs): + with self.subTest(i=i): + qdq_module = self.get_qdq_module( + test_module, sample_input, quant_dtype=qtype + ) + self.lower_module_and_test_output(qdq_module, sample_input) def test_qnn_backend_max_pool3d(self): # NOTE: The pad should be at most half of effective kernel size. From f3c142de491effcb8cf860594cf112006300d113 Mon Sep 17 00:00:00 2001 From: jethroqti Date: Wed, 14 Jan 2026 23:25:39 -0800 Subject: [PATCH 2/2] ignore too complex func --- backends/qualcomm/_passes/recompose_pad_maxpool2d.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backends/qualcomm/_passes/recompose_pad_maxpool2d.py b/backends/qualcomm/_passes/recompose_pad_maxpool2d.py index 4762d2beb53..287daa4f399 100644 --- a/backends/qualcomm/_passes/recompose_pad_maxpool2d.py +++ b/backends/qualcomm/_passes/recompose_pad_maxpool2d.py @@ -50,7 +50,7 @@ def __init__(self): self.max_pool2d = exir_ops.edge.aten.max_pool2d_with_indices.default self.pad_op = exir_ops.edge.aten.constant_pad_nd.default - def call(self, graph_module: torch.fx.GraphModule): + def call(self, graph_module: torch.fx.GraphModule): # noqa C901 graph = graph_module.graph for node in graph.nodes: num_args = len(node.args)