Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 162 additions & 0 deletions mlx/compile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,14 @@ bool is_ternary(const Primitive& p) {
return typeid(p) == typeid(Select);
}

bool is_expand_dims(const Primitive& p) {
return typeid(p) == typeid(ExpandDims);
}

bool is_squeeze(const Primitive& p) {
return typeid(p) == typeid(Squeeze);
}

bool is_broadcast(const Primitive& p) {
return typeid(p) == typeid(Broadcast);
}
Expand Down Expand Up @@ -557,6 +565,153 @@ struct VecU64Hash {
}
};

// Move view-only layout ops (currently ExpandDims/Squeeze) in front of
// elementwise ops to keep fusion opportunities.
bool compile_layout_mover(
std::vector<array>& tape,
ParentsMap& parents_map,
std::vector<array>& outputs) {
std::unordered_map<uintptr_t, array> old_to_new;
std::unordered_map<uintptr_t, std::vector<array>> inserted_nodes;
std::unordered_set<uintptr_t> skip_ids;
std::unordered_set<uintptr_t> output_ids;
output_ids.reserve(outputs.size());
for (auto& o : outputs) {
output_ids.insert(o.id());
for (auto& s : o.siblings()) {
output_ids.insert(s.id());
}
}

// Identify rewrites and stage new nodes.
for (auto& arr : tape) {
if (!arr.has_primitive()) {
continue;
}
if (output_ids.find(arr.id()) != output_ids.end()) {
continue;
}

auto& prim = arr.primitive();
bool is_view = is_expand_dims(prim) || is_squeeze(prim);
if (!is_view || arr.inputs().size() != 1 || !arr.siblings().empty()) {
continue;
}

auto parent = arr.inputs()[0];
if (!parent.has_primitive() || !parent.siblings().empty()) {
continue;
}
// Do not rewrite if the producer is a graph output.
if (output_ids.find(parent.id()) != output_ids.end()) {
continue;
}
auto& pprim = parent.primitive();
if (!(is_unary(pprim) || is_binary(pprim) || is_ternary(pprim)) ||
prim.stream() != pprim.stream()) {
continue;
}

auto pit = parents_map.find(parent.id());
if (pit == parents_map.end() || pit->second.size() != 1) {
continue;
}

bool shapes_match = !parent.inputs().empty();
for (auto& in : parent.inputs()) {
shapes_match &= (in.shape() == parent.shape());
}
if (!shapes_match) {
continue;
}

std::vector<array> new_inputs;
new_inputs.reserve(parent.inputs().size());
for (auto& in : parent.inputs()) {
Shape out_shape;
std::shared_ptr<Primitive> vprim;
if (is_expand_dims(prim)) {
auto axes = static_cast<const ExpandDims&>(prim).state();
out_shape = ExpandDims::output_shape(in, axes);
vprim = std::make_shared<ExpandDims>(prim.stream(), axes);
} else if (is_squeeze(prim)) {
auto axes = static_cast<const Squeeze&>(prim).state();
out_shape = Squeeze::output_shape(in, axes);
vprim = std::make_shared<Squeeze>(prim.stream(), axes);
}
new_inputs.emplace_back(
out_shape, in.dtype(), std::move(vprim), std::vector<array>{in});
}

array new_parent(
arr.shape(), parent.dtype(), parent.primitive_ptr(), new_inputs);

old_to_new.insert({arr.id(), new_parent});
old_to_new.insert({parent.id(), new_parent});

std::vector<array> staged;
staged.reserve(new_inputs.size() + 1);
for (auto& ni : new_inputs) {
staged.push_back(std::move(ni));
}
staged.push_back(std::move(new_parent));
inserted_nodes.insert({arr.id(), std::move(staged)});
skip_ids.insert(arr.id());
skip_ids.insert(parent.id());
}

// Helper to rewrite an array if it was replaced.
auto apply_mapping = [&](array& a) {
auto it = old_to_new.find(a.id());
if (it != old_to_new.end()) {
a = it->second;
}
};

// Build the new tape, inserting staged nodes where we skipped.
std::vector<array> new_tape;
new_tape.reserve(tape.size() + inserted_nodes.size());
for (auto& arr : tape) {
if (skip_ids.find(arr.id()) != skip_ids.end()) {
auto it = inserted_nodes.find(arr.id());
if (it != inserted_nodes.end()) {
for (auto& n : it->second) {
new_tape.push_back(std::move(n));
}
}
continue;
}
new_tape.push_back(std::move(arr));
}

// Apply replacements to any nodes we kept.
for (auto& a : new_tape) {
for (auto& in : a.inputs()) {
apply_mapping(in);
}
}

// Rebuild parents map from scratch to stay consistent.
ParentsMap new_parents_map;
for (auto& a : new_tape) {
auto outs = a.outputs();
for (auto& o : outs) {
for (int i = 0; i < o.inputs().size(); ++i) {
new_parents_map[o.inputs()[i].id()].push_back({o, i});
}
}
}
parents_map = std::move(new_parents_map);

// Update outputs.
for (auto& o : outputs) {
apply_mapping(o);
}

tape = std::move(new_tape);
return !old_to_new.empty();
}

// Simplify the tape. Note, this function modifies in-place both the tape,
// the parents map to remove orphaned arrays, and potentially the outputs
void compile_simplify(
Expand Down Expand Up @@ -660,6 +815,13 @@ void compile_simplify(
}
}

// Move simple view-only layout ops ahead of elementwise ops to help fusion.
for (int i = 0; i < max_compile_depth; ++i) {
if (!compile_layout_mover(tape, parents_map, outputs)) {
break;
}
}

std::unordered_map<std::uintptr_t, uint32_t> tape_order;
for (uint32_t i = 0; i < tape.size(); ++i) {
tape_order.insert({tape[i].id(), i});
Expand Down
67 changes: 67 additions & 0 deletions tests/compile_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,55 @@ TEST_CASE("test simplify noops") {
set_compile_mode(CompileMode::enabled);
}

TEST_CASE("test layout mover pushes views across multiple elementwise ops") {
set_compile_mode(CompileMode::no_fuse);
auto fun = [](const std::vector<array>& inputs) -> std::vector<array> {
auto added = inputs[0] + inputs[1];
auto absed = abs(added);
auto expanded = expand_dims(absed, 0);
return {sin(expanded)};
};
auto x = array({1.0f, 2.0f});
auto y = array({3.0f, 4.0f});
auto out = compile(fun)({x, y})[0];
auto abs_node = out.inputs()[0];
auto& abs_prim = abs_node.primitive();
CHECK(typeid(abs_prim) == typeid(Abs));
CHECK_EQ(abs_node.inputs().size(), 1);
CHECK(abs_node.inputs()[0].has_primitive());

auto add_node = abs_node.inputs()[0];
auto& add_prim = add_node.primitive();
CHECK(typeid(add_prim) == typeid(Add));
CHECK_EQ(add_node.inputs().size(), 2);
CHECK(add_node.inputs()[0].has_primitive());
auto& in0_prim = add_node.inputs()[0].primitive();
CHECK(typeid(in0_prim) == typeid(ExpandDims));
CHECK(add_node.inputs()[1].has_primitive());
auto& in1_prim = add_node.inputs()[1].primitive();
CHECK(typeid(in1_prim) == typeid(ExpandDims));
set_compile_mode(CompileMode::enabled);
}

TEST_CASE("test layout mover skips graph outputs") {
set_compile_mode(CompileMode::no_fuse);
auto fun = [](const std::vector<array>& inputs) -> std::vector<array> {
auto added = inputs[0] + inputs[1];
return {expand_dims(added, 0)};
};
auto x = array({1.0f, 2.0f});
auto y = array({3.0f, 4.0f});
auto out = compile(fun)({x, y})[0];
auto& out_prim = out.primitive();
CHECK(typeid(out_prim) == typeid(ExpandDims));
auto parent = out.inputs()[0];
auto& parent_prim = parent.primitive();
CHECK(typeid(parent_prim) == typeid(Add));
CHECK_FALSE(parent.inputs()[0].has_primitive());
CHECK_FALSE(parent.inputs()[1].has_primitive());
set_compile_mode(CompileMode::enabled);
}

auto add_diff(const std::vector<array>& inputs) {
auto a = inputs[0];
return std::vector<array>{cos(a) + sin(a)};
Expand Down Expand Up @@ -263,6 +312,12 @@ auto unary_fused_3(const std::vector<array>& inputs) {
return std::vector<array>{exp(abs(negative(sum(inputs[0], true))))};
}

auto unary_fused_layout(const std::vector<array>& inputs) {
auto added = inputs[0] + inputs[1];
auto expanded = expand_dims(added, 0);
return std::vector<array>{sin(expanded)};
}

TEST_CASE("test compile unary fused") {
// NB: some of these tests are brittle and may need to be
// updated if we change compile conditions
Expand Down Expand Up @@ -327,6 +382,18 @@ TEST_CASE("test compile unary fused") {
auto out3 = compile(unary_fused_1_diff)({array(1.0)});
CHECK(!out1[0].primitive().is_equivalent(out3[0].primitive()));
}

// Layout ops should not block fusion
{
auto cfun = compile(unary_fused_layout);
auto x = array({1.0f, 2.0f});
auto y = array({3.0f, 4.0f});
auto out = cfun({x, y})[0];
auto& prim = out.primitive();
CHECK_EQ(typeid(prim), typeid(Compiled));
auto expected = unary_fused_layout({x, y})[0];
CHECK(allclose(out, expected).item<bool>());
}
}

// All compilable
Expand Down