Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backends/qualcomm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from .insert_reshape_for_reduce_ops import InsertReshapeForReduceOps
from .layout_transform import LayoutTransform
from .lift_constant_scalar_operands import LiftConstantScalarOperands
from .recompose_pad_maxpool2d import RecomposePadMaxPool2d
from .recompose_pixel_unshuffle import RecomposePixelUnshuffle
from .recompose_rms_norm import RecomposeRmsNorm
from .reduce_dynamic_range import ReduceDynamicRange
Expand Down Expand Up @@ -87,6 +88,7 @@
InsertRequantize,
LayoutTransform,
LiftConstantScalarOperands,
RecomposePadMaxPool2d,
RecomposePixelUnshuffle,
RecomposeRmsNorm,
ReduceDynamicRange,
Expand Down
4 changes: 3 additions & 1 deletion backends/qualcomm/_passes/qnn_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
InsertReshapeForReduceOps,
LayoutTransform,
LiftConstantScalarOperands,
RecomposePadMaxPool2d,
RecomposePixelUnshuffle,
RecomposeRmsNorm,
ReduceDynamicRange,
Expand Down Expand Up @@ -93,13 +94,14 @@ def get_capture_program_passes():
(ConvertBmmToMatmul, False),
(DecomposeAny, True),
(DecomposeColIm, True),
(DecomposeMaxPool3d, True),
(DecomposeMinMaxDim, True),
(ExpandBroadcastTensorShape, True),
(FixedLinearKeepDim, True),
(FoldQDQ, True),
(I64toI32, True),
(LayoutTransform, True),
(DecomposeMaxPool3d, True),
(RecomposePadMaxPool2d, True),
(RecomposePixelUnshuffle, True),
(RecomposeRmsNorm, True),
(Remove0DTensor, True),
Expand Down
147 changes: 147 additions & 0 deletions backends/qualcomm/_passes/recompose_pad_maxpool2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import operator
from typing import cast, List

import torch

from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult

from torch._subclasses.fake_tensor import FakeTensorMode


def add_fake_tensor_to_node(padding_node, input_shape, padding_args, dtype):
fake_mode = FakeTensorMode()

with fake_mode:
batch, channels, height, width = input_shape
pad_left, pad_right, pad_top, pad_bottom = padding_args
output_shape = (
batch,
channels,
height + pad_top + pad_bottom,
width + pad_left + pad_right,
)
fake_output = torch.empty(output_shape, dtype=dtype)
if not hasattr(padding_node, "meta"):
padding_node.meta = {}
padding_node.meta["val"] = fake_output

return fake_output


class RecomposePadMaxPool2d(ExportPass):
"""
The padding value used in max_pool2d operations differs between PyTorch and QNN implementations.
PyTorch uses negative infinity, while QNN uses zero. To ensure consistent max_pool2d output across both frameworks,
we handle this by padding tensor with constant in advance then doing max_pool2d without constant padding.
Note that for the quantization flow, we set quant_min as the padding value. If, at runtime, there is a value smaller than quant_min,
it could result in an accuracy drop.
"""

def __init__(self):
super(RecomposePadMaxPool2d, self).__init__()
self.getitem = operator.getitem
self.max_pool2d = exir_ops.edge.aten.max_pool2d_with_indices.default
self.pad_op = exir_ops.edge.aten.constant_pad_nd.default

def call(self, graph_module: torch.fx.GraphModule): # noqa C901
graph = graph_module.graph
for node in graph.nodes:
num_args = len(node.args)
if (
node.op == "call_function"
and node.target == self.max_pool2d
and num_args > 3
):
padding = cast(List[int], node.args[3])
if len(padding) == 1:
padding *= 2
if padding[0] == 0 and padding[1] == 0:
continue
# create padding info for constant_pad_nd
padding = cast(List[int], node.args[3])
if len(padding) == 1:
padding *= 4
elif len(padding) == 2:
padding = [padding[1], padding[1], padding[0], padding[0]]

input_node = node.args[0]
# kernel info
filter_size = cast(List[int], node.args[1])
if len(filter_size) == 1:
filter_size *= 2
# stride info
stride = cast(List[int], node.args[2])
if len(stride) == 1:
stride *= 2
# dilation info
dilation = [1, 1]
if num_args > 4:
dilation = cast(List[int], node.args[4])
if len(padding) == 1:
dilation *= 2

ceil_mode = node.args[5] if num_args > 5 else False

# We need to know the minimum value of input tensor of max_pool2d.
padding_value = float("-inf")
if quant_attrs := node.meta.get("quant_attrs"):
padding_value = quant_attrs.get("quant_min")
pad_value = padding_value
if quant_attrs:
pad_value = (
padding_value - quant_attrs["zero_point"]
) * quant_attrs["scale"]
with graph_module.graph.inserting_after(input_node):
padding_node = graph.create_node(
"call_function",
self.pad_op,
(
input_node,
padding,
pad_value,
),
)
add_fake_tensor_to_node(
padding_node,
input_node.meta["val"].shape,
padding,
input_node.meta["val"].dtype,
)
if quant_attrs:
padding_node.meta["quant_attrs"] = node.meta["quant_attrs"]

with graph_module.graph.inserting_after(padding_node):
# max_pool2d
maxpool2d_args = (
padding_node,
filter_size,
stride,
(0, 0),
dilation,
ceil_mode,
)
maxpool2d_node_tuple = graph.create_node(
"call_function",
self.max_pool2d,
maxpool2d_args,
)
if quant_attrs:
maxpool2d_node_tuple.meta["quant_attrs"] = node.meta[
"quant_attrs"
]
maxpool2d_node_tuple.meta["val"] = [None, None]
maxpool2d_node_tuple.meta["val"][0] = padding_node.meta["val"]

for user in node.users.copy():
user.replace_input_with(node, maxpool2d_node_tuple)

graph.eliminate_dead_code()
graph_module.recompile()
return PassResult(graph_module, True)
2 changes: 2 additions & 0 deletions backends/qualcomm/_passes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def get_passes_dependency_for_capture_program():
FoldQDQ,
I64toI32,
LayoutTransform,
RecomposePadMaxPool2d,
RecomposePixelUnshuffle,
RecomposeRmsNorm,
RemoveRedundancy,
Expand Down Expand Up @@ -105,6 +106,7 @@ def get_passes_dependency_for_capture_program():
ExpandBroadcastTensorShape,
FixedLinearKeepDim,
],
RecomposePadMaxPool2d: [DecomposeMaxPool3d, FoldQDQ],
RecomposePixelUnshuffle: [RemoveRedundancy],
RecomposeRmsNorm: [RemoveRedundancy],
TagQuantIO: [LayoutTransform],
Expand Down
10 changes: 5 additions & 5 deletions backends/qualcomm/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1463,14 +1463,14 @@ def forward(self, x):


class MaxPool2d(torch.nn.Module):
def __init__(self):
def __init__(self, kernel_size=3, stride=1, padding=1, ceil_mode=True):
super().__init__()
self.max_pool2d = torch.nn.MaxPool2d(
kernel_size=3,
stride=1,
padding=1,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=1,
ceil_mode=True,
ceil_mode=ceil_mode,
)

def forward(self, x):
Expand Down
31 changes: 26 additions & 5 deletions backends/qualcomm/tests/test_qnn_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1403,9 +1403,16 @@ def test_qnn_backend_max_dim(self):
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_max_pool2d(self):
module = MaxPool2d() # noqa: F405
modules = [
MaxPool2d(3, 1, 0, True), # noqa: F405
MaxPool2d(3, 1, 0, False), # noqa: F405
MaxPool2d(3, 1, 1, True), # noqa: F405
MaxPool2d(3, 1, 1, False), # noqa: F405
]
sample_input = (torch.randn(4, 3, 24, 24),)
self.lower_module_and_test_output(module, sample_input)
for i, module in enumerate(modules):
with self.subTest(i=i):
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_max_pool3d(self):
# NOTE: The pad should be at most half of effective kernel size.
Expand Down Expand Up @@ -3661,10 +3668,24 @@ def test_qnn_backend_max_dim(self):
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_max_pool2d(self):
module = MaxPool2d() # noqa: F405
modules = [
MaxPool2d(3, 1, 0, True), # noqa: F405
MaxPool2d(3, 1, 0, False), # noqa: F405
MaxPool2d(3, 1, 1, True), # noqa: F405
MaxPool2d(3, 1, 1, False), # noqa: F405
]
test_quants = [QuantDtype.use_8a8w, QuantDtype.use_16a4w, QuantDtype.use_16a8w]
sample_input = (torch.randn(4, 3, 24, 24),)
module = self.get_qdq_module(module, sample_input)
self.lower_module_and_test_output(module, sample_input)
test_pairs = [
(module, quant_type) # noqa: F405
for module, quant_type in itertools.product(modules, test_quants)
]
for i, (test_module, qtype) in enumerate(test_pairs):
with self.subTest(i=i):
qdq_module = self.get_qdq_module(
test_module, sample_input, quant_dtype=qtype
)
self.lower_module_and_test_output(qdq_module, sample_input)

def test_qnn_backend_max_pool3d(self):
# NOTE: The pad should be at most half of effective kernel size.
Expand Down