Skip to content

Commit 86cd951

Browse files
committed
Add webgpu backend
1 parent 2d8e667 commit 86cd951

File tree

14 files changed

+1032
-10
lines changed

14 files changed

+1032
-10
lines changed

CMakeLists.txt

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF)
1616
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
1717
option(MLX_BUILD_METAL "Build metal backend" ON)
1818
option(MLX_BUILD_CPU "Build cpu backend" ON)
19+
option(MLX_BUILD_WEBGPU "Build webgpu backend" OFF)
1920
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
2021
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
2122
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
@@ -52,6 +53,10 @@ if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
5253
endif()
5354
endif()
5455

56+
if(MLX_BUILD_WEBGPU AND MLX_BUILD_METAL)
57+
message(FATAL_ERROR "Can not build both webgpu and metal backends.")
58+
endif()
59+
5560
else()
5661
set(MLX_BUILD_METAL OFF)
5762
message(WARNING "MLX is prioritised for Apple silicon systems using macOS.")
@@ -114,6 +119,17 @@ elseif(MLX_BUILD_METAL)
114119
target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB})
115120
endif()
116121

122+
if(MLX_BUILD_WEBGPU)
123+
FetchContent_Declare(
124+
betann
125+
GIT_REPOSITORY https://github.com/frost-beta/betann.git
126+
GIT_TAG db8cb414c81d05cbdb6827637733f6087e4d2049
127+
EXCLUDE_FROM_ALL)
128+
set(BETANN_BUILD_TESTS OFF)
129+
FetchContent_MakeAvailable(betann)
130+
target_link_libraries(mlx PRIVATE betann)
131+
endif()
132+
117133
if(WIN32)
118134
if(MSVC)
119135
# GGUF does not build with MSVC.

mlx/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
4040

4141
if(MLX_BUILD_METAL)
4242
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal)
43+
elseif(MLX_BUILD_WEBGPU)
44+
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/webgpu)
4345
else()
4446
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_metal)
4547
endif()

mlx/allocator.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,17 +41,17 @@ size_t CommonAllocator::size(Buffer buffer) const {
4141
return *static_cast<size_t*>(buffer.ptr());
4242
}
4343

44-
Buffer malloc_or_wait(size_t size) {
45-
auto buffer = allocator().malloc(size);
44+
Buffer malloc_or_wait(const Device& device, size_t size) {
45+
auto buffer = allocator().malloc(device, size);
4646

4747
while (size && !buffer.ptr() && scheduler::n_active_tasks() > 0) {
4848
scheduler::wait_for_one();
49-
buffer = allocator().malloc(size);
49+
buffer = allocator().malloc(device, size);
5050
}
5151

5252
// Try swapping if needed
5353
if (size && !buffer.ptr()) {
54-
buffer = allocator().malloc(size, /* allow_swap = */ true);
54+
buffer = allocator().malloc(device, size, /* allow_swap = */ true);
5555
}
5656

5757
if (size && !buffer.ptr()) {

mlx/allocator.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
#include <cstdlib>
66

7+
#include "mlx/device.h"
8+
79
namespace mlx::core::allocator {
810

911
// Simple wrapper around buffer pointers
@@ -34,12 +36,19 @@ void free(Buffer buffer);
3436

3537
// Wait for running tasks to finish and free up memory
3638
// if allocation fails
37-
Buffer malloc_or_wait(size_t size);
39+
Buffer malloc_or_wait(const Device& device, size_t size);
40+
inline Buffer malloc_or_wait(size_t size) {
41+
return malloc_or_wait(Device::cpu, size);
42+
}
3843

3944
class Allocator {
4045
/** Abstract base class for a memory allocator. */
4146
public:
4247
virtual Buffer malloc(size_t size, bool allow_swap = false) = 0;
48+
virtual Buffer
49+
malloc(const Device& device, size_t size, bool allow_swap = false) {
50+
return malloc(size, allow_swap);
51+
}
4352
virtual void free(Buffer buffer) = 0;
4453
virtual size_t size(Buffer buffer) const = 0;
4554

mlx/array.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,13 @@ bool array::is_tracer() const {
109109
detail::retain_graph();
110110
}
111111

112+
void array::reset_data_ptr() {
113+
void* data_ptr = buffer().raw_ptr();
114+
auto char_offset = sizeof(char) * itemsize() * array_desc_->offset;
115+
array_desc_->data_ptr =
116+
static_cast<void*>(static_cast<char*>(data_ptr) + char_offset);
117+
}
118+
112119
void array::set_data(allocator::Buffer buffer, Deleter d) {
113120
array_desc_->data = std::make_shared<Data>(buffer, d);
114121
array_desc_->data_ptr = buffer.raw_ptr();
@@ -142,6 +149,7 @@ void array::copy_shared_buffer(
142149
array_desc_->strides = strides;
143150
array_desc_->flags = flags;
144151
array_desc_->data_size = data_size;
152+
array_desc_->offset = offset;
145153
auto char_offset = sizeof(char) * itemsize() * offset;
146154
array_desc_->data_ptr = static_cast<void*>(
147155
static_cast<char*>(other.array_desc_->data_ptr) + char_offset);
@@ -161,6 +169,7 @@ void array::move_shared_buffer(
161169
array_desc_->strides = strides;
162170
array_desc_->flags = flags;
163171
array_desc_->data_size = data_size;
172+
array_desc_->offset = offset;
164173
auto char_offset = sizeof(char) * itemsize() * offset;
165174
auto data_ptr = other.array_desc_->data_ptr;
166175
other.array_desc_->data_ptr = nullptr;

mlx/array.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,8 @@ class array {
401401
// Check if the array is a tracer array
402402
bool is_tracer() const;
403403

404+
void reset_data_ptr();
405+
404406
void set_data(allocator::Buffer buffer, Deleter d = allocator::free);
405407

406408
void set_data(
@@ -465,6 +467,9 @@ class array {
465467
// The size in elements of the data buffer the array accesses
466468
size_t data_size;
467469

470+
// Offset from the shared data in elements
471+
size_t offset{0};
472+
468473
// Contains useful meta data about the array
469474
Flags flags;
470475

mlx/backend/common/binary.h

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "mlx/allocator.h"
77
#include "mlx/array.h"
88
#include "mlx/backend/common/utils.h"
9+
#include "mlx/primitives.h"
910

1011
#include "mlx/backend/common/simd/simd.h"
1112

@@ -47,10 +48,14 @@ void set_binary_op_output_data(
4748
bool donate_with_move = false) {
4849
bool b_donatable = is_donatable(b, out);
4950
bool a_donatable = is_donatable(a, out);
51+
const Device& device = out.primitive().device();
5052
switch (bopt) {
5153
case BinaryOpType::ScalarScalar:
5254
out.set_data(
53-
allocator::malloc_or_wait(out.itemsize()), 1, a.strides(), a.flags());
55+
allocator::malloc_or_wait(device, out.itemsize()),
56+
1,
57+
a.strides(),
58+
a.flags());
5459
break;
5560
case BinaryOpType::ScalarVector:
5661
if (b_donatable) {
@@ -61,7 +66,7 @@ void set_binary_op_output_data(
6166
}
6267
} else {
6368
out.set_data(
64-
allocator::malloc_or_wait(b.data_size() * out.itemsize()),
69+
allocator::malloc_or_wait(device, b.data_size() * out.itemsize()),
6570
b.data_size(),
6671
b.strides(),
6772
b.flags());
@@ -76,7 +81,7 @@ void set_binary_op_output_data(
7681
}
7782
} else {
7883
out.set_data(
79-
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
84+
allocator::malloc_or_wait(device, a.data_size() * out.itemsize()),
8085
a.data_size(),
8186
a.strides(),
8287
a.flags());
@@ -97,7 +102,7 @@ void set_binary_op_output_data(
97102
}
98103
} else {
99104
out.set_data(
100-
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
105+
allocator::malloc_or_wait(device, a.data_size() * out.itemsize()),
101106
a.data_size(),
102107
a.strides(),
103108
a.flags());
@@ -118,7 +123,7 @@ void set_binary_op_output_data(
118123
out.copy_shared_buffer(b);
119124
}
120125
} else {
121-
out.set_data(allocator::malloc_or_wait(out.nbytes()));
126+
out.set_data(allocator::malloc_or_wait(device, out.nbytes()));
122127
}
123128
break;
124129
}

mlx/backend/webgpu/CMakeLists.txt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
target_sources(
2+
mlx
3+
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
4+
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
5+
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
6+
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
7+
${CMAKE_CURRENT_SOURCE_DIR}/../no_metal/event.cpp)

mlx/backend/webgpu/allocator.cpp

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
// Copyright © 2025 Apple Inc.
2+
3+
#include "mlx/backend/webgpu/allocator.h"
4+
5+
namespace mlx::core {
6+
7+
namespace allocator {
8+
9+
Allocator& allocator() {
10+
return webgpu::allocator();
11+
}
12+
13+
void* Buffer::raw_ptr() {
14+
return static_cast<webgpu::DoubleBuffer*>(ptr_)->cpu_data();
15+
}
16+
17+
} // namespace allocator
18+
19+
namespace webgpu {
20+
21+
DoubleBuffer::DoubleBuffer(size_t size)
22+
: cpu_data_(std::malloc(size + sizeof(size_t))) {
23+
*static_cast<size_t*>(cpu_data_) = size;
24+
}
25+
26+
DoubleBuffer::DoubleBuffer(betann::Device& device, size_t size)
27+
: gpu_data_(device.CreateBuffer(
28+
size,
29+
betann::BufferUsage::Storage | betann::BufferUsage::CopySrc)) {}
30+
31+
DoubleBuffer::~DoubleBuffer() {
32+
std::free(cpu_data_);
33+
}
34+
35+
void DoubleBuffer::copy_to_cpu(const void* data, size_t size) {
36+
assert(!cpu_data_);
37+
cpu_data_ = std::malloc(size + sizeof(size_t));
38+
*static_cast<size_t*>(cpu_data_) = size;
39+
std::memcpy(cpu_data(), data, size);
40+
}
41+
42+
size_t DoubleBuffer::size() const {
43+
if (cpu_data_)
44+
return *static_cast<size_t*>(cpu_data_);
45+
if (gpu_data_)
46+
return gpu_data_.GetSize();
47+
return 0;
48+
}
49+
50+
WgpuAllocator::WgpuAllocator() : device_(webgpu::device(Device::gpu)) {}
51+
52+
Buffer
53+
WgpuAllocator::malloc(const Device& device, size_t size, bool allow_swap) {
54+
if (device.type == Device::gpu)
55+
return Buffer(new DoubleBuffer(webgpu::device(device), size));
56+
else
57+
return Buffer(new DoubleBuffer(size));
58+
}
59+
60+
void WgpuAllocator::free(Buffer buffer) {
61+
delete static_cast<DoubleBuffer*>(buffer.ptr());
62+
}
63+
64+
size_t WgpuAllocator::size(Buffer buffer) const {
65+
return static_cast<DoubleBuffer*>(buffer.ptr())->size();
66+
}
67+
68+
void WgpuAllocator::ensure_gpu_data(Buffer& buffer) {
69+
auto* dbuf = static_cast<DoubleBuffer*>(buffer.ptr());
70+
if (dbuf->gpu_data() || dbuf->size() == 0)
71+
return;
72+
dbuf->set_gpu_data(device_.CreateBufferFromData(
73+
dbuf->cpu_data(), dbuf->size(), betann::BufferUsage::Storage));
74+
}
75+
76+
WgpuAllocator& allocator() {
77+
static WgpuAllocator allocator_;
78+
return allocator_;
79+
}
80+
81+
betann::Device& device(mlx::core::Device) {
82+
static betann::Device device;
83+
return device;
84+
}
85+
86+
} // namespace webgpu
87+
88+
namespace metal {
89+
90+
size_t get_active_memory() {
91+
return 0;
92+
}
93+
size_t get_peak_memory() {
94+
return 0;
95+
}
96+
void reset_peak_memory() {}
97+
size_t get_cache_memory() {
98+
return 0;
99+
}
100+
size_t set_memory_limit(size_t, bool) {
101+
return 0;
102+
}
103+
size_t set_cache_limit(size_t) {
104+
return 0;
105+
}
106+
size_t set_wired_limit(size_t) {
107+
return 0;
108+
}
109+
110+
std::unordered_map<std::string, std::variant<std::string, size_t>>
111+
device_info() {
112+
throw std::runtime_error("[webgpu::device_info] Not implemented");
113+
};
114+
115+
void clear_cache() {}
116+
117+
} // namespace metal
118+
119+
} // namespace mlx::core

0 commit comments

Comments
 (0)