From b63d290730605979c14bdb8c692d2c900561b9ca Mon Sep 17 00:00:00 2001 From: ssjia Date: Tue, 10 Mar 2026 10:00:54 -0700 Subject: [PATCH] [ET-VK] Add ANY_STORAGE support to expand_copy Add a texture shader variant for expand_copy and a resize function for dynamic shape support. The texture shader maps each output texel coordinate to the corresponding input texel using modulo on the input sizes, matching the semantics of the existing buffer shader. Use meta_ubo() instead of buffer_meta_ubo() so the correct UBO type is selected based on storage type. Use extract_int_or_symint_list() for target sizes to handle symbolic integers. Register as ANY_STORAGE. Differential Revision: [D95970162](https://our.internmc.facebook.com/intern/diff/D95970162/) [ghstack-poisoned] --- backends/vulkan/op_registry.py | 2 +- .../graph/ops/glsl/expand_texture.glsl | 67 +++++++++++++++++++ .../graph/ops/glsl/expand_texture.yaml | 11 +++ .../vulkan/runtime/graph/ops/impl/Expand.cpp | 40 ++++++++--- backends/vulkan/test/op_tests/cases.py | 1 + 5 files changed, 111 insertions(+), 10 deletions(-) create mode 100644 backends/vulkan/runtime/graph/ops/glsl/expand_texture.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/expand_texture.yaml diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index efe2bc98070..f348cbbce9e 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -1085,7 +1085,7 @@ def register_gather(): @update_features(exir_ops.edge.aten.expand_copy.default) def register_expand_copy(): return OpFeatures( - inputs_storage=utils.ANY_BUFFER, + inputs_storage=utils.ANY_STORAGE, inputs_dtypes=utils.FP_INT_BOOL_T, supports_resize=False, supports_highdim=True, diff --git a/backends/vulkan/runtime/graph/ops/glsl/expand_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/expand_texture.glsl new file mode 100644 index 00000000000..d463260dc3c --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/expand_texture.glsl @@ -0,0 +1,67 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +${define_required_extensions("texture3d", DTYPE)} + +#define PRECISION ${PRECISION} + +#define VEC4_T ${texel_load_type(DTYPE, "texture3d")} +#define T ${texel_load_component_type(DTYPE, "texture3d")} + +${define_active_storage_type("texture3d")} + +layout(std430) buffer; + +#include "indexing.glslh" + +${layout_declare_tensor(B, "w", "t_outp", DTYPE, "texture3d")} +${layout_declare_tensor(B, "r", "t_inp", DTYPE, "texture3d")} + +${layout_declare_ubo(B, "TextureMetadata", "outp")} +${layout_declare_ubo(B, "TextureMetadata", "inp")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const ivec3 out_pos = ivec3(gl_GlobalInvocationID); + + if (out_of_bounds(out_pos, outp)) { + return; + } + + TensorIndex4D out_tidx = texture_pos_to_tensor4d_idx_simple(outp, out_pos); + + VEC4_T out_texel = VEC4_T(0); + + int limit = min( + 4, outp.sizes[outp.packed_dim] - out_tidx.data[outp.packed_dim]); + for (int comp = 0; comp < 4; comp++) { + if (comp >= limit) { + break; + } + + // Map output tensor index to input tensor index using modulo + TensorIndex4D inp_tidx; + inp_tidx.data.x = out_tidx.data.x % inp.sizes.x; + inp_tidx.data.y = out_tidx.data.y % inp.sizes.y; + inp_tidx.data.z = out_tidx.data.z % inp.sizes.z; + inp_tidx.data.w = out_tidx.data.w % inp.sizes.w; + + TextureElementIndex inp_elem = + tensor4d_idx_to_texture_element_idx_simple(inp, inp_tidx); + + VEC4_T inp_texel = texelFetch(t_inp, inp_elem.pos, 0); + out_texel[comp] = inp_texel[inp_elem.comp]; + + out_tidx.data[outp.packed_dim]++; + } + + imageStore(t_outp, out_pos, out_texel); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/expand_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/expand_texture.yaml new file mode 100644 index 00000000000..461b39b11bf --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/expand_texture.yaml @@ -0,0 +1,11 @@ +expand_texture: + parameter_names_with_default_values: + DTYPE: float + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + - VALUE: int32 + - VALUE: uint8 + shader_variants: + - NAME: expand_texture3d diff --git a/backends/vulkan/runtime/graph/ops/impl/Expand.cpp b/backends/vulkan/runtime/graph/ops/impl/Expand.cpp index 1623a26b2a1..b2283630e04 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Expand.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Expand.cpp @@ -16,7 +16,33 @@ namespace vkcompute { -void add_expand_buffer_node( +void resize_expand_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& extra_args) { + const ValueRef in = args.at(1).refs.at(0); + const ValueRef out = args.at(0).refs.at(0); + const ValueRef size_ref = extra_args.at(0); + + const std::vector in_sizes = graph->sizes_of(in); + const std::vector target_sizes = + graph->extract_int_or_symint_list(size_ref); + + const size_t dim_offset = target_sizes.size() - in_sizes.size(); + std::vector out_sizes(target_sizes.size()); + for (size_t i = 0; i < target_sizes.size(); i++) { + if (target_sizes[i] == -1 && i >= dim_offset) { + out_sizes[i] = in_sizes[i - dim_offset]; + } else if (target_sizes[i] == -1) { + out_sizes[i] = 1; + } else { + out_sizes[i] = target_sizes[i]; + } + } + graph->virtual_resize(out, out_sizes); +} + +void add_expand_node( ComputeGraph& graph, const ValueRef in, const ValueRef size, @@ -27,8 +53,8 @@ void add_expand_buffer_node( add_dtype_suffix(kernel_name, graph.dtype_of(out)); vkapi::ParamsBindList param_buffers = { - graph.buffer_meta_ubo(out), - graph.buffer_meta_ubo(in), + graph.meta_ubo(out), + graph.meta_ubo(in), }; graph.execute_nodes().emplace_back(new DynamicDispatchNode( @@ -46,7 +72,7 @@ void add_expand_buffer_node( // Resize Args {size}, // Resizing Logic - nullptr)); + resize_expand_node)); } void expand(ComputeGraph& graph, const std::vector& args) { @@ -57,11 +83,7 @@ void expand(ComputeGraph& graph, const std::vector& args) { (void)implicit; const ValueRef out = args.at(idx++); - if (graph.is_buffer_storage(out)) { - return add_expand_buffer_node(graph, in, size, out); - } - - VK_THROW("Expand operator only supports buffer storage"); + add_expand_node(graph, in, size, out); } REGISTER_OPERATORS { diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index 1c4356cab4f..e06b7f3ce6b 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -1994,6 +1994,7 @@ def get_expand_inputs(): ) test_suite.storage_types = [ "utils::kBuffer", + "utils::kTexture3D", ] test_suite.layouts = [ "utils::kWidthPacked",