From 3bbc37cd04303101c348f18532757a20764f6de2 Mon Sep 17 00:00:00 2001 From: ssjia Date: Mon, 9 Mar 2026 16:45:47 -0700 Subject: [PATCH 1/5] [ET-VK][ez] Fix duplicate placeholder target in create_constant_placeholder Pull Request resolved: https://github.com/pytorch/executorch/pull/18013 When multiple pattern replacements (e.g., quantized conv and quantized linear) share the same weight parameter, each independently calls create_constant_placeholder to create a _sums placeholder with the same name. torch.fx.Graph.create_node deduplicates node.name but not node.target, so the second call produces a placeholder with a unique name but a duplicate target. Since recompile() uses node.target for function parameter names, this causes a SyntaxError: duplicate argument in function definition. Fix by checking the state_dict/constants dicts (O(1) lookup) before creating the node. If the name already exists, find and return the existing placeholder node. ghstack-source-id: 349646647 @exported-using-ghexport Differential Revision: [D95807071](https://our.internmc.facebook.com/intern/diff/D95807071/) --- backends/transforms/utils.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/backends/transforms/utils.py b/backends/transforms/utils.py index ef19a937b0b..cb3e8c9469a 100644 --- a/backends/transforms/utils.py +++ b/backends/transforms/utils.py @@ -111,6 +111,15 @@ def create_constant_placeholder( target = name + # If a placeholder with this target already exists, return it to avoid + # duplicate parameter names in the generated function signature which would + # cause a SyntaxError on recompile. This can happen when multiple pattern + # replacements independently create placeholders for a shared weight. + if name in exp_program.state_dict or name in exp_program.constants: + for n in graph.nodes: + if n.op == "placeholder" and n.target == name: + return n + # Add data to state_dict/ constants match kind: case InputKind.PARAMETER: From b4dc7dc426b70fdbadcdf1b136901d929d597476 Mon Sep 17 00:00:00 2001 From: ssjia Date: Mon, 9 Mar 2026 16:45:49 -0700 Subject: [PATCH 2/5] [ET-VK][qlinear] Look through output view_copy when detecting output quantization Pull Request resolved: https://github.com/pytorch/executorch/pull/18014 When `aten.linear` has 3D+ inputs, it decomposes into `view_copy -> mm -> view_copy`. The output view_copy between mm and the subsequent quantize_per_tensor node was preventing the pattern matcher from detecting output quantization, causing the match to fall through to `linear_q8ta_q8csw` instead of `q8ta_linear_gemv`. This caused a dtype mismatch during FakeTensor re-tracing in FusePatternsPass because `linear_q8ta_q8csw`'s composite implementation does not dequantize its input, producing int8 output where float32 was expected. Mirror the existing input-side view_copy handling (lines 99-104) on the output side so the quantize node is found through the view_copy. ghstack-source-id: 349646653 @exported-using-ghexport Differential Revision: [D95807075](https://our.internmc.facebook.com/intern/diff/D95807075/) --- backends/vulkan/patterns/quantized_linear.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/backends/vulkan/patterns/quantized_linear.py b/backends/vulkan/patterns/quantized_linear.py index df80749e72f..85e3476cad3 100644 --- a/backends/vulkan/patterns/quantized_linear.py +++ b/backends/vulkan/patterns/quantized_linear.py @@ -174,12 +174,21 @@ def __init__(self, mm_node: torch.fx.Node) -> None: # noqa: C901 # Check if the output is also quantized (q → dq → linear → q pattern) # Also handle fused linear+relu (q → dq → linear → relu → q pattern) + # Due to decomposition of aten.linear for 3D+ inputs, there may be a + # view_copy between the mm output and the quantize node. self.quantize_output_node = None self.output_scales_node = None self.output_zeros_node = None self.relu_node = None + self.output_view_copy_node = None if len(self.output_node.users) == 1: cur_node = list(self.output_node.users)[0] + # Skip potential view_copy between linear and output quantize + if utils.is_view_copy_node(cur_node) and len(cur_node.users) == 1: + self.output_view_copy_node = cur_node + self.all_nodes.append(self.output_view_copy_node) + self.output_node = self.output_view_copy_node + cur_node = list(cur_node.users)[0] if cur_node.target == exir_ops.edge.aten.relu.default: self.relu_node = cur_node if len(cur_node.users) == 1: From 66f8920a004057e96f3d50217f75ee0a1c74dd3a Mon Sep 17 00:00:00 2001 From: ssjia Date: Mon, 9 Mar 2026 16:45:52 -0700 Subject: [PATCH 3/5] [ET-VK][qdq] Support high-dimensional tensors in quantize/dequantize per tensor Pull Request resolved: https://github.com/pytorch/executorch/pull/18015 The Q8ta quantize/dequantize ops were limited to 4D tensors because the GLSL block indexing infrastructure only handled 4 dimensions. This adds native 8D support to the block indexing helpers so that tensors with 5+ dimensions (common in models with batch dimensions) can be quantized/dequantized without falling back to CPU. The approach adds 8D versions of the block index functions (using TensorIndex instead of TensorIndex4D), 8D block load/store macros, and an 8D block decomposition function. Packed dim indices are guaranteed < 4, so the block dimensions always operate on data[0] while dims 4-7 are handled as implicit outer batch dimensions. The C++ dispatch functions are also updated to include extra batch dims in the thread count computation. This diff was authored with the assistance of an AI coding tool. ghstack-source-id: 349646650 @exported-using-ghexport Differential Revision: [D95807073](https://our.internmc.facebook.com/intern/diff/D95807073/) --- backends/vulkan/op_registry.py | 2 + .../graph/ops/glsl/block_indexing.glslh | 72 ++++++++++++++++++ .../graph/ops/glsl/block_int8x4_load.glslh | 51 +++++++++++++ .../graph/ops/glsl/block_int8x4_store.glslh | 50 +++++++++++++ .../runtime/graph/ops/glsl/block_load.glslh | 75 +++++++++++++++++++ .../runtime/graph/ops/glsl/block_store.glslh | 68 +++++++++++++++++ .../runtime/graph/ops/glsl/indexing.glslh | 72 ++++++++++++++++++ .../graph/ops/glsl/q8ta_dequantize.glsl | 22 +++--- .../runtime/graph/ops/glsl/q8ta_quantize.glsl | 24 +++--- .../vulkan/runtime/graph/ops/impl/Common.cpp | 10 ++- .../vulkan/test/custom_ops/test_q8ta_qdq.cpp | 21 +++++- 11 files changed, 443 insertions(+), 24 deletions(-) diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index bb7c0562bad..4bcdbadeea5 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -482,6 +482,7 @@ def register_quantize_per_tensor(): outputs_storage=[ utils.PACKED_INT8_BUFFER, ], + supports_highdim=True, ) @@ -499,6 +500,7 @@ def register_dequantize_per_tensor(): outputs_storage=[ utils.CHANNELS_PACKED_TEXTURE_OR_CONTIGUOUS_BUFFER, ], + supports_highdim=True, ) diff --git a/backends/vulkan/runtime/graph/ops/glsl/block_indexing.glslh b/backends/vulkan/runtime/graph/ops/glsl/block_indexing.glslh index e7e64a601ef..d6693d7ff26 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/block_indexing.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/block_indexing.glslh @@ -109,6 +109,78 @@ TensorIndex4D contiguous_block_idx_to_tensor4d_idx_with_block_config( return tidx; } +/* + * 8D version: decomposes a contiguous block index into a TensorIndex with up to + * 8 dimensions. The block config encodes dim_order for the first 4 dims. Dims + * 4-7 are implicit outer (non-block) batch dimensions in ascending order. + */ +TensorIndex contiguous_block_idx_to_tensor_idx_with_block_config( + const BufferMetadata meta, + const uint block_idx, + const int block_config) { + TensorIndex tidx; + initialize(tidx); + + const int inner_dim = get_block_inner_dim(block_config); + const int outer_dim = get_block_outer_dim(block_config); + const int nonblock_dim_1 = get_block_dim_order_2(block_config); + const int nonblock_dim_2 = get_block_dim_order_3(block_config); + const uint inner_bs = uint(get_block_inner_dim_block_size(block_config)); + const uint outer_bs = uint(get_block_outer_dim_block_size(block_config)); + + // Compute block strides for each dim, inner→outer ordering: + // inner_dim, outer_dim, nonblock1, nonblock2, 4, 5, 6, 7 + uint stride = 1; + + // Inner block dim + stride = div_up(size_at(meta, inner_dim), inner_bs); + + // Outer block dim + uint stride_outer = stride; + stride *= div_up(size_at(meta, outer_dim), outer_bs); + + // First non-block dim + uint stride_nb1 = stride; + stride *= size_at(meta, nonblock_dim_1); + + // Second non-block dim + uint stride_nb2 = stride; + stride *= size_at(meta, nonblock_dim_2); + + // Extra batch dims (4-7): always non-block, ascending order + uint batch_strides[4]; + [[unroll]] for (int i = 0; i < 4; ++i) { + batch_strides[i] = stride; + stride *= size_at(meta, 4 + i); + } + + // Decompose from outermost to innermost + uint remaining = block_idx; + + // Dims 7, 6, 5, 4 (outermost batch dims first) + [[unroll]] for (int i = 3; i >= 0; --i) { + tidx.data[1][i] = remaining / batch_strides[i]; + remaining %= batch_strides[i]; + } + + // Second non-block dim + tidx.data[0][nonblock_dim_2] = remaining / stride_nb2; + remaining %= stride_nb2; + + // First non-block dim + tidx.data[0][nonblock_dim_1] = remaining / stride_nb1; + remaining %= stride_nb1; + + // Outer block dim (multiply by block size to get element index) + tidx.data[0][outer_dim] = mul_4(remaining / stride_outer); + remaining %= stride_outer; + + // Inner block dim (multiply by block size to get element index) + tidx.data[0][inner_dim] = mul_4(remaining); + + return tidx; +} + // // TextureMetadata variants of block indexing // diff --git a/backends/vulkan/runtime/graph/ops/glsl/block_int8x4_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/block_int8x4_load.glslh index 6ea636a0a17..d889276082a 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/block_int8x4_load.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/block_int8x4_load.glslh @@ -71,4 +71,55 @@ return block; \ } +// +// 8D version (TensorIndex) +// + +#define define_load_int8x4_buffer_8d_fns(buffer_name) \ + \ + ivec4 load_int8x4_block_from_##buffer_name( \ + const BufferMetadata meta, \ + const TensorIndex tidx_base, \ + const int hashed_layout, \ + const int block_outer_dim) { \ + const int outer_packed_dim = get_outer_packed_dim(hashed_layout); \ + const int outer_block_size = \ + get_outer_packed_dim_block_size(hashed_layout); \ + \ + /* Compute base packed index using 8D block-based indexing */ \ + const uint block_idx = \ + tensor_idx_to_block_idx(meta, tidx_base, hashed_layout); \ + const uint texels_per_block = div_4(get_block_numel(hashed_layout)); \ + uint buf_idx = block_idx * texels_per_block; \ + \ + /* Fast path: contiguous texels when iterating along outer_packed_dim */ \ + if (outer_block_size == 4) { \ + if (block_outer_dim == outer_packed_dim) { \ + return ivec4( \ + buffer_name[buf_idx], \ + buffer_name[buf_idx + 1], \ + buffer_name[buf_idx + 2], \ + buffer_name[buf_idx + 3]); \ + } \ + else { \ + buf_idx += mod_4(int(tidx_base.data[0][outer_packed_dim])); \ + } \ + } \ + \ + /* General path: use stride for non-contiguous access */ \ + const uint outer_stride = \ + stride_at(meta, block_outer_dim) * texels_per_block; \ + const uint outer_size = size_at(meta, block_outer_dim); \ + const int base_outer_idx = int(tidx_base.data[0][block_outer_dim]); \ + \ + ivec4 block = ivec4(0); \ + [[unroll]] for (int block_y = 0; block_y < 4; ++block_y) { \ + if (base_outer_idx + block_y < int(outer_size)) { \ + block[block_y] = buffer_name[buf_idx]; \ + } \ + buf_idx += outer_stride; \ + } \ + return block; \ + } + #endif // BLOCK_INT8X4_LOAD_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/block_int8x4_store.glslh b/backends/vulkan/runtime/graph/ops/glsl/block_int8x4_store.glslh index 2a0e037c291..425df205317 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/block_int8x4_store.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/block_int8x4_store.glslh @@ -71,4 +71,54 @@ } \ } +// +// 8D version (TensorIndex) +// + +#define define_store_int8x4_buffer_8d_fns(buffer_name) \ + \ + void store_int8x4_block_to_##buffer_name( \ + const BufferMetadata meta, \ + const TensorIndex tidx_base, \ + const int hashed_layout, \ + const int block_outer_dim, \ + const ivec4 block) { \ + const int outer_packed_dim = get_outer_packed_dim(hashed_layout); \ + const int outer_block_size = \ + get_outer_packed_dim_block_size(hashed_layout); \ + \ + /* Compute base packed index using 8D block-based indexing */ \ + const uint block_idx = \ + tensor_idx_to_block_idx(meta, tidx_base, hashed_layout); \ + const uint texels_per_block = div_4(get_block_numel(hashed_layout)); \ + uint buf_idx = block_idx * texels_per_block; \ + \ + /* Fast path: contiguous texels when iterating along outer_packed_dim */ \ + if (outer_block_size == 4) { \ + if (block_outer_dim == outer_packed_dim) { \ + buffer_name[buf_idx] = block[0]; \ + buffer_name[buf_idx + 1] = block[1]; \ + buffer_name[buf_idx + 2] = block[2]; \ + buffer_name[buf_idx + 3] = block[3]; \ + return; \ + } \ + else { \ + buf_idx += mod_4(int(tidx_base.data[0][outer_packed_dim])); \ + } \ + } \ + \ + /* General path: use stride for non-contiguous access */ \ + const uint outer_stride = \ + stride_at(meta, block_outer_dim) * texels_per_block; \ + const uint outer_size = size_at(meta, block_outer_dim); \ + const int base_outer_idx = int(tidx_base.data[0][block_outer_dim]); \ + \ + [[unroll]] for (int block_y = 0; block_y < 4; ++block_y) { \ + if (base_outer_idx + block_y < int(outer_size)) { \ + buffer_name[buf_idx] = block[block_y]; \ + } \ + buf_idx += outer_stride; \ + } \ + } + #endif // BLOCK_INT8X4_STORE_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/block_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/block_load.glslh index d72a176aa0e..3719cbd3460 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/block_load.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/block_load.glslh @@ -102,4 +102,79 @@ return block; \ } +// +// 8D buffer load functions (TensorIndex version) +// + +#define define_load_buffer_8d_fns(buffer_name) \ + \ + mat4 load_fp_block_from_##buffer_name( \ + const BufferMetadata meta, \ + const TensorIndex tidx_base, \ + const int hashed_layout, \ + const int block_outer_dim) { \ + const int block_inner_dim = get_packed_dim(hashed_layout); \ + \ + /* Compute base buffer index using 8D strides */ \ + const uint base_idx = \ + tensor_idx_to_buf_idx(meta, tidx_base, hashed_layout); \ + const uint outer_stride = stride_at(meta, block_outer_dim); \ + /* Inner stride is 1 since packed_dim == block_inner_dim */ \ + \ + /* Pre-compute bounds (block dims always < 4, so use data[0]) */ \ + const uint outer_size = size_at(meta, block_outer_dim); \ + const uint inner_size = size_at(meta, block_inner_dim); \ + const int base_outer_idx = int(tidx_base.data[0][block_outer_dim]); \ + const int base_inner_idx = int(tidx_base.data[0][block_inner_dim]); \ + \ + mat4 block; \ + [[unroll]] for (int block_y = 0; block_y < 4; ++block_y) { \ + if (base_outer_idx + block_y < int(outer_size)) { \ + const uint row_idx = base_idx + block_y * outer_stride; \ + [[unroll]] for (int block_x = 0; block_x < 4; ++block_x) { \ + if (base_inner_idx + block_x < int(inner_size)) { \ + block[block_y][block_x] = float(buffer_name[row_idx + block_x]); \ + } else { \ + block[block_y][block_x] = 0.0; \ + } \ + } \ + } else { \ + block[block_y] = vec4(0.0); \ + } \ + } \ + return block; \ + } + +// +// 8D texture load functions (TensorIndex version) +// Converts to TensorIndex4D internally for texture operations. +// + +#define define_load_texture_8d_fns(texture_name) \ + \ + mat4 load_fp_block_from_##texture_name( \ + const TextureMetadata meta, \ + const TensorIndex tidx_base, \ + const int hashed_layout, \ + const int block_outer_dim) { \ + /* Convert to 4D for texture operations (textures are always <= 4D) */ \ + TensorIndex4D tidx4d; \ + tidx4d.data = ivec4(tidx_base.data[0]); \ + ivec3 tex_pos = tensor4d_idx_to_texel_pos_simple(meta, tidx4d); \ + const int tex_outer_dim = mod_4(block_outer_dim); \ + const int outer_size = meta.sizes[block_outer_dim]; \ + const int base_outer_idx = tidx4d.data[block_outer_dim]; \ + \ + mat4 block; \ + [[unroll]] for (int block_y = 0; block_y < 4; ++block_y) { \ + if (base_outer_idx + block_y < outer_size) { \ + block[block_y] = vec4(texelFetch(texture_name, tex_pos, 0)); \ + } else { \ + block[block_y] = vec4(0.0); \ + } \ + tex_pos[tex_outer_dim]++; \ + } \ + return block; \ + } + #endif // BLOCK_LOAD_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/block_store.glslh b/backends/vulkan/runtime/graph/ops/glsl/block_store.glslh index 66e9ab9fa2b..aa302290229 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/block_store.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/block_store.glslh @@ -94,4 +94,72 @@ } \ } +// +// 8D buffer store functions (TensorIndex version) +// + +#define define_store_buffer_8d_fns(buffer_name, scalar_type) \ + \ + void store_fp_block_to_##buffer_name( \ + const BufferMetadata meta, \ + const TensorIndex tidx_base, \ + const int hashed_layout, \ + const int block_outer_dim, \ + const mat4 block) { \ + const int block_inner_dim = get_packed_dim(hashed_layout); \ + \ + /* Compute base buffer index using 8D strides */ \ + const uint base_idx = \ + tensor_idx_to_buf_idx(meta, tidx_base, hashed_layout); \ + const uint outer_stride = stride_at(meta, block_outer_dim); \ + /* Inner stride is 1 since packed_dim == block_inner_dim */ \ + \ + /* Pre-compute bounds (block dims always < 4, so use data[0]) */ \ + const uint outer_size = size_at(meta, block_outer_dim); \ + const uint inner_size = size_at(meta, block_inner_dim); \ + const int base_outer_idx = int(tidx_base.data[0][block_outer_dim]);\ + const int base_inner_idx = int(tidx_base.data[0][block_inner_dim]);\ + \ + [[unroll]] for (int block_y = 0; block_y < 4; ++block_y) { \ + if (base_outer_idx + block_y < int(outer_size)) { \ + const uint row_idx = base_idx + block_y * outer_stride; \ + [[unroll]] for (int block_x = 0; block_x < 4; ++block_x) { \ + if (base_inner_idx + block_x < int(inner_size)) { \ + buffer_name[row_idx + block_x] = \ + scalar_type(block[block_y][block_x]); \ + } \ + } \ + } \ + } \ + } + +// +// 8D texture store functions (TensorIndex version) +// Converts to TensorIndex4D internally for texture operations. +// + +#define define_store_texture_8d_fns(texture_name, vec4_type) \ + \ + void store_fp_block_to_##texture_name( \ + const TextureMetadata meta, \ + const TensorIndex tidx_base, \ + const int hashed_layout, \ + const int block_outer_dim, \ + const mat4 block) { \ + /* Convert to 4D for texture operations (textures are always <= 4D) */ \ + TensorIndex4D tidx4d; \ + tidx4d.data = ivec4(tidx_base.data[0]); \ + ivec3 tex_pos = tensor4d_idx_to_texel_pos_simple(meta, tidx4d); \ + const int tex_outer_dim = mod_4(block_outer_dim); \ + const int outer_size = meta.sizes[block_outer_dim]; \ + const int base_outer_idx = tidx4d.data[block_outer_dim]; \ + \ + [[unroll]] for (int block_y = 0; block_y < 4; ++block_y) { \ + if (base_outer_idx + block_y < outer_size) { \ + imageStore(texture_name, tex_pos, vec4_type(block[block_y])); \ + } \ + tex_pos[tex_outer_dim]++; \ + } \ + } + #endif // BLOCK_STORE_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/indexing.glslh b/backends/vulkan/runtime/graph/ops/glsl/indexing.glslh index 51cda9a3d1d..c576128700e 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/indexing.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/indexing.glslh @@ -200,6 +200,10 @@ bool out_of_bounds(const TensorIndex tidx, const BufferMetadata meta) { any(greaterThanEqual(tidx.data[1], meta.sizes[1])); } +bool out_of_bounds(const TensorIndex tidx, const TextureMetadata meta) { + return any(greaterThanEqual(ivec4(tidx.data[0]), meta.sizes)); +} + // // TensorIndex4D (useful for texture backed tensors) // @@ -624,6 +628,74 @@ int tensor4d_idx_to_texel_idx( return block_idx; } +// +// 8D block-packed tensor indexing (TensorIndex versions) +// +// These functions extend the 4D versions above to handle tensors with up to 8 +// dimensions. Packed dims are guaranteed < 4, so block operations only affect +// data[0]. Dims 4-7 are non-block batch dimensions. +// + +/* + * 8D version of tensor4d_idx_to_block_idx. + * Packed dims (< 4) are divided by their block sizes, then all 8 dims + * contribute to the linear block index via their strides. + */ +int tensor_idx_to_block_idx( + const BufferMetadata meta, + TensorIndex tidx, + const int hashed_layout) { + const int inner_dim = get_packed_dim(hashed_layout); + const int outer_dim = get_outer_packed_dim(hashed_layout); + const int inner_block_size = get_packed_dim_block_size(hashed_layout); + const int outer_block_size = get_outer_packed_dim_block_size(hashed_layout); + + // Convert packed dims to block-space coordinates (packed dims always < 4) + tidx.data[0][inner_dim] = tidx.data[0][inner_dim] / uint(inner_block_size); + tidx.data[0][outer_dim] = tidx.data[0][outer_dim] / uint(outer_block_size); + + // Compute block-space linear index over all 8 dims + int block_idx = 0; + [[unroll]] for (int d = 0; d < 4; ++d) { + block_idx += int(meta.strides[0][d]) * int(tidx.data[0][d]); + } + [[unroll]] for (int d = 0; d < 4; ++d) { + block_idx += int(meta.strides[1][d]) * int(tidx.data[1][d]); + } + return block_idx; +} + +/* + * 8D version of tensor4d_idx_to_intra_block_idx. + * Packed dims are always < 4, so only data[0] is accessed. + */ +int tensor_idx_to_intra_block_idx( + const TensorIndex tidx, + const int hashed_layout) { + const int inner_dim = get_packed_dim(hashed_layout); + const int outer_dim = get_outer_packed_dim(hashed_layout); + const int inner_block_size = get_packed_dim_block_size(hashed_layout); + const int outer_block_size = get_outer_packed_dim_block_size(hashed_layout); + + const int inner_offset = int(tidx.data[0][inner_dim]) % inner_block_size; + const int outer_offset = int(tidx.data[0][outer_dim]) % outer_block_size; + return outer_offset * inner_block_size + inner_offset; +} + +/* + * 8D version of tensor4d_idx_to_buf_idx. + */ +int tensor_idx_to_buf_idx( + const BufferMetadata meta, + const TensorIndex tidx, + const int hashed_layout) { + const int block_idx = tensor_idx_to_block_idx(meta, tidx, hashed_layout); + const int intra_block_idx = + tensor_idx_to_intra_block_idx(tidx, hashed_layout); + const int block_numel = get_block_numel(hashed_layout); + return block_idx * block_numel + intra_block_idx; +} + // // Debug utilities // diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_dequantize.glsl b/backends/vulkan/runtime/graph/ops/glsl/q8ta_dequantize.glsl index 6989dc2d87d..88089627911 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/q8ta_dequantize.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_dequantize.glsl @@ -53,13 +53,13 @@ ${layout_declare_spec_const(C, "int", "inp_block_config", "0")} #include "block_int8x4_load.glslh" #include "block_store.glslh" -// Generate loading functions for t_inp buffer -define_load_int8x4_buffer_fns(t_inp) -// Generate storing functions for t_outp +// Generate 8D loading functions for t_inp buffer +define_load_int8x4_buffer_8d_fns(t_inp) +// Generate 8D storing functions for t_outp $if OUTPUT_STORAGE == "buffer": - define_store_buffer_fns(t_outp, T) + define_store_buffer_8d_fns(t_outp, T) $else: - define_store_texture_fns(t_outp, VEC4_T) + define_store_texture_8d_fns(t_outp, VEC4_T) mat4 dequantize_int8x4_block( const ivec4 block, const float scale, const int zp) { @@ -81,18 +81,20 @@ mat4 dequantize_int8x4_block( } void main() { - TensorIndex4D tidx; + TensorIndex tidx; #ifdef USING_BUFFER - // Buffer storage: use linear dispatch + // Buffer storage: use linear dispatch (supports up to 8D) const uint contig_block_idx = gl_GlobalInvocationID.x; - tidx = contiguous_block_idx_to_tensor4d_idx_with_block_config( + tidx = contiguous_block_idx_to_tensor_idx_with_block_config( inp, contig_block_idx, inp_block_config); #else - // Texture storage: use 3D extents dispatch + // Texture storage: use 3D extents dispatch (limited to 4D) const uvec3 thread_idx = gl_GlobalInvocationID; - tidx = block_idx_3d_to_tensor4d_idx_with_block_config( + TensorIndex4D tidx4d = block_idx_3d_to_tensor4d_idx_with_block_config( inp, thread_idx, inp_block_config); + initialize(tidx); + tidx.data[0] = uvec4(tidx4d.data); #endif if (out_of_bounds(tidx, inp)) { diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_quantize.glsl b/backends/vulkan/runtime/graph/ops/glsl/q8ta_quantize.glsl index c0c2a3a914a..2f458054a32 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/q8ta_quantize.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_quantize.glsl @@ -52,15 +52,15 @@ ${layout_declare_spec_const(C, "int", "outp_block_config", "0")} #include "block_indexing.glslh" #include "block_load.glslh" $if INPUT_STORAGE == "buffer": - // Generate loading functions for t_inp buffer - define_load_buffer_fns(t_inp) + // Generate 8D loading functions for t_inp buffer + define_load_buffer_8d_fns(t_inp) $else: - // Generate loading functions for t_inp texture - define_load_texture_fns(t_inp) + // Generate 8D loading functions for t_inp texture + define_load_texture_8d_fns(t_inp) #include "block_int8x4_store.glslh" -// Generate storing functions for t_outp buffer -define_store_int8x4_buffer_fns(t_outp) +// Generate 8D storing functions for t_outp buffer +define_store_int8x4_buffer_8d_fns(t_outp) ivec4 quantize_fp_block( const mat4 block, const float inv_scale, const int zp) { @@ -79,18 +79,20 @@ ivec4 quantize_fp_block( } void main() { - TensorIndex4D tidx; + TensorIndex tidx; #ifdef USING_BUFFER - // Buffer storage: use linear dispatch + // Buffer storage: use linear dispatch (supports up to 8D) const uint contig_block_idx = gl_GlobalInvocationID.x; - tidx = contiguous_block_idx_to_tensor4d_idx_with_block_config( + tidx = contiguous_block_idx_to_tensor_idx_with_block_config( inp, contig_block_idx, inp_block_config); #else - // Texture storage: use 3D extents dispatch + // Texture storage: use 3D extents dispatch (limited to 4D) const uvec3 thread_idx = gl_GlobalInvocationID; - tidx = block_idx_3d_to_tensor4d_idx_with_block_config( + TensorIndex4D tidx4d = block_idx_3d_to_tensor4d_idx_with_block_config( inp, thread_idx, inp_block_config); + initialize(tidx); + tidx.data[0] = uvec4(tidx4d.data); #endif if (out_of_bounds(tidx, inp)) { diff --git a/backends/vulkan/runtime/graph/ops/impl/Common.cpp b/backends/vulkan/runtime/graph/ops/impl/Common.cpp index d052882afde..40ca0510383 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Common.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Common.cpp @@ -252,6 +252,10 @@ utils::uvec3 pick_linear_global_wg_with_block_config( utils::safe_downcast(utils::val_at(-1 - d, sizes)); } } + // Include extra batch dimensions beyond 4D WHCN space + for (int32_t d = 4; d < static_cast(sizes.size()); ++d) { + num_planes *= utils::safe_downcast(utils::val_at(-1 - d, sizes)); + } // Return linear workgroup size: {total_blocks, 1u, 1u} const uint32_t total_blocks = @@ -285,7 +289,11 @@ utils::uvec3 pick_extents_global_wg_with_block_config( const int64_t W = utils::val_at(-1, sizes); const int64_t H = utils::val_at(-2, sizes); const int64_t C = utils::val_at(-3, sizes); - const int64_t N = utils::val_at(-4, sizes); + // N is the product of all batch dimensions (WHCN dims 3+) + int64_t N = 1; + for (int64_t i = 4; i <= static_cast(sizes.size()); ++i) { + N *= utils::val_at(-i, sizes); + } // Dispatch structure: {x_threads, y_threads, z_threads} // - x corresponds to W dimension diff --git a/backends/vulkan/test/custom_ops/test_q8ta_qdq.cpp b/backends/vulkan/test/custom_ops/test_q8ta_qdq.cpp index a3ff8c42f86..9c8dbce6501 100644 --- a/backends/vulkan/test/custom_ops/test_q8ta_qdq.cpp +++ b/backends/vulkan/test/custom_ops/test_q8ta_qdq.cpp @@ -108,6 +108,8 @@ std::vector generate_q_dq_8bit_easy_cases() { {1, 16, 16, 16}, // 4D: [N, C, H, W] {1, 144}, // 2D: exercises block config with ndim < 4 {1, 90}, // 2D: matches skin_seg model's keypoint/bbox tensor sizes + {2, 1, 3, 16, 16}, // 5D: exercises high-dim batch support + {2, 1, 1, 4, 8, 16}, // 6D: exercises high-dim batch support }; // FP memory layouts to test @@ -199,6 +201,14 @@ std::vector generate_q_dq_8bit_test_cases() { {1, 64, 128, 128}, {1, 32, 64, 64}, {1, 128, 56, 56}, + + // 5D tensors (high-dim batch support) + {2, 1, 3, 16, 16}, + {1, 2, 8, 8, 8}, + + // 6D tensors (high-dim batch support) + {2, 1, 1, 4, 8, 16}, + {1, 1, 2, 3, 8, 8}, }; // FP memory layouts to test @@ -234,9 +244,15 @@ std::vector generate_q_dq_8bit_test_cases() { } } + // Skip texture3d for high-dim tensors (textures are limited to 4D) + const bool is_highdim = shape.size() > 4; + const auto& effective_storage_types = is_highdim + ? std::vector{utils::kBuffer} + : storage_types; + for (const auto& fp_layout : fp_layouts) { for (const auto& quant_layout : quant_layouts) { - for (const auto& storage_type : storage_types) { + for (const auto& storage_type : effective_storage_types) { QDQ8BitConfig config; config.shape = shape; config.test_case_name = prefix; @@ -244,7 +260,8 @@ std::vector generate_q_dq_8bit_test_cases() { test_cases.push_back(create_test_case_from_config( config, storage_type, vkapi::kFloat, fp_layout, quant_layout)); // For 4W4C layout, also test with legacy implementation - if (fp_layout == utils::kChannelsPacked && + // (legacy path doesn't support high-dim tensors) + if (!is_highdim && fp_layout == utils::kChannelsPacked && quant_layout == utils::kPackedInt8_4W4C) { test_cases.push_back(create_test_case_from_config( config, From 0b178db9999daae983efc3ef03aba468af0f0428 Mon Sep 17 00:00:00 2001 From: ssjia Date: Mon, 9 Mar 2026 16:45:55 -0700 Subject: [PATCH 4/5] [ET-VK][qconv] Add q8ta_conv2d_transposed operator Pull Request resolved: https://github.com/pytorch/executorch/pull/18016 Implement quantized transposed 2D convolution for the Vulkan backend, enabling int8 transposed convolutions used in decoder/upsampling networks. The GLSL shader iterates over all kernel positions and derives valid input positions via (output + padding - kernel) / stride. Invalid positions use input_zp_packed so the precomputed weight_sums zero-point correction remains consistent. Reuses the existing q8ta_conv2d weight packing and workgroup size selection since, after the pattern matcher reshapes transposed weights from (IC, OC, KH, KW) to (OC, KH*KW*IC), the layout is identical to regular conv2d. Supports hardware int8 dot product with software fallback, grouped convolutions, optional bias and ReLU activation. Only dilation=1 is supported (matching the ATen conv_transpose2d constraint). This diff was authored with Claude. ghstack-source-id: 349646651 @exported-using-ghexport Differential Revision: [D95807070](https://our.internmc.facebook.com/intern/diff/D95807070/) --- backends/vulkan/custom_ops_lib.py | 99 +++ backends/vulkan/op_registry.py | 33 + .../vulkan/patterns/quantized_convolution.py | 101 ++- .../ops/glsl/q8ta_conv2d_transposed.glsl | 254 ++++++++ .../ops/glsl/q8ta_conv2d_transposed.yaml | 17 + .../runtime/graph/ops/impl/Q8taConv2d.h | 7 + .../graph/ops/impl/Q8taConv2dTransposed.cpp | 267 ++++++++ .../impl/TestQ8taConv2dTransposed.cpp | 87 +++ backends/vulkan/test/custom_ops/targets.bzl | 1 + .../test_q8ta_conv2d_transposed.cpp | 601 ++++++++++++++++++ 10 files changed, 1444 insertions(+), 23 deletions(-) create mode 100644 backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_transposed.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_transposed.yaml create mode 100644 backends/vulkan/runtime/graph/ops/impl/Q8taConv2dTransposed.cpp create mode 100644 backends/vulkan/test/custom_ops/impl/TestQ8taConv2dTransposed.cpp create mode 100644 backends/vulkan/test/custom_ops/test_q8ta_conv2d_transposed.cpp diff --git a/backends/vulkan/custom_ops_lib.py b/backends/vulkan/custom_ops_lib.py index 87506f0b773..7f687bb10f4 100644 --- a/backends/vulkan/custom_ops_lib.py +++ b/backends/vulkan/custom_ops_lib.py @@ -685,6 +685,105 @@ def q8ta_conv2d_dw( lib.impl(name, q8ta_conv2d_dw, "CompositeExplicitAutograd") conv2d_q8ta_q8csw_dw_op = getattr(getattr(torch.ops, namespace), name) + +def q8ta_conv2d_transposed( + x: torch.Tensor, + input_scale: float, + input_zero_point: int, + weights: torch.Tensor, + weight_sums: torch.Tensor, + weight_scales: torch.Tensor, + output_scale: float, + output_zero_point: int, + bias: Optional[torch.Tensor], + kernel_size: list, + stride: list, + padding: list, + output_padding: list, + dilation: list, + groups: int, + activation: str, +): + x = torch.ops.quantized_decomposed.dequantize_per_tensor( + x, input_scale, input_zero_point, -128, 127, x.dtype + ) + + OC = weights.shape[0] + IC_per_group = int(x.shape[1] / groups) + K_h, K_w = kernel_size[0], kernel_size[1] + + orig_weight_K_dim = K_h * K_w * IC_per_group + if weights.shape[-1] > orig_weight_K_dim: + weights = weights[:, :orig_weight_K_dim] + + if weight_scales.shape[0] > OC: + weight_scales = weight_scales[:OC] + if bias is not None: + bias = bias[:OC] + + # Reshape to (OC, IC_per_group, K_h, K_w) then transpose to + # (IC_per_group * groups, OC_per_group, K_h, K_w) for conv_transpose2d + weights = weights.view(OC, IC_per_group, K_h, K_w) + OC_per_group = OC // groups + weights = ( + weights.view(groups, OC_per_group, IC_per_group, K_h, K_w) + .permute(0, 2, 1, 3, 4) + .contiguous() + .view(IC_per_group * groups, OC_per_group, K_h, K_w) + ) + + weight_zeros = torch.zeros_like(weight_scales, dtype=torch.int32) + # Dequantize per OC channel. For transposed weight (IC, OC_per_group, KH, KW), + # OC is at axis=1. + weights = torch.ops.quantized_decomposed.dequantize_per_channel( + weights, + weight_scales[:OC_per_group].repeat(groups) if groups > 1 else weight_scales, + weight_zeros[:OC_per_group].repeat(groups) if groups > 1 else weight_zeros, + 1, + -127, + 127, + torch.int8, + ) + + out = torch.nn.functional.conv_transpose2d( + x, weights, bias, stride, padding, output_padding, groups, dilation + ) + + if activation == "relu": + out = torch.nn.functional.relu(out) + + out = torch.ops.quantized_decomposed.quantize_per_tensor( + out, output_scale, output_zero_point, -128, 127, torch.int8 + ) + + return out + + +name = "q8ta_conv2d_transposed" +lib.define( + f""" + {name}( + Tensor x, + float input_scale, + int input_zero_point, + Tensor weights, + Tensor weight_sums, + Tensor weight_scales, + float output_scale, + int output_zero_point, + Tensor? bias, + SymInt[] kernel_size, + SymInt[] stride, + SymInt[] padding, + SymInt[] output_padding, + SymInt[] dilation, + SymInt groups, + str activation) -> Tensor + """ +) +lib.impl(name, q8ta_conv2d_transposed, "CompositeExplicitAutograd") +q8ta_conv2d_transposed_op = getattr(getattr(torch.ops, namespace), name) + ###################### ## apply_rotary_emb ## ###################### diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 4bcdbadeea5..af2389d72f9 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -865,6 +865,39 @@ def register_q8ta_conv2d_ops(): ) +@update_features( + [ + exir_ops.edge.et_vk.q8ta_conv2d_transposed.default, + ] +) +def register_q8ta_conv2d_transposed_op(): + return OpFeatures( + inputs_storage=[ + utils.PACKED_INT8_CONV2D_BUFFER, # input + utils.NO_STORAGE, # input_scale (non tensor) + utils.NO_STORAGE, # input_zero_point (non tensor) + utils.NO_STORAGE, # weight (prepacked) + utils.NO_STORAGE, # weight_sums (prepacked) + utils.NO_STORAGE, # weight_scales (prepacked) + utils.NO_STORAGE, # output_scale (non tensor) + utils.NO_STORAGE, # output_zero_point (non tensor) + utils.NO_STORAGE, # bias (prepacked) + utils.NO_STORAGE, # kernel_size (non tensor) + utils.NO_STORAGE, # stride (non tensor) + utils.NO_STORAGE, # padding (non tensor) + utils.NO_STORAGE, # output_padding (non tensor) + utils.NO_STORAGE, # dilation (non tensor) + utils.NO_STORAGE, # groups (non tensor) + utils.NO_STORAGE, # activation (non tensor) + ], + outputs_storage=[ + utils.PACKED_INT8_CHANNELS_PACKED_BUFFER, + ], + supports_resize=False, + supports_prepacking=True, + ) + + # ============================================================================= # Q8taLinear.cpp # ============================================================================= diff --git a/backends/vulkan/patterns/quantized_convolution.py b/backends/vulkan/patterns/quantized_convolution.py index 12ebbd1a382..d291d4009b7 100644 --- a/backends/vulkan/patterns/quantized_convolution.py +++ b/backends/vulkan/patterns/quantized_convolution.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Optional +from typing import cast, List, Optional import executorch.backends.vulkan.utils as utils @@ -33,12 +33,27 @@ def __init__(self, conv_node: torch.fx.Node) -> None: self.match_found = False self.all_nodes = [self.anchor_node] + # Determine if this is a transposed convolution + self.transposed = False + self.output_padding = [0, 0] + if conv_node.target == exir_ops.edge.aten.convolution.default: + transposed_flag = conv_node.args[6] if len(conv_node.args) > 6 else False + if transposed_flag: + self.transposed = True + self.output_padding = ( + cast(List[int], conv_node.args[7]) if len(conv_node.args) > 7 else [0, 0] + ) + # Extract convolution parameters self.stride = conv_node.args[3] if len(conv_node.args) > 3 else [1, 1] self.padding = conv_node.args[4] if len(conv_node.args) > 4 else [0, 0] self.dilation = conv_node.args[5] if len(conv_node.args) > 5 else [1, 1] self.groups = conv_node.args[8] if len(conv_node.args) > 8 else 1 + # Transposed conv only supported with dilation=[1,1] + if self.transposed and cast(List[int], self.dilation) != [1, 1]: + return + const_node, arg_chain = utils.trace_args_until_placeholder( self.anchor_node.args[1] ) @@ -60,6 +75,16 @@ def __init__(self, conv_node: torch.fx.Node) -> None: self.dequantize_weight_node = dequantize_weight_node self.all_nodes.extend(arg_chain) + # For transposed conv, verify per-channel quantization is on the OC dimension. + # Transposed weight shape is (IC, OC_per_group, KH, KW), so per-OC quantization + # should be on axis=1. If axis=0, that's per-IC which is not supported. + if self.transposed and utils.is_dequant_per_channel_node( + self.dequantize_weight_node + ): + quant_axis = self.dequantize_weight_node.args[3] + if quant_axis != 1: + return + # Identify weight quantization parameter nodes self.weight_scales_node, arg_chain = utils.trace_args_until_placeholder( self.dequantize_weight_node.args[1] @@ -177,9 +202,30 @@ def make_q8ta_conv2d_custom_op( bias_tensor = get_param_tensor(ep, match.bias_node) assert bias_tensor is not None - OC, IC_per_group, H, W = weight_tensor.shape + if match.transposed: + # Transposed conv weight shape: (IC, OC_per_group, H, W) + IC, OC_per_group, H, W = weight_tensor.shape + OC = OC_per_group * match.groups + IC_per_group = IC // match.groups + # Reshape to (OC, H*W*IC_per_group) matrix format for Im2Col-based + # transposed convolution. + # (IC, OC_per_group, H, W) -> + # (groups, IC_per_group, OC_per_group, H, W) -> + # (groups, OC_per_group, H, W, IC_per_group) -> + # (OC, H*W*IC_per_group) + weight_tensor = ( + weight_tensor.reshape(match.groups, IC_per_group, OC_per_group, H, W) + .permute(0, 2, 3, 4, 1) + .contiguous() + .reshape(OC, H * W * IC_per_group) + .contiguous() + ) + else: + OC, IC_per_group, H, W = weight_tensor.shape - is_depthwise_conv = IC_per_group == 1 and match.groups == OC + is_depthwise_conv = ( + not match.transposed and IC_per_group == 1 and match.groups == OC + ) if is_depthwise_conv: assert OC % 4 == 0, "depthwise conv requires that OC is divisible by 4" @@ -188,7 +234,7 @@ def make_q8ta_conv2d_custom_op( weight_tensor = ( weight_tensor.permute(2, 3, 1, 0).contiguous().view(H, W, OC).contiguous() ) - else: + elif not match.transposed: # Reshape weight tensor from (OC, IC_per_group, H, W) to (OC, H * W * IC_per_group) # (i.e. matrix format). This prepares the weights for Im2Col-based convolution. weight_tensor = ( @@ -257,32 +303,41 @@ def make_q8ta_conv2d_custom_op( ) with graph_module.graph.inserting_before(match.output_node): - op_target = exir_ops.edge.et_vk.q8ta_conv2d.default - if is_depthwise_conv: + if match.transposed: + op_target = exir_ops.edge.et_vk.q8ta_conv2d_transposed.default + elif is_depthwise_conv: op_target = exir_ops.edge.et_vk.q8ta_conv2d_dw.default elif is_pointwise_conv: op_target = exir_ops.edge.et_vk.q8ta_conv2d_pw.default + else: + op_target = exir_ops.edge.et_vk.q8ta_conv2d.default + + op_args = ( + match.quantize_input_node, + match.input_scales_node, + match.input_zeros_node, + match.weight_node, + weight_sums_node, + match.weight_scales_node, + match.output_scales_node, + match.output_zeros_node, + match.bias_node, + [H, W], + match.stride, + match.padding, + ) + if match.transposed: + op_args = op_args + (match.output_padding,) + op_args = op_args + ( + match.dilation, + match.groups, + "relu" if match.relu_node is not None else "none", + ) qconv_node = graph_module.graph.create_node( "call_function", op_target, - args=( - match.quantize_input_node, - match.input_scales_node, - match.input_zeros_node, - match.weight_node, - weight_sums_node, - match.weight_scales_node, - match.output_scales_node, - match.output_zeros_node, - match.bias_node, # Add bias after weight_scales - [H, W], # Pass kernel size information before stride - match.stride, - match.padding, - match.dilation, - match.groups, - "relu" if match.relu_node is not None else "none", - ), + args=op_args, ) qconv_node.meta["val"] = match.output_node.meta["val"] diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_transposed.glsl b/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_transposed.glsl new file mode 100644 index 00000000000..efed2e3a95b --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_transposed.glsl @@ -0,0 +1,254 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +${define_required_extensions("buffer", DTYPE)} + +#define USE_INT8_DOT_PRODUCT_EXT ${USE_INT8_DOT_PRODUCT_EXT} + +#extension GL_EXT_control_flow_attributes : require +$if USE_INT8_DOT_PRODUCT_EXT == 1: + #extension GL_EXT_integer_dot_product : require + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, "buffer")} +#define T ${texel_load_component_type(DTYPE, "buffer")} + +${define_active_storage_type("buffer")} + +layout(std430) buffer; + +#include "indexing.glslh" +#include "common.glslh" +#include "conv2d_common.glslh" + +${layout_declare_tensor(B, "w", "t_packed_int8_output", "int", "buffer", is_scalar_array=True)} +${layout_declare_tensor(B, "r", "t_packed_int8_input", "int", "buffer", is_scalar_array=True)} +${layout_declare_tensor(B, "r", "t_packed_int8_weight", "int", "texture2d", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight_sums", "int", "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight_scales", DTYPE, "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_bias", DTYPE, "buffer", is_scalar_array=False)} + +// Metadata for input/output tensors (memory layout agnostic) +${layout_declare_ubo(B, "BufferMetadata", "outp")} +${layout_declare_ubo(B, "BufferMetadata", "inp")} +${layout_declare_ubo(B, "Conv2DParams", "conv2d_params")} + +layout(push_constant) uniform restrict Block { + float input_scale; + int input_zp; + float output_inv_scale; + int output_zp; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +${layout_declare_spec_const(C, "int", "apply_bias", "1")} +${layout_declare_spec_const(C, "int", "activation_type", "0")} + +// Layout specialization constants +${layout_declare_spec_const(C, "int", "inp_layout", "CONTIG_LAYOUT_INT")} +${layout_declare_spec_const(C, "int", "outp_layout", "CONTIG_LAYOUT_INT")} + +// Load weight block for a given (ic4, kx, ky, oc4) position. +// Weight texture layout (from pack_q8_conv2d_weights.glsl): +// block_x = oc4 * K_w + kx +// block_y = ky * IC4 + ic4 +// Each texel ivec4 has 4 components (4 output channels), each component is +// a packed int32 containing 4 int8 values for 4 consecutive input channels. +ivec4 load_weight_block(int ic4, int kx, int ky, int oc4, int IC4, int KW) { + const int block_x = oc4 * KW + kx; + const int block_y = ky * IC4 + ic4; + return texelFetch(t_packed_int8_weight, ivec2(block_x, block_y), 0); +} + +ivec4 quantize(const vec4 texel, const float inv_scale, const int zp) { + vec4 quantized = round(texel * inv_scale) + zp; + return clamp(ivec4(quantized), -128, 127); +} + +void main() { + // Thread mapping: same as q8ta_conv2d + // Each thread handles a 4W x 4C tile of output + int oc4 = int(gl_GlobalInvocationID.z); + int w4 = int(gl_GlobalInvocationID.x); + + // Initialize output tensor index (WHCN order) + TensorIndex4D outp_tidx; + outp_tidx.data[0] = w4 * 4; + outp_tidx.data[1] = int(gl_GlobalInvocationID.y); + outp_tidx.data[2] = oc4 * 4; + outp_tidx.data[3] = 0; + + const int W = int(outp.sizes[0][0]); + const int OC = int(outp.sizes[0][2]); + + // Bounds check + if (any(greaterThanEqual(outp_tidx.data, ivec4(outp.sizes[0])))) { + return; + } + + // Input dimensions + const int inp_W = int(inp.sizes[0][0]); + const int inp_H = int(inp.sizes[0][1]); + const int IC = int(inp.sizes[0][2]); + + // Compute channels per group + const int OC_per_group = OC / conv2d_params.groups; + const int IC_per_group = IC / conv2d_params.groups; + const int IC4_per_group = div_up_4(IC_per_group); + + // Determine which group this output channel block belongs to + const int group_idx = outp_tidx.data[2] / OC_per_group; + const int ic_group_start = group_idx * IC_per_group; + + // Get strides for efficient indexing + const int inp_w_stride = int(inp.strides[0][0]); + const int inp_h_stride = int(inp.strides[0][1]); + + // Create packed input zero point (4 copies of input_zp packed into int32) + const int input_zp_packed = pack_into_int32(ivec4(input_zp)); + + // Initialize accumulators for 4 width positions x 4 output channels each + ivec4 acc[4]; + [[unroll]] for (int i = 0; i < 4; ++i) { + acc[i] = ivec4(0); + } + + // Transposed convolution loop structure: + // Iterate over all kernel positions (ky, kx, ic4). For each position, + // compute the corresponding input position. If the input position is valid + // (in bounds and on a stride-aligned position), load the actual input; + // otherwise use input_zp_packed. This ensures weight_sums correction is + // consistent with the accumulation (all kernel positions are accounted for). + // + // For transposed convolution, the input position for a given (output, kernel) + // is: input = (output + padding - kernel) / stride, which must be exact + // (remainder == 0) and within [0, input_size). + + const int KH = conv2d_params.kernel_size.y; + const int KW = conv2d_params.kernel_size.x; + const int stride_x = conv2d_params.stride.x; + const int stride_y = conv2d_params.stride.y; + const int pad_x = conv2d_params.padding.x; + const int pad_y = conv2d_params.padding.y; + + for (int ky = 0; ky < KH; ky++) { + // Check if this kernel row maps to a valid input row + const int in_y_numer = outp_tidx.data[1] + pad_y - ky; + const bool y_stride_valid = (in_y_numer >= 0) && ((in_y_numer % stride_y) == 0); + const int iy = y_stride_valid ? (in_y_numer / stride_y) : 0; + const bool h_in_bounds = y_stride_valid && (iy < inp_H); + + // Loop order: ic4 before kx for better weight texture cache locality. + // Consecutive ic4 values at fixed kx access consecutive y-rows in the + // weight texture, but consecutive kx at fixed ic4 access consecutive + // x-coordinates (same row) which is the fast dimension for texture + // cache lines. By iterating ic4 in the outer loop and kx in the inner + // loop, each ic4 iteration sweeps kx across a texture row. + for (int ic4 = 0; ic4 < IC4_per_group; ic4++) { + for (int kx = 0; kx < KW; kx++) { + // Load weight block: 4 output channels x 4 input channels + const ivec4 weight_block = load_weight_block( + ic4, kx, ky, oc4, IC4_per_group, KW); + + // Process 4 adjacent width positions + [[unroll]] for (int subtile_w = 0; subtile_w < 4; ++subtile_w) { + const int ow = outp_tidx.data[0] + subtile_w; + const int in_x_numer = ow + pad_x - kx; + + // Load packed input, or use zero point if out of bounds + int packed_input = input_zp_packed; + if (h_in_bounds && in_x_numer >= 0 && (in_x_numer % stride_x) == 0) { + const int ix = in_x_numer / stride_x; + if (ix < inp_W) { + TensorIndex4D inp_tidx; + inp_tidx.data[0] = ix; + inp_tidx.data[1] = iy; + inp_tidx.data[2] = ic_group_start + ic4 * 4; + inp_tidx.data[3] = 0; + + int inp_texel_idx; + if (get_outer_packed_dim_block_size(inp_layout) == 1) { + inp_texel_idx = tensor4d_idx_to_texel_idx(inp, inp_tidx, inp_layout); + } else { + const int w4_inp = div_4(ix); + const int inp_c4 = div_4(inp_tidx.data[2]); + inp_texel_idx = (iy * inp_h_stride + w4_inp * inp_w_stride + inp_c4) * 4 + mod_4(ix); + } + packed_input = t_packed_int8_input[inp_texel_idx]; + } + } + + // Accumulate using packed int8 dot product for each output channel + [[unroll]] for (int oc_offset = 0; oc_offset < 4; ++oc_offset) { + acc[subtile_w][oc_offset] = dotPacked4x8AccSat( + packed_input, + weight_block[oc_offset], + acc[subtile_w][oc_offset]); + } + } + } + } + } + + // Apply input zero point correction via weight_sums + const vec4 weight_sums = vec4(t_weight_sums[oc4]); + const vec4 weight_scales = vec4(t_weight_scales[oc4]); + + // Convert to float, apply dequantization, and optionally add bias + vec4 facc[4]; + [[unroll]] for (int subtile_w = 0; subtile_w < 4; ++subtile_w) { + facc[subtile_w] = vec4(acc[subtile_w]); + facc[subtile_w] -= weight_sums * input_zp; + facc[subtile_w] *= weight_scales * input_scale; + } + + // Apply bias if enabled + if (apply_bias > 0) { + const vec4 bias = vec4(t_bias[oc4]); + [[unroll]] for (int subtile_w = 0; subtile_w < 4; ++subtile_w) { + facc[subtile_w] += bias; + } + } + + // Apply ReLU if enabled + if (activation_type > 0) { + [[unroll]] for (int subtile_w = 0; subtile_w < 4; ++subtile_w) { + facc[subtile_w] = max(facc[subtile_w], vec4(0.0)); + } + } + + // Compute base output texel index (for subtile_w=0) + const int base_outp_texel_idx = tensor4d_idx_to_texel_idx(outp, outp_tidx, outp_layout); + const int out_w_stride = int(outp.strides[0][0]); + + // Quantize and store outputs using stride offsets + [[unroll]] for (int subtile_w = 0; subtile_w < 4; ++subtile_w) { + // Skip out-of-bounds width positions + if (outp_tidx.data[0] >= W) { + continue; + } + + const ivec4 quantized_out = quantize(facc[subtile_w], output_inv_scale, output_zp); + const int packed_out = pack_into_int32(quantized_out); + + // Store using stride offset from base + int outp_texel_idx; + if (get_outer_packed_dim_block_size(outp_layout) == 1) { + outp_texel_idx = base_outp_texel_idx + subtile_w * out_w_stride; + } else { + outp_texel_idx = base_outp_texel_idx + subtile_w; + } + + t_packed_int8_output[outp_texel_idx] = packed_out; + + outp_tidx.data[0] += 1; + } +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_transposed.yaml b/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_transposed.yaml new file mode 100644 index 00000000000..69469fabd95 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_conv2d_transposed.yaml @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +q8ta_conv2d_transposed: + parameter_names_with_default_values: + DTYPE: float + USE_INT8_DOT_PRODUCT_EXT: 1 + generate_variant_forall: + DTYPE: + - VALUE: float + shader_variants: + - NAME: q8ta_conv2d_transposed + - NAME: q8ta_conv2d_transposed_fallback + USE_INT8_DOT_PRODUCT_EXT: 0 diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.h b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.h index 6da98fbef74..f463589c50a 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.h +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2d.h @@ -9,6 +9,7 @@ #pragma once #include +#include #include namespace vkcompute { @@ -145,4 +146,10 @@ void add_q8ta_im2col_node( void q8ta_conv2d_im2col(ComputeGraph& graph, const std::vector& args); +// Transposed convolution + +void q8ta_conv2d_transposed( + ComputeGraph& graph, + const std::vector& args); + } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dTransposed.cpp b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dTransposed.cpp new file mode 100644 index 00000000000..bdbdaa14fec --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taConv2dTransposed.cpp @@ -0,0 +1,267 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include +#include +#include +#include + +namespace vkcompute { + +// Dedicated workgroup size functions for transposed convolution. +// Unlike regular conv2d, transposed conv with stride > 1 causes branch +// divergence along the height dimension (different rows have different +// stride-alignment patterns). Keeping local_y=1 ensures all threads in a +// workgroup process the same height row, maximizing branch coherence. + +utils::uvec3 pick_q8ta_conv2d_transposed_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)resize_args; + + const ValueRef output = args.at(0).refs.at(0); + + const uint32_t W = graph->size_at(-1, output); + const uint32_t H = graph->size_at(-2, output); + const uint32_t C = graph->size_at(-3, output); + + const uint32_t W4 = utils::div_up_4(W); + const uint32_t C4 = utils::div_up_4(C); + + return {W4, H, C4}; +} + +utils::uvec3 pick_q8ta_conv2d_transposed_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)resize_args; + (void)graph; + (void)args; + + // Always keep local_y=1 to avoid branch divergence between height rows. + if (global_workgroup_size[0u] >= 6 && global_workgroup_size[2u] >= 6) { + return {8u, 1u, 8u}; + } + if (global_workgroup_size[0u] < 2u) { + return {1u, 1u, 64u}; + } + if (global_workgroup_size[2u] < 2u) { + return {64u, 1u, 1u}; + } + return {16u, 1u, 4u}; +} + +void add_q8ta_conv2d_transposed_node( + ComputeGraph& graph, + const ValueRef packed_int8_input, + const ValueRef input_scale, + const ValueRef input_zp, + const ValueRef packed_weight, + const ValueRef packed_weight_sums, + const ValueRef packed_weight_scales, + const ValueRef output_scale, + const ValueRef output_zp, + const ValueRef bias_data, + const ValueRef packed_bias, + const ValueRef kernel_size, + const ValueRef stride, + const ValueRef padding, + const ValueRef dilation, + const ValueRef groups, + const uint32_t activation_type, + const ValueRef packed_int8_output) { + Conv2DParams conv_params = create_conv2d_params( + graph, + packed_int8_input, + packed_int8_output, + kernel_size, + stride, + padding, + dilation, + groups); + + // Transposed convolution only supports dilation=1 + VK_CHECK_COND( + conv_params.dilation[0] == 1 && conv_params.dilation[1] == 1, + "q8ta_conv2d_transposed only supports dilation=1"); + + // The implementation requires that for grouped convolutions, the input + // channels per group is a multiple of 4. + if (conv_params.groups > 1) { + VK_CHECK_COND(conv_params.in_channels_per_group % 4 == 0); + } + + // Validate packed dim info for input and output tensors + VK_CHECK_COND(q8ta_conv2d_check_packed_dim_info( + graph.packed_dim_info_of(packed_int8_input))); + VK_CHECK_COND(q8ta_conv2d_check_packed_dim_info( + graph.packed_dim_info_of(packed_int8_output))); + + // Validate dtype is kInt8x4 + VK_CHECK_COND(graph.dtype_of(packed_int8_input) == vkapi::kInt8x4); + VK_CHECK_COND(graph.dtype_of(packed_int8_output) == vkapi::kInt8x4); + + float input_scale_val = graph.extract_scalar(input_scale); + int32_t input_zp_val = graph.extract_scalar(input_zp); + + float output_inv_scale_val = 1.0f / graph.extract_scalar(output_scale); + int32_t output_zp_val = graph.extract_scalar(output_zp); + + uint32_t apply_bias = 1; + if (graph.val_is_none(bias_data)) { + apply_bias = 0; + } + + std::vector push_constants = { + PushConstantDataInfo(&input_scale_val, sizeof(input_scale_val)), + PushConstantDataInfo(&input_zp_val, sizeof(input_zp_val)), + PushConstantDataInfo(&output_inv_scale_val, sizeof(output_inv_scale_val)), + PushConstantDataInfo(&output_zp_val, sizeof(output_zp_val)), + }; + + const bool use_hw_dot = + graph.context()->adapter_ptr()->supports_int8_dot_product(); + std::string kernel_name = + use_hw_dot ? "q8ta_conv2d_transposed" : "q8ta_conv2d_transposed_fallback"; + add_dtype_suffix(kernel_name, graph.dtype_of(packed_weight_scales)); + + // Pass metadata for both output and input tensors + vkapi::ParamsBindList param_buffers = { + graph.buffer_meta_ubo(packed_int8_output), + graph.buffer_meta_ubo(packed_int8_input), + graph.create_params_buffer(conv_params)}; + + // Build spec constants: apply_bias, activation_type + layout constants + vkapi::SpecVarList spec_constants = { + apply_bias, + activation_type, + // Layout specialization constants + graph.hashed_layout_of(packed_int8_input), + graph.hashed_layout_of(packed_int8_output), + }; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + pick_q8ta_conv2d_transposed_global_wg_size, + pick_q8ta_conv2d_transposed_local_wg_size, + // Inputs and Outputs + {{packed_int8_output, vkapi::kWrite}, + {{packed_int8_input, + packed_weight, + packed_weight_sums, + packed_weight_scales, + packed_bias}, + vkapi::kRead}}, + // Shader params buffers + param_buffers, + // Push Constants + push_constants, + // Specialization Constants + spec_constants, + // Resize args + {})); +} + +void q8ta_conv2d_transposed( + ComputeGraph& graph, + const std::vector& args) { + int32_t idx = 0; + const ValueRef packed_int8_input = args.at(idx++); + const ValueRef input_scale = args.at(idx++); + const ValueRef input_zp = args.at(idx++); + const ValueRef weight_data = args.at(idx++); + const ValueRef weight_sums_data = args.at(idx++); + const ValueRef weight_scales_data = args.at(idx++); + const ValueRef output_scale = args.at(idx++); + const ValueRef output_zp = args.at(idx++); + const ValueRef bias_data = args.at(idx++); + const ValueRef kernel_size = args.at(idx++); + const ValueRef stride = args.at(idx++); + const ValueRef padding = args.at(idx++); + args.at(idx++); // output_padding: only affects output size, not shader + const ValueRef dilation = args.at(idx++); + const ValueRef groups = args.at(idx++); + const ValueRef activation = args.at(idx++); + const ValueRef packed_int8_output = args.at(idx++); + + uint32_t activation_type_val = static_cast( + activation_type_from_string(graph.extract_string(activation))); + + QuantizationConfig weight_quant_config(8, kPerChannel, {}); + + // Reuse the conv2d weight packing (after the pattern matcher reshapes the + // transposed weight to (OC, KH*KW*IC_per_group), the weight layout is + // identical to regular conv2d) + ValueRef packed_weight = prepack_quantized_conv2d_weight( + graph, + weight_quant_config, + weight_data, + packed_int8_input, + packed_int8_output, + groups, + kernel_size); + + ValueRef packed_weight_sums = prepack_standard( + graph, weight_sums_data, utils::kBuffer, utils::kWidthPacked); + + ValueRef packed_weight_scales = prepack_standard( + graph, weight_scales_data, utils::kBuffer, utils::kWidthPacked); + + // Create a dummy tensor to fill the binding slot of the bias tensor if it is + // not provided + TmpTensor dummy_bias( + &graph, + {}, + graph.dtype_of(weight_scales_data), + utils::kBuffer, + utils::kWidthPacked); + + ValueRef packed_bias = dummy_bias.vref; + if (graph.val_is_not_none(bias_data)) { + packed_bias = + prepack_standard(graph, bias_data, utils::kBuffer, utils::kWidthPacked); + } + + add_q8ta_conv2d_transposed_node( + graph, + packed_int8_input, + input_scale, + input_zp, + packed_weight, + packed_weight_sums, + packed_weight_scales, + output_scale, + output_zp, + bias_data, + packed_bias, + kernel_size, + stride, + padding, + dilation, + groups, + activation_type_val, + packed_int8_output); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(et_vk.q8ta_conv2d_transposed.default, q8ta_conv2d_transposed); +} + +} // namespace vkcompute diff --git a/backends/vulkan/test/custom_ops/impl/TestQ8taConv2dTransposed.cpp b/backends/vulkan/test/custom_ops/impl/TestQ8taConv2dTransposed.cpp new file mode 100644 index 00000000000..894ce71fed9 --- /dev/null +++ b/backends/vulkan/test/custom_ops/impl/TestQ8taConv2dTransposed.cpp @@ -0,0 +1,87 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include + +namespace vkcompute { + +void test_q8ta_conv2d_transposed( + ComputeGraph& graph, + const std::vector& args) { + int32_t idx = 0; + const ValueRef fp_input = args.at(idx++); + const ValueRef input_scale = args.at(idx++); + const ValueRef input_zp = args.at(idx++); + const ValueRef weight_data = args.at(idx++); + const ValueRef weight_sums_data = args.at(idx++); + const ValueRef weight_scales_data = args.at(idx++); + const ValueRef output_scale = args.at(idx++); + const ValueRef output_zp = args.at(idx++); + const ValueRef bias_data = args.at(idx++); + const ValueRef kernel_size = args.at(idx++); + const ValueRef stride = args.at(idx++); + const ValueRef padding = args.at(idx++); + const ValueRef output_padding = args.at(idx++); + const ValueRef dilation = args.at(idx++); + const ValueRef groups = args.at(idx++); + const ValueRef activation = args.at(idx++); + const ValueRef layout_int = args.at(idx++); + const ValueRef fp_output = args.at(idx++); + + int32_t layout_value = graph.extract_scalar(layout_int); + utils::GPUMemoryLayout layout = + static_cast(layout_value); + + TmpTensor packed_int8_input( + &graph, graph.sizes_of(fp_input), vkapi::kInt8x4, utils::kBuffer, layout); + + TmpTensor packed_int8_output( + &graph, + graph.sizes_of(fp_output), + vkapi::kInt8x4, + utils::kBuffer, + layout); + + add_q8ta_quantize_node( + graph, fp_input, input_scale, input_zp, packed_int8_input); + + std::vector conv_args = { + packed_int8_input, + input_scale, + input_zp, + weight_data, + weight_sums_data, + weight_scales_data, + output_scale, + output_zp, + bias_data, + kernel_size, + stride, + padding, + output_padding, + dilation, + groups, + activation, + packed_int8_output}; + VK_GET_OP_FN("et_vk.q8ta_conv2d_transposed.default")(graph, conv_args); + + add_q8ta_dequantize_node( + graph, packed_int8_output, output_scale, output_zp, fp_output); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP( + test_etvk.test_q8ta_conv2d_transposed.default, + test_q8ta_conv2d_transposed); +} + +} // namespace vkcompute diff --git a/backends/vulkan/test/custom_ops/targets.bzl b/backends/vulkan/test/custom_ops/targets.bzl index badba5666fa..ba4873af603 100644 --- a/backends/vulkan/test/custom_ops/targets.bzl +++ b/backends/vulkan/test/custom_ops/targets.bzl @@ -98,3 +98,4 @@ def define_common_targets(is_fbcode = False): define_custom_op_test_binary("test_q8ta_conv2d_pw") define_custom_op_test_binary("test_q8ta_conv2d_dw") define_custom_op_test_binary("test_q8ta_linear") + define_custom_op_test_binary("test_q8ta_conv2d_transposed") diff --git a/backends/vulkan/test/custom_ops/test_q8ta_conv2d_transposed.cpp b/backends/vulkan/test/custom_ops/test_q8ta_conv2d_transposed.cpp new file mode 100644 index 00000000000..903a9c678b1 --- /dev/null +++ b/backends/vulkan/test/custom_ops/test_q8ta_conv2d_transposed.cpp @@ -0,0 +1,601 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include + +#include +#include + +#include + +#include "conv2d_utils.h" +#include "utils.h" + +// #define DEBUG_MODE + +using namespace executorch::vulkan::prototyping; + +using namespace vkcompute; + +static constexpr int64_t kRefDimSizeLimit = 100; + +// Transposed convolution output size formula: +// H_out = (H_in - 1) * stride_h - 2 * pad_h + dilation_h * (K_h - 1) +// + output_pad_h + 1 +static int64_t get_transpose_output_height( + const Conv2dConfig& config, + int32_t output_pad_h) { + return (config.input_size.h - 1) * config.stride.h - 2 * config.padding.h + + config.dilation.h * (config.kernel.h - 1) + output_pad_h + 1; +} + +static int64_t get_transpose_output_width( + const Conv2dConfig& config, + int32_t output_pad_w) { + return (config.input_size.w - 1) * config.stride.w - 2 * config.padding.w + + config.dilation.w * (config.kernel.w - 1) + output_pad_w + 1; +} + +// Utility function to create a test case from a Conv2dConfig for transposed +// convolution +static TestCase create_test_case_from_config( + const Conv2dConfig& config, + int32_t output_pad_h, + int32_t output_pad_w, + vkapi::ScalarType input_dtype, + utils::StorageType fp_storage_type, + utils::GPUMemoryLayout int8_memory_layout) { + TestCase test_case; + + int64_t H_out = get_transpose_output_height(config, output_pad_h); + int64_t W_out = get_transpose_output_width(config, output_pad_w); + + // Input tensor (float/half) - [1, C_in, H_in, W_in] (batch size always 1) + // For transposed conv, C_in is typically larger (downsampled channels) + std::vector input_size = { + 1, config.channels.in, config.input_size.h, config.input_size.w}; + + utils::GPUMemoryLayout fp_memory_layout = fp_storage_type == utils::kBuffer + ? utils::kWidthPacked + : utils::kChannelsPacked; + + // Create test case name + std::string prefix = config.test_case_name.substr(0, 4); + std::string test_name = prefix + " " + std::to_string(config.channels.in) + + "->" + std::to_string(config.channels.out) + " " + + "I=" + std::to_string(config.input_size.h) + "," + + std::to_string(config.input_size.w) + " " + + "g=" + std::to_string(config.groups) + " " + + "k=" + std::to_string(config.kernel.h) + " " + + "op=" + std::to_string(output_pad_h) + "," + + std::to_string(output_pad_w) + " " + + repr_str(utils::kBuffer, int8_memory_layout); + test_case.set_name(test_name); + + test_case.set_operator_name("test_etvk.test_q8ta_conv2d_transposed.default"); + + ValueSpec input_tensor( + input_size, + input_dtype, + fp_storage_type, + fp_memory_layout, + DataGenType::RANDOM); + + if (debugging()) { + print_valuespec_data(input_tensor, "input_tensor"); + } + + float input_scale_val = 0.008123; + ValueSpec input_scale(input_scale_val); + + int32_t input_zero_point_val = 2; + ValueSpec input_zero_point(input_zero_point_val); + + // Quantized weight tensor (int8) - [C_out, align_up_4(C_in_per_group * K_h * + // K_w)] After the pattern matcher reshapes, the transposed conv weight has + // the same layout as regular conv2d + const int64_t in_channels_per_group = config.channels.in / config.groups; + const int64_t in_features = utils::align_up_4( + in_channels_per_group * config.kernel.h * config.kernel.w); + std::vector weight_size = {config.channels.out, in_features}; + ValueSpec quantized_weight( + weight_size, + vkapi::kChar, + fp_storage_type, + utils::kWidthPacked, + DataGenType::RANDINT8); + quantized_weight.set_constant(true); + + if (debugging()) { + print_valuespec_data(quantized_weight, "weight_tensor"); + } + + const int64_t aligned_out_channels = utils::align_up_4(config.channels.out); + + ValueSpec weight_scales( + {aligned_out_channels}, + input_dtype, + fp_storage_type, + utils::kWidthPacked, + DataGenType::RANDOM_SCALES); + weight_scales.set_constant(true); + + ValueSpec weight_sums( + {aligned_out_channels}, + vkapi::kInt, + fp_storage_type, + utils::kWidthPacked, + DataGenType::ZEROS); + weight_sums.set_constant(true); + + compute_weight_sums( + weight_sums, quantized_weight, config.channels.out, in_features); + + ValueSpec bias( + {aligned_out_channels}, + input_dtype, + fp_storage_type, + utils::kWidthPacked, + DataGenType::ZEROS); + bias.set_constant(true); + + float output_scale_val = 0.05314; + ValueSpec output_scale(output_scale_val); + + int32_t output_zero_point_val = -1; + ValueSpec output_zero_point(output_zero_point_val); + + ValueSpec stride({config.stride.h, config.stride.w}); + ValueSpec padding({config.padding.h, config.padding.w}); + ValueSpec output_padding({output_pad_h, output_pad_w}); + ValueSpec dilation({config.dilation.h, config.dilation.w}); + ValueSpec groups(config.groups); + ValueSpec kernel_size({config.kernel.h, config.kernel.w}); + + // Output tensor - [1, C_out, H_out, W_out] + ValueSpec output( + {1, config.channels.out, H_out, W_out}, + input_dtype, + fp_storage_type, + fp_memory_layout, + DataGenType::ZEROS); + + // Add all specs to test case + test_case.add_input_spec(input_tensor); + test_case.add_input_spec(input_scale); + test_case.add_input_spec(input_zero_point); + test_case.add_input_spec(quantized_weight); + test_case.add_input_spec(weight_sums); + test_case.add_input_spec(weight_scales); + test_case.add_input_spec(output_scale); + test_case.add_input_spec(output_zero_point); + test_case.add_input_spec(bias); + test_case.add_input_spec(kernel_size); + test_case.add_input_spec(stride); + test_case.add_input_spec(padding); + test_case.add_input_spec(output_padding); + test_case.add_input_spec(dilation); + test_case.add_input_spec(groups); + + ValueSpec activation = ValueSpec::make_string("none"); + test_case.add_input_spec(activation); + + ValueSpec layout_int(static_cast(int8_memory_layout)); + test_case.add_input_spec(layout_int); + + test_case.add_output_spec(output); + + test_case.set_abs_tolerance(output_scale_val + 1e-4f); + + test_case.set_shader_filter({ + "nchw_to", + "to_nchw", + "q8ta_quantize", + "q8ta_dequantize", + }); + + return test_case; +} + +// Generate easy test cases for debugging +std::vector generate_quantized_conv2d_transposed_easy_cases() { + std::vector test_cases; + + Conv2dConfig config = { + OutInChannels(16, 32), + InputSize2D(8, 8), + KernelSize(3, 3), + Stride(2, 2), + Padding(1, 1), + Dilation(1, 1), + 1, + }; + + std::vector int8_memory_layouts = { + utils::kPackedInt8_4C1W, utils::kPackedInt8_4W4C, utils::kPackedInt8_4C}; + + for (const utils::GPUMemoryLayout int8_memory_layout : int8_memory_layouts) { + config.test_case_name = + make_test_case_name(config, false, utils::kTexture3D, utils::kBuffer); + test_cases.push_back(create_test_case_from_config( + config, + /*output_pad_h=*/1, + /*output_pad_w=*/1, + vkapi::kFloat, + utils::kTexture3D, + int8_memory_layout)); + } + + return test_cases; +} + +// Generate test cases for quantized transposed conv2d +static std::vector generate_quantized_conv2d_transposed_test_cases() { + std::vector test_cases; + if (!vkcompute::api::context()->adapter_ptr()->supports_int8_dot_product()) { + return test_cases; + } + + // Each entry: {config, output_pad_h, output_pad_w} + struct TransposedConvTestConfig { + Conv2dConfig config; + int32_t output_pad_h; + int32_t output_pad_w; + }; + + std::vector configs = { + // Basic transposed conv (stride=2, common in decoder networks) + {{OutInChannels(16, 32), + InputSize2D(8, 8), + KernelSize(3, 3), + Stride(2, 2), + Padding(1, 1), + Dilation(1, 1), + 1}, + 1, + 1}, + {{OutInChannels(32, 64), + InputSize2D(4, 4), + KernelSize(3, 3), + Stride(2, 2), + Padding(1, 1), + Dilation(1, 1), + 1}, + 1, + 1}, + // No output padding + {{OutInChannels(16, 32), + InputSize2D(8, 8), + KernelSize(4, 4), + Stride(2, 2), + Padding(1, 1), + Dilation(1, 1), + 1}, + 0, + 0}, + // Stride=1 (degenerate case) + {{OutInChannels(16, 16), + InputSize2D(8, 8), + KernelSize(3, 3), + Stride(1, 1), + Padding(1, 1), + Dilation(1, 1), + 1}, + 0, + 0}, + // Grouped transposed conv + {{OutInChannels(32, 64), + InputSize2D(8, 8), + KernelSize(3, 3), + Stride(2, 2), + Padding(1, 1), + Dilation(1, 1), + 2}, + 1, + 1}, + // Larger spatial + {{OutInChannels(64, 128), + InputSize2D(16, 16), + KernelSize(4, 4), + Stride(2, 2), + Padding(1, 1), + Dilation(1, 1), + 1}, + 0, + 0}, + // Performance cases + {{OutInChannels(64, 128), + InputSize2D(32, 32), + KernelSize(3, 3), + Stride(2, 2), + Padding(1, 1), + Dilation(1, 1), + 1}, + 1, + 1}, + {{OutInChannels(128, 256), + InputSize2D(16, 16), + KernelSize(4, 4), + Stride(2, 2), + Padding(1, 1), + Dilation(1, 1), + 1}, + 0, + 0}, + }; + + std::vector int8_memory_layouts = { + utils::kPackedInt8_4C1W, utils::kPackedInt8_4W4C, utils::kPackedInt8_4C}; + + for (auto& tc : configs) { + auto& config = tc.config; + bool is_performance = config.channels.out > kRefDimSizeLimit || + config.channels.in > kRefDimSizeLimit || + config.input_size.h > kRefDimSizeLimit || + config.input_size.w > kRefDimSizeLimit; + + for (const utils::GPUMemoryLayout int8_memory_layout : + int8_memory_layouts) { + config.test_case_name = make_test_case_name( + config, is_performance, utils::kTexture3D, utils::kBuffer); + + test_cases.push_back(create_test_case_from_config( + config, + tc.output_pad_h, + tc.output_pad_w, + vkapi::kFloat, + utils::kTexture3D, + int8_memory_layout)); + } + } + + return test_cases; +} + +// Reference implementation for quantized transposed conv2d +static void conv2d_transposed_q8ta_reference_impl(TestCase& test_case) { + int32_t idx = 0; + const ValueSpec& input_spec = test_case.inputs()[idx++]; + const ValueSpec& input_scale_spec = test_case.inputs()[idx++]; + const ValueSpec& input_zeros_spec = test_case.inputs()[idx++]; + const ValueSpec& weight_spec = test_case.inputs()[idx++]; + const ValueSpec& weight_sums_spec = test_case.inputs()[idx++]; + (void)weight_sums_spec; + const ValueSpec& weight_scales_spec = test_case.inputs()[idx++]; + const ValueSpec& output_scale_spec = test_case.inputs()[idx++]; + const ValueSpec& output_zeros_spec = test_case.inputs()[idx++]; + const ValueSpec& bias_spec = test_case.inputs()[idx++]; + const ValueSpec& kernel_size_spec = test_case.inputs()[idx++]; + const ValueSpec& stride_spec = test_case.inputs()[idx++]; + const ValueSpec& padding_spec = test_case.inputs()[idx++]; + const ValueSpec& output_padding_spec = test_case.inputs()[idx++]; + (void)output_padding_spec; // output_padding only affects output size + const ValueSpec& dilation_spec = test_case.inputs()[idx++]; + const ValueSpec& groups_spec = test_case.inputs()[idx++]; + const ValueSpec& activation_spec = test_case.inputs()[idx++]; + (void)activation_spec; + const ValueSpec& layout_spec = test_case.inputs()[idx++]; + (void)layout_spec; + + ValueSpec& output_spec = test_case.outputs()[0]; + + auto input_sizes = input_spec.get_tensor_sizes(); + auto output_sizes = output_spec.get_tensor_sizes(); + + int64_t N = input_sizes[0]; + int64_t C_in = input_sizes[1]; + int64_t H_in = input_sizes[2]; + int64_t W_in = input_sizes[3]; + int64_t C_out = output_sizes[1]; + int64_t H_out = output_sizes[2]; + int64_t W_out = output_sizes[3]; + + auto kernel_size_data = kernel_size_spec.get_int32_data(); + int64_t K_h = kernel_size_data[0]; + int64_t K_w = kernel_size_data[1]; + + auto stride_data = stride_spec.get_int32_data(); + auto padding_data = padding_spec.get_int32_data(); + auto dilation_data = dilation_spec.get_int32_data(); + int64_t stride_h = stride_data[0]; + int64_t stride_w = stride_data[1]; + int64_t pad_h = padding_data[0]; + int64_t pad_w = padding_data[1]; + int64_t dilation_h = dilation_data[0]; + int64_t dilation_w = dilation_data[1]; + int64_t groups = groups_spec.get_int_value(); + + if (N > kRefDimSizeLimit || C_in > kRefDimSizeLimit || + H_in > kRefDimSizeLimit || W_in > kRefDimSizeLimit || + C_out > kRefDimSizeLimit) { + throw std::invalid_argument( + "One or more dimensions exceed the allowed limit for reference implementation."); + } + + if (input_spec.dtype != vkapi::kFloat) { + throw std::invalid_argument("Unsupported dtype"); + } + + auto& input_data = input_spec.get_float_data(); + const float input_scale = input_scale_spec.get_float_value(); + const int32_t input_zero_point = input_zeros_spec.get_int_value(); + + auto& weight_data = weight_spec.get_int8_data(); + auto& weight_scales_data = weight_scales_spec.get_float_data(); + auto& bias_data = bias_spec.get_float_data(); + + const float output_scale = output_scale_spec.get_float_value(); + const int32_t output_zero_point = output_zeros_spec.get_int_value(); + + int64_t C_in_per_group = C_in / groups; + int64_t C_out_per_group = C_out / groups; + + int64_t num_output_elements = N * C_out * H_out * W_out; + + auto& ref_data = output_spec.get_ref_float_data(); + ref_data.resize(num_output_elements); + + const int64_t in_features = utils::align_up_4(C_in_per_group * K_h * K_w); + + // Transposed convolution reference implementation. + // For transposed conv, we scatter each input element across the output + // rather than gather. But for the reference we compute it by iterating + // over output positions and finding which input positions contribute. + // + // For each output position (oh, ow), an input position (iy, ix) contributes + // via kernel position (kh, kw) if: + // oh + pad_h - kh * dilation_h == iy * stride_h + // ow + pad_w - kw * dilation_w == ix * stride_w + // i.e., (oh + pad_h - kh * dilation_h) must be divisible by stride_h + // and the quotient must be a valid input index. + for (int64_t n = 0; n < N; ++n) { + for (int64_t out_c = 0; out_c < C_out; ++out_c) { + int64_t group_idx = out_c / C_out_per_group; + int64_t in_c_start = group_idx * C_in_per_group; + + for (int64_t out_h = 0; out_h < H_out; ++out_h) { + for (int64_t out_w = 0; out_w < W_out; ++out_w) { + int32_t int_sum = 0; + int32_t weight_sum = 0; + + for (int64_t kh = 0; kh < K_h; ++kh) { + int64_t h_offset = out_h + pad_h - kh * dilation_h; + if (h_offset < 0 || h_offset % stride_h != 0) { + continue; + } + int64_t iy = h_offset / stride_h; + if (iy >= H_in) { + continue; + } + + for (int64_t kw = 0; kw < K_w; ++kw) { + int64_t w_offset = out_w + pad_w - kw * dilation_w; + if (w_offset < 0 || w_offset % stride_w != 0) { + continue; + } + int64_t ix = w_offset / stride_w; + if (ix >= W_in) { + continue; + } + + for (int64_t ic_local = 0; ic_local < C_in_per_group; + ++ic_local) { + int64_t in_c = in_c_start + ic_local; + + int64_t input_idx = n * (C_in * H_in * W_in) + + in_c * (H_in * W_in) + iy * W_in + ix; + + float quant_input_f = + std::round(input_data[input_idx] / input_scale) + + input_zero_point; + quant_input_f = + std::min(std::max(quant_input_f, -128.0f), 127.0f); + int8_t quantized_input = static_cast(quant_input_f); + + // Weight layout: [C_out, align_up_4(C_in_per_group * K_h * + // K_w)] Inner dimension order: kh, kw, ic_local + int64_t weight_idx = out_c * in_features + + (kh * (K_w * C_in_per_group) + kw * C_in_per_group + + ic_local); + int8_t quantized_weight = weight_data[weight_idx]; + + int_sum += static_cast(quantized_input) * + static_cast(quantized_weight); + + weight_sum += static_cast(quantized_weight); + } + } + } + + int32_t zero_point_correction = input_zero_point * weight_sum; + int32_t accum_adjusted = int_sum - zero_point_correction; + float float_result = + accum_adjusted * input_scale * weight_scales_data[out_c]; + + float_result += bias_data[out_c]; + + float quant_output_f = + std::round(float_result / output_scale) + output_zero_point; + quant_output_f = std::min(std::max(quant_output_f, -128.0f), 127.0f); + int8_t quantized_output = static_cast(quant_output_f); + + float dequant_output = + (static_cast(quantized_output) - output_zero_point) * + output_scale; + + int64_t output_idx = n * (C_out * H_out * W_out) + + out_c * (H_out * W_out) + out_h * W_out + out_w; + ref_data[output_idx] = dequant_output; + } + } + } + } +} + +static void reference_impl(TestCase& test_case) { + conv2d_transposed_q8ta_reference_impl(test_case); +} + +static int64_t quantized_conv2d_transposed_flop_calculator( + const TestCase& test_case) { + int kernel_idx = 9; + + const auto& input_sizes = test_case.inputs()[0].get_tensor_sizes(); + const auto& output_sizes = test_case.outputs()[0].get_tensor_sizes(); + const auto& kernel_sizes = test_case.inputs()[kernel_idx].get_int32_data(); + + int64_t N = input_sizes[0]; + int64_t C_in = input_sizes[1]; + int64_t C_out = output_sizes[1]; + int64_t K_h = kernel_sizes[0]; + int64_t K_w = kernel_sizes[1]; + int64_t H_out = output_sizes[2]; + int64_t W_out = output_sizes[3]; + + int64_t output_elements = N * C_out * H_out * W_out; + int64_t ops_per_output = C_in * K_h * K_w; + + return output_elements * ops_per_output; +} + +int main(int argc, char* argv[]) { + set_debugging(false); + set_print_output(false); +#ifdef DEBUG_MODE + set_print_latencies(true); +#else + set_print_latencies(false); +#endif + set_use_gpu_timestamps(true); + + print_performance_header(); + std::cout + << "Quantized Transposed Conv2d Operation with Output Quantization Prototyping Framework" + << std::endl; + print_separator(); + + ReferenceComputeFunc ref_fn = reference_impl; + + auto results = execute_test_cases( +#ifdef DEBUG_MODE + generate_quantized_conv2d_transposed_easy_cases, +#else + generate_quantized_conv2d_transposed_test_cases, +#endif + quantized_conv2d_transposed_flop_calculator, + "QuantizedTransposedConv2d", +#ifdef DEBUG_MODE + 0, + 1, +#else + 3, + 10, +#endif + ref_fn); + + return 0; +} From 08cb79380d2e50d773d4fb523fbed98c18638300 Mon Sep 17 00:00:00 2001 From: ssjia Date: Mon, 9 Mar 2026 16:45:57 -0700 Subject: [PATCH 5/5] [ET-VK][qlinear] Add bmm support to quantized linear pattern detector Pull Request resolved: https://github.com/pytorch/executorch/pull/18017 Some quantized linear projections (e.g. in EdgeTAM's SpatialPerceiver / mask decoder) decompose as aten.bmm instead of aten.mm. Add aten.bmm.default as an anchor node in the quantized linear pattern detector so these nodes can be fused into custom quantized linear ops. Reject bmm nodes with batch dim > 1 since the custom ops assume a single batch. ghstack-source-id: 349646654 @exported-using-ghexport Differential Revision: [D95807072](https://our.internmc.facebook.com/intern/diff/D95807072/) --- backends/vulkan/patterns/quantized_linear.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/backends/vulkan/patterns/quantized_linear.py b/backends/vulkan/patterns/quantized_linear.py index 85e3476cad3..b9b307e14f1 100644 --- a/backends/vulkan/patterns/quantized_linear.py +++ b/backends/vulkan/patterns/quantized_linear.py @@ -90,6 +90,11 @@ def __init__(self, mm_node: torch.fx.Node) -> None: # noqa: C901 # Identify output node self.output_node = self.anchor_node + # bmm with batch dim > 1 is not supported + is_bmm = self.anchor_node.target == exir_ops.edge.aten.bmm.default + if is_bmm and self.output_node.meta["val"].shape[0] != 1: + return + # Identify primary input node of the anchor. Due to decomposition of aten.linear # there may be a view_copy node between the original input tensor to the linear # op and the actual linear op node. @@ -268,6 +273,7 @@ def is_input_dynamic_perchannel_quantized(self) -> bool: exir_ops.edge.aten.linear.default, exir_ops.edge.aten.mm.default, exir_ops.edge.aten.addmm.default, + exir_ops.edge.aten.bmm.default, }