diff --git a/backends/webgpu/CMakeLists.txt b/backends/webgpu/CMakeLists.txt index 50aea6a469c..cbe36dade70 100644 --- a/backends/webgpu/CMakeLists.txt +++ b/backends/webgpu/CMakeLists.txt @@ -50,6 +50,7 @@ set(WEBGPU_SRCS runtime/ops/slice/Slice.cpp runtime/ops/permute/Permute.cpp runtime/ops/cat/Cat.cpp + runtime/ops/index/Index.cpp ) add_library(webgpu_backend ${WEBGPU_SRCS}) diff --git a/backends/webgpu/runtime/ops/index/Index.cpp b/backends/webgpu/runtime/ops/index/Index.cpp new file mode 100644 index 00000000000..04ebcef3a4f --- /dev/null +++ b/backends/webgpu/runtime/ops/index/Index.cpp @@ -0,0 +1,188 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +#include + +#include +#include +#include +#include + +namespace executorch::backends::webgpu { + +namespace { + +struct IndexParams { + uint32_t numel; + uint32_t _pad[3]; // pad to 16 bytes +}; + +// aten.index.Tensor 1D-self gather out[i]=self[index[i]] (mirrors Vulkan). +void index_impl(WebGPUGraph& graph, const std::vector& args) { + // args: [self, indices (Tensor?[] -> ValueList), out]. + const int self_id = args.at(0); + const int list_id = args.at(1); + const int out_id = args.at(args.size() - 1); + + if (graph.get_value_type(self_id) != WebGPUGraph::ValueType::Tensor) { + throw std::runtime_error("index: self arg is not a tensor"); + } + if (graph.get_value_type(out_id) != WebGPUGraph::ValueType::Tensor) { + throw std::runtime_error("index: out arg is not a tensor"); + } + if (graph.get_value_type(list_id) != WebGPUGraph::ValueType::ValueList) { + throw std::runtime_error("index: indices arg is not a ValueList"); + } + + // Exactly one non-Null index tensor (mirror Vulkan IndexTensor.cpp:67-69). + const std::vector& ids = graph.get_value_list(list_id); + int index_id = -1; + for (int id : ids) { + if (graph.get_value_type(id) == WebGPUGraph::ValueType::Null) { + continue; + } + if (graph.get_value_type(id) != WebGPUGraph::ValueType::Tensor) { + throw std::runtime_error("index: index list element is not a tensor"); + } + if (index_id != -1) { + throw std::runtime_error("index: expected exactly one index tensor"); + } + index_id = id; + } + if (index_id == -1) { + throw std::runtime_error("index: no index tensor provided"); + } + + WGPUDevice device = graph.device(); + + const auto& self_tensor = graph.get_tensor(self_id); + const auto& index_tensor = graph.get_tensor(index_id); + const auto& out_tensor = graph.get_tensor(out_id); + + const size_t out_numel = out_tensor.nbytes / sizeof(float); + if (out_tensor.nbytes != out_numel * sizeof(float) || + self_tensor.nbytes % sizeof(float) != 0) { + throw std::runtime_error("index: non-fp32 self/out (nbytes != numel * 4)"); + } + // Index is the int32 downcast of the int64 advanced index (downcast_64_bit). + const size_t index_numel = index_tensor.nbytes / sizeof(int32_t); + if (index_tensor.nbytes != index_numel * sizeof(int32_t)) { + throw std::runtime_error("index: index buffer is not int32 (nbytes % 4)"); + } + // out is one self element per index element (row_width == 1, 1D self). + if (out_numel != index_numel) { + throw std::runtime_error("index: out numel != index numel"); + } + + uint32_t num_elements = static_cast(out_numel); + uint32_t wg_size = utils::clamp_workgroup_size(device, kIndexWorkgroupSizeX); + uint32_t workgroup_count = + utils::compute_1d_workgroup_count(device, num_elements, wg_size, "index"); + + WGPUConstantEntry wg_size_constant = {}; + wg_size_constant.key = {"wg_size", WGPU_STRLEN}; + wg_size_constant.value = static_cast(wg_size); + + IndexParams params = {}; + params.numel = num_elements; + + WGPUBufferDescriptor uniform_desc = {}; + uniform_desc.size = sizeof(IndexParams); + uniform_desc.usage = WGPUBufferUsage_Uniform | WGPUBufferUsage_CopyDst; + uniform_desc.mappedAtCreation = true; + WGPUBuffer uniform_buffer = wgpuDeviceCreateBuffer(device, &uniform_desc); + std::memcpy( + wgpuBufferGetMappedRange(uniform_buffer, 0, sizeof(IndexParams)), + ¶ms, + sizeof(IndexParams)); + wgpuBufferUnmap(uniform_buffer); + graph.add_uniform_buffer_bytes(sizeof(IndexParams)); + + WGPUShaderSourceWGSL wgsl_desc = {}; + wgsl_desc.chain.sType = WGPUSType_ShaderSourceWGSL; + wgsl_desc.code = {kIndexWGSL, WGPU_STRLEN}; + WGPUShaderModuleDescriptor shader_desc = {}; + shader_desc.nextInChain = &wgsl_desc.chain; + WGPUShaderModule shader = wgpuDeviceCreateShaderModule(device, &shader_desc); + + // self (read), out (read_write), index (read i32), params (uniform). + WGPUBindGroupLayoutEntry entries[4] = {}; + entries[0].binding = 0; + entries[0].visibility = WGPUShaderStage_Compute; + entries[0].buffer.type = WGPUBufferBindingType_ReadOnlyStorage; + entries[1].binding = 1; + entries[1].visibility = WGPUShaderStage_Compute; + entries[1].buffer.type = WGPUBufferBindingType_Storage; + entries[2].binding = 2; + entries[2].visibility = WGPUShaderStage_Compute; + entries[2].buffer.type = WGPUBufferBindingType_ReadOnlyStorage; + entries[3].binding = 3; + entries[3].visibility = WGPUShaderStage_Compute; + entries[3].buffer.type = WGPUBufferBindingType_Uniform; + + WGPUBindGroupLayoutDescriptor bgl_desc = {}; + bgl_desc.entryCount = 4; + bgl_desc.entries = entries; + WGPUBindGroupLayout bgl = wgpuDeviceCreateBindGroupLayout(device, &bgl_desc); + + WGPUPipelineLayoutDescriptor pl_desc = {}; + pl_desc.bindGroupLayoutCount = 1; + pl_desc.bindGroupLayouts = &bgl; + WGPUPipelineLayout pipeline_layout = + wgpuDeviceCreatePipelineLayout(device, &pl_desc); + + WGPUComputePipelineDescriptor pipeline_desc = {}; + pipeline_desc.layout = pipeline_layout; + pipeline_desc.compute.module = shader; + pipeline_desc.compute.entryPoint = {"main", WGPU_STRLEN}; + pipeline_desc.compute.constantCount = 1; + pipeline_desc.compute.constants = &wg_size_constant; + WGPUComputePipeline pipeline = + wgpuDeviceCreateComputePipeline(device, &pipeline_desc); + + WGPUBindGroupEntry bg_entries[4] = {}; + bg_entries[0].binding = 0; + bg_entries[0].buffer = self_tensor.buffer; + bg_entries[0].size = self_tensor.nbytes; + bg_entries[1].binding = 1; + bg_entries[1].buffer = out_tensor.buffer; + bg_entries[1].size = out_tensor.nbytes; + bg_entries[2].binding = 2; + bg_entries[2].buffer = index_tensor.buffer; + bg_entries[2].size = index_tensor.nbytes; + bg_entries[3].binding = 3; + bg_entries[3].buffer = uniform_buffer; + bg_entries[3].size = sizeof(IndexParams); + + WGPUBindGroupDescriptor bg_desc = {}; + bg_desc.layout = bgl; + bg_desc.entryCount = 4; + bg_desc.entries = bg_entries; + WGPUBindGroup bind_group = wgpuDeviceCreateBindGroup(device, &bg_desc); + + graph.add_dispatch({pipeline, bind_group, workgroup_count}); + + wgpuShaderModuleRelease(shader); + wgpuBindGroupLayoutRelease(bgl); + wgpuPipelineLayoutRelease(pipeline_layout); + // The bind group keeps the uniform buffer alive until release. + wgpuBufferRelease(uniform_buffer); +} + +} // namespace + +WEBGPU_REGISTER_OPERATORS { + WEBGPU_REGISTER_OP(aten.index.Tensor, index_impl); +} + +} // namespace executorch::backends::webgpu diff --git a/backends/webgpu/runtime/ops/index/index.wgsl b/backends/webgpu/runtime/ops/index/index.wgsl new file mode 100644 index 00000000000..b0fd6df81bf --- /dev/null +++ b/backends/webgpu/runtime/ops/index/index.wgsl @@ -0,0 +1,22 @@ +@group(0) @binding(0) var input: array; +@group(0) @binding(1) var output: array; +@group(0) @binding(2) var index: array; + +struct Params { + numel: u32, +} +@group(0) @binding(3) var params: Params; + +override wg_size: u32 = 64; + +@compute @workgroup_size(wg_size) +fn main(@builtin(global_invocation_id) gid: vec3) { + let out_bufi = gid.x; + if (out_bufi >= params.numel) { + return; + } + + // 1D-self gather out[i]=self[index[i]] (mirrors Vulkan index_tensor_buffer.glsl). + let i = index[out_bufi]; + output[out_bufi] = input[u32(i)]; +} diff --git a/backends/webgpu/runtime/ops/index/index_wgsl.h b/backends/webgpu/runtime/ops/index/index_wgsl.h new file mode 100644 index 00000000000..839a3b164bb --- /dev/null +++ b/backends/webgpu/runtime/ops/index/index_wgsl.h @@ -0,0 +1,46 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace executorch::backends::webgpu { + +// @generated from index.wgsl - DO NOT EDIT. +// wgsl-sha256: daed48e60bfcf2b7420d277576d794137d3bff383aef4f68464c98c8a7235c8e +inline constexpr const char* kIndexWGSL = R"( +@group(0) @binding(0) var input: array; +@group(0) @binding(1) var output: array; +@group(0) @binding(2) var index: array; + +struct Params { + numel: u32, +} +@group(0) @binding(3) var params: Params; + +override wg_size: u32 = 64; + +@compute @workgroup_size(wg_size) +fn main(@builtin(global_invocation_id) gid: vec3) { + let out_bufi = gid.x; + if (out_bufi >= params.numel) { + return; + } + + // 1D-self gather out[i]=self[index[i]] (mirrors Vulkan index_tensor_buffer.glsl). + let i = index[out_bufi]; + output[out_bufi] = input[u32(i)]; +} +)"; + +inline constexpr uint32_t kIndexWorkgroupSizeX = 64; +inline constexpr uint32_t kIndexWorkgroupSizeY = 1; +inline constexpr uint32_t kIndexWorkgroupSizeZ = 1; + +} // namespace executorch::backends::webgpu