diff --git a/backends/test/harness/tester.py b/backends/test/harness/tester.py index bfcee310b69..52c661dd748 100644 --- a/backends/test/harness/tester.py +++ b/backends/test/harness/tester.py @@ -41,8 +41,12 @@ def __init__( example_inputs: Tuple[torch.Tensor], stage_classes: Dict[StageType, Callable] | None = None, dynamic_shapes: Optional[Tuple[Any]] = None, + training: bool = False, ): - module.eval() + if training: + module.train() + else: + module.eval() self.stage_classes = stage_classes or Tester.default_stage_classes() self.original_module = module diff --git a/backends/xnnpack/_passes/__init__.py b/backends/xnnpack/_passes/__init__.py index 4992d7a4abd..45560124f57 100644 --- a/backends/xnnpack/_passes/__init__.py +++ b/backends/xnnpack/_passes/__init__.py @@ -23,6 +23,7 @@ from executorch.backends.xnnpack._passes.convert_to_upsample_bilinear2d import ( ConvertToUpsampleBilinear2d, ) +from executorch.backends.xnnpack._passes.decompose_batch_norm import DecomposeBatchNorm from executorch.backends.xnnpack._passes.decompose_cat import DecomposeConcatenate from executorch.backends.xnnpack._passes.fuse_activation_pass import FuseActivationPass from executorch.backends.xnnpack._passes.fuse_batch_norm import FuseBatchNormPass @@ -76,6 +77,7 @@ def __init__( ConvertToSDPAPass, ConstPropPass, FuseBatchNormPass, + DecomposeBatchNorm, FuseActivationPass, DecomposeConcatenate, RemoveGetItemPass, diff --git a/backends/xnnpack/_passes/decompose_batch_norm.py b/backends/xnnpack/_passes/decompose_batch_norm.py new file mode 100644 index 00000000000..683682e9b23 --- /dev/null +++ b/backends/xnnpack/_passes/decompose_batch_norm.py @@ -0,0 +1,296 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import operator + +import torch +from executorch.backends.transforms.utils import create_constant_placeholder +from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass +from executorch.backends.xnnpack.utils.utils import ( + check_or_raise, + get_param_tensor, + get_tensor_name, + is_param_node, +) +from executorch.exir.backend.utils import WhyNoPartition +from executorch.exir.dialects._ops import ops as exir_ops +from torch.export.graph_signature import InputKind +from torch.fx.passes.infra.pass_base import PassResult + +logger = logging.getLogger(__name__) +logger.setLevel(logging.WARNING) + + +class DecomposeBatchNorm(XNNPACKPass): + """ + Decompose batchnorm operators into 1x1 depthwise convolution. + """ + + BATCH_NORM_OPS = { + exir_ops.edge.aten.native_batch_norm.default, + exir_ops.edge.aten._native_batch_norm_legit_no_training.default, + } + + @staticmethod + def can_decompose_batch_norm( # noqa: C901 + node: torch.fx.Node, + exported_program: torch.export.ExportedProgram, + why: WhyNoPartition | None = None, + ) -> bool: + """ + Determine whether the given batch norm node can be decomposed by this pass. + """ + + if ( + node.op != "call_function" + or node.target not in DecomposeBatchNorm.BATCH_NORM_OPS + ): + return False + + input_meta = node.args[0].meta["val"] + + # Since we're converting to conv and XNNPACK doesn't support conv3d, we can't + # handle BatchNorm3d. Validate the input dimension. We'll take NC, NCL, or NCHW. + if input_meta.dim() not in (2, 3, 4): + if why: + why( + node, + f"Unsupported input rank {input_meta.dim()} for XNN batch norm operator.", + ) + return False + + # The batch norm node returns a tuple of output and other stuff we don't care about. + # All users must be getitem nodes that fetch the output (index 0). + # The partitioner should enforce this, but we'll check it here too. + for user in node.users: + if user.target != operator.getitem or user.args[1] != 0: + if why: + why(node, "Batch norm users must only access the output tensor.") + return False + + # Channel dimension and non-input args must be statically known. + if not isinstance(input_meta.shape[1], int): + if why: + why( + node, + f"Channel dimension must be statically known, but was {input_meta.shape[1]}.", + ) + return False + + if node.args[1] is not None and not is_param_node( + exported_program, node.args[1] + ): + if why: + why(node, "Batch norm affine weight must be static.") + return False + + if node.args[2] is not None and not is_param_node( + exported_program, node.args[2] + ): + if why: + why(node, "Batch norm affine bias must be static.") + return False + + if not is_param_node(exported_program, node.args[3]) or not is_param_node( + exported_program, node.args[4] + ): + if why: + why(node, "Batch norm running mean and variance must be static.") + return False + + if isinstance(node.args[-1], torch.fx.Node): + if why: + why(node, "Batch norm epsilon must be static.") + return False + + if ( + node.target == exir_ops.edge.aten.native_batch_norm.default + and node.args[5] is not False + ): + if why: + why(node, "Training batch norm is not supported.") + return False + + return True + + @staticmethod + def compute_w_and_b( + eps: float, + running_mean: torch.Tensor, # [C] + running_var: torch.Tensor, # [C] + gamma: torch.Tensor, # [C], learned weight + beta: torch.Tensor, # [C], learned bias + ) -> (torch.Tensor, torch.Tensor): + """ + Compute equivalent per-channel weight and bias to match the batch norm + computation with frozen values. + """ + + # See https://docs.pytorch.org/docs/stable/generated/torch.nn.BatchNorm1d.html + + # Do the math in double precision and convert back to the original dtype at the + # end. ATen kernels do this math in increased precision for float16. Note that + # all of the parameter dtypes must match, as per the ATen behavior. + + # Also note that gamma and beta can be None if affine=False. This is equivalent + # to gamma = 1 and beta = 0. + gamma_f64 = gamma.double() if gamma is not None else torch.Tensor([1]).double() + beta_f64 = beta.double() if beta is not None else torch.Tensor([0]).double() + running_mean_f64 = running_mean.double() + running_var_f64 = running_var.double() + + denom = torch.sqrt(running_var_f64 + torch.Tensor([eps])) + new_weight = gamma_f64 / denom + new_bias = -running_mean_f64 * gamma_f64 / denom + beta_f64 + + return new_weight.to(running_mean.dtype), new_bias.to(running_mean.dtype) + + def replace_bn_node_with_conv( + self, + bn_node: torch.fx.Node, + graph_module: torch.fx.GraphModule, + ) -> torch.fx.Node: + """ + Replace a BatchNorm with NCL or NCHW input with an equivalent depthwise + convolution. + """ + + # Compute the equivalent per-channel weights and biases. + # Note that the batch norm node args are + # (input, gamma, beta, running_mean, running_var, [training], momentum, eps). + # The training arg is not present in the _no_training variant. + weight, bias = DecomposeBatchNorm.compute_w_and_b( + eps=bn_node.args[-1], + running_mean=get_param_tensor(self.exported_program, bn_node.args[3]), + running_var=get_param_tensor(self.exported_program, bn_node.args[4]), + gamma=get_param_tensor(self.exported_program, bn_node.args[1]), + beta=get_param_tensor(self.exported_program, bn_node.args[2]), + ) + + # Conv weights have shape [out_c, in_c/g, spatial...]. + # For dw, in_c = g. The kernel is also 1x1 (or just 1, for 1d). + # + # BatchNorm weights have shape [in_c]. + # So we just need to unsqueeze the [in_c] to to [in_c, 1, 1, [1]]. + input_meta = bn_node.args[0].meta["val"] + channel_count = input_meta.shape[1] + spatial_dims = max( + input_meta.dim() - 2, 1 + ) # Min of 1 since 1d can be NC or NCL. + new_weight_shape = [weight.shape[0], 1] + [1] * spatial_dims + weight = weight.reshape(new_weight_shape) + + # Generate names for the new weight and bias parameters based on the original + # batch norm gamma parameter name. + gamma_name = get_tensor_name(self.exported_program, bn_node.args[1]) + weight_name = (gamma_name + "_decomposed_bn_weight").replace(".", "_") + bias_name = (gamma_name + "_decomposed_bn_bias").replace(".", "_") + + # Insert the new weight and bias as constant placeholders in the graph. + with graph_module.graph.inserting_before(bn_node.args[1]): + weight_node = create_constant_placeholder( + exp_program=self.exported_program, + graph=graph_module.graph, + kind=InputKind.PARAMETER, + name=weight_name, + data=weight, + ) + bias_node = create_constant_placeholder( + exp_program=self.exported_program, + graph=graph_module.graph, + kind=InputKind.PARAMETER, + name=bias_name, + data=bias, + ) + + with graph_module.graph.inserting_after(bn_node): + conv_node = graph_module.graph.call_function( + exir_ops.edge.aten.convolution.default, + args=( + bn_node.args[0], # Input + weight_node, # Weight + bias_node, # Bias + [1] * spatial_dims, # Stride + [0] * spatial_dims, # Padding + [1] * spatial_dims, # Dilation + False, # Transposed + [0] * spatial_dims, # Output_padding + channel_count, # Groups (depthwise, so groups=in_channels) + ), + ) + + # Find the getitem user nodes and replace them with the conv node. + # The decomp checks above enforce that the node is only used by getitem[0]. + users = list(bn_node.users) + for user in users: + user.replace_all_uses_with(conv_node) + graph_module.graph.erase_node(user) + + graph_module.graph.erase_node(bn_node) + return conv_node + + def decompose_node( + self, node: torch.fx.Node, graph_module: torch.fx.GraphModule + ) -> None: + input_meta = node.args[0].meta["val"] + + # These should be checked by the partitioner and calling node, + # so we should never fail these checks. + check_or_raise( + node.op == "call_function" + and node.target in DecomposeBatchNorm.BATCH_NORM_OPS, + f"Invalid batch norm operator {node.op}.", + ) + + check_or_raise( + input_meta.dim() in (2, 3, 4), + f"Unsupported input rank {input_meta.dim()} for XNN batch norm operator.", + ) + + channel_count = input_meta.shape[1] + check_or_raise( + isinstance(channel_count, int), + f"Channel dimension must be statically known, but was {channel_count}.", + ) + + # Create the convolution node. + conv_node = self.replace_bn_node_with_conv(node, graph_module) + + # BatchNorm1d can be NC or NCL. Conv1d requies the L dim, so unsqueeze NC -> NCL. + if input_meta.dim() == 2: + with graph_module.graph.inserting_before(conv_node): + # Insert unsqueeze node before. + unsqueeze_node = graph_module.graph.call_function( + exir_ops.edge.aten.unsqueeze_copy.default, + args=(conv_node.args[0], 2), + ) + conv_node.args = (unsqueeze_node, *conv_node.args[1:]) + + with graph_module.graph.inserting_after(conv_node): + # Insert squeeze node after. + squeeze_node = graph_module.graph.call_function( + exir_ops.edge.aten.squeeze_copy.dim, args=(conv_node, 2) + ) + conv_node.replace_all_uses_with(squeeze_node) + # This gets overwritten by replace_all_uses_with. Maybe there's + # a better solution? + squeeze_node.args = (conv_node, *squeeze_node.args[1:]) + + # override + def call(self, graph_module: torch.fx.GraphModule): + # Find and transform all eligible batch norm nodes. + for node in graph_module.graph.nodes: + if node.op == "call_function" and node.target in self.BATCH_NORM_OPS: + if self.can_decompose_batch_norm(node, self.exported_program): + self.decompose_node(node, graph_module) + + graph_module.recompile() + + # Propagate metadata and retrace module + graph_module = super().call(graph_module).graph_module + + return PassResult(graph_module, True) diff --git a/backends/xnnpack/_passes/fuse_batch_norm.py b/backends/xnnpack/_passes/fuse_batch_norm.py index a51920ed5ad..76ecce91585 100644 --- a/backends/xnnpack/_passes/fuse_batch_norm.py +++ b/backends/xnnpack/_passes/fuse_batch_norm.py @@ -11,19 +11,17 @@ create_constant_placeholder, delete_constant_placeholder, ) - from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass - from executorch.backends.xnnpack.utils.utils import ( get_param_tensor, get_tensor_name, is_param_node, ) from executorch.exir import ExportedProgram +from executorch.exir.backend.utils import WhyNoPartition from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import PassResult from torch.export.graph_signature import InputKind - from torch.nn.utils.fusion import fuse_conv_bn_weights, fuse_linear_bn_weights @@ -81,15 +79,27 @@ def call(self, graph_module: torch.fx.GraphModule): return PassResult(graph_module, True) @staticmethod - def can_fuse( + def can_fuse( # noqa: C901 input_node: torch.fx.Node, bn: torch.fx.Node, program: ExportedProgram, + why: WhyNoPartition | None = None, ) -> bool: """ Determine whether a BatchNorm node can be fused with the preceding convolution or linear node. """ + if input_node.op != "call_function": + return False + + if input_node.target not in ( + exir_ops.edge.aten.convolution.default, + exir_ops.edge.aten.linear.default, + ): + if why: + why("Input node must be a convolution or linear op.") + return False + is_conv = input_node.target == exir_ops.edge.aten.convolution.default # All users of the batch_norm node must be getitem ops. @@ -98,6 +108,8 @@ def can_fuse( if [ (user.target == operator.getitem and user.args[1] == 0) for user in bn.users ].count(False): + if why: + why("Batch norm users must only access the output tensor.") return False input_node_weights = input_node.args[1] @@ -107,11 +119,15 @@ def can_fuse( if not isinstance(input_node_weights, torch.fx.Node) or not isinstance( bn_weights, torch.fx.Node ): + if why: + why("Input node weights must be parameters.") return False if [ is_param_node(program, node) for node in {input_node_weights, bn_weights} ].count(False): + if why: + why("Node weights must be static.") return False # Check the rank of the convolutution input - only Conv1d and 2d are supported. @@ -122,6 +138,8 @@ def can_fuse( or "val" not in conv_input.meta or len(conv_input.meta["val"].shape) not in (3, 4) ): + if why: + why("Convolution input must be rank 3 or 4.") return False return True diff --git a/backends/xnnpack/operators/op_squeeze.py b/backends/xnnpack/operators/op_squeeze.py index 7a21fe9e551..3fd5a692e0c 100644 --- a/backends/xnnpack/operators/op_squeeze.py +++ b/backends/xnnpack/operators/op_squeeze.py @@ -36,8 +36,9 @@ def define_node( debug_handle: int, ) -> None: + dim = cast(int, node.args[1]) check_or_raise( - cast(int, node.args[1]) == -1, + dim == -1 or dim == len(node.args[0].meta["val"].shape) - 1, "XNNPACK currently only supports squeezing in last dimension", ) @@ -98,8 +99,9 @@ def define_node( debug_handle: int, ) -> None: + dim = cast(int, node.args[1]) check_or_raise( - cast(int, node.args[1]) == -1, + dim == -1 or dim == len(node.args[0].meta["val"].shape), "XNNPACK currently only supports unsqueezing in last dimension", ) diff --git a/backends/xnnpack/partition/config/node_configs.py b/backends/xnnpack/partition/config/node_configs.py index 4659ea05a0f..bffe87d30c5 100644 --- a/backends/xnnpack/partition/config/node_configs.py +++ b/backends/xnnpack/partition/config/node_configs.py @@ -9,15 +9,13 @@ from typing import List, Optional import torch +from executorch.backends.xnnpack._passes.decompose_batch_norm import DecomposeBatchNorm from executorch.backends.xnnpack._passes.fuse_batch_norm import FuseBatchNormPass from executorch.backends.xnnpack.partition.config.xnnpack_config import ( ConfigPrecisionType, XNNPartitionerConfig, ) from executorch.backends.xnnpack.utils.utils import is_param_node -from executorch.exir.backend.canonical_partitioners.config_partitioner import ( - format_target_name, -) from executorch.exir.backend.utils import WhyNoPartition from torch.export import ExportedProgram @@ -35,18 +33,11 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: bn = node input_node = node.all_input_nodes[0] - if input_node.op != "call_function": - return False - - input_name = format_target_name(input_node.target.__name__) # pyre-ignore - - if input_name not in ["convolution.default", "linear.default"]: - why(node, f"Invalid input target {input_name.split('.')[0]}") - return False - + can_decompose = DecomposeBatchNorm.can_decompose_batch_norm(node, ep, why) can_fuse = FuseBatchNormPass.can_fuse(input_node, bn, ep) - if not can_fuse: - why(node, f"BatchNorm cannot be fused with {input_name.split('.')[0]}") + + if not can_fuse and not can_decompose: + why(node, f"BatchNorm cannot be decomposed or fused with {input_node}") return False return True diff --git a/backends/xnnpack/test/ops/test_batch_norm.py b/backends/xnnpack/test/ops/test_batch_norm.py new file mode 100644 index 00000000000..2391d781d9a --- /dev/null +++ b/backends/xnnpack/test/ops/test_batch_norm.py @@ -0,0 +1,380 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from executorch.backends.xnnpack.test.test_xnnpack_utils import randomize_bn +from executorch.backends.xnnpack.test.tester import Tester + + +class TestBatchNorm(unittest.TestCase): + """ + End-to-end tests for standalone BatchNorm operators lowered to XNNPACK. + """ + + def setUp(self): + torch._dynamo.reset() + + class BatchNorm1dNC(torch.nn.Module): + """BatchNorm1d with NC input (batch, channels).""" + + def __init__(self, num_features: int): + super().__init__() + self.num_features = num_features + self.bn = torch.nn.BatchNorm1d(num_features) + + def forward(self, x): + return self.bn(x) + + def get_inputs(self): + return (torch.randn(2, self.num_features),) + + class BatchNorm1dNCL(torch.nn.Module): + """BatchNorm1d with NCL input (batch, channels, length).""" + + def __init__(self, num_features: int): + super().__init__() + self.num_features = num_features + self.bn = torch.nn.BatchNorm1d(num_features) + + def forward(self, x): + return self.bn(x) + + def get_inputs(self): + return (torch.randn(2, self.num_features, 8),) + + class BatchNorm2d(torch.nn.Module): + """BatchNorm2d with NCHW input (batch, channels, height, width).""" + + def __init__( + self, + num_features: int, + dtype: torch.dtype = torch.float, + affine: bool = True, + ): + super().__init__() + self.num_features = num_features + self.dtype = dtype + self.bn = torch.nn.BatchNorm2d(num_features, affine=affine).to(dtype) + + def forward(self, x): + return self.bn(x) + + def get_inputs(self): + return (torch.randn(2, self.num_features, 4, 4).to(self.dtype),) + + def _test_batch_norm(self, model: torch.nn.Module): + """ + Test that a standalone BatchNorm is lowered to XNNPACK via decomposition + to depthwise convolution. + """ + # Warm up batch norm running stats + model.eval() + with torch.no_grad(): + for _ in range(5): + model(*model.get_inputs()) + + ( + Tester(model, model.get_inputs()) + .export() + .to_edge_transform_and_lower() + .check_not( + [ + "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default" + ] + ) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .serialize() + .run_method_and_compare_outputs() + ) + + class LinearReluBatchNorm(torch.nn.Module): + """ + Linear followed by ReLU, BatchNorm, residual add, and a second Linear. + The BatchNorm is standalone (not fused) because ReLU breaks the fusion pattern. + """ + + def __init__(self, features: int): + super().__init__() + self.features = features + self.linear1 = torch.nn.Linear(features, features) + self.relu = torch.nn.ReLU() + self.bn = randomize_bn(features, dimensionality=1) + self.linear2 = torch.nn.Linear(features, features) + + def forward(self, x): + y = self.linear1(x) + y = self.relu(y) + y = self.bn(y) + y = y + x + y = self.linear2(y) + return y + + def get_inputs(self): + return (torch.randn(2, self.features),) + + def test_fp32_linear_relu_batch_norm(self): + """ + Test Linear + ReLU + BatchNorm where the BatchNorm is standalone (not fused + with linear) because ReLU breaks the fusion pattern. The standalone BatchNorm + should be decomposed to depthwise convolution. + """ + model = self.LinearReluBatchNorm(features=8) + model.eval() + + ( + Tester(model, model.get_inputs()) + .export() + .to_edge_transform_and_lower() + # BatchNorm should be decomposed (not present in the graph) + .check_not( + [ + "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default" + ] + ) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .serialize() + .run_method_and_compare_outputs() + ) + + def test_fp32_batch_norm_nc(self): + """Test BatchNorm1d with NC input is lowered to XNNPACK.""" + self._test_batch_norm(self.BatchNorm1dNC(num_features=3)) + + def test_fp32_batch_norm_ncl(self): + """Test BatchNorm1d with NCL input is lowered to XNNPACK.""" + self._test_batch_norm(self.BatchNorm1dNCL(num_features=3)) + + def test_fp32_batch_norm_nchw(self): + """Test BatchNorm2d with NCHW input is lowered to XNNPACK.""" + self._test_batch_norm(self.BatchNorm2d(num_features=3)) + + def test_fp16_batch_norm_nchw(self): + """Test BatchNorm2d with fp16 NCHW input is lowered to XNNPACK.""" + self._test_batch_norm(self.BatchNorm2d(num_features=3, dtype=torch.float16)) + + def test_fp32_batch_norm_nchw_non_affine(self): + """Test non-affine BatchNorm2d with NCHW input is lowered to XNNPACK.""" + self._test_batch_norm(self.BatchNorm2d(num_features=3, affine=False)) + + class BatchNorm2dChannelsLast(torch.nn.Module): + """BatchNorm2d with channels_last memory format input.""" + + def __init__(self, num_features: int): + super().__init__() + self.num_features = num_features + self.bn = torch.nn.BatchNorm2d(num_features) + + def forward(self, x): + return self.bn(x) + + def get_inputs(self): + return ( + torch.randn(2, self.num_features, 4, 4).to( + memory_format=torch.channels_last + ), + ) + + def test_fp32_batch_norm_nchw_channels_last(self): + """Test BatchNorm2d with channels_last memory format input is lowered to XNNPACK.""" + self._test_batch_norm(self.BatchNorm2dChannelsLast(num_features=3)) + + class BatchNorm3d(torch.nn.Module): + """BatchNorm3d with NCDHW input (batch, channels, depth, height, width).""" + + def __init__(self, num_features: int): + super().__init__() + self.num_features = num_features + self.bn = torch.nn.BatchNorm3d(num_features) + + def forward(self, x): + return self.bn(x) + + def get_inputs(self): + return (torch.randn(2, self.num_features, 4, 4, 4),) + + def test_fp32_batch_norm3d_not_partitioned(self): + """Test that BatchNorm3d is NOT partitioned to XNNPACK (unsupported).""" + model = self.BatchNorm3d(num_features=3) + model.eval() + with torch.no_grad(): + for _ in range(5): + model(*model.get_inputs()) + + ( + Tester(model, model.get_inputs()) + .export() + .to_edge_transform_and_lower() + # BatchNorm3d should remain in the graph (not lowered to XNNPACK) + .check( + [ + "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default" + ] + ) + # No delegate call should be present since nothing was partitioned + .check_not(["torch.ops.higher_order.executorch_call_delegate"]) + .to_executorch() + .serialize() + .run_method_and_compare_outputs() + ) + + class Conv2dReluBatchNorm(torch.nn.Module): + """Conv2d followed by ReLU and then BatchNorm (standalone BN, not fused).""" + + def __init__(self, in_channels: int, out_channels: int): + super().__init__() + self.in_channels = in_channels + self.conv = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, padding=1 + ) + self.relu = torch.nn.ReLU() + self.bn = randomize_bn(out_channels) + + def forward(self, x): + x = self.conv(x) + x = self.relu(x) + x = self.bn(x) + return x + + def get_inputs(self): + return (torch.randn(2, self.in_channels, 8, 8),) + + def test_fp32_conv2d_relu_batch_norm(self): + """ + Test Conv2d + ReLU + BatchNorm where the BatchNorm is standalone (not fused + with conv) because ReLU breaks the fusion pattern. The standalone BatchNorm + should be decomposed to depthwise convolution. + """ + model = self.Conv2dReluBatchNorm(in_channels=3, out_channels=8) + model.eval() + + ( + Tester(model, model.get_inputs()) + .export() + .to_edge_transform_and_lower() + # BatchNorm should be decomposed (not present in the graph) + .check_not( + [ + "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default" + ] + ) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .serialize() + .run_method_and_compare_outputs() + ) + + class Conv2dBatchNorm(torch.nn.Module): + """Conv2d followed by BatchNorm (fuseable pattern).""" + + def __init__(self, in_channels: int, out_channels: int): + super().__init__() + self.in_channels = in_channels + self.conv = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, padding=1 + ) + self.bn = randomize_bn(out_channels) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return x + + def get_inputs(self): + return (torch.randn(2, self.in_channels, 8, 8),) + + def test_fp32_conv2d_batch_norm_fused(self): + """ + Test Conv2d + BatchNorm where the BatchNorm is fused into the Conv2d. + This tests the existing fusion path (not decomposition). + """ + model = self.Conv2dBatchNorm(in_channels=3, out_channels=8) + model.eval() + + ( + Tester(model, model.get_inputs()) + .export() + .to_edge_transform_and_lower() + # BatchNorm should be fused into conv (not present in the graph) + .check_not( + [ + "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default" + ] + ) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .serialize() + .run_method_and_compare_outputs() + ) + + class Conv2dBatchNormChannelsLast(torch.nn.Module): + """Conv2d followed by BatchNorm (fuseable pattern) with channels_last input.""" + + def __init__(self, in_channels: int, out_channels: int): + super().__init__() + self.in_channels = in_channels + self.conv = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, padding=1 + ) + self.bn = randomize_bn(out_channels) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return x + + def get_inputs(self): + return ( + torch.randn(2, self.in_channels, 8, 8).to( + memory_format=torch.channels_last + ), + ) + + def test_fp32_conv2d_batch_norm_fused_channels_last(self): + """ + Test Conv2d + BatchNorm with channels_last input where the BatchNorm is + fused into the Conv2d. + """ + model = self.Conv2dBatchNormChannelsLast(in_channels=3, out_channels=8) + model.eval() + + ( + Tester(model, model.get_inputs()) + .export() + .to_edge_transform_and_lower() + # BatchNorm should be fused into conv (not present in the graph) + .check_not( + [ + "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default" + ] + ) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .serialize() + .run_method_and_compare_outputs() + ) + + def test_training_bn_not_partitioned(self): + """Test that training mode BatchNorm is not partitioned.""" + model = self.BatchNorm2d(num_features=3) + for _ in range(5): + model(*model.get_inputs()) + + ( + Tester(model, model.get_inputs(), training=True) + .export() + .to_edge_transform_and_lower() + .check( + [ + "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_functional" + ] + ) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 0}) + .run_method_and_compare_outputs() + ) diff --git a/backends/xnnpack/test/passes/test_batch_norm_fusion.py b/backends/xnnpack/test/passes/test_batch_norm_fusion.py index 4e4cd065fa5..15c96035d91 100644 --- a/backends/xnnpack/test/passes/test_batch_norm_fusion.py +++ b/backends/xnnpack/test/passes/test_batch_norm_fusion.py @@ -104,30 +104,6 @@ def test_q8_conv_batch_norm_fusion(self): .run_method_and_compare_outputs() ) - def test_fp32_conv_batch_norm_no_fusion_doesnt_partition(self): - """ - We do not currently support standalone batch norms (i.e. batch norms that are - not fused with a conv). This is planned, but until implemented, this test ensures - that we do not partition the standalone batch norm and then fail to lower. - """ - - class BN(torch.nn.Module): - def __init__(self): - super().__init__() - self.bn = torch.nn.BatchNorm2d(2) - - def forward(self, x): - return self.bn(x) - - ( - Tester(BN(), (torch.randn(2, 2, 4, 4),)) - .export() - .to_edge() - .check_count({self.bn_name: 1}) - .partition() - .check_count({self.bn_name: 1}) - ) - def test_fp32_linear_batch_norm_fusion(self): for bias in [True, False]: ( @@ -137,32 +113,10 @@ def test_fp32_linear_batch_norm_fusion(self): ) .export() .to_edge_transform_and_lower() - .check_count({self.bn_name: 1}) + .check_count({self.bn_name: 0}) .run_method_and_compare_outputs() ) - def test_fp32_linear_batch_norm_no_fusion_doesnt_partition(self): - """ - We do not currently support standalone batch norms (i.e. batch norms that are - not fused with a linear). This is planned, but until implemented, this test ensures - that we do not partition the standalone batch norm and then fail to lower. - """ - - class BN(torch.nn.Module): - def __init__(self): - super().__init__() - self.bn = torch.nn.BatchNorm1d(2) - - def forward(self, x): - return self.bn(x) - - ( - Tester(BN(), (torch.randn(2, 2),)) - .export() - .to_edge_transform_and_lower() - .check_count({self.bn_name: 1}) - ) - def test_fp32_conv3d_batch_norm_doesnt_partition(self): """ Conv3d is not currently supported by XNNPACK. We also don't support standalone diff --git a/backends/xnnpack/test/passes/test_decompose_batch_norm.py b/backends/xnnpack/test/passes/test_decompose_batch_norm.py new file mode 100644 index 00000000000..d2897258fcd --- /dev/null +++ b/backends/xnnpack/test/passes/test_decompose_batch_norm.py @@ -0,0 +1,196 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from executorch.backends.xnnpack._passes.decompose_batch_norm import DecomposeBatchNorm +from executorch.backends.xnnpack.test.tester import RunPasses, Tester +from executorch.exir import EdgeProgramManager +from executorch.exir.dialects._ops import ops as exir_ops + + +class TestDecomposeBatchNorm(unittest.TestCase): + PassStage = RunPasses([DecomposeBatchNorm]) + bn_name = "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default" + conv_name = "executorch_exir_dialects_edge__ops_aten_convolution_default" + + def setUp(self): + torch._dynamo.reset() + + class BatchNorm1dNC(torch.nn.Module): + """Simple BatchNorm1d module with NC input (no spatial dimension).""" + + def __init__(self, num_features: int): + super().__init__() + self.bn = torch.nn.BatchNorm1d(num_features) + # Run a forward pass to update the BN running stats. + self.forward(torch.randn(2, num_features) * 2 + 2) + + def forward(self, x): + return self.bn(x) + + class BatchNorm1dNCL(torch.nn.Module): + """Simple BatchNorm1d module with NCL input.""" + + def __init__(self, num_features: int): + super().__init__() + self.bn = torch.nn.BatchNorm1d(num_features) + # Run a forward pass to update the BN running stats. + self.forward(torch.randn(2, num_features, 4) * 2 + 2) + + def forward(self, x): + return self.bn(x) + + class BatchNorm2d(torch.nn.Module): + """Simple BatchNorm2d module with NCHW input.""" + + def __init__(self, num_features: int, affine: bool = True): + super().__init__() + self.bn = torch.nn.BatchNorm2d(num_features, affine=affine) + # Run a forward pass to update the BN running stats. + self.forward(torch.randn(2, num_features, 4, 4) * 2 + 2) + + def forward(self, x): + return self.bn(x) + + def test_fp32_batch_norm_nc(self): + """Test that BatchNorm1d with NC input is decomposed to convolution.""" + model = self.BatchNorm1dNC(3).eval() + tester = ( + Tester( + model, + (torch.randn(2, 3),), + ) + .export() + .to_edge() + .check_count({self.bn_name: 1}) + .run_passes(self.PassStage) + .check_count({self.conv_name: 1}) + .check_not([self.bn_name]) + .run_method_and_compare_outputs() + ) + self._validate_decomposition(tester.get_artifact(), torch.float32, 3, 1) + + def test_fp32_batch_norm_ncl(self): + """Test that BatchNorm1d with NCL input is decomposed to convolution.""" + model = self.BatchNorm1dNCL(3).eval() + tester = ( + Tester( + model, + (torch.randn(2, 3, 4),), + ) + .export() + .to_edge() + .check_count({self.bn_name: 1}) + .run_passes(self.PassStage) + .check_count({self.conv_name: 1}) + .check_not([self.bn_name]) + .run_method_and_compare_outputs() + ) + self._validate_decomposition(tester.get_artifact(), torch.float32, 3, 1) + + def test_fp32_batch_norm_nchw(self): + """Test that BatchNorm2d with NCHW input is decomposed to convolution.""" + model = self.BatchNorm2d(3).eval() + tester = ( + Tester( + model, + (torch.randn(2, 3, 4, 4),), + ) + .export() + .to_edge() + .check_count({self.bn_name: 1}) + .run_passes(self.PassStage) + .check_count({self.conv_name: 1}) + .check_not([self.bn_name]) + .run_method_and_compare_outputs() + ) + self._validate_decomposition(tester.get_artifact(), torch.float32, 3, 2) + + def test_fp16_batch_norm_nchw(self): + """Test that BatchNorm2d with NCHW input is decomposed to convolution.""" + model = self.BatchNorm2d(3).to(torch.float16).eval() + tester = ( + Tester( + model, + (torch.randn(2, 3, 4, 4, dtype=torch.float16),), + ) + .export() + .to_edge() + .check_count({self.bn_name: 1}) + .run_passes(self.PassStage) + .check_count({self.conv_name: 1}) + .check_not([self.bn_name]) + .run_method_and_compare_outputs() + ) + self._validate_decomposition(tester.get_artifact(), torch.float16, 3, 2) + + def test_fp32_batch_norm_nchw_non_affine(self): + """Test that non-affine BatchNorm2d with NCHW input is decomposed to convolution.""" + model = self.BatchNorm2d(3, affine=False).eval() + tester = ( + Tester( + model, + (torch.randn(2, 3, 4, 4),), + ) + .export() + .to_edge() + .check_count({self.bn_name: 1}) + .run_passes(self.PassStage) + .check_count({self.conv_name: 1}) + .check_not([self.bn_name]) + .run_method_and_compare_outputs() + ) + self._validate_decomposition(tester.get_artifact(), torch.float32, 3, 2) + + def _validate_decomposition( + self, + edge_manager: EdgeProgramManager, + dtype: torch.dtype, + num_channels: int, + spatial_dims: int, + ): + # Verify that the graph contains a 1x1 depthwise convolution and that + # the transformed parameter dtypes match the original. + + conv_node = next( + n + for n in edge_manager.exported_program().graph.nodes + if n.target == exir_ops.edge.aten.convolution.default + ) + self.assertEqual(conv_node.meta["val"].dtype, dtype) + + self.assertEqual(len(conv_node.args), 9) + ( + _, + w_node, + b_node, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + ) = conv_node.args + + # Check the convolution parameters. It should be 1x1 depthwise convolution. + self.assertEqual(stride, [1] * spatial_dims) + self.assertEqual(padding, [0] * spatial_dims) + self.assertEqual(dilation, [1] * spatial_dims) + self.assertEqual(transposed, False) + self.assertEqual(output_padding, [0] * spatial_dims) + self.assertEqual(groups, num_channels) + + w_meta = w_node.meta["val"] + b_meta = b_node.meta["val"] + + # Weight should be (out_c, in_c/g, kH, [kW]) + # Bias should be (out_c) + self.assertEqual(w_meta.shape, tuple([num_channels, 1] + [1] * spatial_dims)) + self.assertEqual(w_meta.dtype, dtype) + self.assertEqual(b_meta.shape, (num_channels,)) + self.assertEqual(b_meta.dtype, dtype) diff --git a/backends/xnnpack/test/tester/tester.py b/backends/xnnpack/test/tester/tester.py index 62eb504faa7..b3ba707ec96 100644 --- a/backends/xnnpack/test/tester/tester.py +++ b/backends/xnnpack/test/tester/tester.py @@ -107,6 +107,7 @@ def __init__( module: torch.nn.Module, example_inputs: Tuple[torch.Tensor], dynamic_shapes: Optional[Tuple[Any]] = None, + **kwargs, ): # Specialize for XNNPACK stage_classes = ( @@ -127,4 +128,5 @@ def __init__( stage_classes=stage_classes, example_inputs=example_inputs, dynamic_shapes=dynamic_shapes, + **kwargs, )