diff --git a/backends/transforms/utils.py b/backends/transforms/utils.py index ef19a937b0b..cb3e8c9469a 100644 --- a/backends/transforms/utils.py +++ b/backends/transforms/utils.py @@ -111,6 +111,15 @@ def create_constant_placeholder( target = name + # If a placeholder with this target already exists, return it to avoid + # duplicate parameter names in the generated function signature which would + # cause a SyntaxError on recompile. This can happen when multiple pattern + # replacements independently create placeholders for a shared weight. + if name in exp_program.state_dict or name in exp_program.constants: + for n in graph.nodes: + if n.op == "placeholder" and n.target == name: + return n + # Add data to state_dict/ constants match kind: case InputKind.PARAMETER: diff --git a/backends/vulkan/custom_ops_lib.py b/backends/vulkan/custom_ops_lib.py index 87506f0b773..7f687bb10f4 100644 --- a/backends/vulkan/custom_ops_lib.py +++ b/backends/vulkan/custom_ops_lib.py @@ -685,6 +685,105 @@ def q8ta_conv2d_dw( lib.impl(name, q8ta_conv2d_dw, "CompositeExplicitAutograd") conv2d_q8ta_q8csw_dw_op = getattr(getattr(torch.ops, namespace), name) + +def q8ta_conv2d_transposed( + x: torch.Tensor, + input_scale: float, + input_zero_point: int, + weights: torch.Tensor, + weight_sums: torch.Tensor, + weight_scales: torch.Tensor, + output_scale: float, + output_zero_point: int, + bias: Optional[torch.Tensor], + kernel_size: list, + stride: list, + padding: list, + output_padding: list, + dilation: list, + groups: int, + activation: str, +): + x = torch.ops.quantized_decomposed.dequantize_per_tensor( + x, input_scale, input_zero_point, -128, 127, x.dtype + ) + + OC = weights.shape[0] + IC_per_group = int(x.shape[1] / groups) + K_h, K_w = kernel_size[0], kernel_size[1] + + orig_weight_K_dim = K_h * K_w * IC_per_group + if weights.shape[-1] > orig_weight_K_dim: + weights = weights[:, :orig_weight_K_dim] + + if weight_scales.shape[0] > OC: + weight_scales = weight_scales[:OC] + if bias is not None: + bias = bias[:OC] + + # Reshape to (OC, IC_per_group, K_h, K_w) then transpose to + # (IC_per_group * groups, OC_per_group, K_h, K_w) for conv_transpose2d + weights = weights.view(OC, IC_per_group, K_h, K_w) + OC_per_group = OC // groups + weights = ( + weights.view(groups, OC_per_group, IC_per_group, K_h, K_w) + .permute(0, 2, 1, 3, 4) + .contiguous() + .view(IC_per_group * groups, OC_per_group, K_h, K_w) + ) + + weight_zeros = torch.zeros_like(weight_scales, dtype=torch.int32) + # Dequantize per OC channel. For transposed weight (IC, OC_per_group, KH, KW), + # OC is at axis=1. + weights = torch.ops.quantized_decomposed.dequantize_per_channel( + weights, + weight_scales[:OC_per_group].repeat(groups) if groups > 1 else weight_scales, + weight_zeros[:OC_per_group].repeat(groups) if groups > 1 else weight_zeros, + 1, + -127, + 127, + torch.int8, + ) + + out = torch.nn.functional.conv_transpose2d( + x, weights, bias, stride, padding, output_padding, groups, dilation + ) + + if activation == "relu": + out = torch.nn.functional.relu(out) + + out = torch.ops.quantized_decomposed.quantize_per_tensor( + out, output_scale, output_zero_point, -128, 127, torch.int8 + ) + + return out + + +name = "q8ta_conv2d_transposed" +lib.define( + f""" + {name}( + Tensor x, + float input_scale, + int input_zero_point, + Tensor weights, + Tensor weight_sums, + Tensor weight_scales, + float output_scale, + int output_zero_point, + Tensor? bias, + SymInt[] kernel_size, + SymInt[] stride, + SymInt[] padding, + SymInt[] output_padding, + SymInt[] dilation, + SymInt groups, + str activation) -> Tensor + """ +) +lib.impl(name, q8ta_conv2d_transposed, "CompositeExplicitAutograd") +q8ta_conv2d_transposed_op = getattr(getattr(torch.ops, namespace), name) + ###################### ## apply_rotary_emb ## ###################### diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index bb7c0562bad..af2389d72f9 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -482,6 +482,7 @@ def register_quantize_per_tensor(): outputs_storage=[ utils.PACKED_INT8_BUFFER, ], + supports_highdim=True, ) @@ -499,6 +500,7 @@ def register_dequantize_per_tensor(): outputs_storage=[ utils.CHANNELS_PACKED_TEXTURE_OR_CONTIGUOUS_BUFFER, ], + supports_highdim=True, ) @@ -863,6 +865,39 @@ def register_q8ta_conv2d_ops(): ) +@update_features( + [ + exir_ops.edge.et_vk.q8ta_conv2d_transposed.default, + ] +) +def register_q8ta_conv2d_transposed_op(): + return OpFeatures( + inputs_storage=[ + utils.PACKED_INT8_CONV2D_BUFFER, # input + utils.NO_STORAGE, # input_scale (non tensor) + utils.NO_STORAGE, # input_zero_point (non tensor) + utils.NO_STORAGE, # weight (prepacked) + utils.NO_STORAGE, # weight_sums (prepacked) + utils.NO_STORAGE, # weight_scales (prepacked) + utils.NO_STORAGE, # output_scale (non tensor) + utils.NO_STORAGE, # output_zero_point (non tensor) + utils.NO_STORAGE, # bias (prepacked) + utils.NO_STORAGE, # kernel_size (non tensor) + utils.NO_STORAGE, # stride (non tensor) + utils.NO_STORAGE, # padding (non tensor) + utils.NO_STORAGE, # output_padding (non tensor) + utils.NO_STORAGE, # dilation (non tensor) + utils.NO_STORAGE, # groups (non tensor) + utils.NO_STORAGE, # activation (non tensor) + ], + outputs_storage=[ + utils.PACKED_INT8_CHANNELS_PACKED_BUFFER, + ], + supports_resize=False, + supports_prepacking=True, + ) + + # ============================================================================= # Q8taLinear.cpp # ============================================================================= diff --git a/backends/vulkan/patterns/quantized_convolution.py b/backends/vulkan/patterns/quantized_convolution.py index 12ebbd1a382..d291d4009b7 100644 --- a/backends/vulkan/patterns/quantized_convolution.py +++ b/backends/vulkan/patterns/quantized_convolution.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Optional +from typing import cast, List, Optional import executorch.backends.vulkan.utils as utils @@ -33,12 +33,27 @@ def __init__(self, conv_node: torch.fx.Node) -> None: self.match_found = False self.all_nodes = [self.anchor_node] + # Determine if this is a transposed convolution + self.transposed = False + self.output_padding = [0, 0] + if conv_node.target == exir_ops.edge.aten.convolution.default: + transposed_flag = conv_node.args[6] if len(conv_node.args) > 6 else False + if transposed_flag: + self.transposed = True + self.output_padding = ( + cast(List[int], conv_node.args[7]) if len(conv_node.args) > 7 else [0, 0] + ) + # Extract convolution parameters self.stride = conv_node.args[3] if len(conv_node.args) > 3 else [1, 1] self.padding = conv_node.args[4] if len(conv_node.args) > 4 else [0, 0] self.dilation = conv_node.args[5] if len(conv_node.args) > 5 else [1, 1] self.groups = conv_node.args[8] if len(conv_node.args) > 8 else 1 + # Transposed conv only supported with dilation=[1,1] + if self.transposed and cast(List[int], self.dilation) != [1, 1]: + return + const_node, arg_chain = utils.trace_args_until_placeholder( self.anchor_node.args[1] ) @@ -60,6 +75,16 @@ def __init__(self, conv_node: torch.fx.Node) -> None: self.dequantize_weight_node = dequantize_weight_node self.all_nodes.extend(arg_chain) + # For transposed conv, verify per-channel quantization is on the OC dimension. + # Transposed weight shape is (IC, OC_per_group, KH, KW), so per-OC quantization + # should be on axis=1. If axis=0, that's per-IC which is not supported. + if self.transposed and utils.is_dequant_per_channel_node( + self.dequantize_weight_node + ): + quant_axis = self.dequantize_weight_node.args[3] + if quant_axis != 1: + return + # Identify weight quantization parameter nodes self.weight_scales_node, arg_chain = utils.trace_args_until_placeholder( self.dequantize_weight_node.args[1] @@ -177,9 +202,30 @@ def make_q8ta_conv2d_custom_op( bias_tensor = get_param_tensor(ep, match.bias_node) assert bias_tensor is not None - OC, IC_per_group, H, W = weight_tensor.shape + if match.transposed: + # Transposed conv weight shape: (IC, OC_per_group, H, W) + IC, OC_per_group, H, W = weight_tensor.shape + OC = OC_per_group * match.groups + IC_per_group = IC // match.groups + # Reshape to (OC, H*W*IC_per_group) matrix format for Im2Col-based + # transposed convolution. + # (IC, OC_per_group, H, W) -> + # (groups, IC_per_group, OC_per_group, H, W) -> + # (groups, OC_per_group, H, W, IC_per_group) -> + # (OC, H*W*IC_per_group) + weight_tensor = ( + weight_tensor.reshape(match.groups, IC_per_group, OC_per_group, H, W) + .permute(0, 2, 3, 4, 1) + .contiguous() + .reshape(OC, H * W * IC_per_group) + .contiguous() + ) + else: + OC, IC_per_group, H, W = weight_tensor.shape - is_depthwise_conv = IC_per_group == 1 and match.groups == OC + is_depthwise_conv = ( + not match.transposed and IC_per_group == 1 and match.groups == OC + ) if is_depthwise_conv: assert OC % 4 == 0, "depthwise conv requires that OC is divisible by 4" @@ -188,7 +234,7 @@ def make_q8ta_conv2d_custom_op( weight_tensor = ( weight_tensor.permute(2, 3, 1, 0).contiguous().view(H, W, OC).contiguous() ) - else: + elif not match.transposed: # Reshape weight tensor from (OC, IC_per_group, H, W) to (OC, H * W * IC_per_group) # (i.e. matrix format). This prepares the weights for Im2Col-based convolution. weight_tensor = ( @@ -257,32 +303,41 @@ def make_q8ta_conv2d_custom_op( ) with graph_module.graph.inserting_before(match.output_node): - op_target = exir_ops.edge.et_vk.q8ta_conv2d.default - if is_depthwise_conv: + if match.transposed: + op_target = exir_ops.edge.et_vk.q8ta_conv2d_transposed.default + elif is_depthwise_conv: op_target = exir_ops.edge.et_vk.q8ta_conv2d_dw.default elif is_pointwise_conv: op_target = exir_ops.edge.et_vk.q8ta_conv2d_pw.default + else: + op_target = exir_ops.edge.et_vk.q8ta_conv2d.default + + op_args = ( + match.quantize_input_node, + match.input_scales_node, + match.input_zeros_node, + match.weight_node, + weight_sums_node, + match.weight_scales_node, + match.output_scales_node, + match.output_zeros_node, + match.bias_node, + [H, W], + match.stride, + match.padding, + ) + if match.transposed: + op_args = op_args + (match.output_padding,) + op_args = op_args + ( + match.dilation, + match.groups, + "relu" if match.relu_node is not None else "none", + ) qconv_node = graph_module.graph.create_node( "call_function", op_target, - args=( - match.quantize_input_node, - match.input_scales_node, - match.input_zeros_node, - match.weight_node, - weight_sums_node, - match.weight_scales_node, - match.output_scales_node, - match.output_zeros_node, - match.bias_node, # Add bias after weight_scales - [H, W], # Pass kernel size information before stride - match.stride, - match.padding, - match.dilation, - match.groups, - "relu" if match.relu_node is not None else "none", - ), + args=op_args, ) qconv_node.meta["val"] = match.output_node.meta["val"] diff --git a/backends/vulkan/patterns/quantized_linear.py b/backends/vulkan/patterns/quantized_linear.py index df80749e72f..b9b307e14f1 100644 --- a/backends/vulkan/patterns/quantized_linear.py +++ b/backends/vulkan/patterns/quantized_linear.py @@ -90,6 +90,11 @@ def __init__(self, mm_node: torch.fx.Node) -> None: # noqa: C901 # Identify output node self.output_node = self.anchor_node + # bmm with batch dim > 1 is not supported + is_bmm = self.anchor_node.target == exir_ops.edge.aten.bmm.default + if is_bmm and self.output_node.meta["val"].shape[0] != 1: + return + # Identify primary input node of the anchor. Due to decomposition of aten.linear # there may be a view_copy node between the original input tensor to the linear # op and the actual linear op node. @@ -174,12 +179,21 @@ def __init__(self, mm_node: torch.fx.Node) -> None: # noqa: C901 # Check if the output is also quantized (q → dq → linear → q pattern) # Also handle fused linear+relu (q → dq → linear → relu → q pattern) + # Due to decomposition of aten.linear for 3D+ inputs, there may be a + # view_copy between the mm output and the quantize node. self.quantize_output_node = None self.output_scales_node = None self.output_zeros_node = None self.relu_node = None + self.output_view_copy_node = None if len(self.output_node.users) == 1: cur_node = list(self.output_node.users)[0] + # Skip potential view_copy between linear and output quantize + if utils.is_view_copy_node(cur_node) and len(cur_node.users) == 1: + self.output_view_copy_node = cur_node + self.all_nodes.append(self.output_view_copy_node) + self.output_node = self.output_view_copy_node + cur_node = list(cur_node.users)[0] if cur_node.target == exir_ops.edge.aten.relu.default: self.relu_node = cur_node if len(cur_node.users) == 1: @@ -259,6 +273,7 @@ def is_input_dynamic_perchannel_quantized(self) -> bool: exir_ops.edge.aten.linear.default, exir_ops.edge.aten.mm.default, exir_ops.edge.aten.addmm.default, + exir_ops.edge.aten.bmm.default, } diff --git a/backends/vulkan/runtime/graph/ops/glsl/block_indexing.glslh b/backends/vulkan/runtime/graph/ops/glsl/block_indexing.glslh index e7e64a601ef..d6693d7ff26 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/block_indexing.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/block_indexing.glslh @@ -109,6 +109,78 @@ TensorIndex4D contiguous_block_idx_to_tensor4d_idx_with_block_config( return tidx; } +/* + * 8D version: decomposes a contiguous block index into a TensorIndex with up to + * 8 dimensions. The block config encodes dim_order for the first 4 dims. Dims + * 4-7 are implicit outer (non-block) batch dimensions in ascending order. + */ +TensorIndex contiguous_block_idx_to_tensor_idx_with_block_config( + const BufferMetadata meta, + const uint block_idx, + const int block_config) { + TensorIndex tidx; + initialize(tidx); + + const int inner_dim = get_block_inner_dim(block_config); + const int outer_dim = get_block_outer_dim(block_config); + const int nonblock_dim_1 = get_block_dim_order_2(block_config); + const int nonblock_dim_2 = get_block_dim_order_3(block_config); + const uint inner_bs = uint(get_block_inner_dim_block_size(block_config)); + const uint outer_bs = uint(get_block_outer_dim_block_size(block_config)); + + // Compute block strides for each dim, inner→outer ordering: + // inner_dim, outer_dim, nonblock1, nonblock2, 4, 5, 6, 7 + uint stride = 1; + + // Inner block dim + stride = div_up(size_at(meta, inner_dim), inner_bs); + + // Outer block dim + uint stride_outer = stride; + stride *= div_up(size_at(meta, outer_dim), outer_bs); + + // First non-block dim + uint stride_nb1 = stride; + stride *= size_at(meta, nonblock_dim_1); + + // Second non-block dim + uint stride_nb2 = stride; + stride *= size_at(meta, nonblock_dim_2); + + // Extra batch dims (4-7): always non-block, ascending order + uint batch_strides[4]; + [[unroll]] for (int i = 0; i < 4; ++i) { + batch_strides[i] = stride; + stride *= size_at(meta, 4 + i); + } + + // Decompose from outermost to innermost + uint remaining = block_idx; + + // Dims 7, 6, 5, 4 (outermost batch dims first) + [[unroll]] for (int i = 3; i >= 0; --i) { + tidx.data[1][i] = remaining / batch_strides[i]; + remaining %= batch_strides[i]; + } + + // Second non-block dim + tidx.data[0][nonblock_dim_2] = remaining / stride_nb2; + remaining %= stride_nb2; + + // First non-block dim + tidx.data[0][nonblock_dim_1] = remaining / stride_nb1; + remaining %= stride_nb1; + + // Outer block dim (multiply by block size to get element index) + tidx.data[0][outer_dim] = mul_4(remaining / stride_outer); + remaining %= stride_outer; + + // Inner block dim (multiply by block size to get element index) + tidx.data[0][inner_dim] = mul_4(remaining); + + return tidx; +} + // // TextureMetadata variants of block indexing // diff --git a/backends/vulkan/runtime/graph/ops/glsl/block_int8x4_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/block_int8x4_load.glslh index 6ea636a0a17..d889276082a 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/block_int8x4_load.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/block_int8x4_load.glslh @@ -71,4 +71,55 @@ return block; \ } +// +// 8D version (TensorIndex) +// + +#define define_load_int8x4_buffer_8d_fns(buffer_name) \ + \ + ivec4 load_int8x4_block_from_##buffer_name( \ + const BufferMetadata meta, \ + const TensorIndex tidx_base, \ + const int hashed_layout, \ + const int block_outer_dim) { \ + const int outer_packed_dim = get_outer_packed_dim(hashed_layout); \ + const int outer_block_size = \ + get_outer_packed_dim_block_size(hashed_layout); \ + \ + /* Compute base packed index using 8D block-based indexing */ \ + const uint block_idx = \ + tensor_idx_to_block_idx(meta, tidx_base, hashed_layout); \ + const uint texels_per_block = div_4(get_block_numel(hashed_layout)); \ + uint buf_idx = block_idx * texels_per_block; \ + \ + /* Fast path: contiguous texels when iterating along outer_packed_dim */ \ + if (outer_block_size == 4) { \ + if (block_outer_dim == outer_packed_dim) { \ + return ivec4( \ + buffer_name[buf_idx], \ + buffer_name[buf_idx + 1], \ + buffer_name[buf_idx + 2], \ + buffer_name[buf_idx + 3]); \ + } \ + else { \ + buf_idx += mod_4(int(tidx_base.data[0][outer_packed_dim])); \ + } \ + } \ + \ + /* General path: use stride for non-contiguous access */ \ + const uint outer_stride = \ + stride_at(meta, block_outer_dim) * texels_per_block; \ + const uint outer_size = size_at(meta, block_outer_dim); \ + const int base_outer_idx = int(tidx_base.data[0][block_outer_dim]); \ + \ + ivec4 block = ivec4(0); \ + [[unroll]] for (int block_y = 0; block_y < 4; ++block_y) { \ + if (base_outer_idx + block_y < int(outer_size)) { \ + block[block_y] = buffer_name[buf_idx]; \ + } \ + buf_idx += outer_stride; \ + } \ + return block; \ + } + #endif // BLOCK_INT8X4_LOAD_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/block_int8x4_store.glslh b/backends/vulkan/runtime/graph/ops/glsl/block_int8x4_store.glslh index 2a0e037c291..425df205317 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/block_int8x4_store.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/block_int8x4_store.glslh @@ -71,4 +71,54 @@ } \ } +// +// 8D version (TensorIndex) +// + +#define define_store_int8x4_buffer_8d_fns(buffer_name) \ + \ + void store_int8x4_block_to_##buffer_name( \ + const BufferMetadata meta, \ + const TensorIndex tidx_base, \ + const int hashed_layout, \ + const int block_outer_dim, \ + const ivec4 block) { \ + const int outer_packed_dim = get_outer_packed_dim(hashed_layout); \ + const int outer_block_size = \ + get_outer_packed_dim_block_size(hashed_layout); \ + \ + /* Compute base packed index using 8D block-based indexing */ \ + const uint block_idx = \ + tensor_idx_to_block_idx(meta, tidx_base, hashed_layout); \ + const uint texels_per_block = div_4(get_block_numel(hashed_layout)); \ + uint buf_idx = block_idx * texels_per_block; \ + \ + /* Fast path: contiguous texels when iterating along outer_packed_dim */ \ + if (outer_block_size == 4) { \ + if (block_outer_dim == outer_packed_dim) { \ + buffer_name[buf_idx] = block[0]; \ + buffer_name[buf_idx + 1] = block[1]; \ + buffer_name[buf_idx + 2] = block[2]; \ + buffer_name[buf_idx + 3] = block[3]; \ + return; \ + } \ + else { \ + buf_idx += mod_4(int(tidx_base.data[0][outer_packed_dim])); \ + } \ + } \ + \ + /* General path: use stride for non-contiguous access */ \ + const uint outer_stride = \ + stride_at(meta, block_outer_dim) * texels_per_block; \ + const uint outer_size = size_at(meta, block_outer_dim); \ + const int base_outer_idx = int(tidx_base.data[0][block_outer_dim]); \ + \ + [[unroll]] for (int block_y = 0; block_y < 4; ++block_y) { \ + if (base_outer_idx + block_y < int(outer_size)) { \ + buffer_name[buf_idx] = block[block_y]; \ + } \ + buf_idx += outer_stride; \ + } \ + } + #endif // BLOCK_INT8X4_STORE_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/block_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/block_load.glslh index d72a176aa0e..3719cbd3460 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/block_load.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/block_load.glslh @@ -102,4 +102,79 @@ return block; \ } +// +// 8D buffer load functions (TensorIndex version) +// + +#define define_load_buffer_8d_fns(buffer_name) \ + \ + mat4 load_fp_block_from_##buffer_name( \ + const BufferMetadata meta, \ + const TensorIndex tidx_base, \ + const int hashed_layout, \ + const int block_outer_dim) { \ + const int block_inner_dim = get_packed_dim(hashed_layout); \ + \ + /* Compute base buffer index using 8D strides */ \ + const uint base_idx = \ + tensor_idx_to_buf_idx(meta, tidx_base, hashed_layout); \ + const uint outer_stride = stride_at(meta, block_outer_dim); \ + /* Inner stride is 1 since packed_dim == block_inner_dim */ \ + \ + /* Pre-compute bounds (block dims always < 4, so use data[0]) */ \ + const uint outer_size = size_at(meta, block_outer_dim); \ + const uint inner_size = size_at(meta, block_inner_dim); \ + const int base_outer_idx = int(tidx_base.data[0][block_outer_dim]); \ + const int base_inner_idx = int(tidx_base.data[0][block_inner_dim]); \ + \ + mat4 block; \ + [[unroll]] for (int block_y = 0; block_y < 4; ++block_y) { \ + if (base_outer_idx + block_y < int(outer_size)) { \ + const uint row_idx = base_idx + block_y * outer_stride; \ + [[unroll]] for (int block_x = 0; block_x < 4; ++block_x) { \ + if (base_inner_idx + block_x < int(inner_size)) { \ + block[block_y][block_x] = float(buffer_name[row_idx + block_x]); \ + } else { \ + block[block_y][block_x] = 0.0; \ + } \ + } \ + } else { \ + block[block_y] = vec4(0.0); \ + } \ + } \ + return block; \ + } + +// +// 8D texture load functions (TensorIndex version) +// Converts to TensorIndex4D internally for texture operations. +// + +#define define_load_texture_8d_fns(texture_name) \ + \ + mat4 load_fp_block_from_##texture_name( \ + const TextureMetadata meta, \ + const TensorIndex tidx_base, \ + const int hashed_layout, \ + const int block_outer_dim) { \ + /* Convert to 4D for texture operations (textures are always <= 4D) */ \ + TensorIndex4D tidx4d; \ + tidx4d.data = ivec4(tidx_base.data[0]); \ + ivec3 tex_pos = tensor4d_idx_to_texel_pos_simple(meta, tidx4d); \ + const int tex_outer_dim = mod_4(block_outer_dim); \ + const int outer_size = meta.sizes[block_outer_dim]; \ + const int base_outer_idx = tidx4d.data[block_outer_dim]; \ + \ + mat4 block; \ + [[unroll]] for (int block_y = 0; block_y < 4; ++block_y) { \ + if (base_outer_idx + block_y < outer_size) { \ + block[block_y] = vec4(texelFetch(texture_name, tex_pos, 0)); \ + } else { \ + block[block_y] = vec4(0.0); \ + } \ + tex_pos[tex_outer_dim]++; \ + } \ + return block; \ + } + #endif // BLOCK_LOAD_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/block_store.glslh b/backends/vulkan/runtime/graph/ops/glsl/block_store.glslh index 66e9ab9fa2b..aa302290229 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/block_store.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/block_store.glslh @@ -94,4 +94,72 @@ } \ } +// +// 8D buffer store functions (TensorIndex version) +// + +#define define_store_buffer_8d_fns(buffer_name, scalar_type) \ + \ + void store_fp_block_to_##buffer_name( \ + const BufferMetadata meta, \ + const TensorIndex tidx_base, \ + const int hashed_layout, \ + const int block_outer_dim, \ + const mat4 block) { \ + const int block_inner_dim = get_packed_dim(hashed_layout); \ + \ + /* Compute base buffer index using 8D strides */ \ + const uint base_idx = \ + tensor_idx_to_buf_idx(meta, tidx_base, hashed_layout); \ + const uint outer_stride = stride_at(meta, block_outer_dim); \ + /* Inner stride is 1 since packed_dim == block_inner_dim */ \ + \ + /* Pre-compute bounds (block dims always < 4, so use data[0]) */ \ + const uint outer_size = size_at(meta, block_outer_dim); \ + const uint inner_size = size_at(meta, block_inner_dim); \ + const int base_outer_idx = int(tidx_base.data[0][block_outer_dim]);\ + const int base_inner_idx = int(tidx_base.data[0][block_inner_dim]);\ + \ + [[unroll]] for (int block_y = 0; block_y < 4; ++block_y) { \ + if (base_outer_idx + block_y < int(outer_size)) { \ + const uint row_idx = base_idx + block_y * outer_stride; \ + [[unroll]] for (int block_x = 0; block_x < 4; ++block_x) { \ + if (base_inner_idx + block_x < int(inner_size)) { \ + buffer_name[row_idx + block_x] = \ + scalar_type(block[block_y][block_x]); \ + } \ + } \ + } \ + } \ + } + +// +// 8D texture store functions (TensorIndex version) +// Converts to TensorIndex4D internally for texture operations. +// + +#define define_store_texture_8d_fns(texture_name, vec4_type) \ + \ + void store_fp_block_to_##texture_name( \ + const TextureMetadata meta, \ + const TensorIndex tidx_base, \ + const int hashed_layout, \ + const int block_outer_dim, \ + const mat4 block) { \ + /* Convert to 4D for texture operations (textures are always <= 4D) */ \ + TensorIndex4D tidx4d; \ + tidx4d.data = ivec4(tidx_base.data[0]); \ + ivec3 tex_pos = tensor4d_idx_to_texel_pos_simple(meta, tidx4d); \ + const int tex_outer_dim = mod_4(block_outer_dim); \ + const int outer_size = meta.sizes[block_outer_dim]; \ + const int base_outer_idx = tidx4d.data[block_outer_dim]; \ + \ + [[unroll]] for (int block_y = 0; block_y < 4; ++block_y) { \ + if (base_outer_idx + block_y < outer_size) { \ + imageStore(texture_name, tex_pos, vec4_type(block[block_y])); \ + } \ + tex_pos[tex_outer_dim]++; \ + } \ + } + #endif // BLOCK_STORE_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/indexing.glslh b/backends/vulkan/runtime/graph/ops/glsl/indexing.glslh index 51cda9a3d1d..c576128700e 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/indexing.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/indexing.glslh @@ -200,6 +200,10 @@ bool out_of_bounds(const TensorIndex tidx, const BufferMetadata meta) { any(greaterThanEqual(tidx.data[1], meta.sizes[1])); } +bool out_of_bounds(const TensorIndex tidx, const TextureMetadata meta) { + return any(greaterThanEqual(ivec4(tidx.data[0]), meta.sizes)); +} + // // TensorIndex4D (useful for texture backed tensors) // @@ -624,6 +628,74 @@ int tensor4d_idx_to_texel_idx( return block_idx; } +// +// 8D block-packed tensor indexing (TensorIndex versions) +// +// These functions extend the 4D versions above to handle tensors with up to 8 +// dimensions. Packed dims are guaranteed < 4, so block operations only affect +// data[0]. Dims 4-7 are non-block batch dimensions. +// + +/* + * 8D version of tensor4d_idx_to_block_idx. + * Packed dims (< 4) are divided by their block sizes, then all 8 dims + * contribute to the linear block index via their strides. + */ +int tensor_idx_to_block_idx( + const BufferMetadata meta, + TensorIndex tidx, + const int hashed_layout) { + const int inner_dim = get_packed_dim(hashed_layout); + const int outer_dim = get_outer_packed_dim(hashed_layout); + const int inner_block_size = get_packed_dim_block_size(hashed_layout); + const int outer_block_size = get_outer_packed_dim_block_size(hashed_layout); + + // Convert packed dims to block-space coordinates (packed dims always < 4) + tidx.data[0][inner_dim] = tidx.data[0][inner_dim] / uint(inner_block_size); + tidx.data[0][outer_dim] = tidx.data[0][outer_dim] / uint(outer_block_size); + + // Compute block-space linear index over all 8 dims + int block_idx = 0; + [[unroll]] for (int d = 0; d < 4; ++d) { + block_idx += int(meta.strides[0][d]) * int(tidx.data[0][d]); + } + [[unroll]] for (int d = 0; d < 4; ++d) { + block_idx += int(meta.strides[1][d]) * int(tidx.data[1][d]); + } + return block_idx; +} + +/* + * 8D version of tensor4d_idx_to_intra_block_idx. + * Packed dims are always < 4, so only data[0] is accessed. + */ +int tensor_idx_to_intra_block_idx( + const TensorIndex tidx, + const int hashed_layout) { + const int inner_dim = get_packed_dim(hashed_layout); + const int outer_dim = get_outer_packed_dim(hashed_layout); + const int inner_block_size = get_packed_dim_block_size(hashed_layout); + const int outer_block_size = get_outer_packed_dim_block_size(hashed_layout); + + const int inner_offset = int(tidx.data[0][inner_dim]) % inner_block_size; + const int outer_offset = int(tidx.data[0][outer_dim]) % outer_block_size; + return outer_offset * inner_block_size + inner_offset; +} + +/* + * 8D version of tensor4d_idx_to_buf_idx. + */ +int tensor_idx_to_buf_idx( + const BufferMetadata meta, + const TensorIndex tidx, + const int hashed_layout) { + const int block_idx = tensor_idx_to_block_idx(meta, tidx, hashed_layout); + const int intra_block_idx = + tensor_idx_to_intra_block_idx(tidx, hashed_layout); + const int block_numel = get_block_numel(hashed_layout); + return block_idx * block_numel + intra_block_idx; +} + // // Debug utilities // diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_transposed.glsl b/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_transposed.glsl new file mode 100644 index 00000000000..efed2e3a95b --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_transposed.glsl @@ -0,0 +1,254 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +${define_required_extensions("buffer", DTYPE)} + +#define USE_INT8_DOT_PRODUCT_EXT ${USE_INT8_DOT_PRODUCT_EXT} + +#extension GL_EXT_control_flow_attributes : require +$if USE_INT8_DOT_PRODUCT_EXT == 1: + #extension GL_EXT_integer_dot_product : require + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, "buffer")} +#define T ${texel_load_component_type(DTYPE, "buffer")} + +${define_active_storage_type("buffer")} + +layout(std430) buffer; + +#include "indexing.glslh" +#include "common.glslh" +#include "conv2d_common.glslh" + +${layout_declare_tensor(B, "w", "t_packed_int8_output", "int", "buffer", is_scalar_array=True)} +${layout_declare_tensor(B, "r", "t_packed_int8_input", "int", "buffer", is_scalar_array=True)} +${layout_declare_tensor(B, "r", "t_packed_int8_weight", "int", "texture2d", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight_sums", "int", "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight_scales", DTYPE, "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_bias", DTYPE, "buffer", is_scalar_array=False)} + +// Metadata for input/output tensors (memory layout agnostic) +${layout_declare_ubo(B, "BufferMetadata", "outp")} +${layout_declare_ubo(B, "BufferMetadata", "inp")} +${layout_declare_ubo(B, "Conv2DParams", "conv2d_params")} + +layout(push_constant) uniform restrict Block { + float input_scale; + int input_zp; + float output_inv_scale; + int output_zp; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +${layout_declare_spec_const(C, "int", "apply_bias", "1")} +${layout_declare_spec_const(C, "int", "activation_type", "0")} + +// Layout specialization constants +${layout_declare_spec_const(C, "int", "inp_layout", "CONTIG_LAYOUT_INT")} +${layout_declare_spec_const(C, "int", "outp_layout", "CONTIG_LAYOUT_INT")} + +// Load weight block for a given (ic4, kx, ky, oc4) position. +// Weight texture layout (from pack_q8_conv2d_weights.glsl): +// block_x = oc4 * K_w + kx +// block_y = ky * IC4 + ic4 +// Each texel ivec4 has 4 components (4 output channels), each component is +// a packed int32 containing 4 int8 values for 4 consecutive input channels. +ivec4 load_weight_block(int ic4, int kx, int ky, int oc4, int IC4, int KW) { + const int block_x = oc4 * KW + kx; + const int block_y = ky * IC4 + ic4; + return texelFetch(t_packed_int8_weight, ivec2(block_x, block_y), 0); +} + +ivec4 quantize(const vec4 texel, const float inv_scale, const int zp) { + vec4 quantized = round(texel * inv_scale) + zp; + return clamp(ivec4(quantized), -128, 127); +} + +void main() { + // Thread mapping: same as q8ta_conv2d + // Each thread handles a 4W x 4C tile of output + int oc4 = int(gl_GlobalInvocationID.z); + int w4 = int(gl_GlobalInvocationID.x); + + // Initialize output tensor index (WHCN order) + TensorIndex4D outp_tidx; + outp_tidx.data[0] = w4 * 4; + outp_tidx.data[1] = int(gl_GlobalInvocationID.y); + outp_tidx.data[2] = oc4 * 4; + outp_tidx.data[3] = 0; + + const int W = int(outp.sizes[0][0]); + const int OC = int(outp.sizes[0][2]); + + // Bounds check + if (any(greaterThanEqual(outp_tidx.data, ivec4(outp.sizes[0])))) { + return; + } + + // Input dimensions + const int inp_W = int(inp.sizes[0][0]); + const int inp_H = int(inp.sizes[0][1]); + const int IC = int(inp.sizes[0][2]); + + // Compute channels per group + const int OC_per_group = OC / conv2d_params.groups; + const int IC_per_group = IC / conv2d_params.groups; + const int IC4_per_group = div_up_4(IC_per_group); + + // Determine which group this output channel block belongs to + const int group_idx = outp_tidx.data[2] / OC_per_group; + const int ic_group_start = group_idx * IC_per_group; + + // Get strides for efficient indexing + const int inp_w_stride = int(inp.strides[0][0]); + const int inp_h_stride = int(inp.strides[0][1]); + + // Create packed input zero point (4 copies of input_zp packed into int32) + const int input_zp_packed = pack_into_int32(ivec4(input_zp)); + + // Initialize accumulators for 4 width positions x 4 output channels each + ivec4 acc[4]; + [[unroll]] for (int i = 0; i < 4; ++i) { + acc[i] = ivec4(0); + } + + // Transposed convolution loop structure: + // Iterate over all kernel positions (ky, kx, ic4). For each position, + // compute the corresponding input position. If the input position is valid + // (in bounds and on a stride-aligned position), load the actual input; + // otherwise use input_zp_packed. This ensures weight_sums correction is + // consistent with the accumulation (all kernel positions are accounted for). + // + // For transposed convolution, the input position for a given (output, kernel) + // is: input = (output + padding - kernel) / stride, which must be exact + // (remainder == 0) and within [0, input_size). + + const int KH = conv2d_params.kernel_size.y; + const int KW = conv2d_params.kernel_size.x; + const int stride_x = conv2d_params.stride.x; + const int stride_y = conv2d_params.stride.y; + const int pad_x = conv2d_params.padding.x; + const int pad_y = conv2d_params.padding.y; + + for (int ky = 0; ky < KH; ky++) { + // Check if this kernel row maps to a valid input row + const int in_y_numer = outp_tidx.data[1] + pad_y - ky; + const bool y_stride_valid = (in_y_numer >= 0) && ((in_y_numer % stride_y) == 0); + const int iy = y_stride_valid ? (in_y_numer / stride_y) : 0; + const bool h_in_bounds = y_stride_valid && (iy < inp_H); + + // Loop order: ic4 before kx for better weight texture cache locality. + // Consecutive ic4 values at fixed kx access consecutive y-rows in the + // weight texture, but consecutive kx at fixed ic4 access consecutive + // x-coordinates (same row) which is the fast dimension for texture + // cache lines. By iterating ic4 in the outer loop and kx in the inner + // loop, each ic4 iteration sweeps kx across a texture row. + for (int ic4 = 0; ic4 < IC4_per_group; ic4++) { + for (int kx = 0; kx < KW; kx++) { + // Load weight block: 4 output channels x 4 input channels + const ivec4 weight_block = load_weight_block( + ic4, kx, ky, oc4, IC4_per_group, KW); + + // Process 4 adjacent width positions + [[unroll]] for (int subtile_w = 0; subtile_w < 4; ++subtile_w) { + const int ow = outp_tidx.data[0] + subtile_w; + const int in_x_numer = ow + pad_x - kx; + + // Load packed input, or use zero point if out of bounds + int packed_input = input_zp_packed; + if (h_in_bounds && in_x_numer >= 0 && (in_x_numer % stride_x) == 0) { + const int ix = in_x_numer / stride_x; + if (ix < inp_W) { + TensorIndex4D inp_tidx; + inp_tidx.data[0] = ix; + inp_tidx.data[1] = iy; + inp_tidx.data[2] = ic_group_start + ic4 * 4; + inp_tidx.data[3] = 0; + + int inp_texel_idx; + if (get_outer_packed_dim_block_size(inp_layout) == 1) { + inp_texel_idx = tensor4d_idx_to_texel_idx(inp, inp_tidx, inp_layout); + } else { + const int w4_inp = div_4(ix); + const int inp_c4 = div_4(inp_tidx.data[2]); + inp_texel_idx = (iy * inp_h_stride + w4_inp * inp_w_stride + inp_c4) * 4 + mod_4(ix); + } + packed_input = t_packed_int8_input[inp_texel_idx]; + } + } + + // Accumulate using packed int8 dot product for each output channel + [[unroll]] for (int oc_offset = 0; oc_offset < 4; ++oc_offset) { + acc[subtile_w][oc_offset] = dotPacked4x8AccSat( + packed_input, + weight_block[oc_offset], + acc[subtile_w][oc_offset]); + } + } + } + } + } + + // Apply input zero point correction via weight_sums + const vec4 weight_sums = vec4(t_weight_sums[oc4]); + const vec4 weight_scales = vec4(t_weight_scales[oc4]); + + // Convert to float, apply dequantization, and optionally add bias + vec4 facc[4]; + [[unroll]] for (int subtile_w = 0; subtile_w < 4; ++subtile_w) { + facc[subtile_w] = vec4(acc[subtile_w]); + facc[subtile_w] -= weight_sums * input_zp; + facc[subtile_w] *= weight_scales * input_scale; + } + + // Apply bias if enabled + if (apply_bias > 0) { + const vec4 bias = vec4(t_bias[oc4]); + [[unroll]] for (int subtile_w = 0; subtile_w < 4; ++subtile_w) { + facc[subtile_w] += bias; + } + } + + // Apply ReLU if enabled + if (activation_type > 0) { + [[unroll]] for (int subtile_w = 0; subtile_w < 4; ++subtile_w) { + facc[subtile_w] = max(facc[subtile_w], vec4(0.0)); + } + } + + // Compute base output texel index (for subtile_w=0) + const int base_outp_texel_idx = tensor4d_idx_to_texel_idx(outp, outp_tidx, outp_layout); + const int out_w_stride = int(outp.strides[0][0]); + + // Quantize and store outputs using stride offsets + [[unroll]] for (int subtile_w = 0; subtile_w < 4; ++subtile_w) { + // Skip out-of-bounds width positions + if (outp_tidx.data[0] >= W) { + continue; + } + + const ivec4 quantized_out = quantize(facc[subtile_w], output_inv_scale, output_zp); + const int packed_out = pack_into_int32(quantized_out); + + // Store using stride offset from base + int outp_texel_idx; + if (get_outer_packed_dim_block_size(outp_layout) == 1) { + outp_texel_idx = base_outp_texel_idx + subtile_w * out_w_stride; + } else { + outp_texel_idx = base_outp_texel_idx + subtile_w; + } + + t_packed_int8_output[outp_texel_idx] = packed_out; + + outp_tidx.data[0] += 1; + } +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_transposed.yaml b/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_transposed.yaml new file mode 100644 index 00000000000..69469fabd95 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_transposed.yaml @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +q8ta_conv2d_transposed: + parameter_names_with_default_values: + DTYPE: float + USE_INT8_DOT_PRODUCT_EXT: 1 + generate_variant_forall: + DTYPE: + - VALUE: float + shader_variants: + - NAME: q8ta_conv2d_transposed + - NAME: q8ta_conv2d_transposed_fallback + USE_INT8_DOT_PRODUCT_EXT: 0 diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_dequantize.glsl b/backends/vulkan/runtime/graph/ops/glsl/q8ta_dequantize.glsl index 6989dc2d87d..88089627911 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/q8ta_dequantize.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_dequantize.glsl @@ -53,13 +53,13 @@ ${layout_declare_spec_const(C, "int", "inp_block_config", "0")} #include "block_int8x4_load.glslh" #include "block_store.glslh" -// Generate loading functions for t_inp buffer -define_load_int8x4_buffer_fns(t_inp) -// Generate storing functions for t_outp +// Generate 8D loading functions for t_inp buffer +define_load_int8x4_buffer_8d_fns(t_inp) +// Generate 8D storing functions for t_outp $if OUTPUT_STORAGE == "buffer": - define_store_buffer_fns(t_outp, T) + define_store_buffer_8d_fns(t_outp, T) $else: - define_store_texture_fns(t_outp, VEC4_T) + define_store_texture_8d_fns(t_outp, VEC4_T) mat4 dequantize_int8x4_block( const ivec4 block, const float scale, const int zp) { @@ -81,18 +81,20 @@ mat4 dequantize_int8x4_block( } void main() { - TensorIndex4D tidx; + TensorIndex tidx; #ifdef USING_BUFFER - // Buffer storage: use linear dispatch + // Buffer storage: use linear dispatch (supports up to 8D) const uint contig_block_idx = gl_GlobalInvocationID.x; - tidx = contiguous_block_idx_to_tensor4d_idx_with_block_config( + tidx = contiguous_block_idx_to_tensor_idx_with_block_config( inp, contig_block_idx, inp_block_config); #else - // Texture storage: use 3D extents dispatch + // Texture storage: use 3D extents dispatch (limited to 4D) const uvec3 thread_idx = gl_GlobalInvocationID; - tidx = block_idx_3d_to_tensor4d_idx_with_block_config( + TensorIndex4D tidx4d = block_idx_3d_to_tensor4d_idx_with_block_config( inp, thread_idx, inp_block_config); + initialize(tidx); + tidx.data[0] = uvec4(tidx4d.data); #endif if (out_of_bounds(tidx, inp)) { diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_quantize.glsl b/backends/vulkan/runtime/graph/ops/glsl/q8ta_quantize.glsl index c0c2a3a914a..2f458054a32 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/q8ta_quantize.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_quantize.glsl @@ -52,15 +52,15 @@ ${layout_declare_spec_const(C, "int", "outp_block_config", "0")} #include "block_indexing.glslh" #include "block_load.glslh" $if INPUT_STORAGE == "buffer": - // Generate loading functions for t_inp buffer - define_load_buffer_fns(t_inp) + // Generate 8D loading functions for t_inp buffer + define_load_buffer_8d_fns(t_inp) $else: - // Generate loading functions for t_inp texture - define_load_texture_fns(t_inp) + // Generate 8D loading functions for t_inp texture + define_load_texture_8d_fns(t_inp) #include "block_int8x4_store.glslh" -// Generate storing functions for t_outp buffer -define_store_int8x4_buffer_fns(t_outp) +// Generate 8D storing functions for t_outp buffer +define_store_int8x4_buffer_8d_fns(t_outp) ivec4 quantize_fp_block( const mat4 block, const float inv_scale, const int zp) { @@ -79,18 +79,20 @@ ivec4 quantize_fp_block( } void main() { - TensorIndex4D tidx; + TensorIndex tidx; #ifdef USING_BUFFER - // Buffer storage: use linear dispatch + // Buffer storage: use linear dispatch (supports up to 8D) const uint contig_block_idx = gl_GlobalInvocationID.x; - tidx = contiguous_block_idx_to_tensor4d_idx_with_block_config( + tidx = contiguous_block_idx_to_tensor_idx_with_block_config( inp, contig_block_idx, inp_block_config); #else - // Texture storage: use 3D extents dispatch + // Texture storage: use 3D extents dispatch (limited to 4D) const uvec3 thread_idx = gl_GlobalInvocationID; - tidx = block_idx_3d_to_tensor4d_idx_with_block_config( + TensorIndex4D tidx4d = block_idx_3d_to_tensor4d_idx_with_block_config( inp, thread_idx, inp_block_config); + initialize(tidx); + tidx.data[0] = uvec4(tidx4d.data); #endif if (out_of_bounds(tidx, inp)) { diff --git a/backends/vulkan/runtime/graph/ops/impl/Common.cpp b/backends/vulkan/runtime/graph/ops/impl/Common.cpp index d052882afde..40ca0510383 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Common.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Common.cpp @@ -252,6 +252,10 @@ utils::uvec3 pick_linear_global_wg_with_block_config( utils::safe_downcast(utils::val_at(-1 - d, sizes)); } } + // Include extra batch dimensions beyond 4D WHCN space + for (int32_t d = 4; d < static_cast(sizes.size()); ++d) { + num_planes *= utils::safe_downcast(utils::val_at(-1 - d, sizes)); + } // Return linear workgroup size: {total_blocks, 1u, 1u} const uint32_t total_blocks = @@ -285,7 +289,11 @@ utils::uvec3 pick_extents_global_wg_with_block_config( const int64_t W = utils::val_at(-1, sizes); const int64_t H = utils::val_at(-2, sizes); const int64_t C = utils::val_at(-3, sizes); - const int64_t N = utils::val_at(-4, sizes); + // N is the product of all batch dimensions (WHCN dims 3+) + int64_t N = 1; + for (int64_t i = 4; i <= static_cast(sizes.size()); ++i) { + N *= utils::val_at(-i, sizes); + } // Dispatch structure: {x_threads, y_threads, z_threads} // - x corresponds to W dimension diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.h b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.h index 6da98fbef74..f463589c50a 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.h +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.h @@ -9,6 +9,7 @@ #pragma once #include +#include #include namespace vkcompute { @@ -145,4 +146,10 @@ void add_q8ta_im2col_node( void q8ta_conv2d_im2col(ComputeGraph& graph, const std::vector& args); +// Transposed convolution + +void q8ta_conv2d_transposed( + ComputeGraph& graph, + const std::vector& args); + } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dTransposed.cpp b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dTransposed.cpp new file mode 100644 index 00000000000..bdbdaa14fec --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dTransposed.cpp @@ -0,0 +1,267 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include +#include +#include +#include + +namespace vkcompute { + +// Dedicated workgroup size functions for transposed convolution. +// Unlike regular conv2d, transposed conv with stride > 1 causes branch +// divergence along the height dimension (different rows have different +// stride-alignment patterns). Keeping local_y=1 ensures all threads in a +// workgroup process the same height row, maximizing branch coherence. + +utils::uvec3 pick_q8ta_conv2d_transposed_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)resize_args; + + const ValueRef output = args.at(0).refs.at(0); + + const uint32_t W = graph->size_at(-1, output); + const uint32_t H = graph->size_at(-2, output); + const uint32_t C = graph->size_at(-3, output); + + const uint32_t W4 = utils::div_up_4(W); + const uint32_t C4 = utils::div_up_4(C); + + return {W4, H, C4}; +} + +utils::uvec3 pick_q8ta_conv2d_transposed_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)resize_args; + (void)graph; + (void)args; + + // Always keep local_y=1 to avoid branch divergence between height rows. + if (global_workgroup_size[0u] >= 6 && global_workgroup_size[2u] >= 6) { + return {8u, 1u, 8u}; + } + if (global_workgroup_size[0u] < 2u) { + return {1u, 1u, 64u}; + } + if (global_workgroup_size[2u] < 2u) { + return {64u, 1u, 1u}; + } + return {16u, 1u, 4u}; +} + +void add_q8ta_conv2d_transposed_node( + ComputeGraph& graph, + const ValueRef packed_int8_input, + const ValueRef input_scale, + const ValueRef input_zp, + const ValueRef packed_weight, + const ValueRef packed_weight_sums, + const ValueRef packed_weight_scales, + const ValueRef output_scale, + const ValueRef output_zp, + const ValueRef bias_data, + const ValueRef packed_bias, + const ValueRef kernel_size, + const ValueRef stride, + const ValueRef padding, + const ValueRef dilation, + const ValueRef groups, + const uint32_t activation_type, + const ValueRef packed_int8_output) { + Conv2DParams conv_params = create_conv2d_params( + graph, + packed_int8_input, + packed_int8_output, + kernel_size, + stride, + padding, + dilation, + groups); + + // Transposed convolution only supports dilation=1 + VK_CHECK_COND( + conv_params.dilation[0] == 1 && conv_params.dilation[1] == 1, + "q8ta_conv2d_transposed only supports dilation=1"); + + // The implementation requires that for grouped convolutions, the input + // channels per group is a multiple of 4. + if (conv_params.groups > 1) { + VK_CHECK_COND(conv_params.in_channels_per_group % 4 == 0); + } + + // Validate packed dim info for input and output tensors + VK_CHECK_COND(q8ta_conv2d_check_packed_dim_info( + graph.packed_dim_info_of(packed_int8_input))); + VK_CHECK_COND(q8ta_conv2d_check_packed_dim_info( + graph.packed_dim_info_of(packed_int8_output))); + + // Validate dtype is kInt8x4 + VK_CHECK_COND(graph.dtype_of(packed_int8_input) == vkapi::kInt8x4); + VK_CHECK_COND(graph.dtype_of(packed_int8_output) == vkapi::kInt8x4); + + float input_scale_val = graph.extract_scalar(input_scale); + int32_t input_zp_val = graph.extract_scalar(input_zp); + + float output_inv_scale_val = 1.0f / graph.extract_scalar(output_scale); + int32_t output_zp_val = graph.extract_scalar(output_zp); + + uint32_t apply_bias = 1; + if (graph.val_is_none(bias_data)) { + apply_bias = 0; + } + + std::vector push_constants = { + PushConstantDataInfo(&input_scale_val, sizeof(input_scale_val)), + PushConstantDataInfo(&input_zp_val, sizeof(input_zp_val)), + PushConstantDataInfo(&output_inv_scale_val, sizeof(output_inv_scale_val)), + PushConstantDataInfo(&output_zp_val, sizeof(output_zp_val)), + }; + + const bool use_hw_dot = + graph.context()->adapter_ptr()->supports_int8_dot_product(); + std::string kernel_name = + use_hw_dot ? "q8ta_conv2d_transposed" : "q8ta_conv2d_transposed_fallback"; + add_dtype_suffix(kernel_name, graph.dtype_of(packed_weight_scales)); + + // Pass metadata for both output and input tensors + vkapi::ParamsBindList param_buffers = { + graph.buffer_meta_ubo(packed_int8_output), + graph.buffer_meta_ubo(packed_int8_input), + graph.create_params_buffer(conv_params)}; + + // Build spec constants: apply_bias, activation_type + layout constants + vkapi::SpecVarList spec_constants = { + apply_bias, + activation_type, + // Layout specialization constants + graph.hashed_layout_of(packed_int8_input), + graph.hashed_layout_of(packed_int8_output), + }; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + pick_q8ta_conv2d_transposed_global_wg_size, + pick_q8ta_conv2d_transposed_local_wg_size, + // Inputs and Outputs + {{packed_int8_output, vkapi::kWrite}, + {{packed_int8_input, + packed_weight, + packed_weight_sums, + packed_weight_scales, + packed_bias}, + vkapi::kRead}}, + // Shader params buffers + param_buffers, + // Push Constants + push_constants, + // Specialization Constants + spec_constants, + // Resize args + {})); +} + +void q8ta_conv2d_transposed( + ComputeGraph& graph, + const std::vector& args) { + int32_t idx = 0; + const ValueRef packed_int8_input = args.at(idx++); + const ValueRef input_scale = args.at(idx++); + const ValueRef input_zp = args.at(idx++); + const ValueRef weight_data = args.at(idx++); + const ValueRef weight_sums_data = args.at(idx++); + const ValueRef weight_scales_data = args.at(idx++); + const ValueRef output_scale = args.at(idx++); + const ValueRef output_zp = args.at(idx++); + const ValueRef bias_data = args.at(idx++); + const ValueRef kernel_size = args.at(idx++); + const ValueRef stride = args.at(idx++); + const ValueRef padding = args.at(idx++); + args.at(idx++); // output_padding: only affects output size, not shader + const ValueRef dilation = args.at(idx++); + const ValueRef groups = args.at(idx++); + const ValueRef activation = args.at(idx++); + const ValueRef packed_int8_output = args.at(idx++); + + uint32_t activation_type_val = static_cast( + activation_type_from_string(graph.extract_string(activation))); + + QuantizationConfig weight_quant_config(8, kPerChannel, {}); + + // Reuse the conv2d weight packing (after the pattern matcher reshapes the + // transposed weight to (OC, KH*KW*IC_per_group), the weight layout is + // identical to regular conv2d) + ValueRef packed_weight = prepack_quantized_conv2d_weight( + graph, + weight_quant_config, + weight_data, + packed_int8_input, + packed_int8_output, + groups, + kernel_size); + + ValueRef packed_weight_sums = prepack_standard( + graph, weight_sums_data, utils::kBuffer, utils::kWidthPacked); + + ValueRef packed_weight_scales = prepack_standard( + graph, weight_scales_data, utils::kBuffer, utils::kWidthPacked); + + // Create a dummy tensor to fill the binding slot of the bias tensor if it is + // not provided + TmpTensor dummy_bias( + &graph, + {}, + graph.dtype_of(weight_scales_data), + utils::kBuffer, + utils::kWidthPacked); + + ValueRef packed_bias = dummy_bias.vref; + if (graph.val_is_not_none(bias_data)) { + packed_bias = + prepack_standard(graph, bias_data, utils::kBuffer, utils::kWidthPacked); + } + + add_q8ta_conv2d_transposed_node( + graph, + packed_int8_input, + input_scale, + input_zp, + packed_weight, + packed_weight_sums, + packed_weight_scales, + output_scale, + output_zp, + bias_data, + packed_bias, + kernel_size, + stride, + padding, + dilation, + groups, + activation_type_val, + packed_int8_output); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(et_vk.q8ta_conv2d_transposed.default, q8ta_conv2d_transposed); +} + +} // namespace vkcompute diff --git a/backends/vulkan/test/custom_ops/impl/TestQ8taConv2dTransposed.cpp b/backends/vulkan/test/custom_ops/impl/TestQ8taConv2dTransposed.cpp new file mode 100644 index 00000000000..894ce71fed9 --- /dev/null +++ b/backends/vulkan/test/custom_ops/impl/TestQ8taConv2dTransposed.cpp @@ -0,0 +1,87 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include + +namespace vkcompute { + +void test_q8ta_conv2d_transposed( + ComputeGraph& graph, + const std::vector& args) { + int32_t idx = 0; + const ValueRef fp_input = args.at(idx++); + const ValueRef input_scale = args.at(idx++); + const ValueRef input_zp = args.at(idx++); + const ValueRef weight_data = args.at(idx++); + const ValueRef weight_sums_data = args.at(idx++); + const ValueRef weight_scales_data = args.at(idx++); + const ValueRef output_scale = args.at(idx++); + const ValueRef output_zp = args.at(idx++); + const ValueRef bias_data = args.at(idx++); + const ValueRef kernel_size = args.at(idx++); + const ValueRef stride = args.at(idx++); + const ValueRef padding = args.at(idx++); + const ValueRef output_padding = args.at(idx++); + const ValueRef dilation = args.at(idx++); + const ValueRef groups = args.at(idx++); + const ValueRef activation = args.at(idx++); + const ValueRef layout_int = args.at(idx++); + const ValueRef fp_output = args.at(idx++); + + int32_t layout_value = graph.extract_scalar(layout_int); + utils::GPUMemoryLayout layout = + static_cast(layout_value); + + TmpTensor packed_int8_input( + &graph, graph.sizes_of(fp_input), vkapi::kInt8x4, utils::kBuffer, layout); + + TmpTensor packed_int8_output( + &graph, + graph.sizes_of(fp_output), + vkapi::kInt8x4, + utils::kBuffer, + layout); + + add_q8ta_quantize_node( + graph, fp_input, input_scale, input_zp, packed_int8_input); + + std::vector conv_args = { + packed_int8_input, + input_scale, + input_zp, + weight_data, + weight_sums_data, + weight_scales_data, + output_scale, + output_zp, + bias_data, + kernel_size, + stride, + padding, + output_padding, + dilation, + groups, + activation, + packed_int8_output}; + VK_GET_OP_FN("et_vk.q8ta_conv2d_transposed.default")(graph, conv_args); + + add_q8ta_dequantize_node( + graph, packed_int8_output, output_scale, output_zp, fp_output); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP( + test_etvk.test_q8ta_conv2d_transposed.default, + test_q8ta_conv2d_transposed); +} + +} // namespace vkcompute diff --git a/backends/vulkan/test/custom_ops/targets.bzl b/backends/vulkan/test/custom_ops/targets.bzl index badba5666fa..ba4873af603 100644 --- a/backends/vulkan/test/custom_ops/targets.bzl +++ b/backends/vulkan/test/custom_ops/targets.bzl @@ -98,3 +98,4 @@ def define_common_targets(is_fbcode = False): define_custom_op_test_binary("test_q8ta_conv2d_pw") define_custom_op_test_binary("test_q8ta_conv2d_dw") define_custom_op_test_binary("test_q8ta_linear") + define_custom_op_test_binary("test_q8ta_conv2d_transposed") diff --git a/backends/vulkan/test/custom_ops/test_q8ta_conv2d_transposed.cpp b/backends/vulkan/test/custom_ops/test_q8ta_conv2d_transposed.cpp new file mode 100644 index 00000000000..903a9c678b1 --- /dev/null +++ b/backends/vulkan/test/custom_ops/test_q8ta_conv2d_transposed.cpp @@ -0,0 +1,601 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include + +#include +#include + +#include + +#include "conv2d_utils.h" +#include "utils.h" + +// #define DEBUG_MODE + +using namespace executorch::vulkan::prototyping; + +using namespace vkcompute; + +static constexpr int64_t kRefDimSizeLimit = 100; + +// Transposed convolution output size formula: +// H_out = (H_in - 1) * stride_h - 2 * pad_h + dilation_h * (K_h - 1) +// + output_pad_h + 1 +static int64_t get_transpose_output_height( + const Conv2dConfig& config, + int32_t output_pad_h) { + return (config.input_size.h - 1) * config.stride.h - 2 * config.padding.h + + config.dilation.h * (config.kernel.h - 1) + output_pad_h + 1; +} + +static int64_t get_transpose_output_width( + const Conv2dConfig& config, + int32_t output_pad_w) { + return (config.input_size.w - 1) * config.stride.w - 2 * config.padding.w + + config.dilation.w * (config.kernel.w - 1) + output_pad_w + 1; +} + +// Utility function to create a test case from a Conv2dConfig for transposed +// convolution +static TestCase create_test_case_from_config( + const Conv2dConfig& config, + int32_t output_pad_h, + int32_t output_pad_w, + vkapi::ScalarType input_dtype, + utils::StorageType fp_storage_type, + utils::GPUMemoryLayout int8_memory_layout) { + TestCase test_case; + + int64_t H_out = get_transpose_output_height(config, output_pad_h); + int64_t W_out = get_transpose_output_width(config, output_pad_w); + + // Input tensor (float/half) - [1, C_in, H_in, W_in] (batch size always 1) + // For transposed conv, C_in is typically larger (downsampled channels) + std::vector input_size = { + 1, config.channels.in, config.input_size.h, config.input_size.w}; + + utils::GPUMemoryLayout fp_memory_layout = fp_storage_type == utils::kBuffer + ? utils::kWidthPacked + : utils::kChannelsPacked; + + // Create test case name + std::string prefix = config.test_case_name.substr(0, 4); + std::string test_name = prefix + " " + std::to_string(config.channels.in) + + "->" + std::to_string(config.channels.out) + " " + + "I=" + std::to_string(config.input_size.h) + "," + + std::to_string(config.input_size.w) + " " + + "g=" + std::to_string(config.groups) + " " + + "k=" + std::to_string(config.kernel.h) + " " + + "op=" + std::to_string(output_pad_h) + "," + + std::to_string(output_pad_w) + " " + + repr_str(utils::kBuffer, int8_memory_layout); + test_case.set_name(test_name); + + test_case.set_operator_name("test_etvk.test_q8ta_conv2d_transposed.default"); + + ValueSpec input_tensor( + input_size, + input_dtype, + fp_storage_type, + fp_memory_layout, + DataGenType::RANDOM); + + if (debugging()) { + print_valuespec_data(input_tensor, "input_tensor"); + } + + float input_scale_val = 0.008123; + ValueSpec input_scale(input_scale_val); + + int32_t input_zero_point_val = 2; + ValueSpec input_zero_point(input_zero_point_val); + + // Quantized weight tensor (int8) - [C_out, align_up_4(C_in_per_group * K_h * + // K_w)] After the pattern matcher reshapes, the transposed conv weight has + // the same layout as regular conv2d + const int64_t in_channels_per_group = config.channels.in / config.groups; + const int64_t in_features = utils::align_up_4( + in_channels_per_group * config.kernel.h * config.kernel.w); + std::vector weight_size = {config.channels.out, in_features}; + ValueSpec quantized_weight( + weight_size, + vkapi::kChar, + fp_storage_type, + utils::kWidthPacked, + DataGenType::RANDINT8); + quantized_weight.set_constant(true); + + if (debugging()) { + print_valuespec_data(quantized_weight, "weight_tensor"); + } + + const int64_t aligned_out_channels = utils::align_up_4(config.channels.out); + + ValueSpec weight_scales( + {aligned_out_channels}, + input_dtype, + fp_storage_type, + utils::kWidthPacked, + DataGenType::RANDOM_SCALES); + weight_scales.set_constant(true); + + ValueSpec weight_sums( + {aligned_out_channels}, + vkapi::kInt, + fp_storage_type, + utils::kWidthPacked, + DataGenType::ZEROS); + weight_sums.set_constant(true); + + compute_weight_sums( + weight_sums, quantized_weight, config.channels.out, in_features); + + ValueSpec bias( + {aligned_out_channels}, + input_dtype, + fp_storage_type, + utils::kWidthPacked, + DataGenType::ZEROS); + bias.set_constant(true); + + float output_scale_val = 0.05314; + ValueSpec output_scale(output_scale_val); + + int32_t output_zero_point_val = -1; + ValueSpec output_zero_point(output_zero_point_val); + + ValueSpec stride({config.stride.h, config.stride.w}); + ValueSpec padding({config.padding.h, config.padding.w}); + ValueSpec output_padding({output_pad_h, output_pad_w}); + ValueSpec dilation({config.dilation.h, config.dilation.w}); + ValueSpec groups(config.groups); + ValueSpec kernel_size({config.kernel.h, config.kernel.w}); + + // Output tensor - [1, C_out, H_out, W_out] + ValueSpec output( + {1, config.channels.out, H_out, W_out}, + input_dtype, + fp_storage_type, + fp_memory_layout, + DataGenType::ZEROS); + + // Add all specs to test case + test_case.add_input_spec(input_tensor); + test_case.add_input_spec(input_scale); + test_case.add_input_spec(input_zero_point); + test_case.add_input_spec(quantized_weight); + test_case.add_input_spec(weight_sums); + test_case.add_input_spec(weight_scales); + test_case.add_input_spec(output_scale); + test_case.add_input_spec(output_zero_point); + test_case.add_input_spec(bias); + test_case.add_input_spec(kernel_size); + test_case.add_input_spec(stride); + test_case.add_input_spec(padding); + test_case.add_input_spec(output_padding); + test_case.add_input_spec(dilation); + test_case.add_input_spec(groups); + + ValueSpec activation = ValueSpec::make_string("none"); + test_case.add_input_spec(activation); + + ValueSpec layout_int(static_cast(int8_memory_layout)); + test_case.add_input_spec(layout_int); + + test_case.add_output_spec(output); + + test_case.set_abs_tolerance(output_scale_val + 1e-4f); + + test_case.set_shader_filter({ + "nchw_to", + "to_nchw", + "q8ta_quantize", + "q8ta_dequantize", + }); + + return test_case; +} + +// Generate easy test cases for debugging +std::vector generate_quantized_conv2d_transposed_easy_cases() { + std::vector test_cases; + + Conv2dConfig config = { + OutInChannels(16, 32), + InputSize2D(8, 8), + KernelSize(3, 3), + Stride(2, 2), + Padding(1, 1), + Dilation(1, 1), + 1, + }; + + std::vector int8_memory_layouts = { + utils::kPackedInt8_4C1W, utils::kPackedInt8_4W4C, utils::kPackedInt8_4C}; + + for (const utils::GPUMemoryLayout int8_memory_layout : int8_memory_layouts) { + config.test_case_name = + make_test_case_name(config, false, utils::kTexture3D, utils::kBuffer); + test_cases.push_back(create_test_case_from_config( + config, + /*output_pad_h=*/1, + /*output_pad_w=*/1, + vkapi::kFloat, + utils::kTexture3D, + int8_memory_layout)); + } + + return test_cases; +} + +// Generate test cases for quantized transposed conv2d +static std::vector generate_quantized_conv2d_transposed_test_cases() { + std::vector test_cases; + if (!vkcompute::api::context()->adapter_ptr()->supports_int8_dot_product()) { + return test_cases; + } + + // Each entry: {config, output_pad_h, output_pad_w} + struct TransposedConvTestConfig { + Conv2dConfig config; + int32_t output_pad_h; + int32_t output_pad_w; + }; + + std::vector configs = { + // Basic transposed conv (stride=2, common in decoder networks) + {{OutInChannels(16, 32), + InputSize2D(8, 8), + KernelSize(3, 3), + Stride(2, 2), + Padding(1, 1), + Dilation(1, 1), + 1}, + 1, + 1}, + {{OutInChannels(32, 64), + InputSize2D(4, 4), + KernelSize(3, 3), + Stride(2, 2), + Padding(1, 1), + Dilation(1, 1), + 1}, + 1, + 1}, + // No output padding + {{OutInChannels(16, 32), + InputSize2D(8, 8), + KernelSize(4, 4), + Stride(2, 2), + Padding(1, 1), + Dilation(1, 1), + 1}, + 0, + 0}, + // Stride=1 (degenerate case) + {{OutInChannels(16, 16), + InputSize2D(8, 8), + KernelSize(3, 3), + Stride(1, 1), + Padding(1, 1), + Dilation(1, 1), + 1}, + 0, + 0}, + // Grouped transposed conv + {{OutInChannels(32, 64), + InputSize2D(8, 8), + KernelSize(3, 3), + Stride(2, 2), + Padding(1, 1), + Dilation(1, 1), + 2}, + 1, + 1}, + // Larger spatial + {{OutInChannels(64, 128), + InputSize2D(16, 16), + KernelSize(4, 4), + Stride(2, 2), + Padding(1, 1), + Dilation(1, 1), + 1}, + 0, + 0}, + // Performance cases + {{OutInChannels(64, 128), + InputSize2D(32, 32), + KernelSize(3, 3), + Stride(2, 2), + Padding(1, 1), + Dilation(1, 1), + 1}, + 1, + 1}, + {{OutInChannels(128, 256), + InputSize2D(16, 16), + KernelSize(4, 4), + Stride(2, 2), + Padding(1, 1), + Dilation(1, 1), + 1}, + 0, + 0}, + }; + + std::vector int8_memory_layouts = { + utils::kPackedInt8_4C1W, utils::kPackedInt8_4W4C, utils::kPackedInt8_4C}; + + for (auto& tc : configs) { + auto& config = tc.config; + bool is_performance = config.channels.out > kRefDimSizeLimit || + config.channels.in > kRefDimSizeLimit || + config.input_size.h > kRefDimSizeLimit || + config.input_size.w > kRefDimSizeLimit; + + for (const utils::GPUMemoryLayout int8_memory_layout : + int8_memory_layouts) { + config.test_case_name = make_test_case_name( + config, is_performance, utils::kTexture3D, utils::kBuffer); + + test_cases.push_back(create_test_case_from_config( + config, + tc.output_pad_h, + tc.output_pad_w, + vkapi::kFloat, + utils::kTexture3D, + int8_memory_layout)); + } + } + + return test_cases; +} + +// Reference implementation for quantized transposed conv2d +static void conv2d_transposed_q8ta_reference_impl(TestCase& test_case) { + int32_t idx = 0; + const ValueSpec& input_spec = test_case.inputs()[idx++]; + const ValueSpec& input_scale_spec = test_case.inputs()[idx++]; + const ValueSpec& input_zeros_spec = test_case.inputs()[idx++]; + const ValueSpec& weight_spec = test_case.inputs()[idx++]; + const ValueSpec& weight_sums_spec = test_case.inputs()[idx++]; + (void)weight_sums_spec; + const ValueSpec& weight_scales_spec = test_case.inputs()[idx++]; + const ValueSpec& output_scale_spec = test_case.inputs()[idx++]; + const ValueSpec& output_zeros_spec = test_case.inputs()[idx++]; + const ValueSpec& bias_spec = test_case.inputs()[idx++]; + const ValueSpec& kernel_size_spec = test_case.inputs()[idx++]; + const ValueSpec& stride_spec = test_case.inputs()[idx++]; + const ValueSpec& padding_spec = test_case.inputs()[idx++]; + const ValueSpec& output_padding_spec = test_case.inputs()[idx++]; + (void)output_padding_spec; // output_padding only affects output size + const ValueSpec& dilation_spec = test_case.inputs()[idx++]; + const ValueSpec& groups_spec = test_case.inputs()[idx++]; + const ValueSpec& activation_spec = test_case.inputs()[idx++]; + (void)activation_spec; + const ValueSpec& layout_spec = test_case.inputs()[idx++]; + (void)layout_spec; + + ValueSpec& output_spec = test_case.outputs()[0]; + + auto input_sizes = input_spec.get_tensor_sizes(); + auto output_sizes = output_spec.get_tensor_sizes(); + + int64_t N = input_sizes[0]; + int64_t C_in = input_sizes[1]; + int64_t H_in = input_sizes[2]; + int64_t W_in = input_sizes[3]; + int64_t C_out = output_sizes[1]; + int64_t H_out = output_sizes[2]; + int64_t W_out = output_sizes[3]; + + auto kernel_size_data = kernel_size_spec.get_int32_data(); + int64_t K_h = kernel_size_data[0]; + int64_t K_w = kernel_size_data[1]; + + auto stride_data = stride_spec.get_int32_data(); + auto padding_data = padding_spec.get_int32_data(); + auto dilation_data = dilation_spec.get_int32_data(); + int64_t stride_h = stride_data[0]; + int64_t stride_w = stride_data[1]; + int64_t pad_h = padding_data[0]; + int64_t pad_w = padding_data[1]; + int64_t dilation_h = dilation_data[0]; + int64_t dilation_w = dilation_data[1]; + int64_t groups = groups_spec.get_int_value(); + + if (N > kRefDimSizeLimit || C_in > kRefDimSizeLimit || + H_in > kRefDimSizeLimit || W_in > kRefDimSizeLimit || + C_out > kRefDimSizeLimit) { + throw std::invalid_argument( + "One or more dimensions exceed the allowed limit for reference implementation."); + } + + if (input_spec.dtype != vkapi::kFloat) { + throw std::invalid_argument("Unsupported dtype"); + } + + auto& input_data = input_spec.get_float_data(); + const float input_scale = input_scale_spec.get_float_value(); + const int32_t input_zero_point = input_zeros_spec.get_int_value(); + + auto& weight_data = weight_spec.get_int8_data(); + auto& weight_scales_data = weight_scales_spec.get_float_data(); + auto& bias_data = bias_spec.get_float_data(); + + const float output_scale = output_scale_spec.get_float_value(); + const int32_t output_zero_point = output_zeros_spec.get_int_value(); + + int64_t C_in_per_group = C_in / groups; + int64_t C_out_per_group = C_out / groups; + + int64_t num_output_elements = N * C_out * H_out * W_out; + + auto& ref_data = output_spec.get_ref_float_data(); + ref_data.resize(num_output_elements); + + const int64_t in_features = utils::align_up_4(C_in_per_group * K_h * K_w); + + // Transposed convolution reference implementation. + // For transposed conv, we scatter each input element across the output + // rather than gather. But for the reference we compute it by iterating + // over output positions and finding which input positions contribute. + // + // For each output position (oh, ow), an input position (iy, ix) contributes + // via kernel position (kh, kw) if: + // oh + pad_h - kh * dilation_h == iy * stride_h + // ow + pad_w - kw * dilation_w == ix * stride_w + // i.e., (oh + pad_h - kh * dilation_h) must be divisible by stride_h + // and the quotient must be a valid input index. + for (int64_t n = 0; n < N; ++n) { + for (int64_t out_c = 0; out_c < C_out; ++out_c) { + int64_t group_idx = out_c / C_out_per_group; + int64_t in_c_start = group_idx * C_in_per_group; + + for (int64_t out_h = 0; out_h < H_out; ++out_h) { + for (int64_t out_w = 0; out_w < W_out; ++out_w) { + int32_t int_sum = 0; + int32_t weight_sum = 0; + + for (int64_t kh = 0; kh < K_h; ++kh) { + int64_t h_offset = out_h + pad_h - kh * dilation_h; + if (h_offset < 0 || h_offset % stride_h != 0) { + continue; + } + int64_t iy = h_offset / stride_h; + if (iy >= H_in) { + continue; + } + + for (int64_t kw = 0; kw < K_w; ++kw) { + int64_t w_offset = out_w + pad_w - kw * dilation_w; + if (w_offset < 0 || w_offset % stride_w != 0) { + continue; + } + int64_t ix = w_offset / stride_w; + if (ix >= W_in) { + continue; + } + + for (int64_t ic_local = 0; ic_local < C_in_per_group; + ++ic_local) { + int64_t in_c = in_c_start + ic_local; + + int64_t input_idx = n * (C_in * H_in * W_in) + + in_c * (H_in * W_in) + iy * W_in + ix; + + float quant_input_f = + std::round(input_data[input_idx] / input_scale) + + input_zero_point; + quant_input_f = + std::min(std::max(quant_input_f, -128.0f), 127.0f); + int8_t quantized_input = static_cast(quant_input_f); + + // Weight layout: [C_out, align_up_4(C_in_per_group * K_h * + // K_w)] Inner dimension order: kh, kw, ic_local + int64_t weight_idx = out_c * in_features + + (kh * (K_w * C_in_per_group) + kw * C_in_per_group + + ic_local); + int8_t quantized_weight = weight_data[weight_idx]; + + int_sum += static_cast(quantized_input) * + static_cast(quantized_weight); + + weight_sum += static_cast(quantized_weight); + } + } + } + + int32_t zero_point_correction = input_zero_point * weight_sum; + int32_t accum_adjusted = int_sum - zero_point_correction; + float float_result = + accum_adjusted * input_scale * weight_scales_data[out_c]; + + float_result += bias_data[out_c]; + + float quant_output_f = + std::round(float_result / output_scale) + output_zero_point; + quant_output_f = std::min(std::max(quant_output_f, -128.0f), 127.0f); + int8_t quantized_output = static_cast(quant_output_f); + + float dequant_output = + (static_cast(quantized_output) - output_zero_point) * + output_scale; + + int64_t output_idx = n * (C_out * H_out * W_out) + + out_c * (H_out * W_out) + out_h * W_out + out_w; + ref_data[output_idx] = dequant_output; + } + } + } + } +} + +static void reference_impl(TestCase& test_case) { + conv2d_transposed_q8ta_reference_impl(test_case); +} + +static int64_t quantized_conv2d_transposed_flop_calculator( + const TestCase& test_case) { + int kernel_idx = 9; + + const auto& input_sizes = test_case.inputs()[0].get_tensor_sizes(); + const auto& output_sizes = test_case.outputs()[0].get_tensor_sizes(); + const auto& kernel_sizes = test_case.inputs()[kernel_idx].get_int32_data(); + + int64_t N = input_sizes[0]; + int64_t C_in = input_sizes[1]; + int64_t C_out = output_sizes[1]; + int64_t K_h = kernel_sizes[0]; + int64_t K_w = kernel_sizes[1]; + int64_t H_out = output_sizes[2]; + int64_t W_out = output_sizes[3]; + + int64_t output_elements = N * C_out * H_out * W_out; + int64_t ops_per_output = C_in * K_h * K_w; + + return output_elements * ops_per_output; +} + +int main(int argc, char* argv[]) { + set_debugging(false); + set_print_output(false); +#ifdef DEBUG_MODE + set_print_latencies(true); +#else + set_print_latencies(false); +#endif + set_use_gpu_timestamps(true); + + print_performance_header(); + std::cout + << "Quantized Transposed Conv2d Operation with Output Quantization Prototyping Framework" + << std::endl; + print_separator(); + + ReferenceComputeFunc ref_fn = reference_impl; + + auto results = execute_test_cases( +#ifdef DEBUG_MODE + generate_quantized_conv2d_transposed_easy_cases, +#else + generate_quantized_conv2d_transposed_test_cases, +#endif + quantized_conv2d_transposed_flop_calculator, + "QuantizedTransposedConv2d", +#ifdef DEBUG_MODE + 0, + 1, +#else + 3, + 10, +#endif + ref_fn); + + return 0; +} diff --git a/backends/vulkan/test/custom_ops/test_q8ta_qdq.cpp b/backends/vulkan/test/custom_ops/test_q8ta_qdq.cpp index a3ff8c42f86..9c8dbce6501 100644 --- a/backends/vulkan/test/custom_ops/test_q8ta_qdq.cpp +++ b/backends/vulkan/test/custom_ops/test_q8ta_qdq.cpp @@ -108,6 +108,8 @@ std::vector generate_q_dq_8bit_easy_cases() { {1, 16, 16, 16}, // 4D: [N, C, H, W] {1, 144}, // 2D: exercises block config with ndim < 4 {1, 90}, // 2D: matches skin_seg model's keypoint/bbox tensor sizes + {2, 1, 3, 16, 16}, // 5D: exercises high-dim batch support + {2, 1, 1, 4, 8, 16}, // 6D: exercises high-dim batch support }; // FP memory layouts to test @@ -199,6 +201,14 @@ std::vector generate_q_dq_8bit_test_cases() { {1, 64, 128, 128}, {1, 32, 64, 64}, {1, 128, 56, 56}, + + // 5D tensors (high-dim batch support) + {2, 1, 3, 16, 16}, + {1, 2, 8, 8, 8}, + + // 6D tensors (high-dim batch support) + {2, 1, 1, 4, 8, 16}, + {1, 1, 2, 3, 8, 8}, }; // FP memory layouts to test @@ -234,9 +244,15 @@ std::vector generate_q_dq_8bit_test_cases() { } } + // Skip texture3d for high-dim tensors (textures are limited to 4D) + const bool is_highdim = shape.size() > 4; + const auto& effective_storage_types = is_highdim + ? std::vector{utils::kBuffer} + : storage_types; + for (const auto& fp_layout : fp_layouts) { for (const auto& quant_layout : quant_layouts) { - for (const auto& storage_type : storage_types) { + for (const auto& storage_type : effective_storage_types) { QDQ8BitConfig config; config.shape = shape; config.test_case_name = prefix; @@ -244,7 +260,8 @@ std::vector generate_q_dq_8bit_test_cases() { test_cases.push_back(create_test_case_from_config( config, storage_type, vkapi::kFloat, fp_layout, quant_layout)); // For 4W4C layout, also test with legacy implementation - if (fp_layout == utils::kChannelsPacked && + // (legacy path doesn't support high-dim tensors) + if (!is_highdim && fp_layout == utils::kChannelsPacked && quant_layout == utils::kPackedInt8_4W4C) { test_cases.push_back(create_test_case_from_config( config,