Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -1085,7 +1085,7 @@ def register_gather():
@update_features(exir_ops.edge.aten.expand_copy.default)
def register_expand_copy():
return OpFeatures(
inputs_storage=utils.ANY_BUFFER,
inputs_storage=utils.ANY_STORAGE,
inputs_dtypes=utils.FP_INT_BOOL_T,
supports_resize=False,
supports_highdim=True,
Expand Down
67 changes: 67 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/expand_texture.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#version 450 core

${define_required_extensions("texture3d", DTYPE)}

#define PRECISION ${PRECISION}

#define VEC4_T ${texel_load_type(DTYPE, "texture3d")}
#define T ${texel_load_component_type(DTYPE, "texture3d")}

${define_active_storage_type("texture3d")}

layout(std430) buffer;

#include "indexing.glslh"

${layout_declare_tensor(B, "w", "t_outp", DTYPE, "texture3d")}
${layout_declare_tensor(B, "r", "t_inp", DTYPE, "texture3d")}

${layout_declare_ubo(B, "TextureMetadata", "outp")}
${layout_declare_ubo(B, "TextureMetadata", "inp")}

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

void main() {
const ivec3 out_pos = ivec3(gl_GlobalInvocationID);

if (out_of_bounds(out_pos, outp)) {
return;
}

TensorIndex4D out_tidx = texture_pos_to_tensor4d_idx_simple(outp, out_pos);

VEC4_T out_texel = VEC4_T(0);

int limit = min(
4, outp.sizes[outp.packed_dim] - out_tidx.data[outp.packed_dim]);
for (int comp = 0; comp < 4; comp++) {
if (comp >= limit) {
break;
}

// Map output tensor index to input tensor index using modulo
TensorIndex4D inp_tidx;
inp_tidx.data.x = out_tidx.data.x % inp.sizes.x;
inp_tidx.data.y = out_tidx.data.y % inp.sizes.y;
inp_tidx.data.z = out_tidx.data.z % inp.sizes.z;
inp_tidx.data.w = out_tidx.data.w % inp.sizes.w;

TextureElementIndex inp_elem =
tensor4d_idx_to_texture_element_idx_simple(inp, inp_tidx);

VEC4_T inp_texel = texelFetch(t_inp, inp_elem.pos, 0);
out_texel[comp] = inp_texel[inp_elem.comp];

out_tidx.data[outp.packed_dim]++;
}

imageStore(t_outp, out_pos, out_texel);
}
11 changes: 11 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/expand_texture.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
expand_texture:
parameter_names_with_default_values:
DTYPE: float
generate_variant_forall:
DTYPE:
- VALUE: half
- VALUE: float
- VALUE: int32
- VALUE: uint8
shader_variants:
- NAME: expand_texture3d
40 changes: 31 additions & 9 deletions backends/vulkan/runtime/graph/ops/impl/Expand.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,33 @@

namespace vkcompute {

void add_expand_buffer_node(
void resize_expand_node(
ComputeGraph* graph,
const std::vector<ArgGroup>& args,
const std::vector<ValueRef>& extra_args) {
const ValueRef in = args.at(1).refs.at(0);
const ValueRef out = args.at(0).refs.at(0);
const ValueRef size_ref = extra_args.at(0);

const std::vector<int64_t> in_sizes = graph->sizes_of(in);
const std::vector<int64_t> target_sizes =
graph->extract_int_or_symint_list(size_ref);

const size_t dim_offset = target_sizes.size() - in_sizes.size();
std::vector<int64_t> out_sizes(target_sizes.size());
for (size_t i = 0; i < target_sizes.size(); i++) {
if (target_sizes[i] == -1 && i >= dim_offset) {
out_sizes[i] = in_sizes[i - dim_offset];
} else if (target_sizes[i] == -1) {
out_sizes[i] = 1;
} else {
out_sizes[i] = target_sizes[i];
}
}
graph->virtual_resize(out, out_sizes);
}

void add_expand_node(
ComputeGraph& graph,
const ValueRef in,
const ValueRef size,
Expand All @@ -27,8 +53,8 @@ void add_expand_buffer_node(
add_dtype_suffix(kernel_name, graph.dtype_of(out));

vkapi::ParamsBindList param_buffers = {
graph.buffer_meta_ubo(out),
graph.buffer_meta_ubo(in),
graph.meta_ubo(out),
graph.meta_ubo(in),
};

graph.execute_nodes().emplace_back(new DynamicDispatchNode(
Expand All @@ -46,7 +72,7 @@ void add_expand_buffer_node(
// Resize Args
{size},
// Resizing Logic
nullptr));
resize_expand_node));
}

void expand(ComputeGraph& graph, const std::vector<ValueRef>& args) {
Expand All @@ -57,11 +83,7 @@ void expand(ComputeGraph& graph, const std::vector<ValueRef>& args) {
(void)implicit;
const ValueRef out = args.at(idx++);

if (graph.is_buffer_storage(out)) {
return add_expand_buffer_node(graph, in, size, out);
}

VK_THROW("Expand operator only supports buffer storage");
add_expand_node(graph, in, size, out);
}

REGISTER_OPERATORS {
Expand Down
1 change: 1 addition & 0 deletions backends/vulkan/test/op_tests/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -1994,6 +1994,7 @@ def get_expand_inputs():
)
test_suite.storage_types = [
"utils::kBuffer",
"utils::kTexture3D",
]
test_suite.layouts = [
"utils::kWidthPacked",
Expand Down
Loading