diff --git a/examples/specdec_bench/specdec_bench/utils.py b/examples/specdec_bench/specdec_bench/utils.py index 9a52d0ceac2..73d1e048c80 100644 --- a/examples/specdec_bench/specdec_bench/utils.py +++ b/examples/specdec_bench/specdec_bench/utils.py @@ -196,6 +196,10 @@ def _checkpoint_provenance(model_dir): def _is_sensitive_key(key): + # Engine configs can carry non-string dict keys (e.g. int layer ids in a + # serving_config); those are never sensitive field *names*, so skip them. + if not isinstance(key, str): + return False klow = key.lower() if klow in _SENSITIVE_KEY_ALLOWLIST: return False diff --git a/examples/speculative_decoding/launch_train.sh b/examples/speculative_decoding/launch_train.sh index 41d71d14173..fc623930767 100755 --- a/examples/speculative_decoding/launch_train.sh +++ b/examples/speculative_decoding/launch_train.sh @@ -30,12 +30,14 @@ SCRIPT_DIR="$(dirname "$(readlink -f "$0")")" CONFIG_FILE="" NUM_NODES=1 HEAD_NODE_IP="" +MACHINE_RANK="" EXTRA_ARGS=() while [ $# -gt 0 ]; do case "$1" in --config*) if [[ "$1" != *=* ]]; then shift; fi; CONFIG_FILE="${1#*=}" ;; --num_nodes*) if [[ "$1" != *=* ]]; then shift; fi; NUM_NODES="${1#*=}" ;; --head_node_ip*) if [[ "$1" != *=* ]]; then shift; fi; HEAD_NODE_IP="${1#*=}" ;; + --machine_rank*) if [[ "$1" != *=* ]]; then shift; fi; MACHINE_RANK="${1#*=}" ;; *) EXTRA_ARGS+=("$1") ;; esac shift @@ -59,9 +61,13 @@ fi # Multi-node routing args (accelerate only; training config comes from the YAML) MULTI_NODE_ARGS="" if [[ "$NUM_NODES" != "1" ]]; then + # machine_rank: caller may pass --machine_rank explicitly (needed when the + # SLURM allocation reserves node 0 for something else, e.g. the streaming + # vllm serve, so SLURM_PROCID is offset from accelerate's 0-based rank). + # Default to $SLURM_PROCID for the all-nodes-are-trainers case. MULTI_NODE_ARGS="--num_processes $TOTAL_GPU \ --num_machines $NUM_NODES \ - --machine_rank $SLURM_PROCID \ + --machine_rank ${MACHINE_RANK:-$SLURM_PROCID} \ --rdzv_backend c10d \ --main_process_ip $HEAD_NODE_IP \ --main_process_port 29500" diff --git a/modelopt/recipe/config.py b/modelopt/recipe/config.py index 4bf91b52d6f..97d93bbafc6 100644 --- a/modelopt/recipe/config.py +++ b/modelopt/recipe/config.py @@ -178,7 +178,11 @@ class ModelOptDFlashRecipe(ModelOptSpeculativeRecipeBase): @model_validator(mode="after") def _derive_dflash_offline(self) -> ModelOptDFlashRecipe: - self.dflash.dflash_offline = self.data.offline_data_path is not None + # offline (dumped .pt) and streaming (hidden states over HTTP from a vLLM + # serve) both feed pre-computed base hidden states to the DFlash module, so + # both set dflash_offline. Only fully-online training runs the base model. + # Mirrors ModelOptEagleRecipe._derive_eagle_offline. + self.dflash.dflash_offline = self.data.mode != "online" return self diff --git a/modelopt/torch/speculative/config.py b/modelopt/torch/speculative/config.py index 0d9dfc882f5..cfc67f7bb1b 100644 --- a/modelopt/torch/speculative/config.py +++ b/modelopt/torch/speculative/config.py @@ -68,8 +68,10 @@ class DFlashConfig(ModeloptBaseConfig): dflash_offline: bool = ModeloptField( default=False, description=( - "Whether to use detached DFlash (offline training from pre-computed hidden states). " - "Derived by ModelOptDFlashRecipe from data.offline_data_path; not user-configurable." + "Whether the DFlash module consumes pre-computed hidden states (offline from " + "dumped .pt files, or streaming over HTTP from a vLLM serve) instead of running " + "the base model. Derived by ModelOptDFlashRecipe from data.mode (True unless " + "online); not user-configurable." ), ) diff --git a/tools/launcher/common/eagle3/train_eagle_streaming.sh b/tools/launcher/common/eagle3/train_eagle_streaming.sh index 158bd7a0cf6..4a8dc8bbacf 100755 --- a/tools/launcher/common/eagle3/train_eagle_streaming.sh +++ b/tools/launcher/common/eagle3/train_eagle_streaming.sh @@ -24,12 +24,19 @@ # $SLURM_NODEID: # nodes == 1 -> co-located: vllm serve on $SERVE_GPU, trainer on the rest of # the local GPUs (original single-node behavior). -# nodes >= 2 -> split across nodes: node 0 runs vllm serve on all its GPUs, -# node 1 runs the trainer on all its GPUs. The two roles -# rendezvous through the shared /scratchspace mount (node 0 -# publishes its address; node 1 signals completion). For large -# models whose serve needs a whole node (e.g. Kimi-K2.5 TP=8), -# allocate exactly 2 nodes. +# nodes == 2 -> split: node 0 runs vllm serve on all its GPUs, node 1 runs +# the trainer on all its GPUs. Roles rendezvous through the +# shared /scratchspace mount (node 0 publishes its serve +# address; the trainer signals completion). +# nodes >= 3 -> 1 serve node (node 0) + N trainer nodes (nodes 1..NNODES-1) +# doing multi-node DDP. The head trainer (node 1, accelerate +# machine_rank 0) publishes its IP for accelerate's c10d +# rendezvous; all trainer nodes read both the serve address and +# the head-trainer address from /scratchspace. NOTE: only global +# rank 0 fetches hidden states from the single serve and +# broadcasts to the rest (DataLoaderDispatcher), so the single +# serve is the throughput ceiling — adding trainer nodes scales +# effective batch / compute, not data-production throughput. # # Env vars (required): # HF_MODEL_CKPT Target model path. Used by both vllm serve (as the @@ -56,7 +63,8 @@ # TRAIN_GPUS single-node only: CUDA_VISIBLE_DEVICES for the trainer. # default = all local GPUs except SERVE_GPU. # SERVE_ADVERTISE_IP multi-node only: address node 1 should dial. default is -# node 0's first `hostname -I` IP. +# node 0's routable IP (its resolved Slurm node name, else +# its first non-loopback / non-link-local IP). # # All script args are forwarded to launch_train.sh (typically: --config # plus OmegaConf dotlist overrides). @@ -112,7 +120,7 @@ export PATH=$PATH:/workspace/.local/bin ################################################################################################### -trap 'error_handler $0 $LINENO' ERR # ERROR HANDLER +trap 'error_handler $0 $LINENO' ERR if [ -z "$HF_MODEL_CKPT" ]; then echo "ERROR: HF_MODEL_CKPT must be set." >&2; exit 1 @@ -154,11 +162,9 @@ launch_vllm() { # would expose *zero* GPUs (not all), so leave it unset to use the whole node. local -a gpu_env=() [ -n "$cvd" ] && gpu_env=(env "CUDA_VISIBLE_DEVICES=$cvd") - # Optional single-value memory knobs (each a space-free env value, so they - # survive nemo_run's unquoted `export FOO=value`; assembled into --flag value - # pairs here). --cpu-offload-gb spills N GB of weights/GPU to host RAM, the - # key lever for fitting a large model on too-few GPUs (slower, prefill-only - # use tolerates it). --max-model-len / --max-num-seqs trim KV/activation. + # Optional single-value memory knobs (see header), assembled into --flag + # value pairs. Each is a space-free env value so it survives nemo_run's + # unquoted `export FOO=value`. local -a opt_args=() [ -n "${SERVE_CPU_OFFLOAD_GB:-}" ] && opt_args+=(--cpu-offload-gb "$SERVE_CPU_OFFLOAD_GB") [ -n "${SERVE_MAX_MODEL_LEN:-}" ] && opt_args+=(--max-model-len "$SERVE_MAX_MODEL_LEN") @@ -222,28 +228,52 @@ wait_vllm_ready() { # per process; multiple workers would duplicate requests against the server. run_trainer_and_export() { local url="$1" cvd="$2" - echo "Launching trainer (server=${url}, CUDA_VISIBLE_DEVICES=${cvd:-all})..." + # Optional multi-node trainer routing (see dispatch section). Defaults keep + # the original single-trainer-node behavior: no --num_nodes, export on rank 0. + local num_tnodes="${3:-1}" head_ip="${4:-}" mrank="${5:-0}" + echo "Launching trainer (server=${url}, CUDA_VISIBLE_DEVICES=${cvd:-all}, trainer_nodes=${num_tnodes}, machine_rank=${mrank})..." # Empty cvd -> use all GPUs on the node (don't set the var; "" would hide all). local -a gpu_env=() [ -n "$cvd" ] && gpu_env=(env "CUDA_VISIBLE_DEVICES=$cvd") + # Engage accelerate multi-node routing only when >1 trainer node; a single + # trainer node keeps the original invocation (no --num_nodes) verbatim. + local -a mn_args=() + if [ "${num_tnodes}" -gt 1 ]; then + mn_args=(--num_nodes "$num_tnodes" --head_node_ip "$head_ip" --machine_rank "$mrank") + fi "${gpu_env[@]}" bash modules/Model-Optimizer/examples/speculative_decoding/launch_train.sh \ "${SCRIPT_ARGS[@]}" \ + "${mn_args[@]}" \ data.streaming_server_url="$url" \ data.streaming_model_name="$HF_MODEL_CKPT" \ data.streaming_shared_storage_path="$SERVE_SCRATCH" \ training.dataloader_num_workers=0 || { echo "ERROR: trainer failed." >&2; return 1; } + # Export only on the head trainer (machine_rank 0); non-head trainer nodes + # would race writing the same export dir. The export reads the saved + # checkpoint (training.output_dir), not the serve, so it is serve-independent. + if [ "${mrank}" -ne 0 ]; then + echo "machine_rank=${mrank}: training done, skipping export (head trainer handles it)." + return 0 + fi + + # Export the trained draft to HF format. Derive the checkpoint dir from the + # forwarded `training.output_dir=` dotlist (defaulting to the EAGLE + # convention) so EAGLE and DFlash runs each export their own output_dir. + # EXPORT_EXTRA_ARGS lets DFlash on a custom-modeling base (e.g. Kimi) pass + # --trust_remote_code; empty by default so EAGLE behavior is unchanged. + local out_dir + out_dir=$(printf '%s\n' "${SCRIPT_ARGS[@]}" | sed -n 's/^training\.output_dir=//p' | tail -1) + out_dir="${out_dir:-/scratchspace/eagle3}" python3 modules/Model-Optimizer/examples/speculative_decoding/scripts/export_hf_checkpoint.py \ - --model_path /scratchspace/eagle3 \ - --export_path /scratchspace/export + --model_path "$out_dir" \ + --export_path "${EXPORT_PATH:-/scratchspace/export}" \ + ${EXPORT_EXTRA_ARGS:-} } # --------------------------------------------------------------------------- -# Topology dispatch (driven by the Slurm allocation, i.e. the yaml `nodes:`): -# SLURM_NNODES == 1 -> co-located: vllm on $SERVE_GPU, trainer on the rest. -# SLURM_NNODES >= 2 -> split: node 0 serves on all its GPUs, node 1 trains on -# all its GPUs; they rendezvous via /scratchspace. -# nemo_run runs this script once per node, so we branch on $SLURM_NODEID. +# Topology dispatch (see header): nemo_run runs this script once per node, so +# branch on $SLURM_NNODES / $SLURM_NODEID. Per-branch detail in section heads. # --------------------------------------------------------------------------- NNODES="${SLURM_NNODES:-1}" NODEID="${SLURM_NODEID:-0}" @@ -299,27 +329,55 @@ elif [ "$NODEID" -eq 0 ]; then while [ ! -f "$DONE_FILE" ]; do sleep 10; done echo "Training-done sentinel seen; serve node exiting (EXIT trap stops vllm)." -elif [ "$NODEID" -eq 1 ]; then - # ---------------------- multi-node: trainer node ----------------------- - # Release the serve node on any exit (success or failure) so it doesn't hang. - trap 'touch "$DONE_FILE" 2>/dev/null || true' EXIT +elif [ "$NODEID" -ge 1 ]; then + # -------------------- multi-node: trainer node(s) ---------------------- + # Node 0 is the vllm serve; trainer nodes are SLURM nodes 1..NNODES-1, which + # map to 0-based accelerate machine ranks (head trainer = SLURM node 1). + NUM_TRAINER_NODES=$(( NNODES - 1 )) + TRAINER_RANK=$(( NODEID - 1 )) + TRAINER_ADDR_FILE="/scratchspace/.trainer_addr" + + # Only the head trainer (rank 0) signals the serve node to release on exit; + # a non-head node exiting first must NOT tear the serve down early. + if [ "$TRAINER_RANK" -eq 0 ]; then + trap 'touch "$DONE_FILE" 2>/dev/null || true' EXIT + rm -f "$TRAINER_ADDR_FILE" # clear stale rendezvous state + fi - echo "Trainer node waiting (up to ${SERVE_READY_TIMEOUT}s) for the serve address..." + echo "Trainer node (rank ${TRAINER_RANK}/${NUM_TRAINER_NODES}) waiting for the serve address..." for ((i = 0; i < SERVE_READY_TIMEOUT; i++)); do [ -f "$SERVE_ADDR_FILE" ] && break sleep 1 done [ -f "$SERVE_ADDR_FILE" ] || { echo "ERROR: serve node never published its address." >&2; exit 1; } URL="http://$(cat "$SERVE_ADDR_FILE"):${SERVE_PORT}" - wait_vllm_ready "$URL" || exit 1 - run_trainer_and_export "$URL" "" || exit 1 -else - # ------------- multi-node: extra nodes (unused by default) ------------- - echo "Node rank ${NODEID} idle: the default split uses node 0 = vllm serve, node 1 = trainer." - echo "Multi-node *training* (>1 trainer node) is not wired up yet; allocate exactly 2 nodes." - while [ ! -f "$DONE_FILE" ]; do sleep 10; done + if [ "$NUM_TRAINER_NODES" -le 1 ]; then + # Original 1-serve + 1-trainer topology: single-node DDP, unchanged. + run_trainer_and_export "$URL" "" || exit 1 + else + # >1 trainer node: head (rank 0) publishes its routable IP for accelerate's + # c10d rendezvous (port 29500); all trainer nodes read it and join. Reuse + # the serve node's IP-resolution logic (avoid link-local / loopback). + if [ "$TRAINER_RANK" -eq 0 ]; then + head_addr="${TRAINER_ADVERTISE_IP:-}" + [ -z "$head_addr" ] && head_addr=$(getent hosts "${SLURMD_NODENAME:-$(hostname)}" 2>/dev/null | awk '{print $1}' | head -1) + [ -z "$head_addr" ] && head_addr=$(hostname -I | tr ' ' '\n' | grep -vE '^(127\.|169\.254\.|fe80:|::1)' | head -1) + [ -z "$head_addr" ] && head_addr=$(hostname -I | awk '{print $1}') + echo "$head_addr" > "$TRAINER_ADDR_FILE" + echo "Head trainer (rank 0) published ${head_addr} for c10d rendezvous." + else + echo "Trainer rank ${TRAINER_RANK} waiting for head-trainer address..." + for ((i = 0; i < SERVE_READY_TIMEOUT; i++)); do + [ -f "$TRAINER_ADDR_FILE" ] && break + sleep 1 + done + [ -f "$TRAINER_ADDR_FILE" ] || { echo "ERROR: head trainer never published its address." >&2; exit 1; } + fi + HEAD_IP=$(cat "$TRAINER_ADDR_FILE") + run_trainer_and_export "$URL" "" "$NUM_TRAINER_NODES" "$HEAD_IP" "$TRAINER_RANK" || exit 1 + fi fi ################################################################################################### diff --git a/tools/launcher/core.py b/tools/launcher/core.py index aa60bbad9e9..f6ae6493af3 100644 --- a/tools/launcher/core.py +++ b/tools/launcher/core.py @@ -286,6 +286,9 @@ def build_slurm_executor( retries=0, packager=packager, srun_args=slurm_config.srun_args, + # --segment=: pin all nodes into one topology block (one NVL72 / NVLink + # domain). None -> omitted, scheduler places freely (default behavior). + segment=slurm_config.segment, ) return executor diff --git a/tools/launcher/examples/moonshotai/Kimi-K2.5/hf_dflash_dryrun.yaml b/tools/launcher/examples/moonshotai/Kimi-K2.5/hf_dflash_dryrun.yaml new file mode 100644 index 00000000000..b12c3b0f538 --- /dev/null +++ b/tools/launcher/examples/moonshotai/Kimi-K2.5/hf_dflash_dryrun.yaml @@ -0,0 +1,64 @@ +# DFlash dry-run smoke test for Kimi-K2.5 (NVFP4). +# +# Single-task pipeline that exercises the full convert→save→export path WITHOUT +# actually training. Uses the same `common/specdec/dflash_online_training.sh` +# entrypoint as a real DFlash run; all dry-run behaviour is expressed as dotlist +# overrides on `main.py` (shared with EAGLE3 — `--dry_run` is mode-agnostic): +# +# --dry_run → main.py skips trainer.train(), saves +# the (untrained) ModelOpt checkpoint +# to training.output_dir right after +# mtsp.convert(model, [("dflash", ...)]) +# data.offline_data_path= → DataArguments derives data.mode from +# the data-source fields, so setting an +# offline path makes mode='offline' → +# use_offline_training=True. Combined +# with use_fake_base_for_offline=true +# this loads a FakeBaseModel (only +# embed_tokens + lm_head), so the ~1T +# MoE base fits on a single GPU. The +# file is never read in --dry_run mode. +# model.trust_remote_code=true → Kimi-K2.5 (deepseek_v3 arch) ships a +# custom modeling file +# dflash.dflash_mask_token_id=163838 → Kimi-K2.5 has no dedicated mask token +# ([EOS]=163585, [PAD]=163839); 163838 is +# a reserved slot used as the DFlash mask +# (matches the real Kimi-K2.5 DFlash run) +# +# The dflash_online_training.sh export block then writes an HF-format DFlash draft +# to /scratchspace/dflash/exported-checkpoint-final with the correct architecture +# (5-layer draft block, block_size=8) but untrained weights — acceptance ~0%, by +# design. Useful for smoke-testing the launcher / convert / export plumbing and +# validating downstream loaders without paying for a real training run. +# +# Usage: +# uv run launch.py --yaml examples/moonshotai/Kimi-K2.5/hf_dflash_dryrun.yaml --yes + +job_name: Kimi-K2.5_DFlash_dryrun +pipeline: + allow_to_fail: false + skip: false + note: + + global_vars: + hf_model: /hf-local/nvidia/Kimi-K2.5-NVFP4/ + + # Convert → save → export (no training). + task_0: + script: common/specdec/dflash_online_training.sh + args: + - --dry_run + - --config modules/Model-Optimizer/modelopt_recipes/general/speculative_decoding/dflash.yaml + - model.model_name_or_path=<> + - model.use_fake_base_for_offline=true + - model.trust_remote_code=true + - data.offline_data_path=/tmp/dryrun-placeholder + - training.output_dir=/scratchspace/dflash + - training.disable_tqdm=true + - dflash.dflash_mask_token_id=163838 + slurm_config: + _factory_: "slurm_factory" + nodes: 1 + ntasks_per_node: 1 + gpus_per_node: 1 + container: nvcr.io/nvidia/tensorrt-llm/release:1.3.0rc10 diff --git a/tools/launcher/examples/moonshotai/Kimi-K2.5/hf_streaming_dflash.yaml b/tools/launcher/examples/moonshotai/Kimi-K2.5/hf_streaming_dflash.yaml new file mode 100644 index 00000000000..ff99ae62c7f --- /dev/null +++ b/tools/launcher/examples/moonshotai/Kimi-K2.5/hf_streaming_dflash.yaml @@ -0,0 +1,131 @@ +# DFlash streaming speculative-decoding training for Kimi-K2.5-NVFP4 on +# GB200/Blackwell (HSG). Sibling of hf_streaming_eagle3.yaml — same vLLM-serve + +# trainer split, same hardware reasoning — but trains a DFlash drafter instead of +# EAGLE3 by pointing the (shared, algorithm-agnostic) streaming script at the +# dflash recipe. +# +# Why GB200: nodes have only 4 GPUs each (vs CW's 8), but 192 GB/GPU and native +# NVFP4. Kimi-K2.5-NVFP4 (~551 GB) fits at TP=4 on ONE node (4 x 192 = 768 GB, +# ~138 GB/GPU of weights) with NO cpu-offload. So: node 0 = vllm serve (TP=4, +# whole node), node 1 = DFlash trainer (fake base), 4 GPUs each, 2 nodes. +# +# How streaming feeds DFlash (vs EAGLE3): the trainer's streaming path was wired +# up for DFlash by deriving dflash_offline from data.mode (modelopt/recipe/config.py +# ModelOptDFlashRecipe._derive_dflash_offline), so data.mode=streaming sets +# dflash_offline=True and the DFlash module consumes the streamed hidden states +# (base_model_outputs) instead of running the fake base. The vLLM connector, +# streaming dataset, and offline collator are all algorithm-agnostic: vLLM dumps +# captured layers as [seq, n_captured, hidden]; the dataset splits the LAST +# captured layer into base_model_hidden_states (used for DFlash self-logit +# distillation) and the REST into aux_hidden_states (DFlash's concatenated +# target-layer features). So n_captured must be (num DFlash target layers + 1). +# +# Capture ids (kimi_k25 / deepseek_v3 arch, 61 layers, capture id space 0..60; +# the true final layer is NOT capturable so we use 60 as the base, same as EAGLE3): +# DFlash target layers come from build_target_layer_ids(num_orig=61, num_draft=5) +# = [1,15,30,44,58] (0-based) -> vLLM ids (+1 for the embedding layer) = +# [2,16,31,45,59]. Append base 60. captured = [2,16,31,45,59,60] = 6, so the +# dataset yields 5 aux layers, matching the 5-layer DFlash draft block. +# +# answer_only_loss: forced false here. DFlash's recipe default is true, which +# requires the tokenizer chat template to carry {% generation %} tags so the +# streaming dataset can derive an assistant-token mask; Kimi's template does not, +# and the streaming path (unlike online) does not inject data.chat_template. To +# train assistant-only later, supply a generation-tagged template and flip this on. +# +# Run ON the HSG login node (paramiko can't reach HSG through the sss proxy): +# export SLURM_HOST=localhost SLURM_ACCOUNT=coreai_dlalgo_modelopt \ +# SLURM_PARTITION=batch \ +# SLURM_HF_LOCAL=/lustre/fs1/portfolios/coreai/projects/coreai_dlalgo_modelopt/hf-local \ +# SLURM_JOB_DIR=/lustre/fs1/portfolios/coreai/projects/coreai_dlalgo_modelopt/users/haoguo/experiments \ +# NEMORUN_HOME=$PWD +# uv run launch.py --yaml examples/moonshotai/Kimi-K2.5/hf_streaming_dflash.yaml \ +# identity=$HOME/.ssh/id_ecdsa detach=True --yes +# +# The export lands in /scratchspace/export. To benchmark it, point +# specdec_bench.yaml's --draft_model_dir there (or copy it under /hf-local). + +job_name: Kimi-K2.5-NVFP4_DFlash_streaming +pipeline: + allow_to_fail: false + skip: false + note: + + global_vars: + hf_model: /hf-local/nvidia/Kimi-K2.5-NVFP4 + + # Step 1: Build input conversations (model-agnostic) + task_0: + script: common/eagle3/make_dataset.sh + args: + - -f modules/Model-Optimizer/examples/dataset/example_data_config.yaml + - --full-conversations + slurm_config: + _factory_: "slurm_factory" + nodes: 1 + ntasks_per_node: 1 + # HSG QOS (QOSMinGRES) requires whole-node GPU allocation (4 on GB200), + # so request 4 even though make_dataset is CPU-only. + gpus_per_node: 4 + container: nvcr.io/nvidia/tensorrt-llm/release:1.3.0rc10 + + # Step 2: Streaming DFlash training (node0 serve TP=4 / node1 train), 4 GPU/node. + # Reuses the shared streaming orchestrator (common/eagle3/train_eagle_streaming.sh): + # only the --config recipe (dflash vs eagle3) and EAGLE_CAPTURE_IDS differ. + task_1: + script: common/eagle3/train_eagle_streaming.sh + args: + - --config modules/Model-Optimizer/modelopt_recipes/general/speculative_decoding/dflash.yaml + - model.model_name_or_path=<> + - model.use_fake_base_for_offline=true + - model.trust_remote_code=true + - data.mode=streaming + - data.data_path=/scratchspace/data/train.jsonl + # Keep concurrent in-flight requests low: a 64-wide flood made cold NVFP4 + # MoE kernels/flashinfer autotune stall a worker past vLLM's engine<->worker + # timeout, killing EngineCore (TimeoutError) mid-serve -> 500s -> trainer abort. + - data.streaming_prefetch=8 + - training.output_dir=/scratchspace/dflash + # Must be divisible by dflash_block_size (8). DFlash trains on fixed-length blocks. + - training.training_seq_len=4096 + - training.disable_tqdm=true + - training.num_train_epochs=1 + - training.max_steps=3000 + # See header: Kimi's template lacks {% generation %} tags; train on all tokens. + - training.answer_only_loss=false + # dflash.yaml sets report_to=tensorboard, but the vLLM container has no + # tensorboard -> TensorBoardCallback RuntimeError at trainer init. Disable + # reporting (loss still prints to stdout via logging_steps). + - training.report_to=none + # Kimi-K2.5 has no dedicated mask token ([EOS]=163585, [PAD]=163839); 163838 + # is a reserved slot used as the DFlash mask (matches the real Kimi DFlash run). + - dflash.dflash_mask_token_id=163838 + environment: + - HF_MODEL_CKPT: <> + # No spaces in values: nemo_run emits `export FOO=value` unquoted. + # DFlash target layers (vLLM-indexed) + base 60; see header for derivation. + - EAGLE_CAPTURE_IDS: "[2,16,31,45,59,60]" + - SERVE_TP: "4" + # DFlash on a custom-modeling base (Kimi) needs --trust_remote_code at export. + - EXPORT_EXTRA_ARGS: "--trust_remote_code" + # Kimi-K2.5-NVFP4 ~138 GB weights/GPU at TP=4; GB200 has 184 GB. The model's + # native max_seq_len is 262144, whose KV cache OOMs. Cap context to the + # training seq len and leave headroom for activation spikes. + - SERVE_MAX_MODEL_LEN: "4096" + # Small batches: smaller per-step MoE compute stays under the engine timeout. + - SERVE_MAX_NUM_SEQS: "4" + - SERVE_GPU_MEM_UTIL: "0.8" + - SERVE_READY_TIMEOUT: "2400" + - SERVE_EXTRA_ARGS: "--trust-remote-code" + # The killer was "RPC call to sample_tokens timed out" — a worker stalls on + # the first real serving step (cold NVFP4 MoE kernels) past vLLM's default + # execute-model timeout, so EngineCore dies. Extend the timeouts that govern + # that path (seconds). VLLM_RPC_TIMEOUT (ms) is a different RPC and didn't help. + - VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS: "1200" + - VLLM_ENGINE_ITERATION_TIMEOUT_S: "1200" + slurm_config: + _factory_: "slurm_factory" + nodes: 2 + ntasks_per_node: 1 + gpus_per_node: 4 + container: vllm/vllm-openai:latest diff --git a/tools/launcher/examples/moonshotai/Kimi-K2.5/hf_streaming_dflash_multi_node.yaml b/tools/launcher/examples/moonshotai/Kimi-K2.5/hf_streaming_dflash_multi_node.yaml new file mode 100644 index 00000000000..fb92ba11234 --- /dev/null +++ b/tools/launcher/examples/moonshotai/Kimi-K2.5/hf_streaming_dflash_multi_node.yaml @@ -0,0 +1,133 @@ +# DFlash streaming speculative-decoding training for Kimi-K2.5-NVFP4 on +# GB200/Blackwell (HSG). Sibling of hf_streaming_eagle3.yaml — same vLLM-serve + +# trainer split, same hardware reasoning — but trains a DFlash drafter instead of +# EAGLE3 by pointing the (shared, algorithm-agnostic) streaming script at the +# dflash recipe. +# +# Why GB200: nodes have only 4 GPUs each (vs CW's 8), but 192 GB/GPU and native +# NVFP4. Kimi-K2.5-NVFP4 (~551 GB) fits at TP=4 on ONE node (4 x 192 = 768 GB, +# ~138 GB/GPU of weights) with NO cpu-offload. So: node 0 = vllm serve (TP=4, +# whole node), node 1 = DFlash trainer (fake base), 4 GPUs each, 2 nodes. +# +# How streaming feeds DFlash (vs EAGLE3): the trainer's streaming path was wired +# up for DFlash by deriving dflash_offline from data.mode (modelopt/recipe/config.py +# ModelOptDFlashRecipe._derive_dflash_offline), so data.mode=streaming sets +# dflash_offline=True and the DFlash module consumes the streamed hidden states +# (base_model_outputs) instead of running the fake base. The vLLM connector, +# streaming dataset, and offline collator are all algorithm-agnostic: vLLM dumps +# captured layers as [seq, n_captured, hidden]; the dataset splits the LAST +# captured layer into base_model_hidden_states (used for DFlash self-logit +# distillation) and the REST into aux_hidden_states (DFlash's concatenated +# target-layer features). So n_captured must be (num DFlash target layers + 1). +# +# Capture ids (kimi_k25 / deepseek_v3 arch, 61 layers, capture id space 0..60; +# the true final layer is NOT capturable so we use 60 as the base, same as EAGLE3): +# DFlash target layers come from build_target_layer_ids(num_orig=61, num_draft=5) +# = [1,15,30,44,58] (0-based) -> vLLM ids (+1 for the embedding layer) = +# [2,16,31,45,59]. Append base 60. captured = [2,16,31,45,59,60] = 6, so the +# dataset yields 5 aux layers, matching the 5-layer DFlash draft block. +# +# answer_only_loss: forced false here. DFlash's recipe default is true, which +# requires the tokenizer chat template to carry {% generation %} tags so the +# streaming dataset can derive an assistant-token mask; Kimi's template does not, +# and the streaming path (unlike online) does not inject data.chat_template. To +# train assistant-only later, supply a generation-tagged template and flip this on. +# +# Run ON the HSG login node (paramiko can't reach HSG through the sss proxy): +# export SLURM_HOST=localhost SLURM_ACCOUNT=coreai_dlalgo_modelopt \ +# SLURM_PARTITION=batch \ +# SLURM_HF_LOCAL=/lustre/fs1/portfolios/coreai/projects/coreai_dlalgo_modelopt/hf-local \ +# SLURM_JOB_DIR=/lustre/fs1/portfolios/coreai/projects/coreai_dlalgo_modelopt/users/haoguo/experiments \ +# NEMORUN_HOME=$PWD +# uv run launch.py --yaml examples/moonshotai/Kimi-K2.5/hf_streaming_dflash.yaml \ +# identity=$HOME/.ssh/id_ecdsa detach=True --yes +# +# The export lands in /scratchspace/export. To benchmark it, point +# specdec_bench.yaml's --draft_model_dir there (or copy it under /hf-local). + +job_name: Kimi-K2.5-NVFP4_DFlash_streaming_multi_node +pipeline: + allow_to_fail: false + skip: false + note: + + global_vars: + hf_model: /hf-local/nvidia/Kimi-K2.5-NVFP4 + + # Step 1: Build input conversations (model-agnostic) + task_0: + script: common/eagle3/make_dataset.sh + args: + - -f modules/Model-Optimizer/examples/dataset/example_data_config.yaml + - --full-conversations + slurm_config: + _factory_: "slurm_factory" + nodes: 1 + ntasks_per_node: 1 + # HSG QOS (QOSMinGRES) requires whole-node GPU allocation (4 on GB200), + # so request 4 even though make_dataset is CPU-only. + gpus_per_node: 4 + container: nvcr.io/nvidia/tensorrt-llm/release:1.3.0rc10 + + # Step 2: Streaming DFlash training (node0 serve TP=4 / node1 train), 4 GPU/node. + # Reuses the shared streaming orchestrator (common/eagle3/train_eagle_streaming.sh): + # only the --config recipe (dflash vs eagle3) and EAGLE_CAPTURE_IDS differ. + task_1: + script: common/eagle3/train_eagle_streaming.sh + args: + - --config modules/Model-Optimizer/modelopt_recipes/general/speculative_decoding/dflash.yaml + - model.model_name_or_path=<> + - model.use_fake_base_for_offline=true + - model.trust_remote_code=true + - data.mode=streaming + - data.data_path=/scratchspace/data/train.jsonl + # Keep concurrent in-flight requests low: a 64-wide flood made cold NVFP4 + # MoE kernels/flashinfer autotune stall a worker past vLLM's engine<->worker + # timeout, killing EngineCore (TimeoutError) mid-serve -> 500s -> trainer abort. + - data.streaming_prefetch=8 + - training.output_dir=/scratchspace/dflash + # Must be divisible by dflash_block_size (8). DFlash trains on fixed-length blocks. + - training.training_seq_len=4096 + - training.disable_tqdm=true + - training.num_train_epochs=1 + - training.ar_validate_steps=500000 + - training.max_steps=500 + # See header: Kimi's template lacks {% generation %} tags; train on all tokens. + - training.answer_only_loss=false + # dflash.yaml sets report_to=tensorboard, but the vLLM container has no + # tensorboard -> TensorBoardCallback RuntimeError at trainer init. Disable + # reporting (loss still prints to stdout via logging_steps). + - training.report_to=none + # Kimi-K2.5 has no dedicated mask token ([EOS]=163585, [PAD]=163839); 163838 + # is a reserved slot used as the DFlash mask (matches the real Kimi DFlash run). + - dflash.dflash_mask_token_id=163838 + environment: + - HF_MODEL_CKPT: <> + # No spaces in values: nemo_run emits `export FOO=value` unquoted. + # DFlash target layers (vLLM-indexed) + base 60; see header for derivation. + - EAGLE_CAPTURE_IDS: "[2,16,31,45,59,60]" + - SERVE_TP: "4" + # DFlash on a custom-modeling base (Kimi) needs --trust_remote_code at export. + - EXPORT_EXTRA_ARGS: "--trust_remote_code" + # Kimi-K2.5-NVFP4 ~138 GB weights/GPU at TP=4; GB200 has 184 GB. The model's + # native max_seq_len is 262144, whose KV cache OOMs. Cap context to the + # training seq len and leave headroom for activation spikes. + - SERVE_MAX_MODEL_LEN: "4096" + # Small batches: smaller per-step MoE compute stays under the engine timeout. + - SERVE_MAX_NUM_SEQS: "4" + - SERVE_GPU_MEM_UTIL: "0.8" + - SERVE_READY_TIMEOUT: "2400" + - SERVE_EXTRA_ARGS: "--trust-remote-code" + # The killer was "RPC call to sample_tokens timed out" — a worker stalls on + # the first real serving step (cold NVFP4 MoE kernels) past vLLM's default + # execute-model timeout, so EngineCore dies. Extend the timeouts that govern + # that path (seconds). VLLM_RPC_TIMEOUT (ms) is a different RPC and didn't help. + - VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS: "1200" + - VLLM_ENGINE_ITERATION_TIMEOUT_S: "1200" + slurm_config: + _factory_: "slurm_factory" + nodes: 3 + segment: 3 + ntasks_per_node: 1 + gpus_per_node: 4 + container: vllm/vllm-openai:latest diff --git a/tools/launcher/examples/moonshotai/Kimi-K2.5/specdec_bench.yaml b/tools/launcher/examples/moonshotai/Kimi-K2.5/specdec_bench.yaml new file mode 100644 index 00000000000..a943f39c27e --- /dev/null +++ b/tools/launcher/examples/moonshotai/Kimi-K2.5/specdec_bench.yaml @@ -0,0 +1,81 @@ +# DFLASH speculative-decoding benchmark for Kimi-K2.5-NVFP4 via vLLM. +# +# Serves Kimi-K2.5-NVFP4 in-process (no HTTP server — specdec_bench drives an +# AsyncLLM) at TP=4 with expert parallelism, attaches a trained/exported DFLASH +# draft, and benchmarks speculative decoding on MT-Bench. Writes timing.json + +# aa_timing.json + acceptance_rate.json + mtbench.json + specbench_responses.jsonl +# to /scratchspace/specdec_bench/. +# +# Hardware = GB200/Blackwell (HSG), same reasoning as hf_streaming_eagle3.yaml: +# Kimi-K2.5-NVFP4 (~551 GB) needs native NVFP4 + the 192 GB/GPU of GB200; it fits +# at TP=4 on ONE 4-GPU node with no cpu-offload. On CW H100 it has no native FP4 +# and falls back to offload, so the working path is GB200. +# +# DFLASH specifics: +# - draft tokens default to 8 in specdec_bench (matches DFlash block_size=8); +# --draft_length does NOT apply to DFLASH. To override sampling / engine args +# (e.g. speculative_num_draft_tokens, temperature), write a runtime-params +# yaml and add `- --runtime_params ` below — see +# examples/specdec_bench/README.md (runtime_args_long_context.yaml pattern). +# - --draft_model_dir must point at a trained+exported HF-format DFLASH draft +# (e.g. produced by hf_offline_dflash.yaml / a real DFlash run). Edit the path +# below, or override on the CLI: pipeline.task_0.args[0]="--draft_model_dir /hf-local/" +# - Kimi needs --trust_remote_code for both tokenizer and model. +# +# NOTE on dataset: uses MT-Bench (the question.jsonl staged under /hf-local), so +# it runs without any data-prep step. To benchmark on SPEED-Bench instead, first +# generate + stage a split: +# python3 examples/specdec_bench/prepare_data.py --dataset speed --config all +# (splits: qualitative, throughput_1k, throughput_16k, ...) then swap the +# `--mtbench` arg for: +# - --dataset speed +# - --dataset_path modules/Model-Optimizer/examples/specdec_bench/data/speed/throughput_16k +# +# NOTE on container: vllm/vllm-openai:latest is x86 and may lack DFLASH support; +# on GB200/aarch64 use an aarch64 vLLM image new enough for DFLASH (validated on +# a 0511 nightly). Override with: pipeline.task_0.slurm_config.container= +# +# Run ON the HSG login node (paramiko can't reach HSG through the sss proxy): +# export SLURM_HOST=localhost SLURM_ACCOUNT=coreai_dlalgo_modelopt \ +# SLURM_PARTITION=batch \ +# SLURM_HF_LOCAL=/lustre/fs1/portfolios/coreai/projects/coreai_dlalgo_modelopt/hf-local \ +# SLURM_JOB_DIR=/lustre/fs1/portfolios/coreai/projects/coreai_dlalgo_modelopt/users/haoguo/experiments \ +# NEMORUN_HOME=$PWD +# uv run launch.py --yaml examples/moonshotai/Kimi-K2.5/specdec_bench.yaml \ +# identity=$HOME/.ssh/id_ecdsa detach=True --yes + +job_name: Kimi-K2.5-NVFP4_DFLASH_specdec_bench + +pipeline: + allow_to_fail: false + skip: false + note: + + global_vars: + hf_model: /hf-local/nvidia/Kimi-K2.5-NVFP4 + + task_0: + script: common/specdec_bench/run.sh + args: + # TODO: point at your trained + exported HF-format DFLASH draft checkpoint. + - --draft_model_dir /hf-local/nvidia/Kimi-K2.5-DFlash + - --speculative_algorithm DFLASH + - --engine VLLM + - --mtbench /hf-local/HuggingFaceH4/mt_bench_prompts/raw/question.jsonl + - --tp_size 4 + - --ep_size 4 + - --concurrency 32 + - --output_length 1024 + - --trust_remote_code + - --aa_timing + - --show_progress + - --save_dir /scratchspace/specdec_bench + environment: + - HF_MODEL_CKPT: <> + - HF_LOCAL: /hf-local + slurm_config: + _factory_: "slurm_factory" + nodes: 1 + ntasks_per_node: 1 + gpus_per_node: 4 + container: vllm/vllm-openai:latest diff --git a/tools/launcher/slurm_config.py b/tools/launcher/slurm_config.py index 8ecd51f6f86..0bcfff14ad9 100644 --- a/tools/launcher/slurm_config.py +++ b/tools/launcher/slurm_config.py @@ -48,6 +48,11 @@ class SlurmConfig: gpus_per_node: int = 1 time: str = "04:00:00" local: bool = False + # Slurm --segment=: force the job's nodes into a single topology block. + # On a topology/block cluster (e.g. GB200 NVL72, where one block = one NVLink + # domain) set this to the node count to keep all nodes in one NVL72 so + # inter-node traffic rides NVLink. None = let the scheduler place freely. + segment: Optional[int] = None @run.cli.factory @@ -68,6 +73,7 @@ def slurm_factory( srun_args: list[str] = ["--no-container-mount-home"], array: Optional[str] = None, time: str = "04:00:00", + segment: Optional[int] = None, ) -> SlurmConfig: """Generic Slurm factory — configure via environment variables or CLI overrides.""" return SlurmConfig( @@ -84,4 +90,5 @@ def slurm_factory( srun_args=srun_args, array=array, time=time, + segment=segment, )