-
Notifications
You must be signed in to change notification settings - Fork 799
[draft / not ready for review] Add prefill/decode multifunction support in ET #16552
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/16552
Note: Links to docs will display an error until the docs builds have been completed. ❌ 6 New Failures, 1 Unrelated FailureAs of commit 2b3decd with merge base a0ba28e ( NEW FAILURES - The following jobs have failed:
UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
|
@metascroy has imported this pull request. If you are a Meta employee, you can view this in D90716824. |
Summary:
This diff adds multifunction export support for static Llama models on CoreML. Multifunction models export separate prefill and decode graphs with weight sharing, enabling more efficient autoregressive generation compared to the single-method approach.
### Key Changes
**CoreML Backend Compiler (`coreml_preprocess.py`)**
- Added `MULTIMETHOD_WEIGHT_SHARING_STRATEGY` enum with `NONE` and `POSITIONAL` strategies
- Added `generate_multimethod_weight_sharing_strategy_compile_spec()` to enable weight sharing across methods
- Implemented multifunction CoreML model compilation using `ct.utils.MultiFunctionDescriptor`
- When weight sharing is enabled, weights from the first method are shared positionally with subsequent methods
**Model Metadata (`model_metadata.h`, `serde_json.mm`)**
- Added `MethodMetadata` struct to store per-method input/output names for multifunction models
- Extended `ModelMetadata` with `methods` map and `default_method` field
- Added `is_multifunction()` helper to detect multifunction models
- Updated JSON serialization to handle the new multifunction metadata format
**Runtime Changes (`ETCoreMLModelManager.mm`, `backend_delegate.mm`, `coreml_backend_delegate.mm`)**
- Updated `ETCoreMLModelManager` to set `functionName` on `MLModelConfiguration` only for multifunction models (based on `metadata.is_multifunction()`)
- Legacy single-function models continue to work with `functionName=nil`
- Added method name propagation through the delegate initialization path
- Updated model loading to use per-method input/output names when available
**Export Script (`export_static_llm_coreml.py`)**
- Added `--multifunction` flag to export models with separate prefill (seqlen=input_len) and decode (seqlen=1) methods
- Multifunction mode uses `generate_full_logits=False` for efficiency (only outputs last token logits)
- Single method mode (default) retains `generate_full_logits=True` for lookahead decoding support
- Generates combined metadata with method-specific prefixes (e.g., `decode_input_len`, `prefill_input_len`)
**New Runner (`run_static_llm_multifunction.py`)**
- Added dedicated runner for multifunction models
- Handles separate prefill and decode method execution
- Manages cache state transfer between prefill and decode phases
- Supports both 2D (generate_full_logits=False) and 3D (generate_full_logits=True) logits output
**Build System (`CMakeLists.txt`)**
- Fixed installation of CoreML backend headers
**Utilities (`extract_coreml_models.py`)**
- Updated model extraction script to handle multifunction models
**Documentation (`README.md`)**
- Added documentation for both export modes (single method and multifunction)
- Added comprehensive export options reference table
- Added usage examples for both modes
### Usage Examples
**Single Method Export (for lookahead decoding):**
```bash
python examples/apple/coreml/llama/export_static_llm_coreml.py \
--checkpoint $HOME/models/llama1b/llama1b.pth \
--params $HOME/models/llama1b/params.json \
--output static_llm_coreml_model.pte \
--input_len 32 \
--max_context_len 1024
```
**Multifunction Export (separate prefill/decode):**
```bash
python examples/apple/coreml/llama/export_static_llm_coreml.py \
--checkpoint $HOME/models/llama1b/llama1b.pth \
--params $HOME/models/llama1b/params.json \
--output static_llm_coreml_multifunction.pte \
--input_len 64 \
--max_context_len 1024 \
--multifunction
```
**Run Single Method Model (with lookahead):**
```bash
python examples/apple/coreml/llama/run_static_llm.py \
--model static_llm_coreml_model.pte \
--params $HOME/models/llama1b/params.json \
--tokenizer $HOME/models/llama1b/tokenizer.model \
--prompt "Once upon a time" \
--max_new_tokens 100 \
--lookahead
```
**Run Multifunction Model:**
```bash
python examples/apple/coreml/llama/run_static_llm_multifunction.py \
--model static_llm_coreml_multifunction.pte \
--params $HOME/models/llama1b/params.json \
--tokenizer $HOME/models/llama1b/tokenizer.model \
--prompt "Once upon a time" \
--max_new_tokens 100 \
--input_len 64 \
--max_context_len 1024
```
### Mode Comparison
| Feature | Single Method | Multifunction |
|---------|---------------|---------------|
| Sequence length | Fixed (input_len for both prefill & decode) | Separate (input_len for prefill, 1 for decode) |
| Logits output | Full (all tokens) | Last token only |
| Lookahead decoding | ✅ Supported | ❌ Not supported |
| Weight sharing | N/A | ✅ Enabled |
| Generation efficiency | Good with lookahead | Optimized decode step |
Test Plan:
New unit test +
Tested both export modes on Llama 1B:
1. Exported single method model with `--input_len 32 --max_context_len 1024`
2. Exported multifunction model with `--input_len 64 --max_context_len 1024 --multifunction`
3. Ran single method model with `--lookahead` flag
4. Ran multifunction model with matching input_len and max_context_len
5. Verified text generation produces coherent output for both modes
Differential Revision: D90716824
Pulled By: metascroy
a9a71a5 to
2b3decd
Compare
|
@metascroy has exported this pull request. If you are a Meta employee, you can view the originating Diff in D90716824. |
Summary
This diff adds multifunction export support for static Llama models on CoreML. Multifunction models export separate prefill and decode graphs with weight sharing, enabling more efficient autoregressive generation compared to the single-method approach.
Key Changes
CoreML Backend Compiler (
coreml_preprocess.py)MULTIMETHOD_WEIGHT_SHARING_STRATEGYenum withNONEandPOSITIONALstrategiesgenerate_multimethod_weight_sharing_strategy_compile_spec()to enable weight sharing across methodsct.utils.MultiFunctionDescriptorModel Metadata (
model_metadata.h,serde_json.mm)MethodMetadatastruct to store per-method input/output names for multifunction modelsModelMetadatawithmethodsmap anddefault_methodfieldis_multifunction()helper to detect multifunction modelsRuntime Changes (
ETCoreMLModelManager.mm,backend_delegate.mm,coreml_backend_delegate.mm)ETCoreMLModelManagerto setfunctionNameonMLModelConfigurationonly for multifunction models (based onmetadata.is_multifunction())functionName=nilExport Script (
export_static_llm_coreml.py)--multifunctionflag to export models with separate prefill (seqlen=input_len) and decode (seqlen=1) methodsgenerate_full_logits=Falsefor efficiency (only outputs last token logits)generate_full_logits=Truefor lookahead decoding supportdecode_input_len,prefill_input_len)New Runner (
run_static_llm_multifunction.py)Build System (
CMakeLists.txt)Utilities (
extract_coreml_models.py)Documentation (
README.md)Usage Examples
Single Method Export (for lookahead decoding):
python examples/apple/coreml/llama/export_static_llm_coreml.py \ --checkpoint $HOME/models/llama1b/llama1b.pth \ --params $HOME/models/llama1b/params.json \ --output static_llm_coreml_model.pte \ --input_len 32 \ --max_context_len 1024Multifunction Export (separate prefill/decode):
python examples/apple/coreml/llama/export_static_llm_coreml.py \ --checkpoint $HOME/models/llama1b/llama1b.pth \ --params $HOME/models/llama1b/params.json \ --output static_llm_coreml_multifunction.pte \ --input_len 64 \ --max_context_len 1024 \ --multifunctionRun Single Method Model (with lookahead):
python examples/apple/coreml/llama/run_static_llm.py \ --model static_llm_coreml_model.pte \ --params $HOME/models/llama1b/params.json \ --tokenizer $HOME/models/llama1b/tokenizer.model \ --prompt "Once upon a time" \ --max_new_tokens 100 \ --lookaheadRun Multifunction Model:
python examples/apple/coreml/llama/run_static_llm_multifunction.py \ --model static_llm_coreml_multifunction.pte \ --params $HOME/models/llama1b/params.json \ --tokenizer $HOME/models/llama1b/tokenizer.model \ --prompt "Once upon a time" \ --max_new_tokens 100 \ --input_len 64 \ --max_context_len 1024Mode Comparison
Test Plan
New unit test +
Tested both export modes on Llama 1B:
--input_len 32 --max_context_len 1024--input_len 64 --max_context_len 1024 --multifunction--lookaheadflag