Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/webgpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ set(WEBGPU_SRCS
runtime/ops/slice/Slice.cpp
runtime/ops/permute/Permute.cpp
runtime/ops/cat/Cat.cpp
runtime/ops/index/Index.cpp
)

add_library(webgpu_backend ${WEBGPU_SRCS})
Expand Down
188 changes: 188 additions & 0 deletions backends/webgpu/runtime/ops/index/Index.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/webgpu/runtime/WebGPUGraph.h>
#include <executorch/backends/webgpu/runtime/WebGPUUtils.h>
#include <executorch/backends/webgpu/runtime/ops/OperatorRegistry.h>
#include <executorch/backends/webgpu/runtime/ops/index/index_wgsl.h>

#include <webgpu/webgpu.h>

#include <cstdint>
#include <cstring>
#include <stdexcept>
#include <vector>

namespace executorch::backends::webgpu {

namespace {

struct IndexParams {
uint32_t numel;
uint32_t _pad[3]; // pad to 16 bytes
};

// aten.index.Tensor 1D-self gather out[i]=self[index[i]] (mirrors Vulkan).
void index_impl(WebGPUGraph& graph, const std::vector<int>& args) {
// args: [self, indices (Tensor?[] -> ValueList), out].
const int self_id = args.at(0);
const int list_id = args.at(1);
const int out_id = args.at(args.size() - 1);

if (graph.get_value_type(self_id) != WebGPUGraph::ValueType::Tensor) {
throw std::runtime_error("index: self arg is not a tensor");
}
if (graph.get_value_type(out_id) != WebGPUGraph::ValueType::Tensor) {
throw std::runtime_error("index: out arg is not a tensor");
}
if (graph.get_value_type(list_id) != WebGPUGraph::ValueType::ValueList) {
throw std::runtime_error("index: indices arg is not a ValueList");
}

// Exactly one non-Null index tensor (mirror Vulkan IndexTensor.cpp:67-69).
const std::vector<int>& ids = graph.get_value_list(list_id);
int index_id = -1;
for (int id : ids) {
if (graph.get_value_type(id) == WebGPUGraph::ValueType::Null) {
continue;
}
if (graph.get_value_type(id) != WebGPUGraph::ValueType::Tensor) {
throw std::runtime_error("index: index list element is not a tensor");
}
if (index_id != -1) {
throw std::runtime_error("index: expected exactly one index tensor");
}
index_id = id;
}
if (index_id == -1) {
throw std::runtime_error("index: no index tensor provided");
}

WGPUDevice device = graph.device();

const auto& self_tensor = graph.get_tensor(self_id);
const auto& index_tensor = graph.get_tensor(index_id);
const auto& out_tensor = graph.get_tensor(out_id);

const size_t out_numel = out_tensor.nbytes / sizeof(float);
if (out_tensor.nbytes != out_numel * sizeof(float) ||
self_tensor.nbytes % sizeof(float) != 0) {
throw std::runtime_error("index: non-fp32 self/out (nbytes != numel * 4)");
}
// Index is the int32 downcast of the int64 advanced index (downcast_64_bit).
const size_t index_numel = index_tensor.nbytes / sizeof(int32_t);
if (index_tensor.nbytes != index_numel * sizeof(int32_t)) {
throw std::runtime_error("index: index buffer is not int32 (nbytes % 4)");
}
// out is one self element per index element (row_width == 1, 1D self).
if (out_numel != index_numel) {
throw std::runtime_error("index: out numel != index numel");
}

uint32_t num_elements = static_cast<uint32_t>(out_numel);
uint32_t wg_size = utils::clamp_workgroup_size(device, kIndexWorkgroupSizeX);
uint32_t workgroup_count =
utils::compute_1d_workgroup_count(device, num_elements, wg_size, "index");

WGPUConstantEntry wg_size_constant = {};
wg_size_constant.key = {"wg_size", WGPU_STRLEN};
wg_size_constant.value = static_cast<double>(wg_size);

IndexParams params = {};
params.numel = num_elements;

WGPUBufferDescriptor uniform_desc = {};
uniform_desc.size = sizeof(IndexParams);
uniform_desc.usage = WGPUBufferUsage_Uniform | WGPUBufferUsage_CopyDst;
uniform_desc.mappedAtCreation = true;
WGPUBuffer uniform_buffer = wgpuDeviceCreateBuffer(device, &uniform_desc);
std::memcpy(
wgpuBufferGetMappedRange(uniform_buffer, 0, sizeof(IndexParams)),
&params,
sizeof(IndexParams));
wgpuBufferUnmap(uniform_buffer);
graph.add_uniform_buffer_bytes(sizeof(IndexParams));

WGPUShaderSourceWGSL wgsl_desc = {};
wgsl_desc.chain.sType = WGPUSType_ShaderSourceWGSL;
wgsl_desc.code = {kIndexWGSL, WGPU_STRLEN};
WGPUShaderModuleDescriptor shader_desc = {};
shader_desc.nextInChain = &wgsl_desc.chain;
WGPUShaderModule shader = wgpuDeviceCreateShaderModule(device, &shader_desc);

// self (read), out (read_write), index (read i32), params (uniform).
WGPUBindGroupLayoutEntry entries[4] = {};
entries[0].binding = 0;
entries[0].visibility = WGPUShaderStage_Compute;
entries[0].buffer.type = WGPUBufferBindingType_ReadOnlyStorage;
entries[1].binding = 1;
entries[1].visibility = WGPUShaderStage_Compute;
entries[1].buffer.type = WGPUBufferBindingType_Storage;
entries[2].binding = 2;
entries[2].visibility = WGPUShaderStage_Compute;
entries[2].buffer.type = WGPUBufferBindingType_ReadOnlyStorage;
entries[3].binding = 3;
entries[3].visibility = WGPUShaderStage_Compute;
entries[3].buffer.type = WGPUBufferBindingType_Uniform;

WGPUBindGroupLayoutDescriptor bgl_desc = {};
bgl_desc.entryCount = 4;
bgl_desc.entries = entries;
WGPUBindGroupLayout bgl = wgpuDeviceCreateBindGroupLayout(device, &bgl_desc);

WGPUPipelineLayoutDescriptor pl_desc = {};
pl_desc.bindGroupLayoutCount = 1;
pl_desc.bindGroupLayouts = &bgl;
WGPUPipelineLayout pipeline_layout =
wgpuDeviceCreatePipelineLayout(device, &pl_desc);

WGPUComputePipelineDescriptor pipeline_desc = {};
pipeline_desc.layout = pipeline_layout;
pipeline_desc.compute.module = shader;
pipeline_desc.compute.entryPoint = {"main", WGPU_STRLEN};
pipeline_desc.compute.constantCount = 1;
pipeline_desc.compute.constants = &wg_size_constant;
WGPUComputePipeline pipeline =
wgpuDeviceCreateComputePipeline(device, &pipeline_desc);

WGPUBindGroupEntry bg_entries[4] = {};
bg_entries[0].binding = 0;
bg_entries[0].buffer = self_tensor.buffer;
bg_entries[0].size = self_tensor.nbytes;
bg_entries[1].binding = 1;
bg_entries[1].buffer = out_tensor.buffer;
bg_entries[1].size = out_tensor.nbytes;
bg_entries[2].binding = 2;
bg_entries[2].buffer = index_tensor.buffer;
bg_entries[2].size = index_tensor.nbytes;
bg_entries[3].binding = 3;
bg_entries[3].buffer = uniform_buffer;
bg_entries[3].size = sizeof(IndexParams);

WGPUBindGroupDescriptor bg_desc = {};
bg_desc.layout = bgl;
bg_desc.entryCount = 4;
bg_desc.entries = bg_entries;
WGPUBindGroup bind_group = wgpuDeviceCreateBindGroup(device, &bg_desc);

graph.add_dispatch({pipeline, bind_group, workgroup_count});

wgpuShaderModuleRelease(shader);
wgpuBindGroupLayoutRelease(bgl);
wgpuPipelineLayoutRelease(pipeline_layout);
// The bind group keeps the uniform buffer alive until release.
wgpuBufferRelease(uniform_buffer);
}

} // namespace

WEBGPU_REGISTER_OPERATORS {
WEBGPU_REGISTER_OP(aten.index.Tensor, index_impl);
}

} // namespace executorch::backends::webgpu
22 changes: 22 additions & 0 deletions backends/webgpu/runtime/ops/index/index.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
@group(0) @binding(0) var<storage, read> input: array<f32>;
@group(0) @binding(1) var<storage, read_write> output: array<f32>;
@group(0) @binding(2) var<storage, read> index: array<i32>;

struct Params {
numel: u32,
}
@group(0) @binding(3) var<uniform> params: Params;

override wg_size: u32 = 64;

@compute @workgroup_size(wg_size)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let out_bufi = gid.x;
if (out_bufi >= params.numel) {
return;
}

// 1D-self gather out[i]=self[index[i]] (mirrors Vulkan index_tensor_buffer.glsl).
let i = index[out_bufi];
output[out_bufi] = input[u32(i)];
}
46 changes: 46 additions & 0 deletions backends/webgpu/runtime/ops/index/index_wgsl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <cstdint>

namespace executorch::backends::webgpu {

// @generated from index.wgsl - DO NOT EDIT.
// wgsl-sha256: daed48e60bfcf2b7420d277576d794137d3bff383aef4f68464c98c8a7235c8e
inline constexpr const char* kIndexWGSL = R"(
@group(0) @binding(0) var<storage, read> input: array<f32>;
@group(0) @binding(1) var<storage, read_write> output: array<f32>;
@group(0) @binding(2) var<storage, read> index: array<i32>;

struct Params {
numel: u32,
}
@group(0) @binding(3) var<uniform> params: Params;

override wg_size: u32 = 64;

@compute @workgroup_size(wg_size)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let out_bufi = gid.x;
if (out_bufi >= params.numel) {
return;
}

// 1D-self gather out[i]=self[index[i]] (mirrors Vulkan index_tensor_buffer.glsl).
let i = index[out_bufi];
output[out_bufi] = input[u32(i)];
}
)";

inline constexpr uint32_t kIndexWorkgroupSizeX = 64;
inline constexpr uint32_t kIndexWorkgroupSizeY = 1;
inline constexpr uint32_t kIndexWorkgroupSizeZ = 1;

} // namespace executorch::backends::webgpu
Loading