diff --git a/examples/speculative_decoding/README.md b/examples/speculative_decoding/README.md index 2a29f644..fe74d1a6 100644 --- a/examples/speculative_decoding/README.md +++ b/examples/speculative_decoding/README.md @@ -89,7 +89,9 @@ For large models, you can export intermediate hidden states to disk and train on ### Dumpping Hidden States to Disk -We support two backends for generating base model hidden states. For better effciency, it is recommended to use TRT-LLM: +We support two backends for generating base model hidden states: + +#### TRT-LLM Backend (Recommended for efficiency) ```bash python collect_hidden_states/compute_hidden_states_trtllm.py \ @@ -100,7 +102,9 @@ python collect_hidden_states/compute_hidden_states_trtllm.py \ **NOTE**: TRT-LLM installation needed for the above command. -Alternatively, you can generate the same hidden states with HF: +#### HuggingFace Backend (Works with all model families) + +Alternatively, you can generate hidden states with HuggingFace, which is compatible with any model in Hugging Face format (including Kimi and other proprietary models): ```bash python collect_hidden_states/compute_hidden_states_hf.py \ @@ -109,7 +113,7 @@ python collect_hidden_states/compute_hidden_states_hf.py \ --output-dir $HIDDEN_STATES_DIR ``` -**NOTE**: See [`run_hf_compute_hiddens_dp.sh`](./collect_hidden_states/run_hf_compute_hiddens_dp.sh) and [`run_trtllm_compute_hiddens_dp.sh`](./collect_hidden_states/run_trtllm_compute_hiddens_dp.sh) for a simple example using data parallelism (DP) to accelerate hidden state generation. +For large-scale hidden state generation, see [`run_hf_compute_hiddens_dp.sh`](./collect_hidden_states/run_hf_compute_hiddens_dp.sh) and [`run_trtllm_compute_hiddens_dp.sh`](./collect_hidden_states/run_trtllm_compute_hiddens_dp.sh) for examples using data parallelism (DP) to accelerate the process. ### Train Draft Model with Dumped Hidden States @@ -124,6 +128,49 @@ Once we finish dumping hidden states, launch offline training with an extra `--o --offline-data $HIDDEN_STATES_DIR ``` +### Offline Training with Custom Models (e.g., Kimi) + +For proprietary or non-HuggingFace models like Kimi, follow this offline training workflow: + +1. **Prepare your input conversations** in the standard format (`.jsonl` with conversation IDs and content) + +2. **Extract hidden states offline** using either backend: + +```bash +# Using HuggingFace backend (works with any HF-compatible model) +python collect_hidden_states/compute_hidden_states_hf.py \ + --model deepseek-ai/Kimi-K2 \ + --input-file input_conversations/train.jsonl \ + --output-dir hidden_states_dir +``` + +3. **Create an eagle_config.json** for your model. For Kimi models, specify the `kimik2` decoder type: + +```json +{ + "num_hidden_layers": 2, + "intermediate_size": 8192 +} +``` + +4. **Launch offline training** with the `--eagle_decoder_type` parameter: + +```bash +./launch_train.sh --model deepseek-ai/Kimi-K2 \ + --output_dir output_dir \ + --data input_conversations/train.jsonl \ + --num_epochs 1 \ + --eagle_config eagle_config.json \ + --offline-data hidden_states_dir \ + --eagle_decoder_type kimik2 +``` + +Note: The `--eagle_decoder_type` parameter accepts: +- `llama` (default, for Llama, Mistral, Qwen, Phi, etc.) +- `kimik2` (for Kimi models like K2 and K2.5) + +This workflow is particularly useful for large models or when training resources are limited, as hidden states can be computed once and reused for multiple training runs. + ## Model Validation For online training checkpoints, we can run in-framework evaluation on MT-bench: @@ -328,6 +375,8 @@ trainer.save_model("") | Mistral | ✅ | ✅ | ✅ | | Phi 3 | ✅ | ✅ | ✅ | | QWen 1.5,2,2.5,3 | ✅ | ✅ | ✅ | +| Kimi K2 | ✅ | ✅ | ✅ | +| Kimi K2.5 | ✅ | ✅ | ✅ | ## Speculation Module Checkpoints