diff --git a/mlx/compile.cpp b/mlx/compile.cpp index ca5f069937..951ce7f04a 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -60,6 +60,14 @@ bool is_ternary(const Primitive& p) { return typeid(p) == typeid(Select); } +bool is_expand_dims(const Primitive& p) { + return typeid(p) == typeid(ExpandDims); +} + +bool is_squeeze(const Primitive& p) { + return typeid(p) == typeid(Squeeze); +} + bool is_broadcast(const Primitive& p) { return typeid(p) == typeid(Broadcast); } @@ -557,6 +565,153 @@ struct VecU64Hash { } }; +// Move view-only layout ops (currently ExpandDims/Squeeze) in front of +// elementwise ops to keep fusion opportunities. +bool compile_layout_mover( + std::vector& tape, + ParentsMap& parents_map, + std::vector& outputs) { + std::unordered_map old_to_new; + std::unordered_map> inserted_nodes; + std::unordered_set skip_ids; + std::unordered_set output_ids; + output_ids.reserve(outputs.size()); + for (auto& o : outputs) { + output_ids.insert(o.id()); + for (auto& s : o.siblings()) { + output_ids.insert(s.id()); + } + } + + // Identify rewrites and stage new nodes. + for (auto& arr : tape) { + if (!arr.has_primitive()) { + continue; + } + if (output_ids.find(arr.id()) != output_ids.end()) { + continue; + } + + auto& prim = arr.primitive(); + bool is_view = is_expand_dims(prim) || is_squeeze(prim); + if (!is_view || arr.inputs().size() != 1 || !arr.siblings().empty()) { + continue; + } + + auto parent = arr.inputs()[0]; + if (!parent.has_primitive() || !parent.siblings().empty()) { + continue; + } + // Do not rewrite if the producer is a graph output. + if (output_ids.find(parent.id()) != output_ids.end()) { + continue; + } + auto& pprim = parent.primitive(); + if (!(is_unary(pprim) || is_binary(pprim) || is_ternary(pprim)) || + prim.stream() != pprim.stream()) { + continue; + } + + auto pit = parents_map.find(parent.id()); + if (pit == parents_map.end() || pit->second.size() != 1) { + continue; + } + + bool shapes_match = !parent.inputs().empty(); + for (auto& in : parent.inputs()) { + shapes_match &= (in.shape() == parent.shape()); + } + if (!shapes_match) { + continue; + } + + std::vector new_inputs; + new_inputs.reserve(parent.inputs().size()); + for (auto& in : parent.inputs()) { + Shape out_shape; + std::shared_ptr vprim; + if (is_expand_dims(prim)) { + auto axes = static_cast(prim).state(); + out_shape = ExpandDims::output_shape(in, axes); + vprim = std::make_shared(prim.stream(), axes); + } else if (is_squeeze(prim)) { + auto axes = static_cast(prim).state(); + out_shape = Squeeze::output_shape(in, axes); + vprim = std::make_shared(prim.stream(), axes); + } + new_inputs.emplace_back( + out_shape, in.dtype(), std::move(vprim), std::vector{in}); + } + + array new_parent( + arr.shape(), parent.dtype(), parent.primitive_ptr(), new_inputs); + + old_to_new.insert({arr.id(), new_parent}); + old_to_new.insert({parent.id(), new_parent}); + + std::vector staged; + staged.reserve(new_inputs.size() + 1); + for (auto& ni : new_inputs) { + staged.push_back(std::move(ni)); + } + staged.push_back(std::move(new_parent)); + inserted_nodes.insert({arr.id(), std::move(staged)}); + skip_ids.insert(arr.id()); + skip_ids.insert(parent.id()); + } + + // Helper to rewrite an array if it was replaced. + auto apply_mapping = [&](array& a) { + auto it = old_to_new.find(a.id()); + if (it != old_to_new.end()) { + a = it->second; + } + }; + + // Build the new tape, inserting staged nodes where we skipped. + std::vector new_tape; + new_tape.reserve(tape.size() + inserted_nodes.size()); + for (auto& arr : tape) { + if (skip_ids.find(arr.id()) != skip_ids.end()) { + auto it = inserted_nodes.find(arr.id()); + if (it != inserted_nodes.end()) { + for (auto& n : it->second) { + new_tape.push_back(std::move(n)); + } + } + continue; + } + new_tape.push_back(std::move(arr)); + } + + // Apply replacements to any nodes we kept. + for (auto& a : new_tape) { + for (auto& in : a.inputs()) { + apply_mapping(in); + } + } + + // Rebuild parents map from scratch to stay consistent. + ParentsMap new_parents_map; + for (auto& a : new_tape) { + auto outs = a.outputs(); + for (auto& o : outs) { + for (int i = 0; i < o.inputs().size(); ++i) { + new_parents_map[o.inputs()[i].id()].push_back({o, i}); + } + } + } + parents_map = std::move(new_parents_map); + + // Update outputs. + for (auto& o : outputs) { + apply_mapping(o); + } + + tape = std::move(new_tape); + return !old_to_new.empty(); +} + // Simplify the tape. Note, this function modifies in-place both the tape, // the parents map to remove orphaned arrays, and potentially the outputs void compile_simplify( @@ -660,6 +815,13 @@ void compile_simplify( } } + // Move simple view-only layout ops ahead of elementwise ops to help fusion. + for (int i = 0; i < max_compile_depth; ++i) { + if (!compile_layout_mover(tape, parents_map, outputs)) { + break; + } + } + std::unordered_map tape_order; for (uint32_t i = 0; i < tape.size(); ++i) { tape_order.insert({tape[i].id(), i}); diff --git a/tests/compile_tests.cpp b/tests/compile_tests.cpp index e65cfc76f6..6be7b8a780 100644 --- a/tests/compile_tests.cpp +++ b/tests/compile_tests.cpp @@ -168,6 +168,55 @@ TEST_CASE("test simplify noops") { set_compile_mode(CompileMode::enabled); } +TEST_CASE("test layout mover pushes views across multiple elementwise ops") { + set_compile_mode(CompileMode::no_fuse); + auto fun = [](const std::vector& inputs) -> std::vector { + auto added = inputs[0] + inputs[1]; + auto absed = abs(added); + auto expanded = expand_dims(absed, 0); + return {sin(expanded)}; + }; + auto x = array({1.0f, 2.0f}); + auto y = array({3.0f, 4.0f}); + auto out = compile(fun)({x, y})[0]; + auto abs_node = out.inputs()[0]; + auto& abs_prim = abs_node.primitive(); + CHECK(typeid(abs_prim) == typeid(Abs)); + CHECK_EQ(abs_node.inputs().size(), 1); + CHECK(abs_node.inputs()[0].has_primitive()); + + auto add_node = abs_node.inputs()[0]; + auto& add_prim = add_node.primitive(); + CHECK(typeid(add_prim) == typeid(Add)); + CHECK_EQ(add_node.inputs().size(), 2); + CHECK(add_node.inputs()[0].has_primitive()); + auto& in0_prim = add_node.inputs()[0].primitive(); + CHECK(typeid(in0_prim) == typeid(ExpandDims)); + CHECK(add_node.inputs()[1].has_primitive()); + auto& in1_prim = add_node.inputs()[1].primitive(); + CHECK(typeid(in1_prim) == typeid(ExpandDims)); + set_compile_mode(CompileMode::enabled); +} + +TEST_CASE("test layout mover skips graph outputs") { + set_compile_mode(CompileMode::no_fuse); + auto fun = [](const std::vector& inputs) -> std::vector { + auto added = inputs[0] + inputs[1]; + return {expand_dims(added, 0)}; + }; + auto x = array({1.0f, 2.0f}); + auto y = array({3.0f, 4.0f}); + auto out = compile(fun)({x, y})[0]; + auto& out_prim = out.primitive(); + CHECK(typeid(out_prim) == typeid(ExpandDims)); + auto parent = out.inputs()[0]; + auto& parent_prim = parent.primitive(); + CHECK(typeid(parent_prim) == typeid(Add)); + CHECK_FALSE(parent.inputs()[0].has_primitive()); + CHECK_FALSE(parent.inputs()[1].has_primitive()); + set_compile_mode(CompileMode::enabled); +} + auto add_diff(const std::vector& inputs) { auto a = inputs[0]; return std::vector{cos(a) + sin(a)}; @@ -263,6 +312,12 @@ auto unary_fused_3(const std::vector& inputs) { return std::vector{exp(abs(negative(sum(inputs[0], true))))}; } +auto unary_fused_layout(const std::vector& inputs) { + auto added = inputs[0] + inputs[1]; + auto expanded = expand_dims(added, 0); + return std::vector{sin(expanded)}; +} + TEST_CASE("test compile unary fused") { // NB: some of these tests are brittle and may need to be // updated if we change compile conditions @@ -327,6 +382,18 @@ TEST_CASE("test compile unary fused") { auto out3 = compile(unary_fused_1_diff)({array(1.0)}); CHECK(!out1[0].primitive().is_equivalent(out3[0].primitive())); } + + // Layout ops should not block fusion + { + auto cfun = compile(unary_fused_layout); + auto x = array({1.0f, 2.0f}); + auto y = array({3.0f, 4.0f}); + auto out = cfun({x, y})[0]; + auto& prim = out.primitive(); + CHECK_EQ(typeid(prim), typeid(Compiled)); + auto expected = unary_fused_layout({x, y})[0]; + CHECK(allclose(out, expected).item()); + } } // All compilable