Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions backends/transforms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,15 @@ def create_constant_placeholder(

target = name

# If a placeholder with this target already exists, return it to avoid
# duplicate parameter names in the generated function signature which would
# cause a SyntaxError on recompile. This can happen when multiple pattern
# replacements independently create placeholders for a shared weight.
if name in exp_program.state_dict or name in exp_program.constants:
for n in graph.nodes:
if n.op == "placeholder" and n.target == name:
return n

# Add data to state_dict/ constants
match kind:
case InputKind.PARAMETER:
Expand Down
99 changes: 99 additions & 0 deletions backends/vulkan/custom_ops_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,105 @@ def q8ta_conv2d_dw(
lib.impl(name, q8ta_conv2d_dw, "CompositeExplicitAutograd")
conv2d_q8ta_q8csw_dw_op = getattr(getattr(torch.ops, namespace), name)


def q8ta_conv2d_transposed(
x: torch.Tensor,
input_scale: float,
input_zero_point: int,
weights: torch.Tensor,
weight_sums: torch.Tensor,
weight_scales: torch.Tensor,
output_scale: float,
output_zero_point: int,
bias: Optional[torch.Tensor],
kernel_size: list,
stride: list,
padding: list,
output_padding: list,
dilation: list,
groups: int,
activation: str,
):
x = torch.ops.quantized_decomposed.dequantize_per_tensor(
x, input_scale, input_zero_point, -128, 127, x.dtype
)

OC = weights.shape[0]
IC_per_group = int(x.shape[1] / groups)
K_h, K_w = kernel_size[0], kernel_size[1]

orig_weight_K_dim = K_h * K_w * IC_per_group
if weights.shape[-1] > orig_weight_K_dim:
weights = weights[:, :orig_weight_K_dim]

if weight_scales.shape[0] > OC:
weight_scales = weight_scales[:OC]
if bias is not None:
bias = bias[:OC]

# Reshape to (OC, IC_per_group, K_h, K_w) then transpose to
# (IC_per_group * groups, OC_per_group, K_h, K_w) for conv_transpose2d
weights = weights.view(OC, IC_per_group, K_h, K_w)
OC_per_group = OC // groups
weights = (
weights.view(groups, OC_per_group, IC_per_group, K_h, K_w)
.permute(0, 2, 1, 3, 4)
.contiguous()
.view(IC_per_group * groups, OC_per_group, K_h, K_w)
)

weight_zeros = torch.zeros_like(weight_scales, dtype=torch.int32)
# Dequantize per OC channel. For transposed weight (IC, OC_per_group, KH, KW),
# OC is at axis=1.
weights = torch.ops.quantized_decomposed.dequantize_per_channel(
weights,
weight_scales[:OC_per_group].repeat(groups) if groups > 1 else weight_scales,
weight_zeros[:OC_per_group].repeat(groups) if groups > 1 else weight_zeros,
1,
-127,
127,
torch.int8,
)

out = torch.nn.functional.conv_transpose2d(
x, weights, bias, stride, padding, output_padding, groups, dilation
)

if activation == "relu":
out = torch.nn.functional.relu(out)

out = torch.ops.quantized_decomposed.quantize_per_tensor(
out, output_scale, output_zero_point, -128, 127, torch.int8
)

return out


name = "q8ta_conv2d_transposed"
lib.define(
f"""
{name}(
Tensor x,
float input_scale,
int input_zero_point,
Tensor weights,
Tensor weight_sums,
Tensor weight_scales,
float output_scale,
int output_zero_point,
Tensor? bias,
SymInt[] kernel_size,
SymInt[] stride,
SymInt[] padding,
SymInt[] output_padding,
SymInt[] dilation,
SymInt groups,
str activation) -> Tensor
"""
)
lib.impl(name, q8ta_conv2d_transposed, "CompositeExplicitAutograd")
q8ta_conv2d_transposed_op = getattr(getattr(torch.ops, namespace), name)

######################
## apply_rotary_emb ##
######################
Expand Down
35 changes: 35 additions & 0 deletions backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,7 @@ def register_quantize_per_tensor():
outputs_storage=[
utils.PACKED_INT8_BUFFER,
],
supports_highdim=True,
)


Expand All @@ -499,6 +500,7 @@ def register_dequantize_per_tensor():
outputs_storage=[
utils.CHANNELS_PACKED_TEXTURE_OR_CONTIGUOUS_BUFFER,
],
supports_highdim=True,
)


Expand Down Expand Up @@ -863,6 +865,39 @@ def register_q8ta_conv2d_ops():
)


@update_features(
[
exir_ops.edge.et_vk.q8ta_conv2d_transposed.default,
]
)
def register_q8ta_conv2d_transposed_op():
return OpFeatures(
inputs_storage=[
utils.PACKED_INT8_CONV2D_BUFFER, # input
utils.NO_STORAGE, # input_scale (non tensor)
utils.NO_STORAGE, # input_zero_point (non tensor)
utils.NO_STORAGE, # weight (prepacked)
utils.NO_STORAGE, # weight_sums (prepacked)
utils.NO_STORAGE, # weight_scales (prepacked)
utils.NO_STORAGE, # output_scale (non tensor)
utils.NO_STORAGE, # output_zero_point (non tensor)
utils.NO_STORAGE, # bias (prepacked)
utils.NO_STORAGE, # kernel_size (non tensor)
utils.NO_STORAGE, # stride (non tensor)
utils.NO_STORAGE, # padding (non tensor)
utils.NO_STORAGE, # output_padding (non tensor)
utils.NO_STORAGE, # dilation (non tensor)
utils.NO_STORAGE, # groups (non tensor)
utils.NO_STORAGE, # activation (non tensor)
],
outputs_storage=[
utils.PACKED_INT8_CHANNELS_PACKED_BUFFER,
],
supports_resize=False,
supports_prepacking=True,
)


# =============================================================================
# Q8taLinear.cpp
# =============================================================================
Expand Down
101 changes: 78 additions & 23 deletions backends/vulkan/patterns/quantized_convolution.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional
from typing import cast, List, Optional

import executorch.backends.vulkan.utils as utils

Expand All @@ -28,17 +28,32 @@


class QuantizedConvolutionMatch(PatternMatch):
def __init__(self, conv_node: torch.fx.Node) -> None:

Check warning on line 31 in backends/vulkan/patterns/quantized_convolution.py

View workflow job for this annotation

GitHub Actions / lintrunner

FLAKE8 C901

'QuantizedConvolutionMatch.__init__' is too complex (17) See https://www.flake8rules.com/rules/C901.html.
self.anchor_node = conv_node
self.match_found = False
self.all_nodes = [self.anchor_node]

# Determine if this is a transposed convolution
self.transposed = False
self.output_padding = [0, 0]
if conv_node.target == exir_ops.edge.aten.convolution.default:
transposed_flag = conv_node.args[6] if len(conv_node.args) > 6 else False
if transposed_flag:
self.transposed = True
self.output_padding = (
cast(List[int], conv_node.args[7]) if len(conv_node.args) > 7 else [0, 0]
)

# Extract convolution parameters
self.stride = conv_node.args[3] if len(conv_node.args) > 3 else [1, 1]
self.padding = conv_node.args[4] if len(conv_node.args) > 4 else [0, 0]
self.dilation = conv_node.args[5] if len(conv_node.args) > 5 else [1, 1]
self.groups = conv_node.args[8] if len(conv_node.args) > 8 else 1

# Transposed conv only supported with dilation=[1,1]
if self.transposed and cast(List[int], self.dilation) != [1, 1]:
return

const_node, arg_chain = utils.trace_args_until_placeholder(
self.anchor_node.args[1]
)
Expand All @@ -60,6 +75,16 @@
self.dequantize_weight_node = dequantize_weight_node
self.all_nodes.extend(arg_chain)

# For transposed conv, verify per-channel quantization is on the OC dimension.
# Transposed weight shape is (IC, OC_per_group, KH, KW), so per-OC quantization
# should be on axis=1. If axis=0, that's per-IC which is not supported.
if self.transposed and utils.is_dequant_per_channel_node(
self.dequantize_weight_node
):
quant_axis = self.dequantize_weight_node.args[3]
if quant_axis != 1:
return

# Identify weight quantization parameter nodes
self.weight_scales_node, arg_chain = utils.trace_args_until_placeholder(
self.dequantize_weight_node.args[1]
Expand Down Expand Up @@ -177,9 +202,30 @@
bias_tensor = get_param_tensor(ep, match.bias_node)
assert bias_tensor is not None

OC, IC_per_group, H, W = weight_tensor.shape
if match.transposed:
# Transposed conv weight shape: (IC, OC_per_group, H, W)
IC, OC_per_group, H, W = weight_tensor.shape
OC = OC_per_group * match.groups
IC_per_group = IC // match.groups
# Reshape to (OC, H*W*IC_per_group) matrix format for Im2Col-based
# transposed convolution.
# (IC, OC_per_group, H, W) ->
# (groups, IC_per_group, OC_per_group, H, W) ->
# (groups, OC_per_group, H, W, IC_per_group) ->
# (OC, H*W*IC_per_group)
weight_tensor = (
weight_tensor.reshape(match.groups, IC_per_group, OC_per_group, H, W)
.permute(0, 2, 3, 4, 1)
.contiguous()
.reshape(OC, H * W * IC_per_group)
.contiguous()
)
else:
OC, IC_per_group, H, W = weight_tensor.shape

is_depthwise_conv = IC_per_group == 1 and match.groups == OC
is_depthwise_conv = (
not match.transposed and IC_per_group == 1 and match.groups == OC
)

if is_depthwise_conv:
assert OC % 4 == 0, "depthwise conv requires that OC is divisible by 4"
Expand All @@ -188,7 +234,7 @@
weight_tensor = (
weight_tensor.permute(2, 3, 1, 0).contiguous().view(H, W, OC).contiguous()
)
else:
elif not match.transposed:
# Reshape weight tensor from (OC, IC_per_group, H, W) to (OC, H * W * IC_per_group)
# (i.e. matrix format). This prepares the weights for Im2Col-based convolution.
weight_tensor = (
Expand Down Expand Up @@ -257,32 +303,41 @@
)

with graph_module.graph.inserting_before(match.output_node):
op_target = exir_ops.edge.et_vk.q8ta_conv2d.default
if is_depthwise_conv:
if match.transposed:
op_target = exir_ops.edge.et_vk.q8ta_conv2d_transposed.default
elif is_depthwise_conv:
op_target = exir_ops.edge.et_vk.q8ta_conv2d_dw.default
elif is_pointwise_conv:
op_target = exir_ops.edge.et_vk.q8ta_conv2d_pw.default
else:
op_target = exir_ops.edge.et_vk.q8ta_conv2d.default

op_args = (
match.quantize_input_node,
match.input_scales_node,
match.input_zeros_node,
match.weight_node,
weight_sums_node,
match.weight_scales_node,
match.output_scales_node,
match.output_zeros_node,
match.bias_node,
[H, W],
match.stride,
match.padding,
)
if match.transposed:
op_args = op_args + (match.output_padding,)
op_args = op_args + (
match.dilation,
match.groups,
"relu" if match.relu_node is not None else "none",
)

qconv_node = graph_module.graph.create_node(
"call_function",
op_target,
args=(
match.quantize_input_node,
match.input_scales_node,
match.input_zeros_node,
match.weight_node,
weight_sums_node,
match.weight_scales_node,
match.output_scales_node,
match.output_zeros_node,
match.bias_node, # Add bias after weight_scales
[H, W], # Pass kernel size information before stride
match.stride,
match.padding,
match.dilation,
match.groups,
"relu" if match.relu_node is not None else "none",
),
args=op_args,
)

qconv_node.meta["val"] = match.output_node.meta["val"]
Expand Down
15 changes: 15 additions & 0 deletions backends/vulkan/patterns/quantized_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,11 @@ def __init__(self, mm_node: torch.fx.Node) -> None: # noqa: C901
# Identify output node
self.output_node = self.anchor_node

# bmm with batch dim > 1 is not supported
is_bmm = self.anchor_node.target == exir_ops.edge.aten.bmm.default
if is_bmm and self.output_node.meta["val"].shape[0] != 1:
return

# Identify primary input node of the anchor. Due to decomposition of aten.linear
# there may be a view_copy node between the original input tensor to the linear
# op and the actual linear op node.
Expand Down Expand Up @@ -174,12 +179,21 @@ def __init__(self, mm_node: torch.fx.Node) -> None: # noqa: C901

# Check if the output is also quantized (q → dq → linear → q pattern)
# Also handle fused linear+relu (q → dq → linear → relu → q pattern)
# Due to decomposition of aten.linear for 3D+ inputs, there may be a
# view_copy between the mm output and the quantize node.
self.quantize_output_node = None
self.output_scales_node = None
self.output_zeros_node = None
self.relu_node = None
self.output_view_copy_node = None
if len(self.output_node.users) == 1:
cur_node = list(self.output_node.users)[0]
# Skip potential view_copy between linear and output quantize
if utils.is_view_copy_node(cur_node) and len(cur_node.users) == 1:
self.output_view_copy_node = cur_node
self.all_nodes.append(self.output_view_copy_node)
self.output_node = self.output_view_copy_node
cur_node = list(cur_node.users)[0]
if cur_node.target == exir_ops.edge.aten.relu.default:
self.relu_node = cur_node
if len(cur_node.users) == 1:
Expand Down Expand Up @@ -259,6 +273,7 @@ def is_input_dynamic_perchannel_quantized(self) -> bool:
exir_ops.edge.aten.linear.default,
exir_ops.edge.aten.mm.default,
exir_ops.edge.aten.addmm.default,
exir_ops.edge.aten.bmm.default,
}


Expand Down
Loading
Loading