From 17a90597c5bd238c501a61237d2e12bf32328510 Mon Sep 17 00:00:00 2001 From: none0663 Date: Tue, 19 May 2026 20:20:04 +0800 Subject: [PATCH 1/3] feat: add SFT entropy logging and validation loss monitoring Add two monitoring features for SFT training to detect overfitting: 1. Training entropy (--log-sft-entropy): - Computes token-level entropy under no_grad to avoid OOM - Logged as train/entropy to TensorBoard/WandB 2. Validation loss (--val-data + --val-interval): - Full DP-parallel val loss computation with dynamic batching - Token-weighted aggregation across ranks (not rank-mean) - CP-correct reduction via get_sum_of_sample_mean - Deadlock-safe: all ranks synchronize before collective ops - Runs initial val before training for baseline - Also logs val/entropy when --log-sft-entropy is set Co-Authored-By: Claude Opus 4.6 (1M context) --- scripts/run-qwen3.5-35B-A3B-sft.sh | 18 ++ slime/backends/megatron_utils/actor.py | 32 +++ slime/backends/megatron_utils/loss.py | 45 +++- slime/backends/megatron_utils/val_loss.py | 280 ++++++++++++++++++++++ slime/ray/actor_group.py | 4 + slime/utils/arguments.py | 53 ++++ train_async.py | 8 + 7 files changed, 431 insertions(+), 9 deletions(-) create mode 100644 slime/backends/megatron_utils/val_loss.py diff --git a/scripts/run-qwen3.5-35B-A3B-sft.sh b/scripts/run-qwen3.5-35B-A3B-sft.sh index 6893133924..33522238cd 100644 --- a/scripts/run-qwen3.5-35B-A3B-sft.sh +++ b/scripts/run-qwen3.5-35B-A3B-sft.sh @@ -101,6 +101,21 @@ WANDB_ARGS=( # --wandb-group qwen3.5-35B-sft ) +TB_ARGS=( + --use-tensorboard + --tb-project-name qwen3.5-35B-A3B-sft +) + +ENTROPY_ARGS=( + --log-sft-entropy +) + +VAL_ARGS=( + # Uncomment to enable val loss monitoring (val-batch-size defaults to 64, val-input-key defaults to "messages") + # --val-data ${BASE_FOLDER}/val_data.jsonl + # --val-interval 10 +) + MISC_ARGS=( # default dropout in megatron is 0.1 --attention-dropout 0.0 @@ -157,6 +172,9 @@ ray job submit --address="http://127.0.0.1:8265" \ ${SFT_ARGS[@]} \ ${OPTIMIZER_ARGS[@]} \ ${WANDB_ARGS[@]} \ + ${TB_ARGS[@]} \ + ${ENTROPY_ARGS[@]} \ + ${VAL_ARGS[@]} \ ${PERF_ARGS[@]} \ ${EVAL_ARGS[@]} \ ${MISC_ARGS[@]} \ diff --git a/slime/backends/megatron_utils/actor.py b/slime/backends/megatron_utils/actor.py index e9887b2310..b8a5a8e03b 100644 --- a/slime/backends/megatron_utils/actor.py +++ b/slime/backends/megatron_utils/actor.py @@ -529,6 +529,38 @@ def train_actor(self, rollout_id: int, rollout_data: RolloutBatch, external_data log_perf_data(rollout_id, self.args) + def compute_val_loss(self, rollout_id: int) -> None: + """Compute validation loss with full DP coordination. + + Called periodically by train_async.py (controlled by --val-interval). + Each DP rank independently tokenizes its shard of val data and runs + forward-only; results are gathered and logged on the source rank. + """ + if self.args.debug_rollout_only: + return + + if not getattr(self.args, "val_data", None): + return + + from .val_loss import ValDataLoader, compute_val_loss + + # Lazy-initialize val data loader (each rank gets its own shard) + if not hasattr(self, "_val_data_loader"): + self._val_data_loader = ValDataLoader( + self.args, + dp_rank=mpu.get_data_parallel_rank(with_context_parallel=False), + dp_size=mpu.get_data_parallel_world_size(with_context_parallel=False), + ) + + if self.args.offload_train: + self.wake_up() + + with torch.no_grad(): + compute_val_loss(self.args, self.model, self._val_data_loader, rollout_id) + + if self.args.offload_train: + self.sleep() + @timer def save_model(self, rollout_id: int, force_sync: bool = False) -> None: if self.args.debug_rollout_only: diff --git a/slime/backends/megatron_utils/loss.py b/slime/backends/megatron_utils/loss.py index 256338fca2..23ee9312c7 100644 --- a/slime/backends/megatron_utils/loss.py +++ b/slime/backends/megatron_utils/loss.py @@ -1087,7 +1087,13 @@ def sft_loss_function( """Compute supervised fine-tuning loss over response tokens. Computes log-probabilities of the ground-truth tokens in the response - segments and returns the negative log-likelihood as the loss. + segments and returns the negative log-likelihood as the loss. Optionally + computes and logs token-level entropy when ``args.log_sft_entropy`` is set. + + Entropy is computed under ``torch.no_grad()`` since it is only used for + logging and does not participate in the loss. This avoids retaining the + extra ``[T, V]`` clone in the autograd graph which would cause OOM for + large-vocabulary models. Args: args: Configuration (passed through to helpers). @@ -1097,12 +1103,15 @@ def sft_loss_function( sum_of_sample_mean: Reduction function that averages per-sample values. Returns: - Tuple of `(loss, metrics)` where `metrics` contains a single detached - scalar "loss". + Tuple of `(loss, metrics)` where `metrics` contains detached scalars + "loss" and optionally "entropy". """ response_lengths = batch["response_lengths"] total_lengths = batch["total_lengths"] + log_entropy = getattr(args, "log_sft_entropy", False) + + # Step 1: compute log_probs for loss (with gradient) _, log_probs_and_entropy = get_log_probs_and_entropy( logits, args=args, @@ -1121,12 +1130,30 @@ def sft_loss_function( if log_probs.numel() == 0: loss += 0 * logits.sum() - return ( - loss, - { - "loss": loss.clone().detach(), - }, - ) + reported_loss = { + "loss": loss.clone().detach(), + } + + # Step 2: compute entropy for logging only (no_grad to avoid OOM) + # The logits.clone() inside calculate_log_probs_and_entropy won't be + # retained in the autograd graph, so it's freed after computation. + if log_entropy: + with torch.no_grad(): + _, entropy_result = get_log_probs_and_entropy( + logits, + args=args, + unconcat_tokens=batch["unconcat_tokens"], + total_lengths=total_lengths, + response_lengths=response_lengths, + with_entropy=True, + max_seq_lens=batch.get("max_seq_lens", None), + ) + entropy = entropy_result["entropy"] + entropy = torch.cat(entropy, dim=0) + mean_entropy = sum_of_sample_mean(entropy) + reported_loss["entropy"] = mean_entropy.detach() + + return loss, reported_loss def loss_function( diff --git a/slime/backends/megatron_utils/val_loss.py b/slime/backends/megatron_utils/val_loss.py new file mode 100644 index 0000000000..9af5037c39 --- /dev/null +++ b/slime/backends/megatron_utils/val_loss.py @@ -0,0 +1,280 @@ +"""Validation loss computation for SFT training. + +Periodically computes validation NLL loss during SFT training with full +Data Parallel coordination. Reuses the existing training infrastructure: + +- slime/utils/data.read_file() — load jsonl/parquet +- slime/utils/mask_utils.MultiTurnLossMaskGenerator — tokenize + loss mask +- slime/backends/megatron_utils/data.get_data_iterator() — dynamic batching +- slime/backends/megatron_utils/model.forward_only() — pipeline-parallel forward +- slime/backends/megatron_utils/loss.get_log_probs_and_entropy() — log_probs +- slime/backends/megatron_utils/cp_utils.get_sum_of_sample_mean() — CP-correct reduction +""" + +import logging +from argparse import Namespace +from collections.abc import Sequence + +import torch +import torch.distributed as dist +from megatron.core import mpu +from megatron.core.distributed import DistributedDataParallel as DDP + +from slime.utils import logging_utils +from slime.utils.data import read_file +from slime.utils.mask_utils import MultiTurnLossMaskGenerator +from slime.utils.metric_utils import compute_rollout_step +from slime.utils.processing_utils import load_tokenizer + +logger = logging.getLogger(__name__) + + +class ValDataLoader: + """Loads and tokenizes validation data, sharded across DP ranks. + + Reuses the same tokenization pipeline as sft_rollout.py + (MultiTurnLossMaskGenerator). Each DP rank keeps an interleaved shard + of the full dataset. Data is tokenized once at init and cached in memory. + """ + + def __init__(self, args: Namespace, dp_rank: int, dp_size: int): + self.args = args + self.dp_rank = dp_rank + self.dp_size = dp_size + self._index = 0 + self._samples = self._load_and_tokenize() + logger.info( + f"ValDataLoader: rank {dp_rank}/{dp_size}, " + f"{len(self._samples)} samples from {args.val_data}" + ) + + def _load_and_tokenize(self) -> list[dict]: + """Load val file, shard by DP rank, tokenize via MultiTurnLossMaskGenerator.""" + tokenizer = load_tokenizer(self.args.hf_checkpoint, trust_remote_code=True) + mask_generator = MultiTurnLossMaskGenerator( + tokenizer, tokenizer_type=self.args.loss_mask_type + ) + val_input_key = getattr(self.args, "val_input_key", "messages") + val_tool_key = getattr(self.args, "val_tool_key", None) + + all_records = list(read_file(self.args.val_data)) + my_records = all_records[self.dp_rank :: self.dp_size] + + samples = [] + for record in my_records: + messages = record[val_input_key] + tools = record.get(val_tool_key) if val_tool_key else None + try: + token_ids, loss_mask = mask_generator.get_loss_mask(messages, tools=tools) + except Exception as e: + logger.debug(f"Skipping val sample: {e}") + continue + + if len(token_ids) != len(loss_mask): + continue + + response_length = mask_generator.get_response_lengths([loss_mask])[0] + if response_length == 0: + continue + + samples.append({ + "token_ids": token_ids, + "loss_mask": loss_mask[-response_length:], + "response_length": response_length, + "total_length": len(token_ids), + }) + + return samples + + def get_batch(self, batch_size: int) -> dict: + """Build a RolloutBatch-compatible dict for get_data_iterator(). + + Sets `dynamic_global_batch_size` so that get_data_iterator treats + the entire val batch as one "rollout step" and applies dynamic + batching (max_tokens_per_gpu) to split into microbatches. + + Returns dict with keys: tokens, loss_masks, response_lengths, + total_lengths, dynamic_global_batch_size. + """ + if not self._samples: + return {} + + device = torch.cuda.current_device() + tokens_list = [] + loss_masks_list = [] + response_lengths = [] + total_lengths = [] + + for _ in range(batch_size): + if self._index >= len(self._samples): + self._index = 0 + + sample = self._samples[self._index] + self._index += 1 + + tokens_list.append( + torch.tensor(sample["token_ids"], dtype=torch.long, device=device) + ) + loss_masks_list.append( + torch.tensor(sample["loss_mask"], dtype=torch.int, device=device) + ) + response_lengths.append(sample["response_length"]) + total_lengths.append(sample["total_length"]) + + # Tell get_data_iterator the effective global batch size so it computes + # num_local_gbs = batch_size, num_steps_per_rollout = 1, and applies + # dynamic batching (max_tokens_per_gpu) to split into microbatches. + return { + "tokens": tokens_list, + "loss_masks": loss_masks_list, + "response_lengths": response_lengths, + "total_lengths": total_lengths, + "dynamic_global_batch_size": batch_size * self.dp_size, + } + + @property + def has_data(self) -> bool: + return len(self._samples) > 0 + + +def compute_val_loss( + args: Namespace, + model: Sequence[DDP], + val_data_loader: ValDataLoader, + rollout_id: int, +) -> None: + """Compute validation loss using the same forward pipeline as training. + + Reuses get_data_iterator (dynamic batch / CP / PP support) and + forward_only (pipeline-parallel forward pass) from the training path. + Uses get_sum_of_sample_mean for CP-correct reduction. + Loss is aggregated across DP ranks with token-weighted mean. + + All ranks MUST call this function together (even if their shard is empty) + to avoid collective deadlocks. + + Args: + args: Runtime arguments (same as training). + model: DDP-wrapped model chunks. + val_data_loader: Initialized ValDataLoader for this DP rank. + rollout_id: Current rollout step (for logging x-axis). + """ + from .cp_utils import get_sum_of_sample_mean + from .data import get_data_iterator + from .loss import get_log_probs_and_entropy + from .model import forward_only + + # --- Synchronize: ensure all ranks agree on whether to proceed --- + # Prevents deadlock if some ranks have empty val data. + has_data = torch.tensor( + [1 if val_data_loader.has_data else 0], + device=torch.cuda.current_device(), + dtype=torch.int, + ) + dist.all_reduce(has_data, op=dist.ReduceOp.MIN) + if has_data.item() == 0: + logger.warning("Some DP ranks have no val data, skipping val loss computation") + return + + val_batch = val_data_loader.get_batch(args.val_batch_size) + + # get_data_iterator handles dynamic batching (max_tokens_per_gpu), + # context parallelism, VPP, etc. — same path as training data. + data_iterator, num_microbatches = get_data_iterator(args, model, val_batch) + + # Temporarily enable entropy in forward_only if log_sft_entropy is set. + # forward_only uses args.use_rollout_entropy to decide whether to compute + # entropy. We use a separate val-specific flag to avoid mutating the RL flag. + log_entropy = getattr(args, "log_sft_entropy", False) + orig_flag = args.use_rollout_entropy + if log_entropy: + args.use_rollout_entropy = True + + # forward_only runs pipeline-parallel forward, collects log_probs + # (and entropy if enabled) on the last PP stage. + # Model is switched to eval mode internally. + result = forward_only( + get_log_probs_and_entropy, + args, + model, + data_iterator, + num_microbatches, + ) + + # Restore immediately — don't leak into RL path + args.use_rollout_entropy = orig_flag + + # --- Compute NLL on last PP stage using CP-correct reduction --- + local_nll = torch.zeros(1, device=torch.cuda.current_device()) + local_tokens = torch.zeros(1, device=torch.cuda.current_device()) + local_entropy = torch.zeros(1, device=torch.cuda.current_device()) + + if mpu.is_pipeline_last_stage() and "log_probs" in result: + log_probs_list = result["log_probs"] + response_lengths = val_batch["response_lengths"] + total_lengths = val_batch["total_lengths"] + loss_masks = val_batch["loss_masks"] + + # Use get_sum_of_sample_mean for CP-correct per-token reduction + # (handles zigzag CP slicing, allgather-CP, etc.) + sum_of_sample_mean = get_sum_of_sample_mean( + total_lengths, + response_lengths, + loss_masks, + calculate_per_token_loss=True, + qkv_format=args.qkv_format, + max_seq_lens=val_batch.get("max_seq_lens", None), + ) + log_probs_cat = torch.cat(log_probs_list, dim=0) + # sum_of_sample_mean with per_token_loss=True sums all masked log_probs + nll = -sum_of_sample_mean(log_probs_cat) + + num_tokens = sum( + torch.clamp_min(m.sum(), 0) for m in loss_masks + ) + + local_nll.fill_(nll.item()) + local_tokens.fill_(num_tokens.item()) + + # Compute entropy if available (same CP-correct reduction) + if log_entropy and "entropy" in result: + entropy_list = result["entropy"] + entropy_cat = torch.cat(entropy_list, dim=0) + entropy_sum = sum_of_sample_mean(entropy_cat) + local_entropy.fill_(entropy_sum.item()) + + # --- Token-weighted aggregation across DP ranks --- + # All-reduce sum: total_nll and total_tokens across all DP ranks. + # This gives correct token-weighted mean regardless of per-rank token count. + dp_group = mpu.get_data_parallel_group(with_context_parallel=True) + dist.all_reduce(local_nll, op=dist.ReduceOp.SUM, group=dp_group) + dist.all_reduce(local_tokens, op=dist.ReduceOp.SUM, group=dp_group) + if log_entropy: + dist.all_reduce(local_entropy, op=dist.ReduceOp.SUM, group=dp_group) + + # --- Log on primary rank --- + if ( + mpu.get_data_parallel_rank(with_context_parallel=True) == 0 + and mpu.get_tensor_model_parallel_rank() == 0 + and mpu.is_pipeline_last_stage() + ): + total_nll = local_nll.item() + total_tokens = local_tokens.item() + + if total_tokens > 0: + val_loss = total_nll / total_tokens + else: + val_loss = 0.0 + + step = compute_rollout_step(args, rollout_id) + log_dict = { + "val/loss": val_loss, + "val/num_tokens": total_tokens, + "val/step": step, + } + + if log_entropy and total_tokens > 0: + log_dict["val/entropy"] = local_entropy.item() / total_tokens + + logging_utils.log(args, log_dict, step_key="val/step") + logger.info(f"val {rollout_id}: {log_dict}") diff --git a/slime/ray/actor_group.py b/slime/ray/actor_group.py index c9ce215558..26478fd090 100644 --- a/slime/ray/actor_group.py +++ b/slime/ray/actor_group.py @@ -147,3 +147,7 @@ def clear_memory(self): def set_rollout_manager(self, rollout_manager): return ray.get([actor.set_rollout_manager.remote(rollout_manager) for actor in self._actor_handlers]) + + def compute_val_loss(self, rollout_id): + """Compute validation loss across all actor ranks.""" + return ray.get([actor.compute_val_loss.remote(rollout_id) for actor in self._actor_handlers]) diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index e8a1730782..ac687580a5 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -702,6 +702,47 @@ def add_eval_arguments(parser): parser.add_argument("--eval-min-new-tokens", type=int, default=None) parser.add_argument("--eval-max-context-len", type=int, default=None) + # Validation loss arguments (for SFT overfitting monitoring) + parser.add_argument( + "--val-data", + type=str, + default=None, + help=( + "Path to validation data (parquet/jsonl) for loss computation during training. " + "Data should contain multi-turn conversations in the column specified by --val-input-key." + ), + ) + parser.add_argument( + "--val-input-key", + type=str, + default="messages", + help="Column name for conversations in the validation data file.", + ) + parser.add_argument( + "--val-tool-key", + type=str, + default=None, + help=( + "Column name for tool definitions in the validation data file. " + "If None, tools are not passed to the tokenizer (same as SFT rollout when no tools)." + ), + ) + parser.add_argument( + "--val-batch-size", + type=int, + default=64, + help="Number of samples per validation loss computation batch (per DP rank).", + ) + parser.add_argument( + "--val-interval", + type=int, + default=None, + help=( + "Compute validation loss every N rollout steps. " + "If None, validation loss is not computed." + ), + ) + return parser def add_algo_arguments(parser): @@ -888,6 +929,15 @@ def add_algo_arguments(parser): "This is useful for doing special loss mask." ), ) + parser.add_argument( + "--log-sft-entropy", + action="store_true", + default=False, + help=( + "Whether to compute and log token-level entropy during SFT training. " + "When enabled, mean entropy is logged as 'train/entropy' to TensorBoard/WandB." + ), + ) parser.add_argument( "--get-mismatch-metrics", action="store_true", @@ -1670,6 +1720,9 @@ def slime_validate_args(args): if args.eval_interval is not None: assert args.eval_datasets, "Evaluation datasets must be configured when eval_interval is set." + if args.val_interval is not None: + assert getattr(args, "val_data", None), "--val-data is required when --val-interval is set." + if args.save_interval is not None: assert args.save is not None, "'--save' is required when save_interval is set." diff --git a/train_async.py b/train_async.py index 6960bd0558..56d95de542 100644 --- a/train_async.py +++ b/train_async.py @@ -31,6 +31,10 @@ def train(args): if args.check_weight_update_equal: ray.get(rollout_manager.check_weights.remote(action="compare")) + # Run initial val before training to record baseline loss (step 0). + if args.val_interval and getattr(args, "val_data", None): + actor_model.compute_val_loss(rollout_id=0) + # async train loop. rollout_data_next_future = rollout_manager.generate.remote(args.start_rollout_id) for rollout_id in range(args.start_rollout_id, args.num_rollout): @@ -75,6 +79,10 @@ def train(args): if should_run_periodic_action(rollout_id, args.eval_interval, num_rollout_per_epoch): ray.get(rollout_manager.eval.remote(rollout_id)) + # Compute validation loss periodically (for SFT overfitting monitoring) + if args.val_interval and should_run_periodic_action(rollout_id, args.val_interval, num_rollout_per_epoch): + actor_model.compute_val_loss(rollout_id) + ray.get(rollout_manager.dispose.remote()) finish_tracking(args) From e8c1224082be27ee462dc30e9b214457edad53c1 Mon Sep 17 00:00:00 2001 From: none0663 Date: Tue, 19 May 2026 20:42:30 +0800 Subject: [PATCH 2/3] fix: correct CP>1 val/loss underestimation and resume baseline step 1. Divide local_tokens by cp_size before all_reduce to avoid overcounting when loss_masks are replicated across CP ranks. 2. Use max(0, start_rollout_id - 1) for baseline val step to avoid discontinuity and step collision on training resume. Co-Authored-By: Claude Opus 4.6 (1M context) --- slime/backends/megatron_utils/val_loss.py | 6 +++++- train_async.py | 6 ++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/slime/backends/megatron_utils/val_loss.py b/slime/backends/megatron_utils/val_loss.py index 9af5037c39..7c6f716140 100644 --- a/slime/backends/megatron_utils/val_loss.py +++ b/slime/backends/megatron_utils/val_loss.py @@ -234,7 +234,11 @@ def compute_val_loss( ) local_nll.fill_(nll.item()) - local_tokens.fill_(num_tokens.item()) + # Divide by cp_size: each CP rank holds the full loss_mask but only + # computes a partial NLL. When all_reduced across the DP+CP group, + # NLL partials sum correctly, but tokens would be overcounted by cp_size. + cp_size = mpu.get_context_parallel_world_size() + local_tokens.fill_(num_tokens.item() / cp_size) # Compute entropy if available (same CP-correct reduction) if log_entropy and "entropy" in result: diff --git a/train_async.py b/train_async.py index 56d95de542..c296339b9b 100644 --- a/train_async.py +++ b/train_async.py @@ -31,9 +31,11 @@ def train(args): if args.check_weight_update_equal: ray.get(rollout_manager.check_weights.remote(action="compare")) - # Run initial val before training to record baseline loss (step 0). + # Run initial val before training to record baseline loss. + # Use start_rollout_id - 1 so the baseline point sits just before the + # first training step (avoids step collision and discontinuity on resume). if args.val_interval and getattr(args, "val_data", None): - actor_model.compute_val_loss(rollout_id=0) + actor_model.compute_val_loss(rollout_id=max(0, args.start_rollout_id - 1)) # async train loop. rollout_data_next_future = rollout_manager.generate.remote(args.start_rollout_id) From 98813a48896e5a573fb9683f473e8f586ee86dda Mon Sep 17 00:00:00 2001 From: none0663 Date: Tue, 19 May 2026 20:51:30 +0800 Subject: [PATCH 3/3] fix: handle small val dataset and baseline step collision edge cases 1. When val_data has fewer samples than dp_size, replicate to all ranks instead of leaving empty shards (which would skip val entirely). 2. Skip baseline val when it would collide with the first periodic val at the same step (val_interval=1 + start_rollout_id=0). Co-Authored-By: Claude Opus 4.6 (1M context) --- slime/backends/megatron_utils/val_loss.py | 10 +++++++++- train_async.py | 7 ++++++- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/slime/backends/megatron_utils/val_loss.py b/slime/backends/megatron_utils/val_loss.py index 7c6f716140..79c9a94c14 100644 --- a/slime/backends/megatron_utils/val_loss.py +++ b/slime/backends/megatron_utils/val_loss.py @@ -58,7 +58,15 @@ def _load_and_tokenize(self) -> list[dict]: val_tool_key = getattr(self.args, "val_tool_key", None) all_records = list(read_file(self.args.val_data)) - my_records = all_records[self.dp_rank :: self.dp_size] + if len(all_records) < self.dp_size: + # Replicate small datasets so every rank has data (avoids skip). + logger.info( + f"Val data ({len(all_records)} samples) < dp_size ({self.dp_size}), " + f"replicating to all ranks." + ) + my_records = all_records + else: + my_records = all_records[self.dp_rank :: self.dp_size] samples = [] for record in my_records: diff --git a/train_async.py b/train_async.py index c296339b9b..d450b03e2b 100644 --- a/train_async.py +++ b/train_async.py @@ -34,8 +34,13 @@ def train(args): # Run initial val before training to record baseline loss. # Use start_rollout_id - 1 so the baseline point sits just before the # first training step (avoids step collision and discontinuity on resume). + # Skip when val_interval=1 and start=0: the first periodic val fires + # immediately at step 0, so a separate baseline would collide. if args.val_interval and getattr(args, "val_data", None): - actor_model.compute_val_loss(rollout_id=max(0, args.start_rollout_id - 1)) + baseline_id = max(0, args.start_rollout_id - 1) + first_periodic_id = args.val_interval - 1 + args.start_rollout_id + if baseline_id < first_periodic_id: + actor_model.compute_val_loss(rollout_id=baseline_id) # async train loop. rollout_data_next_future = rollout_manager.generate.remote(args.start_rollout_id)