From c8facf8acc0333c605ebc68e140f9127fb1c0356 Mon Sep 17 00:00:00 2001 From: Benjamin Chislett Date: Thu, 19 Mar 2026 02:54:47 +0000 Subject: [PATCH 1/2] bucket-based sequence length padding Signed-off-by: Benjamin Chislett --- examples/speculative_decoding/eagle_utils.py | 20 ++++++++--- examples/speculative_decoding/launch_train.sh | 7 +++- examples/speculative_decoding/main.py | 19 ++++++++++- .../torch/speculative/plugins/transformers.py | 16 ++++----- .../utils/plugins/transformers_dataset.py | 33 +++++++++++++++++-- 5 files changed, 79 insertions(+), 16 deletions(-) diff --git a/examples/speculative_decoding/eagle_utils.py b/examples/speculative_decoding/eagle_utils.py index 8c96a19a76..bf9e7ad3b3 100644 --- a/examples/speculative_decoding/eagle_utils.py +++ b/examples/speculative_decoding/eagle_utils.py @@ -40,6 +40,7 @@ LanguageDataCollator, ShardedDataset, VisionLanguageDataCollator, + _get_bucket_size, ) try: @@ -89,8 +90,9 @@ def __getitem__(self, i) -> dict[str, torch.Tensor]: class EagleOfflineDataCollator: """Data collator that truncate or pads data for offline training.""" - def __init__(self, train_len): + def __init__(self, train_len, bucket_granularity=0): self.train_len = train_len + self.bucket_granularity = bucket_granularity def _pad_or_truncate(self, x: torch.Tensor, length: int, dim: int = 0): """Pad or truncate a tensor to length along a given dimension.""" @@ -110,13 +112,19 @@ def _pad_or_truncate(self, x: torch.Tensor, length: int, dim: int = 0): return out def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]: + if self.bucket_granularity > 0: + batch_max = max(item["input_ids"].shape[0] for item in features) + pad_len = _get_bucket_size(batch_max, self.train_len, self.bucket_granularity) + else: + pad_len = self.train_len + base_batch = { - k: torch.stack([self._pad_or_truncate(item[k], self.train_len) for item in features]) + k: torch.stack([self._pad_or_truncate(item[k], pad_len) for item in features]) for k in ["input_ids", "attention_mask", "loss_mask", "labels"] } base_model_outputs = { - k: torch.stack([self._pad_or_truncate(item[k], self.train_len) for item in features]) + k: torch.stack([self._pad_or_truncate(item[k], pad_len) for item in features]) for k in ["base_model_hidden_states", "aux_hidden_states"] } @@ -131,6 +139,7 @@ def make_eagle_supervised_data_module( tokenizer: transformers.PreTrainedTokenizer, data_args, train_len=None, + bucket_granularity=0, ) -> dict: if data_args.offline_data_path is None: train_dataset = ShardedDataset("json", data_files=data_args.data_path) @@ -140,6 +149,7 @@ def make_eagle_supervised_data_module( tokenizer=tokenizer, train_len=train_len, return_labels=True, + bucket_granularity=bucket_granularity, ) else: data_collator = VisionLanguageDataCollator( @@ -159,7 +169,9 @@ def make_eagle_supervised_data_module( raise ValueError(f"No .pt files found in {data_args.offline_data_path}") train_dataset = OfflineSupervisedDataset(dumped_files) - data_collator = EagleOfflineDataCollator(train_len=train_len) + data_collator = EagleOfflineDataCollator( + train_len=train_len, bucket_granularity=bucket_granularity + ) return { "train_dataset": train_dataset, diff --git a/examples/speculative_decoding/launch_train.sh b/examples/speculative_decoding/launch_train.sh index 0ffe17486d..4e699646fb 100755 --- a/examples/speculative_decoding/launch_train.sh +++ b/examples/speculative_decoding/launch_train.sh @@ -110,6 +110,10 @@ while [ $# -gt 0 ]; do if [[ "$1" != *=* ]]; then shift; fi NUM_NODES="${1#*=}" ;; + --bucket_granularity*) + if [[ "$1" != *=* ]]; then shift; fi + BUCKET_GRANULARITY="${1#*=}" + ;; --head_node_ip*) if [[ "$1" != *=* ]]; then shift; fi HEAD_NODE_IP="${1#*=}" @@ -164,7 +168,7 @@ DRAFT_VOCAB_CACHE=${DRAFT_VOCAB_CACHE:-""} MIX_HIDDEN_STATES=${MIX_HIDDEN_STATES:-"False"} DISABLE_TORCH_COMPILE=${DISABLE_TORCH_COMPILE:-"False"} NUM_TTT_STEPS=${NUM_TTT_STEPS:-3} - +BUCKET_GRANULARITY=${BUCKET_GRANULARITY:-512} if [[ "$MODE" == "eagle3" ]]; then if [[ -n "$EAGLE_CONFIG" ]]; then @@ -259,6 +263,7 @@ CMD="accelerate launch $MULTI_NODE_ARGS --mixed_precision bf16 ${SCRIPT_DIR}/mai --cp_size $CP_SIZE \ --dp_shard_size $DP_SHARD_SIZE \ --num_ttt_steps $NUM_TTT_STEPS \ + --bucket_granularity $BUCKET_GRANULARITY \ " start_time=$(date +%s) diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index 0db3867ccb..ea20e94058 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -111,6 +111,15 @@ class TrainingArguments(transformers.TrainingArguments): ) cp_size: int = field(default=1, metadata={"help": "Context parallelism size."}) dp_shard_size: int = field(default=1, metadata={"help": "Data parallelism shard size."}) + bucket_granularity: int = field( + default=512, + metadata={ + "help": ( + "Pad sequences to the nearest multiple of this value instead of training_seq_len. " + "Set to 0 to disable (always pad to training_seq_len)." + ) + }, + ) @dataclass @@ -237,8 +246,16 @@ def train(): print_rank_0("Loading dataset...") if training_args.mode == "eagle3": + bucket_gran = training_args.bucket_granularity + if bucket_gran > 0 and training_args.cp_size > 1: + from math import lcm + + bucket_gran = lcm(bucket_gran, training_args.cp_size) data_module = make_eagle_supervised_data_module( - tokenizer, data_args, train_len=training_args.training_seq_len + tokenizer, + data_args, + train_len=training_args.training_seq_len, + bucket_granularity=bucket_gran, ) trainer = EagleTrainerWithAccLog( diff --git a/modelopt/torch/speculative/plugins/transformers.py b/modelopt/torch/speculative/plugins/transformers.py index 4be730d371..5bd1c58bc6 100644 --- a/modelopt/torch/speculative/plugins/transformers.py +++ b/modelopt/torch/speculative/plugins/transformers.py @@ -626,6 +626,7 @@ def _activate_torch_compile(self): import torch._dynamo torch._dynamo.config.suppress_errors = True # Allow fallback to eager mode + torch._dynamo.config.cache_size_limit = 2048 # Individual try-catch for each function to maximize torch.compile usage try: @@ -634,24 +635,23 @@ def _activate_torch_compile(self): print("Disabling torch.compile for _prepare_eagle_inputs due to compilation error.") try: - self._eagle_forward = torch.compile( - self._eagle_forward, dynamic=False, mode="max-autotune" - ) + self._eagle_forward = torch.compile(self._eagle_forward, dynamic=False) except Exception: print("Disabling torch.compile for _eagle_forward due to compilation error.") try: - self._eagle_loss = torch.compile(self._eagle_loss, dynamic=False, fullgraph=True) + self._eagle_loss = torch.compile(self._eagle_loss, dynamic=False) except Exception: print("Disabling torch.compile for _eagle_loss due to compilation error.") def _get_ttt_attention_mask(self, batch_size, seq_length, ttt_step): # compile and cached flex attention masks in first call - if ttt_step not in self._cached_attn_blk_masks: - self._cached_attn_blk_masks.update( - {ttt_step: self._compute_ttt_attention_mask(batch_size, seq_length, ttt_step)} + cache_key = (ttt_step, seq_length) + if cache_key not in self._cached_attn_blk_masks: + self._cached_attn_blk_masks[cache_key] = self._compute_ttt_attention_mask( + batch_size, seq_length, ttt_step ) - return self._cached_attn_blk_masks[ttt_step] + return self._cached_attn_blk_masks[cache_key] def _prepare_decoder_attention_mask( self, attention_mask, input_shape, past_key_values_length, device, dtype diff --git a/modelopt/torch/utils/plugins/transformers_dataset.py b/modelopt/torch/utils/plugins/transformers_dataset.py index e147ebf2c2..dc91bec02e 100644 --- a/modelopt/torch/utils/plugins/transformers_dataset.py +++ b/modelopt/torch/utils/plugins/transformers_dataset.py @@ -20,6 +20,7 @@ import os import torch +import torch.nn.functional as F import transformers from datasets import load_dataset from transformers.trainer_pt_utils import LabelSmoother @@ -112,6 +113,13 @@ def _load_dataset(self): self._raw_samples = shard +def _get_bucket_size(seq_len: int, max_len: int, granularity: int) -> int: + """Round seq_len up to the nearest multiple of granularity, capped at max_len.""" + if granularity <= 0: + return max_len + return min(((seq_len + granularity - 1) // granularity) * granularity, max_len) + + class LanguageDataCollator: """Data collator for language modeling tasks. @@ -129,6 +137,7 @@ def __init__( answer_only_loss: bool = False, json_key: str = "text", return_labels: bool = False, + bucket_granularity: int = 0, ): """Initialize the LanguageDataset.""" if not isinstance(tokenizer, transformers.PreTrainedTokenizerBase): @@ -143,6 +152,7 @@ def __init__( self.answer_only_loss = answer_only_loss self.json_key = json_key self.return_labels = return_labels + self.bucket_granularity = bucket_granularity if chat_template is not None: self.tokenizer.chat_template = chat_template @@ -172,16 +182,19 @@ def _post_process_chat_template(self): ) def _process_chat_sample(self, examples: list): + padding = "longest" if self.bucket_granularity > 0 else "max_length" tokenized_examples = self.tokenizer.apply_chat_template( examples, return_tensors="pt", return_dict=True, - padding="max_length", + padding=padding, truncation=True, max_length=self.train_len, add_generation_prompt=self.add_generation_prompt, return_assistant_tokens_mask=self.answer_only_loss, ) + if self.bucket_granularity > 0: + tokenized_examples = self._pad_to_bucket(tokenized_examples) if self.return_labels: input_ids = tokenized_examples["input_ids"] labels = input_ids.new_full(input_ids.shape, IGNORE_TOKEN_ID) @@ -189,14 +202,30 @@ def _process_chat_sample(self, examples: list): tokenized_examples["labels"] = labels return tokenized_examples + def _pad_to_bucket(self, tokenized_examples): + cur_len = tokenized_examples["input_ids"].shape[1] + bucket_len = _get_bucket_size(cur_len, self.train_len, self.bucket_granularity) + pad_size = bucket_len - cur_len + if pad_size > 0: + tokenized_examples["input_ids"] = F.pad( + tokenized_examples["input_ids"], (0, pad_size), value=self.tokenizer.pad_token_id + ) + tokenized_examples["attention_mask"] = F.pad( + tokenized_examples["attention_mask"], (0, pad_size), value=0 + ) + return tokenized_examples + def _process_text_sample(self, examples: list): + padding = "longest" if self.bucket_granularity > 0 else "max_length" tokenized_examples = self.tokenizer( examples, return_tensors="pt", - padding="max_length", + padding=padding, truncation=True, max_length=self.train_len, ) + if self.bucket_granularity > 0: + tokenized_examples = self._pad_to_bucket(tokenized_examples) return tokenized_examples def __call__(self, examples): From 4a645282b5629192f94a2e9d4fb67c84b6c5f6de Mon Sep 17 00:00:00 2001 From: Benjamin Chislett Date: Thu, 19 Mar 2026 03:48:45 +0000 Subject: [PATCH 2/2] fix compile cache thrash in pseudo spec generate Signed-off-by: Benjamin Chislett --- modelopt/torch/speculative/plugins/transformers.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/modelopt/torch/speculative/plugins/transformers.py b/modelopt/torch/speculative/plugins/transformers.py index 5bd1c58bc6..eb4d21d493 100644 --- a/modelopt/torch/speculative/plugins/transformers.py +++ b/modelopt/torch/speculative/plugins/transformers.py @@ -1100,7 +1100,10 @@ def pseudo_speculative_generate( ) # Use SDPA attention during generation for both stability and performance - with temporary_set_config_value(self.eagle_config, "_attn_implementation", "sdpa"): + with ( + temporary_set_config_value(self.eagle_config, "_attn_implementation", "sdpa"), + torch.compiler.set_stance("force_eager"), + ): _, eagle_prenorm_h, eagle_logits, _ = self._eagle_forward( eagle_input_hidden_states, self._base_model_embeddings(eagle_ids),