Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .github/workflows/mlx.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,11 @@ jobs:
echo "::endgroup::"

echo "::group::Build test runners"
${CONDA_RUN} cmake --build cmake-out --target op_test_runner multi_thread_test_runner -j$(( $(sysctl -n hw.ncpu) - 1 ))
${CONDA_RUN} cmake --build cmake-out --target op_test_runner multi_thread_test_runner mlx_mutable_state_test -j$(( $(sysctl -n hw.ncpu) - 1 ))
echo "::endgroup::"

echo "::group::Run mutable-state (multi-session) unit test"
./cmake-out/backends/mlx/test/mlx_mutable_state_test
echo "::endgroup::"

echo "::group::Run op unit tests"
Expand Down
10 changes: 10 additions & 0 deletions backends/arm/quantizer/quantization_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,16 @@ def _get_fixed_qparams_qspec(
if _transpose_dimname is not None:
_one_to_one_shared_input_qspec.add(_transpose_dimname)

for _op in (
getattr(torch.ops.aten.moveaxis, "int", None),
getattr(torch.ops.aten.moveaxis, "intlist", None),
getattr(torch.ops.aten.movedim, "int", None),
getattr(torch.ops.aten.movedim, "intlist", None),
):
if _op is not None:
_one_to_one_shared_input_qspec.add(_op)


_one_to_one_shared_input_or_input_act_qspec: set[OpOverload] = {
torch.ops.aten.alias.default,
torch.ops.aten.clone.default,
Expand Down
17 changes: 17 additions & 0 deletions backends/arm/test/ops/test_permute.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,12 @@ def forward(self, x):
return torch.permute(x, self.dims)


class SimpleMoveAxis(torch.nn.Module):

def forward(self, x):
return torch.moveaxis(x, 1, -1)


@common.parametrize(
"test_data", test_data_suite | test_data_suite_fp16 | test_data_suite_bf16
)
Expand Down Expand Up @@ -118,6 +124,17 @@ def test_permute_u55_INT(test_data):
pipeline.run()


def test_moveaxis_u55_INT():
pipeline = EthosU55PipelineINT[input_t1](
SimpleMoveAxis(),
(torch.rand(1, 4, 5, 6),),
"torch.ops.aten.moveaxis.int",
exir_ops="executorch_exir_dialects_edge__ops_aten_permute_copy_default",
run_on_fvp=False,
)
pipeline.run()


@common.parametrize("test_data", test_data_suite_u55_reject)
def test_permute_u55_INT_not_delegated(test_data: torch.Tensor):
test_data, dims = test_data()
Expand Down
35 changes: 35 additions & 0 deletions backends/arm/test/quantizer/test_generic_annotater.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,41 @@ def test_transpose_tosa_INT():
)


def test_moveaxis_movedim_tosa_INT():
check_annotation(
SingleOpModel(
torch.moveaxis,
(torch.randn(2, 3, 4),),
source=1,
destination=-1,
),
)
check_annotation(
SingleOpModel(
torch.moveaxis,
(torch.randn(2, 3, 4),),
source=(0, 1),
destination=(-1, -2),
),
)
check_annotation(
SingleOpModel(
torch.movedim,
(torch.randn(2, 3, 4),),
source=1,
destination=-1,
),
)
check_annotation(
SingleOpModel(
torch.movedim,
(torch.randn(2, 3, 4),),
source=(0, 1),
destination=(-1, -2),
),
)


def test_tile_tosa_INT():
check_annotation(
SingleOpModel(torch.tile, (torch.randn(4, 4),), dims=(2,)),
Expand Down
6 changes: 4 additions & 2 deletions backends/mlx/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -255,8 +255,10 @@ option(ET_MLX_ALLOW_CUSTOM_KERNEL_EXECUTION
ON
)

set(_mlx_backend__srcs ${CMAKE_CURRENT_SOURCE_DIR}/runtime/MLXLoader.cpp
${CMAKE_CURRENT_SOURCE_DIR}/runtime/MLXBackend.cpp
set(_mlx_backend__srcs
${CMAKE_CURRENT_SOURCE_DIR}/runtime/MLXLoader.cpp
${CMAKE_CURRENT_SOURCE_DIR}/runtime/MLXBackend.cpp
${CMAKE_CURRENT_SOURCE_DIR}/runtime/mlx_mutable_state.cpp
)

add_library(mlxdelegate ${_mlx_backend__srcs})
Expand Down
37 changes: 37 additions & 0 deletions backends/mlx/custom_kernel_ops/gated_delta_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,15 @@ def gated_delta_rule(
B, T_len, Hk, Dk = q.shape
Hv, Dv = v.shape[-2:]

# The Metal kernel maps each v-head to its k-head group
# (hk_idx = hv_idx / (Hv / Hk)); mirror that here so the eager reference also
# supports Hk != Hv (GQA) instead of relying on broadcasting, which requires
# Hk == Hv. repeat_interleave on the head dim reproduces that index mapping.
if Hk != Hv:
q = q.repeat_interleave(Hv // Hk, dim=2)
k = k.repeat_interleave(Hv // Hk, dim=2)
Hk = Hv

s = state.clone()

ys = []
Expand Down Expand Up @@ -101,6 +110,7 @@ def gated_delta_rule_fake(
IntOrVid,
MetalKernelNode,
MultiplyNode,
RepeatNode,
ScanNode,
SubtractNode,
SumNode,
Expand Down Expand Up @@ -450,6 +460,33 @@ def _emit_scan(self, P: MLXProgramBuilder, n: Node) -> Slot:
]
)

# GQA: q/k carry Hk heads but the recurrence state/v have Hv heads. Expand
# q/k to Hv (repeat_interleave on the head axis) so the per-step broadcasts
# match, mirroring the Metal kernel's hk_idx = hv_idx / (Hv / Hk).
Hk = int(self.q_node.meta["val"].shape[-2])
Hv = int(self.v_node.meta["val"].shape[-2])
if Hk != Hv:
rep = IntOrVid.from_literal(Hv // Hk)
_, q_exp = P.make_tmp_slot()
P.emit(
RepeatNode(
x=P.slot_to_tid(q_slot),
out=P.slot_to_tid(q_exp),
repeats=rep,
axis=2,
)
)
_, k_exp = P.make_tmp_slot()
P.emit(
RepeatNode(
x=P.slot_to_tid(k_slot),
out=P.slot_to_tid(k_exp),
repeats=rep,
axis=2,
)
)
q_slot, k_slot = q_exp, k_exp

# Carry needs a writable slot. This is node n's persistent output (the
# mutated state), so it must be a node-owned slot — not a temp slot, whose
# id is reclaimed on tmp_scope exit and would be read as dead by a later
Expand Down
5 changes: 2 additions & 3 deletions backends/mlx/custom_kernel_ops/test/test_gated_delta_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,8 @@ def forward(
g: torch.Tensor, # [B, T, Hv]
beta: torch.Tensor, # [B, T, Hv]
) -> torch.Tensor:
if self.head_repeat > 1:
q = q.repeat_interleave(self.head_repeat, dim=2)
k = k.repeat_interleave(self.head_repeat, dim=2)
# Pass native Hk (no repeat_interleave): the op itself must handle
# GQA head expansion (kernel via hk_idx mapping, scan/eager internally).
return torch.ops.mlx.gated_delta_rule(
q, k, v, g, beta, self.state, use_custom_kernel=self.use_custom_kernel
)
Expand Down
59 changes: 59 additions & 0 deletions backends/mlx/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,8 @@
from executorch.exir.dialects._ops import ops as exir_ops
from torch.fx.node import Node

_LEAKY_RELU_DEFAULT_NEGATIVE_SLOPE = 0.01


def require_static_int(value: Any, param_name: str, op_name: str) -> None:
"""
Expand Down Expand Up @@ -2786,6 +2788,63 @@ def _relu_handler(P: MLXProgramBuilder, n: Node) -> Slot:
return out


@REGISTRY.register(target=[torch.ops.aten.leaky_relu.default])
def _leaky_relu_handler(P: MLXProgramBuilder, n: Node) -> Slot:
"""Handle aten.leaky_relu.default - leaky rectified linear unit.

leaky_relu(x) = x if x >= 0
= slope * x otherwise

Implemented as where(x >= 0, x, slope * x) so it stays correct for any
negative_slope (including values > 1), matching eager PyTorch.
"""
args = P.args(n)
require_args(args, 1, 2, "aten.leaky_relu")
require_kwargs(P.kwargs(n), set(), "aten.leaky_relu")

x = args[0]
negative_slope = _LEAKY_RELU_DEFAULT_NEGATIVE_SLOPE
if len(args) > 1 and args[1] is not None:
negative_slope = float(args[1])

x_meta = n.args[0].meta.get("val")
if x_meta is None:
raise ValueError("Input tensor metadata not found for leaky_relu")
dtype = x_meta.dtype

zero_slot = emit_lifted_constant(P, 0.0, dtype)
slope_slot = emit_lifted_constant(P, negative_slope, dtype)

_, cond_slot = P.make_tmp_slot()
P.emit(
GreaterEqualNode(
a=P.slot_to_tid(x),
b=P.slot_to_tid(zero_slot),
out=P.slot_to_tid(cond_slot),
)
)

_, scaled_slot = P.make_tmp_slot()
P.emit(
MultiplyNode(
a=P.slot_to_tid(slope_slot),
b=P.slot_to_tid(x),
out=P.slot_to_tid(scaled_slot),
)
)

out = P.make_or_get_slot(n)
P.emit(
WhereNode(
condition=P.slot_to_tid(cond_slot),
x=P.slot_to_tid(x),
y=P.slot_to_tid(scaled_slot),
out=P.slot_to_tid(out),
)
)
return out


@REGISTRY.register(target=[torch.ops.aten._log_softmax.default])
def _log_softmax_handler(P: MLXProgramBuilder, n: Node) -> Slot:
"""Handle aten._log_softmax.default - log of softmax.
Expand Down
16 changes: 16 additions & 0 deletions backends/mlx/runtime/MLXBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "MLXExecutor.h"
#include "MLXInterpreter.h"
#include "MLXLoader.h"
#include "mlx_mutable_state.h"

#include <executorch/runtime/backend/interface.h>
#include <executorch/runtime/core/error.h>
Expand Down Expand Up @@ -277,6 +278,12 @@ class MLXBackend final : public ::executorch::runtime::BackendInterface {
eval(handle->constants.tensors);
}

// Register the handle with the per-session mutable-state manager. This is
// a no-op unless a multi-session owner is active for this load (see
// mlx_mutable_state.h); single-session execution is unaffected.
mutable_state_note_handle(
handle, &handle->program, &handle->mutable_buffers);

} catch (const std::exception& e) {
ET_LOG(Error, "Failed to load MLX program: %s", e.what());
handle->~MLXHandle();
Expand Down Expand Up @@ -366,6 +373,14 @@ class MLXBackend final : public ::executorch::runtime::BackendInterface {
}
}

// Select the active session's mutable buffers (KV cache, recurrent/conv
// state) before running. No-op for single-session handles; weights stay
// shared via ExecutionState::constants.
if (Error rebind_err = mutable_state_rebind_for_execute(h, h->state);
rebind_err != Error::Ok) {
return rebind_err;
}

// Run the MLX program (builds lazy computation graph)
h->interpreter.run(program, h->state, h->stream);

Expand Down Expand Up @@ -431,6 +446,7 @@ class MLXBackend final : public ::executorch::runtime::BackendInterface {
void destroy(DelegateHandle* handle) const override {
std::lock_guard<std::mutex> lock(mlx_global_mutex());
if (handle != nullptr) {
mutable_state_forget_handle(handle);
auto* mlx_handle = static_cast<MLXHandle*>(handle);
mlx_handle->~MLXHandle();
}
Expand Down
Loading
Loading