Skip to content

Commit c0900db

Browse files
committed
Add webgpu backend
1 parent 1762793 commit c0900db

File tree

11 files changed

+1192
-0
lines changed

11 files changed

+1192
-0
lines changed

CMakeLists.txt

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF)
1616
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
1717
option(MLX_BUILD_METAL "Build metal backend" ON)
1818
option(MLX_BUILD_CPU "Build cpu backend" ON)
19+
option(MLX_BUILD_WEBGPU "Build webgpu backend" OFF)
1920
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
2021
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
2122
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
@@ -52,6 +53,10 @@ if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
5253
endif()
5354
endif()
5455

56+
if(MLX_BUILD_WEBGPU AND MLX_BUILD_METAL)
57+
message(FATAL_ERROR "Can not build both webgpu and metal backends.")
58+
endif()
59+
5560
else()
5661
set(MLX_BUILD_METAL OFF)
5762
message(WARNING "MLX is prioritised for Apple silicon systems using macOS.")
@@ -114,6 +119,17 @@ elseif(MLX_BUILD_METAL)
114119
target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB})
115120
endif()
116121

122+
if(MLX_BUILD_WEBGPU)
123+
FetchContent_Declare(
124+
betann
125+
GIT_REPOSITORY https://github.com/frost-beta/betann.git
126+
GIT_TAG db2d5c9bddb75d0d67f675a68ea79ae0fcba723e
127+
EXCLUDE_FROM_ALL)
128+
set(BETANN_BUILD_TESTS OFF)
129+
FetchContent_MakeAvailable(betann)
130+
target_link_libraries(mlx PRIVATE betann)
131+
endif()
132+
117133
if(WIN32)
118134
if(MSVC)
119135
# GGUF does not build with MSVC.

mlx/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
4242

4343
if(MLX_BUILD_METAL)
4444
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal)
45+
elseif(MLX_BUILD_WEBGPU)
46+
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/webgpu)
4547
else()
4648
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_metal)
4749
endif()

mlx/array.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,13 @@ bool array::is_tracer() const {
132132
detail::retain_graph();
133133
}
134134

135+
void array::reset_data_ptr() {
136+
void* data_ptr = buffer().raw_ptr();
137+
auto char_offset = sizeof(char) * itemsize() * array_desc_->offset;
138+
array_desc_->data_ptr =
139+
static_cast<void*>(static_cast<char*>(data_ptr) + char_offset);
140+
}
141+
135142
void array::set_data(allocator::Buffer buffer, Deleter d) {
136143
array_desc_->data = std::make_shared<Data>(buffer, d);
137144
array_desc_->data_ptr = buffer.raw_ptr();
@@ -165,6 +172,7 @@ void array::copy_shared_buffer(
165172
array_desc_->strides = strides;
166173
array_desc_->flags = flags;
167174
array_desc_->data_size = data_size;
175+
array_desc_->offset = offset;
168176
auto char_offset = sizeof(char) * itemsize() * offset;
169177
array_desc_->data_ptr = static_cast<void*>(
170178
static_cast<char*>(other.array_desc_->data_ptr) + char_offset);
@@ -184,6 +192,7 @@ void array::move_shared_buffer(
184192
array_desc_->strides = strides;
185193
array_desc_->flags = flags;
186194
array_desc_->data_size = data_size;
195+
array_desc_->offset = offset;
187196
auto char_offset = sizeof(char) * itemsize() * offset;
188197
auto data_ptr = other.array_desc_->data_ptr;
189198
other.array_desc_->data_ptr = nullptr;

mlx/array.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,12 @@ class array {
333333
return array_desc_->data_size;
334334
}
335335

336+
/** The offset (in elements) of the underlying buffer the array points to.
337+
**/
338+
size_t offset() const {
339+
return array_desc_->offset;
340+
}
341+
336342
allocator::Buffer& buffer() {
337343
return array_desc_->data->buffer;
338344
}
@@ -413,6 +419,8 @@ class array {
413419
// Check if the array is a tracer array
414420
bool is_tracer() const;
415421

422+
void reset_data_ptr();
423+
416424
void set_data(allocator::Buffer buffer, Deleter d = allocator::free);
417425

418426
void set_data(
@@ -477,6 +485,9 @@ class array {
477485
// The size in elements of the data buffer the array accesses
478486
size_t data_size;
479487

488+
// Offset from the shared data in elements
489+
size_t offset{0};
490+
480491
// Contains useful meta data about the array
481492
Flags flags;
482493

mlx/backend/webgpu/CMakeLists.txt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
target_sources(
2+
mlx
3+
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
4+
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
5+
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
6+
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
7+
${CMAKE_CURRENT_SOURCE_DIR}/../no_metal/event.cpp)

mlx/backend/webgpu/allocator.cpp

Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
// Copyright © 2025 Apple Inc.
2+
3+
#include "mlx/backend/webgpu/allocator.h"
4+
5+
#include "mlx/array.h"
6+
#include "mlx/backend/webgpu/utils.h"
7+
#include "mlx/primitives.h"
8+
9+
namespace mlx::core {
10+
11+
namespace allocator {
12+
13+
Allocator& allocator() {
14+
return webgpu::allocator();
15+
}
16+
17+
void* Buffer::raw_ptr() {
18+
return static_cast<webgpu::DoubleBuffer*>(ptr_)->cpu_data();
19+
}
20+
21+
} // namespace allocator
22+
23+
namespace webgpu {
24+
25+
DoubleBuffer::DoubleBuffer(size_t size)
26+
: size_(size), cpu_data_(std::malloc(size)) {}
27+
28+
DoubleBuffer::DoubleBuffer(betann::Device& device, Dtype dtype, size_t size)
29+
: size_(size),
30+
gpu_data_(device.CreateBuffer(
31+
size * gpu_size_factor(dtype),
32+
betann::BufferUsage::Storage | betann::BufferUsage::CopySrc)) {}
33+
34+
DoubleBuffer::~DoubleBuffer() {
35+
std::free(cpu_data_);
36+
}
37+
38+
WgpuAllocator::WgpuAllocator() : device_(webgpu::device(Device::gpu)) {}
39+
40+
Buffer WgpuAllocator::malloc(size_t size, bool allow_swap) {
41+
return Buffer(new DoubleBuffer(size));
42+
}
43+
44+
void WgpuAllocator::free(Buffer buffer) {
45+
delete static_cast<DoubleBuffer*>(buffer.ptr());
46+
}
47+
48+
size_t WgpuAllocator::size(Buffer buffer) const {
49+
return static_cast<DoubleBuffer*>(buffer.ptr())->size();
50+
}
51+
52+
void WgpuAllocator::ensure_cpu_data(array& arr, const void* data) {
53+
auto* dbuf = static_cast<DoubleBuffer*>(arr.buffer().ptr());
54+
if (dbuf->cpu_data() || dbuf->size() == 0)
55+
return;
56+
void* cpu_data = std::malloc(dbuf->size());
57+
size_t num_elements = dbuf->size() / arr.itemsize();
58+
switch (arr.dtype()) {
59+
case int32:
60+
case uint32:
61+
case float16:
62+
case float32:
63+
std::memcpy(cpu_data, data, dbuf->size());
64+
break;
65+
case bool_:
66+
std::transform(
67+
static_cast<const uint32_t*>(data),
68+
static_cast<const uint32_t*>(data) + num_elements,
69+
static_cast<bool*>(cpu_data),
70+
[](uint32_t e) { return static_cast<bool>(e); });
71+
break;
72+
case uint8:
73+
std::transform(
74+
static_cast<const uint32_t*>(data),
75+
static_cast<const uint32_t*>(data) + num_elements,
76+
static_cast<uint8_t*>(cpu_data),
77+
[](uint32_t e) { return static_cast<uint8_t>(e); });
78+
break;
79+
case uint16:
80+
std::transform(
81+
static_cast<const uint32_t*>(data),
82+
static_cast<const uint32_t*>(data) + num_elements,
83+
static_cast<uint16_t*>(cpu_data),
84+
[](uint32_t e) { return static_cast<uint16_t>(e); });
85+
break;
86+
case int8:
87+
std::transform(
88+
static_cast<const int32_t*>(data),
89+
static_cast<const int32_t*>(data) + num_elements,
90+
static_cast<int8_t*>(cpu_data),
91+
[](int32_t e) { return static_cast<int8_t>(e); });
92+
break;
93+
case int16:
94+
std::transform(
95+
static_cast<const int32_t*>(data),
96+
static_cast<const int32_t*>(data) + num_elements,
97+
static_cast<int16_t*>(cpu_data),
98+
[](int32_t e) { return static_cast<int16_t>(e); });
99+
break;
100+
default:
101+
throw_unsupported_dtype_error(arr.dtype());
102+
}
103+
dbuf->set_cpu_data(cpu_data);
104+
}
105+
106+
void WgpuAllocator::ensure_gpu_data(array& arr) {
107+
auto* dbuf = static_cast<DoubleBuffer*>(arr.buffer().ptr());
108+
if (dbuf->gpu_data() || dbuf->size() == 0)
109+
return;
110+
size_t num_elements = dbuf->size() / arr.itemsize();
111+
switch (arr.dtype()) {
112+
case int32:
113+
case uint32:
114+
case float16:
115+
case float32:
116+
dbuf->set_gpu_data(
117+
device_.CreateBufferFromData(dbuf->cpu_data(), dbuf->size()));
118+
break;
119+
case bool_:
120+
dbuf->set_gpu_data(device_.CreateBufferTransformTo<uint32_t>(
121+
static_cast<bool*>(dbuf->cpu_data()), num_elements));
122+
break;
123+
case uint8:
124+
dbuf->set_gpu_data(device_.CreateBufferTransformTo<uint32_t>(
125+
static_cast<uint8_t*>(dbuf->cpu_data()), num_elements));
126+
break;
127+
case uint16:
128+
dbuf->set_gpu_data(device_.CreateBufferTransformTo<uint32_t>(
129+
static_cast<uint16_t*>(dbuf->cpu_data()), num_elements));
130+
break;
131+
case int8:
132+
dbuf->set_gpu_data(device_.CreateBufferTransformTo<int32_t>(
133+
static_cast<int8_t*>(dbuf->cpu_data()), num_elements));
134+
break;
135+
case int16:
136+
dbuf->set_gpu_data(device_.CreateBufferTransformTo<int32_t>(
137+
static_cast<int16_t*>(dbuf->cpu_data()), num_elements));
138+
break;
139+
default:
140+
throw_unsupported_dtype_error(arr.dtype());
141+
}
142+
}
143+
144+
Buffer WgpuAllocator::malloc_gpu(array& arr) {
145+
return malloc_gpu(arr, arr.nbytes());
146+
}
147+
148+
Buffer WgpuAllocator::malloc_gpu(array& arr, size_t size) {
149+
return Buffer(new DoubleBuffer(device_, arr.dtype(), size));
150+
}
151+
152+
WgpuAllocator& allocator() {
153+
static WgpuAllocator allocator_;
154+
return allocator_;
155+
}
156+
157+
betann::Device& device(mlx::core::Device) {
158+
static betann::Device device;
159+
return device;
160+
}
161+
162+
betann::Device& device(array& arr) {
163+
return device(arr.primitive().device());
164+
}
165+
166+
} // namespace webgpu
167+
168+
namespace metal {
169+
170+
size_t get_active_memory() {
171+
return 0;
172+
}
173+
size_t get_peak_memory() {
174+
return 0;
175+
}
176+
void reset_peak_memory() {}
177+
size_t get_cache_memory() {
178+
return 0;
179+
}
180+
size_t set_memory_limit(size_t, bool) {
181+
return 0;
182+
}
183+
size_t set_cache_limit(size_t) {
184+
return 0;
185+
}
186+
size_t set_wired_limit(size_t) {
187+
return 0;
188+
}
189+
190+
std::unordered_map<std::string, std::variant<std::string, size_t>>
191+
device_info() {
192+
throw std::runtime_error("[webgpu::device_info] Not implemented");
193+
};
194+
195+
void clear_cache() {}
196+
197+
} // namespace metal
198+
199+
} // namespace mlx::core

mlx/backend/webgpu/allocator.h

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
// Copyright © 2025 Apple Inc.
2+
3+
#pragma once
4+
5+
#include "mlx/allocator.h"
6+
#include "mlx/device.h"
7+
8+
#include <betann/betann.h>
9+
10+
namespace mlx::core {
11+
class array;
12+
struct Dtype;
13+
} // namespace mlx::core
14+
15+
namespace mlx::core::webgpu {
16+
17+
using allocator::Buffer;
18+
19+
// Holds data for both CPU and GPU.
20+
class DoubleBuffer {
21+
public:
22+
// Allocates memory in CPU.
23+
explicit DoubleBuffer(size_t size);
24+
// Allocates memory in GPU.
25+
DoubleBuffer(betann::Device& device, Dtype dtype, size_t size);
26+
27+
~DoubleBuffer();
28+
29+
void set_cpu_data(void* data) {
30+
assert(!cpu_data_);
31+
cpu_data_ = data;
32+
}
33+
void set_gpu_data(betann::Buffer buffer) {
34+
gpu_data_ = std::move(buffer);
35+
}
36+
37+
void* cpu_data() const {
38+
return cpu_data_;
39+
}
40+
const betann::Buffer& gpu_data() const {
41+
return gpu_data_;
42+
}
43+
44+
size_t size() const {
45+
return size_;
46+
}
47+
48+
private:
49+
size_t size_;
50+
void* cpu_data_ = nullptr;
51+
betann::Buffer gpu_data_;
52+
};
53+
54+
class WgpuAllocator : public allocator::Allocator {
55+
public:
56+
Buffer malloc(size_t size, bool allow_swap) override;
57+
void free(Buffer buffer) override;
58+
size_t size(Buffer buffer) const override;
59+
60+
void ensure_cpu_data(array& arr, const void* data);
61+
void ensure_gpu_data(array& arr);
62+
Buffer malloc_gpu(array& arr);
63+
Buffer malloc_gpu(array& arr, size_t size);
64+
65+
private:
66+
WgpuAllocator();
67+
friend WgpuAllocator& allocator();
68+
69+
betann::Device& device_;
70+
};
71+
72+
WgpuAllocator& allocator();
73+
74+
betann::Device& device(mlx::core::Device);
75+
betann::Device& device(array& arr);
76+
77+
} // namespace mlx::core::webgpu

0 commit comments

Comments
 (0)