From c7bcdd3f110321d48dd414d6f90118e4344e55e0 Mon Sep 17 00:00:00 2001 From: ssjia Date: Tue, 10 Mar 2026 10:01:10 -0700 Subject: [PATCH] [ET-VK] Add ANY_STORAGE support to repeat Add a buffer shader variant for the repeat operator and rewrite the C++ dispatch code to support both buffer and texture storage types. Key changes: - Add repeat_buffer.glsl/.yaml: a new buffer-path compute shader that uses BufferMetadata UBOs and linear index arithmetic to map each output element to its source element via modulo on each dimension. - Rename repeat.glsl/.yaml to repeat_texture.glsl/.yaml (shader variant name repeat_texture3d) and update the push constant from dst_repeats to out_dims, computing out_channel_size directly from output dimensions rather than from input size * repeat count. - Rewrite Repeat.cpp: remove the old check_args() function and pre-computed ivec4 push constants. Add resize_repeat_node() that uses extract_int_or_symint_list() to dynamically compute output sizes from input sizes and repeats, enabling dynamic shape support. Dispatch uses add_storage_type_suffix for shader selection, with the buffer path using meta_ubo UBOs and the texture path using push constants from logical_limits_pc_of/sizes_pc_of. Wire resize_repeat_node into DynamicDispatchNode. - Update op_registry.py: change inputs_storage from ANY_TEXTURE to ANY_STORAGE. - Update test cases: add utils::kBuffer to storage_types for both 2d and 3d repeat test suites. Differential Revision: [D95970170](https://our.internmc.facebook.com/intern/diff/D95970170/) [ghstack-poisoned] --- backends/vulkan/op_registry.py | 2 +- .../runtime/graph/ops/glsl/repeat_buffer.glsl | 51 +++++++ .../runtime/graph/ops/glsl/repeat_buffer.yaml | 13 ++ .../glsl/{repeat.glsl => repeat_texture.glsl} | 8 +- .../glsl/{repeat.yaml => repeat_texture.yaml} | 4 +- .../vulkan/runtime/graph/ops/impl/Repeat.cpp | 141 +++++++----------- backends/vulkan/test/op_tests/cases.py | 4 +- 7 files changed, 125 insertions(+), 98 deletions(-) create mode 100644 backends/vulkan/runtime/graph/ops/glsl/repeat_buffer.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/repeat_buffer.yaml rename backends/vulkan/runtime/graph/ops/glsl/{repeat.glsl => repeat_texture.glsl} (94%) rename backends/vulkan/runtime/graph/ops/glsl/{repeat.yaml => repeat_texture.yaml} (84%) diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 8ece5903cea..6734ada8f98 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -1343,7 +1343,7 @@ def register_grid_priors(): @update_features(exir_ops.edge.aten.repeat.default) def register_repeat(): return OpFeatures( - inputs_storage=utils.ANY_TEXTURE, + inputs_storage=utils.ANY_STORAGE, inputs_dtypes=utils.FP_INT_BOOL_T, ) diff --git a/backends/vulkan/runtime/graph/ops/glsl/repeat_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/repeat_buffer.glsl new file mode 100644 index 00000000000..be2d87a168f --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/repeat_buffer.glsl @@ -0,0 +1,51 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +${define_required_extensions("buffer", DTYPE)} + +#define PRECISION ${PRECISION} + +#define T ${buffer_scalar_type(DTYPE)} + +${define_active_storage_type("buffer")} + +layout(std430) buffer; + +#include "indexing.glslh" + +${layout_declare_tensor(B, "w", "t_out", DTYPE, "buffer")} +${layout_declare_tensor(B, "r", "t_in", DTYPE, "buffer")} + +${layout_declare_ubo(B, "BufferMetadata", "out_meta")} +${layout_declare_ubo(B, "BufferMetadata", "in_meta")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const uint out_bufi = gl_GlobalInvocationID.x; + if (out_of_bounds(out_bufi, out_meta)) { + return; + } + + TensorIndex out_tidx = linear_idx_to_tensor_idx(out_meta, out_bufi); + + TensorIndex in_tidx; + initialize(in_tidx); + + const int n = int_ndim(out_meta); + for (int d = 0; d < n; d++) { + in_tidx.data[div_4(d)][mod_4(d)] = + idx_at(out_tidx, d) % size_at(in_meta, d); + } + + const uint in_bufi = tensor_idx_to_linear_idx(in_meta, in_tidx); + + t_out[out_bufi] = t_in[in_bufi]; +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/repeat_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/repeat_buffer.yaml new file mode 100644 index 00000000000..83d03d00b01 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/repeat_buffer.yaml @@ -0,0 +1,13 @@ +repeat_buffer: + parameter_names_with_default_values: + DTYPE: float + STORAGE: buffer + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + - VALUE: int32 + - VALUE: int8 + - VALUE: uint8 + shader_variants: + - NAME: repeat_buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/repeat.glsl b/backends/vulkan/runtime/graph/ops/glsl/repeat_texture.glsl similarity index 94% rename from backends/vulkan/runtime/graph/ops/glsl/repeat.glsl rename to backends/vulkan/runtime/graph/ops/glsl/repeat_texture.glsl index 441cd57c17d..cc0817b229f 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/repeat.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/repeat_texture.glsl @@ -21,8 +21,8 @@ layout(push_constant) uniform restrict Block { ivec4 range; // source tensor sizes in WHCB dims respectively ivec4 src_dims; - // destination tensor repeats in WHCB dims respectively - ivec4 dst_repeats; + // output tensor sizes in WHCB dims respectively + ivec4 out_dims; }; #include "indexing_utils.h" @@ -58,7 +58,7 @@ void main() { // if tensors are channel packed if (packed_dim == C_DIM) { // the output channels in a batch will be channel size * channel repetitions aligned by 4 - const int out_channel_size = alignup4(src_dims.z * dst_repeats.z); + const int out_channel_size = alignup4(out_dims.z); // batch index in the output const int out_pos_batch_index = pos.z / out_channel_size; @@ -76,7 +76,7 @@ void main() { channel_index = (pos.z - (batch_index + batch_repetition_index * src_dims.w) * out_channel_size) % src_dims.z; } else { // the output channels in a batch will be channel size * channel repetitions - const int out_channel_size = src_dims.z * dst_repeats.z; + const int out_channel_size = out_dims.z; // source batch index for based on current output pos batch_index = (pos.z / out_channel_size) % src_dims.w; diff --git a/backends/vulkan/runtime/graph/ops/glsl/repeat.yaml b/backends/vulkan/runtime/graph/ops/glsl/repeat_texture.yaml similarity index 84% rename from backends/vulkan/runtime/graph/ops/glsl/repeat.yaml rename to backends/vulkan/runtime/graph/ops/glsl/repeat_texture.yaml index f40d94142e1..a9e469fc39b 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/repeat.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/repeat_texture.yaml @@ -1,4 +1,4 @@ -repeat: +repeat_texture: parameter_names_with_default_values: DTYPE: float NDIM: 3 @@ -11,4 +11,4 @@ repeat: - VALUE: int8 - VALUE: uint8 shader_variants: - - NAME: repeat + - NAME: repeat_texture3d diff --git a/backends/vulkan/runtime/graph/ops/impl/Repeat.cpp b/backends/vulkan/runtime/graph/ops/impl/Repeat.cpp index 2b42c0bd150..e4d5aad66b3 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Repeat.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Repeat.cpp @@ -16,107 +16,70 @@ namespace vkcompute { -namespace { - -void check_args( - ComputeGraph& graph, - const ValueRef in, - const std::vector& repeats, - const ValueRef out) { - VK_CHECK_COND(graph.packed_dim_of(in) == graph.packed_dim_of(out)); - - VK_CHECK_COND(graph.storage_type_of(in) == graph.storage_type_of(out)); - if (graph.storage_type_of(in) == utils::kTexture2D) { - VK_CHECK_COND(graph.dim_of(in) <= 2); +void resize_repeat_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& extra_args) { + const ValueRef in = args.at(1).refs.at(0); + const ValueRef out = args.at(0).refs.at(0); + const ValueRef repeats_ref = extra_args.at(0); + + const std::vector in_sizes = graph->sizes_of(in); + const std::vector repeats = + graph->extract_int_or_symint_list(repeats_ref); + + const size_t out_ndim = std::max(in_sizes.size(), repeats.size()); + std::vector out_sizes(out_ndim); + for (size_t i = 0; i < out_ndim; i++) { + const size_t in_offset = i + in_sizes.size() - out_ndim; + const size_t rep_offset = i + repeats.size() - out_ndim; + // Prepend 1s to in_sizes if repeats is longer, and vice versa + const int64_t in_size = + (i >= out_ndim - in_sizes.size()) ? in_sizes[in_offset] : 1; + const int64_t r = + (i >= out_ndim - repeats.size()) ? repeats[rep_offset] : 1; + out_sizes[i] = in_size * r; } - - const int64_t in_dim = graph.dim_of(in); - VK_CHECK_COND( - in_dim <= repeats.size(), - "Input tensor dim size must be not greater than the repeat argument's size"); - - const std::vector in_sizes = graph.sizes_of(in); - const std::vector out_sizes = graph.sizes_of(out); - - VK_CHECK_COND( - dim_at(in_sizes) * dim_at(repeats) == - dim_at(out_sizes), - "Output's width doesn't match input's width * repeat count"); - - VK_CHECK_COND( - dim_at(in_sizes) * dim_at(repeats) == - dim_at(out_sizes), - "Output's height doesn't match input's height * repeat count"); - - VK_CHECK_COND( - dim_at(in_sizes) * dim_at(repeats) == - dim_at(out_sizes), - "Output's channel doesn't match input's channel * repeat count"); - - VK_CHECK_COND( - dim_at(in_sizes) * dim_at(repeats) == - dim_at(out_sizes), - "Output's batch doesn't match input's batch * repeat count"); + graph->virtual_resize(out, out_sizes); } -} // namespace - void add_repeat_node( ComputeGraph& graph, ValueRef in, ValueRef repeats_ref, ValueRef out) { - const std::vector repeats = *(graph.get_int_list(repeats_ref)); - - check_args(graph, in, repeats, out); - - const std::vector in_sizes = graph.sizes_of(in); - const utils::ivec4 src_dims{ - dim_at(in_sizes), - dim_at(in_sizes), - dim_at(in_sizes), - dim_at(in_sizes)}; - const utils::ivec4 dst_repeats{ - dim_at(repeats), - dim_at(repeats), - dim_at(repeats), - dim_at(repeats)}; - std::string kernel_name = "repeat"; kernel_name.reserve(kShaderNameReserve); + add_storage_type_suffix(kernel_name, graph.storage_type_of(out)); add_dtype_suffix(kernel_name, graph.dtype_of(out)); - // A copy of range with the last element set to batch size of the input tensor - const utils::ivec3 wg_size = graph.logical_limits_of(out); - - const auto shader = VK_KERNEL_FROM_STR(kernel_name); - - graph.execute_nodes().emplace_back(new DynamicDispatchNode( - graph, - VK_KERNEL_FROM_STR(kernel_name), - default_pick_global_wg_size, - default_pick_local_wg_size, - // Inputs and Outputs - { - {out, vkapi::kWrite}, - {in, vkapi::kRead}, - }, - // Parameter buffers - {}, - // Push Constants - { - PushConstantDataInfo(&wg_size, sizeof(wg_size), sizeof(utils::ivec4)), - PushConstantDataInfo( - &src_dims, sizeof(src_dims), sizeof(utils::ivec4)), - PushConstantDataInfo( - &dst_repeats, sizeof(dst_repeats), sizeof(utils::ivec4)), - }, - // Specialization Constants - {graph.hashed_layout_of(out), graph.hashed_layout_of(in)}, - // Resize Args - {}, - // Resizing Logic - nullptr)); + if (graph.is_buffer_storage(out)) { + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + default_pick_global_wg_size, + default_pick_local_wg_size, + {{out, vkapi::kWrite}, {in, vkapi::kRead}}, + {graph.meta_ubo(out), graph.meta_ubo(in)}, + {}, + {}, + {repeats_ref}, + resize_repeat_node)); + } else { + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + default_pick_global_wg_size, + default_pick_local_wg_size, + {{out, vkapi::kWrite}, {in, vkapi::kRead}}, + {}, + {graph.logical_limits_pc_of(out), + graph.sizes_pc_of(in), + graph.sizes_pc_of(out)}, + {graph.hashed_layout_of(out), graph.hashed_layout_of(in)}, + {repeats_ref}, + resize_repeat_node)); + } } void repeat(ComputeGraph& graph, const std::vector& args) { diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index 87a086db831..b48d7f98d98 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -1308,7 +1308,7 @@ def get_repeat_inputs(): "utils::kHeightPacked", "utils::kChannelsPacked", ] - test_suite_2d.storage_types = ["utils::kTexture3D"] + test_suite_2d.storage_types = ["utils::kTexture3D", "utils::kBuffer"] test_suite_2d.data_gen = "make_seq_tensor" test_suite_2d.dtypes = ["at::kFloat"] test_suite_2d.test_name_suffix = "2d" @@ -1353,7 +1353,7 @@ def get_repeat_inputs(): "utils::kHeightPacked", "utils::kChannelsPacked", ] - test_suite_3d.storage_types = ["utils::kTexture3D"] + test_suite_3d.storage_types = ["utils::kTexture3D", "utils::kBuffer"] test_suite_3d.data_gen = "make_seq_tensor" test_suite_3d.dtypes = ["at::kFloat"] test_suite_3d.test_name_suffix = "3d"