Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions examples/speculative_decoding/launch_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,10 @@ while [ $# -gt 0 ]; do
if [[ "$1" != *=* ]]; then shift; fi
MIX_HIDDEN_STATES="${1#*=}"
;;
--disable_torch_compile*)
if [[ "$1" != *=* ]]; then shift; fi
DISABLE_TORCH_COMPILE="${1#*=}"
;;
*)
>&2 printf "Error: Invalid argument ${1#*=}\n"
exit 1
Expand Down Expand Up @@ -158,6 +162,7 @@ DP_SHARD_SIZE=${DP_SHARD_SIZE:-$((TOTAL_GPU/CP_SIZE))}
LOG_STEPS=${LOG_STEPS:-100}
DRAFT_VOCAB_CACHE=${DRAFT_VOCAB_CACHE:-""}
MIX_HIDDEN_STATES=${MIX_HIDDEN_STATES:-"False"}
DISABLE_TORCH_COMPILE=${DISABLE_TORCH_COMPILE:-"False"}
NUM_TTT_STEPS=${NUM_TTT_STEPS:-3}


Expand Down Expand Up @@ -245,6 +250,7 @@ CMD="accelerate launch $MULTI_NODE_ARGS --mixed_precision bf16 ${SCRIPT_DIR}/mai
--estimate_ar $ESTIMATE_AR \
--ar_validate_steps $AR_VALIDATE_STEPS \
--mix_hidden_states $MIX_HIDDEN_STATES \
--disable_torch_compile $DISABLE_TORCH_COMPILE \
$DRAFT_VOCAB_CACHE_ARGS \
$VLM_ARGS \
$OFFLINE_TRAINING_ARGS \
Expand Down
12 changes: 9 additions & 3 deletions examples/speculative_decoding/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,10 @@ class EagleArguments:
default=False,
metadata={"help": "Whether to mix hidden states from previous TTT step."},
)
disable_torch_compile: bool = field(
default=False,
metadata={"help": "Disable torch.compile on eagle forward/loss methods."},
)
num_ttt_steps: int = field(
default=3,
metadata={"help": "Number of train-time-test steps to use during training."},
Expand All @@ -149,9 +153,10 @@ def train():
model_args, data_args, training_args, medusa_args, eagle_args = (
parser.parse_args_into_dataclasses()
)
training_args.parallelism_config = ParallelismConfig(
cp_size=training_args.cp_size, dp_shard_size=training_args.dp_shard_size
)
if training_args.cp_size > 1 or training_args.dp_shard_size > 1:
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note, this is an unrelated bugfix related to #1045 (does not fully solve the issue, just a single-gpu workaround)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed in slack, this issue id due to transformers version mismatch. Should be fixed after updating transformers.

training_args.parallelism_config = ParallelismConfig(
cp_size=training_args.cp_size, dp_shard_size=training_args.dp_shard_size
)
if training_args.cp_size > 1:
patch_ring_attention_for_ttt()
# Specific patch to accelerate 1.12.0. Removable after move to 1.13.0
Expand Down Expand Up @@ -212,6 +217,7 @@ def train():
"eagle_decoder_type": eagle_args.eagle_decoder_type,
"eagle_offline": use_offline_training,
"eagle_mix_hidden_states": eagle_args.mix_hidden_states,
"eagle_use_torch_compile": not eagle_args.disable_torch_compile,
"eagle_ttt_steps": eagle_args.num_ttt_steps,
"eagle_architecture_config": custom_config,
}
Expand Down
10 changes: 10 additions & 0 deletions modelopt/torch/speculative/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,13 @@ class EagleConfig(ModeloptBaseConfig):
"Whether to mix hidden states of multiple TTT steps. It is a technique to reduce training cost."
),
)

eagle_use_torch_compile: bool = ModeloptField(
default=True,
description="Whether to use torch.compile on eagle forward/loss methods for faster training.",
)

eagle_enable_nvtx: bool = ModeloptField(
default=False,
description="Whether to enable NVTX ranges for profiling eagle forward/loss methods.",
)
2 changes: 2 additions & 0 deletions modelopt/torch/speculative/eagle/eagle_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,5 @@ def modify(
self.eagle_decoder_type = config.eagle_decoder_type
self.eagle_ttt_steps = config.eagle_ttt_steps
self.eagle_mix_hidden_states = config.eagle_mix_hidden_states
self.eagle_use_torch_compile = config.eagle_use_torch_compile
self.eagle_enable_nvtx = config.eagle_enable_nvtx
171 changes: 114 additions & 57 deletions modelopt/torch/speculative/plugins/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,10 @@ def __init__(self, config, decoder_layer_cls, bias=False):
num_layers=self.config.parallel_draft_heads_num_layers,
)

def _maybe_init_rope(self):
if self.config.eagle_decoder_type == "llama" and not hasattr(self, "rotary_emb"):
self.rotary_emb = LlamaRotaryEmbedding(config=self.config)

def _expand_first_attn_in_dim(self, first_layer_attn):
"""Modify qkv projection in first layer to accept 2h hidden size."""
# Find Linear modules to expand
Expand Down Expand Up @@ -372,11 +376,6 @@ def forward(
self._input_embeds = self.layers[0].input_layernorm(inputs_embeds)

if self.config.eagle_decoder_type == "llama":
# Lazy init rope to avoid save/load meta tensor error
Comment thread
benchislett marked this conversation as resolved.
if not hasattr(self, "rotary_emb"):
self.rotary_emb = LlamaRotaryEmbedding(
config=self.config, device=hidden_states.device
)
position_embeddings = self.rotary_emb(hidden_states, position_ids)
else:
position_embeddings = None
Expand Down Expand Up @@ -454,6 +453,23 @@ def _draft_model_config(self):
"""Return the llm config for the draft model."""
return self.eagle_config

def _enable_cp_ttt(self):
if self.training and not self.eagle_mix_hidden_states:
return enable_cp_ttt_patch()
return contextlib.nullcontext()

def _nvtx_range(self, name):
"""Optionally create an NVTX range for the given name when config.eagle_enable_nvtx is set."""
if not self.eagle_enable_nvtx:
return contextlib.nullcontext()
try:
import torch.cuda.nvtx as nvtx

return nvtx.range(name)
except Exception as e:
print(f"Failed to create NVTX range {name}: {e}")
return contextlib.nullcontext()

def get_exporter(self) -> SpeculativeDecodingExporter:
"""Get the exporter for the draft model."""
exporter_cls = (
Expand Down Expand Up @@ -618,8 +634,27 @@ def modify(
# https://github.com/huggingface/transformers/blob/v4.56-release/src/transformers/trainer.py#L566
self.is_quantized = False

if self.eagle_use_torch_compile:
self._activate_torch_compile()

self._cached_attn_blk_masks = {}

def _activate_torch_compile(self):
import torch._dynamo

torch._dynamo.config.suppress_errors = True # Allow fallback to eager mode

compile_targets = [
("_prepare_eagle_inputs", {}),
("_eagle_forward", {"mode": "max-autotune"}),
("_eagle_loss", {"fullgraph": True}),
]
for name, kwargs in compile_targets:
try:
setattr(self, name, torch.compile(getattr(self, name), dynamic=False, **kwargs))
except Exception: # noqa: PERF203
print(f"Disabling torch.compile for {name} due to compilation error.")

def _get_ttt_attention_mask(self, batch_size, seq_length, ttt_step):
# compile and cached flex attention masks in first call
if ttt_step not in self._cached_attn_blk_masks:
Expand Down Expand Up @@ -716,7 +751,20 @@ def _prepare_eagle_inputs(
else:
eagle_position_ids = position_ids.view(-1, seq_length).long()

return eagle_input_embeds, eagle_input_hiddens, eagle_attention_mask, eagle_position_ids
base_model_logits = base_outputs.logits
if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size:
base_model_logits = self._map_logits_to_draft_vocab(base_model_logits)
base_output_predict_tok = base_model_logits.argmax(dim=-1).detach()
base_output_softmax_logits = torch.softmax(base_model_logits, dim=2).detach()

return (
eagle_input_embeds,
eagle_input_hiddens,
eagle_attention_mask,
eagle_position_ids,
base_output_predict_tok,
base_output_softmax_logits,
)

def _compute_ttt_attention_mask(
self, batch_size, seq_length, ttt_step
Expand Down Expand Up @@ -872,15 +920,16 @@ def forward(
base_outputs.logits = self.lm_head(base_outputs.out_hiddens)
past_key_values = None
else:
base_outputs, past_key_values = self._base_model_forward(
input_ids,
attention_mask,
position_ids,
past_key_values,
self.eagle_freeze_base_model,
labels,
**kwargs,
)
with self._nvtx_range("base_model_forward"):
base_outputs, past_key_values = self._base_model_forward(
input_ids,
attention_mask,
position_ids,
past_key_values,
self.eagle_freeze_base_model,
labels,
**kwargs,
)

if not isinstance(past_key_values, Cache):
past_key_values = _get_empty_cache(self._base_llm_config)
Expand All @@ -890,20 +939,27 @@ def forward(

# ====Prepare inputs for the first eagle forward pass====
eagle_loss = None
train_accs = [[] for _ in range(self.eagle_config.parallel_draft_step)]
num_parallel = self.eagle_config.parallel_draft_step
num_ttt = self.eagle_ttt_steps
train_accs = torch.zeros(num_parallel, num_ttt, device=input_ids.device)
Comment thread
benchislett marked this conversation as resolved.
b, seq_length, _ = base_outputs.out_hiddens.shape
(
eagle_input_embeds,
eagle_input_hiddens,
eagle_attn_mask_0,
eagle_position_ids,
) = self._prepare_eagle_inputs(
input_ids,
attention_mask,
position_ids,
eagle_cache,
base_outputs,
)
with self._nvtx_range("prepare_eagle_inputs"):
(
eagle_input_embeds,
eagle_input_hiddens,
eagle_attn_mask_0,
eagle_position_ids,
base_output_predict_tok,
base_output_softmax_logits,
) = self._prepare_eagle_inputs(
input_ids,
attention_mask,
position_ids,
eagle_cache,
base_outputs,
)

self.eagle_module._maybe_init_rope()

Comment thread
benchislett marked this conversation as resolved.
# ====Run eagle forward with extra training-time-test steps====
for ttt_step in range(self.eagle_ttt_steps):
Expand All @@ -913,11 +969,7 @@ def forward(
if self.eagle_mix_hidden_states or ttt_step == 0
else self._get_ttt_attention_mask(b, seq_length, ttt_step)
)
with (
enable_cp_ttt_patch()
if self.training and not self.eagle_mix_hidden_states
else contextlib.nullcontext()
):
with self._enable_cp_ttt(), self._nvtx_range("eagle_forward"):
_, eagle_output_hiddens, eagle_logits, eagle_cache = self._eagle_forward(
eagle_input_hiddens,
eagle_input_embeds,
Expand Down Expand Up @@ -945,23 +997,28 @@ def forward(

for i in range(self.eagle_config.parallel_draft_step):
eagle_logit = eagle_logits[i]
classification_loss, acc = self._eagle_loss(
# base model predict +1 tok, while eagle predict +2
# so we shift base model outputs compared to eagle outputs
# additionally, we mask the first n tok of eagle outputs at nth TTT step
base_outputs.logits[:, 1 + i + ttt_step :],
eagle_logit[:, ttt_step : -(1 + i)],
loss_mask[:, 1 + ttt_step :] if i == 0 else loss_mask[:, 1 + ttt_step : -i],
)
with self._nvtx_range("eagle_loss"):
classification_loss, acc = self._eagle_loss(
# base model predict +1 tok, while eagle predict +2
# so we shift base model outputs compared to eagle outputs
# additionally, we mask the first n tok of eagle outputs at nth TTT step
base_output_softmax_logits[:, 1 + i + ttt_step :],
base_output_predict_tok[:, 1 + i + ttt_step :],
eagle_logit[:, ttt_step : -(1 + i)],
loss_mask[:, 1 + ttt_step :] if i == 0 else loss_mask[:, 1 + ttt_step : -i],
)
# Apply loss decay factor to focus on early steps
classification_loss *= self.eagle_loss_decay_factor ** (ttt_step + i)
eagle_loss = (
classification_loss if eagle_loss is None else eagle_loss + classification_loss
)
train_accs[i].append(acc)
train_accs[i, ttt_step] = acc
if not self.training:
break

# Slice by actual number of steps taken, in case of early return
train_accs = train_accs[:, : ttt_step + 1].tolist()

# Merge base model loss and eagle loss
if base_outputs.loss is None and eagle_loss is None:
loss = None
Expand All @@ -979,27 +1036,23 @@ def forward(

def _eagle_loss(
self,
base_model_logits,
base_output_softmax_logits,
base_output_predict_tok,
eagle_logits,
loss_mask,
):
"""Function for EAGLE loss computing."""
if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size:
base_model_logits = self._map_logits_to_draft_vocab(base_model_logits)
loss_mask = loss_mask[:, : eagle_logits.shape[1], None]
classification_loss = nn.Softmax(dim=2)(base_model_logits) * nn.LogSoftmax(dim=2)(
eagle_logits
)
classification_loss = -torch.sum(torch.sum(loss_mask * classification_loss, 2)) / (
loss_mask.sum() + 1e-5
)
# Compute accuracy
base_predict_tok = base_model_logits.clone().detach().argmax(dim=-1)
eagle_predict_tok = eagle_logits.clone().detach().argmax(dim=-1)
eagle_logsoft = torch.log_softmax(eagle_logits, dim=2)
classification_loss = -torch.sum(
torch.sum(loss_mask * base_output_softmax_logits * eagle_logsoft, 2)
) / (loss_mask.sum() + 1e-5)
# Compute accuracy (returned as tensor to avoid sync; .item() called after TTT loop)
eagle_predict_tok = eagle_logits.detach().argmax(dim=-1)
valid = loss_mask[:, :, 0].bool()
correct = (base_predict_tok == eagle_predict_tok) & valid
correct = (base_output_predict_tok == eagle_predict_tok) & valid
denom = valid.sum().clamp_min(1).float()
accuracy = round(correct.sum().float().div(denom).item(), 3)
accuracy = correct.sum().float() / denom

return classification_loss, accuracy

Expand Down Expand Up @@ -1039,6 +1092,7 @@ def pseudo_speculative_generate(
else:
eagle_input_hidden_states = base_model_hidden_states

self.eagle_module._maybe_init_rope()
draft_tokens = []
for step in range(steps):
b, seq_length = eagle_ids.shape
Expand All @@ -1051,7 +1105,10 @@ def pseudo_speculative_generate(
)

# Use SDPA attention during generation for both stability and performance
with temporary_set_config_value(self.eagle_config, "_attn_implementation", "sdpa"):
with (
temporary_set_config_value(self.eagle_config, "_attn_implementation", "sdpa"),
self._nvtx_range("eagle_forward"),
):
_, eagle_prenorm_h, eagle_logits, _ = self._eagle_forward(
eagle_input_hidden_states,
self._base_model_embeddings(eagle_ids),
Expand Down
Loading