From be460ad0809d55a4dcecef7a9617dcdfb2779e09 Mon Sep 17 00:00:00 2001 From: Benjamin Chislett Date: Thu, 19 Mar 2026 04:29:55 +0000 Subject: [PATCH 1/3] add tensorboard support Signed-off-by: Benjamin Chislett --- examples/speculative_decoding/eagle_utils.py | 71 ++++++++++++++++--- examples/speculative_decoding/launch_train.sh | 13 +++- examples/speculative_decoding/main.py | 24 ++++++- 3 files changed, 95 insertions(+), 13 deletions(-) diff --git a/examples/speculative_decoding/eagle_utils.py b/examples/speculative_decoding/eagle_utils.py index 8c96a19a76..001363e977 100644 --- a/examples/speculative_decoding/eagle_utils.py +++ b/examples/speculative_decoding/eagle_utils.py @@ -29,6 +29,7 @@ from packaging.version import Version from scripts.ar_validate import validate_ar from torch.utils.data import Dataset +from torch.utils.tensorboard import SummaryWriter from transformers import Trainer, TrainerCallback from transformers.trainer_pt_utils import LabelSmoother @@ -170,34 +171,47 @@ def make_eagle_supervised_data_module( class EagleTrainerWithAccLog(Trainer): """Wrapper around Trainer that logs training accuracy.""" + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.model_accepts_loss_kwargs = False + def compute_loss(self, *args, **kwargs): """Override compute_loss to save train accs in trainer state.""" if not hasattr(self.state, "training_accs"): self.state.training_accs = [] kwargs.pop("num_items_in_batch", None) - loss, outputs = super().compute_loss(return_outputs=True, *args, **kwargs) + return_outputs = kwargs.pop("return_outputs", False) + loss, outputs = super().compute_loss(*args, return_outputs=True, **kwargs) if hasattr(outputs, "train_acc"): self.state.training_accs.append(outputs.train_acc) - return loss + return (loss, outputs) if return_outputs else loss class EagleTrainingPlot(TrainerCallback): """Callback that plot training acc and AR during training.""" - def __init__(self, ar_validate_steps: int = 1000, estimate_ar: bool = False): + def __init__( + self, + ar_validate_steps: int = 1000, + estimate_ar: bool = False, + tb_writer: SummaryWriter | None = None, + ): self.ar_validate_steps = ar_validate_steps if wandb and is_master(): wandb.init() self.estimate_ar = estimate_ar + self.tb_writer = tb_writer + self.last_seen_step = -1 - def on_log(self, args, state, control, **kwargs): - """Log training acc and estimate AR during log step.""" + def _report_stats(self, state, eval_mode: bool, **kwargs): if not hasattr(state, "training_accs") or len(state.training_accs) == 0: - return control + return average_acc = np.mean(state.training_accs, axis=0) + mode_name = "Eval" if eval_mode else "Training" + mode_id = mode_name.lower() if self.estimate_ar: # Calculate mean training AR since last log - # NOTE: This is only an estimate of the real AR. + # NOTE: This is only a estimate of the real AR. est_ar = 1 acc_cumprod = 1 for step_acc in average_acc[0]: @@ -207,7 +221,7 @@ def on_log(self, args, state, control, **kwargs): for draft_acc in average_acc[1:]: acc_cumprod *= draft_acc[-1] est_ar += acc_cumprod - print_rank_0(f"Step {state.global_step} Estimated Training AR: {est_ar:.4f}") + print_rank_0(f"Step {state.global_step} Estimated {mode_name} AR: {est_ar:.4f}") # log to wandb if wandb and is_master(): @@ -217,11 +231,44 @@ def on_log(self, args, state, control, **kwargs): for i, draft_acc in enumerate(average_acc): for j, step_acc in enumerate(draft_acc): wandb.log( - {f"parallel_{i}_step_{j}_train_acc": step_acc}, step=state.global_step + {f"parallel_{i}_step_{j}_{mode_id}_acc": step_acc}, step=state.global_step + ) + if self.estimate_ar: + wandb.log({f"estimated_{mode_id}_ar": est_ar}, step=state.global_step) + + if self.tb_writer: + # TODO: What are in "kwargs.logs"? + for i, draft_acc in enumerate(average_acc): + for j, step_acc in enumerate(draft_acc): + self.tb_writer.add_scalar( + f"{mode_id}/parallel_{i}_step_{j}_{mode_id}_acc", + step_acc, + state.global_step, ) if self.estimate_ar: - wandb.log({"estimated_training_ar": est_ar}, step=state.global_step) + self.tb_writer.add_scalar(f"{mode_id}/estimated_ar", est_ar, state.global_step) + def on_log(self, args, state, control, **kwargs): + """Log training acc and estimate AR during log step.""" + if not hasattr(state, "training_accs") or len(state.training_accs) == 0: + self.last_seen_step = state.global_step + return control + + if state.global_step != self.last_seen_step: + # Eval mode doesn't increment the global step, so we can use that to detect eval vs training + self._report_stats(state, eval_mode=False, **kwargs) + # reset training_accs + state.training_accs = [] + + self.last_seen_step = state.global_step + return control + + def on_evaluate(self, args, state, control, **kwargs): + """Log eval acc and estimate AR during eval step.""" + if not hasattr(state, "training_accs") or len(state.training_accs) == 0: + return control + + self._report_stats(state, eval_mode=True, **kwargs) # reset training_accs state.training_accs = [] return control @@ -242,6 +289,10 @@ def on_step_end(self, args, state, control, **kwargs): print_rank_0(f"Step {state.global_step} AR: {sum(ars) / len(ars):.4f}") if wandb and is_master(): wandb.log({"validate_ar": sum(ars) / len(ars)}, step=state.global_step) + if self.tb_writer: + self.tb_writer.add_scalar( + "custom/validate_ar", sum(ars) / len(ars), state.global_step + ) except Exception: print_rank_0("AR validation not available.") return control diff --git a/examples/speculative_decoding/launch_train.sh b/examples/speculative_decoding/launch_train.sh index 079c40da71..4f7a5d9219 100755 --- a/examples/speculative_decoding/launch_train.sh +++ b/examples/speculative_decoding/launch_train.sh @@ -118,6 +118,10 @@ while [ $# -gt 0 ]; do if [[ "$1" != *=* ]]; then shift; fi MIX_HIDDEN_STATES="${1#*=}" ;; + --tensorboard*) + if [[ "$1" != *=* ]]; then shift; fi + ENABLE_TENSORBOARD="${1#*=}" + ;; *) >&2 printf "Error: Invalid argument ${1#*=}\n" exit 1 @@ -159,7 +163,7 @@ LOG_STEPS=${LOG_STEPS:-100} DRAFT_VOCAB_CACHE=${DRAFT_VOCAB_CACHE:-""} MIX_HIDDEN_STATES=${MIX_HIDDEN_STATES:-"False"} NUM_TTT_STEPS=${NUM_TTT_STEPS:-3} - +ENABLE_TENSORBOARD=${ENABLE_TENSORBOARD:-"False"} if [[ "$MODE" == "eagle3" ]]; then if [[ -n "$EAGLE_CONFIG" ]]; then @@ -216,6 +220,12 @@ else MULTI_NODE_ARGS="" fi +if [[ "$ENABLE_TENSORBOARD" != "False" ]]; then + OBSERVABILITY_ARGS="--report_to tensorboard" +else + OBSERVABILITY_ARGS="" +fi + # Disable tokenizers parallelism to avoid warning export TOKENIZERS_PARALLELISM=False CMD="accelerate launch $MULTI_NODE_ARGS --mixed_precision bf16 ${SCRIPT_DIR}/main.py \ @@ -253,6 +263,7 @@ CMD="accelerate launch $MULTI_NODE_ARGS --mixed_precision bf16 ${SCRIPT_DIR}/mai --cp_size $CP_SIZE \ --dp_shard_size $DP_SHARD_SIZE \ --num_ttt_steps $NUM_TTT_STEPS \ + $OBSERVABILITY_ARGS \ " start_time=$(date +%s) diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index cd7ef34758..413d5e3a58 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -43,6 +43,8 @@ make_eagle_supervised_data_module, patch_ring_attention_for_ttt, ) +from torch.utils.tensorboard import SummaryWriter +from transformers.integrations import TensorBoardCallback from transformers.trainer_utils import get_last_checkpoint import modelopt.torch.opt as mto @@ -102,7 +104,7 @@ class TrainingArguments(transformers.TrainingArguments): bf16: bool = field(default=True) mode: Literal["eagle3", "medusa"] = "eagle3" estimate_ar: bool = field( - default=False, metadata={"help": "Whether to estimate AR during training for logging."} + default=True, metadata={"help": "Whether to estimate AR during training for logging."} ) ar_validate_steps: int = field(default=1000, metadata={"help": "Steps between AR validation."}) disable_tqdm: bool = field(default=False, metadata={"help": "Disable tqdm progress bar."}) @@ -235,11 +237,29 @@ def train(): tokenizer, data_args, train_len=training_args.training_seq_len ) + callbacks = [] + tb_writer = None + if "tensorboard" in training_args.report_to: + log_dir = training_args.output_dir + tb_writer = SummaryWriter(log_dir=log_dir) + if isinstance(training_args.report_to, list): + training_args.report_to.remove("tensorboard") + else: + training_args.report_to = "none" + callbacks.append(TensorBoardCallback(tb_writer=tb_writer)) + callbacks.append( + EagleTrainingPlot( + training_args.ar_validate_steps, + tb_writer=tb_writer, + estimate_ar=training_args.estimate_ar, + ) + ) + trainer = EagleTrainerWithAccLog( model=model, processing_class=tokenizer, args=training_args, - callbacks=[EagleTrainingPlot(training_args.ar_validate_steps, training_args.estimate_ar)], + callbacks=callbacks, **data_module, ) From e1a7720d1577a1bf4dfc4d0b8e08ab7a116dbdee Mon Sep 17 00:00:00 2001 From: Benjamin Chislett Date: Thu, 19 Mar 2026 04:39:19 +0000 Subject: [PATCH 2/3] estimate_ar on by default Signed-off-by: Benjamin Chislett --- examples/speculative_decoding/launch_train.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/speculative_decoding/launch_train.sh b/examples/speculative_decoding/launch_train.sh index 4f7a5d9219..c10dd406e1 100755 --- a/examples/speculative_decoding/launch_train.sh +++ b/examples/speculative_decoding/launch_train.sh @@ -156,7 +156,7 @@ DISABLE_TQDM=${DISABLE_TQDM:-False} VLM_PROCESSOR=${VLM_PROCESSOR:-} VLM_IMG_DIR=${VLM_IMG_DIR:-} AR_VALIDATE_STEPS=${AR_VALIDATE_STEPS:-1000} -ESTIMATE_AR=${ESTIMATE_AR:-False} +ESTIMATE_AR=${ESTIMATE_AR:-True} CP_SIZE=${CP_SIZE:-1} DP_SHARD_SIZE=${DP_SHARD_SIZE:-$((TOTAL_GPU/CP_SIZE))} LOG_STEPS=${LOG_STEPS:-100} From 8a442e4428ac1e709ebed17fd5ba13ff8c25c29a Mon Sep 17 00:00:00 2001 From: Benjamin Chislett Date: Thu, 19 Mar 2026 04:42:24 +0000 Subject: [PATCH 3/3] update EAGLE README.md Signed-off-by: Benjamin Chislett --- examples/speculative_decoding/README.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/examples/speculative_decoding/README.md b/examples/speculative_decoding/README.md index 2a29f644e6..ef98e6fd20 100644 --- a/examples/speculative_decoding/README.md +++ b/examples/speculative_decoding/README.md @@ -319,6 +319,11 @@ trainer.save_state() trainer.save_model("") ``` +### Observability + +If W&B is installed, it will be used automatically for logging. If `--tensorboard` is provided, +it will be used instead, outputting data into the provided log directory. + ## Support Matrix | Model | Medusa | EAGLE1/2 | EAGLE3 |