From f79d8fafff1003c412deee82bb2a59b66cd408ad Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Fri, 6 Mar 2026 18:38:08 -0800 Subject: [PATCH] draft --- .Package.swift/backend_mlx/dummy.swift | 0 .Package.swift/backend_mlx_debug/dummy.swift | 0 CMakePresets.json | 3 +- Package.swift | 18 +- backends/mlx/CMakeLists.txt | 16 +- backends/mlx/examples/llm/export_llm_hf.py | 347 ++++++++++++++++++- backends/mlx/llm/cache.py | 14 +- backends/mlx/llm/quantization.py | 6 + backends/mlx/ops.py | 49 ++- backends/mlx/patches/mlx_ios_metal.patch | 196 +++++++++++ backends/mlx/runtime/MLXBackend.cpp | 7 + extension/llm/export/quantize.py | 14 +- scripts/build_apple_frameworks.sh | 72 +++- tools/cmake/preset/apple_common.cmake | 1 + 14 files changed, 699 insertions(+), 44 deletions(-) create mode 100644 .Package.swift/backend_mlx/dummy.swift create mode 100644 .Package.swift/backend_mlx_debug/dummy.swift create mode 100644 backends/mlx/patches/mlx_ios_metal.patch diff --git a/.Package.swift/backend_mlx/dummy.swift b/.Package.swift/backend_mlx/dummy.swift new file mode 100644 index 00000000000..e69de29bb2d diff --git a/.Package.swift/backend_mlx_debug/dummy.swift b/.Package.swift/backend_mlx_debug/dummy.swift new file mode 100644 index 00000000000..e69de29bb2d diff --git a/CMakePresets.json b/CMakePresets.json index fa1d77623d9..c4379aaaece 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -45,7 +45,8 @@ "CMAKE_TOOLCHAIN_FILE": "${sourceDir}/third-party/ios-cmake/ios.toolchain.cmake", "EXECUTORCH_BUILD_PRESET_FILE": "${sourceDir}/tools/cmake/preset/macos.cmake", "PLATFORM": "MAC_ARM64", - "DEPLOYMENT_TARGET": "12.0", + "DEPLOYMENT_TARGET": "14.0", + "CMAKE_OSX_DEPLOYMENT_TARGET": "14.0", "CMAKE_MACOSX_BUNDLE": "OFF" }, "condition": { diff --git a/Package.swift b/Package.swift index 3186284f5f6..195ce27bf0f 100644 --- a/Package.swift +++ b/Package.swift @@ -52,6 +52,13 @@ let products = deliverables([ "sqlite3", ], ], + "backend_mlx": [ + "frameworks": [ + "Metal", + "MetalPerformanceShaders", + ], + "forceLoad": true, + ], "backend_mps": [ "frameworks": [ "Metal", @@ -113,15 +120,20 @@ for (key, value) in products { name: key, path: "cmake-out/\(key).xcframework" )) + let forceLoad = value["forceLoad"] as? Bool ?? false + var linkerSettings: [LinkerSetting] = + (value["frameworks"] as? [String] ?? []).map { .linkedFramework($0) } + + (value["libraries"] as? [String] ?? []).map { .linkedLibrary($0) } + if forceLoad { + linkerSettings.append(.unsafeFlags(["-all_load"])) + } let target: Target = .target( name: "\(key)\(dependencies_suffix)", dependencies: ([key] + (value["targets"] as? [String] ?? []).map { key.hasSuffix(debug_suffix) ? $0 + debug_suffix : $0 }).map { .target(name: $0) }, path: ".Package.swift/\(key)", - linkerSettings: - (value["frameworks"] as? [String] ?? []).map { .linkedFramework($0) } + - (value["libraries"] as? [String] ?? []).map { .linkedLibrary($0) } + linkerSettings: linkerSettings ) packageTargets.append(target) } diff --git a/backends/mlx/CMakeLists.txt b/backends/mlx/CMakeLists.txt index 00e7c497b1c..3ab65e56f56 100644 --- a/backends/mlx/CMakeLists.txt +++ b/backends/mlx/CMakeLists.txt @@ -212,10 +212,24 @@ set(MLX_METAL_JIT CACHE BOOL "Use JIT compiled Metal kernels" ) +# For iOS builds, embed the metallib in the library so apps don't need to +# bundle it separately. This avoids the runtime "metallib not found" issue +# with static xcframeworks. +if(CMAKE_SYSTEM_NAME MATCHES "iOS") + set(MLX_EMBED_METALLIB + ON + CACHE BOOL "Embed metallib in static library for iOS" FORCE + ) + message(STATUS "iOS build: embedding metallib in static library") +endif() + # Auto-apply patches to MLX submodule. Each patch is applied idempotently: `git # apply --check` tests whether the patch is still applicable (i.e. not yet # applied), and only then applies it. -set(_mlx_patches "${CMAKE_CURRENT_SOURCE_DIR}/patches/mlx_json.patch") +# Note: mlx_ios_metal.patch includes the json changes from mlx_json.patch +set(_mlx_patches + "${CMAKE_CURRENT_SOURCE_DIR}/patches/mlx_ios_metal.patch" +) foreach(_patch IN LISTS _mlx_patches) if(EXISTS "${_patch}" AND EXISTS "${MLX_SOURCE_DIR}") get_filename_component(_patch_name "${_patch}" NAME) diff --git a/backends/mlx/examples/llm/export_llm_hf.py b/backends/mlx/examples/llm/export_llm_hf.py index 39f13e434be..c5e9233aa87 100644 --- a/backends/mlx/examples/llm/export_llm_hf.py +++ b/backends/mlx/examples/llm/export_llm_hf.py @@ -58,6 +58,8 @@ def _export_with_optimum( no_tie_word_embeddings: bool = False, qlinear_group_size: Optional[int] = None, qembedding_group_size: Optional[int] = None, + multimodal_only: bool = False, + nvfp4_per_tensor_scale: bool = False, ) -> None: import executorch.exir as exir from executorch.backends.mlx import MLXPartitioner @@ -89,20 +91,174 @@ def _export_with_optimum( exportable.model.config, "tie_word_embeddings", False ) and not no_tie_word_embeddings, + skip_incompatible_shapes=True, # Skip vision tower layers with odd shapes + nvfp4_per_tensor_scale=nvfp4_per_tensor_scale, ) logger.info("Exporting model with torch.export...") exported_progs = exportable.export() + if len(exported_progs) == 1: + exported_progs = {"forward": next(iter(exported_progs.values()))} + + # Skip forward if --multimodal-only is set + if multimodal_only and "forward" in exported_progs: + logger.info("Removing 'forward' export (--multimodal-only)") + del exported_progs["forward"] + + # Add multimodal export methods (token_embedding and text_decoder) + # for compatibility with MultimodalRunner + logger.info("Adding multimodal export methods...") + model = exportable.model + max_cache_len = exportable.metadata.get("get_max_seq_len", max_seq_len) + + torch_dtype_map = { + "float32": torch.float32, + "float16": torch.float16, + "bfloat16": torch.bfloat16, + } + torch_dtype = torch_dtype_map.get(dtype_str, torch.bfloat16) + + seq_length = 3 + example_input_ids = torch.zeros((1, seq_length), dtype=torch.long) + example_cache_position = torch.arange(seq_length, dtype=torch.long) + seq_len_dim = torch.export.Dim("seq_length_dim", max=max_cache_len - 1) + + with torch.no_grad(): + # Export token_embedding method + logger.info(" Exporting 'token_embedding' method...") + token_embedding_layer = model.get_input_embeddings() + token_embedding_dynamic_shapes = ({1: seq_len_dim},) + token_embedding_ep = torch.export.export( + token_embedding_layer, + args=(example_input_ids,), + dynamic_shapes=token_embedding_dynamic_shapes, + strict=True, + ) + exported_progs["token_embedding"] = token_embedding_ep + logger.info(" token_embedding export completed") + + # Export text_decoder method + logger.info(" Exporting 'text_decoder' method...") + # Handle nested configs (e.g., Gemma3 has text_config) + if hasattr(model.config, "text_config"): + hidden_size = model.config.text_config.hidden_size + else: + hidden_size = model.config.hidden_size + example_inputs_embeds = torch.zeros( + (1, seq_length, hidden_size), dtype=torch_dtype + ) + text_decoder_dynamic_shapes = { + "inputs_embeds": {1: seq_len_dim}, + "cache_position": {0: seq_len_dim}, + } + + class TextDecoderWrapper(torch.nn.Module): + def __init__(self, exportable_module): + super().__init__() + self.exportable = exportable_module + + def forward(self, inputs_embeds, cache_position): + if hasattr(self.exportable, "cache"): + cache = self.exportable.cache + elif hasattr(self.exportable, "static_cache"): + cache = self.exportable.static_cache + else: + cache = None + + outputs = self.exportable.model( + inputs_embeds=inputs_embeds, + cache_position=cache_position, + past_key_values=cache, + use_cache=True, + ) + return outputs.logits + + text_decoder_wrapper = TextDecoderWrapper(exportable) + text_decoder_ep = torch.export.export( + text_decoder_wrapper, + args=(), + kwargs={ + "inputs_embeds": example_inputs_embeds, + "cache_position": example_cache_position, + }, + dynamic_shapes=text_decoder_dynamic_shapes, + strict=True, + ) + exported_progs["text_decoder"] = text_decoder_ep + logger.info(" text_decoder export completed") + + # Export vision_encoder method (for multimodal models with vision tower) + if hasattr(model, "get_image_features") or hasattr(model, "vision_tower"): + logger.info(" Exporting 'vision_encoder' method...") + + class VisionEncoderWrapper(torch.nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, input_features): + image_embeds = self.model.get_image_features(input_features) + if isinstance(image_embeds, list): + image_embeds = torch.stack(image_embeds) + return image_embeds + + vision_encoder = VisionEncoderWrapper(model) + + try: + from transformers import AutoProcessor + + processor = AutoProcessor.from_pretrained(model_id) + sample_conversation = [ + { + "role": "user", + "content": [ + { + "type": "image", + "url": "https://llava-vl.github.io/static/images/view.jpg", + }, + ], + }, + ] + processed_inputs = processor.apply_chat_template( + sample_conversation, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + ) + if "pixel_values" in processed_inputs: + example_pixel_values = processed_inputs["pixel_values"].to( + dtype=torch_dtype + ) + logger.info( + f" Using pixel_values shape: {example_pixel_values.shape}" + ) + + vision_encoder_ep = torch.export.export( + vision_encoder, + args=(), + kwargs={"input_features": example_pixel_values}, + dynamic_shapes=None, + strict=True, + ) + exported_progs["vision_encoder"] = vision_encoder_ep + logger.info(" vision_encoder export completed") + else: + logger.warning( + " Skipping vision_encoder: processor didn't return pixel_values" + ) + except Exception as e: + logger.warning(f" Skipping vision_encoder export: {e}") + else: + logger.info(" Skipping vision_encoder: model has no vision tower") + logger.info("Delegating to MLX backend...") edge_config = EdgeCompileConfig( _check_ir_validity=False, _skip_dim_order=True, ) - if len(exported_progs) == 1: - exported_progs = {"forward": next(iter(exported_progs.values()))} - edge_program = exir.to_edge_transform_and_lower( exported_progs, transform_passes=get_default_passes(), @@ -134,6 +290,8 @@ def _export_with_custom_components( no_tie_word_embeddings: bool = False, qlinear_group_size: Optional[int] = None, qembedding_group_size: Optional[int] = None, + multimodal_only: bool = False, + nvfp4_per_tensor_scale: bool = False, ) -> None: """ Export using direct HF model with custom MLX components. @@ -276,6 +434,8 @@ def _export_with_custom_components( qembedding_group_size=qembedding_group_size, tie_word_embeddings=getattr(model.config, "tie_word_embeddings", False) and not no_tie_word_embeddings, + skip_incompatible_shapes=True, # Skip vision tower layers with odd shapes + nvfp4_per_tensor_scale=nvfp4_per_tensor_scale, ) logger.info("Exporting model with torch.export...") @@ -289,21 +449,163 @@ def _export_with_custom_components( "cache_position": {0: seq_len_dim}, } + exported_programs = {} + with torch.no_grad(): - exported_program = torch.export.export( - exportable, + # 1. Export "forward" method (BC for TextLLMRunner - takes input_ids) + # Skip if --multimodal-only is set (reduces model size ~2x) + if not multimodal_only: + logger.info("Exporting 'forward' method (input_ids -> logits)...") + forward_ep = torch.export.export( + exportable, + args=(), + kwargs={ + "input_ids": example_input_ids, + "cache_position": example_cache_position, + }, + dynamic_shapes=dynamic_shapes, + strict=True, + ) + exported_programs["forward"] = forward_ep + logger.info(" forward export completed") + else: + logger.info("Skipping 'forward' export (--multimodal-only)") + + # 2. Export "token_embedding" method (for MultimodalRunner) + logger.info("Exporting 'token_embedding' method (input_ids -> embeddings)...") + token_embedding_layer = model.get_input_embeddings() + token_embedding_dynamic_shapes = ({1: seq_len_dim},) + token_embedding_ep = torch.export.export( + token_embedding_layer, + args=(example_input_ids,), + dynamic_shapes=token_embedding_dynamic_shapes, + strict=True, + ) + exported_programs["token_embedding"] = token_embedding_ep + logger.info(" token_embedding export completed") + + # 3. Export "text_decoder" method (for MultimodalRunner - takes inputs_embeds) + logger.info("Exporting 'text_decoder' method (inputs_embeds -> logits)...") + # Handle nested configs (e.g., Gemma3 has text_config) + if hasattr(model.config, "text_config"): + hidden_size = model.config.text_config.hidden_size + else: + hidden_size = model.config.hidden_size + example_inputs_embeds = torch.zeros( + (1, seq_length, hidden_size), dtype=torch_dtype + ) + text_decoder_dynamic_shapes = { + "inputs_embeds": {1: seq_len_dim}, + "cache_position": {0: seq_len_dim}, + } + + # Create a wrapper that takes inputs_embeds instead of input_ids + class TextDecoderWrapper(torch.nn.Module): + def __init__(self, exportable_module): + super().__init__() + self.exportable = exportable_module + + def forward(self, inputs_embeds, cache_position): + # Get the cache from the exportable module + if hasattr(self.exportable, "cache"): + cache = self.exportable.cache + elif hasattr(self.exportable, "static_cache"): + cache = self.exportable.static_cache + else: + cache = None + + # Call model with inputs_embeds instead of input_ids + outputs = self.exportable.model( + inputs_embeds=inputs_embeds, + cache_position=cache_position, + past_key_values=cache, + use_cache=True, + ) + return outputs.logits + + text_decoder_wrapper = TextDecoderWrapper(exportable) + text_decoder_ep = torch.export.export( + text_decoder_wrapper, args=(), kwargs={ - "input_ids": example_input_ids, + "inputs_embeds": example_inputs_embeds, "cache_position": example_cache_position, }, - dynamic_shapes=dynamic_shapes, + dynamic_shapes=text_decoder_dynamic_shapes, strict=True, ) + exported_programs["text_decoder"] = text_decoder_ep + logger.info(" text_decoder export completed") + + # 4. Export "vision_encoder" method (for multimodal models with vision tower) + if hasattr(model, "get_image_features") or hasattr(model, "vision_tower"): + logger.info("Exporting 'vision_encoder' method (pixel_values -> image_embeds)...") + + class VisionEncoderWrapper(torch.nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, input_features): + image_embeds = self.model.get_image_features(input_features) + if isinstance(image_embeds, list): + image_embeds = torch.stack(image_embeds) + return image_embeds + + vision_encoder = VisionEncoderWrapper(model) + + # Get example input from processor + try: + from transformers import AutoProcessor + + processor = AutoProcessor.from_pretrained(model_id) + sample_conversation = [ + { + "role": "user", + "content": [ + { + "type": "image", + "url": "https://llava-vl.github.io/static/images/view.jpg", + }, + ], + }, + ] + processed_inputs = processor.apply_chat_template( + sample_conversation, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + ) + if "pixel_values" in processed_inputs: + example_pixel_values = processed_inputs["pixel_values"].to( + dtype=torch_dtype + ) + logger.info( + f" Using pixel_values shape: {example_pixel_values.shape}" + ) + + vision_encoder_ep = torch.export.export( + vision_encoder, + args=(), + kwargs={"input_features": example_pixel_values}, + dynamic_shapes=None, # No dynamic shapes for now + strict=True, + ) + exported_programs["vision_encoder"] = vision_encoder_ep + logger.info(" vision_encoder export completed") + else: + logger.warning( + " Skipping vision_encoder: processor didn't return pixel_values" + ) + except Exception as e: + logger.warning(f" Skipping vision_encoder export: {e}") + else: + logger.info(" Skipping vision_encoder: model has no vision tower") logger.info("Export completed successfully") - for sym, constraint in exported_program.range_constraints.items(): - logger.info(f" Range constraint: {sym}: {constraint}") + for name, ep in exported_programs.items(): + logger.info(f" {name}: {len(ep.range_constraints)} range constraints") logger.info("Delegating to MLX backend...") edge_config = EdgeCompileConfig( @@ -311,11 +613,22 @@ def _export_with_custom_components( _skip_dim_order=True, ) + # Build metadata methods for the etLLM app + metadata = { + "get_max_seq_len": effective_cache_len, + "get_max_context_len": effective_cache_len, + "use_kv_cache": True, + "use_sdpa_with_kv_cache": use_custom_sdpa, + "enable_dynamic_shape": True, + } + logger.info(f"Exporting with metadata: {metadata}") + edge_program = exir.to_edge_transform_and_lower( - {"forward": exported_program}, + exported_programs, transform_passes=get_default_passes(), partitioner=[MLXPartitioner()], compile_config=edge_config, + constant_methods=metadata, ) logger.info("Exporting to ExecuTorch...") @@ -351,6 +664,8 @@ def export_llama_hf( no_tie_word_embeddings: bool = False, qlinear_group_size: Optional[int] = None, qembedding_group_size: Optional[int] = None, + multimodal_only: bool = False, + nvfp4_per_tensor_scale: bool = False, ) -> None: """ Export a HuggingFace Llama model to ExecuTorch with MLX backend. @@ -382,6 +697,8 @@ def export_llama_hf( no_tie_word_embeddings=no_tie_word_embeddings, qlinear_group_size=qlinear_group_size, qembedding_group_size=qembedding_group_size, + multimodal_only=multimodal_only, + nvfp4_per_tensor_scale=nvfp4_per_tensor_scale, ) else: logger.info("Using optimum-executorch pipeline (no custom components)") @@ -395,6 +712,8 @@ def export_llama_hf( no_tie_word_embeddings=no_tie_word_embeddings, qlinear_group_size=qlinear_group_size, qembedding_group_size=qembedding_group_size, + multimodal_only=multimodal_only, + nvfp4_per_tensor_scale=nvfp4_per_tensor_scale, ) @@ -442,6 +761,12 @@ def main(): default=False, help="Use MLX custom KV cache (mlx::kv_cache_update)", ) + parser.add_argument( + "--multimodal-only", + action="store_true", + default=False, + help="Skip 'forward' export for multimodal models (reduces size ~2x)", + ) args = parser.parse_args() @@ -457,6 +782,8 @@ def main(): no_tie_word_embeddings=args.no_tie_word_embeddings, qlinear_group_size=args.qlinear_group_size, qembedding_group_size=args.qembedding_group_size, + multimodal_only=args.multimodal_only, + nvfp4_per_tensor_scale=getattr(args, "nvfp4_per_tensor_scale", False), ) diff --git a/backends/mlx/llm/cache.py b/backends/mlx/llm/cache.py index 9709980689b..c23a68c4a2c 100644 --- a/backends/mlx/llm/cache.py +++ b/backends/mlx/llm/cache.py @@ -326,14 +326,20 @@ def __init__( device: Device for cache tensors (default: None = CPU) dtype: Data type for cache tensors (default: torch.float32) """ + # Handle nested configs (e.g., Gemma3 has text_config) + if hasattr(config, "text_config"): + text_config = config.text_config + else: + text_config = config + # Resolve dimensions from config BEFORE calling parent - num_layers = config.num_hidden_layers - num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) + num_layers = text_config.num_hidden_layers + num_heads = getattr(text_config, "num_key_value_heads", text_config.num_attention_heads) head_dim = getattr( - config, "head_dim", config.hidden_size // config.num_attention_heads + text_config, "head_dim", text_config.hidden_size // text_config.num_attention_heads ) actual_max_cache_len = max_cache_len or getattr( - config, "max_position_embeddings", 2048 + text_config, "max_position_embeddings", 2048 ) # Initialize parent StaticCache with required arguments diff --git a/backends/mlx/llm/quantization.py b/backends/mlx/llm/quantization.py index 196e4a9ac1f..9f2c6c52be6 100644 --- a/backends/mlx/llm/quantization.py +++ b/backends/mlx/llm/quantization.py @@ -55,3 +55,9 @@ def add_quantization_args(parser: argparse.ArgumentParser) -> None: help="Disable tying lm_head weights to embedding after quantization, " "even if the model config has tie_word_embeddings=True", ) + parser.add_argument( + "--nvfp4-per-tensor-scale", + action="store_true", + default=False, + help="Enable per-tensor scale for NVFP4 quantization (improves accuracy)", + ) diff --git a/backends/mlx/ops.py b/backends/mlx/ops.py index 439d4569313..0e68ae073d2 100644 --- a/backends/mlx/ops.py +++ b/backends/mlx/ops.py @@ -2377,7 +2377,7 @@ def _convolution_handler(P: MLXProgramBuilder, n: Node) -> Slot: transposed, output_padding, groups) This op appears when PyTorch doesn't decompose to specific conv ops - (e.g. grouped conv_transpose). + (e.g. grouped conv_transpose, or vision models). """ raw_args = n.args x_node, w_node = raw_args[0], raw_args[1] @@ -2385,11 +2385,6 @@ def _convolution_handler(P: MLXProgramBuilder, n: Node) -> Slot: transposed = raw_args[6] if len(raw_args) > 6 else False groups = raw_args[8] if len(raw_args) > 8 else 1 - if not transposed: - raise ValueError( - "aten.convolution with transposed=False: use aten.conv{1,2,3}d instead" - ) - x_meta = x_node.meta.get("val") if x_meta is None: raise ValueError("aten.convolution: input shape metadata required") @@ -2402,19 +2397,35 @@ def _convolution_handler(P: MLXProgramBuilder, n: Node) -> Slot: raw_args[7] if len(raw_args) > 7 else 0, ndim, 0 ) - return _emit_conv_transpose( - P, - n, - x_node, - w_node, - bias_node, - stride, - padding, - dilation, - output_padding, - groups, - ndim, - ) + if transposed: + # Transposed convolution (conv_transpose) + return _emit_conv_transpose( + P, + n, + x_node, + w_node, + bias_node, + stride, + padding, + dilation, + output_padding, + groups, + ndim, + ) + else: + # Normal convolution - dispatch to conv handler + return _emit_conv( + P, + n, + x_node, + w_node, + bias_node, + stride, + padding, + dilation, + groups, + ndim, + ) @REGISTRY.register(target=[torch.ops.aten.conv_transpose1d.default]) diff --git a/backends/mlx/patches/mlx_ios_metal.patch b/backends/mlx/patches/mlx_ios_metal.patch new file mode 100644 index 00000000000..8d0ab8fb3a0 --- /dev/null +++ b/backends/mlx/patches/mlx_ios_metal.patch @@ -0,0 +1,196 @@ +diff --git a/CMakeLists.txt b/CMakeLists.txt +index 041a476c..afaf7658 100644 +--- a/CMakeLists.txt ++++ b/CMakeLists.txt +@@ -55,7 +55,7 @@ message( + "Building MLX for ${CMAKE_SYSTEM_PROCESSOR} processor on ${CMAKE_SYSTEM_NAME}" + ) + +-if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin") ++if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin" OR ${CMAKE_SYSTEM_NAME} MATCHES "iOS") + if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "x86_64") + if(NOT MLX_ENABLE_X64_MAC) + message( +@@ -190,17 +190,31 @@ if(MLX_BUILD_METAL) + set(METAL_CPP_URL + https://developer.apple.com/metal/cpp/files/metal-cpp_26.zip) + +- if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "") ++ if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "" AND NOT ${CMAKE_SYSTEM_NAME} MATCHES "iOS") + if(${CMAKE_OSX_DEPLOYMENT_TARGET} LESS 14.0) + message(FATAL_ERROR "MLX requires macOS >= 14.0") + endif() + set(XCRUN_FLAGS "-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}") + endif() +- execute_process( +- COMMAND +- zsh "-c" +- "echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'" +- OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY) ++ # Detect Metal version - use appropriate SDK for the target platform ++ if(${CMAKE_SYSTEM_NAME} MATCHES "iOS") ++ if(DEPLOYMENT_TARGET) ++ set(METAL_VERSION_FLAGS "-mios-version-min=${DEPLOYMENT_TARGET}") ++ else() ++ set(METAL_VERSION_FLAGS "-mios-version-min=17.0") ++ endif() ++ execute_process( ++ COMMAND ++ zsh "-c" ++ "echo \"__METAL_VERSION__\" | xcrun -sdk iphoneos metal ${METAL_VERSION_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'" ++ OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY) ++ else() ++ execute_process( ++ COMMAND ++ zsh "-c" ++ "echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'" ++ OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY) ++ endif() + FetchContent_Declare(metal_cpp URL ${METAL_CPP_URL}) + FetchContent_MakeAvailable(metal_cpp) + target_include_directories( +@@ -309,13 +323,19 @@ else() + set(MLX_BUILD_ACCELERATE OFF) + endif() + +-message(STATUS "Downloading json") +-FetchContent_Declare( +- json +- URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz) +-FetchContent_MakeAvailable(json) +-target_include_directories( +- mlx PRIVATE $) ++# Only fetch json if nlohmann_json target doesn't already exist ++# (ExecuTorch provides its own copy) ++if(NOT TARGET nlohmann_json) ++ message(STATUS "Downloading json") ++ FetchContent_Declare( ++ json ++ URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz) ++ FetchContent_MakeAvailable(json) ++ target_include_directories( ++ mlx PRIVATE $) ++else() ++ message(STATUS "Using existing nlohmann_json target") ++endif() + + add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx) + +diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp +index 15824d6c..babc14bb 100644 +--- a/mlx/backend/metal/device.cpp ++++ b/mlx/backend/metal/device.cpp +@@ -13,6 +13,11 @@ + #include "mlx/backend/metal/utils.h" + #include "mlx/utils.h" + ++#if defined(MLX_EMBED_METALLIB) ++#include "mlx_metallib_embedded.h" ++#include ++#endif ++ + namespace mlx::core::metal { + + namespace { +@@ -134,6 +139,26 @@ std::pair load_swiftpm_library( + } + + MTL::Library* load_default_library(MTL::Device* device) { ++#if defined(MLX_EMBED_METALLIB) ++ // Try embedded metallib first (for iOS static xcframework distribution) ++ { ++ dispatch_data_t data = dispatch_data_create( ++ mlx_metallib, ++ mlx_metallib_len, ++ nullptr, ++ DISPATCH_DATA_DESTRUCTOR_DEFAULT); ++ ++ NS::Error* error = nullptr; ++ MTL::Library* lib = device->newLibrary(data, &error); ++ dispatch_release(data); ++ ++ if (lib) { ++ return lib; ++ } ++ // Fall through to file-based search if embedded load fails ++ } ++#endif ++ + NS::Error* error[5]; + MTL::Library* lib; + // First try the colocated mlx.metallib +diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt +index 8d3d8a19..e6dd17cd 100644 +--- a/mlx/backend/metal/kernels/CMakeLists.txt ++++ b/mlx/backend/metal/kernels/CMakeLists.txt +@@ -9,6 +9,19 @@ set(BASE_HEADERS + logging.h + utils.h) + ++# Determine SDK and deployment target flags for Metal compilation ++if(${CMAKE_SYSTEM_NAME} MATCHES "iOS") ++ set(METAL_SDK "iphoneos") ++ if(DEPLOYMENT_TARGET) ++ set(METAL_TARGET_FLAGS "-mios-version-min=${DEPLOYMENT_TARGET}") ++ else() ++ set(METAL_TARGET_FLAGS "-mios-version-min=17.0") ++ endif() ++else() ++ set(METAL_SDK "macosx") ++ set(METAL_TARGET_FLAGS "") ++endif() ++ + function(build_kernel_base TARGET SRCFILE DEPS) + set(METAL_FLAGS + -x +@@ -24,12 +37,14 @@ function(build_kernel_base TARGET SRCFILE DEPS) + if(CMAKE_BUILD_TYPE STREQUAL "Debug" AND MLX_METAL_VERSION GREATER_EQUAL 320) + set(METAL_FLAGS ${METAL_FLAGS} -fmetal-enable-logging) + endif() +- if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "") ++ if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "" AND NOT ${CMAKE_SYSTEM_NAME} MATCHES "iOS") + set(METAL_FLAGS ${METAL_FLAGS} + "-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}") ++ elseif(METAL_TARGET_FLAGS) ++ set(METAL_FLAGS ${METAL_FLAGS} ${METAL_TARGET_FLAGS}) + endif() + add_custom_command( +- COMMAND xcrun -sdk macosx metal ${METAL_FLAGS} -c ${SRCFILE} ++ COMMAND xcrun -sdk ${METAL_SDK} metal ${METAL_FLAGS} -c ${SRCFILE} + -I${PROJECT_SOURCE_DIR} -o ${TARGET}.air + DEPENDS ${SRCFILE} ${DEPS} ${BASE_HEADERS} + OUTPUT ${TARGET}.air +@@ -176,7 +191,7 @@ endif() + + add_custom_command( + OUTPUT ${MLX_METAL_PATH}/mlx.metallib +- COMMAND xcrun -sdk macosx metallib ${KERNEL_AIR} -o ++ COMMAND xcrun -sdk ${METAL_SDK} metallib ${KERNEL_AIR} -o + ${MLX_METAL_PATH}/mlx.metallib + DEPENDS ${KERNEL_AIR} + COMMENT "Building mlx.metallib" +@@ -184,6 +199,24 @@ add_custom_command( + + add_custom_target(mlx-metallib DEPENDS ${MLX_METAL_PATH}/mlx.metallib) + ++# Generate embedded metallib header for iOS builds (no runtime file search needed) ++if(MLX_EMBED_METALLIB) ++ add_custom_command( ++ OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/mlx_metallib_embedded.h ++ COMMAND xxd -i mlx.metallib > ${CMAKE_CURRENT_BINARY_DIR}/mlx_metallib_embedded.h ++ WORKING_DIRECTORY ${MLX_METAL_PATH} ++ DEPENDS mlx-metallib ++ COMMENT "Generating embedded metallib header for iOS" ++ VERBATIM) ++ ++ add_custom_target(mlx-metallib-embedded ++ DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/mlx_metallib_embedded.h) ++ ++ add_dependencies(mlx mlx-metallib-embedded) ++ target_include_directories(mlx PRIVATE ${CMAKE_CURRENT_BINARY_DIR}) ++ target_compile_definitions(mlx PRIVATE MLX_EMBED_METALLIB=1) ++endif() ++ + add_dependencies(mlx mlx-metallib) + + # Install metallib diff --git a/backends/mlx/runtime/MLXBackend.cpp b/backends/mlx/runtime/MLXBackend.cpp index 99e20114ea7..cae9e89f640 100644 --- a/backends/mlx/runtime/MLXBackend.cpp +++ b/backends/mlx/runtime/MLXBackend.cpp @@ -13,7 +13,14 @@ #include #include #include + +// Suppress warnings from ExecuTorch headers that trigger -Wshorten-64-to-32 +// These are in core runtime headers and cannot be fixed in MLX backend +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wshorten-64-to-32" #include +#pragma clang diagnostic pop + #include #include diff --git a/extension/llm/export/quantize.py b/extension/llm/export/quantize.py index fb2678ff60f..ff2bf2412fc 100644 --- a/extension/llm/export/quantize.py +++ b/extension/llm/export/quantize.py @@ -25,7 +25,9 @@ def _make_granularity(group_size: int): return PerAxis(0) if group_size == 0 else PerGroup(group_size) -def _make_linear_config(config_name: str, group_size: int, packing_format=None): +def _make_linear_config( + config_name: str, group_size: int, packing_format=None, use_per_tensor_scale=False +): """Build a TorchAO config for linear layer quantization.""" from torchao.quantization.quant_api import ( Int4WeightOnlyConfig, @@ -39,7 +41,7 @@ def _make_linear_config(config_name: str, group_size: int, packing_format=None): from executorch.extension.llm.export.nvfp4 import ExportableNVFP4Config assert group_size == 16, "NVFP4 requires group_size=16" - return ExportableNVFP4Config(use_per_tensor_scale=False) + return ExportableNVFP4Config(use_per_tensor_scale=use_per_tensor_scale) elif config_name == "4w": if packing_format: return Int4WeightOnlyConfig( @@ -180,6 +182,7 @@ def quantize_model_( qembedding_group_size: Optional[int] = None, tie_word_embeddings: bool = False, skip_incompatible_shapes: bool = False, + nvfp4_per_tensor_scale: bool = False, ) -> None: """Quantize linear and embedding layers in a module in-place. @@ -204,6 +207,8 @@ def quantize_model_( after quantization. skip_incompatible_shapes: If True, silently skip layers with incompatible weight shapes. If False (default), raise RuntimeError. + nvfp4_per_tensor_scale: If True, enable per-tensor scale for NVFP4 + quantization which improves accuracy. """ if not qlinear_config and not qembedding_config: return @@ -256,7 +261,10 @@ def quantize_model_( # Quantize linear layers if qlinear_config: config = _make_linear_config( - qlinear_config, qlinear_group_size, qlinear_packing_format + qlinear_config, + qlinear_group_size, + qlinear_packing_format, + use_per_tensor_scale=nvfp4_per_tensor_scale, ) print( f" Applying {qlinear_config} linear quantization " diff --git a/scripts/build_apple_frameworks.sh b/scripts/build_apple_frameworks.sh index 63fa4cf4545..1ba28c35bf0 100755 --- a/scripts/build_apple_frameworks.sh +++ b/scripts/build_apple_frameworks.sh @@ -15,7 +15,7 @@ PRESETS_RELATIVE_OUT_DIR=("ios" "simulator" "macos") SOURCE_ROOT_DIR=$(git rev-parse --show-toplevel) OUTPUT_DIR="${SOURCE_ROOT_DIR}/cmake-out" -BUCK2=$(python3 "$SOURCE_ROOT_DIR/tools/cmake/resolve_buck.py" --cache_dir="$SOURCE_ROOT_DIR/buck2-bin") +BUCK2=$(python "$SOURCE_ROOT_DIR/tools/cmake/resolve_buck.py" --cache_dir="$SOURCE_ROOT_DIR/buck2-bin") if [[ "$BUCK2" == "buck2" ]]; then BUCK2=$(command -v buck2) fi @@ -103,6 +103,12 @@ FRAMEWORK_BACKEND_MPS="backend_mps:\ libmpsdelegate.a,\ :" +FRAMEWORK_BACKEND_MLX="backend_mlx:\ +libmlxdelegate.a,\ +libmlx.a,\ +:" +# Note: mlx.metallib resource is handled separately after XCFramework creation + FRAMEWORK_BACKEND_XNNPACK="backend_xnnpack:\ libXNNPACK.a,\ libkleidiai.a,\ @@ -138,8 +144,10 @@ usage() { echo "Options:" echo " --Debug Build Debug version." echo " --Release Build Release version." + echo " --ios-only Only build iOS (skip macOS and simulator)." echo " --coreml Only build the Core ML backend." echo " --llm Only build the LLM custom kernels." + echo " --mlx Only build the MLX backend." echo " --mps Only build the Metal Performance Shaders backend." echo " --optimized Only build the Optimized kernels." echo " --quantized Only build the Quantized kernels." @@ -150,6 +158,8 @@ usage() { } CMAKE_OPTIONS_OVERRIDE=() +IOS_ONLY=false + set_cmake_options_override() { local option_name="$1" @@ -158,6 +168,7 @@ set_cmake_options_override() { CMAKE_OPTIONS_OVERRIDE=( "-DEXECUTORCH_BUILD_COREML=OFF" "-DEXECUTORCH_BUILD_KERNELS_LLM=OFF" + "-DEXECUTORCH_BUILD_MLX=OFF" "-DEXECUTORCH_BUILD_MPS=OFF" "-DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=OFF" "-DEXECUTORCH_BUILD_KERNELS_QUANTIZED=OFF" @@ -177,18 +188,20 @@ set_cmake_options_override() { for arg in "$@"; do case $arg in -h|--help) usage ;; - --Release) + --Release|--release) if [[ ! " ${MODES[*]:-} " =~ \bRelease\b ]]; then MODES+=("Release") fi ;; - --Debug) + --Debug|--debug) if [[ ! " ${MODES[*]:-} " =~ \bDebug\b ]]; then MODES+=("Debug") fi ;; + --ios-only) IOS_ONLY=true ;; --coreml) set_cmake_options_override "EXECUTORCH_BUILD_COREML";; --llm) set_cmake_options_override "EXECUTORCH_BUILD_KERNELS_LLM" ;; + --mlx) set_cmake_options_override "EXECUTORCH_BUILD_MLX" ;; --mps) set_cmake_options_override "EXECUTORCH_BUILD_MPS" ;; --optimized) set_cmake_options_override "EXECUTORCH_BUILD_KERNELS_OPTIMIZED" ;; --quantized) set_cmake_options_override "EXECUTORCH_BUILD_KERNELS_QUANTIZED" ;; @@ -206,6 +219,13 @@ if [[ ${#MODES[@]} -eq 0 ]]; then MODES=("Release" "Debug") fi +# Filter presets based on --ios-only flag +if [[ "$IOS_ONLY" == "true" ]]; then + echo "iOS-only build: skipping macOS and simulator" + PRESETS=("ios") + PRESETS_RELATIVE_OUT_DIR=("ios") +fi + echo "Building libraries" rm -rf "${OUTPUT_DIR}" @@ -314,6 +334,7 @@ for mode in "${MODES[@]}"; do append_framework_flag "" "$FRAMEWORK_EXECUTORCH_LLM" "$mode" append_framework_flag "" "$FRAMEWORK_THREADPOOL" "$mode" append_framework_flag "EXECUTORCH_BUILD_COREML" "$FRAMEWORK_BACKEND_COREML" "$mode" + append_framework_flag "EXECUTORCH_BUILD_MLX" "$FRAMEWORK_BACKEND_MLX" "$mode" append_framework_flag "EXECUTORCH_BUILD_MPS" "$FRAMEWORK_BACKEND_MPS" "$mode" append_framework_flag "EXECUTORCH_BUILD_XNNPACK" "$FRAMEWORK_BACKEND_XNNPACK" "$mode" append_framework_flag "EXECUTORCH_BUILD_KERNELS_LLM" "$FRAMEWORK_KERNELS_LLM" "$mode" @@ -323,6 +344,51 @@ for mode in "${MODES[@]}"; do cd "${OUTPUT_DIR}" "$SOURCE_ROOT_DIR"/scripts/create_frameworks.sh "${FRAMEWORK_FLAGS[@]}" + + # Bundle mlx.metallib into the MLX XCFramework Resources folder (macOS only) + # For iOS, metallib is embedded in the static library (MLX_EMBED_METALLIB=ON) + # For macOS, we still need to copy it since embedding is not enabled + for cmake_option in "${CMAKE_OPTIONS_OVERRIDE[@]:-}"; do + if [[ "$cmake_option" =~ "-DEXECUTORCH_BUILD_MLX=OFF" ]]; then + echo "Skipping mlx.metallib bundling (MLX disabled)" + continue 2 + fi + done + + xcframework_name="backend_mlx" + if [[ "$mode" != "Release" ]]; then + xcframework_name="backend_mlx_$(echo "$mode" | tr '[:upper:]' '[:lower:]')" + fi + + if [[ -d "${OUTPUT_DIR}/${xcframework_name}.xcframework" ]]; then + echo "Bundling mlx.metallib into ${xcframework_name}.xcframework (macOS slices only)" + for slice_dir in "${OUTPUT_DIR}/${xcframework_name}.xcframework"/*; do + if [[ -d "$slice_dir" && ! "$slice_dir" =~ Info.plist ]]; then + slice_name=$(basename "$slice_dir") + + # Skip iOS slices - metallib is embedded in the static library + if [[ "$slice_name" =~ ^ios ]]; then + echo " Skipping $slice_name (metallib embedded in library)" + continue + fi + + # For macOS slices, copy the metallib + metallib_found=false + metallib_path="${OUTPUT_DIR}/macos/backends/mlx/mlx/mlx/backend/metal/kernels/mlx.metallib" + + if [[ -f "$metallib_path" ]]; then + mkdir -p "$slice_dir/Resources" + cp "$metallib_path" "$slice_dir/Resources/mlx.metallib" + echo " Copied mlx.metallib to $slice_dir/Resources/" + metallib_found=true + fi + + if [[ "$metallib_found" == "false" ]]; then + echo " Warning: mlx.metallib not found for slice $slice_name" + fi + fi + done + fi done echo "Cleaning up" diff --git a/tools/cmake/preset/apple_common.cmake b/tools/cmake/preset/apple_common.cmake index 27ec35aa43e..9c843f9985a 100644 --- a/tools/cmake/preset/apple_common.cmake +++ b/tools/cmake/preset/apple_common.cmake @@ -18,6 +18,7 @@ add_compile_options( set_overridable_option(BUILD_TESTING OFF) set_overridable_option(EXECUTORCH_BUILD_XNNPACK ON) set_overridable_option(EXECUTORCH_BUILD_COREML ON) +set_overridable_option(EXECUTORCH_BUILD_MLX ON) set_overridable_option(EXECUTORCH_BUILD_MPS ON) set_overridable_option(EXECUTORCH_XNNPACK_ENABLE_WEIGHT_CACHE ON) set_overridable_option(EXECUTORCH_XNNPACK_SHARED_WORKSPACE ON)