Skip to content

Commit cb542f2

Browse files
committed
Add webgpu backend
1 parent f5cc1ee commit cb542f2

File tree

10 files changed

+619
-0
lines changed

10 files changed

+619
-0
lines changed

CMakeLists.txt

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF)
1616
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
1717
option(MLX_BUILD_METAL "Build metal backend" ON)
1818
option(MLX_BUILD_CPU "Build cpu backend" ON)
19+
option(MLX_BUILD_WEBGPU "Build webgpu backend" OFF)
1920
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
2021
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
2122
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
@@ -52,6 +53,10 @@ if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
5253
endif()
5354
endif()
5455

56+
if(MLX_BUILD_WEBGPU AND MLX_BUILD_METAL)
57+
message(FATAL_ERROR "Can not build both webgpu and metal backends.")
58+
endif()
59+
5560
else()
5661
set(MLX_BUILD_METAL OFF)
5762
message(WARNING "MLX is prioritised for Apple silicon systems using macOS.")
@@ -114,6 +119,17 @@ elseif(MLX_BUILD_METAL)
114119
target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB})
115120
endif()
116121

122+
if(MLX_BUILD_WEBGPU)
123+
FetchContent_Declare(
124+
betann
125+
GIT_REPOSITORY https://github.com/frost-beta/betann.git
126+
GIT_TAG 68c5546f6d87cf90f236411369d55a9374bf8b73
127+
EXCLUDE_FROM_ALL)
128+
set(BETANN_BUILD_TESTS OFF)
129+
FetchContent_MakeAvailable(betann)
130+
target_link_libraries(mlx PRIVATE betann)
131+
endif()
132+
117133
if(WIN32)
118134
if(MSVC)
119135
# GGUF does not build with MSVC.

examples/cpp/tutorial.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ namespace mx = mlx::core;
1010
void array_basics() {
1111
// Make a scalar array:
1212
mx::array x(1.0);
13+
std::cout << x + x << std::endl;
1314

1415
// Get the value out of it:
1516
auto s = x.item<float>();

mlx/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ endif()
4747

4848
if(MLX_BUILD_METAL)
4949
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal)
50+
elseif(MLX_BUILD_WEBGPU)
51+
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/webgpu)
5052
else()
5153
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_metal)
5254
endif()

mlx/array.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,13 @@ bool array::is_tracer() const {
109109
detail::retain_graph();
110110
}
111111

112+
void array::reset_data_ptr() {
113+
void* data_ptr = buffer().raw_ptr();
114+
auto char_offset = sizeof(char) * itemsize() * array_desc_->offset;
115+
array_desc_->data_ptr =
116+
static_cast<void*>(static_cast<char*>(data_ptr) + char_offset);
117+
}
118+
112119
void array::set_data(allocator::Buffer buffer, Deleter d) {
113120
array_desc_->data = std::make_shared<Data>(buffer, d);
114121
array_desc_->data_ptr = buffer.raw_ptr();
@@ -142,6 +149,7 @@ void array::copy_shared_buffer(
142149
array_desc_->strides = strides;
143150
array_desc_->flags = flags;
144151
array_desc_->data_size = data_size;
152+
array_desc_->offset = offset;
145153
auto char_offset = sizeof(char) * itemsize() * offset;
146154
array_desc_->data_ptr = static_cast<void*>(
147155
static_cast<char*>(other.array_desc_->data_ptr) + char_offset);
@@ -161,6 +169,7 @@ void array::move_shared_buffer(
161169
array_desc_->strides = strides;
162170
array_desc_->flags = flags;
163171
array_desc_->data_size = data_size;
172+
array_desc_->offset = offset;
164173
auto char_offset = sizeof(char) * itemsize() * offset;
165174
auto data_ptr = other.array_desc_->data_ptr;
166175
other.array_desc_->data_ptr = nullptr;

mlx/array.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,8 @@ class array {
401401
// Check if the array is a tracer array
402402
bool is_tracer() const;
403403

404+
void reset_data_ptr();
405+
404406
void set_data(allocator::Buffer buffer, Deleter d = allocator::free);
405407

406408
void set_data(
@@ -465,6 +467,9 @@ class array {
465467
// The size in elements of the data buffer the array accesses
466468
size_t data_size;
467469

470+
// Offset from the shared data in elements
471+
size_t offset{0};
472+
468473
// Contains useful meta data about the array
469474
Flags flags;
470475

mlx/backend/webgpu/CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
target_sources(
2+
mlx
3+
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
4+
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
5+
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
6+
${CMAKE_CURRENT_SOURCE_DIR}/../no_metal/event.cpp)

mlx/backend/webgpu/allocator.cpp

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
// Copyright © 2025 Apple Inc.
2+
3+
#include "mlx/backend/webgpu/allocator.h"
4+
5+
namespace mlx::core {
6+
7+
namespace allocator {
8+
9+
Allocator& allocator() {
10+
return webgpu::allocator();
11+
}
12+
13+
void* Buffer::raw_ptr() {
14+
return static_cast<webgpu::DoubleBuffer*>(ptr_)->cpu_data();
15+
}
16+
17+
} // namespace allocator
18+
19+
namespace webgpu {
20+
21+
DoubleBuffer::DoubleBuffer(size_t size)
22+
: cpu_data_(std::malloc(size + sizeof(size_t))) {
23+
*static_cast<size_t*>(cpu_data_) = size;
24+
}
25+
26+
DoubleBuffer::DoubleBuffer(betann::Device& device, size_t size)
27+
: gpu_data_(device.CreateBuffer(
28+
size,
29+
betann::BufferUsage::Storage | betann::BufferUsage::CopySrc)) {}
30+
31+
DoubleBuffer::~DoubleBuffer() {
32+
std::free(cpu_data_);
33+
}
34+
35+
void DoubleBuffer::copy_to_cpu(const void* data, size_t size) {
36+
assert(!cpu_data_);
37+
cpu_data_ = std::malloc(size + sizeof(size_t));
38+
*static_cast<size_t*>(cpu_data_) = size;
39+
std::memcpy(cpu_data(), data, size);
40+
}
41+
42+
size_t DoubleBuffer::size() const {
43+
if (cpu_data_)
44+
return *static_cast<size_t*>(cpu_data_);
45+
if (gpu_data_)
46+
return gpu_data_.GetSize();
47+
return 0;
48+
}
49+
50+
WgpuAllocator::WgpuAllocator() : device_(webgpu::device(Device::gpu)) {}
51+
52+
Buffer WgpuAllocator::malloc(size_t size, bool allow_swap) {
53+
return Buffer(new DoubleBuffer(size));
54+
}
55+
56+
void WgpuAllocator::free(Buffer buffer) {
57+
delete static_cast<DoubleBuffer*>(buffer.ptr());
58+
}
59+
60+
size_t WgpuAllocator::size(Buffer buffer) const {
61+
return static_cast<DoubleBuffer*>(buffer.ptr())->size();
62+
}
63+
64+
Buffer WgpuAllocator::gpu_malloc(size_t size) {
65+
return Buffer(new DoubleBuffer(device_, size));
66+
}
67+
68+
void WgpuAllocator::ensure_gpu_data(Buffer& buffer) {
69+
auto* dbuf = static_cast<DoubleBuffer*>(buffer.ptr());
70+
if (dbuf->gpu_data() || dbuf->size() == 0)
71+
return;
72+
dbuf->set_gpu_data(device_.CreateBufferFromData(
73+
dbuf->cpu_data(), dbuf->size(), betann::BufferUsage::Storage));
74+
}
75+
76+
WgpuAllocator& allocator() {
77+
static WgpuAllocator allocator_;
78+
return allocator_;
79+
}
80+
81+
betann::Device& device(mlx::core::Device) {
82+
static betann::Device device;
83+
return device;
84+
}
85+
86+
} // namespace webgpu
87+
88+
namespace metal {
89+
90+
size_t get_active_memory() {
91+
return 0;
92+
}
93+
size_t get_peak_memory() {
94+
return 0;
95+
}
96+
void reset_peak_memory() {}
97+
size_t get_cache_memory() {
98+
return 0;
99+
}
100+
size_t set_memory_limit(size_t, bool) {
101+
return 0;
102+
}
103+
size_t set_cache_limit(size_t) {
104+
return 0;
105+
}
106+
size_t set_wired_limit(size_t) {
107+
return 0;
108+
}
109+
110+
std::unordered_map<std::string, std::variant<std::string, size_t>>
111+
device_info() {
112+
throw std::runtime_error("[webgpu::device_info] Not implemented");
113+
};
114+
115+
void clear_cache() {}
116+
117+
} // namespace metal
118+
119+
} // namespace mlx::core

mlx/backend/webgpu/allocator.h

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
// Copyright © 2025 Apple Inc.
2+
3+
#pragma once
4+
5+
#include "mlx/allocator.h"
6+
#include "mlx/device.h"
7+
8+
#include <betann/betann.h>
9+
10+
namespace mlx::core::webgpu {
11+
12+
using allocator::Buffer;
13+
14+
// Holds data for both CPU and GPU.
15+
class DoubleBuffer {
16+
public:
17+
// Allocates memory in CPU.
18+
explicit DoubleBuffer(size_t size);
19+
// Allocates memory in GPU.
20+
DoubleBuffer(betann::Device& device, size_t size);
21+
22+
~DoubleBuffer();
23+
24+
void copy_to_cpu(const void* data, size_t size);
25+
void set_gpu_data(betann::Buffer buffer) {
26+
gpu_data_ = std::move(buffer);
27+
}
28+
29+
void* cpu_data() const {
30+
return cpu_data_ ? static_cast<size_t*>(cpu_data_) + 1 : nullptr;
31+
}
32+
const betann::Buffer& gpu_data() const {
33+
return gpu_data_;
34+
}
35+
36+
size_t size() const;
37+
38+
private:
39+
void* cpu_data_ = nullptr;
40+
betann::Buffer gpu_data_;
41+
};
42+
43+
class WgpuAllocator : public allocator::Allocator {
44+
public:
45+
Buffer malloc(size_t size, bool allow_swap = false) override;
46+
void free(Buffer buffer) override;
47+
size_t size(Buffer buffer) const override;
48+
49+
Buffer gpu_malloc(size_t size);
50+
void ensure_gpu_data(Buffer& buffer);
51+
52+
private:
53+
WgpuAllocator();
54+
friend WgpuAllocator& allocator();
55+
56+
betann::Device& device_;
57+
};
58+
59+
WgpuAllocator& allocator();
60+
61+
betann::Device& device(mlx::core::Device);
62+
63+
} // namespace mlx::core::webgpu

mlx/backend/webgpu/metal.cpp

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
// Copyright © 2023-2024 Apple Inc.
2+
3+
#include <stdexcept>
4+
5+
#include "mlx/backend/metal/metal.h"
6+
#include "mlx/backend/metal/metal_impl.h"
7+
#include "mlx/backend/webgpu/allocator.h"
8+
#include "mlx/primitives.h"
9+
#include "mlx/scheduler.h"
10+
#include "mlx/utils.h"
11+
12+
namespace mlx::core::metal {
13+
14+
bool is_available() {
15+
return true;
16+
}
17+
18+
void new_stream(Stream) {}
19+
20+
std::function<void()> make_task(array arr, bool signal) {
21+
return [arr = std::move(arr), signal]() mutable {
22+
auto s = arr.primitive().stream();
23+
auto& device = webgpu::device(s.device);
24+
25+
for (auto& input : arr.inputs()) {
26+
if (input.event().valid() &&
27+
input.event().stream() != arr.primitive().stream()) {
28+
input.event().wait();
29+
}
30+
// Ensure all inputs copy their CPU data to GPU.
31+
webgpu::allocator().ensure_gpu_data(input.buffer());
32+
}
33+
34+
auto outputs = arr.outputs();
35+
{
36+
std::vector<array> inputs;
37+
if (arr.is_tracer()) {
38+
inputs = arr.inputs();
39+
}
40+
41+
try {
42+
arr.primitive().eval_gpu(arr.inputs(), outputs);
43+
} catch (const std::exception& error) {
44+
abort_with_exception(error);
45+
}
46+
}
47+
std::vector<std::shared_ptr<array::Data>> buffers;
48+
for (auto& in : arr.inputs()) {
49+
buffers.push_back(in.data_shared_ptr());
50+
}
51+
for (auto& s : arr.siblings()) {
52+
buffers.push_back(s.data_shared_ptr());
53+
}
54+
if (!arr.is_tracer()) {
55+
arr.detach();
56+
}
57+
for (auto& out : outputs) {
58+
out.set_status(array::Status::evaluated);
59+
}
60+
61+
// Copy data from GPU to CPU.
62+
// FIXME(zcbenz): Should only do it when necessary.
63+
if (arr.data_shared_ptr()) {
64+
auto* dbuf = static_cast<webgpu::DoubleBuffer*>(arr.buffer().ptr());
65+
if (dbuf->gpu_data() && !dbuf->cpu_data()) {
66+
device.Flush();
67+
wgpu::Buffer staging = device.CopyToStagingBuffer(dbuf->gpu_data());
68+
device.Flush();
69+
device.ReadStagingBuffer(
70+
staging,
71+
[arr, dbuf, buffers = std::move(buffers)](
72+
const void* data) mutable {
73+
dbuf->copy_to_cpu(data, dbuf->size());
74+
arr.reset_data_ptr();
75+
});
76+
}
77+
}
78+
79+
if (signal) {
80+
device.Flush();
81+
device.WaitAll();
82+
arr.event().signal();
83+
} else {
84+
device.OnSubmittedWorkDone([buffers = std::move(buffers)]() {});
85+
}
86+
};
87+
}
88+
89+
std::function<void()> make_synchronize_task(
90+
Stream s,
91+
std::shared_ptr<std::promise<void>> p) {
92+
return [s, p = std::move(p)]() {
93+
auto& device = webgpu::device(s.device);
94+
device.WaitAll();
95+
p->set_value();
96+
};
97+
}
98+
99+
void start_capture(std::string) {}
100+
void stop_capture() {}
101+
102+
} // namespace mlx::core::metal

0 commit comments

Comments
 (0)